diff --git a/chain/route.go b/chain/route.go index 9d1f117..4e176f1 100644 --- a/chain/route.go +++ b/chain/route.go @@ -10,6 +10,7 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/core/metrics" "github.com/go-gost/core/selector" + xmetrics "github.com/go-gost/x/metrics" ) type RouteOptions struct { @@ -124,7 +125,7 @@ func (r *route) connect(ctx context.Context, logger logger.Logger) (conn net.Con if marker != nil { marker.Mark() } - if v := metrics.GetCounter(metrics.MetricChainErrorsCounter, + if v := xmetrics.GetCounter(xmetrics.MetricChainErrorsCounter, metrics.Labels{"chain": name, "node": node.Name}); v != nil { v.Inc() } @@ -171,7 +172,7 @@ func (r *route) connect(ctx context.Context, logger logger.Logger) (conn net.Con if cn, _ := r.options.Chain.(chainNamer); cn != nil { name = cn.Name() } - if v := metrics.GetObserver(metrics.MetricNodeConnectDurationObserver, + if v := xmetrics.GetObserver(xmetrics.MetricNodeConnectDurationObserver, metrics.Labels{"chain": name, "node": node.Name}); v != nil { v.Observe(time.Since(start).Seconds()) } diff --git a/config/parsing/parse.go b/config/parsing/parse.go index 6340e3a..2c2a301 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -35,6 +35,7 @@ const ( mdKeyProxyProtocol = "proxyProtocol" mdKeyInterface = "interface" mdKeySoMark = "so_mark" + mdKeyHash = "hash" ) func ParseAuther(cfg *config.AutherConfig) auth.Authenticator { @@ -120,6 +121,8 @@ func parseChainSelector(cfg *config.SelectorConfig) selector.Selector[chain.Chai strategy = xs.RandomStrategy[chain.Chainer]() case "fifo", "ha": strategy = xs.FIFOStrategy[chain.Chainer]() + case "hash": + strategy = xs.HashStrategy[chain.Chainer]() default: strategy = xs.RoundRobinStrategy[chain.Chainer]() } @@ -143,6 +146,8 @@ func parseNodeSelector(cfg *config.SelectorConfig) selector.Selector[*chain.Node strategy = xs.RandomStrategy[*chain.Node]() case "fifo", "ha": strategy = xs.FIFOStrategy[*chain.Node]() + case "hash": + strategy = xs.HashStrategy[*chain.Node]() default: strategy = xs.RoundRobinStrategy[*chain.Node]() } diff --git a/config/parsing/service.go b/config/parsing/service.go index 1183af7..1552b12 100644 --- a/config/parsing/service.go +++ b/config/parsing/service.go @@ -20,6 +20,7 @@ import ( tls_util "github.com/go-gost/x/internal/util/tls" "github.com/go-gost/x/metadata" "github.com/go-gost/x/registry" + xservice "github.com/go-gost/x/service" ) func ParseService(cfg *config.ServiceConfig) (service.Service, error) { @@ -200,9 +201,9 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { return nil, err } - s := service.NewService(cfg.Name, ln, h, - service.AdmissionOption(admission.AdmissionGroup(admissions...)), - service.LoggerOption(serviceLogger), + s := xservice.NewService(cfg.Name, ln, h, + xservice.AdmissionOption(admission.AdmissionGroup(admissions...)), + xservice.LoggerOption(serviceLogger), ) serviceLogger.Infof("listening on %s/%s", s.Addr().String(), s.Addr().Network()) diff --git a/go.mod b/go.mod index 81e06dd..2b45c7a 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.7.7 - github.com/go-gost/core v0.0.0-20220914115321-50d443049f3b + github.com/go-gost/core v0.0.0-20220928034632-6e7a8f461903 github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 diff --git a/go.sum b/go.sum index ae5c28f..e754954 100644 --- a/go.sum +++ b/go.sum @@ -98,10 +98,8 @@ github.com/gin-gonic/gin v1.7.7/go.mod h1:axIBovoeJpVj8S3BwE0uPMTeReE4+AfFtqpqaZ github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gost/core v0.0.0-20220914115321-50d443049f3b h1:fWUPYFp0W/6GEhL0wrURGPQN2AQHhf4IZKiALJJOJh8= -github.com/go-gost/core v0.0.0-20220914115321-50d443049f3b/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= -github.com/go-gost/core v0.0.0-20220920034830-41ff9835a66d h1:UFn21xIJgWE/te12rzQA7Ymwbo+MaxOcp38K41L+Yck= -github.com/go-gost/core v0.0.0-20220920034830-41ff9835a66d/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= +github.com/go-gost/core v0.0.0-20220928034632-6e7a8f461903 h1:Ye6Ns0+Ms63vC+nbe9sBgBDTr+l+ukPX18SvEJuWXUw= +github.com/go-gost/core v0.0.0-20220928034632-6e7a8f461903/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= diff --git a/handler/http/handler.go b/handler/http/handler.go index 64fc9a3..acbce83 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -21,6 +21,7 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" netpkg "github.com/go-gost/x/internal/net" + sx "github.com/go-gost/x/internal/util/selector" "github.com/go-gost/x/registry" ) @@ -177,6 +178,11 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt req.Header.Del("Proxy-Authorization") + switch h.md.hash { + case "host": + ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: addr}) + } + cc, err := h.router.Dial(ctx, network, addr) if err != nil { resp.StatusCode = http.StatusServiceUnavailable diff --git a/handler/http/metadata.go b/handler/http/metadata.go index 5fed552..7df1c5d 100644 --- a/handler/http/metadata.go +++ b/handler/http/metadata.go @@ -12,6 +12,7 @@ type metadata struct { probeResistance *probeResistance enableUDP bool header http.Header + hash string } func (h *httpHandler) parseMetadata(md mdata.Metadata) error { @@ -21,6 +22,7 @@ func (h *httpHandler) parseMetadata(md mdata.Metadata) error { probeResistKeyX = "probe_resist" knock = "knock" enableUDP = "udp" + hash = "hash" ) if m := mdutil.GetStringMapString(md, header); len(m) > 0 { @@ -45,6 +47,7 @@ func (h *httpHandler) parseMetadata(md mdata.Metadata) error { } } h.md.enableUDP = mdutil.GetBool(md, enableUDP) + h.md.hash = mdutil.GetString(md, hash) return nil } diff --git a/handler/http2/handler.go b/handler/http2/handler.go index 6babdd4..98db7bf 100644 --- a/handler/http2/handler.go +++ b/handler/http2/handler.go @@ -23,6 +23,7 @@ import ( "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" netpkg "github.com/go-gost/x/internal/net" + sx "github.com/go-gost/x/internal/util/selector" "github.com/go-gost/x/registry" ) @@ -157,6 +158,11 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req req.Header.Del("Proxy-Authorization") req.Header.Del("Proxy-Connection") + switch h.md.hash { + case "host": + ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: addr}) + } + cc, err := h.router.Dial(ctx, "tcp", addr) if err != nil { log.Error(err) diff --git a/handler/http2/metadata.go b/handler/http2/metadata.go index 1b560af..2dfa4f5 100644 --- a/handler/http2/metadata.go +++ b/handler/http2/metadata.go @@ -11,6 +11,7 @@ import ( type metadata struct { probeResistance *probeResistance header http.Header + hash string } func (h *http2Handler) parseMetadata(md mdata.Metadata) error { @@ -19,6 +20,7 @@ func (h *http2Handler) parseMetadata(md mdata.Metadata) error { probeResistKey = "probeResistance" probeResistKeyX = "probe_resist" knock = "knock" + hash = "hash" ) if m := mdutil.GetStringMapString(md, header); len(m) > 0 { @@ -42,6 +44,7 @@ func (h *http2Handler) parseMetadata(md mdata.Metadata) error { } } } + h.md.hash = mdutil.GetString(md, hash) return nil } diff --git a/handler/relay/connect.go b/handler/relay/connect.go index 52dffe1..45f82c6 100644 --- a/handler/relay/connect.go +++ b/handler/relay/connect.go @@ -10,6 +10,7 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/relay" netpkg "github.com/go-gost/x/internal/net" + sx "github.com/go-gost/x/internal/util/selector" ) func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { @@ -40,6 +41,10 @@ func (h *relayHandler) handleConnect(ctx context.Context, conn net.Conn, network return err } + switch h.md.hash { + case "host": + ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: address}) + } cc, err := h.router.Dial(ctx, network, address) if err != nil { resp.Status = relay.StatusNetworkUnreachable diff --git a/handler/relay/metadata.go b/handler/relay/metadata.go index b91b37e..27fbae7 100644 --- a/handler/relay/metadata.go +++ b/handler/relay/metadata.go @@ -13,6 +13,7 @@ type metadata struct { enableBind bool udpBufferSize int noDelay bool + hash string } func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) { @@ -21,6 +22,7 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) { enableBind = "bind" udpBufferSize = "udpBufferSize" noDelay = "nodelay" + hash = "hash" ) h.md.readTimeout = mdutil.GetDuration(md, readTimeout) @@ -32,5 +34,7 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) { } else { h.md.udpBufferSize = 1500 } + + h.md.hash = mdutil.GetString(md, hash) return } diff --git a/handler/sni/handler.go b/handler/sni/handler.go index 74b91ed..abf8828 100644 --- a/handler/sni/handler.go +++ b/handler/sni/handler.go @@ -21,6 +21,7 @@ import ( md "github.com/go-gost/core/metadata" dissector "github.com/go-gost/tls-dissector" netpkg "github.com/go-gost/x/internal/net" + sx "github.com/go-gost/x/internal/util/selector" "github.com/go-gost/x/registry" ) @@ -121,6 +122,11 @@ func (h *sniHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, raddr net return nil } + switch h.md.hash { + case "host": + ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: host}) + } + cc, err := h.router.Dial(ctx, "tcp", host) if err != nil { log.Error(err) @@ -179,6 +185,11 @@ func (h *sniHandler) handleHTTPS(ctx context.Context, rw io.ReadWriter, raddr ne return nil } + switch h.md.hash { + case "host": + ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: host}) + } + cc, err := h.router.Dial(ctx, "tcp", host) if err != nil { log.Error(err) diff --git a/handler/sni/metadata.go b/handler/sni/metadata.go index 4c383fb..5626263 100644 --- a/handler/sni/metadata.go +++ b/handler/sni/metadata.go @@ -9,13 +9,16 @@ import ( type metadata struct { readTimeout time.Duration + hash string } func (h *sniHandler) parseMetadata(md mdata.Metadata) (err error) { const ( readTimeout = "readTimeout" + hash = "hash" ) h.md.readTimeout = mdutil.GetDuration(md, readTimeout) + h.md.hash = mdutil.GetString(md, hash) return } diff --git a/handler/socks/v4/handler.go b/handler/socks/v4/handler.go index b7b6050..67af4ba 100644 --- a/handler/socks/v4/handler.go +++ b/handler/socks/v4/handler.go @@ -12,6 +12,7 @@ import ( md "github.com/go-gost/core/metadata" "github.com/go-gost/gosocks4" netpkg "github.com/go-gost/x/internal/net" + sx "github.com/go-gost/x/internal/util/selector" "github.com/go-gost/x/registry" ) @@ -123,6 +124,11 @@ func (h *socks4Handler) handleConnect(ctx context.Context, conn net.Conn, req *g return resp.Write(conn) } + switch h.md.hash { + case "host": + ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: addr}) + } + cc, err := h.router.Dial(ctx, "tcp", addr) if err != nil { resp := gosocks4.NewReply(gosocks4.Failed, nil) diff --git a/handler/socks/v4/metadata.go b/handler/socks/v4/metadata.go index d6c2f4f..21ad759 100644 --- a/handler/socks/v4/metadata.go +++ b/handler/socks/v4/metadata.go @@ -9,13 +9,16 @@ import ( type metadata struct { readTimeout time.Duration + hash string } func (h *socks4Handler) parseMetadata(md mdata.Metadata) (err error) { const ( readTimeout = "readTimeout" + hash = "hash" ) h.md.readTimeout = mdutil.GetDuration(md, readTimeout) + h.md.hash = mdutil.GetString(md, hash) return } diff --git a/handler/socks/v5/connect.go b/handler/socks/v5/connect.go index 40a9c46..f3e8bc5 100644 --- a/handler/socks/v5/connect.go +++ b/handler/socks/v5/connect.go @@ -9,6 +9,7 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/gosocks5" netpkg "github.com/go-gost/x/internal/net" + sx "github.com/go-gost/x/internal/util/selector" ) func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error { @@ -25,6 +26,11 @@ func (h *socks5Handler) handleConnect(ctx context.Context, conn net.Conn, networ return resp.Write(conn) } + switch h.md.hash { + case "host": + ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: address}) + } + cc, err := h.router.Dial(ctx, network, address) if err != nil { resp := gosocks5.NewReply(gosocks5.NetUnreachable, nil) diff --git a/handler/socks/v5/metadata.go b/handler/socks/v5/metadata.go index b674522..523cff9 100644 --- a/handler/socks/v5/metadata.go +++ b/handler/socks/v5/metadata.go @@ -15,6 +15,7 @@ type metadata struct { enableUDP bool udpBufferSize int compatibilityMode bool + hash string } func (h *socks5Handler) parseMetadata(md mdata.Metadata) (err error) { @@ -25,6 +26,7 @@ func (h *socks5Handler) parseMetadata(md mdata.Metadata) (err error) { enableUDP = "udp" udpBufferSize = "udpBufferSize" compatibilityMode = "comp" + hash = "hash" ) h.md.readTimeout = mdutil.GetDuration(md, readTimeout) @@ -39,6 +41,7 @@ func (h *socks5Handler) parseMetadata(md mdata.Metadata) (err error) { } h.md.compatibilityMode = mdutil.GetBool(md, compatibilityMode) + h.md.hash = mdutil.GetString(md, hash) return nil } diff --git a/handler/ss/handler.go b/handler/ss/handler.go index fd206bc..5b156da 100644 --- a/handler/ss/handler.go +++ b/handler/ss/handler.go @@ -12,6 +12,7 @@ import ( md "github.com/go-gost/core/metadata" "github.com/go-gost/gosocks5" netpkg "github.com/go-gost/x/internal/net" + sx "github.com/go-gost/x/internal/util/selector" "github.com/go-gost/x/internal/util/ss" "github.com/go-gost/x/registry" "github.com/shadowsocks/go-shadowsocks2/core" @@ -106,6 +107,11 @@ func (h *ssHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.H return nil } + switch h.md.hash { + case "host": + ctx = sx.ContextWithHash(ctx, &sx.Hash{Source: addr.String()}) + } + cc, err := h.router.Dial(ctx, "tcp", addr.String()) if err != nil { return err diff --git a/handler/ss/metadata.go b/handler/ss/metadata.go index 4b2f154..06bda19 100644 --- a/handler/ss/metadata.go +++ b/handler/ss/metadata.go @@ -10,16 +10,19 @@ import ( type metadata struct { key string readTimeout time.Duration + hash string } func (h *ssHandler) parseMetadata(md mdata.Metadata) (err error) { const ( key = "key" readTimeout = "readTimeout" + hash = "hash" ) h.md.key = mdutil.GetString(md, key) h.md.readTimeout = mdutil.GetDuration(md, readTimeout) + h.md.hash = mdutil.GetString(md, hash) return } diff --git a/internal/util/selector/key.go b/internal/util/selector/key.go index d90fa8e..292eac3 100644 --- a/internal/util/selector/key.go +++ b/internal/util/selector/key.go @@ -8,7 +8,6 @@ type hashKey struct{} type Hash struct { Source string - Value int } var ( diff --git a/listener/grpc/listener.go b/listener/grpc/listener.go index b69af97..4e63660 100644 --- a/listener/grpc/listener.go +++ b/listener/grpc/listener.go @@ -58,10 +58,10 @@ func (l *grpcListener) Init(md md.Metadata) (err error) { return } ln = metrics.WrapListener(l.options.Service, ln) + ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = admission.WrapListener(l.options.Admission, ln) ln = limiter.WrapListener(l.options.TrafficLimiter, ln) ln = climiter.WrapListener(l.options.ConnLimiter, ln) - ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) var opts []grpc.ServerOption if !l.md.insecure { diff --git a/listener/http2/h2/listener.go b/listener/http2/h2/listener.go index 8e2ddaa..f0a1d5f 100644 --- a/listener/http2/h2/listener.go +++ b/listener/http2/h2/listener.go @@ -80,10 +80,10 @@ func (l *h2Listener) Init(md md.Metadata) (err error) { } l.addr = ln.Addr() ln = metrics.WrapListener(l.options.Service, ln) + ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = admission.WrapListener(l.options.Admission, ln) ln = limiter.WrapListener(l.options.TrafficLimiter, ln) ln = climiter.WrapListener(l.options.ConnLimiter, ln) - ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) if l.h2c { l.server.Handler = h2c.NewHandler( diff --git a/listener/http2/listener.go b/listener/http2/listener.go index 16e043c..666c3f9 100644 --- a/listener/http2/listener.go +++ b/listener/http2/listener.go @@ -69,10 +69,10 @@ func (l *http2Listener) Init(md md.Metadata) (err error) { } l.addr = ln.Addr() ln = metrics.WrapListener(l.options.Service, ln) + ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = admission.WrapListener(l.options.Admission, ln) ln = limiter.WrapListener(l.options.TrafficLimiter, ln) ln = climiter.WrapListener(l.options.ConnLimiter, ln) - ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = tls.NewListener( ln, diff --git a/listener/mtls/listener.go b/listener/mtls/listener.go index 9d3d21e..ed5d1aa 100644 --- a/listener/mtls/listener.go +++ b/listener/mtls/listener.go @@ -57,10 +57,10 @@ func (l *mtlsListener) Init(md md.Metadata) (err error) { } ln = metrics.WrapListener(l.options.Service, ln) + ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = admission.WrapListener(l.options.Admission, ln) ln = limiter.WrapListener(l.options.TrafficLimiter, ln) ln = climiter.WrapListener(l.options.ConnLimiter, ln) - ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.Listener = tls.NewListener(ln, l.options.TLSConfig) l.cqueue = make(chan net.Conn, l.md.backlog) diff --git a/listener/mws/listener.go b/listener/mws/listener.go index 40a4008..3823f3c 100644 --- a/listener/mws/listener.go +++ b/listener/mws/listener.go @@ -99,10 +99,10 @@ func (l *mwsListener) Init(md md.Metadata) (err error) { return } ln = metrics.WrapListener(l.options.Service, ln) + ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = admission.WrapListener(l.options.Admission, ln) ln = limiter.WrapListener(l.options.TrafficLimiter, ln) ln = climiter.WrapListener(l.options.ConnLimiter, ln) - ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) if l.tlsEnabled { ln = tls.NewListener(ln, l.options.TLSConfig) diff --git a/listener/obfs/http/listener.go b/listener/obfs/http/listener.go index 9772049..5bbe2d9 100644 --- a/listener/obfs/http/listener.go +++ b/listener/obfs/http/listener.go @@ -53,10 +53,10 @@ func (l *obfsListener) Init(md md.Metadata) (err error) { return } ln = metrics.WrapListener(l.options.Service, ln) + ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = admission.WrapListener(l.options.Admission, ln) ln = limiter.WrapListener(l.options.TrafficLimiter, ln) ln = climiter.WrapListener(l.options.ConnLimiter, ln) - ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.Listener = ln return diff --git a/listener/obfs/tls/listener.go b/listener/obfs/tls/listener.go index 1cba1e9..aac7c64 100644 --- a/listener/obfs/tls/listener.go +++ b/listener/obfs/tls/listener.go @@ -52,10 +52,10 @@ func (l *obfsListener) Init(md md.Metadata) (err error) { return } ln = metrics.WrapListener(l.options.Service, ln) + ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = admission.WrapListener(l.options.Admission, ln) ln = limiter.WrapListener(l.options.TrafficLimiter, ln) ln = climiter.WrapListener(l.options.ConnLimiter, ln) - ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.Listener = ln return diff --git a/listener/redirect/tcp/listener.go b/listener/redirect/tcp/listener.go index 9b0523f..c036cb9 100644 --- a/listener/redirect/tcp/listener.go +++ b/listener/redirect/tcp/listener.go @@ -60,10 +60,10 @@ func (l *redirectListener) Init(md md.Metadata) (err error) { } ln = metrics.WrapListener(l.options.Service, ln) + ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = admission.WrapListener(l.options.Admission, ln) ln = limiter.WrapListener(l.options.TrafficLimiter, ln) ln = climiter.WrapListener(l.options.ConnLimiter, ln) - ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.ln = ln return } diff --git a/listener/ssh/listener.go b/listener/ssh/listener.go index 9367baa..ad0c006 100644 --- a/listener/ssh/listener.go +++ b/listener/ssh/listener.go @@ -59,10 +59,10 @@ func (l *sshListener) Init(md md.Metadata) (err error) { } ln = metrics.WrapListener(l.options.Service, ln) + ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = admission.WrapListener(l.options.Admission, ln) ln = limiter.WrapListener(l.options.TrafficLimiter, ln) ln = climiter.WrapListener(l.options.ConnLimiter, ln) - ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.Listener = ln config := &ssh.ServerConfig{ diff --git a/listener/sshd/listener.go b/listener/sshd/listener.go index c532f76..23ea87f 100644 --- a/listener/sshd/listener.go +++ b/listener/sshd/listener.go @@ -68,10 +68,10 @@ func (l *sshdListener) Init(md md.Metadata) (err error) { } ln = metrics.WrapListener(l.options.Service, ln) + ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = admission.WrapListener(l.options.Admission, ln) ln = limiter.WrapListener(l.options.TrafficLimiter, ln) ln = climiter.WrapListener(l.options.ConnLimiter, ln) - ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.Listener = ln config := &ssh.ServerConfig{ diff --git a/listener/tcp/listener.go b/listener/tcp/listener.go index 638064b..c91c6e2 100644 --- a/listener/tcp/listener.go +++ b/listener/tcp/listener.go @@ -55,10 +55,10 @@ func (l *tcpListener) Init(md md.Metadata) (err error) { l.logger.Debugf("pp: %d", l.options.ProxyProtocol) ln = metrics.WrapListener(l.options.Service, ln) + ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = admission.WrapListener(l.options.Admission, ln) ln = limiter.WrapListener(l.options.TrafficLimiter, ln) ln = climiter.WrapListener(l.options.ConnLimiter, ln) - ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.ln = ln return diff --git a/listener/tls/listener.go b/listener/tls/listener.go index 259cc1c..82712e0 100644 --- a/listener/tls/listener.go +++ b/listener/tls/listener.go @@ -53,10 +53,10 @@ func (l *tlsListener) Init(md md.Metadata) (err error) { return } ln = metrics.WrapListener(l.options.Service, ln) + ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = admission.WrapListener(l.options.Admission, ln) ln = limiter.WrapListener(l.options.TrafficLimiter, ln) ln = climiter.WrapListener(l.options.ConnLimiter, ln) - ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) l.ln = tls.NewListener(ln, l.options.TLSConfig) diff --git a/listener/ws/listener.go b/listener/ws/listener.go index 5db2269..a07c4d1 100644 --- a/listener/ws/listener.go +++ b/listener/ws/listener.go @@ -94,10 +94,10 @@ func (l *wsListener) Init(md md.Metadata) (err error) { return } ln = metrics.WrapListener(l.options.Service, ln) + ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) ln = admission.WrapListener(l.options.Admission, ln) ln = limiter.WrapListener(l.options.TrafficLimiter, ln) ln = climiter.WrapListener(l.options.ConnLimiter, ln) - ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second) if l.tlsEnabled { ln = tls.NewListener(ln, l.options.TLSConfig) diff --git a/metrics/metrics.go b/metrics/metrics.go index 1324ed9..b12dd72 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -1,135 +1,52 @@ package metrics -import ( - "os" +import "github.com/go-gost/core/metrics" - "github.com/go-gost/core/metrics" - "github.com/prometheus/client_golang/prometheus" +const ( + // Number of services. Labels: host. + MetricServicesGauge metrics.MetricName = "gost_services" + // Total service requests. Labels: host, service. + MetricServiceRequestsCounter metrics.MetricName = "gost_service_requests_total" + // Number of in-flight requests. Labels: host, service. + MetricServiceRequestsInFlightGauge metrics.MetricName = "gost_service_requests_in_flight" + // Request duration historgram. Labels: host, service. + MetricServiceRequestsDurationObserver metrics.MetricName = "gost_service_request_duration_seconds" + // Total service input data transfer size in bytes. Labels: host, service. + MetricServiceTransferInputBytesCounter metrics.MetricName = "gost_service_transfer_input_bytes_total" + // Total service output data transfer size in bytes. Labels: host, service. + MetricServiceTransferOutputBytesCounter metrics.MetricName = "gost_service_transfer_output_bytes_total" + // Chain node connect duration histogram. Labels: host, chain, node. + MetricNodeConnectDurationObserver metrics.MetricName = "gost_chain_node_connect_duration_seconds" + // Total service handler errors. Labels: host, service. + MetricServiceHandlerErrorsCounter metrics.MetricName = "gost_service_handler_errors_total" + // Total chain connect errors. Labels: host, chain, node. + MetricChainErrorsCounter metrics.MetricName = "gost_chain_errors_total" ) -type promMetrics struct { - host string - gauges map[metrics.MetricName]*prometheus.GaugeVec - counters map[metrics.MetricName]*prometheus.CounterVec - histograms map[metrics.MetricName]*prometheus.HistogramVec +var ( + global metrics.Metrics = Noop() +) + +func Init(m metrics.Metrics) { + if m != nil { + global = m + } else { + global = Noop() + } } -func NewMetrics() metrics.Metrics { - host, _ := os.Hostname() - m := &promMetrics{ - host: host, - gauges: map[metrics.MetricName]*prometheus.GaugeVec{ - metrics.MetricServicesGauge: prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: string(metrics.MetricServicesGauge), - Help: "Current number of services", - }, - []string{"host"}), - metrics.MetricServiceRequestsInFlightGauge: prometheus.NewGaugeVec( - prometheus.GaugeOpts{ - Name: string(metrics.MetricServiceRequestsInFlightGauge), - Help: "Current in-flight requests", - }, - []string{"host", "service"}), - }, - counters: map[metrics.MetricName]*prometheus.CounterVec{ - metrics.MetricServiceRequestsCounter: prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: string(metrics.MetricServiceRequestsCounter), - Help: "Total number of requests", - }, - []string{"host", "service"}), - metrics.MetricServiceTransferInputBytesCounter: prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: string(metrics.MetricServiceTransferInputBytesCounter), - Help: "Total service input data transfer size in bytes", - }, - []string{"host", "service"}), - metrics.MetricServiceTransferOutputBytesCounter: prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: string(metrics.MetricServiceTransferOutputBytesCounter), - Help: "Total service output data transfer size in bytes", - }, - []string{"host", "service"}), - metrics.MetricServiceHandlerErrorsCounter: prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: string(metrics.MetricServiceHandlerErrorsCounter), - Help: "Total service handler errors", - }, - []string{"host", "service"}), - metrics.MetricChainErrorsCounter: prometheus.NewCounterVec( - prometheus.CounterOpts{ - Name: string(metrics.MetricChainErrorsCounter), - Help: "Total chain errors", - }, - []string{"host", "chain", "node"}), - }, - histograms: map[metrics.MetricName]*prometheus.HistogramVec{ - metrics.MetricServiceRequestsDurationObserver: prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Name: string(metrics.MetricServiceRequestsDurationObserver), - Help: "Distribution of request latencies", - Buckets: []float64{ - .005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10, 15, 30, 60, - }, - }, - []string{"host", "service"}), - metrics.MetricNodeConnectDurationObserver: prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Name: string(metrics.MetricNodeConnectDurationObserver), - Help: "Distribution of chain node connect latencies", - Buckets: []float64{ - .01, .05, .1, .25, .5, 1, 1.5, 2, 5, 10, 15, 30, 60, - }, - }, - []string{"host", "chain", "node"}), - }, - } - for k := range m.gauges { - prometheus.MustRegister(m.gauges[k]) - } - for k := range m.counters { - prometheus.MustRegister(m.counters[k]) - } - for k := range m.histograms { - prometheus.MustRegister(m.histograms[k]) - } - - return m +func IsEnabled() bool { + return global != Noop() } -func (m *promMetrics) Gauge(name metrics.MetricName, labels metrics.Labels) metrics.Gauge { - v, ok := m.gauges[name] - if !ok { - return nil - } - if labels == nil { - labels = metrics.Labels{} - } - labels["host"] = m.host - return v.With(prometheus.Labels(labels)) +func GetCounter(name metrics.MetricName, labels metrics.Labels) metrics.Counter { + return global.Counter(name, labels) } -func (m *promMetrics) Counter(name metrics.MetricName, labels metrics.Labels) metrics.Counter { - v, ok := m.counters[name] - if !ok { - return nil - } - if labels == nil { - labels = metrics.Labels{} - } - labels["host"] = m.host - return v.With(prometheus.Labels(labels)) +func GetGauge(name metrics.MetricName, labels metrics.Labels) metrics.Gauge { + return global.Gauge(name, labels) } -func (m *promMetrics) Observer(name metrics.MetricName, labels metrics.Labels) metrics.Observer { - v, ok := m.histograms[name] - if !ok { - return nil - } - if labels == nil { - labels = metrics.Labels{} - } - labels["host"] = m.host - return v.With(prometheus.Labels(labels)) +func GetObserver(name metrics.MetricName, labels metrics.Labels) metrics.Observer { + return global.Observer(name, labels) } diff --git a/metrics/noop.go b/metrics/noop.go new file mode 100644 index 0000000..4ea35d1 --- /dev/null +++ b/metrics/noop.go @@ -0,0 +1,45 @@ +package metrics + +import "github.com/go-gost/core/metrics" + +var ( + nopGauge = &noopGauge{} + nopCounter = &noopCounter{} + nopObserver = &noopObserver{} + + noop metrics.Metrics = &noopMetrics{} +) + +type noopMetrics struct{} + +func Noop() metrics.Metrics { + return noop +} + +func (m *noopMetrics) Counter(name metrics.MetricName, labels metrics.Labels) metrics.Counter { + return nopCounter +} + +func (m *noopMetrics) Gauge(name metrics.MetricName, labels metrics.Labels) metrics.Gauge { + return nopGauge +} + +func (m *noopMetrics) Observer(name metrics.MetricName, labels metrics.Labels) metrics.Observer { + return nopObserver +} + +type noopGauge struct{} + +func (*noopGauge) Inc() {} +func (*noopGauge) Dec() {} +func (*noopGauge) Add(v float64) {} +func (*noopGauge) Set(v float64) {} + +type noopCounter struct{} + +func (*noopCounter) Inc() {} +func (*noopCounter) Add(v float64) {} + +type noopObserver struct{} + +func (*noopObserver) Observe(v float64) {} diff --git a/metrics/prom.go b/metrics/prom.go new file mode 100644 index 0000000..413380f --- /dev/null +++ b/metrics/prom.go @@ -0,0 +1,135 @@ +package metrics + +import ( + "os" + + "github.com/go-gost/core/metrics" + "github.com/prometheus/client_golang/prometheus" +) + +type promMetrics struct { + host string + gauges map[metrics.MetricName]*prometheus.GaugeVec + counters map[metrics.MetricName]*prometheus.CounterVec + histograms map[metrics.MetricName]*prometheus.HistogramVec +} + +func NewMetrics() metrics.Metrics { + host, _ := os.Hostname() + m := &promMetrics{ + host: host, + gauges: map[metrics.MetricName]*prometheus.GaugeVec{ + MetricServicesGauge: prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: string(MetricServicesGauge), + Help: "Current number of services", + }, + []string{"host"}), + MetricServiceRequestsInFlightGauge: prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: string(MetricServiceRequestsInFlightGauge), + Help: "Current in-flight requests", + }, + []string{"host", "service"}), + }, + counters: map[metrics.MetricName]*prometheus.CounterVec{ + MetricServiceRequestsCounter: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: string(MetricServiceRequestsCounter), + Help: "Total number of requests", + }, + []string{"host", "service"}), + MetricServiceTransferInputBytesCounter: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: string(MetricServiceTransferInputBytesCounter), + Help: "Total service input data transfer size in bytes", + }, + []string{"host", "service"}), + MetricServiceTransferOutputBytesCounter: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: string(MetricServiceTransferOutputBytesCounter), + Help: "Total service output data transfer size in bytes", + }, + []string{"host", "service"}), + MetricServiceHandlerErrorsCounter: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: string(MetricServiceHandlerErrorsCounter), + Help: "Total service handler errors", + }, + []string{"host", "service"}), + MetricChainErrorsCounter: prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: string(MetricChainErrorsCounter), + Help: "Total chain errors", + }, + []string{"host", "chain", "node"}), + }, + histograms: map[metrics.MetricName]*prometheus.HistogramVec{ + MetricServiceRequestsDurationObserver: prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: string(MetricServiceRequestsDurationObserver), + Help: "Distribution of request latencies", + Buckets: []float64{ + .005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10, 15, 30, 60, + }, + }, + []string{"host", "service"}), + MetricNodeConnectDurationObserver: prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Name: string(MetricNodeConnectDurationObserver), + Help: "Distribution of chain node connect latencies", + Buckets: []float64{ + .01, .05, .1, .25, .5, 1, 1.5, 2, 5, 10, 15, 30, 60, + }, + }, + []string{"host", "chain", "node"}), + }, + } + for k := range m.gauges { + prometheus.MustRegister(m.gauges[k]) + } + for k := range m.counters { + prometheus.MustRegister(m.counters[k]) + } + for k := range m.histograms { + prometheus.MustRegister(m.histograms[k]) + } + + return m +} + +func (m *promMetrics) Gauge(name metrics.MetricName, labels metrics.Labels) metrics.Gauge { + v, ok := m.gauges[name] + if !ok { + return nil + } + if labels == nil { + labels = metrics.Labels{} + } + labels["host"] = m.host + return v.With(prometheus.Labels(labels)) +} + +func (m *promMetrics) Counter(name metrics.MetricName, labels metrics.Labels) metrics.Counter { + v, ok := m.counters[name] + if !ok { + return nil + } + if labels == nil { + labels = metrics.Labels{} + } + labels["host"] = m.host + return v.With(prometheus.Labels(labels)) +} + +func (m *promMetrics) Observer(name metrics.MetricName, labels metrics.Labels) metrics.Observer { + v, ok := m.histograms[name] + if !ok { + return nil + } + if labels == nil { + labels = metrics.Labels{} + } + labels["host"] = m.host + return v.With(prometheus.Labels(labels)) +} diff --git a/metrics/service/service.go b/metrics/service/service.go index 5eba0bb..aa90f80 100644 --- a/metrics/service/service.go +++ b/metrics/service/service.go @@ -4,6 +4,7 @@ import ( "net" "net/http" + "github.com/go-gost/core/service" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -23,12 +24,12 @@ func PathOption(path string) Option { } } -type Service struct { +type metricService struct { s *http.Server ln net.Listener } -func NewService(addr string, opts ...Option) (*Service, error) { +func NewService(addr string, opts ...Option) (service.Service, error) { ln, err := net.Listen("tcp", addr) if err != nil { return nil, err @@ -44,7 +45,7 @@ func NewService(addr string, opts ...Option) (*Service, error) { mux := http.NewServeMux() mux.Handle(options.path, promhttp.Handler()) - return &Service{ + return &metricService{ s: &http.Server{ Handler: mux, }, @@ -52,14 +53,14 @@ func NewService(addr string, opts ...Option) (*Service, error) { }, nil } -func (s *Service) Serve() error { +func (s *metricService) Serve() error { return s.s.Serve(s.ln) } -func (s *Service) Addr() net.Addr { +func (s *metricService) Addr() net.Addr { return s.ln.Addr() } -func (s *Service) Close() error { +func (s *metricService) Close() error { return s.s.Close() } diff --git a/metrics/wrapper/conn.go b/metrics/wrapper/conn.go index 2a2ace7..dd9bdc7 100644 --- a/metrics/wrapper/conn.go +++ b/metrics/wrapper/conn.go @@ -9,6 +9,7 @@ import ( "github.com/go-gost/core/metrics" xnet "github.com/go-gost/x/internal/net" "github.com/go-gost/x/internal/net/udp" + xmetrics "github.com/go-gost/x/metrics" ) var ( @@ -22,7 +23,7 @@ type serverConn struct { } func WrapConn(service string, c net.Conn) net.Conn { - if !metrics.IsEnabled() { + if !xmetrics.IsEnabled() { return c } return &serverConn{ @@ -33,8 +34,8 @@ func WrapConn(service string, c net.Conn) net.Conn { func (c *serverConn) Read(b []byte) (n int, err error) { n, err = c.Conn.Read(b) - if counter := metrics.GetCounter( - metrics.MetricServiceTransferInputBytesCounter, + if counter := xmetrics.GetCounter( + xmetrics.MetricServiceTransferInputBytesCounter, metrics.Labels{ "service": c.service, }); counter != nil { @@ -45,8 +46,8 @@ func (c *serverConn) Read(b []byte) (n int, err error) { func (c *serverConn) Write(b []byte) (n int, err error) { n, err = c.Conn.Write(b) - if counter := metrics.GetCounter( - metrics.MetricServiceTransferOutputBytesCounter, + if counter := xmetrics.GetCounter( + xmetrics.MetricServiceTransferOutputBytesCounter, metrics.Labels{ "service": c.service, }); counter != nil { @@ -70,7 +71,7 @@ type packetConn struct { } func WrapPacketConn(service string, pc net.PacketConn) net.PacketConn { - if !metrics.IsEnabled() { + if !xmetrics.IsEnabled() { return pc } return &packetConn{ @@ -81,8 +82,8 @@ func WrapPacketConn(service string, pc net.PacketConn) net.PacketConn { func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { n, addr, err = c.PacketConn.ReadFrom(p) - if counter := metrics.GetCounter( - metrics.MetricServiceTransferInputBytesCounter, + if counter := xmetrics.GetCounter( + xmetrics.MetricServiceTransferInputBytesCounter, metrics.Labels{ "service": c.service, }); counter != nil { @@ -93,8 +94,8 @@ func (c *packetConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (c *packetConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { n, err = c.PacketConn.WriteTo(p, addr) - if counter := metrics.GetCounter( - metrics.MetricServiceTransferOutputBytesCounter, + if counter := xmetrics.GetCounter( + xmetrics.MetricServiceTransferOutputBytesCounter, metrics.Labels{ "service": c.service, }); counter != nil { @@ -139,8 +140,8 @@ func (c *udpConn) SetWriteBuffer(n int) error { func (c *udpConn) Read(b []byte) (n int, err error) { if nc, ok := c.PacketConn.(io.Reader); ok { n, err = nc.Read(b) - if counter := metrics.GetCounter( - metrics.MetricServiceTransferInputBytesCounter, + if counter := xmetrics.GetCounter( + xmetrics.MetricServiceTransferInputBytesCounter, metrics.Labels{ "service": c.service, }); counter != nil { @@ -154,8 +155,8 @@ func (c *udpConn) Read(b []byte) (n int, err error) { func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { n, addr, err = c.PacketConn.ReadFrom(p) - if counter := metrics.GetCounter( - metrics.MetricServiceTransferInputBytesCounter, + if counter := xmetrics.GetCounter( + xmetrics.MetricServiceTransferInputBytesCounter, metrics.Labels{ "service": c.service, }); counter != nil { @@ -167,8 +168,8 @@ func (c *udpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { func (c *udpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { if nc, ok := c.PacketConn.(udp.ReadUDP); ok { n, addr, err = nc.ReadFromUDP(b) - if counter := metrics.GetCounter( - metrics.MetricServiceTransferInputBytesCounter, + if counter := xmetrics.GetCounter( + xmetrics.MetricServiceTransferInputBytesCounter, metrics.Labels{ "service": c.service, }); counter != nil { @@ -183,8 +184,8 @@ func (c *udpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { func (c *udpConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAddr, err error) { if nc, ok := c.PacketConn.(udp.ReadUDP); ok { n, oobn, flags, addr, err = nc.ReadMsgUDP(b, oob) - if counter := metrics.GetCounter( - metrics.MetricServiceTransferInputBytesCounter, + if counter := xmetrics.GetCounter( + xmetrics.MetricServiceTransferInputBytesCounter, metrics.Labels{ "service": c.service, }); counter != nil { @@ -199,8 +200,8 @@ func (c *udpConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *net.UDPAd func (c *udpConn) Write(b []byte) (n int, err error) { if nc, ok := c.PacketConn.(io.Writer); ok { n, err = nc.Write(b) - if counter := metrics.GetCounter( - metrics.MetricServiceTransferOutputBytesCounter, + if counter := xmetrics.GetCounter( + xmetrics.MetricServiceTransferOutputBytesCounter, metrics.Labels{ "service": c.service, }); counter != nil { @@ -214,8 +215,8 @@ func (c *udpConn) Write(b []byte) (n int, err error) { func (c *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { n, err = c.PacketConn.WriteTo(p, addr) - if counter := metrics.GetCounter( - metrics.MetricServiceTransferOutputBytesCounter, + if counter := xmetrics.GetCounter( + xmetrics.MetricServiceTransferOutputBytesCounter, metrics.Labels{ "service": c.service, }); counter != nil { @@ -227,8 +228,8 @@ func (c *udpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *udpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) { if nc, ok := c.PacketConn.(udp.WriteUDP); ok { n, err = nc.WriteToUDP(b, addr) - if counter := metrics.GetCounter( - metrics.MetricServiceTransferOutputBytesCounter, + if counter := xmetrics.GetCounter( + xmetrics.MetricServiceTransferOutputBytesCounter, metrics.Labels{ "service": c.service, }); counter != nil { @@ -243,8 +244,8 @@ func (c *udpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) { func (c *udpConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) { if nc, ok := c.PacketConn.(udp.WriteUDP); ok { n, oobn, err = nc.WriteMsgUDP(b, oob, addr) - if counter := metrics.GetCounter( - metrics.MetricServiceTransferOutputBytesCounter, + if counter := xmetrics.GetCounter( + xmetrics.MetricServiceTransferOutputBytesCounter, metrics.Labels{ "service": c.service, }); counter != nil { diff --git a/metrics/wrapper/listener.go b/metrics/wrapper/listener.go index f663cb8..eb339fa 100644 --- a/metrics/wrapper/listener.go +++ b/metrics/wrapper/listener.go @@ -3,7 +3,7 @@ package wrapper import ( "net" - "github.com/go-gost/core/metrics" + xmetrics "github.com/go-gost/x/metrics" ) type listener struct { @@ -12,7 +12,7 @@ type listener struct { } func WrapListener(service string, ln net.Listener) net.Listener { - if !metrics.IsEnabled() { + if !xmetrics.IsEnabled() { return ln } diff --git a/selector/strategy.go b/selector/strategy.go index bba51f1..f8b9e11 100644 --- a/selector/strategy.go +++ b/selector/strategy.go @@ -2,11 +2,13 @@ package selector import ( "context" + "hash/crc32" "math/rand" "sync" "sync/atomic" "time" + "github.com/go-gost/core/logger" "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" "github.com/go-gost/core/selector" @@ -101,7 +103,9 @@ func (s *hashStrategy[T]) Apply(ctx context.Context, vs ...T) (v T) { return } if h := sx.HashFromContext(ctx); h != nil { - return vs[h.Value%len(vs)] + value := uint64(crc32.ChecksumIEEE([]byte(h.Source))) + logger.Default().Tracef("hash %s %d", h.Source, value) + return vs[value%uint64(len(vs))] } s.mu.Lock() diff --git a/service/service.go b/service/service.go new file mode 100644 index 0000000..725994d --- /dev/null +++ b/service/service.go @@ -0,0 +1,169 @@ +package service + +import ( + "context" + "hash/crc32" + "net" + "time" + + "github.com/go-gost/core/admission" + "github.com/go-gost/core/handler" + "github.com/go-gost/core/listener" + "github.com/go-gost/core/logger" + "github.com/go-gost/core/metrics" + "github.com/go-gost/core/recorder" + "github.com/go-gost/core/service" + sx "github.com/go-gost/x/internal/util/selector" + xmetrics "github.com/go-gost/x/metrics" +) + +type options struct { + admission admission.Admission + recorders []recorder.RecorderObject + logger logger.Logger +} + +type Option func(opts *options) + +func AdmissionOption(admission admission.Admission) Option { + return func(opts *options) { + opts.admission = admission + } +} + +func RecordersOption(recorders ...recorder.RecorderObject) Option { + return func(opts *options) { + opts.recorders = recorders + } +} + +func LoggerOption(logger logger.Logger) Option { + return func(opts *options) { + opts.logger = logger + } +} + +type defaultService struct { + name string + listener listener.Listener + handler handler.Handler + options options +} + +func NewService(name string, ln listener.Listener, h handler.Handler, opts ...Option) service.Service { + var options options + for _, opt := range opts { + opt(&options) + } + return &defaultService{ + name: name, + listener: ln, + handler: h, + options: options, + } +} + +func (s *defaultService) Addr() net.Addr { + return s.listener.Addr() +} + +func (s *defaultService) Close() error { + return s.listener.Close() +} + +func (s *defaultService) Serve() error { + if v := xmetrics.GetGauge( + xmetrics.MetricServicesGauge, + metrics.Labels{}); v != nil { + v.Inc() + defer v.Dec() + } + + var tempDelay time.Duration + for { + conn, e := s.listener.Accept() + if e != nil { + // TODO: remove Temporary checking + if ne, ok := e.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 1 * time.Second + } else { + tempDelay *= 2 + } + if max := 5 * time.Second; tempDelay > max { + tempDelay = max + } + s.options.logger.Warnf("accept: %v, retrying in %v", e, tempDelay) + time.Sleep(tempDelay) + continue + } + s.options.logger.Errorf("accept: %v", e) + return e + } + tempDelay = 0 + + for _, rec := range s.options.recorders { + host := conn.RemoteAddr().String() + if h, _, _ := net.SplitHostPort(host); h != "" { + host = h + } + if rec.Record == recorder.RecorderServiceClientAddress { + if err := rec.Recorder.Record(context.Background(), []byte(host)); err != nil { + s.options.logger.Errorf("record %s: %v", rec.Record, err) + } + } + } + if s.options.admission != nil && + !s.options.admission.Admit(conn.RemoteAddr().String()) { + conn.Close() + s.options.logger.Debugf("admission: %s is denied", conn.RemoteAddr()) + continue + } + + go func() { + if v := xmetrics.GetCounter(xmetrics.MetricServiceRequestsCounter, + metrics.Labels{"service": s.name}); v != nil { + v.Inc() + } + + if v := xmetrics.GetGauge(xmetrics.MetricServiceRequestsInFlightGauge, + metrics.Labels{"service": s.name}); v != nil { + v.Inc() + defer v.Dec() + } + + start := time.Now() + if v := xmetrics.GetObserver(xmetrics.MetricServiceRequestsDurationObserver, + metrics.Labels{"service": s.name}); v != nil { + defer func() { + v.Observe(float64(time.Since(start).Seconds())) + }() + } + + ctx := sx.ContextWithHash(context.Background(), &sx.Hash{ + Source: conn.RemoteAddr().String(), + Value: ipHash(conn.RemoteAddr()), + }) + + if err := s.handler.Handle(ctx, conn); err != nil { + s.options.logger.Error(err) + if v := xmetrics.GetCounter(xmetrics.MetricServiceHandlerErrorsCounter, + metrics.Labels{"service": s.name}); v != nil { + v.Inc() + } + } + }() + } +} + +func ipHash(addr net.Addr) uint64 { + if addr == nil { + return 0 + } + + host, _, _ := net.SplitHostPort(addr.String()) + if ip := net.ParseIP(host); ip != nil { + return uint64(crc32.ChecksumIEEE(ip.To16())) + } + return 0 +}