From 1a1ee384b70164cea973f5bc6d084c3379c4ddf8 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Tue, 16 Nov 2021 18:58:47 +0800 Subject: [PATCH] update socks5 conn --- pkg/common/util/socks/conn.go | 48 +++++++++++++++++------------------ pkg/connector/ss/connector.go | 14 +++++----- pkg/handler/socks/v5/udp.go | 6 ++--- pkg/listener/rudp/listener.go | 2 +- 4 files changed, 34 insertions(+), 36 deletions(-) diff --git a/pkg/common/util/socks/conn.go b/pkg/common/util/socks/conn.go index 06ef82b..12e977d 100644 --- a/pkg/common/util/socks/conn.go +++ b/pkg/common/util/socks/conn.go @@ -8,33 +8,31 @@ import ( "github.com/go-gost/gost/pkg/common/bufpool" ) -var ( - _ net.PacketConn = (*UDPTunConn)(nil) - _ net.Conn = (*UDPTunConn)(nil) - - _ net.PacketConn = (*UDPConn)(nil) - _ net.Conn = (*UDPConn)(nil) -) - -type UDPTunConn struct { +type udpTunConn struct { net.Conn taddr net.Addr } -func UDPTunClientConn(c net.Conn, targetAddr net.Addr) *UDPTunConn { - return &UDPTunConn{ +func UDPTunClientConn(c net.Conn, targetAddr net.Addr) net.Conn { + return &udpTunConn{ Conn: c, taddr: targetAddr, } } -func UDPTunServerConn(c net.Conn) *UDPTunConn { - return &UDPTunConn{ +func UDPTunClientPacketConn(c net.Conn) net.PacketConn { + return &udpTunConn{ Conn: c, } } -func (c *UDPTunConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { +func UDPTunServerConn(c net.Conn) net.PacketConn { + return &udpTunConn{ + Conn: c, + } +} + +func (c *udpTunConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { socksAddr := gosocks5.Addr{} header := gosocks5.UDPHeader{ Addr: &socksAddr, @@ -57,12 +55,12 @@ func (c *UDPTunConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { return } -func (c *UDPTunConn) Read(b []byte) (n int, err error) { +func (c *udpTunConn) Read(b []byte) (n int, err error) { n, _, err = c.ReadFrom(b) return } -func (c *UDPTunConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { +func (c *udpTunConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { socksAddr := gosocks5.Addr{} if err = socksAddr.ParseFrom(addr.String()); err != nil { return @@ -83,7 +81,7 @@ func (c *UDPTunConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return } -func (c *UDPTunConn) Write(b []byte) (n int, err error) { +func (c *udpTunConn) Write(b []byte) (n int, err error) { return c.WriteTo(b, c.taddr) } @@ -91,15 +89,15 @@ var ( DefaultBufferSize = 4096 ) -type UDPConn struct { +type udpConn struct { net.PacketConn raddr net.Addr taddr net.Addr bufferSize int } -func NewUDPConn(c net.PacketConn, bufferSize int) *UDPConn { - return &UDPConn{ +func UDPConn(c net.PacketConn, bufferSize int) net.PacketConn { + return &udpConn{ PacketConn: c, bufferSize: bufferSize, } @@ -108,7 +106,7 @@ func NewUDPConn(c net.PacketConn, bufferSize int) *UDPConn { // ReadFrom reads an UDP datagram. // NOTE: for server side, // the returned addr is the target address the client want to relay to. -func (c *UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { +func (c *udpConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { rbuf := bufpool.Get(c.bufferSize) defer bufpool.Put(rbuf) @@ -131,12 +129,12 @@ func (c *UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { return } -func (c *UDPConn) Read(b []byte) (n int, err error) { +func (c *udpConn) Read(b []byte) (n int, err error) { n, _, err = c.ReadFrom(b) return } -func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { +func (c *udpConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { wbuf := bufpool.Get(c.bufferSize) defer bufpool.Put(wbuf) @@ -165,10 +163,10 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return } -func (c *UDPConn) Write(b []byte) (n int, err error) { +func (c *udpConn) Write(b []byte) (n int, err error) { return c.WriteTo(b, c.taddr) } -func (c *UDPConn) RemoteAddr() net.Addr { +func (c *udpConn) RemoteAddr() net.Addr { return c.raddr } diff --git a/pkg/connector/ss/connector.go b/pkg/connector/ss/connector.go index 803650f..28c2128 100644 --- a/pkg/connector/ss/connector.go +++ b/pkg/connector/ss/connector.go @@ -53,13 +53,7 @@ func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, addre switch network { case "tcp", "tcp4", "tcp6": case "udp", "udp4", "udp6": - if c.md.enableUDP { - return c.connectUDP(ctx, conn, network, address) - } else { - err := errors.New("UDP relay is disabled") - c.logger.Error(err) - return nil, err - } + return c.connectUDP(ctx, conn, network, address) default: err := fmt.Errorf("network %s unsupported", network) c.logger.Error(err) @@ -105,6 +99,12 @@ func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, addre } func (c *ssConnector) connectUDP(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) { + if c.md.enableUDP { + err := errors.New("UDP relay is disabled") + c.logger.Error(err) + return nil, err + } + if c.md.connectTimeout > 0 { conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) defer conn.SetDeadline(time.Time{}) diff --git a/pkg/handler/socks/v5/udp.go b/pkg/handler/socks/v5/udp.go index 27a96af..e2821e8 100644 --- a/pkg/handler/socks/v5/udp.go +++ b/pkg/handler/socks/v5/udp.go @@ -63,7 +63,7 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosoc defer peer.Close() go h.relayUDP( - socks.NewUDPConn(relay, h.md.udpBufferSize), + socks.UDPConn(relay, h.md.udpBufferSize), peer, ) } else { @@ -75,8 +75,8 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosoc defer tun.Close() go h.tunnelClientUDP( - socks.NewUDPConn(relay, h.md.udpBufferSize), - socks.UDPTunClientConn(tun, nil), + socks.UDPConn(relay, h.md.udpBufferSize), + socks.UDPTunClientPacketConn(tun), ) } diff --git a/pkg/listener/rudp/listener.go b/pkg/listener/rudp/listener.go index 6105fc9..6995a03 100644 --- a/pkg/listener/rudp/listener.go +++ b/pkg/listener/rudp/listener.go @@ -193,7 +193,7 @@ func (l *rudpListener) initUDPTunnel(conn net.Conn) (net.PacketConn, error) { } l.logger.Debugf("bind on %s OK", baddr) - return socks.UDPTunClientConn(conn, nil), nil + return socks.UDPTunClientPacketConn(conn), nil } func (l *rudpListener) getConn(conn net.PacketConn, raddr net.Addr) *udp.Conn {