add resolver
This commit is contained in:
parent
9b3d7e1110
commit
a6671a468e
9
gost.yml
9
gost.yml
@ -14,16 +14,17 @@ profiling:
|
|||||||
|
|
||||||
resolvers:
|
resolvers:
|
||||||
- name: resolver-0
|
- name: resolver-0
|
||||||
|
nameservers:
|
||||||
|
- addr: udp://8.8.8.8:53
|
||||||
|
chain: chain-0
|
||||||
ttl: 60s
|
ttl: 60s
|
||||||
prefer: ipv4
|
prefer: ipv4
|
||||||
clientIP: 1.2.3.4
|
clientIP: 1.2.3.4
|
||||||
nameServers:
|
timeout: 3s
|
||||||
- addr: udp://8.8.8.8:53
|
|
||||||
timeout: 5s
|
|
||||||
- addr: tcp://1.1.1.1:53
|
- addr: tcp://1.1.1.1:53
|
||||||
- addr: tls://1.1.1.1:853
|
- addr: tls://1.1.1.1:853
|
||||||
- addr: https://1.0.0.1/dns-query
|
- addr: https://1.0.0.1/dns-query
|
||||||
domain: cloudflare-dns.com
|
hostname: cloudflare-dns.com
|
||||||
|
|
||||||
services:
|
services:
|
||||||
- name: http+tcp
|
- name: http+tcp
|
||||||
|
@ -6,12 +6,14 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-gost/gost/pkg/bypass"
|
"github.com/go-gost/gost/pkg/bypass"
|
||||||
"github.com/go-gost/gost/pkg/chain"
|
"github.com/go-gost/gost/pkg/chain"
|
||||||
"github.com/go-gost/gost/pkg/common/bufpool"
|
"github.com/go-gost/gost/pkg/common/bufpool"
|
||||||
"github.com/go-gost/gost/pkg/handler"
|
"github.com/go-gost/gost/pkg/handler"
|
||||||
|
resolver_util "github.com/go-gost/gost/pkg/internal/util/resolver"
|
||||||
"github.com/go-gost/gost/pkg/logger"
|
"github.com/go-gost/gost/pkg/logger"
|
||||||
md "github.com/go-gost/gost/pkg/metadata"
|
md "github.com/go-gost/gost/pkg/metadata"
|
||||||
"github.com/go-gost/gost/pkg/registry"
|
"github.com/go-gost/gost/pkg/registry"
|
||||||
@ -27,6 +29,7 @@ type dnsHandler struct {
|
|||||||
chain *chain.Chain
|
chain *chain.Chain
|
||||||
bypass bypass.Bypass
|
bypass bypass.Bypass
|
||||||
exchangers []exchanger.Exchanger
|
exchangers []exchanger.Exchanger
|
||||||
|
cache *resolver_util.Cache
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
md metadata
|
md metadata
|
||||||
}
|
}
|
||||||
@ -37,8 +40,11 @@ func NewHandler(opts ...handler.Option) handler.Handler {
|
|||||||
opt(options)
|
opt(options)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cache := resolver_util.NewCache().WithLogger(options.Logger)
|
||||||
|
|
||||||
return &dnsHandler{
|
return &dnsHandler{
|
||||||
bypass: options.Bypass,
|
bypass: options.Bypass,
|
||||||
|
cache: cache,
|
||||||
logger: options.Logger,
|
logger: options.Logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -49,9 +55,14 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, server := range h.md.servers {
|
for _, server := range h.md.servers {
|
||||||
|
server = strings.TrimSpace(server)
|
||||||
|
if server == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
ex, err := exchanger.NewExchanger(
|
ex, err := exchanger.NewExchanger(
|
||||||
server,
|
server,
|
||||||
exchanger.ChainOption(h.chain),
|
exchanger.ChainOption(h.chain),
|
||||||
|
exchanger.TimeoutOption(h.md.timeout),
|
||||||
exchanger.LoggerOption(h.logger),
|
exchanger.LoggerOption(h.logger),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -61,14 +72,18 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
|
|||||||
h.exchangers = append(h.exchangers, ex)
|
h.exchangers = append(h.exchangers, ex)
|
||||||
}
|
}
|
||||||
if len(h.exchangers) == 0 {
|
if len(h.exchangers) == 0 {
|
||||||
ex, _ := exchanger.NewExchanger(
|
addr := "udp://127.0.0.1:53"
|
||||||
"udp://127.0.0.53:53",
|
ex, err := exchanger.NewExchanger(
|
||||||
|
addr,
|
||||||
exchanger.ChainOption(h.chain),
|
exchanger.ChainOption(h.chain),
|
||||||
|
exchanger.TimeoutOption(h.md.timeout),
|
||||||
exchanger.LoggerOption(h.logger),
|
exchanger.LoggerOption(h.logger),
|
||||||
)
|
)
|
||||||
if ex != nil {
|
h.logger.Warnf("resolver not found, default to %s", addr)
|
||||||
h.exchangers = append(h.exchangers, ex)
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
h.exchangers = append(h.exchangers, ex)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -106,7 +121,6 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
|
|||||||
|
|
||||||
reply, err := h.exchange(ctx, b[:n])
|
reply, err := h.exchange(ctx, b[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error(err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -118,6 +132,7 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
|
|||||||
func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||||
mq := dns.Msg{}
|
mq := dns.Msg{}
|
||||||
if err := mq.Unpack(msg); err != nil {
|
if err := mq.Unpack(msg); err != nil {
|
||||||
|
h.logger.Error(err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -125,6 +140,8 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
|||||||
return nil, errors.New("msg: empty question")
|
return nil, errors.New("msg: empty question")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resolver_util.AddSubnetOpt(&mq, h.md.clientIP)
|
||||||
|
|
||||||
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||||
h.logger.Debug(mq.String())
|
h.logger.Debug(mq.String())
|
||||||
} else {
|
} else {
|
||||||
@ -132,26 +149,22 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var mr *dns.Msg
|
var mr *dns.Msg
|
||||||
// Only cache for single question.
|
// cache only for single question message.
|
||||||
/*
|
|
||||||
if len(mq.Question) == 1 {
|
if len(mq.Question) == 1 {
|
||||||
key := newResolverCacheKey(&mq.Question[0])
|
key := resolver_util.NewCacheKey(&mq.Question[0])
|
||||||
mr = r.cache.loadCache(key)
|
mr = h.cache.Load(key)
|
||||||
if mr != nil {
|
if mr != nil {
|
||||||
log.Logf("[dns] exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
|
h.logger.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
|
||||||
mr.Id = mq.Id
|
mr.Id = mq.Id
|
||||||
return mr.Pack()
|
return mr.Pack()
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if mr != nil {
|
if mr != nil {
|
||||||
r.cache.storeCache(key, mr, r.TTL())
|
h.cache.Store(key, mr, h.md.ttl)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
// r.addSubnetOpt(mq)
|
|
||||||
|
|
||||||
query, err := mq.Pack()
|
query, err := mq.Pack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -169,7 +182,6 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
|||||||
h.logger.Error(err)
|
h.logger.Error(err)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error(err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
mdata "github.com/go-gost/gost/pkg/metadata"
|
mdata "github.com/go-gost/gost/pkg/metadata"
|
||||||
@ -11,10 +13,10 @@ type metadata struct {
|
|||||||
retryCount int
|
retryCount int
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
prefer string
|
clientIP net.IP
|
||||||
clientIP string
|
|
||||||
// nameservers
|
// nameservers
|
||||||
servers []string
|
servers []string
|
||||||
|
dns []string // compatible with v2
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
|
func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||||
@ -23,9 +25,9 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
|
|||||||
retryCount = "retry"
|
retryCount = "retry"
|
||||||
ttl = "ttl"
|
ttl = "ttl"
|
||||||
timeout = "timeout"
|
timeout = "timeout"
|
||||||
prefer = "prefer"
|
|
||||||
clientIP = "clientIP"
|
clientIP = "clientIP"
|
||||||
servers = "servers"
|
servers = "servers"
|
||||||
|
dns = "dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
|
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
|
||||||
@ -35,9 +37,15 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
|
|||||||
if h.md.timeout <= 0 {
|
if h.md.timeout <= 0 {
|
||||||
h.md.timeout = 5 * time.Second
|
h.md.timeout = 5 * time.Second
|
||||||
}
|
}
|
||||||
h.md.prefer = mdata.GetString(md, prefer)
|
sip := mdata.GetString(md, clientIP)
|
||||||
h.md.clientIP = mdata.GetString(md, clientIP)
|
if sip != "" {
|
||||||
|
h.md.clientIP = net.ParseIP(sip)
|
||||||
|
}
|
||||||
h.md.servers = mdata.GetStrings(md, servers)
|
h.md.servers = mdata.GetStrings(md, servers)
|
||||||
|
h.md.dns = strings.Split(mdata.GetString(md, dns), ",")
|
||||||
|
if len(h.md.dns) > 0 {
|
||||||
|
h.md.servers = append(h.md.servers, h.md.dns...)
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
86
pkg/internal/util/resolver/cache.go
Normal file
86
pkg/internal/util/resolver/cache.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package resolver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-gost/gost/pkg/logger"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CacheKey string
|
||||||
|
|
||||||
|
// NewCacheKey generates resolver cache key from question of dns query.
|
||||||
|
func NewCacheKey(q *dns.Question) CacheKey {
|
||||||
|
if q == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String())
|
||||||
|
return CacheKey(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheItem struct {
|
||||||
|
msg *dns.Msg
|
||||||
|
ts time.Time
|
||||||
|
ttl time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type Cache struct {
|
||||||
|
m sync.Map
|
||||||
|
logger logger.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCache() *Cache {
|
||||||
|
return &Cache{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache) WithLogger(logger logger.Logger) *Cache {
|
||||||
|
c.logger = logger
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache) Load(key CacheKey) *dns.Msg {
|
||||||
|
v, ok := c.m.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
item, ok := v.(*cacheItem)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
elapsed := time.Since(item.ts)
|
||||||
|
if item.ttl > 0 {
|
||||||
|
if elapsed > item.ttl {
|
||||||
|
c.m.Delete(key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, rr := range item.msg.Answer {
|
||||||
|
if elapsed > time.Duration(rr.Header().Ttl)*time.Second {
|
||||||
|
c.m.Delete(key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logger.Debugf("resolver cache hit %s", key)
|
||||||
|
|
||||||
|
return item.msg.Copy()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) {
|
||||||
|
if key == "" || mr == nil || ttl < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.m.Store(key, &cacheItem{
|
||||||
|
msg: mr.Copy(),
|
||||||
|
ts: time.Now(),
|
||||||
|
ttl: ttl,
|
||||||
|
})
|
||||||
|
|
||||||
|
c.logger.Debugf("resolver cache store %s", key)
|
||||||
|
}
|
30
pkg/internal/util/resolver/resolver.go
Normal file
30
pkg/internal/util/resolver/resolver.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package resolver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func AddSubnetOpt(m *dns.Msg, ip net.IP) {
|
||||||
|
if m == nil || ip == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
opt := new(dns.OPT)
|
||||||
|
opt.Hdr.Name = "."
|
||||||
|
opt.Hdr.Rrtype = dns.TypeOPT
|
||||||
|
e := new(dns.EDNS0_SUBNET)
|
||||||
|
e.Code = dns.EDNS0SUBNET
|
||||||
|
if ip := ip.To4(); ip != nil {
|
||||||
|
e.Family = 1
|
||||||
|
e.SourceNetmask = 24
|
||||||
|
e.Address = ip
|
||||||
|
} else {
|
||||||
|
e.Family = 2
|
||||||
|
e.SourceNetmask = 128
|
||||||
|
e.Address = ip.To16()
|
||||||
|
}
|
||||||
|
opt.Option = append(opt.Option, e)
|
||||||
|
m.Extra = append(m.Extra, opt)
|
||||||
|
}
|
@ -107,7 +107,7 @@ func NewExchanger(addr string, opts ...Option) (Exchanger, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
ex.network = "tcp"
|
ex.network = "tcp"
|
||||||
case "doh":
|
case "https":
|
||||||
ex.addr = addr
|
ex.addr = addr
|
||||||
if ex.options.tlsConfig == nil {
|
if ex.options.tlsConfig == nil {
|
||||||
ex.options.tlsConfig = &tls.Config{
|
ex.options.tlsConfig = &tls.Config{
|
||||||
@ -134,7 +134,7 @@ func NewExchanger(addr string, opts ...Option) (Exchanger, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ex *exchanger) Exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
func (ex *exchanger) Exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||||
if ex.network == "doh" {
|
if ex.network == "https" {
|
||||||
return ex.dohExchange(ctx, msg)
|
return ex.dohExchange(ctx, msg)
|
||||||
}
|
}
|
||||||
return ex.exchange(ctx, msg)
|
return ex.exchange(ctx, msg)
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
package resolver
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type NameServer struct {
|
|
||||||
Addr string
|
|
||||||
Protocol string
|
|
||||||
Hostname string // for TLS handshake verification
|
|
||||||
Exchanger Exchanger
|
|
||||||
Timeout time.Duration
|
|
||||||
}
|
|
@ -3,9 +3,178 @@ package resolver
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-gost/gost/pkg/chain"
|
||||||
|
resolver_util "github.com/go-gost/gost/pkg/internal/util/resolver"
|
||||||
|
"github.com/go-gost/gost/pkg/logger"
|
||||||
|
"github.com/go-gost/gost/pkg/resolver/exchanger"
|
||||||
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Resolver interface {
|
type Resolver interface {
|
||||||
// Resolve returns a slice of the host's IPv4 and IPv6 addresses.
|
// Resolve returns a slice of the host's IPv4 and IPv6 addresses.
|
||||||
Resolve(ctx context.Context, host string) ([]net.IP, error)
|
Resolve(ctx context.Context, host string) ([]net.IP, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NameServer struct {
|
||||||
|
Addr string
|
||||||
|
Chain *chain.Chain
|
||||||
|
TTL time.Duration
|
||||||
|
Timeout time.Duration
|
||||||
|
ClientIP net.IP
|
||||||
|
Prefer string
|
||||||
|
Hostname string // for TLS handshake verification
|
||||||
|
exchanger exchanger.Exchanger
|
||||||
|
}
|
||||||
|
|
||||||
|
type resolverOptions struct {
|
||||||
|
domain string
|
||||||
|
logger logger.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResolverOption func(opts *resolverOptions)
|
||||||
|
|
||||||
|
func DomainResolverOption(domain string) ResolverOption {
|
||||||
|
return func(opts *resolverOptions) {
|
||||||
|
opts.domain = domain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoggerResolverOption(logger logger.Logger) ResolverOption {
|
||||||
|
return func(opts *resolverOptions) {
|
||||||
|
opts.logger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type resolver struct {
|
||||||
|
servers []NameServer
|
||||||
|
cache *resolver_util.Cache
|
||||||
|
options resolverOptions
|
||||||
|
logger logger.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewResolver(nameservers []NameServer, opts ...ResolverOption) (Resolver, error) {
|
||||||
|
options := resolverOptions{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(&options)
|
||||||
|
}
|
||||||
|
|
||||||
|
var servers []NameServer
|
||||||
|
for _, server := range nameservers {
|
||||||
|
addr := strings.TrimSpace(server.Addr)
|
||||||
|
if addr == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ex, err := exchanger.NewExchanger(
|
||||||
|
addr,
|
||||||
|
exchanger.ChainOption(server.Chain),
|
||||||
|
exchanger.TimeoutOption(server.Timeout),
|
||||||
|
exchanger.LoggerOption(options.logger),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
options.logger.Warnf("parse %s: %v", server, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
server.exchanger = ex
|
||||||
|
servers = append(servers, server)
|
||||||
|
}
|
||||||
|
cache := resolver_util.NewCache().
|
||||||
|
WithLogger(options.logger)
|
||||||
|
|
||||||
|
return &resolver{
|
||||||
|
servers: servers,
|
||||||
|
cache: cache,
|
||||||
|
options: options,
|
||||||
|
logger: options.logger,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resolver) Resolve(ctx context.Context, host string) (ips []net.IP, err error) {
|
||||||
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
|
return []net.IP{ip}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.options.domain != "" &&
|
||||||
|
!strings.Contains(host, ".") {
|
||||||
|
host = host + "." + r.options.domain
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, server := range r.servers {
|
||||||
|
ips, err = r.resolve(ctx, &server, host)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.Error(err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Debugf("resolve %s via %s: %v", host, server.exchanger.String(), ips)
|
||||||
|
|
||||||
|
if len(ips) > 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resolver) resolve(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) {
|
||||||
|
if server == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if server.Prefer == "ipv6" { // prefer ipv6
|
||||||
|
mq := dns.Msg{}
|
||||||
|
mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
|
||||||
|
ips, err = r.resolveIPs(ctx, server, &mq)
|
||||||
|
if err != nil || len(ips) > 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fallback to ipv4
|
||||||
|
mq := dns.Msg{}
|
||||||
|
mq.SetQuestion(dns.Fqdn(host), dns.TypeA)
|
||||||
|
return r.resolveIPs(ctx, server, &mq)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resolver) resolveIPs(ctx context.Context, server *NameServer, mq *dns.Msg) (ips []net.IP, err error) {
|
||||||
|
key := resolver_util.NewCacheKey(&mq.Question[0])
|
||||||
|
mr := r.cache.Load(key)
|
||||||
|
if mr == nil {
|
||||||
|
resolver_util.AddSubnetOpt(mq, server.ClientIP)
|
||||||
|
mr, err = r.exchange(ctx, server.exchanger, mq)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.cache.Store(key, mr, server.TTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ans := range mr.Answer {
|
||||||
|
if ar, _ := ans.(*dns.AAAA); ar != nil {
|
||||||
|
ips = append(ips, ar.AAAA)
|
||||||
|
}
|
||||||
|
if ar, _ := ans.(*dns.A); ar != nil {
|
||||||
|
ips = append(ips, ar.A)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *resolver) exchange(ctx context.Context, ex exchanger.Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) {
|
||||||
|
query, err := mq.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reply, err := ex.Exchange(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mr = &dns.Msg{}
|
||||||
|
err = mr.Unpack(reply)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user