add nodes field for forwarder

This commit is contained in:
ginuerzh
2022-08-25 21:35:37 +08:00
parent 498a425656
commit d043ad94e7
6 changed files with 138 additions and 98 deletions

View File

@ -29,7 +29,8 @@ func init() {
}
type dnsHandler struct {
exchangers []exchanger.Exchanger
group *chain.NodeGroup
exchangers map[string]exchanger.Exchanger
cache *resolver_util.Cache
router *chain.Router
hosts hosts.HostMapper
@ -44,7 +45,8 @@ func NewHandler(opts ...handler.Option) handler.Handler {
}
return &dnsHandler{
options: options,
options: options,
exchangers: make(map[string]exchanger.Exchanger),
}
}
@ -62,23 +64,38 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
}
h.hosts = h.router.Hosts()
for _, server := range h.md.dns {
server = strings.TrimSpace(server)
if server == "" {
if h.group == nil {
h.group = &chain.NodeGroup{}
for i, addr := range h.md.dns {
addr = strings.TrimSpace(addr)
if addr == "" {
continue
}
h.group.AddNode(&chain.Node{
Name: fmt.Sprintf("target-%d", i),
Addr: addr,
Marker: &chain.FailMarker{},
})
}
}
for _, node := range h.group.Nodes() {
addr := strings.TrimSpace(node.Addr)
if addr == "" {
continue
}
ex, err := exchanger.NewExchanger(
server,
addr,
exchanger.RouterOption(h.router),
exchanger.TimeoutOption(h.md.timeout),
exchanger.LoggerOption(log),
)
if err != nil {
log.Warnf("parse %s: %v", server, err)
log.Warnf("parse %s: %v", addr, err)
continue
}
h.exchangers = append(h.exchangers, ex)
h.exchangers[node.Name] = ex
}
if len(h.exchangers) == 0 {
ex, err := exchanger.NewExchanger(
defaultNameserver,
@ -90,12 +107,17 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
if err != nil {
return err
}
h.exchangers = append(h.exchangers, ex)
h.exchangers["default"] = ex
}
return
}
// Forward implements handler.Forwarder.
func (h *dnsHandler) Forward(group *chain.NodeGroup) {
h.group = group
}
func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
@ -152,7 +174,6 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger
}
var mr *dns.Msg
if log.IsLevelEnabled(logger.TraceLevel) {
defer func() {
if mr != nil {
@ -161,6 +182,15 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger
}()
}
if h.options.Bypass != nil && mq.Question[0].Qclass == dns.ClassINET {
if h.options.Bypass.Contains(strings.Trim(mq.Question[0].Name, ".")) {
log.Debug("bypass: ", mq.Question[0].Name)
mr = (&dns.Msg{}).SetReply(&mq)
b := bufpool.Get(h.md.bufferSize)
return mr.PackBuffer(*b)
}
}
mr = h.lookupHosts(&mq, log)
if mr != nil {
b := bufpool.Get(h.md.bufferSize)
@ -195,16 +225,16 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger
return nil, err
}
var reply []byte
for _, ex := range h.exchangers {
log.Debugf("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
reply, err = ex.Exchange(ctx, query)
if err == nil {
break
}
ex := h.selectExchanger(strings.Trim(mq.Question[0].Name, "."))
if ex == nil {
err := fmt.Errorf("exchange not found for %s", mq.Question[0].Name)
log.Error(err)
return nil, err
}
reply, err := ex.Exchange(ctx, query)
if err != nil {
log.Error(err)
return nil, err
}
@ -266,3 +296,15 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) {
return
}
func (h *dnsHandler) selectExchanger(addr string) exchanger.Exchanger {
if h.group == nil {
return nil
}
node := h.group.FilterAddr(addr).Next()
if node == nil {
return nil
}
return h.exchangers[node.Name]
}