diff --git a/cmd/gost/config.go b/cmd/gost/config.go index 36f2e57..64e32ec 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -42,11 +42,11 @@ func buildService(cfg *config.Config) (services []*service.Service) { "service": svc.Name, "listener": svc.Listener.Type, "handler": svc.Handler.Type, + "chain": svc.Chain, }) listenerLogger := serviceLogger.WithFields(map[string]interface{}{ "kind": "listener", - "type": svc.Listener.Type, }) ln := registry.GetListener(svc.Listener.Type)( listener.AddrOption(svc.Addr), @@ -63,7 +63,6 @@ func buildService(cfg *config.Config) (services []*service.Service) { handlerLogger := serviceLogger.WithFields(map[string]interface{}{ "kind": "handler", - "type": svc.Handler.Type, }) h := registry.GetHandler(svc.Handler.Type)( @@ -89,7 +88,7 @@ func buildService(cfg *config.Config) (services []*service.Service) { WithLogger(serviceLogger) services = append(services, s) - serviceLogger.Infof("listening on: %s/%s", s.Addr().String(), s.Addr().Network()) + serviceLogger.Infof("listening on %s/%s", s.Addr().String(), s.Addr().Network()) } return @@ -100,18 +99,22 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { return nil } - c := &chain.Chain{} + chainLogger := log.WithFields(map[string]interface{}{ + "kind": "chain", + "chain": cfg.Name, + }) + c := &chain.Chain{} selector := selectorFromConfig(cfg.Selector) for _, hop := range cfg.Hops { group := &chain.NodeGroup{} for _, v := range hop.Nodes { - - connectorLogger := log.WithFields(map[string]interface{}{ - "kind": "connector", - "type": v.Connector.Type, - "hop": hop.Name, - "node": v.Name, + connectorLogger := chainLogger.WithFields(map[string]interface{}{ + "kind": "connector", + "connector": v.Connector.Type, + "dialer": v.Dialer.Type, + "hop": hop.Name, + "node": v.Name, }) cr := registry.GetConnector(v.Connector.Type)( connector.LoggerOption(connectorLogger), @@ -120,11 +123,12 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { connectorLogger.Fatal("init: ", err) } - dialerLogger := log.WithFields(map[string]interface{}{ - "kind": "dialer", - "type": v.Dialer.Type, - "hop": hop.Name, - "node": v.Name, + dialerLogger := chainLogger.WithFields(map[string]interface{}{ + "kind": "dialer", + "connector": v.Connector.Type, + "dialer": v.Dialer.Type, + "hop": hop.Name, + "node": v.Name, }) d := registry.GetDialer(v.Dialer.Type)( dialer.LoggerOption(dialerLogger), diff --git a/cmd/gost/norm.go b/cmd/gost/norm.go index 49d1d72..1783508 100644 --- a/cmd/gost/norm.go +++ b/cmd/gost/norm.go @@ -109,7 +109,7 @@ func normChain(chain *config.ChainConfig) { } } if u.User != nil { - md["user"] = []interface{}{u.User.String()} + md["user"] = u.User.String() } node.Addr = u.Host diff --git a/cmd/gost/register.go b/cmd/gost/register.go index c642a7b..82f2022 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -4,6 +4,7 @@ import ( // Register connectors _ "github.com/go-gost/gost/pkg/connector/forward" _ "github.com/go-gost/gost/pkg/connector/http" + _ "github.com/go-gost/gost/pkg/connector/relay" _ "github.com/go-gost/gost/pkg/connector/socks/v4" _ "github.com/go-gost/gost/pkg/connector/socks/v5" _ "github.com/go-gost/gost/pkg/connector/ss" diff --git a/pkg/chain/route.go b/pkg/chain/route.go index 6460525..3f09685 100644 --- a/pkg/chain/route.go +++ b/pkg/chain/route.go @@ -23,7 +23,7 @@ func (r *Route) AddNode(node *Node) { r.nodes = append(r.nodes, node) } -func (r *Route) Connect(ctx context.Context) (conn net.Conn, err error) { +func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) { if r.IsEmpty() { return nil, ErrEmptyRoute } @@ -72,7 +72,7 @@ func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, er return r.dialDirect(ctx, network, address) } - conn, err := r.Connect(ctx) + conn, err := r.connect(ctx) if err != nil { return nil, err } @@ -103,7 +103,7 @@ func (r *Route) Bind(ctx context.Context, network, address string, opts ...conne return r.bindLocal(ctx, network, address, opts...) } - conn, err := r.Connect(ctx) + conn, err := r.connect(ctx) if err != nil { return nil, err } diff --git a/pkg/chain/router.go b/pkg/chain/router.go index 3401df2..f6158d3 100644 --- a/pkg/chain/router.go +++ b/pkg/chain/router.go @@ -36,7 +36,7 @@ func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Co if count <= 0 { count = 1 } - r.logger.Debugf("dial: %s/%s", address, network) + r.logger.Debugf("dial %s/%s", address, network) for i := 0; i < count; i++ { route := r.chain.GetRouteFor(network, address) @@ -47,41 +47,14 @@ func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Co fmt.Fprintf(&buf, "%s@%s > ", node.Name(), node.Addr()) } fmt.Fprintf(&buf, "%s", address) - r.logger.Debugf("route(retry=%d): %s", i, buf.String()) + r.logger.Debugf("route(retry=%d) %s", i, buf.String()) } conn, err = route.Dial(ctx, network, address) if err == nil { break } - r.logger.Errorf("route(retry=%d): %s", i, err) - } - - return -} - -func (r *Router) Connect(ctx context.Context) (conn net.Conn, err error) { - count := r.retries + 1 - if count <= 0 { - count = 1 - } - - for i := 0; i < count; i++ { - route := r.chain.GetRoute() - - if r.logger.IsLevelEnabled(logger.DebugLevel) { - buf := bytes.Buffer{} - for _, node := range route.Path() { - fmt.Fprintf(&buf, "%s@%s > ", node.Name(), node.Addr()) - } - r.logger.Debugf("route(retry=%d): %s", i, buf.String()) - } - - conn, err = route.Connect(ctx) - if err == nil { - break - } - r.logger.Errorf("route(retry=%d): %s", i, err) + r.logger.Errorf("route(retry=%d) %s", i, err) } return @@ -92,7 +65,7 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...conn if count <= 0 { count = 1 } - r.logger.Debugf("bind: %s/%s", address, network) + r.logger.Debugf("bind on %s/%s", address, network) for i := 0; i < count; i++ { route := r.chain.GetRouteFor(network, address) @@ -103,14 +76,14 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...conn fmt.Fprintf(&buf, "%s@%s > ", node.Name(), node.Addr()) } fmt.Fprintf(&buf, "%s", address) - r.logger.Debugf("route(retry=%d): %s", i, buf.String()) + r.logger.Debugf("route(retry=%d) %s", i, buf.String()) } ln, err = route.Bind(ctx, network, address, opts...) if err == nil { break } - r.logger.Errorf("route(retry=%d): %s", i, err) + r.logger.Errorf("route(retry=%d) %s", i, err) } return diff --git a/pkg/chain/transport.go b/pkg/chain/transport.go index d74c0af..c71e07e 100644 --- a/pkg/chain/transport.go +++ b/pkg/chain/transport.go @@ -49,8 +49,12 @@ func (tr *Transport) dialOptions() []dialer.DialOption { } func (tr *Transport) Handshake(ctx context.Context, conn net.Conn) (net.Conn, error) { + var err error if hs, ok := tr.dialer.(dialer.Handshaker); ok { - return hs.Handshake(ctx, conn) + conn, err = hs.Handshake(ctx, conn) + if err != nil { + return nil, err + } } if hs, ok := tr.connector.(connector.Handshaker); ok { return hs.Handshake(ctx, conn) diff --git a/pkg/common/util/relay/conn.go b/pkg/common/util/relay/conn.go deleted file mode 100644 index 51477b4..0000000 --- a/pkg/common/util/relay/conn.go +++ /dev/null @@ -1,58 +0,0 @@ -package relay - -import ( - "encoding/binary" - "errors" - "io" - "math" - "net" -) - -type packetConn struct { - net.Conn -} - -func UDPTunConn(conn net.Conn) net.Conn { - return &packetConn{ - Conn: conn, - } -} - -func (c *packetConn) Read(b []byte) (n int, err error) { - var bb [2]byte - _, err = io.ReadFull(c.Conn, bb[:]) - if err != nil { - return - } - - dlen := int(binary.BigEndian.Uint16(bb[:])) - if len(b) >= dlen { - return io.ReadFull(c.Conn, b[:dlen]) - } - buf := make([]byte, dlen) - _, err = io.ReadFull(c.Conn, buf) - n = copy(b, buf) - - return -} - -func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { - n, err = c.Read(b) - addr = c.Conn.RemoteAddr() - return -} - -func (c *packetConn) Write(b []byte) (n int, err error) { - if len(b) > math.MaxUint16 { - err = errors.New("write: data maximum exceeded") - return - } - - var bb [2]byte - binary.BigEndian.PutUint16(bb[:2], uint16(len(b))) - _, err = c.Conn.Write(bb[:]) - if err != nil { - return - } - return c.Conn.Write(b) -} diff --git a/pkg/common/util/ss/ss.go b/pkg/common/util/ss/ss.go index 1a26bb3..ca9b81e 100644 --- a/pkg/common/util/ss/ss.go +++ b/pkg/common/util/ss/ss.go @@ -51,8 +51,7 @@ func (c *shadowConn) Write(b []byte) (n int, err error) { n = len(b) // force byte length consistent if c.wbuf.Len() > 0 { c.wbuf.Write(b) // append the data to the cached header - _, err = c.Conn.Write(c.wbuf.Bytes()) - c.wbuf.Reset() + _, err = c.wbuf.WriteTo(c.Conn) return } _, err = c.Conn.Write(b) diff --git a/pkg/connector/forward/connector.go b/pkg/connector/forward/connector.go index 6b2f3b5..730ee7c 100644 --- a/pkg/connector/forward/connector.go +++ b/pkg/connector/forward/connector.go @@ -40,7 +40,7 @@ func (c *forwardConnector) Connect(ctx context.Context, conn net.Conn, network, "network": network, "address": address, }) - c.logger.Infof("connect: %s/%s", address, network) + c.logger.Infof("connect %s/%s", address, network) return conn, nil } diff --git a/pkg/connector/http/connector.go b/pkg/connector/http/connector.go index 38e9d8a..14e645b 100644 --- a/pkg/connector/http/connector.go +++ b/pkg/connector/http/connector.go @@ -48,7 +48,7 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add "network": network, "address": address, }) - c.logger.Infof("connect: %s/%s", address, network) + c.logger.Infof("connect %s/%s", address, network) switch network { case "tcp", "tcp4", "tcp6": diff --git a/pkg/connector/relay/bind.go b/pkg/connector/relay/bind.go new file mode 100644 index 0000000..924697a --- /dev/null +++ b/pkg/connector/relay/bind.go @@ -0,0 +1,131 @@ +package relay + +import ( + "context" + "fmt" + "net" + "strconv" + + "github.com/go-gost/gost/pkg/common/util/mux" + "github.com/go-gost/gost/pkg/common/util/socks" + "github.com/go-gost/gost/pkg/common/util/udp" + "github.com/go-gost/gost/pkg/connector" + "github.com/go-gost/relay" +) + +// Bind implements connector.Binder. +func (c *relayConnector) Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error) { + c.logger = c.logger.WithFields(map[string]interface{}{ + "network": network, + "address": address, + }) + c.logger.Infof("bind on %s/%s", address, network) + + options := connector.BindOptions{} + for _, opt := range opts { + opt(&options) + } + + switch network { + case "tcp", "tcp4", "tcp6": + return c.bindTCP(ctx, conn, network, address) + case "udp", "udp4", "udp6": + return c.bindUDP(ctx, conn, network, address, &options) + default: + err := fmt.Errorf("network %s is unsupported", network) + c.logger.Error(err) + return nil, err + } +} + +func (c *relayConnector) bindTCP(ctx context.Context, conn net.Conn, network, address string) (net.Listener, error) { + laddr, err := c.bind(conn, relay.BIND, network, address) + if err != nil { + return nil, err + } + + session, err := mux.ServerSession(conn) + if err != nil { + return nil, err + } + + return &tcpListener{ + addr: laddr, + session: session, + logger: c.logger, + }, nil +} + +func (c *relayConnector) bindUDP(ctx context.Context, conn net.Conn, network, address string, opts *connector.BindOptions) (net.Listener, error) { + laddr, err := c.bind(conn, relay.FUDP|relay.BIND, network, address) + if err != nil { + return nil, err + } + + ln := udp.NewListener( + socks.UDPTunClientPacketConn(conn), + laddr, + opts.Backlog, + opts.UDPDataQueueSize, opts.UDPDataBufferSize, + opts.UDPConnTTL, + c.logger) + + return ln, nil +} + +func (c *relayConnector) bind(conn net.Conn, cmd uint8, network, address string) (net.Addr, error) { + req := relay.Request{ + Version: relay.Version1, + Flags: cmd, + } + + if c.md.user != nil { + pwd, _ := c.md.user.Password() + req.Features = append(req.Features, &relay.UserAuthFeature{ + Username: c.md.user.Username(), + Password: pwd, + }) + } + fa := &relay.AddrFeature{} + fa.ParseFrom(address) + req.Features = append(req.Features, fa) + if _, err := req.WriteTo(conn); err != nil { + return nil, err + } + + // first reply, bind status + resp := relay.Response{} + if _, err := resp.ReadFrom(conn); err != nil { + return nil, err + } + + if resp.Status != relay.StatusOK { + return nil, fmt.Errorf("bind on %s/%s failed", address, network) + } + + var addr string + for _, f := range resp.Features { + if f.Type() == relay.FeatureAddr { + if fa, ok := f.(*relay.AddrFeature); ok { + addr = net.JoinHostPort(fa.Host, strconv.Itoa(int(fa.Port))) + } + } + } + + var baddr net.Addr + var err error + switch network { + case "tcp", "tcp4", "tcp6": + baddr, err = net.ResolveTCPAddr(network, addr) + case "udp", "udp4", "udp6": + baddr, err = net.ResolveUDPAddr(network, addr) + default: + err = fmt.Errorf("unknown network %s", network) + } + if err != nil { + return nil, err + } + c.logger.Debugf("bind on %s/%s OK", baddr, baddr.Network()) + + return baddr, nil +} diff --git a/pkg/connector/relay/conn.go b/pkg/connector/relay/conn.go index 6388171..c3354e5 100644 --- a/pkg/connector/relay/conn.go +++ b/pkg/connector/relay/conn.go @@ -6,45 +6,53 @@ import ( "errors" "fmt" "io" + "math" "net" "sync" - "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/relay" ) -type conn struct { +type tcpConn struct { net.Conn - udp bool - wbuf bytes.Buffer - once sync.Once - headerSent bool - logger logger.Logger + wbuf bytes.Buffer + once sync.Once } -func (c *conn) Read(b []byte) (n int, err error) { +func (c *tcpConn) Read(b []byte) (n int, err error) { c.once.Do(func() { - resp := relay.Response{} - _, err = resp.ReadFrom(c.Conn) - if err != nil { - return - } - if resp.Version != relay.Version1 { - err = relay.ErrBadVersion - return - } - if resp.Status != relay.StatusOK { - err = fmt.Errorf("status %d", resp.Status) - return - } + err = readResponse(c.Conn) }) if err != nil { return } + return c.Conn.Read(b) +} - if !c.udp { - return c.Conn.Read(b) +func (c *tcpConn) Write(b []byte) (n int, err error) { + n = len(b) // force byte length consistent + if c.wbuf.Len() > 0 { + c.wbuf.Write(b) // append the data to the cached header + _, err = c.wbuf.WriteTo(c.Conn) + return + } + _, err = c.Conn.Write(b) + return +} + +type udpConn struct { + net.Conn + wbuf bytes.Buffer + once sync.Once +} + +func (c *udpConn) Read(b []byte) (n int, err error) { + c.once.Do(func() { + err = readResponse(c.Conn) + }) + if err != nil { + return } var bb [2]byte @@ -52,6 +60,7 @@ func (c *conn) Read(b []byte) (n int, err error) { if err != nil { return } + dlen := int(binary.BigEndian.Uint16(bb[:])) if len(b) >= dlen { return io.ReadFull(c.Conn, b[:dlen]) @@ -59,59 +68,64 @@ func (c *conn) Read(b []byte) (n int, err error) { buf := make([]byte, dlen) _, err = io.ReadFull(c.Conn, buf) n = copy(b, buf) + return } -func (c *conn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { - n, err = c.Read(b) - addr = c.Conn.RemoteAddr() - return -} - -func (c *conn) Write(b []byte) (n int, err error) { - if len(b) > 0xFFFF { +func (c *udpConn) Write(b []byte) (n int, err error) { + if len(b) > math.MaxUint16 { err = errors.New("write: data maximum exceeded") return } - n = len(b) // force byte length consistent + + n = len(b) if c.wbuf.Len() > 0 { - if c.udp { - var bb [2]byte - binary.BigEndian.PutUint16(bb[:2], uint16(len(b))) - c.wbuf.Write(bb[:]) - c.headerSent = true - } + var bb [2]byte + binary.BigEndian.PutUint16(bb[:], uint16(len(b))) + c.wbuf.Write(bb[:]) c.wbuf.Write(b) // append the data to the cached header - // _, err = c.Conn.Write(c.wbuf.Bytes()) - // c.wbuf.Reset() _, err = c.wbuf.WriteTo(c.Conn) return } - if !c.udp { - return c.Conn.Write(b) - } - if !c.headerSent { - c.headerSent = true - b2 := make([]byte, len(b)+2) - copy(b2, b) - _, err = c.Conn.Write(b2) + var bb [2]byte + binary.BigEndian.PutUint16(bb[:], uint16(len(b))) + _, err = c.Conn.Write(bb[:]) + if err != nil { return } - nsize := 2 + len(b) - var buf []byte - if nsize <= mediumBufferSize { - buf = mPool.Get().([]byte) - defer mPool.Put(buf) - } else { - buf = make([]byte, nsize) - } - binary.BigEndian.PutUint16(buf[:2], uint16(len(b))) - n = copy(buf[2:], b) - _, err = c.Conn.Write(buf[:nsize]) - return + return c.Conn.Write(b) } -func (c *relayConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - return c.Write(b) +func readResponse(r io.Reader) (err error) { + resp := relay.Response{} + _, err = resp.ReadFrom(r) + if err != nil { + return + } + + if resp.Version != relay.Version1 { + err = relay.ErrBadVersion + return + } + + if resp.Status != relay.StatusOK { + err = fmt.Errorf("status %d", resp.Status) + return + } + return nil +} + +type bindConn struct { + net.Conn + localAddr net.Addr + remoteAddr net.Addr +} + +func (c *bindConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *bindConn) RemoteAddr() net.Addr { + return c.remoteAddr } diff --git a/pkg/connector/relay/connector.go b/pkg/connector/relay/connector.go index 7f9788d..3cb8fee 100644 --- a/pkg/connector/relay/connector.go +++ b/pkg/connector/relay/connector.go @@ -2,9 +2,11 @@ package relay import ( "context" + "fmt" "net" "time" + "github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -43,23 +45,30 @@ func (c *relayConnector) Connect(ctx context.Context, conn net.Conn, network, ad "network": network, "address": address, }) - c.logger.Infof("connect: %s/%s", address, network) + c.logger.Infof("connect %s/%s", address, network) if c.md.connectTimeout > 0 { conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) defer conn.SetDeadline(time.Time{}) } - var udpMode bool - if network == "udp" || network == "udp4" || network == "udp6" { - udpMode = true - } - req := relay.Request{ Version: relay.Version1, + Flags: relay.CONNECT, } - if udpMode { + if network == "udp" || network == "udp4" || network == "udp6" { req.Flags |= relay.FUDP + + // UDP association + if address == "" { + baddr, err := c.bind(conn, relay.FUDP|relay.BIND, network, address) + if err != nil { + return nil, err + } + c.logger.Debugf("associate on %s OK", baddr) + + return socks.UDPTunClientConn(conn, nil), nil + } } if c.md.user != nil { @@ -76,7 +85,43 @@ func (c *relayConnector) Connect(ctx context.Context, conn net.Conn, network, ad return nil, err } - req.Features = append(req.Features, af) + // forward mode if port is 0. + if af.Port > 0 { + req.Features = append(req.Features, af) + } + } + + if c.md.noDelay { + if _, err := req.WriteTo(conn); err != nil { + return nil, err + } + } + + switch network { + case "tcp", "tcp4", "tcp6": + cc := &tcpConn{ + Conn: conn, + } + if !c.md.noDelay { + if _, err := req.WriteTo(&cc.wbuf); err != nil { + return nil, err + } + } + conn = cc + case "udp", "udp4", "udp6": + cc := &udpConn{ + Conn: conn, + } + if !c.md.noDelay { + if _, err := req.WriteTo(&cc.wbuf); err != nil { + return nil, err + } + } + conn = cc + default: + err := fmt.Errorf("network %s is unsupported", network) + c.logger.Error(err) + return nil, err } return conn, nil diff --git a/pkg/connector/relay/listener.go b/pkg/connector/relay/listener.go new file mode 100644 index 0000000..6721fe9 --- /dev/null +++ b/pkg/connector/relay/listener.go @@ -0,0 +1,73 @@ +package relay + +import ( + "fmt" + "net" + "strconv" + + "github.com/go-gost/gost/pkg/common/util/mux" + "github.com/go-gost/gost/pkg/logger" + "github.com/go-gost/relay" +) + +type tcpListener struct { + addr net.Addr + session *mux.Session + logger logger.Logger +} + +func (p *tcpListener) Accept() (net.Conn, error) { + cc, err := p.session.Accept() + if err != nil { + return nil, err + } + + conn, err := p.getPeerConn(cc) + if err != nil { + cc.Close() + return nil, err + } + + return conn, nil +} + +func (p *tcpListener) getPeerConn(conn net.Conn) (net.Conn, error) { + // second reply, peer connected + resp := relay.Response{} + if _, err := resp.ReadFrom(conn); err != nil { + return nil, err + } + + if resp.Status != relay.StatusOK { + err := fmt.Errorf("peer connect failed") + return nil, err + } + + var address string + for _, f := range resp.Features { + if f.Type() == relay.FeatureAddr { + if fa, ok := f.(*relay.AddrFeature); ok { + address = net.JoinHostPort(fa.Host, strconv.Itoa(int(fa.Port))) + } + } + } + + raddr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + return nil, err + } + + return &bindConn{ + Conn: conn, + localAddr: p.addr, + remoteAddr: raddr, + }, nil +} + +func (p *tcpListener) Addr() net.Addr { + return p.addr +} + +func (p *tcpListener) Close() error { + return p.session.Close() +} diff --git a/pkg/connector/relay/metadata.go b/pkg/connector/relay/metadata.go index 58e0a8b..76e64d9 100644 --- a/pkg/connector/relay/metadata.go +++ b/pkg/connector/relay/metadata.go @@ -11,14 +11,14 @@ import ( type metadata struct { connectTimeout time.Duration user *url.Userinfo - nodelay bool + noDelay bool } func (c *relayConnector) parseMetadata(md md.Metadata) (err error) { const ( user = "user" connectTimeout = "connectTimeout" - nodelay = "nodelay" + noDelay = "nodelay" ) if v := md.GetString(user); v != "" { @@ -30,7 +30,7 @@ func (c *relayConnector) parseMetadata(md md.Metadata) (err error) { } } c.md.connectTimeout = md.GetDuration(connectTimeout) - c.md.nodelay = md.GetBool(nodelay) + c.md.noDelay = md.GetBool(noDelay) return } diff --git a/pkg/connector/socks/v4/connector.go b/pkg/connector/socks/v4/connector.go index f079167..e18eb12 100644 --- a/pkg/connector/socks/v4/connector.go +++ b/pkg/connector/socks/v4/connector.go @@ -47,7 +47,7 @@ func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, a "network": network, "address": address, }) - c.logger.Infof("connect: %s/%s", address, network) + c.logger.Infof("connect %s/%s", address, network) switch network { case "tcp", "tcp4", "tcp6": diff --git a/pkg/connector/socks/v5/connector.go b/pkg/connector/socks/v5/connector.go index 2b31999..b12e651 100644 --- a/pkg/connector/socks/v5/connector.go +++ b/pkg/connector/socks/v5/connector.go @@ -91,7 +91,12 @@ func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, a "network": network, "address": address, }) - c.logger.Infof("connect: %s/%s", address, network) + c.logger.Infof("connect %s/%s", address, network) + + if c.md.connectTimeout > 0 { + conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) + defer conn.SetDeadline(time.Time{}) + } switch network { case "udp", "udp4", "udp6": @@ -114,11 +119,6 @@ func (c *socks5Connector) Connect(ctx context.Context, conn net.Conn, network, a return nil, err } - if c.md.connectTimeout > 0 { - conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) - defer conn.SetDeadline(time.Time{}) - } - req := gosocks5.NewRequest(gosocks5.CmdConnect, &addr) if err := req.Write(conn); err != nil { c.logger.Error(err) diff --git a/pkg/connector/ss/connector.go b/pkg/connector/ss/connector.go index 8e04b71..efb3375 100644 --- a/pkg/connector/ss/connector.go +++ b/pkg/connector/ss/connector.go @@ -46,7 +46,7 @@ func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, addre "network": network, "address": address, }) - c.logger.Infof("connect: %s/%s", address, network) + c.logger.Infof("connect %s/%s", address, network) switch network { case "tcp", "tcp4", "tcp6": diff --git a/pkg/connector/ss/udp/connector.go b/pkg/connector/ss/udp/connector.go index e909ae9..cadc3f2 100644 --- a/pkg/connector/ss/udp/connector.go +++ b/pkg/connector/ss/udp/connector.go @@ -45,7 +45,7 @@ func (c *ssuConnector) Connect(ctx context.Context, conn net.Conn, network, addr "network": network, "address": address, }) - c.logger.Infof("connect: %s/%s", address, network) + c.logger.Infof("connect %s/%s", address, network) switch network { case "udp", "udp4", "udp6": diff --git a/pkg/handler/auto/handler.go b/pkg/handler/auto/handler.go index 8b964be..731de23 100644 --- a/pkg/handler/auto/handler.go +++ b/pkg/handler/auto/handler.go @@ -4,16 +4,19 @@ import ( "bufio" "context" "net" + "time" "github.com/go-gost/gosocks4" "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/handler" http_handler "github.com/go-gost/gost/pkg/handler/http" + relay_handler "github.com/go-gost/gost/pkg/handler/relay" socks4_handler "github.com/go-gost/gost/pkg/handler/socks/v4" socks5_handler "github.com/go-gost/gost/pkg/handler/socks/v5" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" + "github.com/go-gost/relay" ) func init() { @@ -24,6 +27,7 @@ type autoHandler struct { httpHandler handler.Handler socks4Handler handler.Handler socks5Handler handler.Handler + relayHandler handler.Handler log logger.Logger } @@ -53,6 +57,10 @@ func NewHandler(opts ...handler.Option) handler.Handler { v = append(opts, handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "socks5"}))) h.socks5Handler = socks5_handler.NewHandler(v...) + + v = append(opts, + handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "relay"}))) + h.relayHandler = relay_handler.NewHandler(v...) return h } @@ -66,6 +74,9 @@ func (h *autoHandler) Init(md md.Metadata) error { if err := h.socks5Handler.Init(md); err != nil { return err } + if err := h.relayHandler.Init(md); err != nil { + return err + } return nil } @@ -75,6 +86,14 @@ func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) { "local": conn.LocalAddr().String(), }) + start := time.Now() + h.log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + h.log.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + br := bufio.NewReader(conn) b, err := br.Peek(1) if err != nil { @@ -89,6 +108,8 @@ func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) { h.socks4Handler.Handle(ctx, cc) case gosocks5.Ver5: // socks5 h.socks5Handler.Handle(ctx, cc) + case relay.Version1: // relay + h.relayHandler.Handle(ctx, cc) default: // http h.httpHandler.Handle(ctx, cc) } diff --git a/pkg/handler/forward/local/handler.go b/pkg/handler/forward/local/handler.go index c39feb9..ad6a57a 100644 --- a/pkg/handler/forward/local/handler.go +++ b/pkg/handler/forward/local/handler.go @@ -17,6 +17,7 @@ import ( func init() { registry.RegisterHandler("tcp", NewHandler) registry.RegisterHandler("udp", NewHandler) + registry.RegisterHandler("forward", NewHandler) } type forwardHandler struct { @@ -40,7 +41,15 @@ func NewHandler(opts ...handler.Option) handler.Handler { } func (h *forwardHandler) Init(md md.Metadata) (err error) { - return h.parseMetadata(md) + if err = h.parseMetadata(md); err != nil { + return + } + + if h.group == nil { + // dummy node used by relay connector. + h.group = chain.NewNodeGroup(chain.NewNode("dummy", ":0")) + } + return nil } // WithChain implements chain.Chainable interface diff --git a/pkg/handler/relay/conn.go b/pkg/handler/relay/conn.go new file mode 100644 index 0000000..350ed0a --- /dev/null +++ b/pkg/handler/relay/conn.go @@ -0,0 +1,81 @@ +package relay + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "math" + "net" +) + +type tcpConn struct { + net.Conn + wbuf bytes.Buffer +} + +func (c *tcpConn) Read(b []byte) (n int, err error) { + if err != nil { + return + } + return c.Conn.Read(b) +} + +func (c *tcpConn) Write(b []byte) (n int, err error) { + n = len(b) // force byte length consistent + if c.wbuf.Len() > 0 { + c.wbuf.Write(b) // append the data to the cached header + _, err = c.wbuf.WriteTo(c.Conn) + return + } + _, err = c.Conn.Write(b) + return +} + +type udpConn struct { + net.Conn + wbuf bytes.Buffer +} + +func (c *udpConn) Read(b []byte) (n int, err error) { + var bb [2]byte + _, err = io.ReadFull(c.Conn, bb[:]) + if err != nil { + return + } + + dlen := int(binary.BigEndian.Uint16(bb[:])) + if len(b) >= dlen { + return io.ReadFull(c.Conn, b[:dlen]) + } + buf := make([]byte, dlen) + _, err = io.ReadFull(c.Conn, buf) + n = copy(b, buf) + + return +} + +func (c *udpConn) Write(b []byte) (n int, err error) { + if len(b) > math.MaxUint16 { + err = errors.New("write: data maximum exceeded") + return + } + + n = len(b) + if c.wbuf.Len() > 0 { + var bb [2]byte + binary.BigEndian.PutUint16(bb[:], uint16(len(b))) + c.wbuf.Write(bb[:]) + c.wbuf.Write(b) // append the data to the cached header + _, err = c.wbuf.WriteTo(c.Conn) + return + } + + var bb [2]byte + binary.BigEndian.PutUint16(bb[:], uint16(len(b))) + _, err = c.Conn.Write(bb[:]) + if err != nil { + return + } + return c.Conn.Write(b) +} diff --git a/pkg/handler/relay/connect.go b/pkg/handler/relay/connect.go index 15de5b6..648cadb 100644 --- a/pkg/handler/relay/connect.go +++ b/pkg/handler/relay/connect.go @@ -7,7 +7,6 @@ import ( "time" "github.com/go-gost/gost/pkg/chain" - util_relay "github.com/go-gost/gost/pkg/common/util/relay" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/relay" ) @@ -51,12 +50,36 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network } defer cc.Close() - if _, err := resp.WriteTo(conn); err != nil { - h.logger.Error(err) + if h.md.noDelay { + if _, err := resp.WriteTo(conn); err != nil { + h.logger.Error(err) + return + } } - if network == "udp" { - conn = util_relay.UDPTunConn(conn) + switch network { + case "udp", "udp4", "udp6": + rc := &udpConn{ + Conn: conn, + } + if !h.md.noDelay { + // cache the header + if _, err := resp.WriteTo(&rc.wbuf); err != nil { + return + } + } + conn = rc + default: + rc := &tcpConn{ + Conn: conn, + } + if !h.md.noDelay { + // cache the header + if _, err := resp.WriteTo(&rc.wbuf); err != nil { + return + } + } + conn = rc } t := time.Now() diff --git a/pkg/handler/relay/forward.go b/pkg/handler/relay/forward.go index 1ed55b6..96389a5 100644 --- a/pkg/handler/relay/forward.go +++ b/pkg/handler/relay/forward.go @@ -8,11 +8,18 @@ import ( "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/relay" ) func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network string) { + resp := relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + } target := h.group.Next() if target == nil { + resp.Status = relay.StatusServiceUnavailable + resp.WriteTo(conn) h.logger.Error("no target available") return } @@ -30,15 +37,51 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network cc, err := r.Dial(ctx, network, target.Addr()) if err != nil { - h.logger.Error(err) // TODO: the router itself may be failed due to the failed node in the router, // the dead marker may be a wrong operation. target.Marker().Mark() + + resp.Status = relay.StatusHostUnreachable + resp.WriteTo(conn) + h.logger.Error(err) + return } defer cc.Close() target.Marker().Reset() + if h.md.noDelay { + if _, err := resp.WriteTo(conn); err != nil { + h.logger.Error(err) + return + } + } + + switch network { + case "udp", "udp4", "udp6": + rc := &udpConn{ + Conn: conn, + } + if !h.md.noDelay { + // cache the header + if _, err := resp.WriteTo(&rc.wbuf); err != nil { + return + } + } + conn = rc + default: + rc := &tcpConn{ + Conn: conn, + } + if !h.md.noDelay { + // cache the header + if _, err := resp.WriteTo(&rc.wbuf); err != nil { + return + } + } + conn = rc + } + t := time.Now() h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr()) handler.Transport(conn, cc) diff --git a/pkg/handler/relay/handler.go b/pkg/handler/relay/handler.go index e00c832..1ad1152 100644 --- a/pkg/handler/relay/handler.go +++ b/pkg/handler/relay/handler.go @@ -123,7 +123,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { if address != "" { resp.Status = relay.StatusForbidden resp.WriteTo(conn) - h.logger.Error("forbidden") + h.logger.Error("forward mode, connect is forbidden") return } // forward mode @@ -132,7 +132,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { } switch req.Flags & relay.CmdMask { - case relay.CONNECT: + case 0, relay.CONNECT: h.handleConnect(ctx, conn, network, address) case relay.BIND: h.handleBind(ctx, conn, network, address) diff --git a/pkg/handler/relay/metadata.go b/pkg/handler/relay/metadata.go index 46a408c..3fa1b71 100644 --- a/pkg/handler/relay/metadata.go +++ b/pkg/handler/relay/metadata.go @@ -14,6 +14,7 @@ type metadata struct { retryCount int enableBind bool udpBufferSize int + noDelay bool } func (h *relayHandler) parseMetadata(md md.Metadata) (err error) { @@ -23,6 +24,7 @@ func (h *relayHandler) parseMetadata(md md.Metadata) (err error) { retryCount = "retry" enableBind = "bind" udpBufferSize = "udpBufferSize" + noDelay = "nodelay" ) if v, _ := md.Get(users).([]interface{}); len(v) > 0 { @@ -42,6 +44,7 @@ func (h *relayHandler) parseMetadata(md md.Metadata) (err error) { h.md.readTimeout = md.GetDuration(readTimeout) h.md.retryCount = md.GetInt(retryCount) h.md.enableBind = md.GetBool(enableBind) + h.md.noDelay = md.GetBool(noDelay) h.md.udpBufferSize = md.GetInt(udpBufferSize) if h.md.udpBufferSize > 0 { if h.md.udpBufferSize < 512 { diff --git a/pkg/handler/socks/v5/udp.go b/pkg/handler/socks/v5/udp.go index 725b157..409fee3 100644 --- a/pkg/handler/socks/v5/udp.go +++ b/pkg/handler/socks/v5/udp.go @@ -2,7 +2,6 @@ package v5 import ( "context" - "errors" "fmt" "io" "io/ioutil" @@ -10,7 +9,6 @@ import ( "time" "github.com/go-gost/gosocks5" - "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/common/bufpool" "github.com/go-gost/gost/pkg/common/util/socks" ) @@ -54,153 +52,24 @@ func (h *socks5Handler) handleUDP(ctx context.Context, conn net.Conn) { }) h.logger.Debugf("bind on %s OK", relay.LocalAddr()) - if h.chain.IsEmpty() { - // serve as standard socks5 udp relay. - peer, err := net.ListenUDP("udp", nil) - if err != nil { - h.logger.Error(err) - return - } - defer peer.Close() - - go h.relayUDP( - socks.UDPConn(relay, h.md.udpBufferSize), - peer, - ) - } else { - tun, err := h.getUDPTun(ctx) - if err != nil { - h.logger.Error(err) - return - } - defer tun.Close() - - go h.tunnelClientUDP( - socks.UDPConn(relay, h.md.udpBufferSize), - socks.UDPTunClientPacketConn(tun), - ) + peer, err := net.ListenUDP("udp", nil) + if err != nil { + h.logger.Error(err) + return } + defer peer.Close() + + go h.relayUDP( + socks.UDPConn(relay, h.md.udpBufferSize), + peer, + ) t := time.Now() - h.logger.Infof("%s <-> %s", conn.RemoteAddr(), &saddr) + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), relay.LocalAddr()) io.Copy(ioutil.Discard, conn) h.logger. WithFields(map[string]interface{}{"duration": time.Since(t)}). - Infof("%s >-< %s", conn.RemoteAddr(), &saddr) -} - -func (h *socks5Handler) getUDPTun(ctx context.Context) (conn net.Conn, err error) { - r := (&chain.Router{}). - WithChain(h.chain). - WithRetry(h.md.retryCount). - WithLogger(h.logger) - conn, err = r.Connect(ctx) - if err != nil { - return nil, err - } - - defer func() { - if err != nil { - conn.Close() - conn = nil - } - }() - - if h.md.timeout > 0 { - conn.SetDeadline(time.Now().Add(h.md.timeout)) - defer conn.SetDeadline(time.Time{}) - } - - req := gosocks5.NewRequest(socks.CmdUDPTun, nil) - if err = req.Write(conn); err != nil { - return - } - h.logger.Debug(req) - - reply, err := gosocks5.ReadReply(conn) - if err != nil { - return - } - h.logger.Debug(reply) - - if reply.Rep != gosocks5.Succeeded { - err = errors.New("UDP associate failed") - return - } - - return -} - -func (h *socks5Handler) tunnelClientUDP(c, tun net.PacketConn) (err error) { - bufSize := h.md.udpBufferSize - errc := make(chan error, 2) - - go func() { - for { - err := func() error { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - n, raddr, err := c.ReadFrom(b) - if err != nil { - return err - } - - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr) - return nil - } - - if _, err := tun.WriteTo(b[:n], raddr); err != nil { - return err - } - - h.logger.Debugf("%s >>> %s data: %d", - tun.LocalAddr(), raddr, n) - - return nil - }() - - if err != nil { - errc <- err - return - } - } - }() - - go func() { - for { - err := func() error { - b := bufpool.Get(bufSize) - defer bufpool.Put(b) - - n, raddr, err := tun.ReadFrom(b) - if err != nil { - return err - } - if h.bypass != nil && h.bypass.Contains(raddr.String()) { - h.logger.Warn("bypass: ", raddr) - return nil - } - - if _, err := c.WriteTo(b[:n], raddr); err != nil { - return err - } - - h.logger.Debugf("%s <<< %s data: %d", - tun.LocalAddr(), raddr, n) - - return nil - }() - - if err != nil { - errc <- err - return - } - } - }() - - return <-errc + Infof("%s >-< %s", conn.RemoteAddr(), relay.LocalAddr()) } func (h *socks5Handler) relayUDP(c, peer net.PacketConn) (err error) { diff --git a/pkg/handler/ss/udp/handler.go b/pkg/handler/ss/udp/handler.go index f9e670d..bd736b5 100644 --- a/pkg/handler/ss/udp/handler.go +++ b/pkg/handler/ss/udp/handler.go @@ -84,18 +84,18 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn) { WithChain(h.chain). WithRetry(h.md.retryCount). WithLogger(h.logger) - c, err := r.Dial(ctx, "udp", "") + c, err := r.Dial(ctx, "udp", "") // UDP association if err != nil { h.logger.Error(err) return } + defer c.Close() cc, ok := c.(net.PacketConn) if !ok { - h.logger.Errorf("%s: not a packet connection") + h.logger.Errorf("wrong connection type") return } - defer cc.Close() t := time.Now() h.logger.Infof("%s <-> %s", conn.RemoteAddr(), cc.LocalAddr()) diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index b96909d..b1cb609 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -5,6 +5,7 @@ import ( "github.com/go-gost/gost/pkg/dialer" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/listener" + "github.com/go-gost/gost/pkg/logger" ) type NewListener func(opts ...listener.Option) listener.Listener @@ -20,6 +21,9 @@ var ( ) func RegisterListener(name string, newf NewListener) { + if listeners[name] != nil { + logger.Default().Fatalf("register duplicate listener: %s", name) + } listeners[name] = newf } @@ -28,6 +32,9 @@ func GetListener(name string) NewListener { } func RegisterHandler(name string, newf NewHandler) { + if handlers[name] != nil { + logger.Default().Fatalf("register duplicate handler: %s", name) + } handlers[name] = newf } @@ -36,6 +43,9 @@ func GetHandler(name string) NewHandler { } func RegisterDialer(name string, newf NewDialer) { + if dialers[name] != nil { + logger.Default().Fatalf("register duplicate dialer: %s", name) + } dialers[name] = newf } @@ -44,6 +54,9 @@ func GetDialer(name string) NewDialer { } func RegiserConnector(name string, newf NewConnector) { + if connectors[name] != nil { + logger.Default().Fatalf("register duplicate connector: %s", name) + } connectors[name] = newf }