add hosts support for dns
This commit is contained in:
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -49,6 +50,8 @@ func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
}
|
||||
|
||||
func (h *dnsHandler) Init(md md.Metadata) (err error) {
|
||||
h.logger = h.options.Logger
|
||||
|
||||
if err = h.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
@ -58,10 +61,9 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
|
||||
Retries: h.options.Retries,
|
||||
Chain: h.options.Chain,
|
||||
Resolver: h.options.Resolver,
|
||||
Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
// Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
}
|
||||
h.logger = h.options.Logger
|
||||
|
||||
for _, server := range h.md.dns {
|
||||
server = strings.TrimSpace(server)
|
||||
@ -127,6 +129,7 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer bufpool.Put(&reply)
|
||||
|
||||
if _, err = conn.Write(reply); err != nil {
|
||||
h.logger.Error(err)
|
||||
@ -153,14 +156,31 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
var mr *dns.Msg
|
||||
// cache only for single question message.
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
defer func() {
|
||||
if mr != nil {
|
||||
h.logger.Debug(mr.String())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
mr = h.lookupHosts(&mq)
|
||||
if mr != nil {
|
||||
b := bufpool.Get(4096)
|
||||
return mr.PackBuffer(*b)
|
||||
}
|
||||
|
||||
// only cache for single question message.
|
||||
if len(mq.Question) == 1 {
|
||||
key := resolver_util.NewCacheKey(&mq.Question[0])
|
||||
mr = h.cache.Load(key)
|
||||
if mr != nil {
|
||||
h.logger.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
|
||||
mr.Id = mq.Id
|
||||
return mr.Pack()
|
||||
|
||||
b := bufpool.Get(4096)
|
||||
return mr.PackBuffer(*b)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
@ -170,7 +190,10 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||
}()
|
||||
}
|
||||
|
||||
query, err := mq.Pack()
|
||||
b := bufpool.Get(4096)
|
||||
defer bufpool.Put(b)
|
||||
|
||||
query, err := mq.PackBuffer(*b)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
return nil, err
|
||||
@ -204,6 +227,56 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// lookup host mapper
|
||||
func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) {
|
||||
if h.options.Hosts == nil ||
|
||||
r.Question[0].Qclass != dns.ClassINET ||
|
||||
(r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA) {
|
||||
return nil
|
||||
}
|
||||
|
||||
m = &dns.Msg{}
|
||||
m.SetReply(r)
|
||||
|
||||
host := strings.TrimSuffix(r.Question[0].Name, ".")
|
||||
|
||||
switch r.Question[0].Qtype {
|
||||
case dns.TypeA:
|
||||
ips, _ := h.options.Hosts.Lookup("ip4", host)
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
h.logger.Debugf("hit host mapper: %s -> %s", host, ips)
|
||||
|
||||
for _, ip := range ips {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s IN A %s\n", r.Question[0].Name, ip.String()))
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
return nil
|
||||
}
|
||||
m.Answer = append(m.Answer, rr)
|
||||
}
|
||||
|
||||
case dns.TypeAAAA:
|
||||
ips, _ := h.options.Hosts.Lookup("ip6", host)
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
h.logger.Debugf("hit host mapper: %s -> %s", host, ips)
|
||||
|
||||
for _, ip := range ips {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s IN AAAA %s\n", r.Question[0].Name, ip.String()))
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
return nil
|
||||
}
|
||||
m.Answer = append(m.Answer, rr)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string {
|
||||
buf := new(bytes.Buffer)
|
||||
buf.WriteString(m.MsgHdr.String() + " ")
|
||||
|
Reference in New Issue
Block a user