diff --git a/admission/admission.go b/admission/admission.go index 19817c9..3ac4b52 100644 --- a/admission/admission.go +++ b/admission/admission.go @@ -164,21 +164,41 @@ func (p *admission) reload(ctx context.Context) error { func (p *admission) load(ctx context.Context) (patterns []string, err error) { if p.options.fileLoader != nil { - r, er := p.options.fileLoader.Load(ctx) - if er != nil { - p.options.logger.Warnf("file loader: %v", er) - } - if v, _ := p.parsePatterns(r); v != nil { - patterns = append(patterns, v...) + if lister, ok := p.options.fileLoader.(loader.Lister); ok { + list, er := lister.List(ctx) + if er != nil { + p.options.logger.Warnf("file loader: %v", er) + } + for _, s := range list { + if line := p.parseLine(s); line != "" { + patterns = append(patterns, line) + } + } + } else { + r, er := p.options.fileLoader.Load(ctx) + if er != nil { + p.options.logger.Warnf("file loader: %v", er) + } + if v, _ := p.parsePatterns(r); v != nil { + patterns = append(patterns, v...) + } } } if p.options.redisLoader != nil { - r, er := p.options.redisLoader.Load(ctx) - if er != nil { - p.options.logger.Warnf("redis loader: %v", er) - } - if v, _ := p.parsePatterns(r); v != nil { - patterns = append(patterns, v...) + if lister, ok := p.options.redisLoader.(loader.Lister); ok { + list, er := lister.List(ctx) + if er != nil { + p.options.logger.Warnf("redis loader: %v", er) + } + patterns = append(patterns, list...) + } else { + r, er := p.options.redisLoader.Load(ctx) + if er != nil { + p.options.logger.Warnf("redis loader: %v", er) + } + if v, _ := p.parsePatterns(r); v != nil { + patterns = append(patterns, v...) + } } } @@ -192,12 +212,7 @@ func (p *admission) parsePatterns(r io.Reader) (patterns []string, err error) { scanner := bufio.NewScanner(r) for scanner.Scan() { - line := scanner.Text() - if n := strings.IndexByte(line, '#'); n >= 0 { - line = line[:n] - } - line = strings.TrimSpace(line) - if line != "" { + if line := p.parseLine(scanner.Text()); line != "" { patterns = append(patterns, line) } } @@ -206,6 +221,13 @@ func (p *admission) parsePatterns(r io.Reader) (patterns []string, err error) { return } +func (p *admission) parseLine(s string) string { + if n := strings.IndexByte(s, '#'); n >= 0 { + s = s[:n] + } + return strings.TrimSpace(s) +} + func (p *admission) matched(addr string) bool { p.mu.RLock() defer p.mu.RUnlock() diff --git a/auth/auth.go b/auth/auth.go index 83a3cc8..1a2a5c3 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -143,25 +143,41 @@ func (p *authenticator) load(ctx context.Context) (m map[string]string, err erro m = make(map[string]string) if p.options.fileLoader != nil { - r, er := p.options.fileLoader.Load(ctx) - if er != nil { - p.options.logger.Warnf("file loader: %v", er) - } - if auths, _ := p.parseAuths(r); auths != nil { - for k, v := range auths { - m[k] = v + if mapper, ok := p.options.fileLoader.(loader.Mapper); ok { + auths, er := mapper.Map(ctx) + if er != nil { + p.options.logger.Warnf("file loader: %v", er) + } + m = auths + } else { + r, er := p.options.fileLoader.Load(ctx) + if er != nil { + p.options.logger.Warnf("file loader: %v", er) + } + if auths, _ := p.parseAuths(r); auths != nil { + m = auths } } } if p.options.redisLoader != nil { - r, er := p.options.redisLoader.Load(ctx) - if er != nil { - p.options.logger.Warnf("redis loader: %v", er) - } - if auths, _ := p.parseAuths(r); auths != nil { + if mapper, ok := p.options.fileLoader.(loader.Mapper); ok { + auths, er := mapper.Map(ctx) + if er != nil { + p.options.logger.Warnf("file loader: %v", er) + } for k, v := range auths { m[k] = v } + } else { + r, er := p.options.redisLoader.Load(ctx) + if er != nil { + p.options.logger.Warnf("redis loader: %v", er) + } + if auths, _ := p.parseAuths(r); auths != nil { + for k, v := range auths { + m[k] = v + } + } } } diff --git a/bypass/bypass.go b/bypass/bypass.go index c60a545..4369354 100644 --- a/bypass/bypass.go +++ b/bypass/bypass.go @@ -159,21 +159,41 @@ func (bp *bypass) reload(ctx context.Context) error { func (bp *bypass) load(ctx context.Context) (patterns []string, err error) { if bp.options.fileLoader != nil { - r, er := bp.options.fileLoader.Load(ctx) - if er != nil { - bp.options.logger.Warnf("file loader: %v", er) - } - if v, _ := bp.parsePatterns(r); v != nil { - patterns = append(patterns, v...) + if lister, ok := bp.options.fileLoader.(loader.Lister); ok { + list, er := lister.List(ctx) + if er != nil { + bp.options.logger.Warnf("file loader: %v", er) + } + for _, s := range list { + if line := bp.parseLine(s); line != "" { + patterns = append(patterns, line) + } + } + } else { + r, er := bp.options.fileLoader.Load(ctx) + if er != nil { + bp.options.logger.Warnf("file loader: %v", er) + } + if v, _ := bp.parsePatterns(r); v != nil { + patterns = append(patterns, v...) + } } } if bp.options.redisLoader != nil { - r, er := bp.options.redisLoader.Load(ctx) - if er != nil { - bp.options.logger.Warnf("redis loader: %v", er) - } - if v, _ := bp.parsePatterns(r); v != nil { - patterns = append(patterns, v...) + if lister, ok := bp.options.redisLoader.(loader.Lister); ok { + list, er := lister.List(ctx) + if er != nil { + bp.options.logger.Warnf("redis loader: %v", er) + } + patterns = append(patterns, list...) + } else { + r, er := bp.options.redisLoader.Load(ctx) + if er != nil { + bp.options.logger.Warnf("redis loader: %v", er) + } + if v, _ := bp.parsePatterns(r); v != nil { + patterns = append(patterns, v...) + } } } @@ -187,12 +207,7 @@ func (bp *bypass) parsePatterns(r io.Reader) (patterns []string, err error) { scanner := bufio.NewScanner(r) for scanner.Scan() { - line := scanner.Text() - if n := strings.IndexByte(line, '#'); n >= 0 { - line = line[:n] - } - line = strings.TrimSpace(line) - if line != "" { + if line := bp.parseLine(scanner.Text()); line != "" { patterns = append(patterns, line) } } @@ -221,6 +236,13 @@ func (bp *bypass) Contains(addr string) bool { return b } +func (bp *bypass) parseLine(s string) string { + if n := strings.IndexByte(s, '#'); n >= 0 { + s = s[:n] + } + return strings.TrimSpace(s) +} + func (bp *bypass) matched(addr string) bool { bp.mu.RLock() defer bp.mu.RUnlock() diff --git a/config/parsing/tls.go b/config/parsing/tls.go index a576b4c..4a1bb93 100644 --- a/config/parsing/tls.go +++ b/config/parsing/tls.go @@ -38,9 +38,9 @@ func BuildDefaultTLSConfig(cfg *config.TLSConfig) { tlsConfig = &tls.Config{ Certificates: []tls.Certificate{cert}, } - log.Warn("load TLS certificate files failed, use random generated certificate") + log.Warn("load global TLS certificate files failed, use random generated certificate") } else { - log.Info("load TLS certificate files OK") + log.Info("load global TLS certificate files OK") } defaultTLSConfig = tlsConfig } diff --git a/go.sum b/go.sum index 7274ca0..801806d 100644 --- a/go.sum +++ b/go.sum @@ -121,8 +121,6 @@ github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gost/core v0.0.0-20220411145302-03988fee0b9a h1:GYdVcUmEoFVzkPNDavMsDnGA6AM+jgORE4n3/tIiKqI= -github.com/go-gost/core v0.0.0-20220411145302-03988fee0b9a/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= github.com/go-gost/core v0.0.0-20220413143512-acee88323487 h1:boiBuK2m2jImLYkOlvey8bIrEl7TKM5ZbU+wNmX5oyg= github.com/go-gost/core v0.0.0-20220413143512-acee88323487/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= diff --git a/hosts/hosts.go b/hosts/hosts.go index cd493ea..d6e3b57 100644 --- a/hosts/hosts.go +++ b/hosts/hosts.go @@ -218,20 +218,40 @@ func (h *Hosts) reload(ctx context.Context) (err error) { func (h *Hosts) load(ctx context.Context) (mappings []Mapping, err error) { if h.options.fileLoader != nil { - r, er := h.options.fileLoader.Load(ctx) - if er != nil { - h.options.logger.Warnf("file loader: %v", er) + if lister, ok := h.options.fileLoader.(loader.Lister); ok { + list, er := lister.List(ctx) + if er != nil { + h.options.logger.Warnf("file loader: %v", er) + } + for _, s := range list { + mappings = append(mappings, h.parseLine(s)...) + } + } else { + r, er := h.options.fileLoader.Load(ctx) + if er != nil { + h.options.logger.Warnf("file loader: %v", er) + } + mappings, _ = h.parseMapping(r) } - mappings, _ = h.parseMapping(r) } if h.options.redisLoader != nil { - r, er := h.options.redisLoader.Load(ctx) - if er != nil { - h.options.logger.Warnf("redis loader: %v", er) - } - if m, _ := h.parseMapping(r); m != nil { - mappings = append(mappings, m...) + if lister, ok := h.options.redisLoader.(loader.Lister); ok { + list, er := lister.List(ctx) + if er != nil { + h.options.logger.Warnf("redis loader: %v", er) + } + for _, s := range list { + mappings = append(mappings, h.parseLine(s)...) + } + } else { + r, er := h.options.redisLoader.Load(ctx) + if er != nil { + h.options.logger.Warnf("redis loader: %v", er) + } + if m, _ := h.parseMapping(r); m != nil { + mappings = append(mappings, m...) + } } } @@ -245,35 +265,39 @@ func (h *Hosts) parseMapping(r io.Reader) (mappings []Mapping, err error) { scanner := bufio.NewScanner(r) for scanner.Scan() { - line := strings.Replace(scanner.Text(), "\t", " ", -1) - line = strings.TrimSpace(line) - if n := strings.IndexByte(line, '#'); n >= 0 { - line = line[:n] - } - var sp []string - for _, s := range strings.Split(line, " ") { - if s = strings.TrimSpace(s); s != "" { - sp = append(sp, s) - } - } - if len(sp) < 2 { - continue // invalid lines are ignored - } + mappings = append(mappings, h.parseLine(scanner.Text())...) + } + err = scanner.Err() + return +} - ip := net.ParseIP(sp[0]) - if ip == nil { - continue // invalid IP addresses are ignored - } - - for _, v := range sp[1:] { - mappings = append(mappings, Mapping{ - Hostname: v, - IP: ip, - }) +func (h *Hosts) parseLine(s string) (mappings []Mapping) { + line := strings.Replace(s, "\t", " ", -1) + line = strings.TrimSpace(line) + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + var sp []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + sp = append(sp, s) } } + if len(sp) < 2 { + return // invalid lines are ignored + } - err = scanner.Err() + ip := net.ParseIP(sp[0]) + if ip == nil { + return // invalid IP addresses are ignored + } + + for _, v := range sp[1:] { + mappings = append(mappings, Mapping{ + Hostname: v, + IP: ip, + }) + } return } diff --git a/internal/loader/file.go b/internal/loader/file.go index bd8be11..7ba7349 100644 --- a/internal/loader/file.go +++ b/internal/loader/file.go @@ -1,6 +1,7 @@ package loader import ( + "bufio" "bytes" "context" "io" @@ -26,6 +27,23 @@ func (l *fileLoader) Load(ctx context.Context) (io.Reader, error) { return bytes.NewReader(data), nil } +// List implements Lister interface{} +func (l *fileLoader) List(ctx context.Context) (list []string, err error) { + f, err := os.Open(l.filename) + if err != nil { + return + } + defer f.Close() + + scanner := bufio.NewScanner(f) + for scanner.Scan() { + list = append(list, scanner.Text()) + } + err = scanner.Err() + + return +} + func (l *fileLoader) Close() error { return nil } diff --git a/internal/loader/loader.go b/internal/loader/loader.go index 2ae8add..cd882f5 100644 --- a/internal/loader/loader.go +++ b/internal/loader/loader.go @@ -9,3 +9,11 @@ type Loader interface { Load(context.Context) (io.Reader, error) Close() error } + +type Lister interface { + List(ctx context.Context) ([]string, error) +} + +type Mapper interface { + Map(ctx context.Context) (map[string]string, error) +} diff --git a/internal/loader/redis.go b/internal/loader/redis.go index a57abfd..daced47 100644 --- a/internal/loader/redis.go +++ b/internal/loader/redis.go @@ -68,13 +68,18 @@ func RedisSetLoader(addr string, opts ...RedisLoaderOption) Loader { } func (p *redisSetLoader) Load(ctx context.Context) (io.Reader, error) { - v, err := p.client.SMembers(ctx, p.key).Result() + v, err := p.List(ctx) if err != nil { return nil, err } return bytes.NewReader([]byte(strings.Join(v, "\n"))), nil } +// List implements Lister interface{} +func (p *redisSetLoader) List(ctx context.Context) ([]string, error) { + return p.client.SMembers(ctx, p.key).Result() +} + func (p *redisSetLoader) Close() error { return p.client.Close() } @@ -107,13 +112,18 @@ func RedisListLoader(addr string, opts ...RedisLoaderOption) Loader { } func (p *redisListLoader) Load(ctx context.Context) (io.Reader, error) { - v, err := p.client.LRange(ctx, p.key, 0, -1).Result() + v, err := p.List(ctx) if err != nil { return nil, err } return bytes.NewReader([]byte(strings.Join(v, "\n"))), nil } +// List implements Lister interface{} +func (p *redisListLoader) List(ctx context.Context) ([]string, error) { + return p.client.LRange(ctx, p.key, 0, -1).Result() +} + func (p *redisListLoader) Close() error { return p.client.Close() } @@ -146,7 +156,7 @@ func RedisHashLoader(addr string, opts ...RedisLoaderOption) Loader { } func (p *redisHashLoader) Load(ctx context.Context) (io.Reader, error) { - m, err := p.client.HGetAll(ctx, p.key).Result() + m, err := p.Map(ctx) if err != nil { return nil, err } @@ -158,6 +168,25 @@ func (p *redisHashLoader) Load(ctx context.Context) (io.Reader, error) { return bytes.NewBufferString(b.String()), nil } +// List implements Lister interface{} +func (p *redisHashLoader) List(ctx context.Context) (list []string, err error) { + m, err := p.Map(ctx) + if err != nil { + return + } + + for k, v := range m { + list = append(list, fmt.Sprintf("%s %s", k, v)) + } + + return +} + +// Map implements Mapper interface{} +func (p *redisHashLoader) Map(ctx context.Context) (map[string]string, error) { + return p.client.HGetAll(ctx, p.key).Result() +} + func (p *redisHashLoader) Close() error { return p.client.Close() }