diff --git a/hop/hop.go b/hop/hop.go index 689bb30..c98623b 100644 --- a/hop/hop.go +++ b/hop/hop.go @@ -1,7 +1,6 @@ package hop import ( - "bytes" "context" "encoding/json" "io" @@ -181,7 +180,11 @@ func (p *chainHop) filterByHost(host string, nodes ...*chain.Node) (filters []*c if node == nil { continue } - vhost := node.Options().Host + + var vhost string + if filter := node.Options().Filter; filter != nil { + vhost = filter.Host + } if vhost == "" { // backup node if !found { filters = append(filters, node) @@ -216,14 +219,18 @@ func (p *chainHop) filterByProtocol(protocol string, nodes ...*chain.Node) (filt continue } - if node.Options().Protocol == "" { + var prot string + if filter := node.Options().Filter; filter != nil { + prot = filter.Protocol + } + if prot == "" { if !found { filters = append(filters, node) } continue } - if node.Options().Protocol == protocol { + if prot == protocol { if !found { filters = nil } @@ -244,19 +251,31 @@ func (p *chainHop) filterByPath(path string, nodes ...*chain.Node) (filters []*c p.options.logger.Debugf("filter by path: %s", path) sort.SliceStable(nodes, func(i, j int) bool { - return len(nodes[i].Options().Path) > len(nodes[j].Options().Path) + filter1 := nodes[i].Options().Filter + if filter1 == nil { + return false + } + filter2 := nodes[j].Options().Filter + if filter2 == nil { + return true + } + return len(filter1.Path) > len(filter2.Path) }) found := false for _, node := range nodes { - if node.Options().Path == "" { + var pathFilter string + if filter := node.Options().Filter; filter != nil { + pathFilter = filter.Path + } + if pathFilter == "" { if !found { filters = append(filters, node) } continue } - if strings.HasPrefix(path, node.Options().Path) { + if strings.HasPrefix(path, pathFilter) { if !found { filters = nil } @@ -308,33 +327,30 @@ func (p *chainHop) reload(ctx context.Context) (err error) { } func (p *chainHop) load(ctx context.Context) (nodes []*chain.Node, err error) { - if p.options.fileLoader != nil { - r, er := p.options.fileLoader.Load(ctx) + if loader := p.options.fileLoader; loader != nil { + r, er := loader.Load(ctx) if er != nil { p.options.logger.Warnf("file loader: %v", er) } nodes, _ = p.parseNode(r) } - if p.options.redisLoader != nil { - if lister, ok := p.options.redisLoader.(loader.Lister); ok { - list, er := lister.List(ctx) - if er != nil { - p.options.logger.Warnf("redis loader: %v", er) - } - for _, s := range list { - nl, _ := p.parseNode(bytes.NewReader([]byte(s))) - nodes = append(nodes, nl...) - } + if loader := p.options.redisLoader; loader != nil { + r, er := loader.Load(ctx) + if er != nil { + p.options.logger.Warnf("redis loader: %v", er) } + ns, _ := p.parseNode(r) + nodes = append(nodes, ns...) } - if p.options.httpLoader != nil { - r, er := p.options.httpLoader.Load(ctx) + + if loader := p.options.httpLoader; loader != nil { + r, er := loader.Load(ctx) if er != nil { p.options.logger.Warnf("http loader: %v", er) } - if node, _ := p.parseNode(r); node != nil { - nodes = append(nodes, node...) + if ns, _ := p.parseNode(r); ns != nil { + nodes = append(nodes, ns...) } } @@ -342,6 +358,10 @@ func (p *chainHop) load(ctx context.Context) (nodes []*chain.Node, err error) { } func (p *chainHop) parseNode(r io.Reader) ([]*chain.Node, error) { + if r == nil { + return nil, nil + } + var ncs []*config.NodeConfig if err := json.NewDecoder(r).Decode(&ncs); err != nil { return nil, err