update socks5 conn

This commit is contained in:
ginuerzh 2021-11-16 18:58:47 +08:00
parent 83dacf67d5
commit 1a1ee384b7
4 changed files with 34 additions and 36 deletions

View File

@ -8,33 +8,31 @@ import (
"github.com/go-gost/gost/pkg/common/bufpool" "github.com/go-gost/gost/pkg/common/bufpool"
) )
var ( type udpTunConn struct {
_ net.PacketConn = (*UDPTunConn)(nil)
_ net.Conn = (*UDPTunConn)(nil)
_ net.PacketConn = (*UDPConn)(nil)
_ net.Conn = (*UDPConn)(nil)
)
type UDPTunConn struct {
net.Conn net.Conn
taddr net.Addr taddr net.Addr
} }
func UDPTunClientConn(c net.Conn, targetAddr net.Addr) *UDPTunConn { func UDPTunClientConn(c net.Conn, targetAddr net.Addr) net.Conn {
return &UDPTunConn{ return &udpTunConn{
Conn: c, Conn: c,
taddr: targetAddr, taddr: targetAddr,
} }
} }
func UDPTunServerConn(c net.Conn) *UDPTunConn { func UDPTunClientPacketConn(c net.Conn) net.PacketConn {
return &UDPTunConn{ return &udpTunConn{
Conn: c, Conn: c,
} }
} }
func (c *UDPTunConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { func UDPTunServerConn(c net.Conn) net.PacketConn {
return &udpTunConn{
Conn: c,
}
}
func (c *udpTunConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
socksAddr := gosocks5.Addr{} socksAddr := gosocks5.Addr{}
header := gosocks5.UDPHeader{ header := gosocks5.UDPHeader{
Addr: &socksAddr, Addr: &socksAddr,
@ -57,12 +55,12 @@ func (c *UDPTunConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
return return
} }
func (c *UDPTunConn) Read(b []byte) (n int, err error) { func (c *udpTunConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b) n, _, err = c.ReadFrom(b)
return return
} }
func (c *UDPTunConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { func (c *udpTunConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
socksAddr := gosocks5.Addr{} socksAddr := gosocks5.Addr{}
if err = socksAddr.ParseFrom(addr.String()); err != nil { if err = socksAddr.ParseFrom(addr.String()); err != nil {
return return
@ -83,7 +81,7 @@ func (c *UDPTunConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return return
} }
func (c *UDPTunConn) Write(b []byte) (n int, err error) { func (c *udpTunConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.taddr) return c.WriteTo(b, c.taddr)
} }
@ -91,15 +89,15 @@ var (
DefaultBufferSize = 4096 DefaultBufferSize = 4096
) )
type UDPConn struct { type udpConn struct {
net.PacketConn net.PacketConn
raddr net.Addr raddr net.Addr
taddr net.Addr taddr net.Addr
bufferSize int bufferSize int
} }
func NewUDPConn(c net.PacketConn, bufferSize int) *UDPConn { func UDPConn(c net.PacketConn, bufferSize int) net.PacketConn {
return &UDPConn{ return &udpConn{
PacketConn: c, PacketConn: c,
bufferSize: bufferSize, bufferSize: bufferSize,
} }
@ -108,7 +106,7 @@ func NewUDPConn(c net.PacketConn, bufferSize int) *UDPConn {
// ReadFrom reads an UDP datagram. // ReadFrom reads an UDP datagram.
// NOTE: for server side, // NOTE: for server side,
// the returned addr is the target address the client want to relay to. // the returned addr is the target address the client want to relay to.
func (c *UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { func (c *udpConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
rbuf := bufpool.Get(c.bufferSize) rbuf := bufpool.Get(c.bufferSize)
defer bufpool.Put(rbuf) defer bufpool.Put(rbuf)
@ -131,12 +129,12 @@ func (c *UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
return return
} }
func (c *UDPConn) Read(b []byte) (n int, err error) { func (c *udpConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b) n, _, err = c.ReadFrom(b)
return return
} }
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { func (c *udpConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
wbuf := bufpool.Get(c.bufferSize) wbuf := bufpool.Get(c.bufferSize)
defer bufpool.Put(wbuf) defer bufpool.Put(wbuf)
@ -165,10 +163,10 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return return
} }
func (c *UDPConn) Write(b []byte) (n int, err error) { func (c *udpConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.taddr) return c.WriteTo(b, c.taddr)
} }
func (c *UDPConn) RemoteAddr() net.Addr { func (c *udpConn) RemoteAddr() net.Addr {
return c.raddr return c.raddr
} }

View File

@ -53,13 +53,7 @@ func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, addre
switch network { switch network {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
case "udp", "udp4", "udp6": case "udp", "udp4", "udp6":
if c.md.enableUDP {
return c.connectUDP(ctx, conn, network, address) return c.connectUDP(ctx, conn, network, address)
} else {
err := errors.New("UDP relay is disabled")
c.logger.Error(err)
return nil, err
}
default: default:
err := fmt.Errorf("network %s unsupported", network) err := fmt.Errorf("network %s unsupported", network)
c.logger.Error(err) c.logger.Error(err)
@ -105,6 +99,12 @@ func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, addre
} }
func (c *ssConnector) connectUDP(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) { func (c *ssConnector) connectUDP(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) {
if c.md.enableUDP {
err := errors.New("UDP relay is disabled")
c.logger.Error(err)
return nil, err
}
if c.md.connectTimeout > 0 { if c.md.connectTimeout > 0 {
conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) conn.SetDeadline(time.Now().Add(c.md.connectTimeout))
defer conn.SetDeadline(time.Time{}) defer conn.SetDeadline(time.Time{})

View File

@ -63,7 +63,7 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosoc
defer peer.Close() defer peer.Close()
go h.relayUDP( go h.relayUDP(
socks.NewUDPConn(relay, h.md.udpBufferSize), socks.UDPConn(relay, h.md.udpBufferSize),
peer, peer,
) )
} else { } else {
@ -75,8 +75,8 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosoc
defer tun.Close() defer tun.Close()
go h.tunnelClientUDP( go h.tunnelClientUDP(
socks.NewUDPConn(relay, h.md.udpBufferSize), socks.UDPConn(relay, h.md.udpBufferSize),
socks.UDPTunClientConn(tun, nil), socks.UDPTunClientPacketConn(tun),
) )
} }

View File

@ -193,7 +193,7 @@ func (l *rudpListener) initUDPTunnel(conn net.Conn) (net.PacketConn, error) {
} }
l.logger.Debugf("bind on %s OK", baddr) l.logger.Debugf("bind on %s OK", baddr)
return socks.UDPTunClientConn(conn, nil), nil return socks.UDPTunClientPacketConn(conn), nil
} }
func (l *rudpListener) getConn(conn net.PacketConn, raddr net.Addr) *udp.Conn { func (l *rudpListener) getConn(conn net.PacketConn, raddr net.Addr) *udp.Conn {