add ssu connector

This commit is contained in:
ginuerzh 2021-11-09 23:34:19 +08:00
parent 92dc87830f
commit cae199dbd9
29 changed files with 1031 additions and 678 deletions

View File

@ -45,7 +45,7 @@ services:
# bypass: bypass01
- name: ssu
url: "ss://chacha20:gost@:8000"
addr: ":8338"
addr: ":8388"
handler:
type: ssu
metadata:
@ -54,7 +54,8 @@ services:
readTimeout: 5s
retry: 3
listener:
type: udp
type: tcp
# chain: chain-ssu
- name: socks5+tcp
url: "socks5://gost:gost@:1080"
addr: ":1080"
@ -71,7 +72,7 @@ services:
type: tcp
metadata:
keepAlive: 15s
# chain: chain-socks5
chain: chain-socks5
# bypass: bypass01
- name: socks5+tcp
url: "socks5://gost:gost@:1080"
@ -178,6 +179,20 @@ chains:
dialer:
type: tcp
metadata: {}
- name: chain-ssu
hops:
- name: hop01
nodes:
- name: node01
addr: ":8339"
url: "http://gost:gost@:8081"
# bypass: bypass01
connector:
type: ssu
metadata: {}
dialer:
type: udp
metadata: {}
bypasses:
- name: bypass01

View File

@ -6,9 +6,11 @@ import (
_ "github.com/go-gost/gost/pkg/connector/socks/v4"
_ "github.com/go-gost/gost/pkg/connector/socks/v5"
_ "github.com/go-gost/gost/pkg/connector/ss"
_ "github.com/go-gost/gost/pkg/connector/ssu"
// Register dialers
_ "github.com/go-gost/gost/pkg/dialer/tcp"
_ "github.com/go-gost/gost/pkg/dialer/udp"
// Register handlers
_ "github.com/go-gost/gost/pkg/handler/http"

2
go.mod
View File

@ -7,7 +7,7 @@ require (
github.com/coreos/go-iptables v0.5.0 // indirect
github.com/ginuerzh/tls-dissector v0.0.2-0.20201202075250-98fa925912da
github.com/go-gost/gosocks4 v0.0.1
github.com/go-gost/gosocks5 v0.3.1-0.20211108125245-019dfd6b3aea
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09
github.com/gobwas/glob v0.2.3
github.com/golang/snappy v0.0.3
github.com/google/gopacket v1.1.19 // indirect

2
go.sum
View File

@ -125,6 +125,8 @@ github.com/go-gost/gosocks5 v0.3.1-0.20211108032632-bbfd2de9a32d h1:mjoFToMUWNN0
github.com/go-gost/gosocks5 v0.3.1-0.20211108032632-bbfd2de9a32d/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4=
github.com/go-gost/gosocks5 v0.3.1-0.20211108125245-019dfd6b3aea h1:mrm6bMpdxBvInvBuDbUaAQWV60r/PaByLIG9fQJEEIc=
github.com/go-gost/gosocks5 v0.3.1-0.20211108125245-019dfd6b3aea/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4=
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04=
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=

View File

@ -43,6 +43,21 @@ func (c *httpConnector) Init(md md.Metadata) (err error) {
}
func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) {
c.logger = c.logger.WithFields(map[string]interface{}{
"local": conn.LocalAddr().String(),
"remote": conn.RemoteAddr().String(),
"network": network,
"address": address,
})
switch network {
case "tcp", "tcp4", "tcp6":
default:
err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network)
c.logger.Error(err)
return nil, err
}
req := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: address},
@ -56,11 +71,6 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
}
req.Header.Set("Proxy-Connection", "keep-alive")
c.logger = c.logger.WithFields(map[string]interface{}{
"local": conn.LocalAddr().String(),
"remote": conn.RemoteAddr().String(),
"target": address,
})
c.logger.Infof("connect: ", address)
if user := c.md.User; user != nil {

View File

@ -42,10 +42,20 @@ func (c *socks4Connector) Init(md md.Metadata) (err error) {
func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) {
c.logger = c.logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
"target": address,
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
"network": network,
"address": address,
})
switch network {
case "tcp", "tcp4", "tcp6":
default:
err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network)
c.logger.Error(err)
return nil, err
}
c.logger.Info("connect: ", address)
var addr *gosocks4.Addr
@ -87,19 +97,14 @@ func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, a
c.logger.Error(err)
return nil, err
}
if c.logger.IsLevelEnabled(logger.DebugLevel) {
c.logger.Debug(req)
}
c.logger.Debug(req)
reply, err := gosocks4.ReadReply(conn)
if err != nil {
c.logger.Error(err)
return nil, err
}
if c.logger.IsLevelEnabled(logger.DebugLevel) {
c.logger.Debug(reply)
}
c.logger.Debug(reply)
if reply.Code != gosocks4.Granted {
return nil, fmt.Errorf("error: %d", reply.Code)

View File

@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"strings"
@ -79,6 +80,7 @@ func (c *socks5Connector) Handshake(ctx context.Context, conn net.Conn) (net.Con
cc := gosocks5.ClientConn(conn, c.selector)
if err := cc.Handleshake(); err != nil {
c.logger.Error(err)
return nil, err
}
@ -87,12 +89,22 @@ func (c *socks5Connector) Handshake(ctx context.Context, conn net.Conn) (net.Con
func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) {
c.logger = c.logger.WithFields(map[string]interface{}{
"target": address,
"network": network,
"address": address,
})
switch network {
case "tcp", "tcp4", "tcp6":
default:
err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network)
c.logger.Error(err)
return nil, err
}
c.logger.Info("connect: ", address)
addr, err := gosocks5.NewAddr(address)
if err != nil {
addr := gosocks5.Addr{}
if err := addr.ParseFrom(address); err != nil {
c.logger.Error(err)
return nil, err
}
@ -102,25 +114,19 @@ func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, a
defer conn.SetDeadline(time.Time{})
}
req := gosocks5.NewRequest(gosocks5.CmdConnect, addr)
req := gosocks5.NewRequest(gosocks5.CmdConnect, &addr)
if err := req.Write(conn); err != nil {
c.logger.Error(err)
return nil, err
}
if c.logger.IsLevelEnabled(logger.DebugLevel) {
c.logger.Debug(req)
}
c.logger.Debug(req)
reply, err := gosocks5.ReadReply(conn)
if err != nil {
c.logger.Error(err)
return nil, err
}
if c.logger.IsLevelEnabled(logger.DebugLevel) {
c.logger.Debug(reply)
}
c.logger.Debug(reply)
if reply.Rep != gosocks5.Succeeded {
return nil, errors.New("service unavailable")

View File

@ -18,9 +18,7 @@ type clientSelector struct {
}
func (s *clientSelector) Methods() []uint8 {
if s.logger.IsLevelEnabled(logger.DebugLevel) {
s.logger.Debug("methods: ", s.methods)
}
s.logger.Debug("methods: ", s.methods)
return s.methods
}
@ -33,9 +31,7 @@ func (s *clientSelector) Select(methods ...uint8) (method uint8) {
}
func (s *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) {
if s.logger.IsLevelEnabled(logger.DebugLevel) {
s.logger.Debug("method selected: ", method)
}
s.logger.Debug("method selected: ", method)
switch method {
case socks.MethodTLS:
@ -57,18 +53,14 @@ func (s *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, erro
s.logger.Error(err)
return nil, err
}
if s.logger.IsLevelEnabled(logger.DebugLevel) {
s.logger.Debug(req)
}
s.logger.Debug(req)
resp, err := gosocks5.ReadUserPassResponse(conn)
if err != nil {
s.logger.Error(err)
return nil, err
}
if s.logger.IsLevelEnabled(logger.DebugLevel) {
s.logger.Debug(resp)
}
s.logger.Debug(resp)
if resp.Status != gosocks5.Succeeded {
return nil, gosocks5.ErrAuthFailure

View File

@ -2,6 +2,7 @@ package ss
import (
"context"
"fmt"
"net"
"time"
@ -40,21 +41,30 @@ func (c *ssConnector) Init(md md.Metadata) (err error) {
func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) {
c.logger = c.logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
"target": address,
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
"network": network,
"address": address,
})
switch network {
case "tcp", "tcp4", "tcp6":
default:
err := fmt.Errorf("network %s unsupported, should be tcp, tcp4 or tcp6", network)
c.logger.Error(err)
return nil, err
}
c.logger.Infof("connect: ", address)
socksAddr, err := gosocks5.NewAddr(address)
if err != nil {
c.logger.Error("parse addr: ", err)
addr := gosocks5.Addr{}
if err := addr.ParseFrom(address); err != nil {
c.logger.Error(err)
return nil, err
}
rawaddr := bufpool.Get(512)
defer bufpool.Put(rawaddr)
n, err := socksAddr.Encode(rawaddr)
n, err := addr.Encode(rawaddr)
if err != nil {
c.logger.Error("encoding addr: ", err)
return nil, err

View File

@ -0,0 +1,105 @@
package ssu
import (
"context"
"fmt"
"net"
"time"
"github.com/go-gost/gost/pkg/connector"
"github.com/go-gost/gost/pkg/internal/utils/socks"
"github.com/go-gost/gost/pkg/internal/utils/ss"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
)
func init() {
registry.RegiserConnector("ssu", NewConnector)
}
type ssuConnector struct {
md metadata
logger logger.Logger
}
func NewConnector(opts ...connector.Option) connector.Connector {
options := &connector.Options{}
for _, opt := range opts {
opt(options)
}
return &ssuConnector{
logger: options.Logger,
}
}
func (c *ssuConnector) Init(md md.Metadata) (err error) {
return c.parseMetadata(md)
}
func (c *ssuConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) {
c.logger = c.logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
"network": network,
"address": address,
})
switch network {
case "udp", "udp4", "udp6":
default:
err := fmt.Errorf("network %s unsupported, should be udp, udp4 or udp6", network)
c.logger.Error(err)
return nil, err
}
c.logger.Info("connect: ", address)
if c.md.connectTimeout > 0 {
conn.SetDeadline(time.Now().Add(c.md.connectTimeout))
defer conn.SetDeadline(time.Time{})
}
taddr, _ := net.ResolveUDPAddr(network, address)
if taddr == nil {
taddr = &net.UDPAddr{}
}
pc, ok := conn.(net.PacketConn)
if ok {
if c.md.cipher != nil {
pc = c.md.cipher.PacketConn(pc)
}
return ss.UDPClientConn(pc, conn.RemoteAddr(), taddr, c.md.bufferSize), nil
}
return socks.UDPTunClientConn(conn, taddr), nil
}
func (c *ssuConnector) parseMetadata(md md.Metadata) (err error) {
c.md.cipher, err = ss.ShadowCipher(
md.GetString(method),
md.GetString(password),
md.GetString(key),
)
if err != nil {
return
}
c.md.connectTimeout = md.GetDuration(connectTimeout)
c.md.bufferSize = md.GetInt(bufferSize)
if c.md.bufferSize > 0 {
if c.md.bufferSize < 512 {
c.md.bufferSize = 512
}
if c.md.bufferSize > 65*1024 {
c.md.bufferSize = 65 * 1024
}
} else {
c.md.bufferSize = 4096
}
return
}

View File

@ -0,0 +1,21 @@
package ssu
import (
"time"
"github.com/shadowsocks/go-shadowsocks2/core"
)
const (
method = "method"
password = "password"
key = "key"
connectTimeout = "timeout"
bufferSize = "bufferSize"
)
type metadata struct {
cipher core.Cipher
connectTimeout time.Duration
bufferSize int
}

View File

@ -46,12 +46,10 @@ func (d *tcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp
if err != nil {
d.logger.Error(err)
} else {
if d.logger.IsLevelEnabled(logger.DebugLevel) {
d.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial with dial func")
}
d.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial with dial func")
}
return conn, err
}
@ -61,12 +59,10 @@ func (d *tcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp
if err != nil {
d.logger.Error(err)
} else {
if d.logger.IsLevelEnabled(logger.DebugLevel) {
d.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial direct")
}
d.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial direct")
}
return conn, err
}

17
pkg/dialer/udp/conn.go Normal file
View File

@ -0,0 +1,17 @@
package udp
import "net"
type conn struct {
*net.UDPConn
}
func (c *conn) WriteTo(b []byte, addr net.Addr) (int, error) {
return c.UDPConn.Write(b)
}
func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
n, err = c.UDPConn.Read(b)
addr = c.RemoteAddr()
return
}

54
pkg/dialer/udp/dialer.go Normal file
View File

@ -0,0 +1,54 @@
package udp
import (
"context"
"net"
"github.com/go-gost/gost/pkg/dialer"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
)
func init() {
registry.RegisterDialer("udp", NewDialer)
}
type udpDialer struct {
md metadata
logger logger.Logger
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &udpDialer{
logger: options.Logger,
}
}
func (d *udpDialer) Init(md md.Metadata) (err error) {
return d.parseMetadata(md)
}
func (d *udpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) {
taddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
c, err := net.DialUDP("udp", nil, taddr)
if err != nil {
return nil, err
}
return &conn{
UDPConn: c,
}, nil
}
func (d *udpDialer) parseMetadata(md md.Metadata) (err error) {
return
}

View File

@ -0,0 +1,15 @@
package udp
import "time"
const (
dialTimeout = "dialTimeout"
)
const (
defaultDialTimeout = 5 * time.Second
)
type metadata struct {
dialTimeout time.Duration
}

View File

@ -70,19 +70,15 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) {
h.logger.Error(err)
return
}
conn.SetReadDeadline(time.Time{})
h.logger.Debug(req)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(req)
}
conn.SetReadDeadline(time.Time{})
if h.md.authenticator != nil &&
!h.md.authenticator.Authenticate(string(req.Userid), "") {
resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil)
resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(resp)
}
h.logger.Debug(resp)
return
}
@ -107,9 +103,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
if h.bypass != nil && h.bypass.Contains(addr) {
resp := gosocks4.NewReply(gosocks4.Rejected, nil)
resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(resp)
}
h.logger.Debug(resp)
h.logger.Info("bypass: ", addr)
return
}
@ -122,9 +116,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
if err != nil {
resp := gosocks4.NewReply(gosocks4.Failed, nil)
resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(resp)
}
h.logger.Debug(resp)
return
}
@ -135,9 +127,7 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
h.logger.Error(err)
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(resp)
}
h.logger.Debug(resp)
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr)

View File

@ -7,7 +7,6 @@ import (
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
)
func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, req *gosocks5.Request) {
@ -33,9 +32,7 @@ func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, req *goso
if err != nil {
resp := gosocks5.NewReply(gosocks5.Failure, nil)
resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(resp)
}
h.logger.Debug(resp)
return
}
defer cc.Close()
@ -45,9 +42,7 @@ func (h *socks5Handler) handleBind(ctx context.Context, conn net.Conn, req *goso
h.logger.Error(err)
resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil)
resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(resp)
}
h.logger.Debug(resp)
return
}
@ -65,32 +60,25 @@ func (h *socks5Handler) bindLocal(ctx context.Context, conn net.Conn, addr strin
if err := reply.Write(conn); err != nil {
h.logger.Error(err)
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(reply.String())
}
h.logger.Debug(reply)
return
}
socksAddr, err := gosocks5.NewAddr(ln.Addr().String())
if err != nil {
socksAddr := gosocks5.Addr{}
if err := socksAddr.ParseFrom(ln.Addr().String()); err != nil {
h.logger.Warn(err)
socksAddr = &gosocks5.Addr{
Type: gosocks5.AddrIPv4,
}
}
// Issue: may not reachable when host has multi-interface
socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String())
socksAddr.Type = 0
reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr)
reply := gosocks5.NewReply(gosocks5.Succeeded, &socksAddr)
if err := reply.Write(conn); err != nil {
h.logger.Error(err)
ln.Close()
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(reply.String())
}
h.logger.Debug(reply)
h.logger = h.logger.WithFields(map[string]interface{}{
"bind": socksAddr.String(),
@ -143,14 +131,13 @@ func (h *socks5Handler) serveBind(ctx context.Context, conn net.Conn, ln net.Lis
}
defer rc.Close()
raddr, _ := gosocks5.NewAddr(rc.RemoteAddr().String())
reply := gosocks5.NewReply(gosocks5.Succeeded, raddr)
raddr := gosocks5.Addr{}
raddr.ParseFrom(rc.RemoteAddr().String())
reply := gosocks5.NewReply(gosocks5.Succeeded, &raddr)
if err := reply.Write(pc2); err != nil {
h.logger.Error(err)
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(reply.String())
}
h.logger.Debug(reply)
h.logger.Infof("peer accepted: %s", raddr.String())
start := time.Now()

View File

@ -7,7 +7,6 @@ import (
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
)
func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr string) {
@ -20,9 +19,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr s
if h.bypass != nil && h.bypass.Contains(addr) {
resp := gosocks5.NewReply(gosocks5.NotAllowed, nil)
resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(resp)
}
h.logger.Debug(resp)
h.logger.Info("bypass: ", addr)
return
}
@ -35,9 +32,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr s
if err != nil {
resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil)
resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(resp)
}
h.logger.Debug(resp)
return
}
@ -48,9 +43,7 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, addr s
h.logger.Error(err)
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(resp)
}
h.logger.Debug(resp)
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr)

View File

@ -83,12 +83,9 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) {
h.logger.Error(err)
return
}
h.logger.Debug(req)
conn.SetReadDeadline(time.Time{})
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(req)
}
switch req.Cmd {
case gosocks5.CmdConnect:
h.handleConnect(ctx, conn, req.Addr.String())
@ -104,9 +101,7 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn) {
h.logger.Errorf("unknown cmd: %d", req.Cmd)
resp := gosocks5.NewReply(gosocks5.CmdUnsupported, nil)
resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(resp)
}
h.logger.Debug(resp)
return
}
}

View File

@ -8,7 +8,6 @@ import (
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/internal/utils/mux"
"github.com/go-gost/gost/pkg/logger"
)
func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, req *gosocks5.Request) {
@ -34,9 +33,7 @@ func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, req *g
if err != nil {
resp := gosocks5.NewReply(gosocks5.Failure, nil)
resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(resp)
}
h.logger.Debug(resp)
return
}
defer cc.Close()
@ -46,9 +43,7 @@ func (h *socks5Handler) handleMuxBind(ctx context.Context, conn net.Conn, req *g
h.logger.Error(err)
resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil)
resp.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(resp)
}
h.logger.Debug(resp)
return
}
@ -71,32 +66,26 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, addr st
if err := reply.Write(conn); err != nil {
h.logger.Error(err)
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(reply.String())
}
h.logger.Debug(reply)
return
}
socksAddr, err := gosocks5.NewAddr(ln.Addr().String())
socksAddr := gosocks5.Addr{}
socksAddr.ParseFrom(ln.Addr().String())
if err != nil {
h.logger.Warn(err)
socksAddr = &gosocks5.Addr{
Type: gosocks5.AddrIPv4,
}
}
// Issue: may not reachable when host has multi-interface
socksAddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String())
socksAddr.Type = 0
reply := gosocks5.NewReply(gosocks5.Succeeded, socksAddr)
reply := gosocks5.NewReply(gosocks5.Succeeded, &socksAddr)
if err := reply.Write(conn); err != nil {
h.logger.Error(err)
ln.Close()
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(reply.String())
}
h.logger.Debug(reply)
h.logger = h.logger.WithFields(map[string]interface{}{
"bind": socksAddr.String(),

View File

@ -23,9 +23,7 @@ func (selector *serverSelector) Methods() []uint8 {
}
func (s *serverSelector) Select(methods ...uint8) (method uint8) {
if s.logger.IsLevelEnabled(logger.DebugLevel) {
s.logger.Debugf("%d %d %v", gosocks5.Ver5, len(methods), methods)
}
s.logger.Debugf("%d %d %v", gosocks5.Ver5, len(methods), methods)
method = gosocks5.MethodNoAuth
for _, m := range methods {
if m == socks.MethodTLS && !s.noTLS {
@ -48,9 +46,7 @@ func (s *serverSelector) Select(methods ...uint8) (method uint8) {
}
func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) {
if s.logger.IsLevelEnabled(logger.DebugLevel) {
s.logger.Debugf("%d %d", gosocks5.Ver5, method)
}
s.logger.Debugf("%d %d", gosocks5.Ver5, method)
switch method {
case socks.MethodTLS:
conn = tls.Server(conn, s.TLSConfig)
@ -65,9 +61,7 @@ func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, erro
s.logger.Error(err)
return nil, err
}
if s.logger.IsLevelEnabled(logger.DebugLevel) {
s.logger.Debug(req.String())
}
s.logger.Debug(req)
if s.Authenticator != nil &&
!s.Authenticator.Authenticate(req.Username, req.Password) {
@ -76,9 +70,8 @@ func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, erro
s.logger.Error(err)
return nil, err
}
if s.logger.IsLevelEnabled(logger.DebugLevel) {
s.logger.Info(resp.String())
}
s.logger.Info(resp)
return nil, gosocks5.ErrAuthFailure
}
@ -87,9 +80,8 @@ func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, erro
s.logger.Error(err)
return nil, err
}
if s.logger.IsLevelEnabled(logger.DebugLevel) {
s.logger.Debug(resp.String())
}
s.logger.Debug(resp)
case gosocks5.MethodNoAcceptable:
return nil, gosocks5.ErrBadMethod
}

View File

@ -1,7 +1,6 @@
package v5
import (
"bytes"
"context"
"errors"
"io"
@ -13,7 +12,6 @@ import (
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/internal/bufpool"
"github.com/go-gost/gost/pkg/internal/utils/socks"
"github.com/go-gost/gost/pkg/logger"
)
func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosocks5.Request) {
@ -26,27 +24,21 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosoc
h.logger.Error(err)
reply := gosocks5.NewReply(gosocks5.Failure, nil)
reply.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(reply)
}
h.logger.Debug(reply)
return
}
defer relay.Close()
saddr, _ := gosocks5.NewAddr(relay.LocalAddr().String())
if saddr == nil {
saddr = &gosocks5.Addr{}
}
saddr := gosocks5.Addr{}
saddr.ParseFrom(relay.LocalAddr().String())
saddr.Type = 0
saddr.Host, _, _ = net.SplitHostPort(conn.LocalAddr().String()) // replace the IP to the out-going interface's
reply := gosocks5.NewReply(gosocks5.Succeeded, saddr)
reply := gosocks5.NewReply(gosocks5.Succeeded, &saddr)
if err := reply.Write(conn); err != nil {
h.logger.Error(err)
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(reply)
}
h.logger.Debug(reply)
h.logger = h.logger.WithFields(map[string]interface{}{
"bind": saddr.String(),
@ -62,7 +54,10 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosoc
}
defer peer.Close()
go h.relayUDP(relay, peer)
go h.relayUDP(
socks.NewUDPConn(relay, h.md.udpBufferSize),
peer,
)
} else {
tun, err := h.getUDPTun(ctx)
if err != nil {
@ -71,15 +66,18 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn, req *gosoc
}
defer tun.Close()
go h.tunnelClientUDP(relay, tun)
go h.tunnelClientUDP(
socks.NewUDPConn(relay, h.md.udpBufferSize),
socks.UDPTunClientConn(tun, nil),
)
}
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), saddr)
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), &saddr)
io.Copy(ioutil.Discard, conn)
h.logger.
WithFields(map[string]interface{}{"duration": time.Since(t)}).
Infof("%s >-< %s", conn.RemoteAddr(), saddr)
Infof("%s >-< %s", conn.RemoteAddr(), &saddr)
}
func (h *socks5Handler) getUDPTun(ctx context.Context) (conn net.Conn, err error) {
@ -108,17 +106,13 @@ func (h *socks5Handler) getUDPTun(ctx context.Context) (conn net.Conn, err error
if err = req.Write(conn); err != nil {
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(req)
}
h.logger.Debug(req)
reply, err := gosocks5.ReadReply(conn)
if err != nil {
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(reply)
}
h.logger.Debug(reply)
if reply.Rep != gosocks5.Succeeded {
err = errors.New("UDP associate failed")
@ -128,119 +122,72 @@ func (h *socks5Handler) getUDPTun(ctx context.Context) (conn net.Conn, err error
return
}
func (h *socks5Handler) tunnelClientUDP(c net.PacketConn, tunnel net.Conn) (err error) {
func (h *socks5Handler) tunnelClientUDP(c, tun net.PacketConn) (err error) {
bufSize := h.md.udpBufferSize
errc := make(chan error, 2)
var clientAddr net.Addr
go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
for {
n, laddr, err := c.ReadFrom(b)
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := c.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := tun.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s >>> %s data: %d",
tun.LocalAddr(), raddr, n)
return nil
}()
if err != nil {
errc <- err
return
}
if clientAddr == nil {
clientAddr = laddr
}
var addr gosocks5.Addr
header := gosocks5.UDPHeader{
Addr: &addr,
}
hlen, err := header.ReadFrom(bytes.NewReader(b[:n]))
if err != nil {
errc <- err
return
}
raddr, err := net.ResolveUDPAddr("udp", addr.String())
if err != nil {
continue // drop silently
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
continue // bypass
}
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: b[hlen:n],
}
dgram.Header.Rsv = uint16(len(dgram.Data))
if _, err := dgram.WriteTo(tunnel); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s >>> %s: %v data: %d",
clientAddr, raddr, b[:hlen], len(dgram.Data))
}
}
}()
go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
const dataPos = 262
for {
addr := gosocks5.Addr{}
header := gosocks5.UDPHeader{
Addr: &addr,
}
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := tun.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := c.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s <<< %s data: %d",
tun.LocalAddr(), raddr, n)
return nil
}()
data := b[dataPos:]
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: data,
}
_, err := dgram.ReadFrom(tunnel)
if err != nil {
errc <- err
return
}
// NOTE: the dgram.Data may be reallocated if the provided buffer is too short,
// we drop it for simplicity. As this occurs, you should enlarge the buffer size.
if len(dgram.Data) > len(data) {
h.logger.Warnf("buffer too short, dropped")
continue
}
// pipe from tunnel to relay
if clientAddr == nil {
h.logger.Warnf("ignore unexpected peer from %s", addr)
continue
}
raddr := addr.String()
if h.bypass != nil && h.bypass.Contains(raddr) {
h.logger.Warn("bypass: ", raddr)
continue // bypass
}
addrLen := addr.Length()
addr.Encode(b[dataPos-addrLen : dataPos])
hlen := addrLen + 3
if _, err := c.WriteTo(b[dataPos-hlen:dataPos+len(dgram.Data)], clientAddr); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s <<< %s: %v data: %d",
clientAddr, addr.String(), b[dataPos-hlen:dataPos], len(dgram.Data))
}
}
}()
@ -251,91 +198,69 @@ func (h *socks5Handler) relayUDP(c, peer net.PacketConn) (err error) {
bufSize := h.md.udpBufferSize
errc := make(chan error, 2)
var clientAddr net.Addr
go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
for {
n, laddr, err := c.ReadFrom(b)
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := c.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := peer.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s >>> %s data: %d",
peer.LocalAddr(), raddr, n)
return nil
}()
if err != nil {
errc <- err
return
}
if clientAddr == nil {
clientAddr = laddr
}
var addr gosocks5.Addr
header := gosocks5.UDPHeader{
Addr: &addr,
}
hlen, err := header.ReadFrom(bytes.NewReader(b[:n]))
if err != nil {
errc <- err
return
}
raddr, err := net.ResolveUDPAddr("udp", addr.String())
if err != nil {
continue // drop silently
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
continue // bypass
}
data := b[hlen:n]
if _, err := peer.WriteTo(data, raddr); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s >>> %s: %v data: %d",
clientAddr, raddr, b[:hlen], len(data))
}
}
}()
go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
const dataPos = 262
for {
n, raddr, err := peer.ReadFrom(b[dataPos:])
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := peer.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := c.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s <<< %s data: %d",
peer.LocalAddr(), raddr, n)
return nil
}()
if err != nil {
errc <- err
return
}
if clientAddr == nil {
continue
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
continue // bypass
}
socksAddr, _ := gosocks5.NewAddr(raddr.String())
if socksAddr == nil {
socksAddr = &gosocks5.Addr{}
}
addrLen := socksAddr.Length()
socksAddr.Encode(b[dataPos-addrLen : dataPos])
hlen := addrLen + 3
if _, err := c.WriteTo(b[dataPos-hlen:dataPos+n], clientAddr); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s <<< %s: %v data: %d",
clientAddr, raddr, b[dataPos-hlen:dataPos], n)
}
}
}()

View File

@ -8,7 +8,7 @@ import (
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/internal/bufpool"
"github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/gost/pkg/internal/utils/socks"
)
func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *gosocks5.Request) {
@ -35,9 +35,7 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go
h.logger.Error(err)
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(reply)
}
h.logger.Debug(reply)
h.logger = h.logger.WithFields(map[string]interface{}{
"bind": saddr.String(),
@ -45,7 +43,10 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), saddr)
h.tunnelServerUDP(conn, relay)
h.tunnelServerUDP(
socks.UDPTunServerConn(conn),
relay,
)
h.logger.
WithFields(map[string]interface{}{
"duration": time.Since(t),
@ -64,9 +65,7 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go
h.logger.Error(err)
reply := gosocks5.NewReply(gosocks5.Failure, nil)
reply.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(reply)
}
h.logger.Debug(reply)
return
}
defer cc.Close()
@ -76,9 +75,7 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go
h.logger.Error(err)
reply := gosocks5.NewReply(gosocks5.Failure, nil)
reply.Write(conn)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(reply)
}
h.logger.Debug(reply)
}
t := time.Now()
@ -91,97 +88,72 @@ func (h *socks5Handler) handleUDPTun(ctx context.Context, conn net.Conn, req *go
Infof("%s >-< %s", conn.RemoteAddr(), cc.RemoteAddr())
}
func (h *socks5Handler) tunnelServerUDP(tunnel net.Conn, c net.PacketConn) (err error) {
func (h *socks5Handler) tunnelServerUDP(tunnel, c net.PacketConn) (err error) {
bufSize := h.md.udpBufferSize
errc := make(chan error, 2)
go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
const dataPos = 262
for {
addr := gosocks5.Addr{}
header := gosocks5.UDPHeader{
Addr: &addr,
}
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := tunnel.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := c.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s >>> %s data: %d",
c.LocalAddr(), raddr, n)
return nil
}()
data := b[dataPos:]
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: data,
}
_, err := dgram.ReadFrom(tunnel)
if err != nil {
errc <- err
return
}
// NOTE: the dgram.Data may be reallocated if the provided buffer is too short,
// we drop it for simplicity. As this occurs, you should enlarge the buffer size.
if len(dgram.Data) > len(data) {
h.logger.Warnf("buffer too short, dropped")
continue
}
raddr, err := net.ResolveUDPAddr("udp", addr.String())
if err != nil {
continue // drop silently
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr.String())
continue // bypass
}
if _, err := c.WriteTo(dgram.Data, raddr); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s >>> %s: %v data: %d",
tunnel.RemoteAddr(), raddr, header.String(), len(dgram.Data))
}
}
}()
go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
for {
n, raddr, err := c.ReadFrom(b)
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := c.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
if _, err := tunnel.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s <<< %s data: %d",
c.LocalAddr(), raddr, n)
return nil
}()
if err != nil {
errc <- err
return
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr.String())
continue // bypass
}
addr, _ := gosocks5.NewAddr(raddr.String())
if addr == nil {
addr = &gosocks5.Addr{}
}
header := gosocks5.UDPHeader{
Rsv: uint16(n),
Addr: addr,
}
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: b[:n],
}
if _, err := dgram.WriteTo(tunnel); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s <<< %s: %v data: %d",
tunnel.RemoteAddr(), raddr, header.String(), len(dgram.Data))
}
}
}()

View File

@ -1,16 +1,15 @@
package ssu
import (
"bytes"
"context"
"net"
"time"
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/bypass"
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/internal/bufpool"
"github.com/go-gost/gost/pkg/internal/utils/socks"
"github.com/go-gost/gost/pkg/internal/utils/ss"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
@ -91,7 +90,10 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) {
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
h.relayPacket(pc, cc)
h.relayPacket(
ss.UDPServerConn(pc, conn.RemoteAddr(), h.md.bufferSize),
cc,
)
h.logger.
WithFields(map[string]interface{}{"duration": time.Since(t)}).
Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr())
@ -104,7 +106,7 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) {
t := time.Now()
h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr())
h.tunnelUDP(conn, cc)
h.tunnelUDP(socks.UDPTunServerConn(conn), cc)
h.logger.
WithFields(map[string]interface{}{"duration": time.Since(t)}).
Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr())
@ -112,47 +114,30 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) {
func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
bufSize := h.md.bufferSize
errc := make(chan error, 2)
var clientAddr net.Addr
go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
for {
err := func() error {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, addr, err := pc1.ReadFrom(b)
if err != nil {
return err
}
if clientAddr == nil {
clientAddr = addr
}
rb := bytes.NewBuffer(b[:n])
saddr := gosocks5.Addr{}
if _, err := saddr.ReadFrom(rb); err != nil {
return err
}
taddr, err := net.ResolveUDPAddr("udp", saddr.String())
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(taddr.String()) {
h.logger.Warn("bypass: ", taddr)
if h.bypass != nil && h.bypass.Contains(addr.String()) {
h.logger.Warn("bypass: ", addr)
return nil
}
if _, err = pc2.WriteTo(rb.Bytes(), taddr); err != nil {
if _, err = pc2.WriteTo(b[:n], addr); err != nil {
return err
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s >>> %s: %v, data: %d",
addr, taddr, saddr.String(), rb.Len())
}
h.logger.Debugf("%s >>> %s data: %d",
pc2.LocalAddr(), addr, n)
return nil
}()
@ -164,41 +149,27 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
}()
go func() {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
const dataPos = 259
for {
err := func() error {
n, raddr, err := pc2.ReadFrom(b[dataPos:])
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
n, raddr, err := pc2.ReadFrom(b)
if err != nil {
return err
}
if clientAddr == nil {
return nil
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr)
return nil
}
socksAddr, _ := gosocks5.NewAddr(raddr.String())
if socksAddr == nil {
socksAddr = &gosocks5.Addr{}
}
addrLen := socksAddr.Length()
socksAddr.Encode(b[dataPos-addrLen : dataPos])
if _, err = pc1.WriteTo(b[dataPos-addrLen:dataPos+n], clientAddr); err != nil {
if _, err = pc1.WriteTo(b[:n], raddr); err != nil {
return err
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s <<< %s: %v data: %d",
clientAddr, raddr, b[dataPos-addrLen:dataPos], n)
}
h.logger.Debugf("%s <<< %s data: %d",
pc2.LocalAddr(), raddr, n)
return nil
}()
@ -212,7 +183,7 @@ func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) {
return <-errc
}
func (h *ssuHandler) tunnelUDP(tunnel net.Conn, c net.PacketConn) (err error) {
func (h *ssuHandler) tunnelUDP(tunnel, c net.PacketConn) (err error) {
bufSize := h.md.bufferSize
errc := make(chan error, 2)
@ -220,49 +191,32 @@ func (h *ssuHandler) tunnelUDP(tunnel net.Conn, c net.PacketConn) (err error) {
b := bufpool.Get(bufSize)
defer bufpool.Put(b)
const dataPos = 262
for {
addr := gosocks5.Addr{}
header := gosocks5.UDPHeader{
Addr: &addr,
}
err := func() error {
n, addr, err := tunnel.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(addr.String()) {
h.logger.Warn("bypass: ", addr.String())
return nil // bypass
}
if _, err := c.WriteTo(b[:n], addr); err != nil {
return err
}
h.logger.Debugf("%s >>> %s data: %d",
c.LocalAddr(), addr, n)
return nil
}()
data := b[dataPos:]
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: data,
}
_, err := dgram.ReadFrom(tunnel)
if err != nil {
errc <- err
return
}
// NOTE: the dgram.Data may be reallocated if the provided buffer is too short,
// we drop it for simplicity. As this occurs, you should enlarge the buffer size.
if len(dgram.Data) > len(data) {
h.logger.Warnf("buffer too short, dropped")
continue
}
raddr, err := net.ResolveUDPAddr("udp", addr.String())
if err != nil {
continue // drop silently
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr.String())
continue // bypass
}
if _, err := c.WriteTo(dgram.Data, raddr); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s >>> %s: %v data: %d",
tunnel.RemoteAddr(), raddr, header.String(), len(dgram.Data))
}
}
}()
@ -271,38 +225,31 @@ func (h *ssuHandler) tunnelUDP(tunnel net.Conn, c net.PacketConn) (err error) {
defer bufpool.Put(b)
for {
n, raddr, err := c.ReadFrom(b)
err := func() error {
n, raddr, err := c.ReadFrom(b)
if err != nil {
return err
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr.String())
return nil // bypass
}
if _, err := tunnel.WriteTo(b[:n], raddr); err != nil {
return err
}
h.logger.Debugf("%s <<< %s data: %d",
c.LocalAddr(), raddr, n)
return nil
}()
if err != nil {
errc <- err
return
}
if h.bypass != nil && h.bypass.Contains(raddr.String()) {
h.logger.Warn("bypass: ", raddr.String())
continue // bypass
}
addr, _ := gosocks5.NewAddr(raddr.String())
if addr == nil {
addr = &gosocks5.Addr{}
}
header := gosocks5.UDPHeader{
Rsv: uint16(n),
Addr: addr,
}
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: b[:n],
}
if _, err := dgram.WriteTo(tunnel); err != nil {
errc <- err
return
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debugf("%s <<< %s: %v data: %d",
tunnel.RemoteAddr(), raddr, header.String(), len(dgram.Data))
}
}
}()

View File

@ -0,0 +1,173 @@
package socks
import (
"bytes"
"net"
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/internal/bufpool"
)
var (
_ net.PacketConn = (*UDPTunConn)(nil)
_ net.Conn = (*UDPTunConn)(nil)
_ net.PacketConn = (*UDPConn)(nil)
_ net.Conn = (*UDPConn)(nil)
)
type UDPTunConn struct {
net.Conn
taddr net.Addr
}
func UDPTunClientConn(c net.Conn, targetAddr net.Addr) *UDPTunConn {
return &UDPTunConn{
Conn: c,
taddr: targetAddr,
}
}
func UDPTunServerConn(c net.Conn) *UDPTunConn {
return &UDPTunConn{
Conn: c,
}
}
func (c *UDPTunConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
socksAddr := gosocks5.Addr{}
header := gosocks5.UDPHeader{
Addr: &socksAddr,
}
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: b,
}
_, err = dgram.ReadFrom(c.Conn)
if err != nil {
return
}
n = len(dgram.Data)
if n > len(b) {
n = copy(b, dgram.Data)
}
addr, err = net.ResolveUDPAddr("udp", socksAddr.String())
return
}
func (c *UDPTunConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *UDPTunConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
socksAddr := gosocks5.Addr{}
if err = socksAddr.ParseFrom(addr.String()); err != nil {
return
}
header := gosocks5.UDPHeader{
Addr: &socksAddr,
}
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: b,
}
dgram.Header.Rsv = uint16(len(dgram.Data))
_, err = dgram.WriteTo(c.Conn)
n = len(b)
return
}
func (c *UDPTunConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.taddr)
}
var (
DefaultBufferSize = 4096
)
type UDPConn struct {
net.PacketConn
raddr net.Addr
taddr net.Addr
bufferSize int
}
func NewUDPConn(c net.PacketConn, bufferSize int) *UDPConn {
return &UDPConn{
PacketConn: c,
bufferSize: bufferSize,
}
}
// ReadFrom reads an UDP datagram.
// NOTE: for server side,
// the returned addr is the target address the client want to relay to.
func (c *UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
rbuf := bufpool.Get(c.bufferSize)
defer bufpool.Put(rbuf)
n, c.raddr, err = c.PacketConn.ReadFrom(rbuf)
if err != nil {
return
}
socksAddr := gosocks5.Addr{}
header := gosocks5.UDPHeader{
Addr: &socksAddr,
}
hlen, err := header.ReadFrom(bytes.NewReader(rbuf[:n]))
if err != nil {
return
}
n = copy(b, rbuf[hlen:n])
addr, err = net.ResolveUDPAddr("udp", socksAddr.String())
return
}
func (c *UDPConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
wbuf := bufpool.Get(c.bufferSize)
defer bufpool.Put(wbuf)
socksAddr := gosocks5.Addr{}
if err = socksAddr.ParseFrom(addr.String()); err != nil {
return
}
header := gosocks5.UDPHeader{
Addr: &socksAddr,
}
dgram := gosocks5.UDPDatagram{
Header: &header,
Data: b,
}
buf := bytes.NewBuffer(wbuf[:0])
_, err = dgram.WriteTo(buf)
if err != nil {
return
}
_, err = c.PacketConn.WriteTo(buf.Bytes(), c.raddr)
n = len(b)
return
}
func (c *UDPConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.taddr)
}
func (c *UDPConn) RemoteAddr() net.Addr {
return c.raddr
}

View File

@ -0,0 +1,96 @@
package ss
import (
"bytes"
"net"
"github.com/go-gost/gosocks5"
"github.com/go-gost/gost/pkg/internal/bufpool"
)
var (
DefaultBufferSize = 4096
)
var (
_ net.PacketConn = (*UDPConn)(nil)
_ net.Conn = (*UDPConn)(nil)
)
type UDPConn struct {
net.PacketConn
raddr net.Addr
taddr net.Addr
bufferSize int
}
func UDPClientConn(c net.PacketConn, remoteAddr, targetAddr net.Addr, bufferSize int) *UDPConn {
return &UDPConn{
PacketConn: c,
raddr: remoteAddr,
taddr: targetAddr,
bufferSize: bufferSize,
}
}
func UDPServerConn(c net.PacketConn, remoteAddr net.Addr, bufferSize int) *UDPConn {
return &UDPConn{
PacketConn: c,
raddr: remoteAddr,
bufferSize: bufferSize,
}
}
func (c *UDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
rbuf := bufpool.Get(c.bufferSize)
defer bufpool.Put(rbuf)
n, _, err = c.PacketConn.ReadFrom(rbuf)
if err != nil {
return
}
saddr := gosocks5.Addr{}
addrLen, err := saddr.ReadFrom(bytes.NewReader(rbuf[:n]))
if err != nil {
return
}
n = copy(b, rbuf[addrLen:n])
addr, err = net.ResolveUDPAddr("udp", saddr.String())
return
}
func (c *UDPConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
wbuf := bufpool.Get(c.bufferSize)
defer bufpool.Put(wbuf)
socksAddr := gosocks5.Addr{}
if err = socksAddr.ParseFrom(addr.String()); err != nil {
return
}
addrLen, err := socksAddr.Encode(wbuf)
if err != nil {
return
}
n = copy(wbuf[addrLen:], b)
_, err = c.PacketConn.WriteTo(wbuf[:addrLen+n], c.raddr)
return
}
func (c *UDPConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.taddr)
}
func (c *UDPConn) RemoteAddr() net.Addr {
return c.raddr
}

View File

@ -6,109 +6,184 @@ import (
"sync"
"sync/atomic"
"time"
"github.com/go-gost/gost/pkg/internal/bufpool"
"github.com/go-gost/gost/pkg/logger"
)
// serverConn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type serverConn struct {
// conn is a server side connection for UDP client peer, it implements net.Conn and net.PacketConn.
type conn struct {
net.PacketConn
raddr net.Addr
remoteAddr net.Addr
rc chan []byte // data receive queue
fresh int32
idle int32
closed chan struct{}
closeMutex sync.Mutex
config *serverConnConfig
}
type serverConnConfig struct {
ttl time.Duration
qsize int
onClose func()
}
func newServerConn(conn net.PacketConn, raddr net.Addr, cfg *serverConnConfig) *serverConn {
if conn == nil || raddr == nil {
return nil
}
if cfg == nil {
cfg = &serverConnConfig{}
}
c := &serverConn{
PacketConn: conn,
raddr: raddr,
rc: make(chan []byte, cfg.qsize),
func newConn(c net.PacketConn, raddr net.Addr, queue int) *conn {
return &conn{
PacketConn: c,
remoteAddr: raddr,
rc: make(chan []byte, queue),
closed: make(chan struct{}),
config: cfg,
}
go c.ttlWait()
return c
}
func (c *serverConn) send(b []byte) error {
func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select {
case c.rc <- b:
return nil
default:
return errors.New("queue is full")
case bb := <-c.rc:
n = copy(b, bb)
c.SetIdle(false)
bufpool.Put(bb)
case <-c.closed:
err = net.ErrClosed
return
}
addr = c.remoteAddr
return
}
func (c *serverConn) Read(b []byte) (n int, err error) {
func (c *conn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *serverConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
select {
case bb := <-c.rc:
n = copy(b, bb)
atomic.StoreInt32(&c.fresh, 1)
case <-c.closed:
err = errors.New("read from closed connection")
return
}
addr = c.raddr
return
func (c *conn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.remoteAddr)
}
func (c *serverConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.raddr)
}
func (c *serverConn) Close() error {
func (c *conn) Close() error {
c.closeMutex.Lock()
defer c.closeMutex.Unlock()
select {
case <-c.closed:
return errors.New("connection is closed")
default:
if c.config.onClose != nil {
c.config.onClose()
}
close(c.closed)
}
return nil
}
func (c *serverConn) RemoteAddr() net.Addr {
return c.raddr
func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *serverConn) ttlWait() {
ticker := time.NewTicker(c.config.ttl)
func (c *conn) IsIdle() bool {
return atomic.LoadInt32(&c.idle) > 0
}
func (c *conn) SetIdle(idle bool) {
v := int32(0)
if idle {
v = 1
}
atomic.StoreInt32(&c.idle, v)
}
func (c *conn) Queue(b []byte) error {
select {
case c.rc <- b:
return nil
case <-c.closed:
return net.ErrClosed
default:
return errors.New("recv queue is full")
}
}
type connPool struct {
m sync.Map
ttl time.Duration
closed chan struct{}
logger logger.Logger
}
func newConnPool(ttl time.Duration) *connPool {
p := &connPool{
ttl: ttl,
closed: make(chan struct{}),
}
go p.idleCheck()
return p
}
func (p *connPool) WithLogger(logger logger.Logger) *connPool {
p.logger = logger
return p
}
func (p *connPool) Get(key interface{}) (c *conn, ok bool) {
v, ok := p.m.Load(key)
if ok {
c, ok = v.(*conn)
}
return
}
func (p *connPool) Set(key interface{}, c *conn) {
p.m.Store(key, c)
}
func (p *connPool) Delete(key interface{}) {
p.m.Delete(key)
}
func (p *connPool) Close() {
select {
case <-p.closed:
return
default:
}
close(p.closed)
p.m.Range(func(k, v interface{}) bool {
if c, ok := v.(*conn); ok && c != nil {
c.Close()
}
return true
})
}
func (p *connPool) idleCheck() {
ticker := time.NewTicker(p.ttl)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if !atomic.CompareAndSwapInt32(&c.fresh, 1, 0) {
c.Close()
return
size := 0
idles := 0
p.m.Range(func(key, value interface{}) bool {
c, ok := value.(*conn)
if !ok || c == nil {
p.Delete(key)
return true
}
size++
if c.IsIdle() {
idles++
p.Delete(key)
c.Close()
return true
}
c.SetIdle(true)
return true
})
if idles > 0 {
p.logger.Debugf("connection pool: size=%d, idle=%d", size, idles)
}
case <-c.closed:
case <-p.closed:
return
}
}

View File

@ -2,9 +2,8 @@ package udp
import (
"net"
"sync"
"sync/atomic"
"github.com/go-gost/gost/pkg/internal/bufpool"
"github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
@ -16,13 +15,14 @@ func init() {
}
type udpListener struct {
addr string
md metadata
conn net.PacketConn
connChan chan net.Conn
errChan chan error
connPool connPool
logger logger.Logger
addr string
md metadata
conn net.PacketConn
connChan chan net.Conn
errChan chan error
closeChan chan struct{}
connPool *connPool
logger logger.Logger
}
func NewListener(opts ...listener.Option) listener.Listener {
@ -31,8 +31,10 @@ func NewListener(opts ...listener.Option) listener.Listener {
opt(options)
}
return &udpListener{
addr: options.Addr,
logger: options.Logger,
addr: options.Addr,
errChan: make(chan error, 1),
closeChan: make(chan struct{}),
logger: options.Logger,
}
}
@ -46,15 +48,13 @@ func (l *udpListener) Init(md md.Metadata) (err error) {
return
}
var conn net.PacketConn
conn, err = net.ListenUDP("udp", laddr)
l.conn, err = net.ListenUDP("udp", laddr)
if err != nil {
return
}
l.conn = conn
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.errChan = make(chan error, 1)
l.connPool = newConnPool(l.md.ttl).WithLogger(l.logger)
go l.listenLoop()
@ -74,12 +74,14 @@ func (l *udpListener) Accept() (conn net.Conn, err error) {
}
func (l *udpListener) Close() error {
err := l.conn.Close()
l.connPool.Range(func(k interface{}, v *serverConn) bool {
v.Close()
return true
})
return err
select {
case <-l.closeChan:
return nil
default:
close(l.closeChan)
l.connPool.Close()
return l.conn.Close()
}
}
func (l *udpListener) Addr() net.Addr {
@ -88,43 +90,43 @@ func (l *udpListener) Addr() net.Addr {
func (l *udpListener) listenLoop() {
for {
b := make([]byte, l.md.readBufferSize)
b := bufpool.Get(l.md.readBufferSize)
n, raddr, err := l.conn.ReadFrom(b)
if err != nil {
l.logger.Error("accept:", err)
l.errChan <- err
close(l.errChan)
return
}
conn, ok := l.connPool.Get(raddr.String())
if !ok {
conn = newServerConn(l.conn, raddr,
&serverConnConfig{
ttl: l.md.ttl,
qsize: l.md.readQueueSize,
onClose: func() {
l.connPool.Delete(raddr.String())
},
})
select {
case l.connChan <- conn:
l.connPool.Set(raddr.String(), conn)
default:
conn.Close()
l.logger.Error("connection queue is full")
}
c := l.getConn(raddr)
if c == nil {
bufpool.Put(b)
continue
}
if err := conn.send(b[:n]); err != nil {
l.logger.Warn("data discarded:", err)
if err := c.Queue(b[:n]); err != nil {
l.logger.Warn("data discarded: ", err)
}
l.logger.Debug("recv", n)
}
}
func (l *udpListener) getConn(addr net.Addr) *conn {
c, ok := l.connPool.Get(addr.String())
if !ok {
c = newConn(l.conn, addr, l.md.readQueueSize)
select {
case l.connChan <- c:
l.connPool.Set(addr.String(), c)
default:
c.Close()
l.logger.Warnf("connection queue is full, client %s discarded", addr.String())
return nil
}
}
return c
}
func (l *udpListener) parseMetadata(md md.Metadata) (err error) {
l.md.ttl = md.GetDuration(ttl)
if l.md.ttl <= 0 {
@ -147,36 +149,3 @@ func (l *udpListener) parseMetadata(md md.Metadata) (err error) {
return
}
type connPool struct {
size int64
m sync.Map
}
func (p *connPool) Get(key interface{}) (conn *serverConn, ok bool) {
v, ok := p.m.Load(key)
if ok {
conn, ok = v.(*serverConn)
}
return
}
func (p *connPool) Set(key interface{}, conn *serverConn) {
p.m.Store(key, conn)
atomic.AddInt64(&p.size, 1)
}
func (p *connPool) Delete(key interface{}) {
p.m.Delete(key)
atomic.AddInt64(&p.size, -1)
}
func (p *connPool) Range(f func(key interface{}, value *serverConn) bool) {
p.m.Range(func(k, v interface{}) bool {
return f(k, v.(*serverConn))
})
}
func (p *connPool) Size() int64 {
return atomic.LoadInt64(&p.size)
}

View File

@ -4,7 +4,7 @@ import "time"
const (
defaultTTL = 60 * time.Second
defaultReadBufferSize = 1024
defaultReadBufferSize = 4096
defaultReadQueueSize = 128
defaultConnQueueSize = 128
)