From 566e93001063034daf03a2ede9a5529088eb2876 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Mon, 3 Jan 2022 23:45:49 +0800 Subject: [PATCH] update config --- cmd/gost/cmd.go | 171 ++++++++++++++++++++++++---- cmd/gost/config.go | 68 ++++++++--- cmd/gost/main.go | 2 - cmd/gost/norm.go | 139 ---------------------- gost.yml | 30 ++--- pkg/auth/auth.go | 15 +-- pkg/config/config.go | 32 ++++-- pkg/connector/http/connector.go | 4 +- pkg/connector/http/metadata.go | 12 -- pkg/connector/http2/connector.go | 4 +- pkg/connector/http2/metadata.go | 13 --- pkg/connector/option.go | 8 ++ pkg/connector/relay/bind.go | 6 +- pkg/connector/relay/connector.go | 9 +- pkg/connector/relay/metadata.go | 12 -- pkg/connector/socks/v4/connector.go | 9 +- pkg/connector/socks/v4/metadata.go | 6 - pkg/connector/socks/v5/connector.go | 5 +- pkg/connector/socks/v5/metadata.go | 13 --- pkg/connector/ss/connector.go | 21 +++- pkg/connector/ss/metadata.go | 21 +--- pkg/connector/ss/udp/connector.go | 25 +++- pkg/connector/ss/udp/metadata.go | 21 +--- pkg/dialer/forward/ssh/dialer.go | 9 +- pkg/dialer/forward/ssh/metadata.go | 13 --- pkg/dialer/option.go | 8 ++ pkg/handler/forward/ssh/handler.go | 16 +-- pkg/handler/forward/ssh/metadata.go | 17 --- pkg/handler/http/handler.go | 12 +- pkg/handler/http/metadata.go | 24 +--- pkg/handler/http2/handler.go | 12 +- pkg/handler/http2/metadata.go | 24 +--- pkg/handler/option.go | 16 ++- pkg/handler/relay/handler.go | 14 ++- pkg/handler/relay/metadata.go | 17 --- pkg/handler/socks/v4/handler.go | 14 ++- pkg/handler/socks/v4/metadata.go | 15 +-- pkg/handler/socks/v5/handler.go | 14 ++- pkg/handler/socks/v5/metadata.go | 17 --- pkg/listener/option.go | 12 +- pkg/listener/ssh/listener.go | 16 +-- pkg/listener/ssh/metadata.go | 17 --- 42 files changed, 412 insertions(+), 521 deletions(-) delete mode 100644 cmd/gost/norm.go diff --git a/cmd/gost/cmd.go b/cmd/gost/cmd.go index 5b7f54b..2b6b8a2 100644 --- a/cmd/gost/cmd.go +++ b/cmd/gost/cmd.go @@ -7,11 +7,12 @@ import ( "strings" "github.com/go-gost/gost/pkg/config" + "github.com/go-gost/gost/pkg/registry" ) var ( - ErrInvalidService = errors.New("invalid service") - ErrInvalidNode = errors.New("invalid node") + ErrInvalidCmd = errors.New("invalid cmd") + ErrInvalidNode = errors.New("invalid node") ) type stringList []string @@ -36,32 +37,36 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { } for i, node := range nodes { - url, err := checkCmd(node) + url, err := normCmd(node) if err != nil { return nil, err } + + nodeConfig, err := buildNodeConfig(url) + if err != nil { + return nil, err + } + nodeConfig.Name = "node-0" + chain.Hops = append(chain.Hops, &config.HopConfig{ - Name: fmt.Sprintf("hop-%d", i), - Nodes: []*config.NodeConfig{ - { - Name: "node-0", - URL: url, - }, - }, + Name: fmt.Sprintf("hop-%d", i), + Nodes: []*config.NodeConfig{nodeConfig}, }) } for i, svc := range services { - url, err := checkCmd(svc) + url, err := normCmd(svc) if err != nil { return nil, err } - service := &config.ServiceConfig{ - Name: fmt.Sprintf("service-%d", i), - URL: url, + + service, err := buildServiceConfig(url) + if err != nil { + return nil, err } + service.Name = fmt.Sprintf("service-%d", i) if chain != nil { - service.Chain = chain.Name + service.Handler.Chain = chain.Name } cfg.Services = append(cfg.Services, service) } @@ -69,20 +74,140 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { return cfg, nil } -func checkCmd(s string) (string, error) { +func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) { + var handler, listener string + schemes := strings.Split(url.Scheme, "+") + if len(schemes) == 1 { + handler = schemes[0] + listener = schemes[0] + } + if len(schemes) == 2 { + handler = schemes[0] + listener = schemes[1] + } + + svc := &config.ServiceConfig{ + Addr: url.Host, + } + + if h := registry.GetHandler(handler); h == nil { + handler = "auto" + } + if ln := registry.GetListener(listener); ln == nil { + listener = "tcp" + if handler == "ssu" { + listener = "udp" + } + } + + if remotes := strings.Trim(url.EscapedPath(), "/"); remotes != "" { + svc.Forwarder = &config.ForwarderConfig{ + Targets: strings.Split(remotes, ","), + } + if handler != "relay" { + if listener == "tcp" || listener == "udp" || + listener == "rtcp" || listener == "rudp" || + listener == "tun" || listener == "tap" { + handler = listener + } else { + handler = "tcp" + } + } + } + + md := make(map[string]interface{}) + for k, v := range url.Query() { + if len(v) > 0 { + md[k] = v[0] + } + } + + var auths []config.AuthConfig + if url.User != nil { + auth := config.AuthConfig{ + Username: url.User.Username(), + } + auth.Password, _ = url.User.Password() + auths = append(auths, auth) + } + + svc.Handler = &config.HandlerConfig{ + Type: handler, + Auths: auths, + Metadata: md, + } + svc.Listener = &config.ListenerConfig{ + Type: listener, + Metadata: md, + } + + return svc, nil +} + +func buildNodeConfig(url *url.URL) (*config.NodeConfig, error) { + var connector, dialer string + schemes := strings.Split(url.Scheme, "+") + if len(schemes) == 1 { + connector = schemes[0] + dialer = schemes[0] + } + if len(schemes) == 2 { + connector = schemes[0] + dialer = schemes[1] + } + + node := &config.NodeConfig{ + Addr: url.Host, + } + + if c := registry.GetConnector(connector); c == nil { + connector = "http" + } + if d := registry.GetDialer(dialer); d == nil { + dialer = "tcp" + if connector == "ssu" { + dialer = "udp" + } + } + + md := make(map[string]interface{}) + for k, v := range url.Query() { + if len(v) > 0 { + md[k] = v[0] + } + } + md["serverName"] = url.Host + + var auth *config.AuthConfig + if url.User != nil { + auth = &config.AuthConfig{ + Username: url.User.Username(), + } + auth.Password, _ = url.User.Password() + } + + node.Connector = &config.ConnectorConfig{ + Type: connector, + Auth: auth, + Metadata: md, + } + node.Dialer = &config.DialerConfig{ + Type: dialer, + Metadata: md, + } + + return node, nil +} + +func normCmd(s string) (*url.URL, error) { s = strings.TrimSpace(s) if s == "" { - return "", ErrInvalidService + return nil, ErrInvalidCmd } if !strings.Contains(s, "://") { s = "auto://" + s } - u, err := url.Parse(s) - if err != nil { - return "", err - } - - return u.String(), nil + return url.Parse(s) } diff --git a/cmd/gost/config.go b/cmd/gost/config.go index d6ba5b8..54c305c 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -3,9 +3,11 @@ package main import ( "io" "net" + "net/url" "os" "strings" + "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/config" @@ -54,12 +56,14 @@ func buildService(cfg *config.Config) (services []*service.Service) { } for _, svc := range cfg.Services { + if svc.Listener == nil || svc.Handler == nil { + continue + } serviceLogger := log.WithFields(map[string]interface{}{ "kind": "service", "service": svc.Name, "listener": svc.Listener.Type, "handler": svc.Handler.Type, - "chain": svc.Chain, }) listenerLogger := serviceLogger.WithFields(map[string]interface{}{ @@ -67,11 +71,12 @@ func buildService(cfg *config.Config) (services []*service.Service) { }) ln := registry.GetListener(svc.Listener.Type)( listener.AddrOption(svc.Addr), + listener.AuthenticatorOption(authFromConfig(svc.Listener.Auths...)), listener.LoggerOption(listenerLogger), ) if chainable, ok := ln.(chain.Chainable); ok { - chainable.WithChain(chains[svc.Chain]) + chainable.WithChain(chains[svc.Listener.Chain]) } if svc.Listener.Metadata == nil { @@ -86,12 +91,13 @@ func buildService(cfg *config.Config) (services []*service.Service) { }) h := registry.GetHandler(svc.Handler.Type)( - handler.BypassOption(bypasses[svc.Bypass]), handler.LoggerOption(handlerLogger), + handler.BypassOption(bypasses[svc.Handler.Bypass]), + handler.AuthenticatorOption(authFromConfig(svc.Handler.Auths...)), handler.RouterOption(&chain.Router{ - Chain: chains[svc.Chain], - Resolver: resolvers[svc.Resolver], - Hosts: hosts[svc.Hosts], + Chain: chains[svc.Handler.Chain], + Resolver: resolvers[svc.Handler.Resolver], + Hosts: hosts[svc.Handler.Hosts], Logger: handlerLogger, }), ) @@ -134,14 +140,27 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { for _, hop := range cfg.Hops { group := &chain.NodeGroup{} for _, v := range hop.Nodes { - connectorLogger := chainLogger.WithFields(map[string]interface{}{ - "kind": "connector", + nodeLogger := chainLogger.WithFields(map[string]interface{}{ + "kind": "node", "connector": v.Connector.Type, "dialer": v.Dialer.Type, "hop": hop.Name, "node": v.Name, }) + connectorLogger := nodeLogger.WithFields(map[string]interface{}{ + "kind": "connector", + }) + + var connectorUser *url.Userinfo + if auth := v.Connector.Auth; auth != nil && auth.Username != "" { + if auth.Password == "" { + connectorUser = url.User(auth.Username) + } else { + connectorUser = url.UserPassword(auth.Username, auth.Password) + } + } cr := registry.GetConnector(v.Connector.Type)( + connector.UserOption(connectorUser), connector.LoggerOption(connectorLogger), ) @@ -152,14 +171,20 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { connectorLogger.Fatal("init: ", err) } - dialerLogger := chainLogger.WithFields(map[string]interface{}{ - "kind": "dialer", - "connector": v.Connector.Type, - "dialer": v.Dialer.Type, - "hop": hop.Name, - "node": v.Name, + dialerLogger := nodeLogger.WithFields(map[string]interface{}{ + "kind": "dialer", }) + + var dialerUser *url.Userinfo + if auth := v.Dialer.Auth; auth != nil && auth.Username != "" { + if auth.Password == "" { + dialerUser = url.User(auth.Username) + } else { + dialerUser = url.UserPassword(auth.Username, auth.Password) + } + } d := registry.GetDialer(v.Dialer.Type)( + dialer.UserOption(dialerUser), dialer.LoggerOption(dialerLogger), ) @@ -305,3 +330,18 @@ func hostsFromConfig(cfg *config.HostsConfig) hostspkg.HostMapper { } return hosts } + +func authFromConfig(cfgs ...config.AuthConfig) auth.Authenticator { + auths := make(map[string]string) + for _, cfg := range cfgs { + if cfg.Username == "" { + continue + } + auths[cfg.Username] = cfg.Password + } + if len(auths) > 0 { + return auth.NewMapAuthenticator(auths) + } + + return nil +} diff --git a/cmd/gost/main.go b/cmd/gost/main.go index d6d9464..d1e6c43 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -63,8 +63,6 @@ func main() { log.Fatal(err) } - normConfig(cfg) - log = logFromConfig(cfg.Log) if outputCfgFile != "" { diff --git a/cmd/gost/norm.go b/cmd/gost/norm.go deleted file mode 100644 index 0e16f12..0000000 --- a/cmd/gost/norm.go +++ /dev/null @@ -1,139 +0,0 @@ -package main - -import ( - "net/url" - "strings" - - "github.com/go-gost/gost/pkg/config" - "github.com/go-gost/gost/pkg/registry" -) - -// normConfig normalizes the config. -func normConfig(cfg *config.Config) { - for _, svc := range cfg.Services { - normService(svc) - } - for _, chain := range cfg.Chains { - normChain(chain) - } -} - -func normService(svc *config.ServiceConfig) { - if svc.URL == "" { - return - } - - u, _ := url.Parse(svc.URL) - - var handler, listener string - schemes := strings.Split(u.Scheme, "+") - if len(schemes) == 1 { - handler = schemes[0] - listener = schemes[0] - } - if len(schemes) == 2 { - handler = schemes[0] - listener = schemes[1] - } - - md := make(map[string]interface{}) - for k, v := range u.Query() { - if len(v) > 0 { - md[k] = v[0] - } - } - if u.User != nil { - md["users"] = []interface{}{u.User.String()} - } - - svc.Addr = u.Host - - if h := registry.GetHandler(handler); h == nil { - handler = "auto" - } - if ln := registry.GetListener(listener); ln == nil { - listener = "tcp" - if handler == "ssu" { - listener = "udp" - } - } - - if remotes := strings.Trim(u.EscapedPath(), "/"); remotes != "" { - svc.Forwarder = &config.ForwarderConfig{ - Targets: strings.Split(remotes, ","), - } - if handler != "relay" { - if listener == "tcp" || listener == "udp" || - listener == "rtcp" || listener == "rudp" || - listener == "tun" || listener == "tap" { - handler = listener - } else { - handler = "tcp" - } - } - } - - svc.Handler = &config.HandlerConfig{ - Type: handler, - Metadata: md, - } - svc.Listener = &config.ListenerConfig{ - Type: listener, - Metadata: md, - } -} - -func normChain(chain *config.ChainConfig) { - for _, hop := range chain.Hops { - for _, node := range hop.Nodes { - if node.URL == "" { - continue - } - - u, _ := url.Parse(node.URL) - - var connector, dialer string - schemes := strings.Split(u.Scheme, "+") - if len(schemes) == 1 { - connector = schemes[0] - dialer = schemes[0] - } - if len(schemes) == 2 { - connector = schemes[0] - dialer = schemes[1] - } - - md := make(map[string]interface{}) - for k, v := range u.Query() { - if len(v) > 0 { - md[k] = v[0] - } - } - if u.User != nil { - md["user"] = u.User.String() - } - md["serverName"] = u.Host - - node.Addr = u.Host - - if c := registry.GetConnector(connector); c == nil { - connector = "http" - } - if d := registry.GetDialer(dialer); d == nil { - dialer = "tcp" - if connector == "ssu" { - dialer = "udp" - } - } - - node.Connector = &config.ConnectorConfig{ - Type: connector, - Metadata: md, - } - node.Dialer = &config.DialerConfig{ - Type: dialer, - Metadata: md, - } - } - } -} diff --git a/gost.yml b/gost.yml index c2d8835..e86a2c2 100644 --- a/gost.yml +++ b/gost.yml @@ -7,13 +7,12 @@ services: - name: http+tcp url: "http://gost:gost@:8000" addr: ":28000" - chain: chain01 - # bypass: bypass01 handler: type: http + chain: chain01 + # bypass: bypass01 metadata: proxyAgent: "gost/3.0" - retry: 3 auths: - user1:pass1 - user2:pass2 @@ -26,10 +25,10 @@ services: - name: ss url: "ss://chacha20:gost@:8000" addr: ":28338" - # chain: chain01 - # bypass: bypass01 handler: type: ss + # chain: chain01 + # bypass: bypass01 metadata: method: chacha20-ietf password: gost @@ -43,10 +42,10 @@ services: - name: socks5 url: "socks5://gost:gost@:1080" addr: ":21080" - # chain: chain-ss - # bypass: bypass01 handler: type: socks5 + # chain: chain-ss + # bypass: bypass01 metadata: auths: - gost:gost @@ -112,7 +111,6 @@ services: - name: rtcp addr: ":28100" - # chain: chain-socks5 forwarder: targets: - 192.168.8.8:80 @@ -122,6 +120,7 @@ services: readTimeout: 5s listener: type: rtcp + # chain: chain-socks5 metadata: keepAlive: 15s mux: true @@ -318,21 +317,6 @@ hosts: - bar - baz -probeResistance: -- name: pr-code404 - type: code - value: 404 - knock: www.example.com -- name: pr-web - type: web - value: http://example.com/page.html -- name: pr-host - type: host - value: example.com:80 -- name: pr-file - type: file - value: /path/to/file - profiling: addr: ":6060" enabled: true diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index d77ee19..5c2428c 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -6,22 +6,22 @@ type Authenticator interface { } // LocalAuthenticator is an Authenticator that authenticates client by local key-value pairs. -type LocalAuthenticator struct { +type MapAuthenticator struct { kvs map[string]string } -// NewLocalAuthenticator creates an Authenticator that authenticates client by local infos. -func NewLocalAuthenticator(kvs map[string]string) *LocalAuthenticator { +// NewMapAuthenticator creates an Authenticator that authenticates client by local infos. +func NewMapAuthenticator(kvs map[string]string) *MapAuthenticator { if kvs == nil { kvs = make(map[string]string) } - return &LocalAuthenticator{ + return &MapAuthenticator{ kvs: kvs, } } // Authenticate checks the validity of the provided user-password pair. -func (au *LocalAuthenticator) Authenticate(user, password string) bool { +func (au *MapAuthenticator) Authenticate(user, password string) bool { if au == nil { return true } @@ -33,8 +33,3 @@ func (au *LocalAuthenticator) Authenticate(user, password string) bool { v, ok := au.kvs[user] return ok && (v == "" || password == v) } - -// Add adds a key-value pair to the Authenticator. -func (au *LocalAuthenticator) Add(k, v string) { - au.kvs[k] = v -} diff --git a/pkg/config/config.go b/pkg/config/config.go index 4d3826e..6f0484a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -36,6 +36,11 @@ type TLSConfig struct { CA string } +type AuthConfig struct { + Username string + Password string +} + type SelectorConfig struct { Strategy string MaxFails int @@ -76,11 +81,18 @@ type HostsConfig struct { type ListenerConfig struct { Type string + Chain string `yaml:",omitempty"` + Auths []AuthConfig `yaml:",omitempty"` Metadata map[string]interface{} `yaml:",omitempty"` } type HandlerConfig struct { Type string + Chain string `yaml:",omitempty"` + Bypass string `yaml:",omitempty"` + Resolver string `yaml:",omitempty"` + Hosts string `yaml:",omitempty"` + Auths []AuthConfig `yaml:",omitempty"` Metadata map[string]interface{} `yaml:",omitempty"` } @@ -91,24 +103,21 @@ type ForwarderConfig struct { type DialerConfig struct { Type string + Auth *AuthConfig `yaml:",omitempty"` Metadata map[string]interface{} `yaml:",omitempty"` } type ConnectorConfig struct { Type string + Auth *AuthConfig `yaml:",omitempty"` Metadata map[string]interface{} `yaml:",omitempty"` } type ServiceConfig struct { Name string - URL string `yaml:",omitempty"` Addr string `yaml:",omitempty"` - Chain string `yaml:",omitempty"` - Bypass string `yaml:",omitempty"` - Resolver string `yaml:",omitempty"` - Hosts string `yaml:",omitempty"` - Listener *ListenerConfig `yaml:",omitempty"` Handler *HandlerConfig `yaml:",omitempty"` + Listener *ListenerConfig `yaml:",omitempty"` Forwarder *ForwarderConfig `yaml:",omitempty"` } @@ -126,7 +135,6 @@ type HopConfig struct { type NodeConfig struct { Name string - URL string `yaml:",omitempty"` Addr string `yaml:",omitempty"` Dialer *DialerConfig `yaml:",omitempty"` Connector *ConnectorConfig `yaml:",omitempty"` @@ -134,14 +142,14 @@ type NodeConfig struct { } type Config struct { - Log *LogConfig `yaml:",omitempty"` - Profiling *ProfilingConfig `yaml:",omitempty"` - TLS *TLSConfig `yaml:",omitempty"` + Services []*ServiceConfig + Chains []*ChainConfig `yaml:",omitempty"` Bypasses []*BypassConfig `yaml:",omitempty"` Resolvers []*ResolverConfig `yaml:",omitempty"` Hosts []*HostsConfig `yaml:",omitempty"` - Chains []*ChainConfig `yaml:",omitempty"` - Services []*ServiceConfig + TLS *TLSConfig `yaml:",omitempty"` + Log *LogConfig `yaml:",omitempty"` + Profiling *ProfilingConfig `yaml:",omitempty"` } func (c *Config) Load() error { diff --git a/pkg/connector/http/connector.go b/pkg/connector/http/connector.go index 264fe2a..def67be 100644 --- a/pkg/connector/http/connector.go +++ b/pkg/connector/http/connector.go @@ -23,6 +23,7 @@ func init() { } type httpConnector struct { + user *url.Userinfo md metadata logger logger.Logger } @@ -34,6 +35,7 @@ func NewConnector(opts ...connector.Option) connector.Connector { } return &httpConnector{ + user: options.User, logger: options.Logger, } } @@ -65,7 +67,7 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add } req.Header.Set("Proxy-Connection", "keep-alive") - if user := c.md.User; user != nil { + if user := c.user; user != nil { u := user.Username() p, _ := user.Password() req.Header.Set("Proxy-Authorization", diff --git a/pkg/connector/http/metadata.go b/pkg/connector/http/metadata.go index 507c4f4..3208989 100644 --- a/pkg/connector/http/metadata.go +++ b/pkg/connector/http/metadata.go @@ -2,8 +2,6 @@ package http import ( "net/http" - "net/url" - "strings" "time" mdata "github.com/go-gost/gost/pkg/metadata" @@ -11,7 +9,6 @@ import ( type metadata struct { connectTimeout time.Duration - User *url.Userinfo header http.Header } @@ -24,15 +21,6 @@ func (c *httpConnector) parseMetadata(md mdata.Metadata) (err error) { c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) - if v := mdata.GetString(md, user); v != "" { - ss := strings.SplitN(v, ":", 2) - if len(ss) == 1 { - c.md.User = url.User(ss[0]) - } else { - c.md.User = url.UserPassword(ss[0], ss[1]) - } - } - if mm := mdata.GetStringMapString(md, header); len(mm) > 0 { hd := http.Header{} for k, v := range mm { diff --git a/pkg/connector/http2/connector.go b/pkg/connector/http2/connector.go index 162c6d4..e00ea66 100644 --- a/pkg/connector/http2/connector.go +++ b/pkg/connector/http2/connector.go @@ -24,6 +24,7 @@ func init() { } type http2Connector struct { + user *url.Userinfo md metadata logger logger.Logger } @@ -35,6 +36,7 @@ func NewConnector(opts ...connector.Option) connector.Connector { } return &http2Connector{ + user: options.User, logger: options.Logger, } } @@ -74,7 +76,7 @@ func (c *http2Connector) Connect(ctx context.Context, conn net.Conn, network, ad req.Header.Set("User-Agent", c.md.UserAgent) } - if user := c.md.User; user != nil { + if user := c.user; user != nil { u := user.Username() p, _ := user.Password() req.Header.Set("Proxy-Authorization", diff --git a/pkg/connector/http2/metadata.go b/pkg/connector/http2/metadata.go index 494c233..5e3396c 100644 --- a/pkg/connector/http2/metadata.go +++ b/pkg/connector/http2/metadata.go @@ -1,8 +1,6 @@ package http2 import ( - "net/url" - "strings" "time" mdata "github.com/go-gost/gost/pkg/metadata" @@ -15,14 +13,12 @@ const ( type metadata struct { connectTimeout time.Duration UserAgent string - User *url.Userinfo } func (c *http2Connector) parseMetadata(md mdata.Metadata) (err error) { const ( connectTimeout = "timeout" userAgent = "userAgent" - user = "user" ) c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) @@ -31,14 +27,5 @@ func (c *http2Connector) parseMetadata(md mdata.Metadata) (err error) { c.md.UserAgent = defaultUserAgent } - if v := mdata.GetString(md, user); v != "" { - ss := strings.SplitN(v, ":", 2) - if len(ss) == 1 { - c.md.User = url.User(ss[0]) - } else { - c.md.User = url.UserPassword(ss[0], ss[1]) - } - } - return } diff --git a/pkg/connector/option.go b/pkg/connector/option.go index 1832b5a..f1d377f 100644 --- a/pkg/connector/option.go +++ b/pkg/connector/option.go @@ -1,17 +1,25 @@ package connector import ( + "net/url" "time" "github.com/go-gost/gost/pkg/logger" ) type Options struct { + User *url.Userinfo Logger logger.Logger } type Option func(opts *Options) +func UserOption(user *url.Userinfo) Option { + return func(opts *Options) { + opts.User = user + } +} + func LoggerOption(logger logger.Logger) Option { return func(opts *Options) { opts.Logger = logger diff --git a/pkg/connector/relay/bind.go b/pkg/connector/relay/bind.go index 924697a..33b223c 100644 --- a/pkg/connector/relay/bind.go +++ b/pkg/connector/relay/bind.go @@ -79,10 +79,10 @@ func (c *relayConnector) bind(conn net.Conn, cmd uint8, network, address string) Flags: cmd, } - if c.md.user != nil { - pwd, _ := c.md.user.Password() + if c.user != nil { + pwd, _ := c.user.Password() req.Features = append(req.Features, &relay.UserAuthFeature{ - Username: c.md.user.Username(), + Username: c.user.Username(), Password: pwd, }) } diff --git a/pkg/connector/relay/connector.go b/pkg/connector/relay/connector.go index 3cb8fee..74e547d 100644 --- a/pkg/connector/relay/connector.go +++ b/pkg/connector/relay/connector.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/url" "time" "github.com/go-gost/gost/pkg/common/util/socks" @@ -19,6 +20,7 @@ func init() { } type relayConnector struct { + user *url.Userinfo logger logger.Logger md metadata } @@ -30,6 +32,7 @@ func NewConnector(opts ...connector.Option) connector.Connector { } return &relayConnector{ + user: options.User, logger: options.Logger, } } @@ -71,10 +74,10 @@ func (c *relayConnector) Connect(ctx context.Context, conn net.Conn, network, ad } } - if c.md.user != nil { - pwd, _ := c.md.user.Password() + if c.user != nil { + pwd, _ := c.user.Password() req.Features = append(req.Features, &relay.UserAuthFeature{ - Username: c.md.user.Username(), + Username: c.user.Username(), Password: pwd, }) } diff --git a/pkg/connector/relay/metadata.go b/pkg/connector/relay/metadata.go index fb66030..fcf4a73 100644 --- a/pkg/connector/relay/metadata.go +++ b/pkg/connector/relay/metadata.go @@ -1,8 +1,6 @@ package relay import ( - "net/url" - "strings" "time" mdata "github.com/go-gost/gost/pkg/metadata" @@ -10,25 +8,15 @@ import ( type metadata struct { connectTimeout time.Duration - user *url.Userinfo noDelay bool } func (c *relayConnector) parseMetadata(md mdata.Metadata) (err error) { const ( - user = "user" connectTimeout = "connectTimeout" noDelay = "nodelay" ) - if v := mdata.GetString(md, user); v != "" { - ss := strings.SplitN(v, ":", 2) - if len(ss) == 1 { - c.md.user = url.User(ss[0]) - } else { - c.md.user = url.UserPassword(ss[0], ss[1]) - } - } c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) c.md.noDelay = mdata.GetBool(md, noDelay) diff --git a/pkg/connector/socks/v4/connector.go b/pkg/connector/socks/v4/connector.go index e18eb12..b80abd0 100644 --- a/pkg/connector/socks/v4/connector.go +++ b/pkg/connector/socks/v4/connector.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "net/url" "strconv" "time" @@ -21,6 +22,7 @@ func init() { } type socks4Connector struct { + user *url.Userinfo md metadata logger logger.Logger } @@ -32,6 +34,7 @@ func NewConnector(opts ...connector.Option) connector.Connector { } return &socks4Connector{ + user: options.User, logger: options.Logger, } } @@ -96,7 +99,11 @@ func (c *socks4Connector) Connect(ctx context.Context, conn net.Conn, network, a defer conn.SetDeadline(time.Time{}) } - req := gosocks4.NewRequest(gosocks4.CmdConnect, addr, nil) + var userid []byte + if c.user != nil && c.user.Username() != "" { + userid = []byte(c.user.Username()) + } + req := gosocks4.NewRequest(gosocks4.CmdConnect, addr, userid) if err := req.Write(conn); err != nil { c.logger.Error(err) return nil, err diff --git a/pkg/connector/socks/v4/metadata.go b/pkg/connector/socks/v4/metadata.go index 54fb242..4bc4b2a 100644 --- a/pkg/connector/socks/v4/metadata.go +++ b/pkg/connector/socks/v4/metadata.go @@ -1,7 +1,6 @@ package v4 import ( - "net/url" "time" mdata "github.com/go-gost/gost/pkg/metadata" @@ -9,20 +8,15 @@ import ( type metadata struct { connectTimeout time.Duration - User *url.Userinfo disable4a bool } func (c *socks4Connector) parseMetadata(md mdata.Metadata) (err error) { const ( connectTimeout = "timeout" - user = "user" disable4a = "disable4a" ) - if v := mdata.GetString(md, user); v != "" { - c.md.User = url.User(v) - } c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) c.md.disable4a = mdata.GetBool(md, disable4a) diff --git a/pkg/connector/socks/v5/connector.go b/pkg/connector/socks/v5/connector.go index de5b5b2..db20e8e 100644 --- a/pkg/connector/socks/v5/connector.go +++ b/pkg/connector/socks/v5/connector.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "net/url" "time" "github.com/go-gost/gosocks5" @@ -23,6 +24,7 @@ func init() { type socks5Connector struct { selector gosocks5.Selector + user *url.Userinfo logger logger.Logger md metadata } @@ -34,6 +36,7 @@ func NewConnector(opts ...connector.Option) connector.Connector { } return &socks5Connector{ + user: options.User, logger: options.Logger, } } @@ -49,7 +52,7 @@ func (c *socks5Connector) Init(md md.Metadata) (err error) { gosocks5.MethodUserPass, }, logger: c.logger, - User: c.md.User, + User: c.user, TLSConfig: c.md.tlsConfig, } if !c.md.noTLS { diff --git a/pkg/connector/socks/v5/metadata.go b/pkg/connector/socks/v5/metadata.go index 259137a..0ef0903 100644 --- a/pkg/connector/socks/v5/metadata.go +++ b/pkg/connector/socks/v5/metadata.go @@ -2,8 +2,6 @@ package v5 import ( "crypto/tls" - "net/url" - "strings" "time" mdata "github.com/go-gost/gost/pkg/metadata" @@ -11,7 +9,6 @@ import ( type metadata struct { connectTimeout time.Duration - User *url.Userinfo tlsConfig *tls.Config noTLS bool } @@ -19,19 +16,9 @@ type metadata struct { func (c *socks5Connector) parseMetadata(md mdata.Metadata) (err error) { const ( connectTimeout = "timeout" - user = "user" noTLS = "notls" ) - if v := mdata.GetString(md, user); v != "" { - ss := strings.SplitN(v, ":", 2) - if len(ss) == 1 { - c.md.User = url.User(ss[0]) - } else { - c.md.User = url.UserPassword(ss[0], ss[1]) - } - } - c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) c.md.noTLS = mdata.GetBool(md, noTLS) diff --git a/pkg/connector/ss/connector.go b/pkg/connector/ss/connector.go index efb3375..b956861 100644 --- a/pkg/connector/ss/connector.go +++ b/pkg/connector/ss/connector.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/url" "time" "github.com/go-gost/gosocks5" @@ -13,6 +14,7 @@ import ( "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" + "github.com/shadowsocks/go-shadowsocks2/core" ) func init() { @@ -20,6 +22,8 @@ func init() { } type ssConnector struct { + user *url.Userinfo + cipher core.Cipher md metadata logger logger.Logger } @@ -31,12 +35,23 @@ func NewConnector(opts ...connector.Option) connector.Connector { } return &ssConnector{ + user: options.User, logger: options.Logger, } } func (c *ssConnector) Init(md md.Metadata) (err error) { - return c.parseMetadata(md) + if err = c.parseMetadata(md); err != nil { + return + } + + if c.user != nil { + method := c.user.Username() + password, _ := c.user.Password() + c.cipher, err = ss.ShadowCipher(method, password, c.md.key) + } + + return } func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { @@ -80,8 +95,8 @@ func (c *ssConnector) Connect(ctx context.Context, conn net.Conn, network, addre defer conn.SetDeadline(time.Time{}) } - if c.md.cipher != nil { - conn = c.md.cipher.StreamConn(conn) + if c.cipher != nil { + conn = c.cipher.StreamConn(conn) } var sc net.Conn diff --git a/pkg/connector/ss/metadata.go b/pkg/connector/ss/metadata.go index 052f544..b986103 100644 --- a/pkg/connector/ss/metadata.go +++ b/pkg/connector/ss/metadata.go @@ -1,42 +1,25 @@ package ss import ( - "strings" "time" - "github.com/go-gost/gost/pkg/common/util/ss" mdata "github.com/go-gost/gost/pkg/metadata" - "github.com/shadowsocks/go-shadowsocks2/core" ) type metadata struct { - cipher core.Cipher + key string connectTimeout time.Duration noDelay bool } func (c *ssConnector) parseMetadata(md mdata.Metadata) (err error) { const ( - user = "user" key = "key" connectTimeout = "timeout" noDelay = "nodelay" ) - var method, password string - if v := mdata.GetString(md, user); v != "" { - ss := strings.SplitN(v, ":", 2) - if len(ss) == 1 { - method = ss[0] - } else { - method, password = ss[0], ss[1] - } - } - c.md.cipher, err = ss.ShadowCipher(method, password, mdata.GetString(md, key)) - if err != nil { - return - } - + c.md.key = mdata.GetString(md, key) c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) c.md.noDelay = mdata.GetBool(md, noDelay) diff --git a/pkg/connector/ss/udp/connector.go b/pkg/connector/ss/udp/connector.go index 5d72ecc..e45bbe6 100644 --- a/pkg/connector/ss/udp/connector.go +++ b/pkg/connector/ss/udp/connector.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/url" "time" "github.com/go-gost/gost/pkg/common/util/socks" @@ -12,6 +13,7 @@ import ( "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" + "github.com/shadowsocks/go-shadowsocks2/core" ) func init() { @@ -19,6 +21,8 @@ func init() { } type ssuConnector struct { + user *url.Userinfo + cipher core.Cipher md metadata logger logger.Logger } @@ -30,12 +34,23 @@ func NewConnector(opts ...connector.Option) connector.Connector { } return &ssuConnector{ + user: options.User, logger: options.Logger, } } func (c *ssuConnector) Init(md md.Metadata) (err error) { - return c.parseMetadata(md) + if err = c.parseMetadata(md); err != nil { + return + } + + if c.user != nil { + method := c.user.Username() + password, _ := c.user.Password() + c.cipher, err = ss.ShadowCipher(method, password, c.md.key) + } + + return } func (c *ssuConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { @@ -67,16 +82,16 @@ func (c *ssuConnector) Connect(ctx context.Context, conn net.Conn, network, addr pc, ok := conn.(net.PacketConn) if ok { - if c.md.cipher != nil { - pc = c.md.cipher.PacketConn(pc) + if c.cipher != nil { + pc = c.cipher.PacketConn(pc) } // standard UDP relay return ss.UDPClientConn(pc, conn.RemoteAddr(), taddr, c.md.bufferSize), nil } - if c.md.cipher != nil { - conn = ss.ShadowConn(c.md.cipher.StreamConn(conn), nil) + if c.cipher != nil { + conn = ss.ShadowConn(c.cipher.StreamConn(conn), nil) } // UDP over TCP diff --git a/pkg/connector/ss/udp/metadata.go b/pkg/connector/ss/udp/metadata.go index 7291552..46ea479 100644 --- a/pkg/connector/ss/udp/metadata.go +++ b/pkg/connector/ss/udp/metadata.go @@ -2,42 +2,25 @@ package ss import ( "math" - "strings" "time" - "github.com/go-gost/gost/pkg/common/util/ss" mdata "github.com/go-gost/gost/pkg/metadata" - "github.com/shadowsocks/go-shadowsocks2/core" ) type metadata struct { - cipher core.Cipher + key string connectTimeout time.Duration bufferSize int } func (c *ssuConnector) parseMetadata(md mdata.Metadata) (err error) { const ( - user = "user" key = "key" connectTimeout = "timeout" bufferSize = "bufferSize" // udp buffer size ) - var method, password string - if v := mdata.GetString(md, user); v != "" { - ss := strings.SplitN(v, ":", 2) - if len(ss) == 1 { - method = ss[0] - } else { - method, password = ss[0], ss[1] - } - } - c.md.cipher, err = ss.ShadowCipher(method, password, mdata.GetString(md, key)) - if err != nil { - return - } - + c.md.key = mdata.GetString(md, key) c.md.connectTimeout = mdata.GetDuration(md, connectTimeout) if bs := mdata.GetInt(md, bufferSize); bs > 0 { diff --git a/pkg/dialer/forward/ssh/dialer.go b/pkg/dialer/forward/ssh/dialer.go index c4c3891..01c5d69 100644 --- a/pkg/dialer/forward/ssh/dialer.go +++ b/pkg/dialer/forward/ssh/dialer.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "net/url" "sync" "time" @@ -20,6 +21,7 @@ func init() { } type forwardDialer struct { + user *url.Userinfo sessions map[string]*sshSession sessionMutex sync.Mutex logger logger.Logger @@ -33,6 +35,7 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer { } return &forwardDialer{ + user: options.User, sessions: make(map[string]*sshSession), logger: options.Logger, } @@ -161,9 +164,9 @@ func (d *forwardDialer) initSession(ctx context.Context, addr string, conn net.C // Timeout: timeout, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - if d.md.user != nil { - config.User = d.md.user.Username() - if password, _ := d.md.user.Password(); password != "" { + if d.user != nil { + config.User = d.user.Username() + if password, _ := d.user.Password(); password != "" { config.Auth = []ssh.AuthMethod{ ssh.Password(password), } diff --git a/pkg/dialer/forward/ssh/metadata.go b/pkg/dialer/forward/ssh/metadata.go index c589f8a..d1ef591 100644 --- a/pkg/dialer/forward/ssh/metadata.go +++ b/pkg/dialer/forward/ssh/metadata.go @@ -2,8 +2,6 @@ package ssh import ( "io/ioutil" - "net/url" - "strings" "time" mdata "github.com/go-gost/gost/pkg/metadata" @@ -12,27 +10,16 @@ import ( type metadata struct { handshakeTimeout time.Duration - user *url.Userinfo signer ssh.Signer } func (d *forwardDialer) parseMetadata(md mdata.Metadata) (err error) { const ( handshakeTimeout = "handshakeTimeout" - user = "user" privateKeyFile = "privateKeyFile" passphrase = "passphrase" ) - if v := mdata.GetString(md, user); v != "" { - ss := strings.SplitN(v, ":", 2) - if len(ss) == 1 { - d.md.user = url.User(ss[0]) - } else { - d.md.user = url.UserPassword(ss[0], ss[1]) - } - } - if key := mdata.GetString(md, privateKeyFile); key != "" { data, err := ioutil.ReadFile(key) if err != nil { diff --git a/pkg/dialer/option.go b/pkg/dialer/option.go index ceee723..5e97cd0 100644 --- a/pkg/dialer/option.go +++ b/pkg/dialer/option.go @@ -3,16 +3,24 @@ package dialer import ( "context" "net" + "net/url" "github.com/go-gost/gost/pkg/logger" ) type Options struct { + User *url.Userinfo Logger logger.Logger } type Option func(opts *Options) +func UserOption(user *url.Userinfo) Option { + return func(opts *Options) { + opts.User = user + } +} + func LoggerOption(logger logger.Logger) Option { return func(opts *Options) { opts.Logger = logger diff --git a/pkg/handler/forward/ssh/handler.go b/pkg/handler/forward/ssh/handler.go index c06d9bc..8846763 100644 --- a/pkg/handler/forward/ssh/handler.go +++ b/pkg/handler/forward/ssh/handler.go @@ -8,6 +8,7 @@ import ( "strconv" "time" + "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" @@ -31,11 +32,12 @@ func init() { } type forwardHandler struct { - bypass bypass.Bypass - config *ssh.ServerConfig - router *chain.Router - logger logger.Logger - md metadata + bypass bypass.Bypass + config *ssh.ServerConfig + router *chain.Router + authenticator auth.Authenticator + logger logger.Logger + md metadata } func NewHandler(opts ...handler.Option) handler.Handler { @@ -57,13 +59,13 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { } config := &ssh.ServerConfig{ - PasswordCallback: ssh_util.PasswordCallback(h.md.authenticator), + PasswordCallback: ssh_util.PasswordCallback(h.authenticator), PublicKeyCallback: ssh_util.PublicKeyCallback(h.md.authorizedKeys), } config.AddHostKey(h.md.signer) - if h.md.authenticator == nil && len(h.md.authorizedKeys) == 0 { + if h.authenticator == nil && len(h.md.authorizedKeys) == 0 { config.NoClientAuth = true } diff --git a/pkg/handler/forward/ssh/metadata.go b/pkg/handler/forward/ssh/metadata.go index bf98f57..9f8a4bd 100644 --- a/pkg/handler/forward/ssh/metadata.go +++ b/pkg/handler/forward/ssh/metadata.go @@ -2,9 +2,7 @@ package ssh import ( "io/ioutil" - "strings" - "github.com/go-gost/gost/pkg/auth" tls_util "github.com/go-gost/gost/pkg/common/util/tls" ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" mdata "github.com/go-gost/gost/pkg/metadata" @@ -12,32 +10,17 @@ import ( ) type metadata struct { - authenticator auth.Authenticator signer ssh.Signer authorizedKeys map[string]bool } func (h *forwardHandler) parseMetadata(md mdata.Metadata) (err error) { const ( - users = "users" authorizedKeys = "authorizedKeys" privateKeyFile = "privateKeyFile" passphrase = "passphrase" ) - if auths := mdata.GetStrings(md, users); len(auths) > 0 { - authenticator := auth.NewLocalAuthenticator(nil) - for _, auth := range auths { - ss := strings.SplitN(auth, ":", 2) - if len(ss) == 1 { - authenticator.Add(ss[0], "") - } else { - authenticator.Add(ss[0], ss[1]) - } - } - h.md.authenticator = authenticator - } - if key := mdata.GetString(md, privateKeyFile); key != "" { data, err := ioutil.ReadFile(key) if err != nil { diff --git a/pkg/handler/http/handler.go b/pkg/handler/http/handler.go index 4a580ac..1f732a2 100644 --- a/pkg/handler/http/handler.go +++ b/pkg/handler/http/handler.go @@ -16,6 +16,7 @@ import ( "time" "github.com/asaskevich/govalidator" + "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" @@ -29,10 +30,11 @@ func init() { } type httpHandler struct { - bypass bypass.Bypass - router *chain.Router - logger logger.Logger - md metadata + bypass bypass.Bypass + router *chain.Router + authenticator auth.Authenticator + logger logger.Logger + md metadata } func NewHandler(opts ...handler.Option) handler.Handler { @@ -260,7 +262,7 @@ func (h *httpHandler) basicProxyAuth(proxyAuth string) (username, password strin func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.Response) (ok bool) { u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")) - if h.md.authenticator == nil || h.md.authenticator.Authenticate(u, p) { + if h.authenticator == nil || h.authenticator.Authenticate(u, p) { return true } diff --git a/pkg/handler/http/metadata.go b/pkg/handler/http/metadata.go index f6aecd2..58b5893 100644 --- a/pkg/handler/http/metadata.go +++ b/pkg/handler/http/metadata.go @@ -4,41 +4,25 @@ import ( "net/http" "strings" - "github.com/go-gost/gost/pkg/auth" mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { - authenticator auth.Authenticator - probeResist *probeResist - sni bool - enableUDP bool - header http.Header + probeResist *probeResist + sni bool + enableUDP bool + header http.Header } func (h *httpHandler) parseMetadata(md mdata.Metadata) error { const ( header = "header" - users = "users" probeResistKey = "probeResist" knock = "knock" sni = "sni" enableUDP = "udp" ) - if auths := mdata.GetStrings(md, users); len(auths) > 0 { - authenticator := auth.NewLocalAuthenticator(nil) - for _, auth := range auths { - ss := strings.SplitN(auth, ":", 2) - if len(ss) == 1 { - authenticator.Add(ss[0], "") - } else { - authenticator.Add(ss[0], ss[1]) - } - } - h.md.authenticator = authenticator - } - if m := mdata.GetStringMapString(md, header); len(m) > 0 { hd := http.Header{} for k, v := range m { diff --git a/pkg/handler/http2/handler.go b/pkg/handler/http2/handler.go index 315aefa..1e06e03 100644 --- a/pkg/handler/http2/handler.go +++ b/pkg/handler/http2/handler.go @@ -15,6 +15,7 @@ import ( "time" "github.com/asaskevich/govalidator" + "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" @@ -29,10 +30,11 @@ func init() { } type http2Handler struct { - bypass bypass.Bypass - router *chain.Router - logger logger.Logger - md metadata + bypass bypass.Bypass + router *chain.Router + authenticator auth.Authenticator + logger logger.Logger + md metadata } func NewHandler(opts ...handler.Option) handler.Handler { @@ -392,7 +394,7 @@ func (h *http2Handler) basicProxyAuth(proxyAuth string) (username, password stri func (h *http2Handler) authenticate(conn net.Conn, req *http.Request, resp *http.Response) (ok bool) { u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")) - if h.md.authenticator == nil || h.md.authenticator.Authenticate(u, p) { + if h.authenticator == nil || h.authenticator.Authenticate(u, p) { return true } diff --git a/pkg/handler/http2/metadata.go b/pkg/handler/http2/metadata.go index 5093a78..63cccd7 100644 --- a/pkg/handler/http2/metadata.go +++ b/pkg/handler/http2/metadata.go @@ -3,22 +3,19 @@ package http2 import ( "strings" - "github.com/go-gost/gost/pkg/auth" mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { - authenticator auth.Authenticator - proxyAgent string - probeResist *probeResist - sni bool - enableUDP bool + proxyAgent string + probeResist *probeResist + sni bool + enableUDP bool } func (h *http2Handler) parseMetadata(md mdata.Metadata) error { const ( proxyAgent = "proxyAgent" - users = "users" probeResistKey = "probeResist" knock = "knock" sni = "sni" @@ -27,19 +24,6 @@ func (h *http2Handler) parseMetadata(md mdata.Metadata) error { h.md.proxyAgent = mdata.GetString(md, proxyAgent) - if auths := mdata.GetStrings(md, users); len(auths) > 0 { - authenticator := auth.NewLocalAuthenticator(nil) - for _, auth := range auths { - ss := strings.SplitN(auth, ":", 2) - if len(ss) == 1 { - authenticator.Add(ss[0], "") - } else { - authenticator.Add(ss[0], ss[1]) - } - } - h.md.authenticator = authenticator - } - if v := mdata.GetString(md, probeResistKey); v != "" { if ss := strings.SplitN(v, ":", 2); len(ss) == 2 { h.md.probeResist = &probeResist{ diff --git a/pkg/handler/option.go b/pkg/handler/option.go index 2ef067e..f3eb8ec 100644 --- a/pkg/handler/option.go +++ b/pkg/handler/option.go @@ -1,6 +1,7 @@ package handler import ( + "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/logger" @@ -8,10 +9,11 @@ import ( ) type Options struct { - Router *chain.Router - Bypass bypass.Bypass - Resolver resolver.Resolver - Logger logger.Logger + Router *chain.Router + Bypass bypass.Bypass + Resolver resolver.Resolver + Authenticator auth.Authenticator + Logger logger.Logger } type Option func(opts *Options) @@ -28,6 +30,12 @@ func BypassOption(bypass bypass.Bypass) Option { } } +func AuthenticatorOption(auth auth.Authenticator) Option { + return func(opts *Options) { + opts.Authenticator = auth + } +} + func LoggerOption(logger logger.Logger) Option { return func(opts *Options) { opts.Logger = logger diff --git a/pkg/handler/relay/handler.go b/pkg/handler/relay/handler.go index 475bc98..f4cad73 100644 --- a/pkg/handler/relay/handler.go +++ b/pkg/handler/relay/handler.go @@ -6,6 +6,7 @@ import ( "strconv" "time" + "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" @@ -20,11 +21,12 @@ func init() { } type relayHandler struct { - group *chain.NodeGroup - bypass bypass.Bypass - router *chain.Router - logger logger.Logger - md metadata + group *chain.NodeGroup + bypass bypass.Bypass + router *chain.Router + authenticator auth.Authenticator + logger logger.Logger + md metadata } func NewHandler(opts ...handler.Option) handler.Handler { @@ -107,7 +109,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn) { Version: relay.Version1, Status: relay.StatusOK, } - if h.md.authenticator != nil && !h.md.authenticator.Authenticate(user, pass) { + if h.authenticator != nil && !h.authenticator.Authenticate(user, pass) { resp.Status = relay.StatusUnauthorized resp.WriteTo(conn) h.logger.Error("unauthorized") diff --git a/pkg/handler/relay/metadata.go b/pkg/handler/relay/metadata.go index 3564a0f..d516726 100644 --- a/pkg/handler/relay/metadata.go +++ b/pkg/handler/relay/metadata.go @@ -2,15 +2,12 @@ package relay import ( "math" - "strings" "time" - "github.com/go-gost/gost/pkg/auth" mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { - authenticator auth.Authenticator readTimeout time.Duration enableBind bool udpBufferSize int @@ -19,26 +16,12 @@ type metadata struct { func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) { const ( - users = "users" readTimeout = "readTimeout" enableBind = "bind" udpBufferSize = "udpBufferSize" noDelay = "nodelay" ) - if auths := mdata.GetStrings(md, users); len(auths) > 0 { - authenticator := auth.NewLocalAuthenticator(nil) - for _, auth := range auths { - ss := strings.SplitN(auth, ":", 2) - if len(ss) == 1 { - authenticator.Add(ss[0], "") - } else { - authenticator.Add(ss[0], ss[1]) - } - } - h.md.authenticator = authenticator - } - h.md.readTimeout = mdata.GetDuration(md, readTimeout) h.md.enableBind = mdata.GetBool(md, enableBind) h.md.noDelay = mdata.GetBool(md, noDelay) diff --git a/pkg/handler/socks/v4/handler.go b/pkg/handler/socks/v4/handler.go index db46b2c..8332835 100644 --- a/pkg/handler/socks/v4/handler.go +++ b/pkg/handler/socks/v4/handler.go @@ -6,6 +6,7 @@ import ( "time" "github.com/go-gost/gosocks4" + "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" @@ -20,10 +21,11 @@ func init() { } type socks4Handler struct { - bypass bypass.Bypass - router *chain.Router - logger logger.Logger - md metadata + bypass bypass.Bypass + router *chain.Router + authenticator auth.Authenticator + logger logger.Logger + md metadata } func NewHandler(opts ...handler.Option) handler.Handler { @@ -77,8 +79,8 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn) { conn.SetReadDeadline(time.Time{}) - if h.md.authenticator != nil && - !h.md.authenticator.Authenticate(string(req.Userid), "") { + if h.authenticator != nil && + !h.authenticator.Authenticate(string(req.Userid), "") { resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil) resp.Write(conn) h.logger.Debug(resp) diff --git a/pkg/handler/socks/v4/metadata.go b/pkg/handler/socks/v4/metadata.go index d6ab966..29eec7f 100644 --- a/pkg/handler/socks/v4/metadata.go +++ b/pkg/handler/socks/v4/metadata.go @@ -3,31 +3,18 @@ package v4 import ( "time" - "github.com/go-gost/gost/pkg/auth" mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { - authenticator auth.Authenticator - readTimeout time.Duration + readTimeout time.Duration } func (h *socks4Handler) parseMetadata(md mdata.Metadata) (err error) { const ( - users = "users" readTimeout = "readTimeout" ) - if auths := mdata.GetStrings(md, users); len(auths) > 0 { - authenticator := auth.NewLocalAuthenticator(nil) - for _, auth := range auths { - if auth != "" { - authenticator.Add(auth, "") - } - } - h.md.authenticator = authenticator - } - h.md.readTimeout = mdata.GetDuration(md, readTimeout) return } diff --git a/pkg/handler/socks/v5/handler.go b/pkg/handler/socks/v5/handler.go index 84980c3..6cca33c 100644 --- a/pkg/handler/socks/v5/handler.go +++ b/pkg/handler/socks/v5/handler.go @@ -6,6 +6,7 @@ import ( "time" "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/bypass" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/common/util/socks" @@ -21,11 +22,12 @@ func init() { } type socks5Handler struct { - selector gosocks5.Selector - bypass bypass.Bypass - router *chain.Router - logger logger.Logger - md metadata + selector gosocks5.Selector + bypass bypass.Bypass + router *chain.Router + authenticator auth.Authenticator + logger logger.Logger + md metadata } func NewHandler(opts ...handler.Option) handler.Handler { @@ -47,7 +49,7 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) { } h.selector = &serverSelector{ - Authenticator: h.md.authenticator, + Authenticator: h.authenticator, TLSConfig: h.md.tlsConfig, logger: h.logger, noTLS: h.md.noTLS, diff --git a/pkg/handler/socks/v5/metadata.go b/pkg/handler/socks/v5/metadata.go index a49c939..22b098a 100644 --- a/pkg/handler/socks/v5/metadata.go +++ b/pkg/handler/socks/v5/metadata.go @@ -3,17 +3,14 @@ package v5 import ( "crypto/tls" "math" - "strings" "time" - "github.com/go-gost/gost/pkg/auth" tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { tlsConfig *tls.Config - authenticator auth.Authenticator timeout time.Duration readTimeout time.Duration noTLS bool @@ -28,7 +25,6 @@ func (h *socks5Handler) parseMetadata(md mdata.Metadata) (err error) { certFile = "certFile" keyFile = "keyFile" caFile = "caFile" - users = "users" readTimeout = "readTimeout" timeout = "timeout" noTLS = "notls" @@ -47,19 +43,6 @@ func (h *socks5Handler) parseMetadata(md mdata.Metadata) (err error) { return } - if auths := mdata.GetStrings(md, users); len(auths) > 0 { - authenticator := auth.NewLocalAuthenticator(nil) - for _, auth := range auths { - ss := strings.SplitN(auth, ":", 2) - if len(ss) == 1 { - authenticator.Add(ss[0], "") - } else { - authenticator.Add(ss[0], ss[1]) - } - } - h.md.authenticator = authenticator - } - h.md.readTimeout = mdata.GetDuration(md, readTimeout) h.md.timeout = mdata.GetDuration(md, timeout) h.md.noTLS = mdata.GetBool(md, noTLS) diff --git a/pkg/listener/option.go b/pkg/listener/option.go index ceb6b88..1f25865 100644 --- a/pkg/listener/option.go +++ b/pkg/listener/option.go @@ -1,12 +1,14 @@ package listener import ( + "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/logger" ) type Options struct { - Addr string - Logger logger.Logger + Addr string + Authenticator auth.Authenticator + Logger logger.Logger } type Option func(opts *Options) @@ -17,6 +19,12 @@ func AddrOption(addr string) Option { } } +func AuthenticatorOption(auth auth.Authenticator) Option { + return func(opts *Options) { + opts.Authenticator = auth + } +} + func LoggerOption(logger logger.Logger) Option { return func(opts *Options) { opts.Logger = logger diff --git a/pkg/listener/ssh/listener.go b/pkg/listener/ssh/listener.go index e57c67b..5ba1ff9 100644 --- a/pkg/listener/ssh/listener.go +++ b/pkg/listener/ssh/listener.go @@ -4,6 +4,7 @@ import ( "fmt" "net" + "github.com/go-gost/gost/pkg/auth" ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" @@ -19,11 +20,12 @@ func init() { type sshListener struct { addr string net.Listener - config *ssh.ServerConfig - cqueue chan net.Conn - errChan chan error - logger logger.Logger - md metadata + config *ssh.ServerConfig + authenticator auth.Authenticator + cqueue chan net.Conn + errChan chan error + logger logger.Logger + md metadata } func NewListener(opts ...listener.Option) listener.Listener { @@ -50,13 +52,13 @@ func (l *sshListener) Init(md md.Metadata) (err error) { l.Listener = ln config := &ssh.ServerConfig{ - PasswordCallback: ssh_util.PasswordCallback(l.md.authenticator), + PasswordCallback: ssh_util.PasswordCallback(l.authenticator), PublicKeyCallback: ssh_util.PublicKeyCallback(l.md.authorizedKeys), } config.AddHostKey(l.md.signer) - if l.md.authenticator == nil && len(l.md.authorizedKeys) == 0 { + if l.authenticator == nil && len(l.md.authorizedKeys) == 0 { config.NoClientAuth = true } diff --git a/pkg/listener/ssh/metadata.go b/pkg/listener/ssh/metadata.go index 8ecf72f..8917143 100644 --- a/pkg/listener/ssh/metadata.go +++ b/pkg/listener/ssh/metadata.go @@ -2,9 +2,7 @@ package ssh import ( "io/ioutil" - "strings" - "github.com/go-gost/gost/pkg/auth" tls_util "github.com/go-gost/gost/pkg/common/util/tls" ssh_util "github.com/go-gost/gost/pkg/internal/util/ssh" mdata "github.com/go-gost/gost/pkg/metadata" @@ -16,7 +14,6 @@ const ( ) type metadata struct { - authenticator auth.Authenticator signer ssh.Signer authorizedKeys map[string]bool backlog int @@ -24,26 +21,12 @@ type metadata struct { func (l *sshListener) parseMetadata(md mdata.Metadata) (err error) { const ( - users = "users" authorizedKeys = "authorizedKeys" privateKeyFile = "privateKeyFile" passphrase = "passphrase" backlog = "backlog" ) - if auths := mdata.GetStrings(md, users); len(auths) > 0 { - authenticator := auth.NewLocalAuthenticator(nil) - for _, auth := range auths { - ss := strings.SplitN(auth, ":", 2) - if len(ss) == 1 { - authenticator.Add(ss[0], "") - } else { - authenticator.Add(ss[0], ss[1]) - } - } - l.md.authenticator = authenticator - } - if key := mdata.GetString(md, privateKeyFile); key != "" { data, err := ioutil.ReadFile(key) if err != nil {