add async option for dns handler
This commit is contained in:
parent
18fa84b51f
commit
de5ce1e1ca
@ -142,7 +142,7 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
reply, err := h.exchange(ctx, (*b)[:n], log)
|
reply, err := h.request(ctx, (*b)[:n], log)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -167,7 +167,7 @@ func (h *dnsHandler) checkRateLimit(addr net.Addr) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) {
|
func (h *dnsHandler) request(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) {
|
||||||
mq := dns.Msg{}
|
mq := dns.Msg{}
|
||||||
if err := mq.Unpack(msg); err != nil {
|
if err := mq.Unpack(msg); err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
@ -210,50 +210,63 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger
|
|||||||
|
|
||||||
// only cache for single question message.
|
// only cache for single question message.
|
||||||
if len(mq.Question) == 1 {
|
if len(mq.Question) == 1 {
|
||||||
key := resolver_util.NewCacheKey(&mq.Question[0])
|
var ttl time.Duration
|
||||||
mr = h.cache.Load(key)
|
mr, ttl = h.cache.Load(resolver_util.NewCacheKey(&mq.Question[0]))
|
||||||
if mr != nil {
|
if mr != nil {
|
||||||
log.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
|
|
||||||
mr.Id = mq.Id
|
mr.Id = mq.Id
|
||||||
|
if int32(ttl.Seconds()) > 0 {
|
||||||
|
log.Debugf("message %d (cached): %s", mq.Id, mq.Question[0].String())
|
||||||
b := bufpool.Get(h.md.bufferSize)
|
b := bufpool.Get(h.md.bufferSize)
|
||||||
return mr.PackBuffer(*b)
|
return mr.PackBuffer(*b)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if mr != nil {
|
|
||||||
h.cache.Store(key, mr, h.md.ttl)
|
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if mr != nil && h.md.async {
|
||||||
|
b := bufpool.Get(h.md.bufferSize)
|
||||||
|
reply, err := mr.PackBuffer(*b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
h.cache.RefreshTTL(resolver_util.NewCacheKey(&mq.Question[0]))
|
||||||
|
|
||||||
|
log.Debugf("exchange message %d (async): %s", mq.Id, mq.Question[0].String())
|
||||||
|
go h.exchange(ctx, &mq)
|
||||||
|
return reply, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("exchange message %d: %s", mq.Id, mq.Question[0].String())
|
||||||
|
return h.exchange(ctx, &mq)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *dnsHandler) exchange(ctx context.Context, mq *dns.Msg) ([]byte, error) {
|
||||||
b := bufpool.Get(h.md.bufferSize)
|
b := bufpool.Get(h.md.bufferSize)
|
||||||
defer bufpool.Put(b)
|
defer bufpool.Put(b)
|
||||||
|
|
||||||
query, err := mq.PackBuffer(*b)
|
query, err := mq.PackBuffer(*b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ex := h.selectExchanger(ctx, strings.Trim(mq.Question[0].Name, "."))
|
ex := h.selectExchanger(ctx, strings.Trim(mq.Question[0].Name, "."))
|
||||||
if ex == nil {
|
if ex == nil {
|
||||||
err := fmt.Errorf("exchange not found for %s", mq.Question[0].Name)
|
err = fmt.Errorf("exchange not found for %s", mq.Question[0].Name)
|
||||||
log.Error(err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
reply, err := ex.Exchange(ctx, query)
|
reply, err := ex.Exchange(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
mr = &dns.Msg{}
|
mr := &dns.Msg{}
|
||||||
if err = mr.Unpack(reply); err != nil {
|
if err = mr.Unpack(reply); err != nil {
|
||||||
log.Error(err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if len(mq.Question) == 1 {
|
||||||
|
key := resolver_util.NewCacheKey(&mq.Question[0])
|
||||||
|
h.cache.Store(key, mr, h.md.ttl)
|
||||||
|
}
|
||||||
|
|
||||||
return reply, nil
|
return reply, nil
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,7 @@ type metadata struct {
|
|||||||
// nameservers
|
// nameservers
|
||||||
dns []string
|
dns []string
|
||||||
bufferSize int
|
bufferSize int
|
||||||
|
async bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
|
func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||||
@ -31,6 +32,7 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
|
|||||||
clientIP = "clientIP"
|
clientIP = "clientIP"
|
||||||
dns = "dns"
|
dns = "dns"
|
||||||
bufferSize = "bufferSize"
|
bufferSize = "bufferSize"
|
||||||
|
async = "async"
|
||||||
)
|
)
|
||||||
|
|
||||||
h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
|
h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
|
||||||
@ -48,6 +50,7 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
|
|||||||
if h.md.bufferSize <= 0 {
|
if h.md.bufferSize <= 0 {
|
||||||
h.md.bufferSize = defaultBufferSize
|
h.md.bufferSize = defaultBufferSize
|
||||||
}
|
}
|
||||||
|
h.md.async = mdutil.GetBool(md, async)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,10 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultTTL = 60 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
type CacheKey string
|
type CacheKey string
|
||||||
|
|
||||||
// NewCacheKey generates resolver cache key from question of dns query.
|
// NewCacheKey generates resolver cache key from question of dns query.
|
||||||
@ -40,25 +44,31 @@ func (c *Cache) WithLogger(logger logger.Logger) *Cache {
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) Load(key CacheKey) *dns.Msg {
|
func (c *Cache) Load(key CacheKey) (msg *dns.Msg, ttl time.Duration) {
|
||||||
v, ok := c.m.Load(key)
|
v, ok := c.m.Load(key)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
item, ok := v.(*cacheItem)
|
item, ok := v.(*cacheItem)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Since(item.ts) > item.ttl {
|
msg = item.msg.Copy()
|
||||||
c.m.Delete(key)
|
for i := range msg.Answer {
|
||||||
return nil
|
d := uint32(time.Since(item.ts).Seconds())
|
||||||
|
if msg.Answer[i].Header().Ttl > d {
|
||||||
|
msg.Answer[i].Header().Ttl -= d
|
||||||
|
} else {
|
||||||
|
msg.Answer[i].Header().Ttl = 1
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
ttl = item.ttl - time.Since(item.ts)
|
||||||
|
|
||||||
c.logger.Debugf("hit resolver cache: %s", key)
|
c.logger.Debugf("hit resolver cache: %s, ttl: %v", key, ttl)
|
||||||
|
|
||||||
return item.msg.Copy()
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) {
|
func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) {
|
||||||
@ -73,9 +83,13 @@ func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) {
|
|||||||
ttl = v
|
ttl = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if ttl == 0 {
|
if ttl == 0 {
|
||||||
ttl = 30 * time.Second
|
ttl = defaultTTL
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for i := range mr.Answer {
|
||||||
|
mr.Answer[i].Header().Ttl = uint32(ttl.Seconds())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.m.Store(key, &cacheItem{
|
c.m.Store(key, &cacheItem{
|
||||||
@ -86,3 +100,16 @@ func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) {
|
|||||||
|
|
||||||
c.logger.Debugf("resolver cache store: %s, ttl: %v", key, ttl)
|
c.logger.Debugf("resolver cache store: %s, ttl: %v", key, ttl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Cache) RefreshTTL(key CacheKey) {
|
||||||
|
v, ok := c.m.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
item, ok := v.(*cacheItem)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
item.ts = time.Now()
|
||||||
|
}
|
||||||
|
@ -150,8 +150,8 @@ func (r *resolver) resolve6(ctx context.Context, server *NameServer, host string
|
|||||||
|
|
||||||
func (r *resolver) resolveIPs(ctx context.Context, server *NameServer, mq *dns.Msg) (ips []net.IP, err error) {
|
func (r *resolver) resolveIPs(ctx context.Context, server *NameServer, mq *dns.Msg) (ips []net.IP, err error) {
|
||||||
key := resolver_util.NewCacheKey(&mq.Question[0])
|
key := resolver_util.NewCacheKey(&mq.Question[0])
|
||||||
mr := r.cache.Load(key)
|
mr, ttl := r.cache.Load(key)
|
||||||
if mr == nil {
|
if ttl <= 0 {
|
||||||
resolver_util.AddSubnetOpt(mq, server.ClientIP)
|
resolver_util.AddSubnetOpt(mq, server.ClientIP)
|
||||||
mr, err = r.exchange(ctx, server.exchanger, mq)
|
mr, err = r.exchange(ctx, server.exchanger, mq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user