add nodes field for forwarder
This commit is contained in:
@ -29,7 +29,8 @@ func init() {
|
||||
}
|
||||
|
||||
type dnsHandler struct {
|
||||
exchangers []exchanger.Exchanger
|
||||
group *chain.NodeGroup
|
||||
exchangers map[string]exchanger.Exchanger
|
||||
cache *resolver_util.Cache
|
||||
router *chain.Router
|
||||
hosts hosts.HostMapper
|
||||
@ -44,7 +45,8 @@ func NewHandler(opts ...handler.Option) handler.Handler {
|
||||
}
|
||||
|
||||
return &dnsHandler{
|
||||
options: options,
|
||||
options: options,
|
||||
exchangers: make(map[string]exchanger.Exchanger),
|
||||
}
|
||||
}
|
||||
|
||||
@ -62,23 +64,38 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
|
||||
}
|
||||
h.hosts = h.router.Hosts()
|
||||
|
||||
for _, server := range h.md.dns {
|
||||
server = strings.TrimSpace(server)
|
||||
if server == "" {
|
||||
if h.group == nil {
|
||||
h.group = &chain.NodeGroup{}
|
||||
for i, addr := range h.md.dns {
|
||||
addr = strings.TrimSpace(addr)
|
||||
if addr == "" {
|
||||
continue
|
||||
}
|
||||
h.group.AddNode(&chain.Node{
|
||||
Name: fmt.Sprintf("target-%d", i),
|
||||
Addr: addr,
|
||||
Marker: &chain.FailMarker{},
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, node := range h.group.Nodes() {
|
||||
addr := strings.TrimSpace(node.Addr)
|
||||
if addr == "" {
|
||||
continue
|
||||
}
|
||||
ex, err := exchanger.NewExchanger(
|
||||
server,
|
||||
addr,
|
||||
exchanger.RouterOption(h.router),
|
||||
exchanger.TimeoutOption(h.md.timeout),
|
||||
exchanger.LoggerOption(log),
|
||||
)
|
||||
if err != nil {
|
||||
log.Warnf("parse %s: %v", server, err)
|
||||
log.Warnf("parse %s: %v", addr, err)
|
||||
continue
|
||||
}
|
||||
h.exchangers = append(h.exchangers, ex)
|
||||
h.exchangers[node.Name] = ex
|
||||
}
|
||||
|
||||
if len(h.exchangers) == 0 {
|
||||
ex, err := exchanger.NewExchanger(
|
||||
defaultNameserver,
|
||||
@ -90,12 +107,17 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.exchangers = append(h.exchangers, ex)
|
||||
h.exchangers["default"] = ex
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Forward implements handler.Forwarder.
|
||||
func (h *dnsHandler) Forward(group *chain.NodeGroup) {
|
||||
h.group = group
|
||||
}
|
||||
|
||||
func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
|
||||
defer conn.Close()
|
||||
|
||||
@ -152,7 +174,6 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger
|
||||
}
|
||||
|
||||
var mr *dns.Msg
|
||||
|
||||
if log.IsLevelEnabled(logger.TraceLevel) {
|
||||
defer func() {
|
||||
if mr != nil {
|
||||
@ -161,6 +182,15 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger
|
||||
}()
|
||||
}
|
||||
|
||||
if h.options.Bypass != nil && mq.Question[0].Qclass == dns.ClassINET {
|
||||
if h.options.Bypass.Contains(strings.Trim(mq.Question[0].Name, ".")) {
|
||||
log.Debug("bypass: ", mq.Question[0].Name)
|
||||
mr = (&dns.Msg{}).SetReply(&mq)
|
||||
b := bufpool.Get(h.md.bufferSize)
|
||||
return mr.PackBuffer(*b)
|
||||
}
|
||||
}
|
||||
|
||||
mr = h.lookupHosts(&mq, log)
|
||||
if mr != nil {
|
||||
b := bufpool.Get(h.md.bufferSize)
|
||||
@ -195,16 +225,16 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var reply []byte
|
||||
for _, ex := range h.exchangers {
|
||||
log.Debugf("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
|
||||
reply, err = ex.Exchange(ctx, query)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
ex := h.selectExchanger(strings.Trim(mq.Question[0].Name, "."))
|
||||
if ex == nil {
|
||||
err := fmt.Errorf("exchange not found for %s", mq.Question[0].Name)
|
||||
log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reply, err := ex.Exchange(ctx, query)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -266,3 +296,15 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (h *dnsHandler) selectExchanger(addr string) exchanger.Exchanger {
|
||||
if h.group == nil {
|
||||
return nil
|
||||
}
|
||||
node := h.group.FilterAddr(addr).Next()
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return h.exchangers[node.Name]
|
||||
}
|
||||
|
Reference in New Issue
Block a user