add async option for dns handler

This commit is contained in:
ginuerzh
2023-04-14 18:50:33 +08:00
parent 18fa84b51f
commit de5ce1e1ca
4 changed files with 76 additions and 33 deletions

View File

@ -142,7 +142,7 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
return err
}
reply, err := h.exchange(ctx, (*b)[:n], log)
reply, err := h.request(ctx, (*b)[:n], log)
if err != nil {
return err
}
@ -167,7 +167,7 @@ func (h *dnsHandler) checkRateLimit(addr net.Addr) bool {
return true
}
func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) {
func (h *dnsHandler) request(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) {
mq := dns.Msg{}
if err := mq.Unpack(msg); err != nil {
log.Error(err)
@ -210,50 +210,63 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger
// only cache for single question message.
if len(mq.Question) == 1 {
key := resolver_util.NewCacheKey(&mq.Question[0])
mr = h.cache.Load(key)
var ttl time.Duration
mr, ttl = h.cache.Load(resolver_util.NewCacheKey(&mq.Question[0]))
if mr != nil {
log.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
mr.Id = mq.Id
b := bufpool.Get(h.md.bufferSize)
return mr.PackBuffer(*b)
}
defer func() {
if mr != nil {
h.cache.Store(key, mr, h.md.ttl)
if int32(ttl.Seconds()) > 0 {
log.Debugf("message %d (cached): %s", mq.Id, mq.Question[0].String())
b := bufpool.Get(h.md.bufferSize)
return mr.PackBuffer(*b)
}
}()
}
}
if mr != nil && h.md.async {
b := bufpool.Get(h.md.bufferSize)
reply, err := mr.PackBuffer(*b)
if err != nil {
return nil, err
}
h.cache.RefreshTTL(resolver_util.NewCacheKey(&mq.Question[0]))
log.Debugf("exchange message %d (async): %s", mq.Id, mq.Question[0].String())
go h.exchange(ctx, &mq)
return reply, nil
}
log.Debugf("exchange message %d: %s", mq.Id, mq.Question[0].String())
return h.exchange(ctx, &mq)
}
func (h *dnsHandler) exchange(ctx context.Context, mq *dns.Msg) ([]byte, error) {
b := bufpool.Get(h.md.bufferSize)
defer bufpool.Put(b)
query, err := mq.PackBuffer(*b)
if err != nil {
log.Error(err)
return nil, err
}
ex := h.selectExchanger(ctx, strings.Trim(mq.Question[0].Name, "."))
if ex == nil {
err := fmt.Errorf("exchange not found for %s", mq.Question[0].Name)
log.Error(err)
err = fmt.Errorf("exchange not found for %s", mq.Question[0].Name)
return nil, err
}
reply, err := ex.Exchange(ctx, query)
if err != nil {
log.Error(err)
return nil, err
}
mr = &dns.Msg{}
mr := &dns.Msg{}
if err = mr.Unpack(reply); err != nil {
log.Error(err)
return nil, err
}
if len(mq.Question) == 1 {
key := resolver_util.NewCacheKey(&mq.Question[0])
h.cache.Store(key, mr, h.md.ttl)
}
return reply, nil
}

View File

@ -21,6 +21,7 @@ type metadata struct {
// nameservers
dns []string
bufferSize int
async bool
}
func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
@ -31,6 +32,7 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
clientIP = "clientIP"
dns = "dns"
bufferSize = "bufferSize"
async = "async"
)
h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
@ -48,6 +50,7 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
if h.md.bufferSize <= 0 {
h.md.bufferSize = defaultBufferSize
}
h.md.async = mdutil.GetBool(md, async)
return
}