add timeout for sniffing

This commit is contained in:
ginuerzh
2023-05-21 15:47:51 +08:00
parent 30c705ffe5
commit 46db8480fa
17 changed files with 74 additions and 48 deletions

View File

@ -98,10 +98,16 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
var rw io.ReadWriter = conn
if h.md.sniffing {
if h.md.sniffingTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(h.md.sniffingTimeout))
}
// try to sniff TLS traffic
var hdr [dissector.RecordHeaderLen]byte
_, err := io.ReadFull(rw, hdr[:])
rw = xio.NewReadWriter(io.MultiReader(bytes.NewReader(hdr[:]), rw), rw)
n, err := io.ReadFull(rw, hdr[:])
if h.md.sniffingTimeout > 0 {
conn.SetReadDeadline(time.Time{})
}
rw = xio.NewReadWriter(io.MultiReader(bytes.NewReader(hdr[:n]), rw), rw)
if err == nil &&
hdr[0] == dissector.Handshake &&
binary.BigEndian.Uint16(hdr[1:3]) == tls.VersionTLS10 {
@ -129,11 +135,11 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
defer cc.Close()
t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), dstAddr)
log.Infof("%s <-> %s", conn.RemoteAddr(), dstAddr)
netpkg.Transport(rw, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), dstAddr)
}).Infof("%s >-< %s", conn.RemoteAddr(), dstAddr)
return nil
}
@ -170,11 +176,11 @@ func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, radd
defer cc.Close()
t := time.Now()
log.Debugf("%s <-> %s", raddr, host)
log.Infof("%s <-> %s", raddr, host)
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", raddr, host)
}).Infof("%s >-< %s", raddr, host)
}()
if err := req.Write(cc); err != nil {
@ -239,11 +245,11 @@ func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, rad
defer cc.Close()
t := time.Now()
log.Debugf("%s <-> %s", raddr, host)
log.Infof("%s <-> %s", raddr, host)
netpkg.Transport(xio.NewReadWriter(io.MultiReader(buf, rw), rw), cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", raddr, host)
}).Infof("%s >-< %s", raddr, host)
return nil
}

View File

@ -1,21 +1,25 @@
package redirect
import (
"time"
mdata "github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util"
)
type metadata struct {
sniffing bool
tproxy bool
tproxy bool
sniffing bool
sniffingTimeout time.Duration
}
func (h *redirectHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
sniffing = "sniffing"
tproxy = "tproxy"
sniffing = "sniffing"
)
h.md.sniffing = mdutil.GetBool(md, sniffing)
h.md.tproxy = mdutil.GetBool(md, tproxy)
h.md.sniffing = mdutil.GetBool(md, sniffing)
h.md.sniffingTimeout = mdutil.GetDuration(md, "sniffing.timeout")
return
}