diff --git a/connector/socks/v5/conn.go b/connector/socks/v5/conn.go index d11f3b5..3735ca8 100644 --- a/connector/socks/v5/conn.go +++ b/connector/socks/v5/conn.go @@ -1,6 +1,14 @@ package v5 -import "net" +import ( + "bytes" + "net" + "time" + + "github.com/go-gost/core/common/bufpool" + "github.com/go-gost/core/logger" + "github.com/go-gost/gosocks5" +) type bindConn struct { net.Conn @@ -15,3 +23,103 @@ func (c *bindConn) LocalAddr() net.Addr { func (c *bindConn) RemoteAddr() net.Addr { return c.remoteAddr } + +type udpRelayConn struct { + udpConn *net.UDPConn + tcpConn net.Conn + taddr net.Addr + bufferSize int + logger logger.Logger +} + +func (c *udpRelayConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + buf := bufpool.Get(c.bufferSize) + defer bufpool.Put(buf) + + nn, err := c.udpConn.Read(*buf) + if err != nil { + return + } + + socksAddr := gosocks5.Addr{} + header := gosocks5.UDPHeader{ + Addr: &socksAddr, + } + dgram := gosocks5.UDPDatagram{ + Header: &header, + } + _, err = dgram.ReadFrom(bytes.NewReader((*buf)[:nn])) + if err != nil { + return + } + + n = copy(b, dgram.Data) + addr, err = net.ResolveUDPAddr("udp", header.Addr.String()) + + return +} + +func (c *udpRelayConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *udpRelayConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + socksAddr := gosocks5.Addr{} + if err = socksAddr.ParseFrom(addr.String()); err != nil { + return + } + + header := gosocks5.UDPHeader{ + Addr: &socksAddr, + } + dgram := gosocks5.UDPDatagram{ + Header: &header, + Data: b, + } + + buf := bufpool.Get(c.bufferSize) + defer bufpool.Put(buf) + + nn, err := dgram.WriteTo(bytes.NewBuffer((*buf)[:0])) + if err != nil { + return + } + if nn > int64(len(*buf)) { + nn = int64(len(*buf)) + } + + _, err = c.udpConn.Write((*buf)[:nn]) + n = len(b) + + return +} + +func (c *udpRelayConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.taddr) +} + +func (c *udpRelayConn) RemoteAddr() net.Addr { + return c.taddr +} + +func (c *udpRelayConn) LocalAddr() net.Addr { + return c.udpConn.LocalAddr() +} + +func (c *udpRelayConn) Close() error { + c.udpConn.Close() + return c.tcpConn.Close() +} + +func (c *udpRelayConn) SetDeadline(t time.Time) error { + return c.udpConn.SetDeadline(t) +} + +func (c *udpRelayConn) SetReadDeadline(t time.Time) error { + return c.udpConn.SetReadDeadline(t) +} + +func (c *udpRelayConn) SetWriteDeadline(t time.Time) error { + return c.udpConn.SetWriteDeadline(t) +} diff --git a/connector/socks/v5/connector.go b/connector/socks/v5/connector.go index 377ad97..d22be40 100644 --- a/connector/socks/v5/connector.go +++ b/connector/socks/v5/connector.go @@ -151,6 +151,10 @@ func (c *socks5Connector) connectUDP(ctx context.Context, conn net.Conn, network return nil, err } + if c.md.relay == "udp" { + return c.relayUDP(ctx, conn, addr, log) + } + req := gosocks5.NewRequest(socks.CmdUDPTun, nil) if err := req.Write(conn); err != nil { log.Error(err) @@ -171,3 +175,36 @@ func (c *socks5Connector) connectUDP(ctx context.Context, conn net.Conn, network return socks.UDPTunClientConn(conn, addr), nil } + +func (c *socks5Connector) relayUDP(ctx context.Context, conn net.Conn, addr net.Addr, log logger.Logger) (net.Conn, error) { + req := gosocks5.NewRequest(gosocks5.CmdUdp, nil) + if err := req.Write(conn); err != nil { + log.Error(err) + return nil, err + } + log.Debug(req) + + reply, err := gosocks5.ReadReply(conn) + if err != nil { + log.Error(err) + return nil, err + } + log.Debug(reply) + + if reply.Rep != gosocks5.Succeeded { + return nil, errors.New("get socks5 UDP tunnel failure") + } + + cc, err := (&net.Dialer{}).DialContext(ctx, "udp", reply.Addr.String()) + if err != nil { + return nil, err + } + + return &udpRelayConn{ + udpConn: cc.(*net.UDPConn), + tcpConn: conn, + taddr: addr, + bufferSize: c.md.udpBufferSize, + logger: log, + }, nil +} diff --git a/connector/socks/v5/metadata.go b/connector/socks/v5/metadata.go index e02f06b..4fc17b6 100644 --- a/connector/socks/v5/metadata.go +++ b/connector/socks/v5/metadata.go @@ -7,19 +7,32 @@ import ( mdx "github.com/go-gost/x/metadata" ) +const ( + defaultUDPBufferSize = 4096 +) + type metadata struct { connectTimeout time.Duration noTLS bool + relay string + udpBufferSize int } func (c *socks5Connector) parseMetadata(md mdata.Metadata) (err error) { const ( connectTimeout = "timeout" noTLS = "notls" + relay = "relay" + udpBufferSize = "udpBufferSize" ) c.md.connectTimeout = mdx.GetDuration(md, connectTimeout) c.md.noTLS = mdx.GetBool(md, noTLS) + c.md.relay = mdx.GetString(md, relay) + c.md.udpBufferSize = mdx.GetInt(md, udpBufferSize) + if c.md.udpBufferSize <= 0 { + c.md.udpBufferSize = defaultUDPBufferSize + } return }