add resolver

This commit is contained in:
ginuerzh 2021-12-30 19:12:42 +08:00
parent 9b3d7e1110
commit a6671a468e
8 changed files with 343 additions and 50 deletions

View File

@ -14,16 +14,17 @@ profiling:
resolvers:
- name: resolver-0
nameservers:
- addr: udp://8.8.8.8:53
chain: chain-0
ttl: 60s
prefer: ipv4
clientIP: 1.2.3.4
nameServers:
- addr: udp://8.8.8.8:53
timeout: 5s
timeout: 3s
- addr: tcp://1.1.1.1:53
- addr: tls://1.1.1.1:853
- addr: https://1.0.0.1/dns-query
domain: cloudflare-dns.com
hostname: cloudflare-dns.com
services:
- name: http+tcp

View File

@ -6,12 +6,14 @@ import (
"errors"
"net"
"strconv"
"strings"
"time"
"github.com/go-gost/gost/pkg/bypass"
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/common/bufpool"
"github.com/go-gost/gost/pkg/handler"
resolver_util "github.com/go-gost/gost/pkg/internal/util/resolver"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
@ -27,6 +29,7 @@ type dnsHandler struct {
chain *chain.Chain
bypass bypass.Bypass
exchangers []exchanger.Exchanger
cache *resolver_util.Cache
logger logger.Logger
md metadata
}
@ -37,8 +40,11 @@ func NewHandler(opts ...handler.Option) handler.Handler {
opt(options)
}
cache := resolver_util.NewCache().WithLogger(options.Logger)
return &dnsHandler{
bypass: options.Bypass,
cache: cache,
logger: options.Logger,
}
}
@ -49,9 +55,14 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
}
for _, server := range h.md.servers {
server = strings.TrimSpace(server)
if server == "" {
continue
}
ex, err := exchanger.NewExchanger(
server,
exchanger.ChainOption(h.chain),
exchanger.TimeoutOption(h.md.timeout),
exchanger.LoggerOption(h.logger),
)
if err != nil {
@ -61,14 +72,18 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
h.exchangers = append(h.exchangers, ex)
}
if len(h.exchangers) == 0 {
ex, _ := exchanger.NewExchanger(
"udp://127.0.0.53:53",
addr := "udp://127.0.0.1:53"
ex, err := exchanger.NewExchanger(
addr,
exchanger.ChainOption(h.chain),
exchanger.TimeoutOption(h.md.timeout),
exchanger.LoggerOption(h.logger),
)
if ex != nil {
h.exchangers = append(h.exchangers, ex)
h.logger.Warnf("resolver not found, default to %s", addr)
if err != nil {
return err
}
h.exchangers = append(h.exchangers, ex)
}
return
}
@ -106,7 +121,6 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
reply, err := h.exchange(ctx, b[:n])
if err != nil {
h.logger.Error(err)
return
}
@ -118,6 +132,7 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
mq := dns.Msg{}
if err := mq.Unpack(msg); err != nil {
h.logger.Error(err)
return nil, err
}
@ -125,6 +140,8 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
return nil, errors.New("msg: empty question")
}
resolver_util.AddSubnetOpt(&mq, h.md.clientIP)
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(mq.String())
} else {
@ -132,26 +149,22 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
}
var mr *dns.Msg
// Only cache for single question.
/*
// cache only for single question message.
if len(mq.Question) == 1 {
key := newResolverCacheKey(&mq.Question[0])
mr = r.cache.loadCache(key)
key := resolver_util.NewCacheKey(&mq.Question[0])
mr = h.cache.Load(key)
if mr != nil {
log.Logf("[dns] exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
h.logger.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
mr.Id = mq.Id
return mr.Pack()
}
defer func() {
if mr != nil {
r.cache.storeCache(key, mr, r.TTL())
h.cache.Store(key, mr, h.md.ttl)
}
}()
}
*/
// r.addSubnetOpt(mq)
query, err := mq.Pack()
if err != nil {
@ -169,7 +182,6 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
h.logger.Error(err)
}
if err != nil {
h.logger.Error(err)
return nil, err
}

View File

@ -1,6 +1,8 @@
package dns
import (
"net"
"strings"
"time"
mdata "github.com/go-gost/gost/pkg/metadata"
@ -11,10 +13,10 @@ type metadata struct {
retryCount int
ttl time.Duration
timeout time.Duration
prefer string
clientIP string
clientIP net.IP
// nameservers
servers []string
dns []string // compatible with v2
}
func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
@ -23,9 +25,9 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
retryCount = "retry"
ttl = "ttl"
timeout = "timeout"
prefer = "prefer"
clientIP = "clientIP"
servers = "servers"
dns = "dns"
)
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
@ -35,9 +37,15 @@ func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
if h.md.timeout <= 0 {
h.md.timeout = 5 * time.Second
}
h.md.prefer = mdata.GetString(md, prefer)
h.md.clientIP = mdata.GetString(md, clientIP)
sip := mdata.GetString(md, clientIP)
if sip != "" {
h.md.clientIP = net.ParseIP(sip)
}
h.md.servers = mdata.GetStrings(md, servers)
h.md.dns = strings.Split(mdata.GetString(md, dns), ",")
if len(h.md.dns) > 0 {
h.md.servers = append(h.md.servers, h.md.dns...)
}
return
}

View File

@ -0,0 +1,86 @@
package resolver
import (
"fmt"
"sync"
"time"
"github.com/go-gost/gost/pkg/logger"
"github.com/miekg/dns"
)
type CacheKey string
// NewCacheKey generates resolver cache key from question of dns query.
func NewCacheKey(q *dns.Question) CacheKey {
if q == nil {
return ""
}
key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String())
return CacheKey(key)
}
type cacheItem struct {
msg *dns.Msg
ts time.Time
ttl time.Duration
}
type Cache struct {
m sync.Map
logger logger.Logger
}
func NewCache() *Cache {
return &Cache{}
}
func (c *Cache) WithLogger(logger logger.Logger) *Cache {
c.logger = logger
return c
}
func (c *Cache) Load(key CacheKey) *dns.Msg {
v, ok := c.m.Load(key)
if !ok {
return nil
}
item, ok := v.(*cacheItem)
if !ok {
return nil
}
elapsed := time.Since(item.ts)
if item.ttl > 0 {
if elapsed > item.ttl {
c.m.Delete(key)
return nil
}
} else {
for _, rr := range item.msg.Answer {
if elapsed > time.Duration(rr.Header().Ttl)*time.Second {
c.m.Delete(key)
return nil
}
}
}
c.logger.Debugf("resolver cache hit %s", key)
return item.msg.Copy()
}
func (c *Cache) Store(key CacheKey, mr *dns.Msg, ttl time.Duration) {
if key == "" || mr == nil || ttl < 0 {
return
}
c.m.Store(key, &cacheItem{
msg: mr.Copy(),
ts: time.Now(),
ttl: ttl,
})
c.logger.Debugf("resolver cache store %s", key)
}

View File

@ -0,0 +1,30 @@
package resolver
import (
"net"
"github.com/miekg/dns"
)
func AddSubnetOpt(m *dns.Msg, ip net.IP) {
if m == nil || ip == nil {
return
}
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
e := new(dns.EDNS0_SUBNET)
e.Code = dns.EDNS0SUBNET
if ip := ip.To4(); ip != nil {
e.Family = 1
e.SourceNetmask = 24
e.Address = ip
} else {
e.Family = 2
e.SourceNetmask = 128
e.Address = ip.To16()
}
opt.Option = append(opt.Option, e)
m.Extra = append(m.Extra, opt)
}

View File

@ -107,7 +107,7 @@ func NewExchanger(addr string, opts ...Option) (Exchanger, error) {
}
}
ex.network = "tcp"
case "doh":
case "https":
ex.addr = addr
if ex.options.tlsConfig == nil {
ex.options.tlsConfig = &tls.Config{
@ -134,7 +134,7 @@ func NewExchanger(addr string, opts ...Option) (Exchanger, error) {
}
func (ex *exchanger) Exchange(ctx context.Context, msg []byte) ([]byte, error) {
if ex.network == "doh" {
if ex.network == "https" {
return ex.dohExchange(ctx, msg)
}
return ex.exchange(ctx, msg)

View File

@ -1,13 +0,0 @@
package resolver
import (
"time"
)
type NameServer struct {
Addr string
Protocol string
Hostname string // for TLS handshake verification
Exchanger Exchanger
Timeout time.Duration
}

View File

@ -3,9 +3,178 @@ package resolver
import (
"context"
"net"
"strings"
"time"
"github.com/go-gost/gost/pkg/chain"
resolver_util "github.com/go-gost/gost/pkg/internal/util/resolver"
"github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/gost/pkg/resolver/exchanger"
"github.com/miekg/dns"
)
type Resolver interface {
// Resolve returns a slice of the host's IPv4 and IPv6 addresses.
Resolve(ctx context.Context, host string) ([]net.IP, error)
}
type NameServer struct {
Addr string
Chain *chain.Chain
TTL time.Duration
Timeout time.Duration
ClientIP net.IP
Prefer string
Hostname string // for TLS handshake verification
exchanger exchanger.Exchanger
}
type resolverOptions struct {
domain string
logger logger.Logger
}
type ResolverOption func(opts *resolverOptions)
func DomainResolverOption(domain string) ResolverOption {
return func(opts *resolverOptions) {
opts.domain = domain
}
}
func LoggerResolverOption(logger logger.Logger) ResolverOption {
return func(opts *resolverOptions) {
opts.logger = logger
}
}
type resolver struct {
servers []NameServer
cache *resolver_util.Cache
options resolverOptions
logger logger.Logger
}
func NewResolver(nameservers []NameServer, opts ...ResolverOption) (Resolver, error) {
options := resolverOptions{}
for _, opt := range opts {
opt(&options)
}
var servers []NameServer
for _, server := range nameservers {
addr := strings.TrimSpace(server.Addr)
if addr == "" {
continue
}
ex, err := exchanger.NewExchanger(
addr,
exchanger.ChainOption(server.Chain),
exchanger.TimeoutOption(server.Timeout),
exchanger.LoggerOption(options.logger),
)
if err != nil {
options.logger.Warnf("parse %s: %v", server, err)
continue
}
server.exchanger = ex
servers = append(servers, server)
}
cache := resolver_util.NewCache().
WithLogger(options.logger)
return &resolver{
servers: servers,
cache: cache,
options: options,
logger: options.logger,
}, nil
}
func (r *resolver) Resolve(ctx context.Context, host string) (ips []net.IP, err error) {
if ip := net.ParseIP(host); ip != nil {
return []net.IP{ip}, nil
}
if r.options.domain != "" &&
!strings.Contains(host, ".") {
host = host + "." + r.options.domain
}
for _, server := range r.servers {
ips, err = r.resolve(ctx, &server, host)
if err != nil {
r.logger.Error(err)
continue
}
r.logger.Debugf("resolve %s via %s: %v", host, server.exchanger.String(), ips)
if len(ips) > 0 {
break
}
}
return
}
func (r *resolver) resolve(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) {
if server == nil {
return
}
if server.Prefer == "ipv6" { // prefer ipv6
mq := dns.Msg{}
mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
ips, err = r.resolveIPs(ctx, server, &mq)
if err != nil || len(ips) > 0 {
return
}
}
// fallback to ipv4
mq := dns.Msg{}
mq.SetQuestion(dns.Fqdn(host), dns.TypeA)
return r.resolveIPs(ctx, server, &mq)
}
func (r *resolver) resolveIPs(ctx context.Context, server *NameServer, mq *dns.Msg) (ips []net.IP, err error) {
key := resolver_util.NewCacheKey(&mq.Question[0])
mr := r.cache.Load(key)
if mr == nil {
resolver_util.AddSubnetOpt(mq, server.ClientIP)
mr, err = r.exchange(ctx, server.exchanger, mq)
if err != nil {
return
}
r.cache.Store(key, mr, server.TTL)
}
for _, ans := range mr.Answer {
if ar, _ := ans.(*dns.AAAA); ar != nil {
ips = append(ips, ar.AAAA)
}
if ar, _ := ans.(*dns.A); ar != nil {
ips = append(ips, ar.A)
}
}
return
}
func (r *resolver) exchange(ctx context.Context, ex exchanger.Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) {
query, err := mq.Pack()
if err != nil {
return
}
reply, err := ex.Exchange(ctx, query)
if err != nil {
return
}
mr = &dns.Msg{}
err = mr.Unpack(reply)
return
}