add nodes field for forwarder

This commit is contained in:
ginuerzh 2022-08-25 21:35:37 +08:00
parent 498a425656
commit d043ad94e7
6 changed files with 138 additions and 98 deletions

View File

@ -206,7 +206,9 @@ type HandlerConfig struct {
}
type ForwarderConfig struct {
Targets []string `json:"targets"`
// DEPRECATED by nodes since beta.4
Targets []string `yaml:",omitempty" json:"targets,omitempty"`
Nodes []*NodeConfig `json:"nodes"`
Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"`
}

View File

@ -94,12 +94,6 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) {
return nil, err
}
if v.Bypass == "" {
v.Bypass = hop.Bypass
}
if v.Bypasses == nil {
v.Bypasses = hop.Bypasses
}
if v.Resolver == "" {
v.Resolver = hop.Resolver
}
@ -127,20 +121,10 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) {
WithInterface(v.Interface).
WithSockOpts(sockOpts)
var bypasses []bypass.Bypass
if bp := registry.BypassRegistry().Get(v.Bypass); bp != nil {
bypasses = append(bypasses, bp)
}
for _, s := range v.Bypasses {
if bp := registry.BypassRegistry().Get(s); bp != nil {
bypasses = append(bypasses, bp)
}
}
node := &chain.Node{
Name: v.Name,
Addr: v.Addr,
Bypass: bypass.BypassList(bypasses...),
Bypass: bypass.BypassList(bypassList(v.Bypass, v.Bypasses...)...),
Resolver: registry.ResolverRegistry().Get(v.Resolver),
Hosts: registry.HostsRegistry().Get(v.Hosts),
Marker: &chain.FailMarker{},
@ -153,18 +137,8 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) {
if s := parseSelector(hop.Selector); s != nil {
sel = s
}
group.WithSelector(sel)
var bypasses []bypass.Bypass
if bp := registry.BypassRegistry().Get(hop.Bypass); bp != nil {
bypasses = append(bypasses, bp)
}
for _, s := range hop.Bypasses {
if bp := registry.BypassRegistry().Get(s); bp != nil {
bypasses = append(bypasses, bp)
}
}
group.WithBypass(bypass.BypassList(bypasses...))
group.WithSelector(sel).
WithBypass(bypass.BypassList(bypassList(hop.Bypass, hop.Bypasses...)...))
c.AddNodeGroup(group)
}

View File

@ -54,30 +54,14 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
tlsConfig = defaultTLSConfig.Clone()
}
var authers []auth.Authenticator
if auther := registry.AutherRegistry().Get(cfg.Listener.Auther); auther != nil {
authers = append(authers, auther)
}
for _, s := range cfg.Listener.Authers {
if auther := registry.AutherRegistry().Get(s); auther != nil {
authers = append(authers, auther)
}
}
authers := autherList(cfg.Listener.Auther, cfg.Listener.Authers...)
if len(authers) == 0 {
if auther := ParseAutherFromAuth(cfg.Listener.Auth); auther != nil {
authers = append(authers, auther)
}
}
var admissions []admission.Admission
if adm := registry.AdmissionRegistry().Get(cfg.Admission); adm != nil {
admissions = append(admissions, adm)
}
for _, s := range cfg.Admissions {
if adm := registry.AdmissionRegistry().Get(s); adm != nil {
admissions = append(admissions, adm)
}
}
admissions := admissionList(cfg.Admission, cfg.Admissions...)
ln := registry.ListenerRegistry().Get(cfg.Listener.Type)(
listener.AddrOption(cfg.Addr),
@ -116,15 +100,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
tlsConfig = defaultTLSConfig.Clone()
}
authers = nil
if auther := registry.AutherRegistry().Get(cfg.Handler.Auther); auther != nil {
authers = append(authers, auther)
}
for _, s := range cfg.Handler.Authers {
if auther := registry.AutherRegistry().Get(s); auther != nil {
authers = append(authers, auther)
}
}
authers = autherList(cfg.Handler.Auther, cfg.Handler.Authers...)
if len(authers) == 0 {
if auther := ParseAutherFromAuth(cfg.Handler.Auth); auther != nil {
authers = append(authers, auther)
@ -156,20 +132,11 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
WithRecorder(recorders...).
WithLogger(handlerLogger)
var bypasses []bypass.Bypass
if bp := registry.BypassRegistry().Get(cfg.Bypass); bp != nil {
bypasses = append(bypasses, bp)
}
for _, s := range cfg.Bypasses {
if bp := registry.BypassRegistry().Get(s); bp != nil {
bypasses = append(bypasses, bp)
}
}
h := registry.HandlerRegistry().Get(cfg.Handler.Type)(
handler.RouterOption(router),
handler.AutherOption(auth.AuthenticatorList(authers...)),
handler.AuthOption(parseAuth(cfg.Handler.Auth)),
handler.BypassOption(bypass.BypassList(bypasses...)),
handler.BypassOption(bypass.BypassList(bypassList(cfg.Bypass, cfg.Bypasses...)...)),
handler.TLSConfigOption(tlsConfig),
handler.LoggerOption(handlerLogger),
)
@ -196,11 +163,24 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
}
func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup {
if cfg == nil || len(cfg.Targets) == 0 {
if cfg == nil ||
(len(cfg.Targets) == 0 && len(cfg.Nodes) == 0) {
return nil
}
group := &chain.NodeGroup{}
if len(cfg.Nodes) > 0 {
for _, node := range cfg.Nodes {
if node != nil {
group.AddNode(&chain.Node{
Name: node.Name,
Addr: node.Addr,
Bypass: bypass.BypassList(bypassList(node.Bypass, node.Bypasses...)...),
Marker: &chain.FailMarker{},
})
}
}
} else {
for _, target := range cfg.Targets {
if v := strings.TrimSpace(target); v != "" {
group.AddNode(&chain.Node{
@ -210,5 +190,47 @@ func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup {
})
}
}
}
return group.WithSelector(parseSelector(cfg.Selector))
}
func bypassList(name string, names ...string) []bypass.Bypass {
var bypasses []bypass.Bypass
if bp := registry.BypassRegistry().Get(name); bp != nil {
bypasses = append(bypasses, bp)
}
for _, s := range names {
if bp := registry.BypassRegistry().Get(s); bp != nil {
bypasses = append(bypasses, bp)
}
}
return bypasses
}
func autherList(name string, names ...string) []auth.Authenticator {
var authers []auth.Authenticator
if auther := registry.AutherRegistry().Get(name); auther != nil {
authers = append(authers, auther)
}
for _, s := range names {
if auther := registry.AutherRegistry().Get(s); auther != nil {
authers = append(authers, auther)
}
}
return authers
}
func admissionList(name string, names ...string) []admission.Admission {
var admissions []admission.Admission
if adm := registry.AdmissionRegistry().Get(name); adm != nil {
admissions = append(admissions, adm)
}
for _, s := range names {
if adm := registry.AdmissionRegistry().Get(s); adm != nil {
admissions = append(admissions, adm)
}
}
return admissions
}

4
go.mod
View File

@ -6,7 +6,7 @@ require (
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
github.com/gin-contrib/cors v1.3.1
github.com/gin-gonic/gin v1.7.7
github.com/go-gost/core v0.0.0-20220824151220-81bf7b985abe
github.com/go-gost/core v0.0.0-20220825133341-04b4a79b80c2
github.com/go-gost/gosocks4 v0.0.1
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09
github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7
@ -16,7 +16,7 @@ require (
github.com/golang/snappy v0.0.4
github.com/gorilla/websocket v1.5.0
github.com/lucas-clemente/quic-go v0.28.1
github.com/miekg/dns v1.1.47
github.com/miekg/dns v1.1.50
github.com/prometheus/client_golang v1.12.1
github.com/rs/xid v1.3.0
github.com/shadowsocks/go-shadowsocks2 v0.1.5

8
go.sum
View File

@ -119,8 +119,8 @@ github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gost/core v0.0.0-20220824151220-81bf7b985abe h1:PqILl/6QEzdWGnhKjOD2ZqxwCGKd1xUl8aS7DrCdsNQ=
github.com/go-gost/core v0.0.0-20220824151220-81bf7b985abe/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ=
github.com/go-gost/core v0.0.0-20220825133341-04b4a79b80c2 h1:pyFxEUs5ln2rvKDZrk9HKNpJiUYxc4OyEVylkjK4glc=
github.com/go-gost/core v0.0.0-20220825133341-04b4a79b80c2/go.mod h1:bHVbCS9da6XtKNYMkMUVcck5UqDDUkyC37erVfs4GXQ=
github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s=
github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc=
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04=
@ -300,8 +300,8 @@ github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27k
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/microcosm-cc/bluemonday v1.0.1/go.mod h1:hsXNsILzKxV+sX77C5b8FSuKF00vh2OMYv+xgHpAMF4=
github.com/miekg/dns v1.1.47 h1:J9bWiXbqMbnZPcY8Qi2E3EWIBsIm6MZzzJB9VRg5gL8=
github.com/miekg/dns v1.1.47/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA=
github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
github.com/mitchellh/mapstructure v1.4.3 h1:OVowDSCllw/YjdLkam3/sm7wEtOy59d8ndGgCcyj8cs=
github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/mmcloughlin/avo v0.0.0-20200803215136-443f81d77104 h1:ULR/QWMgcgRiZLUjSSJMU+fW+RDMstRdmnDWj9Q+AsA=

View File

@ -29,7 +29,8 @@ func init() {
}
type dnsHandler struct {
exchangers []exchanger.Exchanger
group *chain.NodeGroup
exchangers map[string]exchanger.Exchanger
cache *resolver_util.Cache
router *chain.Router
hosts hosts.HostMapper
@ -45,6 +46,7 @@ func NewHandler(opts ...handler.Option) handler.Handler {
return &dnsHandler{
options: options,
exchangers: make(map[string]exchanger.Exchanger),
}
}
@ -62,23 +64,38 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
}
h.hosts = h.router.Hosts()
for _, server := range h.md.dns {
server = strings.TrimSpace(server)
if server == "" {
if h.group == nil {
h.group = &chain.NodeGroup{}
for i, addr := range h.md.dns {
addr = strings.TrimSpace(addr)
if addr == "" {
continue
}
h.group.AddNode(&chain.Node{
Name: fmt.Sprintf("target-%d", i),
Addr: addr,
Marker: &chain.FailMarker{},
})
}
}
for _, node := range h.group.Nodes() {
addr := strings.TrimSpace(node.Addr)
if addr == "" {
continue
}
ex, err := exchanger.NewExchanger(
server,
addr,
exchanger.RouterOption(h.router),
exchanger.TimeoutOption(h.md.timeout),
exchanger.LoggerOption(log),
)
if err != nil {
log.Warnf("parse %s: %v", server, err)
log.Warnf("parse %s: %v", addr, err)
continue
}
h.exchangers = append(h.exchangers, ex)
h.exchangers[node.Name] = ex
}
if len(h.exchangers) == 0 {
ex, err := exchanger.NewExchanger(
defaultNameserver,
@ -90,12 +107,17 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
if err != nil {
return err
}
h.exchangers = append(h.exchangers, ex)
h.exchangers["default"] = ex
}
return
}
// Forward implements handler.Forwarder.
func (h *dnsHandler) Forward(group *chain.NodeGroup) {
h.group = group
}
func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
@ -152,7 +174,6 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger
}
var mr *dns.Msg
if log.IsLevelEnabled(logger.TraceLevel) {
defer func() {
if mr != nil {
@ -161,6 +182,15 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger
}()
}
if h.options.Bypass != nil && mq.Question[0].Qclass == dns.ClassINET {
if h.options.Bypass.Contains(strings.Trim(mq.Question[0].Name, ".")) {
log.Debug("bypass: ", mq.Question[0].Name)
mr = (&dns.Msg{}).SetReply(&mq)
b := bufpool.Get(h.md.bufferSize)
return mr.PackBuffer(*b)
}
}
mr = h.lookupHosts(&mq, log)
if mr != nil {
b := bufpool.Get(h.md.bufferSize)
@ -195,16 +225,16 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte, log logger.Logger
return nil, err
}
var reply []byte
for _, ex := range h.exchangers {
log.Debugf("exchange message %d via %s: %s", mq.Id, ex.String(), mq.Question[0].String())
reply, err = ex.Exchange(ctx, query)
if err == nil {
break
}
ex := h.selectExchanger(strings.Trim(mq.Question[0].Name, "."))
if ex == nil {
err := fmt.Errorf("exchange not found for %s", mq.Question[0].Name)
log.Error(err)
return nil, err
}
reply, err := ex.Exchange(ctx, query)
if err != nil {
log.Error(err)
return nil, err
}
@ -266,3 +296,15 @@ func (h *dnsHandler) lookupHosts(r *dns.Msg, log logger.Logger) (m *dns.Msg) {
return
}
func (h *dnsHandler) selectExchanger(addr string) exchanger.Exchanger {
if h.group == nil {
return nil
}
node := h.group.FilterAddr(addr).Next()
if node == nil {
return nil
}
return h.exchangers[node.Name]
}