From c4e9b354844e149a90a7ba36fdc9f332acea3825 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sun, 30 Jan 2022 00:43:36 +0800 Subject: [PATCH] add resolver & hosts for chain --- pkg/chain/resovle.go | 44 ++++++++++++++++++ pkg/chain/route.go | 103 ++++++++++++++++++++++++------------------- pkg/chain/router.go | 39 +++------------- 3 files changed, 107 insertions(+), 79 deletions(-) create mode 100644 pkg/chain/resovle.go diff --git a/pkg/chain/resovle.go b/pkg/chain/resovle.go new file mode 100644 index 0000000..2a736b9 --- /dev/null +++ b/pkg/chain/resovle.go @@ -0,0 +1,44 @@ +package chain + +import ( + "context" + "fmt" + "net" + + "github.com/go-gost/gost/pkg/hosts" + "github.com/go-gost/gost/pkg/logger" + "github.com/go-gost/gost/pkg/resolver" +) + +func resolve(ctx context.Context, addr string, resolver resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) { + if addr == "" { + return addr, nil + } + + host, port, err := net.SplitHostPort(addr) + if err != nil { + return "", err + } + if host == "" { + return addr, nil + } + + if hosts != nil { + if ips, _ := hosts.Lookup("ip", host); len(ips) > 0 { + log.Debugf("hit host mapper: %s -> %s", host, ips) + return net.JoinHostPort(ips[0].String(), port), nil + } + } + + if resolver != nil { + ips, err := resolver.Resolve(ctx, host) + if err != nil { + log.Error(err) + } + if len(ips) == 0 { + return "", fmt.Errorf("resolver: domain %s does not exists", host) + } + return net.JoinHostPort(ips[0].String(), port), nil + } + return addr, nil +} diff --git a/pkg/chain/route.go b/pkg/chain/route.go index 21ac9d6..ff3b89b 100644 --- a/pkg/chain/route.go +++ b/pkg/chain/route.go @@ -16,57 +16,14 @@ var ( ) type route struct { - nodes []*Node + nodes []*Node + logger logger.Logger } func (r *route) AddNode(node *Node) { r.nodes = append(r.nodes, node) } -func (r *route) connect(ctx context.Context) (conn net.Conn, err error) { - if r.IsEmpty() { - return nil, ErrEmptyRoute - } - - node := r.nodes[0] - cc, err := node.Transport.Dial(ctx, r.nodes[0].Addr) - if err != nil { - node.Marker.Mark() - return - } - - cn, err := node.Transport.Handshake(ctx, cc) - if err != nil { - cc.Close() - node.Marker.Mark() - return - } - node.Marker.Reset() - - preNode := node - for _, node := range r.nodes[1:] { - cc, err = preNode.Transport.Connect(ctx, cn, "tcp", node.Addr) - if err != nil { - cn.Close() - node.Marker.Mark() - return - } - cc, err = node.Transport.Handshake(ctx, cc) - if err != nil { - cn.Close() - node.Marker.Mark() - return - } - node.Marker.Reset() - - cn = cc - preNode = node - } - - conn = cn - return -} - func (r *route) Dial(ctx context.Context, network, address string) (net.Conn, error) { if r.IsEmpty() { return r.dialDirect(ctx, network, address) @@ -117,6 +74,62 @@ func (r *route) Bind(ctx context.Context, network, address string, opts ...conne return ln, nil } +func (r *route) connect(ctx context.Context) (conn net.Conn, err error) { + if r.IsEmpty() { + return nil, ErrEmptyRoute + } + + node := r.nodes[0] + + addr, err := resolve(ctx, node.Addr, node.Resolver, node.Hosts, r.logger) + if err != nil { + node.Marker.Mark() + return + } + cc, err := node.Transport.Dial(ctx, addr) + if err != nil { + node.Marker.Mark() + return + } + + cn, err := node.Transport.Handshake(ctx, cc) + if err != nil { + cc.Close() + node.Marker.Mark() + return + } + node.Marker.Reset() + + preNode := node + for _, node := range r.nodes[1:] { + addr, err = resolve(ctx, node.Addr, node.Resolver, node.Hosts, r.logger) + if err != nil { + cn.Close() + node.Marker.Mark() + return + } + cc, err = preNode.Transport.Connect(ctx, cn, "tcp", addr) + if err != nil { + cn.Close() + node.Marker.Mark() + return + } + cc, err = node.Transport.Handshake(ctx, cc) + if err != nil { + cn.Close() + node.Marker.Mark() + return + } + node.Marker.Reset() + + cn = cc + preNode = node + } + + conn = cn + return +} + func (r *route) IsEmpty() bool { return r == nil || len(r.nodes) == 0 } diff --git a/pkg/chain/router.go b/pkg/chain/router.go index c780018..21d48a8 100644 --- a/pkg/chain/router.go +++ b/pkg/chain/router.go @@ -52,12 +52,16 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co r.Logger.Debugf("route(retry=%d) %s", i, buf.String()) } - address, err = r.resolve(ctx, address) + address, err = resolve(ctx, address, r.Resolver, r.Hosts, r.Logger) if err != nil { r.Logger.Error(err) break } + if route != nil { + route.logger = r.Logger + } + conn, err = route.Dial(ctx, network, address) if err == nil { break @@ -68,39 +72,6 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co return } -func (r *Router) resolve(ctx context.Context, addr string) (string, error) { - if addr == "" { - return addr, nil - } - - host, port, err := net.SplitHostPort(addr) - if err != nil { - return "", err - } - if host == "" { - return addr, nil - } - - if r.Hosts != nil { - if ips, _ := r.Hosts.Lookup("ip", host); len(ips) > 0 { - r.Logger.Debugf("hit host mapper: %s -> %s", host, ips) - return net.JoinHostPort(ips[0].String(), port), nil - } - } - - if r.Resolver != nil { - ips, err := r.Resolver.Resolve(ctx, host) - if err != nil { - r.Logger.Error(err) - } - if len(ips) == 0 { - return "", fmt.Errorf("resolver: domain %s does not exists", host) - } - return net.JoinHostPort(ips[0].String(), port), nil - } - return addr, nil -} - func (r *Router) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (ln net.Listener, err error) { count := r.Retries + 1 if count <= 0 {