diff --git a/config/config.go b/config/config.go index 6bb586f..d5f2d63 100644 --- a/config/config.go +++ b/config/config.go @@ -376,6 +376,13 @@ type HTTPURLRewriteConfig struct { Replacement string } +type HTTPBodyRewriteConfig struct { + // filter by MIME types + Type string + Match string + Replacement string +} + type NodeFilterConfig struct { Host string `yaml:",omitempty" json:"host,omitempty"` Protocol string `yaml:",omitempty" json:"protocol,omitempty"` @@ -383,10 +390,16 @@ type NodeFilterConfig struct { } type HTTPNodeConfig struct { - Host string `yaml:",omitempty" json:"host,omitempty"` - Header map[string]string `yaml:",omitempty" json:"header,omitempty"` + // rewrite host header + Host string `yaml:",omitempty" json:"host,omitempty"` + // additional request header + Header map[string]string `yaml:",omitempty" json:"header,omitempty"` + // rewrite URL Rewrite []HTTPURLRewriteConfig `yaml:",omitempty" json:"rewrite,omitempty"` - Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` + // rewrite response body + RewriteBody []HTTPBodyRewriteConfig `yaml:"rewriteBody,omitempty" json:"rewriteBody,omitempty"` + // HTTP basic auth + Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` } type TLSNodeConfig struct { diff --git a/config/parsing/node/parse.go b/config/parsing/node/parse.go index 41fd58a..f99a305 100644 --- a/config/parsing/node/parse.go +++ b/config/parsing/node/parse.go @@ -193,12 +193,21 @@ func ParseNode(hop string, cfg *config.NodeConfig, log logger.Logger) (*chain.No } for _, v := range cfg.HTTP.Rewrite { if pattern, _ := regexp.Compile(v.Match); pattern != nil { - settings.Rewrite = append(settings.Rewrite, chain.HTTPURLRewriteSetting{ + settings.RewriteURL = append(settings.RewriteURL, chain.HTTPURLRewriteSetting{ Pattern: pattern, Replacement: v.Replacement, }) } } + for _, v := range cfg.HTTP.RewriteBody { + if pattern, _ := regexp.Compile(v.Match); pattern != nil { + settings.RewriteBody = append(settings.RewriteBody, chain.HTTPBodyRewriteSettings{ + Type: v.Type, + Pattern: pattern, + Replacement: []byte(v.Replacement), + }) + } + } opts = append(opts, chain.HTTPNodeOption(settings)) } diff --git a/go.mod b/go.mod index d65431a..cdb6f13 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.6.0 github.com/gin-gonic/gin v1.9.1 - github.com/go-gost/core v0.1.0 + github.com/go-gost/core v0.1.1 github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.4.2 github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a diff --git a/go.sum b/go.sum index c61997e..763f689 100644 --- a/go.sum +++ b/go.sum @@ -53,8 +53,8 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= -github.com/go-gost/core v0.1.0 h1:LJJc8PIlRflE8ZIpxls+wYX1e8OGB0nUKJYh8HevM4U= -github.com/go-gost/core v0.1.0/go.mod h1:WGI43jOka7FAsSAwi/fSMaqxdR+E339ycb4NBGlFr6A= +github.com/go-gost/core v0.1.1 h1:8joR9KJYBvpurNu3i0zqN9orQthVzOjhtT4STumwNF0= +github.com/go-gost/core v0.1.1/go.mod h1:WGI43jOka7FAsSAwi/fSMaqxdR+E339ycb4NBGlFr6A= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.4.2 h1:IianxHTkACPqCwiOAT3MHoMdSUl+SEPSRu1ikawC1Pc= diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index f076d2b..eedbeb7 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -2,6 +2,7 @@ package local import ( "bufio" + "bytes" "context" "crypto/tls" "errors" @@ -105,7 +106,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand } if protocol == forward.ProtoHTTP { - h.handleHTTP(ctx, rw, conn.RemoteAddr(), log) + h.handleHTTP(ctx, xio.NewReadWriteCloser(rw, rw, conn), conn.RemoteAddr(), log) return nil } @@ -176,11 +177,12 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand return nil } -func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, log logger.Logger) (err error) { +func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriteCloser, remoteAddr net.Addr, log logger.Logger) (err error) { br := bufio.NewReader(rw) - var cc net.Conn for { + var cc net.Conn + resp := &http.Response{ ProtoMajor: 1, ProtoMinor: 1, @@ -241,6 +243,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot }) log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr) + var bodyRewrites []chain.HTTPBodyRewriteSettings if httpSettings := target.Options().HTTP; httpSettings != nil { if auther := httpSettings.Auther; auther != nil { username, password, _ := req.BasicAuth() @@ -261,7 +264,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot req.Header.Set(k, v) } - for _, re := range httpSettings.Rewrite { + for _, re := range httpSettings.RewriteURL { if re.Pattern.MatchString(req.URL.Path) { if s := re.Pattern.ReplaceAllString(req.URL.Path, re.Replacement); s != "" { req.URL.Path = s @@ -269,6 +272,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot } } } + + bodyRewrites = httpSettings.RewriteBody } cc, err = h.options.Router.Dial(ctx, "tcp", target.Addr) @@ -329,7 +334,18 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot log.Trace(string(dump)) } + if res.Close { + defer rw.Close() + } + + if err := h.rewriteBody(res, bodyRewrites...); err != nil { + rw.Close() + log.Errorf("rewrite body: %v", err) + return + } + if err = res.Write(rw); err != nil { + rw.Close() log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err) } }() @@ -348,6 +364,54 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot return } +func (h *forwardHandler) rewriteBody(resp *http.Response, rewrites ...chain.HTTPBodyRewriteSettings) error { + if resp == nil || len(rewrites) == 0 || resp.ContentLength <= 0 { + return nil + } + + if encoding := resp.Header.Get("Content-Encoding"); encoding != "" { + return nil + } + + body, err := drainBody(resp.Body) + if err != nil || body == nil { + return err + } + + contentType, _, _ := strings.Cut(resp.Header.Get("Content-Type"), ";") + for _, rewrite := range rewrites { + rewriteType := rewrite.Type + if rewriteType == "" { + rewriteType = "text/html" + } + if rewriteType != "*" && !strings.Contains(rewriteType, contentType) { + continue + } + + body = rewrite.Pattern.ReplaceAll(body, rewrite.Replacement) + } + + resp.Body = io.NopCloser(bytes.NewReader(body)) + resp.ContentLength = int64(len(body)) + + return nil +} + +func drainBody(b io.ReadCloser) (body []byte, err error) { + if b == nil || b == http.NoBody { + // No copying needed. Preserve the magic sentinel meaning of NoBody. + return nil, nil + } + var buf bytes.Buffer + if _, err = buf.ReadFrom(b); err != nil { + return nil, err + } + if err = b.Close(); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + func (h *forwardHandler) checkRateLimit(addr net.Addr) bool { if h.options.RateLimiter == nil { return true diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index 82fad11..0231df0 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -2,6 +2,7 @@ package remote import ( "bufio" + "bytes" "context" "crypto/tls" "errors" @@ -107,7 +108,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand } } if protocol == forward.ProtoHTTP { - h.handleHTTP(ctx, rw, conn.RemoteAddr(), localAddr, log) + h.handleHTTP(ctx, xio.NewReadWriteCloser(rw, rw, conn), conn.RemoteAddr(), localAddr, log) return nil } @@ -177,11 +178,11 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand return nil } -func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) { +func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriteCloser, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) { br := bufio.NewReader(rw) - var cc net.Conn for { + var cc net.Conn resp := &http.Response{ ProtoMajor: 1, ProtoMinor: 1, @@ -242,6 +243,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot }) log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr) + var bodyRewrites []chain.HTTPBodyRewriteSettings if httpSettings := target.Options().HTTP; httpSettings != nil { if auther := httpSettings.Auther; auther != nil { username, password, _ := req.BasicAuth() @@ -261,7 +263,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot req.Header.Set(k, v) } - for _, re := range httpSettings.Rewrite { + for _, re := range httpSettings.RewriteURL { if re.Pattern.MatchString(req.URL.Path) { if s := re.Pattern.ReplaceAllString(req.URL.Path, re.Replacement); s != "" { req.URL.Path = s @@ -269,6 +271,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot } } } + + bodyRewrites = httpSettings.RewriteBody } cc, err = h.options.Router.Dial(ctx, "tcp", target.Addr) @@ -331,7 +335,18 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot log.Trace(string(dump)) } + if res.Close { + defer rw.Close() + } + + if err := h.rewriteBody(res, bodyRewrites...); err != nil { + rw.Close() + log.Errorf("rewrite body: %v", err) + return + } + if err = res.Write(rw); err != nil { + rw.Close() log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err) } }() @@ -350,6 +365,50 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot return } +func (h *forwardHandler) rewriteBody(resp *http.Response, rewrites ...chain.HTTPBodyRewriteSettings) error { + if resp == nil || len(rewrites) == 0 || resp.ContentLength <= 0 { + return nil + } + + body, err := drainBody(resp.Body) + if err != nil || body == nil { + return err + } + + contentType, _, _ := strings.Cut(resp.Header.Get("Content-Type"), ";") + for _, rewrite := range rewrites { + rewriteType := rewrite.Type + if rewriteType == "" { + rewriteType = "text/html" + } + if rewriteType != "*" && !strings.Contains(rewriteType, contentType) { + continue + } + + body = rewrite.Pattern.ReplaceAll(body, rewrite.Replacement) + } + + resp.Body = io.NopCloser(bytes.NewReader(body)) + resp.ContentLength = int64(len(body)) + + return nil +} + +func drainBody(b io.ReadCloser) (body []byte, err error) { + if b == nil || b == http.NoBody { + // No copying needed. Preserve the magic sentinel meaning of NoBody. + return nil, nil + } + var buf bytes.Buffer + if _, err = buf.ReadFrom(b); err != nil { + return nil, err + } + if err = b.Close(); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + func (h *forwardHandler) checkRateLimit(addr net.Addr) bool { if h.options.RateLimiter == nil { return true diff --git a/handler/http/handler.go b/handler/http/handler.go index 0829274..acc6c85 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -279,6 +279,9 @@ func (h *httpHandler) handleProxy(rw, cc io.ReadWriter, req *http.Request, log l err := func() error { req, err := http.ReadRequest(bufio.NewReader(rw)) if err != nil { + if err == io.EOF { + return nil + } return err } diff --git a/handler/socks/v5/handler.go b/handler/socks/v5/handler.go index ece0b83..f24358c 100644 --- a/handler/socks/v5/handler.go +++ b/handler/socks/v5/handler.go @@ -101,6 +101,7 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl if clientID := sc.ID(); clientID != "" { ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID)) + log = log.WithFields(map[string]any{"user": clientID}) } conn = sc diff --git a/handler/tunnel/entrypoint.go b/handler/tunnel/entrypoint.go index b582431..75cd9ed 100644 --- a/handler/tunnel/entrypoint.go +++ b/handler/tunnel/entrypoint.go @@ -123,13 +123,15 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error { timeout: 15 * time.Second, log: log, } - cc, node, cid, err := d.Dial(ctx, "tcp", tunnelID.String()) + c, node, cid, err := d.Dial(ctx, "tcp", tunnelID.String()) if err != nil { log.Error(err) return resp.Write(conn) } log.Debugf("new connection to tunnel: %s, connector: %s", tunnelID, cid) + cc = c + host := req.Host if h, _, _ := net.SplitHostPort(host); h == "" { host = net.JoinHostPort(strings.Trim(host, "[]"), "80") @@ -149,17 +151,26 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error { Version: relay.Version1, Status: relay.StatusOK, Features: features, - }).WriteTo(cc) + }).WriteTo(c) } - if err := req.Write(cc); err != nil { - cc.Close() + // HTTP/1.0 + if req.ProtoMajor == 1 && req.ProtoMinor == 0 { + if strings.ToLower(req.Header.Get("Connection")) == "keep-alive" { + req.Header.Del("Connection") + } else { + req.Header.Set("Connection", "close") + } + } + + if err := req.Write(c); err != nil { + c.Close() log.Errorf("send request: %v", err) return resp.Write(conn) } if req.Header.Get("Upgrade") == "websocket" { - err := xnet.Transport(cc, xio.NewReadWriter(br, conn)) + err := xnet.Transport(c, xio.NewReadWriter(br, conn)) if err == nil { err = io.EOF } @@ -167,7 +178,7 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error { } go func() { - defer cc.Close() + defer c.Close() t := time.Now() log.Debugf("%s <-> %s", remoteAddr, host) @@ -178,7 +189,7 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error { }).Debugf("%s >-< %s", remoteAddr, host) }() - res, err := http.ReadResponse(bufio.NewReader(cc), req) + res, err := http.ReadResponse(bufio.NewReader(c), req) if err != nil { log.Errorf("read response: %v", err) resp.Write(conn) @@ -190,7 +201,21 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error { log.Trace(string(dump)) } + if res.Close { + defer conn.Close() + } + + // HTTP/1.0 + if req.ProtoMajor == 1 && req.ProtoMinor == 0 { + if !res.Close { + res.Header.Set("Connection", "keep-alive") + } + res.ProtoMajor = req.ProtoMajor + res.ProtoMinor = req.ProtoMinor + } + if err = res.Write(conn); err != nil { + conn.Close() log.Errorf("write response: %v", err) } }() diff --git a/handler/tunnel/handler.go b/handler/tunnel/handler.go index 90d4b37..b177520 100644 --- a/handler/tunnel/handler.go +++ b/handler/tunnel/handler.go @@ -264,6 +264,10 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl // Close implements io.Closer interface. func (h *tunnelHandler) Close() error { + if h.epSvc != nil { + h.epSvc.Close() + } + h.pool.Close() return nil } diff --git a/handler/tunnel/tunnel.go b/handler/tunnel/tunnel.go index 7d3d432..ebc9041 100644 --- a/handler/tunnel/tunnel.go +++ b/handler/tunnel/tunnel.go @@ -66,6 +66,14 @@ func (c *Connector) Session() *mux.Session { return c.s } +func (c *Connector) Close() error { + if c == nil || c.s == nil { + return nil + } + + return c.s.Close() +} + type Tunnel struct { node string id relay.TunnelID @@ -75,7 +83,7 @@ type Tunnel struct { mu sync.RWMutex sd sd.SD ttl time.Duration - rw *selector.RandomWeighted[*Connector] + // rw *selector.RandomWeighted[*Connector] } func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel { @@ -85,7 +93,7 @@ func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel { t: time.Now(), close: make(chan struct{}), ttl: ttl, - rw: selector.NewRandomWeighted[*Connector](), + // rw: selector.NewRandomWeighted[*Connector](), } if t.ttl <= 0 { t.ttl = defaultTTL @@ -117,8 +125,14 @@ func (t *Tunnel) GetConnector(network string) *Connector { t.mu.RLock() defer t.mu.RUnlock() - rw := t.rw - rw.Reset() + // rw := t.rw + // rw.Reset() + + if len(t.connectors) == 1 { + return t.connectors[0] + } + + rw := selector.NewRandomWeighted[*Connector]() found := false for _, c := range t.connectors { @@ -147,6 +161,22 @@ func (t *Tunnel) GetConnector(network string) *Connector { return rw.Next() } +func (t *Tunnel) Close() error { + t.mu.Lock() + defer t.mu.Unlock() + + select { + case <-t.close: + default: + for _, c := range t.connectors { + c.Close() + } + close(t.close) + } + + return nil +} + func (t *Tunnel) CloseOnIdle() bool { t.mu.RLock() defer t.mu.RUnlock() @@ -256,6 +286,22 @@ func (p *ConnectorPool) Get(network string, tid string) *Connector { return t.GetConnector(network) } +func (p *ConnectorPool) Close() error { + if p == nil { + return nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + for k, v := range p.tunnels { + v.Close() + delete(p.tunnels, k) + } + + return nil +} + func (p *ConnectorPool) closeIdles() { ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() diff --git a/ingress/ingress.go b/ingress/ingress.go index 96027dc..e6487af 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -245,7 +245,7 @@ func (ing *localIngress) GetRule(ctx context.Context, host string, opts ...ingre } if ep != nil { - ing.options.logger.Debugf("ingress: %s -> %s", host, ep) + ing.options.logger.Debugf("ingress: %s -> %s:%s", host, ep.Hostname, ep.Endpoint) } return ep diff --git a/internal/io/io.go b/internal/io/io.go index d27bc0f..87c72f8 100644 --- a/internal/io/io.go +++ b/internal/io/io.go @@ -13,3 +13,17 @@ func NewReadWriter(r io.Reader, w io.Writer) io.ReadWriter { Writer: w, } } + +type readWriteCloser struct { + io.Reader + io.Writer + io.Closer +} + +func NewReadWriteCloser(r io.Reader, w io.Writer, c io.Closer) io.ReadWriteCloser { + return &readWriteCloser{ + Reader: r, + Writer: w, + Closer: c, + } +} diff --git a/internal/net/udp/listener.go b/internal/net/udp/listener.go index b88b7ce..affeaa4 100644 --- a/internal/net/udp/listener.go +++ b/internal/net/udp/listener.go @@ -178,12 +178,15 @@ func (c *conn) Read(b []byte) (n int, err error) { return } -func (c *conn) Write(b []byte) (n int, err error) { - n, err = c.WriteTo(b, c.remoteAddr) +func (c *conn) WriteTo(b []byte, addr net.Addr) (n int, err error) { if !c.keepAlive { - c.Close() + defer c.Close() } - return + return c.PacketConn.WriteTo(b, addr) +} + +func (c *conn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.remoteAddr) } func (c *conn) Close() error {