From ba87494cef4bfab14ee635d4adf8185082fa2b77 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Mon, 19 Dec 2022 19:34:13 +0800 Subject: [PATCH] add Set for traffic Limiter --- common/net/dialer/dialer.go | 17 +++++++++-------- limiter/traffic/limiter.go | 1 + 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/common/net/dialer/dialer.go b/common/net/dialer/dialer.go index 36fe0c7..8bcf6b1 100644 --- a/common/net/dialer/dialer.go +++ b/common/net/dialer/dialer.go @@ -26,15 +26,16 @@ type NetDialer struct { Timeout time.Duration DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) Logger logger.Logger - deadline time.Time } func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Conn, err error) { if d == nil { d = DefaultNetDialer } - if d.Timeout <= 0 { - d.Timeout = DefaultTimeout + + timeout := d.Timeout + if timeout <= 0 { + timeout = DefaultTimeout } if d.DialFunc != nil { @@ -46,8 +47,8 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co log = logger.Default() } + deadline := time.Now().Add(timeout) ifces := strings.Split(d.Interface, ",") - d.deadline = time.Now().Add(d.Timeout) for _, ifce := range ifces { strict := strings.HasSuffix(ifce, "!") ifce = strings.TrimSuffix(ifce, "!") @@ -59,7 +60,7 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co } for _, ifAddr := range ifAddrs { - conn, err = d.dialOnce(ctx, network, addr, ifceName, ifAddr, log) + conn, err = d.dialOnce(ctx, network, addr, ifceName, ifAddr, deadline, log) if err == nil { return } @@ -72,7 +73,7 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co return } - if time.Until(d.deadline) < 0 { + if time.Until(deadline) < 0 { return } } @@ -81,7 +82,7 @@ func (d *NetDialer) Dial(ctx context.Context, network, addr string) (conn net.Co return } -func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string, ifAddr net.Addr, log logger.Logger) (net.Conn, error) { +func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string, ifAddr net.Addr, deadline time.Time, log logger.Logger) (net.Conn, error) { if ifceName != "" { log.Debugf("interface: %s %v/%s", ifceName, ifAddr, network) } @@ -125,7 +126,7 @@ func (d *NetDialer) dialOnce(ctx context.Context, network, addr, ifceName string return nil, fmt.Errorf("dial: unsupported network %s", network) } netd := net.Dialer{ - Deadline: d.deadline, + Deadline: deadline, LocalAddr: ifAddr, Control: func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { diff --git a/limiter/traffic/limiter.go b/limiter/traffic/limiter.go index b439b58..523fab9 100644 --- a/limiter/traffic/limiter.go +++ b/limiter/traffic/limiter.go @@ -7,6 +7,7 @@ type Limiter interface { // the returned value is less or equal to n. Wait(ctx context.Context, n int) int Limit() int + Set(n int) } type TrafficLimiter interface {