diff --git a/chain/route.go b/chain/route.go index 73dbcd4..7052ec0 100644 --- a/chain/route.go +++ b/chain/route.go @@ -71,7 +71,9 @@ func (r *route) Dial(ctx context.Context, network, address string, opts ...DialO cc, err := r.GetNode(r.Len()-1).transport.Connect(ctx, conn, network, address) if err != nil { - conn.Close() + if conn != nil { + conn.Close() + } return nil, err } return cc, nil diff --git a/chain/transport.go b/chain/transport.go index 15a4c3f..a469510 100644 --- a/chain/transport.go +++ b/chain/transport.go @@ -17,6 +17,7 @@ type Transport struct { route Route dialer dialer.Dialer connector connector.Connector + timeout time.Duration } func (tr *Transport) Copy() *Transport { @@ -45,10 +46,15 @@ func (tr *Transport) WithConnector(connector connector.Connector) *Transport { return tr } +func (tr *Transport) WithTimeout(d time.Duration) *Transport { + tr.timeout = d + return tr +} + func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) { netd := &net_dialer.NetDialer{ Interface: tr.ifceName, - Timeout: 15 * time.Second, + Timeout: tr.timeout, } if tr.sockOpts != nil { netd.Mark = tr.sockOpts.Mark @@ -81,7 +87,16 @@ func (tr *Transport) Handshake(ctx context.Context, conn net.Conn) (net.Conn, er } func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, address string) (net.Conn, error) { - return tr.connector.Connect(ctx, conn, network, address) + netd := &net_dialer.NetDialer{ + Interface: tr.ifceName, + Timeout: tr.timeout, + } + if tr.sockOpts != nil { + netd.Mark = tr.sockOpts.Mark + } + return tr.connector.Connect(ctx, conn, network, address, + connector.NetDialerConnectOption(netd), + ) } func (tr *Transport) Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error) { diff --git a/common/net/dialer/dialer.go b/common/net/dialer/dialer.go index 738f902..cb7e7aa 100644 --- a/common/net/dialer/dialer.go +++ b/common/net/dialer/dialer.go @@ -46,7 +46,9 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, e if d.DialFunc != nil { return d.DialFunc(ctx, network, addr) } - log.Debugf("interface: %s %v/%s", ifceName, ifAddr, network) + if ifceName != "" { + log.Debugf("interface: %s %v/%s", ifceName, ifAddr, network) + } switch network { case "udp", "udp4", "udp6": diff --git a/connector/option.go b/connector/option.go index 3a3ce06..ef1b8e3 100644 --- a/connector/option.go +++ b/connector/option.go @@ -5,6 +5,7 @@ import ( "net/url" "time" + "github.com/go-gost/core/common/net/dialer" "github.com/go-gost/core/logger" ) @@ -35,10 +36,17 @@ func LoggerOption(logger logger.Logger) Option { } type ConnectOptions struct { + NetDialer *dialer.NetDialer } type ConnectOption func(opts *ConnectOptions) +func NetDialerConnectOption(netd *dialer.NetDialer) ConnectOption { + return func(opts *ConnectOptions) { + opts.NetDialer = netd + } +} + type BindOptions struct { Mux bool Backlog int