diff --git a/config/config.go b/config/config.go index c163eca..e8326ed 100644 --- a/config/config.go +++ b/config/config.go @@ -206,7 +206,9 @@ type HandlerConfig struct { } type ForwarderConfig struct { - Targets []string `json:"targets"` + // DEPRECATED by nodes since beta.4 + Targets []string `yaml:",omitempty" json:"targets,omitempty"` + Nodes []*NodeConfig `json:"nodes"` Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"` } diff --git a/config/parsing/chain.go b/config/parsing/chain.go index d899bfb..974352e 100644 --- a/config/parsing/chain.go +++ b/config/parsing/chain.go @@ -94,12 +94,6 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { return nil, err } - if v.Bypass == "" { - v.Bypass = hop.Bypass - } - if v.Bypasses == nil { - v.Bypasses = hop.Bypasses - } if v.Resolver == "" { v.Resolver = hop.Resolver } @@ -127,20 +121,10 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { WithInterface(v.Interface). WithSockOpts(sockOpts) - var bypasses []bypass.Bypass - if bp := registry.BypassRegistry().Get(v.Bypass); bp != nil { - bypasses = append(bypasses, bp) - } - for _, s := range v.Bypasses { - if bp := registry.BypassRegistry().Get(s); bp != nil { - bypasses = append(bypasses, bp) - } - } - node := &chain.Node{ Name: v.Name, Addr: v.Addr, - Bypass: bypass.BypassList(bypasses...), + Bypass: bypass.BypassList(bypassList(v.Bypass, v.Bypasses...)...), Resolver: registry.ResolverRegistry().Get(v.Resolver), Hosts: registry.HostsRegistry().Get(v.Hosts), Marker: &chain.FailMarker{}, @@ -153,18 +137,8 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { if s := parseSelector(hop.Selector); s != nil { sel = s } - group.WithSelector(sel) - - var bypasses []bypass.Bypass - if bp := registry.BypassRegistry().Get(hop.Bypass); bp != nil { - bypasses = append(bypasses, bp) - } - for _, s := range hop.Bypasses { - if bp := registry.BypassRegistry().Get(s); bp != nil { - bypasses = append(bypasses, bp) - } - } - group.WithBypass(bypass.BypassList(bypasses...)) + group.WithSelector(sel). + WithBypass(bypass.BypassList(bypassList(hop.Bypass, hop.Bypasses...)...)) c.AddNodeGroup(group) } diff --git a/config/parsing/service.go b/config/parsing/service.go index f6a081a..1b4f0e5 100644 --- a/config/parsing/service.go +++ b/config/parsing/service.go @@ -54,30 +54,14 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { tlsConfig = defaultTLSConfig.Clone() } - var authers []auth.Authenticator - if auther := registry.AutherRegistry().Get(cfg.Listener.Auther); auther != nil { - authers = append(authers, auther) - } - for _, s := range cfg.Listener.Authers { - if auther := registry.AutherRegistry().Get(s); auther != nil { - authers = append(authers, auther) - } - } + authers := autherList(cfg.Listener.Auther, cfg.Listener.Authers...) if len(authers) == 0 { if auther := ParseAutherFromAuth(cfg.Listener.Auth); auther != nil { authers = append(authers, auther) } } - var admissions []admission.Admission - if adm := registry.AdmissionRegistry().Get(cfg.Admission); adm != nil { - admissions = append(admissions, adm) - } - for _, s := range cfg.Admissions { - if adm := registry.AdmissionRegistry().Get(s); adm != nil { - admissions = append(admissions, adm) - } - } + admissions := admissionList(cfg.Admission, cfg.Admissions...) ln := registry.ListenerRegistry().Get(cfg.Listener.Type)( listener.AddrOption(cfg.Addr), @@ -116,15 +100,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { tlsConfig = defaultTLSConfig.Clone() } - authers = nil - if auther := registry.AutherRegistry().Get(cfg.Handler.Auther); auther != nil { - authers = append(authers, auther) - } - for _, s := range cfg.Handler.Authers { - if auther := registry.AutherRegistry().Get(s); auther != nil { - authers = append(authers, auther) - } - } + authers = autherList(cfg.Handler.Auther, cfg.Handler.Authers...) if len(authers) == 0 { if auther := ParseAutherFromAuth(cfg.Handler.Auth); auther != nil { authers = append(authers, auther) @@ -156,20 +132,11 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { WithRecorder(recorders...). WithLogger(handlerLogger) - var bypasses []bypass.Bypass - if bp := registry.BypassRegistry().Get(cfg.Bypass); bp != nil { - bypasses = append(bypasses, bp) - } - for _, s := range cfg.Bypasses { - if bp := registry.BypassRegistry().Get(s); bp != nil { - bypasses = append(bypasses, bp) - } - } h := registry.HandlerRegistry().Get(cfg.Handler.Type)( handler.RouterOption(router), handler.AutherOption(auth.AuthenticatorList(authers...)), handler.AuthOption(parseAuth(cfg.Handler.Auth)), - handler.BypassOption(bypass.BypassList(bypasses...)), + handler.BypassOption(bypass.BypassList(bypassList(cfg.Bypass, cfg.Bypasses...)...)), handler.TLSConfigOption(tlsConfig), handler.LoggerOption(handlerLogger), ) @@ -196,19 +163,74 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { } func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup { - if cfg == nil || len(cfg.Targets) == 0 { + if cfg == nil || + (len(cfg.Targets) == 0 && len(cfg.Nodes) == 0) { return nil } group := &chain.NodeGroup{} - for _, target := range cfg.Targets { - if v := strings.TrimSpace(target); v != "" { - group.AddNode(&chain.Node{ - Name: target, - Addr: target, - Marker: &chain.FailMarker{}, - }) + if len(cfg.Nodes) > 0 { + for _, node := range cfg.Nodes { + if node != nil { + group.AddNode(&chain.Node{ + Name: node.Name, + Addr: node.Addr, + Bypass: bypass.BypassList(bypassList(node.Bypass, node.Bypasses...)...), + Marker: &chain.FailMarker{}, + }) + } + } + } else { + for _, target := range cfg.Targets { + if v := strings.TrimSpace(target); v != "" { + group.AddNode(&chain.Node{ + Name: target, + Addr: target, + Marker: &chain.FailMarker{}, + }) + } } } + return group.WithSelector(parseSelector(cfg.Selector)) } + +func bypassList(name string, names ...string) []bypass.Bypass { + var bypasses []bypass.Bypass + if bp := registry.BypassRegistry().Get(name); bp != nil { + bypasses = append(bypasses, bp) + } + for _, s := range names { + if bp := registry.BypassRegistry().Get(s); bp != nil { + bypasses = append(bypasses, bp) + } + } + return bypasses +} + +func autherList(name string, names ...string) []auth.Authenticator { + var authers []auth.Authenticator + if auther := registry.AutherRegistry().Get(name); auther != nil { + authers = append(authers, auther) + } + for _, s := range names { + if auther := registry.AutherRegistry().Get(s); auther != nil { + authers = append(authers, auther) + } + } + return authers +} + +func admissionList(name string, names ...string) []admission.Admission { + var admissions []admission.Admission + if adm := registry.AdmissionRegistry().Get(name); adm != nil { + admissions = append(admissions, adm) + } + for _, s := range names { + if adm := registry.AdmissionRegistry().Get(s); adm != nil { + admissions = append(admissions, adm) + } + } + + return admissions +} diff --git a/go.mod b/go.mod index 226cda5..d396819 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.7.7 - github.com/go-gost/core v0.0.0-20220824151220-81bf7b985abe + github.com/go-gost/core v0.0.0-20220825133341-04b4a79b80c2 github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 @@ -16,7 +16,7 @@ require ( github.com/golang/snappy v0.0.4 github.com/gorilla/websocket v1.5.0 github.com/lucas-clemente/quic-go v0.28.1 - github.com/miekg/dns v1.1.47 + github.com/miekg/dns v1.1.50 github.com/prometheus/client_golang v1.12.1 github.com/rs/xid v1.3.0 github.com/shadowsocks/go-shadowsocks2 v0.1.5 diff --git a/go.sum b/go.sum index b61eb84..bdb3afc 100644 --- a/go.sum +++ b/go.sum @@ -119,8 +119,8 @@ github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gost/core v0.0.0-20220824151220-81bf7b985abe h1:PqILl/6QEzdWGnhKjOD2ZqxwCGKd1xUl8aS7DrCdsNQ= -github.com/go-gost/core v0.0.0-20220824151220-81bf7b985abe/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= +github.com/go-gost/core v0.0.0-20220825133341-04b4a79b80c2 h1:pyFxEUs5ln2rvKDZrk9HKNpJiUYxc4OyEVylkjK4glc= +github.com/go-gost/core v0.0.0-20220825133341-04b4a79b80c2/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= @@ -300,8 +300,8 @@ github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27k github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4= -github.com/miekg/dns v1.1.47 h1:J9bWiXbqMbnZPcY8Qi2E3EWIBsIm6MZzzJB9VRg5gL8= -github.com/miekg/dns v1.1.47/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= +github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= +github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= github.com/mitchellh/mapstructure v1.4.3 h1:OVowDSCllw/YjdLkam3/sm7wEtOy59d8ndGgCcyj8cs= github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mmcloughlin/avo v0.0.0-20200803215136-443f81d77104 h1:ULR/QWMgcgRiZLUjSSJMU+fW+RDMstRdmnDWj9Q+AsA= diff --git a/handler/dns/handler.go b/handler/dns/handler.go index b7aa3c5..661f8cf 100644 --- a/handler/dns/handler.go +++ b/handler/dns/handler.go @@ -29,7 +29,8 @@ func init() { } type dnsHandler struct { - exchangers []exchanger.Exchanger + group *chain.NodeGroup + exchangers map[string]exchanger.Exchanger cache *resolver_util.Cache router *chain.Router hosts hosts.HostMapper @@ -44,7 +45,8 @@ func NewHandler(opts ...handler.Option) handler.Handler { } return &dnsHandler{ - options: options, + options: options, + exchangers: make(map[string]exchanger.Exchanger), } } @@ -62,23 +64,38 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { } h.hosts = h.router.Hosts() - for _, server := range h.md.dns { - server = strings.TrimSpace(server) - if server == "" { + if h.group == nil { + h.group = &chain.NodeGroup{} + for i, addr := range h.md.dns { + addr = strings.TrimSpace(addr) + if addr == "" { + continue + } + h.group.AddNode(&chain.Node{ + Name: fmt.Sprintf("target-%d", i), + Addr: addr, + Marker: &chain.FailMarker{}, + }) + } + } + for _, node := range h.group.Nodes() { + addr := strings.TrimSpace(node.Addr) + if addr == "" { continue } ex, err := exchanger.NewExchanger( - server, + addr, exchanger.RouterOption(h.router), exchanger.TimeoutOption(h.md.timeout), exchanger.LoggerOption(log), ) if err != nil { - log.Warnf("parse %s: %v", server, err) + log.Warnf("parse %s: %v", addr, err) continue } - h.exchangers = append(h.exchangers, ex) + h.exchangers[node.Name] = ex } + if len(h.exchangers) == 0 { ex, err := exchanger.NewExchanger( defaultNameserver, @@ -90,12 +107,17 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { if err != nil { return err } - h.exchangers = append(h.exchangers, ex) + h.exchangers["default"] = ex } return } +// Forward implements handler.Forwarder. +func (h *dnsHandler) Forward(group *chain.NodeGroup) { + h.group = group +} + func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error { defer conn.Close() @@ -152,7 +174,6 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger } var mr *dns.Msg - if log.IsLevelEnabled(logger.TraceLevel) { defer func() { if mr != nil { @@ -161,6 +182,15 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger }() } + if h.options.Bypass != nil && mq.Question[0].Qclass == dns.ClassINET { + if h.options.Bypass.Contains(strings.Trim(mq.Question[0].Name, ".")) { + log.Debug("bypass: ", mq.Question[0].Name) + mr = (&dns.Msg{}).SetReply(&mq) + b := bufpool.Get(h.md.bufferSize) + return mr.PackBuffer(*b) + } + } + mr = h.lookupHosts(&mq, log) if mr != nil { b := bufpool.Get(h.md.bufferSize) @@ -195,16 +225,16 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger return nil, err } - var reply []byte - for _, ex := range h.exchangers { - log.Debugf("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String()) - reply, err = ex.Exchange(ctx, query) - if err == nil { - break - } + ex := h.selectExchanger(strings.Trim(mq.Question[0].Name, ".")) + if ex == nil { + err := fmt.Errorf("exchange not found for %s", mq.Question[0].Name) log.Error(err) + return nil, err } + + reply, err := ex.Exchange(ctx, query) if err != nil { + log.Error(err) return nil, err } @@ -266,3 +296,15 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) { return } + +func (h *dnsHandler) selectExchanger(addr string) exchanger.Exchanger { + if h.group == nil { + return nil + } + node := h.group.FilterAddr(addr).Next() + if node == nil { + return nil + } + + return h.exchangers[node.Name] +}