diff --git a/cmd/gost/cmd.go b/cmd/gost/cmd.go index 8df18d9..7d0b5c1 100644 --- a/cmd/gost/cmd.go +++ b/cmd/gost/cmd.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/go-gost/gost/pkg/config" + "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" ) @@ -67,7 +68,11 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { } service.Name = fmt.Sprintf("service-%d", i) if chain != nil { - service.Handler.Chain = chain.Name + if service.Listener.Type == "rtcp" || service.Listener.Type == "rudp" { + service.Listener.Chain = chain.Name + } else { + service.Handler.Chain = chain.Name + } } cfg.Services = append(cfg.Services, service) } @@ -125,37 +130,30 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) { auths = append(auths, auth) } - md := make(map[string]interface{}) + md := metadata.MapMetadata{} for k, v := range url.Query() { if len(v) > 0 { md[k] = v[0] } } - if sauth := md["auth"]; sauth != nil { - if sa, _ := sauth.(string); sa != "" { - au, err := parseAuthFromCmd(sa) - if err != nil { - return nil, err - } - auths = append(auths, au) - } - } - delete(md, "auth") - var tlsConfig *config.TLSConfig - if certs := md["cert"]; certs != nil { - cert, _ := certs.(string) - key, _ := md["key"].(string) - ca, _ := md["ca"].(string) - tlsConfig = &config.TLSConfig{ - Cert: cert, - Key: key, - CA: ca, + if sa := metadata.GetString(md, "auth"); sa != "" { + au, err := parseAuthFromCmd(sa) + if err != nil { + return nil, err } + auths = append(auths, au) } - delete(md, "cert") - delete(md, "key") - delete(md, "ca") + md.Del("auth") + + tlsConfig := &config.TLSConfig{ + Cert: metadata.GetString(md, "cert"), + Key: metadata.GetString(md, "key"), + CA: metadata.GetString(md, "ca"), + } + md.Del("cert") + md.Del("key") + md.Del("ca") svc.Handler = &config.HandlerConfig{ Type: handler, @@ -205,45 +203,33 @@ func buildNodeConfig(url *url.URL) (*config.NodeConfig, error) { auth.Password, _ = url.User.Password() } - md := make(map[string]interface{}) + md := metadata.MapMetadata{} for k, v := range url.Query() { if len(v) > 0 { md[k] = v[0] } } - md["serverName"] = url.Host - if sauth := md["auth"]; sauth != nil && auth == nil { - if sa, _ := sauth.(string); sa != "" { - au, err := parseAuthFromCmd(sa) - if err != nil { - return nil, err - } - auth = au + if sauth := metadata.GetString(md, "auth"); sauth != "" && auth == nil { + au, err := parseAuthFromCmd(sauth) + if err != nil { + return nil, err } + auth = au } - delete(md, "auth") + md.Del("auth") - var tlsConfig *config.TLSConfig - if certs := md["cert"]; certs != nil { - cert, _ := certs.(string) - key, _ := md["key"].(string) - ca, _ := md["ca"].(string) - secure, _ := md["secure"].(bool) - serverName, _ := md["serverName"].(string) - tlsConfig = &config.TLSConfig{ - Cert: cert, - Key: key, - CA: ca, - Secure: secure, - ServerName: serverName, - } + tlsConfig := &config.TLSConfig{ + CA: metadata.GetString(md, "ca"), + Secure: metadata.GetBool(md, "secure"), + ServerName: metadata.GetString(md, "serverName"), } - delete(md, "cert") - delete(md, "key") - delete(md, "ca") - delete(md, "secure") - delete(md, "serverName") + if tlsConfig.ServerName == "" { + tlsConfig.ServerName = url.Hostname() + } + md.Del("ca") + md.Del("secure") + md.Del("serverName") node.Connector = &config.ConnectorConfig{ Type: connector, diff --git a/cmd/gost/config.go b/cmd/gost/config.go index bb93020..011bf0c 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -56,8 +56,15 @@ func buildService(cfg *config.Config) (services []*service.Service) { } for _, svc := range cfg.Services { - if svc.Listener == nil || svc.Handler == nil { - continue + if svc.Listener == nil { + svc.Listener = &config.ListenerConfig{ + Type: "tcp", + } + } + if svc.Handler == nil { + svc.Handler = &config.HandlerConfig{ + Type: "auto", + } } serviceLogger := log.WithFields(map[string]interface{}{ "kind": "service", @@ -89,10 +96,6 @@ func buildService(cfg *config.Config) (services []*service.Service) { listener.LoggerOption(listenerLogger), ) - if chainable, ok := ln.(chain.Chainable); ok { - chainable.WithChain(chains[svc.Listener.Chain]) - } - if svc.Listener.Metadata == nil { svc.Listener.Metadata = make(map[string]interface{}) } @@ -292,9 +295,9 @@ func logFromConfig(cfg *config.LogConfig) logger.Logger { switch cfg.Output { case "none": return logger.Nop() - case "stdout", "": + case "stdout": out = os.Stdout - case "stderr": + case "stderr", "": out = os.Stderr default: f, err := os.OpenFile(cfg.Output, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) diff --git a/pkg/chain/chain.go b/pkg/chain/chain.go index dfdc88a..515f200 100644 --- a/pkg/chain/chain.go +++ b/pkg/chain/chain.go @@ -1,9 +1,5 @@ package chain -type Chainable interface { - WithChain(chain *Chain) -} - type Chain struct { groups []*NodeGroup } diff --git a/pkg/config/config.go b/pkg/config/config.go index 8d6d286..1f718da 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -31,11 +31,11 @@ type ProfilingConfig struct { } type TLSConfig struct { - Cert string - Key string + Cert string `yaml:",omitempty"` + Key string `yaml:",omitempty"` CA string `yaml:",omitempty"` Secure bool `yaml:",omitempty"` - ServerName string `yaml:",omitempty"` + ServerName string `yaml:"serverName,omitempty"` } type AuthConfig struct { @@ -59,7 +59,7 @@ type NameserverConfig struct { Addr string Chain string Prefer string - ClientIP string + ClientIP string `yaml:"clientIP"` Hostname string TTL time.Duration Timeout time.Duration diff --git a/pkg/listener/dns/metadata.go b/pkg/listener/dns/metadata.go index 187af6d..218f0d7 100644 --- a/pkg/listener/dns/metadata.go +++ b/pkg/listener/dns/metadata.go @@ -20,14 +20,17 @@ type metadata struct { func (l *dnsListener) parseMetadata(md mdata.Metadata) (err error) { const ( + backlog = "backlog" mode = "mode" readBufferSize = "readBufferSize" - - backlog = "backlog" + readTimeout = "readTimeout" + writeTimeout = "writeTimeout" ) l.md.mode = mdata.GetString(md, mode) l.md.readBufferSize = mdata.GetInt(md, readBufferSize) + l.md.readTimeout = mdata.GetDuration(md, readTimeout) + l.md.writeTimeout = mdata.GetDuration(md, writeTimeout) l.md.backlog = mdata.GetInt(md, backlog) if l.md.backlog <= 0 { diff --git a/pkg/listener/http2/metadata.go b/pkg/listener/http2/metadata.go index ae6e817..472aab7 100644 --- a/pkg/listener/http2/metadata.go +++ b/pkg/listener/http2/metadata.go @@ -1,9 +1,6 @@ package http2 import ( - "net/http" - "time" - mdata "github.com/go-gost/gost/pkg/metadata" ) @@ -12,24 +9,12 @@ const ( ) type metadata struct { - path string - handshakeTimeout time.Duration - readHeaderTimeout time.Duration - readBufferSize int - writeBufferSize int - enableCompression bool - responseHeader http.Header - backlog int + backlog int } func (l *http2Listener) parseMetadata(md mdata.Metadata) (err error) { const ( - path = "path" - handshakeTimeout = "handshakeTimeout" - readHeaderTimeout = "readHeaderTimeout" - readBufferSize = "readBufferSize" - writeBufferSize = "writeBufferSize" - backlog = "backlog" + backlog = "backlog" ) l.md.backlog = mdata.GetInt(md, backlog) diff --git a/pkg/listener/option.go b/pkg/listener/option.go index 5ccc37d..ce810e4 100644 --- a/pkg/listener/option.go +++ b/pkg/listener/option.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "net/url" + "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/logger" ) @@ -11,6 +12,7 @@ type Options struct { Addr string Auths []*url.Userinfo TLSConfig *tls.Config + Chain *chain.Chain Logger logger.Logger } @@ -34,6 +36,12 @@ func TLSConfigOption(tlsConfig *tls.Config) Option { } } +func ChainOption(chain *chain.Chain) Option { + return func(opts *Options) { + opts.Chain = chain + } +} + func LoggerOption(logger logger.Logger) Option { return func(opts *Options) { opts.Logger = logger diff --git a/pkg/listener/rtcp/listener.go b/pkg/listener/rtcp/listener.go index 55ffd54..0f7db34 100644 --- a/pkg/listener/rtcp/listener.go +++ b/pkg/listener/rtcp/listener.go @@ -17,47 +17,42 @@ func init() { } type rtcpListener struct { - addr string - laddr net.Addr - chain *chain.Chain - ln net.Listener - md metadata - router *chain.Router - logger logger.Logger - closed chan struct{} + laddr net.Addr + ln net.Listener + md metadata + router *chain.Router + logger logger.Logger + closed chan struct{} + options listener.Options } func NewListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &rtcpListener{ - addr: options.Addr, - closed: make(chan struct{}), - router: &chain.Router{ - Logger: options.Logger, - }, - logger: options.Logger, + closed: make(chan struct{}), + logger: options.Logger, + options: options, } } -// implements chain.Chainable interface -func (l *rtcpListener) WithChain(chain *chain.Chain) { - l.router.Chain = chain -} - func (l *rtcpListener) Init(md md.Metadata) (err error) { if err = l.parseMetadata(md); err != nil { return } - laddr, err := net.ResolveTCPAddr("tcp", l.addr) + laddr, err := net.ResolveTCPAddr("tcp", l.options.Addr) if err != nil { return } l.laddr = laddr + l.router = &chain.Router{ + Chain: l.options.Chain, + Logger: l.logger, + } return } diff --git a/pkg/listener/rtcp/metadata.go b/pkg/listener/rtcp/metadata.go index 4a723ec..544626e 100644 --- a/pkg/listener/rtcp/metadata.go +++ b/pkg/listener/rtcp/metadata.go @@ -12,24 +12,8 @@ const ( ) type metadata struct { - enableMux bool - backlog int - retryCount int } func (l *rtcpListener) parseMetadata(md mdata.Metadata) (err error) { - const ( - enableMux = "mux" - backlog = "backlog" - retryCount = "retry" - ) - - l.md.enableMux = mdata.GetBool(md, enableMux) - l.md.retryCount = mdata.GetInt(md, retryCount) - - l.md.backlog = mdata.GetInt(md, backlog) - if l.md.backlog <= 0 { - l.md.backlog = defaultBacklog - } return } diff --git a/pkg/listener/rudp/listener.go b/pkg/listener/rudp/listener.go index 7d03e65..e604f39 100644 --- a/pkg/listener/rudp/listener.go +++ b/pkg/listener/rudp/listener.go @@ -17,47 +17,42 @@ func init() { } type rudpListener struct { - addr string - laddr *net.UDPAddr - chain *chain.Chain - ln net.Listener - md metadata - router *chain.Router - logger logger.Logger - closed chan struct{} + laddr *net.UDPAddr + ln net.Listener + router *chain.Router + closed chan struct{} + logger logger.Logger + md metadata + options listener.Options } func NewListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &rudpListener{ - addr: options.Addr, - closed: make(chan struct{}), - router: &chain.Router{ - Logger: options.Logger, - }, - logger: options.Logger, + closed: make(chan struct{}), + logger: options.Logger, + options: options, } } -// implements chain.Chainable interface -func (l *rudpListener) WithChain(chain *chain.Chain) { - l.router.Chain = chain -} - func (l *rudpListener) Init(md md.Metadata) (err error) { if err = l.parseMetadata(md); err != nil { return } - laddr, err := net.ResolveUDPAddr("udp", l.addr) + laddr, err := net.ResolveUDPAddr("udp", l.options.Addr) if err != nil { return } l.laddr = laddr + l.router = &chain.Router{ + Chain: l.options.Chain, + Logger: l.logger, + } return } diff --git a/pkg/listener/tun/listener.go b/pkg/listener/tun/listener.go index 683bb82..26da643 100644 --- a/pkg/listener/tun/listener.go +++ b/pkg/listener/tun/listener.go @@ -15,22 +15,22 @@ func init() { } type tunListener struct { - saddr string - addr net.Addr - cqueue chan net.Conn - closed chan struct{} - logger logger.Logger - md metadata + addr net.Addr + cqueue chan net.Conn + closed chan struct{} + logger logger.Logger + md metadata + options listener.Options } func NewListener(opts ...listener.Option) listener.Listener { - options := &listener.Options{} + options := listener.Options{} for _, opt := range opts { - opt(options) + opt(&options) } return &tunListener{ - saddr: options.Addr, - logger: options.Logger, + logger: options.Logger, + options: options, } } @@ -39,7 +39,7 @@ func (l *tunListener) Init(md md.Metadata) (err error) { return } - l.addr, err = net.ResolveUDPAddr("udp", l.saddr) + l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr) if err != nil { return } diff --git a/pkg/metadata/metadata.go b/pkg/metadata/metadata.go index 93f9fee..d4a12d3 100644 --- a/pkg/metadata/metadata.go +++ b/pkg/metadata/metadata.go @@ -30,6 +30,10 @@ func (m MapMetadata) Get(key string) interface{} { return nil } +func (m MapMetadata) Del(key string) { + delete(m, key) +} + func GetBool(md Metadata, key string) (v bool) { if md == nil || !md.IsExists(key) { return