add resolver

This commit is contained in:
ginuerzh
2021-12-30 19:12:42 +08:00
parent 9b3d7e1110
commit a6671a468e
8 changed files with 343 additions and 50 deletions

View File

@ -6,12 +6,14 @@ import (
"errors"
"net"
"strconv"
"strings"
"time"
"github.com/go-gost/gost/pkg/bypass"
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/common/bufpool"
"github.com/go-gost/gost/pkg/handler"
resolver_util "github.com/go-gost/gost/pkg/internal/util/resolver"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
@ -27,6 +29,7 @@ type dnsHandler struct {
chain *chain.Chain
bypass bypass.Bypass
exchangers []exchanger.Exchanger
cache *resolver_util.Cache
logger logger.Logger
md metadata
}
@ -37,8 +40,11 @@ func NewHandler(opts ...handler.Option) handler.Handler {
opt(options)
}
cache := resolver_util.NewCache().WithLogger(options.Logger)
return &dnsHandler{
bypass: options.Bypass,
cache: cache,
logger: options.Logger,
}
}
@ -49,9 +55,14 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
}
for _, server := range h.md.servers {
server = strings.TrimSpace(server)
if server == "" {
continue
}
ex, err := exchanger.NewExchanger(
server,
exchanger.ChainOption(h.chain),
exchanger.TimeoutOption(h.md.timeout),
exchanger.LoggerOption(h.logger),
)
if err != nil {
@ -61,14 +72,18 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
h.exchangers = append(h.exchangers, ex)
}
if len(h.exchangers) == 0 {
ex, _ := exchanger.NewExchanger(
"udp://127.0.0.53:53",
addr := "udp://127.0.0.1:53"
ex, err := exchanger.NewExchanger(
addr,
exchanger.ChainOption(h.chain),
exchanger.TimeoutOption(h.md.timeout),
exchanger.LoggerOption(h.logger),
)
if ex != nil {
h.exchangers = append(h.exchangers, ex)
h.logger.Warnf("resolver not found, default to %s", addr)
if err != nil {
return err
}
h.exchangers = append(h.exchangers, ex)
}
return
}
@ -106,7 +121,6 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
reply, err := h.exchange(ctx, b[:n])
if err != nil {
h.logger.Error(err)
return
}
@ -118,6 +132,7 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
mq := dns.Msg{}
if err := mq.Unpack(msg); err != nil {
h.logger.Error(err)
return nil, err
}
@ -125,6 +140,8 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
return nil, errors.New("msg: empty question")
}
resolver_util.AddSubnetOpt(&mq, h.md.clientIP)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(mq.String())
} else {
@ -132,26 +149,22 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
}
var mr *dns.Msg
// Only cache for single question.
/*
if len(mq.Question) == 1 {
key := newResolverCacheKey(&mq.Question[0])
mr = r.cache.loadCache(key)
if mr != nil {
log.Logf("[dns] exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
mr.Id = mq.Id
return mr.Pack()
}
defer func() {
if mr != nil {
r.cache.storeCache(key, mr, r.TTL())
}
}()
// cache only for single question message.
if len(mq.Question) == 1 {
key := resolver_util.NewCacheKey(&mq.Question[0])
mr = h.cache.Load(key)
if mr != nil {
h.logger.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
mr.Id = mq.Id
return mr.Pack()
}
*/
// r.addSubnetOpt(mq)
defer func() {
if mr != nil {
h.cache.Store(key, mr, h.md.ttl)
}
}()
}
query, err := mq.Pack()
if err != nil {
@ -169,7 +182,6 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
h.logger.Error(err)
}
if err != nil {
h.logger.Error(err)
return nil, err
}