update forward handler

This commit is contained in:
ginuerzh
2023-10-16 23:16:47 +08:00
parent 5ab729b166
commit 5dfbb59f8a
17 changed files with 253 additions and 174 deletions

View File

@ -116,7 +116,7 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
// try to sniff HTTP traffic
if isHTTP(string(hdr[:])) {
return h.handleHTTP(ctx, rw, conn.RemoteAddr(), log)
return h.handleHTTP(ctx, rw, conn.RemoteAddr(), dstAddr, log)
}
}
@ -144,7 +144,7 @@ func (h *redirectHandler) Handle(ctx context.Context, conn net.Conn, opts ...han
return nil
}
func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr net.Addr, log logger.Logger) error {
func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr, dstAddr net.Addr, log logger.Logger) error {
req, err := http.ReadRequest(bufio.NewReader(rw))
if err != nil {
return err
@ -171,7 +171,14 @@ func (h *redirectHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, radd
cc, err := h.router.Dial(ctx, "tcp", host)
if err != nil {
log.Error(err)
return err
}
if cc == nil {
cc, err = h.router.Dial(ctx, "tcp", dstAddr.String())
if err != nil {
log.Error(err)
return err
}
}
defer cc.Close()
@ -216,9 +223,10 @@ func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, rad
log.Error(err)
return err
}
if host == "" {
host = dstAddr.String()
} else {
var cc io.ReadWriteCloser
if host != "" {
if _, _, err := net.SplitHostPort(host); err != nil {
_, port, _ := net.SplitHostPort(dstAddr.String())
if port == "" {
@ -226,21 +234,27 @@ func (h *redirectHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, rad
}
host = net.JoinHostPort(host, port)
}
log = log.WithFields(map[string]any{
"host": host,
})
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "tcp", host) {
log.Debug("bypass: ", host)
return nil
}
cc, err = h.router.Dial(ctx, "tcp", host)
if err != nil {
log.Error(err)
}
}
log = log.WithFields(map[string]any{
"host": host,
})
if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, "tcp", host) {
log.Debug("bypass: ", host)
return nil
}
cc, err := h.router.Dial(ctx, "tcp", host)
if err != nil {
log.Error(err)
return err
if cc == nil {
cc, err = h.router.Dial(ctx, "tcp", dstAddr.String())
if err != nil {
log.Error(err)
return err
}
}
defer cc.Close()