From a8a6bbc3a3ba3ef8759f29da2422effb7dda1a2d Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sat, 12 Feb 2022 00:33:20 +0800 Subject: [PATCH] add auther config --- cmd/gost/cmd.go | 21 +++++----- cmd/gost/config.go | 8 ++++ pkg/common/util/auth/auth.go | 24 ----------- pkg/config/config.go | 31 +++++++++++---- pkg/config/parsing/chain.go | 24 +---------- pkg/config/parsing/parse.go | 47 +++++++++++++++++++++- pkg/config/parsing/service.go | 29 ++++++-------- pkg/connector/http/connector.go | 4 +- pkg/connector/http2/connector.go | 4 +- pkg/connector/option.go | 6 +-- pkg/connector/relay/bind.go | 6 +-- pkg/connector/relay/connector.go | 9 ++--- pkg/connector/socks/v4/connector.go | 7 +--- pkg/connector/socks/v5/connector.go | 2 +- pkg/connector/ss/connector.go | 6 +-- pkg/connector/ss/udp/connector.go | 6 +-- pkg/dialer/option.go | 6 +-- pkg/dialer/sshd/dialer.go | 9 ++--- pkg/handler/http/handler.go | 12 ++---- pkg/handler/http2/handler.go | 12 ++---- pkg/handler/option.go | 13 ++++-- pkg/handler/relay/handler.go | 14 +++---- pkg/handler/socks/v4/handler.go | 14 +++---- pkg/handler/socks/v5/handler.go | 3 +- pkg/handler/ss/handler.go | 6 +-- pkg/handler/ss/udp/handler.go | 6 +-- pkg/handler/tap/handler.go | 6 +-- pkg/handler/tun/handler.go | 6 +-- pkg/listener/option.go | 14 +++++-- pkg/listener/ssh/listener.go | 6 +-- pkg/listener/sshd/listener.go | 6 +-- pkg/registry/auther.go | 62 +++++++++++++++++++++++++++++ pkg/registry/bypass.go | 3 ++ pkg/registry/chain.go | 3 ++ pkg/registry/hosts.go | 3 ++ pkg/registry/resolver.go | 3 ++ pkg/registry/service.go | 3 ++ 37 files changed, 261 insertions(+), 183 deletions(-) delete mode 100644 pkg/common/util/auth/auth.go create mode 100644 pkg/registry/auther.go diff --git a/cmd/gost/cmd.go b/cmd/gost/cmd.go index 27b603c..901c933 100644 --- a/cmd/gost/cmd.go +++ b/cmd/gost/cmd.go @@ -107,7 +107,7 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { } resolverCfg.Nameservers = append( resolverCfg.Nameservers, - config.NameserverConfig{ + &config.NameserverConfig{ Addr: rs, }, ) @@ -127,7 +127,7 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { } hostsCfg.Mappings = append( hostsCfg.Mappings, - config.HostMappingConfig{ + &config.HostMappingConfig{ Hostname: ss[0], IP: ss[1], }, @@ -194,7 +194,7 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { } resolverCfg.Nameservers = append( resolverCfg.Nameservers, - config.NameserverConfig{ + &config.NameserverConfig{ Addr: rs, }, ) @@ -214,7 +214,7 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { } hostsCfg.Mappings = append( hostsCfg.Mappings, - config.HostMappingConfig{ + &config.HostMappingConfig{ Hostname: ss[0], IP: ss[1], }, @@ -271,13 +271,12 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) { } } - var auths []*config.AuthConfig + var auth *config.AuthConfig if url.User != nil { - auth := &config.AuthConfig{ + auth = &config.AuthConfig{ Username: url.User.Username(), } auth.Password, _ = url.User.Password() - auths = append(auths, auth) } md := metadata.MapMetadata{} @@ -292,7 +291,7 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) { if err != nil { return nil, err } - auths = append(auths, au) + auth = au } md.Del("auth") @@ -319,7 +318,7 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) { svc.Handler = &config.HandlerConfig{ Type: handler, - Auths: auths, + Auth: auth, Metadata: md, } svc.Listener = &config.ListenerConfig{ @@ -329,10 +328,10 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) { } if svc.Handler.Type == "sshd" { - svc.Handler.Auths = nil + svc.Handler.Auth = nil } if svc.Listener.Type == "sshd" { - svc.Listener.Auths = auths + svc.Listener.Auth = auth } return svc, nil diff --git a/cmd/gost/config.go b/cmd/gost/config.go index 9b2d6b1..fad9956 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -16,6 +16,14 @@ func buildService(cfg *config.Config) (services []*service.Service) { return } + for _, autherCfg := range cfg.Authers { + if auther := parsing.ParseAuther(autherCfg); auther != nil { + if err := registry.Auther().Register(autherCfg.Name, auther); err != nil { + log.Fatal(err) + } + } + } + for _, bypassCfg := range cfg.Bypasses { if bp := parsing.ParseBypass(bypassCfg); bp != nil { if err := registry.Bypass().Register(bypassCfg.Name, bp); err != nil { diff --git a/pkg/common/util/auth/auth.go b/pkg/common/util/auth/auth.go deleted file mode 100644 index 09bb4ef..0000000 --- a/pkg/common/util/auth/auth.go +++ /dev/null @@ -1,24 +0,0 @@ -package auth - -import ( - "net/url" - - "github.com/go-gost/gost/pkg/auth" -) - -func AuthFromUsers(users ...*url.Userinfo) auth.Authenticator { - kvs := make(map[string]string) - for _, v := range users { - if v == nil || v.Username() == "" { - continue - } - kvs[v.Username()], _ = v.Password() - } - - var authenticator auth.Authenticator - if len(kvs) > 0 { - authenticator = auth.NewMapAuthenticator(kvs) - } - - return authenticator -} diff --git a/pkg/config/config.go b/pkg/config/config.go index 2e00cc3..21fdc9a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -67,6 +67,14 @@ type TLSConfig struct { ServerName string `yaml:"serverName,omitempty" json:"serverName,omitempty"` } +type AutherConfig struct { + Name string `json:"name"` + // inline, file, redis, etc. + Type string `yaml:",omitempty" json:"type,omitempty"` + Auths []*AuthConfig `yaml:",omitempty" json:"auths"` + // File string `yaml:",omitempty" json:"file"` +} + type AuthConfig struct { Username string `json:"username"` Password string `yaml:",omitempty" json:"password,omitempty"` @@ -79,7 +87,9 @@ type SelectorConfig struct { } type BypassConfig struct { - Name string `json:"name"` + Name string `json:"name"` + // inline, file, etc. + Type string `yaml:",omitempty" json:"type,omitempty"` Reverse bool `yaml:",omitempty" json:"reverse,omitempty"` Matchers []string `json:"matchers"` } @@ -95,8 +105,10 @@ type NameserverConfig struct { } type ResolverConfig struct { - Name string `json:"name"` - Nameservers []NameserverConfig `json:"nameservers"` + Name string `json:"name"` + // inline, file, etc. + Type string `yaml:",omitempty" json:"type,omitempty"` + Nameservers []*NameserverConfig `json:"nameservers"` } type HostMappingConfig struct { @@ -106,14 +118,17 @@ type HostMappingConfig struct { } type HostsConfig struct { - Name string `json:"name"` - Mappings []HostMappingConfig `json:"mappings"` + Name string `json:"name"` + // inline, file, etc. + Type string `yaml:",omitempty" json:"type,omitempty"` + Mappings []*HostMappingConfig `json:"mappings"` } type ListenerConfig struct { Type string `json:"type"` Chain string `yaml:",omitempty" json:"chain,omitempty"` - Auths []*AuthConfig `yaml:",omitempty" json:"auths,omitempty"` + Auther string `yaml:",omitempty" json:"auther,omitempty"` + Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` Metadata map[string]interface{} `yaml:",omitempty" json:"metadata,omitempty"` } @@ -122,7 +137,8 @@ type HandlerConfig struct { Type string `json:"type"` Retries int `yaml:",omitempty" json:"retries,omitempty"` Chain string `yaml:",omitempty" json:"chain,omitempty"` - Auths []*AuthConfig `yaml:",omitempty" json:"auths,omitempty"` + Auther string `yaml:",omitempty" json:"auther,omitempty"` + Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` Metadata map[string]interface{} `yaml:",omitempty" json:"metadata,omitempty"` } @@ -185,6 +201,7 @@ type NodeConfig struct { type Config struct { Services []*ServiceConfig `json:"services"` Chains []*ChainConfig `yaml:",omitempty" json:"chains,omitempty"` + Authers []*AutherConfig `yaml:",omitempty" json:"authers,omitempty"` Bypasses []*BypassConfig `yaml:",omitempty" json:"bypasses,omitempty"` Resolvers []*ResolverConfig `yaml:",omitempty" json:"resolvers,omitempty"` Hosts []*HostsConfig `yaml:",omitempty" json:"hosts,omitempty"` diff --git a/pkg/config/parsing/chain.go b/pkg/config/parsing/chain.go index 051e742..e73b3b9 100644 --- a/pkg/config/parsing/chain.go +++ b/pkg/config/parsing/chain.go @@ -1,8 +1,6 @@ package parsing import ( - "net/url" - "github.com/go-gost/gost/pkg/chain" tls_util "github.com/go-gost/gost/pkg/common/util/tls" "github.com/go-gost/gost/pkg/config" @@ -39,15 +37,6 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { "kind": "connector", }) - var user *url.Userinfo - if auth := v.Connector.Auth; auth != nil && auth.Username != "" { - if auth.Password == "" { - user = url.User(auth.Username) - } else { - user = url.UserPassword(auth.Username, auth.Password) - } - } - tlsCfg := v.Connector.TLS if tlsCfg == nil { tlsCfg = &config.TLSConfig{} @@ -61,7 +50,7 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { } cr := registry.GetConnector(v.Connector.Type)( - connector.UserOption(user), + connector.AuthOption(parseAuth(v.Connector.Auth)), connector.TLSConfigOption(tlsConfig), connector.LoggerOption(connectorLogger), ) @@ -78,15 +67,6 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { "kind": "dialer", }) - user = nil - if auth := v.Dialer.Auth; auth != nil && auth.Username != "" { - if auth.Password == "" { - user = url.User(auth.Username) - } else { - user = url.UserPassword(auth.Username, auth.Password) - } - } - tlsCfg = v.Dialer.TLS if tlsCfg == nil { tlsCfg = &config.TLSConfig{} @@ -100,7 +80,7 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { } d := registry.GetDialer(v.Dialer.Type)( - dialer.UserOption(user), + dialer.AuthOption(parseAuth(v.Dialer.Auth)), dialer.TLSConfigOption(tlsConfig), dialer.LoggerOption(dialerLogger), ) diff --git a/pkg/config/parsing/parse.go b/pkg/config/parsing/parse.go index b23d505..9c6b919 100644 --- a/pkg/config/parsing/parse.go +++ b/pkg/config/parsing/parse.go @@ -2,16 +2,59 @@ package parsing import ( "net" + "net/url" + "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/config" hostspkg "github.com/go-gost/gost/pkg/hosts" "github.com/go-gost/gost/pkg/logger" + "github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/resolver" resolver_impl "github.com/go-gost/gost/pkg/resolver/impl" ) +func ParseAuther(cfg *config.AutherConfig) auth.Authenticator { + if cfg == nil { + return nil + } + + m := make(map[string]string) + + for _, user := range cfg.Auths { + if user.Username == "" { + continue + } + m[user.Username] = user.Password + } + + if len(m) == 0 { + return nil + } + return auth.NewMapAuthenticator(m) +} + +func autherFromAuth(au *config.AuthConfig) auth.Authenticator { + if au == nil || au.Username == "" { + return nil + } + return auth.NewMapAuthenticator(map[string]string{ + au.Username: au.Password, + }) +} + +func parseAuth(cfg *config.AuthConfig) *url.Userinfo { + if cfg == nil || cfg.Username == "" { + return nil + } + + if cfg.Password == "" { + return url.User(cfg.Username) + } + return url.UserPassword(cfg.Username, cfg.Password) +} + func parseSelector(cfg *config.SelectorConfig) chain.Selector { if cfg == nil { return nil @@ -57,8 +100,8 @@ func ParseResolver(cfg *config.ResolverConfig) (resolver.Resolver, error) { var nameservers []resolver_impl.NameServer for _, server := range cfg.Nameservers { nameservers = append(nameservers, resolver_impl.NameServer{ - Addr: server.Addr, - // Chain: chains[server.Chain], + Addr: server.Addr, + Chain: registry.Chain().Get(server.Chain), TTL: server.TTL, Timeout: server.Timeout, ClientIP: net.ParseIP(server.ClientIP), diff --git a/pkg/config/parsing/service.go b/pkg/config/parsing/service.go index 5c25e04..71664ae 100644 --- a/pkg/config/parsing/service.go +++ b/pkg/config/parsing/service.go @@ -1,7 +1,6 @@ package parsing import ( - "net/url" "strings" "github.com/go-gost/gost/pkg/chain" @@ -48,10 +47,16 @@ func ParseService(cfg *config.ServiceConfig) (*service.Service, error) { return nil, err } + auther := autherFromAuth(cfg.Listener.Auth) + if cfg.Listener.Auther != "" { + auther = registry.Auther().Get(cfg.Listener.Auther) + } + ln := registry.GetListener(cfg.Listener.Type)( listener.AddrOption(cfg.Addr), listener.ChainOption(registry.Chain().Get(cfg.Listener.Chain)), - listener.AuthsOption(parseAuths(cfg.Listener.Auths...)...), + listener.AutherOption(auther), + listener.AuthOption(parseAuth(cfg.Listener.Auth)), listener.TLSConfigOption(tlsConfig), listener.LoggerOption(listenerLogger), ) @@ -79,8 +84,13 @@ func ParseService(cfg *config.ServiceConfig) (*service.Service, error) { return nil, err } + auther = autherFromAuth(cfg.Handler.Auth) + if cfg.Handler.Auther != "" { + auther = registry.Auther().Get(cfg.Handler.Auther) + } h := registry.GetHandler(cfg.Handler.Type)( - handler.AuthsOption(parseAuths(cfg.Handler.Auths...)...), + handler.AutherOption(auther), + handler.AuthOption(parseAuth(cfg.Handler.Auth)), handler.RetriesOption(cfg.Handler.Retries), handler.ChainOption(registry.Chain().Get(cfg.Handler.Chain)), handler.BypassOption(registry.Bypass().Get(cfg.Bypass)), @@ -111,19 +121,6 @@ func ParseService(cfg *config.ServiceConfig) (*service.Service, error) { return s, nil } -func parseAuths(cfgs ...*config.AuthConfig) []*url.Userinfo { - var auths []*url.Userinfo - - for _, cfg := range cfgs { - if cfg == nil || cfg.Username == "" { - continue - } - auths = append(auths, url.UserPassword(cfg.Username, cfg.Password)) - } - - return auths -} - func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup { if cfg == nil || len(cfg.Targets) == 0 { return nil diff --git a/pkg/connector/http/connector.go b/pkg/connector/http/connector.go index 965c66e..8ea1102 100644 --- a/pkg/connector/http/connector.go +++ b/pkg/connector/http/connector.go @@ -23,7 +23,6 @@ func init() { } type httpConnector struct { - user *url.Userinfo md metadata options connector.Options } @@ -35,7 +34,6 @@ func NewConnector(opts ...connector.Option) connector.Connector { } return &httpConnector{ - user: options.User, options: options, } } @@ -67,7 +65,7 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add } req.Header.Set("Proxy-Connection", "keep-alive") - if user := c.user; user != nil { + if user := c.options.Auth; user != nil { u := user.Username() p, _ := user.Password() req.Header.Set("Proxy-Authorization", diff --git a/pkg/connector/http2/connector.go b/pkg/connector/http2/connector.go index 1d40d59..7fb1cc4 100644 --- a/pkg/connector/http2/connector.go +++ b/pkg/connector/http2/connector.go @@ -24,7 +24,6 @@ func init() { } type http2Connector struct { - user *url.Userinfo md metadata options connector.Options } @@ -36,7 +35,6 @@ func NewConnector(opts ...connector.Option) connector.Connector { } return &http2Connector{ - user: options.User, options: options, } } @@ -76,7 +74,7 @@ func (c *http2Connector) Connect(ctx context.Context, conn net.Conn, network, ad req.Header.Set("User-Agent", c.md.UserAgent) } - if user := c.user; user != nil { + if user := c.options.Auth; user != nil { u := user.Username() p, _ := user.Password() req.Header.Set("Proxy-Authorization", diff --git a/pkg/connector/option.go b/pkg/connector/option.go index d9014fa..68e5066 100644 --- a/pkg/connector/option.go +++ b/pkg/connector/option.go @@ -9,16 +9,16 @@ import ( ) type Options struct { - User *url.Userinfo + Auth *url.Userinfo TLSConfig *tls.Config Logger logger.Logger } type Option func(opts *Options) -func UserOption(user *url.Userinfo) Option { +func AuthOption(auth *url.Userinfo) Option { return func(opts *Options) { - opts.User = user + opts.Auth = auth } } diff --git a/pkg/connector/relay/bind.go b/pkg/connector/relay/bind.go index dde2c71..c4a0fb9 100644 --- a/pkg/connector/relay/bind.go +++ b/pkg/connector/relay/bind.go @@ -82,10 +82,10 @@ func (c *relayConnector) bind(conn net.Conn, cmd uint8, network, address string) Flags: cmd, } - if c.user != nil { - pwd, _ := c.user.Password() + if c.options.Auth != nil { + pwd, _ := c.options.Auth.Password() req.Features = append(req.Features, &relay.UserAuthFeature{ - Username: c.user.Username(), + Username: c.options.Auth.Username(), Password: pwd, }) } diff --git a/pkg/connector/relay/connector.go b/pkg/connector/relay/connector.go index c39447a..4ff4949 100644 --- a/pkg/connector/relay/connector.go +++ b/pkg/connector/relay/connector.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net" - "net/url" "time" "github.com/go-gost/gost/pkg/common/util/socks" @@ -19,7 +18,6 @@ func init() { } type relayConnector struct { - user *url.Userinfo md metadata options connector.Options } @@ -31,7 +29,6 @@ func NewConnector(opts ...connector.Option) connector.Connector { } return &relayConnector{ - user: options.User, options: options, } } @@ -73,10 +70,10 @@ func (c *relayConnector) Connect(ctx context.Context, conn net.Conn, network, ad } } - if c.user != nil { - pwd, _ := c.user.Password() + if c.options.Auth != nil { + pwd, _ := c.options.Auth.Password() req.Features = append(req.Features, &relay.UserAuthFeature{ - Username: c.user.Username(), + Username: c.options.Auth.Username(), Password: pwd, }) } diff --git a/pkg/connector/socks/v4/connector.go b/pkg/connector/socks/v4/connector.go index 02791ad..dc633b0 100644 --- a/pkg/connector/socks/v4/connector.go +++ b/pkg/connector/socks/v4/connector.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net" - "net/url" "strconv" "time" @@ -21,7 +20,6 @@ func init() { } type socks4Connector struct { - user *url.Userinfo md metadata options connector.Options } @@ -33,7 +31,6 @@ func NewConnector(opts ...connector.Option) connector.Connector { } return &socks4Connector{ - user: options.User, options: options, } } @@ -99,8 +96,8 @@ func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, a } var userid []byte - if c.user != nil && c.user.Username() != "" { - userid = []byte(c.user.Username()) + if c.options.Auth != nil { + userid = []byte(c.options.Auth.Username()) } req := gosocks4.NewRequest(gosocks4.CmdConnect, addr, userid) if err := req.Write(conn); err != nil { diff --git a/pkg/connector/socks/v5/connector.go b/pkg/connector/socks/v5/connector.go index 993b6e1..64796f0 100644 --- a/pkg/connector/socks/v5/connector.go +++ b/pkg/connector/socks/v5/connector.go @@ -48,7 +48,7 @@ func (c *socks5Connector) Init(md md.Metadata) (err error) { gosocks5.MethodNoAuth, gosocks5.MethodUserPass, }, - User: c.options.User, + User: c.options.Auth, TLSConfig: c.options.TLSConfig, logger: c.options.Logger, } diff --git a/pkg/connector/ss/connector.go b/pkg/connector/ss/connector.go index eec970c..b0124bf 100644 --- a/pkg/connector/ss/connector.go +++ b/pkg/connector/ss/connector.go @@ -41,9 +41,9 @@ func (c *ssConnector) Init(md md.Metadata) (err error) { return } - if c.options.User != nil { - method := c.options.User.Username() - password, _ := c.options.User.Password() + if c.options.Auth != nil { + method := c.options.Auth.Username() + password, _ := c.options.Auth.Password() c.cipher, err = ss.ShadowCipher(method, password, c.md.key) } diff --git a/pkg/connector/ss/udp/connector.go b/pkg/connector/ss/udp/connector.go index 61d0ac0..7cd9b87 100644 --- a/pkg/connector/ss/udp/connector.go +++ b/pkg/connector/ss/udp/connector.go @@ -40,9 +40,9 @@ func (c *ssuConnector) Init(md md.Metadata) (err error) { return } - if c.options.User != nil { - method := c.options.User.Username() - password, _ := c.options.User.Password() + if c.options.Auth != nil { + method := c.options.Auth.Username() + password, _ := c.options.Auth.Password() c.cipher, err = ss.ShadowCipher(method, password, c.md.key) } diff --git a/pkg/dialer/option.go b/pkg/dialer/option.go index 061dce8..a0131d3 100644 --- a/pkg/dialer/option.go +++ b/pkg/dialer/option.go @@ -10,16 +10,16 @@ import ( ) type Options struct { - User *url.Userinfo + Auth *url.Userinfo TLSConfig *tls.Config Logger logger.Logger } type Option func(opts *Options) -func UserOption(user *url.Userinfo) Option { +func AuthOption(auth *url.Userinfo) Option { return func(opts *Options) { - opts.User = user + opts.Auth = auth } } diff --git a/pkg/dialer/sshd/dialer.go b/pkg/dialer/sshd/dialer.go index a064e94..62edb90 100644 --- a/pkg/dialer/sshd/dialer.go +++ b/pkg/dialer/sshd/dialer.go @@ -4,7 +4,6 @@ import ( "context" "errors" "net" - "net/url" "sync" "time" @@ -20,7 +19,6 @@ func init() { } type sshdDialer struct { - user *url.Userinfo sessions map[string]*sshSession sessionMutex sync.Mutex md metadata @@ -34,7 +32,6 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer { } return &sshdDialer{ - user: options.User, sessions: make(map[string]*sshSession), options: options, } @@ -167,9 +164,9 @@ func (d *sshdDialer) initSession(ctx context.Context, addr string, conn net.Conn // Timeout: timeout, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - if d.user != nil { - config.User = d.user.Username() - if password, _ := d.user.Password(); password != "" { + if d.options.Auth != nil { + config.User = d.options.Auth.Username() + if password, _ := d.options.Auth.Password(); password != "" { config.Auth = []ssh.AuthMethod{ ssh.Password(password), } diff --git a/pkg/handler/http/handler.go b/pkg/handler/http/handler.go index c3bbb49..9d9a13d 100644 --- a/pkg/handler/http/handler.go +++ b/pkg/handler/http/handler.go @@ -16,9 +16,7 @@ import ( "time" "github.com/asaskevich/govalidator" - "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/chain" - auth_util "github.com/go-gost/gost/pkg/common/util/auth" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -30,10 +28,9 @@ func init() { } type httpHandler struct { - router *chain.Router - authenticator auth.Authenticator - md metadata - options handler.Options + router *chain.Router + md metadata + options handler.Options } func NewHandler(opts ...handler.Option) handler.Handler { @@ -52,7 +49,6 @@ func (h *httpHandler) Init(md md.Metadata) error { return err } - h.authenticator = auth_util.AuthFromUsers(h.options.Auths...) h.router = &chain.Router{ Retries: h.options.Retries, Chain: h.options.Chain, @@ -266,7 +262,7 @@ func (h *httpHandler) basicProxyAuth(proxyAuth string, log logger.Logger) (usern func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.Response, log logger.Logger) (ok bool) { u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log) - if h.authenticator == nil || h.authenticator.Authenticate(u, p) { + if h.options.Auther == nil || h.options.Auther.Authenticate(u, p) { return true } diff --git a/pkg/handler/http2/handler.go b/pkg/handler/http2/handler.go index 6600272..fe20942 100644 --- a/pkg/handler/http2/handler.go +++ b/pkg/handler/http2/handler.go @@ -18,9 +18,7 @@ import ( "strings" "time" - "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/chain" - auth_util "github.com/go-gost/gost/pkg/common/util/auth" "github.com/go-gost/gost/pkg/handler" http2_util "github.com/go-gost/gost/pkg/internal/util/http2" "github.com/go-gost/gost/pkg/logger" @@ -33,10 +31,9 @@ func init() { } type http2Handler struct { - router *chain.Router - authenticator auth.Authenticator - md metadata - options handler.Options + router *chain.Router + md metadata + options handler.Options } func NewHandler(opts ...handler.Option) handler.Handler { @@ -55,7 +52,6 @@ func (h *http2Handler) Init(md md.Metadata) error { return err } - h.authenticator = auth_util.AuthFromUsers(h.options.Auths...) h.router = &chain.Router{ Retries: h.options.Retries, Chain: h.options.Chain, @@ -239,7 +235,7 @@ func (h *http2Handler) basicProxyAuth(proxyAuth string) (username, password stri func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp *http.Response, log logger.Logger) (ok bool) { u, p, _ := h.basicProxyAuth(r.Header.Get("Proxy-Authorization")) - if h.authenticator == nil || h.authenticator.Authenticate(u, p) { + if h.options.Auther == nil || h.options.Auther.Authenticate(u, p) { return true } diff --git a/pkg/handler/option.go b/pkg/handler/option.go index 2d7a75b..cedce4c 100644 --- a/pkg/handler/option.go +++ b/pkg/handler/option.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "net/url" + "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/hosts" @@ -17,7 +18,8 @@ type Options struct { Resolver resolver.Resolver Hosts hosts.HostMapper Bypass bypass.Bypass - Auths []*url.Userinfo + Auth *url.Userinfo + Auther auth.Authenticator TLSConfig *tls.Config Logger logger.Logger } @@ -54,9 +56,14 @@ func BypassOption(bypass bypass.Bypass) Option { } } -func AuthsOption(auths ...*url.Userinfo) Option { +func AuthOption(auth *url.Userinfo) Option { return func(opts *Options) { - opts.Auths = auths + opts.Auth = auth + } +} +func AutherOption(auther auth.Authenticator) Option { + return func(opts *Options) { + opts.Auther = auther } } diff --git a/pkg/handler/relay/handler.go b/pkg/handler/relay/handler.go index 7c8a158..14a4269 100644 --- a/pkg/handler/relay/handler.go +++ b/pkg/handler/relay/handler.go @@ -6,9 +6,7 @@ import ( "strconv" "time" - "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/chain" - auth_util "github.com/go-gost/gost/pkg/common/util/auth" "github.com/go-gost/gost/pkg/handler" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" @@ -20,11 +18,10 @@ func init() { } type relayHandler struct { - group *chain.NodeGroup - router *chain.Router - authenticator auth.Authenticator - md metadata - options handler.Options + group *chain.NodeGroup + router *chain.Router + md metadata + options handler.Options } func NewHandler(opts ...handler.Option) handler.Handler { @@ -43,7 +40,6 @@ func (h *relayHandler) Init(md md.Metadata) (err error) { return err } - h.authenticator = auth_util.AuthFromUsers(h.options.Auths...) h.router = &chain.Router{ Retries: h.options.Retries, Chain: h.options.Chain, @@ -113,7 +109,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { Version: relay.Version1, Status: relay.StatusOK, } - if h.authenticator != nil && !h.authenticator.Authenticate(user, pass) { + if h.options.Auther != nil && !h.options.Auther.Authenticate(user, pass) { resp.Status = relay.StatusUnauthorized resp.WriteTo(conn) log.Error("unauthorized") diff --git a/pkg/handler/socks/v4/handler.go b/pkg/handler/socks/v4/handler.go index 6bd6e61..54cea8c 100644 --- a/pkg/handler/socks/v4/handler.go +++ b/pkg/handler/socks/v4/handler.go @@ -6,9 +6,7 @@ import ( "time" "github.com/go-gost/gosocks4" - "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/chain" - auth_util "github.com/go-gost/gost/pkg/common/util/auth" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -21,10 +19,9 @@ func init() { } type socks4Handler struct { - router *chain.Router - authenticator auth.Authenticator - md metadata - options handler.Options + router *chain.Router + md metadata + options handler.Options } func NewHandler(opts ...handler.Option) handler.Handler { @@ -43,7 +40,6 @@ func (h *socks4Handler) Init(md md.Metadata) (err error) { return err } - h.authenticator = auth_util.AuthFromUsers(h.options.Auths...) h.router = &chain.Router{ Retries: h.options.Retries, Chain: h.options.Chain, @@ -85,8 +81,8 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) { conn.SetReadDeadline(time.Time{}) - if h.authenticator != nil && - !h.authenticator.Authenticate(string(req.Userid), "") { + if h.options.Auther != nil && + !h.options.Auther.Authenticate(string(req.Userid), "") { resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil) resp.Write(conn) log.Debug(resp) diff --git a/pkg/handler/socks/v5/handler.go b/pkg/handler/socks/v5/handler.go index 34d2f08..1327fd6 100644 --- a/pkg/handler/socks/v5/handler.go +++ b/pkg/handler/socks/v5/handler.go @@ -7,7 +7,6 @@ import ( "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/chain" - auth_util "github.com/go-gost/gost/pkg/common/util/auth" "github.com/go-gost/gost/pkg/common/util/socks" "github.com/go-gost/gost/pkg/handler" md "github.com/go-gost/gost/pkg/metadata" @@ -51,7 +50,7 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) { } h.selector = &serverSelector{ - Authenticator: auth_util.AuthFromUsers(h.options.Auths...), + Authenticator: h.options.Auther, TLSConfig: h.options.TLSConfig, logger: h.options.Logger, noTLS: h.md.noTLS, diff --git a/pkg/handler/ss/handler.go b/pkg/handler/ss/handler.go index 337d920..b329376 100644 --- a/pkg/handler/ss/handler.go +++ b/pkg/handler/ss/handler.go @@ -42,9 +42,9 @@ func (h *ssHandler) Init(md md.Metadata) (err error) { if err = h.parseMetadata(md); err != nil { return } - if len(h.options.Auths) > 0 { - method := h.options.Auths[0].Username() - password, _ := h.options.Auths[0].Password() + if h.options.Auth != nil { + method := h.options.Auth.Username() + password, _ := h.options.Auth.Password() h.cipher, err = ss.ShadowCipher(method, password, h.md.key) if err != nil { return diff --git a/pkg/handler/ss/udp/handler.go b/pkg/handler/ss/udp/handler.go index 2e6498d..4b5ce15 100644 --- a/pkg/handler/ss/udp/handler.go +++ b/pkg/handler/ss/udp/handler.go @@ -43,9 +43,9 @@ func (h *ssuHandler) Init(md md.Metadata) (err error) { return } - if len(h.options.Auths) > 0 { - method := h.options.Auths[0].Username() - password, _ := h.options.Auths[0].Password() + if h.options.Auth != nil { + method := h.options.Auth.Username() + password, _ := h.options.Auth.Password() h.cipher, err = ss.ShadowCipher(method, password, h.md.key) if err != nil { return diff --git a/pkg/handler/tap/handler.go b/pkg/handler/tap/handler.go index dc7b405..3850763 100644 --- a/pkg/handler/tap/handler.go +++ b/pkg/handler/tap/handler.go @@ -54,9 +54,9 @@ func (h *tapHandler) Init(md md.Metadata) (err error) { return } - if len(h.options.Auths) > 0 { - method := h.options.Auths[0].Username() - password, _ := h.options.Auths[0].Password() + if h.options.Auth != nil { + method := h.options.Auth.Username() + password, _ := h.options.Auth.Password() h.cipher, err = ss.ShadowCipher(method, password, h.md.key) if err != nil { return diff --git a/pkg/handler/tun/handler.go b/pkg/handler/tun/handler.go index 4b6b2f3..088c15c 100644 --- a/pkg/handler/tun/handler.go +++ b/pkg/handler/tun/handler.go @@ -56,9 +56,9 @@ func (h *tunHandler) Init(md md.Metadata) (err error) { return } - if len(h.options.Auths) > 0 { - method := h.options.Auths[0].Username() - password, _ := h.options.Auths[0].Password() + if h.options.Auth != nil { + method := h.options.Auth.Username() + password, _ := h.options.Auth.Password() h.cipher, err = ss.ShadowCipher(method, password, h.md.key) if err != nil { return diff --git a/pkg/listener/option.go b/pkg/listener/option.go index aa9cd6d..353effd 100644 --- a/pkg/listener/option.go +++ b/pkg/listener/option.go @@ -4,13 +4,15 @@ import ( "crypto/tls" "net/url" + "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/logger" ) type Options struct { Addr string - Auths []*url.Userinfo + Auther auth.Authenticator + Auth *url.Userinfo TLSConfig *tls.Config Chain chain.Chainer Logger logger.Logger @@ -24,9 +26,15 @@ func AddrOption(addr string) Option { } } -func AuthsOption(auths ...*url.Userinfo) Option { +func AutherOption(auther auth.Authenticator) Option { return func(opts *Options) { - opts.Auths = auths + opts.Auther = auther + } +} + +func AuthOption(auth *url.Userinfo) Option { + return func(opts *Options) { + opts.Auth = auth } } diff --git a/pkg/listener/ssh/listener.go b/pkg/listener/ssh/listener.go index 2ddc352..eb82314 100644 --- a/pkg/listener/ssh/listener.go +++ b/pkg/listener/ssh/listener.go @@ -5,7 +5,6 @@ import ( "net" "time" - auth_util "github.com/go-gost/gost/pkg/common/util/auth" ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" @@ -53,13 +52,12 @@ func (l *sshListener) Init(md md.Metadata) (err error) { l.Listener = ln - authenticator := auth_util.AuthFromUsers(l.options.Auths...) config := &ssh.ServerConfig{ - PasswordCallback: ssh_util.PasswordCallback(authenticator), + PasswordCallback: ssh_util.PasswordCallback(l.options.Auther), PublicKeyCallback: ssh_util.PublicKeyCallback(l.md.authorizedKeys), } config.AddHostKey(l.md.signer) - if authenticator == nil && len(l.md.authorizedKeys) == 0 { + if l.options.Auther == nil && len(l.md.authorizedKeys) == 0 { config.NoClientAuth = true } diff --git a/pkg/listener/sshd/listener.go b/pkg/listener/sshd/listener.go index 90f9aba..3e3a091 100644 --- a/pkg/listener/sshd/listener.go +++ b/pkg/listener/sshd/listener.go @@ -7,7 +7,6 @@ import ( "strconv" "time" - auth_util "github.com/go-gost/gost/pkg/common/util/auth" ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" sshd_util "github.com/go-gost/gost/pkg/internal/util/sshd" "github.com/go-gost/gost/pkg/listener" @@ -62,13 +61,12 @@ func (l *sshdListener) Init(md md.Metadata) (err error) { l.Listener = ln - authenticator := auth_util.AuthFromUsers(l.options.Auths...) config := &ssh.ServerConfig{ - PasswordCallback: ssh_util.PasswordCallback(authenticator), + PasswordCallback: ssh_util.PasswordCallback(l.options.Auther), PublicKeyCallback: ssh_util.PublicKeyCallback(l.md.authorizedKeys), } config.AddHostKey(l.md.signer) - if authenticator == nil && len(l.md.authorizedKeys) == 0 { + if l.options.Auther == nil && len(l.md.authorizedKeys) == 0 { config.NoClientAuth = true } diff --git a/pkg/registry/auther.go b/pkg/registry/auther.go new file mode 100644 index 0000000..2281d72 --- /dev/null +++ b/pkg/registry/auther.go @@ -0,0 +1,62 @@ +package registry + +import ( + "sync" + + "github.com/go-gost/gost/pkg/auth" +) + +var ( + autherReg = &autherRegistry{} +) + +func Auther() *autherRegistry { + return autherReg +} + +type autherRegistry struct { + m sync.Map +} + +func (r *autherRegistry) Register(name string, auth auth.Authenticator) error { + if _, loaded := r.m.LoadOrStore(name, auth); loaded { + return ErrDup + } + + return nil +} + +func (r *autherRegistry) Unregister(name string) { + r.m.Delete(name) +} + +func (r *autherRegistry) IsRegistered(name string) bool { + _, ok := r.m.Load(name) + return ok +} + +func (r *autherRegistry) Get(name string) auth.Authenticator { + if name == "" { + return nil + } + return &autherWrapper{name: name} +} + +func (r *autherRegistry) get(name string) auth.Authenticator { + if v, ok := r.m.Load(name); ok { + return v.(auth.Authenticator) + } + return nil +} + +type autherWrapper struct { + name string +} + +func (w *autherWrapper) Authenticate(user, password string) bool { + v := autherReg.get(w.name) + if v == nil { + return true + } + return v.Authenticate(user, password) +} diff --git a/pkg/registry/bypass.go b/pkg/registry/bypass.go index 06cee2a..a77da73 100644 --- a/pkg/registry/bypass.go +++ b/pkg/registry/bypass.go @@ -36,6 +36,9 @@ func (r *bypassRegistry) IsRegistered(name string) bool { } func (r *bypassRegistry) Get(name string) bypass.Bypass { + if name == "" { + return nil + } return &bypassWrapper{name: name} } diff --git a/pkg/registry/chain.go b/pkg/registry/chain.go index 53096bb..99b74ce 100644 --- a/pkg/registry/chain.go +++ b/pkg/registry/chain.go @@ -36,6 +36,9 @@ func (r *chainRegistry) IsRegistered(name string) bool { } func (r *chainRegistry) Get(name string) chain.Chainer { + if name == "" { + return nil + } return &chainWrapper{name: name} } diff --git a/pkg/registry/hosts.go b/pkg/registry/hosts.go index 3aea9ae..6692609 100644 --- a/pkg/registry/hosts.go +++ b/pkg/registry/hosts.go @@ -37,6 +37,9 @@ func (r *hostsRegistry) IsRegistered(name string) bool { } func (r *hostsRegistry) Get(name string) hosts.HostMapper { + if name == "" { + return nil + } return &hostsWrapper{name: name} } diff --git a/pkg/registry/resolver.go b/pkg/registry/resolver.go index ad6c089..6035590 100644 --- a/pkg/registry/resolver.go +++ b/pkg/registry/resolver.go @@ -38,6 +38,9 @@ func (r *resolverRegistry) IsRegistered(name string) bool { } func (r *resolverRegistry) Get(name string) resolver.Resolver { + if name == "" { + return nil + } return &resolverWrapper{name: name} } diff --git a/pkg/registry/service.go b/pkg/registry/service.go index 7c45f34..5b4da05 100644 --- a/pkg/registry/service.go +++ b/pkg/registry/service.go @@ -36,6 +36,9 @@ func (r *serviceRegistry) IsRegistered(name string) bool { } func (r *serviceRegistry) Get(name string) *service.Service { + if name == "" { + return nil + } v, ok := r.m.Load(name) if !ok { return nil