update config parsing

This commit is contained in:
ginuerzh 2022-02-07 22:15:16 +08:00
parent 0983ecc52f
commit 1ec10ff7dd
23 changed files with 846 additions and 534 deletions

View File

@ -1,4 +1,5 @@
FROM --platform=$BUILDPLATFORM golang:1-alpine as builder FROM --platform=$BUILDPLATFORM golang:1-alpine as builder
# FROM --platform=$BUILDPLATFORM golang:1.18-rc-alpine as builder
# Convert TARGETPLATFORM to GOARCH format # Convert TARGETPLATFORM to GOARCH format
# https://github.com/tonistiigi/xx # https://github.com/tonistiigi/xx

View File

@ -1,308 +1,76 @@
package main package main
import ( import (
"crypto/tls"
"io" "io"
"net"
"net/url"
"os" "os"
"strings"
"github.com/go-gost/gost/pkg/bypass"
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/config" "github.com/go-gost/gost/pkg/config"
"github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/config/parsing"
"github.com/go-gost/gost/pkg/dialer"
"github.com/go-gost/gost/pkg/handler"
hostspkg "github.com/go-gost/gost/pkg/hosts"
"github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"
"github.com/go-gost/gost/pkg/resolver"
resolver_impl "github.com/go-gost/gost/pkg/resolver/impl"
"github.com/go-gost/gost/pkg/service" "github.com/go-gost/gost/pkg/service"
) )
var (
chains = make(map[string]*chain.Chain)
bypasses = make(map[string]bypass.Bypass)
resolvers = make(map[string]resolver.Resolver)
hosts = make(map[string]hostspkg.HostMapper)
)
func buildService(cfg *config.Config) (services []*service.Service) { func buildService(cfg *config.Config) (services []*service.Service) {
if cfg == nil || len(cfg.Services) == 0 { if cfg == nil || len(cfg.Services) == 0 {
return return
} }
for _, bypassCfg := range cfg.Bypasses { for _, bypassCfg := range cfg.Bypasses {
bypasses[bypassCfg.Name] = bypassFromConfig(bypassCfg) if bp := parsing.ParseBypass(bypassCfg); bp != nil {
if err := registry.Bypass().Register(bypassCfg.Name, bp); err != nil {
log.Fatal(err)
}
}
} }
for _, resolverCfg := range cfg.Resolvers { for _, resolverCfg := range cfg.Resolvers {
r, err := resolverFromConfig(resolverCfg) r, err := parsing.ParseResolver(resolverCfg)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
resolvers[resolverCfg.Name] = r if r != nil {
if err := registry.Resolver().Register(resolverCfg.Name, r); err != nil {
log.Fatal(err)
}
}
} }
for _, hostsCfg := range cfg.Hosts { for _, hostsCfg := range cfg.Hosts {
hosts[hostsCfg.Name] = hostsFromConfig(hostsCfg) if h := parsing.ParseHosts(hostsCfg); h != nil {
if err := registry.Hosts().Register(hostsCfg.Name, h); err != nil {
log.Fatal(err)
}
}
} }
for _, chainCfg := range cfg.Chains { for _, chainCfg := range cfg.Chains {
chains[chainCfg.Name] = chainFromConfig(chainCfg) c, err := parsing.ParseChain(chainCfg)
if err != nil {
log.Fatal(err)
}
if c != nil {
if err := registry.Chain().Register(chainCfg.Name, c); err != nil {
log.Fatal(err)
}
}
} }
for _, svc := range cfg.Services { for _, svcCfg := range cfg.Services {
if svc.Listener == nil { svc, err := parsing.ParseService(svcCfg)
svc.Listener = &config.ListenerConfig{
Type: "tcp",
}
}
if svc.Handler == nil {
svc.Handler = &config.HandlerConfig{
Type: "auto",
}
}
serviceLogger := log.WithFields(map[string]interface{}{
"kind": "service",
"service": svc.Name,
"listener": svc.Listener.Type,
"handler": svc.Handler.Type,
})
listenerLogger := serviceLogger.WithFields(map[string]interface{}{
"kind": "listener",
})
var tlsConfig *tls.Config
var err error
tlsCfg := svc.Listener.TLS
if tlsCfg == nil {
tlsCfg = &config.TLSConfig{}
}
tlsConfig, err = loadServerTLSConfig(tlsCfg)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
if svc != nil {
ln := registry.GetListener(svc.Listener.Type)( if err := registry.Service().Register(svcCfg.Name, svc); err != nil {
listener.AddrOption(svc.Addr), log.Fatal(err)
listener.ChainOption(chains[svc.Listener.Chain]), }
listener.AuthsOption(authsFromConfig(svc.Listener.Auths...)...),
listener.TLSConfigOption(tlsConfig),
listener.LoggerOption(listenerLogger),
)
if svc.Listener.Metadata == nil {
svc.Listener.Metadata = make(map[string]interface{})
} }
if err := ln.Init(metadata.MapMetadata(svc.Listener.Metadata)); err != nil {
listenerLogger.Fatal("init: ", err)
}
handlerLogger := serviceLogger.WithFields(map[string]interface{}{
"kind": "handler",
})
tlsConfig = nil
tlsCfg = svc.Handler.TLS
if tlsCfg == nil {
tlsCfg = &config.TLSConfig{}
}
tlsConfig, err = loadServerTLSConfig(tlsCfg)
if err != nil {
log.Fatal(err)
}
h := registry.GetHandler(svc.Handler.Type)(
handler.AuthsOption(authsFromConfig(svc.Handler.Auths...)...),
handler.RetriesOption(svc.Handler.Retries),
handler.ChainOption(chains[svc.Handler.Chain]),
handler.BypassOption(bypasses[svc.Bypass]),
handler.ResolverOption(resolvers[svc.Resolver]),
handler.HostsOption(hosts[svc.Hosts]),
handler.TLSConfigOption(tlsConfig),
handler.LoggerOption(handlerLogger),
)
if forwarder, ok := h.(handler.Forwarder); ok {
forwarder.Forward(forwarderFromConfig(svc.Forwarder))
}
if svc.Handler.Metadata == nil {
svc.Handler.Metadata = make(map[string]interface{})
}
if err := h.Init(metadata.MapMetadata(svc.Handler.Metadata)); err != nil {
handlerLogger.Fatal("init: ", err)
}
s := (&service.Service{}).
WithListener(ln).
WithHandler(h).
WithLogger(serviceLogger)
services = append(services, s)
serviceLogger.Infof("listening on %s/%s", s.Addr().String(), s.Addr().Network())
} }
return return
} }
func chainFromConfig(cfg *config.ChainConfig) *chain.Chain {
if cfg == nil {
return nil
}
chainLogger := log.WithFields(map[string]interface{}{
"kind": "chain",
"chain": cfg.Name,
})
c := &chain.Chain{}
selector := selectorFromConfig(cfg.Selector)
for _, hop := range cfg.Hops {
group := &chain.NodeGroup{}
for _, v := range hop.Nodes {
nodeLogger := chainLogger.WithFields(map[string]interface{}{
"kind": "node",
"connector": v.Connector.Type,
"dialer": v.Dialer.Type,
"hop": hop.Name,
"node": v.Name,
})
connectorLogger := nodeLogger.WithFields(map[string]interface{}{
"kind": "connector",
})
var user *url.Userinfo
if auth := v.Connector.Auth; auth != nil && auth.Username != "" {
if auth.Password == "" {
user = url.User(auth.Username)
} else {
user = url.UserPassword(auth.Username, auth.Password)
}
}
var tlsConfig *tls.Config
var err error
tlsCfg := v.Connector.TLS
if tlsCfg == nil {
tlsCfg = &config.TLSConfig{}
}
tlsConfig, err = loadClientTLSConfig(tlsCfg)
if err != nil {
log.Fatal(err)
}
cr := registry.GetConnector(v.Connector.Type)(
connector.UserOption(user),
connector.TLSConfigOption(tlsConfig),
connector.LoggerOption(connectorLogger),
)
if v.Connector.Metadata == nil {
v.Connector.Metadata = make(map[string]interface{})
}
if err := cr.Init(metadata.MapMetadata(v.Connector.Metadata)); err != nil {
connectorLogger.Fatal("init: ", err)
}
dialerLogger := nodeLogger.WithFields(map[string]interface{}{
"kind": "dialer",
})
user = nil
if auth := v.Dialer.Auth; auth != nil && auth.Username != "" {
if auth.Password == "" {
user = url.User(auth.Username)
} else {
user = url.UserPassword(auth.Username, auth.Password)
}
}
tlsConfig = nil
tlsCfg = v.Dialer.TLS
if tlsCfg == nil {
tlsCfg = &config.TLSConfig{}
}
tlsConfig, err = loadClientTLSConfig(tlsCfg)
if err != nil {
log.Fatal(err)
}
d := registry.GetDialer(v.Dialer.Type)(
dialer.UserOption(user),
dialer.TLSConfigOption(tlsConfig),
dialer.LoggerOption(dialerLogger),
)
if v.Dialer.Metadata == nil {
v.Dialer.Metadata = make(map[string]interface{})
}
if err := d.Init(metadata.MapMetadata(v.Dialer.Metadata)); err != nil {
dialerLogger.Fatal("init: ", err)
}
tr := (&chain.Transport{}).
WithConnector(cr).
WithDialer(d).
WithAddr(v.Addr)
if v.Bypass == "" {
v.Bypass = hop.Bypass
}
if v.Resolver == "" {
v.Resolver = hop.Resolver
}
if v.Hosts == "" {
v.Hosts = hop.Hosts
}
node := &chain.Node{
Name: v.Name,
Addr: v.Addr,
Transport: tr,
Bypass: bypasses[v.Bypass],
Resolver: resolvers[v.Resolver],
Hosts: hosts[v.Hosts],
Marker: &chain.FailMarker{},
}
group.AddNode(node)
}
sel := selector
if s := selectorFromConfig(hop.Selector); s != nil {
sel = s
}
group.WithSelector(sel)
c.AddNodeGroup(group)
}
return c
}
func forwarderFromConfig(cfg *config.ForwarderConfig) *chain.NodeGroup {
if cfg == nil || len(cfg.Targets) == 0 {
return nil
}
group := &chain.NodeGroup{}
for _, target := range cfg.Targets {
if v := strings.TrimSpace(target); v != "" {
group.AddNode(&chain.Node{
Name: target,
Addr: target,
Marker: &chain.FailMarker{},
})
}
}
return group.WithSelector(selectorFromConfig(cfg.Selector))
}
func logFromConfig(cfg *config.LogConfig) logger.Logger { func logFromConfig(cfg *config.LogConfig) logger.Logger {
if cfg == nil { if cfg == nil {
cfg = &config.LogConfig{} cfg = &config.LogConfig{}
@ -314,7 +82,7 @@ func logFromConfig(cfg *config.LogConfig) logger.Logger {
var out io.Writer = os.Stderr var out io.Writer = os.Stderr
switch cfg.Output { switch cfg.Output {
case "none": case "none", "null":
return logger.Nop() return logger.Nop()
case "stdout": case "stdout":
out = os.Stdout out = os.Stdout
@ -332,105 +100,3 @@ func logFromConfig(cfg *config.LogConfig) logger.Logger {
return logger.NewLogger(opts...) return logger.NewLogger(opts...)
} }
func selectorFromConfig(cfg *config.SelectorConfig) chain.Selector {
if cfg == nil {
return nil
}
var strategy chain.Strategy
switch cfg.Strategy {
case "round", "rr":
strategy = chain.RoundRobinStrategy()
case "random", "rand":
strategy = chain.RandomStrategy()
case "fifo", "ha":
strategy = chain.FIFOStrategy()
default:
strategy = chain.RoundRobinStrategy()
}
return chain.NewSelector(
strategy,
chain.InvalidFilter(),
chain.FailFilter(cfg.MaxFails, cfg.FailTimeout),
)
}
func bypassFromConfig(cfg *config.BypassConfig) bypass.Bypass {
if cfg == nil {
return nil
}
return bypass.NewBypassPatterns(
cfg.Reverse,
cfg.Matchers,
bypass.LoggerBypassOption(log.WithFields(map[string]interface{}{
"kind": "bypass",
"bypass": cfg.Name,
})),
)
}
func resolverFromConfig(cfg *config.ResolverConfig) (resolver.Resolver, error) {
if cfg == nil {
return nil, nil
}
var nameservers []resolver_impl.NameServer
for _, server := range cfg.Nameservers {
nameservers = append(nameservers, resolver_impl.NameServer{
Addr: server.Addr,
Chain: chains[server.Chain],
TTL: server.TTL,
Timeout: server.Timeout,
ClientIP: net.ParseIP(server.ClientIP),
Prefer: server.Prefer,
Hostname: server.Hostname,
})
}
logger := log.WithFields(map[string]interface{}{
"kind": "resolver",
"resolver": cfg.Name,
})
return resolver_impl.NewResolver(
nameservers,
resolver_impl.LoggerResolverOption(logger),
)
}
func hostsFromConfig(cfg *config.HostsConfig) hostspkg.HostMapper {
if cfg == nil || len(cfg.Mappings) == 0 {
return nil
}
hosts := hostspkg.NewHosts()
hosts.Logger = log.WithFields(map[string]interface{}{
"kind": "hosts",
"hosts": cfg.Name,
})
for _, host := range cfg.Mappings {
if host.IP == "" || host.Hostname == "" {
continue
}
ip := net.ParseIP(host.IP)
if ip == nil {
continue
}
hosts.Map(ip, host.Hostname, host.Aliases...)
}
return hosts
}
func authsFromConfig(cfgs ...*config.AuthConfig) []*url.Userinfo {
var auths []*url.Userinfo
for _, cfg := range cfgs {
if cfg == nil || cfg.Username == "" {
continue
}
auths = append(auths, url.UserPassword(cfg.Username, cfg.Password))
}
return auths
}

View File

@ -3,7 +3,6 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"io"
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
@ -16,11 +15,11 @@ import (
var ( var (
log = logger.Default() log = logger.Default()
cfgFile string cfgFile string
outputCfgFile string outputFormat string
services stringList services stringList
nodes stringList nodes stringList
debug bool debug bool
) )
func init() { func init() {
@ -31,7 +30,7 @@ func init() {
flag.StringVar(&cfgFile, "C", "", "configure file") flag.StringVar(&cfgFile, "C", "", "configure file")
flag.BoolVar(&printVersion, "V", false, "print version") flag.BoolVar(&printVersion, "V", false, "print version")
flag.BoolVar(&debug, "D", false, "debug mode") flag.BoolVar(&debug, "D", false, "debug mode")
flag.StringVar(&outputCfgFile, "O", "", "write config to FILE") flag.StringVar(&outputFormat, "O", "", "output format, one of yaml|json format")
flag.Parse() flag.Parse()
if printVersion { if printVersion {
@ -65,19 +64,8 @@ func main() {
log = logFromConfig(cfg.Log) log = logFromConfig(cfg.Log)
if outputCfgFile != "" { if outputFormat != "" {
var w io.Writer if err := cfg.Write(os.Stdout, outputFormat); err != nil {
if outputCfgFile == "-" {
w = os.Stdout
} else {
f, err := os.Create(outputCfgFile)
if err != nil {
log.Fatal(err)
}
defer f.Close()
w = f
}
if err := cfg.Write(w); err != nil {
log.Fatal(err) log.Fatal(err)
} }
os.Exit(0) os.Exit(0)

View File

@ -14,14 +14,6 @@ import (
"github.com/go-gost/gost/pkg/config" "github.com/go-gost/gost/pkg/config"
) )
func loadServerTLSConfig(cfg *config.TLSConfig) (*tls.Config, error) {
return tls_util.LoadServerConfig(cfg.CertFile, cfg.KeyFile, cfg.CAFile)
}
func loadClientTLSConfig(cfg *config.TLSConfig) (*tls.Config, error) {
return tls_util.LoadClientConfig(cfg.CertFile, cfg.KeyFile, cfg.CAFile, cfg.Secure, cfg.ServerName)
}
func buildDefaultTLSConfig(cfg *config.TLSConfig) { func buildDefaultTLSConfig(cfg *config.TLSConfig) {
if cfg == nil { if cfg == nil {
cfg = &config.TLSConfig{ cfg = &config.TLSConfig{

View File

@ -1,5 +1,9 @@
package chain package chain
type Chainer interface {
Route(network, address string) *Route
}
type Chain struct { type Chain struct {
groups []*NodeGroup groups []*NodeGroup
} }
@ -8,16 +12,12 @@ func (c *Chain) AddNodeGroup(group *NodeGroup) {
c.groups = append(c.groups, group) c.groups = append(c.groups, group)
} }
func (c *Chain) GetRoute() (r *route) { func (c *Chain) Route(network, address string) (r *Route) {
return c.GetRouteFor("tcp", "")
}
func (c *Chain) GetRouteFor(network, address string) (r *route) {
if c == nil || len(c.groups) == 0 { if c == nil || len(c.groups) == 0 {
return return
} }
r = &route{} r = &Route{}
for _, group := range c.groups { for _, group := range c.groups {
node := group.Next() node := group.Next()
if node == nil { if node == nil {
@ -32,14 +32,10 @@ func (c *Chain) GetRouteFor(network, address string) (r *route) {
WithRoute(r) WithRoute(r)
node = node.Copy() node = node.Copy()
node.Transport = tr node.Transport = tr
r = &route{} r = &Route{}
} }
r.AddNode(node) r.addNode(node)
} }
return r return r
} }
func (c *Chain) IsEmpty() bool {
return c == nil || len(c.groups) == 0
}

View File

@ -10,7 +10,7 @@ import (
"github.com/go-gost/gost/pkg/resolver" "github.com/go-gost/gost/pkg/resolver"
) )
func resolve(ctx context.Context, addr string, resolver resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) { func resolve(ctx context.Context, network, addr string, resolver resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) {
if addr == "" { if addr == "" {
return addr, nil return addr, nil
} }
@ -24,14 +24,14 @@ func resolve(ctx context.Context, addr string, resolver resolver.Resolver, hosts
} }
if hosts != nil { if hosts != nil {
if ips, _ := hosts.Lookup("ip", host); len(ips) > 0 { if ips, _ := hosts.Lookup(network, host); len(ips) > 0 {
log.Debugf("hit host mapper: %s -> %s", host, ips) log.Debugf("hit host mapper: %s -> %s", host, ips)
return net.JoinHostPort(ips[0].String(), port), nil return net.JoinHostPort(ips[0].String(), port), nil
} }
} }
if resolver != nil { if resolver != nil {
ips, err := resolver.Resolve(ctx, host) ips, err := resolver.Resolve(ctx, network, host)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
} }

View File

@ -15,17 +15,17 @@ var (
ErrEmptyRoute = errors.New("empty route") ErrEmptyRoute = errors.New("empty route")
) )
type route struct { type Route struct {
nodes []*Node nodes []*Node
logger logger.Logger logger logger.Logger
} }
func (r *route) AddNode(node *Node) { func (r *Route) addNode(node *Node) {
r.nodes = append(r.nodes, node) r.nodes = append(r.nodes, node)
} }
func (r *route) Dial(ctx context.Context, network, address string) (net.Conn, error) { func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, error) {
if r.IsEmpty() { if r.Len() == 0 {
return r.dialDirect(ctx, network, address) return r.dialDirect(ctx, network, address)
} }
@ -34,7 +34,7 @@ func (r *route) Dial(ctx context.Context, network, address string) (net.Conn, er
return nil, err return nil, err
} }
cc, err := r.Last().Transport.Connect(ctx, conn, network, address) cc, err := r.GetNode(r.Len()-1).Transport.Connect(ctx, conn, network, address)
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
@ -42,7 +42,7 @@ func (r *route) Dial(ctx context.Context, network, address string) (net.Conn, er
return cc, nil return cc, nil
} }
func (r *route) dialDirect(ctx context.Context, network, address string) (net.Conn, error) { func (r *Route) dialDirect(ctx context.Context, network, address string) (net.Conn, error) {
switch network { switch network {
case "udp", "udp4", "udp6": case "udp", "udp4", "udp6":
if address == "" { if address == "" {
@ -55,8 +55,8 @@ func (r *route) dialDirect(ctx context.Context, network, address string) (net.Co
return d.DialContext(ctx, network, address) return d.DialContext(ctx, network, address)
} }
func (r *route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { func (r *Route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) {
if r.IsEmpty() { if r.Len() == 0 {
return r.bindLocal(ctx, network, address, opts...) return r.bindLocal(ctx, network, address, opts...)
} }
@ -65,7 +65,7 @@ func (r *route) Bind(ctx context.Context, network, address string, opts ...conne
return nil, err return nil, err
} }
ln, err := r.Last().Transport.Bind(ctx, conn, network, address, opts...) ln, err := r.GetNode(r.Len()-1).Transport.Bind(ctx, conn, network, address, opts...)
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
@ -74,14 +74,15 @@ func (r *route) Bind(ctx context.Context, network, address string, opts ...conne
return ln, nil return ln, nil
} }
func (r *route) connect(ctx context.Context) (conn net.Conn, err error) { func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) {
if r.IsEmpty() { if r.Len() == 0 {
return nil, ErrEmptyRoute return nil, ErrEmptyRoute
} }
network := "ip"
node := r.nodes[0] node := r.nodes[0]
addr, err := resolve(ctx, node.Addr, node.Resolver, node.Hosts, r.logger) addr, err := resolve(ctx, network, node.Addr, node.Resolver, node.Hosts, r.logger)
if err != nil { if err != nil {
node.Marker.Mark() node.Marker.Mark()
return return
@ -102,7 +103,7 @@ func (r *route) connect(ctx context.Context) (conn net.Conn, err error) {
preNode := node preNode := node
for _, node := range r.nodes[1:] { for _, node := range r.nodes[1:] {
addr, err = resolve(ctx, node.Addr, node.Resolver, node.Hosts, r.logger) addr, err = resolve(ctx, network, node.Addr, node.Resolver, node.Hosts, r.logger)
if err != nil { if err != nil {
cn.Close() cn.Close()
node.Marker.Mark() node.Marker.Mark()
@ -130,18 +131,21 @@ func (r *route) connect(ctx context.Context) (conn net.Conn, err error) {
return return
} }
func (r *route) IsEmpty() bool { func (r *Route) Len() int {
return r == nil || len(r.nodes) == 0 if r == nil {
return 0
}
return len(r.nodes)
} }
func (r *route) Last() *Node { func (r *Route) GetNode(index int) *Node {
if r.IsEmpty() { if r.Len() == 0 || index < 0 || index >= len(r.nodes) {
return nil return nil
} }
return r.nodes[len(r.nodes)-1] return r.nodes[index]
} }
func (r *route) Path() (path []*Node) { func (r *Route) Path() (path []*Node) {
if r == nil || len(r.nodes) == 0 { if r == nil || len(r.nodes) == 0 {
return nil return nil
} }
@ -155,7 +159,7 @@ func (r *route) Path() (path []*Node) {
return return
} }
func (r *route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { func (r *Route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) {
options := connector.BindOptions{} options := connector.BindOptions{}
for _, opt := range opts { for _, opt := range opts {
opt(&options) opt(&options)

View File

@ -14,7 +14,7 @@ import (
type Router struct { type Router struct {
Retries int Retries int
Chain *Chain Chain Chainer
Hosts hosts.HostMapper Hosts hosts.HostMapper
Resolver resolver.Resolver Resolver resolver.Resolver
Logger logger.Logger Logger logger.Logger
@ -41,7 +41,10 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
r.Logger.Debugf("dial %s/%s", address, network) r.Logger.Debugf("dial %s/%s", address, network)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
route := r.Chain.GetRouteFor(network, address) var route *Route
if r.Chain != nil {
route = r.Chain.Route(network, address)
}
if r.Logger.IsLevelEnabled(logger.DebugLevel) { if r.Logger.IsLevelEnabled(logger.DebugLevel) {
buf := bytes.Buffer{} buf := bytes.Buffer{}
@ -52,7 +55,7 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co
r.Logger.Debugf("route(retry=%d) %s", i, buf.String()) r.Logger.Debugf("route(retry=%d) %s", i, buf.String())
} }
address, err = resolve(ctx, address, r.Resolver, r.Hosts, r.Logger) address, err = resolve(ctx, "ip", address, r.Resolver, r.Hosts, r.Logger)
if err != nil { if err != nil {
r.Logger.Error(err) r.Logger.Error(err)
break break
@ -80,7 +83,10 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...conn
r.Logger.Debugf("bind on %s/%s", address, network) r.Logger.Debugf("bind on %s/%s", address, network)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
route := r.Chain.GetRouteFor(network, address) var route *Route
if r.Chain != nil {
route = r.Chain.Route(network, address)
}
if r.Logger.IsLevelEnabled(logger.DebugLevel) { if r.Logger.IsLevelEnabled(logger.DebugLevel) {
buf := bytes.Buffer{} buf := bytes.Buffer{}

View File

@ -10,7 +10,7 @@ import (
type Transport struct { type Transport struct {
addr string addr string
route *route route *Route
dialer dialer.Dialer dialer dialer.Dialer
connector connector.Connector connector connector.Connector
} }
@ -39,7 +39,7 @@ func (tr *Transport) dialOptions() []dialer.DialOption {
opts := []dialer.DialOption{ opts := []dialer.DialOption{
dialer.HostDialOption(tr.addr), dialer.HostDialOption(tr.addr),
} }
if !tr.route.IsEmpty() { if tr.route.Len() > 0 {
opts = append(opts, opts = append(opts,
dialer.DialFuncDialOption( dialer.DialFuncDialOption(
func(ctx context.Context, addr string) (net.Conn, error) { func(ctx context.Context, addr string) (net.Conn, error) {
@ -84,7 +84,7 @@ func (tr *Transport) Multiplex() bool {
return false return false
} }
func (tr *Transport) WithRoute(r *route) *Transport { func (tr *Transport) WithRoute(r *Route) *Transport {
tr.route = r tr.route = r
return tr return tr
} }

View File

@ -1,6 +1,7 @@
package config package config
import ( import (
"encoding/json"
"io" "io"
"time" "time"
@ -20,148 +21,148 @@ func init() {
} }
type LogConfig struct { type LogConfig struct {
Output string `yaml:",omitempty"` Output string `yaml:",omitempty" json:"output,omitempty"`
Level string `yaml:",omitempty"` Level string `yaml:",omitempty" json:"level,omitempty"`
Format string `yaml:",omitempty"` Format string `yaml:",omitempty" json:"format,omitempty"`
} }
type ProfilingConfig struct { type ProfilingConfig struct {
Addr string Addr string `json:"addr"`
Enabled bool Enabled bool `json:"enabled"`
} }
type TLSConfig struct { type TLSConfig struct {
CertFile string `yaml:"certFile,omitempty"` CertFile string `yaml:"certFile,omitempty" json:"certFile,omitempty"`
KeyFile string `yaml:"keyFile,omitempty"` KeyFile string `yaml:"keyFile,omitempty" json:"keyFile,omitempty"`
CAFile string `yaml:"caFile,omitempty"` CAFile string `yaml:"caFile,omitempty" json:"caFile,omitempty"`
Secure bool `yaml:",omitempty"` Secure bool `yaml:",omitempty" json:"secure,omitempty"`
ServerName string `yaml:"serverName,omitempty"` ServerName string `yaml:"serverName,omitempty" json:"serverName,omitempty"`
} }
type AuthConfig struct { type AuthConfig struct {
Username string Username string `json:"username"`
Password string Password string `yaml:",omitempty" json:"password,omitempty"`
} }
type SelectorConfig struct { type SelectorConfig struct {
Strategy string Strategy string `json:"strategy"`
MaxFails int `yaml:"maxFails"` MaxFails int `yaml:"maxFails" json:"maxFails"`
FailTimeout time.Duration `yaml:"failTimeout"` FailTimeout time.Duration `yaml:"failTimeout" json:"failTimeout"`
} }
type BypassConfig struct { type BypassConfig struct {
Name string Name string `json:"name"`
Reverse bool `yaml:",omitempty"` Reverse bool `yaml:",omitempty" json:"reverse,omitempty"`
Matchers []string Matchers []string `json:"matchers"`
} }
type NameserverConfig struct { type NameserverConfig struct {
Addr string Addr string `json:"addr"`
Chain string `yaml:",omitempty"` Chain string `yaml:",omitempty" json:"chain,omitempty"`
Prefer string `yaml:",omitempty"` Prefer string `yaml:",omitempty" json:"prefer,omitempty"`
ClientIP string `yaml:"clientIP,omitempty"` ClientIP string `yaml:"clientIP,omitempty" json:"clientIP,omitempty"`
Hostname string `yaml:",omitempty"` Hostname string `yaml:",omitempty" json:"hostname,omitempty"`
TTL time.Duration `yaml:",omitempty"` TTL time.Duration `yaml:",omitempty" json:"ttl,omitempty"`
Timeout time.Duration `yaml:",omitempty"` Timeout time.Duration `yaml:",omitempty" json:"timeout,omitempty"`
} }
type ResolverConfig struct { type ResolverConfig struct {
Name string Name string `json:"name"`
Nameservers []NameserverConfig Nameservers []NameserverConfig `json:"nameservers"`
} }
type HostMappingConfig struct { type HostMappingConfig struct {
IP string IP string `json:"ip"`
Hostname string Hostname string `json:"hostname"`
Aliases []string `yaml:",omitempty"` Aliases []string `yaml:",omitempty" json:"aliases,omitempty"`
} }
type HostsConfig struct { type HostsConfig struct {
Name string Name string `json:"name"`
Mappings []HostMappingConfig Mappings []HostMappingConfig `json:"mappings"`
} }
type ListenerConfig struct { type ListenerConfig struct {
Type string Type string `json:"type"`
Chain string `yaml:",omitempty"` Chain string `yaml:",omitempty" json:"chain,omitempty"`
Auths []*AuthConfig `yaml:",omitempty"` Auths []*AuthConfig `yaml:",omitempty" json:"auths,omitempty"`
TLS *TLSConfig `yaml:",omitempty"` TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"`
Metadata map[string]interface{} `yaml:",omitempty"` Metadata map[string]interface{} `yaml:",omitempty" json:"metadata,omitempty"`
} }
type HandlerConfig struct { type HandlerConfig struct {
Type string Type string `json:"type"`
Retries int `yaml:",omitempty"` Retries int `yaml:",omitempty" json:"retries,omitempty"`
Chain string `yaml:",omitempty"` Chain string `yaml:",omitempty" json:"chain,omitempty"`
Auths []*AuthConfig `yaml:",omitempty"` Auths []*AuthConfig `yaml:",omitempty" json:"auths,omitempty"`
TLS *TLSConfig `yaml:",omitempty"` TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"`
Metadata map[string]interface{} `yaml:",omitempty"` Metadata map[string]interface{} `yaml:",omitempty" json:"metadata,omitempty"`
} }
type ForwarderConfig struct { type ForwarderConfig struct {
Targets []string Targets []string `json:"targets"`
Selector *SelectorConfig `yaml:",omitempty"` Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"`
} }
type DialerConfig struct { type DialerConfig struct {
Type string Type string `json:"type"`
Auth *AuthConfig `yaml:",omitempty"` Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"`
TLS *TLSConfig `yaml:",omitempty"` TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"`
Metadata map[string]interface{} `yaml:",omitempty"` Metadata map[string]interface{} `yaml:",omitempty" json:"metadata,omitempty"`
} }
type ConnectorConfig struct { type ConnectorConfig struct {
Type string Type string `json:"type"`
Auth *AuthConfig `yaml:",omitempty"` Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"`
TLS *TLSConfig `yaml:",omitempty"` TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"`
Metadata map[string]interface{} `yaml:",omitempty"` Metadata map[string]interface{} `yaml:",omitempty" json:"metadata,omitempty"`
} }
type ServiceConfig struct { type ServiceConfig struct {
Name string Name string `json:"name"`
Addr string `yaml:",omitempty"` Addr string `yaml:",omitempty" json:"addr,omitempty"`
Bypass string `yaml:",omitempty"` Bypass string `yaml:",omitempty" json:"bypass,omitempty"`
Resolver string `yaml:",omitempty"` Resolver string `yaml:",omitempty" json:"resolver,omitempty"`
Hosts string `yaml:",omitempty"` Hosts string `yaml:",omitempty" json:"hosts,omitempty"`
Handler *HandlerConfig `yaml:",omitempty"` Handler *HandlerConfig `yaml:",omitempty" json:"handler,omitempty"`
Listener *ListenerConfig `yaml:",omitempty"` Listener *ListenerConfig `yaml:",omitempty" json:"listener,omitempty"`
Forwarder *ForwarderConfig `yaml:",omitempty"` Forwarder *ForwarderConfig `yaml:",omitempty" json:"forwarder,omitempty"`
} }
type ChainConfig struct { type ChainConfig struct {
Name string Name string `json:"name"`
Selector *SelectorConfig `yaml:",omitempty"` Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"`
Hops []*HopConfig Hops []*HopConfig `json:"hops"`
} }
type HopConfig struct { type HopConfig struct {
Name string Name string `json:"name"`
Selector *SelectorConfig `yaml:",omitempty"` Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"`
Bypass string `yaml:",omitempty"` Bypass string `yaml:",omitempty" json:"bypass,omitempty"`
Resolver string `yaml:",omitempty"` Resolver string `yaml:",omitempty" json:"resolver,omitempty"`
Hosts string `yaml:",omitempty"` Hosts string `yaml:",omitempty" json:"hosts,omitempty"`
Nodes []*NodeConfig Nodes []*NodeConfig `json:"nodes"`
} }
type NodeConfig struct { type NodeConfig struct {
Name string Name string `json:"name"`
Addr string `yaml:",omitempty"` Addr string `yaml:",omitempty" json:"addr,omitempty"`
Bypass string `yaml:",omitempty"` Bypass string `yaml:",omitempty" json:"bypass,omitempty"`
Resolver string `yaml:",omitempty"` Resolver string `yaml:",omitempty" json:"resolver,omitempty"`
Hosts string `yaml:",omitempty"` Hosts string `yaml:",omitempty" json:"hosts,omitempty"`
Connector *ConnectorConfig `yaml:",omitempty"` Connector *ConnectorConfig `yaml:",omitempty" json:"connector,omitempty"`
Dialer *DialerConfig `yaml:",omitempty"` Dialer *DialerConfig `yaml:",omitempty" json:"dialer,omitempty"`
} }
type Config struct { type Config struct {
Services []*ServiceConfig Services []*ServiceConfig `json:"services"`
Chains []*ChainConfig `yaml:",omitempty"` Chains []*ChainConfig `yaml:",omitempty" json:"chains,omitempty"`
Bypasses []*BypassConfig `yaml:",omitempty"` Bypasses []*BypassConfig `yaml:",omitempty" json:"bypasses,omitempty"`
Resolvers []*ResolverConfig `yaml:",omitempty"` Resolvers []*ResolverConfig `yaml:",omitempty" json:"resolvers,omitempty"`
Hosts []*HostsConfig `yaml:",omitempty"` Hosts []*HostsConfig `yaml:",omitempty" json:"hosts,omitempty"`
TLS *TLSConfig `yaml:",omitempty"` TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"`
Log *LogConfig `yaml:",omitempty"` Log *LogConfig `yaml:",omitempty" json:"log,omitempty"`
Profiling *ProfilingConfig `yaml:",omitempty"` Profiling *ProfilingConfig `yaml:",omitempty" json:"profiling,omitempty"`
} }
func (c *Config) Load() error { func (c *Config) Load() error {
@ -188,9 +189,19 @@ func (c *Config) ReadFile(file string) error {
return v.Unmarshal(c) return v.Unmarshal(c)
} }
func (c *Config) Write(w io.Writer) error { func (c *Config) Write(w io.Writer, format string) error {
enc := yaml.NewEncoder(w) switch format {
defer enc.Close() case "json":
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
enc.Encode(c)
return nil
case "yaml":
fallthrough
default:
enc := yaml.NewEncoder(w)
defer enc.Close()
return enc.Encode(c) return enc.Encode(c)
}
} }

152
pkg/config/parsing/chain.go Normal file
View File

@ -0,0 +1,152 @@
package parsing
import (
"net/url"
"github.com/go-gost/gost/pkg/chain"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
"github.com/go-gost/gost/pkg/config"
"github.com/go-gost/gost/pkg/connector"
"github.com/go-gost/gost/pkg/dialer"
"github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
)
func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) {
if cfg == nil {
return nil, nil
}
chainLogger := logger.Default().WithFields(map[string]interface{}{
"kind": "chain",
"chain": cfg.Name,
})
c := &chain.Chain{}
selector := parseSelector(cfg.Selector)
for _, hop := range cfg.Hops {
group := &chain.NodeGroup{}
for _, v := range hop.Nodes {
nodeLogger := chainLogger.WithFields(map[string]interface{}{
"kind": "node",
"connector": v.Connector.Type,
"dialer": v.Dialer.Type,
"hop": hop.Name,
"node": v.Name,
})
connectorLogger := nodeLogger.WithFields(map[string]interface{}{
"kind": "connector",
})
var user *url.Userinfo
if auth := v.Connector.Auth; auth != nil && auth.Username != "" {
if auth.Password == "" {
user = url.User(auth.Username)
} else {
user = url.UserPassword(auth.Username, auth.Password)
}
}
tlsCfg := v.Connector.TLS
if tlsCfg == nil {
tlsCfg = &config.TLSConfig{}
}
tlsConfig, err := tls_util.LoadClientConfig(
tlsCfg.CertFile, tlsCfg.KeyFile, tlsCfg.CAFile,
tlsCfg.Secure, tlsCfg.ServerName)
if err != nil {
chainLogger.Error(err)
return nil, err
}
cr := registry.GetConnector(v.Connector.Type)(
connector.UserOption(user),
connector.TLSConfigOption(tlsConfig),
connector.LoggerOption(connectorLogger),
)
if v.Connector.Metadata == nil {
v.Connector.Metadata = make(map[string]interface{})
}
if err := cr.Init(metadata.MapMetadata(v.Connector.Metadata)); err != nil {
connectorLogger.Error("init: ", err)
return nil, err
}
dialerLogger := nodeLogger.WithFields(map[string]interface{}{
"kind": "dialer",
})
user = nil
if auth := v.Dialer.Auth; auth != nil && auth.Username != "" {
if auth.Password == "" {
user = url.User(auth.Username)
} else {
user = url.UserPassword(auth.Username, auth.Password)
}
}
tlsCfg = v.Dialer.TLS
if tlsCfg == nil {
tlsCfg = &config.TLSConfig{}
}
tlsConfig, err = tls_util.LoadClientConfig(
tlsCfg.CertFile, tlsCfg.KeyFile, tlsCfg.CAFile,
tlsCfg.Secure, tlsCfg.ServerName)
if err != nil {
chainLogger.Error(err)
return nil, err
}
d := registry.GetDialer(v.Dialer.Type)(
dialer.UserOption(user),
dialer.TLSConfigOption(tlsConfig),
dialer.LoggerOption(dialerLogger),
)
if v.Dialer.Metadata == nil {
v.Dialer.Metadata = make(map[string]interface{})
}
if err := d.Init(metadata.MapMetadata(v.Dialer.Metadata)); err != nil {
dialerLogger.Error("init: ", err)
return nil, err
}
tr := (&chain.Transport{}).
WithConnector(cr).
WithDialer(d).
WithAddr(v.Addr)
if v.Bypass == "" {
v.Bypass = hop.Bypass
}
if v.Resolver == "" {
v.Resolver = hop.Resolver
}
if v.Hosts == "" {
v.Hosts = hop.Hosts
}
node := &chain.Node{
Name: v.Name,
Addr: v.Addr,
Transport: tr,
Bypass: registry.Bypass().Get(v.Bypass),
Resolver: registry.Resolver().Get(v.Resolver),
Hosts: registry.Hosts().Get(v.Hosts),
Marker: &chain.FailMarker{},
}
group.AddNode(node)
}
sel := selector
if s := parseSelector(hop.Selector); s != nil {
sel = s
}
group.WithSelector(sel)
c.AddNodeGroup(group)
}
return c, nil
}

103
pkg/config/parsing/parse.go Normal file
View File

@ -0,0 +1,103 @@
package parsing
import (
"net"
"github.com/go-gost/gost/pkg/bypass"
"github.com/go-gost/gost/pkg/chain"
"github.com/go-gost/gost/pkg/config"
hostspkg "github.com/go-gost/gost/pkg/hosts"
"github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/gost/pkg/resolver"
resolver_impl "github.com/go-gost/gost/pkg/resolver/impl"
)
func parseSelector(cfg *config.SelectorConfig) chain.Selector {
if cfg == nil {
return nil
}
var strategy chain.Strategy
switch cfg.Strategy {
case "round", "rr":
strategy = chain.RoundRobinStrategy()
case "random", "rand":
strategy = chain.RandomStrategy()
case "fifo", "ha":
strategy = chain.FIFOStrategy()
default:
strategy = chain.RoundRobinStrategy()
}
return chain.NewSelector(
strategy,
chain.InvalidFilter(),
chain.FailFilter(cfg.MaxFails, cfg.FailTimeout),
)
}
func ParseBypass(cfg *config.BypassConfig) bypass.Bypass {
if cfg == nil {
return nil
}
return bypass.NewBypassPatterns(
cfg.Reverse,
cfg.Matchers,
bypass.LoggerBypassOption(logger.Default().WithFields(map[string]interface{}{
"kind": "bypass",
"bypass": cfg.Name,
})),
)
}
func ParseResolver(cfg *config.ResolverConfig) (resolver.Resolver, error) {
if cfg == nil {
return nil, nil
}
var nameservers []resolver_impl.NameServer
for _, server := range cfg.Nameservers {
nameservers = append(nameservers, resolver_impl.NameServer{
Addr: server.Addr,
// Chain: chains[server.Chain],
TTL: server.TTL,
Timeout: server.Timeout,
ClientIP: net.ParseIP(server.ClientIP),
Prefer: server.Prefer,
Hostname: server.Hostname,
})
}
return resolver_impl.NewResolver(
nameservers,
resolver_impl.LoggerResolverOption(
logger.Default().WithFields(map[string]interface{}{
"kind": "resolver",
"resolver": cfg.Name,
}),
),
)
}
func ParseHosts(cfg *config.HostsConfig) hostspkg.HostMapper {
if cfg == nil || len(cfg.Mappings) == 0 {
return nil
}
hosts := hostspkg.NewHosts()
hosts.Logger = logger.Default().WithFields(map[string]interface{}{
"kind": "hosts",
"hosts": cfg.Name,
})
for _, host := range cfg.Mappings {
if host.IP == "" || host.Hostname == "" {
continue
}
ip := net.ParseIP(host.IP)
if ip == nil {
continue
}
hosts.Map(ip, host.Hostname, host.Aliases...)
}
return hosts
}

View File

@ -0,0 +1,143 @@
package parsing
import (
"net/url"
"strings"
"github.com/go-gost/gost/pkg/chain"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
"github.com/go-gost/gost/pkg/config"
"github.com/go-gost/gost/pkg/handler"
"github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger"
"github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/go-gost/gost/pkg/service"
)
func ParseService(cfg *config.ServiceConfig) (*service.Service, error) {
if cfg.Listener == nil {
cfg.Listener = &config.ListenerConfig{
Type: "tcp",
}
}
if cfg.Handler == nil {
cfg.Handler = &config.HandlerConfig{
Type: "auto",
}
}
serviceLogger := logger.Default().WithFields(map[string]interface{}{
"kind": "service",
"service": cfg.Name,
"listener": cfg.Listener.Type,
"handler": cfg.Handler.Type,
})
listenerLogger := serviceLogger.WithFields(map[string]interface{}{
"kind": "listener",
})
tlsCfg := cfg.Listener.TLS
if tlsCfg == nil {
tlsCfg = &config.TLSConfig{}
}
tlsConfig, err := tls_util.LoadServerConfig(
tlsCfg.CertFile, tlsCfg.KeyFile, tlsCfg.CAFile)
if err != nil {
listenerLogger.Error(err)
return nil, err
}
ln := registry.GetListener(cfg.Listener.Type)(
listener.AddrOption(cfg.Addr),
listener.ChainOption(registry.Chain().Get(cfg.Listener.Chain)),
listener.AuthsOption(parseAuths(cfg.Listener.Auths...)...),
listener.TLSConfigOption(tlsConfig),
listener.LoggerOption(listenerLogger),
)
if cfg.Listener.Metadata == nil {
cfg.Listener.Metadata = make(map[string]interface{})
}
if err := ln.Init(metadata.MapMetadata(cfg.Listener.Metadata)); err != nil {
listenerLogger.Error("init: ", err)
return nil, err
}
handlerLogger := serviceLogger.WithFields(map[string]interface{}{
"kind": "handler",
})
tlsCfg = cfg.Handler.TLS
if tlsCfg == nil {
tlsCfg = &config.TLSConfig{}
}
tlsConfig, err = tls_util.LoadServerConfig(
tlsCfg.CertFile, tlsCfg.KeyFile, tlsCfg.CAFile)
if err != nil {
handlerLogger.Error(err)
return nil, err
}
h := registry.GetHandler(cfg.Handler.Type)(
handler.AuthsOption(parseAuths(cfg.Handler.Auths...)...),
handler.RetriesOption(cfg.Handler.Retries),
handler.ChainOption(registry.Chain().Get(cfg.Handler.Chain)),
handler.BypassOption(registry.Bypass().Get(cfg.Bypass)),
handler.ResolverOption(registry.Resolver().Get(cfg.Resolver)),
handler.HostsOption(registry.Hosts().Get(cfg.Hosts)),
handler.TLSConfigOption(tlsConfig),
handler.LoggerOption(handlerLogger),
)
if forwarder, ok := h.(handler.Forwarder); ok {
forwarder.Forward(parseForwarder(cfg.Forwarder))
}
if cfg.Handler.Metadata == nil {
cfg.Handler.Metadata = make(map[string]interface{})
}
if err := h.Init(metadata.MapMetadata(cfg.Handler.Metadata)); err != nil {
handlerLogger.Error("init: ", err)
return nil, err
}
s := (&service.Service{}).
WithListener(ln).
WithHandler(h).
WithLogger(serviceLogger)
serviceLogger.Infof("listening on %s/%s", s.Addr().String(), s.Addr().Network())
return s, nil
}
func parseAuths(cfgs ...*config.AuthConfig) []*url.Userinfo {
var auths []*url.Userinfo
for _, cfg := range cfgs {
if cfg == nil || cfg.Username == "" {
continue
}
auths = append(auths, url.UserPassword(cfg.Username, cfg.Password))
}
return auths
}
func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup {
if cfg == nil || len(cfg.Targets) == 0 {
return nil
}
group := &chain.NodeGroup{}
for _, target := range cfg.Targets {
if v := strings.TrimSpace(target); v != "" {
group.AddNode(&chain.Node{
Name: target,
Addr: target,
Marker: &chain.FailMarker{},
})
}
}
return group.WithSelector(parseSelector(cfg.Selector))
}

View File

@ -13,7 +13,7 @@ import (
type Options struct { type Options struct {
Retries int Retries int
Chain *chain.Chain Chain chain.Chainer
Resolver resolver.Resolver Resolver resolver.Resolver
Hosts hosts.HostMapper Hosts hosts.HostMapper
Bypass bypass.Bypass Bypass bypass.Bypass
@ -30,7 +30,7 @@ func RetriesOption(retries int) Option {
} }
} }
func ChainOption(chain *chain.Chain) Option { func ChainOption(chain chain.Chainer) Option {
return func(opts *Options) { return func(opts *Options) {
opts.Chain = chain opts.Chain = chain
} }

View File

@ -12,7 +12,7 @@ type Options struct {
Addr string Addr string
Auths []*url.Userinfo Auths []*url.Userinfo
TLSConfig *tls.Config TLSConfig *tls.Config
Chain *chain.Chain Chain chain.Chainer
Logger logger.Logger Logger logger.Logger
} }
@ -36,7 +36,7 @@ func TLSConfigOption(tlsConfig *tls.Config) Option {
} }
} }
func ChainOption(chain *chain.Chain) Option { func ChainOption(chain chain.Chainer) Option {
return func(opts *Options) { return func(opts *Options) {
opts.Chain = chain opts.Chain = chain
} }

50
pkg/registry/bypass.go Normal file
View File

@ -0,0 +1,50 @@
package registry
import (
"sync"
"github.com/go-gost/gost/pkg/bypass"
)
var (
bypassReg = &bypassRegistry{}
)
func Bypass() *bypassRegistry {
return bypassReg
}
type bypassRegistry struct {
m sync.Map
}
func (r *bypassRegistry) Register(name string, bypass bypass.Bypass) error {
if _, loaded := r.m.LoadOrStore(name, bypass); loaded {
return ErrDup
}
return nil
}
func (r *bypassRegistry) Unregister(name string) {
r.m.Delete(name)
}
func (r *bypassRegistry) Get(name string) bypass.Bypass {
if _, ok := r.m.Load(name); !ok {
return nil
}
return &bypassWrapper{name: name}
}
type bypassWrapper struct {
name string
}
func (w *bypassWrapper) Contains(addr string) bool {
bp := bypassReg.Get(w.name)
if bp == nil {
return false
}
return bp.Contains(addr)
}

50
pkg/registry/chain.go Normal file
View File

@ -0,0 +1,50 @@
package registry
import (
"sync"
"github.com/go-gost/gost/pkg/chain"
)
var (
chainReg = &chainRegistry{}
)
func Chain() *chainRegistry {
return chainReg
}
type chainRegistry struct {
m sync.Map
}
func (r *chainRegistry) Register(name string, chain chain.Chainer) error {
if _, loaded := r.m.LoadOrStore(name, chain); loaded {
return ErrDup
}
return nil
}
func (r *chainRegistry) Unregister(name string) {
r.m.Delete(name)
}
func (r *chainRegistry) Get(name string) chain.Chainer {
if _, ok := r.m.Load(name); !ok {
return nil
}
return &chainWrapper{name: name}
}
type chainWrapper struct {
name string
}
func (w *chainWrapper) Route(network, address string) *chain.Route {
v := Chain().Get(w.name)
if v == nil {
return nil
}
return v.Route(network, address)
}

51
pkg/registry/hosts.go Normal file
View File

@ -0,0 +1,51 @@
package registry
import (
"net"
"sync"
"github.com/go-gost/gost/pkg/hosts"
)
var (
hostsReg = &hostsRegistry{}
)
func Hosts() *hostsRegistry {
return hostsReg
}
type hostsRegistry struct {
m sync.Map
}
func (r *hostsRegistry) Register(name string, hosts hosts.HostMapper) error {
if _, loaded := r.m.LoadOrStore(name, hosts); loaded {
return ErrDup
}
return nil
}
func (r *hostsRegistry) Unregister(name string) {
r.m.Delete(name)
}
func (r *hostsRegistry) Get(name string) hosts.HostMapper {
if _, ok := r.m.Load(name); !ok {
return nil
}
return &hostsWrapper{name: name}
}
type hostsWrapper struct {
name string
}
func (w *hostsWrapper) Lookup(network, host string) ([]net.IP, bool) {
v := Hosts().Get(w.name)
if v == nil {
return nil, false
}
return v.Lookup(network, host)
}

View File

@ -1,6 +1,8 @@
package registry package registry
import ( import (
"errors"
"github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/connector"
"github.com/go-gost/gost/pkg/dialer" "github.com/go-gost/gost/pkg/dialer"
"github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/handler"
@ -8,6 +10,11 @@ import (
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
) )
var (
ErrDup = errors.New("registry: duplicate instance")
ErrNotFound = errors.New("registry: instance not found")
)
type NewListener func(opts ...listener.Option) listener.Listener type NewListener func(opts ...listener.Option) listener.Listener
type NewHandler func(opts ...handler.Option) handler.Handler type NewHandler func(opts ...handler.Option) handler.Handler
type NewDialer func(opts ...dialer.Option) dialer.Dialer type NewDialer func(opts ...dialer.Option) dialer.Dialer

52
pkg/registry/resolver.go Normal file
View File

@ -0,0 +1,52 @@
package registry
import (
"context"
"net"
"sync"
"github.com/go-gost/gost/pkg/resolver"
)
var (
resolverReg = &resolverRegistry{}
)
func Resolver() *resolverRegistry {
return resolverReg
}
type resolverRegistry struct {
m sync.Map
}
func (r *resolverRegistry) Register(name string, resolver resolver.Resolver) error {
if _, loaded := r.m.LoadOrStore(name, resolver); loaded {
return ErrDup
}
return nil
}
func (r *resolverRegistry) Unregister(name string) {
r.m.Delete(name)
}
func (r *resolverRegistry) Get(name string) resolver.Resolver {
if _, ok := r.m.Load(name); !ok {
return nil
}
return &resolverWrapper{name: name}
}
type resolverWrapper struct {
name string
}
func (w *resolverWrapper) Resolve(ctx context.Context, network, host string) ([]net.IP, error) {
r := Resolver().Get(w.name)
if r == nil {
return nil, ErrNotFound
}
return r.Resolve(ctx, network, host)
}

39
pkg/registry/service.go Normal file
View File

@ -0,0 +1,39 @@
package registry
import (
"sync"
"github.com/go-gost/gost/pkg/service"
)
var (
svcReg = &serviceRegistry{}
)
func Service() *serviceRegistry {
return svcReg
}
type serviceRegistry struct {
m sync.Map
}
func (r *serviceRegistry) Register(name string, svc *service.Service) error {
if _, loaded := r.m.LoadOrStore(name, svc); loaded {
return ErrDup
}
return nil
}
func (r *serviceRegistry) Unregister(name string) {
r.m.Delete(name)
}
func (r *serviceRegistry) Get(name string) *service.Service {
v, ok := r.m.Load(name)
if !ok {
return nil
}
return v.(*service.Service)
}

View File

@ -16,7 +16,7 @@ import (
type NameServer struct { type NameServer struct {
Addr string Addr string
Chain *chain.Chain Chain chain.Chainer
TTL time.Duration TTL time.Duration
Timeout time.Duration Timeout time.Duration
ClientIP net.IP ClientIP net.IP
@ -89,7 +89,7 @@ func NewResolver(nameservers []NameServer, opts ...ResolverOption) (resolverpkg.
}, nil }, nil
} }
func (r *resolver) Resolve(ctx context.Context, host string) (ips []net.IP, err error) { func (r *resolver) Resolve(ctx context.Context, network, host string) (ips []net.IP, err error) {
if ip := net.ParseIP(host); ip != nil { if ip := net.ParseIP(host); ip != nil {
return []net.IP{ip}, nil return []net.IP{ip}, nil
} }

View File

@ -7,5 +7,6 @@ import (
type Resolver interface { type Resolver interface {
// Resolve returns a slice of the host's IPv4 and IPv6 addresses. // Resolve returns a slice of the host's IPv4 and IPv6 addresses.
Resolve(ctx context.Context, host string) ([]net.IP, error) // The network should be 'ip', 'ip4' or 'ip6', default network is 'ip'.
Resolve(ctx context.Context, network, host string) ([]net.IP, error)
} }