diff --git a/admission/admission.go b/admission/admission.go index 9088214..35d5a80 100644 --- a/admission/admission.go +++ b/admission/admission.go @@ -13,7 +13,6 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/x/internal/loader" "github.com/go-gost/x/internal/matcher" - "google.golang.org/grpc" ) type options struct { @@ -22,7 +21,6 @@ type options struct { fileLoader loader.Loader redisLoader loader.Loader httpLoader loader.Loader - client *grpc.ClientConn period time.Duration logger logger.Logger } @@ -65,12 +63,6 @@ func HTTPLoaderOption(httpLoader loader.Loader) Option { } } -func PluginConnOption(c *grpc.ClientConn) Option { - return func(opts *options) { - opts.client = c - } -} - func LoggerOption(logger logger.Logger) Option { return func(opts *options) { opts.logger = logger diff --git a/admission/plugin.go b/admission/plugin.go index e99d492..81540ee 100644 --- a/admission/plugin.go +++ b/admission/plugin.go @@ -2,37 +2,36 @@ package admission import ( "context" + "io" admission_pkg "github.com/go-gost/core/admission" + "github.com/go-gost/core/logger" "github.com/go-gost/plugin/admission/proto" - xlogger "github.com/go-gost/x/logger" + "google.golang.org/grpc" ) -type pluginAdmission struct { - client proto.AdmissionClient - options options +type grpcPluginAdmission struct { + conn grpc.ClientConnInterface + client proto.AdmissionClient + log logger.Logger } -// NewPluginAdmission creates a plugin admission. -func NewPluginAdmission(opts ...Option) admission_pkg.Admission { - var options options - for _, opt := range opts { - opt(&options) +// NewGRPCPluginAdmission creates an Admission plugin based on gRPC. +func NewGRPCPluginAdmission(name string, conn grpc.ClientConnInterface) admission_pkg.Admission { + p := &grpcPluginAdmission{ + conn: conn, + log: logger.Default().WithFields(map[string]any{ + "kind": "admission", + "admission": name, + }), } - if options.logger == nil { - options.logger = xlogger.Nop() - } - - p := &pluginAdmission{ - options: options, - } - if options.client != nil { - p.client = proto.NewAdmissionClient(options.client) + if conn != nil { + p.client = proto.NewAdmissionClient(conn) } return p } -func (p *pluginAdmission) Admit(ctx context.Context, addr string) bool { +func (p *grpcPluginAdmission) Admit(ctx context.Context, addr string) bool { if p.client == nil { return false } @@ -42,15 +41,15 @@ func (p *pluginAdmission) Admit(ctx context.Context, addr string) bool { Addr: addr, }) if err != nil { - p.options.logger.Error(err) + p.log.Error(err) return false } return r.Ok } -func (p *pluginAdmission) Close() error { - if p.options.client != nil { - return p.options.client.Close() +func (p *grpcPluginAdmission) Close() error { + if closer, ok := p.conn.(io.Closer); ok { + return closer.Close() } return nil } diff --git a/api/middleware.go b/api/middleware.go index 924a614..5b9006a 100644 --- a/api/middleware.go +++ b/api/middleware.go @@ -35,7 +35,7 @@ func mwBasicAuth(auther auth.Authenticator) gin.HandlerFunc { return } u, p, _ := c.Request.BasicAuth() - if !auther.Authenticate(c, u, p) { + if ok, _ := auther.Authenticate(c, u, p); !ok { c.AbortWithStatus(http.StatusUnauthorized) } } diff --git a/auth/auth.go b/auth/auth.go index d6ba5d6..cfc9e4b 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -12,7 +12,6 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/x/internal/loader" xlogger "github.com/go-gost/x/logger" - "google.golang.org/grpc" ) type options struct { @@ -21,7 +20,6 @@ type options struct { redisLoader loader.Loader httpLoader loader.Loader period time.Duration - client *grpc.ClientConn logger logger.Logger } @@ -57,12 +55,6 @@ func HTTPLoaderOption(httpLoader loader.Loader) Option { } } -func PluginConnOption(c *grpc.ClientConn) Option { - return func(opts *options) { - opts.client = c - } -} - func LoggerOption(logger logger.Logger) Option { return func(opts *options) { opts.logger = logger @@ -105,20 +97,20 @@ func NewAuthenticator(opts ...Option) auth.Authenticator { } // Authenticate checks the validity of the provided user-password pair. -func (p *authenticator) Authenticate(ctx context.Context, user, password string) bool { +func (p *authenticator) Authenticate(ctx context.Context, user, password string) (bool, string) { if p == nil { - return true + return true, "" } p.mu.RLock() defer p.mu.RUnlock() if len(p.kvs) == 0 { - return true + return false, "" } v, ok := p.kvs[user] - return ok && (v == "" || password == v) + return ok && (v == "" || password == v), "" } func (p *authenticator) periodReload(ctx context.Context) error { diff --git a/auth/plugin.go b/auth/plugin.go index 77070cd..66d0f52 100644 --- a/auth/plugin.go +++ b/auth/plugin.go @@ -2,40 +2,40 @@ package auth import ( "context" + "io" "github.com/go-gost/core/auth" + "github.com/go-gost/core/logger" "github.com/go-gost/plugin/auth/proto" - xlogger "github.com/go-gost/x/logger" + "google.golang.org/grpc" ) -type pluginAuthenticator struct { - client proto.AuthenticatorClient - options options +type grpcPluginAuthenticator struct { + conn grpc.ClientConnInterface + client proto.AuthenticatorClient + log logger.Logger } -// NewPluginAuthenticator creates an Authenticator that authenticates client by plugin. -func NewPluginAuthenticator(opts ...Option) auth.Authenticator { - var options options - for _, opt := range opts { - opt(&options) - } - if options.logger == nil { - options.logger = xlogger.Nop() +// NewGRPCPluginAuthenticator creates an Authenticator plugin based on gRPC. +func NewGRPCPluginAuthenticator(name string, conn grpc.ClientConnInterface) auth.Authenticator { + p := &grpcPluginAuthenticator{ + conn: conn, + log: logger.Default().WithFields(map[string]any{ + "kind": "auther", + "auther": name, + }), } - p := &pluginAuthenticator{ - options: options, - } - if options.client != nil { - p.client = proto.NewAuthenticatorClient(options.client) + if conn != nil { + p.client = proto.NewAuthenticatorClient(conn) } return p } // Authenticate checks the validity of the provided user-password pair. -func (p *pluginAuthenticator) Authenticate(ctx context.Context, user, password string) bool { +func (p *grpcPluginAuthenticator) Authenticate(ctx context.Context, user, password string) (bool, string) { if p.client == nil { - return false + return false, "" } r, err := p.client.Authenticate(ctx, @@ -44,15 +44,15 @@ func (p *pluginAuthenticator) Authenticate(ctx context.Context, user, password s Password: password, }) if err != nil { - p.options.logger.Error(err) - return false + p.log.Error(err) + return false, "" } - return r.Ok + return r.Ok, r.Id } -func (p *pluginAuthenticator) Close() error { - if p.options.client != nil { - return p.options.client.Close() +func (p *grpcPluginAuthenticator) Close() error { + if closer, ok := p.conn.(io.Closer); ok { + return closer.Close() } return nil } diff --git a/bypass/plugin.go b/bypass/plugin.go index 8393410..dc80934 100644 --- a/bypass/plugin.go +++ b/bypass/plugin.go @@ -2,55 +2,56 @@ package bypass import ( "context" + "io" bypass_pkg "github.com/go-gost/core/bypass" + "github.com/go-gost/core/logger" "github.com/go-gost/plugin/bypass/proto" - xlogger "github.com/go-gost/x/logger" + auth_util "github.com/go-gost/x/internal/util/auth" + "google.golang.org/grpc" ) -type pluginBypass struct { - client proto.BypassClient - options options +type grpcPluginBypass struct { + conn grpc.ClientConnInterface + client proto.BypassClient + log logger.Logger } -// NewPluginBypass creates a plugin bypass. -func NewPluginBypass(opts ...Option) bypass_pkg.Bypass { - var options options - for _, opt := range opts { - opt(&options) +// NewGRPCPluginBypass creates a Bypass plugin based on gRPC. +func NewGRPCPluginBypass(name string, conn grpc.ClientConnInterface) bypass_pkg.Bypass { + p := &grpcPluginBypass{ + conn: conn, + log: logger.Default().WithFields(map[string]any{ + "kind": "bypass", + "bypass": name, + }), } - if options.logger == nil { - options.logger = xlogger.Nop() - } - - p := &pluginBypass{ - options: options, - } - if options.client != nil { - p.client = proto.NewBypassClient(options.client) + if conn != nil { + p.client = proto.NewBypassClient(conn) } return p } -func (p *pluginBypass) Contains(ctx context.Context, addr string) bool { +func (p *grpcPluginBypass) Contains(ctx context.Context, addr string) bool { if p.client == nil { - return false + return true } r, err := p.client.Bypass(ctx, &proto.BypassRequest{ - Addr: addr, + Addr: addr, + Client: string(auth_util.IDFromContext(ctx)), }) if err != nil { - p.options.logger.Error(err) - return false + p.log.Error(err) + return true } return r.Ok } -func (p *pluginBypass) Close() error { - if p.options.client != nil { - return p.options.client.Close() +func (p *grpcPluginBypass) Close() error { + if closer, ok := p.conn.(io.Closer); ok { + return closer.Close() } return nil } diff --git a/config/parsing/parse.go b/config/parsing/parse.go index 22fcd2a..ab09845 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -65,13 +65,7 @@ func ParseAuther(cfg *config.AutherConfig) auth.Authenticator { if err != nil { logger.Default().Error(err) } - return auth_impl.NewPluginAuthenticator( - auth_impl.PluginConnOption(c), - auth_impl.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "auther", - "auther": cfg.Name, - })), - ) + return auth_impl.NewGRPCPluginAuthenticator(cfg.Name, c) } m := make(map[string]string) @@ -199,13 +193,7 @@ func ParseAdmission(cfg *config.AdmissionConfig) admission.Admission { if err != nil { logger.Default().Error(err) } - return admission_impl.NewPluginAdmission( - admission_impl.PluginConnOption(c), - admission_impl.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "admission", - "admission": cfg.Name, - })), - ) + return admission_impl.NewGRPCPluginAdmission(cfg.Name, c) } opts := []admission_impl.Option{ @@ -248,13 +236,7 @@ func ParseBypass(cfg *config.BypassConfig) bypass.Bypass { if err != nil { logger.Default().Error(err) } - return bypass_impl.NewPluginBypass( - bypass_impl.PluginConnOption(c), - bypass_impl.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "bypass", - "bypass": cfg.Name, - })), - ) + return bypass_impl.NewGRPCPluginBypass(cfg.Name, c) } opts := []bypass_impl.Option{ @@ -298,13 +280,7 @@ func ParseResolver(cfg *config.ResolverConfig) (resolver.Resolver, error) { logger.Default().Error(err) return nil, err } - return resolver_impl.NewPluginResolver( - resolver_impl.PluginConnOption(c), - resolver_impl.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "resolver", - "resolver": cfg.Name, - })), - ) + return resolver_impl.NewGRPCPluginResolver(cfg.Name, c) } var nameservers []resolver_impl.NameServer @@ -341,13 +317,7 @@ func ParseHosts(cfg *config.HostsConfig) hosts.HostMapper { if err != nil { logger.Default().Error(err) } - return xhosts.NewPluginHostMapper( - xhosts.PluginConnOption(c), - xhosts.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "hosts", - "hosts": cfg.Name, - })), - ) + return xhosts.NewGRPCPluginHostMapper(cfg.Name, c) } var mappings []xhosts.Mapping @@ -413,13 +383,7 @@ func ParseIngress(cfg *config.IngressConfig) ingress.Ingress { if err != nil { logger.Default().Error(err) } - return xingress.NewPluginIngress( - xingress.PluginConnOption(c), - xingress.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "ingress", - "ingress": cfg.Name, - })), - ) + return xingress.NewGRPCPluginIngress(cfg.Name, c) } var rules []xingress.Rule @@ -481,13 +445,7 @@ func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) { if err != nil { logger.Default().Error(err) } - return xrecorder.NewPluginRecorder( - xrecorder.PluginConnOption(c), - xrecorder.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "recorder", - "recorder": cfg.Name, - })), - ) + return xrecorder.NewGRPCPluginRecorder(cfg.Name, c) } if cfg.File != nil && cfg.File.Path != "" { diff --git a/connector/socks/v5/selector.go b/connector/socks/v5/selector.go index af0d96e..b7076ad 100644 --- a/connector/socks/v5/selector.go +++ b/connector/socks/v5/selector.go @@ -30,12 +30,13 @@ func (s *clientSelector) Select(methods ...uint8) (method uint8) { return } -func (s *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { +func (s *clientSelector) OnSelected(method uint8, conn net.Conn) (string, net.Conn, error) { s.logger.Debug("method selected: ", method) switch method { case socks.MethodTLS: conn = tls.Client(conn, s.TLSConfig) + return "", conn, nil case gosocks5.MethodUserPass, socks.MethodTLSAuth: if method == socks.MethodTLSAuth { @@ -52,22 +53,25 @@ func (s *clientSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, erro s.logger.Trace(req) if err := req.Write(conn); err != nil { s.logger.Error(err) - return nil, err + return "", nil, err } resp, err := gosocks5.ReadUserPassResponse(conn) if err != nil { s.logger.Error(err) - return nil, err + return "", nil, err } s.logger.Trace(resp) if resp.Status != gosocks5.Succeeded { - return nil, gosocks5.ErrAuthFailure + return "", nil, gosocks5.ErrAuthFailure } + return "", conn, nil + case gosocks5.MethodNoAcceptable: - return nil, gosocks5.ErrBadMethod + return "", nil, gosocks5.ErrBadMethod + default: + return "", nil, gosocks5.ErrBadFormat } - return conn, nil } diff --git a/go.mod b/go.mod index 2033857..e98240d 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,9 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.9.1 - github.com/go-gost/core v0.0.0-20230918131208-c258a114c40b + github.com/go-gost/core v0.0.0-20230919141921-a1419ec2f4d1 github.com/go-gost/gosocks4 v0.0.1 - github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 + github.com/go-gost/gosocks5 v0.4.0 github.com/go-gost/plugin v0.0.0-20230418123101-d221a4ec9a98 github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 @@ -86,7 +86,7 @@ require ( github.com/pion/transport/v2 v2.0.2 // indirect github.com/pion/udp/v2 v2.0.1 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/prometheus/client_model v0.3.0 // indirect + github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect diff --git a/go.sum b/go.sum index 4a576ba..049efb8 100644 --- a/go.sum +++ b/go.sum @@ -101,10 +101,12 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gost/core v0.0.0-20230918131208-c258a114c40b h1:kqALaNXbbYyKFlcLj3ODsuvzplRxypnJOhMINSiM8sk= github.com/go-gost/core v0.0.0-20230918131208-c258a114c40b/go.mod h1:db6lBY+DkC3ct4OJfclsKnQwQmcv1B9NnMnpI2MNUwY= +github.com/go-gost/core v0.0.0-20230919141921-a1419ec2f4d1 h1:tV5Ra3bBU5R9Mcg9lGzMQeVeLcnCFEEiE6UNnx6F46k= +github.com/go-gost/core v0.0.0-20230919141921-a1419ec2f4d1/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= -github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= -github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= +github.com/go-gost/gosocks5 v0.4.0 h1:EIrOEkpJez4gwHrMa33frA+hHXJyevjp47thpMQsJzI= +github.com/go-gost/gosocks5 v0.4.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= github.com/go-gost/plugin v0.0.0-20230418123101-d221a4ec9a98 h1:dOtNcxZbMDwtowa8b91nK2JcTL1lG0EIv0sXqSbvTc4= github.com/go-gost/plugin v0.0.0-20230418123101-d221a4ec9a98/go.mod h1:IGQawP0E9B36VZ0AfDOpBK23bW4rOSiHtnU7mtafpAM= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 h1:qAG1OyjvdA5h221CfFSS3J359V3d2E7dJWyP29QoDSI= @@ -311,8 +313,8 @@ github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1: github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.3.0 h1:UBgGFHqYdG/TPFD1B1ogZywDqEkwp3fBMvqdiQ7Xew4= -github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w= +github.com/prometheus/client_model v0.4.0 h1:5lQXD3cAg1OXBf4Wq03gTrXHeaV0TQvGfUooCfx1yqY= +github.com/prometheus/client_model v0.4.0/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index 08b2f47..445c129 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -18,6 +18,7 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" xnet "github.com/go-gost/x/internal/net" + auth_util "github.com/go-gost/x/internal/util/auth" "github.com/go-gost/x/internal/util/forward" "github.com/go-gost/x/registry" ) @@ -208,12 +209,14 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l if auther := target.Options().Auther; auther != nil { username, password, _ := req.BasicAuth() - if !auther.Authenticate(ctx, username, password) { + ok, id := auther.Authenticate(ctx, username, password) + if !ok { resp.StatusCode = http.StatusUnauthorized resp.Header.Set("WWW-Authenticate", "Basic") log.Warnf("node %s(%s) 401 unauthorized", target.Name, target.Addr) return resp.Write(rw) } + ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) } var cc net.Conn diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index d9297a4..a6bce36 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -19,6 +19,7 @@ import ( mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" xnet "github.com/go-gost/x/internal/net" + auth_util "github.com/go-gost/x/internal/util/auth" "github.com/go-gost/x/internal/util/forward" "github.com/go-gost/x/registry" ) @@ -205,12 +206,14 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l if auther := target.Options().Auther; auther != nil { username, password, _ := req.BasicAuth() - if !auther.Authenticate(ctx, username, password) { + ok, id := auther.Authenticate(ctx, username, password) + if !ok { resp.StatusCode = http.StatusUnauthorized resp.Header.Set("WWW-Authenticate", "Basic") log.Warnf("node %s(%s) 401 unauthorized", target.Name, target.Addr) return resp.Write(rw) } + ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) } var cc net.Conn diff --git a/handler/http/handler.go b/handler/http/handler.go index 9980dbc..28a4ab7 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -22,6 +22,7 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" netpkg "github.com/go-gost/x/internal/net" + auth_util "github.com/go-gost/x/internal/util/auth" sx "github.com/go-gost/x/internal/util/selector" "github.com/go-gost/x/registry" ) @@ -145,6 +146,12 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt resp.Header = http.Header{} } + ok, id := h.authenticate(ctx, conn, req, resp, log) + if !ok { + return nil + } + ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, addr) { resp.StatusCode = http.StatusForbidden @@ -157,10 +164,6 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt return resp.Write(conn) } - if !h.authenticate(ctx, conn, req, resp, log) { - return nil - } - if network == "udp" { return h.handleUDP(ctx, conn, log) } @@ -266,10 +269,13 @@ func (h *httpHandler) basicProxyAuth(proxyAuth string, log logger.Logger) (usern return cs[:s], cs[s+1:], true } -func (h *httpHandler) authenticate(ctx context.Context, conn net.Conn, req *http.Request, resp *http.Response, log logger.Logger) (ok bool) { +func (h *httpHandler) authenticate(ctx context.Context, conn net.Conn, req *http.Request, resp *http.Response, log logger.Logger) (ok bool, token string) { u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log) - if h.options.Auther == nil || h.options.Auther.Authenticate(ctx, u, p) { - return true + if h.options.Auther == nil { + return true, "" + } + if ok, token = h.options.Auther.Authenticate(ctx, u, p); ok { + return } pr := h.md.probeResistance diff --git a/handler/http2/handler.go b/handler/http2/handler.go index 960067c..9aaf4bf 100644 --- a/handler/http2/handler.go +++ b/handler/http2/handler.go @@ -24,6 +24,7 @@ import ( md "github.com/go-gost/core/metadata" xio "github.com/go-gost/x/internal/io" netpkg "github.com/go-gost/x/internal/net" + auth_util "github.com/go-gost/x/internal/util/auth" sx "github.com/go-gost/x/internal/util/selector" "github.com/go-gost/x/registry" ) @@ -138,12 +139,6 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req w.Header().Set(k, h.md.header.Get(k)) } - if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, addr) { - w.WriteHeader(http.StatusForbidden) - log.Debug("bypass: ", addr) - return nil - } - resp := &http.Response{ ProtoMajor: 2, ProtoMinor: 0, @@ -151,7 +146,15 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req Body: io.NopCloser(bytes.NewReader([]byte{})), } - if !h.authenticate(ctx, w, req, resp, log) { + ok, id := h.authenticate(ctx, w, req, resp, log) + if !ok { + return nil + } + ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + + if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, addr) { + w.WriteHeader(http.StatusForbidden) + log.Debug("bypass: ", addr) return nil } @@ -251,10 +254,13 @@ func (h *http2Handler) basicProxyAuth(proxyAuth string) (username, password stri return cs[:s], cs[s+1:], true } -func (h *http2Handler) authenticate(ctx context.Context, w http.ResponseWriter, r *http.Request, resp *http.Response, log logger.Logger) (ok bool) { +func (h *http2Handler) authenticate(ctx context.Context, w http.ResponseWriter, r *http.Request, resp *http.Response, log logger.Logger) (ok bool, token string) { u, p, _ := h.basicProxyAuth(r.Header.Get("Proxy-Authorization")) - if h.options.Auther == nil || h.options.Auther.Authenticate(ctx, u, p) { - return true + if h.options.Auther == nil { + return true, "" + } + if ok, token = h.options.Auther.Authenticate(ctx, u, p); ok { + return } pr := h.md.probeResistance diff --git a/handler/relay/handler.go b/handler/relay/handler.go index 6bb596d..c4946e7 100644 --- a/handler/relay/handler.go +++ b/handler/relay/handler.go @@ -15,6 +15,7 @@ import ( "github.com/go-gost/core/service" "github.com/go-gost/relay" xnet "github.com/go-gost/x/internal/net" + auth_util "github.com/go-gost/x/internal/util/auth" "github.com/go-gost/x/registry" xservice "github.com/go-gost/x/service" ) @@ -200,11 +201,14 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle log = log.WithFields(map[string]any{"user": user}) } - if h.options.Auther != nil && - !h.options.Auther.Authenticate(ctx, user, pass) { - resp.Status = relay.StatusUnauthorized - resp.WriteTo(conn) - return ErrUnauthorized + if h.options.Auther != nil { + ok, id := h.options.Auther.Authenticate(ctx, user, pass) + if !ok { + resp.Status = relay.StatusUnauthorized + resp.WriteTo(conn) + return ErrUnauthorized + } + ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) } network := networkID.String() diff --git a/handler/socks/v4/handler.go b/handler/socks/v4/handler.go index 42412d3..0163920 100644 --- a/handler/socks/v4/handler.go +++ b/handler/socks/v4/handler.go @@ -12,6 +12,7 @@ import ( md "github.com/go-gost/core/metadata" "github.com/go-gost/gosocks4" netpkg "github.com/go-gost/x/internal/net" + auth_util "github.com/go-gost/x/internal/util/auth" sx "github.com/go-gost/x/internal/util/selector" "github.com/go-gost/x/registry" ) @@ -90,11 +91,14 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl conn.SetReadDeadline(time.Time{}) - if h.options.Auther != nil && - !h.options.Auther.Authenticate(ctx, string(req.Userid), "") { - resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil) - log.Trace(resp) - return resp.Write(conn) + if h.options.Auther != nil { + ok, id := h.options.Auther.Authenticate(ctx, string(req.Userid), "") + if !ok { + resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil) + log.Trace(resp) + return resp.Write(conn) + } + ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) } switch req.Cmd { diff --git a/handler/socks/v5/handler.go b/handler/socks/v5/handler.go index a6597c8..690ad24 100644 --- a/handler/socks/v5/handler.go +++ b/handler/socks/v5/handler.go @@ -10,6 +10,7 @@ import ( "github.com/go-gost/core/handler" md "github.com/go-gost/core/metadata" "github.com/go-gost/gosocks5" + auth_util "github.com/go-gost/x/internal/util/auth" "github.com/go-gost/x/internal/util/socks" "github.com/go-gost/x/registry" ) @@ -86,13 +87,17 @@ func (h *socks5Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl conn.SetReadDeadline(time.Now().Add(h.md.readTimeout)) } - conn = gosocks5.ServerConn(conn, h.selector) - req, err := gosocks5.ReadRequest(conn) + sc := gosocks5.ServerConn(conn, h.selector) + req, err := gosocks5.ReadRequest(sc) if err != nil { log.Error(err) return err } log.Trace(req) + + ctx = auth_util.ContextWithID(ctx, auth_util.ID(sc.ID())) + + conn = sc conn.SetReadDeadline(time.Time{}) address := req.Addr.String() diff --git a/handler/socks/v5/selector.go b/handler/socks/v5/selector.go index 9949e83..8f6b271 100644 --- a/handler/socks/v5/selector.go +++ b/handler/socks/v5/selector.go @@ -46,11 +46,12 @@ func (s *serverSelector) Select(methods ...uint8) (method uint8) { return } -func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, error) { +func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (string, net.Conn, error) { s.logger.Debugf("%d %d", gosocks5.Ver5, method) switch method { case socks.MethodTLS: conn = tls.Server(conn, s.TLSConfig) + return "", conn, nil case gosocks5.MethodUserPass, socks.MethodTLSAuth: if method == socks.MethodTLSAuth { @@ -60,32 +61,37 @@ func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Conn, erro req, err := gosocks5.ReadUserPassRequest(conn) if err != nil { s.logger.Error(err) - return nil, err + return "", nil, err } s.logger.Trace(req) - if s.Authenticator != nil && - !s.Authenticator.Authenticate(context.Background(), req.Username, req.Password) { - resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) - if err := resp.Write(conn); err != nil { - s.logger.Error(err) - return nil, err - } - s.logger.Info(resp) + var id string + if s.Authenticator != nil { + var ok bool + ok, id = s.Authenticator.Authenticate(context.Background(), req.Username, req.Password) + if !ok { + resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) + if err := resp.Write(conn); err != nil { + s.logger.Error(err) + return "", nil, err + } + s.logger.Info(resp) - return nil, gosocks5.ErrAuthFailure + return "", nil, gosocks5.ErrAuthFailure + } } resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded) s.logger.Trace(resp) if err := resp.Write(conn); err != nil { s.logger.Error(err) - return nil, err + return "", nil, err } + return id, conn, nil case gosocks5.MethodNoAcceptable: - return nil, gosocks5.ErrBadMethod + return "", nil, gosocks5.ErrBadMethod + default: + return "", nil, gosocks5.ErrBadFormat } - - return conn, nil } diff --git a/handler/tun/server.go b/handler/tun/server.go index db22b16..caa22b4 100644 --- a/handler/tun/server.go +++ b/handler/tun/server.go @@ -135,7 +135,7 @@ func (h *tunHandler) transportServer(ctx context.Context, tun io.ReadWriter, con ok := true key := bytes.TrimRight((*b)[4:20], "\x00") for _, ip := range peerIPs { - if ok = auther.Authenticate(ctx, ip.String(), string(key)); !ok { + if ok, _ = auther.Authenticate(ctx, ip.String(), string(key)); !ok { break } } diff --git a/hosts/plugin.go b/hosts/plugin.go index 56f2e21..e414652 100644 --- a/hosts/plugin.go +++ b/hosts/plugin.go @@ -2,39 +2,39 @@ package hosts import ( "context" + "io" "net" "github.com/go-gost/core/hosts" + "github.com/go-gost/core/logger" "github.com/go-gost/plugin/hosts/proto" - xlogger "github.com/go-gost/x/logger" + auth_util "github.com/go-gost/x/internal/util/auth" + "google.golang.org/grpc" ) -type pluginHostMapper struct { - client proto.HostMapperClient - options options +type grpcPluginHostMapper struct { + conn grpc.ClientConnInterface + client proto.HostMapperClient + log logger.Logger } -// NewPluginHostMapper creates a plugin HostMapper. -func NewPluginHostMapper(opts ...Option) hosts.HostMapper { - var options options - for _, opt := range opts { - opt(&options) +// NewGRPCPluginHostMapper creates a HostMapper plugin based on gRPC. +func NewGRPCPluginHostMapper(name string, conn grpc.ClientConnInterface) hosts.HostMapper { + p := &grpcPluginHostMapper{ + conn: conn, + log: logger.Default().WithFields(map[string]any{ + "kind": "hosts", + "hosts": name, + }), } - if options.logger == nil { - options.logger = xlogger.Nop() - } - - p := &pluginHostMapper{ - options: options, - } - if options.client != nil { - p.client = proto.NewHostMapperClient(options.client) + if conn != nil { + p.client = proto.NewHostMapperClient(conn) } return p } -func (p *pluginHostMapper) Lookup(ctx context.Context, network, host string) (ips []net.IP, ok bool) { - p.options.logger.Debugf("lookup %s/%s", host, network) +func (p *grpcPluginHostMapper) Lookup(ctx context.Context, network, host string) (ips []net.IP, ok bool) { + p.log.Debugf("lookup %s/%s", host, network) if p.client == nil { return @@ -44,9 +44,10 @@ func (p *pluginHostMapper) Lookup(ctx context.Context, network, host string) (ip &proto.LookupRequest{ Network: network, Host: host, + Client: string(auth_util.IDFromContext(ctx)), }) if err != nil { - p.options.logger.Error(err) + p.log.Error(err) return } for _, s := range r.Ips { @@ -58,9 +59,9 @@ func (p *pluginHostMapper) Lookup(ctx context.Context, network, host string) (ip return } -func (p *pluginHostMapper) Close() error { - if p.options.client != nil { - return p.options.client.Close() +func (p *grpcPluginHostMapper) Close() error { + if closer, ok := p.conn.(io.Closer); ok { + return closer.Close() } return nil } diff --git a/ingress/plugin.go b/ingress/plugin.go index 4f86166..6c1d125 100644 --- a/ingress/plugin.go +++ b/ingress/plugin.go @@ -2,37 +2,36 @@ package ingress import ( "context" + "io" ingress_pkg "github.com/go-gost/core/ingress" + "github.com/go-gost/core/logger" "github.com/go-gost/plugin/ingress/proto" - xlogger "github.com/go-gost/x/logger" + "google.golang.org/grpc" ) -type pluginIngress struct { - client proto.IngressClient - options options +type grpcPluginIngress struct { + conn grpc.ClientConnInterface + client proto.IngressClient + log logger.Logger } -// NewPluginIngress creates a plugin ingress. -func NewPluginIngress(opts ...Option) ingress_pkg.Ingress { - var options options - for _, opt := range opts { - opt(&options) +// NewGRPCPluginIngress creates a ingress plugin based on gRPC. +func NewGRPCPluginIngress(name string, conn grpc.ClientConnInterface) ingress_pkg.Ingress { + p := &grpcPluginIngress{ + conn: conn, + log: logger.Default().WithFields(map[string]any{ + "kind": "ingress", + "ingress": name, + }), } - if options.logger == nil { - options.logger = xlogger.Nop() - } - - p := &pluginIngress{ - options: options, - } - if options.client != nil { - p.client = proto.NewIngressClient(options.client) + if conn != nil { + p.client = proto.NewIngressClient(conn) } return p } -func (p *pluginIngress) Get(ctx context.Context, host string) string { +func (p *grpcPluginIngress) Get(ctx context.Context, host string) string { if p.client == nil { return "" } @@ -42,15 +41,15 @@ func (p *pluginIngress) Get(ctx context.Context, host string) string { Host: host, }) if err != nil { - p.options.logger.Error(err) + p.log.Error(err) return "" } return r.GetEndpoint() } -func (p *pluginIngress) Close() error { - if p.options.client != nil { - return p.options.client.Close() +func (p *grpcPluginIngress) Close() error { + if closer, ok := p.conn.(io.Closer); ok { + return closer.Close() } return nil } diff --git a/internal/util/auth/key.go b/internal/util/auth/key.go new file mode 100644 index 0000000..96c95b7 --- /dev/null +++ b/internal/util/auth/key.go @@ -0,0 +1,22 @@ +package auth + +import ( + "context" +) + +type idKey struct{} + +type ID string + +var ( + clientIDKey = &idKey{} +) + +func ContextWithID(ctx context.Context, id ID) context.Context { + return context.WithValue(ctx, clientIDKey, id) +} + +func IDFromContext(ctx context.Context) ID { + v, _ := ctx.Value(clientIDKey).(ID) + return v +} diff --git a/internal/util/ssh/ssh.go b/internal/util/ssh/ssh.go index a52b235..55af884 100644 --- a/internal/util/ssh/ssh.go +++ b/internal/util/ssh/ssh.go @@ -27,7 +27,7 @@ func PasswordCallback(au auth.Authenticator) PasswordCallbackFunc { return nil } return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { - if au.Authenticate(context.Background(), conn.User(), string(password)) { + if ok, _ := au.Authenticate(context.Background(), conn.User(), string(password)); ok { return nil, nil } return nil, fmt.Errorf("password rejected for %s", conn.User()) diff --git a/recorder/plugin.go b/recorder/plugin.go index 0f1ab11..a473d00 100644 --- a/recorder/plugin.go +++ b/recorder/plugin.go @@ -2,58 +2,36 @@ package recorder import ( "context" + "io" "github.com/go-gost/core/logger" "github.com/go-gost/core/recorder" "github.com/go-gost/plugin/recorder/proto" - xlogger "github.com/go-gost/x/logger" "google.golang.org/grpc" ) -type pluginOptions struct { - client *grpc.ClientConn - logger logger.Logger +type grpcPluginRecorder struct { + conn grpc.ClientConnInterface + client proto.RecorderClient + log logger.Logger } -type PluginOption func(opts *pluginOptions) - -func PluginConnOption(c *grpc.ClientConn) PluginOption { - return func(opts *pluginOptions) { - opts.client = c +// NewGRPCPluginRecorder creates a plugin recorder. +func NewGRPCPluginRecorder(name string, conn grpc.ClientConnInterface) recorder.Recorder { + p := &grpcPluginRecorder{ + conn: conn, + log: logger.Default().WithFields(map[string]any{ + "kind": "recorder", + "recorder": name, + }), } -} - -func LoggerOption(logger logger.Logger) PluginOption { - return func(opts *pluginOptions) { - opts.logger = logger - } -} - -type pluginRecorder struct { - client proto.RecorderClient - options pluginOptions -} - -// NewPluginRecorder creates a plugin recorder. -func NewPluginRecorder(opts ...PluginOption) recorder.Recorder { - var options pluginOptions - for _, opt := range opts { - opt(&options) - } - if options.logger == nil { - options.logger = xlogger.Nop() - } - - p := &pluginRecorder{ - options: options, - } - if options.client != nil { - p.client = proto.NewRecorderClient(options.client) + if conn != nil { + p.client = proto.NewRecorderClient(conn) } return p } -func (p *pluginRecorder) Record(ctx context.Context, b []byte) error { +func (p *grpcPluginRecorder) Record(ctx context.Context, b []byte) error { if p.client == nil { return nil } @@ -63,15 +41,15 @@ func (p *pluginRecorder) Record(ctx context.Context, b []byte) error { Data: b, }) if err != nil { - p.options.logger.Error(err) + p.log.Error(err) return err } return nil } -func (p *pluginRecorder) Close() error { - if p.options.client != nil { - return p.options.client.Close() +func (p *grpcPluginRecorder) Close() error { + if closer, ok := p.conn.(io.Closer); ok { + return closer.Close() } return nil } diff --git a/registry/auther.go b/registry/auther.go index 79ad04e..d1d5918 100644 --- a/registry/auther.go +++ b/registry/auther.go @@ -30,10 +30,10 @@ type autherWrapper struct { r *autherRegistry } -func (w *autherWrapper) Authenticate(ctx context.Context, user, password string) bool { +func (w *autherWrapper) Authenticate(ctx context.Context, user, password string) (bool, string) { v := w.r.get(w.name) if v == nil { - return true + return true, "" } return v.Authenticate(ctx, user, password) } diff --git a/resolver/plugin.go b/resolver/plugin.go index d77a744..4ecae23 100644 --- a/resolver/plugin.go +++ b/resolver/plugin.go @@ -2,39 +2,39 @@ package resolver import ( "context" + "io" "net" + "github.com/go-gost/core/logger" resolver_pkg "github.com/go-gost/core/resolver" "github.com/go-gost/plugin/resolver/proto" - xlogger "github.com/go-gost/x/logger" + auth_util "github.com/go-gost/x/internal/util/auth" + "google.golang.org/grpc" ) -type pluginResolver struct { - client proto.ResolverClient - options options +type grpcPluginResolver struct { + conn grpc.ClientConnInterface + client proto.ResolverClient + log logger.Logger } -// NewPluginResolver creates a plugin Resolver. -func NewPluginResolver(opts ...Option) (resolver_pkg.Resolver, error) { - var options options - for _, opt := range opts { - opt(&options) +// NewGRPCPluginResolver creates a Resolver plugin based on gRPC. +func NewGRPCPluginResolver(name string, conn grpc.ClientConnInterface) (resolver_pkg.Resolver, error) { + p := &grpcPluginResolver{ + conn: conn, + log: logger.Default().WithFields(map[string]any{ + "kind": "resolver", + "resolver": name, + }), } - if options.logger == nil { - options.logger = xlogger.Nop() - } - - p := &pluginResolver{ - options: options, - } - if options.client != nil { - p.client = proto.NewResolverClient(options.client) + if conn != nil { + p.client = proto.NewResolverClient(conn) } return p, nil } -func (p *pluginResolver) Resolve(ctx context.Context, network, host string) (ips []net.IP, err error) { - p.options.logger.Debugf("resolve %s/%s", host, network) +func (p *grpcPluginResolver) Resolve(ctx context.Context, network, host string) (ips []net.IP, err error) { + p.log.Debugf("resolve %s/%s", host, network) if p.client == nil { return @@ -44,9 +44,10 @@ func (p *pluginResolver) Resolve(ctx context.Context, network, host string) (ips &proto.ResolveRequest{ Network: network, Host: host, + Client: string(auth_util.IDFromContext(ctx)), }) if err != nil { - p.options.logger.Error(err) + p.log.Error(err) return } for _, s := range r.Ips { @@ -57,9 +58,9 @@ func (p *pluginResolver) Resolve(ctx context.Context, network, host string) (ips return } -func (p *pluginResolver) Close() error { - if p.options.client != nil { - return p.options.client.Close() +func (p *grpcPluginResolver) Close() error { + if closer, ok := p.conn.(io.Closer); ok { + return closer.Close() } return nil }