add traffic limiter for proxy handler

This commit is contained in:
ginuerzh
2023-11-18 18:28:09 +08:00
parent 330631fd79
commit 88cc6ff4d5
38 changed files with 633 additions and 200 deletions

View File

@ -15,7 +15,7 @@ import (
"github.com/go-gost/core/metrics"
"github.com/go-gost/core/recorder"
"github.com/go-gost/core/service"
sx "github.com/go-gost/x/internal/util/selector"
ctxvalue "github.com/go-gost/x/internal/ctx"
xmetrics "github.com/go-gost/x/metrics"
"github.com/rs/xid"
)
@ -145,20 +145,26 @@ func (s *defaultService) Serve() error {
}
tempDelay = 0
host := conn.RemoteAddr().String()
if h, _, _ := net.SplitHostPort(host); h != "" {
host = h
clientAddr := conn.RemoteAddr().String()
clientIP := clientAddr
if h, _, _ := net.SplitHostPort(clientAddr); h != "" {
clientIP = h
}
ctx := ctxvalue.ContextWithSid(context.Background(), ctxvalue.Sid(xid.New().String()))
ctx = ctxvalue.ContextWithClientAddr(ctx, ctxvalue.ClientAddr(clientAddr))
ctx = ctxvalue.ContextWithHash(ctx, &ctxvalue.Hash{Source: clientIP})
for _, rec := range s.options.recorders {
if rec.Record == recorder.RecorderServiceClientAddress {
if err := rec.Recorder.Record(context.Background(), []byte(host)); err != nil {
if err := rec.Recorder.Record(ctx, []byte(clientIP)); err != nil {
s.options.logger.Errorf("record %s: %v", rec.Record, err)
}
break
}
}
if s.options.admission != nil &&
!s.options.admission.Admit(context.Background(), conn.RemoteAddr().String()) {
!s.options.admission.Admit(ctx, conn.RemoteAddr().String()) {
conn.Close()
s.options.logger.Debugf("admission: %s is denied", conn.RemoteAddr())
continue
@ -166,12 +172,12 @@ func (s *defaultService) Serve() error {
go func() {
if v := xmetrics.GetCounter(xmetrics.MetricServiceRequestsCounter,
metrics.Labels{"service": s.name, "client": host}); v != nil {
metrics.Labels{"service": s.name, "client": clientIP}); v != nil {
v.Inc()
}
if v := xmetrics.GetGauge(xmetrics.MetricServiceRequestsInFlightGauge,
metrics.Labels{"service": s.name, "client": host}); v != nil {
metrics.Labels{"service": s.name, "client": clientIP}); v != nil {
v.Inc()
defer v.Dec()
}
@ -184,13 +190,10 @@ func (s *defaultService) Serve() error {
}()
}
ctx := sx.ContextWithHash(context.Background(), &sx.Hash{Source: host})
ctx = ContextWithSid(ctx, xid.New().String())
if err := s.handler.Handle(ctx, conn); err != nil {
s.options.logger.Error(err)
if v := xmetrics.GetCounter(xmetrics.MetricServiceHandlerErrorsCounter,
metrics.Labels{"service": s.name, "client": host}); v != nil {
metrics.Labels{"service": s.name, "client": clientIP}); v != nil {
v.Inc()
}
}
@ -211,18 +214,3 @@ func (s *defaultService) execCmds(phase string, cmds []string) {
}
}
}
type sidKey struct{}
var (
ssid sidKey
)
func ContextWithSid(ctx context.Context, sid string) context.Context {
return context.WithValue(ctx, ssid, sid)
}
func SidFromContext(ctx context.Context) string {
v, _ := ctx.Value(ssid).(string)
return v
}