add http body rewrite for forward handler
This commit is contained in:
parent
c0a80400d2
commit
3656ba9315
@ -376,6 +376,13 @@ type HTTPURLRewriteConfig struct {
|
|||||||
Replacement string
|
Replacement string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type HTTPBodyRewriteConfig struct {
|
||||||
|
// filter by MIME types
|
||||||
|
Type string
|
||||||
|
Match string
|
||||||
|
Replacement string
|
||||||
|
}
|
||||||
|
|
||||||
type NodeFilterConfig struct {
|
type NodeFilterConfig struct {
|
||||||
Host string `yaml:",omitempty" json:"host,omitempty"`
|
Host string `yaml:",omitempty" json:"host,omitempty"`
|
||||||
Protocol string `yaml:",omitempty" json:"protocol,omitempty"`
|
Protocol string `yaml:",omitempty" json:"protocol,omitempty"`
|
||||||
@ -383,10 +390,16 @@ type NodeFilterConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type HTTPNodeConfig struct {
|
type HTTPNodeConfig struct {
|
||||||
Host string `yaml:",omitempty" json:"host,omitempty"`
|
// rewrite host header
|
||||||
Header map[string]string `yaml:",omitempty" json:"header,omitempty"`
|
Host string `yaml:",omitempty" json:"host,omitempty"`
|
||||||
|
// additional request header
|
||||||
|
Header map[string]string `yaml:",omitempty" json:"header,omitempty"`
|
||||||
|
// rewrite URL
|
||||||
Rewrite []HTTPURLRewriteConfig `yaml:",omitempty" json:"rewrite,omitempty"`
|
Rewrite []HTTPURLRewriteConfig `yaml:",omitempty" json:"rewrite,omitempty"`
|
||||||
Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"`
|
// rewrite response body
|
||||||
|
RewriteBody []HTTPBodyRewriteConfig `yaml:"rewriteBody,omitempty" json:"rewriteBody,omitempty"`
|
||||||
|
// HTTP basic auth
|
||||||
|
Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TLSNodeConfig struct {
|
type TLSNodeConfig struct {
|
||||||
|
@ -193,12 +193,21 @@ func ParseNode(hop string, cfg *config.NodeConfig, log logger.Logger) (*chain.No
|
|||||||
}
|
}
|
||||||
for _, v := range cfg.HTTP.Rewrite {
|
for _, v := range cfg.HTTP.Rewrite {
|
||||||
if pattern, _ := regexp.Compile(v.Match); pattern != nil {
|
if pattern, _ := regexp.Compile(v.Match); pattern != nil {
|
||||||
settings.Rewrite = append(settings.Rewrite, chain.HTTPURLRewriteSetting{
|
settings.RewriteURL = append(settings.RewriteURL, chain.HTTPURLRewriteSetting{
|
||||||
Pattern: pattern,
|
Pattern: pattern,
|
||||||
Replacement: v.Replacement,
|
Replacement: v.Replacement,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for _, v := range cfg.HTTP.RewriteBody {
|
||||||
|
if pattern, _ := regexp.Compile(v.Match); pattern != nil {
|
||||||
|
settings.RewriteBody = append(settings.RewriteBody, chain.HTTPBodyRewriteSettings{
|
||||||
|
Type: v.Type,
|
||||||
|
Pattern: pattern,
|
||||||
|
Replacement: []byte(v.Replacement),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
opts = append(opts, chain.HTTPNodeOption(settings))
|
opts = append(opts, chain.HTTPNodeOption(settings))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
2
go.mod
2
go.mod
@ -9,7 +9,7 @@ require (
|
|||||||
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
|
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
|
||||||
github.com/gin-contrib/cors v1.6.0
|
github.com/gin-contrib/cors v1.6.0
|
||||||
github.com/gin-gonic/gin v1.9.1
|
github.com/gin-gonic/gin v1.9.1
|
||||||
github.com/go-gost/core v0.1.0
|
github.com/go-gost/core v0.1.1
|
||||||
github.com/go-gost/gosocks4 v0.0.1
|
github.com/go-gost/gosocks4 v0.0.1
|
||||||
github.com/go-gost/gosocks5 v0.4.2
|
github.com/go-gost/gosocks5 v0.4.2
|
||||||
github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a
|
github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a
|
||||||
|
4
go.sum
4
go.sum
@ -53,8 +53,8 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
|
|||||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||||
github.com/go-gost/core v0.1.0 h1:LJJc8PIlRflE8ZIpxls+wYX1e8OGB0nUKJYh8HevM4U=
|
github.com/go-gost/core v0.1.1 h1:8joR9KJYBvpurNu3i0zqN9orQthVzOjhtT4STumwNF0=
|
||||||
github.com/go-gost/core v0.1.0/go.mod h1:WGI43jOka7FAsSAwi/fSMaqxdR+E339ycb4NBGlFr6A=
|
github.com/go-gost/core v0.1.1/go.mod h1:WGI43jOka7FAsSAwi/fSMaqxdR+E339ycb4NBGlFr6A=
|
||||||
github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s=
|
github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s=
|
||||||
github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc=
|
github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc=
|
||||||
github.com/go-gost/gosocks5 v0.4.2 h1:IianxHTkACPqCwiOAT3MHoMdSUl+SEPSRu1ikawC1Pc=
|
github.com/go-gost/gosocks5 v0.4.2 h1:IianxHTkACPqCwiOAT3MHoMdSUl+SEPSRu1ikawC1Pc=
|
||||||
|
@ -2,6 +2,7 @@ package local
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
@ -179,8 +180,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
|||||||
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, log logger.Logger) (err error) {
|
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, log logger.Logger) (err error) {
|
||||||
br := bufio.NewReader(rw)
|
br := bufio.NewReader(rw)
|
||||||
|
|
||||||
var cc net.Conn
|
|
||||||
for {
|
for {
|
||||||
|
var cc net.Conn
|
||||||
|
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
ProtoMajor: 1,
|
ProtoMajor: 1,
|
||||||
ProtoMinor: 1,
|
ProtoMinor: 1,
|
||||||
@ -241,6 +243,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
})
|
})
|
||||||
log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr)
|
log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr)
|
||||||
|
|
||||||
|
var bodyRewrites []chain.HTTPBodyRewriteSettings
|
||||||
if httpSettings := target.Options().HTTP; httpSettings != nil {
|
if httpSettings := target.Options().HTTP; httpSettings != nil {
|
||||||
if auther := httpSettings.Auther; auther != nil {
|
if auther := httpSettings.Auther; auther != nil {
|
||||||
username, password, _ := req.BasicAuth()
|
username, password, _ := req.BasicAuth()
|
||||||
@ -261,7 +264,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
req.Header.Set(k, v)
|
req.Header.Set(k, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, re := range httpSettings.Rewrite {
|
for _, re := range httpSettings.RewriteURL {
|
||||||
if re.Pattern.MatchString(req.URL.Path) {
|
if re.Pattern.MatchString(req.URL.Path) {
|
||||||
if s := re.Pattern.ReplaceAllString(req.URL.Path, re.Replacement); s != "" {
|
if s := re.Pattern.ReplaceAllString(req.URL.Path, re.Replacement); s != "" {
|
||||||
req.URL.Path = s
|
req.URL.Path = s
|
||||||
@ -269,6 +272,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bodyRewrites = httpSettings.RewriteBody
|
||||||
}
|
}
|
||||||
|
|
||||||
cc, err = h.options.Router.Dial(ctx, "tcp", target.Addr)
|
cc, err = h.options.Router.Dial(ctx, "tcp", target.Addr)
|
||||||
@ -329,6 +334,11 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
log.Trace(string(dump))
|
log.Trace(string(dump))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := h.rewriteBody(res, bodyRewrites...); err != nil {
|
||||||
|
log.Errorf("rewrite body: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err = res.Write(rw); err != nil {
|
if err = res.Write(rw); err != nil {
|
||||||
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
||||||
}
|
}
|
||||||
@ -348,6 +358,54 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *forwardHandler) rewriteBody(resp *http.Response, rewrites ...chain.HTTPBodyRewriteSettings) error {
|
||||||
|
if resp == nil || len(rewrites) == 0 || resp.ContentLength <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := drainBody(resp.Body)
|
||||||
|
if err != nil || body == nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType, _, _ := strings.Cut(resp.Header.Get("Content-Type"), ";")
|
||||||
|
for _, rewrite := range rewrites {
|
||||||
|
rewriteType := rewrite.Type
|
||||||
|
if rewriteType == "" {
|
||||||
|
rewriteType = "text/html"
|
||||||
|
}
|
||||||
|
if rewriteType != "*" && !strings.Contains(rewriteType, contentType) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
body = rewrite.Pattern.ReplaceAll(body, rewrite.Replacement)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
resp.ContentLength = int64(len(body))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func drainBody(b io.ReadCloser) (body []byte, err error) {
|
||||||
|
if b == nil || b == http.NoBody {
|
||||||
|
// No copying needed. Preserve the magic sentinel meaning of NoBody.
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if _, err = buf.ReadFrom(b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = b.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (h *forwardHandler) checkRateLimit(addr net.Addr) bool {
|
func (h *forwardHandler) checkRateLimit(addr net.Addr) bool {
|
||||||
if h.options.RateLimiter == nil {
|
if h.options.RateLimiter == nil {
|
||||||
return true
|
return true
|
||||||
|
@ -2,6 +2,7 @@ package remote
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
@ -179,9 +180,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
|||||||
|
|
||||||
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) {
|
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) {
|
||||||
br := bufio.NewReader(rw)
|
br := bufio.NewReader(rw)
|
||||||
var cc net.Conn
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
var cc net.Conn
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
ProtoMajor: 1,
|
ProtoMajor: 1,
|
||||||
ProtoMinor: 1,
|
ProtoMinor: 1,
|
||||||
@ -242,6 +243,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
})
|
})
|
||||||
log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr)
|
log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr)
|
||||||
|
|
||||||
|
var bodyRewrites []chain.HTTPBodyRewriteSettings
|
||||||
if httpSettings := target.Options().HTTP; httpSettings != nil {
|
if httpSettings := target.Options().HTTP; httpSettings != nil {
|
||||||
if auther := httpSettings.Auther; auther != nil {
|
if auther := httpSettings.Auther; auther != nil {
|
||||||
username, password, _ := req.BasicAuth()
|
username, password, _ := req.BasicAuth()
|
||||||
@ -261,7 +263,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
req.Header.Set(k, v)
|
req.Header.Set(k, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, re := range httpSettings.Rewrite {
|
for _, re := range httpSettings.RewriteURL {
|
||||||
if re.Pattern.MatchString(req.URL.Path) {
|
if re.Pattern.MatchString(req.URL.Path) {
|
||||||
if s := re.Pattern.ReplaceAllString(req.URL.Path, re.Replacement); s != "" {
|
if s := re.Pattern.ReplaceAllString(req.URL.Path, re.Replacement); s != "" {
|
||||||
req.URL.Path = s
|
req.URL.Path = s
|
||||||
@ -269,6 +271,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bodyRewrites = httpSettings.RewriteBody
|
||||||
}
|
}
|
||||||
|
|
||||||
cc, err = h.options.Router.Dial(ctx, "tcp", target.Addr)
|
cc, err = h.options.Router.Dial(ctx, "tcp", target.Addr)
|
||||||
@ -331,6 +335,11 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
log.Trace(string(dump))
|
log.Trace(string(dump))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := h.rewriteBody(res, bodyRewrites...); err != nil {
|
||||||
|
log.Errorf("rewrite body: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err = res.Write(rw); err != nil {
|
if err = res.Write(rw); err != nil {
|
||||||
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
||||||
}
|
}
|
||||||
@ -350,6 +359,50 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *forwardHandler) rewriteBody(resp *http.Response, rewrites ...chain.HTTPBodyRewriteSettings) error {
|
||||||
|
if resp == nil || len(rewrites) == 0 || resp.ContentLength <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := drainBody(resp.Body)
|
||||||
|
if err != nil || body == nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType, _, _ := strings.Cut(resp.Header.Get("Content-Type"), ";")
|
||||||
|
for _, rewrite := range rewrites {
|
||||||
|
rewriteType := rewrite.Type
|
||||||
|
if rewriteType == "" {
|
||||||
|
rewriteType = "text/html"
|
||||||
|
}
|
||||||
|
if rewriteType != "*" && !strings.Contains(rewriteType, contentType) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
body = rewrite.Pattern.ReplaceAll(body, rewrite.Replacement)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
resp.ContentLength = int64(len(body))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func drainBody(b io.ReadCloser) (body []byte, err error) {
|
||||||
|
if b == nil || b == http.NoBody {
|
||||||
|
// No copying needed. Preserve the magic sentinel meaning of NoBody.
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if _, err = buf.ReadFrom(b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = b.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (h *forwardHandler) checkRateLimit(addr net.Addr) bool {
|
func (h *forwardHandler) checkRateLimit(addr net.Addr) bool {
|
||||||
if h.options.RateLimiter == nil {
|
if h.options.RateLimiter == nil {
|
||||||
return true
|
return true
|
||||||
|
Loading…
Reference in New Issue
Block a user