From 44acdc92153af3fa434553674d7e0ba5c6cea1aa Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 26 Jan 2022 18:29:48 +0800 Subject: [PATCH] improve config --- cmd/gost/cmd.go | 82 ++++++++++++++++++++------- cmd/gost/config.go | 36 +++++++++--- gost.yml | 6 +- pkg/chain/chain.go | 10 ++-- pkg/chain/node.go | 44 +++----------- pkg/chain/route.go | 28 ++++----- pkg/chain/router.go | 4 +- pkg/chain/selector.go | 6 +- pkg/config/config.go | 11 +++- pkg/handler/forward/local/handler.go | 16 +++--- pkg/handler/forward/remote/handler.go | 14 ++--- pkg/handler/relay/forward.go | 14 ++--- pkg/handler/tap/handler.go | 4 +- pkg/handler/tun/handler.go | 4 +- 14 files changed, 160 insertions(+), 119 deletions(-) diff --git a/cmd/gost/cmd.go b/cmd/gost/cmd.go index 01523cb..136a917 100644 --- a/cmd/gost/cmd.go +++ b/cmd/gost/cmd.go @@ -52,7 +52,26 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { } nodeConfig.Name = "node-0" + var nodes []*config.NodeConfig + for _, host := range strings.Split(nodeConfig.Addr, ",") { + if host == "" { + continue + } + nodeCfg := &config.NodeConfig{} + *nodeCfg = *nodeConfig + nodeCfg.Name = fmt.Sprintf("node-%d", len(nodes)) + nodeCfg.Addr = host + nodes = append(nodes, nodeCfg) + } + md := metadata.MapMetadata(nodeConfig.Connector.Metadata) + + hopConfig := &config.HopConfig{ + Name: fmt.Sprintf("hop-%d", i), + Selector: parseSelector(md), + Nodes: nodes, + } + if v := metadata.GetString(md, "bypass"); v != "" { bypassCfg := &config.BypassConfig{ Name: fmt.Sprintf("bypass-%d", len(cfg.Bypasses)), @@ -67,28 +86,52 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { } bypassCfg.Matchers = append(bypassCfg.Matchers, s) } - nodeConfig.Bypass = bypassCfg.Name + hopConfig.Bypass = bypassCfg.Name cfg.Bypasses = append(cfg.Bypasses, bypassCfg) md.Del("bypass") } - - var nodes []*config.NodeConfig - for _, host := range strings.Split(nodeConfig.Addr, ",") { - if host == "" { - continue + if v := metadata.GetString(md, "resolver"); v != "" { + resolverCfg := &config.ResolverConfig{ + Name: fmt.Sprintf("resolver-%d", len(cfg.Resolvers)), } - nodeCfg := &config.NodeConfig{} - *nodeCfg = *nodeConfig - nodeCfg.Name = fmt.Sprintf("node-%d", len(nodes)) - nodeCfg.Addr = host - nodes = append(nodes, nodeCfg) + for _, rs := range strings.Split(v, ",") { + if rs == "" { + continue + } + resolverCfg.Nameservers = append( + resolverCfg.Nameservers, + config.NameserverConfig{ + Addr: rs, + }, + ) + } + hopConfig.Resolver = resolverCfg.Name + cfg.Resolvers = append(cfg.Resolvers, resolverCfg) + md.Del("resolver") + } + if v := metadata.GetString(md, "hosts"); v != "" { + hostsCfg := &config.HostsConfig{ + Name: fmt.Sprintf("hosts-%d", len(cfg.Hosts)), + } + for _, s := range strings.Split(v, ",") { + ss := strings.SplitN(s, ":", 2) + if len(ss) != 2 { + continue + } + hostsCfg.Mappings = append( + hostsCfg.Mappings, + config.HostMappingConfig{ + Hostname: ss[0], + IP: ss[1], + }, + ) + } + hopConfig.Hosts = hostsCfg.Name + cfg.Hosts = append(cfg.Hosts, hostsCfg) + md.Del("hosts") } - chain.Hops = append(chain.Hops, &config.HopConfig{ - Name: fmt.Sprintf("hop-%d", i), - Selector: parseSelector(md), - Nodes: nodes, - }) + chain.Hops = append(chain.Hops, hopConfig) } for i, svc := range services { @@ -126,7 +169,7 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { } bypassCfg.Matchers = append(bypassCfg.Matchers, s) } - service.Handler.Bypass = bypassCfg.Name + service.Bypass = bypassCfg.Name cfg.Bypasses = append(cfg.Bypasses, bypassCfg) md.Del("bypass") } @@ -145,7 +188,7 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { }, ) } - service.Handler.Resolver = resolverCfg.Name + service.Resolver = resolverCfg.Name cfg.Resolvers = append(cfg.Resolvers, resolverCfg) md.Del("resolver") } @@ -166,11 +209,10 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { }, ) } - service.Handler.Hosts = hostsCfg.Name + service.Hosts = hostsCfg.Name cfg.Hosts = append(cfg.Hosts, hostsCfg) md.Del("hosts") } - } return cfg, nil diff --git a/cmd/gost/config.go b/cmd/gost/config.go index 15e5559..44254e9 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -119,12 +119,12 @@ func buildService(cfg *config.Config) (services []*service.Service) { } h := registry.GetHandler(svc.Handler.Type)( + handler.AuthsOption(authsFromConfig(svc.Handler.Auths...)...), handler.RetriesOption(svc.Handler.Retries), handler.ChainOption(chains[svc.Handler.Chain]), - handler.ResolverOption(resolvers[svc.Handler.Resolver]), - handler.HostsOption(hosts[svc.Handler.Hosts]), - handler.BypassOption(bypasses[svc.Handler.Bypass]), - handler.AuthsOption(authsFromConfig(svc.Handler.Auths...)...), + handler.BypassOption(bypasses[svc.Bypass]), + handler.ResolverOption(resolvers[svc.Resolver]), + handler.HostsOption(hosts[svc.Hosts]), handler.TLSConfigOption(tlsConfig), handler.LoggerOption(handlerLogger), ) @@ -252,9 +252,25 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { WithDialer(d). WithAddr(v.Addr) - node := chain.NewNode(v.Name, v.Addr). - WithTransport(tr). - WithBypass(bypasses[v.Bypass]) + if v.Bypass == "" { + v.Bypass = hop.Bypass + } + if v.Resolver == "" { + v.Resolver = hop.Resolver + } + if v.Hosts == "" { + v.Hosts = hop.Hosts + } + + node := &chain.Node{ + Name: v.Name, + Addr: v.Addr, + Transport: tr, + Bypass: bypasses[v.Bypass], + Resolver: resolvers[v.Resolver], + Hosts: hosts[v.Hosts], + Marker: &chain.FailMarker{}, + } group.AddNode(node) } @@ -277,7 +293,11 @@ func forwarderFromConfig(cfg *config.ForwarderConfig) *chain.NodeGroup { group := &chain.NodeGroup{} for _, target := range cfg.Targets { if v := strings.TrimSpace(target); v != "" { - group.AddNode(chain.NewNode(target, target)) + group.AddNode(&chain.Node{ + Name: target, + Addr: target, + Marker: &chain.FailMarker{}, + }) } } return group.WithSelector(selectorFromConfig(cfg.Selector)) diff --git a/gost.yml b/gost.yml index c795eba..71a3989 100644 --- a/gost.yml +++ b/gost.yml @@ -6,10 +6,10 @@ log: services: - name: http+tcp addr: ":28000" + # bypass: bypass01 handler: type: http chain: chain01 - # bypass: bypass01 metadata: proxyAgent: "gost/3.0" auths: @@ -23,10 +23,10 @@ services: keepAlive: 15s - name: ss addr: ":28338" + # bypass: bypass01 handler: type: ss # chain: chain01 - # bypass: bypass01 metadata: method: chacha20-ietf password: gost @@ -39,10 +39,10 @@ services: keepAlive: 15s - name: socks5 addr: ":21080" + # bypass: bypass01 handler: type: socks5 # chain: chain-ss - # bypass: bypass01 metadata: auths: - gost:gost diff --git a/pkg/chain/chain.go b/pkg/chain/chain.go index 515f200..57275dc 100644 --- a/pkg/chain/chain.go +++ b/pkg/chain/chain.go @@ -23,15 +23,15 @@ func (c *Chain) GetRouteFor(network, address string) (r *route) { if node == nil { return } - if node.bypass != nil && node.bypass.Contains(address) { + if node.Bypass != nil && node.Bypass.Contains(address) { break } - if node.transport.Multiplex() { - tr := node.transport.Copy(). + if node.Transport.Multiplex() { + tr := node.Transport.Copy(). WithRoute(r) - node = node.Copy(). - WithTransport(tr) + node = node.Copy() + node.Transport = tr r = &route{} } diff --git a/pkg/chain/node.go b/pkg/chain/node.go index ff42843..9972396 100644 --- a/pkg/chain/node.go +++ b/pkg/chain/node.go @@ -5,44 +5,18 @@ import ( "time" "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/hosts" + "github.com/go-gost/gost/pkg/resolver" ) type Node struct { - name string - addr string - transport *Transport - bypass bypass.Bypass - marker *FailMarker -} - -func NewNode(name, addr string) *Node { - return &Node{ - name: name, - addr: addr, - marker: &FailMarker{}, - } -} - -func (node *Node) Name() string { - return node.name -} - -func (node *Node) Addr() string { - return node.addr -} - -func (node *Node) Marker() *FailMarker { - return node.marker -} - -func (node *Node) WithTransport(tr *Transport) *Node { - node.transport = tr - return node -} - -func (node *Node) WithBypass(bp bypass.Bypass) *Node { - node.bypass = bp - return node + Name string + Addr string + Transport *Transport + Bypass bypass.Bypass + Resolver resolver.Resolver + Hosts hosts.HostMapper + Marker *FailMarker } func (node *Node) Copy() *Node { diff --git a/pkg/chain/route.go b/pkg/chain/route.go index ef6b5f1..21ac9d6 100644 --- a/pkg/chain/route.go +++ b/pkg/chain/route.go @@ -29,35 +29,35 @@ func (r *route) connect(ctx context.Context) (conn net.Conn, err error) { } node := r.nodes[0] - cc, err := node.transport.Dial(ctx, r.nodes[0].Addr()) + cc, err := node.Transport.Dial(ctx, r.nodes[0].Addr) if err != nil { - node.Marker().Mark() + node.Marker.Mark() return } - cn, err := node.transport.Handshake(ctx, cc) + cn, err := node.Transport.Handshake(ctx, cc) if err != nil { cc.Close() - node.Marker().Mark() + node.Marker.Mark() return } - node.Marker().Reset() + node.Marker.Reset() preNode := node for _, node := range r.nodes[1:] { - cc, err = preNode.transport.Connect(ctx, cn, "tcp", node.Addr()) + cc, err = preNode.Transport.Connect(ctx, cn, "tcp", node.Addr) if err != nil { cn.Close() - node.Marker().Mark() + node.Marker.Mark() return } - cc, err = node.transport.Handshake(ctx, cc) + cc, err = node.Transport.Handshake(ctx, cc) if err != nil { cn.Close() - node.Marker().Mark() + node.Marker.Mark() return } - node.Marker().Reset() + node.Marker.Reset() cn = cc preNode = node @@ -77,7 +77,7 @@ func (r *route) Dial(ctx context.Context, network, address string) (net.Conn, er return nil, err } - cc, err := r.Last().transport.Connect(ctx, conn, network, address) + cc, err := r.Last().Transport.Connect(ctx, conn, network, address) if err != nil { conn.Close() return nil, err @@ -108,7 +108,7 @@ func (r *route) Bind(ctx context.Context, network, address string, opts ...conne return nil, err } - ln, err := r.Last().transport.Bind(ctx, conn, network, address, opts...) + ln, err := r.Last().Transport.Bind(ctx, conn, network, address, opts...) if err != nil { conn.Close() return nil, err @@ -134,8 +134,8 @@ func (r *route) Path() (path []*Node) { } for _, node := range r.nodes { - if node.transport != nil && node.transport.route != nil { - path = append(path, node.transport.route.Path()...) + if node.Transport != nil && node.Transport.route != nil { + path = append(path, node.Transport.route.Path()...) } path = append(path, node) } diff --git a/pkg/chain/router.go b/pkg/chain/router.go index 0c266bf..c780018 100644 --- a/pkg/chain/router.go +++ b/pkg/chain/router.go @@ -46,7 +46,7 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co if r.Logger.IsLevelEnabled(logger.DebugLevel) { buf := bytes.Buffer{} for _, node := range route.Path() { - fmt.Fprintf(&buf, "%s@%s > ", node.Name(), node.Addr()) + fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr) } fmt.Fprintf(&buf, "%s", address) r.Logger.Debugf("route(retry=%d) %s", i, buf.String()) @@ -114,7 +114,7 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...conn if r.Logger.IsLevelEnabled(logger.DebugLevel) { buf := bytes.Buffer{} for _, node := range route.Path() { - fmt.Fprintf(&buf, "%s@%s > ", node.Name(), node.Addr()) + fmt.Fprintf(&buf, "%s@%s > ", node.Name, node.Addr) } fmt.Fprintf(&buf, "%s", address) r.Logger.Debugf("route(retry=%d) %s", i, buf.String()) diff --git a/pkg/chain/selector.go b/pkg/chain/selector.go index 5c96338..2ad0a03 100644 --- a/pkg/chain/selector.go +++ b/pkg/chain/selector.go @@ -141,8 +141,8 @@ func (f *failFilter) Filter(nodes ...*Node) []*Node { } var nl []*Node for _, node := range nodes { - if node.Marker().FailCount() < int64(maxFails) || - time.Since(time.Unix(node.Marker().FailTime(), 0)) >= failTimeout { + if node.Marker.FailCount() < int64(maxFails) || + time.Since(time.Unix(node.Marker.FailTime(), 0)) >= failTimeout { nl = append(nl, node) } } @@ -161,7 +161,7 @@ func InvalidFilter() Filter { func (f *invalidFilter) Filter(nodes ...*Node) []*Node { var nl []*Node for _, node := range nodes { - _, sport, _ := net.SplitHostPort(node.Addr()) + _, sport, _ := net.SplitHostPort(node.Addr) if port, _ := strconv.Atoi(sport); port > 0 { nl = append(nl, node) } diff --git a/pkg/config/config.go b/pkg/config/config.go index 95fe4cf..cd683d7 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -93,9 +93,6 @@ type HandlerConfig struct { Type string Retries int `yaml:",omitempty"` Chain string `yaml:",omitempty"` - Bypass string `yaml:",omitempty"` - Resolver string `yaml:",omitempty"` - Hosts string `yaml:",omitempty"` Auths []*AuthConfig `yaml:",omitempty"` TLS *TLSConfig `yaml:",omitempty"` Metadata map[string]interface{} `yaml:",omitempty"` @@ -123,6 +120,9 @@ type ConnectorConfig struct { type ServiceConfig struct { Name string Addr string `yaml:",omitempty"` + Bypass string `yaml:",omitempty"` + Resolver string `yaml:",omitempty"` + Hosts string `yaml:",omitempty"` Handler *HandlerConfig `yaml:",omitempty"` Listener *ListenerConfig `yaml:",omitempty"` Forwarder *ForwarderConfig `yaml:",omitempty"` @@ -137,6 +137,9 @@ type ChainConfig struct { type HopConfig struct { Name string Selector *SelectorConfig `yaml:",omitempty"` + Bypass string `yaml:",omitempty"` + Resolver string `yaml:",omitempty"` + Hosts string `yaml:",omitempty"` Nodes []*NodeConfig } @@ -144,6 +147,8 @@ type NodeConfig struct { Name string Addr string `yaml:",omitempty"` Bypass string `yaml:",omitempty"` + Resolver string `yaml:",omitempty"` + Hosts string `yaml:",omitempty"` Connector *ConnectorConfig `yaml:",omitempty"` Dialer *DialerConfig `yaml:",omitempty"` } diff --git a/pkg/handler/forward/local/handler.go b/pkg/handler/forward/local/handler.go index 657df1a..8e9d93f 100644 --- a/pkg/handler/forward/local/handler.go +++ b/pkg/handler/forward/local/handler.go @@ -43,7 +43,7 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { if h.group == nil { // dummy node used by relay connector. - h.group = chain.NewNodeGroup(chain.NewNode("dummy", ":0")) + h.group = chain.NewNodeGroup(&chain.Node{Name: "dummy", Addr: ":0"}) } h.router = &chain.Router{ @@ -90,26 +90,26 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { } log = log.WithFields(map[string]interface{}{ - "dst": fmt.Sprintf("%s/%s", target.Addr(), network), + "dst": fmt.Sprintf("%s/%s", target.Addr, network), }) - log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr) - cc, err := h.router.Dial(ctx, network, target.Addr()) + cc, err := h.router.Dial(ctx, network, target.Addr) if err != nil { log.Error(err) // TODO: the router itself may be failed due to the failed node in the router, // the dead marker may be a wrong operation. - target.Marker().Mark() + target.Marker.Mark() return } defer cc.Close() - target.Marker().Reset() + target.Marker.Reset() t := time.Now() - log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) handler.Transport(conn, cc) log.WithFields(map[string]interface{}{ "duration": time.Since(t), - }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr()) + }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr) } diff --git a/pkg/handler/forward/remote/handler.go b/pkg/handler/forward/remote/handler.go index 80c22c4..a08b45d 100644 --- a/pkg/handler/forward/remote/handler.go +++ b/pkg/handler/forward/remote/handler.go @@ -84,26 +84,26 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { } log = log.WithFields(map[string]interface{}{ - "dst": fmt.Sprintf("%s/%s", target.Addr(), network), + "dst": fmt.Sprintf("%s/%s", target.Addr, network), }) - log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr) - cc, err := h.router.Dial(ctx, network, target.Addr()) + cc, err := h.router.Dial(ctx, network, target.Addr) if err != nil { log.Error(err) // TODO: the router itself may be failed due to the failed node in the router, // the dead marker may be a wrong operation. - target.Marker().Mark() + target.Marker.Mark() return } defer cc.Close() - target.Marker().Reset() + target.Marker.Reset() t := time.Now() - log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) handler.Transport(conn, cc) log.WithFields(map[string]interface{}{ "duration": time.Since(t), - }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr()) + }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr) } diff --git a/pkg/handler/relay/forward.go b/pkg/handler/relay/forward.go index 11f0b3d..79cc98e 100644 --- a/pkg/handler/relay/forward.go +++ b/pkg/handler/relay/forward.go @@ -25,17 +25,17 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network } log = log.WithFields(map[string]interface{}{ - "dst": fmt.Sprintf("%s/%s", target.Addr(), network), + "dst": fmt.Sprintf("%s/%s", target.Addr, network), "cmd": "forward", }) - log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr) - cc, err := h.router.Dial(ctx, network, target.Addr()) + cc, err := h.router.Dial(ctx, network, target.Addr) if err != nil { // TODO: the router itself may be failed due to the failed node in the router, // the dead marker may be a wrong operation. - target.Marker().Mark() + target.Marker.Mark() resp.Status = relay.StatusHostUnreachable resp.WriteTo(conn) @@ -44,7 +44,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network return } defer cc.Close() - target.Marker().Reset() + target.Marker.Reset() if h.md.noDelay { if _, err := resp.WriteTo(conn); err != nil { @@ -79,9 +79,9 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network } t := time.Now() - log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) handler.Transport(conn, cc) log.WithFields(map[string]interface{}{ "duration": time.Since(t), - }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr()) + }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr) } diff --git a/pkg/handler/tap/handler.go b/pkg/handler/tap/handler.go index 44f9068..dc7b405 100644 --- a/pkg/handler/tap/handler.go +++ b/pkg/handler/tap/handler.go @@ -109,7 +109,7 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) { target := h.group.Next() if target != nil { - raddr, err = net.ResolveUDPAddr(network, target.Addr()) + raddr, err = net.ResolveUDPAddr(network, target.Addr) if err != nil { log.Error(err) return @@ -117,7 +117,7 @@ func (h *tapHandler) Handle(ctx context.Context, conn net.Conn) { log = log.WithFields(map[string]interface{}{ "dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()), }) - log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr) } h.handleLoop(ctx, conn, raddr, cc.Config(), log) diff --git a/pkg/handler/tun/handler.go b/pkg/handler/tun/handler.go index c87e0b0..4b6b2f3 100644 --- a/pkg/handler/tun/handler.go +++ b/pkg/handler/tun/handler.go @@ -112,7 +112,7 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) { target := h.group.Next() if target != nil { - raddr, err = net.ResolveUDPAddr(network, target.Addr()) + raddr, err = net.ResolveUDPAddr(network, target.Addr) if err != nil { log.Error(err) return @@ -120,7 +120,7 @@ func (h *tunHandler) Handle(ctx context.Context, conn net.Conn) { log = log.WithFields(map[string]interface{}{ "dst": fmt.Sprintf("%s/%s", raddr.String(), raddr.Network()), }) - log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + log.Infof("%s >> %s", conn.RemoteAddr(), target.Addr) } h.handleLoop(ctx, conn, raddr, cc.Config(), log)