add sshd listener
This commit is contained in:
@ -33,7 +33,6 @@ type dnsHandler struct {
|
||||
exchangers []exchanger.Exchanger
|
||||
cache *resolver_util.Cache
|
||||
router *chain.Router
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options handler.Options
|
||||
}
|
||||
@ -50,19 +49,18 @@ func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
}
|
||||
|
||||
func (h *dnsHandler) Init(md md.Metadata) (err error) {
|
||||
h.logger = h.options.Logger
|
||||
|
||||
if err = h.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
log := h.options.Logger
|
||||
|
||||
h.cache = resolver_util.NewCache().WithLogger(h.options.Logger)
|
||||
h.cache = resolver_util.NewCache().WithLogger(log)
|
||||
h.router = &chain.Router{
|
||||
Retries: h.options.Retries,
|
||||
Chain: h.options.Chain,
|
||||
Resolver: h.options.Resolver,
|
||||
// Hosts: h.options.Hosts,
|
||||
Logger: h.options.Logger,
|
||||
Logger: log,
|
||||
}
|
||||
|
||||
for _, server := range h.md.dns {
|
||||
@ -74,10 +72,10 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
|
||||
server,
|
||||
exchanger.RouterOption(h.router),
|
||||
exchanger.TimeoutOption(h.md.timeout),
|
||||
exchanger.LoggerOption(h.logger),
|
||||
exchanger.LoggerOption(log),
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Warnf("parse %s: %v", server, err)
|
||||
log.Warnf("parse %s: %v", server, err)
|
||||
continue
|
||||
}
|
||||
h.exchangers = append(h.exchangers, ex)
|
||||
@ -87,9 +85,9 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
|
||||
defaultNameserver,
|
||||
exchanger.RouterOption(h.router),
|
||||
exchanger.TimeoutOption(h.md.timeout),
|
||||
exchanger.LoggerOption(h.logger),
|
||||
exchanger.LoggerOption(log),
|
||||
)
|
||||
h.logger.Warnf("resolver not found, default to %s", defaultNameserver)
|
||||
log.Warnf("resolver not found, default to %s", defaultNameserver)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -103,14 +101,14 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
start := time.Now()
|
||||
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||
log := h.options.Logger.WithFields(map[string]interface{}{
|
||||
"remote": conn.RemoteAddr().String(),
|
||||
"local": conn.LocalAddr().String(),
|
||||
})
|
||||
|
||||
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
defer func() {
|
||||
h.logger.WithFields(map[string]interface{}{
|
||||
log.WithFields(map[string]interface{}{
|
||||
"duration": time.Since(start),
|
||||
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}()
|
||||
@ -120,26 +118,25 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
|
||||
n, err := conn.Read(*b)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
h.logger.Info("read data: ", n)
|
||||
|
||||
reply, err := h.exchange(ctx, (*b)[:n])
|
||||
reply, err := h.exchange(ctx, (*b)[:n], log)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer bufpool.Put(&reply)
|
||||
|
||||
if _, err = conn.Write(reply); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||
func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) {
|
||||
mq := dns.Msg{}
|
||||
if err := mq.Unpack(msg); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -149,23 +146,23 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||
|
||||
resolver_util.AddSubnetOpt(&mq, h.md.clientIP)
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
h.logger.Debug(mq.String())
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
log.Debug(mq.String())
|
||||
} else {
|
||||
h.logger.Info(h.dumpMsgHeader(&mq))
|
||||
log.Info(h.dumpMsgHeader(&mq))
|
||||
}
|
||||
|
||||
var mr *dns.Msg
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
defer func() {
|
||||
if mr != nil {
|
||||
h.logger.Debug(mr.String())
|
||||
log.Debug(mr.String())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
mr = h.lookupHosts(&mq)
|
||||
mr = h.lookupHosts(&mq, log)
|
||||
if mr != nil {
|
||||
b := bufpool.Get(4096)
|
||||
return mr.PackBuffer(*b)
|
||||
@ -176,7 +173,7 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||
key := resolver_util.NewCacheKey(&mq.Question[0])
|
||||
mr = h.cache.Load(key)
|
||||
if mr != nil {
|
||||
h.logger.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
|
||||
log.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
|
||||
mr.Id = mq.Id
|
||||
|
||||
b := bufpool.Get(4096)
|
||||
@ -195,18 +192,18 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||
|
||||
query, err := mq.PackBuffer(*b)
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var reply []byte
|
||||
for _, ex := range h.exchangers {
|
||||
h.logger.Infof("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
|
||||
log.Infof("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
|
||||
reply, err = ex.Exchange(ctx, query)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -214,21 +211,21 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||
|
||||
mr = &dns.Msg{}
|
||||
if err = mr.Unpack(reply); err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||
h.logger.Debug(mr.String())
|
||||
if log.IsLevelEnabled(logger.DebugLevel) {
|
||||
log.Debug(mr.String())
|
||||
} else {
|
||||
h.logger.Info(h.dumpMsgHeader(mr))
|
||||
log.Info(h.dumpMsgHeader(mr))
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
// lookup host mapper
|
||||
func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) {
|
||||
func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) {
|
||||
if h.options.Hosts == nil ||
|
||||
r.Question[0].Qclass != dns.ClassINET ||
|
||||
(r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA) {
|
||||
@ -246,12 +243,12 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) {
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
h.logger.Debugf("hit host mapper: %s -> %s", host, ips)
|
||||
log.Debugf("hit host mapper: %s -> %s", host, ips)
|
||||
|
||||
for _, ip := range ips {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s IN A %s\n", r.Question[0].Name, ip.String()))
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return nil
|
||||
}
|
||||
m.Answer = append(m.Answer, rr)
|
||||
@ -262,12 +259,12 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) {
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
h.logger.Debugf("hit host mapper: %s -> %s", host, ips)
|
||||
log.Debugf("hit host mapper: %s -> %s", host, ips)
|
||||
|
||||
for _, ip := range ips {
|
||||
rr, err := dns.NewRR(fmt.Sprintf("%s IN AAAA %s\n", r.Question[0].Name, ip.String()))
|
||||
if err != nil {
|
||||
h.logger.Error(err)
|
||||
log.Error(err)
|
||||
return nil
|
||||
}
|
||||
m.Answer = append(m.Answer, rr)
|
||||
|
Reference in New Issue
Block a user