From de5ce1e1cadc91834584e2d4e09abc8b63adb2cc Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Fri, 14 Apr 2023 18:50:33 +0800 Subject: [PATCH] add async option for dns handler --- handler/dns/handler.go | 53 ++++++++++++++++++++------------- handler/dns/metadata.go | 3 ++ internal/util/resolver/cache.go | 49 +++++++++++++++++++++++------- resolver/resolver.go | 4 +-- 4 files changed, 76 insertions(+), 33 deletions(-) diff --git a/handler/dns/handler.go b/handler/dns/handler.go index d223355..2dd1bbd 100644 --- a/handler/dns/handler.go +++ b/handler/dns/handler.go @@ -142,7 +142,7 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler. return err } - reply, err := h.exchange(ctx, (*b)[:n], log) + reply, err := h.request(ctx, (*b)[:n], log) if err != nil { return err } @@ -167,7 +167,7 @@ func (h *dnsHandler) checkRateLimit(addr net.Addr) bool { return true } -func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) { +func (h *dnsHandler) request(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) { mq := dns.Msg{} if err := mq.Unpack(msg); err != nil { log.Error(err) @@ -210,50 +210,63 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger // only cache for single question message. if len(mq.Question) == 1 { - key := resolver_util.NewCacheKey(&mq.Question[0]) - mr = h.cache.Load(key) + var ttl time.Duration + mr, ttl = h.cache.Load(resolver_util.NewCacheKey(&mq.Question[0])) if mr != nil { - log.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String()) mr.Id = mq.Id - - b := bufpool.Get(h.md.bufferSize) - return mr.PackBuffer(*b) - } - - defer func() { - if mr != nil { - h.cache.Store(key, mr, h.md.ttl) + if int32(ttl.Seconds()) > 0 { + log.Debugf("message %d (cached): %s", mq.Id, mq.Question[0].String()) + b := bufpool.Get(h.md.bufferSize) + return mr.PackBuffer(*b) } - }() + } } + if mr != nil && h.md.async { + b := bufpool.Get(h.md.bufferSize) + reply, err := mr.PackBuffer(*b) + if err != nil { + return nil, err + } + h.cache.RefreshTTL(resolver_util.NewCacheKey(&mq.Question[0])) + + log.Debugf("exchange message %d (async): %s", mq.Id, mq.Question[0].String()) + go h.exchange(ctx, &mq) + return reply, nil + } + + log.Debugf("exchange message %d: %s", mq.Id, mq.Question[0].String()) + return h.exchange(ctx, &mq) +} + +func (h *dnsHandler) exchange(ctx context.Context, mq *dns.Msg) ([]byte, error) { b := bufpool.Get(h.md.bufferSize) defer bufpool.Put(b) query, err := mq.PackBuffer(*b) if err != nil { - log.Error(err) return nil, err } ex := h.selectExchanger(ctx, strings.Trim(mq.Question[0].Name, ".")) if ex == nil { - err := fmt.Errorf("exchange not found for %s", mq.Question[0].Name) - log.Error(err) + err = fmt.Errorf("exchange not found for %s", mq.Question[0].Name) return nil, err } reply, err := ex.Exchange(ctx, query) if err != nil { - log.Error(err) return nil, err } - mr = &dns.Msg{} + mr := &dns.Msg{} if err = mr.Unpack(reply); err != nil { - log.Error(err) return nil, err } + if len(mq.Question) == 1 { + key := resolver_util.NewCacheKey(&mq.Question[0]) + h.cache.Store(key, mr, h.md.ttl) + } return reply, nil } diff --git a/handler/dns/metadata.go b/handler/dns/metadata.go index fc746c4..90ba3b2 100644 --- a/handler/dns/metadata.go +++ b/handler/dns/metadata.go @@ -21,6 +21,7 @@ type metadata struct { // nameservers dns []string bufferSize int + async bool } func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) { @@ -31,6 +32,7 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) { clientIP = "clientIP" dns = "dns" bufferSize = "bufferSize" + async = "async" ) h.md.readTimeout = mdutil.GetDuration(md, readTimeout) @@ -48,6 +50,7 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) { if h.md.bufferSize <= 0 { h.md.bufferSize = defaultBufferSize } + h.md.async = mdutil.GetBool(md, async) return } diff --git a/internal/util/resolver/cache.go b/internal/util/resolver/cache.go index 177fe2f..eb910c2 100644 --- a/internal/util/resolver/cache.go +++ b/internal/util/resolver/cache.go @@ -9,6 +9,10 @@ import ( "github.com/miekg/dns" ) +const ( + defaultTTL = 60 * time.Second +) + type CacheKey string // NewCacheKey generates resolver cache key from question of dns query. @@ -40,25 +44,31 @@ func (c *Cache) WithLogger(logger logger.Logger) *Cache { return c } -func (c *Cache) Load(key CacheKey) *dns.Msg { +func (c *Cache) Load(key CacheKey) (msg *dns.Msg, ttl time.Duration) { v, ok := c.m.Load(key) if !ok { - return nil + return } item, ok := v.(*cacheItem) if !ok { - return nil + return } - if time.Since(item.ts) > item.ttl { - c.m.Delete(key) - return nil + msg = item.msg.Copy() + for i := range msg.Answer { + d := uint32(time.Since(item.ts).Seconds()) + if msg.Answer[i].Header().Ttl > d { + msg.Answer[i].Header().Ttl -= d + } else { + msg.Answer[i].Header().Ttl = 1 + } } + ttl = item.ttl - time.Since(item.ts) - c.logger.Debugf("hit resolver cache: %s", key) + c.logger.Debugf("hit resolver cache: %s, ttl: %v", key, ttl) - return item.msg.Copy() + return } func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) { @@ -73,9 +83,13 @@ func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) { ttl = v } } - } - if ttl == 0 { - ttl = 30 * time.Second + if ttl == 0 { + ttl = defaultTTL + } + } else { + for i := range mr.Answer { + mr.Answer[i].Header().Ttl = uint32(ttl.Seconds()) + } } c.m.Store(key, &cacheItem{ @@ -86,3 +100,16 @@ func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) { c.logger.Debugf("resolver cache store: %s, ttl: %v", key, ttl) } + +func (c *Cache) RefreshTTL(key CacheKey) { + v, ok := c.m.Load(key) + if !ok { + return + } + + item, ok := v.(*cacheItem) + if !ok { + return + } + item.ts = time.Now() +} diff --git a/resolver/resolver.go b/resolver/resolver.go index 536fa58..8d08cfc 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -150,8 +150,8 @@ func (r *resolver) resolve6(ctx context.Context, server *NameServer, host string func (r *resolver) resolveIPs(ctx context.Context, server *NameServer, mq *dns.Msg) (ips []net.IP, err error) { key := resolver_util.NewCacheKey(&mq.Question[0]) - mr := r.cache.Load(key) - if mr == nil { + mr, ttl := r.cache.Load(key) + if ttl <= 0 { resolver_util.AddSubnetOpt(mq, server.ClientIP) mr, err = r.exchange(ctx, server.exchanger, mq) if err != nil {