add traffic limiter for proxy handler
This commit is contained in:
@ -15,7 +15,7 @@ import (
|
||||
"github.com/go-gost/core/metrics"
|
||||
"github.com/go-gost/core/recorder"
|
||||
"github.com/go-gost/core/service"
|
||||
sx "github.com/go-gost/x/internal/util/selector"
|
||||
ctxvalue "github.com/go-gost/x/internal/ctx"
|
||||
xmetrics "github.com/go-gost/x/metrics"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
@ -145,20 +145,26 @@ func (s *defaultService) Serve() error {
|
||||
}
|
||||
tempDelay = 0
|
||||
|
||||
host := conn.RemoteAddr().String()
|
||||
if h, _, _ := net.SplitHostPort(host); h != "" {
|
||||
host = h
|
||||
clientAddr := conn.RemoteAddr().String()
|
||||
clientIP := clientAddr
|
||||
if h, _, _ := net.SplitHostPort(clientAddr); h != "" {
|
||||
clientIP = h
|
||||
}
|
||||
|
||||
ctx := ctxvalue.ContextWithSid(context.Background(), ctxvalue.Sid(xid.New().String()))
|
||||
ctx = ctxvalue.ContextWithClientAddr(ctx, ctxvalue.ClientAddr(clientAddr))
|
||||
ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: clientIP})
|
||||
|
||||
for _, rec := range s.options.recorders {
|
||||
if rec.Record == recorder.RecorderServiceClientAddress {
|
||||
if err := rec.Recorder.Record(context.Background(), []byte(host)); err != nil {
|
||||
if err := rec.Recorder.Record(ctx, []byte(clientIP)); err != nil {
|
||||
s.options.logger.Errorf("record %s: %v", rec.Record, err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if s.options.admission != nil &&
|
||||
!s.options.admission.Admit(context.Background(), conn.RemoteAddr().String()) {
|
||||
!s.options.admission.Admit(ctx, conn.RemoteAddr().String()) {
|
||||
conn.Close()
|
||||
s.options.logger.Debugf("admission: %s is denied", conn.RemoteAddr())
|
||||
continue
|
||||
@ -166,12 +172,12 @@ func (s *defaultService) Serve() error {
|
||||
|
||||
go func() {
|
||||
if v := xmetrics.GetCounter(xmetrics.MetricServiceRequestsCounter,
|
||||
metrics.Labels{"service": s.name, "client": host}); v != nil {
|
||||
metrics.Labels{"service": s.name, "client": clientIP}); v != nil {
|
||||
v.Inc()
|
||||
}
|
||||
|
||||
if v := xmetrics.GetGauge(xmetrics.MetricServiceRequestsInFlightGauge,
|
||||
metrics.Labels{"service": s.name, "client": host}); v != nil {
|
||||
metrics.Labels{"service": s.name, "client": clientIP}); v != nil {
|
||||
v.Inc()
|
||||
defer v.Dec()
|
||||
}
|
||||
@ -184,13 +190,10 @@ func (s *defaultService) Serve() error {
|
||||
}()
|
||||
}
|
||||
|
||||
ctx := sx.ContextWithHash(context.Background(), &sx.Hash{Source: host})
|
||||
ctx = ContextWithSid(ctx, xid.New().String())
|
||||
|
||||
if err := s.handler.Handle(ctx, conn); err != nil {
|
||||
s.options.logger.Error(err)
|
||||
if v := xmetrics.GetCounter(xmetrics.MetricServiceHandlerErrorsCounter,
|
||||
metrics.Labels{"service": s.name, "client": host}); v != nil {
|
||||
metrics.Labels{"service": s.name, "client": clientIP}); v != nil {
|
||||
v.Inc()
|
||||
}
|
||||
}
|
||||
@ -211,18 +214,3 @@ func (s *defaultService) execCmds(phase string, cmds []string) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type sidKey struct{}
|
||||
|
||||
var (
|
||||
ssid sidKey
|
||||
)
|
||||
|
||||
func ContextWithSid(ctx context.Context, sid string) context.Context {
|
||||
return context.WithValue(ctx, ssid, sid)
|
||||
}
|
||||
|
||||
func SidFromContext(ctx context.Context) string {
|
||||
v, _ := ctx.Value(ssid).(string)
|
||||
return v
|
||||
}
|
||||
|
Reference in New Issue
Block a user