diff --git a/cmd/gost/config.go b/cmd/gost/config.go index 7ff0e6d..8c9f6c5 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -46,7 +46,7 @@ func buildService(cfg *config.Config) (services []*service.Service) { listener.LoggerOption(listenerLogger), ) if err := ln.Init(metadata.MapMetadata(svc.Listener.Metadata)); err != nil { - listenerLogger.Fatal("init:", err) + listenerLogger.Fatal("init: ", err) } handlerLogger := log.WithFields(map[string]interface{}{ @@ -61,7 +61,7 @@ func buildService(cfg *config.Config) (services []*service.Service) { handler.LoggerOption(handlerLogger), ) if err := h.Init(metadata.MapMetadata(svc.Handler.Metadata)); err != nil { - handlerLogger.Fatal("init:", err) + handlerLogger.Fatal("init: ", err) } s := (&service.Service{}). @@ -95,7 +95,7 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { connector.LoggerOption(connectorLogger), ) if err := cr.Init(metadata.MapMetadata(v.Connector.Metadata)); err != nil { - connectorLogger.Fatal("init:", err) + connectorLogger.Fatal("init: ", err) } dialerLogger := log.WithFields(map[string]interface{}{ @@ -108,7 +108,7 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { dialer.LoggerOption(dialerLogger), ) if err := d.Init(metadata.MapMetadata(v.Dialer.Metadata)); err != nil { - dialerLogger.Fatal("init:", err) + dialerLogger.Fatal("init: ", err) } tr := (&chain.Transport{}). @@ -165,22 +165,19 @@ func selectorFromConfig(cfg *config.LoadbalancingConfig) chain.Selector { var strategy chain.Strategy switch cfg.Strategy { case "round": - strategy = &chain.RoundRobinStrategy{} + strategy = chain.RoundRobinStrategy() case "random": - strategy = &chain.RandomStrategy{} + strategy = chain.RandomStrategy() case "fifio": - strategy = &chain.FIFOStrategy{} + strategy = chain.FIFOStrategy() default: - strategy = &chain.RoundRobinStrategy{} + strategy = chain.RoundRobinStrategy() } return chain.NewSelector( strategy, - &chain.InvalidFilter{}, - &chain.FailFilter{ - MaxFails: cfg.MaxFails, - FailTimeout: cfg.FailTimeout, - }, + chain.InvalidFilter(), + chain.FailFilter(cfg.MaxFails, cfg.FailTimeout), ) } diff --git a/cmd/gost/gost.yml b/cmd/gost/gost.yml index 4c7542c..e2adadd 100644 --- a/cmd/gost/gost.yml +++ b/cmd/gost/gost.yml @@ -31,7 +31,6 @@ services: metadata: method: AES-256-GCM password: gost - key: gost readTimeout: 5s retry: 3 listener: @@ -40,6 +39,23 @@ services: keepAlive: 15s chain: chain01 # bypass: bypass01 +- name: socks5+tcp + url: "socks5://gost:gost@:1080" + addr: ":1080" + handler: + type: socks5 + metadata: + auths: + - gost:gost + readTimeout: 5s + retry: 3 + notls: true + listener: + type: tcp + metadata: + keepAlive: 15s + chain: chain-socks4 + # bypass: bypass01 chains: - name: chain01 @@ -99,6 +115,20 @@ chains: dialer: type: tcp metadata: {} +- name: chain-socks4 + hops: + - name: hop01 + nodes: + - name: node01 + addr: ":8081" + url: "http://gost:gost@:8081" + # bypass: bypass01 + connector: + type: socks4 + metadata: {} + dialer: + type: tcp + metadata: {} bypasses: - name: bypass01 diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 6f16f25..f779d44 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -3,6 +3,7 @@ package main import ( // Register connectors _ "github.com/go-gost/gost/pkg/connector/http" + _ "github.com/go-gost/gost/pkg/connector/socks/v4" _ "github.com/go-gost/gost/pkg/connector/ss" // Register dialers @@ -10,6 +11,8 @@ import ( // Register handlers _ "github.com/go-gost/gost/pkg/handler/http" + _ "github.com/go-gost/gost/pkg/handler/socks/v4" + _ "github.com/go-gost/gost/pkg/handler/socks/v5" _ "github.com/go-gost/gost/pkg/handler/ss" _ "github.com/go-gost/gost/pkg/handler/ssu" diff --git a/go.mod b/go.mod index 0b2a31c..8dcd962 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect github.com/coreos/go-iptables v0.5.0 // indirect github.com/ginuerzh/tls-dissector v0.0.2-0.20201202075250-98fa925912da + github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.3.0 github.com/gobwas/glob v0.2.3 github.com/golang/snappy v0.0.3 diff --git a/go.sum b/go.sum index cd51b87..1825123 100644 --- a/go.sum +++ b/go.sum @@ -109,6 +109,8 @@ github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= +github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.3.0 h1:Hkmp9YDRBSCJd7xywW6dBPT6B9aQTkuWd+3WCheJiJA= github.com/go-gost/gosocks5 v0.3.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= diff --git a/pkg/chain/chain.go b/pkg/chain/chain.go index b6a07c3..ea48c2e 100644 --- a/pkg/chain/chain.go +++ b/pkg/chain/chain.go @@ -8,7 +8,11 @@ func (c *Chain) AddNodeGroup(group *NodeGroup) { c.groups = append(c.groups, group) } -func (c *Chain) GetRouteFor(addr string) (r *Route) { +func (c *Chain) GetRoute() (r *Route) { + return c.GetRouteFor("tcp", "") +} + +func (c *Chain) GetRouteFor(network, address string) (r *Route) { if c == nil || len(c.groups) == 0 { return } @@ -19,7 +23,7 @@ func (c *Chain) GetRouteFor(addr string) (r *Route) { if node == nil { return } - if node.bypass != nil && node.bypass.Contains(addr) { + if node.bypass != nil && node.bypass.Contains(address) { break } @@ -33,3 +37,7 @@ func (c *Chain) GetRouteFor(addr string) (r *Route) { } return r } + +func (c *Chain) IsEmpty() bool { + return c == nil || len(c.groups) == 0 +} diff --git a/pkg/chain/route.go b/pkg/chain/route.go index 6b02392..d752fd1 100644 --- a/pkg/chain/route.go +++ b/pkg/chain/route.go @@ -6,6 +6,10 @@ import ( "net" ) +var ( + ErrEmptyRoute = errors.New("empty route") +) + type Route struct { nodes []*Node } @@ -16,7 +20,7 @@ func (r *Route) AddNode(node *Node) { func (r *Route) Connect(ctx context.Context) (conn net.Conn, err error) { if r.IsEmpty() { - return nil, errors.New("empty route") + return nil, ErrEmptyRoute } node := r.nodes[0] diff --git a/pkg/chain/selector.go b/pkg/chain/selector.go index bb39b9e..d6335ab 100644 --- a/pkg/chain/selector.go +++ b/pkg/chain/selector.go @@ -49,13 +49,17 @@ type Strategy interface { Apply(nodes ...*Node) *Node } -// RoundStrategy is a strategy for node selector. -// The node will be selected by round-robin algorithm. -type RoundRobinStrategy struct { +type roundRobinStrategy struct { counter uint64 } -func (s *RoundRobinStrategy) Apply(nodes ...*Node) *Node { +// RoundRobinStrategy is a strategy for node selector. +// The node will be selected by round-robin algorithm. +func RoundRobinStrategy() Strategy { + return &roundRobinStrategy{} +} + +func (s *roundRobinStrategy) Apply(nodes ...*Node) *Node { if len(nodes) == 0 { return nil } @@ -64,23 +68,20 @@ func (s *RoundRobinStrategy) Apply(nodes ...*Node) *Node { return nodes[int(n%uint64(len(nodes)))] } -// RandomStrategy is a strategy for node selector. -// The node will be selected randomly. -type RandomStrategy struct { - Seed int64 +type randomStrategy struct { rand *rand.Rand - once sync.Once mux sync.Mutex } -func (s *RandomStrategy) Apply(nodes ...*Node) *Node { - s.once.Do(func() { - seed := s.Seed - if seed == 0 { - seed = time.Now().UnixNano() - } - s.rand = rand.New(rand.NewSource(seed)) - }) +// RandomStrategy is a strategy for node selector. +// The node will be selected randomly. +func RandomStrategy() Strategy { + return &randomStrategy{ + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + } +} + +func (s *randomStrategy) Apply(nodes ...*Node) *Node { if len(nodes) == 0 { return nil } @@ -93,13 +94,17 @@ func (s *RandomStrategy) Apply(nodes ...*Node) *Node { return nodes[r%len(nodes)] } +type fifoStrategy struct{} + // FIFOStrategy is a strategy for node selector. // The node will be selected from first to last, // and will stick to the selected node until it is failed. -type FIFOStrategy struct{} +func FIFOStrategy() Strategy { + return &fifoStrategy{} +} // Apply applies the fifo strategy for the nodes. -func (s *FIFOStrategy) Apply(nodes ...*Node) *Node { +func (s *fifoStrategy) Apply(nodes ...*Node) *Node { if len(nodes) == 0 { return nil } @@ -110,20 +115,27 @@ type Filter interface { Filter(nodes ...*Node) []*Node } +type failFilter struct { + maxFails int + failTimeout time.Duration +} + // FailFilter filters the dead node. // A node is marked as dead if its failed count is greater than MaxFails. -type FailFilter struct { - MaxFails int - FailTimeout time.Duration +func FailFilter(maxFails int, timeout time.Duration) Filter { + return &failFilter{ + maxFails: maxFails, + failTimeout: timeout, + } } // Filter filters dead nodes. -func (f *FailFilter) Filter(nodes ...*Node) []*Node { - maxFails := f.MaxFails +func (f *failFilter) Filter(nodes ...*Node) []*Node { + maxFails := f.maxFails if maxFails == 0 { maxFails = DefaultMaxFails } - failTimeout := f.FailTimeout + failTimeout := f.failTimeout if failTimeout == 0 { failTimeout = DefaultFailTimeout } @@ -141,12 +153,16 @@ func (f *FailFilter) Filter(nodes ...*Node) []*Node { return nl } +type invalidFilter struct{} + // InvalidFilter filters the invalid node. // A node is invalid if its port is invalid (negative or zero value). -type InvalidFilter struct{} +func InvalidFilter() Filter { + return &invalidFilter{} +} // Filter filters invalid nodes. -func (f *InvalidFilter) Filter(nodes ...*Node) []*Node { +func (f *invalidFilter) Filter(nodes ...*Node) []*Node { var nl []*Node for _, node := range nodes { _, sport, _ := net.SplitHostPort(node.Addr()) diff --git a/pkg/connector/socks/v4/connector.go b/pkg/connector/socks/v4/connector.go new file mode 100644 index 0000000..97eb862 --- /dev/null +++ b/pkg/connector/socks/v4/connector.go @@ -0,0 +1,119 @@ +package v4 + +import ( + "context" + "fmt" + "net" + "net/url" + "strconv" + "time" + + "github.com/go-gost/gosocks4" + "github.com/go-gost/gost/pkg/connector" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegiserConnector("socks4", NewConnector) + registry.RegiserConnector("socks4a", NewConnector) +} + +type socks4Connector struct { + md metadata + logger logger.Logger +} + +func NewConnector(opts ...connector.Option) connector.Connector { + options := &connector.Options{} + for _, opt := range opts { + opt(options) + } + + return &socks4Connector{ + logger: options.Logger, + } +} + +func (c *socks4Connector) Init(md md.Metadata) (err error) { + return c.parseMetadata(md) +} + +func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { + c.logger = c.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + "target": address, + }) + c.logger.Info("connect: ", address) + + var addr *gosocks4.Addr + + if c.md.disable4a { + taddr, err := net.ResolveTCPAddr("tcp4", address) + if err != nil { + c.logger.Error("resolve: ", err) + return nil, err + } + if len(taddr.IP) == 0 { + taddr.IP = net.IPv4zero + } + addr = &gosocks4.Addr{ + Type: gosocks4.AddrIPv4, + Host: taddr.IP.String(), + Port: uint16(taddr.Port), + } + } else { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + p, _ := strconv.Atoi(port) + addr = &gosocks4.Addr{ + Type: gosocks4.AddrDomain, + Host: host, + Port: uint16(p), + } + } + + if c.md.connectTimeout > 0 { + conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) + } + defer conn.SetDeadline(time.Time{}) + + req := gosocks4.NewRequest(gosocks4.CmdConnect, addr, nil) + if err := req.Write(conn); err != nil { + c.logger.Error(err) + return nil, err + } + if c.logger.IsLevelEnabled(logger.DebugLevel) { + c.logger.Debug(req) + } + + reply, err := gosocks4.ReadReply(conn) + if err != nil { + c.logger.Error(err) + return nil, err + } + + if c.logger.IsLevelEnabled(logger.DebugLevel) { + c.logger.Debug(reply) + } + + if reply.Code != gosocks4.Granted { + return nil, fmt.Errorf("error: %d", reply.Code) + } + + return conn, nil +} + +func (c *socks4Connector) parseMetadata(md md.Metadata) (err error) { + if v := md.GetString(auth); v != "" { + c.md.User = url.User(v) + } + c.md.connectTimeout = md.GetDuration(connectTimeout) + c.md.disable4a = md.GetBool(disable4a) + + return +} diff --git a/pkg/connector/socks/v4/metadata.go b/pkg/connector/socks/v4/metadata.go new file mode 100644 index 0000000..a73efed --- /dev/null +++ b/pkg/connector/socks/v4/metadata.go @@ -0,0 +1,18 @@ +package v4 + +import ( + "net/url" + "time" +) + +const ( + connectTimeout = "timeout" + auth = "auth" + disable4a = "disable4a" +) + +type metadata struct { + connectTimeout time.Duration + User *url.Userinfo + disable4a bool +} diff --git a/pkg/connector/socks/v5/connector.go b/pkg/connector/socks/v5/connector.go new file mode 100644 index 0000000..7b87378 --- /dev/null +++ b/pkg/connector/socks/v5/connector.go @@ -0,0 +1,116 @@ +package v5 + +import ( + "context" + "fmt" + "net" + "net/url" + "strconv" + "time" + + "github.com/go-gost/gosocks4" + "github.com/go-gost/gost/pkg/connector" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegiserConnector("socks4", NewConnector) + registry.RegiserConnector("socks4a", NewConnector) +} + +type socks4Connector struct { + md metadata + logger logger.Logger +} + +func NewConnector(opts ...connector.Option) connector.Connector { + options := &connector.Options{} + for _, opt := range opts { + opt(options) + } + + return &socks4Connector{ + logger: options.Logger, + } +} + +func (c *socks4Connector) Init(md md.Metadata) (err error) { + return c.parseMetadata(md) +} + +func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { + c.logger = c.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + "target": address, + }) + c.logger.Infof("connect: ", address) + + var addr *gosocks4.Addr + + if c.md.disable4a { + taddr, err := net.ResolveTCPAddr("tcp4", address) + if err != nil { + c.logger.Error("resolve: ", err) + return nil, err + } + if len(taddr.IP) == 0 { + taddr.IP = net.IPv4zero + } + addr = &gosocks4.Addr{ + Type: gosocks4.AddrIPv4, + Host: taddr.IP.String(), + Port: uint16(taddr.Port), + } + } else { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + p, _ := strconv.Atoi(port) + addr = &gosocks4.Addr{ + Type: gosocks4.AddrDomain, + Host: host, + Port: uint16(p), + } + } + conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) + defer conn.SetDeadline(time.Time{}) + + req := gosocks4.NewRequest(gosocks4.CmdConnect, addr, nil) + if err := req.Write(conn); err != nil { + c.logger.Error(err) + return nil, err + } + if c.logger.IsLevelEnabled(logger.DebugLevel) { + c.logger.Debug(req) + } + + reply, err := gosocks4.ReadReply(conn) + if err != nil { + c.logger.Error(err) + return nil, err + } + + if c.logger.IsLevelEnabled(logger.DebugLevel) { + c.logger.Debug(reply) + } + + if reply.Code != gosocks4.Granted { + return nil, fmt.Errorf("error: %d", reply.Code) + } + + return conn, nil +} + +func (c *socks4Connector) parseMetadata(md md.Metadata) (err error) { + if v := md.GetString(auth); v != "" { + c.md.User = url.User(v) + } + c.md.connectTimeout = md.GetDuration(connectTimeout) + c.md.disable4a = md.GetBool(disable4a) + + return +} diff --git a/pkg/connector/socks/v5/metadata.go b/pkg/connector/socks/v5/metadata.go new file mode 100644 index 0000000..ffb0bbc --- /dev/null +++ b/pkg/connector/socks/v5/metadata.go @@ -0,0 +1,18 @@ +package v5 + +import ( + "net/url" + "time" +) + +const ( + connectTimeout = "timeout" + auth = "auth" + disable4a = "disable4a" +) + +type metadata struct { + connectTimeout time.Duration + User *url.Userinfo + disable4a bool +} diff --git a/pkg/connector/ss/connector.go b/pkg/connector/ss/connector.go index 8755970..1b51ddf 100644 --- a/pkg/connector/ss/connector.go +++ b/pkg/connector/ss/connector.go @@ -18,7 +18,7 @@ func init() { registry.RegiserConnector("ss", NewConnector) } -type Connector struct { +type ssConnector struct { md metadata logger logger.Logger } @@ -29,16 +29,16 @@ func NewConnector(opts ...connector.Option) connector.Connector { opt(options) } - return &Connector{ + return &ssConnector{ logger: options.Logger, } } -func (c *Connector) Init(md md.Metadata) (err error) { +func (c *ssConnector) Init(md md.Metadata) (err error) { return c.parseMetadata(md) } -func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { +func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { c.logger = c.logger.WithFields(map[string]interface{}{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), @@ -60,8 +60,10 @@ func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address return nil, err } - conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) - defer conn.SetDeadline(time.Time{}) + if c.md.connectTimeout > 0 { + conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) + defer conn.SetDeadline(time.Time{}) + } if c.md.cipher != nil { conn = c.md.cipher.StreamConn(conn) @@ -82,7 +84,7 @@ func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address return sc, nil } -func (c *Connector) parseMetadata(md md.Metadata) (err error) { +func (c *ssConnector) parseMetadata(md md.Metadata) (err error) { c.md.cipher, err = utils.ShadowCipher( md.GetString(method), md.GetString(password), diff --git a/pkg/handler/http/handler.go b/pkg/handler/http/handler.go index d912965..5816382 100644 --- a/pkg/handler/http/handler.go +++ b/pkg/handler/http/handler.go @@ -2,12 +2,10 @@ package http import ( "bufio" - "bytes" "context" "encoding/base64" "encoding/binary" "errors" - "fmt" "hash/crc32" "net" "net/http" @@ -134,13 +132,13 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt } req.Header.Del("X-Gost-Target") - host := req.Host - if _, port, _ := net.SplitHostPort(host); port == "" { - host = net.JoinHostPort(host, "80") + addr := req.Host + if _, port, _ := net.SplitHostPort(addr); port == "" { + addr = net.JoinHostPort(addr, "80") } fields := map[string]interface{}{ - "dst": host, + "dst": addr, } if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")); u != "" { fields["user"] = u @@ -151,7 +149,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt dump, _ := httputil.DumpRequest(req, false) h.logger.Debug(string(dump)) } - h.logger.Infof("%s > %s", conn.RemoteAddr(), host) + h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr) resp := &http.Response{ ProtoMajor: 1, @@ -179,14 +177,14 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt } */ - if h.bypass != nil && h.bypass.Contains(host) { + if h.bypass != nil && h.bypass.Contains(addr) { resp.StatusCode = http.StatusForbidden if h.logger.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpResponse(resp, false) h.logger.Debug(string(dump)) } - h.logger.Info("bypass: ", host) + h.logger.Info("bypass: ", addr) resp.Write(conn) return @@ -211,7 +209,11 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt req.Header.Del("Proxy-Authorization") - cc, err := h.dial(ctx, host) + r := (&handler.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + cc, err := r.Dial(ctx, "tcp", addr) if err != nil { resp.StatusCode = http.StatusServiceUnavailable resp.Write(conn) @@ -244,50 +246,9 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt } } - h.logger.Infof("%s <> %s", conn.RemoteAddr(), host) + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) handler.Transport(conn, cc) - h.logger.Infof("%s >< %s", conn.RemoteAddr(), host) -} - -func (h *httpHandler) dial(ctx context.Context, addr string) (conn net.Conn, err error) { - count := h.md.retryCount + 1 - if count <= 0 { - count = 1 - } - - for i := 0; i < count; i++ { - route := h.chain.GetRouteFor(addr) - - if h.logger.IsLevelEnabled(logger.DebugLevel) { - buf := bytes.Buffer{} - for _, node := range route.Path() { - fmt.Fprintf(&buf, "%s@%s > ", node.Name(), node.Addr()) - } - fmt.Fprintf(&buf, "%s", addr) - h.logger.Debugf("route(retry=%d): %s", i, buf.String()) - } - - /* - // forward http request - lastNode := route.LastNode() - if req.Method != http.MethodConnect && lastNode.Protocol == "http" { - err = h.forwardRequest(conn, req, route) - if err == nil { - return - } - log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - continue - } - */ - - conn, err = route.Dial(ctx, "tcp", addr) - if err == nil { - break - } - h.logger.Errorf("route(retry=%d): %s", i, err) - } - - return + h.logger.Infof("%s >-< %s", conn.RemoteAddr(), addr) } func (h *httpHandler) decodeServerName(s string) (string, error) { diff --git a/pkg/handler/http/metadata.go b/pkg/handler/http/metadata.go index d28806e..321fdfb 100644 --- a/pkg/handler/http/metadata.go +++ b/pkg/handler/http/metadata.go @@ -3,7 +3,6 @@ package http import "github.com/go-gost/gost/pkg/auth" const ( - addrKey = "addr" proxyAgentKey = "proxyAgent" authsKey = "auths" probeResistKey = "probeResist" @@ -12,7 +11,6 @@ const ( ) type metadata struct { - addr string authenticator auth.Authenticator proxyAgent string retryCount int diff --git a/pkg/handler/router.go b/pkg/handler/router.go new file mode 100644 index 0000000..94cb0ac --- /dev/null +++ b/pkg/handler/router.go @@ -0,0 +1,87 @@ +package handler + +import ( + "bytes" + "context" + "fmt" + "net" + + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/logger" +) + +type Router struct { + chain *chain.Chain + retries int + logger logger.Logger +} + +func (r *Router) WithChain(chain *chain.Chain) *Router { + r.chain = chain + return r +} + +func (r *Router) WithRetry(retries int) *Router { + r.retries = retries + return r +} + +func (r *Router) WithLogger(logger logger.Logger) *Router { + r.logger = logger + return r +} + +func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) { + count := r.retries + 1 + if count <= 0 { + count = 1 + } + + for i := 0; i < count; i++ { + route := r.chain.GetRouteFor(network, address) + + if r.logger.IsLevelEnabled(logger.DebugLevel) { + buf := bytes.Buffer{} + for _, node := range route.Path() { + fmt.Fprintf(&buf, "%s@%s > ", node.Name(), node.Addr()) + } + fmt.Fprintf(&buf, "%s", address) + r.logger.Debugf("route(retry=%d): %s", i, buf.String()) + } + + conn, err = route.Dial(ctx, network, address) + if err == nil { + break + } + r.logger.Errorf("route(retry=%d): %s", i, err) + } + + return +} + +func (r *Router) Connect(ctx context.Context) (conn net.Conn, err error) { + count := r.retries + 1 + if count <= 0 { + count = 1 + } + + for i := 0; i < count; i++ { + route := r.chain.GetRoute() + + if r.logger.IsLevelEnabled(logger.DebugLevel) { + buf := bytes.Buffer{} + for _, node := range route.Path() { + fmt.Fprintf(&buf, "%s@%s > ", node.Name(), node.Addr()) + } + r.logger.Debugf("route(retry=%d): %s", i, buf.String()) + } + + conn, err = route.Connect(ctx) + if err == nil { + break + } + r.logger.Errorf("route(retry=%d): %s", i, err) + } + + return +} diff --git a/pkg/handler/socks/v4/handler.go b/pkg/handler/socks/v4/handler.go new file mode 100644 index 0000000..6588d32 --- /dev/null +++ b/pkg/handler/socks/v4/handler.go @@ -0,0 +1,165 @@ +package v4 + +import ( + "context" + "net" + "time" + + "github.com/go-gost/gosocks4" + "github.com/go-gost/gost/pkg/auth" + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegisterHandler("socks4", NewHandler) + registry.RegisterHandler("socks4a", NewHandler) +} + +type socks4Handler struct { + chain *chain.Chain + bypass bypass.Bypass + logger logger.Logger + md metadata +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := &handler.Options{} + for _, opt := range opts { + opt(options) + } + + return &socks4Handler{ + chain: options.Chain, + bypass: options.Bypass, + logger: options.Logger, + } +} + +func (h *socks4Handler) Init(md md.Metadata) (err error) { + return h.parseMetadata(md) +} + +func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) { + defer conn.Close() + + start := time.Now() + + h.logger = h.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + if h.md.readTimeout > 0 { + conn.SetReadDeadline(time.Now().Add(h.md.readTimeout)) + } + + req, err := gosocks4.ReadRequest(conn) + if err != nil { + h.logger.Error(err) + return + } + conn.SetReadDeadline(time.Time{}) + + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(req) + } + + if h.md.authenticator != nil && + !h.md.authenticator.Authenticate(string(req.Userid), "") { + resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil) + resp.Write(conn) + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(resp) + } + return + } + + switch req.Cmd { + case gosocks4.CmdConnect: + h.handleConnect(ctx, conn, req) + case gosocks4.CmdBind: + h.handleBind(ctx, conn, req) + default: + h.logger.Errorf("unknown cmd: %d", req.Cmd) + } +} + +func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *gosocks4.Request) { + addr := req.Addr.String() + + h.logger = h.logger.WithFields(map[string]interface{}{ + "dst": addr, + }) + h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr) + + if h.bypass != nil && h.bypass.Contains(addr) { + resp := gosocks4.NewReply(gosocks4.Rejected, nil) + resp.Write(conn) + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(resp) + } + h.logger.Info("bypass: ", addr) + return + } + + r := (&handler.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + cc, err := r.Dial(ctx, "tcp", addr) + if err != nil { + resp := gosocks4.NewReply(gosocks4.Failed, nil) + resp.Write(conn) + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(resp) + } + return + } + + defer cc.Close() + + resp := gosocks4.NewReply(gosocks4.Granted, nil) + if err := resp.Write(conn); err != nil { + h.logger.Error(err) + return + } + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(resp) + } + + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) + handler.Transport(conn, cc) + h.logger.Infof("%s >-< %s", conn.RemoteAddr(), addr) +} + +func (h *socks4Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks4.Request) { + // TODO: bind +} + +func (h *socks4Handler) parseMetadata(md md.Metadata) (err error) { + if v, _ := md.Get(authsKey).([]interface{}); len(v) > 0 { + authenticator := auth.NewLocalAuthenticator(nil) + for _, auth := range v { + if v, _ := auth.(string); v != "" { + authenticator.Add(v, "") + } + } + h.md.authenticator = authenticator + } + + h.md.readTimeout = md.GetDuration(readTimeout) + h.md.retryCount = md.GetInt(retryCount) + return +} diff --git a/pkg/handler/socks/v4/metadata.go b/pkg/handler/socks/v4/metadata.go new file mode 100644 index 0000000..a32b5d7 --- /dev/null +++ b/pkg/handler/socks/v4/metadata.go @@ -0,0 +1,19 @@ +package v4 + +import ( + "time" + + "github.com/go-gost/gost/pkg/auth" +) + +const ( + authsKey = "auths" + readTimeout = "readTimeout" + retryCount = "retry" +) + +type metadata struct { + authenticator auth.Authenticator + readTimeout time.Duration + retryCount int +} diff --git a/pkg/handler/socks/v5/bind.go b/pkg/handler/socks/v5/bind.go new file mode 100644 index 0000000..f5088cf --- /dev/null +++ b/pkg/handler/socks/v5/bind.go @@ -0,0 +1,160 @@ +package v5 + +import ( + "context" + "net" + + "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" +) + +func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks5.Request) { + addr := req.Addr.String() + + h.logger = h.logger.WithFields(map[string]interface{}{ + "dst": addr, + "cmd": "bind", + }) + + h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr) + + if h.chain.IsEmpty() { + h.bindLocal(ctx, conn, addr) + return + } + + r := (&handler.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + cc, err := r.Connect(ctx) + if err != nil { + resp := gosocks5.NewReply(gosocks5.Failure, nil) + resp.Write(conn) + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(resp) + } + return + } + defer cc.Close() + + if err := req.Write(cc); err != nil { + h.logger.Error(err) + resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) + resp.Write(conn) + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(resp) + } + return + } + + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) + handler.Transport(conn, cc) + h.logger.Infof("%s >-< %s", conn.RemoteAddr(), addr) +} + +func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, addr string) { + bindAddr, _ := net.ResolveTCPAddr("tcp", addr) + ln, err := net.ListenTCP("tcp", bindAddr) // strict mode: if the port already in use, it will return error + if err != nil { + h.logger.Error(err) + reply := gosocks5.NewReply(gosocks5.Failure, nil) + if err := reply.Write(conn); err != nil { + h.logger.Error(err) + } + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(reply.String()) + } + return + } + + socksAddr, err := gosocks5.NewAddr(ln.Addr().String()) + if err != nil { + h.logger.Warn(err) + socksAddr = &gosocks5.Addr{ + Type: gosocks5.AddrIPv4, + } + } + + // Issue: may not reachable when host has multi-interface + socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) + reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) + if err := reply.Write(conn); err != nil { + h.logger.Error(err) + ln.Close() + return + } + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(reply.String()) + } + h.logger.Infof("bind on: %s OK", socksAddr.String()) + + h.serveBind(ctx, conn, ln) +} + +func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Listener) { + var rc net.Conn + accept := func() <-chan error { + errc := make(chan error, 1) + + go func() { + defer close(errc) + defer ln.Close() + + c, err := ln.Accept() + if err != nil { + errc <- err + } + rc = c + }() + + return errc + } + + pc1, pc2 := net.Pipe() + pipe := func() <-chan error { + errc := make(chan error, 1) + + go func() { + defer close(errc) + defer pc1.Close() + + errc <- handler.Transport(conn, pc1) + }() + + return errc + } + + defer pc2.Close() + + select { + case err := <-accept(): + if err != nil { + h.logger.Error(err) + return + } + defer rc.Close() + + raddr, _ := gosocks5.NewAddr(rc.RemoteAddr().String()) + reply := gosocks5.NewReply(gosocks5.Succeeded, raddr) + if err := reply.Write(pc2); err != nil { + h.logger.Error(err) + } + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(reply.String()) + } + h.logger.Infof("PEER %s ACCEPTED", raddr.String()) + + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), raddr.String()) + handler.Transport(pc2, rc) + h.logger.Infof("%s >-< %s", conn.RemoteAddr(), raddr.String()) + + case err := <-pipe(): + if err != nil { + h.logger.Error(err) + } + ln.Close() + return + } +} diff --git a/pkg/handler/socks/v5/connect.go b/pkg/handler/socks/v5/connect.go new file mode 100644 index 0000000..86f4542 --- /dev/null +++ b/pkg/handler/socks/v5/connect.go @@ -0,0 +1,57 @@ +package v5 + +import ( + "context" + "net" + + "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" +) + +func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr string) { + h.logger = h.logger.WithFields(map[string]interface{}{ + "dst": addr, + "cmd": "connect", + }) + h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr) + + if h.bypass != nil && h.bypass.Contains(addr) { + resp := gosocks5.NewReply(gosocks5.NotAllowed, nil) + resp.Write(conn) + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(resp) + } + h.logger.Info("bypass: ", addr) + return + } + + r := (&handler.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + cc, err := r.Dial(ctx, "tcp", addr) + if err != nil { + resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) + resp.Write(conn) + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(resp) + } + return + } + + defer cc.Close() + + resp := gosocks5.NewReply(gosocks5.Succeeded, nil) + if err := resp.Write(conn); err != nil { + h.logger.Error(err) + return + } + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(resp) + } + + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) + handler.Transport(conn, cc) + h.logger.Infof("%s >-< %s", conn.RemoteAddr(), addr) +} diff --git a/pkg/handler/socks/v5/handler.go b/pkg/handler/socks/v5/handler.go new file mode 100644 index 0000000..5a65269 --- /dev/null +++ b/pkg/handler/socks/v5/handler.go @@ -0,0 +1,125 @@ +package v5 + +import ( + "context" + "net" + "time" + + "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +const ( + // MethodTLS is an extended SOCKS5 method with tls encryption support. + MethodTLS uint8 = 0x80 + // MethodTLSAuth is an extended SOCKS5 method with tls encryption and authentication support. + MethodTLSAuth uint8 = 0x82 + // MethodMux is an extended SOCKS5 method for stream multiplexing. + MethodMux = 0x88 +) + +const ( + // CmdMuxBind is an extended SOCKS5 request CMD for + // multiplexing transport with the binding server. + CmdMuxBind uint8 = 0xF2 + // CmdUDPTun is an extended SOCKS5 request CMD for UDP over TCP. + CmdUDPTun uint8 = 0xF3 +) + +func init() { + registry.RegisterHandler("socks5", NewHandler) + registry.RegisterHandler("socks", NewHandler) +} + +type socks5Handler struct { + selector gosocks5.Selector + chain *chain.Chain + bypass bypass.Bypass + logger logger.Logger + md metadata +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := &handler.Options{} + for _, opt := range opts { + opt(options) + } + + return &socks5Handler{ + chain: options.Chain, + bypass: options.Bypass, + logger: options.Logger, + } +} + +func (h *socks5Handler) Init(md md.Metadata) (err error) { + if err := h.parseMetadata(md); err != nil { + return err + } + + h.selector = &serverSelector{ + Authenticator: h.md.authenticator, + TLSConfig: h.md.tlsConfig, + logger: h.logger, + noTLS: h.md.noTLS, + } + + return +} + +func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) { + defer conn.Close() + + start := time.Now() + + h.logger = h.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + if h.md.readTimeout > 0 { + conn.SetReadDeadline(time.Now().Add(h.md.readTimeout)) + } + + conn = gosocks5.ServerConn(conn, h.selector) + req, err := gosocks5.ReadRequest(conn) + if err != nil { + h.logger.Error(err) + return + } + conn.SetReadDeadline(time.Time{}) + + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(req) + } + + switch req.Cmd { + case gosocks5.CmdConnect: + h.handleConnect(ctx, conn, req.Addr.String()) + case gosocks5.CmdBind: + h.handleBind(ctx, conn, req) + case CmdMuxBind: + case gosocks5.CmdUdp: + case CmdUDPTun: + default: + h.logger.Errorf("unknown cmd: %d", req.Cmd) + resp := gosocks5.NewReply(gosocks5.CmdUnsupported, nil) + resp.Write(conn) + if h.logger.IsLevelEnabled(logger.DebugLevel) { + h.logger.Debug(resp) + } + return + } +} diff --git a/pkg/handler/socks/v5/metadata.go b/pkg/handler/socks/v5/metadata.go new file mode 100644 index 0000000..f3425ed --- /dev/null +++ b/pkg/handler/socks/v5/metadata.go @@ -0,0 +1,62 @@ +package v5 + +import ( + "crypto/tls" + "strings" + "time" + + "github.com/go-gost/gost/pkg/auth" + "github.com/go-gost/gost/pkg/internal/utils" + md "github.com/go-gost/gost/pkg/metadata" +) + +const ( + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + authsKey = "auths" + readTimeout = "readTimeout" + retryCount = "retry" + noTLS = "notls" +) + +type metadata struct { + tlsConfig *tls.Config + authenticator auth.Authenticator + readTimeout time.Duration + retryCount int + noTLS bool +} + +func (h *socks5Handler) parseMetadata(md md.Metadata) error { + var err error + h.md.tlsConfig, err = utils.LoadTLSConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) + if err != nil { + h.logger.Warn("parse tls config: ", err) + } + + if v, _ := md.Get(authsKey).([]interface{}); len(v) > 0 { + authenticator := auth.NewLocalAuthenticator(nil) + for _, auth := range v { + if s, _ := auth.(string); s != "" { + ss := strings.SplitN(s, ":", 2) + if len(ss) == 1 { + authenticator.Add(ss[0], "") + } else { + authenticator.Add(ss[0], ss[1]) + } + } + } + h.md.authenticator = authenticator + } + + h.md.readTimeout = md.GetDuration(readTimeout) + h.md.retryCount = md.GetInt(retryCount) + h.md.noTLS = md.GetBool(noTLS) + + return nil +} diff --git a/pkg/handler/socks/v5/selector.go b/pkg/handler/socks/v5/selector.go new file mode 100644 index 0000000..cc8feb5 --- /dev/null +++ b/pkg/handler/socks/v5/selector.go @@ -0,0 +1,97 @@ +package v5 + +import ( + "crypto/tls" + "net" + + "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/auth" + "github.com/go-gost/gost/pkg/logger" +) + +type serverSelector struct { + methods []uint8 + Authenticator auth.Authenticator + TLSConfig *tls.Config + logger logger.Logger + noTLS bool +} + +func (selector *serverSelector) Methods() []uint8 { + return selector.methods +} + +func (s *serverSelector) Select(methods ...uint8) (method uint8) { + if s.logger.IsLevelEnabled(logger.DebugLevel) { + s.logger.Debugf("%d %d %v", gosocks5.Ver5, len(methods), methods) + } + method = gosocks5.MethodNoAuth + for _, m := range methods { + if m == MethodTLS && !s.noTLS { + method = m + break + } + } + + // when Authenticator is set, auth is mandatory + if s.Authenticator != nil { + if method == gosocks5.MethodNoAuth { + method = gosocks5.MethodUserPass + } + if method == MethodTLS && !s.noTLS { + method = MethodTLSAuth + } + } + + return +} + +func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { + if s.logger.IsLevelEnabled(logger.DebugLevel) { + s.logger.Debugf("%d %d", gosocks5.Ver5, method) + } + switch method { + case MethodTLS: + conn = tls.Server(conn, s.TLSConfig) + + case gosocks5.MethodUserPass, MethodTLSAuth: + if method == MethodTLSAuth { + conn = tls.Server(conn, s.TLSConfig) + } + + req, err := gosocks5.ReadUserPassRequest(conn) + if err != nil { + s.logger.Error(err) + return nil, err + } + if s.logger.IsLevelEnabled(logger.DebugLevel) { + s.logger.Debug(req.String()) + } + + if s.Authenticator != nil && + !s.Authenticator.Authenticate(req.Username, req.Password) { + resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) + if err := resp.Write(conn); err != nil { + s.logger.Error(err) + return nil, err + } + if s.logger.IsLevelEnabled(logger.DebugLevel) { + s.logger.Info(resp.String()) + } + return nil, gosocks5.ErrAuthFailure + } + + resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded) + if err := resp.Write(conn); err != nil { + s.logger.Error(err) + return nil, err + } + if s.logger.IsLevelEnabled(logger.DebugLevel) { + s.logger.Debug(resp.String()) + } + case gosocks5.MethodNoAcceptable: + return nil, gosocks5.ErrBadMethod + } + + return conn, nil +} diff --git a/pkg/handler/ss/handler.go b/pkg/handler/ss/handler.go index 5ddd81d..af0ff95 100644 --- a/pkg/handler/ss/handler.go +++ b/pkg/handler/ss/handler.go @@ -1,9 +1,7 @@ package ss import ( - "bytes" "context" - "fmt" "io" "io/ioutil" "net" @@ -87,22 +85,26 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { "dst": addr.String(), }) - h.logger.Infof("%s > %s", conn.RemoteAddr(), addr) + h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr) if h.bypass != nil && h.bypass.Contains(addr.String()) { h.logger.Info("bypass: ", addr.String()) return } - cc, err := h.dial(ctx, addr.String()) + r := (&handler.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + cc, err := r.Dial(ctx, "tcp", addr.String()) if err != nil { return } defer cc.Close() - h.logger.Infof("%s <> %s", conn.RemoteAddr(), addr) + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) handler.Transport(sc, cc) - h.logger.Infof("%s >< %s", conn.RemoteAddr(), addr) + h.logger.Infof("%s >-< %s", conn.RemoteAddr(), addr) } func (h *ssHandler) discard(conn net.Conn) { @@ -123,31 +125,3 @@ func (h *ssHandler) parseMetadata(md md.Metadata) (err error) { h.md.retryCount = md.GetInt(retryCount) return } - -func (h *ssHandler) dial(ctx context.Context, addr string) (conn net.Conn, err error) { - count := h.md.retryCount + 1 - if count <= 0 { - count = 1 - } - - for i := 0; i < count; i++ { - route := h.chain.GetRouteFor(addr) - - if h.logger.IsLevelEnabled(logger.DebugLevel) { - buf := bytes.Buffer{} - for _, node := range route.Path() { - fmt.Fprintf(&buf, "%s@%s > ", node.Name(), node.Addr()) - } - fmt.Fprintf(&buf, "%s", addr) - h.logger.Debugf("route(retry=%d): %s", i, buf.String()) - } - - conn, err = route.Dial(ctx, "tcp", addr) - if err == nil { - break - } - h.logger.Errorf("route(retry=%d): %s", i, err) - } - - return -} diff --git a/pkg/handler/ssu/handler.go b/pkg/handler/ssu/handler.go index 1182f4d..0826965 100644 --- a/pkg/handler/ssu/handler.go +++ b/pkg/handler/ssu/handler.go @@ -1,9 +1,12 @@ -package ss +package ssu import ( "context" "net" + "time" + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -17,6 +20,8 @@ func init() { } type ssuHandler struct { + chain *chain.Chain + bypass bypass.Bypass logger logger.Logger md metadata } @@ -28,6 +33,8 @@ func NewHandler(opts ...handler.Option) handler.Handler { } return &ssuHandler{ + chain: options.Chain, + bypass: options.Bypass, logger: options.Logger, } } @@ -38,6 +45,21 @@ func (h *ssuHandler) Init(md md.Metadata) (err error) { func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() + + start := time.Now() + + h.logger = h.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + } func (h *ssuHandler) parseMetadata(md md.Metadata) (err error) { @@ -51,6 +73,7 @@ func (h *ssuHandler) parseMetadata(md md.Metadata) (err error) { } h.md.readTimeout = md.GetDuration(readTimeout) + h.md.retryCount = md.GetInt(retryCount) return } diff --git a/pkg/handler/ssu/metadata.go b/pkg/handler/ssu/metadata.go index 76d49f0..89412af 100644 --- a/pkg/handler/ssu/metadata.go +++ b/pkg/handler/ssu/metadata.go @@ -1,4 +1,4 @@ -package ss +package ssu import ( "time" @@ -11,9 +11,11 @@ const ( password = "password" key = "key" readTimeout = "readTimeout" + retryCount = "retry" ) type metadata struct { cipher core.Cipher readTimeout time.Duration + retryCount int } diff --git a/pkg/listener/obfs/http/listener.go b/pkg/listener/obfs/http/listener.go index 793dd1b..d1a97e3 100644 --- a/pkg/listener/obfs/http/listener.go +++ b/pkg/listener/obfs/http/listener.go @@ -11,7 +11,7 @@ import ( ) func init() { - registry.RegisterListener("obfs-http", NewListener) + registry.RegisterListener("ohttp", NewListener) } type obfsListener struct { diff --git a/pkg/listener/obfs/tls/listener.go b/pkg/listener/obfs/tls/listener.go index 046f903..20a2973 100644 --- a/pkg/listener/obfs/tls/listener.go +++ b/pkg/listener/obfs/tls/listener.go @@ -11,7 +11,7 @@ import ( ) func init() { - registry.RegisterListener("obfs-tls", NewListener) + registry.RegisterListener("otls", NewListener) } type obfsListener struct { diff --git a/pkg/logger/gost_logger.go b/pkg/logger/gost_logger.go index 4783065..28c125b 100644 --- a/pkg/logger/gost_logger.go +++ b/pkg/logger/gost_logger.go @@ -62,11 +62,13 @@ func (l *logger) Errorf(format string, args ...interface{}) { // Fatal logs a message at level Fatal then the process will exit with status set to 1. func (l *logger) Fatal(args ...interface{}) { l.log(logrus.FatalLevel, args...) + l.logger.Logger.Exit(1) } // Fatalf logs a message at level Fatal then the process will exit with status set to 1. func (l *logger) Fatalf(format string, args ...interface{}) { l.logf(logrus.FatalLevel, format, args...) + l.logger.Logger.Exit(1) } func (l *logger) GetLevel() LogLevel {