From c5df25e84d86cf8c5435374d575be15ff8684097 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Fri, 19 Nov 2021 15:48:03 +0800 Subject: [PATCH] separate ssu from ss --- cmd/gost/config.go | 227 ---------------------------- cmd/gost/register.go | 2 + pkg/connector/ss/connector.go | 36 ----- pkg/connector/ss/metadata.go | 16 -- pkg/connector/ss/udp/connector.go | 84 +++++++++++ pkg/connector/ss/udp/metadata.go | 49 +++++++ pkg/handler/ss/handler.go | 26 +--- pkg/handler/ss/metadata.go | 16 -- pkg/handler/ss/udp.go | 236 ------------------------------ pkg/handler/ss/udp/handler.go | 173 ++++++++++++++++++++++ pkg/handler/ss/udp/metadata.go | 52 +++++++ 11 files changed, 361 insertions(+), 556 deletions(-) delete mode 100644 cmd/gost/config.go create mode 100644 pkg/connector/ss/udp/connector.go create mode 100644 pkg/connector/ss/udp/metadata.go delete mode 100644 pkg/handler/ss/udp.go create mode 100644 pkg/handler/ss/udp/handler.go create mode 100644 pkg/handler/ss/udp/metadata.go diff --git a/cmd/gost/config.go b/cmd/gost/config.go deleted file mode 100644 index 446db13..0000000 --- a/cmd/gost/config.go +++ /dev/null @@ -1,227 +0,0 @@ -package main - -import ( - "io" - "os" - "strings" - - "github.com/go-gost/gost/pkg/bypass" - "github.com/go-gost/gost/pkg/chain" - "github.com/go-gost/gost/pkg/config" - "github.com/go-gost/gost/pkg/connector" - "github.com/go-gost/gost/pkg/dialer" - "github.com/go-gost/gost/pkg/handler" - "github.com/go-gost/gost/pkg/listener" - "github.com/go-gost/gost/pkg/logger" - "github.com/go-gost/gost/pkg/metadata" - "github.com/go-gost/gost/pkg/registry" - "github.com/go-gost/gost/pkg/service" -) - -var ( - chains = make(map[string]*chain.Chain) - bypasses = make(map[string]bypass.Bypass) -) - -func buildService(cfg *config.Config) (services []*service.Service) { - if cfg == nil || len(cfg.Services) == 0 { - return - } - - for _, bypassCfg := range cfg.Bypasses { - bypasses[bypassCfg.Name] = bypassFromConfig(bypassCfg) - } - - for _, chainCfg := range cfg.Chains { - chains[chainCfg.Name] = chainFromConfig(chainCfg) - } - - for _, svc := range cfg.Services { - serviceLogger := log.WithFields(map[string]interface{}{ - "service": svc.Name, - }) - - listenerLogger := serviceLogger.WithFields(map[string]interface{}{ - "kind": "listener", - "type": svc.Listener.Type, - }) - ln := registry.GetListener(svc.Listener.Type)( - listener.AddrOption(svc.Addr), - listener.LoggerOption(listenerLogger), - ) - - cln, chainable := ln.(listener.Chainable) - if chainable { - cln.Chain(chains[svc.Chain]) - } - - if err := ln.Init(metadata.MapMetadata(svc.Listener.Metadata)); err != nil { - listenerLogger.Fatal("init: ", err) - } - - handlerLogger := serviceLogger.WithFields(map[string]interface{}{ - "kind": "handler", - "type": svc.Handler.Type, - }) - - h := registry.GetHandler(svc.Handler.Type)( - handler.ChainOption(chains[svc.Chain]), - handler.BypassOption(bypasses[svc.Bypass]), - handler.LoggerOption(handlerLogger), - ) - - if forwarder, ok := h.(handler.Forwarder); ok { - chain := chains[svc.Chain] - if chainable { - chain = nil - } - forwarder.Forward(forwarderFromConfig(svc.Forwarder), chain) - } - - if err := h.Init(metadata.MapMetadata(svc.Handler.Metadata)); err != nil { - handlerLogger.Fatal("init: ", err) - } - - s := (&service.Service{}). - WithListener(ln). - WithHandler(h). - WithLogger(serviceLogger) - services = append(services, s) - - serviceLogger.Info("listening on: ", s.Addr()) - } - - return -} - -func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { - if cfg == nil { - return nil - } - - c := &chain.Chain{} - - selector := selectorFromConfig(cfg.Selector) - for _, hop := range cfg.Hops { - group := &chain.NodeGroup{} - for _, v := range hop.Nodes { - - connectorLogger := log.WithFields(map[string]interface{}{ - "kind": "connector", - "type": v.Connector.Type, - "hop": hop.Name, - "node": v.Name, - }) - cr := registry.GetConnector(v.Connector.Type)( - connector.LoggerOption(connectorLogger), - ) - if err := cr.Init(metadata.MapMetadata(v.Connector.Metadata)); err != nil { - connectorLogger.Fatal("init: ", err) - } - - dialerLogger := log.WithFields(map[string]interface{}{ - "kind": "dialer", - "type": v.Dialer.Type, - "hop": hop.Name, - "node": v.Name, - }) - d := registry.GetDialer(v.Dialer.Type)( - dialer.LoggerOption(dialerLogger), - ) - if err := d.Init(metadata.MapMetadata(v.Dialer.Metadata)); err != nil { - dialerLogger.Fatal("init: ", err) - } - - tr := (&chain.Transport{}). - WithConnector(cr). - WithDialer(d) - - node := chain.NewNode(v.Name, v.Addr). - WithTransport(tr). - WithBypass(bypasses[v.Bypass]) - group.AddNode(node) - } - - sel := selector - if s := selectorFromConfig(hop.Selector); s != nil { - sel = s - } - group.WithSelector(sel) - c.AddNodeGroup(group) - } - - return c -} - -func logFromConfig(cfg *config.LogConfig) logger.Logger { - if cfg == nil { - cfg = &config.LogConfig{} - } - opts := []logger.LoggerOption{ - logger.FormatLoggerOption(logger.LogFormat(cfg.Format)), - logger.LevelLoggerOption(logger.LogLevel(cfg.Level)), - } - - var out io.Writer = os.Stderr - switch cfg.Output { - case "stdout", "": - out = os.Stdout - case "stderr": - out = os.Stderr - default: - f, err := os.OpenFile(cfg.Output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) - if err != nil { - log.Warnf("log", err) - } else { - out = f - } - } - opts = append(opts, logger.OutputLoggerOption(out)) - - return logger.NewLogger(opts...) -} - -func selectorFromConfig(cfg *config.SelectorConfig) chain.Selector { - if cfg == nil { - return nil - } - - var strategy chain.Strategy - switch cfg.Strategy { - case "round": - strategy = chain.RoundRobinStrategy() - case "random": - strategy = chain.RandomStrategy() - case "fifo": - strategy = chain.FIFOStrategy() - default: - strategy = chain.RoundRobinStrategy() - } - - return chain.NewSelector( - strategy, - chain.InvalidFilter(), - chain.FailFilter(cfg.MaxFails, cfg.FailTimeout), - ) -} - -func bypassFromConfig(cfg *config.BypassConfig) bypass.Bypass { - if cfg == nil { - return nil - } - return bypass.NewBypassPatterns(cfg.Reverse, cfg.Matchers...) -} - -func forwarderFromConfig(cfg *config.ForwarderConfig) *chain.NodeGroup { - if cfg == nil || len(cfg.Targets) == 0 { - return nil - } - - group := &chain.NodeGroup{} - for _, target := range cfg.Targets { - if v := strings.TrimSpace(target); v != "" { - group.AddNode(chain.NewNode(target, target)) - } - } - return group.WithSelector(selectorFromConfig(cfg.Selector)) -} diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 94baa70..4ecde3c 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -6,6 +6,7 @@ import ( _ "github.com/go-gost/gost/pkg/connector/socks/v4" _ "github.com/go-gost/gost/pkg/connector/socks/v5" _ "github.com/go-gost/gost/pkg/connector/ss" + _ "github.com/go-gost/gost/pkg/connector/ss/udp" // Register dialers _ "github.com/go-gost/gost/pkg/dialer/tcp" @@ -18,6 +19,7 @@ import ( _ "github.com/go-gost/gost/pkg/handler/socks/v4" _ "github.com/go-gost/gost/pkg/handler/socks/v5" _ "github.com/go-gost/gost/pkg/handler/ss" + _ "github.com/go-gost/gost/pkg/handler/ss/udp" // Register listeners _ "github.com/go-gost/gost/pkg/listener/ftcp" diff --git a/pkg/connector/ss/connector.go b/pkg/connector/ss/connector.go index e1407ac..8e04b71 100644 --- a/pkg/connector/ss/connector.go +++ b/pkg/connector/ss/connector.go @@ -2,14 +2,12 @@ package ss import ( "context" - "errors" "fmt" "net" "time" "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/common/bufpool" - "github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/common/util/ss" "github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/logger" @@ -57,8 +55,6 @@ func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, addre c.logger.Error(err) return nil, err } - case "udp", "udp4", "udp6": - return c.connectUDP(ctx, conn, network, address) default: err := fmt.Errorf("network %s is unsupported", network) c.logger.Error(err) @@ -102,35 +98,3 @@ func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, addre return sc, nil } - -func (c *ssConnector) connectUDP(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) { - if !c.md.enableUDP { - err := errors.New("UDP relay is disabled") - c.logger.Error(err) - return nil, err - } - - if c.md.connectTimeout > 0 { - conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) - defer conn.SetDeadline(time.Time{}) - } - - taddr, _ := net.ResolveUDPAddr(network, address) - if taddr == nil { - taddr = &net.UDPAddr{} - } - - pc, ok := conn.(net.PacketConn) - if ok { - if c.md.cipher != nil { - pc = c.md.cipher.PacketConn(pc) - } - - return ss.UDPClientConn(pc, conn.RemoteAddr(), taddr, c.md.udpBufferSize), nil - } - - if c.md.cipher != nil { - conn = ss.ShadowConn(c.md.cipher.StreamConn(conn), nil) - } - return socks.UDPTunClientConn(conn, taddr), nil -} diff --git a/pkg/connector/ss/metadata.go b/pkg/connector/ss/metadata.go index 3825f23..9d583bc 100644 --- a/pkg/connector/ss/metadata.go +++ b/pkg/connector/ss/metadata.go @@ -12,8 +12,6 @@ type metadata struct { cipher core.Cipher connectTimeout time.Duration noDelay bool - enableUDP bool - udpBufferSize int } func (c *ssConnector) parseMetadata(md md.Metadata) (err error) { @@ -23,8 +21,6 @@ func (c *ssConnector) parseMetadata(md md.Metadata) (err error) { key = "key" connectTimeout = "timeout" noDelay = "nodelay" - enableUDP = "udp" // enable UDP relay - udpBufferSize = "udpBufferSize" // udp buffer size ) c.md.cipher, err = ss.ShadowCipher( @@ -38,18 +34,6 @@ func (c *ssConnector) parseMetadata(md md.Metadata) (err error) { c.md.connectTimeout = md.GetDuration(connectTimeout) c.md.noDelay = md.GetBool(noDelay) - c.md.enableUDP = md.GetBool(enableUDP) - - if c.md.udpBufferSize > 0 { - if c.md.udpBufferSize < 512 { - c.md.udpBufferSize = 512 - } - if c.md.udpBufferSize > 65*1024 { - c.md.udpBufferSize = 65 * 1024 - } - } else { - c.md.udpBufferSize = 4096 - } return } diff --git a/pkg/connector/ss/udp/connector.go b/pkg/connector/ss/udp/connector.go new file mode 100644 index 0000000..e909ae9 --- /dev/null +++ b/pkg/connector/ss/udp/connector.go @@ -0,0 +1,84 @@ +package ss + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/go-gost/gost/pkg/common/util/socks" + "github.com/go-gost/gost/pkg/common/util/ss" + "github.com/go-gost/gost/pkg/connector" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegiserConnector("ssu", NewConnector) +} + +type ssuConnector struct { + md metadata + logger logger.Logger +} + +func NewConnector(opts ...connector.Option) connector.Connector { + options := &connector.Options{} + for _, opt := range opts { + opt(options) + } + + return &ssuConnector{ + logger: options.Logger, + } +} + +func (c *ssuConnector) Init(md md.Metadata) (err error) { + return c.parseMetadata(md) +} + +func (c *ssuConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { + c.logger = c.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + "network": network, + "address": address, + }) + c.logger.Infof("connect: %s/%s", address, network) + + switch network { + case "udp", "udp4", "udp6": + default: + err := fmt.Errorf("network %s is unsupported", network) + c.logger.Error(err) + return nil, err + } + + if c.md.connectTimeout > 0 { + conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + taddr, _ := net.ResolveUDPAddr(network, address) + if taddr == nil { + taddr = &net.UDPAddr{} + } + + pc, ok := conn.(net.PacketConn) + if ok { + if c.md.cipher != nil { + pc = c.md.cipher.PacketConn(pc) + } + + // standard UDP relay + return ss.UDPClientConn(pc, conn.RemoteAddr(), taddr, c.md.udpBufferSize), nil + } + + if c.md.cipher != nil { + conn = ss.ShadowConn(c.md.cipher.StreamConn(conn), nil) + } + + // UDP over TCP + return socks.UDPTunClientConn(conn, taddr), nil +} diff --git a/pkg/connector/ss/udp/metadata.go b/pkg/connector/ss/udp/metadata.go new file mode 100644 index 0000000..ae69291 --- /dev/null +++ b/pkg/connector/ss/udp/metadata.go @@ -0,0 +1,49 @@ +package ss + +import ( + "time" + + "github.com/go-gost/gost/pkg/common/util/ss" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/shadowsocks/go-shadowsocks2/core" +) + +type metadata struct { + cipher core.Cipher + connectTimeout time.Duration + udpBufferSize int +} + +func (c *ssuConnector) parseMetadata(md md.Metadata) (err error) { + const ( + method = "method" + password = "password" + key = "key" + connectTimeout = "timeout" + udpBufferSize = "udpBufferSize" // udp buffer size + ) + + c.md.cipher, err = ss.ShadowCipher( + md.GetString(method), + md.GetString(password), + md.GetString(key), + ) + if err != nil { + return + } + + c.md.connectTimeout = md.GetDuration(connectTimeout) + + if c.md.udpBufferSize > 0 { + if c.md.udpBufferSize < 512 { + c.md.udpBufferSize = 512 + } + if c.md.udpBufferSize > 65*1024 { + c.md.udpBufferSize = 65 * 1024 + } + } else { + c.md.udpBufferSize = 4096 + } + + return +} diff --git a/pkg/handler/ss/handler.go b/pkg/handler/ss/handler.go index 8759094..f7e92c6 100644 --- a/pkg/handler/ss/handler.go +++ b/pkg/handler/ss/handler.go @@ -1,7 +1,6 @@ package ss import ( - "bufio" "context" "io" "io/ioutil" @@ -62,12 +61,6 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) }() - // standard UDP relay. - if pc, ok := conn.(net.PacketConn); ok { - h.handleUDP(ctx, pc, conn.RemoteAddr()) - return - } - if h.md.cipher != nil { conn = ss.ShadowConn(h.md.cipher.StreamConn(conn), nil) } @@ -76,25 +69,8 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn) { conn.SetReadDeadline(time.Now().Add(h.md.readTimeout)) } - br := bufio.NewReader(conn) - data, err := br.Peek(3) - conn.SetReadDeadline(time.Time{}) - if err != nil { - h.logger.Error(err) - h.discard(conn) - return - } - - conn = handler.NewBufferReaderConn(conn, br) - if data[2] == 0xff { - // UDP-over-TCP relay - h.handleUDPTun(ctx, conn) - return - } - - // standard TCP. addr := &gosocks5.Addr{} - if _, err = addr.ReadFrom(conn); err != nil { + if _, err := addr.ReadFrom(conn); err != nil { h.logger.Error(err) h.discard(conn) return diff --git a/pkg/handler/ss/metadata.go b/pkg/handler/ss/metadata.go index 697db63..f00fd96 100644 --- a/pkg/handler/ss/metadata.go +++ b/pkg/handler/ss/metadata.go @@ -12,8 +12,6 @@ type metadata struct { cipher core.Cipher readTimeout time.Duration retryCount int - bufferSize int - enableUDP bool } func (h *ssHandler) parseMetadata(md md.Metadata) (err error) { @@ -23,8 +21,6 @@ func (h *ssHandler) parseMetadata(md md.Metadata) (err error) { key = "key" readTimeout = "readTimeout" retryCount = "retry" - enableUDP = "udp" - bufferSize = "bufferSize" ) h.md.cipher, err = ss.ShadowCipher( @@ -38,18 +34,6 @@ func (h *ssHandler) parseMetadata(md md.Metadata) (err error) { h.md.readTimeout = md.GetDuration(readTimeout) h.md.retryCount = md.GetInt(retryCount) - h.md.enableUDP = md.GetBool(enableUDP) - h.md.bufferSize = md.GetInt(bufferSize) - if h.md.bufferSize > 0 { - if h.md.bufferSize < 512 { - h.md.bufferSize = 512 // min buffer size - } - if h.md.bufferSize > 65*1024 { - h.md.bufferSize = 65 * 1024 // max buffer size - } - } else { - h.md.bufferSize = 4096 // default buffer size - } return } diff --git a/pkg/handler/ss/udp.go b/pkg/handler/ss/udp.go deleted file mode 100644 index a75d65a..0000000 --- a/pkg/handler/ss/udp.go +++ /dev/null @@ -1,236 +0,0 @@ -package ss - -import ( - "context" - "net" - "time" - - "github.com/go-gost/gost/pkg/chain" - "github.com/go-gost/gost/pkg/common/bufpool" - "github.com/go-gost/gost/pkg/common/util/socks" - "github.com/go-gost/gost/pkg/common/util/ss" -) - -func (h *ssHandler) handleUDP(ctx context.Context, conn net.PacketConn, raddr net.Addr) { - if !h.md.enableUDP { - h.logger.Error("UDP relay is diabled") - return - } - - if h.md.cipher != nil { - conn = h.md.cipher.PacketConn(conn) - } - - // obtain a udp connection - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - c, err := r.Dial(ctx, "udp", "") - if err != nil { - h.logger.Error(err) - return - } - - cc, ok := c.(net.PacketConn) - if !ok { - h.logger.Errorf("%s: not a packet connection") - return - } - defer cc.Close() - - h.logger = h.logger.WithFields(map[string]interface{}{ - "bind": cc.LocalAddr().String(), - }) - h.logger.Infof("bind on %s OK", cc.LocalAddr().String()) - t := time.Now() - h.logger.Infof("%s <-> %s", raddr, cc.LocalAddr()) - h.relayPacket( - ss.UDPServerConn(conn, raddr, h.md.bufferSize), - cc, - ) - h.logger. - WithFields(map[string]interface{}{"duration": time.Since(t)}). - Infof("%s >-< %s", raddr, cc.LocalAddr()) -} - -func (h *ssHandler) handleUDPTun(ctx context.Context, conn net.Conn) { - if !h.md.enableUDP { - h.logger.Error("UDP relay is diabled") - return - } - - // obtain a udp connection - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - c, err := r.Dial(ctx, "udp", "") - if err != nil { - h.logger.Error(err) - return - } - - cc, ok := c.(net.PacketConn) - if !ok { - h.logger.Errorf("%s: not a packet connection") - return - } - defer cc.Close() - - h.logger = h.logger.WithFields(map[string]interface{}{ - "bind": cc.LocalAddr().String(), - }) - h.logger.Infof("bind on %s OK", cc.LocalAddr().String()) - - t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) - h.tunnelUDP(socks.UDPTunServerConn(conn), cc) - h.logger. - WithFields(map[string]interface{}{"duration": time.Since(t)}). - Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) -} - -func (h *ssHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) { - bufSize := h.md.bufferSize - errc := make(chan error, 2) - - go func() { - for { - err := func() error { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - n, addr, err := pc1.ReadFrom(b) - if err != nil { - return err - } - - if h.bypass != nil && h.bypass.Contains(addr.String()) { - h.logger.Warn("bypass: ", addr) - return nil - } - - if _, err = pc2.WriteTo(b[:n], addr); err != nil { - return err - } - - h.logger.Debugf("%s >>> %s data: %d", - pc2.LocalAddr(), addr, n) - return nil - }() - - if err != nil { - errc <- err - return - } - } - }() - - go func() { - for { - err := func() error { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - n, raddr, err := pc2.ReadFrom(b) - if err != nil { - return err - } - - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr) - return nil - } - - if _, err = pc1.WriteTo(b[:n], raddr); err != nil { - return err - } - - h.logger.Debugf("%s <<< %s data: %d", - pc2.LocalAddr(), raddr, n) - return nil - }() - - if err != nil { - errc <- err - return - } - } - }() - - return <-errc -} - -func (h *ssHandler) tunnelUDP(tunnel, c net.PacketConn) (err error) { - bufSize := h.md.bufferSize - errc := make(chan error, 2) - - go func() { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - for { - err := func() error { - n, addr, err := tunnel.ReadFrom(b) - if err != nil { - return err - } - - if h.bypass != nil && h.bypass.Contains(addr.String()) { - h.logger.Warn("bypass: ", addr.String()) - return nil // bypass - } - - if _, err := c.WriteTo(b[:n], addr); err != nil { - return err - } - - h.logger.Debugf("%s >>> %s data: %d", - c.LocalAddr(), addr, n) - - return nil - }() - - if err != nil { - errc <- err - return - } - } - }() - - go func() { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - for { - err := func() error { - n, raddr, err := c.ReadFrom(b) - if err != nil { - return err - } - - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr.String()) - return nil // bypass - } - - if _, err := tunnel.WriteTo(b[:n], raddr); err != nil { - return err - } - - h.logger.Debugf("%s <<< %s data: %d", - c.LocalAddr(), raddr, n) - - return nil - }() - - if err != nil { - errc <- err - return - } - } - }() - - return <-errc -} diff --git a/pkg/handler/ss/udp/handler.go b/pkg/handler/ss/udp/handler.go new file mode 100644 index 0000000..0ed846a --- /dev/null +++ b/pkg/handler/ss/udp/handler.go @@ -0,0 +1,173 @@ +package ss + +import ( + "context" + "net" + "time" + + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/common/bufpool" + "github.com/go-gost/gost/pkg/common/util/socks" + "github.com/go-gost/gost/pkg/common/util/ss" + "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegisterHandler("ssu", NewHandler) +} + +type ssuHandler struct { + chain *chain.Chain + bypass bypass.Bypass + logger logger.Logger + md metadata +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := &handler.Options{} + for _, opt := range opts { + opt(options) + } + + return &ssuHandler{ + chain: options.Chain, + bypass: options.Bypass, + logger: options.Logger, + } +} + +func (h *ssuHandler) Init(md md.Metadata) (err error) { + return h.parseMetadata(md) +} + +func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { + defer conn.Close() + + start := time.Now() + h.logger = h.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + pc, ok := conn.(net.PacketConn) + if ok { + if h.md.cipher != nil { + pc = h.md.cipher.PacketConn(pc) + } + // standard UDP relay. + pc = ss.UDPServerConn(pc, conn.RemoteAddr(), h.md.bufferSize) + } else { + if h.md.cipher != nil { + conn = ss.ShadowConn(h.md.cipher.StreamConn(conn), nil) + } + // UDP over TCP + pc = socks.UDPTunServerConn(conn) + } + + // obtain a udp connection + r := (&chain.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + c, err := r.Dial(ctx, "udp", "") + if err != nil { + h.logger.Error(err) + return + } + + cc, ok := c.(net.PacketConn) + if !ok { + h.logger.Errorf("%s: not a packet connection") + return + } + defer cc.Close() + + t := time.Now() + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) + h.relayPacket(pc, cc) + h.logger. + WithFields(map[string]interface{}{"duration": time.Since(t)}). + Infof("%s >-< %s", conn.RemoteAddr(), cc.LocalAddr()) +} + +func (h *ssuHandler) relayPacket(pc1, pc2 net.PacketConn) (err error) { + bufSize := h.md.bufferSize + errc := make(chan error, 2) + + go func() { + for { + err := func() error { + b := bufpool.Get(bufSize) + defer bufpool.Put(b) + + n, addr, err := pc1.ReadFrom(b) + if err != nil { + return err + } + + if h.bypass != nil && h.bypass.Contains(addr.String()) { + h.logger.Warn("bypass: ", addr) + return nil + } + + if _, err = pc2.WriteTo(b[:n], addr); err != nil { + return err + } + + h.logger.Debugf("%s >>> %s data: %d", + pc2.LocalAddr(), addr, n) + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + go func() { + for { + err := func() error { + b := bufpool.Get(bufSize) + defer bufpool.Put(b) + + n, raddr, err := pc2.ReadFrom(b) + if err != nil { + return err + } + + if h.bypass != nil && h.bypass.Contains(raddr.String()) { + h.logger.Warn("bypass: ", raddr) + return nil + } + + if _, err = pc1.WriteTo(b[:n], raddr); err != nil { + return err + } + + h.logger.Debugf("%s <<< %s data: %d", + pc2.LocalAddr(), raddr, n) + return nil + }() + + if err != nil { + errc <- err + return + } + } + }() + + return <-errc +} diff --git a/pkg/handler/ss/udp/metadata.go b/pkg/handler/ss/udp/metadata.go new file mode 100644 index 0000000..f1842f6 --- /dev/null +++ b/pkg/handler/ss/udp/metadata.go @@ -0,0 +1,52 @@ +package ss + +import ( + "time" + + "github.com/go-gost/gost/pkg/common/util/ss" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/shadowsocks/go-shadowsocks2/core" +) + +type metadata struct { + cipher core.Cipher + readTimeout time.Duration + retryCount int + bufferSize int +} + +func (h *ssuHandler) parseMetadata(md md.Metadata) (err error) { + const ( + method = "method" + password = "password" + key = "key" + readTimeout = "readTimeout" + retryCount = "retry" + bufferSize = "bufferSize" + ) + + h.md.cipher, err = ss.ShadowCipher( + md.GetString(method), + md.GetString(password), + md.GetString(key), + ) + if err != nil { + return + } + + h.md.readTimeout = md.GetDuration(readTimeout) + h.md.retryCount = md.GetInt(retryCount) + + h.md.bufferSize = md.GetInt(bufferSize) + if h.md.bufferSize > 0 { + if h.md.bufferSize < 512 { + h.md.bufferSize = 512 // min buffer size + } + if h.md.bufferSize > 65*1024 { + h.md.bufferSize = 65 * 1024 // max buffer size + } + } else { + h.md.bufferSize = 4096 // default buffer size + } + return +}