From f35711705612f0b766d566aa102307b744a2698e Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 19 Jan 2022 16:33:27 +0800 Subject: [PATCH] add hosts support for dns --- cmd/gost/cmd.go | 10 +-- cmd/gost/config.go | 17 +++++- gost.yml | 2 +- pkg/bypass/bypass.go | 32 ++++++++-- pkg/chain/router.go | 6 +- pkg/config/config.go | 6 +- pkg/handler/dns/handler.go | 85 ++++++++++++++++++++++++-- pkg/hosts/hosts.go | 94 ++++++++++++++++++++++------- pkg/internal/util/resolver/cache.go | 2 +- pkg/metadata/metadata.go | 5 +- pkg/resolver/impl/resolver.go | 6 +- 11 files changed, 212 insertions(+), 53 deletions(-) diff --git a/cmd/gost/cmd.go b/cmd/gost/cmd.go index b1c0d53..b1f7f36 100644 --- a/cmd/gost/cmd.go +++ b/cmd/gost/cmd.go @@ -164,11 +164,11 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { if len(ss) != 2 { continue } - hostsCfg.Entries = append( - hostsCfg.Entries, - config.HostConfig{ - IP: ss[0], - Hostname: ss[1], + hostsCfg.Mappings = append( + hostsCfg.Mappings, + config.HostMappingConfig{ + Hostname: ss[0], + IP: ss[1], }, ) } diff --git a/cmd/gost/config.go b/cmd/gost/config.go index 2e0e78c..d8fdaa6 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -340,7 +340,14 @@ func bypassFromConfig(cfg *config.BypassConfig) bypass.Bypass { if cfg == nil { return nil } - return bypass.NewBypassPatterns(cfg.Reverse, cfg.Matchers...) + return bypass.NewBypassPatterns( + cfg.Reverse, + cfg.Matchers, + bypass.LoggerBypassOption(log.WithFields(map[string]interface{}{ + "kind": "bypass", + "bypass": cfg.Name, + })), + ) } func resolverFromConfig(cfg *config.ResolverConfig) (resolver.Resolver, error) { @@ -371,12 +378,16 @@ func resolverFromConfig(cfg *config.ResolverConfig) (resolver.Resolver, error) { } func hostsFromConfig(cfg *config.HostsConfig) hostspkg.HostMapper { - if cfg == nil || len(cfg.Entries) == 0 { + if cfg == nil || len(cfg.Mappings) == 0 { return nil } hosts := hostspkg.NewHosts() + hosts.Logger = log.WithFields(map[string]interface{}{ + "kind": "hosts", + "hosts": cfg.Name, + }) - for _, host := range cfg.Entries { + for _, host := range cfg.Mappings { if host.IP == "" || host.Hostname == "" { continue } diff --git a/gost.yml b/gost.yml index 106d83a..c795eba 100644 --- a/gost.yml +++ b/gost.yml @@ -295,7 +295,7 @@ resolvers: hosts: - name: hosts-0 - entries: + mappings: - ip: 127.0.0.1 hostname: localhost - ip: 192.168.1.10 diff --git a/pkg/bypass/bypass.go b/pkg/bypass/bypass.go index 719c85d..142b714 100644 --- a/pkg/bypass/bypass.go +++ b/pkg/bypass/bypass.go @@ -5,6 +5,7 @@ import ( "strconv" "strings" + "github.com/go-gost/gost/pkg/logger" glob "github.com/gobwas/glob" ) @@ -105,30 +106,48 @@ type Bypass interface { Contains(addr string) bool } +type bypassOptions struct { + logger logger.Logger +} + +type BypassOption func(opts *bypassOptions) + +func LoggerBypassOption(logger logger.Logger) BypassOption { + return func(opts *bypassOptions) { + opts.logger = logger + } +} + type bypass struct { matchers []Matcher reversed bool + options bypassOptions } // NewBypass creates and initializes a new Bypass using matchers as its match rules. // The rules will be reversed if the reversed is true. -func NewBypass(reversed bool, matchers ...Matcher) Bypass { +func NewBypass(reversed bool, matchers []Matcher, opts ...BypassOption) Bypass { + options := bypassOptions{} + for _, opt := range opts { + opt(&options) + } return &bypass{ matchers: matchers, reversed: reversed, + options: options, } } // NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules. // The rules will be reversed if the reverse is true. -func NewBypassPatterns(reversed bool, patterns ...string) Bypass { +func NewBypassPatterns(reversed bool, patterns []string, opts ...BypassOption) Bypass { var matchers []Matcher for _, pattern := range patterns { if m := NewMatcher(pattern); m != nil { matchers = append(matchers, m) } } - return NewBypass(reversed, matchers...) + return NewBypass(reversed, matchers, opts...) } func (bp *bypass) Contains(addr string) bool { @@ -153,6 +172,11 @@ func (bp *bypass) Contains(addr string) bool { break } } - return !bp.reversed && matched || + + b := !bp.reversed && matched || bp.reversed && !matched + if b { + bp.options.logger.Debugf("bypass: %s", addr) + } + return b } diff --git a/pkg/chain/router.go b/pkg/chain/router.go index 1556b0a..3061bee 100644 --- a/pkg/chain/router.go +++ b/pkg/chain/router.go @@ -79,9 +79,9 @@ func (r *Router) resolve(ctx context.Context, addr string) (string, error) { } if r.Hosts != nil { - if ip := r.Hosts.Lookup(host); ip != nil { - r.Logger.Debugf("hit hosts: %s -> %s", host, ip) - return net.JoinHostPort(ip.String(), port), nil + if ips, _ := r.Hosts.Lookup("ip", host); len(ips) > 0 { + r.Logger.Debugf("hit host mapper: %s -> %s", host, ips) + return net.JoinHostPort(ips[0].String(), port), nil } } diff --git a/pkg/config/config.go b/pkg/config/config.go index d542e66..95fe4cf 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -70,15 +70,15 @@ type ResolverConfig struct { Nameservers []NameserverConfig } -type HostConfig struct { +type HostMappingConfig struct { IP string Hostname string Aliases []string `yaml:",omitempty"` } type HostsConfig struct { - Name string - Entries []HostConfig + Name string + Mappings []HostMappingConfig } type ListenerConfig struct { diff --git a/pkg/handler/dns/handler.go b/pkg/handler/dns/handler.go index 002df53..8d3e09b 100644 --- a/pkg/handler/dns/handler.go +++ b/pkg/handler/dns/handler.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "net" "strconv" "strings" @@ -49,6 +50,8 @@ func NewHandler(opts ...handler.Option) handler.Handler { } func (h *dnsHandler) Init(md md.Metadata) (err error) { + h.logger = h.options.Logger + if err = h.parseMetadata(md); err != nil { return } @@ -58,10 +61,9 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { Retries: h.options.Retries, Chain: h.options.Chain, Resolver: h.options.Resolver, - Hosts: h.options.Hosts, - Logger: h.options.Logger, + // Hosts: h.options.Hosts, + Logger: h.options.Logger, } - h.logger = h.options.Logger for _, server := range h.md.dns { server = strings.TrimSpace(server) @@ -127,6 +129,7 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) { if err != nil { return } + defer bufpool.Put(&reply) if _, err = conn.Write(reply); err != nil { h.logger.Error(err) @@ -153,14 +156,31 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) { } var mr *dns.Msg - // cache only for single question message. + + if h.logger.IsLevelEnabled(logger.DebugLevel) { + defer func() { + if mr != nil { + h.logger.Debug(mr.String()) + } + }() + } + + mr = h.lookupHosts(&mq) + if mr != nil { + b := bufpool.Get(4096) + return mr.PackBuffer(*b) + } + + // only cache for single question message. if len(mq.Question) == 1 { key := resolver_util.NewCacheKey(&mq.Question[0]) mr = h.cache.Load(key) if mr != nil { h.logger.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String()) mr.Id = mq.Id - return mr.Pack() + + b := bufpool.Get(4096) + return mr.PackBuffer(*b) } defer func() { @@ -170,7 +190,10 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) { }() } - query, err := mq.Pack() + b := bufpool.Get(4096) + defer bufpool.Put(b) + + query, err := mq.PackBuffer(*b) if err != nil { h.logger.Error(err) return nil, err @@ -204,6 +227,56 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) { return reply, nil } +// lookup host mapper +func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) { + if h.options.Hosts == nil || + r.Question[0].Qclass != dns.ClassINET || + (r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA) { + return nil + } + + m = &dns.Msg{} + m.SetReply(r) + + host := strings.TrimSuffix(r.Question[0].Name, ".") + + switch r.Question[0].Qtype { + case dns.TypeA: + ips, _ := h.options.Hosts.Lookup("ip4", host) + if len(ips) == 0 { + return nil + } + h.logger.Debugf("hit host mapper: %s -> %s", host, ips) + + for _, ip := range ips { + rr, err := dns.NewRR(fmt.Sprintf("%s IN A %s\n", r.Question[0].Name, ip.String())) + if err != nil { + h.logger.Error(err) + return nil + } + m.Answer = append(m.Answer, rr) + } + + case dns.TypeAAAA: + ips, _ := h.options.Hosts.Lookup("ip6", host) + if len(ips) == 0 { + return nil + } + h.logger.Debugf("hit host mapper: %s -> %s", host, ips) + + for _, ip := range ips { + rr, err := dns.NewRR(fmt.Sprintf("%s IN AAAA %s\n", r.Question[0].Name, ip.String())) + if err != nil { + h.logger.Error(err) + return nil + } + m.Answer = append(m.Answer, rr) + } + } + + return +} + func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string { buf := new(bytes.Buffer) buf.WriteString(m.MsgHdr.String() + " ") diff --git a/pkg/hosts/hosts.go b/pkg/hosts/hosts.go index 983afbc..4eb5c98 100644 --- a/pkg/hosts/hosts.go +++ b/pkg/hosts/hosts.go @@ -2,17 +2,19 @@ package hosts import ( "net" + "sync" + + "github.com/go-gost/gost/pkg/logger" ) // HostMapper is a mapping from hostname to IP. type HostMapper interface { - Lookup(host string) net.IP + Lookup(network, host string) ([]net.IP, bool) } -type host struct { - IP net.IP +type hostMapping struct { + IPs []net.IP Hostname string - Aliases []string } // Hosts is a static table lookup for hostnames. @@ -21,7 +23,8 @@ type host struct { // Fields of the entry are separated by any number of blanks and/or tab characters. // Text from a "#" character until the end of the line is a comment, and is ignored. type Hosts struct { - mappings []host + mappings sync.Map + Logger logger.Logger } func NewHosts() *Hosts { @@ -30,30 +33,77 @@ func NewHosts() *Hosts { // Map maps ip to hostname or aliases. func (h *Hosts) Map(ip net.IP, hostname string, aliases ...string) { - h.mappings = append(h.mappings, host{ - IP: ip, - Hostname: hostname, - Aliases: aliases, - }) + if hostname == "" { + return + } + + v, _ := h.mappings.Load(hostname) + m, _ := v.(*hostMapping) + if m == nil { + m = &hostMapping{ + IPs: []net.IP{ip}, + Hostname: hostname, + } + } else { + m.IPs = append(m.IPs, ip) + } + h.mappings.Store(hostname, m) + + for _, alias := range aliases { + // indirect mapping from alias to hostname + if alias != "" { + h.mappings.Store(alias, &hostMapping{ + Hostname: hostname, + }) + } + } } -// Lookup searches the IP address corresponds to the given host from the host table. -func (h *Hosts) Lookup(host string) (ip net.IP) { +// Lookup searches the IP address corresponds to the given network and host from the host table. +// The network should be 'ip', 'ip4' or 'ip6', default network is 'ip'. +func (h *Hosts) Lookup(network, host string) (ips []net.IP, ok bool) { if h == nil || host == "" { return } - for _, h := range h.mappings { - if h.Hostname == host { - ip = h.IP - break - } - for _, alias := range h.Aliases { - if alias == host { - ip = h.IP - break - } + v, ok := h.mappings.Load(host) + if !ok { + return + } + m, _ := v.(*hostMapping) + if m == nil { + return + } + + // hostname alias + if host != m.Hostname { + v, _ = h.mappings.Load(m.Hostname) + m, _ = v.(*hostMapping) + if m == nil { + return } } + + switch network { + case "ip4": + for _, ip := range m.IPs { + if ip = ip.To4(); ip != nil { + ips = append(ips, ip) + } + } + case "ip6": + for _, ip := range m.IPs { + if ip.To4() == nil { + ips = append(ips, ip) + } + } + default: + ips = m.IPs + } + + if len(ips) > 0 { + h.Logger.Debugf("host mapper: %s -> %s", host, ips) + } + return } diff --git a/pkg/internal/util/resolver/cache.go b/pkg/internal/util/resolver/cache.go index ac38466..0865138 100644 --- a/pkg/internal/util/resolver/cache.go +++ b/pkg/internal/util/resolver/cache.go @@ -56,7 +56,7 @@ func (c *Cache) Load(key CacheKey) *dns.Msg { return nil } - c.logger.Debugf("resolver cache hit: %s", key) + c.logger.Debugf("hit resolver cache: %s", key) return item.msg.Copy() } diff --git a/pkg/metadata/metadata.go b/pkg/metadata/metadata.go index d4a12d3..9ce835b 100644 --- a/pkg/metadata/metadata.go +++ b/pkg/metadata/metadata.go @@ -105,7 +105,10 @@ func GetString(md Metadata, key string) (v string) { } func GetStrings(md Metadata, key string) (ss []string) { - if v, _ := md.Get(key).([]interface{}); len(v) > 0 { + switch v := md.Get(key).(type) { + case []string: + ss = v + case []interface{}: for _, vv := range v { if s, ok := vv.(string); ok { ss = append(ss, s) diff --git a/pkg/resolver/impl/resolver.go b/pkg/resolver/impl/resolver.go index 0ca4b47..74ef118 100644 --- a/pkg/resolver/impl/resolver.go +++ b/pkg/resolver/impl/resolver.go @@ -48,7 +48,6 @@ type resolver struct { servers []NameServer cache *resolver_util.Cache options resolverOptions - logger logger.Logger } func NewResolver(nameservers []NameServer, opts ...ResolverOption) (resolverpkg.Resolver, error) { @@ -87,7 +86,6 @@ func NewResolver(nameservers []NameServer, opts ...ResolverOption) (resolverpkg. servers: servers, cache: cache, options: options, - logger: options.logger, }, nil } @@ -104,11 +102,11 @@ func (r *resolver) Resolve(ctx context.Context, host string) (ips []net.IP, err for _, server := range r.servers { ips, err = r.resolve(ctx, &server, host) if err != nil { - r.logger.Error(err) + r.options.logger.Error(err) continue } - r.logger.Debugf("resolve %s via %s: %v", host, server.exchanger.String(), ips) + r.options.logger.Debugf("resolve %s via %s: %v", host, server.exchanger.String(), ips) if len(ips) > 0 { break