x/resolver/exchanger/exchanger.go
2024-07-08 22:38:21 +08:00

228 lines
4.9 KiB
Go

package exchanger
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/go-gost/core/chain"
"github.com/go-gost/core/logger"
xchain "github.com/go-gost/x/chain"
"github.com/miekg/dns"
)
type Options struct {
router chain.Router
tlsConfig *tls.Config
timeout time.Duration
logger logger.Logger
}
// Option allows a common way to set Exchanger options.
type Option func(opts *Options)
// RouterOption sets the router for Exchanger.
func RouterOption(router chain.Router) Option {
return func(opts *Options) {
opts.router = router
}
}
// TLSConfigOption sets the TLS config for Exchanger.
func TLSConfigOption(cfg *tls.Config) Option {
return func(opts *Options) {
opts.tlsConfig = cfg
}
}
// LoggerOption sets the logger for Exchanger.
func LoggerOption(logger logger.Logger) Option {
return func(opts *Options) {
opts.logger = logger
}
}
// TimeoutOption sets the timeout for Exchanger.
func TimeoutOption(timeout time.Duration) Option {
return func(opts *Options) {
opts.timeout = timeout
}
}
// Exchanger is an interface for DNS synchronous query.
type Exchanger interface {
Exchange(ctx context.Context, msg []byte) ([]byte, error)
String() string
}
type exchanger struct {
network string
addr string
rawAddr string
router chain.Router
client *http.Client
options Options
}
// NewExchanger create an Exchanger.
// The addr should be URL-like format,
// e.g. udp://1.1.1.1:53, tls://1.1.1.1:853, https://1.0.0.1/dns-query
func NewExchanger(addr string, opts ...Option) (Exchanger, error) {
var options Options
for _, opt := range opts {
opt(&options)
}
if !strings.Contains(addr, "://") {
addr = "udp://" + addr
}
u, err := url.Parse(addr)
if err != nil {
return nil, err
}
if options.timeout <= 0 {
options.timeout = 5 * time.Second
}
ex := &exchanger{
network: u.Scheme,
addr: u.Host,
rawAddr: addr,
router: options.router,
options: options,
}
if _, port, _ := net.SplitHostPort(ex.addr); port == "" {
ex.addr = net.JoinHostPort(ex.addr, "53")
}
if ex.router == nil {
ex.router = xchain.NewRouter(chain.LoggerRouterOption(options.logger))
}
switch ex.network {
case "tcp":
case "dot", "tls":
if ex.options.tlsConfig == nil {
ex.options.tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
ex.network = "tcp"
case "https":
ex.addr = addr
if ex.options.tlsConfig == nil {
ex.options.tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
ex.client = &http.Client{
Timeout: options.timeout,
Transport: &http.Transport{
TLSClientConfig: options.tlsConfig,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: options.timeout,
ExpectContinueTimeout: 1 * time.Second,
DialContext: ex.dial,
},
}
default:
ex.network = "udp"
}
return ex, nil
}
func (ex *exchanger) Exchange(ctx context.Context, msg []byte) ([]byte, error) {
if ex.network == "https" {
return ex.dohExchange(ctx, msg)
}
return ex.exchange(ctx, msg)
}
func (ex *exchanger) dohExchange(ctx context.Context, msg []byte) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, "POST", ex.addr, bytes.NewBuffer(msg))
if err != nil {
return nil, fmt.Errorf("failed to create an HTTPS request: %w", err)
}
// req.Header.Add("Content-Type", "application/dns-udpwireformat")
req.Header.Add("Content-Type", "application/dns-message")
client := ex.client
if client == nil {
client = http.DefaultClient
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to perform an HTTPS request: %w", err)
}
// Check response status code
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
}
// Read wireformat response from the body
buf, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read the response body: %w", err)
}
return buf, nil
}
func (ex *exchanger) exchange(ctx context.Context, msg []byte) ([]byte, error) {
if ex.options.timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, ex.options.timeout)
defer cancel()
}
c, err := ex.dial(ctx, ex.network, ex.addr)
if err != nil {
return nil, err
}
defer c.Close()
if ex.options.tlsConfig != nil {
c = tls.Client(c, ex.options.tlsConfig)
}
if ex.options.timeout > 0 {
c.SetDeadline(time.Now().Add(ex.options.timeout))
}
conn := &dns.Conn{
UDPSize: 1024,
Conn: c,
}
if _, err = conn.Write(msg); err != nil {
return nil, err
}
mr, err := conn.ReadMsg()
if err != nil {
return nil, err
}
return mr.Pack()
}
func (ex *exchanger) dial(ctx context.Context, network, address string) (net.Conn, error) {
return ex.router.Dial(ctx, network, address)
}
func (ex *exchanger) String() string {
return ex.rawAddr
}