From 3b48c4acfb5cf123cfc374b8df00d4059e38bd6c Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 5 Jan 2022 00:02:55 +0800 Subject: [PATCH] add tls config option --- cmd/gost/cmd.go | 119 ++++++++++++++++++++++++---- cmd/gost/config.go | 71 ++++++++++++++--- cmd/gost/tls.go | 8 ++ gost.yml | 6 +- pkg/config/config.go | 16 ++-- pkg/connector/option.go | 12 ++- pkg/connector/socks/v5/connector.go | 2 +- pkg/connector/socks/v5/metadata.go | 2 - pkg/dialer/http2/dialer.go | 10 ++- pkg/dialer/http2/h2/dialer.go | 17 ++-- pkg/dialer/http2/h2/metadata.go | 31 ++------ pkg/dialer/http2/metadata.go | 25 ------ pkg/dialer/option.go | 12 ++- pkg/dialer/quic/dialer.go | 8 +- pkg/dialer/quic/metadata.go | 22 ----- pkg/dialer/tls/dialer.go | 14 ++-- pkg/dialer/tls/metadata.go | 22 ----- pkg/dialer/tls/mux/dialer.go | 8 +- pkg/dialer/tls/mux/metadata.go | 21 ----- pkg/dialer/ws/dialer.go | 15 ++-- pkg/dialer/ws/metadata.go | 32 ++------ pkg/dialer/ws/mux/dialer.go | 17 ++-- pkg/dialer/ws/mux/metadata.go | 32 ++------ pkg/handler/option.go | 22 +++-- pkg/handler/socks/v5/handler.go | 2 +- pkg/handler/socks/v5/metadata.go | 15 ---- pkg/listener/dns/listener.go | 26 +++--- pkg/listener/dns/metadata.go | 15 ---- pkg/listener/http2/h2/listener.go | 28 +++---- pkg/listener/http2/h2/metadata.go | 24 +----- pkg/listener/http2/listener.go | 18 ++--- pkg/listener/http2/metadata.go | 15 ---- pkg/listener/option.go | 14 +++- pkg/listener/quic/listener.go | 14 ++-- pkg/listener/quic/metadata.go | 16 ---- pkg/listener/tls/listener.go | 18 ++--- pkg/listener/tls/metadata.go | 19 ----- pkg/listener/tls/mux/listener.go | 14 ++-- pkg/listener/tls/mux/metadata.go | 17 ---- pkg/listener/ws/listener.go | 24 +++--- pkg/listener/ws/metadata.go | 20 +---- pkg/listener/ws/mux/listener.go | 26 +++--- pkg/listener/ws/mux/metadata.go | 22 +---- 43 files changed, 395 insertions(+), 496 deletions(-) diff --git a/cmd/gost/cmd.go b/cmd/gost/cmd.go index 2b6b8a2..8df18d9 100644 --- a/cmd/gost/cmd.go +++ b/cmd/gost/cmd.go @@ -1,6 +1,7 @@ package main import ( + "encoding/base64" "errors" "fmt" "net/url" @@ -115,21 +116,46 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) { } } + var auths []*config.AuthConfig + if url.User != nil { + auth := &config.AuthConfig{ + Username: url.User.Username(), + } + auth.Password, _ = url.User.Password() + auths = append(auths, auth) + } + md := make(map[string]interface{}) for k, v := range url.Query() { if len(v) > 0 { md[k] = v[0] } } - - var auths []config.AuthConfig - if url.User != nil { - auth := config.AuthConfig{ - Username: url.User.Username(), + if sauth := md["auth"]; sauth != nil { + if sa, _ := sauth.(string); sa != "" { + au, err := parseAuthFromCmd(sa) + if err != nil { + return nil, err + } + auths = append(auths, au) } - auth.Password, _ = url.User.Password() - auths = append(auths, auth) } + delete(md, "auth") + + var tlsConfig *config.TLSConfig + if certs := md["cert"]; certs != nil { + cert, _ := certs.(string) + key, _ := md["key"].(string) + ca, _ := md["ca"].(string) + tlsConfig = &config.TLSConfig{ + Cert: cert, + Key: key, + CA: ca, + } + } + delete(md, "cert") + delete(md, "key") + delete(md, "ca") svc.Handler = &config.HandlerConfig{ Type: handler, @@ -138,6 +164,7 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) { } svc.Listener = &config.ListenerConfig{ Type: listener, + TLS: tlsConfig, Metadata: md, } @@ -170,14 +197,6 @@ func buildNodeConfig(url *url.URL) (*config.NodeConfig, error) { } } - md := make(map[string]interface{}) - for k, v := range url.Query() { - if len(v) > 0 { - md[k] = v[0] - } - } - md["serverName"] = url.Host - var auth *config.AuthConfig if url.User != nil { auth = &config.AuthConfig{ @@ -186,6 +205,46 @@ func buildNodeConfig(url *url.URL) (*config.NodeConfig, error) { auth.Password, _ = url.User.Password() } + md := make(map[string]interface{}) + for k, v := range url.Query() { + if len(v) > 0 { + md[k] = v[0] + } + } + md["serverName"] = url.Host + + if sauth := md["auth"]; sauth != nil && auth == nil { + if sa, _ := sauth.(string); sa != "" { + au, err := parseAuthFromCmd(sa) + if err != nil { + return nil, err + } + auth = au + } + } + delete(md, "auth") + + var tlsConfig *config.TLSConfig + if certs := md["cert"]; certs != nil { + cert, _ := certs.(string) + key, _ := md["key"].(string) + ca, _ := md["ca"].(string) + secure, _ := md["secure"].(bool) + serverName, _ := md["serverName"].(string) + tlsConfig = &config.TLSConfig{ + Cert: cert, + Key: key, + CA: ca, + Secure: secure, + ServerName: serverName, + } + } + delete(md, "cert") + delete(md, "key") + delete(md, "ca") + delete(md, "secure") + delete(md, "serverName") + node.Connector = &config.ConnectorConfig{ Type: connector, Auth: auth, @@ -193,6 +252,7 @@ func buildNodeConfig(url *url.URL) (*config.NodeConfig, error) { } node.Dialer = &config.DialerConfig{ Type: dialer, + TLS: tlsConfig, Metadata: md, } @@ -209,5 +269,32 @@ func normCmd(s string) (*url.URL, error) { s = "auto://" + s } - return url.Parse(s) + url, err := url.Parse(s) + if err != nil { + return nil, err + } + if url.Scheme == "https" { + url.Scheme = "http+tls" + } + + return url, nil +} + +func parseAuthFromCmd(sa string) (*config.AuthConfig, error) { + v, err := base64.StdEncoding.DecodeString(sa) + if err != nil { + return nil, err + } + cs := string(v) + n := strings.IndexByte(cs, ':') + if n < 0 { + return &config.AuthConfig{ + Username: cs, + }, nil + } + + return &config.AuthConfig{ + Username: cs[:n], + Password: cs[n+1:], + }, nil } diff --git a/cmd/gost/config.go b/cmd/gost/config.go index 0155c55..bb93020 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -1,6 +1,7 @@ package main import ( + "crypto/tls" "io" "net" "net/url" @@ -68,9 +69,23 @@ func buildService(cfg *config.Config) (services []*service.Service) { listenerLogger := serviceLogger.WithFields(map[string]interface{}{ "kind": "listener", }) + + var tlsConfig *tls.Config + var err error + + tlsCfg := svc.Listener.TLS + if tlsCfg == nil { + tlsCfg = &config.TLSConfig{} + } + tlsConfig, err = loadServerTLSConfig(tlsCfg) + if err != nil { + log.Fatal(err) + } + ln := registry.GetListener(svc.Listener.Type)( listener.AddrOption(svc.Addr), listener.AuthsOption(authsFromConfig(svc.Listener.Auths...)...), + listener.TLSConfigOption(tlsConfig), listener.LoggerOption(listenerLogger), ) @@ -89,6 +104,16 @@ func buildService(cfg *config.Config) (services []*service.Service) { "kind": "handler", }) + tlsConfig = nil + tlsCfg = svc.Handler.TLS + if tlsCfg == nil { + tlsCfg = &config.TLSConfig{} + } + tlsConfig, err = loadServerTLSConfig(tlsCfg) + if err != nil { + log.Fatal(err) + } + h := registry.GetHandler(svc.Handler.Type)( handler.RetriesOption(svc.Handler.Retries), handler.ChainOption(chains[svc.Handler.Chain]), @@ -96,6 +121,7 @@ func buildService(cfg *config.Config) (services []*service.Service) { handler.HostsOption(hosts[svc.Handler.Hosts]), handler.BypassOption(bypasses[svc.Handler.Bypass]), handler.AuthsOption(authsFromConfig(svc.Handler.Auths...)...), + handler.TLSConfigOption(tlsConfig), handler.LoggerOption(handlerLogger), ) @@ -148,16 +174,29 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { "kind": "connector", }) - var connectorUser *url.Userinfo + var user *url.Userinfo if auth := v.Connector.Auth; auth != nil && auth.Username != "" { if auth.Password == "" { - connectorUser = url.User(auth.Username) + user = url.User(auth.Username) } else { - connectorUser = url.UserPassword(auth.Username, auth.Password) + user = url.UserPassword(auth.Username, auth.Password) } } + + var tlsConfig *tls.Config + var err error + tlsCfg := v.Connector.TLS + if tlsCfg == nil { + tlsCfg = &config.TLSConfig{} + } + tlsConfig, err = loadClientTLSConfig(tlsCfg) + if err != nil { + log.Fatal(err) + } + cr := registry.GetConnector(v.Connector.Type)( - connector.UserOption(connectorUser), + connector.UserOption(user), + connector.TLSConfigOption(tlsConfig), connector.LoggerOption(connectorLogger), ) @@ -172,16 +211,28 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { "kind": "dialer", }) - var dialerUser *url.Userinfo + user = nil if auth := v.Dialer.Auth; auth != nil && auth.Username != "" { if auth.Password == "" { - dialerUser = url.User(auth.Username) + user = url.User(auth.Username) } else { - dialerUser = url.UserPassword(auth.Username, auth.Password) + user = url.UserPassword(auth.Username, auth.Password) } } + + tlsConfig = nil + tlsCfg = v.Dialer.TLS + if tlsCfg == nil { + tlsCfg = &config.TLSConfig{} + } + tlsConfig, err = loadClientTLSConfig(tlsCfg) + if err != nil { + log.Fatal(err) + } + d := registry.GetDialer(v.Dialer.Type)( - dialer.UserOption(dialerUser), + dialer.UserOption(user), + dialer.TLSConfigOption(tlsConfig), dialer.LoggerOption(dialerLogger), ) @@ -328,11 +379,11 @@ func hostsFromConfig(cfg *config.HostsConfig) hostspkg.HostMapper { return hosts } -func authsFromConfig(cfgs ...config.AuthConfig) []*url.Userinfo { +func authsFromConfig(cfgs ...*config.AuthConfig) []*url.Userinfo { var auths []*url.Userinfo for _, cfg := range cfgs { - if cfg.Username == "" { + if cfg == nil || cfg.Username == "" { continue } auths = append(auths, url.UserPassword(cfg.Username, cfg.Password)) diff --git a/cmd/gost/tls.go b/cmd/gost/tls.go index 53dcc57..df18c0a 100644 --- a/cmd/gost/tls.go +++ b/cmd/gost/tls.go @@ -14,6 +14,14 @@ import ( "github.com/go-gost/gost/pkg/config" ) +func loadServerTLSConfig(cfg *config.TLSConfig) (*tls.Config, error) { + return tls_util.LoadServerConfig(cfg.Cert, cfg.Key, cfg.CA) +} + +func loadClientTLSConfig(cfg *config.TLSConfig) (*tls.Config, error) { + return tls_util.LoadClientConfig(cfg.Cert, cfg.Key, cfg.CA, cfg.Secure, cfg.ServerName) +} + func buildDefaultTLSConfig(cfg *config.TLSConfig) { if cfg == nil { cfg = &config.TLSConfig{ diff --git a/gost.yml b/gost.yml index e86a2c2..26678fc 100644 --- a/gost.yml +++ b/gost.yml @@ -283,9 +283,9 @@ bypasses: # http://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml - 224.0.0.0/4 # RFC5771: Multicast/Reserved -# tls: -# cert: "cert.pem" -# key: "key.pem" +tls: + cert: "cert.pem" + key: "key.pem" # ca: "root.ca" resolvers: diff --git a/pkg/config/config.go b/pkg/config/config.go index 4c61dd9..8d6d286 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -31,9 +31,11 @@ type ProfilingConfig struct { } type TLSConfig struct { - Cert string - Key string - CA string + Cert string + Key string + CA string `yaml:",omitempty"` + Secure bool `yaml:",omitempty"` + ServerName string `yaml:",omitempty"` } type AuthConfig struct { @@ -82,7 +84,8 @@ type HostsConfig struct { type ListenerConfig struct { Type string Chain string `yaml:",omitempty"` - Auths []AuthConfig `yaml:",omitempty"` + Auths []*AuthConfig `yaml:",omitempty"` + TLS *TLSConfig `yaml:",omitempty"` Metadata map[string]interface{} `yaml:",omitempty"` } @@ -93,7 +96,8 @@ type HandlerConfig struct { Bypass string `yaml:",omitempty"` Resolver string `yaml:",omitempty"` Hosts string `yaml:",omitempty"` - Auths []AuthConfig `yaml:",omitempty"` + Auths []*AuthConfig `yaml:",omitempty"` + TLS *TLSConfig `yaml:",omitempty"` Metadata map[string]interface{} `yaml:",omitempty"` } @@ -105,12 +109,14 @@ type ForwarderConfig struct { type DialerConfig struct { Type string Auth *AuthConfig `yaml:",omitempty"` + TLS *TLSConfig `yaml:",omitempty"` Metadata map[string]interface{} `yaml:",omitempty"` } type ConnectorConfig struct { Type string Auth *AuthConfig `yaml:",omitempty"` + TLS *TLSConfig `yaml:",omitempty"` Metadata map[string]interface{} `yaml:",omitempty"` } diff --git a/pkg/connector/option.go b/pkg/connector/option.go index f1d377f..d9014fa 100644 --- a/pkg/connector/option.go +++ b/pkg/connector/option.go @@ -1,6 +1,7 @@ package connector import ( + "crypto/tls" "net/url" "time" @@ -8,8 +9,9 @@ import ( ) type Options struct { - User *url.Userinfo - Logger logger.Logger + User *url.Userinfo + TLSConfig *tls.Config + Logger logger.Logger } type Option func(opts *Options) @@ -20,6 +22,12 @@ func UserOption(user *url.Userinfo) Option { } } +func TLSConfigOption(tlsConfig *tls.Config) Option { + return func(opts *Options) { + opts.TLSConfig = tlsConfig + } +} + func LoggerOption(logger logger.Logger) Option { return func(opts *Options) { opts.Logger = logger diff --git a/pkg/connector/socks/v5/connector.go b/pkg/connector/socks/v5/connector.go index 1a4abf9..5b91b71 100644 --- a/pkg/connector/socks/v5/connector.go +++ b/pkg/connector/socks/v5/connector.go @@ -53,7 +53,7 @@ func (c *socks5Connector) Init(md md.Metadata) (err error) { }, logger: c.logger, User: c.options.User, - TLSConfig: c.md.tlsConfig, + TLSConfig: c.options.TLSConfig, } if !c.md.noTLS { selector.methods = append(selector.methods, socks.MethodTLS) diff --git a/pkg/connector/socks/v5/metadata.go b/pkg/connector/socks/v5/metadata.go index 0ef0903..2f59a8e 100644 --- a/pkg/connector/socks/v5/metadata.go +++ b/pkg/connector/socks/v5/metadata.go @@ -1,7 +1,6 @@ package v5 import ( - "crypto/tls" "time" mdata "github.com/go-gost/gost/pkg/metadata" @@ -9,7 +8,6 @@ import ( type metadata struct { connectTimeout time.Duration - tlsConfig *tls.Config noTLS bool } diff --git a/pkg/dialer/http2/dialer.go b/pkg/dialer/http2/dialer.go index 9ec9d60..2790948 100644 --- a/pkg/dialer/http2/dialer.go +++ b/pkg/dialer/http2/dialer.go @@ -19,21 +19,23 @@ func init() { } type http2Dialer struct { - md metadata clients map[string]*http.Client clientMutex sync.Mutex logger logger.Logger + md metadata + options dialer.Options } func NewDialer(opts ...dialer.Option) dialer.Dialer { - options := &dialer.Options{} + options := dialer.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &http2Dialer{ clients: make(map[string]*http.Client), logger: options.Logger, + options: options, } } @@ -69,7 +71,7 @@ func (d *http2Dialer) Dial(ctx context.Context, address string, opts ...dialer.D if !ok { client = &http.Client{ Transport: &http.Transport{ - TLSClientConfig: d.md.tlsConfig, + TLSClientConfig: d.options.TLSConfig, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return d.dial(ctx, network, addr, options) }, diff --git a/pkg/dialer/http2/h2/dialer.go b/pkg/dialer/http2/h2/dialer.go index 2f4df9c..a2a7ac6 100644 --- a/pkg/dialer/http2/h2/dialer.go +++ b/pkg/dialer/http2/h2/dialer.go @@ -27,33 +27,36 @@ func init() { type h2Dialer struct { clients map[string]*http.Client clientMutex sync.Mutex + h2c bool logger logger.Logger md metadata - h2c bool + options dialer.Options } func NewDialer(opts ...dialer.Option) dialer.Dialer { - options := &dialer.Options{} + options := dialer.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &h2Dialer{ + h2c: true, clients: make(map[string]*http.Client), logger: options.Logger, - h2c: true, + options: options, } } func NewTLSDialer(opts ...dialer.Option) dialer.Dialer { - options := &dialer.Options{} + options := dialer.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &h2Dialer{ clients: make(map[string]*http.Client), logger: options.Logger, + options: options, } } @@ -95,7 +98,7 @@ func (d *h2Dialer) Dial(ctx context.Context, address string, opts ...dialer.Dial } } else { client.Transport = &http.Transport{ - TLSClientConfig: d.md.tlsConfig, + TLSClientConfig: d.options.TLSConfig, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return d.dial(ctx, network, addr, options) }, diff --git a/pkg/dialer/http2/h2/metadata.go b/pkg/dialer/http2/h2/metadata.go index 9731c4b..4dc1430 100644 --- a/pkg/dialer/http2/h2/metadata.go +++ b/pkg/dialer/http2/h2/metadata.go @@ -1,42 +1,21 @@ package h2 import ( - "crypto/tls" - "net" - - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { - path string - host string - tlsConfig *tls.Config + host string + path string } func (d *h2Dialer) parseMetadata(md mdata.Metadata) (err error) { const ( - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - secure = "secure" - serverName = "serverName" - path = "path" - ) - - d.md.host = mdata.GetString(md, serverName) - sn, _, _ := net.SplitHostPort(d.md.host) - if sn == "" { - sn = "localhost" - } - d.md.tlsConfig, err = tls_util.LoadClientConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - mdata.GetBool(md, secure), - sn, + host = "host" + path = "path" ) + d.md.host = mdata.GetString(md, host) d.md.path = mdata.GetString(md, path) return diff --git a/pkg/dialer/http2/metadata.go b/pkg/dialer/http2/metadata.go index a26dfb0..7dd0c6b 100644 --- a/pkg/dialer/http2/metadata.go +++ b/pkg/dialer/http2/metadata.go @@ -1,37 +1,12 @@ package http2 import ( - "crypto/tls" - "net" - - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { - tlsConfig *tls.Config } func (d *http2Dialer) parseMetadata(md mdata.Metadata) (err error) { - const ( - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - secure = "secure" - serverName = "serverName" - ) - - sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName)) - if sn == "" { - sn = "localhost" - } - d.md.tlsConfig, err = tls_util.LoadClientConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - mdata.GetBool(md, secure), - sn, - ) - return } diff --git a/pkg/dialer/option.go b/pkg/dialer/option.go index 5e97cd0..ab92635 100644 --- a/pkg/dialer/option.go +++ b/pkg/dialer/option.go @@ -2,6 +2,7 @@ package dialer import ( "context" + "crypto/tls" "net" "net/url" @@ -9,8 +10,9 @@ import ( ) type Options struct { - User *url.Userinfo - Logger logger.Logger + User *url.Userinfo + TLSConfig *tls.Config + Logger logger.Logger } type Option func(opts *Options) @@ -21,6 +23,12 @@ func UserOption(user *url.Userinfo) Option { } } +func TLSConfigOption(tlsConfig *tls.Config) Option { + return func(opts *Options) { + opts.TLSConfig = tlsConfig + } +} + func LoggerOption(logger logger.Logger) Option { return func(opts *Options) { opts.Logger = logger diff --git a/pkg/dialer/quic/dialer.go b/pkg/dialer/quic/dialer.go index a1ef700..163406f 100644 --- a/pkg/dialer/quic/dialer.go +++ b/pkg/dialer/quic/dialer.go @@ -24,17 +24,19 @@ type quicDialer struct { sessionMutex sync.Mutex logger logger.Logger md metadata + options dialer.Options } func NewDialer(opts ...dialer.Option) dialer.Dialer { - options := &dialer.Options{} + options := dialer.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &quicDialer{ sessions: make(map[string]*quicSession), logger: options.Logger, + options: options, } } @@ -141,7 +143,7 @@ func (d *quicDialer) initSession(ctx context.Context, addr string, conn net.Conn }, } - tlsCfg := d.md.tlsConfig + tlsCfg := d.options.TLSConfig tlsCfg.NextProtos = []string{"http/3", "quic/v1"} session, err := quic.DialContext(ctx, pc, udpAddr, addr, tlsCfg, quicConfig) diff --git a/pkg/dialer/quic/metadata.go b/pkg/dialer/quic/metadata.go index 141754b..85e7f14 100644 --- a/pkg/dialer/quic/metadata.go +++ b/pkg/dialer/quic/metadata.go @@ -1,11 +1,8 @@ package quic import ( - "crypto/tls" - "net" "time" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) @@ -15,7 +12,6 @@ type metadata struct { handshakeTimeout time.Duration cipherKey []byte - tlsConfig *tls.Config } func (d *quicDialer) parseMetadata(md mdata.Metadata) (err error) { @@ -24,12 +20,6 @@ func (d *quicDialer) parseMetadata(md mdata.Metadata) (err error) { handshakeTimeout = "handshakeTimeout" maxIdleTimeout = "maxIdleTimeout" - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - secure = "secure" - serverName = "serverName" - cipherKey = "cipherKey" ) @@ -39,18 +29,6 @@ func (d *quicDialer) parseMetadata(md mdata.Metadata) (err error) { d.md.cipherKey = []byte(key) } - sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName)) - if sn == "" { - sn = "localhost" - } - d.md.tlsConfig, err = tls_util.LoadClientConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - mdata.GetBool(md, secure), - sn, - ) - d.md.keepAlive = mdata.GetBool(md, keepAlive) d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) d.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout) diff --git a/pkg/dialer/tls/dialer.go b/pkg/dialer/tls/dialer.go index 9ed9483..c9689f2 100644 --- a/pkg/dialer/tls/dialer.go +++ b/pkg/dialer/tls/dialer.go @@ -17,18 +17,20 @@ func init() { } type tlsDialer struct { - md metadata - logger logger.Logger + md metadata + logger logger.Logger + options dialer.Options } func NewDialer(opts ...dialer.Option) dialer.Dialer { - options := &dialer.Options{} + options := dialer.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &tlsDialer{ - logger: options.Logger, + logger: options.Logger, + options: options, } } @@ -57,7 +59,7 @@ func (d *tlsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia defer conn.SetDeadline(time.Time{}) } - tlsConn := tls.Client(conn, d.md.tlsConfig) + tlsConn := tls.Client(conn, d.options.TLSConfig) if err := tlsConn.HandshakeContext(ctx); err != nil { conn.Close() return nil, err diff --git a/pkg/dialer/tls/metadata.go b/pkg/dialer/tls/metadata.go index 11ab968..2584971 100644 --- a/pkg/dialer/tls/metadata.go +++ b/pkg/dialer/tls/metadata.go @@ -1,42 +1,20 @@ package tls import ( - "crypto/tls" - "net" "time" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { - tlsConfig *tls.Config handshakeTimeout time.Duration } func (d *tlsDialer) parseMetadata(md mdata.Metadata) (err error) { const ( - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - secure = "secure" - serverName = "serverName" - handshakeTimeout = "handshakeTimeout" ) - sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName)) - if sn == "" { - sn = "localhost" - } - d.md.tlsConfig, err = tls_util.LoadClientConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - mdata.GetBool(md, secure), - sn, - ) - d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) return diff --git a/pkg/dialer/tls/mux/dialer.go b/pkg/dialer/tls/mux/dialer.go index 499f577..94d67cc 100644 --- a/pkg/dialer/tls/mux/dialer.go +++ b/pkg/dialer/tls/mux/dialer.go @@ -24,17 +24,19 @@ type mtlsDialer struct { sessionMutex sync.Mutex logger logger.Logger md metadata + options dialer.Options } func NewDialer(opts ...dialer.Option) dialer.Dialer { - options := &dialer.Options{} + options := dialer.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &mtlsDialer{ sessions: make(map[string]*muxSession), logger: options.Logger, + options: options, } } @@ -149,7 +151,7 @@ func (d *mtlsDialer) dial(ctx context.Context, network, addr string, opts *diale } func (d *mtlsDialer) initSession(ctx context.Context, conn net.Conn) (*muxSession, error) { - tlsConn := tls.Client(conn, d.md.tlsConfig) + tlsConn := tls.Client(conn, d.options.TLSConfig) if err := tlsConn.HandshakeContext(ctx); err != nil { return nil, err } diff --git a/pkg/dialer/tls/mux/metadata.go b/pkg/dialer/tls/mux/metadata.go index 75ac50b..7808170 100644 --- a/pkg/dialer/tls/mux/metadata.go +++ b/pkg/dialer/tls/mux/metadata.go @@ -1,16 +1,12 @@ package mux import ( - "crypto/tls" - "net" "time" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { - tlsConfig *tls.Config handshakeTimeout time.Duration muxKeepAliveDisabled bool @@ -23,12 +19,6 @@ type metadata struct { func (d *mtlsDialer) parseMetadata(md mdata.Metadata) (err error) { const ( - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - secure = "secure" - serverName = "serverName" - handshakeTimeout = "handshakeTimeout" muxKeepAliveDisabled = "muxKeepAliveDisabled" @@ -39,17 +29,6 @@ func (d *mtlsDialer) parseMetadata(md mdata.Metadata) (err error) { muxMaxStreamBuffer = "muxMaxStreamBuffer" ) - sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName)) - if sn == "" { - sn = "localhost" - } - d.md.tlsConfig, err = tls_util.LoadClientConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - mdata.GetBool(md, secure), - sn, - ) d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) d.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled) diff --git a/pkg/dialer/ws/dialer.go b/pkg/dialer/ws/dialer.go index c424d47..d046302 100644 --- a/pkg/dialer/ws/dialer.go +++ b/pkg/dialer/ws/dialer.go @@ -23,28 +23,31 @@ type wsDialer struct { tlsEnabled bool logger logger.Logger md metadata + options dialer.Options } func NewDialer(opts ...dialer.Option) dialer.Dialer { - options := &dialer.Options{} + options := dialer.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &wsDialer{ - logger: options.Logger, + logger: options.Logger, + options: options, } } func NewTLSDialer(opts ...dialer.Option) dialer.Dialer { - options := &dialer.Options{} + options := dialer.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &wsDialer{ tlsEnabled: true, logger: options.Logger, + options: options, } } @@ -96,7 +99,7 @@ func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dial url := url.URL{Scheme: "ws", Host: host, Path: d.md.path} if d.tlsEnabled { url.Scheme = "wss" - dialer.TLSClientConfig = d.md.tlsConfig + dialer.TLSClientConfig = d.options.TLSConfig } c, resp, err := dialer.Dial(url.String(), d.md.header) diff --git a/pkg/dialer/ws/metadata.go b/pkg/dialer/ws/metadata.go index e7e0002..e793f38 100644 --- a/pkg/dialer/ws/metadata.go +++ b/pkg/dialer/ws/metadata.go @@ -1,12 +1,9 @@ package ws import ( - "crypto/tls" - "net" "net/http" "time" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) @@ -15,9 +12,8 @@ const ( ) type metadata struct { - path string - host string - tlsConfig *tls.Config + host string + path string handshakeTimeout time.Duration readHeaderTimeout time.Duration @@ -30,14 +26,8 @@ type metadata struct { func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) { const ( - path = "path" host = "host" - - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - secure = "secure" - serverName = "serverName" + path = "path" handshakeTimeout = "handshakeTimeout" readHeaderTimeout = "readHeaderTimeout" @@ -48,25 +38,13 @@ func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) { header = "header" ) + d.md.host = mdata.GetString(md, host) + d.md.path = mdata.GetString(md, path) if d.md.path == "" { d.md.path = defaultPath } - d.md.host = mdata.GetString(md, host) - - sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName)) - if sn == "" { - sn = "localhost" - } - d.md.tlsConfig, err = tls_util.LoadClientConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - mdata.GetBool(md, secure), - sn, - ) - d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) d.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout) d.md.readBufferSize = mdata.GetInt(md, readBufferSize) diff --git a/pkg/dialer/ws/mux/dialer.go b/pkg/dialer/ws/mux/dialer.go index a230ed9..0cd8d70 100644 --- a/pkg/dialer/ws/mux/dialer.go +++ b/pkg/dialer/ws/mux/dialer.go @@ -25,33 +25,36 @@ func init() { type mwsDialer struct { sessions map[string]*muxSession sessionMutex sync.Mutex + tlsEnabled bool logger logger.Logger md metadata - tlsEnabled bool + options dialer.Options } func NewDialer(opts ...dialer.Option) dialer.Dialer { - options := &dialer.Options{} + options := dialer.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &mwsDialer{ sessions: make(map[string]*muxSession), logger: options.Logger, + options: options, } } func NewTLSDialer(opts ...dialer.Option) dialer.Dialer { - options := &dialer.Options{} + options := dialer.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &mwsDialer{ + tlsEnabled: true, sessions: make(map[string]*muxSession), logger: options.Logger, - tlsEnabled: true, + options: options, } } func (d *mwsDialer) Init(md md.Metadata) (err error) { @@ -182,7 +185,7 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) url := url.URL{Scheme: "ws", Host: host, Path: d.md.path} if d.tlsEnabled { url.Scheme = "wss" - dialer.TLSClientConfig = d.md.tlsConfig + dialer.TLSClientConfig = d.options.TLSConfig } c, resp, err := dialer.Dial(url.String(), d.md.header) diff --git a/pkg/dialer/ws/mux/metadata.go b/pkg/dialer/ws/mux/metadata.go index 56922ce..4a168cd 100644 --- a/pkg/dialer/ws/mux/metadata.go +++ b/pkg/dialer/ws/mux/metadata.go @@ -1,12 +1,9 @@ package mux import ( - "crypto/tls" - "net" "net/http" "time" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) @@ -15,9 +12,8 @@ const ( ) type metadata struct { - path string - host string - tlsConfig *tls.Config + host string + path string handshakeTimeout time.Duration readHeaderTimeout time.Duration @@ -37,14 +33,8 @@ type metadata struct { func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) { const ( - path = "path" host = "host" - - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - secure = "secure" - serverName = "serverName" + path = "path" handshakeTimeout = "handshakeTimeout" readHeaderTimeout = "readHeaderTimeout" @@ -62,25 +52,13 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) { muxMaxStreamBuffer = "muxMaxStreamBuffer" ) + d.md.host = mdata.GetString(md, host) + d.md.path = mdata.GetString(md, path) if d.md.path == "" { d.md.path = defaultPath } - d.md.host = mdata.GetString(md, host) - - sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName)) - if sn == "" { - sn = "localhost" - } - d.md.tlsConfig, err = tls_util.LoadClientConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - mdata.GetBool(md, secure), - sn, - ) - d.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled) d.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval) d.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout) diff --git a/pkg/handler/option.go b/pkg/handler/option.go index 501f6a4..5462d4a 100644 --- a/pkg/handler/option.go +++ b/pkg/handler/option.go @@ -1,6 +1,7 @@ package handler import ( + "crypto/tls" "net/url" "github.com/go-gost/gost/pkg/bypass" @@ -11,13 +12,14 @@ import ( ) type Options struct { - Retries int - Chain *chain.Chain - Resolver resolver.Resolver - Hosts hosts.HostMapper - Bypass bypass.Bypass - Auths []*url.Userinfo - Logger logger.Logger + Retries int + Chain *chain.Chain + Resolver resolver.Resolver + Hosts hosts.HostMapper + Bypass bypass.Bypass + Auths []*url.Userinfo + TLSConfig *tls.Config + Logger logger.Logger } type Option func(opts *Options) @@ -58,6 +60,12 @@ func AuthsOption(auths ...*url.Userinfo) Option { } } +func TLSConfigOption(tlsConfig *tls.Config) Option { + return func(opts *Options) { + opts.TLSConfig = tlsConfig + } +} + func LoggerOption(logger logger.Logger) Option { return func(opts *Options) { opts.Logger = logger diff --git a/pkg/handler/socks/v5/handler.go b/pkg/handler/socks/v5/handler.go index 71f5a89..fb3e78e 100644 --- a/pkg/handler/socks/v5/handler.go +++ b/pkg/handler/socks/v5/handler.go @@ -55,7 +55,7 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) { h.selector = &serverSelector{ Authenticator: auth_util.AuthFromUsers(h.options.Auths...), - TLSConfig: h.md.tlsConfig, + TLSConfig: h.options.TLSConfig, logger: h.logger, noTLS: h.md.noTLS, } diff --git a/pkg/handler/socks/v5/metadata.go b/pkg/handler/socks/v5/metadata.go index 22b098a..427d44f 100644 --- a/pkg/handler/socks/v5/metadata.go +++ b/pkg/handler/socks/v5/metadata.go @@ -1,16 +1,13 @@ package v5 import ( - "crypto/tls" "math" "time" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { - tlsConfig *tls.Config timeout time.Duration readTimeout time.Duration noTLS bool @@ -22,9 +19,6 @@ type metadata struct { func (h *socks5Handler) parseMetadata(md mdata.Metadata) (err error) { const ( - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" readTimeout = "readTimeout" timeout = "timeout" noTLS = "notls" @@ -34,15 +28,6 @@ func (h *socks5Handler) parseMetadata(md mdata.Metadata) (err error) { compatibilityMode = "comp" ) - h.md.tlsConfig, err = tls_util.LoadServerConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - ) - if err != nil { - return - } - h.md.readTimeout = mdata.GetDuration(md, readTimeout) h.md.timeout = mdata.GetDuration(md, timeout) h.md.noTLS = mdata.GetBool(md, noTLS) diff --git a/pkg/listener/dns/listener.go b/pkg/listener/dns/listener.go index 912bcf8..19c3cc2 100644 --- a/pkg/listener/dns/listener.go +++ b/pkg/listener/dns/listener.go @@ -21,23 +21,23 @@ func init() { } type dnsListener struct { - saddr string addr net.Addr server Server cqueue chan net.Conn errChan chan error logger logger.Logger md metadata + options listener.Options } func NewListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &dnsListener{ - saddr: options.Addr, - logger: options.Logger, + logger: options.Logger, + options: options, } } @@ -46,7 +46,7 @@ func (l *dnsListener) Init(md md.Metadata) (err error) { return } - l.addr, err = net.ResolveTCPAddr("tcp", l.saddr) + l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr) if err != nil { return err } @@ -55,7 +55,7 @@ func (l *dnsListener) Init(md md.Metadata) (err error) { case "tcp": l.server = &dns.Server{ Net: "tcp", - Addr: l.saddr, + Addr: l.options.Addr, Handler: l, ReadTimeout: l.md.readTimeout, WriteTimeout: l.md.writeTimeout, @@ -63,16 +63,16 @@ func (l *dnsListener) Init(md md.Metadata) (err error) { case "tls": l.server = &dns.Server{ Net: "tcp-tls", - Addr: l.saddr, + Addr: l.options.Addr, Handler: l, - TLSConfig: l.md.tlsConfig, + TLSConfig: l.options.TLSConfig, ReadTimeout: l.md.readTimeout, WriteTimeout: l.md.writeTimeout, } case "https": l.server = &dohServer{ - addr: l.saddr, - tlsConfig: l.md.tlsConfig, + addr: l.options.Addr, + tlsConfig: l.options.TLSConfig, server: &http.Server{ Handler: l, ReadTimeout: l.md.readTimeout, @@ -80,10 +80,10 @@ func (l *dnsListener) Init(md md.Metadata) (err error) { }, } default: - l.addr, err = net.ResolveUDPAddr("udp", l.saddr) + l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr) l.server = &dns.Server{ Net: "udp", - Addr: l.saddr, + Addr: l.options.Addr, Handler: l, UDPSize: l.md.readBufferSize, ReadTimeout: l.md.readTimeout, diff --git a/pkg/listener/dns/metadata.go b/pkg/listener/dns/metadata.go index 4269924..187af6d 100644 --- a/pkg/listener/dns/metadata.go +++ b/pkg/listener/dns/metadata.go @@ -1,10 +1,8 @@ package dns import ( - "crypto/tls" "time" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) @@ -17,7 +15,6 @@ type metadata struct { readBufferSize int readTimeout time.Duration writeTimeout time.Duration - tlsConfig *tls.Config backlog int } @@ -26,24 +23,12 @@ func (l *dnsListener) parseMetadata(md mdata.Metadata) (err error) { mode = "mode" readBufferSize = "readBufferSize" - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - backlog = "backlog" ) l.md.mode = mdata.GetString(md, mode) l.md.readBufferSize = mdata.GetInt(md, readBufferSize) - l.md.tlsConfig, err = tls_util.LoadServerConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - ) - if err != nil { - return - } l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog diff --git a/pkg/listener/http2/h2/listener.go b/pkg/listener/http2/h2/listener.go index 92bd4eb..8e2a408 100644 --- a/pkg/listener/http2/h2/listener.go +++ b/pkg/listener/http2/h2/listener.go @@ -22,35 +22,35 @@ func init() { type h2Listener struct { server *http.Server - saddr string addr net.Addr cqueue chan net.Conn errChan chan error logger logger.Logger md metadata h2c bool + options listener.Options } func NewListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &h2Listener{ - saddr: options.Addr, - logger: options.Logger, - h2c: true, + h2c: true, + logger: options.Logger, + options: options, } } func NewTLSListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &h2Listener{ - saddr: options.Addr, - logger: options.Logger, + logger: options.Logger, + options: options, } } @@ -60,10 +60,10 @@ func (l *h2Listener) Init(md md.Metadata) (err error) { } l.server = &http.Server{ - Addr: l.saddr, + Addr: l.options.Addr, } - ln, err := net.Listen("tcp", l.saddr) + ln, err := net.Listen("tcp", l.options.Addr) if err != nil { return err } @@ -74,12 +74,12 @@ func (l *h2Listener) Init(md md.Metadata) (err error) { http.HandlerFunc(l.handleFunc), &http2.Server{}) } else { l.server.Handler = http.HandlerFunc(l.handleFunc) - l.server.TLSConfig = l.md.tlsConfig + l.server.TLSConfig = l.options.TLSConfig if err := http2.ConfigureServer(l.server, nil); err != nil { ln.Close() return err } - ln = tls.NewListener(ln, l.md.tlsConfig) + ln = tls.NewListener(ln, l.options.TLSConfig) } l.cqueue = make(chan net.Conn, l.md.backlog) diff --git a/pkg/listener/http2/h2/metadata.go b/pkg/listener/http2/h2/metadata.go index 3eca516..ca11e16 100644 --- a/pkg/listener/http2/h2/metadata.go +++ b/pkg/listener/http2/h2/metadata.go @@ -1,9 +1,6 @@ package h2 import ( - "crypto/tls" - - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) @@ -12,29 +9,16 @@ const ( ) type metadata struct { - path string - tlsConfig *tls.Config - backlog int + path string + backlog int } func (l *h2Listener) parseMetadata(md mdata.Metadata) (err error) { const ( - path = "path" - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - backlog = "backlog" + path = "path" + backlog = "backlog" ) - l.md.tlsConfig, err = tls_util.LoadServerConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - ) - if err != nil { - return - } - l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog diff --git a/pkg/listener/http2/listener.go b/pkg/listener/http2/listener.go index f5f81ac..9a6caec 100644 --- a/pkg/listener/http2/listener.go +++ b/pkg/listener/http2/listener.go @@ -20,22 +20,22 @@ func init() { type http2Listener struct { server *http.Server - saddr string addr net.Addr cqueue chan net.Conn errChan chan error logger logger.Logger md metadata + options listener.Options } func NewListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &http2Listener{ - saddr: options.Addr, - logger: options.Logger, + logger: options.Logger, + options: options, } } @@ -45,15 +45,15 @@ func (l *http2Listener) Init(md md.Metadata) (err error) { } l.server = &http.Server{ - Addr: l.saddr, + Addr: l.options.Addr, Handler: http.HandlerFunc(l.handleFunc), - TLSConfig: l.md.tlsConfig, + TLSConfig: l.options.TLSConfig, } if err := http2.ConfigureServer(l.server, nil); err != nil { return err } - ln, err := net.Listen("tcp", l.saddr) + ln, err := net.Listen("tcp", l.options.Addr) if err != nil { return err } @@ -63,7 +63,7 @@ func (l *http2Listener) Init(md md.Metadata) (err error) { &util.TCPKeepAliveListener{ TCPListener: ln.(*net.TCPListener), }, - l.md.tlsConfig, + l.options.TLSConfig, ) l.cqueue = make(chan net.Conn, l.md.backlog) diff --git a/pkg/listener/http2/metadata.go b/pkg/listener/http2/metadata.go index ecfede0..ae6e817 100644 --- a/pkg/listener/http2/metadata.go +++ b/pkg/listener/http2/metadata.go @@ -1,11 +1,9 @@ package http2 import ( - "crypto/tls" "net/http" "time" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) @@ -15,7 +13,6 @@ const ( type metadata struct { path string - tlsConfig *tls.Config handshakeTimeout time.Duration readHeaderTimeout time.Duration readBufferSize int @@ -28,9 +25,6 @@ type metadata struct { func (l *http2Listener) parseMetadata(md mdata.Metadata) (err error) { const ( path = "path" - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" handshakeTimeout = "handshakeTimeout" readHeaderTimeout = "readHeaderTimeout" readBufferSize = "readBufferSize" @@ -38,15 +32,6 @@ func (l *http2Listener) parseMetadata(md mdata.Metadata) (err error) { backlog = "backlog" ) - l.md.tlsConfig, err = tls_util.LoadServerConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - ) - if err != nil { - return - } - l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog diff --git a/pkg/listener/option.go b/pkg/listener/option.go index 828aa75..5ccc37d 100644 --- a/pkg/listener/option.go +++ b/pkg/listener/option.go @@ -1,15 +1,17 @@ package listener import ( + "crypto/tls" "net/url" "github.com/go-gost/gost/pkg/logger" ) type Options struct { - Addr string - Auths []*url.Userinfo - Logger logger.Logger + Addr string + Auths []*url.Userinfo + TLSConfig *tls.Config + Logger logger.Logger } type Option func(opts *Options) @@ -26,6 +28,12 @@ func AuthsOption(auths ...*url.Userinfo) Option { } } +func TLSConfigOption(tlsConfig *tls.Config) Option { + return func(opts *Options) { + opts.TLSConfig = tlsConfig + } +} + func LoggerOption(logger logger.Logger) Option { return func(opts *Options) { opts.Logger = logger diff --git a/pkg/listener/quic/listener.go b/pkg/listener/quic/listener.go index 741c9a5..0ca6fbc 100644 --- a/pkg/listener/quic/listener.go +++ b/pkg/listener/quic/listener.go @@ -17,22 +17,22 @@ func init() { } type quicListener struct { - addr string ln quic.Listener cqueue chan net.Conn errChan chan error logger logger.Logger md metadata + options listener.Options } func NewListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &quicListener{ - addr: options.Addr, - logger: options.Logger, + logger: options.Logger, + options: options, } } @@ -41,7 +41,7 @@ func (l *quicListener) Init(md md.Metadata) (err error) { return } - laddr, err := net.ResolveUDPAddr("udp", l.addr) + laddr, err := net.ResolveUDPAddr("udp", l.options.Addr) if err != nil { return } @@ -67,7 +67,7 @@ func (l *quicListener) Init(md md.Metadata) (err error) { }, } - tlsCfg := l.md.tlsConfig + tlsCfg := l.options.TLSConfig tlsCfg.NextProtos = []string{"http/3", "quic/v1"} ln, err := quic.Listen(conn, tlsCfg, config) diff --git a/pkg/listener/quic/metadata.go b/pkg/listener/quic/metadata.go index 3d43ede..a86789f 100644 --- a/pkg/listener/quic/metadata.go +++ b/pkg/listener/quic/metadata.go @@ -1,10 +1,8 @@ package quic import ( - "crypto/tls" "time" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) @@ -17,7 +15,6 @@ type metadata struct { handshakeTimeout time.Duration maxIdleTimeout time.Duration - tlsConfig *tls.Config cipherKey []byte backlog int } @@ -28,23 +25,10 @@ func (l *quicListener) parseMetadata(md mdata.Metadata) (err error) { handshakeTimeout = "handshakeTimeout" maxIdleTimeout = "maxIdleTimeout" - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - backlog = "backlog" cipherKey = "cipherKey" ) - l.md.tlsConfig, err = tls_util.LoadServerConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - ) - - if err != nil { - return - } l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog diff --git a/pkg/listener/tls/listener.go b/pkg/listener/tls/listener.go index c56dd27..93f7971 100644 --- a/pkg/listener/tls/listener.go +++ b/pkg/listener/tls/listener.go @@ -15,20 +15,20 @@ func init() { } type tlsListener struct { - addr string net.Listener - logger logger.Logger - md metadata + logger logger.Logger + md metadata + options listener.Options } func NewListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &tlsListener{ - addr: options.Addr, - logger: options.Logger, + logger: options.Logger, + options: options, } } @@ -37,12 +37,12 @@ func (l *tlsListener) Init(md md.Metadata) (err error) { return } - ln, err := net.Listen("tcp", l.addr) + ln, err := net.Listen("tcp", l.options.Addr) if err != nil { return } - l.Listener = tls.NewListener(ln, l.md.tlsConfig) + l.Listener = tls.NewListener(ln, l.options.TLSConfig) return } diff --git a/pkg/listener/tls/metadata.go b/pkg/listener/tls/metadata.go index d5067c2..2047776 100644 --- a/pkg/listener/tls/metadata.go +++ b/pkg/listener/tls/metadata.go @@ -1,31 +1,12 @@ package tls import ( - "crypto/tls" - - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { - tlsConfig *tls.Config } func (l *tlsListener) parseMetadata(md mdata.Metadata) (err error) { - const ( - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - ) - - l.md.tlsConfig, err = tls_util.LoadServerConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - ) - if err != nil { - return - } - return } diff --git a/pkg/listener/tls/mux/listener.go b/pkg/listener/tls/mux/listener.go index 01f1cc7..4261346 100644 --- a/pkg/listener/tls/mux/listener.go +++ b/pkg/listener/tls/mux/listener.go @@ -16,22 +16,22 @@ func init() { } type mtlsListener struct { - addr string net.Listener cqueue chan net.Conn errChan chan error logger logger.Logger md metadata + options listener.Options } func NewListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &mtlsListener{ - addr: options.Addr, - logger: options.Logger, + logger: options.Logger, + options: options, } } @@ -40,11 +40,11 @@ func (l *mtlsListener) Init(md md.Metadata) (err error) { return } - ln, err := net.Listen("tcp", l.addr) + ln, err := net.Listen("tcp", l.options.Addr) if err != nil { return } - l.Listener = tls.NewListener(ln, l.md.tlsConfig) + l.Listener = tls.NewListener(ln, l.options.TLSConfig) l.cqueue = make(chan net.Conn, l.md.backlog) l.errChan = make(chan error, 1) diff --git a/pkg/listener/tls/mux/metadata.go b/pkg/listener/tls/mux/metadata.go index 9a2119a..dad3c0f 100644 --- a/pkg/listener/tls/mux/metadata.go +++ b/pkg/listener/tls/mux/metadata.go @@ -1,10 +1,8 @@ package mux import ( - "crypto/tls" "time" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) @@ -13,8 +11,6 @@ const ( ) type metadata struct { - tlsConfig *tls.Config - muxKeepAliveDisabled bool muxKeepAliveInterval time.Duration muxKeepAliveTimeout time.Duration @@ -27,10 +23,6 @@ type metadata struct { func (l *mtlsListener) parseMetadata(md mdata.Metadata) (err error) { const ( - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - backlog = "backlog" muxKeepAliveDisabled = "muxKeepAliveDisabled" @@ -41,15 +33,6 @@ func (l *mtlsListener) parseMetadata(md mdata.Metadata) (err error) { muxMaxStreamBuffer = "muxMaxStreamBuffer" ) - l.md.tlsConfig, err = tls_util.LoadServerConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - ) - if err != nil { - return - } - l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { l.md.backlog = defaultBacklog diff --git a/pkg/listener/ws/listener.go b/pkg/listener/ws/listener.go index 4c7c68c..165954a 100644 --- a/pkg/listener/ws/listener.go +++ b/pkg/listener/ws/listener.go @@ -20,7 +20,6 @@ func init() { } type wsListener struct { - saddr string addr net.Addr upgrader *websocket.Upgrader srv *http.Server @@ -29,28 +28,29 @@ type wsListener struct { errChan chan error logger logger.Logger md metadata + options listener.Options } func NewListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &wsListener{ - saddr: options.Addr, - logger: options.Logger, + logger: options.Logger, + options: options, } } func NewTLSListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &wsListener{ - saddr: options.Addr, - logger: options.Logger, tlsEnabled: true, + logger: options.Logger, + options: options, } } @@ -70,7 +70,7 @@ func (l *wsListener) Init(md md.Metadata) (err error) { mux := http.NewServeMux() mux.Handle(l.md.path, http.HandlerFunc(l.upgrade)) l.srv = &http.Server{ - Addr: l.saddr, + Addr: l.options.Addr, Handler: mux, ReadHeaderTimeout: l.md.readHeaderTimeout, } @@ -78,12 +78,12 @@ func (l *wsListener) Init(md md.Metadata) (err error) { l.cqueue = make(chan net.Conn, l.md.backlog) l.errChan = make(chan error, 1) - ln, err := net.Listen("tcp", l.saddr) + ln, err := net.Listen("tcp", l.options.Addr) if err != nil { return } if l.tlsEnabled { - ln = tls.NewListener(ln, l.md.tlsConfig) + ln = tls.NewListener(ln, l.options.TLSConfig) } l.addr = ln.Addr() diff --git a/pkg/listener/ws/metadata.go b/pkg/listener/ws/metadata.go index f45219a..a4c79c9 100644 --- a/pkg/listener/ws/metadata.go +++ b/pkg/listener/ws/metadata.go @@ -1,11 +1,9 @@ package ws import ( - "crypto/tls" "net/http" "time" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) @@ -15,9 +13,8 @@ const ( ) type metadata struct { - path string - backlog int - tlsConfig *tls.Config + path string + backlog int handshakeTimeout time.Duration readHeaderTimeout time.Duration @@ -30,10 +27,6 @@ type metadata struct { func (l *wsListener) parseMetadata(md mdata.Metadata) (err error) { const ( - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - path = "path" backlog = "backlog" @@ -46,15 +39,6 @@ func (l *wsListener) parseMetadata(md mdata.Metadata) (err error) { header = "header" ) - l.md.tlsConfig, err = tls_util.LoadServerConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - ) - if err != nil { - return - } - l.md.path = mdata.GetString(md, path) if l.md.path == "" { l.md.path = defaultPath diff --git a/pkg/listener/ws/mux/listener.go b/pkg/listener/ws/mux/listener.go index 0d90a13..bad8739 100644 --- a/pkg/listener/ws/mux/listener.go +++ b/pkg/listener/ws/mux/listener.go @@ -21,37 +21,37 @@ func init() { } type mwsListener struct { - saddr string addr net.Addr upgrader *websocket.Upgrader srv *http.Server cqueue chan net.Conn errChan chan error + tlsEnabled bool logger logger.Logger md metadata - tlsEnabled bool + options listener.Options } func NewListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &mwsListener{ - saddr: options.Addr, - logger: options.Logger, + logger: options.Logger, + options: options, } } func NewTLSListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &mwsListener{ - saddr: options.Addr, - logger: options.Logger, tlsEnabled: true, + logger: options.Logger, + options: options, } } @@ -75,7 +75,7 @@ func (l *mwsListener) Init(md md.Metadata) (err error) { mux := http.NewServeMux() mux.Handle(path, http.HandlerFunc(l.upgrade)) l.srv = &http.Server{ - Addr: l.saddr, + Addr: l.options.Addr, Handler: mux, ReadHeaderTimeout: l.md.readHeaderTimeout, } @@ -83,12 +83,12 @@ func (l *mwsListener) Init(md md.Metadata) (err error) { l.cqueue = make(chan net.Conn, l.md.backlog) l.errChan = make(chan error, 1) - ln, err := net.Listen("tcp", l.saddr) + ln, err := net.Listen("tcp", l.options.Addr) if err != nil { return } if l.tlsEnabled { - ln = tls.NewListener(ln, l.md.tlsConfig) + ln = tls.NewListener(ln, l.options.TLSConfig) } l.addr = ln.Addr() diff --git a/pkg/listener/ws/mux/metadata.go b/pkg/listener/ws/mux/metadata.go index e17f2ab..33e4a16 100644 --- a/pkg/listener/ws/mux/metadata.go +++ b/pkg/listener/ws/mux/metadata.go @@ -1,11 +1,9 @@ package mux import ( - "crypto/tls" "net/http" "time" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" mdata "github.com/go-gost/gost/pkg/metadata" ) @@ -15,10 +13,9 @@ const ( ) type metadata struct { - path string - backlog int - tlsConfig *tls.Config - header http.Header + path string + backlog int + header http.Header handshakeTimeout time.Duration readHeaderTimeout time.Duration @@ -40,10 +37,6 @@ func (l *mwsListener) parseMetadata(md mdata.Metadata) (err error) { backlog = "backlog" header = "header" - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - handshakeTimeout = "handshakeTimeout" readHeaderTimeout = "readHeaderTimeout" readBufferSize = "readBufferSize" @@ -58,15 +51,6 @@ func (l *mwsListener) parseMetadata(md mdata.Metadata) (err error) { muxMaxStreamBuffer = "muxMaxStreamBuffer" ) - l.md.tlsConfig, err = tls_util.LoadServerConfig( - mdata.GetString(md, certFile), - mdata.GetString(md, keyFile), - mdata.GetString(md, caFile), - ) - if err != nil { - return - } - l.md.path = mdata.GetString(md, path) if l.md.path == "" { l.md.path = defaultPath