From dca2a79c54d25d38af8d6753d3281fea2f29f4e4 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Thu, 14 Apr 2022 14:53:13 +0800 Subject: [PATCH] update sni handler --- handler/http/handler.go | 2 +- handler/http/metadata.go | 3 - handler/sni/conn.go | 18 +---- handler/sni/handler.go | 163 +++++++++++++++++++++++---------------- 4 files changed, 102 insertions(+), 84 deletions(-) diff --git a/handler/http/handler.go b/handler/http/handler.go index 88ca5c4..87a7494 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -84,7 +84,7 @@ func (h *httpHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler } func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request, log logger.Logger) error { - if h.md.sni && !req.URL.IsAbs() && govalidator.IsDNSName(req.Host) { + if !req.URL.IsAbs() && govalidator.IsDNSName(req.Host) { req.URL.Scheme = "http" } diff --git a/handler/http/metadata.go b/handler/http/metadata.go index 1509796..d99a695 100644 --- a/handler/http/metadata.go +++ b/handler/http/metadata.go @@ -10,7 +10,6 @@ import ( type metadata struct { probeResistance *probeResistance - sni bool enableUDP bool header http.Header } @@ -20,7 +19,6 @@ func (h *httpHandler) parseMetadata(md mdata.Metadata) error { header = "header" probeResistKey = "probeResistance" knock = "knock" - sni = "sni" enableUDP = "udp" ) @@ -41,7 +39,6 @@ func (h *httpHandler) parseMetadata(md mdata.Metadata) error { } } } - h.md.sni = mdx.GetBool(md, sni) h.md.enableUDP = mdx.GetBool(md, enableUDP) return nil diff --git a/handler/sni/conn.go b/handler/sni/conn.go index 22d1035..0a6e09d 100644 --- a/handler/sni/conn.go +++ b/handler/sni/conn.go @@ -1,20 +1,10 @@ package sni import ( - "net" + "io" ) -type cacheConn struct { - net.Conn - buf []byte -} - -func (c *cacheConn) Read(b []byte) (n int, err error) { - if len(c.buf) > 0 { - n = copy(b, c.buf) - c.buf = c.buf[n:] - return - } - - return c.Conn.Read(b) +type readWriter struct { + io.Reader + io.Writer } diff --git a/handler/sni/handler.go b/handler/sni/handler.go index 918b49e..e66b2d5 100644 --- a/handler/sni/handler.go +++ b/handler/sni/handler.go @@ -1,19 +1,23 @@ package sni import ( + "bufio" "bytes" "context" + "crypto/tls" "encoding/base64" "encoding/binary" "errors" "hash/crc32" "io" "net" + "net/http" + "net/http/httputil" "time" "github.com/go-gost/core/chain" - "github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/handler" + "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" dissector "github.com/go-gost/tls-dissector" netpkg "github.com/go-gost/x/internal/net" @@ -25,10 +29,9 @@ func init() { } type sniHandler struct { - httpHandler handler.Handler - router *chain.Router - md metadata - options handler.Options + router *chain.Router + md metadata + options handler.Options } func NewHandler(opts ...handler.Option) handler.Handler { @@ -41,12 +44,6 @@ func NewHandler(opts ...handler.Option) handler.Handler { options: options, } - if f := registry.HandlerRegistry().Get("http"); f != nil { - v := append(opts, - handler.LoggerOption(h.options.Logger.WithFields(map[string]any{"type": "http"}))) - h.httpHandler = f(v...) - } - return h } @@ -54,14 +51,6 @@ func (h *sniHandler) Init(md md.Metadata) (err error) { if err = h.parseMetadata(md); err != nil { return } - if h.httpHandler != nil { - if md != nil { - md.Set("sni", true) - } - if err = h.httpHandler.Init(md); err != nil { - return - } - } h.router = h.options.Router if h.router == nil { @@ -93,69 +82,121 @@ func (h *sniHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. return err } - if hdr[0] != dissector.Handshake { - // We assume it is an HTTP request - conn = &cacheConn{ - Conn: conn, - buf: hdr[:], - } - - if h.httpHandler != nil { - return h.httpHandler.Handle(ctx, conn) - } - return nil + rw := &readWriter{ + Reader: io.MultiReader(bytes.NewReader(hdr[:]), conn), + Writer: conn, } - - length := binary.BigEndian.Uint16(hdr[3:5]) - - buf := bufpool.Get(int(length) + dissector.RecordHeaderLen) - defer bufpool.Put(buf) - if _, err := io.ReadFull(conn, (*buf)[dissector.RecordHeaderLen:]); err != nil { - log.Error(err) - return err + if hdr[0] == dissector.Handshake && + binary.BigEndian.Uint16(hdr[1:3]) == tls.VersionTLS10 { + return h.handleHTTPS(ctx, rw, conn.RemoteAddr(), log) } - copy(*buf, hdr[:]) + return h.handleHTTP(ctx, rw, conn.RemoteAddr(), log) +} - opaque, host, err := h.decodeHost(bytes.NewReader(*buf)) +func (h *sniHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr net.Addr, log logger.Logger) error { + req, err := http.ReadRequest(bufio.NewReader(rw)) if err != nil { - log.Error(err) return err } - target := net.JoinHostPort(host, "443") + if log.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(req, false) + log.Debug(string(dump)) + } + + host := req.Host + if _, _, err := net.SplitHostPort(host); err != nil { + host = net.JoinHostPort(host, "80") + } log = log.WithFields(map[string]any{ - "dst": target, + "host": host, }) - log.Infof("%s >> %s", conn.RemoteAddr(), target) - if h.options.Bypass != nil && h.options.Bypass.Contains(target) { - log.Info("bypass: ", target) + if h.options.Bypass != nil && h.options.Bypass.Contains(host) { + log.Info("bypass: ", host) return nil } - cc, err := h.router.Dial(ctx, "tcp", target) + cc, err := h.router.Dial(ctx, "tcp", host) if err != nil { log.Error(err) return err } defer cc.Close() - if _, err := cc.Write(opaque); err != nil { + t := time.Now() + log.Infof("%s <-> %s", raddr, host) + defer func() { + log.WithFields(map[string]any{ + "duration": time.Since(t), + }).Infof("%s >-< %s", raddr, host) + }() + + if err := req.Write(cc); err != nil { log.Error(err) return err } - t := time.Now() - log.Infof("%s <-> %s", conn.RemoteAddr(), target) - netpkg.Transport(conn, cc) - log.WithFields(map[string]any{ - "duration": time.Since(t), - }).Infof("%s >-< %s", conn.RemoteAddr(), target) + resp, err := http.ReadResponse(bufio.NewReader(cc), req) + if err != nil { + log.Error(err) + return err + } + defer resp.Body.Close() + + if log.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + log.Debug(string(dump)) + } + + return resp.Write(rw) - return nil } -func (h *sniHandler) decodeHost(r io.Reader) (opaque []byte, host string, err error) { +func (h *sniHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, raddr net.Addr, log logger.Logger) error { + buf := new(bytes.Buffer) + host, err := h.decodeHost(io.TeeReader(rw, buf)) + if err != nil { + log.Error(err) + return err + } + + if _, _, err := net.SplitHostPort(host); err != nil { + host = net.JoinHostPort(host, "443") + } + + log = log.WithFields(map[string]any{ + "dst": host, + }) + log.Infof("%s >> %s", raddr, host) + + if h.options.Bypass != nil && h.options.Bypass.Contains(host) { + log.Info("bypass: ", host) + return nil + } + + cc, err := h.router.Dial(ctx, "tcp", host) + if err != nil { + log.Error(err) + return err + } + defer cc.Close() + + t := time.Now() + log.Infof("%s <-> %s", raddr, host) + netpkg.Transport(&readWriter{ + Reader: io.MultiReader(buf, rw), + Writer: rw, + }, cc) + log.WithFields(map[string]any{ + "duration": time.Since(t), + }).Infof("%s >-< %s", raddr, host) + + return nil + +} + +func (h *sniHandler) decodeHost(r io.Reader) (host string, err error) { record, err := dissector.ReadRecord(r) if err != nil { return @@ -190,16 +231,6 @@ func (h *sniHandler) decodeHost(r io.Reader) (opaque []byte, host string, err er } } - record.Opaque, err = clientHello.Encode() - if err != nil { - return - } - - buf := &bytes.Buffer{} - if _, err = record.WriteTo(buf); err != nil { - return - } - opaque = buf.Bytes() return }