From 4e831b95e8cce5237022bb94e6603c196a1f3e4f Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Tue, 25 Jun 2024 20:37:08 +0800 Subject: [PATCH] fix timeout for router --- chain/route.go | 8 -------- chain/router.go | 17 ++++++++++++++++- chain/transport.go | 10 ---------- common/net/dialer/dialer.go | 16 ++-------------- listener/option.go | 13 +++++++------ 5 files changed, 25 insertions(+), 39 deletions(-) diff --git a/chain/route.go b/chain/route.go index 0e09203..8a4dbba 100644 --- a/chain/route.go +++ b/chain/route.go @@ -36,7 +36,6 @@ func (*route) Dial(ctx context.Context, network, address string, opts ...DialOpt } netd := dialer.NetDialer{ - Timeout: options.Timeout, Interface: options.Interface, Netns: options.Netns, Logger: options.Logger, @@ -94,7 +93,6 @@ func (r *route) Nodes() []*Node { } type DialOptions struct { - Timeout time.Duration Interface string Netns string SockOpts *SockOpts @@ -103,12 +101,6 @@ type DialOptions struct { type DialOption func(opts *DialOptions) -func TimeoutDialOption(d time.Duration) DialOption { - return func(opts *DialOptions) { - opts.Timeout = d - } -} - func InterfaceDialOption(ifName string) DialOption { return func(opts *DialOptions) { opts.Interface = ifName diff --git a/chain/router.go b/chain/router.go index 5552525..9189ce5 100644 --- a/chain/router.go +++ b/chain/router.go @@ -103,6 +103,10 @@ func NewRouter(opts ...RouterOption) *Router { opt(&r.options) } } + if r.options.Timeout == 0 { + r.options.Timeout = 15 * time.Second + } + if r.options.Logger == nil { r.options.Logger = logger.Default().WithFields(map[string]any{"kind": "router"}) } @@ -117,6 +121,12 @@ func (r *Router) Options() *RouterOptions { } func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) { + if r.options.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, r.options.Timeout) + defer cancel() + } + host := address if h, _, _ := net.SplitHostPort(address); h != "" { host = h @@ -191,7 +201,6 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co NetnsDialOption(r.options.Netns), SockOptsDialOption(r.options.SockOpts), LoggerDialOption(r.options.Logger), - TimeoutDialOption(r.options.Timeout), ) if err == nil { break @@ -203,6 +212,12 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co } func (r *Router) Bind(ctx context.Context, network, address string, opts ...BindOption) (ln net.Listener, err error) { + if r.options.Timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, r.options.Timeout) + defer cancel() + } + count := r.options.Retries + 1 if count <= 0 { count = 1 diff --git a/chain/transport.go b/chain/transport.go index d68e94f..f0a1562 100644 --- a/chain/transport.go +++ b/chain/transport.go @@ -3,7 +3,6 @@ package chain import ( "context" "net" - "time" net_dialer "github.com/go-gost/core/common/net/dialer" "github.com/go-gost/core/connector" @@ -16,7 +15,6 @@ type TransportOptions struct { Netns string SockOpts *SockOpts Route Route - Timeout time.Duration } type TransportOption func(*TransportOptions) @@ -51,12 +49,6 @@ func RouteTransportOption(route Route) TransportOption { } } -func TimeoutTransportOption(timeout time.Duration) TransportOption { - return func(o *TransportOptions) { - o.Timeout = timeout - } -} - type Transport struct { dialer dialer.Dialer connector connector.Connector @@ -81,7 +73,6 @@ func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) { netd := &net_dialer.NetDialer{ Interface: tr.options.IfceName, Netns: tr.options.Netns, - Timeout: tr.options.Timeout, } if tr.options.SockOpts != nil { netd.Mark = tr.options.SockOpts.Mark @@ -117,7 +108,6 @@ func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, addres netd := &net_dialer.NetDialer{ Interface: tr.options.IfceName, Netns: tr.options.Netns, - Timeout: tr.options.Timeout, } if tr.options.SockOpts != nil { netd.Mark = tr.options.SockOpts.Mark diff --git a/common/net/dialer/dialer.go b/common/net/dialer/dialer.go index 0b63ff2..e04dddd 100644 --- a/common/net/dialer/dialer.go +++ b/common/net/dialer/dialer.go @@ -26,7 +26,6 @@ type NetDialer struct { Interface string Netns string Mark int - Timeout time.Duration DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) Logger logger.Logger } @@ -62,11 +61,6 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co } } - timeout := d.Timeout - if timeout <= 0 { - timeout = DefaultTimeout - } - if d.DialFunc != nil { return d.DialFunc(ctx, network, addr) } @@ -78,7 +72,6 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co default: } - deadline := time.Now().Add(timeout) ifces := strings.Split(d.Interface, ",") for _, ifce := range ifces { strict := strings.HasSuffix(ifce, "!") @@ -91,7 +84,7 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co } for _, ifAddr := range ifAddrs { - conn, err = d.dialOnce(ctx, network, addr, ifceName, ifAddr, deadline, log) + conn, err = d.dialOnce(ctx, network, addr, ifceName, ifAddr, log) if err == nil { return } @@ -103,17 +96,13 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co !strings.Contains(err.Error(), "mismatched local address type") { return } - - if time.Until(deadline) < 0 { - return - } } } return } -func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string, ifAddr net.Addr, deadline time.Time, log logger.Logger) (net.Conn, error) { +func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string, ifAddr net.Addr, log logger.Logger) (net.Conn, error) { if ifceName != "" { log.Debugf("interface: %s %v/%s", ifceName, ifAddr, network) } @@ -157,7 +146,6 @@ func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string return nil, fmt.Errorf("dial: unsupported network %s", network) } netd := net.Dialer{ - Deadline: deadline, LocalAddr: ifAddr, Control: func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { diff --git a/listener/option.go b/listener/option.go index 3a0d721..48c9fa1 100644 --- a/listener/option.go +++ b/listener/option.go @@ -27,6 +27,7 @@ type Options struct { Service string ProxyProtocol int Netns string + Router *chain.Router } type Option func(opts *Options) @@ -73,12 +74,6 @@ func ConnLimiterOption(limiter conn.ConnLimiter) Option { } } -func ChainOption(chain chain.Chainer) Option { - return func(opts *Options) { - opts.Chain = chain - } -} - func StatsOption(stats *stats.Stats) Option { return func(opts *Options) { opts.Stats = stats @@ -108,3 +103,9 @@ func NetnsOption(netns string) Option { opts.Netns = netns } } + +func RouterOption(router *chain.Router) Option { + return func(opts *Options) { + opts.Router = router + } +}