fix timeout for router

This commit is contained in:
ginuerzh 2024-06-25 20:37:08 +08:00
parent ca340b1bf1
commit 4e831b95e8
5 changed files with 25 additions and 39 deletions

View File

@ -36,7 +36,6 @@ func (*route) Dial(ctx context.Context, network, address string, opts ...DialOpt
} }
netd := dialer.NetDialer{ netd := dialer.NetDialer{
Timeout: options.Timeout,
Interface: options.Interface, Interface: options.Interface,
Netns: options.Netns, Netns: options.Netns,
Logger: options.Logger, Logger: options.Logger,
@ -94,7 +93,6 @@ func (r *route) Nodes() []*Node {
} }
type DialOptions struct { type DialOptions struct {
Timeout time.Duration
Interface string Interface string
Netns string Netns string
SockOpts *SockOpts SockOpts *SockOpts
@ -103,12 +101,6 @@ type DialOptions struct {
type DialOption func(opts *DialOptions) type DialOption func(opts *DialOptions)
func TimeoutDialOption(d time.Duration) DialOption {
return func(opts *DialOptions) {
opts.Timeout = d
}
}
func InterfaceDialOption(ifName string) DialOption { func InterfaceDialOption(ifName string) DialOption {
return func(opts *DialOptions) { return func(opts *DialOptions) {
opts.Interface = ifName opts.Interface = ifName

View File

@ -103,6 +103,10 @@ func NewRouter(opts ...RouterOption) *Router {
opt(&r.options) opt(&r.options)
} }
} }
if r.options.Timeout == 0 {
r.options.Timeout = 15 * time.Second
}
if r.options.Logger == nil { if r.options.Logger == nil {
r.options.Logger = logger.Default().WithFields(map[string]any{"kind": "router"}) r.options.Logger = logger.Default().WithFields(map[string]any{"kind": "router"})
} }
@ -117,6 +121,12 @@ func (r *Router) Options() *RouterOptions {
} }
func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) { func (r *Router) Dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
if r.options.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, r.options.Timeout)
defer cancel()
}
host := address host := address
if h, _, _ := net.SplitHostPort(address); h != "" { if h, _, _ := net.SplitHostPort(address); h != "" {
host = h host = h
@ -191,7 +201,6 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
NetnsDialOption(r.options.Netns), NetnsDialOption(r.options.Netns),
SockOptsDialOption(r.options.SockOpts), SockOptsDialOption(r.options.SockOpts),
LoggerDialOption(r.options.Logger), LoggerDialOption(r.options.Logger),
TimeoutDialOption(r.options.Timeout),
) )
if err == nil { if err == nil {
break break
@ -203,6 +212,12 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
} }
func (r *Router) Bind(ctx context.Context, network, address string, opts ...BindOption) (ln net.Listener, err error) { func (r *Router) Bind(ctx context.Context, network, address string, opts ...BindOption) (ln net.Listener, err error) {
if r.options.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, r.options.Timeout)
defer cancel()
}
count := r.options.Retries + 1 count := r.options.Retries + 1
if count <= 0 { if count <= 0 {
count = 1 count = 1

View File

@ -3,7 +3,6 @@ package chain
import ( import (
"context" "context"
"net" "net"
"time"
net_dialer "github.com/go-gost/core/common/net/dialer" net_dialer "github.com/go-gost/core/common/net/dialer"
"github.com/go-gost/core/connector" "github.com/go-gost/core/connector"
@ -16,7 +15,6 @@ type TransportOptions struct {
Netns string Netns string
SockOpts *SockOpts SockOpts *SockOpts
Route Route Route Route
Timeout time.Duration
} }
type TransportOption func(*TransportOptions) type TransportOption func(*TransportOptions)
@ -51,12 +49,6 @@ func RouteTransportOption(route Route) TransportOption {
} }
} }
func TimeoutTransportOption(timeout time.Duration) TransportOption {
return func(o *TransportOptions) {
o.Timeout = timeout
}
}
type Transport struct { type Transport struct {
dialer dialer.Dialer dialer dialer.Dialer
connector connector.Connector connector connector.Connector
@ -81,7 +73,6 @@ func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) {
netd := &net_dialer.NetDialer{ netd := &net_dialer.NetDialer{
Interface: tr.options.IfceName, Interface: tr.options.IfceName,
Netns: tr.options.Netns, Netns: tr.options.Netns,
Timeout: tr.options.Timeout,
} }
if tr.options.SockOpts != nil { if tr.options.SockOpts != nil {
netd.Mark = tr.options.SockOpts.Mark netd.Mark = tr.options.SockOpts.Mark
@ -117,7 +108,6 @@ func (tr *Transport) Connect(ctx context.Context, conn net.Conn, network, addres
netd := &net_dialer.NetDialer{ netd := &net_dialer.NetDialer{
Interface: tr.options.IfceName, Interface: tr.options.IfceName,
Netns: tr.options.Netns, Netns: tr.options.Netns,
Timeout: tr.options.Timeout,
} }
if tr.options.SockOpts != nil { if tr.options.SockOpts != nil {
netd.Mark = tr.options.SockOpts.Mark netd.Mark = tr.options.SockOpts.Mark

View File

@ -26,7 +26,6 @@ type NetDialer struct {
Interface string Interface string
Netns string Netns string
Mark int Mark int
Timeout time.Duration
DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
Logger logger.Logger Logger logger.Logger
} }
@ -62,11 +61,6 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co
} }
} }
timeout := d.Timeout
if timeout <= 0 {
timeout = DefaultTimeout
}
if d.DialFunc != nil { if d.DialFunc != nil {
return d.DialFunc(ctx, network, addr) return d.DialFunc(ctx, network, addr)
} }
@ -78,7 +72,6 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co
default: default:
} }
deadline := time.Now().Add(timeout)
ifces := strings.Split(d.Interface, ",") ifces := strings.Split(d.Interface, ",")
for _, ifce := range ifces { for _, ifce := range ifces {
strict := strings.HasSuffix(ifce, "!") strict := strings.HasSuffix(ifce, "!")
@ -91,7 +84,7 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co
} }
for _, ifAddr := range ifAddrs { for _, ifAddr := range ifAddrs {
conn, err = d.dialOnce(ctx, network, addr, ifceName, ifAddr, deadline, log) conn, err = d.dialOnce(ctx, network, addr, ifceName, ifAddr, log)
if err == nil { if err == nil {
return return
} }
@ -103,17 +96,13 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co
!strings.Contains(err.Error(), "mismatched local address type") { !strings.Contains(err.Error(), "mismatched local address type") {
return return
} }
if time.Until(deadline) < 0 {
return
}
} }
} }
return return
} }
func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string, ifAddr net.Addr, deadline time.Time, log logger.Logger) (net.Conn, error) { func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string, ifAddr net.Addr, log logger.Logger) (net.Conn, error) {
if ifceName != "" { if ifceName != "" {
log.Debugf("interface: %s %v/%s", ifceName, ifAddr, network) log.Debugf("interface: %s %v/%s", ifceName, ifAddr, network)
} }
@ -157,7 +146,6 @@ func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string
return nil, fmt.Errorf("dial: unsupported network %s", network) return nil, fmt.Errorf("dial: unsupported network %s", network)
} }
netd := net.Dialer{ netd := net.Dialer{
Deadline: deadline,
LocalAddr: ifAddr, LocalAddr: ifAddr,
Control: func(network, address string, c syscall.RawConn) error { Control: func(network, address string, c syscall.RawConn) error {
return c.Control(func(fd uintptr) { return c.Control(func(fd uintptr) {

View File

@ -27,6 +27,7 @@ type Options struct {
Service string Service string
ProxyProtocol int ProxyProtocol int
Netns string Netns string
Router *chain.Router
} }
type Option func(opts *Options) type Option func(opts *Options)
@ -73,12 +74,6 @@ func ConnLimiterOption(limiter conn.ConnLimiter) Option {
} }
} }
func ChainOption(chain chain.Chainer) Option {
return func(opts *Options) {
opts.Chain = chain
}
}
func StatsOption(stats *stats.Stats) Option { func StatsOption(stats *stats.Stats) Option {
return func(opts *Options) { return func(opts *Options) {
opts.Stats = stats opts.Stats = stats
@ -108,3 +103,9 @@ func NetnsOption(netns string) Option {
opts.Netns = netns opts.Netns = netns
} }
} }
func RouterOption(router *chain.Router) Option {
return func(opts *Options) {
opts.Router = router
}
}