diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index 5ec1809..08b2f47 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -93,8 +93,14 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand var host string var protocol string if network == "tcp" && h.md.sniffing { + if h.md.sniffingTimeout > 0 { + conn.SetReadDeadline(time.Now().Add(h.md.sniffingTimeout)) + } rw, host, protocol, _ = forward.Sniffing(ctx, conn) log.Debugf("sniffing: host=%s, protocol=%s", host, protocol) + if h.md.sniffingTimeout > 0 { + conn.SetReadDeadline(time.Time{}) + } } if protocol == forward.ProtoHTTP { @@ -152,11 +158,11 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand } t := time.Now() - log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr) + log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) xnet.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), - }).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr) + }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr) return nil } diff --git a/handler/forward/local/metadata.go b/handler/forward/local/metadata.go index 461585e..7b24fab 100644 --- a/handler/forward/local/metadata.go +++ b/handler/forward/local/metadata.go @@ -8,8 +8,9 @@ import ( ) type metadata struct { - readTimeout time.Duration - sniffing bool + readTimeout time.Duration + sniffing bool + sniffingTimeout time.Duration } func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { @@ -20,5 +21,6 @@ func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { h.md.readTimeout = mdutil.GetDuration(md, readTimeout) h.md.sniffing = mdutil.GetBool(md, sniffing) + h.md.sniffingTimeout = mdutil.GetDuration(md, "sniffing.timeout") return } diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index a8f8869..d9297a4 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -93,8 +93,14 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand var host string var protocol string if network == "tcp" && h.md.sniffing { + if h.md.sniffingTimeout > 0 { + conn.SetReadDeadline(time.Now().Add(h.md.sniffingTimeout)) + } rw, host, protocol, _ = forward.Sniffing(ctx, conn) log.Debugf("sniffing: host=%s, protocol=%s", host, protocol) + if h.md.sniffingTimeout > 0 { + conn.SetReadDeadline(time.Time{}) + } } if protocol == forward.ProtoHTTP { h.handleHTTP(ctx, rw, log) @@ -148,11 +154,11 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand } t := time.Now() - log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr) + log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr) xnet.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), - }).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr) + }).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr) return nil } diff --git a/handler/forward/remote/metadata.go b/handler/forward/remote/metadata.go index 41ec68f..ea903d8 100644 --- a/handler/forward/remote/metadata.go +++ b/handler/forward/remote/metadata.go @@ -8,8 +8,9 @@ import ( ) type metadata struct { - readTimeout time.Duration - sniffing bool + readTimeout time.Duration + sniffing bool + sniffingTimeout time.Duration } func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { @@ -20,5 +21,6 @@ func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { h.md.readTimeout = mdutil.GetDuration(md, readTimeout) h.md.sniffing = mdutil.GetBool(md, sniffing) + h.md.sniffingTimeout = mdutil.GetDuration(md, "sniffing.timeout") return } diff --git a/handler/http/handler.go b/handler/http/handler.go index e5ea6ee..9980dbc 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -218,11 +218,11 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt } start := time.Now() - log.Debugf("%s <-> %s", conn.RemoteAddr(), addr) + log.Infof("%s <-> %s", conn.RemoteAddr(), addr) netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(start), - }).Debugf("%s >-< %s", conn.RemoteAddr(), addr) + }).Infof("%s >-< %s", conn.RemoteAddr(), addr) return nil } diff --git a/handler/http2/handler.go b/handler/http2/handler.go index 6e4fafe..6b1bc19 100644 --- a/handler/http2/handler.go +++ b/handler/http2/handler.go @@ -191,21 +191,21 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req defer conn.Close() start := time.Now() - log.Debugf("%s <-> %s", conn.RemoteAddr(), addr) + log.Infof("%s <-> %s", conn.RemoteAddr(), addr) netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(start), - }).Debugf("%s >-< %s", conn.RemoteAddr(), addr) + }).Infof("%s >-< %s", conn.RemoteAddr(), addr) return nil } start := time.Now() - log.Debugf("%s <-> %s", req.RemoteAddr, addr) + log.Infof("%s <-> %s", req.RemoteAddr, addr) netpkg.Transport(xio.NewReadWriter(req.Body, flushWriter{w}), cc) log.WithFields(map[string]any{ "duration": time.Since(start), - }).Debugf("%s >-< %s", req.RemoteAddr, addr) + }).Infof("%s >-< %s", req.RemoteAddr, addr) return nil } diff --git a/handler/redirect/tcp/handler.go b/handler/redirect/tcp/handler.go index 604162e..d51e73b 100644 --- a/handler/redirect/tcp/handler.go +++ b/handler/redirect/tcp/handler.go @@ -98,10 +98,16 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han var rw io.ReadWriter = conn if h.md.sniffing { + if h.md.sniffingTimeout > 0 { + conn.SetReadDeadline(time.Now().Add(h.md.sniffingTimeout)) + } // try to sniff TLS traffic var hdr [dissector.RecordHeaderLen]byte - _, err := io.ReadFull(rw, hdr[:]) - rw = xio.NewReadWriter(io.MultiReader(bytes.NewReader(hdr[:]), rw), rw) + n, err := io.ReadFull(rw, hdr[:]) + if h.md.sniffingTimeout > 0 { + conn.SetReadDeadline(time.Time{}) + } + rw = xio.NewReadWriter(io.MultiReader(bytes.NewReader(hdr[:n]), rw), rw) if err == nil && hdr[0] == dissector.Handshake && binary.BigEndian.Uint16(hdr[1:3]) == tls.VersionTLS10 { @@ -129,11 +135,11 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han defer cc.Close() t := time.Now() - log.Debugf("%s <-> %s", conn.RemoteAddr(), dstAddr) + log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr) netpkg.Transport(rw, cc) log.WithFields(map[string]any{ "duration": time.Since(t), - }).Debugf("%s >-< %s", conn.RemoteAddr(), dstAddr) + }).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr) return nil } @@ -170,11 +176,11 @@ func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, radd defer cc.Close() t := time.Now() - log.Debugf("%s <-> %s", raddr, host) + log.Infof("%s <-> %s", raddr, host) defer func() { log.WithFields(map[string]any{ "duration": time.Since(t), - }).Debugf("%s >-< %s", raddr, host) + }).Infof("%s >-< %s", raddr, host) }() if err := req.Write(cc); err != nil { @@ -239,11 +245,11 @@ func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, rad defer cc.Close() t := time.Now() - log.Debugf("%s <-> %s", raddr, host) + log.Infof("%s <-> %s", raddr, host) netpkg.Transport(xio.NewReadWriter(io.MultiReader(buf, rw), rw), cc) log.WithFields(map[string]any{ "duration": time.Since(t), - }).Debugf("%s >-< %s", raddr, host) + }).Infof("%s >-< %s", raddr, host) return nil } diff --git a/handler/redirect/tcp/metadata.go b/handler/redirect/tcp/metadata.go index cc557f5..1e319d7 100644 --- a/handler/redirect/tcp/metadata.go +++ b/handler/redirect/tcp/metadata.go @@ -1,21 +1,25 @@ package redirect import ( + "time" + mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" ) type metadata struct { - sniffing bool - tproxy bool + tproxy bool + sniffing bool + sniffingTimeout time.Duration } func (h *redirectHandler) parseMetadata(md mdata.Metadata) (err error) { const ( - sniffing = "sniffing" tproxy = "tproxy" + sniffing = "sniffing" ) - h.md.sniffing = mdutil.GetBool(md, sniffing) h.md.tproxy = mdutil.GetBool(md, tproxy) + h.md.sniffing = mdutil.GetBool(md, sniffing) + h.md.sniffingTimeout = mdutil.GetDuration(md, "sniffing.timeout") return } diff --git a/handler/redirect/udp/handler.go b/handler/redirect/udp/handler.go index 2507d11..99bb623 100644 --- a/handler/redirect/udp/handler.go +++ b/handler/redirect/udp/handler.go @@ -88,11 +88,11 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han defer cc.Close() t := time.Now() - log.Debugf("%s <-> %s", conn.RemoteAddr(), dstAddr) + log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr) netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), - }).Debugf("%s >-< %s", conn.RemoteAddr(), dstAddr) + }).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr) return nil } diff --git a/handler/relay/connect.go b/handler/relay/connect.go index 4738dc5..04efd3d 100644 --- a/handler/relay/connect.go +++ b/handler/relay/connect.go @@ -87,11 +87,11 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network } t := time.Now() - log.Debugf("%s <-> %s", conn.RemoteAddr(), address) + log.Infof("%s <-> %s", conn.RemoteAddr(), address) xnet.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), - }).Debugf("%s >-< %s", conn.RemoteAddr(), address) + }).Infof("%s >-< %s", conn.RemoteAddr(), address) return nil } diff --git a/handler/sni/handler.go b/handler/sni/handler.go index 4a8da2e..d0edb04 100644 --- a/handler/sni/handler.go +++ b/handler/sni/handler.go @@ -133,11 +133,11 @@ func (h *sniHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr net defer cc.Close() t := time.Now() - log.Debugf("%s <-> %s", raddr, host) + log.Infof("%s <-> %s", raddr, host) defer func() { log.WithFields(map[string]any{ "duration": time.Since(t), - }).Debugf("%s >-< %s", raddr, host) + }).Infof("%s >-< %s", raddr, host) }() if err := req.Write(cc); err != nil { @@ -201,11 +201,11 @@ func (h *sniHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, raddr ne defer cc.Close() t := time.Now() - log.Debugf("%s <-> %s", raddr, host) + log.Infof("%s <-> %s", raddr, host) netpkg.Transport(xio.NewReadWriter(io.MultiReader(buf, rw), rw), cc) log.WithFields(map[string]any{ "duration": time.Since(t), - }).Debugf("%s >-< %s", raddr, host) + }).Infof("%s >-< %s", raddr, host) return nil } diff --git a/handler/socks/v4/handler.go b/handler/socks/v4/handler.go index 8613d0f..42412d3 100644 --- a/handler/socks/v4/handler.go +++ b/handler/socks/v4/handler.go @@ -147,11 +147,11 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g } t := time.Now() - log.Debugf("%s <-> %s", conn.RemoteAddr(), addr) + log.Infof("%s <-> %s", conn.RemoteAddr(), addr) netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), - }).Debugf("%s >-< %s", conn.RemoteAddr(), addr) + }).Infof("%s >-< %s", conn.RemoteAddr(), addr) return nil } diff --git a/handler/socks/v5/connect.go b/handler/socks/v5/connect.go index e22c89c..da6b9d9 100644 --- a/handler/socks/v5/connect.go +++ b/handler/socks/v5/connect.go @@ -49,11 +49,11 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ } t := time.Now() - log.Debugf("%s <-> %s", conn.RemoteAddr(), address) + log.Infof("%s <-> %s", conn.RemoteAddr(), address) netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), - }).Debugf("%s >-< %s", conn.RemoteAddr(), address) + }).Infof("%s >-< %s", conn.RemoteAddr(), address) return nil } diff --git a/handler/ss/handler.go b/handler/ss/handler.go index 61eab35..4ee10c7 100644 --- a/handler/ss/handler.go +++ b/handler/ss/handler.go @@ -119,11 +119,11 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.H defer cc.Close() t := time.Now() - log.Debugf("%s <-> %s", conn.RemoteAddr(), addr) + log.Infof("%s <-> %s", conn.RemoteAddr(), addr) netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), - }).Debugf("%s >-< %s", conn.RemoteAddr(), addr) + }).Infof("%s >-< %s", conn.RemoteAddr(), addr) return nil } diff --git a/handler/ss/udp/handler.go b/handler/ss/udp/handler.go index aea90f5..a6332a4 100644 --- a/handler/ss/udp/handler.go +++ b/handler/ss/udp/handler.go @@ -112,10 +112,10 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. } t := time.Now() - log.Debugf("%s <-> %s", conn.LocalAddr(), cc.LocalAddr()) + log.Infof("%s <-> %s", conn.LocalAddr(), cc.LocalAddr()) h.relayPacket(pc, cc, log) log.WithFields(map[string]any{"duration": time.Since(t)}). - Debugf("%s >-< %s", conn.LocalAddr(), cc.LocalAddr()) + Infof("%s >-< %s", conn.LocalAddr(), cc.LocalAddr()) return nil } diff --git a/handler/sshd/handler.go b/handler/sshd/handler.go index 0e550cc..1aa9d25 100644 --- a/handler/sshd/handler.go +++ b/handler/sshd/handler.go @@ -104,11 +104,11 @@ func (h *forwardHandler) handleDirectForward(ctx context.Context, conn *sshd_uti defer cc.Close() t := time.Now() - log.Debugf("%s <-> %s", cc.LocalAddr(), targetAddr) + log.Infof("%s <-> %s", cc.LocalAddr(), targetAddr) netpkg.Transport(conn, cc) log.WithFields(map[string]any{ "duration": time.Since(t), - }).Debugf("%s >-< %s", cc.LocalAddr(), targetAddr) + }).Infof("%s >-< %s", cc.LocalAddr(), targetAddr) return nil } @@ -212,11 +212,11 @@ func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_uti }() tm := time.Now() - log.Debugf("%s <-> %s", conn.RemoteAddr(), addr) + log.Infof("%s <-> %s", conn.RemoteAddr(), addr) <-conn.Done() log.WithFields(map[string]any{ "duration": time.Since(tm), - }).Debugf("%s >-< %s", conn.RemoteAddr(), addr) + }).Infof("%s >-< %s", conn.RemoteAddr(), addr) return nil } diff --git a/internal/util/forward/forward.go b/internal/util/forward/forward.go index 57d4a9d..42a4b93 100644 --- a/internal/util/forward/forward.go +++ b/internal/util/forward/forward.go @@ -19,8 +19,8 @@ func Sniffing(ctx context.Context, rdw io.ReadWriter) (rw io.ReadWriter, host st // try to sniff TLS traffic var hdr [dissector.RecordHeaderLen]byte - _, err = io.ReadFull(rw, hdr[:]) - rw = xio.NewReadWriter(io.MultiReader(bytes.NewReader(hdr[:]), rw), rw) + n, err := io.ReadFull(rw, hdr[:]) + rw = xio.NewReadWriter(io.MultiReader(bytes.NewReader(hdr[:n]), rw), rw) if err == nil && hdr[0] == dissector.Handshake && binary.BigEndian.Uint16(hdr[1:3]) == tls.VersionTLS10 {