add hosts support for dns

This commit is contained in:
ginuerzh 2022-01-19 16:33:27 +08:00
parent 1e613fc2e6
commit f357117056
11 changed files with 212 additions and 53 deletions

View File

@ -164,11 +164,11 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) {
if len(ss) != 2 {
continue
}
hostsCfg.Entries = append(
hostsCfg.Entries,
config.HostConfig{
IP: ss[0],
Hostname: ss[1],
hostsCfg.Mappings = append(
hostsCfg.Mappings,
config.HostMappingConfig{
Hostname: ss[0],
IP: ss[1],
},
)
}

View File

@ -340,7 +340,14 @@ func bypassFromConfig(cfg *config.BypassConfig) bypass.Bypass {
if cfg == nil {
return nil
}
return bypass.NewBypassPatterns(cfg.Reverse, cfg.Matchers...)
return bypass.NewBypassPatterns(
cfg.Reverse,
cfg.Matchers,
bypass.LoggerBypassOption(log.WithFields(map[string]interface{}{
"kind": "bypass",
"bypass": cfg.Name,
})),
)
}
func resolverFromConfig(cfg *config.ResolverConfig) (resolver.Resolver, error) {
@ -371,12 +378,16 @@ func resolverFromConfig(cfg *config.ResolverConfig) (resolver.Resolver, error) {
}
func hostsFromConfig(cfg *config.HostsConfig) hostspkg.HostMapper {
if cfg == nil || len(cfg.Entries) == 0 {
if cfg == nil || len(cfg.Mappings) == 0 {
return nil
}
hosts := hostspkg.NewHosts()
hosts.Logger = log.WithFields(map[string]interface{}{
"kind": "hosts",
"hosts": cfg.Name,
})
for _, host := range cfg.Entries {
for _, host := range cfg.Mappings {
if host.IP == "" || host.Hostname == "" {
continue
}

View File

@ -295,7 +295,7 @@ resolvers:
hosts:
- name: hosts-0
entries:
mappings:
- ip: 127.0.0.1
hostname: localhost
- ip: 192.168.1.10

View File

@ -5,6 +5,7 @@ import (
"strconv"
"strings"
"github.com/go-gost/gost/pkg/logger"
glob "github.com/gobwas/glob"
)
@ -105,30 +106,48 @@ type Bypass interface {
Contains(addr string) bool
}
type bypassOptions struct {
logger logger.Logger
}
type BypassOption func(opts *bypassOptions)
func LoggerBypassOption(logger logger.Logger) BypassOption {
return func(opts *bypassOptions) {
opts.logger = logger
}
}
type bypass struct {
matchers []Matcher
reversed bool
options bypassOptions
}
// NewBypass creates and initializes a new Bypass using matchers as its match rules.
// The rules will be reversed if the reversed is true.
func NewBypass(reversed bool, matchers ...Matcher) Bypass {
func NewBypass(reversed bool, matchers []Matcher, opts ...BypassOption) Bypass {
options := bypassOptions{}
for _, opt := range opts {
opt(&options)
}
return &bypass{
matchers: matchers,
reversed: reversed,
options: options,
}
}
// NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules.
// The rules will be reversed if the reverse is true.
func NewBypassPatterns(reversed bool, patterns ...string) Bypass {
func NewBypassPatterns(reversed bool, patterns []string, opts ...BypassOption) Bypass {
var matchers []Matcher
for _, pattern := range patterns {
if m := NewMatcher(pattern); m != nil {
matchers = append(matchers, m)
}
}
return NewBypass(reversed, matchers...)
return NewBypass(reversed, matchers, opts...)
}
func (bp *bypass) Contains(addr string) bool {
@ -153,6 +172,11 @@ func (bp *bypass) Contains(addr string) bool {
break
}
}
return !bp.reversed && matched ||
b := !bp.reversed && matched ||
bp.reversed && !matched
if b {
bp.options.logger.Debugf("bypass: %s", addr)
}
return b
}

View File

@ -79,9 +79,9 @@ func (r *Router) resolve(ctx context.Context, addr string) (string, error) {
}
if r.Hosts != nil {
if ip := r.Hosts.Lookup(host); ip != nil {
r.Logger.Debugf("hit hosts: %s -> %s", host, ip)
return net.JoinHostPort(ip.String(), port), nil
if ips, _ := r.Hosts.Lookup("ip", host); len(ips) > 0 {
r.Logger.Debugf("hit host mapper: %s -> %s", host, ips)
return net.JoinHostPort(ips[0].String(), port), nil
}
}

View File

@ -70,15 +70,15 @@ type ResolverConfig struct {
Nameservers []NameserverConfig
}
type HostConfig struct {
type HostMappingConfig struct {
IP string
Hostname string
Aliases []string `yaml:",omitempty"`
}
type HostsConfig struct {
Name string
Entries []HostConfig
Name string
Mappings []HostMappingConfig
}
type ListenerConfig struct {

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"net"
"strconv"
"strings"
@ -49,6 +50,8 @@ func NewHandler(opts ...handler.Option) handler.Handler {
}
func (h *dnsHandler) Init(md md.Metadata) (err error) {
h.logger = h.options.Logger
if err = h.parseMetadata(md); err != nil {
return
}
@ -58,10 +61,9 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) {
Retries: h.options.Retries,
Chain: h.options.Chain,
Resolver: h.options.Resolver,
Hosts: h.options.Hosts,
Logger: h.options.Logger,
// Hosts: h.options.Hosts,
Logger: h.options.Logger,
}
h.logger = h.options.Logger
for _, server := range h.md.dns {
server = strings.TrimSpace(server)
@ -127,6 +129,7 @@ func (h *dnsHandler) Handle(ctx context.Context, conn net.Conn) {
if err != nil {
return
}
defer bufpool.Put(&reply)
if _, err = conn.Write(reply); err != nil {
h.logger.Error(err)
@ -153,14 +156,31 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
}
var mr *dns.Msg
// cache only for single question message.
if h.logger.IsLevelEnabled(logger.DebugLevel) {
defer func() {
if mr != nil {
h.logger.Debug(mr.String())
}
}()
}
mr = h.lookupHosts(&mq)
if mr != nil {
b := bufpool.Get(4096)
return mr.PackBuffer(*b)
}
// only cache for single question message.
if len(mq.Question) == 1 {
key := resolver_util.NewCacheKey(&mq.Question[0])
mr = h.cache.Load(key)
if mr != nil {
h.logger.Debugf("exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
mr.Id = mq.Id
return mr.Pack()
b := bufpool.Get(4096)
return mr.PackBuffer(*b)
}
defer func() {
@ -170,7 +190,10 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
}()
}
query, err := mq.Pack()
b := bufpool.Get(4096)
defer bufpool.Put(b)
query, err := mq.PackBuffer(*b)
if err != nil {
h.logger.Error(err)
return nil, err
@ -204,6 +227,56 @@ func (h *dnsHandler) exchange(ctx context.Context, msg []byte) ([]byte, error) {
return reply, nil
}
// lookup host mapper
func (h *dnsHandler) lookupHosts(r *dns.Msg) (m *dns.Msg) {
if h.options.Hosts == nil ||
r.Question[0].Qclass != dns.ClassINET ||
(r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA) {
return nil
}
m = &dns.Msg{}
m.SetReply(r)
host := strings.TrimSuffix(r.Question[0].Name, ".")
switch r.Question[0].Qtype {
case dns.TypeA:
ips, _ := h.options.Hosts.Lookup("ip4", host)
if len(ips) == 0 {
return nil
}
h.logger.Debugf("hit host mapper: %s -> %s", host, ips)
for _, ip := range ips {
rr, err := dns.NewRR(fmt.Sprintf("%s IN A %s\n", r.Question[0].Name, ip.String()))
if err != nil {
h.logger.Error(err)
return nil
}
m.Answer = append(m.Answer, rr)
}
case dns.TypeAAAA:
ips, _ := h.options.Hosts.Lookup("ip6", host)
if len(ips) == 0 {
return nil
}
h.logger.Debugf("hit host mapper: %s -> %s", host, ips)
for _, ip := range ips {
rr, err := dns.NewRR(fmt.Sprintf("%s IN AAAA %s\n", r.Question[0].Name, ip.String()))
if err != nil {
h.logger.Error(err)
return nil
}
m.Answer = append(m.Answer, rr)
}
}
return
}
func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string {
buf := new(bytes.Buffer)
buf.WriteString(m.MsgHdr.String() + " ")

View File

@ -2,17 +2,19 @@ package hosts
import (
"net"
"sync"
"github.com/go-gost/gost/pkg/logger"
)
// HostMapper is a mapping from hostname to IP.
type HostMapper interface {
Lookup(host string) net.IP
Lookup(network, host string) ([]net.IP, bool)
}
type host struct {
IP net.IP
type hostMapping struct {
IPs []net.IP
Hostname string
Aliases []string
}
// Hosts is a static table lookup for hostnames.
@ -21,7 +23,8 @@ type host struct {
// Fields of the entry are separated by any number of blanks and/or tab characters.
// Text from a "#" character until the end of the line is a comment, and is ignored.
type Hosts struct {
mappings []host
mappings sync.Map
Logger logger.Logger
}
func NewHosts() *Hosts {
@ -30,30 +33,77 @@ func NewHosts() *Hosts {
// Map maps ip to hostname or aliases.
func (h *Hosts) Map(ip net.IP, hostname string, aliases ...string) {
h.mappings = append(h.mappings, host{
IP: ip,
Hostname: hostname,
Aliases: aliases,
})
if hostname == "" {
return
}
v, _ := h.mappings.Load(hostname)
m, _ := v.(*hostMapping)
if m == nil {
m = &hostMapping{
IPs: []net.IP{ip},
Hostname: hostname,
}
} else {
m.IPs = append(m.IPs, ip)
}
h.mappings.Store(hostname, m)
for _, alias := range aliases {
// indirect mapping from alias to hostname
if alias != "" {
h.mappings.Store(alias, &hostMapping{
Hostname: hostname,
})
}
}
}
// Lookup searches the IP address corresponds to the given host from the host table.
func (h *Hosts) Lookup(host string) (ip net.IP) {
// Lookup searches the IP address corresponds to the given network and host from the host table.
// The network should be 'ip', 'ip4' or 'ip6', default network is 'ip'.
func (h *Hosts) Lookup(network, host string) (ips []net.IP, ok bool) {
if h == nil || host == "" {
return
}
for _, h := range h.mappings {
if h.Hostname == host {
ip = h.IP
break
}
for _, alias := range h.Aliases {
if alias == host {
ip = h.IP
break
}
v, ok := h.mappings.Load(host)
if !ok {
return
}
m, _ := v.(*hostMapping)
if m == nil {
return
}
// hostname alias
if host != m.Hostname {
v, _ = h.mappings.Load(m.Hostname)
m, _ = v.(*hostMapping)
if m == nil {
return
}
}
switch network {
case "ip4":
for _, ip := range m.IPs {
if ip = ip.To4(); ip != nil {
ips = append(ips, ip)
}
}
case "ip6":
for _, ip := range m.IPs {
if ip.To4() == nil {
ips = append(ips, ip)
}
}
default:
ips = m.IPs
}
if len(ips) > 0 {
h.Logger.Debugf("host mapper: %s -> %s", host, ips)
}
return
}

View File

@ -56,7 +56,7 @@ func (c *Cache) Load(key CacheKey) *dns.Msg {
return nil
}
c.logger.Debugf("resolver cache hit: %s", key)
c.logger.Debugf("hit resolver cache: %s", key)
return item.msg.Copy()
}

View File

@ -105,7 +105,10 @@ func GetString(md Metadata, key string) (v string) {
}
func GetStrings(md Metadata, key string) (ss []string) {
if v, _ := md.Get(key).([]interface{}); len(v) > 0 {
switch v := md.Get(key).(type) {
case []string:
ss = v
case []interface{}:
for _, vv := range v {
if s, ok := vv.(string); ok {
ss = append(ss, s)

View File

@ -48,7 +48,6 @@ type resolver struct {
servers []NameServer
cache *resolver_util.Cache
options resolverOptions
logger logger.Logger
}
func NewResolver(nameservers []NameServer, opts ...ResolverOption) (resolverpkg.Resolver, error) {
@ -87,7 +86,6 @@ func NewResolver(nameservers []NameServer, opts ...ResolverOption) (resolverpkg.
servers: servers,
cache: cache,
options: options,
logger: options.logger,
}, nil
}
@ -104,11 +102,11 @@ func (r *resolver) Resolve(ctx context.Context, host string) (ips []net.IP, err
for _, server := range r.servers {
ips, err = r.resolve(ctx, &server, host)
if err != nil {
r.logger.Error(err)
r.options.logger.Error(err)
continue
}
r.logger.Debugf("resolve %s via %s: %v", host, server.exchanger.String(), ips)
r.options.logger.Debugf("resolve %s via %s: %v", host, server.exchanger.String(), ips)
if len(ips) > 0 {
break