228 lines
4.9 KiB
Go
228 lines
4.9 KiB
Go
package exchanger
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-gost/core/chain"
|
|
"github.com/go-gost/core/logger"
|
|
xchain "github.com/go-gost/x/chain"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
type Options struct {
|
|
router chain.Router
|
|
tlsConfig *tls.Config
|
|
timeout time.Duration
|
|
logger logger.Logger
|
|
}
|
|
|
|
// Option allows a common way to set Exchanger options.
|
|
type Option func(opts *Options)
|
|
|
|
// RouterOption sets the router for Exchanger.
|
|
func RouterOption(router chain.Router) Option {
|
|
return func(opts *Options) {
|
|
opts.router = router
|
|
}
|
|
}
|
|
|
|
// TLSConfigOption sets the TLS config for Exchanger.
|
|
func TLSConfigOption(cfg *tls.Config) Option {
|
|
return func(opts *Options) {
|
|
opts.tlsConfig = cfg
|
|
}
|
|
}
|
|
|
|
// LoggerOption sets the logger for Exchanger.
|
|
func LoggerOption(logger logger.Logger) Option {
|
|
return func(opts *Options) {
|
|
opts.logger = logger
|
|
}
|
|
}
|
|
|
|
// TimeoutOption sets the timeout for Exchanger.
|
|
func TimeoutOption(timeout time.Duration) Option {
|
|
return func(opts *Options) {
|
|
opts.timeout = timeout
|
|
}
|
|
}
|
|
|
|
// Exchanger is an interface for DNS synchronous query.
|
|
type Exchanger interface {
|
|
Exchange(ctx context.Context, msg []byte) ([]byte, error)
|
|
String() string
|
|
}
|
|
|
|
type exchanger struct {
|
|
network string
|
|
addr string
|
|
rawAddr string
|
|
router chain.Router
|
|
client *http.Client
|
|
options Options
|
|
}
|
|
|
|
// NewExchanger create an Exchanger.
|
|
// The addr should be URL-like format,
|
|
// e.g. udp://1.1.1.1:53, tls://1.1.1.1:853, https://1.0.0.1/dns-query
|
|
func NewExchanger(addr string, opts ...Option) (Exchanger, error) {
|
|
var options Options
|
|
for _, opt := range opts {
|
|
opt(&options)
|
|
}
|
|
|
|
if !strings.Contains(addr, "://") {
|
|
addr = "udp://" + addr
|
|
}
|
|
u, err := url.Parse(addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if options.timeout <= 0 {
|
|
options.timeout = 5 * time.Second
|
|
}
|
|
|
|
ex := &exchanger{
|
|
network: u.Scheme,
|
|
addr: u.Host,
|
|
rawAddr: addr,
|
|
router: options.router,
|
|
options: options,
|
|
}
|
|
if _, port, _ := net.SplitHostPort(ex.addr); port == "" {
|
|
ex.addr = net.JoinHostPort(ex.addr, "53")
|
|
}
|
|
if ex.router == nil {
|
|
ex.router = xchain.NewRouter(chain.LoggerRouterOption(options.logger))
|
|
}
|
|
|
|
switch ex.network {
|
|
case "tcp":
|
|
case "dot", "tls":
|
|
if ex.options.tlsConfig == nil {
|
|
ex.options.tlsConfig = &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
}
|
|
}
|
|
ex.network = "tcp"
|
|
case "https":
|
|
ex.addr = addr
|
|
if ex.options.tlsConfig == nil {
|
|
ex.options.tlsConfig = &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
}
|
|
}
|
|
ex.client = &http.Client{
|
|
Timeout: options.timeout,
|
|
Transport: &http.Transport{
|
|
TLSClientConfig: options.tlsConfig,
|
|
ForceAttemptHTTP2: true,
|
|
MaxIdleConns: 100,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: options.timeout,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
DialContext: ex.dial,
|
|
},
|
|
}
|
|
default:
|
|
ex.network = "udp"
|
|
}
|
|
|
|
return ex, nil
|
|
}
|
|
|
|
func (ex *exchanger) Exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
|
if ex.network == "https" {
|
|
return ex.dohExchange(ctx, msg)
|
|
}
|
|
return ex.exchange(ctx, msg)
|
|
}
|
|
|
|
func (ex *exchanger) dohExchange(ctx context.Context, msg []byte) ([]byte, error) {
|
|
req, err := http.NewRequestWithContext(ctx, "POST", ex.addr, bytes.NewBuffer(msg))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create an HTTPS request: %w", err)
|
|
}
|
|
|
|
// req.Header.Add("Content-Type", "application/dns-udpwireformat")
|
|
req.Header.Add("Content-Type", "application/dns-message")
|
|
|
|
client := ex.client
|
|
if client == nil {
|
|
client = http.DefaultClient
|
|
}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to perform an HTTPS request: %w", err)
|
|
}
|
|
|
|
// Check response status code
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
|
|
}
|
|
|
|
// Read wireformat response from the body
|
|
buf, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read the response body: %w", err)
|
|
}
|
|
|
|
return buf, nil
|
|
}
|
|
|
|
func (ex *exchanger) exchange(ctx context.Context, msg []byte) ([]byte, error) {
|
|
if ex.options.timeout > 0 {
|
|
var cancel context.CancelFunc
|
|
ctx, cancel = context.WithTimeout(ctx, ex.options.timeout)
|
|
defer cancel()
|
|
}
|
|
|
|
c, err := ex.dial(ctx, ex.network, ex.addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer c.Close()
|
|
|
|
if ex.options.tlsConfig != nil {
|
|
c = tls.Client(c, ex.options.tlsConfig)
|
|
}
|
|
if ex.options.timeout > 0 {
|
|
c.SetDeadline(time.Now().Add(ex.options.timeout))
|
|
}
|
|
|
|
conn := &dns.Conn{
|
|
UDPSize: 1024,
|
|
Conn: c,
|
|
}
|
|
|
|
if _, err = conn.Write(msg); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mr, err := conn.ReadMsg()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return mr.Pack()
|
|
}
|
|
|
|
func (ex *exchanger) dial(ctx context.Context, network, address string) (net.Conn, error) {
|
|
return ex.router.Dial(ctx, network, address)
|
|
}
|
|
|
|
func (ex *exchanger) String() string {
|
|
return ex.rawAddr
|
|
}
|