diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go new file mode 100644 index 0000000..d77ee19 --- /dev/null +++ b/pkg/auth/auth.go @@ -0,0 +1,40 @@ +package auth + +// Authenticator is an interface for user authentication. +type Authenticator interface { + Authenticate(user, password string) bool +} + +// LocalAuthenticator is an Authenticator that authenticates client by local key-value pairs. +type LocalAuthenticator struct { + kvs map[string]string +} + +// NewLocalAuthenticator creates an Authenticator that authenticates client by local infos. +func NewLocalAuthenticator(kvs map[string]string) *LocalAuthenticator { + if kvs == nil { + kvs = make(map[string]string) + } + return &LocalAuthenticator{ + kvs: kvs, + } +} + +// Authenticate checks the validity of the provided user-password pair. +func (au *LocalAuthenticator) Authenticate(user, password string) bool { + if au == nil { + return true + } + + if len(au.kvs) == 0 { + return true + } + + v, ok := au.kvs[user] + return ok && (v == "" || password == v) +} + +// Add adds a key-value pair to the Authenticator. +func (au *LocalAuthenticator) Add(k, v string) { + au.kvs[k] = v +} diff --git a/pkg/components/connector/connector.go b/pkg/components/connector/connector.go index b75d918..118c0ce 100644 --- a/pkg/components/connector/connector.go +++ b/pkg/components/connector/connector.go @@ -3,10 +3,12 @@ package connector import ( "context" "net" + + "github.com/go-gost/gost/pkg/components/metadata" ) // Connector is responsible for connecting to the destination address. type Connector interface { - Init(Metadata) error + Init(metadata.Metadata) error Connect(ctx context.Context, conn net.Conn, network, address string, opts ...ConnectOption) (net.Conn, error) } diff --git a/pkg/components/connector/http/connector.go b/pkg/components/connector/http/connector.go index d21494b..040a8eb 100644 --- a/pkg/components/connector/http/connector.go +++ b/pkg/components/connector/http/connector.go @@ -9,8 +9,10 @@ import ( "net" "net/http" "net/url" + "strings" "github.com/go-gost/gost/pkg/components/connector" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" ) @@ -35,12 +37,8 @@ func NewConnector(opts ...connector.Option) connector.Connector { } } -func (c *Connector) Init(md connector.Metadata) (err error) { - c.md, err = c.parseMetadata(md) - if err != nil { - return - } - return nil +func (c *Connector) Init(md md.Metadata) (err error) { + return c.parseMetadata(md) } func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { @@ -83,17 +81,19 @@ func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address return conn, nil } -func (c *Connector) parseMetadata(md connector.Metadata) (m metadata, err error) { - if md == nil { - md = connector.Metadata{} - } - m.UserAgent = md[userAgent] - if m.UserAgent == "" { - m.UserAgent = defaultUserAgent +func (c *Connector) parseMetadata(md md.Metadata) (err error) { + c.md.UserAgent, _ = md.Get(userAgent).(string) + if c.md.UserAgent == "" { + c.md.UserAgent = defaultUserAgent } - if v, ok := md[username]; ok { - m.User = url.UserPassword(v, md[password]) + if v := md.GetString(auth); v != "" { + ss := strings.SplitN(v, ":", 2) + if len(ss) == 1 { + c.md.User = url.User(ss[0]) + } else { + c.md.User = url.UserPassword(ss[0], ss[1]) + } } return diff --git a/pkg/components/connector/http/metadata.go b/pkg/components/connector/http/metadata.go index cfe883f..ab4a782 100644 --- a/pkg/components/connector/http/metadata.go +++ b/pkg/components/connector/http/metadata.go @@ -4,8 +4,7 @@ import "net/url" const ( userAgent = "userAgent" - username = "username" - password = "password" + auth = "auth" ) const ( diff --git a/pkg/components/connector/metadata.go b/pkg/components/connector/metadata.go deleted file mode 100644 index a2901bb..0000000 --- a/pkg/components/connector/metadata.go +++ /dev/null @@ -1,3 +0,0 @@ -package connector - -type Metadata map[string]string diff --git a/pkg/components/connector/ss/connector.go b/pkg/components/connector/ss/connector.go index 374ef02..0436e8d 100644 --- a/pkg/components/connector/ss/connector.go +++ b/pkg/components/connector/ss/connector.go @@ -5,6 +5,7 @@ import ( "net" "github.com/go-gost/gost/pkg/components/connector" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" ) @@ -29,13 +30,8 @@ func NewConnector(opts ...connector.Option) connector.Connector { } } -func (c *Connector) Init(md connector.Metadata) (err error) { - c.md, err = c.parseMetadata(md) - if err != nil { - return - } - - return nil +func (c *Connector) Init(md md.Metadata) (err error) { + return c.parseMetadata(md) } func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { @@ -43,13 +39,9 @@ func (c *Connector) Connect(ctx context.Context, conn net.Conn, network, address return conn, nil } -func (c *Connector) parseMetadata(md connector.Metadata) (m metadata, err error) { - if md == nil { - md = connector.Metadata{} - } - - m.method = md[method] - m.password = md[password] +func (c *Connector) parseMetadata(md md.Metadata) (err error) { + c.md.method = md.GetString(method) + c.md.password = md.GetString(password) return } diff --git a/pkg/components/dialer/dialer.go b/pkg/components/dialer/dialer.go index 2394ddc..cab257d 100644 --- a/pkg/components/dialer/dialer.go +++ b/pkg/components/dialer/dialer.go @@ -3,11 +3,13 @@ package dialer import ( "context" "net" + + "github.com/go-gost/gost/pkg/components/metadata" ) // Transporter is responsible for dialing to the proxy server. type Dialer interface { - Init(Metadata) error + Init(metadata.Metadata) error Dial(ctx context.Context, addr string, opts ...DialOption) (net.Conn, error) } diff --git a/pkg/components/dialer/metadata.go b/pkg/components/dialer/metadata.go deleted file mode 100644 index 755bb8a..0000000 --- a/pkg/components/dialer/metadata.go +++ /dev/null @@ -1,3 +0,0 @@ -package dialer - -type Metadata map[string]string diff --git a/pkg/components/dialer/tcp/dialer.go b/pkg/components/dialer/tcp/dialer.go index eae2a91..6845319 100644 --- a/pkg/components/dialer/tcp/dialer.go +++ b/pkg/components/dialer/tcp/dialer.go @@ -5,6 +5,7 @@ import ( "net" "github.com/go-gost/gost/pkg/components/dialer" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" ) @@ -29,12 +30,8 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer { } } -func (d *Dialer) Init(md dialer.Metadata) (err error) { - d.md, err = d.parseMetadata(md) - if err != nil { - return - } - return nil +func (d *Dialer) Init(md md.Metadata) (err error) { + return d.parseMetadata(md) } func (d *Dialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { @@ -52,6 +49,6 @@ func (d *Dialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOptio return netd.DialContext(ctx, "tcp", addr) } -func (d *Dialer) parseMetadata(md dialer.Metadata) (m metadata, err error) { +func (d *Dialer) parseMetadata(md md.Metadata) (err error) { return } diff --git a/pkg/components/handler/handler.go b/pkg/components/handler/handler.go index 0c802c4..93c1cb1 100644 --- a/pkg/components/handler/handler.go +++ b/pkg/components/handler/handler.go @@ -3,9 +3,11 @@ package handler import ( "context" "net" + + "github.com/go-gost/gost/pkg/components/metadata" ) type Handler interface { - Init(Metadata) error + Init(metadata.Metadata) error Handle(context.Context, net.Conn) } diff --git a/pkg/components/handler/http/handler.go b/pkg/components/handler/http/handler.go index 06348ec..181fb0a 100644 --- a/pkg/components/handler/http/handler.go +++ b/pkg/components/handler/http/handler.go @@ -5,9 +5,13 @@ import ( "context" "net" "net/http" + "net/http/httputil" + "strings" + "github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/components/handler" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" ) @@ -34,19 +38,41 @@ func NewHandler(opts ...handler.Option) handler.Handler { } } -func (h *Handler) Init(md handler.Metadata) error { +func (h *Handler) Init(md md.Metadata) error { + return h.parseMetadata(md) +} + +func (h *Handler) parseMetadata(md md.Metadata) error { + h.md.proxyAgent = md.GetString(proxyAgent) + + if v, _ := md.Get(auths).([]interface{}); len(v) > 0 { + authenticator := auth.NewLocalAuthenticator(nil) + for _, auth := range v { + if s, _ := auth.(string); s != "" { + ss := strings.SplitN(s, ":", 2) + if len(ss) == 1 { + authenticator.Add(ss[0], "") + } else { + authenticator.Add(ss[0], ss[1]) + } + } + } + h.md.authenticator = authenticator + } return nil } func (h *Handler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() + h.logger = h.logger.WithFields(map[string]interface{}{ + "src": conn.RemoteAddr(), + "local": conn.LocalAddr(), + }) + req, err := http.ReadRequest(bufio.NewReader(conn)) if err != nil { - h.logger.WithFields(map[string]interface{}{ - "src": conn.RemoteAddr(), - "local": conn.LocalAddr(), - }).Error(err) + h.logger.Error(err) return } defer req.Body.Close() @@ -73,6 +99,14 @@ func (h *Handler) handleRequest(ctx context.Context, conn net.Conn, req *http.Re host = net.JoinHostPort(host, "80") } + h.logger = h.logger.WithFields(map[string]interface{}{ + "dst": host, + }) + + if h.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(req, false) + h.logger.Debug(string(dump)) + } /* u, _, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization")) if u != "" { diff --git a/pkg/components/handler/http/metadata.go b/pkg/components/handler/http/metadata.go index 7893903..e4086fb 100644 --- a/pkg/components/handler/http/metadata.go +++ b/pkg/components/handler/http/metadata.go @@ -1,7 +1,16 @@ package http +import "github.com/go-gost/gost/pkg/auth" + +const ( + addr = "addr" + proxyAgent = "proxyAgent" + auths = "auths" +) + type metadata struct { - addr string - proxyAgent string - retryCount int + addr string + authenticator auth.Authenticator + proxyAgent string + retryCount int } diff --git a/pkg/components/handler/metadata.go b/pkg/components/handler/metadata.go deleted file mode 100644 index 4f37725..0000000 --- a/pkg/components/handler/metadata.go +++ /dev/null @@ -1,3 +0,0 @@ -package handler - -type Metadata map[string]string diff --git a/pkg/components/handler/ss/handler.go b/pkg/components/handler/ss/handler.go index bcdbdfa..bc4cd28 100644 --- a/pkg/components/handler/ss/handler.go +++ b/pkg/components/handler/ss/handler.go @@ -8,6 +8,7 @@ import ( "github.com/go-gost/gosocks5" "github.com/go-gost/gost/pkg/components/handler" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" "github.com/shadowsocks/go-shadowsocks2/core" @@ -34,12 +35,8 @@ func NewHandler(opts ...handler.Option) handler.Handler { } } -func (h *Handler) Init(md handler.Metadata) (err error) { - h.md, err = h.parseMetadata(md) - if err != nil { - return - } - return nil +func (h *Handler) Init(md md.Metadata) (err error) { + return h.parseMetadata(md) } func (h *Handler) Handle(ctx context.Context, conn net.Conn) { @@ -74,14 +71,17 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) { handler.Transport(conn, cc) } -func (h *Handler) parseMetadata(md handler.Metadata) (m metadata, err error) { - m.cipher, err = h.initCipher(md[method], md[password], md[key]) +func (h *Handler) parseMetadata(md md.Metadata) (err error) { + h.md.cipher, err = h.initCipher( + md.GetString(method), + md.GetString(password), + md.GetString(key), + ) if err != nil { return } - if v, ok := md[readTimeout]; ok { - m.readTimeout, _ = time.ParseDuration(v) - } + + h.md.readTimeout = md.GetDuration(readTimeout) return } diff --git a/pkg/components/handler/ssu/handler.go b/pkg/components/handler/ssu/handler.go index fcb70fc..694a6cb 100644 --- a/pkg/components/handler/ssu/handler.go +++ b/pkg/components/handler/ssu/handler.go @@ -3,9 +3,9 @@ package ss import ( "context" "net" - "time" "github.com/go-gost/gost/pkg/components/handler" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" "github.com/shadowsocks/go-shadowsocks2/core" @@ -32,26 +32,26 @@ func NewHandler(opts ...handler.Option) handler.Handler { } } -func (h *Handler) Init(md handler.Metadata) (err error) { - h.md, err = h.parseMetadata(md) - if err != nil { - return - } - return nil +func (h *Handler) Init(md md.Metadata) (err error) { + return h.parseMetadata(md) } func (h *Handler) Handle(ctx context.Context, conn net.Conn) { defer conn.Close() } -func (h *Handler) parseMetadata(md handler.Metadata) (m metadata, err error) { - m.cipher, err = h.initCipher(md[method], md[password], md[key]) +func (h *Handler) parseMetadata(md md.Metadata) (err error) { + h.md.cipher, err = h.initCipher( + md.GetString(method), + md.GetString(password), + md.GetString(key), + ) if err != nil { return } - if v, ok := md[readTimeout]; ok { - m.readTimeout, _ = time.ParseDuration(v) - } + + h.md.readTimeout = md.GetDuration(readTimeout) + return } diff --git a/pkg/components/listener/ftcp/listener.go b/pkg/components/listener/ftcp/listener.go index 0fd0925..cd7f69d 100644 --- a/pkg/components/listener/ftcp/listener.go +++ b/pkg/components/listener/ftcp/listener.go @@ -1,12 +1,12 @@ package ftcp import ( - "errors" "net" "sync" "sync/atomic" "github.com/go-gost/gost/pkg/components/listener" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" "github.com/xtaci/tcpraw" @@ -17,6 +17,7 @@ func init() { } type Listener struct { + addr string md metadata conn net.PacketConn connChan chan net.Conn @@ -31,13 +32,13 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(options) } return &Listener{ + addr: options.Addr, logger: options.Logger, } } -func (l *Listener) Init(md listener.Metadata) (err error) { - l.md, err = l.parseMetadata(md) - if err != nil { +func (l *Listener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { return } @@ -118,14 +119,7 @@ func (l *Listener) listenLoop() { } } -func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { - if val, ok := md[addr]; ok { - m.addr = val - } else { - err = errors.New("missing address") - return - } - +func (l *Listener) parseMetadata(md md.Metadata) (err error) { return } diff --git a/pkg/components/listener/ftcp/metadata.go b/pkg/components/listener/ftcp/metadata.go index 38cedc9..694c856 100644 --- a/pkg/components/listener/ftcp/metadata.go +++ b/pkg/components/listener/ftcp/metadata.go @@ -14,8 +14,7 @@ const ( ) type metadata struct { - addr string - ttl time.Duration + ttl time.Duration readBufferSize int readQueueSize int diff --git a/pkg/components/listener/http2/h2/listener.go b/pkg/components/listener/http2/h2/listener.go index 22efc45..89d1161 100644 --- a/pkg/components/listener/http2/h2/listener.go +++ b/pkg/components/listener/http2/h2/listener.go @@ -9,6 +9,7 @@ import ( "github.com/go-gost/gost/pkg/components/internal/utils" "github.com/go-gost/gost/pkg/components/listener" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" "golang.org/x/net/http2" @@ -19,6 +20,7 @@ func init() { } type Listener struct { + addr string net.Listener md metadata server *http2.Server @@ -33,17 +35,17 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(options) } return &Listener{ + addr: options.Addr, logger: options.Logger, } } -func (l *Listener) Init(md listener.Metadata) (err error) { - l.md, err = l.parseMetadata(md) - if err != nil { +func (l *Listener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { return } - ln, err := net.Listen("tcp", l.md.addr) + ln, err := net.Listen("tcp", l.addr) if err != nil { return } @@ -170,15 +172,12 @@ func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) (*conn, error }, nil } -func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { - if val, ok := md[addr]; ok { - m.addr = val - } else { - err = errors.New("missing address") - return - } - - m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile]) +func (l *Listener) parseMetadata(md md.Metadata) (err error) { + l.md.tlsConfig, err = utils.LoadTLSConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) if err != nil { return } diff --git a/pkg/components/listener/http2/h2/metadata.go b/pkg/components/listener/http2/h2/metadata.go index 30d90f9..4fa94fc 100644 --- a/pkg/components/listener/http2/h2/metadata.go +++ b/pkg/components/listener/http2/h2/metadata.go @@ -7,7 +7,6 @@ import ( ) const ( - addr = "addr" path = "path" certFile = "certFile" keyFile = "keyFile" @@ -24,7 +23,6 @@ const ( ) type metadata struct { - addr string path string tlsConfig *tls.Config handshakeTimeout time.Duration diff --git a/pkg/components/listener/http2/listener.go b/pkg/components/listener/http2/listener.go index c927845..f8c0db4 100644 --- a/pkg/components/listener/http2/listener.go +++ b/pkg/components/listener/http2/listener.go @@ -2,12 +2,12 @@ package http2 import ( "crypto/tls" - "errors" "net" "net/http" "github.com/go-gost/gost/pkg/components/internal/utils" "github.com/go-gost/gost/pkg/components/listener" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" "golang.org/x/net/http2" @@ -18,6 +18,7 @@ func init() { } type Listener struct { + saddr string md metadata server *http.Server addr net.Addr @@ -32,18 +33,18 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(options) } return &Listener{ + saddr: options.Addr, logger: options.Logger, } } -func (l *Listener) Init(md listener.Metadata) (err error) { - l.md, err = l.parseMetadata(md) - if err != nil { +func (l *Listener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { return } l.server = &http.Server{ - Addr: l.md.addr, + Addr: l.saddr, Handler: http.HandlerFunc(l.handleFunc), TLSConfig: l.md.tlsConfig, } @@ -51,7 +52,7 @@ func (l *Listener) Init(md listener.Metadata) (err error) { return err } - ln, err := net.Listen("tcp", addr) + ln, err := net.Listen("tcp", l.saddr) if err != nil { return err } @@ -124,15 +125,12 @@ func (l *Listener) handleFunc(w http.ResponseWriter, r *http.Request) { <-conn.closed } -func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { - if val, ok := md[addr]; ok { - m.addr = val - } else { - err = errors.New("missing address") - return - } - - m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile]) +func (l *Listener) parseMetadata(md md.Metadata) (err error) { + l.md.tlsConfig, err = utils.LoadTLSConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) if err != nil { return } diff --git a/pkg/components/listener/http2/metadata.go b/pkg/components/listener/http2/metadata.go index 88d8fd2..d53f11b 100644 --- a/pkg/components/listener/http2/metadata.go +++ b/pkg/components/listener/http2/metadata.go @@ -7,7 +7,6 @@ import ( ) const ( - addr = "addr" path = "path" certFile = "certFile" keyFile = "keyFile" @@ -24,7 +23,6 @@ const ( ) type metadata struct { - addr string path string tlsConfig *tls.Config handshakeTimeout time.Duration diff --git a/pkg/components/listener/kcp/listener.go b/pkg/components/listener/kcp/listener.go index a80f434..f65b477 100644 --- a/pkg/components/listener/kcp/listener.go +++ b/pkg/components/listener/kcp/listener.go @@ -1,12 +1,12 @@ package kcp import ( - "errors" "net" "time" "github.com/go-gost/gost/pkg/components/internal/utils" "github.com/go-gost/gost/pkg/components/listener" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" "github.com/xtaci/kcp-go/v5" @@ -19,6 +19,7 @@ func init() { } type Listener struct { + addr string md metadata ln *kcp.Listener connChan chan net.Conn @@ -32,13 +33,13 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(options) } return &Listener{ + addr: options.Addr, logger: options.Logger, } } -func (l *Listener) Init(md listener.Metadata) (err error) { - l.md, err = l.parseMetadata(md) - if err != nil { +func (l *Listener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { return } @@ -52,14 +53,14 @@ func (l *Listener) Init(md listener.Metadata) (err error) { if config.TCP { var conn net.PacketConn - conn, err = tcpraw.Listen("tcp", addr) + conn, err = tcpraw.Listen("tcp", l.addr) if err != nil { return } ln, err = kcp.ServeConn( blockCrypt(config.Key, config.Crypt, Salt), config.DataShard, config.ParityShard, conn) } else { - ln, err = kcp.ListenWithOptions(addr, + ln, err = kcp.ListenWithOptions(l.addr, blockCrypt(config.Key, config.Crypt, Salt), config.DataShard, config.ParityShard) } if err != nil { @@ -168,13 +169,6 @@ func (l *Listener) mux(conn net.Conn) { } } -func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { - if val, ok := md[addr]; ok { - m.addr = val - } else { - err = errors.New("missing address") - return - } - +func (l *Listener) parseMetadata(md md.Metadata) (err error) { return } diff --git a/pkg/components/listener/kcp/metadata.go b/pkg/components/listener/kcp/metadata.go index 12bc2f7..977c735 100644 --- a/pkg/components/listener/kcp/metadata.go +++ b/pkg/components/listener/kcp/metadata.go @@ -1,8 +1,6 @@ package kcp const ( - addr = "addr" - connQueueSize = "connQueueSize" ) @@ -11,7 +9,6 @@ const ( ) type metadata struct { - addr string config *Config connQueueSize int diff --git a/pkg/components/listener/listener.go b/pkg/components/listener/listener.go index ea87901..ca83627 100644 --- a/pkg/components/listener/listener.go +++ b/pkg/components/listener/listener.go @@ -3,6 +3,8 @@ package listener import ( "errors" "net" + + "github.com/go-gost/gost/pkg/components/metadata" ) var ( @@ -11,7 +13,7 @@ var ( // Listener is a server listener, just like a net.Listener. type Listener interface { - Init(Metadata) error + Init(metadata.Metadata) error net.Listener } diff --git a/pkg/components/listener/metadata.go b/pkg/components/listener/metadata.go deleted file mode 100644 index 6138db0..0000000 --- a/pkg/components/listener/metadata.go +++ /dev/null @@ -1,3 +0,0 @@ -package listener - -type Metadata map[string]string diff --git a/pkg/components/listener/obfs/http/listener.go b/pkg/components/listener/obfs/http/listener.go index 12457c5..e2661bc 100644 --- a/pkg/components/listener/obfs/http/listener.go +++ b/pkg/components/listener/obfs/http/listener.go @@ -1,13 +1,11 @@ package http import ( - "errors" "net" - "strconv" - "time" "github.com/go-gost/gost/pkg/components/internal/utils" "github.com/go-gost/gost/pkg/components/listener" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" ) @@ -17,7 +15,8 @@ func init() { } type Listener struct { - md metadata + addr string + md metadata net.Listener logger logger.Logger } @@ -28,17 +27,17 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(options) } return &Listener{ + addr: options.Addr, logger: options.Logger, } } -func (l *Listener) Init(md listener.Metadata) (err error) { - l.md, err = l.parseMetadata(md) - if err != nil { +func (l *Listener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { return } - laddr, err := net.ResolveTCPAddr("tcp", l.md.addr) + laddr, err := net.ResolveTCPAddr("tcp", l.addr) if err != nil { return } @@ -68,22 +67,9 @@ func (l *Listener) Accept() (net.Conn, error) { return &conn{Conn: c}, nil } -func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { - if val, ok := md[addr]; ok { - m.addr = val - } else { - err = errors.New("missing address") - return - } - - m.keepAlive = true - if val, ok := md[keepAlive]; ok { - m.keepAlive, _ = strconv.ParseBool(val) - } - - if val, ok := md[keepAlivePeriod]; ok { - m.keepAlivePeriod, _ = time.ParseDuration(val) - } +func (l *Listener) parseMetadata(md md.Metadata) (err error) { + l.md.keepAlive = md.GetBool(keepAlive) + l.md.keepAlivePeriod = md.GetDuration(keepAlivePeriod) return } diff --git a/pkg/components/listener/obfs/http/metadata.go b/pkg/components/listener/obfs/http/metadata.go index a96b8dd..76079f6 100644 --- a/pkg/components/listener/obfs/http/metadata.go +++ b/pkg/components/listener/obfs/http/metadata.go @@ -3,7 +3,6 @@ package http import "time" const ( - addr = "addr" keepAlive = "keepAlive" keepAlivePeriod = "keepAlivePeriod" ) @@ -13,7 +12,6 @@ const ( ) type metadata struct { - addr string keepAlive bool keepAlivePeriod time.Duration } diff --git a/pkg/components/listener/obfs/tls/listener.go b/pkg/components/listener/obfs/tls/listener.go index e751540..c58c773 100644 --- a/pkg/components/listener/obfs/tls/listener.go +++ b/pkg/components/listener/obfs/tls/listener.go @@ -1,13 +1,11 @@ package tls import ( - "errors" "net" - "strconv" - "time" "github.com/go-gost/gost/pkg/components/internal/utils" "github.com/go-gost/gost/pkg/components/listener" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" ) @@ -17,7 +15,8 @@ func init() { } type Listener struct { - md metadata + addr string + md metadata net.Listener logger logger.Logger } @@ -28,17 +27,17 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(options) } return &Listener{ + addr: options.Addr, logger: options.Logger, } } -func (l *Listener) Init(md listener.Metadata) (err error) { - l.md, err = l.parseMetadata(md) - if err != nil { +func (l *Listener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { return } - laddr, err := net.ResolveTCPAddr("tcp", l.md.addr) + laddr, err := net.ResolveTCPAddr("tcp", l.addr) if err != nil { return } @@ -68,22 +67,9 @@ func (l *Listener) Accept() (net.Conn, error) { return &conn{Conn: c}, nil } -func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { - if val, ok := md[addr]; ok { - m.addr = val - } else { - err = errors.New("missing address") - return - } - - m.keepAlive = true - if val, ok := md[keepAlive]; ok { - m.keepAlive, _ = strconv.ParseBool(val) - } - - if val, ok := md[keepAlivePeriod]; ok { - m.keepAlivePeriod, _ = time.ParseDuration(val) - } +func (l *Listener) parseMetadata(md md.Metadata) (err error) { + l.md.keepAlive = md.GetBool(keepAlive) + l.md.keepAlivePeriod = md.GetDuration(keepAlivePeriod) return } diff --git a/pkg/components/listener/obfs/tls/metadata.go b/pkg/components/listener/obfs/tls/metadata.go index 5e01136..74d5b44 100644 --- a/pkg/components/listener/obfs/tls/metadata.go +++ b/pkg/components/listener/obfs/tls/metadata.go @@ -3,7 +3,6 @@ package tls import "time" const ( - addr = "addr" keepAlive = "keepAlive" keepAlivePeriod = "keepAlivePeriod" ) @@ -13,7 +12,6 @@ const ( ) type metadata struct { - addr string keepAlive bool keepAlivePeriod time.Duration } diff --git a/pkg/components/listener/quic/listener.go b/pkg/components/listener/quic/listener.go index 24dd5f8..95f44d8 100644 --- a/pkg/components/listener/quic/listener.go +++ b/pkg/components/listener/quic/listener.go @@ -2,11 +2,11 @@ package quic import ( "context" - "errors" "net" "github.com/go-gost/gost/pkg/components/internal/utils" "github.com/go-gost/gost/pkg/components/listener" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" "github.com/lucas-clemente/quic-go" @@ -17,6 +17,7 @@ func init() { } type Listener struct { + addr string md metadata ln quic.Listener connChan chan net.Conn @@ -30,17 +31,17 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(options) } return &Listener{ + addr: options.Addr, logger: options.Logger, } } -func (l *Listener) Init(md listener.Metadata) (err error) { - l.md, err = l.parseMetadata(md) - if err != nil { +func (l *Listener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { return } - laddr, err := net.ResolveUDPAddr("udp", l.md.addr) + laddr, err := net.ResolveUDPAddr("udp", l.addr) if err != nil { return } @@ -131,13 +132,7 @@ func (l *Listener) mux(ctx context.Context, session quic.Session) { } } -func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { - if val, ok := md[addr]; ok { - m.addr = val - } else { - err = errors.New("missing address") - return - } +func (l *Listener) parseMetadata(md md.Metadata) (err error) { return } diff --git a/pkg/components/listener/quic/metadata.go b/pkg/components/listener/quic/metadata.go index 9b41637..3b84e80 100644 --- a/pkg/components/listener/quic/metadata.go +++ b/pkg/components/listener/quic/metadata.go @@ -6,8 +6,6 @@ import ( ) const ( - addr = "addr" - certFile = "certFile" keyFile = "keyFile" caFile = "caFile" @@ -21,7 +19,6 @@ const ( ) type metadata struct { - addr string tlsConfig *tls.Config keepAlive bool HandshakeTimeout time.Duration diff --git a/pkg/components/listener/tcp/listener.go b/pkg/components/listener/tcp/listener.go index 0669d42..cbd6f66 100644 --- a/pkg/components/listener/tcp/listener.go +++ b/pkg/components/listener/tcp/listener.go @@ -2,11 +2,10 @@ package tcp import ( "net" - "strconv" - "time" "github.com/go-gost/gost/pkg/components/internal/utils" "github.com/go-gost/gost/pkg/components/listener" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" ) @@ -33,9 +32,8 @@ func NewListener(opts ...listener.Option) listener.Listener { } } -func (l *Listener) Init(md listener.Metadata) (err error) { - l.md, err = l.parseMetadata(md) - if err != nil { +func (l *Listener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { return } @@ -57,18 +55,13 @@ func (l *Listener) Init(md listener.Metadata) (err error) { } l.Listener = ln + l.logger.Info("listening on:", l.Listener.Addr()) return } -func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { - m.keepAlive = true - if val, ok := md[keepAlive]; ok { - m.keepAlive, _ = strconv.ParseBool(val) - } - - if val, ok := md[keepAlivePeriod]; ok { - m.keepAlivePeriod, _ = time.ParseDuration(val) - } +func (l *Listener) parseMetadata(md md.Metadata) (err error) { + l.md.keepAlive = md.GetBool(keepAlive) + l.md.keepAlivePeriod = md.GetDuration(keepAlivePeriod) return } diff --git a/pkg/components/listener/tls/listener.go b/pkg/components/listener/tls/listener.go index 171915b..60bbf32 100644 --- a/pkg/components/listener/tls/listener.go +++ b/pkg/components/listener/tls/listener.go @@ -2,12 +2,11 @@ package tls import ( "crypto/tls" - "errors" "net" - "time" "github.com/go-gost/gost/pkg/components/internal/utils" "github.com/go-gost/gost/pkg/components/listener" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" ) @@ -17,7 +16,8 @@ func init() { } type Listener struct { - md metadata + addr string + md metadata net.Listener logger logger.Logger } @@ -28,17 +28,17 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(options) } return &Listener{ + addr: options.Addr, logger: options.Logger, } } -func (l *Listener) Init(md listener.Metadata) (err error) { - l.md, err = l.parseMetadata(md) - if err != nil { +func (l *Listener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { return } - ln, err := net.Listen("tcp", l.md.addr) + ln, err := net.Listen("tcp", l.addr) if err != nil { return } @@ -55,22 +55,16 @@ func (l *Listener) Init(md listener.Metadata) (err error) { return } -func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { - if val, ok := md[addr]; ok { - m.addr = val - } else { - err = errors.New("missing address") - return - } - - m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile]) +func (l *Listener) parseMetadata(md md.Metadata) (err error) { + l.md.tlsConfig, err = utils.LoadTLSConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) if err != nil { return } - if val, ok := md[keepAlivePeriod]; ok { - m.keepAlivePeriod, _ = time.ParseDuration(val) - } - + l.md.keepAlivePeriod = md.GetDuration(keepAlivePeriod) return } diff --git a/pkg/components/listener/tls/metadata.go b/pkg/components/listener/tls/metadata.go index 6b9d92d..d544654 100644 --- a/pkg/components/listener/tls/metadata.go +++ b/pkg/components/listener/tls/metadata.go @@ -6,7 +6,6 @@ import ( ) const ( - addr = "addr" certFile = "certFile" keyFile = "keyFile" caFile = "caFile" @@ -14,7 +13,6 @@ const ( ) type metadata struct { - addr string tlsConfig *tls.Config keepAlivePeriod time.Duration } diff --git a/pkg/components/listener/tls/mux/listener.go b/pkg/components/listener/tls/mux/listener.go index 3ae0db1..747ed44 100644 --- a/pkg/components/listener/tls/mux/listener.go +++ b/pkg/components/listener/tls/mux/listener.go @@ -2,11 +2,11 @@ package mux import ( "crypto/tls" - "errors" "net" "github.com/go-gost/gost/pkg/components/internal/utils" "github.com/go-gost/gost/pkg/components/listener" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" "github.com/xtaci/smux" @@ -17,7 +17,8 @@ func init() { } type Listener struct { - md metadata + addr string + md metadata net.Listener connChan chan net.Conn errChan chan error @@ -30,17 +31,17 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(options) } return &Listener{ + addr: options.Addr, logger: options.Logger, } } -func (l *Listener) Init(md listener.Metadata) (err error) { - l.md, err = l.parseMetadata(md) - if err != nil { +func (l *Listener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { return } - ln, err := net.Listen("tcp", l.md.addr) + ln, err := net.Listen("tcp", l.addr) if err != nil { return } @@ -125,15 +126,12 @@ func (l *Listener) Accept() (conn net.Conn, err error) { return } -func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { - if val, ok := md[addr]; ok { - m.addr = val - } else { - err = errors.New("missing address") - return - } - - m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile]) +func (l *Listener) parseMetadata(md md.Metadata) (err error) { + l.md.tlsConfig, err = utils.LoadTLSConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) if err != nil { return } diff --git a/pkg/components/listener/tls/mux/metadata.go b/pkg/components/listener/tls/mux/metadata.go index 7c65039..a5d0cec 100644 --- a/pkg/components/listener/tls/mux/metadata.go +++ b/pkg/components/listener/tls/mux/metadata.go @@ -6,7 +6,6 @@ import ( ) const ( - addr = "addr" certFile = "certFile" keyFile = "keyFile" caFile = "caFile" @@ -24,7 +23,6 @@ const ( ) type metadata struct { - addr string tlsConfig *tls.Config muxKeepAliveDisabled bool diff --git a/pkg/components/listener/udp/listener.go b/pkg/components/listener/udp/listener.go index 67c75f6..1181c01 100644 --- a/pkg/components/listener/udp/listener.go +++ b/pkg/components/listener/udp/listener.go @@ -1,12 +1,12 @@ package udp import ( - "errors" "net" "sync" "sync/atomic" "github.com/go-gost/gost/pkg/components/listener" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" ) @@ -16,6 +16,7 @@ func init() { } type Listener struct { + addr string md metadata conn net.PacketConn connChan chan net.Conn @@ -30,17 +31,17 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(options) } return &Listener{ + addr: options.Addr, logger: options.Logger, } } -func (l *Listener) Init(md listener.Metadata) (err error) { - l.md, err = l.parseMetadata(md) - if err != nil { +func (l *Listener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { return } - laddr, err := net.ResolveUDPAddr("udp", l.md.addr) + laddr, err := net.ResolveUDPAddr("udp", l.addr) if err != nil { return } @@ -124,14 +125,7 @@ func (l *Listener) listenLoop() { } } -func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { - if val, ok := md[addr]; ok { - m.addr = val - } else { - err = errors.New("missing address") - return - } - +func (l *Listener) parseMetadata(md md.Metadata) (err error) { return } diff --git a/pkg/components/listener/udp/metadata.go b/pkg/components/listener/udp/metadata.go index 44f7035..ec01921 100644 --- a/pkg/components/listener/udp/metadata.go +++ b/pkg/components/listener/udp/metadata.go @@ -9,13 +9,8 @@ const ( defaultConnQueueSize = 128 ) -const ( - addr = "addr" -) - type metadata struct { - addr string - ttl time.Duration + ttl time.Duration readBufferSize int readQueueSize int diff --git a/pkg/components/listener/ws/listener.go b/pkg/components/listener/ws/listener.go index 1481a6d..43cbc4d 100644 --- a/pkg/components/listener/ws/listener.go +++ b/pkg/components/listener/ws/listener.go @@ -2,12 +2,12 @@ package ws import ( "crypto/tls" - "errors" "net" "net/http" "github.com/go-gost/gost/pkg/components/internal/utils" "github.com/go-gost/gost/pkg/components/listener" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" "github.com/gorilla/websocket" @@ -19,6 +19,7 @@ func init() { } type Listener struct { + saddr string md metadata addr net.Addr upgrader *websocket.Upgrader @@ -34,13 +35,13 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(options) } return &Listener{ + saddr: options.Addr, logger: options.Logger, } } -func (l *Listener) Init(md listener.Metadata) (err error) { - l.md, err = l.parseMetadata(md) - if err != nil { +func (l *Listener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { return } @@ -59,7 +60,7 @@ func (l *Listener) Init(md listener.Metadata) (err error) { mux := http.NewServeMux() mux.Handle(path, http.HandlerFunc(l.upgrade)) l.srv = &http.Server{ - Addr: l.md.addr, + Addr: l.saddr, TLSConfig: l.md.tlsConfig, Handler: mux, ReadHeaderTimeout: l.md.readHeaderTimeout, @@ -72,7 +73,7 @@ func (l *Listener) Init(md listener.Metadata) (err error) { l.connChan = make(chan net.Conn, queueSize) l.errChan = make(chan error, 1) - ln, err := net.Listen("tcp", l.md.addr) + ln, err := net.Listen("tcp", l.saddr) if err != nil { return } @@ -113,15 +114,12 @@ func (l *Listener) Addr() net.Addr { return l.addr } -func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { - if val, ok := md[addr]; ok { - m.addr = val - } else { - err = errors.New("missing address") - return - } - - m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile]) +func (l *Listener) parseMetadata(md md.Metadata) (err error) { + l.md.tlsConfig, err = utils.LoadTLSConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) if err != nil { return } diff --git a/pkg/components/listener/ws/metadata.go b/pkg/components/listener/ws/metadata.go index 017195a..95f81e1 100644 --- a/pkg/components/listener/ws/metadata.go +++ b/pkg/components/listener/ws/metadata.go @@ -7,7 +7,6 @@ import ( ) const ( - addr = "addr" path = "path" certFile = "certFile" keyFile = "keyFile" @@ -27,7 +26,6 @@ const ( ) type metadata struct { - addr string path string tlsConfig *tls.Config handshakeTimeout time.Duration diff --git a/pkg/components/listener/ws/mux/listener.go b/pkg/components/listener/ws/mux/listener.go index ae9ff96..2b99c17 100644 --- a/pkg/components/listener/ws/mux/listener.go +++ b/pkg/components/listener/ws/mux/listener.go @@ -2,12 +2,12 @@ package mux import ( "crypto/tls" - "errors" "net" "net/http" "github.com/go-gost/gost/pkg/components/internal/utils" "github.com/go-gost/gost/pkg/components/listener" + md "github.com/go-gost/gost/pkg/components/metadata" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/registry" "github.com/gorilla/websocket" @@ -20,6 +20,7 @@ func init() { } type Listener struct { + saddr string md metadata addr net.Addr upgrader *websocket.Upgrader @@ -39,9 +40,8 @@ func NewListener(opts ...listener.Option) listener.Listener { } } -func (l *Listener) Init(md listener.Metadata) (err error) { - l.md, err = l.parseMetadata(md) - if err != nil { +func (l *Listener) Init(md md.Metadata) (err error) { + if err = l.parseMetadata(md); err != nil { return } @@ -60,7 +60,7 @@ func (l *Listener) Init(md listener.Metadata) (err error) { mux := http.NewServeMux() mux.Handle(path, http.HandlerFunc(l.upgrade)) l.srv = &http.Server{ - Addr: l.md.addr, + Addr: l.saddr, TLSConfig: l.md.tlsConfig, Handler: mux, ReadHeaderTimeout: l.md.readHeaderTimeout, @@ -69,7 +69,7 @@ func (l *Listener) Init(md listener.Metadata) (err error) { l.connChan = make(chan net.Conn, l.md.connQueueSize) l.errChan = make(chan error, 1) - ln, err := net.Listen("tcp", l.md.addr) + ln, err := net.Listen("tcp", l.saddr) if err != nil { return } @@ -110,15 +110,12 @@ func (l *Listener) Addr() net.Addr { return l.addr } -func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { - if val, ok := md[addr]; ok { - m.addr = val - } else { - err = errors.New("missing address") - return - } - - m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile]) +func (l *Listener) parseMetadata(md md.Metadata) (err error) { + l.md.tlsConfig, err = utils.LoadTLSConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) if err != nil { return } diff --git a/pkg/components/listener/ws/mux/metadata.go b/pkg/components/listener/ws/mux/metadata.go index 89d89c0..8233ec3 100644 --- a/pkg/components/listener/ws/mux/metadata.go +++ b/pkg/components/listener/ws/mux/metadata.go @@ -7,7 +7,6 @@ import ( ) const ( - addr = "addr" path = "path" certFile = "certFile" keyFile = "keyFile" @@ -34,7 +33,6 @@ const ( ) type metadata struct { - addr string path string tlsConfig *tls.Config handshakeTimeout time.Duration diff --git a/pkg/components/metadata/metadata.go b/pkg/components/metadata/metadata.go new file mode 100644 index 0000000..698f36f --- /dev/null +++ b/pkg/components/metadata/metadata.go @@ -0,0 +1,61 @@ +package metadata + +import "time" + +type Metadata interface { + Get(key string) interface{} + GetBool(key string) bool + GetInt(key string) int + GetFloat(key string) float64 + GetString(key string) string + GetDuration(key string) time.Duration +} + +type MapMetadata map[string]interface{} + +func (m MapMetadata) Get(key string) interface{} { + if m != nil { + return m[key] + } + return nil +} + +func (m MapMetadata) GetBool(key string) (v bool) { + if m != nil { + v, _ = m[key].(bool) + } + return +} + +func (m MapMetadata) GetInt(key string) (v int) { + if m != nil { + v, _ = m[key].(int) + } + return +} + +func (m MapMetadata) GetFloat(key string) (v float64) { + if m != nil { + v, _ = m[key].(float64) + } + return +} + +func (m MapMetadata) GetString(key string) (v string) { + if m != nil { + v, _ = m[key].(string) + } + return +} + +func (m MapMetadata) GetDuration(key string) (v time.Duration) { + if m != nil { + switch vv := m[key].(type) { + case int: + return time.Duration(vv) * time.Second + case string: + v, _ = time.ParseDuration(vv) + } + } + return +} diff --git a/pkg/config/config.go b/pkg/config/config.go index e8d1a20..6602f83 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -18,6 +18,7 @@ func init() { } type LogConfig struct { + Output string Level string Format string } @@ -29,22 +30,22 @@ type LoadbalancingConfig struct { type ListenerConfig struct { Type string - Metadata map[string]string + Metadata map[string]interface{} } type HandlerConfig struct { Type string - Metadata map[string]string + Metadata map[string]interface{} } type DialerConfig struct { Type string - Metadata map[string]string + Metadata map[string]interface{} } type ConnectorConfig struct { Type string - Metadata map[string]string + Metadata map[string]interface{} } type ServiceConfig struct { diff --git a/pkg/logger/gost_logger.go b/pkg/logger/gost_logger.go index 5598298..4783065 100644 --- a/pkg/logger/gost_logger.go +++ b/pkg/logger/gost_logger.go @@ -1,43 +1,17 @@ package logger import ( - "os" + "fmt" + "path/filepath" + "runtime" "github.com/sirupsen/logrus" ) -var ( - _ Logger = (*logger)(nil) -) - type logger struct { logger *logrus.Entry } -func newLogger(name string) *logger { - l := logrus.New() - l.SetOutput(os.Stdout) - - gl := &logger{ - logger: l.WithFields(logrus.Fields{ - logFieldScope: name, - }), - } - - return gl -} - -// EnableJSONOutput enables JSON formatted output log. -func (l *logger) EnableJSONOutput(enabled bool) { - l.logger.Logger.SetFormatter(&logrus.JSONFormatter{}) -} - -// SetOutputLevel sets log output level -func (l *logger) SetLevel(level LogLevel) { - lvl, _ := logrus.ParseLevel(string(level)) - l.logger.Logger.SetLevel(lvl) -} - // WithFields adds new fields to log. func (l *logger) WithFields(fields map[string]interface{}) Logger { return &logger{ @@ -47,50 +21,85 @@ func (l *logger) WithFields(fields map[string]interface{}) Logger { // Debug logs a message at level Debug. func (l *logger) Debug(args ...interface{}) { - l.logger.Debug(args...) + l.log(logrus.DebugLevel, args...) } // Debugf logs a message at level Debug. func (l *logger) Debugf(format string, args ...interface{}) { - l.logger.Debugf(format, args...) + l.logf(logrus.DebugLevel, format, args...) } // Info logs a message at level Info. func (l *logger) Info(args ...interface{}) { - l.logger.Info(args...) + l.log(logrus.InfoLevel, args...) } // Infof logs a message at level Info. func (l *logger) Infof(format string, args ...interface{}) { - l.logger.Infof(format, args...) + l.logf(logrus.InfoLevel, format, args...) } // Warn logs a message at level Warn. func (l *logger) Warn(args ...interface{}) { - l.logger.Warn(args...) + l.log(logrus.WarnLevel, args...) } // Warnf logs a message at level Warn. func (l *logger) Warnf(format string, args ...interface{}) { - l.logger.Warnf(format, args...) + l.logf(logrus.WarnLevel, format, args...) } // Error logs a message at level Error. func (l *logger) Error(args ...interface{}) { - l.logger.Error(args...) + l.log(logrus.ErrorLevel, args...) } // Errorf logs a message at level Error. func (l *logger) Errorf(format string, args ...interface{}) { - l.logger.Errorf(format, args...) + l.logf(logrus.ErrorLevel, format, args...) } // Fatal logs a message at level Fatal then the process will exit with status set to 1. func (l *logger) Fatal(args ...interface{}) { - l.logger.Fatal(args...) + l.log(logrus.FatalLevel, args...) } // Fatalf logs a message at level Fatal then the process will exit with status set to 1. func (l *logger) Fatalf(format string, args ...interface{}) { - l.logger.Fatalf(format, args...) + l.logf(logrus.FatalLevel, format, args...) +} + +func (l *logger) GetLevel() LogLevel { + return LogLevel(l.logger.Logger.GetLevel().String()) +} + +func (l *logger) IsLevelEnabled(level LogLevel) bool { + lvl, _ := logrus.ParseLevel(string(level)) + return l.logger.Logger.IsLevelEnabled(lvl) +} + +func (l *logger) log(level logrus.Level, args ...interface{}) { + lg := l.logger + if l.logger.Logger.IsLevelEnabled(logrus.DebugLevel) { + lg = lg.WithField("caller", l.caller(3)) + } + lg.Log(level, args...) +} + +func (l *logger) logf(level logrus.Level, format string, args ...interface{}) { + lg := l.logger + if l.logger.Logger.IsLevelEnabled(logrus.DebugLevel) { + lg = lg.WithField("caller", l.caller(3)) + } + lg.Logf(level, format, args...) +} + +func (l *logger) caller(skip int) string { + _, file, line, ok := runtime.Caller(skip) + if !ok { + file = "" + } else { + file = filepath.Join(filepath.Base(filepath.Dir(file)), filepath.Base(file)) + } + return fmt.Sprintf("%s:%d", file, line) } diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index c1e05b4..3a44f2b 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -1,9 +1,17 @@ package logger -import "sync" +import ( + "io" + + "github.com/sirupsen/logrus" +) + +// LogFormat is format type +type LogFormat string const ( - logFieldScope = "scope" + TextFormat LogFormat = "text" + JSONFormat LogFormat = "json" ) // LogLevel is Logger Level type @@ -22,14 +30,7 @@ const ( FatalLevel LogLevel = "fatal" ) -var ( - globalLoggers = make(map[string]Logger) - globalLoggersLock sync.RWMutex -) - type Logger interface { - EnableJSONOutput(enabled bool) - SetLevel(level LogLevel) WithFields(map[string]interface{}) Logger Debug(args ...interface{}) Debugf(format string, args ...interface{}) @@ -41,17 +42,65 @@ type Logger interface { Errorf(format string, args ...interface{}) Fatal(args ...interface{}) Fatalf(format string, args ...interface{}) + GetLevel() LogLevel + IsLevelEnabled(level LogLevel) bool } -func NewLogger(name string) Logger { - globalLoggersLock.Lock() - defer globalLoggersLock.Unlock() +type LoggerOptions struct { + Output io.Writer + Format LogFormat + Level LogLevel +} - logger, ok := globalLoggers[name] - if !ok { - logger = newLogger(name) - globalLoggers[name] = logger +type LoggerOption func(opts *LoggerOptions) + +func OutputLoggerOption(out io.Writer) LoggerOption { + return func(opts *LoggerOptions) { + opts.Output = out + } +} + +func FormatLoggerOption(format LogFormat) LoggerOption { + return func(opts *LoggerOptions) { + opts.Format = format + } +} + +func LevelLoggerOption(level LogLevel) LoggerOption { + return func(opts *LoggerOptions) { + opts.Level = level + } +} + +func NewLogger(opts ...LoggerOption) Logger { + var options LoggerOptions + for _, opt := range opts { + opt(&options) } - return logger + log := logrus.New() + if options.Output != nil { + log.SetOutput(options.Output) + } + + switch options.Format { + case JSONFormat: + log.SetFormatter(&logrus.JSONFormatter{}) + default: + log.SetFormatter(&logrus.TextFormatter{ + FullTimestamp: true, + }) + } + + switch options.Level { + case DebugLevel, InfoLevel, WarnLevel, ErrorLevel, FatalLevel: + lvl, _ := logrus.ParseLevel(string(options.Level)) + log.SetLevel(lvl) + default: + log.SetLevel(logrus.InfoLevel) + } + + return &logger{ + logger: logrus.NewEntry(log), + } }