add dns handler

This commit is contained in:
ginuerzh 2021-12-29 23:45:58 +08:00
parent 8600ee7c5d
commit 9b3d7e1110
15 changed files with 541 additions and 39 deletions

View File

@ -32,6 +32,7 @@ import (
// Register handlers
_ "github.com/go-gost/gost/pkg/handler/auto"
_ "github.com/go-gost/gost/pkg/handler/dns"
_ "github.com/go-gost/gost/pkg/handler/forward/local"
_ "github.com/go-gost/gost/pkg/handler/forward/remote"
_ "github.com/go-gost/gost/pkg/handler/forward/ssh"

2
go.mod
View File

@ -26,7 +26,7 @@ require (
github.com/magiconair/properties v1.8.5 // indirect
github.com/marten-seemann/qtls-go1-16 v0.1.4 // indirect
github.com/marten-seemann/qtls-go1-17 v0.1.0 // indirect
github.com/miekg/dns v1.1.44
github.com/miekg/dns v1.1.45
github.com/milosgajdos/tenus v0.0.3
github.com/mitchellh/mapstructure v1.4.2 // indirect
github.com/mmcloughlin/avo v0.0.0-20200803215136-443f81d77104 // indirect

2
go.sum
View File

@ -285,6 +285,8 @@ github.com/miekg/dns v1.1.26 h1:gPxPSwALAeHJSjarOs00QjVdV9QoBvc1D2ujQUr5BzU=
github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
github.com/miekg/dns v1.1.44 h1:4rpqcegYPVkvIeVhITrKP1sRR3KjfRc1nrOPMUZmLyc=
github.com/miekg/dns v1.1.44/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
github.com/miekg/dns v1.1.45 h1:g5fRIhm9nx7g8osrAvgb16QJfmyMsyOCb+J7LSv+Qzk=
github.com/miekg/dns v1.1.45/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
github.com/milosgajdos/tenus v0.0.3 h1:jmaJzwaY1DUyYVD0lM4U+uvP2kkEg1VahDqRFxIkVBE=
github.com/milosgajdos/tenus v0.0.3/go.mod h1:eIjx29vNeDOYWJuCnaHY2r4fq5egetV26ry3on7p8qY=
github.com/mitchellh/cli v1.1.0/go.mod h1:xcISNoH86gajksDmfB23e/pu+B+GeFRMYmoHXxx3xhI=

View File

@ -12,6 +12,19 @@ profiling:
# key: "key.pem"
# ca: "root.ca"
resolvers:
- name: resolver-0
ttl: 60s
prefer: ipv4
clientIP: 1.2.3.4
nameServers:
- addr: udp://8.8.8.8:53
timeout: 5s
- addr: tcp://1.1.1.1:53
- addr: tls://1.1.1.1:853
- addr: https://1.0.0.1/dns-query
domain: cloudflare-dns.com
services:
- name: http+tcp
url: "http://gost:gost@:8000"

View File

@ -32,6 +32,19 @@ func (r *Router) WithLogger(logger logger.Logger) *Router {
}
func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
conn, err = r.dial(ctx, network, address)
if err != nil {
return
}
if network == "udp" || network == "udp4" || network == "udp6" {
if _, ok := conn.(net.PacketConn); !ok {
return &packetConn{conn}, nil
}
}
return
}
func (r *Router) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
count := r.retries + 1
if count <= 0 {
count = 1
@ -88,3 +101,17 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...conn
return
}
type packetConn struct {
net.Conn
}
func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(b)
addr = c.Conn.RemoteAddr()
return
}
func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.Write(b)
}

199
pkg/handler/dns/handler.go Normal file
View File

@ -0,0 +1,199 @@
package dns
import (
"bytes"
"context"
"errors"
"net"
"strconv"
"time"
"github.com/go-gost/gost/pkg/bypass"
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/common/bufpool"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/go-gost/gost/pkg/resolver/exchanger"
"github.com/miekg/dns"
)
func init() {
registry.RegisterHandler("dns", NewHandler)
}
type dnsHandler struct {
chain *chain.Chain
bypass bypass.Bypass
exchangers []exchanger.Exchanger
logger logger.Logger
md metadata
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := &handler.Options{}
for _, opt := range opts {
opt(options)
}
return &dnsHandler{
bypass: options.Bypass,
logger: options.Logger,
}
}
func (h *dnsHandler) Init(md md.Metadata) (err error) {
if err = h.parseMetadata(md); err != nil {
return
}
for _, server := range h.md.servers {
ex, err := exchanger.NewExchanger(
server,
exchanger.ChainOption(h.chain),
exchanger.LoggerOption(h.logger),
)
if err != nil {
h.logger.Warnf("parse %s: %v", server, err)
continue
}
h.exchangers = append(h.exchangers, ex)
}
if len(h.exchangers) == 0 {
ex, _ := exchanger.NewExchanger(
"udp://127.0.0.53:53",
exchanger.ChainOption(h.chain),
exchanger.LoggerOption(h.logger),
)
if ex != nil {
h.exchangers = append(h.exchangers, ex)
}
}
return
}
// implements chain.Chainable interface
func (h *dnsHandler) WithChain(chain *chain.Chain) {
h.chain = chain
}
func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
defer conn.Close()
start := time.Now()
h.logger = h.logger.WithFields(map[string]interface{}{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
h.logger.WithFields(map[string]interface{}{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
b := bufpool.Get(4096)
defer bufpool.Put(b)
n, err := conn.Read(b)
if err != nil {
h.logger.Error(err)
return
}
h.logger.Info("read data: ", n)
reply, err := h.exchange(ctx, b[:n])
if err != nil {
h.logger.Error(err)
return
}
if _, err = conn.Write(reply); err != nil {
h.logger.Error(err)
}
}
func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
mq := dns.Msg{}
if err := mq.Unpack(msg); err != nil {
return nil, err
}
if len(mq.Question) == 0 {
return nil, errors.New("msg: empty question")
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(mq.String())
} else {
h.logger.Info(h.dumpMsgHeader(&mq))
}
var mr *dns.Msg
// Only cache for single question.
/*
if len(mq.Question) == 1 {
key := newResolverCacheKey(&mq.Question[0])
mr = r.cache.loadCache(key)
if mr != nil {
log.Logf("[dns] exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
mr.Id = mq.Id
return mr.Pack()
}
defer func() {
if mr != nil {
r.cache.storeCache(key, mr, r.TTL())
}
}()
}
*/
// r.addSubnetOpt(mq)
query, err := mq.Pack()
if err != nil {
h.logger.Error(err)
return nil, err
}
var reply []byte
for _, ex := range h.exchangers {
h.logger.Infof("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
reply, err = ex.Exchange(ctx, query)
if err == nil {
break
}
h.logger.Error(err)
}
if err != nil {
h.logger.Error(err)
return nil, err
}
mr = &dns.Msg{}
if err = mr.Unpack(reply); err != nil {
h.logger.Error(err)
return nil, err
}
if h.logger.IsLevelEnabled(logger.DebugLevel) {
h.logger.Debug(mr.String())
} else {
h.logger.Info(h.dumpMsgHeader(mr))
}
return reply, nil
}
func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string {
buf := new(bytes.Buffer)
buf.WriteString(m.MsgHdr.String() + " ")
buf.WriteString("QUERY: " + strconv.Itoa(len(m.Question)) + ", ")
buf.WriteString("ANSWER: " + strconv.Itoa(len(m.Answer)) + ", ")
buf.WriteString("AUTHORITY: " + strconv.Itoa(len(m.Ns)) + ", ")
buf.WriteString("ADDITIONAL: " + strconv.Itoa(len(m.Extra)))
return buf.String()
}

View File

@ -0,0 +1,43 @@
package dns
import (
"time"
mdata "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
readTimeout time.Duration
retryCount int
ttl time.Duration
timeout time.Duration
prefer string
clientIP string
// nameservers
servers []string
}
func (h *dnsHandler) parseMetadata(md mdata.Metadata) (err error) {
const (
readTimeout = "readTimeout"
retryCount = "retry"
ttl = "ttl"
timeout = "timeout"
prefer = "prefer"
clientIP = "clientIP"
servers = "servers"
)
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
h.md.retryCount = mdata.GetInt(md, retryCount)
h.md.ttl = mdata.GetDuration(md, ttl)
h.md.timeout = mdata.GetDuration(md, timeout)
if h.md.timeout <= 0 {
h.md.timeout = 5 * time.Second
}
h.md.prefer = mdata.GetString(md, prefer)
h.md.clientIP = mdata.GetString(md, clientIP)
h.md.servers = mdata.GetStrings(md, servers)
return
}

View File

@ -1,17 +0,0 @@
package tap
import "net"
type packetConn struct {
net.Conn
}
func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(b)
addr = c.Conn.RemoteAddr()
return
}
func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.Write(b)
}

View File

@ -2,6 +2,7 @@ package tap
import (
"context"
"errors"
"fmt"
"io"
"net"
@ -122,7 +123,12 @@ func (h *tapHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add
if err != nil {
return err
}
pc = &packetConn{cc}
var ok bool
pc, ok = cc.(net.PacketConn)
if !ok {
return errors.New("invalid connection")
}
} else {
if h.md.tcpMode {
if addr != nil {

View File

@ -1,17 +0,0 @@
package tun
import "net"
type packetConn struct {
net.Conn
}
func (c *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(b)
addr = c.Conn.RemoteAddr()
return
}
func (c *packetConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.Write(b)
}

View File

@ -2,6 +2,7 @@ package tun
import (
"context"
"errors"
"fmt"
"io"
"net"
@ -124,7 +125,12 @@ func (h *tunHandler) handleLoop(ctx context.Context, conn net.Conn, addr net.Add
if err != nil {
return err
}
pc = &packetConn{cc}
var ok bool
pc, ok = cc.(net.PacketConn)
if !ok {
return errors.New("invalid connnection")
}
} else {
if h.md.tcpMode {
if addr != nil {

View File

@ -60,8 +60,9 @@ func (c *serverConn) Read(b []byte) (n int, err error) {
case <-c.closed:
err = io.ErrClosedPipe
return
default:
return c.r.Read(b)
}
return c.r.Read(b)
}
func (c *serverConn) Write(b []byte) (n int, err error) {
@ -69,8 +70,9 @@ func (c *serverConn) Write(b []byte) (n int, err error) {
case <-c.closed:
err = io.ErrClosedPipe
return
default:
return c.w.Write(b)
}
return c.w.Write(b)
}
func (c *serverConn) Close() error {

View File

@ -0,0 +1,213 @@
package exchanger
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/logger"
"github.com/miekg/dns"
)
type Options struct {
chain *chain.Chain
tlsConfig *tls.Config
timeout time.Duration
logger logger.Logger
}
// Option allows a common way to set Exchanger options.
type Option func(opts *Options)
// ChainOption sets the chain for Exchanger.
func ChainOption(chain *chain.Chain) Option {
return func(opts *Options) {
opts.chain = chain
}
}
// TLSConfigOption sets the TLS config for Exchanger.
func TLSConfigOption(cfg *tls.Config) Option {
return func(opts *Options) {
opts.tlsConfig = cfg
}
}
// LoggerOption sets the logger for Exchanger.
func LoggerOption(logger logger.Logger) Option {
return func(opts *Options) {
opts.logger = logger
}
}
// TimeoutOption sets the timeout for Exchanger.
func TimeoutOption(timeout time.Duration) Option {
return func(opts *Options) {
opts.timeout = timeout
}
}
// Exchanger is an interface for DNS synchronous query.
type Exchanger interface {
Exchange(ctx context.Context, msg []byte) ([]byte, error)
String() string
}
type exchanger struct {
network string
addr string
rawAddr string
router *chain.Router
client *http.Client
options Options
}
// NewExchanger create an Exchanger.
func NewExchanger(addr string, opts ...Option) (Exchanger, error) {
var options Options
for _, opt := range opts {
opt(&options)
}
if !strings.Contains(addr, "://") {
addr = "udp://" + addr
}
u, err := url.Parse(addr)
if err != nil {
return nil, err
}
ex := &exchanger{
network: u.Scheme,
addr: u.Host,
rawAddr: addr,
options: options,
}
ex.router = (&chain.Router{}).
WithChain(options.chain).
WithLogger(options.logger)
if _, port, _ := net.SplitHostPort(ex.addr); port == "" {
ex.addr = net.JoinHostPort(ex.addr, "53")
}
switch ex.network {
case "tcp":
case "dot", "tls":
if ex.options.tlsConfig == nil {
ex.options.tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
ex.network = "tcp"
case "doh":
ex.addr = addr
if ex.options.tlsConfig == nil {
ex.options.tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
ex.client = &http.Client{
Timeout: options.timeout,
Transport: &http.Transport{
TLSClientConfig: options.tlsConfig,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: options.timeout,
ExpectContinueTimeout: 1 * time.Second,
DialContext: ex.dial,
},
}
default:
ex.network = "udp"
}
return ex, nil
}
func (ex *exchanger) Exchange(ctx context.Context, msg []byte) ([]byte, error) {
if ex.network == "doh" {
return ex.dohExchange(ctx, msg)
}
return ex.exchange(ctx, msg)
}
func (ex *exchanger) dohExchange(ctx context.Context, msg []byte) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, "POST", ex.addr, bytes.NewBuffer(msg))
if err != nil {
return nil, fmt.Errorf("failed to create an HTTPS request: %w", err)
}
// req.Header.Add("Content-Type", "application/dns-udpwireformat")
req.Header.Add("Content-Type", "application/dns-message")
client := ex.client
if client == nil {
client = http.DefaultClient
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to perform an HTTPS request: %w", err)
}
// Check response status code
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
}
// Read wireformat response from the body
buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read the response body: %w", err)
}
return buf, nil
}
func (ex *exchanger) exchange(ctx context.Context, msg []byte) ([]byte, error) {
if ex.options.timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, ex.options.timeout)
defer cancel()
}
c, err := ex.dial(ctx, ex.network, ex.addr)
if err != nil {
return nil, err
}
defer c.Close()
if ex.options.tlsConfig != nil {
c = tls.Client(c, ex.options.tlsConfig)
}
conn := &dns.Conn{Conn: c}
if _, err = conn.Write(msg); err != nil {
return nil, err
}
mr, err := conn.ReadMsg()
if err != nil {
return nil, err
}
return mr.Pack()
}
func (ex *exchanger) dial(ctx context.Context, network, address string) (net.Conn, error) {
return ex.router.Dial(ctx, network, address)
}
func (ex *exchanger) String() string {
return ex.rawAddr
}

13
pkg/resolver/ns.go Normal file
View File

@ -0,0 +1,13 @@
package resolver
import (
"time"
)
type NameServer struct {
Addr string
Protocol string
Hostname string // for TLS handshake verification
Exchanger Exchanger
Timeout time.Duration
}

11
pkg/resolver/resolver.go Normal file
View File

@ -0,0 +1,11 @@
package resolver
import (
"context"
"net"
)
type Resolver interface {
// Resolve returns a slice of the host's IPv4 and IPv6 addresses.
Resolve(ctx context.Context, host string) ([]net.IP, error)
}