diff --git a/registry/admission.go b/registry/admission.go index 5168f59..3c54eb4 100644 --- a/registry/admission.go +++ b/registry/admission.go @@ -20,10 +20,7 @@ func (r *admissionRegistry) Get(name string) admission.Admission { } func (r *admissionRegistry) get(name string) admission.Admission { - if v := r.registry.Get(name); v != nil { - return v.(admission.Admission) - } - return nil + return r.registry.Get(name) } type admissionWrapper struct { diff --git a/registry/auther.go b/registry/auther.go index 826104d..208e59f 100644 --- a/registry/auther.go +++ b/registry/auther.go @@ -20,10 +20,7 @@ func (r *autherRegistry) Get(name string) auth.Authenticator { } func (r *autherRegistry) get(name string) auth.Authenticator { - if v := r.registry.Get(name); v != nil { - return v.(auth.Authenticator) - } - return nil + return r.registry.Get(name) } type autherWrapper struct { diff --git a/registry/bypass.go b/registry/bypass.go index 5699e44..4f48aa4 100644 --- a/registry/bypass.go +++ b/registry/bypass.go @@ -20,10 +20,7 @@ func (r *bypassRegistry) Get(name string) bypass.Bypass { } func (r *bypassRegistry) get(name string) bypass.Bypass { - if v := r.registry.Get(name); v != nil { - return v.(bypass.Bypass) - } - return nil + return r.registry.Get(name) } type bypassWrapper struct { diff --git a/registry/chain.go b/registry/chain.go index b1b27aa..a969f44 100644 --- a/registry/chain.go +++ b/registry/chain.go @@ -24,10 +24,7 @@ func (r *chainRegistry) Get(name string) chain.Chainer { } func (r *chainRegistry) get(name string) chain.Chainer { - if v := r.registry.Get(name); v != nil { - return v.(chain.Chainer) - } - return nil + return r.registry.Get(name) } type chainWrapper struct { diff --git a/registry/hop.go b/registry/hop.go index 52c3ce1..14e4bb4 100644 --- a/registry/hop.go +++ b/registry/hop.go @@ -22,10 +22,7 @@ func (r *hopRegistry) Get(name string) chain.Hop { } func (r *hopRegistry) get(name string) chain.Hop { - if v := r.registry.Get(name); v != nil { - return v.(chain.Hop) - } - return nil + return r.registry.Get(name) } type hopWrapper struct { diff --git a/registry/hosts.go b/registry/hosts.go index ab71310..a4e7093 100644 --- a/registry/hosts.go +++ b/registry/hosts.go @@ -22,10 +22,7 @@ func (r *hostsRegistry) Get(name string) hosts.HostMapper { } func (r *hostsRegistry) get(name string) hosts.HostMapper { - if v := r.registry.Get(name); v != nil { - return v.(hosts.HostMapper) - } - return nil + return r.registry.Get(name) } type hostsWrapper struct { diff --git a/registry/limiter.go b/registry/limiter.go index 218de80..4ade180 100644 --- a/registry/limiter.go +++ b/registry/limiter.go @@ -22,10 +22,7 @@ func (r *trafficLimiterRegistry) Get(name string) traffic.TrafficLimiter { } func (r *trafficLimiterRegistry) get(name string) traffic.TrafficLimiter { - if v := r.registry.Get(name); v != nil { - return v.(traffic.TrafficLimiter) - } - return nil + return r.registry.Get(name) } type trafficLimiterWrapper struct { @@ -65,10 +62,7 @@ func (r *connLimiterRegistry) Get(name string) conn.ConnLimiter { } func (r *connLimiterRegistry) get(name string) conn.ConnLimiter { - if v := r.registry.Get(name); v != nil { - return v.(conn.ConnLimiter) - } - return nil + return r.registry.Get(name) } type connLimiterWrapper struct { @@ -100,10 +94,7 @@ func (r *rateLimiterRegistry) Get(name string) rate.RateLimiter { } func (r *rateLimiterRegistry) get(name string) rate.RateLimiter { - if v := r.registry.Get(name); v != nil { - return v.(rate.RateLimiter) - } - return nil + return r.registry.Get(name) } type rateLimiterWrapper struct { diff --git a/registry/recorder.go b/registry/recorder.go index 0871255..27b28f9 100644 --- a/registry/recorder.go +++ b/registry/recorder.go @@ -22,10 +22,7 @@ func (r *recorderRegistry) Get(name string) recorder.Recorder { } func (r *recorderRegistry) get(name string) recorder.Recorder { - if v := r.registry.Get(name); v != nil { - return v.(recorder.Recorder) - } - return nil + return r.registry.Get(name) } type recorderWrapper struct { diff --git a/registry/registry.go b/registry/registry.go index 206f45d..7978dc4 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -14,6 +14,7 @@ import ( "github.com/go-gost/core/limiter/rate" "github.com/go-gost/core/limiter/traffic" "github.com/go-gost/core/recorder" + reg "github.com/go-gost/core/registry" "github.com/go-gost/core/resolver" "github.com/go-gost/core/service" ) @@ -23,34 +24,25 @@ var ( ) var ( - listenerReg Registry[NewListener] = new(listenerRegistry) - handlerReg Registry[NewHandler] = new(handlerRegistry) - dialerReg Registry[NewDialer] = new(dialerRegistry) - connectorReg Registry[NewConnector] = new(connectorRegistry) + listenerReg reg.Registry[NewListener] = new(listenerRegistry) + handlerReg reg.Registry[NewHandler] = new(handlerRegistry) + dialerReg reg.Registry[NewDialer] = new(dialerRegistry) + connectorReg reg.Registry[NewConnector] = new(connectorRegistry) + serviceReg reg.Registry[service.Service] = new(serviceRegistry) + chainReg reg.Registry[chain.Chainer] = new(chainRegistry) + hopReg reg.Registry[chain.Hop] = new(hopRegistry) + autherReg reg.Registry[auth.Authenticator] = new(autherRegistry) + admissionReg reg.Registry[admission.Admission] = new(admissionRegistry) + bypassReg reg.Registry[bypass.Bypass] = new(bypassRegistry) + resolverReg reg.Registry[resolver.Resolver] = new(resolverRegistry) + hostsReg reg.Registry[hosts.HostMapper] = new(hostsRegistry) + recorderReg reg.Registry[recorder.Recorder] = new(recorderRegistry) - serviceReg Registry[service.Service] = new(serviceRegistry) - chainReg Registry[chain.Chainer] = new(chainRegistry) - hopReg Registry[chain.Hop] = new(hopRegistry) - autherReg Registry[auth.Authenticator] = new(autherRegistry) - admissionReg Registry[admission.Admission] = new(admissionRegistry) - bypassReg Registry[bypass.Bypass] = new(bypassRegistry) - resolverReg Registry[resolver.Resolver] = new(resolverRegistry) - hostsReg Registry[hosts.HostMapper] = new(hostsRegistry) - recorderReg Registry[recorder.Recorder] = new(recorderRegistry) - - trafficLimiterReg Registry[traffic.TrafficLimiter] = new(trafficLimiterRegistry) - connLimiterReg Registry[conn.ConnLimiter] = new(connLimiterRegistry) - rateLimiterReg Registry[rate.RateLimiter] = new(rateLimiterRegistry) + trafficLimiterReg reg.Registry[traffic.TrafficLimiter] = new(trafficLimiterRegistry) + connLimiterReg reg.Registry[conn.ConnLimiter] = new(connLimiterRegistry) + rateLimiterReg reg.Registry[rate.RateLimiter] = new(rateLimiterRegistry) ) -type Registry[T any] interface { - Register(name string, v T) error - Unregister(name string) - IsRegistered(name string) bool - Get(name string) T - GetAll() map[string]T -} - type registry[T any] struct { m sync.Map } @@ -100,66 +92,66 @@ func (r *registry[T]) GetAll() (m map[string]T) { return } -func ListenerRegistry() Registry[NewListener] { +func ListenerRegistry() reg.Registry[NewListener] { return listenerReg } -func HandlerRegistry() Registry[NewHandler] { +func HandlerRegistry() reg.Registry[NewHandler] { return handlerReg } -func DialerRegistry() Registry[NewDialer] { +func DialerRegistry() reg.Registry[NewDialer] { return dialerReg } -func ConnectorRegistry() Registry[NewConnector] { +func ConnectorRegistry() reg.Registry[NewConnector] { return connectorReg } -func ServiceRegistry() Registry[service.Service] { +func ServiceRegistry() reg.Registry[service.Service] { return serviceReg } -func ChainRegistry() Registry[chain.Chainer] { +func ChainRegistry() reg.Registry[chain.Chainer] { return chainReg } -func HopRegistry() Registry[chain.Hop] { +func HopRegistry() reg.Registry[chain.Hop] { return hopReg } -func AutherRegistry() Registry[auth.Authenticator] { +func AutherRegistry() reg.Registry[auth.Authenticator] { return autherReg } -func AdmissionRegistry() Registry[admission.Admission] { +func AdmissionRegistry() reg.Registry[admission.Admission] { return admissionReg } -func BypassRegistry() Registry[bypass.Bypass] { +func BypassRegistry() reg.Registry[bypass.Bypass] { return bypassReg } -func ResolverRegistry() Registry[resolver.Resolver] { +func ResolverRegistry() reg.Registry[resolver.Resolver] { return resolverReg } -func HostsRegistry() Registry[hosts.HostMapper] { +func HostsRegistry() reg.Registry[hosts.HostMapper] { return hostsReg } -func RecorderRegistry() Registry[recorder.Recorder] { +func RecorderRegistry() reg.Registry[recorder.Recorder] { return recorderReg } -func TrafficLimiterRegistry() Registry[traffic.TrafficLimiter] { +func TrafficLimiterRegistry() reg.Registry[traffic.TrafficLimiter] { return trafficLimiterReg } -func ConnLimiterRegistry() Registry[conn.ConnLimiter] { +func ConnLimiterRegistry() reg.Registry[conn.ConnLimiter] { return connLimiterReg } -func RateLimiterRegistry() Registry[rate.RateLimiter] { +func RateLimiterRegistry() reg.Registry[rate.RateLimiter] { return rateLimiterReg } diff --git a/registry/resolver.go b/registry/resolver.go index 1de77d9..dff7e26 100644 --- a/registry/resolver.go +++ b/registry/resolver.go @@ -23,10 +23,7 @@ func (r *resolverRegistry) Get(name string) resolver.Resolver { } func (r *resolverRegistry) get(name string) resolver.Resolver { - if v := r.registry.Get(name); v != nil { - return v.(resolver.Resolver) - } - return nil + return r.registry.Get(name) } type resolverWrapper struct {