diff --git a/cmd/gost/config.go b/cmd/gost/config.go index 729e75b..b2e8137 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -3,6 +3,7 @@ package main import ( "io" "os" + "strings" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" @@ -62,6 +63,11 @@ func buildService(cfg *config.Config) (services []*service.Service) { handler.BypassOption(bypasses[svc.Bypass]), handler.LoggerOption(handlerLogger), ) + + if forwarder, ok := h.(handler.Forwarder); ok { + forwarder.Forward(forwarderFromConfig(svc.Forwarder)) + } + if err := h.Init(metadata.MapMetadata(svc.Handler.Metadata)); err != nil { handlerLogger.Fatal("init: ", err) } @@ -85,7 +91,7 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { c := &chain.Chain{} - selector := selectorFromConfig(cfg.LB) + selector := selectorFromConfig(cfg.Selector) for _, hop := range cfg.Hops { group := &chain.NodeGroup{} for _, v := range hop.Nodes { @@ -127,7 +133,7 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { } sel := selector - if s := selectorFromConfig(hop.LB); s != nil { + if s := selectorFromConfig(hop.Selector); s != nil { sel = s } group.WithSelector(sel) @@ -162,7 +168,7 @@ func logFromConfig(cfg *config.LogConfig) logger.Logger { return logger.NewLogger(opts...) } -func selectorFromConfig(cfg *config.LoadbalancingConfig) chain.Selector { +func selectorFromConfig(cfg *config.SelectorConfig) chain.Selector { if cfg == nil { return nil } @@ -173,7 +179,7 @@ func selectorFromConfig(cfg *config.LoadbalancingConfig) chain.Selector { strategy = chain.RoundRobinStrategy() case "random": strategy = chain.RandomStrategy() - case "fifio": + case "fifo": strategy = chain.FIFOStrategy() default: strategy = chain.RoundRobinStrategy() @@ -190,6 +196,19 @@ func bypassFromConfig(cfg *config.BypassConfig) bypass.Bypass { if cfg == nil { return nil } - return bypass.NewBypassPatterns(cfg.Reverse, cfg.Matchers...) } + +func forwarderFromConfig(cfg *config.ForwarderConfig) *chain.NodeGroup { + if cfg == nil { + return nil + } + + group := &chain.NodeGroup{} + for _, target := range cfg.Targets { + if v := strings.TrimSpace(target); v != "" { + group.AddNode(chain.NewNode(target, target)) + } + } + return group.WithSelector(selectorFromConfig(cfg.Selector)) +} diff --git a/cmd/gost/gost.yml b/cmd/gost/gost.yml index ac6cfcd..4709e92 100644 --- a/cmd/gost/gost.yml +++ b/cmd/gost/gost.yml @@ -10,7 +10,7 @@ profiling: services: - name: http+tcp url: "http://gost:gost@:8000" - addr: ":8000" + addr: ":28000" handler: type: http metadata: @@ -27,38 +27,27 @@ services: keepAlive: 15s chain: chain01 # bypass: bypass01 -- name: ss+tcp +- name: ss url: "ss://chacha20:gost@:8000" - addr: ":8338" + addr: ":28338" handler: type: ss metadata: - method: AES-256-GCM + method: chacha20-ietf password: gost readTimeout: 5s retry: 3 + udp: true + bufferSize: 4096 listener: type: tcp metadata: keepAlive: 15s - chain: chain01 + # chain: chain01 # bypass: bypass01 -- name: ssu - url: "ss://chacha20:gost@:8000" - addr: ":8388" - handler: - type: ssu - metadata: - # method: AES-256-GCM - # password: gost - readTimeout: 5s - retry: 3 - listener: - type: tcp - # chain: chain-ssu -- name: socks5+tcp +- name: socks5 url: "socks5://gost:gost@:1080" - addr: ":1080" + addr: ":21080" handler: type: socks5 metadata: @@ -72,11 +61,11 @@ services: type: tcp metadata: keepAlive: 15s - chain: chain-socks5 + chain: chain-ss # bypass: bypass01 - name: socks5+tcp url: "socks5://gost:gost@:1080" - addr: ":11080" + addr: ":21081" handler: type: socks5 metadata: @@ -90,18 +79,40 @@ services: type: tcp metadata: keepAlive: 15s +- name: forward + url: "socks5://gost:gost@:1080" + addr: ":10053" + forwarder: + targets: + - 192.168.8.8:53 + - 192.168.8.1:53 + - 1.1.1.1:53 + selector: + strategy: fifo + maxFails: 1 + failTimeout: 30s + handler: + type: forward + metadata: + readTimeout: 5s + retry: 3 + listener: + type: udp + metadata: + keepAlive: 15s + chain: chain-ss chains: - name: chain01 - # chain level load balancing - lb: + # chain level selector + selector: strategy: round maxFails: 1 failTimeout: 30s hops: - name: hop01 - # hop level load balancing - lb: + # hop level selector + selector: strategy: round maxFails: 1 failTimeout: 30s @@ -131,8 +142,8 @@ chains: type: tcp metadata: {} - name: hop02 - # hop level load balancing - lb: + # hop level selector + selector: strategy: round maxFails: 1 failTimeout: 30s @@ -179,19 +190,25 @@ chains: dialer: type: tcp metadata: {} -- name: chain-ssu +- name: chain-ss hops: - name: hop01 nodes: - name: node01 - addr: ":8339" + addr: ":28338" url: "http://gost:gost@:8081" # bypass: bypass01 connector: - type: ssu - metadata: {} + type: ss + metadata: + method: chacha20-ietf + password: gost + readTimeout: 5s + nodelay: true + udp: true + bufferSize: 4096 dialer: - type: udp + type: tcp metadata: {} bypasses: diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 095e445..a2ca242 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -6,18 +6,17 @@ import ( _ "github.com/go-gost/gost/pkg/connector/socks/v4" _ "github.com/go-gost/gost/pkg/connector/socks/v5" _ "github.com/go-gost/gost/pkg/connector/ss" - _ "github.com/go-gost/gost/pkg/connector/ssu" // Register dialers _ "github.com/go-gost/gost/pkg/dialer/tcp" _ "github.com/go-gost/gost/pkg/dialer/udp" // Register handlers + _ "github.com/go-gost/gost/pkg/handler/forward/local" _ "github.com/go-gost/gost/pkg/handler/http" _ "github.com/go-gost/gost/pkg/handler/socks/v4" _ "github.com/go-gost/gost/pkg/handler/socks/v5" _ "github.com/go-gost/gost/pkg/handler/ss" - _ "github.com/go-gost/gost/pkg/handler/ssu" // Register listeners _ "github.com/go-gost/gost/pkg/listener/ftcp" diff --git a/pkg/chain/node.go b/pkg/chain/node.go index ecb5002..d4c011f 100644 --- a/pkg/chain/node.go +++ b/pkg/chain/node.go @@ -12,14 +12,14 @@ type Node struct { addr string transport *Transport bypass bypass.Bypass - marker *failMarker + marker *FailMarker } func NewNode(name, addr string) *Node { return &Node{ name: name, addr: addr, - marker: &failMarker{}, + marker: &FailMarker{}, } } @@ -31,6 +31,10 @@ func (node *Node) Addr() string { return node.addr } +func (node *Node) Marker() *FailMarker { + return node.marker +} + func (node *Node) WithTransport(tr *Transport) *Node { node.transport = tr return node @@ -80,13 +84,13 @@ func (g *NodeGroup) Next() *Node { return selector.Select(g.nodes...) } -type failMarker struct { +type FailMarker struct { failTime int64 failCount uint32 mux sync.RWMutex } -func (m *failMarker) FailTime() int64 { +func (m *FailMarker) FailTime() int64 { if m == nil { return 0 } @@ -97,7 +101,7 @@ func (m *failMarker) FailTime() int64 { return m.failTime } -func (m *failMarker) FailCount() uint32 { +func (m *FailMarker) FailCount() uint32 { if m == nil { return 0 } @@ -108,7 +112,7 @@ func (m *failMarker) FailCount() uint32 { return m.failCount } -func (m *failMarker) Mark() { +func (m *FailMarker) Mark() { if m == nil { return } @@ -120,7 +124,7 @@ func (m *failMarker) Mark() { m.failCount++ } -func (m *failMarker) Reset() { +func (m *FailMarker) Reset() { if m == nil { return } diff --git a/pkg/chain/route.go b/pkg/chain/route.go index d752fd1..db4cc19 100644 --- a/pkg/chain/route.go +++ b/pkg/chain/route.go @@ -26,33 +26,33 @@ func (r *Route) Connect(ctx context.Context) (conn net.Conn, err error) { node := r.nodes[0] cc, err := node.transport.Dial(ctx, r.nodes[0].Addr()) if err != nil { - node.marker.Mark() + node.Marker().Mark() return } cn, err := node.transport.Handshake(ctx, cc) if err != nil { cc.Close() - node.marker.Mark() + node.Marker().Mark() return } - node.marker.Reset() + node.Marker().Reset() preNode := node for _, node := range r.nodes[1:] { cc, err = preNode.transport.Connect(ctx, cn, "tcp", node.Addr()) if err != nil { cn.Close() - node.marker.Mark() + node.Marker().Mark() return } cc, err = node.transport.Handshake(ctx, cc) if err != nil { cn.Close() - node.marker.Mark() + node.Marker().Mark() return } - node.marker.Reset() + node.Marker().Reset() cn = cc preNode = node @@ -89,7 +89,7 @@ func (r *Route) dialDirect(ctx context.Context, network, address string) (net.Co default: } - d := &net.Dialer{} + d := net.Dialer{} return d.DialContext(ctx, network, address) } diff --git a/pkg/chain/selector.go b/pkg/chain/selector.go index d6335ab..8c259db 100644 --- a/pkg/chain/selector.go +++ b/pkg/chain/selector.go @@ -15,10 +15,6 @@ const ( DefaultFailTimeout = 30 * time.Second ) -var ( - defaultSelector Selector = NewSelector(nil) -) - type Selector interface { Select(nodes ...*Node) *Node } @@ -145,8 +141,8 @@ func (f *failFilter) Filter(nodes ...*Node) []*Node { } var nl []*Node for _, node := range nodes { - if node.marker.FailCount() < uint32(maxFails) || - time.Since(time.Unix(node.marker.FailTime(), 0)) >= failTimeout { + if node.Marker().FailCount() < uint32(maxFails) || + time.Since(time.Unix(node.Marker().FailTime(), 0)) >= failTimeout { nl = append(nl, node) } } diff --git a/pkg/config/config.go b/pkg/config/config.go index ec37197..a4c72b9 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -29,7 +29,7 @@ type ProfilingConfig struct { Enabled bool } -type LoadbalancingConfig struct { +type SelectorConfig struct { Strategy string MaxFails int FailTimeout time.Duration @@ -50,6 +50,11 @@ type HandlerConfig struct { Metadata map[string]interface{} } +type ForwarderConfig struct { + Targets []string + Selector *SelectorConfig +} + type DialerConfig struct { Type string Metadata map[string]interface{} @@ -61,25 +66,26 @@ type ConnectorConfig struct { } type ServiceConfig struct { - Name string - URL string - Addr string - Listener *ListenerConfig - Handler *HandlerConfig - Chain string - Bypass string + Name string + URL string + Addr string + Listener *ListenerConfig + Handler *HandlerConfig + Forwarder *ForwarderConfig + Chain string + Bypass string } type ChainConfig struct { - Name string - LB *LoadbalancingConfig - Hops []HopConfig + Name string + Selector *SelectorConfig + Hops []HopConfig } type HopConfig struct { - Name string - LB *LoadbalancingConfig - Nodes []NodeConfig + Name string + Selector *SelectorConfig + Nodes []NodeConfig } type NodeConfig struct { diff --git a/pkg/connector/http/connector.go b/pkg/connector/http/connector.go index bfcb54b..27de560 100644 --- a/pkg/connector/http/connector.go +++ b/pkg/connector/http/connector.go @@ -49,6 +49,7 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add "network": network, "address": address, }) + c.logger.Infof("connect: %s/%s", address, network) switch network { case "tcp", "tcp4", "tcp6": @@ -71,8 +72,6 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add } req.Header.Set("Proxy-Connection", "keep-alive") - c.logger.Infof("connect: ", address) - if user := c.md.User; user != nil { u := user.Username() p, _ := user.Password() diff --git a/pkg/connector/socks/v4/connector.go b/pkg/connector/socks/v4/connector.go index fa02303..512d6bd 100644 --- a/pkg/connector/socks/v4/connector.go +++ b/pkg/connector/socks/v4/connector.go @@ -2,6 +2,7 @@ package v4 import ( "context" + "errors" "fmt" "net" "net/url" @@ -47,6 +48,7 @@ func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, a "network": network, "address": address, }) + c.logger.Infof("connect: %s/%s", address, network) switch network { case "tcp", "tcp4", "tcp6": @@ -56,8 +58,6 @@ func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, a return nil, err } - c.logger.Info("connect: ", address) - var addr *gosocks4.Addr if c.md.disable4a { @@ -107,7 +107,9 @@ func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, a c.logger.Debug(reply) if reply.Code != gosocks4.Granted { - return nil, fmt.Errorf("error: %d", reply.Code) + err = errors.New("host unreachable") + c.logger.Error(err) + return nil, err } return conn, nil diff --git a/pkg/connector/socks/v5/connector.go b/pkg/connector/socks/v5/connector.go index b4d946b..dc38f04 100644 --- a/pkg/connector/socks/v5/connector.go +++ b/pkg/connector/socks/v5/connector.go @@ -67,6 +67,7 @@ func (c *socks5Connector) Init(md md.Metadata) (err error) { return } +// Handshake implements connector.Handshaker. func (c *socks5Connector) Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) { c.logger = c.logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), @@ -92,17 +93,18 @@ func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, a "network": network, "address": address, }) + c.logger.Infof("connect: %s/%s", address, network) switch network { + case "udp", "udp4", "udp6": + return c.connectUDP(ctx, conn, network, address) case "tcp", "tcp4", "tcp6": default: - err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network) + err := fmt.Errorf("network %s unsupported", network) c.logger.Error(err) return nil, err } - c.logger.Info("connect: ", address) - addr := gosocks5.Addr{} if err := addr.ParseFrom(address); err != nil { c.logger.Error(err) @@ -129,12 +131,48 @@ func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, a c.logger.Debug(reply) if reply.Rep != gosocks5.Succeeded { - return nil, errors.New("service unavailable") + err = errors.New("host unreachable") + c.logger.Error(err) + return nil, err } return conn, nil } +func (c *socks5Connector) connectUDP(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) { + addr, err := net.ResolveUDPAddr(network, address) + if err != nil { + c.logger.Error(err) + return nil, err + } + + req := gosocks5.NewRequest(socks.CmdUDPTun, nil) + if err := req.Write(conn); err != nil { + c.logger.Error(err) + return nil, err + } + c.logger.Debug(req) + + reply, err := gosocks5.ReadReply(conn) + if err != nil { + c.logger.Error(err) + return nil, err + } + c.logger.Debug(reply) + + if reply.Rep != gosocks5.Succeeded { + return nil, errors.New("get socks5 UDP tunnel failure") + } + + baddr, err := net.ResolveUDPAddr("udp", reply.Addr.String()) + if err != nil { + return nil, err + } + c.logger.Debugf("associate on %s OK", baddr) + + return socks.UDPTunClientConn(conn, addr), nil +} + func (c *socks5Connector) parseMetadata(md md.Metadata) (err error) { if v := md.GetString(auth); v != "" { ss := strings.SplitN(v, ":", 2) diff --git a/pkg/connector/ss/connector.go b/pkg/connector/ss/connector.go index 92dd8f0..8081cda 100644 --- a/pkg/connector/ss/connector.go +++ b/pkg/connector/ss/connector.go @@ -2,6 +2,7 @@ package ss import ( "context" + "errors" "fmt" "net" "time" @@ -9,6 +10,7 @@ import ( "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/internal/bufpool" + "github.com/go-gost/gost/pkg/internal/utils/socks" "github.com/go-gost/gost/pkg/internal/utils/ss" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -46,15 +48,23 @@ func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, addre "network": network, "address": address, }) + c.logger.Infof("connect: %s/%s", address, network) switch network { case "tcp", "tcp4", "tcp6": + case "udp", "udp4", "udp6": + if c.md.enableUDP { + return c.connectUDP(ctx, conn, network, address) + } else { + err := errors.New("UDP relay is disabled") + c.logger.Error(err) + return nil, err + } default: - err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network) + err := fmt.Errorf("network %s unsupported", network) c.logger.Error(err) return nil, err } - c.logger.Infof("connect: ", address) addr := gosocks5.Addr{} if err := addr.ParseFrom(address); err != nil { @@ -94,18 +104,28 @@ func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, addre return sc, nil } -func (c *ssConnector) parseMetadata(md md.Metadata) (err error) { - c.md.cipher, err = ss.ShadowCipher( - md.GetString(method), - md.GetString(password), - md.GetString(key), - ) - if err != nil { - return +func (c *ssConnector) connectUDP(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) { + if c.md.connectTimeout > 0 { + conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) + defer conn.SetDeadline(time.Time{}) } - c.md.connectTimeout = md.GetDuration(connectTimeout) - c.md.noDelay = md.GetBool(noDelay) + taddr, _ := net.ResolveUDPAddr(network, address) + if taddr == nil { + taddr = &net.UDPAddr{} + } - return + pc, ok := conn.(net.PacketConn) + if ok { + if c.md.cipher != nil { + pc = c.md.cipher.PacketConn(pc) + } + + return ss.UDPClientConn(pc, conn.RemoteAddr(), taddr, c.md.udpBufferSize), nil + } + + if c.md.cipher != nil { + conn = ss.ShadowConn(c.md.cipher.StreamConn(conn), nil) + } + return socks.UDPTunClientConn(conn, taddr), nil } diff --git a/pkg/connector/ss/metadata.go b/pkg/connector/ss/metadata.go index 1b01380..415d624 100644 --- a/pkg/connector/ss/metadata.go +++ b/pkg/connector/ss/metadata.go @@ -3,19 +3,53 @@ package ss import ( "time" + "github.com/go-gost/gost/pkg/internal/utils/ss" + md "github.com/go-gost/gost/pkg/metadata" "github.com/shadowsocks/go-shadowsocks2/core" ) -const ( - method = "method" - password = "password" - key = "key" - connectTimeout = "timeout" - noDelay = "noDelay" -) - type metadata struct { cipher core.Cipher connectTimeout time.Duration noDelay bool + enableUDP bool + udpBufferSize int +} + +func (c *ssConnector) parseMetadata(md md.Metadata) (err error) { + const ( + method = "method" + password = "password" + key = "key" + connectTimeout = "timeout" + noDelay = "noDelay" + enableUDP = "udp" // enable UDP relay + udpBufferSize = "udpBufferSize" // udp buffer size + ) + + c.md.cipher, err = ss.ShadowCipher( + md.GetString(method), + md.GetString(password), + md.GetString(key), + ) + if err != nil { + return + } + + c.md.connectTimeout = md.GetDuration(connectTimeout) + c.md.noDelay = md.GetBool(noDelay) + c.md.enableUDP = md.GetBool(enableUDP) + + if c.md.udpBufferSize > 0 { + if c.md.udpBufferSize < 512 { + c.md.udpBufferSize = 512 + } + if c.md.udpBufferSize > 65*1024 { + c.md.udpBufferSize = 65 * 1024 + } + } else { + c.md.udpBufferSize = 4096 + } + + return } diff --git a/pkg/connector/ssu/connector.go b/pkg/connector/ssu/connector.go deleted file mode 100644 index 69dcf3e..0000000 --- a/pkg/connector/ssu/connector.go +++ /dev/null @@ -1,105 +0,0 @@ -package ssu - -import ( - "context" - "fmt" - "net" - "time" - - "github.com/go-gost/gost/pkg/connector" - "github.com/go-gost/gost/pkg/internal/utils/socks" - "github.com/go-gost/gost/pkg/internal/utils/ss" - "github.com/go-gost/gost/pkg/logger" - md "github.com/go-gost/gost/pkg/metadata" - "github.com/go-gost/gost/pkg/registry" -) - -func init() { - registry.RegiserConnector("ssu", NewConnector) -} - -type ssuConnector struct { - md metadata - logger logger.Logger -} - -func NewConnector(opts ...connector.Option) connector.Connector { - options := &connector.Options{} - for _, opt := range opts { - opt(options) - } - - return &ssuConnector{ - logger: options.Logger, - } -} - -func (c *ssuConnector) Init(md md.Metadata) (err error) { - return c.parseMetadata(md) -} - -func (c *ssuConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { - c.logger = c.logger.WithFields(map[string]interface{}{ - "remote": conn.RemoteAddr().String(), - "local": conn.LocalAddr().String(), - "network": network, - "address": address, - }) - - switch network { - case "udp", "udp4", "udp6": - default: - err := fmt.Errorf("network %s unsupported, should be udp, udp4 or udp6", network) - c.logger.Error(err) - return nil, err - } - - c.logger.Info("connect: ", address) - - if c.md.connectTimeout > 0 { - conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) - defer conn.SetDeadline(time.Time{}) - } - - taddr, _ := net.ResolveUDPAddr(network, address) - if taddr == nil { - taddr = &net.UDPAddr{} - } - - pc, ok := conn.(net.PacketConn) - if ok { - if c.md.cipher != nil { - pc = c.md.cipher.PacketConn(pc) - } - - return ss.UDPClientConn(pc, conn.RemoteAddr(), taddr, c.md.bufferSize), nil - } - - return socks.UDPTunClientConn(conn, taddr), nil -} - -func (c *ssuConnector) parseMetadata(md md.Metadata) (err error) { - c.md.cipher, err = ss.ShadowCipher( - md.GetString(method), - md.GetString(password), - md.GetString(key), - ) - if err != nil { - return - } - - c.md.connectTimeout = md.GetDuration(connectTimeout) - c.md.bufferSize = md.GetInt(bufferSize) - if c.md.bufferSize > 0 { - if c.md.bufferSize < 512 { - c.md.bufferSize = 512 - } - if c.md.bufferSize > 65*1024 { - c.md.bufferSize = 65 * 1024 - } - } else { - c.md.bufferSize = 4096 - } - - return -} diff --git a/pkg/connector/ssu/metadata.go b/pkg/connector/ssu/metadata.go deleted file mode 100644 index 037f611..0000000 --- a/pkg/connector/ssu/metadata.go +++ /dev/null @@ -1,21 +0,0 @@ -package ssu - -import ( - "time" - - "github.com/shadowsocks/go-shadowsocks2/core" -) - -const ( - method = "method" - password = "password" - key = "key" - connectTimeout = "timeout" - bufferSize = "bufferSize" -) - -type metadata struct { - cipher core.Cipher - connectTimeout time.Duration - bufferSize int -} diff --git a/pkg/handler/forward/local/handler.go b/pkg/handler/forward/local/handler.go new file mode 100644 index 0000000..c68768e --- /dev/null +++ b/pkg/handler/forward/local/handler.go @@ -0,0 +1,113 @@ +package local + +import ( + "context" + "net" + "time" + + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegisterHandler("forward", NewHandler) +} + +type localForwardHandler struct { + group *chain.NodeGroup + chain *chain.Chain + bypass bypass.Bypass + logger logger.Logger + md metadata +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := &handler.Options{} + for _, opt := range opts { + opt(options) + } + + return &localForwardHandler{ + chain: options.Chain, + bypass: options.Bypass, + logger: options.Logger, + } +} + +func (h *localForwardHandler) Init(md md.Metadata) (err error) { + return h.parseMetadata(md) +} + +// Forward implements handler.Forwarder. +func (h *localForwardHandler) Forward(group *chain.NodeGroup) { + h.group = group +} + +func (h *localForwardHandler) Handle(ctx context.Context, conn net.Conn) { + defer conn.Close() + + start := time.Now() + h.logger = h.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + target := h.group.Next() + if target == nil { + h.logger.Error("no target available") + return + } + + h.logger = h.logger.WithFields(map[string]interface{}{ + "dst": target.Addr(), + }) + + h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) + + r := (&handler.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + + network := "tcp" + if _, ok := conn.(net.PacketConn); ok { + network = "udp" + } + + cc, err := r.Dial(ctx, network, target.Addr()) + if err != nil { + h.logger.Error(err) + // TODO: the router itself may be failed due to the failed node in the router, + // the dead marker may be a wrong operation. + target.Marker().Mark() + return + } + defer cc.Close() + target.Marker().Reset() + + t := time.Now() + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr()) + handler.Transport(conn, cc) + h.logger. + WithFields(map[string]interface{}{ + "duration": time.Since(t), + }). + Infof("%s >-< %s", conn.RemoteAddr(), target.Addr()) +} + +func (h *localForwardHandler) parseMetadata(md md.Metadata) (err error) { + h.md.readTimeout = md.GetDuration(readTimeout) + h.md.retryCount = md.GetInt(retryCount) + return +} diff --git a/pkg/handler/forward/local/metadata.go b/pkg/handler/forward/local/metadata.go new file mode 100644 index 0000000..617a55b --- /dev/null +++ b/pkg/handler/forward/local/metadata.go @@ -0,0 +1,15 @@ +package local + +import ( + "time" +) + +const ( + readTimeout = "readTimeout" + retryCount = "retry" +) + +type metadata struct { + readTimeout time.Duration + retryCount int +} diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 610c868..1016d91 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -4,6 +4,7 @@ import ( "context" "net" + "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/metadata" ) @@ -11,3 +12,7 @@ type Handler interface { Init(metadata.Metadata) error Handle(context.Context, net.Conn) } + +type Forwarder interface { + Forward(*chain.NodeGroup) +} diff --git a/pkg/handler/router.go b/pkg/handler/router.go index 94cb0ac..4474e7b 100644 --- a/pkg/handler/router.go +++ b/pkg/handler/router.go @@ -36,6 +36,7 @@ func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Co if count <= 0 { count = 1 } + r.logger.Debugf("dial: %s/%s", address, network) for i := 0; i < count; i++ { route := r.chain.GetRouteFor(network, address) diff --git a/pkg/handler/socks/v5/handler.go b/pkg/handler/socks/v5/handler.go index 3015712..324801f 100644 --- a/pkg/handler/socks/v5/handler.go +++ b/pkg/handler/socks/v5/handler.go @@ -90,13 +90,29 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) { case gosocks5.CmdConnect: h.handleConnect(ctx, conn, req.Addr.String()) case gosocks5.CmdBind: - h.handleBind(ctx, conn, req) + if h.md.enableBind { + h.handleBind(ctx, conn, req) + } else { + h.logger.Error("BIND is diabled") + } case socks.CmdMuxBind: - h.handleMuxBind(ctx, conn, req) + if h.md.enableBind { + h.handleMuxBind(ctx, conn, req) + } else { + h.logger.Error("BIND is diabled") + } case gosocks5.CmdUdp: - h.handleUDP(ctx, conn, req) + if h.md.enableUDP { + h.handleUDP(ctx, conn, req) + } else { + h.logger.Error("UDP relay is diabled") + } case socks.CmdUDPTun: - h.handleUDPTun(ctx, conn, req) + if h.md.enableUDP { + h.handleUDPTun(ctx, conn, req) + } else { + h.logger.Error("UDP relay is diabled") + } default: h.logger.Errorf("unknown cmd: %d", req.Cmd) resp := gosocks5.NewReply(gosocks5.CmdUnsupported, nil) diff --git a/pkg/handler/socks/v5/metadata.go b/pkg/handler/socks/v5/metadata.go index 55e8979..2a1f8f5 100644 --- a/pkg/handler/socks/v5/metadata.go +++ b/pkg/handler/socks/v5/metadata.go @@ -10,18 +10,6 @@ import ( md "github.com/go-gost/gost/pkg/metadata" ) -const ( - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - authsKey = "auths" - readTimeout = "readTimeout" - timeout = "timeout" - retryCount = "retry" - noTLS = "notls" - udpBufferSize = "udpBufferSize" -) - type metadata struct { tlsConfig *tls.Config authenticator auth.Authenticator @@ -29,10 +17,26 @@ type metadata struct { readTimeout time.Duration retryCount int noTLS bool + enableBind bool + enableUDP bool udpBufferSize int } func (h *socks5Handler) parseMetadata(md md.Metadata) error { + const ( + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + authsKey = "auths" + readTimeout = "readTimeout" + timeout = "timeout" + retryCount = "retry" + noTLS = "notls" + enableBind = "bind" + enableUDP = "udp" + udpBufferSize = "udpBufferSize" + ) + var err error h.md.tlsConfig, err = util_tls.LoadTLSConfig( md.GetString(certFile), @@ -62,6 +66,8 @@ func (h *socks5Handler) parseMetadata(md md.Metadata) error { h.md.timeout = md.GetDuration(timeout) h.md.retryCount = md.GetInt(retryCount) h.md.noTLS = md.GetBool(noTLS) + h.md.enableBind = md.GetBool(enableBind) + h.md.enableUDP = md.GetBool(enableUDP) h.md.udpBufferSize = md.GetInt(udpBufferSize) if h.md.udpBufferSize > 0 { diff --git a/pkg/handler/ss/handler.go b/pkg/handler/ss/handler.go index 21bedc3..a2f255a 100644 --- a/pkg/handler/ss/handler.go +++ b/pkg/handler/ss/handler.go @@ -1,6 +1,7 @@ package ss import ( + "bufio" "context" "io" "io/ioutil" @@ -61,24 +62,55 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() - sc := conn + // standard UDP relay. + if pc, ok := conn.(net.PacketConn); ok { + if h.md.enableUDP { + h.handleUDP(ctx, conn.RemoteAddr(), pc) + return + } else { + h.logger.Error("UDP relay is diabled") + } + + return + } + if h.md.cipher != nil { - sc = ss.ShadowConn(h.md.cipher.StreamConn(conn), nil) + conn = ss.ShadowConn(h.md.cipher.StreamConn(conn), nil) } if h.md.readTimeout > 0 { - sc.SetReadDeadline(time.Now().Add(h.md.readTimeout)) + conn.SetReadDeadline(time.Now().Add(h.md.readTimeout)) } - addr := &gosocks5.Addr{} - _, err := addr.ReadFrom(sc) + br := bufio.NewReader(conn) + data, err := br.Peek(3) if err != nil { h.logger.Error(err) h.discard(conn) return } + conn.SetReadDeadline(time.Time{}) - sc.SetReadDeadline(time.Time{}) + conn = handler.NewBufferReaderConn(conn, br) + if data[2] == 0xff { + if h.md.enableUDP { + // UDP-over-TCP relay + h.handleUDPTun(ctx, conn) + } else { + h.logger.Error("UDP relay is diabled") + } + return + } + + // standard TCP. + addr := &gosocks5.Addr{} + if _, err = addr.ReadFrom(conn); err != nil { + h.logger.Error(err) + h.discard(conn) + return + } + + conn.SetReadDeadline(time.Time{}) h.logger = h.logger.WithFields(map[string]interface{}{ "dst": addr.String(), @@ -103,7 +135,7 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { t := time.Now() h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) - handler.Transport(sc, cc) + handler.Transport(conn, cc) h.logger. WithFields(map[string]interface{}{ "duration": time.Since(t), @@ -114,18 +146,3 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { func (h *ssHandler) discard(conn net.Conn) { io.Copy(ioutil.Discard, conn) } - -func (h *ssHandler) parseMetadata(md md.Metadata) (err error) { - h.md.cipher, err = ss.ShadowCipher( - md.GetString(method), - md.GetString(password), - md.GetString(key), - ) - if err != nil { - return - } - - h.md.readTimeout = md.GetDuration(readTimeout) - h.md.retryCount = md.GetInt(retryCount) - return -} diff --git a/pkg/handler/ss/metadata.go b/pkg/handler/ss/metadata.go index a3e8a1e..e31e9bc 100644 --- a/pkg/handler/ss/metadata.go +++ b/pkg/handler/ss/metadata.go @@ -3,19 +3,53 @@ package ss import ( "time" + "github.com/go-gost/gost/pkg/internal/utils/ss" + md "github.com/go-gost/gost/pkg/metadata" "github.com/shadowsocks/go-shadowsocks2/core" ) -const ( - method = "method" - password = "password" - key = "key" - readTimeout = "readTimeout" - retryCount = "retry" -) - type metadata struct { cipher core.Cipher readTimeout time.Duration retryCount int + bufferSize int + enableUDP bool +} + +func (h *ssHandler) parseMetadata(md md.Metadata) (err error) { + const ( + method = "method" + password = "password" + key = "key" + readTimeout = "readTimeout" + retryCount = "retry" + enableUDP = "udp" + bufferSize = "bufferSize" + ) + + h.md.cipher, err = ss.ShadowCipher( + md.GetString(method), + md.GetString(password), + md.GetString(key), + ) + if err != nil { + return + } + + h.md.readTimeout = md.GetDuration(readTimeout) + h.md.retryCount = md.GetInt(retryCount) + h.md.enableUDP = md.GetBool(enableUDP) + + h.md.bufferSize = md.GetInt(bufferSize) + if h.md.bufferSize > 0 { + if h.md.bufferSize < 512 { + h.md.bufferSize = 512 // min buffer size + } + if h.md.bufferSize > 65*1024 { + h.md.bufferSize = 65 * 1024 // max buffer size + } + } else { + h.md.bufferSize = 4096 // default buffer size + } + return } diff --git a/pkg/handler/ssu/handler.go b/pkg/handler/ss/udp.go similarity index 60% rename from pkg/handler/ssu/handler.go rename to pkg/handler/ss/udp.go index ee80c2b..1553a1c 100644 --- a/pkg/handler/ssu/handler.go +++ b/pkg/handler/ss/udp.go @@ -1,64 +1,21 @@ -package ssu +package ss import ( "context" "net" "time" - "github.com/go-gost/gost/pkg/bypass" - "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/internal/bufpool" "github.com/go-gost/gost/pkg/internal/utils/socks" "github.com/go-gost/gost/pkg/internal/utils/ss" - "github.com/go-gost/gost/pkg/logger" - md "github.com/go-gost/gost/pkg/metadata" - "github.com/go-gost/gost/pkg/registry" ) -func init() { - registry.RegisterHandler("ssu", NewHandler) -} - -type ssuHandler struct { - chain *chain.Chain - bypass bypass.Bypass - logger logger.Logger - md metadata -} - -func NewHandler(opts ...handler.Option) handler.Handler { - options := &handler.Options{} - for _, opt := range opts { - opt(options) +func (h *ssHandler) handleUDP(ctx context.Context, raddr net.Addr, conn net.PacketConn) { + if h.md.cipher != nil { + conn = h.md.cipher.PacketConn(conn) } - return &ssuHandler{ - chain: options.Chain, - bypass: options.Bypass, - logger: options.Logger, - } -} - -func (h *ssuHandler) Init(md md.Metadata) (err error) { - return h.parseMetadata(md) -} - -func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { - defer conn.Close() - - start := time.Now() - h.logger = h.logger.WithFields(map[string]interface{}{ - "remote": conn.RemoteAddr().String(), - "local": conn.LocalAddr().String(), - }) - h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) - defer func() { - h.logger.WithFields(map[string]interface{}{ - "duration": time.Since(start), - }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) - }() - // obtain a udp connection r := (&handler.Router{}). WithChain(h.chain). @@ -81,28 +38,40 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { "bind": cc.LocalAddr().String(), }) h.logger.Infof("bind on %s OK", cc.LocalAddr().String()) + t := time.Now() + h.logger.Infof("%s <-> %s", raddr, cc.LocalAddr()) + h.relayPacket( + ss.UDPServerConn(conn, raddr, h.md.bufferSize), + cc, + ) + h.logger. + WithFields(map[string]interface{}{"duration": time.Since(t)}). + Infof("%s >-< %s", raddr, cc.LocalAddr()) +} - pc, ok := conn.(net.PacketConn) - if ok { - if h.md.cipher != nil { - pc = h.md.cipher.PacketConn(pc) - } - - t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) - h.relayPacket( - ss.UDPServerConn(pc, conn.RemoteAddr(), h.md.bufferSize), - cc, - ) - h.logger. - WithFields(map[string]interface{}{"duration": time.Since(t)}). - Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) +func (h *ssHandler) handleUDPTun(ctx context.Context, conn net.Conn) { + // obtain a udp connection + r := (&handler.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + c, err := r.Dial(ctx, "udp", "") + if err != nil { + h.logger.Error(err) return } - if h.md.cipher != nil { - conn = ss.ShadowConn(h.md.cipher.StreamConn(conn), nil) + cc, ok := c.(net.PacketConn) + if !ok { + h.logger.Errorf("%s: not a packet connection") + return } + defer cc.Close() + + h.logger = h.logger.WithFields(map[string]interface{}{ + "bind": cc.LocalAddr().String(), + }) + h.logger.Infof("bind on %s OK", cc.LocalAddr().String()) t := time.Now() h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) @@ -112,7 +81,7 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) } -func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) { +func (h *ssHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) { bufSize := h.md.bufferSize errc := make(chan error, 2) @@ -183,7 +152,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) { return <-errc } -func (h *ssuHandler) tunnelUDP(tunnel, c net.PacketConn) (err error) { +func (h *ssHandler) tunnelUDP(tunnel, c net.PacketConn) (err error) { bufSize := h.md.bufferSize errc := make(chan error, 2) @@ -255,31 +224,3 @@ func (h *ssuHandler) tunnelUDP(tunnel, c net.PacketConn) (err error) { return <-errc } - -func (h *ssuHandler) parseMetadata(md md.Metadata) (err error) { - h.md.cipher, err = ss.ShadowCipher( - md.GetString(method), - md.GetString(password), - md.GetString(key), - ) - if err != nil { - return - } - - h.md.readTimeout = md.GetDuration(readTimeout) - h.md.retryCount = md.GetInt(retryCount) - - h.md.bufferSize = md.GetInt(bufferSize) - if h.md.bufferSize > 0 { - if h.md.bufferSize < 512 { - h.md.bufferSize = 512 // min buffer size - } - if h.md.bufferSize > 65*1024 { - h.md.bufferSize = 65 * 1024 // max buffer size - } - } else { - h.md.bufferSize = 4096 // default buffer size - } - - return -} diff --git a/pkg/handler/ssu/metadata.go b/pkg/handler/ssu/metadata.go deleted file mode 100644 index c8f1df2..0000000 --- a/pkg/handler/ssu/metadata.go +++ /dev/null @@ -1,23 +0,0 @@ -package ssu - -import ( - "time" - - "github.com/shadowsocks/go-shadowsocks2/core" -) - -const ( - method = "method" - password = "password" - key = "key" - readTimeout = "readTimeout" - retryCount = "retry" - bufferSize = "bufferSize" -) - -type metadata struct { - cipher core.Cipher - readTimeout time.Duration - retryCount int - bufferSize int -} diff --git a/pkg/handler/transport.go b/pkg/handler/transport.go index 4e17b7e..89266ea 100644 --- a/pkg/handler/transport.go +++ b/pkg/handler/transport.go @@ -1,7 +1,9 @@ package handler import ( + "bufio" "io" + "net" "github.com/go-gost/gost/pkg/internal/bufpool" ) @@ -30,3 +32,19 @@ func copyBuffer(dst io.Writer, src io.Reader) error { _, err := io.CopyBuffer(dst, src, buf) return err } + +type bufferReaderConn struct { + net.Conn + br *bufio.Reader +} + +func NewBufferReaderConn(conn net.Conn, br *bufio.Reader) net.Conn { + return &bufferReaderConn{ + Conn: conn, + br: br, + } +} + +func (c *bufferReaderConn) Read(b []byte) (int, error) { + return c.br.Read(b) +} diff --git a/pkg/internal/utils/socks/conn.go b/pkg/internal/utils/socks/conn.go index fc0ee6c..ba1adc2 100644 --- a/pkg/internal/utils/socks/conn.go +++ b/pkg/internal/utils/socks/conn.go @@ -76,6 +76,7 @@ func (c *UDPTunConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { Data: b, } dgram.Header.Rsv = uint16(len(dgram.Data)) + dgram.Header.Frag = 0xff // UDP tun relay flag, used by shadowsocks _, err = dgram.WriteTo(c.Conn) n = len(b)