diff --git a/auth/auth.go b/auth/auth.go index 4373b00..83a3cc8 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -120,16 +120,13 @@ func (p *authenticator) periodReload(ctx context.Context) error { } } -func (p *authenticator) reload(ctx context.Context) error { +func (p *authenticator) reload(ctx context.Context) (err error) { kvs := make(map[string]string) for k, v := range p.options.auths { kvs[k] = v } m, err := p.load(ctx) - if err != nil { - return err - } for k, v := range m { kvs[k] = v } @@ -139,7 +136,7 @@ func (p *authenticator) reload(ctx context.Context) error { p.kvs = kvs - return nil + return } func (p *authenticator) load(ctx context.Context) (m map[string]string, err error) { @@ -180,7 +177,8 @@ func (p *authenticator) parseAuths(r io.Reader) (auths map[string]string, err er scanner := bufio.NewScanner(r) for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) + line := strings.Replace(scanner.Text(), "\t", " ", -1) + line = strings.TrimSpace(line) if n := strings.IndexByte(line, '#'); n == 0 { continue } diff --git a/config/config.go b/config/config.go index dc1ffd6..5bdf01a 100644 --- a/config/config.go +++ b/config/config.go @@ -119,6 +119,7 @@ type RedisLoader struct { DB int `yaml:",omitempty" json:"db,omitempty"` Password string `yaml:",omitempty" json:"password,omitempty"` Key string `yaml:",omitempty" json:"key,omitempty"` + Type string `yaml:",omitempty" json:"type,omitempty"` } type NameserverConfig struct { @@ -145,6 +146,9 @@ type HostMappingConfig struct { type HostsConfig struct { Name string `json:"name"` Mappings []*HostMappingConfig `json:"mappings"` + Reload time.Duration `yaml:",omitempty" json:"reload,omitempty"` + File *FileLoader `yaml:",omitempty" json:"file,omitempty"` + Redis *RedisLoader `yaml:",omitempty" json:"redis,omitempty"` } type RecorderConfig struct { diff --git a/config/parsing/parse.go b/config/parsing/parse.go index f138484..832f636 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -191,27 +191,55 @@ func ParseResolver(cfg *config.ResolverConfig) (resolver.Resolver, error) { } func ParseHosts(cfg *config.HostsConfig) hosts.HostMapper { - if cfg == nil || len(cfg.Mappings) == 0 { + if cfg == nil { return nil } - hosts := hosts_impl.NewHosts() - hosts.Logger = logger.Default().WithFields(map[string]any{ - "kind": "hosts", - "hosts": cfg.Name, - }) - for _, host := range cfg.Mappings { - if host.IP == "" || host.Hostname == "" { + var mappings []hosts_impl.Mapping + for _, mapping := range cfg.Mappings { + if mapping.IP == "" || mapping.Hostname == "" { continue } - ip := net.ParseIP(host.IP) + ip := net.ParseIP(mapping.IP) if ip == nil { continue } - hosts.Map(ip, host.Hostname, host.Aliases...) + mappings = append(mappings, hosts_impl.Mapping{ + Hostname: mapping.Hostname, + IP: ip, + }) } - return hosts + opts := []hosts_impl.Option{ + hosts_impl.MappingsOption(mappings), + hosts_impl.ReloadPeriodOption(cfg.Reload), + hosts_impl.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "hosts", + "hosts": cfg.Name, + })), + } + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, hosts_impl.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + switch cfg.Redis.Type { + case "list": // redis list + opts = append(opts, hosts_impl.RedisLoaderOption(loader.RedisListLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + default: // redis set + opts = append(opts, hosts_impl.RedisLoaderOption(loader.RedisSetLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + } + return hosts_impl.NewHostMapper(opts...) } func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) { diff --git a/hosts/hosts.go b/hosts/hosts.go index a314528..cd493ea 100644 --- a/hosts/hosts.go +++ b/hosts/hosts.go @@ -1,16 +1,62 @@ package hosts import ( + "bufio" + "context" + "io" "net" "strings" "sync" + "time" + "github.com/go-gost/core/hosts" "github.com/go-gost/core/logger" + "github.com/go-gost/x/internal/loader" ) -type hostMapping struct { - IPs []net.IP +type Mapping struct { Hostname string + IP net.IP +} + +type options struct { + mappings []Mapping + fileLoader loader.Loader + redisLoader loader.Loader + period time.Duration + logger logger.Logger +} + +type Option func(opts *options) + +func MappingsOption(mappings []Mapping) Option { + return func(opts *options) { + opts.mappings = mappings + } +} + +func ReloadPeriodOption(period time.Duration) Option { + return func(opts *options) { + opts.period = period + } +} + +func FileLoaderOption(fileLoader loader.Loader) Option { + return func(opts *options) { + opts.fileLoader = fileLoader + } +} + +func RedisLoaderOption(redisLoader loader.Loader) Option { + return func(opts *options) { + opts.redisLoader = redisLoader + } +} + +func LoggerOption(logger logger.Logger) Option { + return func(opts *options) { + opts.logger = logger + } } // Hosts is a static table lookup for hostnames. @@ -19,57 +65,51 @@ type hostMapping struct { // Fields of the entry are separated by any number of blanks and/or tab characters. // Text from a "#" character until the end of the line is a comment, and is ignored. type Hosts struct { - mappings sync.Map - Logger logger.Logger + mappings map[string][]net.IP + mu sync.RWMutex + cancelFunc context.CancelFunc + options options } -func NewHosts() *Hosts { - return &Hosts{} -} - -// Map maps ip to hostname or aliases. -func (h *Hosts) Map(ip net.IP, hostname string, aliases ...string) { - if hostname == "" { - return +func NewHostMapper(opts ...Option) hosts.HostMapper { + var options options + for _, opt := range opts { + opt(&options) } - v, _ := h.mappings.Load(hostname) - m, _ := v.(*hostMapping) - if m == nil { - m = &hostMapping{ - IPs: []net.IP{ip}, - Hostname: hostname, - } - } else { - m.IPs = append(m.IPs, ip) + ctx, cancel := context.WithCancel(context.TODO()) + p := &Hosts{ + mappings: make(map[string][]net.IP), + cancelFunc: cancel, + options: options, } - h.mappings.Store(hostname, m) - for _, alias := range aliases { - // indirect mapping from alias to hostname - if alias != "" { - h.mappings.Store(alias, &hostMapping{ - Hostname: hostname, - }) - } + if err := p.reload(ctx); err != nil { + options.logger.Warnf("reload: %v", err) } + if p.options.period > 0 { + go p.periodReload(ctx) + } + + return p } // Lookup searches the IP address corresponds to the given network and host from the host table. // The network should be 'ip', 'ip4' or 'ip6', default network is 'ip'. // the host should be a hostname (example.org) or a hostname with dot prefix (.example.org). func (h *Hosts) Lookup(network, host string) (ips []net.IP, ok bool) { - m := h.lookup(host) - if m == nil { - m = h.lookup("." + host) + h.options.logger.Debugf("lookup %s/%s", host, network) + ips = h.lookup(host) + if ips == nil { + ips = h.lookup("." + host) } - if m == nil { + if ips == nil { s := host for { if index := strings.IndexByte(s, '.'); index > 0 { - m = h.lookup(s[index:]) + ips = h.lookup(s[index:]) s = s[index+1:] - if m == nil { + if ips == nil { continue } } @@ -77,51 +117,173 @@ func (h *Hosts) Lookup(network, host string) (ips []net.IP, ok bool) { } } - if m == nil { + if ips == nil { return } - // hostname alias - if !strings.HasPrefix(m.Hostname, ".") && host != m.Hostname { - m = h.lookup(m.Hostname) - if m == nil { - return - } - } - switch network { case "ip4": - for _, ip := range m.IPs { + var v []net.IP + for _, ip := range ips { if ip = ip.To4(); ip != nil { - ips = append(ips, ip) + v = append(v, ip) } } + ips = v case "ip6": - for _, ip := range m.IPs { + var v []net.IP + for _, ip := range ips { if ip.To4() == nil { - ips = append(ips, ip) + v = append(v, ip) } } + ips = v default: - ips = m.IPs } if len(ips) > 0 { - h.Logger.Debugf("host mapper: %s -> %s", host, ips) + h.options.logger.Debugf("host mapper: %s/%s -> %s", host, network, ips) } return } -func (h *Hosts) lookup(host string) *hostMapping { - if h == nil || host == "" { +func (h *Hosts) lookup(host string) []net.IP { + if h == nil || len(h.mappings) == 0 { return nil } - v, ok := h.mappings.Load(host) - if !ok { - return nil - } - m, _ := v.(*hostMapping) - return m + h.mu.RLock() + defer h.mu.RUnlock() + + return h.mappings[host] +} + +func (h *Hosts) periodReload(ctx context.Context) error { + period := h.options.period + if period < time.Second { + period = time.Second + } + ticker := time.NewTicker(period) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := h.reload(ctx); err != nil { + h.options.logger.Warnf("reload: %v", err) + // return err + } + h.options.logger.Debugf("hosts reload done") + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (h *Hosts) reload(ctx context.Context) (err error) { + mappings := make(map[string][]net.IP) + + mapf := func(hostname string, ip net.IP) { + ips := mappings[hostname] + found := false + for i := range ips { + if ip.Equal(ips[i]) { + found = true + break + } + } + if !found { + ips = append(ips, ip) + } + mappings[hostname] = ips + } + + for _, mapping := range h.options.mappings { + mapf(mapping.Hostname, mapping.IP) + } + + m, err := h.load(ctx) + for i := range m { + mapf(m[i].Hostname, m[i].IP) + } + + h.mu.Lock() + defer h.mu.Unlock() + + h.mappings = mappings + + return +} + +func (h *Hosts) load(ctx context.Context) (mappings []Mapping, err error) { + if h.options.fileLoader != nil { + r, er := h.options.fileLoader.Load(ctx) + if er != nil { + h.options.logger.Warnf("file loader: %v", er) + } + mappings, _ = h.parseMapping(r) + } + + if h.options.redisLoader != nil { + r, er := h.options.redisLoader.Load(ctx) + if er != nil { + h.options.logger.Warnf("redis loader: %v", er) + } + if m, _ := h.parseMapping(r); m != nil { + mappings = append(mappings, m...) + } + } + + return +} + +func (h *Hosts) parseMapping(r io.Reader) (mappings []Mapping, err error) { + if r == nil { + return + } + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := strings.Replace(scanner.Text(), "\t", " ", -1) + line = strings.TrimSpace(line) + if n := strings.IndexByte(line, '#'); n >= 0 { + line = line[:n] + } + var sp []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + sp = append(sp, s) + } + } + if len(sp) < 2 { + continue // invalid lines are ignored + } + + ip := net.ParseIP(sp[0]) + if ip == nil { + continue // invalid IP addresses are ignored + } + + for _, v := range sp[1:] { + mappings = append(mappings, Mapping{ + Hostname: v, + IP: ip, + }) + } + } + + err = scanner.Err() + return +} + +func (h *Hosts) Close() error { + h.cancelFunc() + if h.options.fileLoader != nil { + h.options.fileLoader.Close() + } + if h.options.redisLoader != nil { + h.options.redisLoader.Close() + } + return nil } diff --git a/internal/loader/redis.go b/internal/loader/redis.go index 38dac60..a57abfd 100644 --- a/internal/loader/redis.go +++ b/internal/loader/redis.go @@ -79,6 +79,45 @@ func (p *redisSetLoader) Close() error { return p.client.Close() } +type redisListLoader struct { + client *redis.Client + key string +} + +// RedisListLoader loads data from redis list. +func RedisListLoader(addr string, opts ...RedisLoaderOption) Loader { + var options redisLoaderOptions + for _, opt := range opts { + opt(&options) + } + + key := options.key + if key == "" { + key = DefaultRedisKey + } + + return &redisListLoader{ + client: redis.NewClient(&redis.Options{ + Addr: addr, + Password: options.password, + DB: options.db, + }), + key: key, + } +} + +func (p *redisListLoader) Load(ctx context.Context) (io.Reader, error) { + v, err := p.client.LRange(ctx, p.key, 0, -1).Result() + if err != nil { + return nil, err + } + return bytes.NewReader([]byte(strings.Join(v, "\n"))), nil +} + +func (p *redisListLoader) Close() error { + return p.client.Close() +} + type redisHashLoader struct { client *redis.Client key string