add timeout for sniffing

This commit is contained in:
ginuerzh
2023-05-21 15:47:51 +08:00
parent 30c705ffe5
commit 46db8480fa
17 changed files with 74 additions and 48 deletions

View File

@ -93,8 +93,14 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
var host string
var protocol string
if network == "tcp" && h.md.sniffing {
if h.md.sniffingTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(h.md.sniffingTimeout))
}
rw, host, protocol, _ = forward.Sniffing(ctx, conn)
log.Debugf("sniffing: host=%s, protocol=%s", host, protocol)
if h.md.sniffingTimeout > 0 {
conn.SetReadDeadline(time.Time{})
}
}
if protocol == forward.ProtoHTTP {
@ -152,11 +158,11 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
}
t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr)
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr)
xnet.Transport(rw, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr)
}).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr)
return nil
}

View File

@ -8,8 +8,9 @@ import (
)
type metadata struct {
readTimeout time.Duration
sniffing bool
readTimeout time.Duration
sniffing bool
sniffingTimeout time.Duration
}
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
@ -20,5 +21,6 @@ func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
h.md.sniffing = mdutil.GetBool(md, sniffing)
h.md.sniffingTimeout = mdutil.GetDuration(md, "sniffing.timeout")
return
}

View File

@ -93,8 +93,14 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
var host string
var protocol string
if network == "tcp" && h.md.sniffing {
if h.md.sniffingTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(h.md.sniffingTimeout))
}
rw, host, protocol, _ = forward.Sniffing(ctx, conn)
log.Debugf("sniffing: host=%s, protocol=%s", host, protocol)
if h.md.sniffingTimeout > 0 {
conn.SetReadDeadline(time.Time{})
}
}
if protocol == forward.ProtoHTTP {
h.handleHTTP(ctx, rw, log)
@ -148,11 +154,11 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
}
t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr)
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr)
xnet.Transport(rw, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr)
}).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr)
return nil
}

View File

@ -8,8 +8,9 @@ import (
)
type metadata struct {
readTimeout time.Duration
sniffing bool
readTimeout time.Duration
sniffing bool
sniffingTimeout time.Duration
}
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
@ -20,5 +21,6 @@ func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
h.md.sniffing = mdutil.GetBool(md, sniffing)
h.md.sniffingTimeout = mdutil.GetDuration(md, "sniffing.timeout")
return
}

View File

@ -218,11 +218,11 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
}
start := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), addr)
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
netpkg.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Debugf("%s >-< %s", conn.RemoteAddr(), addr)
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
return nil
}

View File

@ -191,21 +191,21 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req
defer conn.Close()
start := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), addr)
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
netpkg.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Debugf("%s >-< %s", conn.RemoteAddr(), addr)
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
return nil
}
start := time.Now()
log.Debugf("%s <-> %s", req.RemoteAddr, addr)
log.Infof("%s <-> %s", req.RemoteAddr, addr)
netpkg.Transport(xio.NewReadWriter(req.Body, flushWriter{w}), cc)
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Debugf("%s >-< %s", req.RemoteAddr, addr)
}).Infof("%s >-< %s", req.RemoteAddr, addr)
return nil
}

View File

@ -98,10 +98,16 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
var rw io.ReadWriter = conn
if h.md.sniffing {
if h.md.sniffingTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(h.md.sniffingTimeout))
}
// try to sniff TLS traffic
var hdr [dissector.RecordHeaderLen]byte
_, err := io.ReadFull(rw, hdr[:])
rw = xio.NewReadWriter(io.MultiReader(bytes.NewReader(hdr[:]), rw), rw)
n, err := io.ReadFull(rw, hdr[:])
if h.md.sniffingTimeout > 0 {
conn.SetReadDeadline(time.Time{})
}
rw = xio.NewReadWriter(io.MultiReader(bytes.NewReader(hdr[:n]), rw), rw)
if err == nil &&
hdr[0] == dissector.Handshake &&
binary.BigEndian.Uint16(hdr[1:3]) == tls.VersionTLS10 {
@ -129,11 +135,11 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
defer cc.Close()
t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), dstAddr)
log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr)
netpkg.Transport(rw, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), dstAddr)
}).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr)
return nil
}
@ -170,11 +176,11 @@ func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, radd
defer cc.Close()
t := time.Now()
log.Debugf("%s <-> %s", raddr, host)
log.Infof("%s <-> %s", raddr, host)
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", raddr, host)
}).Infof("%s >-< %s", raddr, host)
}()
if err := req.Write(cc); err != nil {
@ -239,11 +245,11 @@ func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, rad
defer cc.Close()
t := time.Now()
log.Debugf("%s <-> %s", raddr, host)
log.Infof("%s <-> %s", raddr, host)
netpkg.Transport(xio.NewReadWriter(io.MultiReader(buf, rw), rw), cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", raddr, host)
}).Infof("%s >-< %s", raddr, host)
return nil
}

View File

@ -1,21 +1,25 @@
package redirect
import (
"time"
mdata "github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util"
)
type metadata struct {
sniffing bool
tproxy bool
tproxy bool
sniffing bool
sniffingTimeout time.Duration
}
func (h *redirectHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
sniffing = "sniffing"
tproxy = "tproxy"
sniffing = "sniffing"
)
h.md.sniffing = mdutil.GetBool(md, sniffing)
h.md.tproxy = mdutil.GetBool(md, tproxy)
h.md.sniffing = mdutil.GetBool(md, sniffing)
h.md.sniffingTimeout = mdutil.GetDuration(md, "sniffing.timeout")
return
}

View File

@ -88,11 +88,11 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
defer cc.Close()
t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), dstAddr)
log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr)
netpkg.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), dstAddr)
}).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr)
return nil
}

View File

@ -87,11 +87,11 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network
}
t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), address)
log.Infof("%s <-> %s", conn.RemoteAddr(), address)
xnet.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), address)
}).Infof("%s >-< %s", conn.RemoteAddr(), address)
return nil
}

View File

@ -133,11 +133,11 @@ func (h *sniHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr net
defer cc.Close()
t := time.Now()
log.Debugf("%s <-> %s", raddr, host)
log.Infof("%s <-> %s", raddr, host)
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", raddr, host)
}).Infof("%s >-< %s", raddr, host)
}()
if err := req.Write(cc); err != nil {
@ -201,11 +201,11 @@ func (h *sniHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, raddr ne
defer cc.Close()
t := time.Now()
log.Debugf("%s <-> %s", raddr, host)
log.Infof("%s <-> %s", raddr, host)
netpkg.Transport(xio.NewReadWriter(io.MultiReader(buf, rw), rw), cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", raddr, host)
}).Infof("%s >-< %s", raddr, host)
return nil
}

View File

@ -147,11 +147,11 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g
}
t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), addr)
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
netpkg.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), addr)
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
return nil
}

View File

@ -49,11 +49,11 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ
}
t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), address)
log.Infof("%s <-> %s", conn.RemoteAddr(), address)
netpkg.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), address)
}).Infof("%s >-< %s", conn.RemoteAddr(), address)
return nil
}

View File

@ -119,11 +119,11 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.H
defer cc.Close()
t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), addr)
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
netpkg.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), addr)
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
return nil
}

View File

@ -112,10 +112,10 @@ func (h *ssuHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
}
t := time.Now()
log.Debugf("%s <-> %s", conn.LocalAddr(), cc.LocalAddr())
log.Infof("%s <-> %s", conn.LocalAddr(), cc.LocalAddr())
h.relayPacket(pc, cc, log)
log.WithFields(map[string]any{"duration": time.Since(t)}).
Debugf("%s >-< %s", conn.LocalAddr(), cc.LocalAddr())
Infof("%s >-< %s", conn.LocalAddr(), cc.LocalAddr())
return nil
}

View File

@ -104,11 +104,11 @@ func (h *forwardHandler) handleDirectForward(ctx context.Context, conn *sshd_uti
defer cc.Close()
t := time.Now()
log.Debugf("%s <-> %s", cc.LocalAddr(), targetAddr)
log.Infof("%s <-> %s", cc.LocalAddr(), targetAddr)
netpkg.Transport(conn, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", cc.LocalAddr(), targetAddr)
}).Infof("%s >-< %s", cc.LocalAddr(), targetAddr)
return nil
}
@ -212,11 +212,11 @@ func (h *forwardHandler) handleRemoteForward(ctx context.Context, conn *sshd_uti
}()
tm := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), addr)
log.Infof("%s <-> %s", conn.RemoteAddr(), addr)
<-conn.Done()
log.WithFields(map[string]any{
"duration": time.Since(tm),
}).Debugf("%s >-< %s", conn.RemoteAddr(), addr)
}).Infof("%s >-< %s", conn.RemoteAddr(), addr)
return nil
}