add dns handler
This commit is contained in:
@ -32,6 +32,7 @@ import (
|
|||||||
|
|
||||||
// Register handlers
|
// Register handlers
|
||||||
_ "github.com/go-gost/gost/pkg/handler/auto"
|
_ "github.com/go-gost/gost/pkg/handler/auto"
|
||||||
|
_ "github.com/go-gost/gost/pkg/handler/dns"
|
||||||
_ "github.com/go-gost/gost/pkg/handler/forward/local"
|
_ "github.com/go-gost/gost/pkg/handler/forward/local"
|
||||||
_ "github.com/go-gost/gost/pkg/handler/forward/remote"
|
_ "github.com/go-gost/gost/pkg/handler/forward/remote"
|
||||||
_ "github.com/go-gost/gost/pkg/handler/forward/ssh"
|
_ "github.com/go-gost/gost/pkg/handler/forward/ssh"
|
||||||
|
2
go.mod
2
go.mod
@ -26,7 +26,7 @@ require (
|
|||||||
github.com/magiconair/properties v1.8.5 // indirect
|
github.com/magiconair/properties v1.8.5 // indirect
|
||||||
github.com/marten-seemann/qtls-go1-16 v0.1.4 // indirect
|
github.com/marten-seemann/qtls-go1-16 v0.1.4 // indirect
|
||||||
github.com/marten-seemann/qtls-go1-17 v0.1.0 // indirect
|
github.com/marten-seemann/qtls-go1-17 v0.1.0 // indirect
|
||||||
github.com/miekg/dns v1.1.44
|
github.com/miekg/dns v1.1.45
|
||||||
github.com/milosgajdos/tenus v0.0.3
|
github.com/milosgajdos/tenus v0.0.3
|
||||||
github.com/mitchellh/mapstructure v1.4.2 // indirect
|
github.com/mitchellh/mapstructure v1.4.2 // indirect
|
||||||
github.com/mmcloughlin/avo v0.0.0-20200803215136-443f81d77104 // indirect
|
github.com/mmcloughlin/avo v0.0.0-20200803215136-443f81d77104 // indirect
|
||||||
|
2
go.sum
2
go.sum
@ -285,6 +285,8 @@ github.com/miekg/dns v1.1.26 h1:gPxPSwALAeHJSjarOs00QjVdV9QoBvc1D2ujQUr5BzU=
|
|||||||
github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
|
github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
|
||||||
github.com/miekg/dns v1.1.44 h1:4rpqcegYPVkvIeVhITrKP1sRR3KjfRc1nrOPMUZmLyc=
|
github.com/miekg/dns v1.1.44 h1:4rpqcegYPVkvIeVhITrKP1sRR3KjfRc1nrOPMUZmLyc=
|
||||||
github.com/miekg/dns v1.1.44/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
|
github.com/miekg/dns v1.1.44/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
|
||||||
|
github.com/miekg/dns v1.1.45 h1:g5fRIhm9nx7g8osrAvgb16QJfmyMsyOCb+J7LSv+Qzk=
|
||||||
|
github.com/miekg/dns v1.1.45/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
|
||||||
github.com/milosgajdos/tenus v0.0.3 h1:jmaJzwaY1DUyYVD0lM4U+uvP2kkEg1VahDqRFxIkVBE=
|
github.com/milosgajdos/tenus v0.0.3 h1:jmaJzwaY1DUyYVD0lM4U+uvP2kkEg1VahDqRFxIkVBE=
|
||||||
github.com/milosgajdos/tenus v0.0.3/go.mod h1:eIjx29vNeDOYWJuCnaHY2r4fq5egetV26ry3on7p8qY=
|
github.com/milosgajdos/tenus v0.0.3/go.mod h1:eIjx29vNeDOYWJuCnaHY2r4fq5egetV26ry3on7p8qY=
|
||||||
github.com/mitchellh/cli v1.1.0/go.mod h1:xcISNoH86gajksDmfB23e/pu+B+GeFRMYmoHXxx3xhI=
|
github.com/mitchellh/cli v1.1.0/go.mod h1:xcISNoH86gajksDmfB23e/pu+B+GeFRMYmoHXxx3xhI=
|
||||||
|
13
gost.yml
13
gost.yml
@ -12,6 +12,19 @@ profiling:
|
|||||||
# key: "key.pem"
|
# key: "key.pem"
|
||||||
# ca: "root.ca"
|
# ca: "root.ca"
|
||||||
|
|
||||||
|
resolvers:
|
||||||
|
- name: resolver-0
|
||||||
|
ttl: 60s
|
||||||
|
prefer: ipv4
|
||||||
|
clientIP: 1.2.3.4
|
||||||
|
nameServers:
|
||||||
|
- addr: udp://8.8.8.8:53
|
||||||
|
timeout: 5s
|
||||||
|
- addr: tcp://1.1.1.1:53
|
||||||
|
- addr: tls://1.1.1.1:853
|
||||||
|
- addr: https://1.0.0.1/dns-query
|
||||||
|
domain: cloudflare-dns.com
|
||||||
|
|
||||||
services:
|
services:
|
||||||
- name: http+tcp
|
- name: http+tcp
|
||||||
url: "http://gost:gost@:8000"
|
url: "http://gost:gost@:8000"
|
||||||
|
@ -32,6 +32,19 @@ func (r *Router) WithLogger(logger logger.Logger) *Router {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
||||||
|
conn, err = r.dial(ctx, network, address)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if network == "udp" || network == "udp4" || network == "udp6" {
|
||||||
|
if _, ok := conn.(net.PacketConn); !ok {
|
||||||
|
return &packetConn{conn}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
|
||||||
count := r.retries + 1
|
count := r.retries + 1
|
||||||
if count <= 0 {
|
if count <= 0 {
|
||||||
count = 1
|
count = 1
|
||||||
@ -88,3 +101,17 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...conn
|
|||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type packetConn struct {
|
||||||
|
net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||||
|
n, err = c.Read(b)
|
||||||
|
addr = c.Conn.RemoteAddr()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||||
|
return c.Write(b)
|
||||||
|
}
|
||||||
|
199
pkg/handler/dns/handler.go
Normal file
199
pkg/handler/dns/handler.go
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-gost/gost/pkg/bypass"
|
||||||
|
"github.com/go-gost/gost/pkg/chain"
|
||||||
|
"github.com/go-gost/gost/pkg/common/bufpool"
|
||||||
|
"github.com/go-gost/gost/pkg/handler"
|
||||||
|
"github.com/go-gost/gost/pkg/logger"
|
||||||
|
md "github.com/go-gost/gost/pkg/metadata"
|
||||||
|
"github.com/go-gost/gost/pkg/registry"
|
||||||
|
"github.com/go-gost/gost/pkg/resolver/exchanger"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
registry.RegisterHandler("dns", NewHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
type dnsHandler struct {
|
||||||
|
chain *chain.Chain
|
||||||
|
bypass bypass.Bypass
|
||||||
|
exchangers []exchanger.Exchanger
|
||||||
|
logger logger.Logger
|
||||||
|
md metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHandler(opts ...handler.Option) handler.Handler {
|
||||||
|
options := &handler.Options{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(options)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &dnsHandler{
|
||||||
|
bypass: options.Bypass,
|
||||||
|
logger: options.Logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *dnsHandler) Init(md md.Metadata) (err error) {
|
||||||
|
if err = h.parseMetadata(md); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, server := range h.md.servers {
|
||||||
|
ex, err := exchanger.NewExchanger(
|
||||||
|
server,
|
||||||
|
exchanger.ChainOption(h.chain),
|
||||||
|
exchanger.LoggerOption(h.logger),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warnf("parse %s: %v", server, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h.exchangers = append(h.exchangers, ex)
|
||||||
|
}
|
||||||
|
if len(h.exchangers) == 0 {
|
||||||
|
ex, _ := exchanger.NewExchanger(
|
||||||
|
"udp://127.0.0.53:53",
|
||||||
|
exchanger.ChainOption(h.chain),
|
||||||
|
exchanger.LoggerOption(h.logger),
|
||||||
|
)
|
||||||
|
if ex != nil {
|
||||||
|
h.exchangers = append(h.exchangers, ex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// implements chain.Chainable interface
|
||||||
|
func (h *dnsHandler) WithChain(chain *chain.Chain) {
|
||||||
|
h.chain = chain
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
h.logger = h.logger.WithFields(map[string]interface{}{
|
||||||
|
"remote": conn.RemoteAddr().String(),
|
||||||
|
"local": conn.LocalAddr().String(),
|
||||||
|
})
|
||||||
|
|
||||||
|
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||||
|
defer func() {
|
||||||
|
h.logger.WithFields(map[string]interface{}{
|
||||||
|
"duration": time.Since(start),
|
||||||
|
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||||
|
}()
|
||||||
|
|
||||||
|
b := bufpool.Get(4096)
|
||||||
|
defer bufpool.Put(b)
|
||||||
|
|
||||||
|
n, err := conn.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Info("read data: ", n)
|
||||||
|
|
||||||
|
reply, err := h.exchange(ctx, b[:n])
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = conn.Write(reply); err != nil {
|
||||||
|
h.logger.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||||
|
mq := dns.Msg{}
|
||||||
|
if err := mq.Unpack(msg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(mq.Question) == 0 {
|
||||||
|
return nil, errors.New("msg: empty question")
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||||
|
h.logger.Debug(mq.String())
|
||||||
|
} else {
|
||||||
|
h.logger.Info(h.dumpMsgHeader(&mq))
|
||||||
|
}
|
||||||
|
|
||||||
|
var mr *dns.Msg
|
||||||
|
// Only cache for single question.
|
||||||
|
/*
|
||||||
|
if len(mq.Question) == 1 {
|
||||||
|
key := newResolverCacheKey(&mq.Question[0])
|
||||||
|
mr = r.cache.loadCache(key)
|
||||||
|
if mr != nil {
|
||||||
|
log.Logf("[dns] exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
|
||||||
|
mr.Id = mq.Id
|
||||||
|
return mr.Pack()
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if mr != nil {
|
||||||
|
r.cache.storeCache(key, mr, r.TTL())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// r.addSubnetOpt(mq)
|
||||||
|
|
||||||
|
query, err := mq.Pack()
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var reply []byte
|
||||||
|
for _, ex := range h.exchangers {
|
||||||
|
h.logger.Infof("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
|
||||||
|
reply, err = ex.Exchange(ctx, query)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
h.logger.Error(err)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
mr = &dns.Msg{}
|
||||||
|
if err = mr.Unpack(reply); err != nil {
|
||||||
|
h.logger.Error(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.logger.IsLevelEnabled(logger.DebugLevel) {
|
||||||
|
h.logger.Debug(mr.String())
|
||||||
|
} else {
|
||||||
|
h.logger.Info(h.dumpMsgHeader(mr))
|
||||||
|
}
|
||||||
|
|
||||||
|
return reply, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
buf.WriteString(m.MsgHdr.String() + " ")
|
||||||
|
buf.WriteString("QUERY: " + strconv.Itoa(len(m.Question)) + ", ")
|
||||||
|
buf.WriteString("ANSWER: " + strconv.Itoa(len(m.Answer)) + ", ")
|
||||||
|
buf.WriteString("AUTHORITY: " + strconv.Itoa(len(m.Ns)) + ", ")
|
||||||
|
buf.WriteString("ADDITIONAL: " + strconv.Itoa(len(m.Extra)))
|
||||||
|
return buf.String()
|
||||||
|
}
|
43
pkg/handler/dns/metadata.go
Normal file
43
pkg/handler/dns/metadata.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
mdata "github.com/go-gost/gost/pkg/metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
type metadata struct {
|
||||||
|
readTimeout time.Duration
|
||||||
|
retryCount int
|
||||||
|
ttl time.Duration
|
||||||
|
timeout time.Duration
|
||||||
|
prefer string
|
||||||
|
clientIP string
|
||||||
|
// nameservers
|
||||||
|
servers []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
|
||||||
|
const (
|
||||||
|
readTimeout = "readTimeout"
|
||||||
|
retryCount = "retry"
|
||||||
|
ttl = "ttl"
|
||||||
|
timeout = "timeout"
|
||||||
|
prefer = "prefer"
|
||||||
|
clientIP = "clientIP"
|
||||||
|
servers = "servers"
|
||||||
|
)
|
||||||
|
|
||||||
|
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
|
||||||
|
h.md.retryCount = mdata.GetInt(md, retryCount)
|
||||||
|
h.md.ttl = mdata.GetDuration(md, ttl)
|
||||||
|
h.md.timeout = mdata.GetDuration(md, timeout)
|
||||||
|
if h.md.timeout <= 0 {
|
||||||
|
h.md.timeout = 5 * time.Second
|
||||||
|
}
|
||||||
|
h.md.prefer = mdata.GetString(md, prefer)
|
||||||
|
h.md.clientIP = mdata.GetString(md, clientIP)
|
||||||
|
h.md.servers = mdata.GetStrings(md, servers)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
@ -1,17 +0,0 @@
|
|||||||
package tap
|
|
||||||
|
|
||||||
import "net"
|
|
||||||
|
|
||||||
type packetConn struct {
|
|
||||||
net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
|
||||||
n, err = c.Read(b)
|
|
||||||
addr = c.Conn.RemoteAddr()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
|
||||||
return c.Write(b)
|
|
||||||
}
|
|
@ -2,6 +2,7 @@ package tap
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -122,7 +123,12 @@ func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
pc = &packetConn{cc}
|
|
||||||
|
var ok bool
|
||||||
|
pc, ok = cc.(net.PacketConn)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("invalid connection")
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if h.md.tcpMode {
|
if h.md.tcpMode {
|
||||||
if addr != nil {
|
if addr != nil {
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
package tun
|
|
||||||
|
|
||||||
import "net"
|
|
||||||
|
|
||||||
type packetConn struct {
|
|
||||||
net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
|
||||||
n, err = c.Read(b)
|
|
||||||
addr = c.Conn.RemoteAddr()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
|
||||||
return c.Write(b)
|
|
||||||
}
|
|
@ -2,6 +2,7 @@ package tun
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -124,7 +125,12 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
pc = &packetConn{cc}
|
|
||||||
|
var ok bool
|
||||||
|
pc, ok = cc.(net.PacketConn)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("invalid connnection")
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if h.md.tcpMode {
|
if h.md.tcpMode {
|
||||||
if addr != nil {
|
if addr != nil {
|
||||||
|
@ -60,8 +60,9 @@ func (c *serverConn) Read(b []byte) (n int, err error) {
|
|||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
err = io.ErrClosedPipe
|
err = io.ErrClosedPipe
|
||||||
return
|
return
|
||||||
|
default:
|
||||||
|
return c.r.Read(b)
|
||||||
}
|
}
|
||||||
return c.r.Read(b)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *serverConn) Write(b []byte) (n int, err error) {
|
func (c *serverConn) Write(b []byte) (n int, err error) {
|
||||||
@ -69,8 +70,9 @@ func (c *serverConn) Write(b []byte) (n int, err error) {
|
|||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
err = io.ErrClosedPipe
|
err = io.ErrClosedPipe
|
||||||
return
|
return
|
||||||
|
default:
|
||||||
|
return c.w.Write(b)
|
||||||
}
|
}
|
||||||
return c.w.Write(b)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *serverConn) Close() error {
|
func (c *serverConn) Close() error {
|
||||||
|
213
pkg/resolver/exchanger/exchanger.go
Normal file
213
pkg/resolver/exchanger/exchanger.go
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
package exchanger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-gost/gost/pkg/chain"
|
||||||
|
"github.com/go-gost/gost/pkg/logger"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
chain *chain.Chain
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
timeout time.Duration
|
||||||
|
logger logger.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option allows a common way to set Exchanger options.
|
||||||
|
type Option func(opts *Options)
|
||||||
|
|
||||||
|
// ChainOption sets the chain for Exchanger.
|
||||||
|
func ChainOption(chain *chain.Chain) Option {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.chain = chain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TLSConfigOption sets the TLS config for Exchanger.
|
||||||
|
func TLSConfigOption(cfg *tls.Config) Option {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.tlsConfig = cfg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoggerOption sets the logger for Exchanger.
|
||||||
|
func LoggerOption(logger logger.Logger) Option {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.logger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeoutOption sets the timeout for Exchanger.
|
||||||
|
func TimeoutOption(timeout time.Duration) Option {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.timeout = timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchanger is an interface for DNS synchronous query.
|
||||||
|
type Exchanger interface {
|
||||||
|
Exchange(ctx context.Context, msg []byte) ([]byte, error)
|
||||||
|
String() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type exchanger struct {
|
||||||
|
network string
|
||||||
|
addr string
|
||||||
|
rawAddr string
|
||||||
|
router *chain.Router
|
||||||
|
client *http.Client
|
||||||
|
options Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewExchanger create an Exchanger.
|
||||||
|
func NewExchanger(addr string, opts ...Option) (Exchanger, error) {
|
||||||
|
var options Options
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(&options)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(addr, "://") {
|
||||||
|
addr = "udp://" + addr
|
||||||
|
}
|
||||||
|
u, err := url.Parse(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ex := &exchanger{
|
||||||
|
network: u.Scheme,
|
||||||
|
addr: u.Host,
|
||||||
|
rawAddr: addr,
|
||||||
|
options: options,
|
||||||
|
}
|
||||||
|
ex.router = (&chain.Router{}).
|
||||||
|
WithChain(options.chain).
|
||||||
|
WithLogger(options.logger)
|
||||||
|
if _, port, _ := net.SplitHostPort(ex.addr); port == "" {
|
||||||
|
ex.addr = net.JoinHostPort(ex.addr, "53")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ex.network {
|
||||||
|
case "tcp":
|
||||||
|
case "dot", "tls":
|
||||||
|
if ex.options.tlsConfig == nil {
|
||||||
|
ex.options.tlsConfig = &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ex.network = "tcp"
|
||||||
|
case "doh":
|
||||||
|
ex.addr = addr
|
||||||
|
if ex.options.tlsConfig == nil {
|
||||||
|
ex.options.tlsConfig = &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ex.client = &http.Client{
|
||||||
|
Timeout: options.timeout,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSClientConfig: options.tlsConfig,
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: options.timeout,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
DialContext: ex.dial,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ex.network = "udp"
|
||||||
|
}
|
||||||
|
|
||||||
|
return ex, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ex *exchanger) Exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||||
|
if ex.network == "doh" {
|
||||||
|
return ex.dohExchange(ctx, msg)
|
||||||
|
}
|
||||||
|
return ex.exchange(ctx, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ex *exchanger) dohExchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", ex.addr, bytes.NewBuffer(msg))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create an HTTPS request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// req.Header.Add("Content-Type", "application/dns-udpwireformat")
|
||||||
|
req.Header.Add("Content-Type", "application/dns-message")
|
||||||
|
|
||||||
|
client := ex.client
|
||||||
|
if client == nil {
|
||||||
|
client = http.DefaultClient
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to perform an HTTPS request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check response status code
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read wireformat response from the body
|
||||||
|
buf, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read the response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ex *exchanger) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
||||||
|
if ex.options.timeout > 0 {
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, ex.options.timeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := ex.dial(ctx, ex.network, ex.addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
if ex.options.tlsConfig != nil {
|
||||||
|
c = tls.Client(c, ex.options.tlsConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &dns.Conn{Conn: c}
|
||||||
|
|
||||||
|
if _, err = conn.Write(msg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
mr, err := conn.ReadMsg()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return mr.Pack()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ex *exchanger) dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return ex.router.Dial(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ex *exchanger) String() string {
|
||||||
|
return ex.rawAddr
|
||||||
|
}
|
13
pkg/resolver/ns.go
Normal file
13
pkg/resolver/ns.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package resolver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NameServer struct {
|
||||||
|
Addr string
|
||||||
|
Protocol string
|
||||||
|
Hostname string // for TLS handshake verification
|
||||||
|
Exchanger Exchanger
|
||||||
|
Timeout time.Duration
|
||||||
|
}
|
11
pkg/resolver/resolver.go
Normal file
11
pkg/resolver/resolver.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package resolver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Resolver interface {
|
||||||
|
// Resolve returns a slice of the host's IPv4 and IPv6 addresses.
|
||||||
|
Resolve(ctx context.Context, host string) ([]net.IP, error)
|
||||||
|
}
|
Reference in New Issue
Block a user