From 1759c95e78041d07e0cb930b26601be2112f3c88 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Tue, 17 Oct 2023 21:55:07 +0800 Subject: [PATCH] add mux config --- connector/relay/bind.go | 4 +- connector/relay/metadata.go | 19 ++++ connector/socks/v5/bind.go | 2 +- connector/tunnel/bind.go | 2 +- dialer/mtcp/conn.go | 38 ++++++++ dialer/mtcp/dialer.go | 131 ++++++++++++++++++++++++++++ dialer/mtcp/metadata.go | 30 +++++++ handler/file/handler.go | 133 ++++++++++++++++++++++++++++ handler/file/metadata.go | 15 ++++ handler/forward/local/handler.go | 6 +- handler/forward/remote/handler.go | 6 +- handler/metrics/handler.go | 134 +++++++++++++++++++++++++++++ handler/metrics/metadata.go | 22 +++++ handler/relay/bind.go | 4 +- handler/relay/entrypoint.go | 3 +- handler/socks/v5/mbind.go | 2 +- handler/tunnel/bind.go | 2 +- handler/tunnel/metadata.go | 19 ++++ internal/util/mux/mux.go | 65 +++++++++++++- limiter/traffic/traffic.go | 16 ++-- listener/mtcp/listener.go | 138 ++++++++++++++++++++++++++++++ listener/mtcp/metadata.go | 37 ++++++++ 22 files changed, 805 insertions(+), 23 deletions(-) create mode 100644 dialer/mtcp/conn.go create mode 100644 dialer/mtcp/dialer.go create mode 100644 dialer/mtcp/metadata.go create mode 100644 handler/file/handler.go create mode 100644 handler/file/metadata.go create mode 100644 handler/metrics/handler.go create mode 100644 handler/metrics/metadata.go create mode 100644 listener/mtcp/listener.go create mode 100644 listener/mtcp/metadata.go diff --git a/connector/relay/bind.go b/connector/relay/bind.go index e703378..1e6b04c 100644 --- a/connector/relay/bind.go +++ b/connector/relay/bind.go @@ -50,7 +50,7 @@ func (c *relayConnector) bindTunnel(ctx context.Context, conn net.Conn, network, } log.Infof("create tunnel on %s/%s OK, tunnel=%s, connector=%s", addr, network, c.md.tunnelID.String(), cid) - session, err := mux.ServerSession(conn) + session, err := mux.ServerSession(conn, c.md.muxCfg) if err != nil { return nil, err } @@ -130,7 +130,7 @@ func (c *relayConnector) bindTCP(ctx context.Context, conn net.Conn, network, ad } log.Debugf("bind on %s/%s OK", laddr, laddr.Network()) - session, err := mux.ServerSession(conn) + session, err := mux.ServerSession(conn, c.md.muxCfg) if err != nil { return nil, err } diff --git a/connector/relay/metadata.go b/connector/relay/metadata.go index 78fc681..0ce4f14 100644 --- a/connector/relay/metadata.go +++ b/connector/relay/metadata.go @@ -6,13 +6,19 @@ import ( mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" "github.com/go-gost/relay" + "github.com/go-gost/x/internal/util/mux" "github.com/google/uuid" ) +const ( + defaultMuxVersion = 2 +) + type metadata struct { connectTimeout time.Duration noDelay bool tunnelID relay.TunnelID + muxCfg *mux.Config } func (c *relayConnector) parseMetadata(md mdata.Metadata) (err error) { @@ -32,5 +38,18 @@ func (c *relayConnector) parseMetadata(md mdata.Metadata) (err error) { c.md.tunnelID = relay.NewTunnelID(uuid[:]) } + c.md.muxCfg = &mux.Config{ + Version: mdutil.GetInt(md, "mux.version"), + KeepAliveInterval: mdutil.GetDuration(md, "mux.keepaliveInterval"), + KeepAliveDisabled: mdutil.GetBool(md, "mux.keepaliveDisabled"), + KeepAliveTimeout: mdutil.GetDuration(md, "mux.keepaliveTimeout"), + MaxFrameSize: mdutil.GetInt(md, "mux.maxFrameSize"), + MaxReceiveBuffer: mdutil.GetInt(md, "mux.maxReceiveBuffer"), + MaxStreamBuffer: mdutil.GetInt(md, "mux.maxStreamBuffer"), + } + if c.md.muxCfg.Version == 0 { + c.md.muxCfg.Version = defaultMuxVersion + } + return } diff --git a/connector/socks/v5/bind.go b/connector/socks/v5/bind.go index aa4cb39..6922880 100644 --- a/connector/socks/v5/bind.go +++ b/connector/socks/v5/bind.go @@ -62,7 +62,7 @@ func (c *socks5Connector) muxBindTCP(ctx context.Context, conn net.Conn, network return nil, err } - session, err := mux.ServerSession(conn) + session, err := mux.ServerSession(conn, nil) if err != nil { return nil, err } diff --git a/connector/tunnel/bind.go b/connector/tunnel/bind.go index ed2a093..0f6f9e8 100644 --- a/connector/tunnel/bind.go +++ b/connector/tunnel/bind.go @@ -21,7 +21,7 @@ func (c *tunnelConnector) Bind(ctx context.Context, conn net.Conn, network, addr } log.Infof("create tunnel on %s/%s OK, tunnel=%s, connector=%s", addr, network, c.md.tunnelID.String(), cid) - session, err := mux.ServerSession(conn) + session, err := mux.ServerSession(conn, nil) if err != nil { return nil, err } diff --git a/dialer/mtcp/conn.go b/dialer/mtcp/conn.go new file mode 100644 index 0000000..491449b --- /dev/null +++ b/dialer/mtcp/conn.go @@ -0,0 +1,38 @@ +package mtcp + +import ( + "net" + + "github.com/go-gost/x/internal/util/mux" +) + +type muxSession struct { + conn net.Conn + session *mux.Session +} + +func (session *muxSession) GetConn() (net.Conn, error) { + return session.session.GetConn() +} + +func (session *muxSession) Accept() (net.Conn, error) { + return session.session.Accept() +} + +func (session *muxSession) Close() error { + if session.session == nil { + return nil + } + return session.session.Close() +} + +func (session *muxSession) IsClosed() bool { + if session.session == nil { + return true + } + return session.session.IsClosed() +} + +func (session *muxSession) NumStreams() int { + return session.session.NumStreams() +} diff --git a/dialer/mtcp/dialer.go b/dialer/mtcp/dialer.go new file mode 100644 index 0000000..a9d1831 --- /dev/null +++ b/dialer/mtcp/dialer.go @@ -0,0 +1,131 @@ +package mtcp + +import ( + "context" + "errors" + "net" + "sync" + "time" + + "github.com/go-gost/core/dialer" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/x/internal/util/mux" + "github.com/go-gost/x/registry" +) + +func init() { + registry.DialerRegistry().Register("mtcp", NewDialer) +} + +type mtcpDialer struct { + sessions map[string]*muxSession + sessionMutex sync.Mutex + logger logger.Logger + md metadata + options dialer.Options +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := dialer.Options{} + for _, opt := range opts { + opt(&options) + } + + return &mtcpDialer{ + sessions: make(map[string]*muxSession), + logger: options.Logger, + options: options, + } +} + +func (d *mtcpDialer) Init(md md.Metadata) (err error) { + if err = d.parseMetadata(md); err != nil { + return + } + + return nil +} + +// Multiplex implements dialer.Multiplexer interface. +func (d *mtcpDialer) Multiplex() bool { + return true +} + +func (d *mtcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) { + d.sessionMutex.Lock() + defer d.sessionMutex.Unlock() + + session, ok := d.sessions[addr] + if session != nil && session.IsClosed() { + delete(d.sessions, addr) // session is dead + ok = false + } + if !ok { + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) + } + + conn, err = options.NetDialer.Dial(ctx, "tcp", addr) + if err != nil { + return + } + + session = &muxSession{conn: conn} + d.sessions[addr] = session + } + + return session.conn, err +} + +// Handshake implements dialer.Handshaker +func (d *mtcpDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) { + opts := &dialer.HandshakeOptions{} + for _, option := range options { + option(opts) + } + + d.sessionMutex.Lock() + defer d.sessionMutex.Unlock() + + if d.md.handshakeTimeout > 0 { + conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + session, ok := d.sessions[opts.Addr] + if session != nil && session.conn != conn { + conn.Close() + return nil, errors.New("mtls: unrecognized connection") + } + + if !ok || session.session == nil { + s, err := d.initSession(ctx, conn) + if err != nil { + d.logger.Error(err) + conn.Close() + delete(d.sessions, opts.Addr) + return nil, err + } + session = s + d.sessions[opts.Addr] = session + } + cc, err := session.GetConn() + if err != nil { + session.Close() + delete(d.sessions, opts.Addr) + return nil, err + } + + return cc, nil +} + +func (d *mtcpDialer) initSession(ctx context.Context, conn net.Conn) (*muxSession, error) { + // stream multiplex + session, err := mux.ClientSession(conn, d.md.muxCfg) + if err != nil { + return nil, err + } + return &muxSession{conn: conn, session: session}, nil +} diff --git a/dialer/mtcp/metadata.go b/dialer/mtcp/metadata.go new file mode 100644 index 0000000..abc885c --- /dev/null +++ b/dialer/mtcp/metadata.go @@ -0,0 +1,30 @@ +package mtcp + +import ( + "time" + + mdata "github.com/go-gost/core/metadata" + mdutil "github.com/go-gost/core/metadata/util" + "github.com/go-gost/x/internal/util/mux" +) + +type metadata struct { + handshakeTimeout time.Duration + muxCfg *mux.Config +} + +func (d *mtcpDialer) parseMetadata(md mdata.Metadata) (err error) { + d.md.handshakeTimeout = mdutil.GetDuration(md, "handshakeTimeout") + + d.md.muxCfg = &mux.Config{ + Version: mdutil.GetInt(md, "mux.version"), + KeepAliveInterval: mdutil.GetDuration(md, "mux.keepaliveInterval"), + KeepAliveDisabled: mdutil.GetBool(md, "mux.keepaliveDisabled"), + KeepAliveTimeout: mdutil.GetDuration(md, "mux.keepaliveTimeout"), + MaxFrameSize: mdutil.GetInt(md, "mux.maxFrameSize"), + MaxReceiveBuffer: mdutil.GetInt(md, "mux.maxReceiveBuffer"), + MaxStreamBuffer: mdutil.GetInt(md, "mux.maxStreamBuffer"), + } + + return +} diff --git a/handler/file/handler.go b/handler/file/handler.go new file mode 100644 index 0000000..2c69b3f --- /dev/null +++ b/handler/file/handler.go @@ -0,0 +1,133 @@ +package file + +import ( + "context" + "net" + "net/http" + "sync" + "time" + + "github.com/go-gost/core/handler" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/x/registry" +) + +func init() { + registry.HandlerRegistry().Register("file", NewHandler) +} + +type fileHandler struct { + handler http.Handler + server *http.Server + ln *singleConnListener + md metadata + options handler.Options +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := handler.Options{} + for _, opt := range opts { + opt(&options) + } + + return &fileHandler{ + options: options, + } +} + +func (h *fileHandler) Init(md md.Metadata) (err error) { + if err = h.parseMetadata(md); err != nil { + return + } + + h.handler = http.FileServer(http.Dir(h.md.dir)) + h.server = &http.Server{ + Handler: http.HandlerFunc(h.handleFunc), + } + + h.ln = &singleConnListener{ + conn: make(chan net.Conn), + done: make(chan struct{}), + } + go h.server.Serve(h.ln) + + return +} + +func (h *fileHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { + h.options.Logger.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }).Infof("%s - %s", conn.RemoteAddr(), conn.LocalAddr()) + + h.ln.send(conn) + + return nil +} + +func (h *fileHandler) Close() error { + return h.server.Close() +} + +func (h *fileHandler) handleFunc(w http.ResponseWriter, r *http.Request) { + if auther := h.options.Auther; auther != nil { + u, p, _ := r.BasicAuth() + if _, ok := auther.Authenticate(r.Context(), u, p); !ok { + w.WriteHeader(http.StatusUnauthorized) + return + } + } + + log := h.options.Logger + start := time.Now() + + h.handler.ServeHTTP(w, r) + + log = log.WithFields(map[string]any{ + "remote": r.RemoteAddr, + "duration": time.Since(start), + }) + log.Infof("%s %s", r.Method, r.RequestURI) +} + +type singleConnListener struct { + conn chan net.Conn + addr net.Addr + done chan struct{} + mu sync.Mutex +} + +func (l *singleConnListener) Accept() (net.Conn, error) { + select { + case conn := <-l.conn: + return conn, nil + + case <-l.done: + return nil, net.ErrClosed + } +} + +func (l *singleConnListener) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + + select { + case <-l.done: + default: + close(l.done) + } + + return nil +} + +func (l *singleConnListener) Addr() net.Addr { + return l.addr +} + +func (l *singleConnListener) send(conn net.Conn) { + select { + case l.conn <- conn: + case <-l.done: + return + } +} diff --git a/handler/file/metadata.go b/handler/file/metadata.go new file mode 100644 index 0000000..7ea4c45 --- /dev/null +++ b/handler/file/metadata.go @@ -0,0 +1,15 @@ +package file + +import ( + mdata "github.com/go-gost/core/metadata" + mdutil "github.com/go-gost/core/metadata/util" +) + +type metadata struct { + dir string +} + +func (h *fileHandler) parseMetadata(md mdata.Metadata) (err error) { + h.md.dir = mdutil.GetString(md, "file.dir", "dir") + return +} diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index 847f2e3..c8fa60e 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -276,7 +276,11 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err) return resp.Write(rw) } - defer res.Body.Close() + + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpResponse(res, false) + log.Trace(string(dump)) + } return res.Write(rw) }() diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index 233c57a..1191a23 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -281,7 +281,11 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, remot log.Warnf("read response from node %s(%s): %v", target.Name, target.Addr, err) return resp.Write(rw) } - defer res.Body.Close() + + if log.IsLevelEnabled(logger.TraceLevel) { + dump, _ := httputil.DumpResponse(res, false) + log.Trace(string(dump)) + } return res.Write(rw) }() diff --git a/handler/metrics/handler.go b/handler/metrics/handler.go new file mode 100644 index 0000000..47f07a3 --- /dev/null +++ b/handler/metrics/handler.go @@ -0,0 +1,134 @@ +package file + +import ( + "context" + "net" + "net/http" + "sync" + "time" + + "github.com/go-gost/core/handler" + md "github.com/go-gost/core/metadata" + xmetrics "github.com/go-gost/x/metrics" + "github.com/go-gost/x/registry" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +func init() { + registry.HandlerRegistry().Register("metrics", NewHandler) +} + +type metricsHandler struct { + handler http.Handler + server *http.Server + ln *singleConnListener + md metadata + options handler.Options +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := handler.Options{} + for _, opt := range opts { + opt(&options) + } + + return &metricsHandler{ + options: options, + } +} + +func (h *metricsHandler) Init(md md.Metadata) (err error) { + if err = h.parseMetadata(md); err != nil { + return + } + + xmetrics.Init(xmetrics.NewMetrics()) + h.handler = promhttp.Handler() + + mux := http.NewServeMux() + mux.Handle(h.md.path, http.HandlerFunc(h.handleFunc)) + h.server = &http.Server{ + Handler: mux, + } + + h.ln = &singleConnListener{ + conn: make(chan net.Conn), + done: make(chan struct{}), + } + go h.server.Serve(h.ln) + + return +} + +func (h *metricsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { + h.ln.send(conn) + + return nil +} + +func (h *metricsHandler) Close() error { + return h.server.Close() +} + +func (h *metricsHandler) handleFunc(w http.ResponseWriter, r *http.Request) { + if auther := h.options.Auther; auther != nil { + u, p, _ := r.BasicAuth() + if _, ok := auther.Authenticate(r.Context(), u, p); !ok { + w.WriteHeader(http.StatusUnauthorized) + return + } + } + + log := h.options.Logger + start := time.Now() + + h.handler.ServeHTTP(w, r) + + log = log.WithFields(map[string]any{ + "remote": r.RemoteAddr, + "duration": time.Since(start), + }) + log.Debugf("%s %s", r.Method, r.RequestURI) +} + +type singleConnListener struct { + conn chan net.Conn + addr net.Addr + done chan struct{} + mu sync.Mutex +} + +func (l *singleConnListener) Accept() (net.Conn, error) { + select { + case conn := <-l.conn: + return conn, nil + + case <-l.done: + return nil, net.ErrClosed + } +} + +func (l *singleConnListener) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + + select { + case <-l.done: + default: + close(l.done) + } + + return nil +} + +func (l *singleConnListener) Addr() net.Addr { + return l.addr +} + +func (l *singleConnListener) send(conn net.Conn) { + select { + case l.conn <- conn: + case <-l.done: + return + } +} diff --git a/handler/metrics/metadata.go b/handler/metrics/metadata.go new file mode 100644 index 0000000..b9da2d8 --- /dev/null +++ b/handler/metrics/metadata.go @@ -0,0 +1,22 @@ +package file + +import ( + mdata "github.com/go-gost/core/metadata" + mdutil "github.com/go-gost/core/metadata/util" +) + +const ( + DefaultPath = "/metrics" +) + +type metadata struct { + path string +} + +func (h *metricsHandler) parseMetadata(md mdata.Metadata) (err error) { + h.md.path = mdutil.GetString(md, "metrics.path", "path") + if h.md.path == "" { + h.md.path = DefaultPath + } + return +} diff --git a/handler/relay/bind.go b/handler/relay/bind.go index e9e4fb3..805206d 100644 --- a/handler/relay/bind.go +++ b/handler/relay/bind.go @@ -81,7 +81,7 @@ func (h *relayHandler) bindTCP(ctx context.Context, conn net.Conn, network, addr } // Upgrade connection to multiplex session. - session, err := mux.ClientSession(conn) + session, err := mux.ClientSession(conn, nil) if err != nil { log.Error(err) return err @@ -219,7 +219,7 @@ func (h *relayHandler) handleBindTunnel(ctx context.Context, conn net.Conn, netw resp.WriteTo(conn) // Upgrade connection to multiplex session. - session, err := mux.ClientSession(conn) + session, err := mux.ClientSession(conn, nil) if err != nil { return } diff --git a/handler/relay/entrypoint.go b/handler/relay/entrypoint.go index c77c67c..839f7c6 100644 --- a/handler/relay/entrypoint.go +++ b/handler/relay/entrypoint.go @@ -17,7 +17,6 @@ import ( md "github.com/go-gost/core/metadata" "github.com/go-gost/relay" admission "github.com/go-gost/x/admission/wrapper" - netpkg "github.com/go-gost/x/internal/net" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/proxyproto" "github.com/go-gost/x/internal/util/forward" @@ -126,7 +125,7 @@ func (h *tcpHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) - netpkg.Transport(conn, cc) + xnet.Transport(conn, cc) log.WithFields(map[string]any{"duration": time.Since(t)}). Debugf("%s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) return nil diff --git a/handler/socks/v5/mbind.go b/handler/socks/v5/mbind.go index 53fce3d..a5ac3d3 100644 --- a/handler/socks/v5/mbind.go +++ b/handler/socks/v5/mbind.go @@ -70,7 +70,7 @@ func (h *socks5Handler) muxBindLocal(ctx context.Context, conn net.Conn, network func (h *socks5Handler) serveMuxBind(ctx context.Context, conn net.Conn, ln net.Listener, log logger.Logger) error { // Upgrade connection to multiplex stream. - session, err := mux.ClientSession(conn) + session, err := mux.ClientSession(conn, nil) if err != nil { log.Error(err) return err diff --git a/handler/tunnel/bind.go b/handler/tunnel/bind.go index d9de83a..fc0d6dc 100644 --- a/handler/tunnel/bind.go +++ b/handler/tunnel/bind.go @@ -48,7 +48,7 @@ func (h *tunnelHandler) handleBind(ctx context.Context, conn net.Conn, network, resp.WriteTo(conn) // Upgrade connection to multiplex session. - session, err := mux.ClientSession(conn) + session, err := mux.ClientSession(conn, h.md.muxCfg) if err != nil { return } diff --git a/handler/tunnel/metadata.go b/handler/tunnel/metadata.go index 4ed5206..e94b395 100644 --- a/handler/tunnel/metadata.go +++ b/handler/tunnel/metadata.go @@ -10,9 +10,14 @@ import ( mdutil "github.com/go-gost/core/metadata/util" "github.com/go-gost/relay" xingress "github.com/go-gost/x/ingress" + "github.com/go-gost/x/internal/util/mux" "github.com/go-gost/x/registry" ) +const ( + defaultMuxVersion = 2 +) + type metadata struct { readTimeout time.Duration noDelay bool @@ -20,6 +25,7 @@ type metadata struct { directTunnel bool entryPointID relay.TunnelID ingress ingress.Ingress + muxCfg *mux.Config } func (h *tunnelHandler) parseMetadata(md mdata.Metadata) (err error) { @@ -54,5 +60,18 @@ func (h *tunnelHandler) parseMetadata(md mdata.Metadata) (err error) { } } + h.md.muxCfg = &mux.Config{ + Version: mdutil.GetInt(md, "mux.version"), + KeepAliveInterval: mdutil.GetDuration(md, "mux.keepaliveInterval"), + KeepAliveDisabled: mdutil.GetBool(md, "mux.keepaliveDisabled"), + KeepAliveTimeout: mdutil.GetDuration(md, "mux.keepaliveTimeout"), + MaxFrameSize: mdutil.GetInt(md, "mux.maxFrameSize"), + MaxReceiveBuffer: mdutil.GetInt(md, "mux.maxReceiveBuffer"), + MaxStreamBuffer: mdutil.GetInt(md, "mux.maxStreamBuffer"), + } + if h.md.muxCfg.Version == 0 { + h.md.muxCfg.Version = defaultMuxVersion + } + return } diff --git a/internal/util/mux/mux.go b/internal/util/mux/mux.go index 4fae020..3aaf483 100644 --- a/internal/util/mux/mux.go +++ b/internal/util/mux/mux.go @@ -2,17 +2,74 @@ package mux import ( "net" + "time" smux "github.com/xtaci/smux" ) +type Config struct { + // SMUX Protocol version, support 1,2 + Version int + + // Disabled keepalive + KeepAliveDisabled bool + + // KeepAliveInterval is how often to send a NOP command to the remote + KeepAliveInterval time.Duration + + // KeepAliveTimeout is how long the session + // will be closed if no data has arrived + KeepAliveTimeout time.Duration + + // MaxFrameSize is used to control the maximum + // frame size to sent to the remote + MaxFrameSize int + + // MaxReceiveBuffer is used to control the maximum + // number of data in the buffer pool + MaxReceiveBuffer int + + // MaxStreamBuffer is used to control the maximum + // number of data per stream + MaxStreamBuffer int +} + +func convertConfig(cfg *Config) *smux.Config { + smuxCfg := smux.DefaultConfig() + if cfg == nil { + return smuxCfg + } + + if cfg.Version > 0 { + smuxCfg.Version = cfg.Version + } + smuxCfg.KeepAliveDisabled = cfg.KeepAliveDisabled + if cfg.KeepAliveInterval > 0 { + smuxCfg.KeepAliveInterval = cfg.KeepAliveInterval + } + if cfg.KeepAliveTimeout > 0 { + smuxCfg.KeepAliveTimeout = cfg.KeepAliveTimeout + } + if cfg.MaxFrameSize > 0 { + smuxCfg.MaxFrameSize = cfg.MaxFrameSize + } + if cfg.MaxReceiveBuffer > 0 { + smuxCfg.MaxReceiveBuffer = cfg.MaxReceiveBuffer + } + if cfg.MaxStreamBuffer > 0 { + smuxCfg.MaxStreamBuffer = cfg.MaxStreamBuffer + } + + return smuxCfg +} + type Session struct { conn net.Conn session *smux.Session } -func ClientSession(conn net.Conn) (*Session, error) { - s, err := smux.Client(conn, smux.DefaultConfig()) +func ClientSession(conn net.Conn, cfg *Config) (*Session, error) { + s, err := smux.Client(conn, convertConfig(cfg)) if err != nil { return nil, err } @@ -22,8 +79,8 @@ func ClientSession(conn net.Conn) (*Session, error) { }, nil } -func ServerSession(conn net.Conn) (*Session, error) { - s, err := smux.Server(conn, smux.DefaultConfig()) +func ServerSession(conn net.Conn, cfg *Config) (*Session, error) { + s, err := smux.Server(conn, convertConfig(cfg)) if err != nil { return nil, err } diff --git a/limiter/traffic/traffic.go b/limiter/traffic/traffic.go index 95c29a7..12b16bb 100644 --- a/limiter/traffic/traffic.go +++ b/limiter/traffic/traffic.go @@ -82,13 +82,15 @@ type limitValue struct { type trafficLimiter struct { generators sync.Map cidrGenerators cidranger.Ranger - connInLimits *cache.Cache - connOutLimits *cache.Cache - inLimits *cache.Cache - outLimits *cache.Cache - mu sync.RWMutex - cancelFunc context.CancelFunc - options options + // connection level in/out limits + connInLimits *cache.Cache + connOutLimits *cache.Cache + // service level in/out limits + inLimits *cache.Cache + outLimits *cache.Cache + mu sync.RWMutex + cancelFunc context.CancelFunc + options options } func NewTrafficLimiter(opts ...Option) limiter.TrafficLimiter { diff --git a/listener/mtcp/listener.go b/listener/mtcp/listener.go new file mode 100644 index 0000000..b483910 --- /dev/null +++ b/listener/mtcp/listener.go @@ -0,0 +1,138 @@ +package mtcp + +import ( + "context" + "net" + "time" + + "github.com/go-gost/core/listener" + "github.com/go-gost/core/logger" + md "github.com/go-gost/core/metadata" + admission "github.com/go-gost/x/admission/wrapper" + xnet "github.com/go-gost/x/internal/net" + "github.com/go-gost/x/internal/net/proxyproto" + "github.com/go-gost/x/internal/util/mux" + climiter "github.com/go-gost/x/limiter/conn/wrapper" + limiter "github.com/go-gost/x/limiter/traffic/wrapper" + metrics "github.com/go-gost/x/metrics/wrapper" + "github.com/go-gost/x/registry" +) + +func init() { + registry.ListenerRegistry().Register("mtcp", NewListener) +} + +type mtcpListener struct { + ln net.Listener + cqueue chan net.Conn + errChan chan error + logger logger.Logger + md metadata + options listener.Options +} + +func NewListener(opts ...listener.Option) listener.Listener { + options := listener.Options{} + for _, opt := range opts { + opt(&options) + } + return &mtcpListener{ + logger: options.Logger, + options: options, + } +} + +func (l *mtcpListener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { + return + } + + network := "tcp" + if xnet.IsIPv4(l.options.Addr) { + network = "tcp4" + } + + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + ln, err := lc.Listen(context.Background(), network, l.options.Addr) + if err != nil { + return + } + + l.logger.Debugf("pp: %d", l.options.ProxyProtocol) + + ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) + ln = metrics.WrapListener(l.options.Service, ln) + ln = admission.WrapListener(l.options.Admission, ln) + ln = limiter.WrapListener(l.options.TrafficLimiter, ln) + ln = climiter.WrapListener(l.options.ConnLimiter, ln) + l.ln = ln + + l.cqueue = make(chan net.Conn, l.md.backlog) + l.errChan = make(chan error, 1) + + go l.listenLoop() + + return +} + +func (l *mtcpListener) Addr() net.Addr { + return l.ln.Addr() +} + +func (l *mtcpListener) Close() error { + return l.ln.Close() +} + +func (l *mtcpListener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.cqueue: + case err, ok = <-l.errChan: + if !ok { + err = listener.ErrClosed + } + } + return +} + +func (l *mtcpListener) listenLoop() { + for { + conn, err := l.ln.Accept() + if err != nil { + l.errChan <- err + close(l.errChan) + return + } + go l.mux(conn) + } +} + +func (l *mtcpListener) mux(conn net.Conn) { + defer conn.Close() + + session, err := mux.ServerSession(conn, l.md.muxCfg) + if err != nil { + l.logger.Error(err) + return + } + defer session.Close() + + for { + stream, err := session.Accept() + if err != nil { + l.logger.Error("accept stream: ", err) + return + } + + select { + case l.cqueue <- stream: + default: + stream.Close() + l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr()) + } + } +} diff --git a/listener/mtcp/metadata.go b/listener/mtcp/metadata.go new file mode 100644 index 0000000..b93221e --- /dev/null +++ b/listener/mtcp/metadata.go @@ -0,0 +1,37 @@ +package mtcp + +import ( + md "github.com/go-gost/core/metadata" + mdutil "github.com/go-gost/core/metadata/util" + "github.com/go-gost/x/internal/util/mux" +) + +const ( + defaultBacklog = 128 +) + +type metadata struct { + mptcp bool + muxCfg *mux.Config + backlog int +} + +func (l *mtcpListener) parseMetadata(md md.Metadata) (err error) { + l.md.mptcp = mdutil.GetBool(md, "mptcp") + + l.md.muxCfg = &mux.Config{ + Version: mdutil.GetInt(md, "mux.version"), + KeepAliveInterval: mdutil.GetDuration(md, "mux.keepaliveInterval"), + KeepAliveDisabled: mdutil.GetBool(md, "mux.keepaliveDisabled"), + KeepAliveTimeout: mdutil.GetDuration(md, "mux.keepaliveTimeout"), + MaxFrameSize: mdutil.GetInt(md, "mux.maxFrameSize"), + MaxReceiveBuffer: mdutil.GetInt(md, "mux.maxReceiveBuffer"), + MaxStreamBuffer: mdutil.GetInt(md, "mux.maxStreamBuffer"), + } + + l.md.backlog = mdutil.GetInt(md, "backlog") + if l.md.backlog <= 0 { + l.md.backlog = defaultBacklog + } + return +}