add hosts

This commit is contained in:
ginuerzh 2021-12-31 11:32:06 +08:00
parent 9769efe33c
commit 4bf754b83b
9 changed files with 229 additions and 66 deletions

View File

@ -2,6 +2,7 @@ package main
import ( import (
"io" "io"
"net"
"os" "os"
"strings" "strings"
@ -11,16 +12,21 @@ import (
"github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/connector"
"github.com/go-gost/gost/pkg/dialer" "github.com/go-gost/gost/pkg/dialer"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
hostspkg "github.com/go-gost/gost/pkg/hosts"
"github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"
"github.com/go-gost/gost/pkg/resolver"
resolver_impl "github.com/go-gost/gost/pkg/resolver/impl"
"github.com/go-gost/gost/pkg/service" "github.com/go-gost/gost/pkg/service"
) )
var ( var (
chains = make(map[string]*chain.Chain) chains = make(map[string]*chain.Chain)
bypasses = make(map[string]bypass.Bypass) bypasses = make(map[string]bypass.Bypass)
resolvers = make(map[string]resolver.Resolver)
hosts = make(map[string]*hostspkg.Hosts)
) )
func buildService(cfg *config.Config) (services []*service.Service) { func buildService(cfg *config.Config) (services []*service.Service) {
@ -32,6 +38,17 @@ func buildService(cfg *config.Config) (services []*service.Service) {
bypasses[bypassCfg.Name] = bypassFromConfig(bypassCfg) bypasses[bypassCfg.Name] = bypassFromConfig(bypassCfg)
} }
for _, resolverCfg := range cfg.Resolvers {
r, err := resolverFromConfig(resolverCfg)
if err != nil {
log.Fatal(err)
}
resolvers[resolverCfg.Name] = r
}
for _, hostsCfg := range cfg.Hosts {
hosts[hostsCfg.Name] = hostsFromConfig(hostsCfg)
}
for _, chainCfg := range cfg.Chains { for _, chainCfg := range cfg.Chains {
chains[chainCfg.Name] = chainFromConfig(chainCfg) chains[chainCfg.Name] = chainFromConfig(chainCfg)
} }
@ -73,6 +90,8 @@ func buildService(cfg *config.Config) (services []*service.Service) {
handler.LoggerOption(handlerLogger), handler.LoggerOption(handlerLogger),
handler.RouterOption(&chain.Router{ handler.RouterOption(&chain.Router{
Chain: chains[svc.Chain], Chain: chains[svc.Chain],
Resolver: resolvers[svc.Resolver],
Hosts: hosts[svc.Hosts],
Logger: handlerLogger, Logger: handlerLogger,
}), }),
) )
@ -173,6 +192,20 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain {
return c return c
} }
func forwarderFromConfig(cfg *config.ForwarderConfig) *chain.NodeGroup {
if cfg == nil || len(cfg.Targets) == 0 {
return nil
}
group := &chain.NodeGroup{}
for _, target := range cfg.Targets {
if v := strings.TrimSpace(target); v != "" {
group.AddNode(chain.NewNode(target, target))
}
}
return group.WithSelector(selectorFromConfig(cfg.Selector))
}
func logFromConfig(cfg *config.LogConfig) logger.Logger { func logFromConfig(cfg *config.LogConfig) logger.Logger {
if cfg == nil { if cfg == nil {
cfg = &config.LogConfig{} cfg = &config.LogConfig{}
@ -234,16 +267,41 @@ func bypassFromConfig(cfg *config.BypassConfig) bypass.Bypass {
return bypass.NewBypassPatterns(cfg.Reverse, cfg.Matchers...) return bypass.NewBypassPatterns(cfg.Reverse, cfg.Matchers...)
} }
func forwarderFromConfig(cfg *config.ForwarderConfig) *chain.NodeGroup { func resolverFromConfig(cfg *config.ResolverConfig) (resolver.Resolver, error) {
if cfg == nil || len(cfg.Targets) == 0 { if cfg == nil {
return nil, nil
}
var nameservers []resolver_impl.NameServer
for _, server := range cfg.Nameservers {
nameservers = append(nameservers, resolver_impl.NameServer{
Addr: server.Addr,
Chain: chains[server.Chain],
TTL: server.TTL,
Timeout: server.Timeout,
ClientIP: net.ParseIP(server.ClientIP),
Prefer: server.Prefer,
Hostname: server.Hostname,
})
}
return resolver_impl.NewResolver(nameservers)
}
func hostsFromConfig(cfg *config.HostsConfig) *hostspkg.Hosts {
if cfg == nil {
return nil return nil
} }
hosts := &hostspkg.Hosts{}
group := &chain.NodeGroup{} for _, host := range cfg.Entries {
for _, target := range cfg.Targets { if host.IP == "" || host.Hostname == "" {
if v := strings.TrimSpace(target); v != "" { continue
group.AddNode(chain.NewNode(target, target))
} }
ip := net.ParseIP(host.IP)
if ip == nil {
continue
} }
return group.WithSelector(selectorFromConfig(cfg.Selector)) hosts.AddHost(hostspkg.NewHost(ip, host.Hostname, host.Aliases...))
}
return hosts
} }

View File

@ -3,29 +3,6 @@ log:
level: debug # debug, info, warn, error, fatal level: debug # debug, info, warn, error, fatal
format: json # text, json format: json # text, json
profiling:
addr: ":6060"
enabled: true
# tls:
# cert: "cert.pem"
# key: "key.pem"
# ca: "root.ca"
resolvers:
- name: resolver-0
nameservers:
- addr: udp://8.8.8.8:53
chain: chain-0
ttl: 60s
prefer: ipv4
clientIP: 1.2.3.4
timeout: 3s
- addr: tcp://1.1.1.1:53
- addr: tls://1.1.1.1:853
- addr: https://1.0.0.1/dns-query
hostname: cloudflare-dns.com
services: services:
- name: http+tcp - name: http+tcp
url: "http://gost:gost@:8000" url: "http://gost:gost@:8000"
@ -95,7 +72,7 @@ services:
readTimeout: 5s readTimeout: 5s
retry: 3 retry: 3
notls: true notls: true
# udpBufferSize: 4096 # range [512, 66560] # udpBufferSize: 1024
listener: listener:
type: tcp type: tcp
metadata: metadata:
@ -285,7 +262,7 @@ chains:
metadata: {} metadata: {}
bypasses: bypasses:
- name: bypass01 - name: bypass-0
reverse: false reverse: false
matchers: matchers:
- .baidu.com - .baidu.com
@ -313,3 +290,41 @@ bypasses:
# From IANA Multicast Address Space Registry # From IANA Multicast Address Space Registry
# http://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml # http://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml
- 224.0.0.0/4 # RFC5771: Multicast/Reserved - 224.0.0.0/4 # RFC5771: Multicast/Reserved
# tls:
# cert: "cert.pem"
# key: "key.pem"
# ca: "root.ca"
resolvers:
- name: resolver-0
nameservers:
- addr: udp://8.8.8.8:53
chain: chain-0
ttl: 60s
prefer: ipv4
clientIP: 1.2.3.4
timeout: 3s
- addr: tcp://1.1.1.1:53
- addr: tls://1.1.1.1:853
- addr: https://1.0.0.1/dns-query
hostname: cloudflare-dns.com
hosts:
- name: hosts-0
entries:
- ip: 127.0.0.1
hostname: localhost
- ip: 192.168.1.10
hostname: foo.mydomain.org
aliases:
- foo
- ip: 192.168.1.13
hostname: bar.mydomain.org
aliases:
- bar
- baz
profiling:
addr: ":6060"
enabled: true

View File

@ -12,16 +12,16 @@ func (c *Chain) AddNodeGroup(group *NodeGroup) {
c.groups = append(c.groups, group) c.groups = append(c.groups, group)
} }
func (c *Chain) GetRoute() (r *Route) { func (c *Chain) GetRoute() (r *route) {
return c.GetRouteFor("tcp", "") return c.GetRouteFor("tcp", "")
} }
func (c *Chain) GetRouteFor(network, address string) (r *Route) { func (c *Chain) GetRouteFor(network, address string) (r *route) {
if c == nil || len(c.groups) == 0 { if c == nil || len(c.groups) == 0 {
return return
} }
r = &Route{} r = &route{}
for _, group := range c.groups { for _, group := range c.groups {
node := group.Next() node := group.Next()
if node == nil { if node == nil {
@ -36,7 +36,7 @@ func (c *Chain) GetRouteFor(network, address string) (r *Route) {
WithRoute(r) WithRoute(r)
node = node.Copy(). node = node.Copy().
WithTransport(tr) WithTransport(tr)
r = &Route{} r = &route{}
} }
r.AddNode(node) r.AddNode(node)

View File

@ -15,15 +15,15 @@ var (
ErrEmptyRoute = errors.New("empty route") ErrEmptyRoute = errors.New("empty route")
) )
type Route struct { type route struct {
nodes []*Node nodes []*Node
} }
func (r *Route) AddNode(node *Node) { func (r *route) AddNode(node *Node) {
r.nodes = append(r.nodes, node) r.nodes = append(r.nodes, node)
} }
func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) { func (r *route) connect(ctx context.Context) (conn net.Conn, err error) {
if r.IsEmpty() { if r.IsEmpty() {
return nil, ErrEmptyRoute return nil, ErrEmptyRoute
} }
@ -67,7 +67,7 @@ func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) {
return return
} }
func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, error) { func (r *route) Dial(ctx context.Context, network, address string) (net.Conn, error) {
if r.IsEmpty() { if r.IsEmpty() {
return r.dialDirect(ctx, network, address) return r.dialDirect(ctx, network, address)
} }
@ -85,7 +85,7 @@ func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, er
return cc, nil return cc, nil
} }
func (r *Route) dialDirect(ctx context.Context, network, address string) (net.Conn, error) { func (r *route) dialDirect(ctx context.Context, network, address string) (net.Conn, error) {
switch network { switch network {
case "udp", "udp4", "udp6": case "udp", "udp4", "udp6":
if address == "" { if address == "" {
@ -98,7 +98,7 @@ func (r *Route) dialDirect(ctx context.Context, network, address string) (net.Co
return d.DialContext(ctx, network, address) return d.DialContext(ctx, network, address)
} }
func (r *Route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { func (r *route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) {
if r.IsEmpty() { if r.IsEmpty() {
return r.bindLocal(ctx, network, address, opts...) return r.bindLocal(ctx, network, address, opts...)
} }
@ -117,18 +117,18 @@ func (r *Route) Bind(ctx context.Context, network, address string, opts ...conne
return ln, nil return ln, nil
} }
func (r *Route) IsEmpty() bool { func (r *route) IsEmpty() bool {
return r == nil || len(r.nodes) == 0 return r == nil || len(r.nodes) == 0
} }
func (r *Route) Last() *Node { func (r *route) Last() *Node {
if r.IsEmpty() { if r.IsEmpty() {
return nil return nil
} }
return r.nodes[len(r.nodes)-1] return r.nodes[len(r.nodes)-1]
} }
func (r *Route) Path() (path []*Node) { func (r *route) Path() (path []*Node) {
if r == nil || len(r.nodes) == 0 { if r == nil || len(r.nodes) == 0 {
return nil return nil
} }
@ -142,7 +142,7 @@ func (r *Route) Path() (path []*Node) {
return return
} }
func (r *Route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { func (r *route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) {
options := connector.BindOptions{} options := connector.BindOptions{}
for _, opt := range opts { for _, opt := range opts {
opt(&options) opt(&options)

View File

@ -3,11 +3,11 @@ package chain
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/connector"
"github.com/go-gost/gost/pkg/hosts"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/gost/pkg/resolver" "github.com/go-gost/gost/pkg/resolver"
) )
@ -15,6 +15,7 @@ import (
type Router struct { type Router struct {
Retries int Retries int
Chain *Chain Chain *Chain
Hosts *hosts.Hosts
Resolver resolver.Resolver Resolver resolver.Resolver
Logger logger.Logger Logger logger.Logger
} }
@ -77,11 +78,10 @@ func (r *Router) resolve(ctx context.Context, addr string) (string, error) {
return "", err return "", err
} }
/* if ip := r.Hosts.Lookup(host); ip != nil {
if ip := hosts.Lookup(host); ip != nil { r.Logger.Debugf("hit hosts: %s -> %s", host, ip)
return net.JoinHostPort(ip.String(), port) return net.JoinHostPort(ip.String(), port), nil
} }
*/
if r.Resolver != nil { if r.Resolver != nil {
ips, err := r.Resolver.Resolve(ctx, host) ips, err := r.Resolver.Resolve(ctx, host)
@ -89,7 +89,7 @@ func (r *Router) resolve(ctx context.Context, addr string) (string, error) {
r.Logger.Error(err) r.Logger.Error(err)
} }
if len(ips) == 0 { if len(ips) == 0 {
return "", errors.New("domain not exists") return "", fmt.Errorf("resolver: domain %s does not exists", host)
} }
return net.JoinHostPort(ips[0].String(), port), nil return net.JoinHostPort(ips[0].String(), port), nil
} }

View File

@ -10,7 +10,7 @@ import (
type Transport struct { type Transport struct {
addr string addr string
route *Route route *route
dialer dialer.Dialer dialer dialer.Dialer
connector connector.Connector connector connector.Connector
} }
@ -82,7 +82,7 @@ func (tr *Transport) Multiplex() bool {
return false return false
} }
func (tr *Transport) WithRoute(r *Route) *Transport { func (tr *Transport) WithRoute(r *route) *Transport {
tr.route = r tr.route = r
return tr return tr
} }

View File

@ -47,6 +47,33 @@ type BypassConfig struct {
Reverse bool `yaml:",omitempty"` Reverse bool `yaml:",omitempty"`
Matchers []string Matchers []string
} }
type NameserverConfig struct {
Addr string
Chain string
Prefer string
ClientIP string
Hostname string
TTL time.Duration
Timeout time.Duration
}
type ResolverConfig struct {
Name string
Nameservers []NameserverConfig
}
type HostConfig struct {
IP string
Hostname string
Aliases []string
}
type HostsConfig struct {
Name string
Entries []HostConfig
}
type ListenerConfig struct { type ListenerConfig struct {
Type string Type string
Metadata map[string]interface{} `yaml:",omitempty"` Metadata map[string]interface{} `yaml:",omitempty"`
@ -78,6 +105,8 @@ type ServiceConfig struct {
Addr string `yaml:",omitempty"` Addr string `yaml:",omitempty"`
Chain string `yaml:",omitempty"` Chain string `yaml:",omitempty"`
Bypass string `yaml:",omitempty"` Bypass string `yaml:",omitempty"`
Resolver string `yaml:",omitempty"`
Hosts string `yaml:",omitempty"`
Listener *ListenerConfig `yaml:",omitempty"` Listener *ListenerConfig `yaml:",omitempty"`
Handler *HandlerConfig `yaml:",omitempty"` Handler *HandlerConfig `yaml:",omitempty"`
Forwarder *ForwarderConfig `yaml:",omitempty"` Forwarder *ForwarderConfig `yaml:",omitempty"`
@ -108,9 +137,11 @@ type Config struct {
Log *LogConfig `yaml:",omitempty"` Log *LogConfig `yaml:",omitempty"`
Profiling *ProfilingConfig `yaml:",omitempty"` Profiling *ProfilingConfig `yaml:",omitempty"`
TLS *TLSConfig `yaml:",omitempty"` TLS *TLSConfig `yaml:",omitempty"`
Services []*ServiceConfig
Chains []*ChainConfig `yaml:",omitempty"`
Bypasses []*BypassConfig `yaml:",omitempty"` Bypasses []*BypassConfig `yaml:",omitempty"`
Resolvers []*ResolverConfig `yaml:",omitempty"`
Hosts []*HostsConfig `yaml:",omitempty"`
Chains []*ChainConfig `yaml:",omitempty"`
Services []*ServiceConfig
} }
func (c *Config) Load() error { func (c *Config) Load() error {

56
pkg/hosts/hosts.go Normal file
View File

@ -0,0 +1,56 @@
package hosts
import (
"net"
)
// Host is a static mapping from hostname to IP.
type Host struct {
IP net.IP
Hostname string
Aliases []string
}
// NewHost creates a Host.
func NewHost(ip net.IP, hostname string, aliases ...string) Host {
return Host{
IP: ip,
Hostname: hostname,
Aliases: aliases,
}
}
// Hosts is a static table lookup for hostnames.
// For each host a single line should be present with the following information:
// IP_address canonical_hostname [aliases...]
// Fields of the entry are separated by any number of blanks and/or tab characters.
// Text from a "#" character until the end of the line is a comment, and is ignored.
type Hosts struct {
hosts []Host
}
// AddHost adds host(s) to the host table.
func (h *Hosts) AddHost(host ...Host) {
h.hosts = append(h.hosts, host...)
}
// Lookup searches the IP address corresponds to the given host from the host table.
func (h *Hosts) Lookup(host string) (ip net.IP) {
if h == nil || host == "" {
return
}
for _, h := range h.hosts {
if h.Hostname == host {
ip = h.IP
break
}
for _, alias := range h.Aliases {
if alias == host {
ip = h.IP
break
}
}
}
return
}

View File

@ -65,7 +65,10 @@ func NewResolver(nameservers []NameServer, opts ...ResolverOption) (resolverpkg.
} }
ex, err := exchanger.NewExchanger( ex, err := exchanger.NewExchanger(
addr, addr,
exchanger.ChainOption(server.Chain), exchanger.RouterOption(&chain.Router{
Chain: server.Chain,
Logger: options.logger,
}),
exchanger.TimeoutOption(server.Timeout), exchanger.TimeoutOption(server.Timeout),
exchanger.LoggerOption(options.logger), exchanger.LoggerOption(options.logger),
) )