add traffic limiter for proxy handler
This commit is contained in:
235
limiter/traffic/plugin.go
Normal file
235
limiter/traffic/plugin.go
Normal file
@ -0,0 +1,235 @@
|
||||
package traffic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-gost/core/limiter/traffic"
|
||||
"github.com/go-gost/core/logger"
|
||||
"github.com/go-gost/plugin/limiter/traffic/proto"
|
||||
"github.com/go-gost/x/internal/plugin"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type grpcPlugin struct {
|
||||
conn grpc.ClientConnInterface
|
||||
client proto.LimiterClient
|
||||
log logger.Logger
|
||||
}
|
||||
|
||||
// NewGRPCPlugin creates a traffic limiter plugin based on gRPC.
|
||||
func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) traffic.TrafficLimiter {
|
||||
var options plugin.Options
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
log := logger.Default().WithFields(map[string]any{
|
||||
"kind": "limiter",
|
||||
"limiter": name,
|
||||
})
|
||||
conn, err := plugin.NewGRPCConn(addr, &options)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
p := &grpcPlugin{
|
||||
conn: conn,
|
||||
log: log,
|
||||
}
|
||||
if conn != nil {
|
||||
p.client = proto.NewLimiterClient(conn)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *grpcPlugin) In(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter {
|
||||
if p.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var options traffic.Options
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
r, err := p.client.Limit(ctx,
|
||||
&proto.LimitRequest{
|
||||
Network: options.Network,
|
||||
Addr: options.Addr,
|
||||
Client: options.Client,
|
||||
Src: options.Src,
|
||||
})
|
||||
if err != nil {
|
||||
p.log.Error(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return NewLimiter(int(r.In))
|
||||
}
|
||||
|
||||
func (p *grpcPlugin) Out(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter {
|
||||
if p.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var options traffic.Options
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
r, err := p.client.Limit(ctx,
|
||||
&proto.LimitRequest{
|
||||
Network: options.Network,
|
||||
Addr: options.Addr,
|
||||
Client: options.Client,
|
||||
Src: options.Src,
|
||||
})
|
||||
if err != nil {
|
||||
p.log.Error(err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return NewLimiter(int(r.Out))
|
||||
}
|
||||
|
||||
func (p *grpcPlugin) Close() error {
|
||||
if closer, ok := p.conn.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type httpPluginRequest struct {
|
||||
Network string `json:"network"`
|
||||
Addr string `json:"addr"`
|
||||
Client string `json:"client"`
|
||||
Src string `json:"src"`
|
||||
}
|
||||
|
||||
type httpPluginResponse struct {
|
||||
In int64 `json:"in"`
|
||||
Out int64 `json:"out"`
|
||||
}
|
||||
|
||||
type httpPlugin struct {
|
||||
url string
|
||||
client *http.Client
|
||||
header http.Header
|
||||
log logger.Logger
|
||||
}
|
||||
|
||||
// NewHTTPPlugin creates a traffic limiter plugin based on HTTP.
|
||||
func NewHTTPPlugin(name string, url string, opts ...plugin.Option) traffic.TrafficLimiter {
|
||||
var options plugin.Options
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return &httpPlugin{
|
||||
url: url,
|
||||
client: plugin.NewHTTPClient(&options),
|
||||
header: options.Header,
|
||||
log: logger.Default().WithFields(map[string]any{
|
||||
"kind": "limiter",
|
||||
"limiter": name,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *httpPlugin) In(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter {
|
||||
if p.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var options traffic.Options
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
rb := httpPluginRequest{
|
||||
Network: options.Network,
|
||||
Addr: options.Addr,
|
||||
Client: options.Client,
|
||||
Src: options.Src,
|
||||
}
|
||||
v, err := json.Marshal(&rb)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.header != nil {
|
||||
req.Header = p.header.Clone()
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
res := httpPluginResponse{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
|
||||
return nil
|
||||
}
|
||||
return NewLimiter(int(res.In))
|
||||
}
|
||||
|
||||
func (p *httpPlugin) Out(ctx context.Context, key string, opts ...traffic.Option) traffic.Limiter {
|
||||
if p.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var options traffic.Options
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
rb := httpPluginRequest{
|
||||
Network: options.Network,
|
||||
Addr: options.Addr,
|
||||
Client: options.Client,
|
||||
Src: options.Src,
|
||||
}
|
||||
v, err := json.Marshal(&rb)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if p.header != nil {
|
||||
req.Header = p.header.Clone()
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
res := httpPluginResponse{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&res); err != nil {
|
||||
return nil
|
||||
}
|
||||
return NewLimiter(int(res.Out))
|
||||
}
|
@ -121,7 +121,7 @@ func NewTrafficLimiter(opts ...Option) limiter.TrafficLimiter {
|
||||
|
||||
// In obtains a traffic input limiter based on key.
|
||||
// The key should be client connection address.
|
||||
func (l *trafficLimiter) In(key string) limiter.Limiter {
|
||||
func (l *trafficLimiter) In(ctx context.Context, key string, opts ...limiter.Option) limiter.Limiter {
|
||||
var lims []limiter.Limiter
|
||||
|
||||
// service level limiter
|
||||
@ -185,7 +185,7 @@ func (l *trafficLimiter) In(key string) limiter.Limiter {
|
||||
|
||||
// Out obtains a traffic output limiter based on key.
|
||||
// The key should be client connection address.
|
||||
func (l *trafficLimiter) Out(key string) limiter.Limiter {
|
||||
func (l *trafficLimiter) Out(ctx context.Context, key string, opts ...limiter.Option) limiter.Limiter {
|
||||
var lims []limiter.Limiter
|
||||
|
||||
// service level limiter
|
||||
|
@ -26,8 +26,8 @@ type serverConn struct {
|
||||
rbuf bytes.Buffer
|
||||
limiter limiter.TrafficLimiter
|
||||
limiterIn limiter.Limiter
|
||||
expIn int64
|
||||
limiterOut limiter.Limiter
|
||||
expIn int64
|
||||
expOut int64
|
||||
}
|
||||
|
||||
@ -35,34 +35,39 @@ func WrapConn(limiter limiter.TrafficLimiter, c net.Conn) net.Conn {
|
||||
if limiter == nil {
|
||||
return c
|
||||
}
|
||||
|
||||
return &serverConn{
|
||||
Conn: c,
|
||||
limiter: limiter,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverConn) getInLimiter(addr net.Addr) limiter.Limiter {
|
||||
func (c *serverConn) getInLimiter() limiter.Limiter {
|
||||
now := time.Now().UnixNano()
|
||||
// cache the limiter for 60s
|
||||
if c.limiter != nil && time.Duration(now-c.expIn) > 60*time.Second {
|
||||
c.limiterIn = c.limiter.In(addr.String())
|
||||
if lim := c.limiter.In(context.Background(), c.RemoteAddr().String()); lim != nil {
|
||||
c.limiterIn = lim
|
||||
}
|
||||
c.expIn = now
|
||||
}
|
||||
return c.limiterIn
|
||||
}
|
||||
|
||||
func (c *serverConn) getOutLimiter(addr net.Addr) limiter.Limiter {
|
||||
func (c *serverConn) getOutLimiter() limiter.Limiter {
|
||||
now := time.Now().UnixNano()
|
||||
// cache the limiter for 60s
|
||||
if c.limiter != nil && time.Duration(now-c.expOut) > 60*time.Second {
|
||||
c.limiterOut = c.limiter.Out(addr.String())
|
||||
if lim := c.limiter.Out(context.Background(), c.RemoteAddr().String()); lim != nil {
|
||||
c.limiterOut = lim
|
||||
}
|
||||
c.expOut = now
|
||||
}
|
||||
return c.limiterOut
|
||||
}
|
||||
|
||||
func (c *serverConn) Read(b []byte) (n int, err error) {
|
||||
limiter := c.getInLimiter(c.RemoteAddr())
|
||||
limiter := c.getInLimiter()
|
||||
if limiter == nil {
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
@ -92,7 +97,7 @@ func (c *serverConn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *serverConn) Write(b []byte) (n int, err error) {
|
||||
limiter := c.getOutLimiter(c.RemoteAddr())
|
||||
limiter := c.getOutLimiter()
|
||||
if limiter == nil {
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
@ -163,7 +168,7 @@ func (c *packetConn) getInLimiter(addr net.Addr) limiter.Limiter {
|
||||
return lim
|
||||
}
|
||||
|
||||
lim = c.limiter.In(addr.String())
|
||||
lim = c.limiter.In(context.Background(), addr.String())
|
||||
c.inLimits.Set(addr.String(), lim, 0)
|
||||
|
||||
return lim
|
||||
@ -187,7 +192,7 @@ func (c *packetConn) getOutLimiter(addr net.Addr) limiter.Limiter {
|
||||
return lim
|
||||
}
|
||||
|
||||
lim = c.limiter.Out(addr.String())
|
||||
lim = c.limiter.Out(context.Background(), addr.String())
|
||||
c.outLimits.Set(addr.String(), lim, 0)
|
||||
|
||||
return lim
|
||||
@ -266,7 +271,7 @@ func (c *udpConn) getInLimiter(addr net.Addr) limiter.Limiter {
|
||||
return lim
|
||||
}
|
||||
|
||||
lim = c.limiter.In(addr.String())
|
||||
lim = c.limiter.In(context.Background(), addr.String())
|
||||
c.inLimits.Set(addr.String(), lim, 0)
|
||||
|
||||
return lim
|
||||
@ -290,7 +295,7 @@ func (c *udpConn) getOutLimiter(addr net.Addr) limiter.Limiter {
|
||||
return lim
|
||||
}
|
||||
|
||||
lim = c.limiter.Out(addr.String())
|
||||
lim = c.limiter.Out(context.Background(), addr.String())
|
||||
c.outLimits.Set(addr.String(), lim, 0)
|
||||
|
||||
return lim
|
||||
|
109
limiter/traffic/wrapper/io.go
Normal file
109
limiter/traffic/wrapper/io.go
Normal file
@ -0,0 +1,109 @@
|
||||
package wrapper
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/core/limiter/traffic"
|
||||
limiter "github.com/go-gost/core/limiter/traffic"
|
||||
)
|
||||
|
||||
// readWriter is an io.ReadWriter with traffic limiter supported.
|
||||
type readWriter struct {
|
||||
io.ReadWriter
|
||||
rbuf bytes.Buffer
|
||||
limiter limiter.TrafficLimiter
|
||||
limiterIn limiter.Limiter
|
||||
limiterOut limiter.Limiter
|
||||
expIn int64
|
||||
expOut int64
|
||||
opts []traffic.Option
|
||||
key string
|
||||
}
|
||||
|
||||
func WrapReadWriter(limiter limiter.TrafficLimiter, rw io.ReadWriter, key string, opts ...traffic.Option) io.ReadWriter {
|
||||
if limiter == nil {
|
||||
return rw
|
||||
}
|
||||
|
||||
return &readWriter{
|
||||
ReadWriter: rw,
|
||||
limiter: limiter,
|
||||
opts: opts,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *readWriter) getInLimiter() limiter.Limiter {
|
||||
now := time.Now().UnixNano()
|
||||
// cache the limiter for 60s
|
||||
if p.limiter != nil && time.Duration(now-p.expIn) > 60*time.Second {
|
||||
if lim := p.limiter.In(context.Background(), p.key, p.opts...); lim != nil {
|
||||
p.limiterIn = lim
|
||||
}
|
||||
p.expIn = now
|
||||
}
|
||||
return p.limiterIn
|
||||
}
|
||||
|
||||
func (p *readWriter) getOutLimiter() limiter.Limiter {
|
||||
now := time.Now().UnixNano()
|
||||
// cache the limiter for 60s
|
||||
if p.limiter != nil && time.Duration(now-p.expOut) > 60*time.Second {
|
||||
if lim := p.limiter.Out(context.Background(), p.key, p.opts...); lim != nil {
|
||||
p.limiterOut = lim
|
||||
}
|
||||
p.expOut = now
|
||||
}
|
||||
return p.limiterOut
|
||||
}
|
||||
|
||||
func (p *readWriter) Read(b []byte) (n int, err error) {
|
||||
limiter := p.getInLimiter()
|
||||
if limiter == nil {
|
||||
return p.ReadWriter.Read(b)
|
||||
}
|
||||
|
||||
if p.rbuf.Len() > 0 {
|
||||
burst := len(b)
|
||||
if p.rbuf.Len() < burst {
|
||||
burst = p.rbuf.Len()
|
||||
}
|
||||
lim := limiter.Wait(context.Background(), burst)
|
||||
return p.rbuf.Read(b[:lim])
|
||||
}
|
||||
|
||||
nn, err := p.ReadWriter.Read(b)
|
||||
if err != nil {
|
||||
return nn, err
|
||||
}
|
||||
|
||||
n = limiter.Wait(context.Background(), nn)
|
||||
if n < nn {
|
||||
if _, err = p.rbuf.Write(b[n:nn]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (p *readWriter) Write(b []byte) (n int, err error) {
|
||||
limiter := p.getOutLimiter()
|
||||
if limiter == nil {
|
||||
return p.ReadWriter.Write(b)
|
||||
}
|
||||
|
||||
nn := 0
|
||||
for len(b) > 0 {
|
||||
nn, err = p.ReadWriter.Write(b[:limiter.Wait(context.Background(), len(b))])
|
||||
n += nn
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
b = b[nn:]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
Reference in New Issue
Block a user