diff --git a/connector/relay/conn.go b/connector/relay/conn.go index 384aaa0..2ce2263 100644 --- a/connector/relay/conn.go +++ b/connector/relay/conn.go @@ -17,13 +17,15 @@ import ( type tcpConn struct { net.Conn - wbuf bytes.Buffer + wbuf *bytes.Buffer once sync.Once } func (c *tcpConn) Read(b []byte) (n int, err error) { c.once.Do(func() { - err = readResponse(c.Conn) + if c.wbuf != nil { + err = readResponse(c.Conn) + } }) if err != nil { @@ -34,7 +36,7 @@ func (c *tcpConn) Read(b []byte) (n int, err error) { func (c *tcpConn) Write(b []byte) (n int, err error) { n = len(b) // force byte length consistent - if c.wbuf.Len() > 0 { + if c.wbuf != nil && c.wbuf.Len() > 0 { c.wbuf.Write(b) // append the data to the cached header _, err = c.Conn.Write(c.wbuf.Bytes()) c.wbuf.Reset() @@ -46,13 +48,15 @@ func (c *tcpConn) Write(b []byte) (n int, err error) { type udpConn struct { net.Conn - wbuf bytes.Buffer + wbuf *bytes.Buffer once sync.Once } func (c *udpConn) Read(b []byte) (n int, err error) { c.once.Do(func() { - err = readResponse(c.Conn) + if c.wbuf != nil { + err = readResponse(c.Conn) + } }) if err != nil { return @@ -84,7 +88,7 @@ func (c *udpConn) Write(b []byte) (n int, err error) { } n = len(b) - if c.wbuf.Len() > 0 { + if c.wbuf != nil && c.wbuf.Len() > 0 { var bb [2]byte binary.BigEndian.PutUint16(bb[:], uint16(len(b))) c.wbuf.Write(bb[:]) diff --git a/connector/relay/connector.go b/connector/relay/connector.go index 5a8c904..9e41ca2 100644 --- a/connector/relay/connector.go +++ b/connector/relay/connector.go @@ -1,6 +1,7 @@ package relay import ( + "bytes" "context" "fmt" "net" @@ -121,8 +122,9 @@ func (c *relayConnector) Connect(ctx context.Context, conn net.Conn, network, ad if !c.md.noDelay { cc := &tcpConn{ Conn: conn, + wbuf: &bytes.Buffer{}, } - if _, err := req.WriteTo(&cc.wbuf); err != nil { + if _, err := req.WriteTo(cc.wbuf); err != nil { return nil, err } conn = cc @@ -132,7 +134,8 @@ func (c *relayConnector) Connect(ctx context.Context, conn net.Conn, network, ad Conn: conn, } if !c.md.noDelay { - if _, err := req.WriteTo(&cc.wbuf); err != nil { + cc.wbuf = &bytes.Buffer{} + if _, err := req.WriteTo(cc.wbuf); err != nil { return nil, err } } diff --git a/connector/tunnel/bind.go b/connector/tunnel/bind.go index 58c2168..ed2a093 100644 --- a/connector/tunnel/bind.go +++ b/connector/tunnel/bind.go @@ -53,13 +53,16 @@ func (c *tunnelConnector) initTunnel(conn net.Conn, network, address string) (ad } af := &relay.AddrFeature{} - af.ParseFrom(address) + af.ParseFrom(conn.LocalAddr().String()) // src address + req.Features = append(req.Features, af) - req.Features = append(req.Features, af, - &relay.TunnelFeature{ - ID: c.md.tunnelID.ID(), - }, - ) + af = &relay.AddrFeature{} + af.ParseFrom(address) + req.Features = append(req.Features, af) // dst address + + req.Features = append(req.Features, &relay.TunnelFeature{ + ID: c.md.tunnelID.ID(), + }) if _, err = req.WriteTo(conn); err != nil { return } diff --git a/connector/tunnel/conn.go b/connector/tunnel/conn.go index 0929981..59cd0c1 100644 --- a/connector/tunnel/conn.go +++ b/connector/tunnel/conn.go @@ -17,13 +17,15 @@ import ( type tcpConn struct { net.Conn - wbuf bytes.Buffer + wbuf *bytes.Buffer once sync.Once } func (c *tcpConn) Read(b []byte) (n int, err error) { c.once.Do(func() { - err = readResponse(c.Conn) + if c.wbuf != nil { + err = readResponse(c.Conn) + } }) if err != nil { @@ -34,7 +36,7 @@ func (c *tcpConn) Read(b []byte) (n int, err error) { func (c *tcpConn) Write(b []byte) (n int, err error) { n = len(b) // force byte length consistent - if c.wbuf.Len() > 0 { + if c.wbuf != nil && c.wbuf.Len() > 0 { c.wbuf.Write(b) // append the data to the cached header _, err = c.Conn.Write(c.wbuf.Bytes()) c.wbuf.Reset() @@ -46,13 +48,15 @@ func (c *tcpConn) Write(b []byte) (n int, err error) { type udpConn struct { net.Conn - wbuf bytes.Buffer + wbuf *bytes.Buffer once sync.Once } func (c *udpConn) Read(b []byte) (n int, err error) { c.once.Do(func() { - err = readResponse(c.Conn) + if c.wbuf != nil { + err = readResponse(c.Conn) + } }) if err != nil { return @@ -84,7 +88,7 @@ func (c *udpConn) Write(b []byte) (n int, err error) { } n = len(b) - if c.wbuf.Len() > 0 { + if c.wbuf != nil && c.wbuf.Len() > 0 { var bb [2]byte binary.BigEndian.PutUint16(bb[:], uint16(len(b))) c.wbuf.Write(bb[:]) diff --git a/connector/tunnel/connector.go b/connector/tunnel/connector.go index 960ad3f..824cf28 100644 --- a/connector/tunnel/connector.go +++ b/connector/tunnel/connector.go @@ -1,6 +1,7 @@ package tunnel import ( + "bytes" "context" "fmt" "net" @@ -9,6 +10,7 @@ import ( "github.com/go-gost/core/connector" md "github.com/go-gost/core/metadata" "github.com/go-gost/relay" + auth_util "github.com/go-gost/x/internal/util/auth" "github.com/go-gost/x/registry" ) @@ -55,6 +57,14 @@ func (c *tunnelConnector) Connect(ctx context.Context, conn net.Conn, network, a Cmd: relay.CmdConnect, } + switch network { + case "udp", "udp4", "udp6": + req.Cmd |= relay.FUDP + req.Features = append(req.Features, &relay.NetworkFeature{ + Network: relay.NetworkUDP, + }) + } + if c.options.Auth != nil { pwd, _ := c.options.Auth.Password() req.Features = append(req.Features, &relay.UserAuthFeature{ @@ -63,33 +73,54 @@ func (c *tunnelConnector) Connect(ctx context.Context, conn net.Conn, network, a }) } - if address != "" { - af := &relay.AddrFeature{} - if err := af.ParseFrom(address); err != nil { - return nil, err - } - req.Features = append(req.Features, af) + srcAddr := conn.LocalAddr().String() + if v := auth_util.ClientAddrFromContext(ctx); v != "" { + srcAddr = string(v) } + af := &relay.AddrFeature{} + af.ParseFrom(srcAddr) + req.Features = append(req.Features, af) // src address + + af = &relay.AddrFeature{} + af.ParseFrom(address) + req.Features = append(req.Features, af) // dst address + req.Features = append(req.Features, &relay.TunnelFeature{ ID: c.md.tunnelID.ID(), }) - switch network { - case "tcp", "tcp4", "tcp6", "unix", "serial": - cc := &tcpConn{ - Conn: conn, - } - if _, err := req.WriteTo(&cc.wbuf); err != nil { + if c.md.noDelay { + if _, err := req.WriteTo(conn); err != nil { return nil, err } - conn = cc + // drain the response + if err := readResponse(conn); err != nil { + return nil, err + } + } + + switch network { + case "tcp", "tcp4", "tcp6": + if !c.md.noDelay { + cc := &tcpConn{ + Conn: conn, + wbuf: &bytes.Buffer{}, + } + if _, err := req.WriteTo(cc.wbuf); err != nil { + return nil, err + } + conn = cc + } case "udp", "udp4", "udp6": cc := &udpConn{ Conn: conn, } - if _, err := req.WriteTo(&cc.wbuf); err != nil { - return nil, err + if !c.md.noDelay { + cc.wbuf = &bytes.Buffer{} + if _, err := req.WriteTo(cc.wbuf); err != nil { + return nil, err + } } conn = cc default: diff --git a/connector/tunnel/metadata.go b/connector/tunnel/metadata.go index cd879df..7a56786 100644 --- a/connector/tunnel/metadata.go +++ b/connector/tunnel/metadata.go @@ -17,15 +17,16 @@ var ( type metadata struct { connectTimeout time.Duration tunnelID relay.TunnelID + noDelay bool } func (c *tunnelConnector) parseMetadata(md mdata.Metadata) (err error) { const ( connectTimeout = "connectTimeout" - noDelay = "nodelay" ) c.md.connectTimeout = mdutil.GetDuration(md, connectTimeout) + c.md.noDelay = mdutil.GetBool(md, "nodelay") if s := mdutil.GetString(md, "tunnelID", "tunnel.id"); s != "" { uuid, err := uuid.Parse(s) diff --git a/go.mod b/go.mod index 0d4b1b6..e266b19 100644 --- a/go.mod +++ b/go.mod @@ -30,7 +30,6 @@ require ( github.com/sirupsen/logrus v1.8.1 github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/spf13/viper v1.14.0 - github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07 github.com/vishvananda/netlink v1.1.0 github.com/xtaci/kcp-go/v5 v5.6.1 github.com/xtaci/smux v1.5.24 diff --git a/go.sum b/go.sum index 86744f3..b30dcfe 100644 --- a/go.sum +++ b/go.sum @@ -332,8 +332,6 @@ github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gt github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/subosito/gotenv v1.4.1 h1:jyEFiXpy21Wm81FBN71l9VoMMV8H8jG+qIK3GCpY6Qs= github.com/subosito/gotenv v1.4.1/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0= -github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07 h1:UyzmZLoiDWMRywV4DUYb9Fbt8uiOSooupjTq10vpvnU= -github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/templexxx/cpu v0.0.1/go.mod h1:w7Tb+7qgcAlIyX4NhLuDKt78AHA5SzPmq0Wj6HiEnnk= github.com/templexxx/cpu v0.0.7 h1:pUEZn8JBy/w5yzdYWgx+0m0xL9uk6j4K91C5kOViAzo= github.com/templexxx/cpu v0.0.7/go.mod h1:w7Tb+7qgcAlIyX4NhLuDKt78AHA5SzPmq0Wj6HiEnnk= diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index c8b03e2..a760b24 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -10,7 +10,6 @@ import ( "net" "net/http" "net/http/httputil" - "sync" "time" "github.com/go-gost/core/chain" @@ -181,7 +180,6 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log logger.Logger) (err error) { br := bufio.NewReader(rw) - var connPool sync.Map for { resp := &http.Response{ @@ -230,45 +228,28 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) } - var cc net.Conn - if v, ok := connPool.Load(target); ok { - cc = v.(net.Conn) - log.Debugf("connection to node %s(%s) found in pool", target.Name, target.Addr) - } - if cc == nil { - cc, err = h.router.Dial(ctx, "tcp", target.Addr) - if err != nil { - // TODO: the router itself may be failed due to the failed node in the router, - // the dead marker may be a wrong operation. - if marker := target.Marker(); marker != nil { - marker.Mark() - } - log.Warnf("connect to node %s(%s) failed: %v", target.Name, target.Addr, err) - return resp.Write(rw) - } + cc, err := h.router.Dial(ctx, "tcp", target.Addr) + if err != nil { + // TODO: the router itself may be failed due to the failed node in the router, + // the dead marker may be a wrong operation. if marker := target.Marker(); marker != nil { - marker.Reset() + marker.Mark() } + log.Warnf("connect to node %s(%s) failed: %v", target.Name, target.Addr, err) + return resp.Write(rw) + } + if marker := target.Marker(); marker != nil { + marker.Reset() + } + defer cc.Close() - if tlsSettings := target.Options().TLS; tlsSettings != nil { - cc = tls.Client(cc, &tls.Config{ - ServerName: tlsSettings.ServerName, - InsecureSkipVerify: !tlsSettings.Secure, - }) - } + log.Debugf("new connection to node %s(%s)", target.Name, target.Addr) - connPool.Store(target, cc) - log.Debugf("new connection to node %s(%s)", target.Name, target.Addr) - - go func() { - defer cc.Close() - err := xnet.CopyBuffer(rw, cc, 8192) - if err != nil { - resp.Write(rw) - } - log.Debugf("close connection to node %s(%s), reason: %v", target.Name, target.Addr, err) - connPool.Delete(target) - }() + if tlsSettings := target.Options().TLS; tlsSettings != nil { + cc = tls.Client(cc, &tls.Config{ + ServerName: tlsSettings.ServerName, + InsecureSkipVerify: !tlsSettings.Secure, + }) } if httpSettings := target.Options().HTTP; httpSettings != nil { @@ -285,21 +266,18 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l log.Trace(string(dump)) } if err := req.Write(cc); err != nil { - log.Warnf("send request to node %s(%s) failed: %v", target.Name, target.Addr, err) + log.Warnf("send request to node %s(%s): %v", target.Name, target.Addr, err) return resp.Write(rw) } - if req.Header.Get("Upgrade") == "websocket" { - err := xnet.CopyBuffer(cc, br, 8192) - if err == nil { - err = io.EOF - } - return err + res, err := http.ReadResponse(bufio.NewReader(cc), req) + if err != nil { + log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err) + return resp.Write(rw) } + defer res.Body.Close() - // cc.SetReadDeadline(time.Now().Add(10 * time.Second)) - - return nil + return res.Write(rw) }() if err != nil { // log.Error(err) @@ -307,13 +285,6 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l } } - connPool.Range(func(key, value any) bool { - if value != nil { - value.(net.Conn).Close() - } - return true - }) - return } diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index 80da46a..dec8343 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "net/http/httputil" + "strconv" "sync" "time" @@ -94,6 +95,8 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) + localAddr := convertAddr(conn.LocalAddr()) + var rw io.ReadWriter = conn var host string var protocol string @@ -108,7 +111,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand } } if protocol == forward.ProtoHTTP { - h.handleHTTP(ctx, rw, conn.RemoteAddr(), conn.LocalAddr(), log) + h.handleHTTP(ctx, rw, conn.RemoteAddr(), localAddr, log) return nil } @@ -166,12 +169,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand marker.Reset() } - if dst, ok := conn.LocalAddr().(*net.TCPAddr); ok { - if dst.IP.Equal(net.IPv6zero) { - dst.IP = net.IPv4zero - } - } - cc = proxyproto.WrapClientConn(h.md.proxyProtocol, conn.RemoteAddr(), conn.LocalAddr(), cc) + cc = proxyproto.WrapClientConn(h.md.proxyProtocol, conn.RemoteAddr(), localAddr, cc) t := time.Now() log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) @@ -334,3 +332,27 @@ func (h *forwardHandler) checkRateLimit(addr net.Addr) bool { return true } + +func convertAddr(addr net.Addr) net.Addr { + host, sp, _ := net.SplitHostPort(addr.String()) + ip := net.ParseIP(host) + port, _ := strconv.Atoi(sp) + + if ip == nil || ip.Equal(net.IPv6zero) { + ip = net.IPv4zero + } + + switch addr.Network() { + case "tcp", "tcp4", "tcp6": + return &net.TCPAddr{ + IP: ip, + Port: port, + } + + default: + return &net.UDPAddr{ + IP: ip, + Port: port, + } + } +} diff --git a/handler/redirect/tcp/handler.go b/handler/redirect/tcp/handler.go index af57d5d..c88ab8d 100644 --- a/handler/redirect/tcp/handler.go +++ b/handler/redirect/tcp/handler.go @@ -116,7 +116,7 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han // try to sniff HTTP traffic if isHTTP(string(hdr[:])) { - return h.handleHTTP(ctx, rw, conn.RemoteAddr(), log) + return h.handleHTTP(ctx, rw, conn.RemoteAddr(), dstAddr, log) } } @@ -144,7 +144,7 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han return nil } -func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr net.Addr, log logger.Logger) error { +func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr, dstAddr net.Addr, log logger.Logger) error { req, err := http.ReadRequest(bufio.NewReader(rw)) if err != nil { return err @@ -171,7 +171,14 @@ func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, radd cc, err := h.router.Dial(ctx, "tcp", host) if err != nil { log.Error(err) - return err + } + + if cc == nil { + cc, err = h.router.Dial(ctx, "tcp", dstAddr.String()) + if err != nil { + log.Error(err) + return err + } } defer cc.Close() @@ -216,9 +223,10 @@ func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, rad log.Error(err) return err } - if host == "" { - host = dstAddr.String() - } else { + + var cc io.ReadWriteCloser + + if host != "" { if _, _, err := net.SplitHostPort(host); err != nil { _, port, _ := net.SplitHostPort(dstAddr.String()) if port == "" { @@ -226,21 +234,27 @@ func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, rad } host = net.JoinHostPort(host, port) } + log = log.WithFields(map[string]any{ + "host": host, + }) + + if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "tcp", host) { + log.Debug("bypass: ", host) + return nil + } + + cc, err = h.router.Dial(ctx, "tcp", host) + if err != nil { + log.Error(err) + } } - log = log.WithFields(map[string]any{ - "host": host, - }) - - if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "tcp", host) { - log.Debug("bypass: ", host) - return nil - } - - cc, err := h.router.Dial(ctx, "tcp", host) - if err != nil { - log.Error(err) - return err + if cc == nil { + cc, err = h.router.Dial(ctx, "tcp", dstAddr.String()) + if err != nil { + log.Error(err) + return err + } } defer cc.Close() diff --git a/handler/relay/connect.go b/handler/relay/connect.go index 8735399..18dcae9 100644 --- a/handler/relay/connect.go +++ b/handler/relay/connect.go @@ -150,7 +150,7 @@ func (h *relayHandler) handleConnectTunnel(ctx context.Context, conn net.Conn, n return err } - cc, _, err := getTunnelConn(network, h.pool, tid, 3, log) + cc, _, err := getTunnelConn(network, h.pool, tunnelID, 3, log) if err != nil { resp.Status = relay.StatusServiceUnavailable resp.WriteTo(conn) diff --git a/handler/tunnel/connect.go b/handler/tunnel/connect.go index edcffe2..0157381 100644 --- a/handler/tunnel/connect.go +++ b/handler/tunnel/connect.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net" - "strconv" "time" "github.com/go-gost/core/logger" @@ -12,37 +11,54 @@ import ( xnet "github.com/go-gost/x/internal/net" ) -func (h *tunnelHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, tunnelID relay.TunnelID, log logger.Logger) error { +func (h *tunnelHandler) handleConnect(ctx context.Context, conn net.Conn, network, srcAddr string, dstAddr string, tunnelID relay.TunnelID, log logger.Logger) error { log = log.WithFields(map[string]any{ - "dst": fmt.Sprintf("%s/%s", address, network), + "dst": fmt.Sprintf("%s/%s", dstAddr, network), "cmd": "connect", "tunnel": tunnelID.String(), }) - log.Debugf("%s >> %s/%s", conn.RemoteAddr(), address, network) - resp := relay.Response{ Version: relay.Version1, Status: relay.StatusOK, } - host, sp, _ := net.SplitHostPort(address) - - if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, network, address) { - log.Debug("bypass: ", address) + if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, network, dstAddr) { + log.Debug("bypass: ", dstAddr) resp.Status = relay.StatusForbidden _, err := resp.WriteTo(conn) return err } + host, _, _ := net.SplitHostPort(dstAddr) + var tid relay.TunnelID - if ingress := h.md.ingress; ingress != nil { + if ingress := h.md.ingress; ingress != nil && host != "" { tid = parseTunnelID(ingress.Get(ctx, host)) } - // client is not an public entrypoint. - if h.md.entryPointID.IsZero() || !tunnelID.Equal(h.md.entryPointID) { - if !tid.Equal(tunnelID) && !h.md.directTunnel { + // client is a public entrypoint. + if tunnelID.Equal(h.md.entryPointID) && !h.md.entryPointID.IsZero() { + if tid.IsZero() { + resp.Status = relay.StatusNetworkUnreachable + resp.WriteTo(conn) + err := fmt.Errorf("no route to host %s", host) + log.Error(err) + return err + } + + if tid.IsPrivate() { + resp.Status = relay.StatusHostUnreachable + resp.WriteTo(conn) + err := fmt.Errorf("access denied: tunnel %s is private for host %s", tunnelID, host) + log.Error(err) + return err + } + } else { + // direct routing + if h.md.directTunnel { + tid = tunnelID + } else if !tid.Equal(tunnelID) { resp.Status = relay.StatusHostUnreachable resp.WriteTo(conn) err := fmt.Errorf("no route to host %s", host) @@ -62,36 +78,35 @@ func (h *tunnelHandler) handleConnect(ctx context.Context, conn net.Conn, networ log.Debugf("%s >> %s", conn.RemoteAddr(), cc.RemoteAddr()) - rc := &tcpConn{ - Conn: conn, - } - // cache the header - if _, err := resp.WriteTo(&rc.wbuf); err != nil { - return err - } - conn = rc - - var features []relay.Feature - af := &relay.AddrFeature{} // source/visitor address - af.ParseFrom(conn.RemoteAddr().String()) - features = append(features, af) - - if host != "" { - port, _ := strconv.Atoi(sp) - // target host - af = &relay.AddrFeature{ - AType: relay.AddrDomain, - Host: host, - Port: uint16(port), + if h.md.noDelay { + if _, err := resp.WriteTo(conn); err != nil { + log.Error(err) + return err } - features = append(features, af) + } else { + rc := &tcpConn{ + Conn: conn, + } + // cache the header + if _, err := resp.WriteTo(&rc.wbuf); err != nil { + return err + } + conn = rc } resp = relay.Response{ - Version: relay.Version1, - Status: relay.StatusOK, - Features: features, + Version: relay.Version1, + Status: relay.StatusOK, } + + af := &relay.AddrFeature{} + af.ParseFrom(srcAddr) + resp.Features = append(resp.Features, af) // src address + + af = &relay.AddrFeature{} + af.ParseFrom(dstAddr) + resp.Features = append(resp.Features, af) // dst address + resp.WriteTo(cc) t := time.Now() diff --git a/handler/tunnel/handler.go b/handler/tunnel/handler.go index a8681a9..40e71af 100644 --- a/handler/tunnel/handler.go +++ b/handler/tunnel/handler.go @@ -118,7 +118,7 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl } var user, pass string - var address string + var srcAddr, dstAddr string var networkID relay.NetworkID var tunnelID relay.TunnelID for _, f := range req.Features { @@ -129,7 +129,12 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl } case relay.FeatureAddr: if feature, _ := f.(*relay.AddrFeature); feature != nil { - address = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port))) + v := net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port))) + if srcAddr != "" { + dstAddr = v + } else { + srcAddr = v + } } case relay.FeatureTunnel: if feature, _ := f.(*relay.TunnelFeature); feature != nil { @@ -170,9 +175,12 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl switch req.Cmd & relay.CmdMask { case relay.CmdConnect: defer conn.Close() - return h.handleConnect(ctx, conn, network, address, tunnelID, log) + + log.Debugf("connect: %s >> %s/%s", srcAddr, dstAddr, network) + return h.handleConnect(ctx, conn, network, srcAddr, dstAddr, tunnelID, log) case relay.CmdBind: - return h.handleBind(ctx, conn, network, address, tunnelID, log) + log.Debugf("bind: %s >> %s/%s", srcAddr, dstAddr, network) + return h.handleBind(ctx, conn, network, dstAddr, tunnelID, log) default: resp.Status = relay.StatusBadRequest resp.WriteTo(conn) diff --git a/handler/tunnel/metadata.go b/handler/tunnel/metadata.go index bd24bdd..4ed5206 100644 --- a/handler/tunnel/metadata.go +++ b/handler/tunnel/metadata.go @@ -15,6 +15,7 @@ import ( type metadata struct { readTimeout time.Duration + noDelay bool hash string directTunnel bool entryPointID relay.TunnelID @@ -22,18 +23,13 @@ type metadata struct { } func (h *tunnelHandler) parseMetadata(md mdata.Metadata) (err error) { - const ( - readTimeout = "readTimeout" - entryPointID = "entrypoint.id" - hash = "hash" - ) + h.md.readTimeout = mdutil.GetDuration(md, "readTimeout") + h.md.noDelay = mdutil.GetBool(md, "nodelay") - h.md.readTimeout = mdutil.GetDuration(md, readTimeout) - - h.md.hash = mdutil.GetString(md, hash) + h.md.hash = mdutil.GetString(md, "hash") h.md.directTunnel = mdutil.GetBool(md, "tunnel.direct") - h.md.entryPointID = parseTunnelID(mdutil.GetString(md, entryPointID)) + h.md.entryPointID = parseTunnelID(mdutil.GetString(md, "entrypoint.id")) h.md.ingress = registry.IngressRegistry().Get(mdutil.GetString(md, "ingress")) if h.md.ingress == nil { diff --git a/handler/tunnel/tunnel.go b/handler/tunnel/tunnel.go index e287efa..a45392a 100644 --- a/handler/tunnel/tunnel.go +++ b/handler/tunnel/tunnel.go @@ -177,6 +177,11 @@ func parseTunnelID(s string) (tid relay.TunnelID) { } func getTunnelConn(network string, pool *ConnectorPool, tid relay.TunnelID, retry int, log logger.Logger) (conn net.Conn, cid relay.ConnectorID, err error) { + if tid.IsZero() { + err = ErrTunnelID + return + } + if retry <= 0 { retry = 1 } diff --git a/internal/util/forward/forward.go b/internal/util/forward/forward.go index 42a4b93..fe459c6 100644 --- a/internal/util/forward/forward.go +++ b/internal/util/forward/forward.go @@ -38,6 +38,10 @@ func Sniffing(ctx context.Context, rdw io.ReadWriter) (rw io.ReadWriter, host st if err == nil { host = r.Host protocol = ProtoHTTP + + if r.Header.Get("Upgrade") == "websocket" { + protocol = ProtoWebsocket + } return } } @@ -89,9 +93,10 @@ func isHTTP(s string) bool { } const ( - ProtoHTTP = "http" - ProtoTLS = "tls" - ProtoSSHv2 = "SSH-2" + ProtoHTTP = "http" + ProtoWebsocket = "ws" + ProtoTLS = "tls" + ProtoSSHv2 = "SSH-2" ) func sniffProtocol(hdr []byte) string {