update resolver

This commit is contained in:
ginuerzh
2023-11-08 20:51:43 +08:00
parent 6980fcfb19
commit a697bf2e35
4 changed files with 84 additions and 4 deletions

View File

@ -198,6 +198,8 @@ type NameserverConfig struct {
Hostname string `yaml:",omitempty" json:"hostname,omitempty"`
TTL time.Duration `yaml:",omitempty" json:"ttl,omitempty"`
Timeout time.Duration `yaml:",omitempty" json:"timeout,omitempty"`
Async bool `yaml:",omitempty" json:"async,omitempty"`
Only string `yaml:",omitempty" json:"only,omitempty"`
}
type ResolverConfig struct {

View File

@ -52,6 +52,8 @@ func ParseResolver(cfg *config.ResolverConfig) (resolver.Resolver, error) {
ClientIP: net.ParseIP(server.ClientIP),
Prefer: server.Prefer,
Hostname: server.Hostname,
Async: server.Async,
Only: server.Only,
})
}

View File

@ -66,7 +66,7 @@ func (c *Cache) Load(key CacheKey) (msg *dns.Msg, ttl time.Duration) {
}
ttl = item.ttl - time.Since(item.ts)
c.logger.Debugf("hit resolver cache: %s, ttl: %v", key, ttl)
c.logger.Debugf("resolver cache hit: %s, ttl: %v", key, ttl)
return
}

View File

@ -22,6 +22,8 @@ type NameServer struct {
ClientIP net.IP
Prefer string
Hostname string // for TLS handshake verification
Async bool
Only string
exchanger exchanger.Exchanger
}
@ -79,6 +81,16 @@ func NewResolver(nameservers []NameServer, opts ...Option) (resolver.Resolver, e
}
server.exchanger = ex
switch server.Only {
case "ip4", "ipv4", "ip6", "ipv6":
server.Prefer = server.Only
default:
server.Only = ""
}
if server.TTL < 0 {
server.Async = false
}
servers = append(servers, server)
}
cache := resolver_util.NewCache().
@ -102,7 +114,11 @@ func (r *localResolver) Resolve(ctx context.Context, network, host string, opts
}
for _, server := range r.servers {
ips, err = r.resolve(ctx, &server, host)
if server.Async {
ips, err = r.resolveAsync(ctx, &server, host)
} else {
ips, err = r.resolve(ctx, &server, host)
}
if err != nil {
r.options.logger.Error(err)
continue
@ -124,18 +140,70 @@ func (r *localResolver) resolve(ctx context.Context, server *NameServer, host st
}
if server.Prefer == "ipv6" { // prefer ipv6
if ips, err = r.resolve6(ctx, server, host); len(ips) > 0 {
if ips, err = r.resolve6(ctx, server, host); len(ips) > 0 || server.Only == "ipv6" {
return
}
return r.resolve4(ctx, server, host)
}
if ips, err = r.resolve4(ctx, server, host); len(ips) > 0 {
if ips, err = r.resolve4(ctx, server, host); len(ips) > 0 || server.Only == "ipv4" {
return
}
return r.resolve6(ctx, server, host)
}
func (r *localResolver) resolveAsync(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) {
ips, ttl, ok := r.lookupCache(ctx, server, host)
if !ok {
return r.resolve(ctx, server, host)
}
if ttl <= 0 {
r.options.logger.Debugf("async resolve %s via %s", host, server.exchanger.String())
go r.resolve(ctx, server, host)
}
return
}
func (r *localResolver) lookupCache(ctx context.Context, server *NameServer, host string) (ips []net.IP, ttl time.Duration, ok bool) {
lookup := func(t uint16, host string) (ips []net.IP, ttl time.Duration, ok bool) {
mq := dns.Msg{}
mq.SetQuestion(dns.Fqdn(host), t)
mr, ttl := r.cache.Load(resolver_util.NewCacheKey(&mq.Question[0]))
if mr == nil {
return
}
ok = true
for _, ans := range mr.Answer {
if ar, _ := ans.(*dns.AAAA); ar != nil {
ips = append(ips, ar.AAAA)
}
if ar, _ := ans.(*dns.A); ar != nil {
ips = append(ips, ar.A)
}
}
return
}
if server.Prefer == "ipv6" {
ips, ttl, ok = lookup(dns.TypeAAAA, host)
if len(ips) > 0 || server.Only == "ipv6" {
return
}
ips, ttl, ok = lookup(dns.TypeA, host)
return
}
ips, ttl, ok = lookup(dns.TypeA, host)
if len(ips) > 0 || server.Only == "ipv4" {
return
}
return lookup(dns.TypeAAAA, host)
}
func (r *localResolver) resolve4(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) {
mq := dns.Msg{}
mq.SetQuestion(dns.Fqdn(host), dns.TypeA)
@ -149,6 +217,10 @@ func (r *localResolver) resolve6(ctx context.Context, server *NameServer, host s
}
func (r *localResolver) resolveIPs(ctx context.Context, server *NameServer, mq *dns.Msg) (ips []net.IP, err error) {
if r.options.logger.IsLevelEnabled(logger.TraceLevel) {
r.options.logger.Trace(mq.String())
}
key := resolver_util.NewCacheKey(&mq.Question[0])
mr, ttl := r.cache.Load(key)
if ttl <= 0 {
@ -158,6 +230,10 @@ func (r *localResolver) resolveIPs(ctx context.Context, server *NameServer, mq *
return
}
r.cache.Store(key, mr, server.TTL)
if r.options.logger.IsLevelEnabled(logger.TraceLevel) {
r.options.logger.Trace(mr.String())
}
}
for _, ans := range mr.Answer {