add timeout for sniffing

This commit is contained in:
ginuerzh
2023-05-21 15:47:51 +08:00
parent 30c705ffe5
commit 46db8480fa
17 changed files with 74 additions and 48 deletions

View File

@ -93,8 +93,14 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
var host string
var protocol string
if network == "tcp" && h.md.sniffing {
if h.md.sniffingTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(h.md.sniffingTimeout))
}
rw, host, protocol, _ = forward.Sniffing(ctx, conn)
log.Debugf("sniffing: host=%s, protocol=%s", host, protocol)
if h.md.sniffingTimeout > 0 {
conn.SetReadDeadline(time.Time{})
}
}
if protocol == forward.ProtoHTTP {
@ -152,11 +158,11 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
}
t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr)
log.Infof("%s <-> %s", conn.RemoteAddr(), target.Addr)
xnet.Transport(rw, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), target.Addr)
}).Infof("%s >-< %s", conn.RemoteAddr(), target.Addr)
return nil
}

View File

@ -8,8 +8,9 @@ import (
)
type metadata struct {
readTimeout time.Duration
sniffing bool
readTimeout time.Duration
sniffing bool
sniffingTimeout time.Duration
}
func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
@ -20,5 +21,6 @@ func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) {
h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
h.md.sniffing = mdutil.GetBool(md, sniffing)
h.md.sniffingTimeout = mdutil.GetDuration(md, "sniffing.timeout")
return
}