fix redis loader for hop

This commit is contained in:
ginuerzh 2024-05-08 21:26:15 +08:00
parent 40f709880d
commit 754b2fdeac

View File

@ -1,7 +1,6 @@
package hop package hop
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"io" "io"
@ -181,7 +180,11 @@ func (p *chainHop) filterByHost(host string, nodes ...*chain.Node) (filters []*c
if node == nil { if node == nil {
continue continue
} }
vhost := node.Options().Host
var vhost string
if filter := node.Options().Filter; filter != nil {
vhost = filter.Host
}
if vhost == "" { // backup node if vhost == "" { // backup node
if !found { if !found {
filters = append(filters, node) filters = append(filters, node)
@ -216,14 +219,18 @@ func (p *chainHop) filterByProtocol(protocol string, nodes ...*chain.Node) (filt
continue continue
} }
if node.Options().Protocol == "" { var prot string
if filter := node.Options().Filter; filter != nil {
prot = filter.Protocol
}
if prot == "" {
if !found { if !found {
filters = append(filters, node) filters = append(filters, node)
} }
continue continue
} }
if node.Options().Protocol == protocol { if prot == protocol {
if !found { if !found {
filters = nil filters = nil
} }
@ -244,19 +251,31 @@ func (p *chainHop) filterByPath(path string, nodes ...*chain.Node) (filters []*c
p.options.logger.Debugf("filter by path: %s", path) p.options.logger.Debugf("filter by path: %s", path)
sort.SliceStable(nodes, func(i, j int) bool { sort.SliceStable(nodes, func(i, j int) bool {
return len(nodes[i].Options().Path) > len(nodes[j].Options().Path) filter1 := nodes[i].Options().Filter
if filter1 == nil {
return false
}
filter2 := nodes[j].Options().Filter
if filter2 == nil {
return true
}
return len(filter1.Path) > len(filter2.Path)
}) })
found := false found := false
for _, node := range nodes { for _, node := range nodes {
if node.Options().Path == "" { var pathFilter string
if filter := node.Options().Filter; filter != nil {
pathFilter = filter.Path
}
if pathFilter == "" {
if !found { if !found {
filters = append(filters, node) filters = append(filters, node)
} }
continue continue
} }
if strings.HasPrefix(path, node.Options().Path) { if strings.HasPrefix(path, pathFilter) {
if !found { if !found {
filters = nil filters = nil
} }
@ -308,33 +327,30 @@ func (p *chainHop) reload(ctx context.Context) (err error) {
} }
func (p *chainHop) load(ctx context.Context) (nodes []*chain.Node, err error) { func (p *chainHop) load(ctx context.Context) (nodes []*chain.Node, err error) {
if p.options.fileLoader != nil { if loader := p.options.fileLoader; loader != nil {
r, er := p.options.fileLoader.Load(ctx) r, er := loader.Load(ctx)
if er != nil { if er != nil {
p.options.logger.Warnf("file loader: %v", er) p.options.logger.Warnf("file loader: %v", er)
} }
nodes, _ = p.parseNode(r) nodes, _ = p.parseNode(r)
} }
if p.options.redisLoader != nil { if loader := p.options.redisLoader; loader != nil {
if lister, ok := p.options.redisLoader.(loader.Lister); ok { r, er := loader.Load(ctx)
list, er := lister.List(ctx)
if er != nil { if er != nil {
p.options.logger.Warnf("redis loader: %v", er) p.options.logger.Warnf("redis loader: %v", er)
} }
for _, s := range list { ns, _ := p.parseNode(r)
nl, _ := p.parseNode(bytes.NewReader([]byte(s))) nodes = append(nodes, ns...)
nodes = append(nodes, nl...)
} }
}
} if loader := p.options.httpLoader; loader != nil {
if p.options.httpLoader != nil { r, er := loader.Load(ctx)
r, er := p.options.httpLoader.Load(ctx)
if er != nil { if er != nil {
p.options.logger.Warnf("http loader: %v", er) p.options.logger.Warnf("http loader: %v", er)
} }
if node, _ := p.parseNode(r); node != nil { if ns, _ := p.parseNode(r); ns != nil {
nodes = append(nodes, node...) nodes = append(nodes, ns...)
} }
} }
@ -342,6 +358,10 @@ func (p *chainHop) load(ctx context.Context) (nodes []*chain.Node, err error) {
} }
func (p *chainHop) parseNode(r io.Reader) ([]*chain.Node, error) { func (p *chainHop) parseNode(r io.Reader) ([]*chain.Node, error) {
if r == nil {
return nil, nil
}
var ncs []*config.NodeConfig var ncs []*config.NodeConfig
if err := json.NewDecoder(r).Decode(&ncs); err != nil { if err := json.NewDecoder(r).Decode(&ncs); err != nil {
return nil, err return nil, err