diff --git a/go.mod b/go.mod index b993fc2..956ef36 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/ginuerzh/tls-dissector v0.0.2-0.20201202075250-98fa925912da github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 - github.com/go-gost/relay v0.1.1-0.20211122150329-54ee406ea49d + github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 github.com/gobwas/glob v0.2.3 github.com/golang/snappy v0.0.3 github.com/google/gopacket v1.1.19 // indirect diff --git a/go.sum b/go.sum index b011a3a..c3b2c9f 100644 --- a/go.sum +++ b/go.sum @@ -115,6 +115,8 @@ github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgO github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= github.com/go-gost/relay v0.1.1-0.20211122150329-54ee406ea49d h1:rzGVzkSvxuDZg8PoYmOR+tvcAg9Dr8whgV19kzuO4YA= github.com/go-gost/relay v0.1.1-0.20211122150329-54ee406ea49d/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= +github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 h1:itaaJhQJ19kUXEB4Igb0EbY8m+1Py2AaNNSBds/9gk4= +github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= diff --git a/pkg/handler/forward/local/handler.go b/pkg/handler/forward/local/handler.go index 6278c76..c39feb9 100644 --- a/pkg/handler/forward/local/handler.go +++ b/pkg/handler/forward/local/handler.go @@ -2,6 +2,7 @@ package forward import ( "context" + "fmt" "net" "time" @@ -42,7 +43,7 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { return h.parseMetadata(md) } -// implements chain.Chainable interface +// WithChain implements chain.Chainable interface func (h *forwardHandler) WithChain(chain *chain.Chain) { h.chain = chain } @@ -74,8 +75,13 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { return } + network := "tcp" + if _, ok := conn.(net.PacketConn); ok { + network = "udp" + } + h.logger = h.logger.WithFields(map[string]interface{}{ - "dst": target.Addr(), + "dst": fmt.Sprintf("%s/%s", target.Addr(), network), }) h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) @@ -85,11 +91,6 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { WithRetry(h.md.retryCount). WithLogger(h.logger) - network := "tcp" - if _, ok := conn.(net.PacketConn); ok { - network = "udp" - } - cc, err := r.Dial(ctx, network, target.Addr()) if err != nil { h.logger.Error(err) diff --git a/pkg/handler/forward/remote/handler.go b/pkg/handler/forward/remote/handler.go index f8a07b1..af118da 100644 --- a/pkg/handler/forward/remote/handler.go +++ b/pkg/handler/forward/remote/handler.go @@ -2,6 +2,7 @@ package forward import ( "context" + "fmt" "net" "time" @@ -68,8 +69,13 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { return } + network := "tcp" + if _, ok := conn.(net.PacketConn); ok { + network = "udp" + } + h.logger = h.logger.WithFields(map[string]interface{}{ - "dst": target.Addr(), + "dst": fmt.Sprintf("%s/%s", target.Addr(), network), }) h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) @@ -79,11 +85,6 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn) { WithRetry(h.md.retryCount). WithLogger(h.logger) - network := "tcp" - if _, ok := conn.(net.PacketConn); ok { - network = "udp" - } - cc, err := r.Dial(ctx, network, target.Addr()) if err != nil { h.logger.Error(err) diff --git a/pkg/handler/relay/bind.go b/pkg/handler/relay/bind.go new file mode 100644 index 0000000..8602f89 --- /dev/null +++ b/pkg/handler/relay/bind.go @@ -0,0 +1,259 @@ +package relay + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/go-gost/gost/pkg/common/bufpool" + "github.com/go-gost/gost/pkg/common/util/mux" + "github.com/go-gost/gost/pkg/common/util/socks" + "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/relay" +) + +func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, address string) { + h.logger = h.logger.WithFields(map[string]interface{}{ + "dst": fmt.Sprintf("%s/%s", address, network), + "cmd": "bind", + }) + + h.logger.Infof("%s >> %s", conn.RemoteAddr(), address) + + resp := relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + } + + if !h.md.enableBind { + resp.Status = relay.StatusForbidden + resp.WriteTo(conn) + h.logger.Error("BIND is diabled") + return + } + + if network == "tcp" { + h.bindTCP(ctx, conn, network, address) + } else { + h.bindUDP(ctx, conn, network, address) + } +} + +func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, address string) { + resp := relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + } + + ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error + if err != nil { + h.logger.Error(err) + resp.Status = relay.StatusServiceUnavailable + resp.WriteTo(conn) + return + } + + af := &relay.AddrFeature{} + err = af.ParseFrom(ln.Addr().String()) + if err != nil { + h.logger.Warn(err) + } + + // Issue: may not reachable when host has multi-interface + af.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) + af.AType = relay.AddrIPv4 + resp.Features = append(resp.Features, af) + if _, err := resp.WriteTo(conn); err != nil { + h.logger.Error(err) + ln.Close() + return + } + + h.logger = h.logger.WithFields(map[string]interface{}{ + "bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()), + }) + h.logger.Debugf("bind on %s OK", ln.Addr()) + + h.serveTCPBind(ctx, conn, ln) +} + +func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, address string) { + resp := relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + } + + bindAddr, _ := net.ResolveUDPAddr(network, address) + pc, err := net.ListenUDP(network, bindAddr) + if err != nil { + h.logger.Error(err) + return + } + defer pc.Close() + + af := &relay.AddrFeature{} + err = af.ParseFrom(pc.LocalAddr().String()) + if err != nil { + h.logger.Warn(err) + } + + // Issue: may not reachable when host has multi-interface + af.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) + af.AType = relay.AddrIPv4 + resp.Features = append(resp.Features, af) + if _, err := resp.WriteTo(conn); err != nil { + h.logger.Error(err) + return + } + + h.logger = h.logger.WithFields(map[string]interface{}{ + "bind": pc.LocalAddr().String(), + }) + h.logger.Debugf("bind on %s OK", pc.LocalAddr()) + + t := time.Now() + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) + h.tunnelServerUDP( + socks.UDPTunServerConn(conn), + pc, + ) + h.logger. + WithFields(map[string]interface{}{ + "duration": time.Since(t), + }). + Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) +} + +func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.Listener) { + // Upgrade connection to multiplex stream. + session, err := mux.ClientSession(conn) + if err != nil { + h.logger.Error(err) + return + } + defer session.Close() + + go func() { + defer ln.Close() + for { + conn, err := session.Accept() + if err != nil { + h.logger.Error(err) + return + } + conn.Close() // we do not handle incoming connections. + } + }() + + for { + rc, err := ln.Accept() + if err != nil { + h.logger.Error(err) + return + } + h.logger.Debugf("peer %s accepted", rc.RemoteAddr()) + + go func(c net.Conn) { + defer c.Close() + + sc, err := session.GetConn() + if err != nil { + h.logger.Error(err) + return + } + defer sc.Close() + + af := &relay.AddrFeature{} + af.ParseFrom(c.RemoteAddr().String()) + resp := relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + Features: []relay.Feature{af}, + } + if _, err := resp.WriteTo(sc); err != nil { + h.logger.Error(err) + return + } + + t := time.Now() + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), c.RemoteAddr().String()) + handler.Transport(sc, c) + h.logger. + WithFields(map[string]interface{}{"duration": time.Since(t)}). + Infof("%s >-< %s", conn.RemoteAddr(), c.RemoteAddr().String()) + }(rc) + } +} + +func (h *relayHandler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) { + bufSize := h.md.udpBufferSize + errc := make(chan error, 2) + + go func() { + for { + err := func() error { + b := bufpool.Get(bufSize) + defer bufpool.Put(b) + + n, raddr, err := tunnel.ReadFrom(b) + if err != nil { + return err + } + + if h.bypass != nil && h.bypass.Contains(raddr.String()) { + h.logger.Warn("bypass: ", raddr) + return nil + } + + if _, err := c.WriteTo(b[:n], raddr); err != nil { + return err + } + + h.logger.Debugf("%s >>> %s data: %d", + c.LocalAddr(), raddr, n) + + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + go func() { + for { + err := func() error { + b := bufpool.Get(bufSize) + defer bufpool.Put(b) + + n, raddr, err := c.ReadFrom(b) + if err != nil { + return err + } + + if h.bypass != nil && h.bypass.Contains(raddr.String()) { + h.logger.Warn("bypass: ", raddr) + return nil + } + + if _, err := tunnel.WriteTo(b[:n], raddr); err != nil { + return err + } + h.logger.Debugf("%s <<< %s data: %d", + c.LocalAddr(), raddr, n) + + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + return <-errc +} diff --git a/pkg/handler/relay/proxy.go b/pkg/handler/relay/connect.go similarity index 75% rename from pkg/handler/relay/proxy.go rename to pkg/handler/relay/connect.go index 1887da7..15de5b6 100644 --- a/pkg/handler/relay/proxy.go +++ b/pkg/handler/relay/connect.go @@ -2,6 +2,7 @@ package relay import ( "context" + "fmt" "net" "time" @@ -11,7 +12,12 @@ import ( "github.com/go-gost/relay" ) -func (h *relayHandler) handleProxy(ctx context.Context, conn net.Conn, network, address string) { +func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string) { + h.logger = h.logger.WithFields(map[string]interface{}{ + "dst": fmt.Sprintf("%s/%s", address, network), + "cmd": "connect", + }) + h.logger.Infof("%s >> %s", conn.RemoteAddr(), address) resp := relay.Response{ @@ -19,6 +25,13 @@ func (h *relayHandler) handleProxy(ctx context.Context, conn net.Conn, network, Status: relay.StatusOK, } + if address == "" { + resp.Status = relay.StatusBadRequest + resp.WriteTo(conn) + h.logger.Error("target not specified") + return + } + if h.bypass != nil && h.bypass.Contains(address) { h.logger.Info("bypass: ", address) resp.Status = relay.StatusForbidden diff --git a/pkg/handler/relay/forward.go b/pkg/handler/relay/forward.go index 34686ee..1ed55b6 100644 --- a/pkg/handler/relay/forward.go +++ b/pkg/handler/relay/forward.go @@ -2,6 +2,7 @@ package relay import ( "context" + "fmt" "net" "time" @@ -17,7 +18,7 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network } h.logger = h.logger.WithFields(map[string]interface{}{ - "dst": target.Addr(), + "dst": fmt.Sprintf("%s/%s", target.Addr(), network), }) h.logger.Infof("%s >> %s", conn.RemoteAddr(), target.Addr()) diff --git a/pkg/handler/relay/handler.go b/pkg/handler/relay/handler.go index b116c96..e00c832 100644 --- a/pkg/handler/relay/handler.go +++ b/pkg/handler/relay/handler.go @@ -43,7 +43,7 @@ func (h *relayHandler) Init(md md.Metadata) (err error) { return h.parseMetadata(md) } -// implements chain.Chainable interface +// WithChain implements chain.Chainable interface func (h *relayHandler) WithChain(chain *chain.Chain) { h.chain = chain } @@ -87,7 +87,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { } var user, pass string - var target string + var address string for _, f := range req.Features { if f.Type() == relay.FeatureUserAuth { feature := f.(*relay.UserAuthFeature) @@ -95,16 +95,13 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { } if f.Type() == relay.FeatureAddr { feature := f.(*relay.AddrFeature) - target = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port))) + address = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port))) } } if user != "" { h.logger = h.logger.WithFields(map[string]interface{}{"user": user}) } - if target != "" { - h.logger = h.logger.WithFields(map[string]interface{}{"dst": target}) - } resp := relay.Response{ Version: relay.Version1, @@ -123,7 +120,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { } if h.group != nil { - if target != "" { + if address != "" { resp.Status = relay.StatusForbidden resp.WriteTo(conn) h.logger.Error("forbidden") @@ -134,13 +131,11 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { return } - if target == "" { - resp.Status = relay.StatusBadRequest - resp.WriteTo(conn) - h.logger.Error("target not specified") - return + switch req.Flags & relay.CmdMask { + case relay.CONNECT: + h.handleConnect(ctx, conn, network, address) + case relay.BIND: + h.handleBind(ctx, conn, network, address) + case relay.ASSOCIATE: } - - // proxy mode - h.handleProxy(ctx, conn, network, target) } diff --git a/pkg/handler/relay/metadata.go b/pkg/handler/relay/metadata.go index 4a7b5e0..46a408c 100644 --- a/pkg/handler/relay/metadata.go +++ b/pkg/handler/relay/metadata.go @@ -12,13 +12,17 @@ type metadata struct { authenticator auth.Authenticator readTimeout time.Duration retryCount int + enableBind bool + udpBufferSize int } func (h *relayHandler) parseMetadata(md md.Metadata) (err error) { const ( - users = "users" - readTimeout = "readTimeout" - retryCount = "retry" + users = "users" + readTimeout = "readTimeout" + retryCount = "retry" + enableBind = "bind" + udpBufferSize = "udpBufferSize" ) if v, _ := md.Get(users).([]interface{}); len(v) > 0 { @@ -37,5 +41,17 @@ func (h *relayHandler) parseMetadata(md md.Metadata) (err error) { } h.md.readTimeout = md.GetDuration(readTimeout) h.md.retryCount = md.GetInt(retryCount) + h.md.enableBind = md.GetBool(enableBind) + h.md.udpBufferSize = md.GetInt(udpBufferSize) + if h.md.udpBufferSize > 0 { + if h.md.udpBufferSize < 512 { + h.md.udpBufferSize = 512 // min buffer size + } + if h.md.udpBufferSize > 65*1024 { + h.md.udpBufferSize = 65 * 1024 // max buffer size + } + } else { + h.md.udpBufferSize = 1024 // default buffer size + } return } diff --git a/pkg/handler/socks/v5/bind.go b/pkg/handler/socks/v5/bind.go index f72010d..3eed1f6 100644 --- a/pkg/handler/socks/v5/bind.go +++ b/pkg/handler/socks/v5/bind.go @@ -2,23 +2,21 @@ package v5 import ( "context" + "fmt" "net" "time" "github.com/go-gost/gosocks5" - "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" ) -func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks5.Request) { - addr := req.Addr.String() - +func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, network, address string) { h.logger = h.logger.WithFields(map[string]interface{}{ - "dst": addr, + "dst": fmt.Sprintf("%s/%s", address, network), "cmd": "bind", }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr) + h.logger.Infof("%s >> %s", conn.RemoteAddr(), address) if !h.md.enableBind { reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) @@ -28,41 +26,12 @@ func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, req *goso return } - if h.chain.IsEmpty() { - h.bindLocal(ctx, conn, addr) - return - } - - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - cc, err := r.Connect(ctx) - if err != nil { - resp := gosocks5.NewReply(gosocks5.Failure, nil) - resp.Write(conn) - h.logger.Debug(resp) - return - } - defer cc.Close() - - // forward request - if err := req.Write(cc); err != nil { - h.logger.Error(err) - resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) - resp.Write(conn) - h.logger.Debug(resp) - return - } - - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) - handler.Transport(conn, cc) - h.logger.Infof("%s >-< %s", conn.RemoteAddr(), addr) + // BIND does not support chain. + h.bindLocal(ctx, conn, network, address) } -func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, addr string) { - bindAddr, _ := net.ResolveTCPAddr("tcp", addr) - ln, err := net.ListenTCP("tcp", bindAddr) // strict mode: if the port already in use, it will return error +func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, network, address string) { + ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error if err != nil { h.logger.Error(err) reply := gosocks5.NewReply(gosocks5.Failure, nil) @@ -90,9 +59,10 @@ func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, addr strin h.logger.Debug(reply) h.logger = h.logger.WithFields(map[string]interface{}{ - "bind": socksAddr.String(), + "bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()), }) - h.logger.Debugf("bind on %s OK", &socksAddr) + + h.logger.Debugf("bind on %s OK", ln.Addr()) h.serveBind(ctx, conn, ln) } diff --git a/pkg/handler/socks/v5/connect.go b/pkg/handler/socks/v5/connect.go index 1723360..fce9572 100644 --- a/pkg/handler/socks/v5/connect.go +++ b/pkg/handler/socks/v5/connect.go @@ -2,6 +2,7 @@ package v5 import ( "context" + "fmt" "net" "time" @@ -10,18 +11,18 @@ import ( "github.com/go-gost/gost/pkg/handler" ) -func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr string) { +func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string) { h.logger = h.logger.WithFields(map[string]interface{}{ - "dst": addr, + "dst": fmt.Sprintf("%s/%s", address, network), "cmd": "connect", }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr) + h.logger.Infof("%s >> %s", conn.RemoteAddr(), address) - if h.bypass != nil && h.bypass.Contains(addr) { + if h.bypass != nil && h.bypass.Contains(address) { resp := gosocks5.NewReply(gosocks5.NotAllowed, nil) resp.Write(conn) h.logger.Debug(resp) - h.logger.Info("bypass: ", addr) + h.logger.Info("bypass: ", address) return } @@ -29,7 +30,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr s WithChain(h.chain). WithRetry(h.md.retryCount). WithLogger(h.logger) - cc, err := r.Dial(ctx, "tcp", addr) + cc, err := r.Dial(ctx, network, address) if err != nil { resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) resp.Write(conn) @@ -47,11 +48,11 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr s h.logger.Debug(resp) t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), address) handler.Transport(conn, cc) h.logger. WithFields(map[string]interface{}{ "duration": time.Since(t), }). - Infof("%s >-< %s", conn.RemoteAddr(), addr) + Infof("%s >-< %s", conn.RemoteAddr(), address) } diff --git a/pkg/handler/socks/v5/handler.go b/pkg/handler/socks/v5/handler.go index ffedf32..d1d4dd8 100644 --- a/pkg/handler/socks/v5/handler.go +++ b/pkg/handler/socks/v5/handler.go @@ -90,17 +90,19 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) { h.logger.Debug(req) conn.SetReadDeadline(time.Time{}) + address := req.Addr.String() + switch req.Cmd { case gosocks5.CmdConnect: - h.handleConnect(ctx, conn, req.Addr.String()) + h.handleConnect(ctx, conn, "tcp", address) case gosocks5.CmdBind: - h.handleBind(ctx, conn, req) + h.handleBind(ctx, conn, "tcp", address) case socks.CmdMuxBind: - h.handleMuxBind(ctx, conn, req) + h.handleMuxBind(ctx, conn, "tcp", address) case gosocks5.CmdUdp: - h.handleUDP(ctx, conn, req) + h.handleUDP(ctx, conn) case socks.CmdUDPTun: - h.handleUDPTun(ctx, conn, req) + h.handleUDPTun(ctx, conn, "udp", address) default: h.logger.Errorf("unknown cmd: %d", req.Cmd) resp := gosocks5.NewReply(gosocks5.CmdUnsupported, nil) diff --git a/pkg/handler/socks/v5/mbind.go b/pkg/handler/socks/v5/mbind.go index e2424ff..bb1fd98 100644 --- a/pkg/handler/socks/v5/mbind.go +++ b/pkg/handler/socks/v5/mbind.go @@ -2,24 +2,22 @@ package v5 import ( "context" + "fmt" "net" "time" "github.com/go-gost/gosocks5" - "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/common/util/mux" "github.com/go-gost/gost/pkg/handler" ) -func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, req *gosocks5.Request) { - addr := req.Addr.String() - +func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, network, address string) { h.logger = h.logger.WithFields(map[string]interface{}{ - "dst": addr, + "dst": fmt.Sprintf("%s/%s", address, network), "cmd": "mbind", }) - h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr) + h.logger.Infof("%s >> %s", conn.RemoteAddr(), address) if !h.md.enableBind { reply := gosocks5.NewReply(gosocks5.NotAllowed, nil) @@ -29,46 +27,11 @@ func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, req *g return } - if h.chain.IsEmpty() { - h.muxBindLocal(ctx, conn, addr) - return - } - - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - cc, err := r.Connect(ctx) - if err != nil { - resp := gosocks5.NewReply(gosocks5.Failure, nil) - resp.Write(conn) - h.logger.Debug(resp) - return - } - defer cc.Close() - - // forward request - if err := req.Write(cc); err != nil { - h.logger.Error(err) - resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) - resp.Write(conn) - h.logger.Debug(resp) - return - } - - t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) - handler.Transport(conn, cc) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(t), - }). - Infof("%s >-< %s", conn.RemoteAddr(), addr) + h.muxBindLocal(ctx, conn, network, address) } -func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, addr string) { - bindAddr, _ := net.ResolveTCPAddr("tcp", addr) - ln, err := net.ListenTCP("tcp", bindAddr) // strict mode: if the port already in use, it will return error +func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network, address string) { + ln, err := net.Listen(network, address) // strict mode: if the port already in use, it will return error if err != nil { h.logger.Error(err) reply := gosocks5.NewReply(gosocks5.Failure, nil) @@ -80,7 +43,7 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, addr st } socksAddr := gosocks5.Addr{} - socksAddr.ParseFrom(ln.Addr().String()) + err = socksAddr.ParseFrom(ln.Addr().String()) if err != nil { h.logger.Warn(err) } @@ -97,9 +60,10 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, addr st h.logger.Debug(reply) h.logger = h.logger.WithFields(map[string]interface{}{ - "bind": socksAddr.String(), + "bind": fmt.Sprintf("%s/%s", ln.Addr(), ln.Addr().Network()), }) - h.logger.Debugf("bind on %s OK", &socksAddr) + + h.logger.Debugf("bind on %s OK", ln.Addr()) h.serveMuxBind(ctx, conn, ln) } diff --git a/pkg/handler/socks/v5/metadata.go b/pkg/handler/socks/v5/metadata.go index bc04dbe..82343d0 100644 --- a/pkg/handler/socks/v5/metadata.go +++ b/pkg/handler/socks/v5/metadata.go @@ -80,7 +80,7 @@ func (h *socks5Handler) parseMetadata(md md.Metadata) error { h.md.udpBufferSize = 65 * 1024 // max buffer size } } else { - h.md.udpBufferSize = 4096 // default buffer size + h.md.udpBufferSize = 1024 // default buffer size } h.md.compatibilityMode = md.GetBool(compatibilityMode) diff --git a/pkg/handler/socks/v5/udp.go b/pkg/handler/socks/v5/udp.go index e2821e8..725b157 100644 --- a/pkg/handler/socks/v5/udp.go +++ b/pkg/handler/socks/v5/udp.go @@ -3,6 +3,7 @@ package v5 import ( "context" "errors" + "fmt" "io" "io/ioutil" "net" @@ -14,7 +15,7 @@ import ( "github.com/go-gost/gost/pkg/common/util/socks" ) -func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosocks5.Request) { +func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) { h.logger = h.logger.WithFields(map[string]interface{}{ "cmd": "udp", }) @@ -49,9 +50,9 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosoc h.logger.Debug(reply) h.logger = h.logger.WithFields(map[string]interface{}{ - "bind": saddr.String(), + "bind": fmt.Sprintf("%s/%s", relay.LocalAddr(), relay.LocalAddr().Network()), }) - h.logger.Debugf("bind on %s OK", &saddr) + h.logger.Debugf("bind on %s OK", relay.LocalAddr()) if h.chain.IsEmpty() { // serve as standard socks5 udp relay. diff --git a/pkg/handler/socks/v5/udp_tun.go b/pkg/handler/socks/v5/udp_tun.go index 58223bf..ce0a72c 100644 --- a/pkg/handler/socks/v5/udp_tun.go +++ b/pkg/handler/socks/v5/udp_tun.go @@ -2,17 +2,16 @@ package v5 import ( "context" + "fmt" "net" "time" "github.com/go-gost/gosocks5" - "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/common/bufpool" "github.com/go-gost/gost/pkg/common/util/socks" - "github.com/go-gost/gost/pkg/handler" ) -func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *gosocks5.Request) { +func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, network, address string) { h.logger = h.logger.WithFields(map[string]interface{}{ "cmd": "udp-tun", }) @@ -25,76 +24,41 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go return } - if h.chain.IsEmpty() { - addr := req.Addr.String() - - bindAddr, _ := net.ResolveUDPAddr("udp", addr) - relay, err := net.ListenUDP("udp", bindAddr) - if err != nil { - h.logger.Error(err) - return - } - defer relay.Close() - - saddr, _ := gosocks5.NewAddr(relay.LocalAddr().String()) - saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) - saddr.Type = 0 - reply := gosocks5.NewReply(gosocks5.Succeeded, saddr) - if err := reply.Write(conn); err != nil { - h.logger.Error(err) - return - } - h.logger.Debug(reply) - - h.logger = h.logger.WithFields(map[string]interface{}{ - "bind": saddr.String(), - }) - - t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), saddr) - h.tunnelServerUDP( - socks.UDPTunServerConn(conn), - relay, - ) - h.logger. - WithFields(map[string]interface{}{ - "duration": time.Since(t), - }). - Infof("%s >-< %s", conn.RemoteAddr(), saddr) - - return - } - - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - cc, err := r.Connect(ctx) + bindAddr, _ := net.ResolveUDPAddr(network, address) + pc, err := net.ListenUDP(network, bindAddr) if err != nil { h.logger.Error(err) - reply := gosocks5.NewReply(gosocks5.Failure, nil) - reply.Write(conn) - h.logger.Debug(reply) return } - defer cc.Close() + defer pc.Close() - // forward request - if err := req.Write(cc); err != nil { + saddr, _ := gosocks5.NewAddr(pc.LocalAddr().String()) + saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) + saddr.Type = 0 + reply := gosocks5.NewReply(gosocks5.Succeeded, saddr) + if err := reply.Write(conn); err != nil { h.logger.Error(err) - reply := gosocks5.NewReply(gosocks5.Failure, nil) - reply.Write(conn) - h.logger.Debug(reply) + return } + h.logger.Debug(reply) + + h.logger = h.logger.WithFields(map[string]interface{}{ + "bind": fmt.Sprintf("%s/%s", pc.LocalAddr(), pc.LocalAddr().Network()), + }) + + h.logger.Debugf("bind on %s OK", pc.LocalAddr()) t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) - handler.Transport(conn, cc) + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), pc.LocalAddr()) + h.tunnelServerUDP( + socks.UDPTunServerConn(conn), + pc, + ) h.logger. WithFields(map[string]interface{}{ "duration": time.Since(t), }). - Infof("%s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) + Infof("%s >-< %s", conn.RemoteAddr(), pc.LocalAddr()) } func (h *socks5Handler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) { diff --git a/pkg/handler/ss/udp/handler.go b/pkg/handler/ss/udp/handler.go index 7486b7f..f9e670d 100644 --- a/pkg/handler/ss/udp/handler.go +++ b/pkg/handler/ss/udp/handler.go @@ -43,7 +43,7 @@ func (h *ssuHandler) Init(md md.Metadata) (err error) { return h.parseMetadata(md) } -// implements chain.Chainable interface +// WithChain implements chain.Chainable interface func (h *ssuHandler) WithChain(chain *chain.Chain) { h.chain = chain } diff --git a/pkg/handler/ss/udp/metadata.go b/pkg/handler/ss/udp/metadata.go index e08ef40..49c5cf1 100644 --- a/pkg/handler/ss/udp/metadata.go +++ b/pkg/handler/ss/udp/metadata.go @@ -55,7 +55,7 @@ func (h *ssuHandler) parseMetadata(md md.Metadata) (err error) { h.md.bufferSize = 65 * 1024 // max buffer size } } else { - h.md.bufferSize = 4096 // default buffer size + h.md.bufferSize = 1024 // default buffer size } return }