Compare commits
8 Commits
f73960ad36
...
490e6b40f5
Author | SHA1 | Date | |
---|---|---|---|
|
490e6b40f5 | ||
|
bc0d6953bc | ||
|
22e522e933 | ||
|
5e8a8a4b4d | ||
|
fa16373d66 | ||
|
1a776dc759 | ||
|
12ef82e41f | ||
|
3656ba9315 |
@ -78,7 +78,7 @@ func (*defaultRoute) Bind(ctx context.Context, network, address string, opts ...
|
|||||||
ReadQueueSize: options.UDPDataQueueSize,
|
ReadQueueSize: options.UDPDataQueueSize,
|
||||||
ReadBufferSize: options.UDPDataBufferSize,
|
ReadBufferSize: options.UDPDataBufferSize,
|
||||||
TTL: options.UDPConnTTL,
|
TTL: options.UDPConnTTL,
|
||||||
KeepAlive: true,
|
Keepalive: true,
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
})
|
})
|
||||||
return ln, err
|
return ln, err
|
||||||
|
@ -42,12 +42,6 @@ func (r *Router) Options() *chain.RouterOptions {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
||||||
if r.options.Timeout > 0 {
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
ctx, cancel = context.WithTimeout(ctx, r.options.Timeout)
|
|
||||||
defer cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
host := address
|
host := address
|
||||||
if h, _, _ := net.SplitHostPort(address); h != "" {
|
if h, _, _ := net.SplitHostPort(address); h != "" {
|
||||||
host = h
|
host = h
|
||||||
@ -93,6 +87,13 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
|
|||||||
r.options.Logger.Debugf("dial %s/%s", address, network)
|
r.options.Logger.Debugf("dial %s/%s", address, network)
|
||||||
|
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
|
ctx := ctx
|
||||||
|
if r.options.Timeout > 0 {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, r.options.Timeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
var ipAddr string
|
var ipAddr string
|
||||||
ipAddr, err = xnet.Resolve(ctx, "ip", address, r.options.Resolver, r.options.HostMapper, r.options.Logger)
|
ipAddr, err = xnet.Resolve(ctx, "ip", address, r.options.Resolver, r.options.HostMapper, r.options.Logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -133,12 +134,6 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) Bind(ctx context.Context, network, address string, opts ...chain.BindOption) (ln net.Listener, err error) {
|
func (r *Router) Bind(ctx context.Context, network, address string, opts ...chain.BindOption) (ln net.Listener, err error) {
|
||||||
if r.options.Timeout > 0 {
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
ctx, cancel = context.WithTimeout(ctx, r.options.Timeout)
|
|
||||||
defer cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
count := r.options.Retries + 1
|
count := r.options.Retries + 1
|
||||||
if count <= 0 {
|
if count <= 0 {
|
||||||
count = 1
|
count = 1
|
||||||
@ -146,6 +141,13 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...chai
|
|||||||
r.options.Logger.Debugf("bind on %s/%s", address, network)
|
r.options.Logger.Debugf("bind on %s/%s", address, network)
|
||||||
|
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
|
ctx := ctx
|
||||||
|
if r.options.Timeout > 0 {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, r.options.Timeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
var route chain.Route
|
var route chain.Route
|
||||||
if r.options.Chain != nil {
|
if r.options.Chain != nil {
|
||||||
route = r.options.Chain.Route(ctx, network, address)
|
route = r.options.Chain.Route(ctx, network, address)
|
||||||
|
@ -376,6 +376,13 @@ type HTTPURLRewriteConfig struct {
|
|||||||
Replacement string
|
Replacement string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type HTTPBodyRewriteConfig struct {
|
||||||
|
// filter by MIME types
|
||||||
|
Type string
|
||||||
|
Match string
|
||||||
|
Replacement string
|
||||||
|
}
|
||||||
|
|
||||||
type NodeFilterConfig struct {
|
type NodeFilterConfig struct {
|
||||||
Host string `yaml:",omitempty" json:"host,omitempty"`
|
Host string `yaml:",omitempty" json:"host,omitempty"`
|
||||||
Protocol string `yaml:",omitempty" json:"protocol,omitempty"`
|
Protocol string `yaml:",omitempty" json:"protocol,omitempty"`
|
||||||
@ -383,10 +390,16 @@ type NodeFilterConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type HTTPNodeConfig struct {
|
type HTTPNodeConfig struct {
|
||||||
Host string `yaml:",omitempty" json:"host,omitempty"`
|
// rewrite host header
|
||||||
Header map[string]string `yaml:",omitempty" json:"header,omitempty"`
|
Host string `yaml:",omitempty" json:"host,omitempty"`
|
||||||
|
// additional request header
|
||||||
|
Header map[string]string `yaml:",omitempty" json:"header,omitempty"`
|
||||||
|
// rewrite URL
|
||||||
Rewrite []HTTPURLRewriteConfig `yaml:",omitempty" json:"rewrite,omitempty"`
|
Rewrite []HTTPURLRewriteConfig `yaml:",omitempty" json:"rewrite,omitempty"`
|
||||||
Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"`
|
// rewrite response body
|
||||||
|
RewriteBody []HTTPBodyRewriteConfig `yaml:"rewriteBody,omitempty" json:"rewriteBody,omitempty"`
|
||||||
|
// HTTP basic auth
|
||||||
|
Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TLSNodeConfig struct {
|
type TLSNodeConfig struct {
|
||||||
|
@ -193,12 +193,21 @@ func ParseNode(hop string, cfg *config.NodeConfig, log logger.Logger) (*chain.No
|
|||||||
}
|
}
|
||||||
for _, v := range cfg.HTTP.Rewrite {
|
for _, v := range cfg.HTTP.Rewrite {
|
||||||
if pattern, _ := regexp.Compile(v.Match); pattern != nil {
|
if pattern, _ := regexp.Compile(v.Match); pattern != nil {
|
||||||
settings.Rewrite = append(settings.Rewrite, chain.HTTPURLRewriteSetting{
|
settings.RewriteURL = append(settings.RewriteURL, chain.HTTPURLRewriteSetting{
|
||||||
Pattern: pattern,
|
Pattern: pattern,
|
||||||
Replacement: v.Replacement,
|
Replacement: v.Replacement,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for _, v := range cfg.HTTP.RewriteBody {
|
||||||
|
if pattern, _ := regexp.Compile(v.Match); pattern != nil {
|
||||||
|
settings.RewriteBody = append(settings.RewriteBody, chain.HTTPBodyRewriteSettings{
|
||||||
|
Type: v.Type,
|
||||||
|
Pattern: pattern,
|
||||||
|
Replacement: []byte(v.Replacement),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
opts = append(opts, chain.HTTPNodeOption(settings))
|
opts = append(opts, chain.HTTPNodeOption(settings))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ func (c *relayConnector) bindUDP(ctx context.Context, conn net.Conn, network, ad
|
|||||||
ReadQueueSize: opts.UDPDataQueueSize,
|
ReadQueueSize: opts.UDPDataQueueSize,
|
||||||
ReadBufferSize: opts.UDPDataBufferSize,
|
ReadBufferSize: opts.UDPDataBufferSize,
|
||||||
TTL: opts.UDPConnTTL,
|
TTL: opts.UDPConnTTL,
|
||||||
KeepAlive: true,
|
Keepalive: true,
|
||||||
Logger: log,
|
Logger: log,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ func (c *socks5Connector) bindUDP(ctx context.Context, conn net.Conn, network, a
|
|||||||
ReadQueueSize: opts.UDPDataQueueSize,
|
ReadQueueSize: opts.UDPDataQueueSize,
|
||||||
ReadBufferSize: opts.UDPDataBufferSize,
|
ReadBufferSize: opts.UDPDataBufferSize,
|
||||||
TTL: opts.UDPConnTTL,
|
TTL: opts.UDPConnTTL,
|
||||||
KeepAlive: true,
|
Keepalive: true,
|
||||||
Logger: log,
|
Logger: log,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
2
go.mod
2
go.mod
@ -9,7 +9,7 @@ require (
|
|||||||
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
|
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
|
||||||
github.com/gin-contrib/cors v1.6.0
|
github.com/gin-contrib/cors v1.6.0
|
||||||
github.com/gin-gonic/gin v1.9.1
|
github.com/gin-gonic/gin v1.9.1
|
||||||
github.com/go-gost/core v0.1.0
|
github.com/go-gost/core v0.1.1
|
||||||
github.com/go-gost/gosocks4 v0.0.1
|
github.com/go-gost/gosocks4 v0.0.1
|
||||||
github.com/go-gost/gosocks5 v0.4.2
|
github.com/go-gost/gosocks5 v0.4.2
|
||||||
github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a
|
github.com/go-gost/plugin v0.0.0-20240103125338-9c84e29cb81a
|
||||||
|
4
go.sum
4
go.sum
@ -53,8 +53,8 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
|
|||||||
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
|
||||||
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||||
github.com/go-gost/core v0.1.0 h1:LJJc8PIlRflE8ZIpxls+wYX1e8OGB0nUKJYh8HevM4U=
|
github.com/go-gost/core v0.1.1 h1:8joR9KJYBvpurNu3i0zqN9orQthVzOjhtT4STumwNF0=
|
||||||
github.com/go-gost/core v0.1.0/go.mod h1:WGI43jOka7FAsSAwi/fSMaqxdR+E339ycb4NBGlFr6A=
|
github.com/go-gost/core v0.1.1/go.mod h1:WGI43jOka7FAsSAwi/fSMaqxdR+E339ycb4NBGlFr6A=
|
||||||
github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s=
|
github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s=
|
||||||
github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc=
|
github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc=
|
||||||
github.com/go-gost/gosocks5 v0.4.2 h1:IianxHTkACPqCwiOAT3MHoMdSUl+SEPSRu1ikawC1Pc=
|
github.com/go-gost/gosocks5 v0.4.2 h1:IianxHTkACPqCwiOAT3MHoMdSUl+SEPSRu1ikawC1Pc=
|
||||||
|
@ -2,6 +2,7 @@ package local
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
@ -105,7 +106,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
|||||||
}
|
}
|
||||||
|
|
||||||
if protocol == forward.ProtoHTTP {
|
if protocol == forward.ProtoHTTP {
|
||||||
h.handleHTTP(ctx, rw, conn.RemoteAddr(), log)
|
h.handleHTTP(ctx, xio.NewReadWriteCloser(rw, rw, conn), conn.RemoteAddr(), log)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,11 +177,12 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, log logger.Logger) (err error) {
|
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriteCloser, remoteAddr net.Addr, log logger.Logger) (err error) {
|
||||||
br := bufio.NewReader(rw)
|
br := bufio.NewReader(rw)
|
||||||
|
|
||||||
var cc net.Conn
|
|
||||||
for {
|
for {
|
||||||
|
var cc net.Conn
|
||||||
|
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
ProtoMajor: 1,
|
ProtoMajor: 1,
|
||||||
ProtoMinor: 1,
|
ProtoMinor: 1,
|
||||||
@ -241,6 +243,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
})
|
})
|
||||||
log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr)
|
log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr)
|
||||||
|
|
||||||
|
var bodyRewrites []chain.HTTPBodyRewriteSettings
|
||||||
if httpSettings := target.Options().HTTP; httpSettings != nil {
|
if httpSettings := target.Options().HTTP; httpSettings != nil {
|
||||||
if auther := httpSettings.Auther; auther != nil {
|
if auther := httpSettings.Auther; auther != nil {
|
||||||
username, password, _ := req.BasicAuth()
|
username, password, _ := req.BasicAuth()
|
||||||
@ -261,7 +264,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
req.Header.Set(k, v)
|
req.Header.Set(k, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, re := range httpSettings.Rewrite {
|
for _, re := range httpSettings.RewriteURL {
|
||||||
if re.Pattern.MatchString(req.URL.Path) {
|
if re.Pattern.MatchString(req.URL.Path) {
|
||||||
if s := re.Pattern.ReplaceAllString(req.URL.Path, re.Replacement); s != "" {
|
if s := re.Pattern.ReplaceAllString(req.URL.Path, re.Replacement); s != "" {
|
||||||
req.URL.Path = s
|
req.URL.Path = s
|
||||||
@ -269,6 +272,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bodyRewrites = httpSettings.RewriteBody
|
||||||
}
|
}
|
||||||
|
|
||||||
cc, err = h.options.Router.Dial(ctx, "tcp", target.Addr)
|
cc, err = h.options.Router.Dial(ctx, "tcp", target.Addr)
|
||||||
@ -329,7 +334,18 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
log.Trace(string(dump))
|
log.Trace(string(dump))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if res.Close {
|
||||||
|
defer rw.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.rewriteBody(res, bodyRewrites...); err != nil {
|
||||||
|
rw.Close()
|
||||||
|
log.Errorf("rewrite body: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err = res.Write(rw); err != nil {
|
if err = res.Write(rw); err != nil {
|
||||||
|
rw.Close()
|
||||||
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -348,6 +364,54 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *forwardHandler) rewriteBody(resp *http.Response, rewrites ...chain.HTTPBodyRewriteSettings) error {
|
||||||
|
if resp == nil || len(rewrites) == 0 || resp.ContentLength <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoding := resp.Header.Get("Content-Encoding"); encoding != "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := drainBody(resp.Body)
|
||||||
|
if err != nil || body == nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType, _, _ := strings.Cut(resp.Header.Get("Content-Type"), ";")
|
||||||
|
for _, rewrite := range rewrites {
|
||||||
|
rewriteType := rewrite.Type
|
||||||
|
if rewriteType == "" {
|
||||||
|
rewriteType = "text/html"
|
||||||
|
}
|
||||||
|
if rewriteType != "*" && !strings.Contains(rewriteType, contentType) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
body = rewrite.Pattern.ReplaceAll(body, rewrite.Replacement)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
resp.ContentLength = int64(len(body))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func drainBody(b io.ReadCloser) (body []byte, err error) {
|
||||||
|
if b == nil || b == http.NoBody {
|
||||||
|
// No copying needed. Preserve the magic sentinel meaning of NoBody.
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if _, err = buf.ReadFrom(b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = b.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (h *forwardHandler) checkRateLimit(addr net.Addr) bool {
|
func (h *forwardHandler) checkRateLimit(addr net.Addr) bool {
|
||||||
if h.options.RateLimiter == nil {
|
if h.options.RateLimiter == nil {
|
||||||
return true
|
return true
|
||||||
|
@ -2,6 +2,7 @@ package remote
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
@ -107,7 +108,7 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if protocol == forward.ProtoHTTP {
|
if protocol == forward.ProtoHTTP {
|
||||||
h.handleHTTP(ctx, rw, conn.RemoteAddr(), localAddr, log)
|
h.handleHTTP(ctx, xio.NewReadWriteCloser(rw, rw, conn), conn.RemoteAddr(), localAddr, log)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -177,11 +178,11 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) {
|
func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriteCloser, remoteAddr net.Addr, localAddr net.Addr, log logger.Logger) (err error) {
|
||||||
br := bufio.NewReader(rw)
|
br := bufio.NewReader(rw)
|
||||||
var cc net.Conn
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
var cc net.Conn
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
ProtoMajor: 1,
|
ProtoMajor: 1,
|
||||||
ProtoMinor: 1,
|
ProtoMinor: 1,
|
||||||
@ -242,6 +243,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
})
|
})
|
||||||
log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr)
|
log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr)
|
||||||
|
|
||||||
|
var bodyRewrites []chain.HTTPBodyRewriteSettings
|
||||||
if httpSettings := target.Options().HTTP; httpSettings != nil {
|
if httpSettings := target.Options().HTTP; httpSettings != nil {
|
||||||
if auther := httpSettings.Auther; auther != nil {
|
if auther := httpSettings.Auther; auther != nil {
|
||||||
username, password, _ := req.BasicAuth()
|
username, password, _ := req.BasicAuth()
|
||||||
@ -261,7 +263,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
req.Header.Set(k, v)
|
req.Header.Set(k, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, re := range httpSettings.Rewrite {
|
for _, re := range httpSettings.RewriteURL {
|
||||||
if re.Pattern.MatchString(req.URL.Path) {
|
if re.Pattern.MatchString(req.URL.Path) {
|
||||||
if s := re.Pattern.ReplaceAllString(req.URL.Path, re.Replacement); s != "" {
|
if s := re.Pattern.ReplaceAllString(req.URL.Path, re.Replacement); s != "" {
|
||||||
req.URL.Path = s
|
req.URL.Path = s
|
||||||
@ -269,6 +271,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bodyRewrites = httpSettings.RewriteBody
|
||||||
}
|
}
|
||||||
|
|
||||||
cc, err = h.options.Router.Dial(ctx, "tcp", target.Addr)
|
cc, err = h.options.Router.Dial(ctx, "tcp", target.Addr)
|
||||||
@ -331,7 +335,18 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
log.Trace(string(dump))
|
log.Trace(string(dump))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if res.Close {
|
||||||
|
defer rw.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.rewriteBody(res, bodyRewrites...); err != nil {
|
||||||
|
rw.Close()
|
||||||
|
log.Errorf("rewrite body: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err = res.Write(rw); err != nil {
|
if err = res.Write(rw); err != nil {
|
||||||
|
rw.Close()
|
||||||
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
log.Errorf("write response from node %s(%s): %v", target.Name, target.Addr, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -350,6 +365,50 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *forwardHandler) rewriteBody(resp *http.Response, rewrites ...chain.HTTPBodyRewriteSettings) error {
|
||||||
|
if resp == nil || len(rewrites) == 0 || resp.ContentLength <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := drainBody(resp.Body)
|
||||||
|
if err != nil || body == nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType, _, _ := strings.Cut(resp.Header.Get("Content-Type"), ";")
|
||||||
|
for _, rewrite := range rewrites {
|
||||||
|
rewriteType := rewrite.Type
|
||||||
|
if rewriteType == "" {
|
||||||
|
rewriteType = "text/html"
|
||||||
|
}
|
||||||
|
if rewriteType != "*" && !strings.Contains(rewriteType, contentType) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
body = rewrite.Pattern.ReplaceAll(body, rewrite.Replacement)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
resp.ContentLength = int64(len(body))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func drainBody(b io.ReadCloser) (body []byte, err error) {
|
||||||
|
if b == nil || b == http.NoBody {
|
||||||
|
// No copying needed. Preserve the magic sentinel meaning of NoBody.
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if _, err = buf.ReadFrom(b); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = b.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (h *forwardHandler) checkRateLimit(addr net.Addr) bool {
|
func (h *forwardHandler) checkRateLimit(addr net.Addr) bool {
|
||||||
if h.options.RateLimiter == nil {
|
if h.options.RateLimiter == nil {
|
||||||
return true
|
return true
|
||||||
|
@ -24,6 +24,7 @@ import (
|
|||||||
md "github.com/go-gost/core/metadata"
|
md "github.com/go-gost/core/metadata"
|
||||||
"github.com/go-gost/core/observer/stats"
|
"github.com/go-gost/core/observer/stats"
|
||||||
ctxvalue "github.com/go-gost/x/ctx"
|
ctxvalue "github.com/go-gost/x/ctx"
|
||||||
|
xio "github.com/go-gost/x/internal/io"
|
||||||
netpkg "github.com/go-gost/x/internal/net"
|
netpkg "github.com/go-gost/x/internal/net"
|
||||||
stats_util "github.com/go-gost/x/internal/util/stats"
|
stats_util "github.com/go-gost/x/internal/util/stats"
|
||||||
traffic_wrapper "github.com/go-gost/x/limiter/traffic/wrapper"
|
traffic_wrapper "github.com/go-gost/x/limiter/traffic/wrapper"
|
||||||
@ -236,7 +237,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
|||||||
}
|
}
|
||||||
|
|
||||||
if req.Method != http.MethodConnect {
|
if req.Method != http.MethodConnect {
|
||||||
return h.handleProxy(rw, cc, req, log)
|
return h.handleProxy(xio.NewReadWriteCloser(rw, rw, conn), cc, req, log)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.StatusCode = http.StatusOK
|
resp.StatusCode = http.StatusOK
|
||||||
@ -261,47 +262,92 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpHandler) handleProxy(rw, cc io.ReadWriter, req *http.Request, log logger.Logger) (err error) {
|
func (h *httpHandler) handleProxy(rw io.ReadWriteCloser, cc io.ReadWriter, req *http.Request, log logger.Logger) (err error) {
|
||||||
req.Header.Del("Proxy-Connection")
|
roundTrip := func(req *http.Request) error {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if err = req.Write(cc); err != nil {
|
resp := &http.Response{
|
||||||
log.Error(err)
|
ProtoMajor: req.ProtoMajor,
|
||||||
return err
|
ProtoMinor: req.ProtoMinor,
|
||||||
}
|
Header: http.Header{},
|
||||||
|
StatusCode: http.StatusServiceUnavailable,
|
||||||
|
}
|
||||||
|
|
||||||
ch := make(chan error, 1)
|
// HTTP/1.0
|
||||||
|
if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
|
||||||
|
if strings.ToLower(req.Header.Get("Connection")) == "keep-alive" {
|
||||||
|
req.Header.Del("Connection")
|
||||||
|
} else {
|
||||||
|
req.Header.Set("Connection", "close")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
req.Header.Del("Proxy-Connection")
|
||||||
ch <- netpkg.CopyBuffer(rw, cc, 32*1024)
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
if err = req.Write(cc); err != nil {
|
||||||
err := func() error {
|
resp.Write(rw)
|
||||||
req, err := http.ReadRequest(bufio.NewReader(rw))
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
res, err := http.ReadResponse(bufio.NewReader(cc), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
h.options.Logger.Errorf("read response: %v", err)
|
||||||
|
resp.Write(rw)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if log.IsLevelEnabled(logger.TraceLevel) {
|
if log.IsLevelEnabled(logger.TraceLevel) {
|
||||||
dump, _ := httputil.DumpRequest(req, false)
|
dump, _ := httputil.DumpResponse(res, false)
|
||||||
log.Trace(string(dump))
|
log.Trace(string(dump))
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Del("Proxy-Connection")
|
if res.Close {
|
||||||
|
defer rw.Close()
|
||||||
if err = req.Write(cc); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}()
|
|
||||||
ch <- err
|
|
||||||
|
|
||||||
if err != nil {
|
// HTTP/1.0
|
||||||
break
|
if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
|
||||||
}
|
if !res.Close {
|
||||||
|
res.Header.Set("Connection", "keep-alive")
|
||||||
|
}
|
||||||
|
res.ProtoMajor = req.ProtoMajor
|
||||||
|
res.ProtoMinor = req.ProtoMinor
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = res.Write(rw); err != nil {
|
||||||
|
rw.Close()
|
||||||
|
log.Errorf("write response: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return <-ch
|
if err = roundTrip(req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
req, err := http.ReadRequest(bufio.NewReader(rw))
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if log.IsLevelEnabled(logger.TraceLevel) {
|
||||||
|
dump, _ := httputil.DumpRequest(req, false)
|
||||||
|
log.Trace(string(dump))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = roundTrip(req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *httpHandler) decodeServerName(s string) (string, error) {
|
func (h *httpHandler) decodeServerName(s string) (string, error) {
|
||||||
|
@ -101,6 +101,7 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl
|
|||||||
|
|
||||||
if clientID := sc.ID(); clientID != "" {
|
if clientID := sc.ID(); clientID != "" {
|
||||||
ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID))
|
ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID))
|
||||||
|
log = log.WithFields(map[string]any{"user": clientID})
|
||||||
}
|
}
|
||||||
|
|
||||||
conn = sc
|
conn = sc
|
||||||
|
@ -83,6 +83,9 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error {
|
|||||||
log.Trace(string(dump))
|
log.Trace(string(dump))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resp.ProtoMajor = req.ProtoMajor
|
||||||
|
resp.ProtoMinor = req.ProtoMinor
|
||||||
|
|
||||||
var tunnelID relay.TunnelID
|
var tunnelID relay.TunnelID
|
||||||
if ep.ingress != nil {
|
if ep.ingress != nil {
|
||||||
if rule := ep.ingress.GetRule(ctx, req.Host); rule != nil {
|
if rule := ep.ingress.GetRule(ctx, req.Host); rule != nil {
|
||||||
@ -123,13 +126,15 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error {
|
|||||||
timeout: 15 * time.Second,
|
timeout: 15 * time.Second,
|
||||||
log: log,
|
log: log,
|
||||||
}
|
}
|
||||||
cc, node, cid, err := d.Dial(ctx, "tcp", tunnelID.String())
|
c, node, cid, err := d.Dial(ctx, "tcp", tunnelID.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return resp.Write(conn)
|
return resp.Write(conn)
|
||||||
}
|
}
|
||||||
log.Debugf("new connection to tunnel: %s, connector: %s", tunnelID, cid)
|
log.Debugf("new connection to tunnel: %s, connector: %s", tunnelID, cid)
|
||||||
|
|
||||||
|
cc = c
|
||||||
|
|
||||||
host := req.Host
|
host := req.Host
|
||||||
if h, _, _ := net.SplitHostPort(host); h == "" {
|
if h, _, _ := net.SplitHostPort(host); h == "" {
|
||||||
host = net.JoinHostPort(strings.Trim(host, "[]"), "80")
|
host = net.JoinHostPort(strings.Trim(host, "[]"), "80")
|
||||||
@ -149,17 +154,26 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error {
|
|||||||
Version: relay.Version1,
|
Version: relay.Version1,
|
||||||
Status: relay.StatusOK,
|
Status: relay.StatusOK,
|
||||||
Features: features,
|
Features: features,
|
||||||
}).WriteTo(cc)
|
}).WriteTo(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := req.Write(cc); err != nil {
|
// HTTP/1.0
|
||||||
cc.Close()
|
if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
|
||||||
|
if strings.ToLower(req.Header.Get("Connection")) == "keep-alive" {
|
||||||
|
req.Header.Del("Connection")
|
||||||
|
} else {
|
||||||
|
req.Header.Set("Connection", "close")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := req.Write(c); err != nil {
|
||||||
|
c.Close()
|
||||||
log.Errorf("send request: %v", err)
|
log.Errorf("send request: %v", err)
|
||||||
return resp.Write(conn)
|
return resp.Write(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Header.Get("Upgrade") == "websocket" {
|
if req.Header.Get("Upgrade") == "websocket" {
|
||||||
err := xnet.Transport(cc, xio.NewReadWriter(br, conn))
|
err := xnet.Transport(c, xio.NewReadWriter(br, conn))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = io.EOF
|
err = io.EOF
|
||||||
}
|
}
|
||||||
@ -167,7 +181,7 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer cc.Close()
|
defer c.Close()
|
||||||
|
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
log.Debugf("%s <-> %s", remoteAddr, host)
|
log.Debugf("%s <-> %s", remoteAddr, host)
|
||||||
@ -178,7 +192,7 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error {
|
|||||||
}).Debugf("%s >-< %s", remoteAddr, host)
|
}).Debugf("%s >-< %s", remoteAddr, host)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
res, err := http.ReadResponse(bufio.NewReader(cc), req)
|
res, err := http.ReadResponse(bufio.NewReader(c), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("read response: %v", err)
|
log.Errorf("read response: %v", err)
|
||||||
resp.Write(conn)
|
resp.Write(conn)
|
||||||
@ -190,7 +204,21 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error {
|
|||||||
log.Trace(string(dump))
|
log.Trace(string(dump))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if res.Close {
|
||||||
|
defer conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTP/1.0
|
||||||
|
if req.ProtoMajor == 1 && req.ProtoMinor == 0 {
|
||||||
|
if !res.Close {
|
||||||
|
res.Header.Set("Connection", "keep-alive")
|
||||||
|
}
|
||||||
|
res.ProtoMajor = req.ProtoMajor
|
||||||
|
res.ProtoMinor = req.ProtoMinor
|
||||||
|
}
|
||||||
|
|
||||||
if err = res.Write(conn); err != nil {
|
if err = res.Write(conn); err != nil {
|
||||||
|
conn.Close()
|
||||||
log.Errorf("write response: %v", err)
|
log.Errorf("write response: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -264,6 +264,10 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl
|
|||||||
|
|
||||||
// Close implements io.Closer interface.
|
// Close implements io.Closer interface.
|
||||||
func (h *tunnelHandler) Close() error {
|
func (h *tunnelHandler) Close() error {
|
||||||
|
if h.epSvc != nil {
|
||||||
|
h.epSvc.Close()
|
||||||
|
}
|
||||||
|
h.pool.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,6 +66,14 @@ func (c *Connector) Session() *mux.Session {
|
|||||||
return c.s
|
return c.s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Connector) Close() error {
|
||||||
|
if c == nil || c.s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.s.Close()
|
||||||
|
}
|
||||||
|
|
||||||
type Tunnel struct {
|
type Tunnel struct {
|
||||||
node string
|
node string
|
||||||
id relay.TunnelID
|
id relay.TunnelID
|
||||||
@ -75,7 +83,7 @@ type Tunnel struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
sd sd.SD
|
sd sd.SD
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
rw *selector.RandomWeighted[*Connector]
|
// rw *selector.RandomWeighted[*Connector]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel {
|
func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel {
|
||||||
@ -85,7 +93,7 @@ func NewTunnel(node string, tid relay.TunnelID, ttl time.Duration) *Tunnel {
|
|||||||
t: time.Now(),
|
t: time.Now(),
|
||||||
close: make(chan struct{}),
|
close: make(chan struct{}),
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
rw: selector.NewRandomWeighted[*Connector](),
|
// rw: selector.NewRandomWeighted[*Connector](),
|
||||||
}
|
}
|
||||||
if t.ttl <= 0 {
|
if t.ttl <= 0 {
|
||||||
t.ttl = defaultTTL
|
t.ttl = defaultTTL
|
||||||
@ -117,8 +125,14 @@ func (t *Tunnel) GetConnector(network string) *Connector {
|
|||||||
t.mu.RLock()
|
t.mu.RLock()
|
||||||
defer t.mu.RUnlock()
|
defer t.mu.RUnlock()
|
||||||
|
|
||||||
rw := t.rw
|
// rw := t.rw
|
||||||
rw.Reset()
|
// rw.Reset()
|
||||||
|
|
||||||
|
if len(t.connectors) == 1 {
|
||||||
|
return t.connectors[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
rw := selector.NewRandomWeighted[*Connector]()
|
||||||
|
|
||||||
found := false
|
found := false
|
||||||
for _, c := range t.connectors {
|
for _, c := range t.connectors {
|
||||||
@ -147,6 +161,22 @@ func (t *Tunnel) GetConnector(network string) *Connector {
|
|||||||
return rw.Next()
|
return rw.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tunnel) Close() error {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-t.close:
|
||||||
|
default:
|
||||||
|
for _, c := range t.connectors {
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
close(t.close)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tunnel) CloseOnIdle() bool {
|
func (t *Tunnel) CloseOnIdle() bool {
|
||||||
t.mu.RLock()
|
t.mu.RLock()
|
||||||
defer t.mu.RUnlock()
|
defer t.mu.RUnlock()
|
||||||
@ -256,6 +286,22 @@ func (p *ConnectorPool) Get(network string, tid string) *Connector {
|
|||||||
return t.GetConnector(network)
|
return t.GetConnector(network)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ConnectorPool) Close() error {
|
||||||
|
if p == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
for k, v := range p.tunnels {
|
||||||
|
v.Close()
|
||||||
|
delete(p.tunnels, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *ConnectorPool) closeIdles() {
|
func (p *ConnectorPool) closeIdles() {
|
||||||
ticker := time.NewTicker(1 * time.Hour)
|
ticker := time.NewTicker(1 * time.Hour)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
@ -245,7 +245,7 @@ func (ing *localIngress) GetRule(ctx context.Context, host string, opts ...ingre
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ep != nil {
|
if ep != nil {
|
||||||
ing.options.logger.Debugf("ingress: %s -> %s", host, ep)
|
ing.options.logger.Debugf("ingress: %s -> %s:%s", host, ep.Hostname, ep.Endpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ep
|
return ep
|
||||||
|
@ -13,3 +13,17 @@ func NewReadWriter(r io.Reader, w io.Writer) io.ReadWriter {
|
|||||||
Writer: w,
|
Writer: w,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type readWriteCloser struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
io.Closer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewReadWriteCloser(r io.Reader, w io.Writer, c io.Closer) io.ReadWriteCloser {
|
||||||
|
return &readWriteCloser{
|
||||||
|
Reader: r,
|
||||||
|
Writer: w,
|
||||||
|
Closer: c,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -17,17 +17,16 @@ type ListenConfig struct {
|
|||||||
ReadQueueSize int
|
ReadQueueSize int
|
||||||
ReadBufferSize int
|
ReadBufferSize int
|
||||||
TTL time.Duration
|
TTL time.Duration
|
||||||
KeepAlive bool
|
Keepalive bool
|
||||||
Logger logger.Logger
|
Logger logger.Logger
|
||||||
}
|
}
|
||||||
type listener struct {
|
type listener struct {
|
||||||
conn net.PacketConn
|
conn net.PacketConn
|
||||||
cqueue chan net.Conn
|
cqueue chan net.Conn
|
||||||
connPool *connPool
|
connPool *connPool
|
||||||
// mux sync.Mutex
|
closed chan struct{}
|
||||||
closed chan struct{}
|
errChan chan error
|
||||||
errChan chan error
|
config *ListenConfig
|
||||||
config *ListenConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(conn net.PacketConn, cfg *ListenConfig) net.Listener {
|
func NewListener(conn net.PacketConn, cfg *ListenConfig) net.Listener {
|
||||||
@ -42,9 +41,7 @@ func NewListener(conn net.PacketConn, cfg *ListenConfig) net.Listener {
|
|||||||
errChan: make(chan error, 1),
|
errChan: make(chan error, 1),
|
||||||
config: cfg,
|
config: cfg,
|
||||||
}
|
}
|
||||||
if cfg.KeepAlive {
|
ln.connPool = newConnPool(cfg.TTL).WithLogger(cfg.Logger)
|
||||||
ln.connPool = newConnPool(cfg.TTL).WithLogger(cfg.Logger)
|
|
||||||
}
|
|
||||||
go ln.listenLoop()
|
go ln.listenLoop()
|
||||||
|
|
||||||
return ln
|
return ln
|
||||||
@ -113,15 +110,12 @@ func (ln *listener) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ln *listener) getConn(raddr net.Addr) *conn {
|
func (ln *listener) getConn(raddr net.Addr) *conn {
|
||||||
// ln.mux.Lock()
|
|
||||||
// defer ln.mux.Unlock()
|
|
||||||
|
|
||||||
c, ok := ln.connPool.Get(raddr.String())
|
c, ok := ln.connPool.Get(raddr.String())
|
||||||
if ok {
|
if ok && !c.isClosed() {
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
c = newConn(ln.conn, ln.Addr(), raddr, ln.config.ReadQueueSize, ln.config.KeepAlive)
|
c = newConn(ln.conn, ln.Addr(), raddr, ln.config.ReadQueueSize, ln.config.Keepalive)
|
||||||
select {
|
select {
|
||||||
case ln.cqueue <- c:
|
case ln.cqueue <- c:
|
||||||
ln.connPool.Set(raddr.String(), c)
|
ln.connPool.Set(raddr.String(), c)
|
||||||
@ -142,17 +136,17 @@ type conn struct {
|
|||||||
idle int32 // indicate the connection is idle
|
idle int32 // indicate the connection is idle
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
closeMutex sync.Mutex
|
closeMutex sync.Mutex
|
||||||
keepAlive bool
|
keepalive bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConn(c net.PacketConn, laddr, remoteAddr net.Addr, queueSize int, keepAlive bool) *conn {
|
func newConn(c net.PacketConn, laddr, remoteAddr net.Addr, queueSize int, keepalive bool) *conn {
|
||||||
return &conn{
|
return &conn{
|
||||||
PacketConn: c,
|
PacketConn: c,
|
||||||
localAddr: laddr,
|
localAddr: laddr,
|
||||||
remoteAddr: remoteAddr,
|
remoteAddr: remoteAddr,
|
||||||
rc: make(chan []byte, queueSize),
|
rc: make(chan []byte, queueSize),
|
||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
keepAlive: keepAlive,
|
keepalive: keepalive,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,12 +172,15 @@ func (c *conn) Read(b []byte) (n int, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) Write(b []byte) (n int, err error) {
|
func (c *conn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||||
n, err = c.WriteTo(b, c.remoteAddr)
|
if !c.keepalive {
|
||||||
if !c.keepAlive {
|
defer c.Close()
|
||||||
c.Close()
|
|
||||||
}
|
}
|
||||||
return
|
return c.PacketConn.WriteTo(b, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) Write(b []byte) (n int, err error) {
|
||||||
|
return c.WriteTo(b, c.remoteAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *conn) Close() error {
|
func (c *conn) Close() error {
|
||||||
@ -198,6 +195,15 @@ func (c *conn) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) isClosed() bool {
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *conn) LocalAddr() net.Addr {
|
func (c *conn) LocalAddr() net.Addr {
|
||||||
return c.localAddr
|
return c.localAddr
|
||||||
}
|
}
|
||||||
|
@ -64,7 +64,7 @@ func (l *ftcpListener) Init(md md.Metadata) (err error) {
|
|||||||
ReadQueueSize: l.md.readQueueSize,
|
ReadQueueSize: l.md.readQueueSize,
|
||||||
ReadBufferSize: l.md.readBufferSize,
|
ReadBufferSize: l.md.readBufferSize,
|
||||||
TTL: l.md.ttl,
|
TTL: l.md.ttl,
|
||||||
KeepAlive: true,
|
Keepalive: true,
|
||||||
Logger: l.logger,
|
Logger: l.logger,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultTTL = 5 * time.Second
|
defaultTTL = 5 * time.Second
|
||||||
defaultReadBufferSize = 1024
|
defaultReadBufferSize = 8192
|
||||||
defaultReadQueueSize = 1024
|
defaultReadQueueSize = 1024
|
||||||
defaultBacklog = 128
|
defaultBacklog = 128
|
||||||
)
|
)
|
||||||
|
@ -65,7 +65,7 @@ func (l *udpListener) Init(md md.Metadata) (err error) {
|
|||||||
Backlog: l.md.backlog,
|
Backlog: l.md.backlog,
|
||||||
ReadQueueSize: l.md.readQueueSize,
|
ReadQueueSize: l.md.readQueueSize,
|
||||||
ReadBufferSize: l.md.readBufferSize,
|
ReadBufferSize: l.md.readBufferSize,
|
||||||
KeepAlive: l.md.keepalive,
|
Keepalive: l.md.keepalive,
|
||||||
TTL: l.md.ttl,
|
TTL: l.md.ttl,
|
||||||
Logger: l.logger,
|
Logger: l.logger,
|
||||||
})
|
})
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultTTL = 5 * time.Second
|
defaultTTL = 5 * time.Second
|
||||||
defaultReadBufferSize = 1024
|
defaultReadBufferSize = 8192
|
||||||
defaultReadQueueSize = 128
|
defaultReadQueueSize = 128
|
||||||
defaultBacklog = 128
|
defaultBacklog = 128
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user