add ssu connector

This commit is contained in:
ginuerzh 2021-11-09 23:34:19 +08:00
parent 92dc87830f
commit cae199dbd9
29 changed files with 1031 additions and 678 deletions

View File

@ -45,7 +45,7 @@ services:
# bypass: bypass01 # bypass: bypass01
- name: ssu - name: ssu
url: "ss://chacha20:gost@:8000" url: "ss://chacha20:gost@:8000"
addr: ":8338" addr: ":8388"
handler: handler:
type: ssu type: ssu
metadata: metadata:
@ -54,7 +54,8 @@ services:
readTimeout: 5s readTimeout: 5s
retry: 3 retry: 3
listener: listener:
type: udp type: tcp
# chain: chain-ssu
- name: socks5+tcp - name: socks5+tcp
url: "socks5://gost:gost@:1080" url: "socks5://gost:gost@:1080"
addr: ":1080" addr: ":1080"
@ -71,7 +72,7 @@ services:
type: tcp type: tcp
metadata: metadata:
keepAlive: 15s keepAlive: 15s
# chain: chain-socks5 chain: chain-socks5
# bypass: bypass01 # bypass: bypass01
- name: socks5+tcp - name: socks5+tcp
url: "socks5://gost:gost@:1080" url: "socks5://gost:gost@:1080"
@ -178,6 +179,20 @@ chains:
dialer: dialer:
type: tcp type: tcp
metadata: {} metadata: {}
- name: chain-ssu
hops:
- name: hop01
nodes:
- name: node01
addr: ":8339"
url: "http://gost:gost@:8081"
# bypass: bypass01
connector:
type: ssu
metadata: {}
dialer:
type: udp
metadata: {}
bypasses: bypasses:
- name: bypass01 - name: bypass01

View File

@ -6,9 +6,11 @@ import (
_ "github.com/go-gost/gost/pkg/connector/socks/v4" _ "github.com/go-gost/gost/pkg/connector/socks/v4"
_ "github.com/go-gost/gost/pkg/connector/socks/v5" _ "github.com/go-gost/gost/pkg/connector/socks/v5"
_ "github.com/go-gost/gost/pkg/connector/ss" _ "github.com/go-gost/gost/pkg/connector/ss"
_ "github.com/go-gost/gost/pkg/connector/ssu"
// Register dialers // Register dialers
_ "github.com/go-gost/gost/pkg/dialer/tcp" _ "github.com/go-gost/gost/pkg/dialer/tcp"
_ "github.com/go-gost/gost/pkg/dialer/udp"
// Register handlers // Register handlers
_ "github.com/go-gost/gost/pkg/handler/http" _ "github.com/go-gost/gost/pkg/handler/http"

2
go.mod
View File

@ -7,7 +7,7 @@ require (
github.com/coreos/go-iptables v0.5.0 // indirect github.com/coreos/go-iptables v0.5.0 // indirect
github.com/ginuerzh/tls-dissector v0.0.2-0.20201202075250-98fa925912da github.com/ginuerzh/tls-dissector v0.0.2-0.20201202075250-98fa925912da
github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks4 v0.0.1
github.com/go-gost/gosocks5 v0.3.1-0.20211108125245-019dfd6b3aea github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09
github.com/gobwas/glob v0.2.3 github.com/gobwas/glob v0.2.3
github.com/golang/snappy v0.0.3 github.com/golang/snappy v0.0.3
github.com/google/gopacket v1.1.19 // indirect github.com/google/gopacket v1.1.19 // indirect

2
go.sum
View File

@ -125,6 +125,8 @@ github.com/go-gost/gosocks5 v0.3.1-0.20211108032632-bbfd2de9a32d h1:mjoFToMUWNN0
github.com/go-gost/gosocks5 v0.3.1-0.20211108032632-bbfd2de9a32d/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= github.com/go-gost/gosocks5 v0.3.1-0.20211108032632-bbfd2de9a32d/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4=
github.com/go-gost/gosocks5 v0.3.1-0.20211108125245-019dfd6b3aea h1:mrm6bMpdxBvInvBuDbUaAQWV60r/PaByLIG9fQJEEIc= github.com/go-gost/gosocks5 v0.3.1-0.20211108125245-019dfd6b3aea h1:mrm6bMpdxBvInvBuDbUaAQWV60r/PaByLIG9fQJEEIc=
github.com/go-gost/gosocks5 v0.3.1-0.20211108125245-019dfd6b3aea/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= github.com/go-gost/gosocks5 v0.3.1-0.20211108125245-019dfd6b3aea/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4=
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04=
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=

View File

@ -43,6 +43,21 @@ func (c *httpConnector) Init(md md.Metadata) (err error) {
} }
func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) {
c.logger = c.logger.WithFields(map[string]interface{}{
"local": conn.LocalAddr().String(),
"remote": conn.RemoteAddr().String(),
"network": network,
"address": address,
})
switch network {
case "tcp", "tcp4", "tcp6":
default:
err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network)
c.logger.Error(err)
return nil, err
}
req := &http.Request{ req := &http.Request{
Method: http.MethodConnect, Method: http.MethodConnect,
URL: &url.URL{Host: address}, URL: &url.URL{Host: address},
@ -56,11 +71,6 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
} }
req.Header.Set("Proxy-Connection", "keep-alive") req.Header.Set("Proxy-Connection", "keep-alive")
c.logger = c.logger.WithFields(map[string]interface{}{
"local": conn.LocalAddr().String(),
"remote": conn.RemoteAddr().String(),
"target": address,
})
c.logger.Infof("connect: ", address) c.logger.Infof("connect: ", address)
if user := c.md.User; user != nil { if user := c.md.User; user != nil {

View File

@ -42,10 +42,20 @@ func (c *socks4Connector) Init(md md.Metadata) (err error) {
func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) {
c.logger = c.logger.WithFields(map[string]interface{}{ c.logger = c.logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(), "remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(), "local": conn.LocalAddr().String(),
"target": address, "network": network,
"address": address,
}) })
switch network {
case "tcp", "tcp4", "tcp6":
default:
err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network)
c.logger.Error(err)
return nil, err
}
c.logger.Info("connect: ", address) c.logger.Info("connect: ", address)
var addr *gosocks4.Addr var addr *gosocks4.Addr
@ -87,19 +97,14 @@ func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, a
c.logger.Error(err) c.logger.Error(err)
return nil, err return nil, err
} }
if c.logger.IsLevelEnabled(logger.DebugLevel) { c.logger.Debug(req)
c.logger.Debug(req)
}
reply, err := gosocks4.ReadReply(conn) reply, err := gosocks4.ReadReply(conn)
if err != nil { if err != nil {
c.logger.Error(err) c.logger.Error(err)
return nil, err return nil, err
} }
c.logger.Debug(reply)
if c.logger.IsLevelEnabled(logger.DebugLevel) {
c.logger.Debug(reply)
}
if reply.Code != gosocks4.Granted { if reply.Code != gosocks4.Granted {
return nil, fmt.Errorf("error: %d", reply.Code) return nil, fmt.Errorf("error: %d", reply.Code)

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"net" "net"
"net/url" "net/url"
"strings" "strings"
@ -79,6 +80,7 @@ func (c *socks5Connector) Handshake(ctx context.Context, conn net.Conn) (net.Con
cc := gosocks5.ClientConn(conn, c.selector) cc := gosocks5.ClientConn(conn, c.selector)
if err := cc.Handleshake(); err != nil { if err := cc.Handleshake(); err != nil {
c.logger.Error(err)
return nil, err return nil, err
} }
@ -87,12 +89,22 @@ func (c *socks5Connector) Handshake(ctx context.Context, conn net.Conn) (net.Con
func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) {
c.logger = c.logger.WithFields(map[string]interface{}{ c.logger = c.logger.WithFields(map[string]interface{}{
"target": address, "network": network,
"address": address,
}) })
switch network {
case "tcp", "tcp4", "tcp6":
default:
err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network)
c.logger.Error(err)
return nil, err
}
c.logger.Info("connect: ", address) c.logger.Info("connect: ", address)
addr, err := gosocks5.NewAddr(address) addr := gosocks5.Addr{}
if err != nil { if err := addr.ParseFrom(address); err != nil {
c.logger.Error(err) c.logger.Error(err)
return nil, err return nil, err
} }
@ -102,25 +114,19 @@ func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, a
defer conn.SetDeadline(time.Time{}) defer conn.SetDeadline(time.Time{})
} }
req := gosocks5.NewRequest(gosocks5.CmdConnect, addr) req := gosocks5.NewRequest(gosocks5.CmdConnect, &addr)
if err := req.Write(conn); err != nil { if err := req.Write(conn); err != nil {
c.logger.Error(err) c.logger.Error(err)
return nil, err return nil, err
} }
c.logger.Debug(req)
if c.logger.IsLevelEnabled(logger.DebugLevel) {
c.logger.Debug(req)
}
reply, err := gosocks5.ReadReply(conn) reply, err := gosocks5.ReadReply(conn)
if err != nil { if err != nil {
c.logger.Error(err) c.logger.Error(err)
return nil, err return nil, err
} }
c.logger.Debug(reply)
if c.logger.IsLevelEnabled(logger.DebugLevel) {
c.logger.Debug(reply)
}
if reply.Rep != gosocks5.Succeeded { if reply.Rep != gosocks5.Succeeded {
return nil, errors.New("service unavailable") return nil, errors.New("service unavailable")

View File

@ -18,9 +18,7 @@ type clientSelector struct {
} }
func (s *clientSelector) Methods() []uint8 { func (s *clientSelector) Methods() []uint8 {
if s.logger.IsLevelEnabled(logger.DebugLevel) { s.logger.Debug("methods: ", s.methods)
s.logger.Debug("methods: ", s.methods)
}
return s.methods return s.methods
} }
@ -33,9 +31,7 @@ func (s *clientSelector) Select(methods ...uint8) (method uint8) {
} }
func (s *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { func (s *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) {
if s.logger.IsLevelEnabled(logger.DebugLevel) { s.logger.Debug("method selected: ", method)
s.logger.Debug("method selected: ", method)
}
switch method { switch method {
case socks.MethodTLS: case socks.MethodTLS:
@ -57,18 +53,14 @@ func (s *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, erro
s.logger.Error(err) s.logger.Error(err)
return nil, err return nil, err
} }
if s.logger.IsLevelEnabled(logger.DebugLevel) { s.logger.Debug(req)
s.logger.Debug(req)
}
resp, err := gosocks5.ReadUserPassResponse(conn) resp, err := gosocks5.ReadUserPassResponse(conn)
if err != nil { if err != nil {
s.logger.Error(err) s.logger.Error(err)
return nil, err return nil, err
} }
if s.logger.IsLevelEnabled(logger.DebugLevel) { s.logger.Debug(resp)
s.logger.Debug(resp)
}
if resp.Status != gosocks5.Succeeded { if resp.Status != gosocks5.Succeeded {
return nil, gosocks5.ErrAuthFailure return nil, gosocks5.ErrAuthFailure

View File

@ -2,6 +2,7 @@ package ss
import ( import (
"context" "context"
"fmt"
"net" "net"
"time" "time"
@ -40,21 +41,30 @@ func (c *ssConnector) Init(md md.Metadata) (err error) {
func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) {
c.logger = c.logger.WithFields(map[string]interface{}{ c.logger = c.logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(), "remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(), "local": conn.LocalAddr().String(),
"target": address, "network": network,
"address": address,
}) })
switch network {
case "tcp", "tcp4", "tcp6":
default:
err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network)
c.logger.Error(err)
return nil, err
}
c.logger.Infof("connect: ", address) c.logger.Infof("connect: ", address)
socksAddr, err := gosocks5.NewAddr(address) addr := gosocks5.Addr{}
if err != nil { if err := addr.ParseFrom(address); err != nil {
c.logger.Error("parse addr: ", err) c.logger.Error(err)
return nil, err return nil, err
} }
rawaddr := bufpool.Get(512) rawaddr := bufpool.Get(512)
defer bufpool.Put(rawaddr) defer bufpool.Put(rawaddr)
n, err := socksAddr.Encode(rawaddr) n, err := addr.Encode(rawaddr)
if err != nil { if err != nil {
c.logger.Error("encoding addr: ", err) c.logger.Error("encoding addr: ", err)
return nil, err return nil, err

View File

@ -0,0 +1,105 @@
package ssu
import (
"context"
"fmt"
"net"
"time"
"github.com/go-gost/gost/pkg/connector"
"github.com/go-gost/gost/pkg/internal/utils/socks"
"github.com/go-gost/gost/pkg/internal/utils/ss"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
)
func init() {
registry.RegiserConnector("ssu", NewConnector)
}
type ssuConnector struct {
md metadata
logger logger.Logger
}
func NewConnector(opts ...connector.Option) connector.Connector {
options := &connector.Options{}
for _, opt := range opts {
opt(options)
}
return &ssuConnector{
logger: options.Logger,
}
}
func (c *ssuConnector) Init(md md.Metadata) (err error) {
return c.parseMetadata(md)
}
func (c *ssuConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) {
c.logger = c.logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
"network": network,
"address": address,
})
switch network {
case "udp", "udp4", "udp6":
default:
err := fmt.Errorf("network %s unsupported, should be udp, udp4 or udp6", network)
c.logger.Error(err)
return nil, err
}
c.logger.Info("connect: ", address)
if c.md.connectTimeout > 0 {
conn.SetDeadline(time.Now().Add(c.md.connectTimeout))
defer conn.SetDeadline(time.Time{})
}
taddr, _ := net.ResolveUDPAddr(network, address)
if taddr == nil {
taddr = &net.UDPAddr{}
}
pc, ok := conn.(net.PacketConn)
if ok {
if c.md.cipher != nil {
pc = c.md.cipher.PacketConn(pc)
}
return ss.UDPClientConn(pc, conn.RemoteAddr(), taddr, c.md.bufferSize), nil
}
return socks.UDPTunClientConn(conn, taddr), nil
}
func (c *ssuConnector) parseMetadata(md md.Metadata) (err error) {
c.md.cipher, err = ss.ShadowCipher(
md.GetString(method),
md.GetString(password),
md.GetString(key),
)
if err != nil {
return
}
c.md.connectTimeout = md.GetDuration(connectTimeout)
c.md.bufferSize = md.GetInt(bufferSize)
if c.md.bufferSize > 0 {
if c.md.bufferSize < 512 {
c.md.bufferSize = 512
}
if c.md.bufferSize > 65*1024 {
c.md.bufferSize = 65 * 1024
}
} else {
c.md.bufferSize = 4096
}
return
}

View File

@ -0,0 +1,21 @@
package ssu
import (
"time"
"github.com/shadowsocks/go-shadowsocks2/core"
)
const (
method = "method"
password = "password"
key = "key"
connectTimeout = "timeout"
bufferSize = "bufferSize"
)
type metadata struct {
cipher core.Cipher
connectTimeout time.Duration
bufferSize int
}

View File

@ -46,12 +46,10 @@ func (d *tcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp
if err != nil { if err != nil {
d.logger.Error(err) d.logger.Error(err)
} else { } else {
if d.logger.IsLevelEnabled(logger.DebugLevel) { d.logger.WithFields(map[string]interface{}{
d.logger.WithFields(map[string]interface{}{ "src": conn.LocalAddr().String(),
"src": conn.LocalAddr().String(), "dst": addr,
"dst": addr, }).Debug("dial with dial func")
}).Debug("dial with dial func")
}
} }
return conn, err return conn, err
} }
@ -61,12 +59,10 @@ func (d *tcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp
if err != nil { if err != nil {
d.logger.Error(err) d.logger.Error(err)
} else { } else {
if d.logger.IsLevelEnabled(logger.DebugLevel) { d.logger.WithFields(map[string]interface{}{
d.logger.WithFields(map[string]interface{}{ "src": conn.LocalAddr().String(),
"src": conn.LocalAddr().String(), "dst": addr,
"dst": addr, }).Debug("dial direct")
}).Debug("dial direct")
}
} }
return conn, err return conn, err
} }

17
pkg/dialer/udp/conn.go Normal file
View File

@ -0,0 +1,17 @@
package udp
import "net"
type conn struct {
*net.UDPConn
}
func (c *conn) WriteTo(b []byte, addr net.Addr) (int, error) {
return c.UDPConn.Write(b)
}
func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
n, err = c.UDPConn.Read(b)
addr = c.RemoteAddr()
return
}

54
pkg/dialer/udp/dialer.go Normal file
View File

@ -0,0 +1,54 @@
package udp
import (
"context"
"net"
"github.com/go-gost/gost/pkg/dialer"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
)
func init() {
registry.RegisterDialer("udp", NewDialer)
}
type udpDialer struct {
md metadata
logger logger.Logger
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &udpDialer{
logger: options.Logger,
}
}
func (d *udpDialer) Init(md md.Metadata) (err error) {
return d.parseMetadata(md)
}
func (d *udpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) {
taddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
c, err := net.DialUDP("udp", nil, taddr)
if err != nil {
return nil, err
}
return &conn{
UDPConn: c,
}, nil
}
func (d *udpDialer) parseMetadata(md md.Metadata) (err error) {
return
}

View File

@ -0,0 +1,15 @@
package udp
import "time"
const (
dialTimeout = "dialTimeout"
)
const (
defaultDialTimeout = 5 * time.Second
)
type metadata struct {
dialTimeout time.Duration
}

View File

@ -70,19 +70,15 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) {
h.logger.Error(err) h.logger.Error(err)
return return
} }
conn.SetReadDeadline(time.Time{}) h.logger.Debug(req)
if h.logger.IsLevelEnabled(logger.DebugLevel) { conn.SetReadDeadline(time.Time{})
h.logger.Debug(req)
}
if h.md.authenticator != nil && if h.md.authenticator != nil &&
!h.md.authenticator.Authenticate(string(req.Userid), "") { !h.md.authenticator.Authenticate(string(req.Userid), "") {
resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil) resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil)
resp.Write(conn) resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(resp)
h.logger.Debug(resp)
}
return return
} }
@ -107,9 +103,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
if h.bypass != nil && h.bypass.Contains(addr) { if h.bypass != nil && h.bypass.Contains(addr) {
resp := gosocks4.NewReply(gosocks4.Rejected, nil) resp := gosocks4.NewReply(gosocks4.Rejected, nil)
resp.Write(conn) resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(resp)
h.logger.Debug(resp)
}
h.logger.Info("bypass: ", addr) h.logger.Info("bypass: ", addr)
return return
} }
@ -122,9 +116,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
if err != nil { if err != nil {
resp := gosocks4.NewReply(gosocks4.Failed, nil) resp := gosocks4.NewReply(gosocks4.Failed, nil)
resp.Write(conn) resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(resp)
h.logger.Debug(resp)
}
return return
} }
@ -135,9 +127,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
h.logger.Error(err) h.logger.Error(err)
return return
} }
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(resp)
h.logger.Debug(resp)
}
t := time.Now() t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr)

View File

@ -7,7 +7,6 @@ import (
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
) )
func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks5.Request) { func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks5.Request) {
@ -33,9 +32,7 @@ func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, req *goso
if err != nil { if err != nil {
resp := gosocks5.NewReply(gosocks5.Failure, nil) resp := gosocks5.NewReply(gosocks5.Failure, nil)
resp.Write(conn) resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(resp)
h.logger.Debug(resp)
}
return return
} }
defer cc.Close() defer cc.Close()
@ -45,9 +42,7 @@ func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, req *goso
h.logger.Error(err) h.logger.Error(err)
resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil)
resp.Write(conn) resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(resp)
h.logger.Debug(resp)
}
return return
} }
@ -65,32 +60,25 @@ func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, addr strin
if err := reply.Write(conn); err != nil { if err := reply.Write(conn); err != nil {
h.logger.Error(err) h.logger.Error(err)
} }
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(reply)
h.logger.Debug(reply.String())
}
return return
} }
socksAddr, err := gosocks5.NewAddr(ln.Addr().String()) socksAddr := gosocks5.Addr{}
if err != nil { if err := socksAddr.ParseFrom(ln.Addr().String()); err != nil {
h.logger.Warn(err) h.logger.Warn(err)
socksAddr = &gosocks5.Addr{
Type: gosocks5.AddrIPv4,
}
} }
// Issue: may not reachable when host has multi-interface // Issue: may not reachable when host has multi-interface
socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String())
socksAddr.Type = 0 socksAddr.Type = 0
reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) reply := gosocks5.NewReply(gosocks5.Succeeded, &socksAddr)
if err := reply.Write(conn); err != nil { if err := reply.Write(conn); err != nil {
h.logger.Error(err) h.logger.Error(err)
ln.Close() ln.Close()
return return
} }
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(reply)
h.logger.Debug(reply.String())
}
h.logger = h.logger.WithFields(map[string]interface{}{ h.logger = h.logger.WithFields(map[string]interface{}{
"bind": socksAddr.String(), "bind": socksAddr.String(),
@ -143,14 +131,13 @@ func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Lis
} }
defer rc.Close() defer rc.Close()
raddr, _ := gosocks5.NewAddr(rc.RemoteAddr().String()) raddr := gosocks5.Addr{}
reply := gosocks5.NewReply(gosocks5.Succeeded, raddr) raddr.ParseFrom(rc.RemoteAddr().String())
reply := gosocks5.NewReply(gosocks5.Succeeded, &raddr)
if err := reply.Write(pc2); err != nil { if err := reply.Write(pc2); err != nil {
h.logger.Error(err) h.logger.Error(err)
} }
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(reply)
h.logger.Debug(reply.String())
}
h.logger.Infof("peer accepted: %s", raddr.String()) h.logger.Infof("peer accepted: %s", raddr.String())
start := time.Now() start := time.Now()

View File

@ -7,7 +7,6 @@ import (
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
) )
func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr string) { func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr string) {
@ -20,9 +19,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr s
if h.bypass != nil && h.bypass.Contains(addr) { if h.bypass != nil && h.bypass.Contains(addr) {
resp := gosocks5.NewReply(gosocks5.NotAllowed, nil) resp := gosocks5.NewReply(gosocks5.NotAllowed, nil)
resp.Write(conn) resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(resp)
h.logger.Debug(resp)
}
h.logger.Info("bypass: ", addr) h.logger.Info("bypass: ", addr)
return return
} }
@ -35,9 +32,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr s
if err != nil { if err != nil {
resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil)
resp.Write(conn) resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(resp)
h.logger.Debug(resp)
}
return return
} }
@ -48,9 +43,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr s
h.logger.Error(err) h.logger.Error(err)
return return
} }
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(resp)
h.logger.Debug(resp)
}
t := time.Now() t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr)

View File

@ -83,12 +83,9 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) {
h.logger.Error(err) h.logger.Error(err)
return return
} }
h.logger.Debug(req)
conn.SetReadDeadline(time.Time{}) conn.SetReadDeadline(time.Time{})
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(req)
}
switch req.Cmd { switch req.Cmd {
case gosocks5.CmdConnect: case gosocks5.CmdConnect:
h.handleConnect(ctx, conn, req.Addr.String()) h.handleConnect(ctx, conn, req.Addr.String())
@ -104,9 +101,7 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) {
h.logger.Errorf("unknown cmd: %d", req.Cmd) h.logger.Errorf("unknown cmd: %d", req.Cmd)
resp := gosocks5.NewReply(gosocks5.CmdUnsupported, nil) resp := gosocks5.NewReply(gosocks5.CmdUnsupported, nil)
resp.Write(conn) resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(resp)
h.logger.Debug(resp)
}
return return
} }
} }

View File

@ -8,7 +8,6 @@ import (
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/internal/utils/mux" "github.com/go-gost/gost/pkg/internal/utils/mux"
"github.com/go-gost/gost/pkg/logger"
) )
func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, req *gosocks5.Request) { func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, req *gosocks5.Request) {
@ -34,9 +33,7 @@ func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, req *g
if err != nil { if err != nil {
resp := gosocks5.NewReply(gosocks5.Failure, nil) resp := gosocks5.NewReply(gosocks5.Failure, nil)
resp.Write(conn) resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(resp)
h.logger.Debug(resp)
}
return return
} }
defer cc.Close() defer cc.Close()
@ -46,9 +43,7 @@ func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, req *g
h.logger.Error(err) h.logger.Error(err)
resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil)
resp.Write(conn) resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(resp)
h.logger.Debug(resp)
}
return return
} }
@ -71,32 +66,26 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, addr st
if err := reply.Write(conn); err != nil { if err := reply.Write(conn); err != nil {
h.logger.Error(err) h.logger.Error(err)
} }
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(reply)
h.logger.Debug(reply.String())
}
return return
} }
socksAddr, err := gosocks5.NewAddr(ln.Addr().String()) socksAddr := gosocks5.Addr{}
socksAddr.ParseFrom(ln.Addr().String())
if err != nil { if err != nil {
h.logger.Warn(err) h.logger.Warn(err)
socksAddr = &gosocks5.Addr{
Type: gosocks5.AddrIPv4,
}
} }
// Issue: may not reachable when host has multi-interface // Issue: may not reachable when host has multi-interface
socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String())
socksAddr.Type = 0 socksAddr.Type = 0
reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr) reply := gosocks5.NewReply(gosocks5.Succeeded, &socksAddr)
if err := reply.Write(conn); err != nil { if err := reply.Write(conn); err != nil {
h.logger.Error(err) h.logger.Error(err)
ln.Close() ln.Close()
return return
} }
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(reply)
h.logger.Debug(reply.String())
}
h.logger = h.logger.WithFields(map[string]interface{}{ h.logger = h.logger.WithFields(map[string]interface{}{
"bind": socksAddr.String(), "bind": socksAddr.String(),

View File

@ -23,9 +23,7 @@ func (selector *serverSelector) Methods() []uint8 {
} }
func (s *serverSelector) Select(methods ...uint8) (method uint8) { func (s *serverSelector) Select(methods ...uint8) (method uint8) {
if s.logger.IsLevelEnabled(logger.DebugLevel) { s.logger.Debugf("%d %d %v", gosocks5.Ver5, len(methods), methods)
s.logger.Debugf("%d %d %v", gosocks5.Ver5, len(methods), methods)
}
method = gosocks5.MethodNoAuth method = gosocks5.MethodNoAuth
for _, m := range methods { for _, m := range methods {
if m == socks.MethodTLS && !s.noTLS { if m == socks.MethodTLS && !s.noTLS {
@ -48,9 +46,7 @@ func (s *serverSelector) Select(methods ...uint8) (method uint8) {
} }
func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) {
if s.logger.IsLevelEnabled(logger.DebugLevel) { s.logger.Debugf("%d %d", gosocks5.Ver5, method)
s.logger.Debugf("%d %d", gosocks5.Ver5, method)
}
switch method { switch method {
case socks.MethodTLS: case socks.MethodTLS:
conn = tls.Server(conn, s.TLSConfig) conn = tls.Server(conn, s.TLSConfig)
@ -65,9 +61,7 @@ func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, erro
s.logger.Error(err) s.logger.Error(err)
return nil, err return nil, err
} }
if s.logger.IsLevelEnabled(logger.DebugLevel) { s.logger.Debug(req)
s.logger.Debug(req.String())
}
if s.Authenticator != nil && if s.Authenticator != nil &&
!s.Authenticator.Authenticate(req.Username, req.Password) { !s.Authenticator.Authenticate(req.Username, req.Password) {
@ -76,9 +70,8 @@ func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, erro
s.logger.Error(err) s.logger.Error(err)
return nil, err return nil, err
} }
if s.logger.IsLevelEnabled(logger.DebugLevel) { s.logger.Info(resp)
s.logger.Info(resp.String())
}
return nil, gosocks5.ErrAuthFailure return nil, gosocks5.ErrAuthFailure
} }
@ -87,9 +80,8 @@ func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, erro
s.logger.Error(err) s.logger.Error(err)
return nil, err return nil, err
} }
if s.logger.IsLevelEnabled(logger.DebugLevel) { s.logger.Debug(resp)
s.logger.Debug(resp.String())
}
case gosocks5.MethodNoAcceptable: case gosocks5.MethodNoAcceptable:
return nil, gosocks5.ErrBadMethod return nil, gosocks5.ErrBadMethod
} }

View File

@ -1,7 +1,6 @@
package v5 package v5
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"io" "io"
@ -13,7 +12,6 @@ import (
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/internal/bufpool" "github.com/go-gost/gost/pkg/internal/bufpool"
"github.com/go-gost/gost/pkg/internal/utils/socks" "github.com/go-gost/gost/pkg/internal/utils/socks"
"github.com/go-gost/gost/pkg/logger"
) )
func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosocks5.Request) { func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosocks5.Request) {
@ -26,27 +24,21 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosoc
h.logger.Error(err) h.logger.Error(err)
reply := gosocks5.NewReply(gosocks5.Failure, nil) reply := gosocks5.NewReply(gosocks5.Failure, nil)
reply.Write(conn) reply.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(reply)
h.logger.Debug(reply)
}
return return
} }
defer relay.Close() defer relay.Close()
saddr, _ := gosocks5.NewAddr(relay.LocalAddr().String()) saddr := gosocks5.Addr{}
if saddr == nil { saddr.ParseFrom(relay.LocalAddr().String())
saddr = &gosocks5.Addr{}
}
saddr.Type = 0 saddr.Type = 0
saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's
reply := gosocks5.NewReply(gosocks5.Succeeded, saddr) reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr)
if err := reply.Write(conn); err != nil { if err := reply.Write(conn); err != nil {
h.logger.Error(err) h.logger.Error(err)
return return
} }
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(reply)
h.logger.Debug(reply)
}
h.logger = h.logger.WithFields(map[string]interface{}{ h.logger = h.logger.WithFields(map[string]interface{}{
"bind": saddr.String(), "bind": saddr.String(),
@ -62,7 +54,10 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosoc
} }
defer peer.Close() defer peer.Close()
go h.relayUDP(relay, peer) go h.relayUDP(
socks.NewUDPConn(relay, h.md.udpBufferSize),
peer,
)
} else { } else {
tun, err := h.getUDPTun(ctx) tun, err := h.getUDPTun(ctx)
if err != nil { if err != nil {
@ -71,15 +66,18 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosoc
} }
defer tun.Close() defer tun.Close()
go h.tunnelClientUDP(relay, tun) go h.tunnelClientUDP(
socks.NewUDPConn(relay, h.md.udpBufferSize),
socks.UDPTunClientConn(tun, nil),
)
} }
t := time.Now() t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), saddr) h.logger.Infof("%s <-> %s", conn.RemoteAddr(), &saddr)
io.Copy(ioutil.Discard, conn) io.Copy(ioutil.Discard, conn)
h.logger. h.logger.
WithFields(map[string]interface{}{"duration": time.Since(t)}). WithFields(map[string]interface{}{"duration": time.Since(t)}).
Infof("%s >-< %s", conn.RemoteAddr(), saddr) Infof("%s >-< %s", conn.RemoteAddr(), &saddr)
} }
func (h *socks5Handler) getUDPTun(ctx context.Context) (conn net.Conn, err error) { func (h *socks5Handler) getUDPTun(ctx context.Context) (conn net.Conn, err error) {
@ -108,17 +106,13 @@ func (h *socks5Handler) getUDPTun(ctx context.Context) (conn net.Conn, err error
if err = req.Write(conn); err != nil { if err = req.Write(conn); err != nil {
return return
} }
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(req)
h.logger.Debug(req)
}
reply, err := gosocks5.ReadReply(conn) reply, err := gosocks5.ReadReply(conn)
if err != nil { if err != nil {
return return
} }
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(reply)
h.logger.Debug(reply)
}
if reply.Rep != gosocks5.Succeeded { if reply.Rep != gosocks5.Succeeded {
err = errors.New("UDP associate failed") err = errors.New("UDP associate failed")
@ -128,119 +122,72 @@ func (h *socks5Handler) getUDPTun(ctx context.Context) (conn net.Conn, err error
return return
} }
func (h *socks5Handler) tunnelClientUDP(c net.PacketConn, tunnel net.Conn) (err error) { func (h *socks5Handler) tunnelClientUDP(c, tun net.PacketConn) (err error) {
bufSize := h.md.udpBufferSize bufSize := h.md.udpBufferSize
errc := make(chan error, 2) errc := make(chan error, 2)
var clientAddr net.Addr
go func() { go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
for { for {
n, laddr, err := c.ReadFrom(b) err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := c.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := tun.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s >>> %s data: %d",
tun.LocalAddr(), raddr, n)
return nil
}()
if err != nil { if err != nil {
errc <- err errc <- err
return return
} }
if clientAddr == nil {
clientAddr = laddr
}
var addr gosocks5.Addr
header := gosocks5.UDPHeader{
Addr: &addr,
}
hlen, err := header.ReadFrom(bytes.NewReader(b[:n]))
if err != nil {
errc <- err
return
}
raddr, err := net.ResolveUDPAddr("udp", addr.String())
if err != nil {
continue // drop silently
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
continue // bypass
}
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: b[hlen:n],
}
dgram.Header.Rsv = uint16(len(dgram.Data))
if _, err := dgram.WriteTo(tunnel); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s >>> %s: %v data: %d",
clientAddr, raddr, b[:hlen], len(dgram.Data))
}
} }
}() }()
go func() { go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
const dataPos = 262
for { for {
addr := gosocks5.Addr{} err := func() error {
header := gosocks5.UDPHeader{ b := bufpool.Get(bufSize)
Addr: &addr, defer bufpool.Put(b)
}
n, raddr, err := tun.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := c.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s <<< %s data: %d",
tun.LocalAddr(), raddr, n)
return nil
}()
data := b[dataPos:]
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: data,
}
_, err := dgram.ReadFrom(tunnel)
if err != nil { if err != nil {
errc <- err errc <- err
return return
} }
// NOTE: the dgram.Data may be reallocated if the provided buffer is too short,
// we drop it for simplicity. As this occurs, you should enlarge the buffer size.
if len(dgram.Data) > len(data) {
h.logger.Warnf("buffer too short, dropped")
continue
}
// pipe from tunnel to relay
if clientAddr == nil {
h.logger.Warnf("ignore unexpected peer from %s", addr)
continue
}
raddr := addr.String()
if h.bypass != nil && h.bypass.Contains(raddr) {
h.logger.Warn("bypass: ", raddr)
continue // bypass
}
addrLen := addr.Length()
addr.Encode(b[dataPos-addrLen : dataPos])
hlen := addrLen + 3
if _, err := c.WriteTo(b[dataPos-hlen:dataPos+len(dgram.Data)], clientAddr); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s <<< %s: %v data: %d",
clientAddr, addr.String(), b[dataPos-hlen:dataPos], len(dgram.Data))
}
} }
}() }()
@ -251,91 +198,69 @@ func (h *socks5Handler) relayUDP(c, peer net.PacketConn) (err error) {
bufSize := h.md.udpBufferSize bufSize := h.md.udpBufferSize
errc := make(chan error, 2) errc := make(chan error, 2)
var clientAddr net.Addr
go func() { go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
for { for {
n, laddr, err := c.ReadFrom(b) err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := c.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := peer.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s >>> %s data: %d",
peer.LocalAddr(), raddr, n)
return nil
}()
if err != nil { if err != nil {
errc <- err errc <- err
return return
} }
if clientAddr == nil {
clientAddr = laddr
}
var addr gosocks5.Addr
header := gosocks5.UDPHeader{
Addr: &addr,
}
hlen, err := header.ReadFrom(bytes.NewReader(b[:n]))
if err != nil {
errc <- err
return
}
raddr, err := net.ResolveUDPAddr("udp", addr.String())
if err != nil {
continue // drop silently
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
continue // bypass
}
data := b[hlen:n]
if _, err := peer.WriteTo(data, raddr); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s >>> %s: %v data: %d",
clientAddr, raddr, b[:hlen], len(data))
}
} }
}() }()
go func() { go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
const dataPos = 262
for { for {
n, raddr, err := peer.ReadFrom(b[dataPos:]) err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := peer.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := c.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s <<< %s data: %d",
peer.LocalAddr(), raddr, n)
return nil
}()
if err != nil { if err != nil {
errc <- err errc <- err
return return
} }
if clientAddr == nil {
continue
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
continue // bypass
}
socksAddr, _ := gosocks5.NewAddr(raddr.String())
if socksAddr == nil {
socksAddr = &gosocks5.Addr{}
}
addrLen := socksAddr.Length()
socksAddr.Encode(b[dataPos-addrLen : dataPos])
hlen := addrLen + 3
if _, err := c.WriteTo(b[dataPos-hlen:dataPos+n], clientAddr); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s <<< %s: %v data: %d",
clientAddr, raddr, b[dataPos-hlen:dataPos], n)
}
} }
}() }()

View File

@ -8,7 +8,7 @@ import (
"github.com/go-gost/gosocks5" "github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/internal/bufpool" "github.com/go-gost/gost/pkg/internal/bufpool"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/internal/utils/socks"
) )
func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *gosocks5.Request) { func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *gosocks5.Request) {
@ -35,9 +35,7 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go
h.logger.Error(err) h.logger.Error(err)
return return
} }
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(reply)
h.logger.Debug(reply)
}
h.logger = h.logger.WithFields(map[string]interface{}{ h.logger = h.logger.WithFields(map[string]interface{}{
"bind": saddr.String(), "bind": saddr.String(),
@ -45,7 +43,10 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go
t := time.Now() t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), saddr) h.logger.Infof("%s <-> %s", conn.RemoteAddr(), saddr)
h.tunnelServerUDP(conn, relay) h.tunnelServerUDP(
socks.UDPTunServerConn(conn),
relay,
)
h.logger. h.logger.
WithFields(map[string]interface{}{ WithFields(map[string]interface{}{
"duration": time.Since(t), "duration": time.Since(t),
@ -64,9 +65,7 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go
h.logger.Error(err) h.logger.Error(err)
reply := gosocks5.NewReply(gosocks5.Failure, nil) reply := gosocks5.NewReply(gosocks5.Failure, nil)
reply.Write(conn) reply.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(reply)
h.logger.Debug(reply)
}
return return
} }
defer cc.Close() defer cc.Close()
@ -76,9 +75,7 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go
h.logger.Error(err) h.logger.Error(err)
reply := gosocks5.NewReply(gosocks5.Failure, nil) reply := gosocks5.NewReply(gosocks5.Failure, nil)
reply.Write(conn) reply.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debug(reply)
h.logger.Debug(reply)
}
} }
t := time.Now() t := time.Now()
@ -91,97 +88,72 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go
Infof("%s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) Infof("%s >-< %s", conn.RemoteAddr(), cc.RemoteAddr())
} }
func (h *socks5Handler) tunnelServerUDP(tunnel net.Conn, c net.PacketConn) (err error) { func (h *socks5Handler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) {
bufSize := h.md.udpBufferSize bufSize := h.md.udpBufferSize
errc := make(chan error, 2) errc := make(chan error, 2)
go func() { go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
const dataPos = 262
for { for {
addr := gosocks5.Addr{} err := func() error {
header := gosocks5.UDPHeader{ b := bufpool.Get(bufSize)
Addr: &addr, defer bufpool.Put(b)
}
n, raddr, err := tunnel.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := c.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s >>> %s data: %d",
c.LocalAddr(), raddr, n)
return nil
}()
data := b[dataPos:]
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: data,
}
_, err := dgram.ReadFrom(tunnel)
if err != nil { if err != nil {
errc <- err errc <- err
return return
} }
// NOTE: the dgram.Data may be reallocated if the provided buffer is too short,
// we drop it for simplicity. As this occurs, you should enlarge the buffer size.
if len(dgram.Data) > len(data) {
h.logger.Warnf("buffer too short, dropped")
continue
}
raddr, err := net.ResolveUDPAddr("udp", addr.String())
if err != nil {
continue // drop silently
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr.String())
continue // bypass
}
if _, err := c.WriteTo(dgram.Data, raddr); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s >>> %s: %v data: %d",
tunnel.RemoteAddr(), raddr, header.String(), len(dgram.Data))
}
} }
}() }()
go func() { go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
for { for {
n, raddr, err := c.ReadFrom(b) err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := c.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := tunnel.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s <<< %s data: %d",
c.LocalAddr(), raddr, n)
return nil
}()
if err != nil { if err != nil {
errc <- err errc <- err
return return
} }
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr.String())
continue // bypass
}
addr, _ := gosocks5.NewAddr(raddr.String())
if addr == nil {
addr = &gosocks5.Addr{}
}
header := gosocks5.UDPHeader{
Rsv: uint16(n),
Addr: addr,
}
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: b[:n],
}
if _, err := dgram.WriteTo(tunnel); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s <<< %s: %v data: %d",
tunnel.RemoteAddr(), raddr, header.String(), len(dgram.Data))
}
} }
}() }()

View File

@ -1,16 +1,15 @@
package ssu package ssu
import ( import (
"bytes"
"context" "context"
"net" "net"
"time" "time"
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/bypass"
"github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/internal/bufpool" "github.com/go-gost/gost/pkg/internal/bufpool"
"github.com/go-gost/gost/pkg/internal/utils/socks"
"github.com/go-gost/gost/pkg/internal/utils/ss" "github.com/go-gost/gost/pkg/internal/utils/ss"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
@ -91,7 +90,10 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) {
t := time.Now() t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
h.relayPacket(pc, cc) h.relayPacket(
ss.UDPServerConn(pc, conn.RemoteAddr(), h.md.bufferSize),
cc,
)
h.logger. h.logger.
WithFields(map[string]interface{}{"duration": time.Since(t)}). WithFields(map[string]interface{}{"duration": time.Since(t)}).
Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr())
@ -104,7 +106,7 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) {
t := time.Now() t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
h.tunnelUDP(conn, cc) h.tunnelUDP(socks.UDPTunServerConn(conn), cc)
h.logger. h.logger.
WithFields(map[string]interface{}{"duration": time.Since(t)}). WithFields(map[string]interface{}{"duration": time.Since(t)}).
Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr())
@ -112,47 +114,30 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) {
func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) { func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
bufSize := h.md.bufferSize bufSize := h.md.bufferSize
errc := make(chan error, 2) errc := make(chan error, 2)
var clientAddr net.Addr
go func() { go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
for { for {
err := func() error { err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, addr, err := pc1.ReadFrom(b) n, addr, err := pc1.ReadFrom(b)
if err != nil { if err != nil {
return err return err
} }
if clientAddr == nil {
clientAddr = addr
}
rb := bytes.NewBuffer(b[:n]) if h.bypass != nil && h.bypass.Contains(addr.String()) {
saddr := gosocks5.Addr{} h.logger.Warn("bypass: ", addr)
if _, err := saddr.ReadFrom(rb); err != nil {
return err
}
taddr, err := net.ResolveUDPAddr("udp", saddr.String())
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(taddr.String()) {
h.logger.Warn("bypass: ", taddr)
return nil return nil
} }
if _, err = pc2.WriteTo(rb.Bytes(), taddr); err != nil { if _, err = pc2.WriteTo(b[:n], addr); err != nil {
return err return err
} }
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debugf("%s >>> %s data: %d",
h.logger.Debugf("%s >>> %s: %v, data: %d", pc2.LocalAddr(), addr, n)
addr, taddr, saddr.String(), rb.Len())
}
return nil return nil
}() }()
@ -164,41 +149,27 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
}() }()
go func() { go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
const dataPos = 259
for { for {
err := func() error { err := func() error {
n, raddr, err := pc2.ReadFrom(b[dataPos:]) b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := pc2.ReadFrom(b)
if err != nil { if err != nil {
return err return err
} }
if clientAddr == nil {
return nil
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) { if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr) h.logger.Warn("bypass: ", raddr)
return nil return nil
} }
socksAddr, _ := gosocks5.NewAddr(raddr.String()) if _, err = pc1.WriteTo(b[:n], raddr); err != nil {
if socksAddr == nil {
socksAddr = &gosocks5.Addr{}
}
addrLen := socksAddr.Length()
socksAddr.Encode(b[dataPos-addrLen : dataPos])
if _, err = pc1.WriteTo(b[dataPos-addrLen:dataPos+n], clientAddr); err != nil {
return err return err
} }
if h.logger.IsLevelEnabled(logger.DebugLevel) { h.logger.Debugf("%s <<< %s data: %d",
h.logger.Debugf("%s <<< %s: %v data: %d", pc2.LocalAddr(), raddr, n)
clientAddr, raddr, b[dataPos-addrLen:dataPos], n)
}
return nil return nil
}() }()
@ -212,7 +183,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
return <-errc return <-errc
} }
func (h *ssuHandler) tunnelUDP(tunnel net.Conn, c net.PacketConn) (err error) { func (h *ssuHandler) tunnelUDP(tunnel, c net.PacketConn) (err error) {
bufSize := h.md.bufferSize bufSize := h.md.bufferSize
errc := make(chan error, 2) errc := make(chan error, 2)
@ -220,49 +191,32 @@ func (h *ssuHandler) tunnelUDP(tunnel net.Conn, c net.PacketConn) (err error) {
b := bufpool.Get(bufSize) b := bufpool.Get(bufSize)
defer bufpool.Put(b) defer bufpool.Put(b)
const dataPos = 262
for { for {
addr := gosocks5.Addr{} err := func() error {
header := gosocks5.UDPHeader{ n, addr, err := tunnel.ReadFrom(b)
Addr: &addr, if err != nil {
} return err
}
if h.bypass != nil && h.bypass.Contains(addr.String()) {
h.logger.Warn("bypass: ", addr.String())
return nil // bypass
}
if _, err := c.WriteTo(b[:n], addr); err != nil {
return err
}
h.logger.Debugf("%s >>> %s data: %d",
c.LocalAddr(), addr, n)
return nil
}()
data := b[dataPos:]
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: data,
}
_, err := dgram.ReadFrom(tunnel)
if err != nil { if err != nil {
errc <- err errc <- err
return return
} }
// NOTE: the dgram.Data may be reallocated if the provided buffer is too short,
// we drop it for simplicity. As this occurs, you should enlarge the buffer size.
if len(dgram.Data) > len(data) {
h.logger.Warnf("buffer too short, dropped")
continue
}
raddr, err := net.ResolveUDPAddr("udp", addr.String())
if err != nil {
continue // drop silently
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr.String())
continue // bypass
}
if _, err := c.WriteTo(dgram.Data, raddr); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s >>> %s: %v data: %d",
tunnel.RemoteAddr(), raddr, header.String(), len(dgram.Data))
}
} }
}() }()
@ -271,38 +225,31 @@ func (h *ssuHandler) tunnelUDP(tunnel net.Conn, c net.PacketConn) (err error) {
defer bufpool.Put(b) defer bufpool.Put(b)
for { for {
n, raddr, err := c.ReadFrom(b) err := func() error {
n, raddr, err := c.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr.String())
return nil // bypass
}
if _, err := tunnel.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s <<< %s data: %d",
c.LocalAddr(), raddr, n)
return nil
}()
if err != nil { if err != nil {
errc <- err errc <- err
return return
} }
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr.String())
continue // bypass
}
addr, _ := gosocks5.NewAddr(raddr.String())
if addr == nil {
addr = &gosocks5.Addr{}
}
header := gosocks5.UDPHeader{
Rsv: uint16(n),
Addr: addr,
}
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: b[:n],
}
if _, err := dgram.WriteTo(tunnel); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s <<< %s: %v data: %d",
tunnel.RemoteAddr(), raddr, header.String(), len(dgram.Data))
}
} }
}() }()

View File

@ -0,0 +1,173 @@
package socks
import (
"bytes"
"net"
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/internal/bufpool"
)
var (
_ net.PacketConn = (*UDPTunConn)(nil)
_ net.Conn = (*UDPTunConn)(nil)
_ net.PacketConn = (*UDPConn)(nil)
_ net.Conn = (*UDPConn)(nil)
)
type UDPTunConn struct {
net.Conn
taddr net.Addr
}
func UDPTunClientConn(c net.Conn, targetAddr net.Addr) *UDPTunConn {
return &UDPTunConn{
Conn: c,
taddr: targetAddr,
}
}
func UDPTunServerConn(c net.Conn) *UDPTunConn {
return &UDPTunConn{
Conn: c,
}
}
func (c *UDPTunConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
socksAddr := gosocks5.Addr{}
header := gosocks5.UDPHeader{
Addr: &socksAddr,
}
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: b,
}
_, err = dgram.ReadFrom(c.Conn)
if err != nil {
return
}
n = len(dgram.Data)
if n > len(b) {
n = copy(b, dgram.Data)
}
addr, err = net.ResolveUDPAddr("udp", socksAddr.String())
return
}
func (c *UDPTunConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *UDPTunConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
socksAddr := gosocks5.Addr{}
if err = socksAddr.ParseFrom(addr.String()); err != nil {
return
}
header := gosocks5.UDPHeader{
Addr: &socksAddr,
}
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: b,
}
dgram.Header.Rsv = uint16(len(dgram.Data))
_, err = dgram.WriteTo(c.Conn)
n = len(b)
return
}
func (c *UDPTunConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.taddr)
}
var (
DefaultBufferSize = 4096
)
type UDPConn struct {
net.PacketConn
raddr net.Addr
taddr net.Addr
bufferSize int
}
func NewUDPConn(c net.PacketConn, bufferSize int) *UDPConn {
return &UDPConn{
PacketConn: c,
bufferSize: bufferSize,
}
}
// ReadFrom reads an UDP datagram.
// NOTE: for server side,
// the returned addr is the target address the client want to relay to.
func (c *UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
rbuf := bufpool.Get(c.bufferSize)
defer bufpool.Put(rbuf)
n, c.raddr, err = c.PacketConn.ReadFrom(rbuf)
if err != nil {
return
}
socksAddr := gosocks5.Addr{}
header := gosocks5.UDPHeader{
Addr: &socksAddr,
}
hlen, err := header.ReadFrom(bytes.NewReader(rbuf[:n]))
if err != nil {
return
}
n = copy(b, rbuf[hlen:n])
addr, err = net.ResolveUDPAddr("udp", socksAddr.String())
return
}
func (c *UDPConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
wbuf := bufpool.Get(c.bufferSize)
defer bufpool.Put(wbuf)
socksAddr := gosocks5.Addr{}
if err = socksAddr.ParseFrom(addr.String()); err != nil {
return
}
header := gosocks5.UDPHeader{
Addr: &socksAddr,
}
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: b,
}
buf := bytes.NewBuffer(wbuf[:0])
_, err = dgram.WriteTo(buf)
if err != nil {
return
}
_, err = c.PacketConn.WriteTo(buf.Bytes(), c.raddr)
n = len(b)
return
}
func (c *UDPConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.taddr)
}
func (c *UDPConn) RemoteAddr() net.Addr {
return c.raddr
}

View File

@ -0,0 +1,96 @@
package ss
import (
"bytes"
"net"
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/internal/bufpool"
)
var (
DefaultBufferSize = 4096
)
var (
_ net.PacketConn = (*UDPConn)(nil)
_ net.Conn = (*UDPConn)(nil)
)
type UDPConn struct {
net.PacketConn
raddr net.Addr
taddr net.Addr
bufferSize int
}
func UDPClientConn(c net.PacketConn, remoteAddr, targetAddr net.Addr, bufferSize int) *UDPConn {
return &UDPConn{
PacketConn: c,
raddr: remoteAddr,
taddr: targetAddr,
bufferSize: bufferSize,
}
}
func UDPServerConn(c net.PacketConn, remoteAddr net.Addr, bufferSize int) *UDPConn {
return &UDPConn{
PacketConn: c,
raddr: remoteAddr,
bufferSize: bufferSize,
}
}
func (c *UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
rbuf := bufpool.Get(c.bufferSize)
defer bufpool.Put(rbuf)
n, _, err = c.PacketConn.ReadFrom(rbuf)
if err != nil {
return
}
saddr := gosocks5.Addr{}
addrLen, err := saddr.ReadFrom(bytes.NewReader(rbuf[:n]))
if err != nil {
return
}
n = copy(b, rbuf[addrLen:n])
addr, err = net.ResolveUDPAddr("udp", saddr.String())
return
}
func (c *UDPConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
wbuf := bufpool.Get(c.bufferSize)
defer bufpool.Put(wbuf)
socksAddr := gosocks5.Addr{}
if err = socksAddr.ParseFrom(addr.String()); err != nil {
return
}
addrLen, err := socksAddr.Encode(wbuf)
if err != nil {
return
}
n = copy(wbuf[addrLen:], b)
_, err = c.PacketConn.WriteTo(wbuf[:addrLen+n], c.raddr)
return
}
func (c *UDPConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.taddr)
}
func (c *UDPConn) RemoteAddr() net.Addr {
return c.raddr
}

View File

@ -6,109 +6,184 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/go-gost/gost/pkg/internal/bufpool"
"github.com/go-gost/gost/pkg/logger"
) )
// serverConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn. // conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type serverConn struct { type conn struct {
net.PacketConn net.PacketConn
raddr net.Addr remoteAddr net.Addr
rc chan []byte // data receive queue rc chan []byte // data receive queue
fresh int32 idle int32
closed chan struct{} closed chan struct{}
closeMutex sync.Mutex closeMutex sync.Mutex
config *serverConnConfig
} }
type serverConnConfig struct { func newConn(c net.PacketConn, raddr net.Addr, queue int) *conn {
ttl time.Duration return &conn{
qsize int PacketConn: c,
onClose func() remoteAddr: raddr,
} rc: make(chan []byte, queue),
func newServerConn(conn net.PacketConn, raddr net.Addr, cfg *serverConnConfig) *serverConn {
if conn == nil || raddr == nil {
return nil
}
if cfg == nil {
cfg = &serverConnConfig{}
}
c := &serverConn{
PacketConn: conn,
raddr: raddr,
rc: make(chan []byte, cfg.qsize),
closed: make(chan struct{}), closed: make(chan struct{}),
config: cfg,
} }
go c.ttlWait()
return c
} }
func (c *serverConn) send(b []byte) error { func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select { select {
case c.rc <- b: case bb := <-c.rc:
return nil n = copy(b, bb)
default: c.SetIdle(false)
return errors.New("queue is full") bufpool.Put(bb)
case <-c.closed:
err = net.ErrClosed
return
} }
addr = c.remoteAddr
return
} }
func (c *serverConn) Read(b []byte) (n int, err error) { func (c *conn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b) n, _, err = c.ReadFrom(b)
return return
} }
func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { func (c *conn) Write(b []byte) (n int, err error) {
select { return c.WriteTo(b, c.remoteAddr)
case bb := <-c.rc:
n = copy(b, bb)
atomic.StoreInt32(&c.fresh, 1)
case <-c.closed:
err = errors.New("read from closed connection")
return
}
addr = c.raddr
return
} }
func (c *serverConn) Write(b []byte) (n int, err error) { func (c *conn) Close() error {
return c.WriteTo(b, c.raddr)
}
func (c *serverConn) Close() error {
c.closeMutex.Lock() c.closeMutex.Lock()
defer c.closeMutex.Unlock() defer c.closeMutex.Unlock()
select { select {
case <-c.closed: case <-c.closed:
return errors.New("connection is closed")
default: default:
if c.config.onClose != nil {
c.config.onClose()
}
close(c.closed) close(c.closed)
} }
return nil return nil
} }
func (c *serverConn) RemoteAddr() net.Addr { func (c *conn) RemoteAddr() net.Addr {
return c.raddr return c.remoteAddr
} }
func (c *serverConn) ttlWait() { func (c *conn) IsIdle() bool {
ticker := time.NewTicker(c.config.ttl) return atomic.LoadInt32(&c.idle) > 0
}
func (c *conn) SetIdle(idle bool) {
v := int32(0)
if idle {
v = 1
}
atomic.StoreInt32(&c.idle, v)
}
func (c *conn) Queue(b []byte) error {
select {
case c.rc <- b:
return nil
case <-c.closed:
return net.ErrClosed
default:
return errors.New("recv queue is full")
}
}
type connPool struct {
m sync.Map
ttl time.Duration
closed chan struct{}
logger logger.Logger
}
func newConnPool(ttl time.Duration) *connPool {
p := &connPool{
ttl: ttl,
closed: make(chan struct{}),
}
go p.idleCheck()
return p
}
func (p *connPool) WithLogger(logger logger.Logger) *connPool {
p.logger = logger
return p
}
func (p *connPool) Get(key interface{}) (c *conn, ok bool) {
v, ok := p.m.Load(key)
if ok {
c, ok = v.(*conn)
}
return
}
func (p *connPool) Set(key interface{}, c *conn) {
p.m.Store(key, c)
}
func (p *connPool) Delete(key interface{}) {
p.m.Delete(key)
}
func (p *connPool) Close() {
select {
case <-p.closed:
return
default:
}
close(p.closed)
p.m.Range(func(k, v interface{}) bool {
if c, ok := v.(*conn); ok && c != nil {
c.Close()
}
return true
})
}
func (p *connPool) idleCheck() {
ticker := time.NewTicker(p.ttl)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
if !atomic.CompareAndSwapInt32(&c.fresh, 1, 0) { size := 0
c.Close() idles := 0
return p.m.Range(func(key, value interface{}) bool {
c, ok := value.(*conn)
if !ok || c == nil {
p.Delete(key)
return true
}
size++
if c.IsIdle() {
idles++
p.Delete(key)
c.Close()
return true
}
c.SetIdle(true)
return true
})
if idles > 0 {
p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles)
} }
case <-c.closed: case <-p.closed:
return return
} }
} }

View File

@ -2,9 +2,8 @@ package udp
import ( import (
"net" "net"
"sync"
"sync/atomic"
"github.com/go-gost/gost/pkg/internal/bufpool"
"github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
@ -16,13 +15,14 @@ func init() {
} }
type udpListener struct { type udpListener struct {
addr string addr string
md metadata md metadata
conn net.PacketConn conn net.PacketConn
connChan chan net.Conn connChan chan net.Conn
errChan chan error errChan chan error
connPool connPool closeChan chan struct{}
logger logger.Logger connPool *connPool
logger logger.Logger
} }
func NewListener(opts ...listener.Option) listener.Listener { func NewListener(opts ...listener.Option) listener.Listener {
@ -31,8 +31,10 @@ func NewListener(opts ...listener.Option) listener.Listener {
opt(options) opt(options)
} }
return &udpListener{ return &udpListener{
addr: options.Addr, addr: options.Addr,
logger: options.Logger, errChan: make(chan error, 1),
closeChan: make(chan struct{}),
logger: options.Logger,
} }
} }
@ -46,15 +48,13 @@ func (l *udpListener) Init(md md.Metadata) (err error) {
return return
} }
var conn net.PacketConn l.conn, err = net.ListenUDP("udp", laddr)
conn, err = net.ListenUDP("udp", laddr)
if err != nil { if err != nil {
return return
} }
l.conn = conn
l.connChan = make(chan net.Conn, l.md.connQueueSize) l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.errChan = make(chan error, 1) l.connPool = newConnPool(l.md.ttl).WithLogger(l.logger)
go l.listenLoop() go l.listenLoop()
@ -74,12 +74,14 @@ func (l *udpListener) Accept() (conn net.Conn, err error) {
} }
func (l *udpListener) Close() error { func (l *udpListener) Close() error {
err := l.conn.Close() select {
l.connPool.Range(func(k interface{}, v *serverConn) bool { case <-l.closeChan:
v.Close() return nil
return true default:
}) close(l.closeChan)
return err l.connPool.Close()
return l.conn.Close()
}
} }
func (l *udpListener) Addr() net.Addr { func (l *udpListener) Addr() net.Addr {
@ -88,43 +90,43 @@ func (l *udpListener) Addr() net.Addr {
func (l *udpListener) listenLoop() { func (l *udpListener) listenLoop() {
for { for {
b := make([]byte, l.md.readBufferSize) b := bufpool.Get(l.md.readBufferSize)
n, raddr, err := l.conn.ReadFrom(b) n, raddr, err := l.conn.ReadFrom(b)
if err != nil { if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err l.errChan <- err
close(l.errChan) close(l.errChan)
return return
} }
conn, ok := l.connPool.Get(raddr.String()) c := l.getConn(raddr)
if !ok { if c == nil {
conn = newServerConn(l.conn, raddr, bufpool.Put(b)
&serverConnConfig{ continue
ttl: l.md.ttl,
qsize: l.md.readQueueSize,
onClose: func() {
l.connPool.Delete(raddr.String())
},
})
select {
case l.connChan <- conn:
l.connPool.Set(raddr.String(), conn)
default:
conn.Close()
l.logger.Error("connection queue is full")
}
} }
if err := conn.send(b[:n]); err != nil { if err := c.Queue(b[:n]); err != nil {
l.logger.Warn("data discarded:", err) l.logger.Warn("data discarded: ", err)
} }
l.logger.Debug("recv", n)
} }
} }
func (l *udpListener) getConn(addr net.Addr) *conn {
c, ok := l.connPool.Get(addr.String())
if !ok {
c = newConn(l.conn, addr, l.md.readQueueSize)
select {
case l.connChan <- c:
l.connPool.Set(addr.String(), c)
default:
c.Close()
l.logger.Warnf("connection queue is full, client %s discarded", addr.String())
return nil
}
}
return c
}
func (l *udpListener) parseMetadata(md md.Metadata) (err error) { func (l *udpListener) parseMetadata(md md.Metadata) (err error) {
l.md.ttl = md.GetDuration(ttl) l.md.ttl = md.GetDuration(ttl)
if l.md.ttl <= 0 { if l.md.ttl <= 0 {
@ -147,36 +149,3 @@ func (l *udpListener) parseMetadata(md md.Metadata) (err error) {
return return
} }
type connPool struct {
size int64
m sync.Map
}
func (p *connPool) Get(key interface{}) (conn *serverConn, ok bool) {
v, ok := p.m.Load(key)
if ok {
conn, ok = v.(*serverConn)
}
return
}
func (p *connPool) Set(key interface{}, conn *serverConn) {
p.m.Store(key, conn)
atomic.AddInt64(&p.size, 1)
}
func (p *connPool) Delete(key interface{}) {
p.m.Delete(key)
atomic.AddInt64(&p.size, -1)
}
func (p *connPool) Range(f func(key interface{}, value *serverConn) bool) {
p.m.Range(func(k, v interface{}) bool {
return f(k, v.(*serverConn))
})
}
func (p *connPool) Size() int64 {
return atomic.LoadInt64(&p.size)
}

View File

@ -4,7 +4,7 @@ import "time"
const ( const (
defaultTTL = 60 * time.Second defaultTTL = 60 * time.Second
defaultReadBufferSize = 1024 defaultReadBufferSize = 4096
defaultReadQueueSize = 128 defaultReadQueueSize = 128
defaultConnQueueSize = 128 defaultConnQueueSize = 128
) )