From d6f8ec5116eb8bd7472c0381645bf964299ed27d Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Mon, 11 Apr 2022 00:03:04 +0800 Subject: [PATCH] add file and redis loader --- admission/admission.go | 194 ++++++++++++++++++++---- auth/auth.go | 200 ++++++++++++++++++++++++- bypass/bypass.go | 177 ++++++++++++++++++++-- config/config.go | 43 ++++-- config/parsing/parse.go | 76 ++++++++-- connector/relay/connector.go | 12 +- go.mod | 2 + go.sum | 6 +- handler/relay/connect.go | 8 +- internal/loader/file.go | 30 ++++ internal/loader/loader.go | 11 ++ internal/loader/redis.go | 124 +++++++++++++++ internal/{util => }/matcher/matcher.go | 0 registry/registry.go | 8 +- 14 files changed, 805 insertions(+), 86 deletions(-) create mode 100644 internal/loader/file.go create mode 100644 internal/loader/loader.go create mode 100644 internal/loader/redis.go rename internal/{util => }/matcher/matcher.go (100%) diff --git a/admission/admission.go b/admission/admission.go index 32cbb64..19817c9 100644 --- a/admission/admission.go +++ b/admission/admission.go @@ -1,19 +1,61 @@ package admission import ( + "bufio" + "context" + "io" "net" + "strings" + "sync" + "time" admission_pkg "github.com/go-gost/core/admission" "github.com/go-gost/core/logger" - "github.com/go-gost/x/internal/util/matcher" + "github.com/go-gost/x/internal/loader" + "github.com/go-gost/x/internal/matcher" ) type options struct { - logger logger.Logger + reverse bool + matchers []string + fileLoader loader.Loader + redisLoader loader.Loader + period time.Duration + logger logger.Logger } type Option func(opts *options) +func ReverseOption(reverse bool) Option { + return func(opts *options) { + opts.reverse = reverse + } +} + +func MatchersOption(matchers []string) Option { + return func(opts *options) { + opts.matchers = matchers + } +} + +func ReloadPeriodOption(period time.Duration) Option { + return func(opts *options) { + opts.period = period + } +} + +func FileLoaderOption(fileLoader loader.Loader) Option { + return func(opts *options) { + opts.fileLoader = fileLoader + } +} + +func RedisLoaderOption(redisLoader loader.Loader) Option { + return func(opts *options) { + opts.redisLoader = redisLoader + } +} + func LoggerOption(logger logger.Logger) Option { return func(opts *options) { opts.logger = logger @@ -23,18 +65,81 @@ func LoggerOption(logger logger.Logger) Option { type admission struct { ipMatcher matcher.Matcher cidrMatcher matcher.Matcher - reversed bool + mu sync.RWMutex + cancelFunc context.CancelFunc options options } -// NewAdmissionPatterns creates and initializes a new Admission using matcher patterns as its match rules. +// NewAdmission creates and initializes a new Admission using matcher patterns as its match rules. // The rules will be reversed if the reverse is true. -func NewAdmission(reversed bool, patterns []string, opts ...Option) admission_pkg.Admission { +func NewAdmission(opts ...Option) admission_pkg.Admission { var options options for _, opt := range opts { opt(&options) } + ctx, cancel := context.WithCancel(context.TODO()) + p := &admission{ + cancelFunc: cancel, + options: options, + } + + if err := p.reload(ctx); err != nil { + options.logger.Warnf("reload: %v", err) + } + if p.options.period > 0 { + go p.periodReload(ctx) + } + + return p +} + +func (p *admission) Admit(addr string) bool { + if addr == "" || p == nil { + return false + } + + // try to strip the port + if host, _, _ := net.SplitHostPort(addr); host != "" { + addr = host + } + + matched := p.matched(addr) + + b := !p.options.reverse && matched || + p.options.reverse && !matched + return b +} + +func (p *admission) periodReload(ctx context.Context) error { + period := p.options.period + if period < time.Second { + period = time.Second + } + ticker := time.NewTicker(period) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := p.reload(ctx); err != nil { + p.options.logger.Warnf("reload: %v", err) + // return err + } + p.options.logger.Debugf("admission reload done") + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (p *admission) reload(ctx context.Context) error { + v, err := p.load(ctx) + if err != nil { + return err + } + patterns := append(p.options.matchers, v...) + var ips []net.IP var inets []*net.IPNet for _, pattern := range patterns { @@ -47,36 +152,75 @@ func NewAdmission(reversed bool, patterns []string, opts ...Option) admission_pk continue } } - return &admission{ - reversed: reversed, - options: options, - ipMatcher: matcher.IPMatcher(ips), - cidrMatcher: matcher.CIDRMatcher(inets), - } + + p.mu.Lock() + defer p.mu.Unlock() + + p.ipMatcher = matcher.IPMatcher(ips) + p.cidrMatcher = matcher.CIDRMatcher(inets) + + return nil } -func (p *admission) Admit(addr string) bool { - if addr == "" || p == nil { - p.options.logger.Debugf("admission: %v is denied", addr) - return false +func (p *admission) load(ctx context.Context) (patterns []string, err error) { + if p.options.fileLoader != nil { + r, er := p.options.fileLoader.Load(ctx) + if er != nil { + p.options.logger.Warnf("file loader: %v", er) + } + if v, _ := p.parsePatterns(r); v != nil { + patterns = append(patterns, v...) + } + } + if p.options.redisLoader != nil { + r, er := p.options.redisLoader.Load(ctx) + if er != nil { + p.options.logger.Warnf("redis loader: %v", er) + } + if v, _ := p.parsePatterns(r); v != nil { + patterns = append(patterns, v...) + } } - // try to strip the port - if host, _, _ := net.SplitHostPort(addr); host != "" { - addr = host + return +} + +func (p *admission) parsePatterns(r io.Reader) (patterns []string, err error) { + if r == nil { + return } - matched := p.matched(addr) - - b := !p.reversed && matched || - p.reversed && !matched - if !b { - p.options.logger.Debugf("admission: %v is denied", addr) + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + line = strings.TrimSpace(line) + if line != "" { + patterns = append(patterns, line) + } } - return b + + err = scanner.Err() + return } func (p *admission) matched(addr string) bool { + p.mu.RLock() + defer p.mu.RUnlock() + return p.ipMatcher.Match(addr) || p.cidrMatcher.Match(addr) } + +func (p *admission) Close() error { + p.cancelFunc() + if p.options.fileLoader != nil { + p.options.fileLoader.Close() + } + if p.options.redisLoader != nil { + p.options.redisLoader.Close() + } + return nil +} diff --git a/auth/auth.go b/auth/auth.go index e95f287..b8a3d9e 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -1,27 +1,213 @@ package auth import ( + "bufio" + "context" + "io" + "strings" + "sync" + "time" + "github.com/go-gost/core/auth" + "github.com/go-gost/core/logger" + "github.com/go-gost/x/internal/loader" ) +type options struct { + auths map[string]string + fileLoader loader.Loader + redisLoader loader.Loader + period time.Duration + logger logger.Logger +} + +type Option func(opts *options) + +func AuthsPeriodOption(auths map[string]string) Option { + return func(opts *options) { + opts.auths = auths + } +} + +func ReloadPeriodOption(period time.Duration) Option { + return func(opts *options) { + opts.period = period + } +} + +func FileLoaderOption(fileLoader loader.Loader) Option { + return func(opts *options) { + opts.fileLoader = fileLoader + } +} + +func RedisLoaderOption(redisLoader loader.Loader) Option { + return func(opts *options) { + opts.redisLoader = redisLoader + } +} + +func LoggerOption(logger logger.Logger) Option { + return func(opts *options) { + opts.logger = logger + } +} + // authenticator is an Authenticator that authenticates client by key-value pairs. type authenticator struct { - kvs map[string]string + kvs map[string]string + mu sync.RWMutex + cancelFunc context.CancelFunc + options options } // NewAuthenticator creates an Authenticator that authenticates client by pre-defined user mapping. -func NewAuthenticator(kvs map[string]string) auth.Authenticator { - return &authenticator{ - kvs: kvs, +func NewAuthenticator(opts ...Option) auth.Authenticator { + var options options + for _, opt := range opts { + opt(&options) } + + ctx, cancel := context.WithCancel(context.TODO()) + p := &authenticator{ + kvs: make(map[string]string), + cancelFunc: cancel, + options: options, + } + + if err := p.reload(ctx); err != nil { + options.logger.Warnf("reload: %v", err) + } + if p.options.period > 0 { + go p.periodReload(ctx) + } + + return p } // Authenticate checks the validity of the provided user-password pair. -func (au *authenticator) Authenticate(user, password string) bool { - if au == nil || len(au.kvs) == 0 { +func (p *authenticator) Authenticate(user, password string) bool { + if p == nil || len(p.kvs) == 0 { return true } - v, ok := au.kvs[user] + p.mu.RLock() + defer p.mu.RUnlock() + + v, ok := p.kvs[user] return ok && (v == "" || password == v) } + +func (p *authenticator) periodReload(ctx context.Context) error { + period := p.options.period + if period < time.Second { + period = time.Second + } + ticker := time.NewTicker(period) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := p.reload(ctx); err != nil { + p.options.logger.Warnf("reload: %v", err) + // return err + } + p.options.logger.Debugf("auther reload done") + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (p *authenticator) reload(ctx context.Context) error { + kvs := make(map[string]string) + for k, v := range p.options.auths { + kvs[k] = v + } + + m, err := p.load(ctx) + if err != nil { + return err + } + for k, v := range m { + kvs[k] = v + } + + p.mu.Lock() + defer p.mu.Unlock() + + p.kvs = kvs + + return nil +} + +func (p *authenticator) load(ctx context.Context) (m map[string]string, err error) { + m = make(map[string]string) + + if p.options.fileLoader != nil { + r, er := p.options.fileLoader.Load(ctx) + if er != nil { + p.options.logger.Warnf("file loader: %v", er) + } + if auths, _ := p.parseAuths(r); auths != nil { + for k, v := range auths { + m[k] = v + } + } + } + if p.options.redisLoader != nil { + r, er := p.options.redisLoader.Load(ctx) + if er != nil { + p.options.logger.Warnf("redis loader: %v", er) + } + if auths, _ := p.parseAuths(r); auths != nil { + for k, v := range auths { + m[k] = v + } + } + } + + return +} + +func (p *authenticator) parseAuths(r io.Reader) (auths map[string]string, err error) { + if r == nil { + return + } + + auths = make(map[string]string) + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + sp := strings.SplitN(strings.TrimSpace(line), " ", 2) + if len(sp) == 1 { + if k := strings.TrimSpace(sp[0]); k != "" { + auths[k] = "" + } + } + if len(sp) == 2 { + if k := strings.TrimSpace(sp[0]); k != "" { + auths[k] = strings.TrimSpace(sp[1]) + } + } + } + + err = scanner.Err() + return +} + +func (p *authenticator) Close() error { + p.cancelFunc() + if p.options.fileLoader != nil { + p.options.fileLoader.Close() + } + if p.options.redisLoader != nil { + p.options.redisLoader.Close() + } + return nil +} diff --git a/bypass/bypass.go b/bypass/bypass.go index 45a7450..c60a545 100644 --- a/bypass/bypass.go +++ b/bypass/bypass.go @@ -1,20 +1,61 @@ package bypass import ( + "bufio" + "context" + "io" "net" "strings" + "sync" + "time" bypass_pkg "github.com/go-gost/core/bypass" "github.com/go-gost/core/logger" - "github.com/go-gost/x/internal/util/matcher" + "github.com/go-gost/x/internal/loader" + "github.com/go-gost/x/internal/matcher" ) type options struct { - logger logger.Logger + reverse bool + matchers []string + fileLoader loader.Loader + redisLoader loader.Loader + period time.Duration + logger logger.Logger } type Option func(opts *options) +func ReverseOption(reverse bool) Option { + return func(opts *options) { + opts.reverse = reverse + } +} + +func MatchersOption(matchers []string) Option { + return func(opts *options) { + opts.matchers = matchers + } +} + +func ReloadPeriodOption(period time.Duration) Option { + return func(opts *options) { + opts.period = period + } +} + +func FileLoaderOption(fileLoader loader.Loader) Option { + return func(opts *options) { + opts.fileLoader = fileLoader + } +} + +func RedisLoaderOption(redisLoader loader.Loader) Option { + return func(opts *options) { + opts.redisLoader = redisLoader + } +} + func LoggerOption(logger logger.Logger) Option { return func(opts *options) { opts.logger = logger @@ -26,18 +67,65 @@ type bypass struct { cidrMatcher matcher.Matcher domainMatcher matcher.Matcher wildcardMatcher matcher.Matcher - reversed bool + mu sync.RWMutex + cancelFunc context.CancelFunc options options } -// NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules. -// The rules will be reversed if the reverse is true. -func NewBypass(reversed bool, patterns []string, opts ...Option) bypass_pkg.Bypass { +// NewBypass creates and initializes a new Bypass. +// The rules will be reversed if the reverse option is true. +func NewBypass(opts ...Option) bypass_pkg.Bypass { var options options for _, opt := range opts { opt(&options) } + ctx, cancel := context.WithCancel(context.TODO()) + + bp := &bypass{ + cancelFunc: cancel, + options: options, + } + + if err := bp.reload(ctx); err != nil { + options.logger.Warnf("reload: %v", err) + } + if bp.options.period > 0 { + go bp.periodReload(ctx) + } + + return bp +} + +func (bp *bypass) periodReload(ctx context.Context) error { + period := bp.options.period + if period < time.Second { + period = time.Second + } + ticker := time.NewTicker(period) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := bp.reload(ctx); err != nil { + bp.options.logger.Warnf("reload: %v", err) + // return err + } + bp.options.logger.Debugf("bypass reload done") + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (bp *bypass) reload(ctx context.Context) error { + v, err := bp.load(ctx) + if err != nil { + return err + } + patterns := append(bp.options.matchers, v...) + var ips []net.IP var inets []*net.IPNet var domains []string @@ -56,16 +144,61 @@ func NewBypass(reversed bool, patterns []string, opts ...Option) bypass_pkg.Bypa continue } domains = append(domains, pattern) + } + bp.mu.Lock() + defer bp.mu.Unlock() + + bp.ipMatcher = matcher.IPMatcher(ips) + bp.cidrMatcher = matcher.CIDRMatcher(inets) + bp.domainMatcher = matcher.DomainMatcher(domains) + bp.wildcardMatcher = matcher.WildcardMatcher(wildcards) + + return nil +} + +func (bp *bypass) load(ctx context.Context) (patterns []string, err error) { + if bp.options.fileLoader != nil { + r, er := bp.options.fileLoader.Load(ctx) + if er != nil { + bp.options.logger.Warnf("file loader: %v", er) + } + if v, _ := bp.parsePatterns(r); v != nil { + patterns = append(patterns, v...) + } } - return &bypass{ - reversed: reversed, - options: options, - ipMatcher: matcher.IPMatcher(ips), - cidrMatcher: matcher.CIDRMatcher(inets), - domainMatcher: matcher.DomainMatcher(domains), - wildcardMatcher: matcher.WildcardMatcher(wildcards), + if bp.options.redisLoader != nil { + r, er := bp.options.redisLoader.Load(ctx) + if er != nil { + bp.options.logger.Warnf("redis loader: %v", er) + } + if v, _ := bp.parsePatterns(r); v != nil { + patterns = append(patterns, v...) + } } + + return +} + +func (bp *bypass) parsePatterns(r io.Reader) (patterns []string, err error) { + if r == nil { + return + } + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + line = strings.TrimSpace(line) + if line != "" { + patterns = append(patterns, line) + } + } + + err = scanner.Err() + return } func (bp *bypass) Contains(addr string) bool { @@ -80,8 +213,8 @@ func (bp *bypass) Contains(addr string) bool { matched := bp.matched(addr) - b := !bp.reversed && matched || - bp.reversed && !matched + b := !bp.options.reverse && matched || + bp.options.reverse && !matched if b { bp.options.logger.Debugf("bypass: %s", addr) } @@ -89,6 +222,9 @@ func (bp *bypass) Contains(addr string) bool { } func (bp *bypass) matched(addr string) bool { + bp.mu.RLock() + defer bp.mu.RUnlock() + if ip := net.ParseIP(addr); ip != nil { return bp.ipMatcher.Match(addr) || bp.cidrMatcher.Match(addr) @@ -97,3 +233,14 @@ func (bp *bypass) matched(addr string) bool { return bp.domainMatcher.Match(addr) || bp.wildcardMatcher.Match(addr) } + +func (bp *bypass) Close() error { + bp.cancelFunc() + if bp.options.fileLoader != nil { + bp.options.fileLoader.Close() + } + if bp.options.redisLoader != nil { + bp.options.redisLoader.Close() + } + return nil +} diff --git a/config/config.go b/config/config.go index 2172281..2a061e2 100644 --- a/config/config.go +++ b/config/config.go @@ -74,11 +74,11 @@ type TLSConfig struct { } type AutherConfig struct { - Name string `json:"name"` - // inline, file, redis, etc. - Type string `yaml:",omitempty" json:"type,omitempty"` - Auths []*AuthConfig `yaml:",omitempty" json:"auths"` - // File string `yaml:",omitempty" json:"file"` + Name string `json:"name"` + Auths []*AuthConfig `yaml:",omitempty" json:"auths"` + Reload time.Duration `yaml:",omitempty" json:"reload,omitempty"` + File *FileLoader `yaml:",omitempty" json:"file,omitempty"` + Redis *RedisLoader `yaml:",omitempty" json:"redis,omitempty"` } type AuthConfig struct { @@ -93,19 +93,32 @@ type SelectorConfig struct { } type AdmissionConfig struct { - Name string `json:"name"` - // inline, file, etc. - Type string `yaml:",omitempty" json:"type,omitempty"` - Reverse bool `yaml:",omitempty" json:"reverse,omitempty"` - Matchers []string `json:"matchers"` + Name string `json:"name"` + Reverse bool `yaml:",omitempty" json:"reverse,omitempty"` + Matchers []string `json:"matchers"` + Reload time.Duration `yaml:",omitempty" json:"reload,omitempty"` + File *FileLoader `yaml:",omitempty" json:"file,omitempty"` + Redis *RedisLoader `yaml:",omitempty" json:"redis,omitempty"` } type BypassConfig struct { - Name string `json:"name"` - // inline, file, etc. - Type string `yaml:",omitempty" json:"type,omitempty"` - Reverse bool `yaml:",omitempty" json:"reverse,omitempty"` - Matchers []string `json:"matchers"` + Name string `json:"name"` + Reverse bool `yaml:",omitempty" json:"reverse,omitempty"` + Matchers []string `json:"matchers"` + Reload time.Duration `yaml:",omitempty" json:"reload,omitempty"` + File *FileLoader `yaml:",omitempty" json:"file,omitempty"` + Redis *RedisLoader `yaml:",omitempty" json:"redis,omitempty"` +} + +type FileLoader struct { + Path string `json:"path"` +} + +type RedisLoader struct { + Addr string `yaml:",omitempty" json:"addr,omitempty"` + DB int `yaml:",omitempty" json:"db,omitempty"` + Password string `yaml:",omitempty" json:"password,omitempty"` + Key string `yaml:",omitempty" json:"key,omitempty"` } type NameserverConfig struct { diff --git a/config/parsing/parse.go b/config/parsing/parse.go index 139b681..221f2cc 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -16,6 +16,7 @@ import ( bypass_impl "github.com/go-gost/x/bypass" "github.com/go-gost/x/config" hosts_impl "github.com/go-gost/x/hosts" + "github.com/go-gost/x/internal/loader" "github.com/go-gost/x/registry" resolver_impl "github.com/go-gost/x/resolver" ) @@ -34,19 +35,39 @@ func ParseAuther(cfg *config.AutherConfig) auth.Authenticator { m[user.Username] = user.Password } - if len(m) == 0 { - return nil + opts := []auth_impl.Option{ + auth_impl.AuthsPeriodOption(m), + auth_impl.ReloadPeriodOption(cfg.Reload), + auth_impl.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "auther", + "auther": cfg.Name, + })), } - return auth_impl.NewAuthenticator(m) + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, auth_impl.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + opts = append(opts, auth_impl.RedisLoaderOption(loader.RedisHashLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + + return auth_impl.NewAuthenticator(opts...) } func ParseAutherFromAuth(au *config.AuthConfig) auth.Authenticator { if au == nil || au.Username == "" { return nil } - return auth_impl.NewAuthenticator(map[string]string{ - au.Username: au.Password, - }) + return auth_impl.NewAuthenticator( + auth_impl.AuthsPeriodOption( + map[string]string{ + au.Username: au.Password, + }, + )) } func parseAuth(cfg *config.AuthConfig) *url.Userinfo { @@ -88,28 +109,55 @@ func ParseAdmission(cfg *config.AdmissionConfig) admission.Admission { if cfg == nil { return nil } - return admission_impl.NewAdmission( - cfg.Reverse, - cfg.Matchers, + opts := []admission_impl.Option{ + admission_impl.MatchersOption(cfg.Matchers), + admission_impl.ReverseOption(cfg.Reverse), + admission_impl.ReloadPeriodOption(cfg.Reload), admission_impl.LoggerOption(logger.Default().WithFields(map[string]any{ "kind": "admission", "admission": cfg.Name, })), - ) + } + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, admission_impl.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + opts = append(opts, admission_impl.RedisLoaderOption(loader.RedisSetLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + return admission_impl.NewAdmission(opts...) } func ParseBypass(cfg *config.BypassConfig) bypass.Bypass { if cfg == nil { return nil } - return bypass_impl.NewBypass( - cfg.Reverse, - cfg.Matchers, + + opts := []bypass_impl.Option{ + bypass_impl.MatchersOption(cfg.Matchers), + bypass_impl.ReverseOption(cfg.Reverse), + bypass_impl.ReloadPeriodOption(cfg.Reload), bypass_impl.LoggerOption(logger.Default().WithFields(map[string]any{ "kind": "bypass", "bypass": cfg.Name, })), - ) + } + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, bypass_impl.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + opts = append(opts, bypass_impl.RedisLoaderOption(loader.RedisSetLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + return bypass_impl.NewBypass(opts...) } func ParseResolver(cfg *config.ResolverConfig) (resolver.Resolver, error) { diff --git a/connector/relay/connector.go b/connector/relay/connector.go index 69bf481..491ed5f 100644 --- a/connector/relay/connector.go +++ b/connector/relay/connector.go @@ -94,19 +94,23 @@ func (c *relayConnector) Connect(ctx context.Context, conn net.Conn, network, ad if _, err := req.WriteTo(conn); err != nil { return nil, err } + // drain the response + if err := readResponse(conn); err != nil { + return nil, err + } } switch network { case "tcp", "tcp4", "tcp6": - cc := &tcpConn{ - Conn: conn, - } if !c.md.noDelay { + cc := &tcpConn{ + Conn: conn, + } if _, err := req.WriteTo(&cc.wbuf); err != nil { return nil, err } + conn = cc } - conn = cc case "udp", "udp4", "udp6": cc := &udpConn{ Conn: conn, diff --git a/go.mod b/go.mod index cd60eb8..35287e9 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 + github.com/go-redis/redis/v8 v8.11.5 github.com/gobwas/glob v0.2.3 github.com/golang/snappy v0.0.4 github.com/gorilla/websocket v1.5.0 @@ -42,6 +43,7 @@ require ( github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cheekybits/genny v1.0.0 // indirect github.com/coreos/go-iptables v0.5.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fsnotify/fsnotify v1.5.1 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.0 // indirect diff --git a/go.sum b/go.sum index d82b241..8fd9e46 100644 --- a/go.sum +++ b/go.sum @@ -89,6 +89,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/docker/libcontainer v2.2.1+incompatible h1:++SbbkCw+X8vAd4j2gOCzZ2Nn7s2xFALTf7LZKmM1/0= github.com/docker/libcontainer v2.2.1+incompatible/go.mod h1:osvj61pYsqhNCMLGX31xr7klUBhHb/ZBuXS0o1Fvwbw= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= @@ -148,6 +150,8 @@ github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= github.com/go-playground/validator/v10 v10.10.1 h1:uA0+amWMiglNZKZ9FJRKUAe9U3RX91eVn1JYXMWt7ig= github.com/go-playground/validator/v10 v10.10.1/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= @@ -327,8 +331,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.13.0 h1:7lLHu94wT9Ij0o6EWWclhu0aOh32VxhkwEJvzuWPeak= github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= +github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/pelletier/go-toml v1.9.4 h1:tjENF6MfZAg8e4ZmZTeWaWiT2vXtsoO6+iuOjFhECwM= github.com/pelletier/go-toml v1.9.4/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= diff --git a/handler/relay/connect.go b/handler/relay/connect.go index 4492bcb..94a20d7 100644 --- a/handler/relay/connect.go +++ b/handler/relay/connect.go @@ -68,16 +68,16 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network } conn = rc default: - rc := &tcpConn{ - Conn: conn, - } if !h.md.noDelay { + rc := &tcpConn{ + Conn: conn, + } // cache the header if _, err := resp.WriteTo(&rc.wbuf); err != nil { return err } + conn = rc } - conn = rc } t := time.Now() diff --git a/internal/loader/file.go b/internal/loader/file.go new file mode 100644 index 0000000..fecfc72 --- /dev/null +++ b/internal/loader/file.go @@ -0,0 +1,30 @@ +package loader + +import ( + "bytes" + "context" + "io" + "os" +) + +type fileLoader struct { + filename string +} + +func FileLoader(filename string) Loader { + return &fileLoader{ + filename: filename, + } +} + +func (l *fileLoader) Load(ctx context.Context) (io.Reader, error) { + data, err := os.ReadFile(l.filename) + if err != nil { + return nil, err + } + return bytes.NewReader(data), nil +} + +func (l *fileLoader) Close() error { + return nil +} diff --git a/internal/loader/loader.go b/internal/loader/loader.go new file mode 100644 index 0000000..2ae8add --- /dev/null +++ b/internal/loader/loader.go @@ -0,0 +1,11 @@ +package loader + +import ( + "context" + "io" +) + +type Loader interface { + Load(context.Context) (io.Reader, error) + Close() error +} diff --git a/internal/loader/redis.go b/internal/loader/redis.go new file mode 100644 index 0000000..6f4c1d5 --- /dev/null +++ b/internal/loader/redis.go @@ -0,0 +1,124 @@ +package loader + +import ( + "bytes" + "context" + "fmt" + "io" + "strings" + + "github.com/go-redis/redis/v8" +) + +const ( + DefaultRedisKey = "gost" +) + +type redisLoaderOptions struct { + db int + password string + key string +} + +type RedisLoaderOption func(opts *redisLoaderOptions) + +func DBRedisLoaderOption(db int) RedisLoaderOption { + return func(opts *redisLoaderOptions) { + opts.db = db + } +} + +func PasswordRedisLoaderOption(password string) RedisLoaderOption { + return func(opts *redisLoaderOptions) { + opts.password = password + } +} + +func KeyRedisLoaderOption(key string) RedisLoaderOption { + return func(opts *redisLoaderOptions) { + opts.key = key + } +} + +type redisSetLoader struct { + client *redis.Client + key string +} + +// RedisSetLoader loads values from redis set. +func RedisSetLoader(addr string, opts ...RedisLoaderOption) Loader { + var options redisLoaderOptions + for _, opt := range opts { + opt(&options) + } + + key := options.key + if key == "" { + key = DefaultRedisKey + } + + return &redisSetLoader{ + client: redis.NewClient(&redis.Options{ + Addr: addr, + Password: options.password, + DB: options.db, + }), + key: key, + } +} + +func (p *redisSetLoader) Load(ctx context.Context) (io.Reader, error) { + v, err := p.client.SMembers(ctx, p.key).Result() + if err != nil { + return nil, err + } + return bytes.NewReader([]byte(strings.Join(v, "\n"))), nil +} + +func (p *redisSetLoader) Close() error { + return p.client.Close() +} + +type redisHashLoader struct { + client *redis.Client + key string +} + +// RedisHashLoader loads values from redis hash. +func RedisHashLoader(addr string, opts ...RedisLoaderOption) Loader { + var options redisLoaderOptions + for _, opt := range opts { + opt(&options) + } + + key := options.key + if key == "" { + key = DefaultRedisKey + } + + return &redisHashLoader{ + client: redis.NewClient(&redis.Options{ + Addr: addr, + Password: options.password, + DB: options.db, + }), + key: key, + } +} + +func (p *redisHashLoader) Load(ctx context.Context) (io.Reader, error) { + m, err := p.client.HGetAll(ctx, p.key).Result() + if err != nil { + return nil, err + } + + var b strings.Builder + for k, v := range m { + fmt.Fprintf(&b, "%s %s\n", k, v) + } + return bytes.NewBufferString(b.String()), nil +} + +func (p *redisHashLoader) Close() error { + return p.client.Close() +} diff --git a/internal/util/matcher/matcher.go b/internal/matcher/matcher.go similarity index 100% rename from internal/util/matcher/matcher.go rename to internal/matcher/matcher.go diff --git a/registry/registry.go b/registry/registry.go index affd3d2..877ba1e 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -2,6 +2,7 @@ package registry import ( "errors" + "io" "sync" "github.com/go-gost/core/admission" @@ -55,7 +56,12 @@ func (r *registry) Register(name string, v any) error { } func (r *registry) Unregister(name string) { - r.m.Delete(name) + if v, ok := r.m.Load(name); ok { + if closer, ok := v.(io.Closer); ok { + closer.Close() + } + r.m.Delete(name) + } } func (r *registry) IsRegistered(name string) bool {