add sshd listener

This commit is contained in:
ginuerzh
2022-01-26 15:53:33 +08:00
parent a134026e76
commit 04dfc8c4c3
39 changed files with 1101 additions and 848 deletions

View File

@ -33,7 +33,6 @@ type dnsHandler struct {
exchangers []exchanger.Exchanger
cache *resolver_util.Cache
router *chain.Router
logger logger.Logger
md metadata
options handler.Options
}
@ -50,19 +49,18 @@ func NewHandler(opts ...handler.Option) handler.Handler {
}
func (h *dnsHandler) Init(md md.Metadata) (err error) {
h.logger = h.options.Logger
if err = h.parseMetadata(md); err != nil {
return
}
log := h.options.Logger
h.cache = resolver_util.NewCache().WithLogger(h.options.Logger)
h.cache = resolver_util.NewCache().WithLogger(log)
h.router = &chain.Router{
Retries: h.options.Retries,
Chain: h.options.Chain,
Resolver: h.options.Resolver,
// Hosts: h.options.Hosts,
Logger: h.options.Logger,
Logger: log,
}
for _, server := range h.md.dns {
@ -74,10 +72,10 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
server,
exchanger.RouterOption(h.router),
exchanger.TimeoutOption(h.md.timeout),
exchanger.LoggerOption(h.logger),
exchanger.LoggerOption(log),
)
if err != nil {
h.logger.Warnf("parse %s: %v", server, err)
log.Warnf("parse %s: %v", server, err)
continue
}
h.exchangers = append(h.exchangers, ex)
@ -87,9 +85,9 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
defaultNameserver,
exchanger.RouterOption(h.router),
exchanger.TimeoutOption(h.md.timeout),
exchanger.LoggerOption(h.logger),
exchanger.LoggerOption(log),
)
h.logger.Warnf("resolver not found, default to %s", defaultNameserver)
log.Warnf("resolver not found, default to %s", defaultNameserver)
if err != nil {
return err
}
@ -103,14 +101,14 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
log := h.options.Logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
log.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
@ -120,26 +118,25 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
n, err := conn.Read(*b)
if err != nil {
h.logger.Error(err)
log.Error(err)
return
}
h.logger.Info("read data: ", n)
reply, err := h.exchange(ctx, (*b)[:n])
reply, err := h.exchange(ctx, (*b)[:n], log)
if err != nil {
return
}
defer bufpool.Put(&reply)
if _, err = conn.Write(reply); err != nil {
h.logger.Error(err)
log.Error(err)
}
}
func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger) ([]byte, error) {
mq := dns.Msg{}
if err := mq.Unpack(msg); err != nil {
h.logger.Error(err)
log.Error(err)
return nil, err
}
@ -149,23 +146,23 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
resolver_util.AddSubnetOpt(&mq, h.md.clientIP)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(mq.String())
if log.IsLevelEnabled(logger.DebugLevel) {
log.Debug(mq.String())
} else {
h.logger.Info(h.dumpMsgHeader(&mq))
log.Info(h.dumpMsgHeader(&mq))
}
var mr *dns.Msg
if h.logger.IsLevelEnabled(logger.DebugLevel) {
if log.IsLevelEnabled(logger.DebugLevel) {
defer func() {
if mr != nil {
h.logger.Debug(mr.String())
log.Debug(mr.String())
}
}()
}
mr = h.lookupHosts(&mq)
mr = h.lookupHosts(&mq, log)
if mr != nil {
b := bufpool.Get(4096)
return mr.PackBuffer(*b)
@ -176,7 +173,7 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
key := resolver_util.NewCacheKey(&mq.Question[0])
mr = h.cache.Load(key)
if mr != nil {
h.logger.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
log.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
mr.Id = mq.Id
b := bufpool.Get(4096)
@ -195,18 +192,18 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
query, err := mq.PackBuffer(*b)
if err != nil {
h.logger.Error(err)
log.Error(err)
return nil, err
}
var reply []byte
for _, ex := range h.exchangers {
h.logger.Infof("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
log.Infof("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
reply, err = ex.Exchange(ctx, query)
if err == nil {
break
}
h.logger.Error(err)
log.Error(err)
}
if err != nil {
return nil, err
@ -214,21 +211,21 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
mr = &dns.Msg{}
if err = mr.Unpack(reply); err != nil {
h.logger.Error(err)
log.Error(err)
return nil, err
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(mr.String())
if log.IsLevelEnabled(logger.DebugLevel) {
log.Debug(mr.String())
} else {
h.logger.Info(h.dumpMsgHeader(mr))
log.Info(h.dumpMsgHeader(mr))
}
return reply, nil
}
// lookup host mapper
func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) {
func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) {
if h.options.Hosts == nil ||
r.Question[0].Qclass != dns.ClassINET ||
(r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA) {
@ -246,12 +243,12 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) {
if len(ips) == 0 {
return nil
}
h.logger.Debugf("hit host mapper: %s -> %s", host, ips)
log.Debugf("hit host mapper: %s -> %s", host, ips)
for _, ip := range ips {
rr, err := dns.NewRR(fmt.Sprintf("%s IN A %s\n", r.Question[0].Name, ip.String()))
if err != nil {
h.logger.Error(err)
log.Error(err)
return nil
}
m.Answer = append(m.Answer, rr)
@ -262,12 +259,12 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) {
if len(ips) == 0 {
return nil
}
h.logger.Debugf("hit host mapper: %s -> %s", host, ips)
log.Debugf("hit host mapper: %s -> %s", host, ips)
for _, ip := range ips {
rr, err := dns.NewRR(fmt.Sprintf("%s IN AAAA %s\n", r.Question[0].Name, ip.String()))
if err != nil {
h.logger.Error(err)
log.Error(err)
return nil
}
m.Answer = append(m.Answer, rr)