915 lines
19 KiB
Go
915 lines
19 KiB
Go
package gost
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/go-log/log"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
var (
|
|
// DefaultResolverTimeout is the default timeout for name resolution.
|
|
DefaultResolverTimeout = 5 * time.Second
|
|
)
|
|
|
|
type nameServerOptions struct {
|
|
timeout time.Duration
|
|
chain *Chain
|
|
}
|
|
|
|
// NameServerOption allows a common way to set name server options.
|
|
type NameServerOption func(*nameServerOptions)
|
|
|
|
// TimeoutNameServerOption sets the timeout for name server.
|
|
func TimeoutNameServerOption(timeout time.Duration) NameServerOption {
|
|
return func(opts *nameServerOptions) {
|
|
opts.timeout = timeout
|
|
}
|
|
}
|
|
|
|
// ChainNameServerOption sets the chain for name server.
|
|
func ChainNameServerOption(chain *Chain) NameServerOption {
|
|
return func(opts *nameServerOptions) {
|
|
opts.chain = chain
|
|
}
|
|
}
|
|
|
|
// NameServer is a name server.
|
|
// Currently supported protocol: TCP, UDP and TLS.
|
|
type NameServer struct {
|
|
Addr string
|
|
Protocol string
|
|
Hostname string // for TLS handshake verification
|
|
exchanger Exchanger
|
|
options nameServerOptions
|
|
}
|
|
|
|
// Init initializes the name server.
|
|
func (ns *NameServer) Init(opts ...NameServerOption) error {
|
|
for _, opt := range opts {
|
|
opt(&ns.options)
|
|
}
|
|
|
|
options := []ExchangerOption{
|
|
TimeoutExchangerOption(ns.options.timeout),
|
|
}
|
|
protocol := strings.ToLower(ns.Protocol)
|
|
switch protocol {
|
|
case "tcp", "tcp-chain":
|
|
if protocol == "tcp-chain" {
|
|
options = append(options, ChainExchangerOption(ns.options.chain))
|
|
}
|
|
ns.exchanger = NewDNSTCPExchanger(ns.Addr, options...)
|
|
case "tls", "tls-chain":
|
|
if protocol == "tls-chain" {
|
|
options = append(options, ChainExchangerOption(ns.options.chain))
|
|
}
|
|
cfg := &tls.Config{
|
|
ServerName: ns.Hostname,
|
|
}
|
|
if cfg.ServerName == "" {
|
|
cfg.InsecureSkipVerify = true
|
|
}
|
|
ns.exchanger = NewDoTExchanger(ns.Addr, cfg, options...)
|
|
case "https", "https-chain":
|
|
if protocol == "https-chain" {
|
|
options = append(options, ChainExchangerOption(ns.options.chain))
|
|
}
|
|
u, err := url.Parse(ns.Addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
u.Scheme = "https"
|
|
cfg := &tls.Config{ServerName: ns.Hostname}
|
|
if cfg.ServerName == "" {
|
|
cfg.InsecureSkipVerify = true
|
|
}
|
|
ns.exchanger = NewDoHExchanger(u, cfg, options...)
|
|
case "udp", "udp-chain":
|
|
fallthrough
|
|
default:
|
|
if protocol == "udp-chain" {
|
|
options = append(options, ChainExchangerOption(ns.options.chain))
|
|
}
|
|
ns.exchanger = NewDNSExchanger(ns.Addr, options...)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ns *NameServer) String() string {
|
|
addr := ns.Addr
|
|
prot := ns.Protocol
|
|
if prot == "" {
|
|
prot = "udp"
|
|
}
|
|
return fmt.Sprintf("%s/%s", addr, prot)
|
|
}
|
|
|
|
type resolverOptions struct {
|
|
chain *Chain
|
|
timeout time.Duration
|
|
ttl time.Duration
|
|
prefer string
|
|
srcIP net.IP
|
|
}
|
|
|
|
// ResolverOption allows a common way to set Resolver options.
|
|
type ResolverOption func(*resolverOptions)
|
|
|
|
// ChainResolverOption sets the chain for Resolver.
|
|
func ChainResolverOption(chain *Chain) ResolverOption {
|
|
return func(opts *resolverOptions) {
|
|
opts.chain = chain
|
|
}
|
|
}
|
|
|
|
// TimeoutResolverOption sets the timeout for Resolver.
|
|
func TimeoutResolverOption(timeout time.Duration) ResolverOption {
|
|
return func(opts *resolverOptions) {
|
|
opts.timeout = timeout
|
|
}
|
|
}
|
|
|
|
// TTLResolverOption sets the timeout for Resolver.
|
|
func TTLResolverOption(ttl time.Duration) ResolverOption {
|
|
return func(opts *resolverOptions) {
|
|
opts.ttl = ttl
|
|
}
|
|
}
|
|
|
|
// PreferResolverOption sets the prefer for Resolver.
|
|
func PreferResolverOption(prefer string) ResolverOption {
|
|
return func(opts *resolverOptions) {
|
|
opts.prefer = prefer
|
|
}
|
|
}
|
|
|
|
// SrcIPResolverOption sets the source IP for Resolver.
|
|
func SrcIPResolverOption(ip net.IP) ResolverOption {
|
|
return func(opts *resolverOptions) {
|
|
opts.srcIP = ip
|
|
}
|
|
}
|
|
|
|
// Resolver is a name resolver for domain name.
|
|
// It contains a list of name servers.
|
|
type Resolver interface {
|
|
// Init initializes the Resolver instance.
|
|
Init(opts ...ResolverOption) error
|
|
// Resolve returns a slice of that host's IPv4 and IPv6 addresses.
|
|
Resolve(host string) ([]net.IP, error)
|
|
// Exchange performs a synchronous query,
|
|
// It sends the message query and waits for a reply.
|
|
Exchange(ctx context.Context, query []byte) (reply []byte, err error)
|
|
}
|
|
|
|
// ReloadResolver is resolover that support live reloading.
|
|
type ReloadResolver interface {
|
|
Resolver
|
|
Reloader
|
|
Stoppable
|
|
}
|
|
|
|
type resolver struct {
|
|
servers []NameServer
|
|
ttl time.Duration
|
|
timeout time.Duration
|
|
period time.Duration
|
|
domain string
|
|
cache *resolverCache
|
|
stopped chan struct{}
|
|
mux sync.RWMutex
|
|
prefer string // ipv4 or ipv6
|
|
srcIP net.IP // for edns0 subnet option
|
|
options resolverOptions
|
|
}
|
|
|
|
// NewResolver create a new Resolver with the given name servers and resolution timeout.
|
|
func NewResolver(ttl time.Duration, servers ...NameServer) ReloadResolver {
|
|
r := newResolver(ttl, servers...)
|
|
return r
|
|
}
|
|
|
|
func newResolver(ttl time.Duration, servers ...NameServer) *resolver {
|
|
return &resolver{
|
|
servers: servers,
|
|
cache: newResolverCache(ttl),
|
|
stopped: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
func (r *resolver) Init(opts ...ResolverOption) error {
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
|
|
r.mux.Lock()
|
|
defer r.mux.Unlock()
|
|
|
|
for _, opt := range opts {
|
|
opt(&r.options)
|
|
}
|
|
|
|
timeout := r.timeout
|
|
if r.options.timeout != 0 {
|
|
timeout = r.options.timeout
|
|
}
|
|
if timeout <= 0 {
|
|
timeout = DefaultResolverTimeout
|
|
}
|
|
|
|
if r.options.ttl != 0 {
|
|
r.ttl = r.options.ttl
|
|
}
|
|
if r.options.prefer != "" {
|
|
r.prefer = r.options.prefer
|
|
}
|
|
if r.options.srcIP != nil {
|
|
r.srcIP = r.options.srcIP
|
|
}
|
|
|
|
var nss []NameServer
|
|
for _, ns := range r.servers {
|
|
if err := ns.Init( // init all name servers
|
|
ChainNameServerOption(r.options.chain),
|
|
TimeoutNameServerOption(timeout),
|
|
); err != nil {
|
|
continue // ignore invalid name servers
|
|
}
|
|
nss = append(nss, ns)
|
|
}
|
|
|
|
r.servers = nss
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *resolver) copyServers() []NameServer {
|
|
r.mux.RLock()
|
|
defer r.mux.RUnlock()
|
|
|
|
servers := make([]NameServer, len(r.servers))
|
|
for i := range r.servers {
|
|
servers[i] = r.servers[i]
|
|
}
|
|
|
|
return servers
|
|
}
|
|
|
|
func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
|
|
r.mux.RLock()
|
|
domain := r.domain
|
|
r.mux.RUnlock()
|
|
|
|
if ip := net.ParseIP(host); ip != nil {
|
|
return []net.IP{ip}, nil
|
|
}
|
|
|
|
if !strings.Contains(host, ".") && domain != "" {
|
|
host = host + "." + domain
|
|
}
|
|
|
|
ctx := context.Background()
|
|
for _, ns := range r.copyServers() {
|
|
ips, err = r.resolve(ctx, ns.exchanger, host)
|
|
if err != nil {
|
|
log.Logf("[resolver] %s via %s : %s", host, ns.String(), err)
|
|
continue
|
|
}
|
|
|
|
if Debug {
|
|
log.Logf("[resolver] %s via %s %v", host, ns.String(), ips)
|
|
}
|
|
if len(ips) > 0 {
|
|
break
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (r *resolver) resolve(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) {
|
|
if ex == nil {
|
|
return
|
|
}
|
|
|
|
r.mux.RLock()
|
|
prefer := r.prefer
|
|
r.mux.RUnlock()
|
|
|
|
if prefer == "ipv6" { // prefer ipv6
|
|
mq := &dns.Msg{}
|
|
mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
|
|
ips, err = r.resolveIPs(ctx, ex, mq)
|
|
if err != nil || len(ips) > 0 {
|
|
return
|
|
}
|
|
}
|
|
|
|
mq := &dns.Msg{}
|
|
mq.SetQuestion(dns.Fqdn(host), dns.TypeA)
|
|
return r.resolveIPs(ctx, ex, mq)
|
|
}
|
|
|
|
func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) {
|
|
key := newResolverCacheKey(&mq.Question[0])
|
|
mr := r.cache.loadCache(key)
|
|
if mr == nil {
|
|
r.addSubnetOpt(mq)
|
|
mr, err = r.exchangeMsg(ctx, ex, mq)
|
|
if err != nil {
|
|
return
|
|
}
|
|
r.cache.storeCache(key, mr, r.TTL())
|
|
}
|
|
|
|
for _, ans := range mr.Answer {
|
|
if ar, _ := ans.(*dns.AAAA); ar != nil {
|
|
ips = append(ips, ar.AAAA)
|
|
}
|
|
if ar, _ := ans.(*dns.A); ar != nil {
|
|
ips = append(ips, ar.A)
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (r *resolver) addSubnetOpt(m *dns.Msg) {
|
|
if m == nil || r.srcIP == nil {
|
|
return
|
|
}
|
|
opt := new(dns.OPT)
|
|
opt.Hdr.Name = "."
|
|
opt.Hdr.Rrtype = dns.TypeOPT
|
|
e := new(dns.EDNS0_SUBNET)
|
|
e.Code = dns.EDNS0SUBNET
|
|
if ip := r.srcIP.To4(); ip != nil {
|
|
e.Family = 1
|
|
e.SourceNetmask = 32
|
|
e.Address = ip.To4()
|
|
} else {
|
|
e.Family = 2
|
|
e.SourceNetmask = 128
|
|
e.Address = r.srcIP
|
|
}
|
|
opt.Option = append(opt.Option, e)
|
|
m.Extra = append(m.Extra, opt)
|
|
}
|
|
|
|
func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) {
|
|
mq := &dns.Msg{}
|
|
if err = mq.Unpack(query); err != nil {
|
|
return
|
|
}
|
|
|
|
if len(mq.Question) == 0 {
|
|
return nil, errors.New("empty question")
|
|
}
|
|
|
|
var mr *dns.Msg
|
|
// Only cache for single question.
|
|
if len(mq.Question) == 1 {
|
|
key := newResolverCacheKey(&mq.Question[0])
|
|
mr = r.cache.loadCache(key)
|
|
if mr != nil {
|
|
log.Logf("[dns] exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
|
|
mr.Id = mq.Id
|
|
return mr.Pack()
|
|
}
|
|
|
|
defer func() {
|
|
if mr != nil {
|
|
r.cache.storeCache(key, mr, r.TTL())
|
|
}
|
|
}()
|
|
}
|
|
|
|
r.addSubnetOpt(mq)
|
|
|
|
for _, ns := range r.copyServers() {
|
|
log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), mq.Question[0].String())
|
|
mr, err = r.exchangeMsg(ctx, ns.exchanger, mq)
|
|
if err == nil {
|
|
break
|
|
}
|
|
log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), err)
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
return mr.Pack()
|
|
}
|
|
|
|
func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) {
|
|
query, err := mq.Pack()
|
|
if err != nil {
|
|
return
|
|
}
|
|
reply, err := ex.Exchange(ctx, query)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
mr = &dns.Msg{}
|
|
err = mr.Unpack(reply)
|
|
|
|
return
|
|
}
|
|
|
|
func (r *resolver) TTL() time.Duration {
|
|
r.mux.RLock()
|
|
defer r.mux.RUnlock()
|
|
return r.ttl
|
|
}
|
|
|
|
func (r *resolver) Reload(rd io.Reader) error {
|
|
var ttl, timeout, period time.Duration
|
|
var domain, prefer string
|
|
var srcIP net.IP
|
|
var nss []NameServer
|
|
|
|
if rd == nil || r.Stopped() {
|
|
return nil
|
|
}
|
|
|
|
scanner := bufio.NewScanner(rd)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
ss := splitLine(line)
|
|
if len(ss) == 0 {
|
|
continue
|
|
}
|
|
|
|
switch ss[0] {
|
|
case "timeout": // timeout option
|
|
if len(ss) > 1 {
|
|
timeout, _ = time.ParseDuration(ss[1])
|
|
}
|
|
case "ttl": // ttl option
|
|
if len(ss) > 1 {
|
|
ttl, _ = time.ParseDuration(ss[1])
|
|
}
|
|
case "reload": // reload option
|
|
if len(ss) > 1 {
|
|
period, _ = time.ParseDuration(ss[1])
|
|
}
|
|
case "domain":
|
|
if len(ss) > 1 {
|
|
domain = ss[1]
|
|
}
|
|
case "search", "sortlist", "options": // we don't support these features in /etc/resolv.conf
|
|
case "prefer":
|
|
if len(ss) > 1 {
|
|
prefer = strings.ToLower(ss[1])
|
|
}
|
|
case "ip":
|
|
if len(ss) > 1 {
|
|
srcIP = net.ParseIP(ss[1])
|
|
}
|
|
case "nameserver": // nameserver option, compatible with /etc/resolv.conf
|
|
if len(ss) <= 1 {
|
|
break
|
|
}
|
|
ss = ss[1:]
|
|
fallthrough
|
|
default:
|
|
var ns NameServer
|
|
switch len(ss) {
|
|
case 0:
|
|
break
|
|
case 1:
|
|
ns.Addr = ss[0]
|
|
case 2:
|
|
ns.Addr = ss[0]
|
|
ns.Protocol = ss[1]
|
|
default:
|
|
ns.Addr = ss[0]
|
|
ns.Protocol = ss[1]
|
|
ns.Hostname = ss[2]
|
|
}
|
|
|
|
if strings.HasPrefix(ns.Addr, "https") && ns.Protocol == "" {
|
|
ns.Protocol = "https"
|
|
}
|
|
nss = append(nss, ns)
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
r.mux.Lock()
|
|
r.ttl = ttl
|
|
r.timeout = timeout
|
|
r.domain = domain
|
|
r.period = period
|
|
r.prefer = prefer
|
|
r.srcIP = srcIP
|
|
r.servers = nss
|
|
r.mux.Unlock()
|
|
|
|
r.Init()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *resolver) Period() time.Duration {
|
|
if r.Stopped() {
|
|
return -1
|
|
}
|
|
|
|
r.mux.RLock()
|
|
defer r.mux.RUnlock()
|
|
|
|
return r.period
|
|
}
|
|
|
|
// Stop stops reloading.
|
|
func (r *resolver) Stop() {
|
|
select {
|
|
case <-r.stopped:
|
|
default:
|
|
close(r.stopped)
|
|
}
|
|
}
|
|
|
|
// Stopped checks whether the reloader is stopped.
|
|
func (r *resolver) Stopped() bool {
|
|
select {
|
|
case <-r.stopped:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (r *resolver) String() string {
|
|
if r == nil {
|
|
return ""
|
|
}
|
|
|
|
r.mux.RLock()
|
|
defer r.mux.RUnlock()
|
|
|
|
b := &bytes.Buffer{}
|
|
fmt.Fprintf(b, "TTL %v\n", r.ttl)
|
|
fmt.Fprintf(b, "Reload %v\n", r.period)
|
|
fmt.Fprintf(b, "Domain %v\n", r.domain)
|
|
for i := range r.servers {
|
|
fmt.Fprintln(b, r.servers[i])
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
type resolverCacheKey string
|
|
|
|
// newResolverCacheKey generates resolver cache key from question of dns query.
|
|
func newResolverCacheKey(q *dns.Question) resolverCacheKey {
|
|
if q == nil {
|
|
return ""
|
|
}
|
|
key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String())
|
|
return resolverCacheKey(key)
|
|
}
|
|
|
|
type resolverCacheItem struct {
|
|
mr *dns.Msg
|
|
ts int64
|
|
ttl time.Duration
|
|
}
|
|
|
|
type resolverCache struct {
|
|
m sync.Map
|
|
}
|
|
|
|
func newResolverCache(ttl time.Duration) *resolverCache {
|
|
return &resolverCache{}
|
|
}
|
|
|
|
func (rc *resolverCache) loadCache(key resolverCacheKey) *dns.Msg {
|
|
v, ok := rc.m.Load(key)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
item, ok := v.(*resolverCacheItem)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
elapsed := time.Since(time.Unix(item.ts, 0))
|
|
if item.ttl > 0 && elapsed > item.ttl {
|
|
rc.m.Delete(key)
|
|
return nil
|
|
}
|
|
for _, rr := range item.mr.Answer {
|
|
if elapsed > time.Duration(rr.Header().Ttl)*time.Second {
|
|
rc.m.Delete(key)
|
|
return nil
|
|
}
|
|
}
|
|
|
|
if Debug {
|
|
log.Logf("[resolver] cache hit %s", key)
|
|
}
|
|
|
|
return item.mr.Copy()
|
|
}
|
|
|
|
func (rc *resolverCache) storeCache(key resolverCacheKey, mr *dns.Msg, ttl time.Duration) {
|
|
if key == "" || mr == nil || ttl < 0 {
|
|
return
|
|
}
|
|
|
|
rc.m.Store(key, &resolverCacheItem{
|
|
mr: mr.Copy(),
|
|
ts: time.Now().Unix(),
|
|
ttl: ttl,
|
|
})
|
|
if Debug {
|
|
log.Logf("[resolver] cache store %s", key)
|
|
}
|
|
}
|
|
|
|
// Exchanger is an interface for DNS synchronous query.
|
|
type Exchanger interface {
|
|
Exchange(ctx context.Context, query []byte) ([]byte, error)
|
|
}
|
|
|
|
type exchangerOptions struct {
|
|
chain *Chain
|
|
timeout time.Duration
|
|
}
|
|
|
|
// ExchangerOption allows a common way to set Exchanger options.
|
|
type ExchangerOption func(opts *exchangerOptions)
|
|
|
|
// ChainExchangerOption sets the chain for Exchanger.
|
|
func ChainExchangerOption(chain *Chain) ExchangerOption {
|
|
return func(opts *exchangerOptions) {
|
|
opts.chain = chain
|
|
}
|
|
}
|
|
|
|
// TimeoutExchangerOption sets the timeout for Exchanger.
|
|
func TimeoutExchangerOption(timeout time.Duration) ExchangerOption {
|
|
return func(opts *exchangerOptions) {
|
|
opts.timeout = timeout
|
|
}
|
|
}
|
|
|
|
type dnsExchanger struct {
|
|
addr string
|
|
options exchangerOptions
|
|
}
|
|
|
|
// NewDNSExchanger creates a DNS over UDP Exchanger
|
|
func NewDNSExchanger(addr string, opts ...ExchangerOption) Exchanger {
|
|
var options exchangerOptions
|
|
for _, opt := range opts {
|
|
opt(&options)
|
|
}
|
|
|
|
if _, port, _ := net.SplitHostPort(addr); port == "" {
|
|
addr = net.JoinHostPort(addr, "53")
|
|
}
|
|
|
|
return &dnsExchanger{
|
|
addr: addr,
|
|
options: options,
|
|
}
|
|
}
|
|
|
|
func (ex *dnsExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) {
|
|
t := time.Now()
|
|
c, err := ex.options.chain.DialContext(ctx,
|
|
"udp", ex.addr,
|
|
TimeoutChainOption(ex.options.timeout),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.SetDeadline(time.Now().Add(ex.options.timeout - time.Since(t)))
|
|
defer c.Close()
|
|
|
|
conn := &dns.Conn{
|
|
Conn: c,
|
|
}
|
|
if _, err = conn.Write(query); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mr, err := conn.ReadMsg()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return mr.Pack()
|
|
}
|
|
|
|
type dnsTCPExchanger struct {
|
|
addr string
|
|
options exchangerOptions
|
|
}
|
|
|
|
// NewDNSTCPExchanger creates a DNS over TCP Exchanger
|
|
func NewDNSTCPExchanger(addr string, opts ...ExchangerOption) Exchanger {
|
|
var options exchangerOptions
|
|
for _, opt := range opts {
|
|
opt(&options)
|
|
}
|
|
|
|
if _, port, _ := net.SplitHostPort(addr); port == "" {
|
|
addr = net.JoinHostPort(addr, "53")
|
|
}
|
|
|
|
return &dnsTCPExchanger{
|
|
addr: addr,
|
|
options: options,
|
|
}
|
|
}
|
|
|
|
func (ex *dnsTCPExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) {
|
|
t := time.Now()
|
|
c, err := ex.options.chain.DialContext(ctx,
|
|
"tcp", ex.addr,
|
|
TimeoutChainOption(ex.options.timeout),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.SetDeadline(time.Now().Add(ex.options.timeout - time.Since(t)))
|
|
defer c.Close()
|
|
|
|
conn := &dns.Conn{
|
|
Conn: c,
|
|
}
|
|
if _, err = conn.Write(query); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mr, err := conn.ReadMsg()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return mr.Pack()
|
|
}
|
|
|
|
type dotExchanger struct {
|
|
addr string
|
|
tlsConfig *tls.Config
|
|
options exchangerOptions
|
|
}
|
|
|
|
// NewDoTExchanger creates a DNS over TLS Exchanger
|
|
func NewDoTExchanger(addr string, tlsConfig *tls.Config, opts ...ExchangerOption) Exchanger {
|
|
var options exchangerOptions
|
|
for _, opt := range opts {
|
|
opt(&options)
|
|
}
|
|
|
|
if _, port, _ := net.SplitHostPort(addr); port == "" {
|
|
addr = net.JoinHostPort(addr, "53")
|
|
}
|
|
|
|
if tlsConfig == nil {
|
|
tlsConfig = &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
}
|
|
}
|
|
return &dotExchanger{
|
|
addr: addr,
|
|
tlsConfig: tlsConfig,
|
|
options: options,
|
|
}
|
|
}
|
|
|
|
func (ex *dotExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
|
conn, err = ex.options.chain.DialContext(ctx,
|
|
network, address,
|
|
TimeoutChainOption(ex.options.timeout),
|
|
)
|
|
if err != nil {
|
|
return
|
|
}
|
|
conn = tls.Client(conn, ex.tlsConfig)
|
|
|
|
return
|
|
}
|
|
|
|
func (ex *dotExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) {
|
|
t := time.Now()
|
|
c, err := ex.dial(ctx, "tcp", ex.addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.SetDeadline(time.Now().Add(ex.options.timeout - time.Since(t)))
|
|
defer c.Close()
|
|
|
|
conn := &dns.Conn{
|
|
Conn: c,
|
|
}
|
|
if _, err = conn.Write(query); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mr, err := conn.ReadMsg()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return mr.Pack()
|
|
}
|
|
|
|
type dohExchanger struct {
|
|
endpoint *url.URL
|
|
client *http.Client
|
|
options exchangerOptions
|
|
}
|
|
|
|
// NewDoHExchanger creates a DNS over HTTPS Exchanger
|
|
func NewDoHExchanger(urlStr *url.URL, tlsConfig *tls.Config, opts ...ExchangerOption) Exchanger {
|
|
var options exchangerOptions
|
|
for _, opt := range opts {
|
|
opt(&options)
|
|
}
|
|
ex := &dohExchanger{
|
|
endpoint: urlStr,
|
|
options: options,
|
|
}
|
|
|
|
ex.client = &http.Client{
|
|
Timeout: options.timeout,
|
|
Transport: &http.Transport{
|
|
// Proxy: ProxyFromEnvironment,
|
|
TLSClientConfig: tlsConfig,
|
|
ForceAttemptHTTP2: true,
|
|
MaxIdleConns: 100,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: options.timeout,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
DialContext: ex.dialContext,
|
|
},
|
|
}
|
|
|
|
return ex
|
|
}
|
|
|
|
func (ex *dohExchanger) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
|
return ex.options.chain.DialContext(ctx,
|
|
network, address,
|
|
TimeoutChainOption(ex.options.timeout),
|
|
)
|
|
}
|
|
|
|
func (ex *dohExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) {
|
|
req, err := http.NewRequestWithContext(ctx, "POST", ex.endpoint.String(), bytes.NewBuffer(query))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create an HTTPS request: %s", err)
|
|
}
|
|
|
|
// req.Header.Add("Content-Type", "application/dns-udpwireformat")
|
|
req.Header.Add("Content-Type", "application/dns-message")
|
|
req.Host = ex.endpoint.Hostname()
|
|
|
|
client := ex.client
|
|
if client == nil {
|
|
client = http.DefaultClient
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to perform an HTTPS request: %s", err)
|
|
}
|
|
|
|
// Check response status code
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
|
|
}
|
|
|
|
// Read wireformat response from the body
|
|
buf, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read the response body: %s", err)
|
|
}
|
|
|
|
return buf, nil
|
|
}
|