update resolver
This commit is contained in:
@ -22,6 +22,8 @@ type NameServer struct {
|
||||
ClientIP net.IP
|
||||
Prefer string
|
||||
Hostname string // for TLS handshake verification
|
||||
Async bool
|
||||
Only string
|
||||
exchanger exchanger.Exchanger
|
||||
}
|
||||
|
||||
@ -79,6 +81,16 @@ func NewResolver(nameservers []NameServer, opts ...Option) (resolver.Resolver, e
|
||||
}
|
||||
|
||||
server.exchanger = ex
|
||||
|
||||
switch server.Only {
|
||||
case "ip4", "ipv4", "ip6", "ipv6":
|
||||
server.Prefer = server.Only
|
||||
default:
|
||||
server.Only = ""
|
||||
}
|
||||
if server.TTL < 0 {
|
||||
server.Async = false
|
||||
}
|
||||
servers = append(servers, server)
|
||||
}
|
||||
cache := resolver_util.NewCache().
|
||||
@ -102,7 +114,11 @@ func (r *localResolver) Resolve(ctx context.Context, network, host string, opts
|
||||
}
|
||||
|
||||
for _, server := range r.servers {
|
||||
ips, err = r.resolve(ctx, &server, host)
|
||||
if server.Async {
|
||||
ips, err = r.resolveAsync(ctx, &server, host)
|
||||
} else {
|
||||
ips, err = r.resolve(ctx, &server, host)
|
||||
}
|
||||
if err != nil {
|
||||
r.options.logger.Error(err)
|
||||
continue
|
||||
@ -124,18 +140,70 @@ func (r *localResolver) resolve(ctx context.Context, server *NameServer, host st
|
||||
}
|
||||
|
||||
if server.Prefer == "ipv6" { // prefer ipv6
|
||||
if ips, err = r.resolve6(ctx, server, host); len(ips) > 0 {
|
||||
if ips, err = r.resolve6(ctx, server, host); len(ips) > 0 || server.Only == "ipv6" {
|
||||
return
|
||||
}
|
||||
return r.resolve4(ctx, server, host)
|
||||
}
|
||||
|
||||
if ips, err = r.resolve4(ctx, server, host); len(ips) > 0 {
|
||||
if ips, err = r.resolve4(ctx, server, host); len(ips) > 0 || server.Only == "ipv4" {
|
||||
return
|
||||
}
|
||||
return r.resolve6(ctx, server, host)
|
||||
}
|
||||
|
||||
func (r *localResolver) resolveAsync(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) {
|
||||
ips, ttl, ok := r.lookupCache(ctx, server, host)
|
||||
if !ok {
|
||||
return r.resolve(ctx, server, host)
|
||||
}
|
||||
|
||||
if ttl <= 0 {
|
||||
r.options.logger.Debugf("async resolve %s via %s", host, server.exchanger.String())
|
||||
go r.resolve(ctx, server, host)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *localResolver) lookupCache(ctx context.Context, server *NameServer, host string) (ips []net.IP, ttl time.Duration, ok bool) {
|
||||
lookup := func(t uint16, host string) (ips []net.IP, ttl time.Duration, ok bool) {
|
||||
mq := dns.Msg{}
|
||||
mq.SetQuestion(dns.Fqdn(host), t)
|
||||
mr, ttl := r.cache.Load(resolver_util.NewCacheKey(&mq.Question[0]))
|
||||
if mr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ok = true
|
||||
|
||||
for _, ans := range mr.Answer {
|
||||
if ar, _ := ans.(*dns.AAAA); ar != nil {
|
||||
ips = append(ips, ar.AAAA)
|
||||
}
|
||||
if ar, _ := ans.(*dns.A); ar != nil {
|
||||
ips = append(ips, ar.A)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if server.Prefer == "ipv6" {
|
||||
ips, ttl, ok = lookup(dns.TypeAAAA, host)
|
||||
if len(ips) > 0 || server.Only == "ipv6" {
|
||||
return
|
||||
}
|
||||
|
||||
ips, ttl, ok = lookup(dns.TypeA, host)
|
||||
return
|
||||
}
|
||||
|
||||
ips, ttl, ok = lookup(dns.TypeA, host)
|
||||
if len(ips) > 0 || server.Only == "ipv4" {
|
||||
return
|
||||
}
|
||||
return lookup(dns.TypeAAAA, host)
|
||||
}
|
||||
|
||||
func (r *localResolver) resolve4(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) {
|
||||
mq := dns.Msg{}
|
||||
mq.SetQuestion(dns.Fqdn(host), dns.TypeA)
|
||||
@ -149,6 +217,10 @@ func (r *localResolver) resolve6(ctx context.Context, server *NameServer, host s
|
||||
}
|
||||
|
||||
func (r *localResolver) resolveIPs(ctx context.Context, server *NameServer, mq *dns.Msg) (ips []net.IP, err error) {
|
||||
if r.options.logger.IsLevelEnabled(logger.TraceLevel) {
|
||||
r.options.logger.Trace(mq.String())
|
||||
}
|
||||
|
||||
key := resolver_util.NewCacheKey(&mq.Question[0])
|
||||
mr, ttl := r.cache.Load(key)
|
||||
if ttl <= 0 {
|
||||
@ -158,6 +230,10 @@ func (r *localResolver) resolveIPs(ctx context.Context, server *NameServer, mq *
|
||||
return
|
||||
}
|
||||
r.cache.Store(key, mr, server.TTL)
|
||||
|
||||
if r.options.logger.IsLevelEnabled(logger.TraceLevel) {
|
||||
r.options.logger.Trace(mr.String())
|
||||
}
|
||||
}
|
||||
|
||||
for _, ans := range mr.Answer {
|
||||
|
Reference in New Issue
Block a user