add http body rewrite for forward handler
This commit is contained in:
@ -2,6 +2,7 @@ package local
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
@ -179,8 +180,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
||||
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, log logger.Logger) (err error) {
|
||||
br := bufio.NewReader(rw)
|
||||
|
||||
var cc net.Conn
|
||||
for {
|
||||
var cc net.Conn
|
||||
|
||||
resp := &http.Response{
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
@ -241,6 +243,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
||||
})
|
||||
log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr)
|
||||
|
||||
var bodyRewrites []chain.HTTPBodyRewriteSettings
|
||||
if httpSettings := target.Options().HTTP; httpSettings != nil {
|
||||
if auther := httpSettings.Auther; auther != nil {
|
||||
username, password, _ := req.BasicAuth()
|
||||
@ -261,7 +264,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
for _, re := range httpSettings.Rewrite {
|
||||
for _, re := range httpSettings.RewriteURL {
|
||||
if re.Pattern.MatchString(req.URL.Path) {
|
||||
if s := re.Pattern.ReplaceAllString(req.URL.Path, re.Replacement); s != "" {
|
||||
req.URL.Path = s
|
||||
@ -269,6 +272,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bodyRewrites = httpSettings.RewriteBody
|
||||
}
|
||||
|
||||
cc, err = h.options.Router.Dial(ctx, "tcp", target.Addr)
|
||||
@ -329,6 +334,11 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
||||
log.Trace(string(dump))
|
||||
}
|
||||
|
||||
if err := h.rewriteBody(res, bodyRewrites...); err != nil {
|
||||
log.Errorf("rewrite body: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err = res.Write(rw); err != nil {
|
||||
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
||||
}
|
||||
@ -348,6 +358,54 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
||||
return
|
||||
}
|
||||
|
||||
func (h *forwardHandler) rewriteBody(resp *http.Response, rewrites ...chain.HTTPBodyRewriteSettings) error {
|
||||
if resp == nil || len(rewrites) == 0 || resp.ContentLength <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
body, err := drainBody(resp.Body)
|
||||
if err != nil || body == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
contentType, _, _ := strings.Cut(resp.Header.Get("Content-Type"), ";")
|
||||
for _, rewrite := range rewrites {
|
||||
rewriteType := rewrite.Type
|
||||
if rewriteType == "" {
|
||||
rewriteType = "text/html"
|
||||
}
|
||||
if rewriteType != "*" && !strings.Contains(rewriteType, contentType) {
|
||||
continue
|
||||
}
|
||||
|
||||
body = rewrite.Pattern.ReplaceAll(body, rewrite.Replacement)
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||
resp.ContentLength = int64(len(body))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func drainBody(b io.ReadCloser) (body []byte, err error) {
|
||||
if b == nil || b == http.NoBody {
|
||||
// No copying needed. Preserve the magic sentinel meaning of NoBody.
|
||||
return nil, nil
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if _, err = buf.ReadFrom(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = b.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (h *forwardHandler) checkRateLimit(addr net.Addr) bool {
|
||||
if h.options.RateLimiter == nil {
|
||||
return true
|
||||
|
@ -2,6 +2,7 @@ package remote
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
@ -179,9 +180,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
||||
|
||||
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) {
|
||||
br := bufio.NewReader(rw)
|
||||
var cc net.Conn
|
||||
|
||||
for {
|
||||
var cc net.Conn
|
||||
resp := &http.Response{
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
@ -242,6 +243,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
||||
})
|
||||
log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr)
|
||||
|
||||
var bodyRewrites []chain.HTTPBodyRewriteSettings
|
||||
if httpSettings := target.Options().HTTP; httpSettings != nil {
|
||||
if auther := httpSettings.Auther; auther != nil {
|
||||
username, password, _ := req.BasicAuth()
|
||||
@ -261,7 +263,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
for _, re := range httpSettings.Rewrite {
|
||||
for _, re := range httpSettings.RewriteURL {
|
||||
if re.Pattern.MatchString(req.URL.Path) {
|
||||
if s := re.Pattern.ReplaceAllString(req.URL.Path, re.Replacement); s != "" {
|
||||
req.URL.Path = s
|
||||
@ -269,6 +271,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bodyRewrites = httpSettings.RewriteBody
|
||||
}
|
||||
|
||||
cc, err = h.options.Router.Dial(ctx, "tcp", target.Addr)
|
||||
@ -331,6 +335,11 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
||||
log.Trace(string(dump))
|
||||
}
|
||||
|
||||
if err := h.rewriteBody(res, bodyRewrites...); err != nil {
|
||||
log.Errorf("rewrite body: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err = res.Write(rw); err != nil {
|
||||
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
||||
}
|
||||
@ -350,6 +359,50 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
||||
return
|
||||
}
|
||||
|
||||
func (h *forwardHandler) rewriteBody(resp *http.Response, rewrites ...chain.HTTPBodyRewriteSettings) error {
|
||||
if resp == nil || len(rewrites) == 0 || resp.ContentLength <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
body, err := drainBody(resp.Body)
|
||||
if err != nil || body == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
contentType, _, _ := strings.Cut(resp.Header.Get("Content-Type"), ";")
|
||||
for _, rewrite := range rewrites {
|
||||
rewriteType := rewrite.Type
|
||||
if rewriteType == "" {
|
||||
rewriteType = "text/html"
|
||||
}
|
||||
if rewriteType != "*" && !strings.Contains(rewriteType, contentType) {
|
||||
continue
|
||||
}
|
||||
|
||||
body = rewrite.Pattern.ReplaceAll(body, rewrite.Replacement)
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||
resp.ContentLength = int64(len(body))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func drainBody(b io.ReadCloser) (body []byte, err error) {
|
||||
if b == nil || b == http.NoBody {
|
||||
// No copying needed. Preserve the magic sentinel meaning of NoBody.
|
||||
return nil, nil
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if _, err = buf.ReadFrom(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = b.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (h *forwardHandler) checkRateLimit(addr net.Addr) bool {
|
||||
if h.options.RateLimiter == nil {
|
||||
return true
|
||||
|
Reference in New Issue
Block a user