From 9b695bc37416777836bf129457a5995f7b86904a Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Fri, 2 Sep 2022 10:57:40 +0800 Subject: [PATCH] add chain group --- config/config.go | 39 +++++++++++++++---------- config/parsing/chain.go | 27 ++++++++++-------- config/parsing/parse.go | 36 ++++++++++++++++++----- config/parsing/service.go | 47 ++++++++++++++++++------------- handler/dns/handler.go | 6 +--- handler/forward/local/handler.go | 8 ++++-- handler/forward/remote/handler.go | 8 ++++-- handler/relay/forward.go | 8 ++++-- registry/chain.go | 25 +++++++++++++--- registry/registry.go | 18 ++++++------ 10 files changed, 145 insertions(+), 77 deletions(-) diff --git a/config/config.go b/config/config.go index e8326ed..19a25ff 100644 --- a/config/config.go +++ b/config/config.go @@ -185,24 +185,26 @@ type RecorderObject struct { } type ListenerConfig struct { - Type string `json:"type"` - Chain string `yaml:",omitempty" json:"chain,omitempty"` - Auther string `yaml:",omitempty" json:"auther,omitempty"` - Authers []string `yaml:",omitempty" json:"authers,omitempty"` - Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` - TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` - Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"` + Type string `json:"type"` + Chain string `yaml:",omitempty" json:"chain,omitempty"` + ChainGroup *ChainGroupConfig `yaml:"chainGroup,omitempty" json:"chainGroup,omitempty"` + Auther string `yaml:",omitempty" json:"auther,omitempty"` + Authers []string `yaml:",omitempty" json:"authers,omitempty"` + Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` + TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` + Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"` } type HandlerConfig struct { - Type string `json:"type"` - Retries int `yaml:",omitempty" json:"retries,omitempty"` - Chain string `yaml:",omitempty" json:"chain,omitempty"` - Auther string `yaml:",omitempty" json:"auther,omitempty"` - Authers []string `yaml:",omitempty" json:"authers,omitempty"` - Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` - TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` - Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"` + Type string `json:"type"` + Retries int `yaml:",omitempty" json:"retries,omitempty"` + Chain string `yaml:",omitempty" json:"chain,omitempty"` + ChainGroup *ChainGroupConfig `yaml:"chainGroup,omitempty" json:"chainGroup,omitempty"` + Auther string `yaml:",omitempty" json:"auther,omitempty"` + Authers []string `yaml:",omitempty" json:"authers,omitempty"` + Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` + TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` + Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"` } type ForwarderConfig struct { @@ -251,6 +253,12 @@ type ChainConfig struct { Name string `json:"name"` Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"` Hops []*HopConfig `json:"hops"` + Metadata map[string]any `yaml:",omitempty", json:"metadata,omitempty"` +} + +type ChainGroupConfig struct { + Chains []string `yaml:",omitempty" json:"chains,omitempty"` + Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"` } type HopConfig struct { @@ -276,6 +284,7 @@ type NodeConfig struct { Hosts string `yaml:",omitempty" json:"hosts,omitempty"` Connector *ConnectorConfig `yaml:",omitempty" json:"connector,omitempty"` Dialer *DialerConfig `yaml:",omitempty" json:"dialer,omitempty"` + Metadata map[string]any `yaml:",omitempty", json:"metadata,omitempty"` } type Config struct { diff --git a/config/parsing/chain.go b/config/parsing/chain.go index 974352e..5a53806 100644 --- a/config/parsing/chain.go +++ b/config/parsing/chain.go @@ -12,7 +12,7 @@ import ( "github.com/go-gost/x/registry" ) -func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { +func ParseChain(cfg *config.ChainConfig) (chain.SelectableChainer, error) { if cfg == nil { return nil, nil } @@ -23,7 +23,11 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { }) c := chain.NewChain(cfg.Name) - selector := parseSelector(cfg.Selector) + if cfg.Metadata != nil { + c.WithMetadata(metadata.NewMetadata(cfg.Metadata)) + } + + selector := parseNodeSelector(cfg.Selector) for _, hop := range cfg.Hops { group := &chain.NodeGroup{} for _, v := range hop.Nodes { @@ -121,24 +125,23 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { WithInterface(v.Interface). WithSockOpts(sockOpts) - node := &chain.Node{ - Name: v.Name, - Addr: v.Addr, - Bypass: bypass.BypassList(bypassList(v.Bypass, v.Bypasses...)...), - Resolver: registry.ResolverRegistry().Get(v.Resolver), - Hosts: registry.HostsRegistry().Get(v.Hosts), - Marker: &chain.FailMarker{}, - Transport: tr, + node := chain.NewNode(v.Name, v.Addr). + WithTransport(tr). + WithBypass(bypass.BypassGroup(bypassList(v.Bypass, v.Bypasses...)...)). + WithResolver(registry.ResolverRegistry().Get(v.Resolver)). + WithHostMapper(registry.HostsRegistry().Get(v.Hosts)) + if v.Metadata != nil { + node.WithMetadata(metadata.NewMetadata(v.Metadata)) } group.AddNode(node) } sel := selector - if s := parseSelector(hop.Selector); s != nil { + if s := parseNodeSelector(hop.Selector); s != nil { sel = s } group.WithSelector(sel). - WithBypass(bypass.BypassList(bypassList(hop.Bypass, hop.Bypasses...)...)) + WithBypass(bypass.BypassGroup(bypassList(hop.Bypass, hop.Bypasses...)...)) c.AddNodeGroup(group) } diff --git a/config/parsing/parse.go b/config/parsing/parse.go index 3c4b6d6..44e24b2 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -83,26 +83,48 @@ func parseAuth(cfg *config.AuthConfig) *url.Userinfo { return url.UserPassword(cfg.Username, cfg.Password) } -func parseSelector(cfg *config.SelectorConfig) chain.Selector { +func parseChainSelector(cfg *config.SelectorConfig) chain.Selector[chain.SelectableChainer] { if cfg == nil { return nil } - var strategy chain.Strategy + var strategy chain.Strategy[chain.SelectableChainer] switch cfg.Strategy { case "round", "rr": - strategy = chain.RoundRobinStrategy() + strategy = chain.RoundRobinStrategy[chain.SelectableChainer]() case "random", "rand": - strategy = chain.RandomStrategy() + strategy = chain.RandomStrategy[chain.SelectableChainer]() case "fifo", "ha": - strategy = chain.FIFOStrategy() + strategy = chain.FIFOStrategy[chain.SelectableChainer]() default: - strategy = chain.RoundRobinStrategy() + strategy = chain.RoundRobinStrategy[chain.SelectableChainer]() + } + return chain.NewSelector( + strategy, + chain.FailFilter[chain.SelectableChainer](cfg.MaxFails, cfg.FailTimeout), + ) +} + +func parseNodeSelector(cfg *config.SelectorConfig) chain.Selector[*chain.Node] { + if cfg == nil { + return nil + } + + var strategy chain.Strategy[*chain.Node] + switch cfg.Strategy { + case "round", "rr": + strategy = chain.RoundRobinStrategy[*chain.Node]() + case "random", "rand": + strategy = chain.RandomStrategy[*chain.Node]() + case "fifo", "ha": + strategy = chain.FIFOStrategy[*chain.Node]() + default: + strategy = chain.RoundRobinStrategy[*chain.Node]() } return chain.NewSelector( strategy, - chain.FailFilter(cfg.MaxFails, cfg.FailTimeout), + chain.FailFilter[*chain.Node](cfg.MaxFails, cfg.FailTimeout), ) } diff --git a/config/parsing/service.go b/config/parsing/service.go index 1b4f0e5..593436d 100644 --- a/config/parsing/service.go +++ b/config/parsing/service.go @@ -65,11 +65,11 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { ln := registry.ListenerRegistry().Get(cfg.Listener.Type)( listener.AddrOption(cfg.Addr), - listener.AutherOption(auth.AuthenticatorList(authers...)), + listener.AutherOption(auth.AuthenticatorGroup(authers...)), listener.AuthOption(parseAuth(cfg.Listener.Auth)), listener.TLSConfigOption(tlsConfig), - listener.AdmissionOption(admission.AdmissionList(admissions...)), - listener.ChainOption(registry.ChainRegistry().Get(cfg.Listener.Chain)), + listener.AdmissionOption(admission.AdmissionGroup(admissions...)), + listener.ChainOption(chainGroup(cfg.Listener.Chain, cfg.Listener.ChainGroup)), listener.LoggerOption(listenerLogger), listener.ServiceOption(cfg.Name), ) @@ -126,7 +126,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { // WithTimeout(timeout time.Duration). WithInterface(cfg.Interface). WithSockOpts(sockOpts). - WithChain(registry.ChainRegistry().Get(cfg.Handler.Chain)). + WithChain(chainGroup(cfg.Handler.Chain, cfg.Handler.ChainGroup)). WithResolver(registry.ResolverRegistry().Get(cfg.Resolver)). WithHosts(registry.HostsRegistry().Get(cfg.Hosts)). WithRecorder(recorders...). @@ -134,9 +134,9 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { h := registry.HandlerRegistry().Get(cfg.Handler.Type)( handler.RouterOption(router), - handler.AutherOption(auth.AuthenticatorList(authers...)), + handler.AutherOption(auth.AuthenticatorGroup(authers...)), handler.AuthOption(parseAuth(cfg.Handler.Auth)), - handler.BypassOption(bypass.BypassList(bypassList(cfg.Bypass, cfg.Bypasses...)...)), + handler.BypassOption(bypass.BypassGroup(bypassList(cfg.Bypass, cfg.Bypasses...)...)), handler.TLSConfigOption(tlsConfig), handler.LoggerOption(handlerLogger), ) @@ -154,7 +154,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { } s := service.NewService(cfg.Name, ln, h, - service.AdmissionOption(admission.AdmissionList(admissions...)), + service.AdmissionOption(admission.AdmissionGroup(admissions...)), service.LoggerOption(serviceLogger), ) @@ -172,27 +172,19 @@ func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup { if len(cfg.Nodes) > 0 { for _, node := range cfg.Nodes { if node != nil { - group.AddNode(&chain.Node{ - Name: node.Name, - Addr: node.Addr, - Bypass: bypass.BypassList(bypassList(node.Bypass, node.Bypasses...)...), - Marker: &chain.FailMarker{}, - }) + group.AddNode(chain.NewNode(node.Name, node.Addr). + WithBypass(bypass.BypassGroup(bypassList(node.Bypass, node.Bypasses...)...))) } } } else { for _, target := range cfg.Targets { if v := strings.TrimSpace(target); v != "" { - group.AddNode(&chain.Node{ - Name: target, - Addr: target, - Marker: &chain.FailMarker{}, - }) + group.AddNode(chain.NewNode(target, target)) } } } - return group.WithSelector(parseSelector(cfg.Selector)) + return group.WithSelector(parseNodeSelector(cfg.Selector)) } func bypassList(name string, names ...string) []bypass.Bypass { @@ -234,3 +226,20 @@ func admissionList(name string, names ...string) []admission.Admission { return admissions } + +func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer { + cg := &chain.ChainGroup{} + if c := registry.ChainRegistry().Get(name); c != nil { + cg.Chains = append(cg.Chains, c) + } + if group != nil { + for _, s := range group.Chains { + if c := registry.ChainRegistry().Get(s); c != nil { + cg.Chains = append(cg.Chains, c) + } + } + cg.Selector = parseChainSelector(group.Selector) + } + + return cg +} diff --git a/handler/dns/handler.go b/handler/dns/handler.go index 661f8cf..c993c16 100644 --- a/handler/dns/handler.go +++ b/handler/dns/handler.go @@ -71,11 +71,7 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { if addr == "" { continue } - h.group.AddNode(&chain.Node{ - Name: fmt.Sprintf("target-%d", i), - Addr: addr, - Marker: &chain.FailMarker{}, - }) + h.group.AddNode(chain.NewNode(fmt.Sprintf("target-%d", i), addr)) } } for _, node := range h.group.Nodes() { diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index 660c888..cebad15 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -100,11 +100,15 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand log.Error(err) // TODO: the router itself may be failed due to the failed node in the router, // the dead marker may be a wrong operation. - target.Marker.Mark() + if marker := target.Marker(); marker != nil { + marker.Mark() + } return err } defer cc.Close() - target.Marker.Reset() + if marker := target.Marker(); marker != nil { + marker.Reset() + } t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr) diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index e496b0e..3cc76cb 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -94,11 +94,15 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand log.Error(err) // TODO: the router itself may be failed due to the failed node in the router, // the dead marker may be a wrong operation. - target.Marker.Mark() + if marker := target.Marker(); marker != nil { + marker.Mark() + } return err } defer cc.Close() - target.Marker.Reset() + if marker := target.Marker(); marker != nil { + marker.Reset() + } t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), target.Addr) diff --git a/handler/relay/forward.go b/handler/relay/forward.go index b374d86..aa60da7 100644 --- a/handler/relay/forward.go +++ b/handler/relay/forward.go @@ -37,7 +37,9 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network if err != nil { // TODO: the router itself may be failed due to the failed node in the router, // the dead marker may be a wrong operation. - target.Marker.Mark() + if marker := target.Marker(); marker != nil { + marker.Mark() + } resp.Status = relay.StatusHostUnreachable resp.WriteTo(conn) @@ -46,7 +48,9 @@ func (h *relayHandler) handleForward(ctx context.Context, conn net.Conn, network return err } defer cc.Close() - target.Marker.Reset() + if marker := target.Marker(); marker != nil { + marker.Reset() + } if h.md.noDelay { if _, err := resp.WriteTo(conn); err != nil { diff --git a/registry/chain.go b/registry/chain.go index 36d2465..c71c7eb 100644 --- a/registry/chain.go +++ b/registry/chain.go @@ -2,26 +2,27 @@ package registry import ( "github.com/go-gost/core/chain" + "github.com/go-gost/core/metadata" ) type chainRegistry struct { registry } -func (r *chainRegistry) Register(name string, v chain.Chainer) error { +func (r *chainRegistry) Register(name string, v chain.SelectableChainer) error { return r.registry.Register(name, v) } -func (r *chainRegistry) Get(name string) chain.Chainer { +func (r *chainRegistry) Get(name string) chain.SelectableChainer { if name != "" { return &chainWrapper{name: name, r: r} } return nil } -func (r *chainRegistry) get(name string) chain.Chainer { +func (r *chainRegistry) get(name string) chain.SelectableChainer { if v := r.registry.Get(name); v != nil { - return v.(chain.Chainer) + return v.(chain.SelectableChainer) } return nil } @@ -31,6 +32,22 @@ type chainWrapper struct { r *chainRegistry } +func (w *chainWrapper) Marker() chain.Marker { + v := w.r.get(w.name) + if v == nil { + return nil + } + return v.Marker() +} + +func (w *chainWrapper) Metadata() metadata.Metadata { + v := w.r.get(w.name) + if v == nil { + return nil + } + return v.Metadata() +} + func (w *chainWrapper) Route(network, address string) *chain.Route { v := w.r.get(w.name) if v == nil { diff --git a/registry/registry.go b/registry/registry.go index f9aa5b4..1ca93e6 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -25,14 +25,14 @@ var ( dialerReg Registry[NewDialer] = &dialerRegistry{} connectorReg Registry[NewConnector] = &connectorRegistry{} - serviceReg Registry[service.Service] = &serviceRegistry{} - chainReg Registry[chain.Chainer] = &chainRegistry{} - autherReg Registry[auth.Authenticator] = &autherRegistry{} - admissionReg Registry[admission.Admission] = &admissionRegistry{} - bypassReg Registry[bypass.Bypass] = &bypassRegistry{} - resolverReg Registry[resolver.Resolver] = &resolverRegistry{} - hostsReg Registry[hosts.HostMapper] = &hostsRegistry{} - recorderReg Registry[recorder.Recorder] = &recorderRegistry{} + serviceReg Registry[service.Service] = &serviceRegistry{} + chainReg Registry[chain.SelectableChainer] = &chainRegistry{} + autherReg Registry[auth.Authenticator] = &autherRegistry{} + admissionReg Registry[admission.Admission] = &admissionRegistry{} + bypassReg Registry[bypass.Bypass] = &bypassRegistry{} + resolverReg Registry[resolver.Resolver] = &resolverRegistry{} + hostsReg Registry[hosts.HostMapper] = &hostsRegistry{} + recorderReg Registry[recorder.Recorder] = &recorderRegistry{} ) type Registry[T any] interface { @@ -99,7 +99,7 @@ func ServiceRegistry() Registry[service.Service] { return serviceReg } -func ChainRegistry() Registry[chain.Chainer] { +func ChainRegistry() Registry[chain.SelectableChainer] { return chainReg }