diff --git a/pkg/chain/route.go b/pkg/chain/route.go index 5df063b..2591747 100644 --- a/pkg/chain/route.go +++ b/pkg/chain/route.go @@ -7,9 +7,9 @@ import ( "net" "time" + "github.com/go-gost/gost/pkg/common/net/dialer" "github.com/go-gost/gost/pkg/common/util/udp" "github.com/go-gost/gost/pkg/connector" - "github.com/go-gost/gost/pkg/dialer" "github.com/go-gost/gost/pkg/logger" ) diff --git a/pkg/chain/transport.go b/pkg/chain/transport.go index 2734b7f..755922c 100644 --- a/pkg/chain/transport.go +++ b/pkg/chain/transport.go @@ -5,6 +5,7 @@ import ( "net" "time" + net_dialer "github.com/go-gost/gost/pkg/common/net/dialer" "github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/dialer" ) @@ -39,7 +40,7 @@ func (tr *Transport) WithConnector(connector connector.Connector) *Transport { } func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) { - netd := &dialer.NetDialer{ + netd := &net_dialer.NetDialer{ Interface: tr.ifceName, Timeout: 30 * time.Second, } diff --git a/pkg/common/net/dialer/dialer.go b/pkg/common/net/dialer/dialer.go new file mode 100644 index 0000000..be61bed --- /dev/null +++ b/pkg/common/net/dialer/dialer.go @@ -0,0 +1,144 @@ +package dialer + +import ( + "context" + "fmt" + "net" + "syscall" + "time" + + "github.com/go-gost/gost/pkg/logger" +) + +var ( + DefaultNetDialer = &NetDialer{ + Timeout: 30 * time.Second, + } +) + +type NetDialer struct { + Interface string + Timeout time.Duration + DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + Logger logger.Logger +} + +func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) { + if d == nil { + d = DefaultNetDialer + } + log := d.Logger + if log == nil { + log = logger.Default() + } + + ifceName, ifAddr, err := parseInterfaceAddr(d.Interface, network) + if err != nil { + return nil, err + } + if d.DialFunc != nil { + return d.DialFunc(ctx, network, addr) + } + logger.Default().Infof("interface: %s %s/%s", ifceName, ifAddr, network) + + switch network { + case "udp", "udp4", "udp6": + if addr == "" { + var laddr *net.UDPAddr + if ifAddr != nil { + laddr, _ = ifAddr.(*net.UDPAddr) + } + + return net.ListenUDP(network, laddr) + } + case "tcp", "tcp4", "tcp6": + default: + return nil, fmt.Errorf("dial: unsupported network %s", network) + } + netd := net.Dialer{ + Timeout: d.Timeout, + LocalAddr: ifAddr, + Control: func(network, address string, c syscall.RawConn) error { + var cerr error + err := c.Control(func(fd uintptr) { + cerr = bindDevice(fd, ifceName) + }) + if err != nil { + return err + } + if cerr != nil { + return cerr + } + return nil + }, + } + return netd.DialContext(ctx, network, addr) +} + +func parseInterfaceAddr(ifceName, network string) (ifce string, addr net.Addr, err error) { + if ifceName == "" { + return + } + + ip := net.ParseIP(ifceName) + if ip == nil { + var ife *net.Interface + ife, err = net.InterfaceByName(ifceName) + if err != nil { + return + } + var addrs []net.Addr + addrs, err = ife.Addrs() + if err != nil { + return + } + if len(addrs) == 0 { + err = fmt.Errorf("addr not found for interface %s", ifceName) + return + } + ip = addrs[0].(*net.IPNet).IP + ifce = ifceName + } else { + ifce, err = findInterfaceByIP(ip) + if err != nil { + return + } + } + + port := 0 + switch network { + case "tcp", "tcp4", "tcp6": + addr = &net.TCPAddr{IP: ip, Port: port} + return + case "udp", "udp4", "udp6": + addr = &net.UDPAddr{IP: ip, Port: port} + return + default: + addr = &net.IPAddr{IP: ip} + return + } +} + +func findInterfaceByIP(ip net.IP) (string, error) { + ifces, err := net.Interfaces() + if err != nil { + return "", err + } + for _, ifce := range ifces { + addrs, _ := ifce.Addrs() + if len(addrs) == 0 { + continue + } + for _, addr := range addrs { + ipAddr, _ := addr.(*net.IPNet) + if ipAddr == nil { + continue + } + // logger.Default().Infof("%s-%s", ipAddr, ip) + if ipAddr.IP.Equal(ip) { + return ifce.Name, nil + } + } + } + return "", nil +} diff --git a/pkg/common/net/dialer/dialer_linux.go b/pkg/common/net/dialer/dialer_linux.go new file mode 100644 index 0000000..cf7f0aa --- /dev/null +++ b/pkg/common/net/dialer/dialer_linux.go @@ -0,0 +1,14 @@ +package dialer + +import ( + "golang.org/x/sys/unix" +) + +func bindDevice(fd uintptr, ifceName string) error { + // unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1) + // unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1) + if ifceName == "" { + return nil + } + return unix.BindToDevice(int(fd), ifceName) +} diff --git a/pkg/common/net/dialer/dialer_other.go b/pkg/common/net/dialer/dialer_other.go new file mode 100644 index 0000000..6ad6b21 --- /dev/null +++ b/pkg/common/net/dialer/dialer_other.go @@ -0,0 +1,7 @@ +//go:build !linux + +package dialer + +func bindDevice(fd uintptr, ifceName string) error { + return nil +} diff --git a/pkg/dialer/grpc/dialer.go b/pkg/dialer/grpc/dialer.go index f1871cd..d93b5bb 100644 --- a/pkg/dialer/grpc/dialer.go +++ b/pkg/dialer/grpc/dialer.go @@ -75,11 +75,7 @@ func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO grpcOpts := []grpc.DialOption{ // grpc.WithBlock(), grpc.WithContextDialer(func(c context.Context, s string) (net.Conn, error) { - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer - } - return netd.Dial(c, "tcp", s) + return options.NetDialer.Dial(c, "tcp", s) }), grpc.WithAuthority(host), grpc.WithConnectParams(grpc.ConnectParams{ diff --git a/pkg/dialer/http2/dialer.go b/pkg/dialer/http2/dialer.go index 735012a..d89b308 100644 --- a/pkg/dialer/http2/dialer.go +++ b/pkg/dialer/http2/dialer.go @@ -7,6 +7,7 @@ import ( "sync" "time" + net_dialer "github.com/go-gost/gost/pkg/common/net/dialer" "github.com/go-gost/gost/pkg/dialer" http2_util "github.com/go-gost/gost/pkg/internal/util/http2" "github.com/go-gost/gost/pkg/logger" @@ -75,7 +76,7 @@ func (d *http2Dialer) Dial(ctx context.Context, address string, opts ...dialer.D DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { netd := options.NetDialer if netd == nil { - netd = dialer.DefaultNetDialer + netd = net_dialer.DefaultNetDialer } return netd.Dial(ctx, network, addr) }, diff --git a/pkg/dialer/http2/h2/dialer.go b/pkg/dialer/http2/h2/dialer.go index c6ba8b8..c1cff6e 100644 --- a/pkg/dialer/http2/h2/dialer.go +++ b/pkg/dialer/http2/h2/dialer.go @@ -93,22 +93,14 @@ func (d *h2Dialer) Dial(ctx context.Context, address string, opts ...dialer.Dial if d.h2c { client.Transport = &http2.Transport{ DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer - } - return netd.Dial(ctx, network, addr) + return options.NetDialer.Dial(ctx, network, addr) }, } } else { client.Transport = &http.Transport{ TLSClientConfig: d.options.TLSConfig, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer - } - return netd.Dial(ctx, network, addr) + return options.NetDialer.Dial(ctx, network, addr) }, ForceAttemptHTTP2: true, MaxIdleConns: 100, diff --git a/pkg/dialer/http3/dialer.go b/pkg/dialer/http3/dialer.go index dbd0a0a..9149862 100644 --- a/pkg/dialer/http3/dialer.go +++ b/pkg/dialer/http3/dialer.go @@ -2,14 +2,16 @@ package http3 import ( "context" + "crypto/tls" "net" "net/http" + "sync" "github.com/go-gost/gost/pkg/dialer" pht_util "github.com/go-gost/gost/pkg/internal/util/pht" - "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" + "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/http3" ) @@ -19,10 +21,10 @@ func init() { } type http3Dialer struct { - client *pht_util.Client - md metadata - logger logger.Logger - options dialer.Options + clients map[string]*pht_util.Client + clientMutex sync.Mutex + md metadata + options dialer.Options } func NewDialer(opts ...dialer.Option) dialer.Dialer { @@ -32,7 +34,7 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer { } return &http3Dialer{ - logger: options.Logger, + clients: make(map[string]*pht_util.Client), options: options, } } @@ -42,23 +44,64 @@ func (d *http3Dialer) Init(md md.Metadata) (err error) { return } - tr := &http3.RoundTripper{ - TLSClientConfig: d.options.TLSConfig, - } - d.client = &pht_util.Client{ - Client: &http.Client{ - // Timeout: 60 * time.Second, - Transport: tr, - }, - AuthorizePath: d.md.authorizePath, - PushPath: d.md.pushPath, - PullPath: d.md.pullPath, - TLSEnabled: true, - Logger: d.options.Logger, - } return nil } func (d *http3Dialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { - return d.client.Dial(ctx, addr) + d.clientMutex.Lock() + defer d.clientMutex.Unlock() + + client, ok := d.clients[addr] + if !ok { + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) + } + + host := d.md.host + if host == "" { + host = options.Host + } + if h, _, _ := net.SplitHostPort(host); h != "" { + host = h + } + + client = &pht_util.Client{ + Host: host, + Client: &http.Client{ + // Timeout: 60 * time.Second, + Transport: &http3.RoundTripper{ + TLSClientConfig: d.options.TLSConfig, + Dial: func(network, adr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlySession, error) { + // d.options.Logger.Infof("dial: %s/%s, %s", addr, network, host) + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + udpConn, err := options.NetDialer.Dial(context.Background(), "udp", "") + if err != nil { + return nil, err + } + + return quic.DialEarly(udpConn.(net.PacketConn), udpAddr, host, tlsCfg, cfg) + }, + }, + }, + AuthorizePath: d.md.authorizePath, + PushPath: d.md.pushPath, + PullPath: d.md.pullPath, + TLSEnabled: true, + Logger: d.options.Logger, + } + + d.clients[addr] = client + } + + return client.Dial(ctx, addr) +} + +// Multiplex implements dialer.Multiplexer interface. +func (d *http3Dialer) Multiplex() bool { + return true } diff --git a/pkg/dialer/http3/metadata.go b/pkg/dialer/http3/metadata.go index 615ade8..04f034a 100644 --- a/pkg/dialer/http3/metadata.go +++ b/pkg/dialer/http3/metadata.go @@ -23,6 +23,7 @@ type metadata struct { authorizePath string pushPath string pullPath string + host string } func (d *http3Dialer) parseMetadata(md mdata.Metadata) (err error) { @@ -30,6 +31,7 @@ func (d *http3Dialer) parseMetadata(md mdata.Metadata) (err error) { authorizePath = "authorizePath" pushPath = "pushPath" pullPath = "pullPath" + host = "host" ) d.md.authorizePath = mdata.GetString(md, authorizePath) @@ -44,5 +46,7 @@ func (d *http3Dialer) parseMetadata(md mdata.Metadata) (err error) { if !strings.HasPrefix(d.md.pullPath, "/") { d.md.pullPath = defaultPullPath } + + d.md.host = mdata.GetString(md, host) return } diff --git a/pkg/dialer/kcp/dialer.go b/pkg/dialer/kcp/dialer.go index 1667ab5..d53025e 100644 --- a/pkg/dialer/kcp/dialer.go +++ b/pkg/dialer/kcp/dialer.go @@ -85,11 +85,7 @@ func (d *kcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp PacketConn: pc, } } else { - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer - } - conn, err = netd.Dial(ctx, "udp", addr) + conn, err = options.NetDialer.Dial(ctx, "udp", addr) if err != nil { return nil, err } diff --git a/pkg/dialer/net.go b/pkg/dialer/net.go deleted file mode 100644 index 482a8ea..0000000 --- a/pkg/dialer/net.go +++ /dev/null @@ -1,84 +0,0 @@ -package dialer - -import ( - "context" - "fmt" - "net" - "time" - - "github.com/go-gost/gost/pkg/logger" -) - -var ( - DefaultNetDialer = &NetDialer{ - Timeout: 30 * time.Second, - } -) - -type NetDialer struct { - Interface string - Timeout time.Duration - DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) -} - -func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) { - ifAddr, err := parseInterfaceAddr(d.Interface, network) - if err != nil { - return nil, err - } - if d.DialFunc != nil { - return d.DialFunc(ctx, network, addr) - } - logger.Default().Infof("interface: %s %s %v", d.Interface, network, ifAddr) - - switch network { - case "udp", "udp4", "udp6": - if addr == "" { - var laddr *net.UDPAddr - if ifAddr != nil { - laddr, _ = ifAddr.(*net.UDPAddr) - } - - return net.ListenUDP(network, laddr) - } - case "tcp", "tcp4", "tcp6": - default: - return nil, fmt.Errorf("dial: unsupported network %s", network) - } - netd := net.Dialer{ - Timeout: d.Timeout, - LocalAddr: ifAddr, - } - return netd.DialContext(ctx, network, addr) -} - -func parseInterfaceAddr(ifceName, network string) (net.Addr, error) { - if ifceName == "" { - return nil, nil - } - - ip := net.ParseIP(ifceName) - if ip == nil { - ifce, err := net.InterfaceByName(ifceName) - if err != nil { - return nil, err - } - addrs, err := ifce.Addrs() - if err != nil { - return nil, err - } - if len(addrs) == 0 { - return nil, fmt.Errorf("addr not found for interface %s", ifceName) - } - ip = addrs[0].(*net.IPNet).IP - } - - switch network { - case "tcp", "tcp4", "tcp6": - return &net.TCPAddr{IP: ip}, nil - case "udp", "udp4", "udp6": - return &net.UDPAddr{IP: ip}, nil - default: - return &net.IPAddr{IP: ip}, nil - } -} diff --git a/pkg/dialer/obfs/http/dialer.go b/pkg/dialer/obfs/http/dialer.go index d2f4550..dd6fa78 100644 --- a/pkg/dialer/obfs/http/dialer.go +++ b/pkg/dialer/obfs/http/dialer.go @@ -40,11 +40,7 @@ func (d *obfsHTTPDialer) Dial(ctx context.Context, addr string, opts ...dialer.D opt(options) } - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer - } - conn, err := netd.Dial(ctx, "tcp", addr) + conn, err := options.NetDialer.Dial(ctx, "tcp", addr) if err != nil { d.logger.Error(err) } diff --git a/pkg/dialer/obfs/tls/dialer.go b/pkg/dialer/obfs/tls/dialer.go index 47c7207..0778225 100644 --- a/pkg/dialer/obfs/tls/dialer.go +++ b/pkg/dialer/obfs/tls/dialer.go @@ -40,11 +40,7 @@ func (d *obfsTLSDialer) Dial(ctx context.Context, addr string, opts ...dialer.Di opt(options) } - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer - } - conn, err := netd.Dial(ctx, "tcp", addr) + conn, err := options.NetDialer.Dial(ctx, "tcp", addr) if err != nil { d.logger.Error(err) } diff --git a/pkg/dialer/option.go b/pkg/dialer/option.go index 3cdb880..f766851 100644 --- a/pkg/dialer/option.go +++ b/pkg/dialer/option.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "net/url" + "github.com/go-gost/gost/pkg/common/net/dialer" "github.com/go-gost/gost/pkg/logger" ) @@ -35,7 +36,7 @@ func LoggerOption(logger logger.Logger) Option { type DialOptions struct { Host string - NetDialer *NetDialer + NetDialer *dialer.NetDialer } type DialOption func(opts *DialOptions) @@ -46,7 +47,7 @@ func HostDialOption(host string) DialOption { } } -func NetDialerDialOption(netd *NetDialer) DialOption { +func NetDialerDialOption(netd *dialer.NetDialer) DialOption { return func(opts *DialOptions) { opts.NetDialer = netd } diff --git a/pkg/dialer/pht/dialer.go b/pkg/dialer/pht/dialer.go index 10f19a4..4e4d617 100644 --- a/pkg/dialer/pht/dialer.go +++ b/pkg/dialer/pht/dialer.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/http" + "sync" "time" "github.com/go-gost/gost/pkg/dialer" @@ -19,11 +20,12 @@ func init() { } type phtDialer struct { - tlsEnabled bool - client *pht_util.Client - md metadata - logger logger.Logger - options dialer.Options + clients map[string]*pht_util.Client + clientMutex sync.Mutex + tlsEnabled bool + md metadata + logger logger.Logger + options dialer.Options } func NewDialer(opts ...dialer.Option) dialer.Dialer { @@ -33,7 +35,7 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer { } return &phtDialer{ - logger: options.Logger, + clients: make(map[string]*pht_util.Client), options: options, } } @@ -46,7 +48,7 @@ func NewTLSDialer(opts ...dialer.Option) dialer.Dialer { return &phtDialer{ tlsEnabled: true, - logger: options.Logger, + clients: make(map[string]*pht_util.Client), options: options, } } @@ -56,36 +58,57 @@ func (d *phtDialer) Init(md md.Metadata) (err error) { return } - tr := &http.Transport{ - // Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - }).DialContext, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - if d.tlsEnabled { - tr.TLSClientConfig = d.options.TLSConfig - } - - d.client = &pht_util.Client{ - Client: &http.Client{ - // Timeout: 60 * time.Second, - Transport: tr, - }, - AuthorizePath: d.md.authorizePath, - PushPath: d.md.pushPath, - PullPath: d.md.pullPath, - TLSEnabled: d.tlsEnabled, - Logger: d.options.Logger, - } return nil } func (d *phtDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { - return d.client.Dial(ctx, addr) + d.clientMutex.Lock() + defer d.clientMutex.Unlock() + + client, ok := d.clients[addr] + if !ok { + var options dialer.DialOptions + for _, opt := range opts { + opt(&options) + } + + host := d.md.host + if host == "" { + host = options.Host + } + if h, _, _ := net.SplitHostPort(host); h != "" { + host = h + } + + tr := &http.Transport{ + // Proxy: http.ProxyFromEnvironment, + DialContext: func(ctx context.Context, network, adr string) (net.Conn, error) { + return options.NetDialer.Dial(ctx, network, addr) + }, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + if d.tlsEnabled { + tr.TLSClientConfig = d.options.TLSConfig + } + + client = &pht_util.Client{ + Host: host, + Client: &http.Client{ + // Timeout: 60 * time.Second, + Transport: tr, + }, + AuthorizePath: d.md.authorizePath, + PushPath: d.md.pushPath, + PullPath: d.md.pullPath, + TLSEnabled: d.tlsEnabled, + Logger: d.options.Logger, + } + d.clients[addr] = client + } + + return client.Dial(ctx, addr) } diff --git a/pkg/dialer/pht/metadata.go b/pkg/dialer/pht/metadata.go index df1df02..adc7f47 100644 --- a/pkg/dialer/pht/metadata.go +++ b/pkg/dialer/pht/metadata.go @@ -23,6 +23,7 @@ type metadata struct { authorizePath string pushPath string pullPath string + host string } func (d *phtDialer) parseMetadata(md mdata.Metadata) (err error) { @@ -30,6 +31,7 @@ func (d *phtDialer) parseMetadata(md mdata.Metadata) (err error) { authorizePath = "authorizePath" pushPath = "pushPath" pullPath = "pullPath" + host = "host" ) d.md.authorizePath = mdata.GetString(md, authorizePath) @@ -44,5 +46,7 @@ func (d *phtDialer) parseMetadata(md mdata.Metadata) (err error) { if !strings.HasPrefix(d.md.pullPath, "/") { d.md.pullPath = defaultPullPath } + + d.md.host = mdata.GetString(md, host) return } diff --git a/pkg/dialer/quic/dialer.go b/pkg/dialer/quic/dialer.go index d3cce09..3c55dfd 100644 --- a/pkg/dialer/quic/dialer.go +++ b/pkg/dialer/quic/dialer.go @@ -64,11 +64,14 @@ func (d *quicDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO opt(options) } - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer + host := d.md.host + if host == "" { + host = options.Host } - conn, err = netd.Dial(ctx, "udp", "") + if h, _, _ := net.SplitHostPort(host); h != "" { + host = h + } + conn, err = options.NetDialer.Dial(ctx, "udp", "") if err != nil { return nil, err } diff --git a/pkg/dialer/quic/metadata.go b/pkg/dialer/quic/metadata.go index 85e7f14..f202052 100644 --- a/pkg/dialer/quic/metadata.go +++ b/pkg/dialer/quic/metadata.go @@ -12,6 +12,7 @@ type metadata struct { handshakeTimeout time.Duration cipherKey []byte + host string } func (d *quicDialer) parseMetadata(md mdata.Metadata) (err error) { @@ -21,6 +22,7 @@ func (d *quicDialer) parseMetadata(md mdata.Metadata) (err error) { maxIdleTimeout = "maxIdleTimeout" cipherKey = "cipherKey" + host = "host" ) d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) @@ -32,5 +34,7 @@ func (d *quicDialer) parseMetadata(md mdata.Metadata) (err error) { d.md.keepAlive = mdata.GetBool(md, keepAlive) d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout) d.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout) + + d.md.host = mdata.GetString(md, host) return } diff --git a/pkg/dialer/ssh/dialer.go b/pkg/dialer/ssh/dialer.go index 30c2a3a..87dc124 100644 --- a/pkg/dialer/ssh/dialer.go +++ b/pkg/dialer/ssh/dialer.go @@ -66,11 +66,7 @@ func (d *sshDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp opt(&options) } - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer - } - conn, err = netd.Dial(ctx, "tcp", addr) + conn, err = options.NetDialer.Dial(ctx, "tcp", addr) if err != nil { return } diff --git a/pkg/dialer/sshd/dialer.go b/pkg/dialer/sshd/dialer.go index c45d438..735a60d 100644 --- a/pkg/dialer/sshd/dialer.go +++ b/pkg/dialer/sshd/dialer.go @@ -65,11 +65,7 @@ func (d *sshdDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO opt(&options) } - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer - } - conn, err = netd.Dial(ctx, "tcp", addr) + conn, err = options.NetDialer.Dial(ctx, "tcp", addr) if err != nil { return } diff --git a/pkg/dialer/tcp/dialer.go b/pkg/dialer/tcp/dialer.go index 28be9e9..d3d1326 100644 --- a/pkg/dialer/tcp/dialer.go +++ b/pkg/dialer/tcp/dialer.go @@ -40,11 +40,7 @@ func (d *tcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp opt(&options) } - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer - } - conn, err := netd.Dial(ctx, "tcp", addr) + conn, err := options.NetDialer.Dial(ctx, "tcp", addr) if err != nil { d.logger.Error(err) } diff --git a/pkg/dialer/tls/dialer.go b/pkg/dialer/tls/dialer.go index 0e63a2c..f2ee19e 100644 --- a/pkg/dialer/tls/dialer.go +++ b/pkg/dialer/tls/dialer.go @@ -44,11 +44,7 @@ func (d *tlsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp opt(&options) } - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer - } - conn, err := netd.Dial(ctx, "tcp", addr) + conn, err := options.NetDialer.Dial(ctx, "tcp", addr) if err != nil { d.logger.Error(err) } diff --git a/pkg/dialer/tls/mux/dialer.go b/pkg/dialer/tls/mux/dialer.go index 9879740..db76c40 100644 --- a/pkg/dialer/tls/mux/dialer.go +++ b/pkg/dialer/tls/mux/dialer.go @@ -68,11 +68,7 @@ func (d *mtlsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO opt(&options) } - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer - } - conn, err = netd.Dial(ctx, "tcp", addr) + conn, err = options.NetDialer.Dial(ctx, "tcp", addr) if err != nil { return } diff --git a/pkg/dialer/udp/dialer.go b/pkg/dialer/udp/dialer.go index 4c84896..f3a6fb2 100644 --- a/pkg/dialer/udp/dialer.go +++ b/pkg/dialer/udp/dialer.go @@ -40,11 +40,7 @@ func (d *udpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp opt(&options) } - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer - } - c, err := netd.Dial(ctx, "udp", addr) + c, err := options.NetDialer.Dial(ctx, "udp", addr) if err != nil { return nil, err } diff --git a/pkg/dialer/ws/dialer.go b/pkg/dialer/ws/dialer.go index ab6cce1..edf47be 100644 --- a/pkg/dialer/ws/dialer.go +++ b/pkg/dialer/ws/dialer.go @@ -61,11 +61,7 @@ func (d *wsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOpt opt(&options) } - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer - } - conn, err := netd.Dial(ctx, "tcp", addr) + conn, err := options.NetDialer.Dial(ctx, "tcp", addr) if err != nil { d.logger.Error(err) } diff --git a/pkg/dialer/ws/mux/dialer.go b/pkg/dialer/ws/mux/dialer.go index e965e17..f13933b 100644 --- a/pkg/dialer/ws/mux/dialer.go +++ b/pkg/dialer/ws/mux/dialer.go @@ -85,11 +85,7 @@ func (d *mwsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOp opt(&options) } - netd := options.NetDialer - if netd == nil { - netd = dialer.DefaultNetDialer - } - conn, err = netd.Dial(ctx, "tcp", addr) + conn, err = options.NetDialer.Dial(ctx, "tcp", addr) if err != nil { return } diff --git a/pkg/internal/util/pht/client.go b/pkg/internal/util/pht/client.go index b708235..556c47a 100644 --- a/pkg/internal/util/pht/client.go +++ b/pkg/internal/util/pht/client.go @@ -8,12 +8,14 @@ import ( "net" "net/http" "net/http/httputil" + "strconv" "strings" "github.com/go-gost/gost/pkg/logger" ) type Client struct { + Host string Client *http.Client AuthorizePath string PushPath string @@ -23,6 +25,16 @@ type Client struct { } func (c *Client) Dial(ctx context.Context, addr string) (net.Conn, error) { + raddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + c.Logger.Error(err) + return nil, err + } + + if c.Host != "" { + addr = net.JoinHostPort(c.Host, strconv.Itoa(raddr.Port)) + } + token, err := c.authorize(ctx, addr) if err != nil { c.Logger.Error(err) @@ -30,13 +42,13 @@ func (c *Client) Dial(ctx context.Context, addr string) (net.Conn, error) { } cn := &clientConn{ - client: c.Client, - rxc: make(chan []byte, 128), - closed: make(chan struct{}), - localAddr: &net.TCPAddr{}, - logger: c.Logger, + client: c.Client, + rxc: make(chan []byte, 128), + closed: make(chan struct{}), + localAddr: &net.TCPAddr{}, + remoteAddr: raddr, + logger: c.Logger, } - cn.remoteAddr, _ = net.ResolveTCPAddr("tcp", addr) scheme := "http" if c.TLSEnabled { diff --git a/pkg/internal/util/pht/conn.go b/pkg/internal/util/pht/conn.go index 4247e05..818bc89 100644 --- a/pkg/internal/util/pht/conn.go +++ b/pkg/internal/util/pht/conn.go @@ -53,6 +53,7 @@ func (c *clientConn) Write(b []byte) (n int, err error) { if err != nil { return } + if c.logger.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpRequest(r, false) c.logger.Debug(string(dump)) @@ -87,6 +88,7 @@ func (c *clientConn) readLoop() { if err != nil { return err } + if c.logger.IsLevelEnabled(logger.DebugLevel) { dump, _ := httputil.DumpRequest(r, false) c.logger.Debug(string(dump))