diff --git a/cmd/gost/config.go b/cmd/gost/config.go index 6f47815..52ca251 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -2,6 +2,7 @@ package main import ( "io" + "net" "os" "strings" @@ -11,16 +12,21 @@ import ( "github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/dialer" "github.com/go-gost/gost/pkg/handler" + hostspkg "github.com/go-gost/gost/pkg/hosts" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" + "github.com/go-gost/gost/pkg/resolver" + resolver_impl "github.com/go-gost/gost/pkg/resolver/impl" "github.com/go-gost/gost/pkg/service" ) var ( - chains = make(map[string]*chain.Chain) - bypasses = make(map[string]bypass.Bypass) + chains = make(map[string]*chain.Chain) + bypasses = make(map[string]bypass.Bypass) + resolvers = make(map[string]resolver.Resolver) + hosts = make(map[string]*hostspkg.Hosts) ) func buildService(cfg *config.Config) (services []*service.Service) { @@ -32,6 +38,17 @@ func buildService(cfg *config.Config) (services []*service.Service) { bypasses[bypassCfg.Name] = bypassFromConfig(bypassCfg) } + for _, resolverCfg := range cfg.Resolvers { + r, err := resolverFromConfig(resolverCfg) + if err != nil { + log.Fatal(err) + } + resolvers[resolverCfg.Name] = r + } + for _, hostsCfg := range cfg.Hosts { + hosts[hostsCfg.Name] = hostsFromConfig(hostsCfg) + } + for _, chainCfg := range cfg.Chains { chains[chainCfg.Name] = chainFromConfig(chainCfg) } @@ -72,8 +89,10 @@ func buildService(cfg *config.Config) (services []*service.Service) { handler.BypassOption(bypasses[svc.Bypass]), handler.LoggerOption(handlerLogger), handler.RouterOption(&chain.Router{ - Chain: chains[svc.Chain], - Logger: handlerLogger, + Chain: chains[svc.Chain], + Resolver: resolvers[svc.Resolver], + Hosts: hosts[svc.Hosts], + Logger: handlerLogger, }), ) @@ -173,6 +192,20 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { return c } +func forwarderFromConfig(cfg *config.ForwarderConfig) *chain.NodeGroup { + if cfg == nil || len(cfg.Targets) == 0 { + return nil + } + + group := &chain.NodeGroup{} + for _, target := range cfg.Targets { + if v := strings.TrimSpace(target); v != "" { + group.AddNode(chain.NewNode(target, target)) + } + } + return group.WithSelector(selectorFromConfig(cfg.Selector)) +} + func logFromConfig(cfg *config.LogConfig) logger.Logger { if cfg == nil { cfg = &config.LogConfig{} @@ -234,16 +267,41 @@ func bypassFromConfig(cfg *config.BypassConfig) bypass.Bypass { return bypass.NewBypassPatterns(cfg.Reverse, cfg.Matchers...) } -func forwarderFromConfig(cfg *config.ForwarderConfig) *chain.NodeGroup { - if cfg == nil || len(cfg.Targets) == 0 { +func resolverFromConfig(cfg *config.ResolverConfig) (resolver.Resolver, error) { + if cfg == nil { + return nil, nil + } + var nameservers []resolver_impl.NameServer + for _, server := range cfg.Nameservers { + nameservers = append(nameservers, resolver_impl.NameServer{ + Addr: server.Addr, + Chain: chains[server.Chain], + TTL: server.TTL, + Timeout: server.Timeout, + ClientIP: net.ParseIP(server.ClientIP), + Prefer: server.Prefer, + Hostname: server.Hostname, + }) + } + return resolver_impl.NewResolver(nameservers) +} + +func hostsFromConfig(cfg *config.HostsConfig) *hostspkg.Hosts { + if cfg == nil { return nil } + hosts := &hostspkg.Hosts{} - group := &chain.NodeGroup{} - for _, target := range cfg.Targets { - if v := strings.TrimSpace(target); v != "" { - group.AddNode(chain.NewNode(target, target)) + for _, host := range cfg.Entries { + if host.IP == "" || host.Hostname == "" { + continue } + + ip := net.ParseIP(host.IP) + if ip == nil { + continue + } + hosts.AddHost(hostspkg.NewHost(ip, host.Hostname, host.Aliases...)) } - return group.WithSelector(selectorFromConfig(cfg.Selector)) + return hosts } diff --git a/gost.yml b/gost.yml index 189952a..9b81b5d 100644 --- a/gost.yml +++ b/gost.yml @@ -3,29 +3,6 @@ log: level: debug # debug, info, warn, error, fatal format: json # text, json -profiling: - addr: ":6060" - enabled: true - -# tls: -# cert: "cert.pem" -# key: "key.pem" -# ca: "root.ca" - -resolvers: -- name: resolver-0 - nameservers: - - addr: udp://8.8.8.8:53 - chain: chain-0 - ttl: 60s - prefer: ipv4 - clientIP: 1.2.3.4 - timeout: 3s - - addr: tcp://1.1.1.1:53 - - addr: tls://1.1.1.1:853 - - addr: https://1.0.0.1/dns-query - hostname: cloudflare-dns.com - services: - name: http+tcp url: "http://gost:gost@:8000" @@ -95,7 +72,7 @@ services: readTimeout: 5s retry: 3 notls: true - # udpBufferSize: 4096 # range [512, 66560] + # udpBufferSize: 1024 listener: type: tcp metadata: @@ -285,7 +262,7 @@ chains: metadata: {} bypasses: -- name: bypass01 +- name: bypass-0 reverse: false matchers: - .baidu.com @@ -312,4 +289,42 @@ bypasses: # From IANA Multicast Address Space Registry # http://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml - - 224.0.0.0/4 # RFC5771: Multicast/Reserved \ No newline at end of file + - 224.0.0.0/4 # RFC5771: Multicast/Reserved + +# tls: +# cert: "cert.pem" +# key: "key.pem" +# ca: "root.ca" + +resolvers: +- name: resolver-0 + nameservers: + - addr: udp://8.8.8.8:53 + chain: chain-0 + ttl: 60s + prefer: ipv4 + clientIP: 1.2.3.4 + timeout: 3s + - addr: tcp://1.1.1.1:53 + - addr: tls://1.1.1.1:853 + - addr: https://1.0.0.1/dns-query + hostname: cloudflare-dns.com + +hosts: +- name: hosts-0 + entries: + - ip: 127.0.0.1 + hostname: localhost + - ip: 192.168.1.10 + hostname: foo.mydomain.org + aliases: + - foo + - ip: 192.168.1.13 + hostname: bar.mydomain.org + aliases: + - bar + - baz + +profiling: + addr: ":6060" + enabled: true diff --git a/pkg/chain/chain.go b/pkg/chain/chain.go index e50e1c1..dfdc88a 100644 --- a/pkg/chain/chain.go +++ b/pkg/chain/chain.go @@ -12,16 +12,16 @@ func (c *Chain) AddNodeGroup(group *NodeGroup) { c.groups = append(c.groups, group) } -func (c *Chain) GetRoute() (r *Route) { +func (c *Chain) GetRoute() (r *route) { return c.GetRouteFor("tcp", "") } -func (c *Chain) GetRouteFor(network, address string) (r *Route) { +func (c *Chain) GetRouteFor(network, address string) (r *route) { if c == nil || len(c.groups) == 0 { return } - r = &Route{} + r = &route{} for _, group := range c.groups { node := group.Next() if node == nil { @@ -36,7 +36,7 @@ func (c *Chain) GetRouteFor(network, address string) (r *Route) { WithRoute(r) node = node.Copy(). WithTransport(tr) - r = &Route{} + r = &route{} } r.AddNode(node) diff --git a/pkg/chain/route.go b/pkg/chain/route.go index 3f09685..ef6b5f1 100644 --- a/pkg/chain/route.go +++ b/pkg/chain/route.go @@ -15,15 +15,15 @@ var ( ErrEmptyRoute = errors.New("empty route") ) -type Route struct { +type route struct { nodes []*Node } -func (r *Route) AddNode(node *Node) { +func (r *route) AddNode(node *Node) { r.nodes = append(r.nodes, node) } -func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) { +func (r *route) connect(ctx context.Context) (conn net.Conn, err error) { if r.IsEmpty() { return nil, ErrEmptyRoute } @@ -67,7 +67,7 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) { return } -func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, error) { +func (r *route) Dial(ctx context.Context, network, address string) (net.Conn, error) { if r.IsEmpty() { return r.dialDirect(ctx, network, address) } @@ -85,7 +85,7 @@ func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, er return cc, nil } -func (r *Route) dialDirect(ctx context.Context, network, address string) (net.Conn, error) { +func (r *route) dialDirect(ctx context.Context, network, address string) (net.Conn, error) { switch network { case "udp", "udp4", "udp6": if address == "" { @@ -98,7 +98,7 @@ func (r *Route) dialDirect(ctx context.Context, network, address string) (net.Co return d.DialContext(ctx, network, address) } -func (r *Route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { +func (r *route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { if r.IsEmpty() { return r.bindLocal(ctx, network, address, opts...) } @@ -117,18 +117,18 @@ func (r *Route) Bind(ctx context.Context, network, address string, opts ...conne return ln, nil } -func (r *Route) IsEmpty() bool { +func (r *route) IsEmpty() bool { return r == nil || len(r.nodes) == 0 } -func (r *Route) Last() *Node { +func (r *route) Last() *Node { if r.IsEmpty() { return nil } return r.nodes[len(r.nodes)-1] } -func (r *Route) Path() (path []*Node) { +func (r *route) Path() (path []*Node) { if r == nil || len(r.nodes) == 0 { return nil } @@ -142,7 +142,7 @@ func (r *Route) Path() (path []*Node) { return } -func (r *Route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { +func (r *route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { options := connector.BindOptions{} for _, opt := range opts { opt(&options) diff --git a/pkg/chain/router.go b/pkg/chain/router.go index 693ca6f..ca02f19 100644 --- a/pkg/chain/router.go +++ b/pkg/chain/router.go @@ -3,11 +3,11 @@ package chain import ( "bytes" "context" - "errors" "fmt" "net" "github.com/go-gost/gost/pkg/connector" + "github.com/go-gost/gost/pkg/hosts" "github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/resolver" ) @@ -15,6 +15,7 @@ import ( type Router struct { Retries int Chain *Chain + Hosts *hosts.Hosts Resolver resolver.Resolver Logger logger.Logger } @@ -77,11 +78,10 @@ func (r *Router) resolve(ctx context.Context, addr string) (string, error) { return "", err } - /* - if ip := hosts.Lookup(host); ip != nil { - return net.JoinHostPort(ip.String(), port) - } - */ + if ip := r.Hosts.Lookup(host); ip != nil { + r.Logger.Debugf("hit hosts: %s -> %s", host, ip) + return net.JoinHostPort(ip.String(), port), nil + } if r.Resolver != nil { ips, err := r.Resolver.Resolve(ctx, host) @@ -89,7 +89,7 @@ func (r *Router) resolve(ctx context.Context, addr string) (string, error) { r.Logger.Error(err) } if len(ips) == 0 { - return "", errors.New("domain not exists") + return "", fmt.Errorf("resolver: domain %s does not exists", host) } return net.JoinHostPort(ips[0].String(), port), nil } diff --git a/pkg/chain/transport.go b/pkg/chain/transport.go index f717022..2d3c12b 100644 --- a/pkg/chain/transport.go +++ b/pkg/chain/transport.go @@ -10,7 +10,7 @@ import ( type Transport struct { addr string - route *Route + route *route dialer dialer.Dialer connector connector.Connector } @@ -82,7 +82,7 @@ func (tr *Transport) Multiplex() bool { return false } -func (tr *Transport) WithRoute(r *Route) *Transport { +func (tr *Transport) WithRoute(r *route) *Transport { tr.route = r return tr } diff --git a/pkg/config/config.go b/pkg/config/config.go index 694bd8a..4d3826e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -47,6 +47,33 @@ type BypassConfig struct { Reverse bool `yaml:",omitempty"` Matchers []string } + +type NameserverConfig struct { + Addr string + Chain string + Prefer string + ClientIP string + Hostname string + TTL time.Duration + Timeout time.Duration +} + +type ResolverConfig struct { + Name string + Nameservers []NameserverConfig +} + +type HostConfig struct { + IP string + Hostname string + Aliases []string +} + +type HostsConfig struct { + Name string + Entries []HostConfig +} + type ListenerConfig struct { Type string Metadata map[string]interface{} `yaml:",omitempty"` @@ -78,6 +105,8 @@ type ServiceConfig struct { Addr string `yaml:",omitempty"` Chain string `yaml:",omitempty"` Bypass string `yaml:",omitempty"` + Resolver string `yaml:",omitempty"` + Hosts string `yaml:",omitempty"` Listener *ListenerConfig `yaml:",omitempty"` Handler *HandlerConfig `yaml:",omitempty"` Forwarder *ForwarderConfig `yaml:",omitempty"` @@ -105,12 +134,14 @@ type NodeConfig struct { } type Config struct { - Log *LogConfig `yaml:",omitempty"` - Profiling *ProfilingConfig `yaml:",omitempty"` - TLS *TLSConfig `yaml:",omitempty"` + Log *LogConfig `yaml:",omitempty"` + Profiling *ProfilingConfig `yaml:",omitempty"` + TLS *TLSConfig `yaml:",omitempty"` + Bypasses []*BypassConfig `yaml:",omitempty"` + Resolvers []*ResolverConfig `yaml:",omitempty"` + Hosts []*HostsConfig `yaml:",omitempty"` + Chains []*ChainConfig `yaml:",omitempty"` Services []*ServiceConfig - Chains []*ChainConfig `yaml:",omitempty"` - Bypasses []*BypassConfig `yaml:",omitempty"` } func (c *Config) Load() error { diff --git a/pkg/hosts/hosts.go b/pkg/hosts/hosts.go new file mode 100644 index 0000000..893fa4f --- /dev/null +++ b/pkg/hosts/hosts.go @@ -0,0 +1,56 @@ +package hosts + +import ( + "net" +) + +// Host is a static mapping from hostname to IP. +type Host struct { + IP net.IP + Hostname string + Aliases []string +} + +// NewHost creates a Host. +func NewHost(ip net.IP, hostname string, aliases ...string) Host { + return Host{ + IP: ip, + Hostname: hostname, + Aliases: aliases, + } +} + +// Hosts is a static table lookup for hostnames. +// For each host a single line should be present with the following information: +// IP_address canonical_hostname [aliases...] +// Fields of the entry are separated by any number of blanks and/or tab characters. +// Text from a "#" character until the end of the line is a comment, and is ignored. +type Hosts struct { + hosts []Host +} + +// AddHost adds host(s) to the host table. +func (h *Hosts) AddHost(host ...Host) { + h.hosts = append(h.hosts, host...) +} + +// Lookup searches the IP address corresponds to the given host from the host table. +func (h *Hosts) Lookup(host string) (ip net.IP) { + if h == nil || host == "" { + return + } + + for _, h := range h.hosts { + if h.Hostname == host { + ip = h.IP + break + } + for _, alias := range h.Aliases { + if alias == host { + ip = h.IP + break + } + } + } + return +} diff --git a/pkg/resolver/impl/resolver.go b/pkg/resolver/impl/resolver.go index 5458cf3..0ca4b47 100644 --- a/pkg/resolver/impl/resolver.go +++ b/pkg/resolver/impl/resolver.go @@ -65,7 +65,10 @@ func NewResolver(nameservers []NameServer, opts ...ResolverOption) (resolverpkg. } ex, err := exchanger.NewExchanger( addr, - exchanger.ChainOption(server.Chain), + exchanger.RouterOption(&chain.Router{ + Chain: server.Chain, + Logger: options.logger, + }), exchanger.TimeoutOption(server.Timeout), exchanger.LoggerOption(options.logger), )