diff --git a/Dockerfile b/Dockerfile index a943757..b3fb06b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,5 @@ FROM --platform=$BUILDPLATFORM golang:1-alpine as builder +# FROM --platform=$BUILDPLATFORM golang:1.18-rc-alpine as builder # Convert TARGETPLATFORM to GOARCH format # https://github.com/tonistiigi/xx diff --git a/cmd/gost/config.go b/cmd/gost/config.go index 44254e9..c68fca3 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -1,308 +1,76 @@ package main import ( - "crypto/tls" "io" - "net" - "net/url" "os" - "strings" - "github.com/go-gost/gost/pkg/bypass" - "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/config" - "github.com/go-gost/gost/pkg/connector" - "github.com/go-gost/gost/pkg/dialer" - "github.com/go-gost/gost/pkg/handler" - hostspkg "github.com/go-gost/gost/pkg/hosts" - "github.com/go-gost/gost/pkg/listener" + "github.com/go-gost/gost/pkg/config/parsing" "github.com/go-gost/gost/pkg/logger" - "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" - "github.com/go-gost/gost/pkg/resolver" - resolver_impl "github.com/go-gost/gost/pkg/resolver/impl" "github.com/go-gost/gost/pkg/service" ) -var ( - chains = make(map[string]*chain.Chain) - bypasses = make(map[string]bypass.Bypass) - resolvers = make(map[string]resolver.Resolver) - hosts = make(map[string]hostspkg.HostMapper) -) - func buildService(cfg *config.Config) (services []*service.Service) { if cfg == nil || len(cfg.Services) == 0 { return } for _, bypassCfg := range cfg.Bypasses { - bypasses[bypassCfg.Name] = bypassFromConfig(bypassCfg) + if bp := parsing.ParseBypass(bypassCfg); bp != nil { + if err := registry.Bypass().Register(bypassCfg.Name, bp); err != nil { + log.Fatal(err) + } + } } for _, resolverCfg := range cfg.Resolvers { - r, err := resolverFromConfig(resolverCfg) + r, err := parsing.ParseResolver(resolverCfg) if err != nil { log.Fatal(err) } - resolvers[resolverCfg.Name] = r + if r != nil { + if err := registry.Resolver().Register(resolverCfg.Name, r); err != nil { + log.Fatal(err) + } + } } + for _, hostsCfg := range cfg.Hosts { - hosts[hostsCfg.Name] = hostsFromConfig(hostsCfg) + if h := parsing.ParseHosts(hostsCfg); h != nil { + if err := registry.Hosts().Register(hostsCfg.Name, h); err != nil { + log.Fatal(err) + } + } } for _, chainCfg := range cfg.Chains { - chains[chainCfg.Name] = chainFromConfig(chainCfg) + c, err := parsing.ParseChain(chainCfg) + if err != nil { + log.Fatal(err) + } + if c != nil { + if err := registry.Chain().Register(chainCfg.Name, c); err != nil { + log.Fatal(err) + } + } } - for _, svc := range cfg.Services { - if svc.Listener == nil { - svc.Listener = &config.ListenerConfig{ - Type: "tcp", - } - } - if svc.Handler == nil { - svc.Handler = &config.HandlerConfig{ - Type: "auto", - } - } - serviceLogger := log.WithFields(map[string]interface{}{ - "kind": "service", - "service": svc.Name, - "listener": svc.Listener.Type, - "handler": svc.Handler.Type, - }) - - listenerLogger := serviceLogger.WithFields(map[string]interface{}{ - "kind": "listener", - }) - - var tlsConfig *tls.Config - var err error - - tlsCfg := svc.Listener.TLS - if tlsCfg == nil { - tlsCfg = &config.TLSConfig{} - } - tlsConfig, err = loadServerTLSConfig(tlsCfg) + for _, svcCfg := range cfg.Services { + svc, err := parsing.ParseService(svcCfg) if err != nil { log.Fatal(err) } - - ln := registry.GetListener(svc.Listener.Type)( - listener.AddrOption(svc.Addr), - listener.ChainOption(chains[svc.Listener.Chain]), - listener.AuthsOption(authsFromConfig(svc.Listener.Auths...)...), - listener.TLSConfigOption(tlsConfig), - listener.LoggerOption(listenerLogger), - ) - - if svc.Listener.Metadata == nil { - svc.Listener.Metadata = make(map[string]interface{}) + if svc != nil { + if err := registry.Service().Register(svcCfg.Name, svc); err != nil { + log.Fatal(err) + } } - if err := ln.Init(metadata.MapMetadata(svc.Listener.Metadata)); err != nil { - listenerLogger.Fatal("init: ", err) - } - - handlerLogger := serviceLogger.WithFields(map[string]interface{}{ - "kind": "handler", - }) - - tlsConfig = nil - tlsCfg = svc.Handler.TLS - if tlsCfg == nil { - tlsCfg = &config.TLSConfig{} - } - tlsConfig, err = loadServerTLSConfig(tlsCfg) - if err != nil { - log.Fatal(err) - } - - h := registry.GetHandler(svc.Handler.Type)( - handler.AuthsOption(authsFromConfig(svc.Handler.Auths...)...), - handler.RetriesOption(svc.Handler.Retries), - handler.ChainOption(chains[svc.Handler.Chain]), - handler.BypassOption(bypasses[svc.Bypass]), - handler.ResolverOption(resolvers[svc.Resolver]), - handler.HostsOption(hosts[svc.Hosts]), - handler.TLSConfigOption(tlsConfig), - handler.LoggerOption(handlerLogger), - ) - - if forwarder, ok := h.(handler.Forwarder); ok { - forwarder.Forward(forwarderFromConfig(svc.Forwarder)) - } - - if svc.Handler.Metadata == nil { - svc.Handler.Metadata = make(map[string]interface{}) - } - if err := h.Init(metadata.MapMetadata(svc.Handler.Metadata)); err != nil { - handlerLogger.Fatal("init: ", err) - } - - s := (&service.Service{}). - WithListener(ln). - WithHandler(h). - WithLogger(serviceLogger) - services = append(services, s) - - serviceLogger.Infof("listening on %s/%s", s.Addr().String(), s.Addr().Network()) } return } -func chainFromConfig(cfg *config.ChainConfig) *chain.Chain { - if cfg == nil { - return nil - } - - chainLogger := log.WithFields(map[string]interface{}{ - "kind": "chain", - "chain": cfg.Name, - }) - - c := &chain.Chain{} - selector := selectorFromConfig(cfg.Selector) - for _, hop := range cfg.Hops { - group := &chain.NodeGroup{} - for _, v := range hop.Nodes { - nodeLogger := chainLogger.WithFields(map[string]interface{}{ - "kind": "node", - "connector": v.Connector.Type, - "dialer": v.Dialer.Type, - "hop": hop.Name, - "node": v.Name, - }) - connectorLogger := nodeLogger.WithFields(map[string]interface{}{ - "kind": "connector", - }) - - var user *url.Userinfo - if auth := v.Connector.Auth; auth != nil && auth.Username != "" { - if auth.Password == "" { - user = url.User(auth.Username) - } else { - user = url.UserPassword(auth.Username, auth.Password) - } - } - - var tlsConfig *tls.Config - var err error - tlsCfg := v.Connector.TLS - if tlsCfg == nil { - tlsCfg = &config.TLSConfig{} - } - tlsConfig, err = loadClientTLSConfig(tlsCfg) - if err != nil { - log.Fatal(err) - } - - cr := registry.GetConnector(v.Connector.Type)( - connector.UserOption(user), - connector.TLSConfigOption(tlsConfig), - connector.LoggerOption(connectorLogger), - ) - - if v.Connector.Metadata == nil { - v.Connector.Metadata = make(map[string]interface{}) - } - if err := cr.Init(metadata.MapMetadata(v.Connector.Metadata)); err != nil { - connectorLogger.Fatal("init: ", err) - } - - dialerLogger := nodeLogger.WithFields(map[string]interface{}{ - "kind": "dialer", - }) - - user = nil - if auth := v.Dialer.Auth; auth != nil && auth.Username != "" { - if auth.Password == "" { - user = url.User(auth.Username) - } else { - user = url.UserPassword(auth.Username, auth.Password) - } - } - - tlsConfig = nil - tlsCfg = v.Dialer.TLS - if tlsCfg == nil { - tlsCfg = &config.TLSConfig{} - } - tlsConfig, err = loadClientTLSConfig(tlsCfg) - if err != nil { - log.Fatal(err) - } - - d := registry.GetDialer(v.Dialer.Type)( - dialer.UserOption(user), - dialer.TLSConfigOption(tlsConfig), - dialer.LoggerOption(dialerLogger), - ) - - if v.Dialer.Metadata == nil { - v.Dialer.Metadata = make(map[string]interface{}) - } - if err := d.Init(metadata.MapMetadata(v.Dialer.Metadata)); err != nil { - dialerLogger.Fatal("init: ", err) - } - - tr := (&chain.Transport{}). - WithConnector(cr). - WithDialer(d). - WithAddr(v.Addr) - - if v.Bypass == "" { - v.Bypass = hop.Bypass - } - if v.Resolver == "" { - v.Resolver = hop.Resolver - } - if v.Hosts == "" { - v.Hosts = hop.Hosts - } - - node := &chain.Node{ - Name: v.Name, - Addr: v.Addr, - Transport: tr, - Bypass: bypasses[v.Bypass], - Resolver: resolvers[v.Resolver], - Hosts: hosts[v.Hosts], - Marker: &chain.FailMarker{}, - } - group.AddNode(node) - } - - sel := selector - if s := selectorFromConfig(hop.Selector); s != nil { - sel = s - } - group.WithSelector(sel) - c.AddNodeGroup(group) - } - - return c -} - -func forwarderFromConfig(cfg *config.ForwarderConfig) *chain.NodeGroup { - if cfg == nil || len(cfg.Targets) == 0 { - return nil - } - - group := &chain.NodeGroup{} - for _, target := range cfg.Targets { - if v := strings.TrimSpace(target); v != "" { - group.AddNode(&chain.Node{ - Name: target, - Addr: target, - Marker: &chain.FailMarker{}, - }) - } - } - return group.WithSelector(selectorFromConfig(cfg.Selector)) -} - func logFromConfig(cfg *config.LogConfig) logger.Logger { if cfg == nil { cfg = &config.LogConfig{} @@ -314,7 +82,7 @@ func logFromConfig(cfg *config.LogConfig) logger.Logger { var out io.Writer = os.Stderr switch cfg.Output { - case "none": + case "none", "null": return logger.Nop() case "stdout": out = os.Stdout @@ -332,105 +100,3 @@ func logFromConfig(cfg *config.LogConfig) logger.Logger { return logger.NewLogger(opts...) } - -func selectorFromConfig(cfg *config.SelectorConfig) chain.Selector { - if cfg == nil { - return nil - } - - var strategy chain.Strategy - switch cfg.Strategy { - case "round", "rr": - strategy = chain.RoundRobinStrategy() - case "random", "rand": - strategy = chain.RandomStrategy() - case "fifo", "ha": - strategy = chain.FIFOStrategy() - default: - strategy = chain.RoundRobinStrategy() - } - - return chain.NewSelector( - strategy, - chain.InvalidFilter(), - chain.FailFilter(cfg.MaxFails, cfg.FailTimeout), - ) -} - -func bypassFromConfig(cfg *config.BypassConfig) bypass.Bypass { - if cfg == nil { - return nil - } - return bypass.NewBypassPatterns( - cfg.Reverse, - cfg.Matchers, - bypass.LoggerBypassOption(log.WithFields(map[string]interface{}{ - "kind": "bypass", - "bypass": cfg.Name, - })), - ) -} - -func resolverFromConfig(cfg *config.ResolverConfig) (resolver.Resolver, error) { - if cfg == nil { - return nil, nil - } - var nameservers []resolver_impl.NameServer - for _, server := range cfg.Nameservers { - nameservers = append(nameservers, resolver_impl.NameServer{ - Addr: server.Addr, - Chain: chains[server.Chain], - TTL: server.TTL, - Timeout: server.Timeout, - ClientIP: net.ParseIP(server.ClientIP), - Prefer: server.Prefer, - Hostname: server.Hostname, - }) - } - - logger := log.WithFields(map[string]interface{}{ - "kind": "resolver", - "resolver": cfg.Name, - }) - return resolver_impl.NewResolver( - nameservers, - resolver_impl.LoggerResolverOption(logger), - ) -} - -func hostsFromConfig(cfg *config.HostsConfig) hostspkg.HostMapper { - if cfg == nil || len(cfg.Mappings) == 0 { - return nil - } - hosts := hostspkg.NewHosts() - hosts.Logger = log.WithFields(map[string]interface{}{ - "kind": "hosts", - "hosts": cfg.Name, - }) - - for _, host := range cfg.Mappings { - if host.IP == "" || host.Hostname == "" { - continue - } - - ip := net.ParseIP(host.IP) - if ip == nil { - continue - } - hosts.Map(ip, host.Hostname, host.Aliases...) - } - return hosts -} - -func authsFromConfig(cfgs ...*config.AuthConfig) []*url.Userinfo { - var auths []*url.Userinfo - - for _, cfg := range cfgs { - if cfg == nil || cfg.Username == "" { - continue - } - auths = append(auths, url.UserPassword(cfg.Username, cfg.Password)) - } - - return auths -} diff --git a/cmd/gost/main.go b/cmd/gost/main.go index dbdc428..225b294 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -3,7 +3,6 @@ package main import ( "flag" "fmt" - "io" "net/http" _ "net/http/pprof" "os" @@ -16,11 +15,11 @@ import ( var ( log = logger.Default() - cfgFile string - outputCfgFile string - services stringList - nodes stringList - debug bool + cfgFile string + outputFormat string + services stringList + nodes stringList + debug bool ) func init() { @@ -31,7 +30,7 @@ func init() { flag.StringVar(&cfgFile, "C", "", "configure file") flag.BoolVar(&printVersion, "V", false, "print version") flag.BoolVar(&debug, "D", false, "debug mode") - flag.StringVar(&outputCfgFile, "O", "", "write config to FILE") + flag.StringVar(&outputFormat, "O", "", "output format, one of yaml|json format") flag.Parse() if printVersion { @@ -65,19 +64,8 @@ func main() { log = logFromConfig(cfg.Log) - if outputCfgFile != "" { - var w io.Writer - if outputCfgFile == "-" { - w = os.Stdout - } else { - f, err := os.Create(outputCfgFile) - if err != nil { - log.Fatal(err) - } - defer f.Close() - w = f - } - if err := cfg.Write(w); err != nil { + if outputFormat != "" { + if err := cfg.Write(os.Stdout, outputFormat); err != nil { log.Fatal(err) } os.Exit(0) diff --git a/cmd/gost/tls.go b/cmd/gost/tls.go index c2c9e24..3f83cb8 100644 --- a/cmd/gost/tls.go +++ b/cmd/gost/tls.go @@ -14,14 +14,6 @@ import ( "github.com/go-gost/gost/pkg/config" ) -func loadServerTLSConfig(cfg *config.TLSConfig) (*tls.Config, error) { - return tls_util.LoadServerConfig(cfg.CertFile, cfg.KeyFile, cfg.CAFile) -} - -func loadClientTLSConfig(cfg *config.TLSConfig) (*tls.Config, error) { - return tls_util.LoadClientConfig(cfg.CertFile, cfg.KeyFile, cfg.CAFile, cfg.Secure, cfg.ServerName) -} - func buildDefaultTLSConfig(cfg *config.TLSConfig) { if cfg == nil { cfg = &config.TLSConfig{ diff --git a/pkg/chain/chain.go b/pkg/chain/chain.go index 57275dc..90c3ba6 100644 --- a/pkg/chain/chain.go +++ b/pkg/chain/chain.go @@ -1,5 +1,9 @@ package chain +type Chainer interface { + Route(network, address string) *Route +} + type Chain struct { groups []*NodeGroup } @@ -8,16 +12,12 @@ func (c *Chain) AddNodeGroup(group *NodeGroup) { c.groups = append(c.groups, group) } -func (c *Chain) GetRoute() (r *route) { - return c.GetRouteFor("tcp", "") -} - -func (c *Chain) GetRouteFor(network, address string) (r *route) { +func (c *Chain) Route(network, address string) (r *Route) { if c == nil || len(c.groups) == 0 { return } - r = &route{} + r = &Route{} for _, group := range c.groups { node := group.Next() if node == nil { @@ -32,14 +32,10 @@ func (c *Chain) GetRouteFor(network, address string) (r *route) { WithRoute(r) node = node.Copy() node.Transport = tr - r = &route{} + r = &Route{} } - r.AddNode(node) + r.addNode(node) } return r } - -func (c *Chain) IsEmpty() bool { - return c == nil || len(c.groups) == 0 -} diff --git a/pkg/chain/resovle.go b/pkg/chain/resovle.go index 2a736b9..bad7573 100644 --- a/pkg/chain/resovle.go +++ b/pkg/chain/resovle.go @@ -10,7 +10,7 @@ import ( "github.com/go-gost/gost/pkg/resolver" ) -func resolve(ctx context.Context, addr string, resolver resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) { +func resolve(ctx context.Context, network, addr string, resolver resolver.Resolver, hosts hosts.HostMapper, log logger.Logger) (string, error) { if addr == "" { return addr, nil } @@ -24,14 +24,14 @@ func resolve(ctx context.Context, addr string, resolver resolver.Resolver, hosts } if hosts != nil { - if ips, _ := hosts.Lookup("ip", host); len(ips) > 0 { + if ips, _ := hosts.Lookup(network, host); len(ips) > 0 { log.Debugf("hit host mapper: %s -> %s", host, ips) return net.JoinHostPort(ips[0].String(), port), nil } } if resolver != nil { - ips, err := resolver.Resolve(ctx, host) + ips, err := resolver.Resolve(ctx, network, host) if err != nil { log.Error(err) } diff --git a/pkg/chain/route.go b/pkg/chain/route.go index ff3b89b..16beafe 100644 --- a/pkg/chain/route.go +++ b/pkg/chain/route.go @@ -15,17 +15,17 @@ var ( ErrEmptyRoute = errors.New("empty route") ) -type route struct { +type Route struct { nodes []*Node logger logger.Logger } -func (r *route) AddNode(node *Node) { +func (r *Route) addNode(node *Node) { r.nodes = append(r.nodes, node) } -func (r *route) Dial(ctx context.Context, network, address string) (net.Conn, error) { - if r.IsEmpty() { +func (r *Route) Dial(ctx context.Context, network, address string) (net.Conn, error) { + if r.Len() == 0 { return r.dialDirect(ctx, network, address) } @@ -34,7 +34,7 @@ func (r *route) Dial(ctx context.Context, network, address string) (net.Conn, er return nil, err } - cc, err := r.Last().Transport.Connect(ctx, conn, network, address) + cc, err := r.GetNode(r.Len()-1).Transport.Connect(ctx, conn, network, address) if err != nil { conn.Close() return nil, err @@ -42,7 +42,7 @@ func (r *route) Dial(ctx context.Context, network, address string) (net.Conn, er return cc, nil } -func (r *route) dialDirect(ctx context.Context, network, address string) (net.Conn, error) { +func (r *Route) dialDirect(ctx context.Context, network, address string) (net.Conn, error) { switch network { case "udp", "udp4", "udp6": if address == "" { @@ -55,8 +55,8 @@ func (r *route) dialDirect(ctx context.Context, network, address string) (net.Co return d.DialContext(ctx, network, address) } -func (r *route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { - if r.IsEmpty() { +func (r *Route) Bind(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { + if r.Len() == 0 { return r.bindLocal(ctx, network, address, opts...) } @@ -65,7 +65,7 @@ func (r *route) Bind(ctx context.Context, network, address string, opts ...conne return nil, err } - ln, err := r.Last().Transport.Bind(ctx, conn, network, address, opts...) + ln, err := r.GetNode(r.Len()-1).Transport.Bind(ctx, conn, network, address, opts...) if err != nil { conn.Close() return nil, err @@ -74,14 +74,15 @@ func (r *route) Bind(ctx context.Context, network, address string, opts ...conne return ln, nil } -func (r *route) connect(ctx context.Context) (conn net.Conn, err error) { - if r.IsEmpty() { +func (r *Route) connect(ctx context.Context) (conn net.Conn, err error) { + if r.Len() == 0 { return nil, ErrEmptyRoute } + network := "ip" node := r.nodes[0] - addr, err := resolve(ctx, node.Addr, node.Resolver, node.Hosts, r.logger) + addr, err := resolve(ctx, network, node.Addr, node.Resolver, node.Hosts, r.logger) if err != nil { node.Marker.Mark() return @@ -102,7 +103,7 @@ func (r *route) connect(ctx context.Context) (conn net.Conn, err error) { preNode := node for _, node := range r.nodes[1:] { - addr, err = resolve(ctx, node.Addr, node.Resolver, node.Hosts, r.logger) + addr, err = resolve(ctx, network, node.Addr, node.Resolver, node.Hosts, r.logger) if err != nil { cn.Close() node.Marker.Mark() @@ -130,18 +131,21 @@ func (r *route) connect(ctx context.Context) (conn net.Conn, err error) { return } -func (r *route) IsEmpty() bool { - return r == nil || len(r.nodes) == 0 +func (r *Route) Len() int { + if r == nil { + return 0 + } + return len(r.nodes) } -func (r *route) Last() *Node { - if r.IsEmpty() { +func (r *Route) GetNode(index int) *Node { + if r.Len() == 0 || index < 0 || index >= len(r.nodes) { return nil } - return r.nodes[len(r.nodes)-1] + return r.nodes[index] } -func (r *route) Path() (path []*Node) { +func (r *Route) Path() (path []*Node) { if r == nil || len(r.nodes) == 0 { return nil } @@ -155,7 +159,7 @@ func (r *route) Path() (path []*Node) { return } -func (r *route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { +func (r *Route) bindLocal(ctx context.Context, network, address string, opts ...connector.BindOption) (net.Listener, error) { options := connector.BindOptions{} for _, opt := range opts { opt(&options) diff --git a/pkg/chain/router.go b/pkg/chain/router.go index 21d48a8..da6f716 100644 --- a/pkg/chain/router.go +++ b/pkg/chain/router.go @@ -14,7 +14,7 @@ import ( type Router struct { Retries int - Chain *Chain + Chain Chainer Hosts hosts.HostMapper Resolver resolver.Resolver Logger logger.Logger @@ -41,7 +41,10 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co r.Logger.Debugf("dial %s/%s", address, network) for i := 0; i < count; i++ { - route := r.Chain.GetRouteFor(network, address) + var route *Route + if r.Chain != nil { + route = r.Chain.Route(network, address) + } if r.Logger.IsLevelEnabled(logger.DebugLevel) { buf := bytes.Buffer{} @@ -52,7 +55,7 @@ func (r *Router) dial(ctx context.Context, network, address string) (conn net.Co r.Logger.Debugf("route(retry=%d) %s", i, buf.String()) } - address, err = resolve(ctx, address, r.Resolver, r.Hosts, r.Logger) + address, err = resolve(ctx, "ip", address, r.Resolver, r.Hosts, r.Logger) if err != nil { r.Logger.Error(err) break @@ -80,7 +83,10 @@ func (r *Router) Bind(ctx context.Context, network, address string, opts ...conn r.Logger.Debugf("bind on %s/%s", address, network) for i := 0; i < count; i++ { - route := r.Chain.GetRouteFor(network, address) + var route *Route + if r.Chain != nil { + route = r.Chain.Route(network, address) + } if r.Logger.IsLevelEnabled(logger.DebugLevel) { buf := bytes.Buffer{} diff --git a/pkg/chain/transport.go b/pkg/chain/transport.go index c7938b9..bf32351 100644 --- a/pkg/chain/transport.go +++ b/pkg/chain/transport.go @@ -10,7 +10,7 @@ import ( type Transport struct { addr string - route *route + route *Route dialer dialer.Dialer connector connector.Connector } @@ -39,7 +39,7 @@ func (tr *Transport) dialOptions() []dialer.DialOption { opts := []dialer.DialOption{ dialer.HostDialOption(tr.addr), } - if !tr.route.IsEmpty() { + if tr.route.Len() > 0 { opts = append(opts, dialer.DialFuncDialOption( func(ctx context.Context, addr string) (net.Conn, error) { @@ -84,7 +84,7 @@ func (tr *Transport) Multiplex() bool { return false } -func (tr *Transport) WithRoute(r *route) *Transport { +func (tr *Transport) WithRoute(r *Route) *Transport { tr.route = r return tr } diff --git a/pkg/config/config.go b/pkg/config/config.go index cd683d7..b8c119f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,6 +1,7 @@ package config import ( + "encoding/json" "io" "time" @@ -20,148 +21,148 @@ func init() { } type LogConfig struct { - Output string `yaml:",omitempty"` - Level string `yaml:",omitempty"` - Format string `yaml:",omitempty"` + Output string `yaml:",omitempty" json:"output,omitempty"` + Level string `yaml:",omitempty" json:"level,omitempty"` + Format string `yaml:",omitempty" json:"format,omitempty"` } type ProfilingConfig struct { - Addr string - Enabled bool + Addr string `json:"addr"` + Enabled bool `json:"enabled"` } type TLSConfig struct { - CertFile string `yaml:"certFile,omitempty"` - KeyFile string `yaml:"keyFile,omitempty"` - CAFile string `yaml:"caFile,omitempty"` - Secure bool `yaml:",omitempty"` - ServerName string `yaml:"serverName,omitempty"` + CertFile string `yaml:"certFile,omitempty" json:"certFile,omitempty"` + KeyFile string `yaml:"keyFile,omitempty" json:"keyFile,omitempty"` + CAFile string `yaml:"caFile,omitempty" json:"caFile,omitempty"` + Secure bool `yaml:",omitempty" json:"secure,omitempty"` + ServerName string `yaml:"serverName,omitempty" json:"serverName,omitempty"` } type AuthConfig struct { - Username string - Password string + Username string `json:"username"` + Password string `yaml:",omitempty" json:"password,omitempty"` } type SelectorConfig struct { - Strategy string - MaxFails int `yaml:"maxFails"` - FailTimeout time.Duration `yaml:"failTimeout"` + Strategy string `json:"strategy"` + MaxFails int `yaml:"maxFails" json:"maxFails"` + FailTimeout time.Duration `yaml:"failTimeout" json:"failTimeout"` } type BypassConfig struct { - Name string - Reverse bool `yaml:",omitempty"` - Matchers []string + Name string `json:"name"` + Reverse bool `yaml:",omitempty" json:"reverse,omitempty"` + Matchers []string `json:"matchers"` } type NameserverConfig struct { - Addr string - Chain string `yaml:",omitempty"` - Prefer string `yaml:",omitempty"` - ClientIP string `yaml:"clientIP,omitempty"` - Hostname string `yaml:",omitempty"` - TTL time.Duration `yaml:",omitempty"` - Timeout time.Duration `yaml:",omitempty"` + Addr string `json:"addr"` + Chain string `yaml:",omitempty" json:"chain,omitempty"` + Prefer string `yaml:",omitempty" json:"prefer,omitempty"` + ClientIP string `yaml:"clientIP,omitempty" json:"clientIP,omitempty"` + Hostname string `yaml:",omitempty" json:"hostname,omitempty"` + TTL time.Duration `yaml:",omitempty" json:"ttl,omitempty"` + Timeout time.Duration `yaml:",omitempty" json:"timeout,omitempty"` } type ResolverConfig struct { - Name string - Nameservers []NameserverConfig + Name string `json:"name"` + Nameservers []NameserverConfig `json:"nameservers"` } type HostMappingConfig struct { - IP string - Hostname string - Aliases []string `yaml:",omitempty"` + IP string `json:"ip"` + Hostname string `json:"hostname"` + Aliases []string `yaml:",omitempty" json:"aliases,omitempty"` } type HostsConfig struct { - Name string - Mappings []HostMappingConfig + Name string `json:"name"` + Mappings []HostMappingConfig `json:"mappings"` } type ListenerConfig struct { - Type string - Chain string `yaml:",omitempty"` - Auths []*AuthConfig `yaml:",omitempty"` - TLS *TLSConfig `yaml:",omitempty"` - Metadata map[string]interface{} `yaml:",omitempty"` + Type string `json:"type"` + Chain string `yaml:",omitempty" json:"chain,omitempty"` + Auths []*AuthConfig `yaml:",omitempty" json:"auths,omitempty"` + TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` + Metadata map[string]interface{} `yaml:",omitempty" json:"metadata,omitempty"` } type HandlerConfig struct { - Type string - Retries int `yaml:",omitempty"` - Chain string `yaml:",omitempty"` - Auths []*AuthConfig `yaml:",omitempty"` - TLS *TLSConfig `yaml:",omitempty"` - Metadata map[string]interface{} `yaml:",omitempty"` + Type string `json:"type"` + Retries int `yaml:",omitempty" json:"retries,omitempty"` + Chain string `yaml:",omitempty" json:"chain,omitempty"` + Auths []*AuthConfig `yaml:",omitempty" json:"auths,omitempty"` + TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` + Metadata map[string]interface{} `yaml:",omitempty" json:"metadata,omitempty"` } type ForwarderConfig struct { - Targets []string - Selector *SelectorConfig `yaml:",omitempty"` + Targets []string `json:"targets"` + Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"` } type DialerConfig struct { - Type string - Auth *AuthConfig `yaml:",omitempty"` - TLS *TLSConfig `yaml:",omitempty"` - Metadata map[string]interface{} `yaml:",omitempty"` + Type string `json:"type"` + Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` + TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` + Metadata map[string]interface{} `yaml:",omitempty" json:"metadata,omitempty"` } type ConnectorConfig struct { - Type string - Auth *AuthConfig `yaml:",omitempty"` - TLS *TLSConfig `yaml:",omitempty"` - Metadata map[string]interface{} `yaml:",omitempty"` + Type string `json:"type"` + Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` + TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` + Metadata map[string]interface{} `yaml:",omitempty" json:"metadata,omitempty"` } type ServiceConfig struct { - Name string - Addr string `yaml:",omitempty"` - Bypass string `yaml:",omitempty"` - Resolver string `yaml:",omitempty"` - Hosts string `yaml:",omitempty"` - Handler *HandlerConfig `yaml:",omitempty"` - Listener *ListenerConfig `yaml:",omitempty"` - Forwarder *ForwarderConfig `yaml:",omitempty"` + Name string `json:"name"` + Addr string `yaml:",omitempty" json:"addr,omitempty"` + Bypass string `yaml:",omitempty" json:"bypass,omitempty"` + Resolver string `yaml:",omitempty" json:"resolver,omitempty"` + Hosts string `yaml:",omitempty" json:"hosts,omitempty"` + Handler *HandlerConfig `yaml:",omitempty" json:"handler,omitempty"` + Listener *ListenerConfig `yaml:",omitempty" json:"listener,omitempty"` + Forwarder *ForwarderConfig `yaml:",omitempty" json:"forwarder,omitempty"` } type ChainConfig struct { - Name string - Selector *SelectorConfig `yaml:",omitempty"` - Hops []*HopConfig + Name string `json:"name"` + Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"` + Hops []*HopConfig `json:"hops"` } type HopConfig struct { - Name string - Selector *SelectorConfig `yaml:",omitempty"` - Bypass string `yaml:",omitempty"` - Resolver string `yaml:",omitempty"` - Hosts string `yaml:",omitempty"` - Nodes []*NodeConfig + Name string `json:"name"` + Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"` + Bypass string `yaml:",omitempty" json:"bypass,omitempty"` + Resolver string `yaml:",omitempty" json:"resolver,omitempty"` + Hosts string `yaml:",omitempty" json:"hosts,omitempty"` + Nodes []*NodeConfig `json:"nodes"` } type NodeConfig struct { - Name string - Addr string `yaml:",omitempty"` - Bypass string `yaml:",omitempty"` - Resolver string `yaml:",omitempty"` - Hosts string `yaml:",omitempty"` - Connector *ConnectorConfig `yaml:",omitempty"` - Dialer *DialerConfig `yaml:",omitempty"` + Name string `json:"name"` + Addr string `yaml:",omitempty" json:"addr,omitempty"` + Bypass string `yaml:",omitempty" json:"bypass,omitempty"` + Resolver string `yaml:",omitempty" json:"resolver,omitempty"` + Hosts string `yaml:",omitempty" json:"hosts,omitempty"` + Connector *ConnectorConfig `yaml:",omitempty" json:"connector,omitempty"` + Dialer *DialerConfig `yaml:",omitempty" json:"dialer,omitempty"` } type Config struct { - Services []*ServiceConfig - Chains []*ChainConfig `yaml:",omitempty"` - Bypasses []*BypassConfig `yaml:",omitempty"` - Resolvers []*ResolverConfig `yaml:",omitempty"` - Hosts []*HostsConfig `yaml:",omitempty"` - TLS *TLSConfig `yaml:",omitempty"` - Log *LogConfig `yaml:",omitempty"` - Profiling *ProfilingConfig `yaml:",omitempty"` + Services []*ServiceConfig `json:"services"` + Chains []*ChainConfig `yaml:",omitempty" json:"chains,omitempty"` + Bypasses []*BypassConfig `yaml:",omitempty" json:"bypasses,omitempty"` + Resolvers []*ResolverConfig `yaml:",omitempty" json:"resolvers,omitempty"` + Hosts []*HostsConfig `yaml:",omitempty" json:"hosts,omitempty"` + TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` + Log *LogConfig `yaml:",omitempty" json:"log,omitempty"` + Profiling *ProfilingConfig `yaml:",omitempty" json:"profiling,omitempty"` } func (c *Config) Load() error { @@ -188,9 +189,19 @@ func (c *Config) ReadFile(file string) error { return v.Unmarshal(c) } -func (c *Config) Write(w io.Writer) error { - enc := yaml.NewEncoder(w) - defer enc.Close() +func (c *Config) Write(w io.Writer, format string) error { + switch format { + case "json": + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + enc.Encode(c) + return nil + case "yaml": + fallthrough + default: + enc := yaml.NewEncoder(w) + defer enc.Close() - return enc.Encode(c) + return enc.Encode(c) + } } diff --git a/pkg/config/parsing/chain.go b/pkg/config/parsing/chain.go new file mode 100644 index 0000000..051e742 --- /dev/null +++ b/pkg/config/parsing/chain.go @@ -0,0 +1,152 @@ +package parsing + +import ( + "net/url" + + "github.com/go-gost/gost/pkg/chain" + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + "github.com/go-gost/gost/pkg/config" + "github.com/go-gost/gost/pkg/connector" + "github.com/go-gost/gost/pkg/dialer" + "github.com/go-gost/gost/pkg/logger" + "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { + if cfg == nil { + return nil, nil + } + + chainLogger := logger.Default().WithFields(map[string]interface{}{ + "kind": "chain", + "chain": cfg.Name, + }) + + c := &chain.Chain{} + selector := parseSelector(cfg.Selector) + for _, hop := range cfg.Hops { + group := &chain.NodeGroup{} + for _, v := range hop.Nodes { + nodeLogger := chainLogger.WithFields(map[string]interface{}{ + "kind": "node", + "connector": v.Connector.Type, + "dialer": v.Dialer.Type, + "hop": hop.Name, + "node": v.Name, + }) + connectorLogger := nodeLogger.WithFields(map[string]interface{}{ + "kind": "connector", + }) + + var user *url.Userinfo + if auth := v.Connector.Auth; auth != nil && auth.Username != "" { + if auth.Password == "" { + user = url.User(auth.Username) + } else { + user = url.UserPassword(auth.Username, auth.Password) + } + } + + tlsCfg := v.Connector.TLS + if tlsCfg == nil { + tlsCfg = &config.TLSConfig{} + } + tlsConfig, err := tls_util.LoadClientConfig( + tlsCfg.CertFile, tlsCfg.KeyFile, tlsCfg.CAFile, + tlsCfg.Secure, tlsCfg.ServerName) + if err != nil { + chainLogger.Error(err) + return nil, err + } + + cr := registry.GetConnector(v.Connector.Type)( + connector.UserOption(user), + connector.TLSConfigOption(tlsConfig), + connector.LoggerOption(connectorLogger), + ) + + if v.Connector.Metadata == nil { + v.Connector.Metadata = make(map[string]interface{}) + } + if err := cr.Init(metadata.MapMetadata(v.Connector.Metadata)); err != nil { + connectorLogger.Error("init: ", err) + return nil, err + } + + dialerLogger := nodeLogger.WithFields(map[string]interface{}{ + "kind": "dialer", + }) + + user = nil + if auth := v.Dialer.Auth; auth != nil && auth.Username != "" { + if auth.Password == "" { + user = url.User(auth.Username) + } else { + user = url.UserPassword(auth.Username, auth.Password) + } + } + + tlsCfg = v.Dialer.TLS + if tlsCfg == nil { + tlsCfg = &config.TLSConfig{} + } + tlsConfig, err = tls_util.LoadClientConfig( + tlsCfg.CertFile, tlsCfg.KeyFile, tlsCfg.CAFile, + tlsCfg.Secure, tlsCfg.ServerName) + if err != nil { + chainLogger.Error(err) + return nil, err + } + + d := registry.GetDialer(v.Dialer.Type)( + dialer.UserOption(user), + dialer.TLSConfigOption(tlsConfig), + dialer.LoggerOption(dialerLogger), + ) + + if v.Dialer.Metadata == nil { + v.Dialer.Metadata = make(map[string]interface{}) + } + if err := d.Init(metadata.MapMetadata(v.Dialer.Metadata)); err != nil { + dialerLogger.Error("init: ", err) + return nil, err + } + + tr := (&chain.Transport{}). + WithConnector(cr). + WithDialer(d). + WithAddr(v.Addr) + + if v.Bypass == "" { + v.Bypass = hop.Bypass + } + if v.Resolver == "" { + v.Resolver = hop.Resolver + } + if v.Hosts == "" { + v.Hosts = hop.Hosts + } + + node := &chain.Node{ + Name: v.Name, + Addr: v.Addr, + Transport: tr, + Bypass: registry.Bypass().Get(v.Bypass), + Resolver: registry.Resolver().Get(v.Resolver), + Hosts: registry.Hosts().Get(v.Hosts), + Marker: &chain.FailMarker{}, + } + group.AddNode(node) + } + + sel := selector + if s := parseSelector(hop.Selector); s != nil { + sel = s + } + group.WithSelector(sel) + c.AddNodeGroup(group) + } + + return c, nil +} diff --git a/pkg/config/parsing/parse.go b/pkg/config/parsing/parse.go new file mode 100644 index 0000000..b23d505 --- /dev/null +++ b/pkg/config/parsing/parse.go @@ -0,0 +1,103 @@ +package parsing + +import ( + "net" + + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/config" + hostspkg "github.com/go-gost/gost/pkg/hosts" + "github.com/go-gost/gost/pkg/logger" + "github.com/go-gost/gost/pkg/resolver" + resolver_impl "github.com/go-gost/gost/pkg/resolver/impl" +) + +func parseSelector(cfg *config.SelectorConfig) chain.Selector { + if cfg == nil { + return nil + } + + var strategy chain.Strategy + switch cfg.Strategy { + case "round", "rr": + strategy = chain.RoundRobinStrategy() + case "random", "rand": + strategy = chain.RandomStrategy() + case "fifo", "ha": + strategy = chain.FIFOStrategy() + default: + strategy = chain.RoundRobinStrategy() + } + + return chain.NewSelector( + strategy, + chain.InvalidFilter(), + chain.FailFilter(cfg.MaxFails, cfg.FailTimeout), + ) +} + +func ParseBypass(cfg *config.BypassConfig) bypass.Bypass { + if cfg == nil { + return nil + } + return bypass.NewBypassPatterns( + cfg.Reverse, + cfg.Matchers, + bypass.LoggerBypassOption(logger.Default().WithFields(map[string]interface{}{ + "kind": "bypass", + "bypass": cfg.Name, + })), + ) +} + +func ParseResolver(cfg *config.ResolverConfig) (resolver.Resolver, error) { + if cfg == nil { + return nil, nil + } + var nameservers []resolver_impl.NameServer + for _, server := range cfg.Nameservers { + nameservers = append(nameservers, resolver_impl.NameServer{ + Addr: server.Addr, + // Chain: chains[server.Chain], + TTL: server.TTL, + Timeout: server.Timeout, + ClientIP: net.ParseIP(server.ClientIP), + Prefer: server.Prefer, + Hostname: server.Hostname, + }) + } + + return resolver_impl.NewResolver( + nameservers, + resolver_impl.LoggerResolverOption( + logger.Default().WithFields(map[string]interface{}{ + "kind": "resolver", + "resolver": cfg.Name, + }), + ), + ) +} + +func ParseHosts(cfg *config.HostsConfig) hostspkg.HostMapper { + if cfg == nil || len(cfg.Mappings) == 0 { + return nil + } + hosts := hostspkg.NewHosts() + hosts.Logger = logger.Default().WithFields(map[string]interface{}{ + "kind": "hosts", + "hosts": cfg.Name, + }) + + for _, host := range cfg.Mappings { + if host.IP == "" || host.Hostname == "" { + continue + } + + ip := net.ParseIP(host.IP) + if ip == nil { + continue + } + hosts.Map(ip, host.Hostname, host.Aliases...) + } + return hosts +} diff --git a/pkg/config/parsing/service.go b/pkg/config/parsing/service.go new file mode 100644 index 0000000..5c25e04 --- /dev/null +++ b/pkg/config/parsing/service.go @@ -0,0 +1,143 @@ +package parsing + +import ( + "net/url" + "strings" + + "github.com/go-gost/gost/pkg/chain" + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + "github.com/go-gost/gost/pkg/config" + "github.com/go-gost/gost/pkg/handler" + "github.com/go-gost/gost/pkg/listener" + "github.com/go-gost/gost/pkg/logger" + "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "github.com/go-gost/gost/pkg/service" +) + +func ParseService(cfg *config.ServiceConfig) (*service.Service, error) { + if cfg.Listener == nil { + cfg.Listener = &config.ListenerConfig{ + Type: "tcp", + } + } + if cfg.Handler == nil { + cfg.Handler = &config.HandlerConfig{ + Type: "auto", + } + } + serviceLogger := logger.Default().WithFields(map[string]interface{}{ + "kind": "service", + "service": cfg.Name, + "listener": cfg.Listener.Type, + "handler": cfg.Handler.Type, + }) + + listenerLogger := serviceLogger.WithFields(map[string]interface{}{ + "kind": "listener", + }) + + tlsCfg := cfg.Listener.TLS + if tlsCfg == nil { + tlsCfg = &config.TLSConfig{} + } + tlsConfig, err := tls_util.LoadServerConfig( + tlsCfg.CertFile, tlsCfg.KeyFile, tlsCfg.CAFile) + if err != nil { + listenerLogger.Error(err) + return nil, err + } + + ln := registry.GetListener(cfg.Listener.Type)( + listener.AddrOption(cfg.Addr), + listener.ChainOption(registry.Chain().Get(cfg.Listener.Chain)), + listener.AuthsOption(parseAuths(cfg.Listener.Auths...)...), + listener.TLSConfigOption(tlsConfig), + listener.LoggerOption(listenerLogger), + ) + + if cfg.Listener.Metadata == nil { + cfg.Listener.Metadata = make(map[string]interface{}) + } + if err := ln.Init(metadata.MapMetadata(cfg.Listener.Metadata)); err != nil { + listenerLogger.Error("init: ", err) + return nil, err + } + + handlerLogger := serviceLogger.WithFields(map[string]interface{}{ + "kind": "handler", + }) + + tlsCfg = cfg.Handler.TLS + if tlsCfg == nil { + tlsCfg = &config.TLSConfig{} + } + tlsConfig, err = tls_util.LoadServerConfig( + tlsCfg.CertFile, tlsCfg.KeyFile, tlsCfg.CAFile) + if err != nil { + handlerLogger.Error(err) + return nil, err + } + + h := registry.GetHandler(cfg.Handler.Type)( + handler.AuthsOption(parseAuths(cfg.Handler.Auths...)...), + handler.RetriesOption(cfg.Handler.Retries), + handler.ChainOption(registry.Chain().Get(cfg.Handler.Chain)), + handler.BypassOption(registry.Bypass().Get(cfg.Bypass)), + handler.ResolverOption(registry.Resolver().Get(cfg.Resolver)), + handler.HostsOption(registry.Hosts().Get(cfg.Hosts)), + handler.TLSConfigOption(tlsConfig), + handler.LoggerOption(handlerLogger), + ) + + if forwarder, ok := h.(handler.Forwarder); ok { + forwarder.Forward(parseForwarder(cfg.Forwarder)) + } + + if cfg.Handler.Metadata == nil { + cfg.Handler.Metadata = make(map[string]interface{}) + } + if err := h.Init(metadata.MapMetadata(cfg.Handler.Metadata)); err != nil { + handlerLogger.Error("init: ", err) + return nil, err + } + + s := (&service.Service{}). + WithListener(ln). + WithHandler(h). + WithLogger(serviceLogger) + + serviceLogger.Infof("listening on %s/%s", s.Addr().String(), s.Addr().Network()) + return s, nil +} + +func parseAuths(cfgs ...*config.AuthConfig) []*url.Userinfo { + var auths []*url.Userinfo + + for _, cfg := range cfgs { + if cfg == nil || cfg.Username == "" { + continue + } + auths = append(auths, url.UserPassword(cfg.Username, cfg.Password)) + } + + return auths +} + +func parseForwarder(cfg *config.ForwarderConfig) *chain.NodeGroup { + if cfg == nil || len(cfg.Targets) == 0 { + return nil + } + + group := &chain.NodeGroup{} + for _, target := range cfg.Targets { + if v := strings.TrimSpace(target); v != "" { + group.AddNode(&chain.Node{ + Name: target, + Addr: target, + Marker: &chain.FailMarker{}, + }) + } + } + return group.WithSelector(parseSelector(cfg.Selector)) +} diff --git a/pkg/handler/option.go b/pkg/handler/option.go index 5462d4a..2d7a75b 100644 --- a/pkg/handler/option.go +++ b/pkg/handler/option.go @@ -13,7 +13,7 @@ import ( type Options struct { Retries int - Chain *chain.Chain + Chain chain.Chainer Resolver resolver.Resolver Hosts hosts.HostMapper Bypass bypass.Bypass @@ -30,7 +30,7 @@ func RetriesOption(retries int) Option { } } -func ChainOption(chain *chain.Chain) Option { +func ChainOption(chain chain.Chainer) Option { return func(opts *Options) { opts.Chain = chain } diff --git a/pkg/listener/option.go b/pkg/listener/option.go index ce810e4..aa9cd6d 100644 --- a/pkg/listener/option.go +++ b/pkg/listener/option.go @@ -12,7 +12,7 @@ type Options struct { Addr string Auths []*url.Userinfo TLSConfig *tls.Config - Chain *chain.Chain + Chain chain.Chainer Logger logger.Logger } @@ -36,7 +36,7 @@ func TLSConfigOption(tlsConfig *tls.Config) Option { } } -func ChainOption(chain *chain.Chain) Option { +func ChainOption(chain chain.Chainer) Option { return func(opts *Options) { opts.Chain = chain } diff --git a/pkg/registry/bypass.go b/pkg/registry/bypass.go new file mode 100644 index 0000000..b44b221 --- /dev/null +++ b/pkg/registry/bypass.go @@ -0,0 +1,50 @@ +package registry + +import ( + "sync" + + "github.com/go-gost/gost/pkg/bypass" +) + +var ( + bypassReg = &bypassRegistry{} +) + +func Bypass() *bypassRegistry { + return bypassReg +} + +type bypassRegistry struct { + m sync.Map +} + +func (r *bypassRegistry) Register(name string, bypass bypass.Bypass) error { + if _, loaded := r.m.LoadOrStore(name, bypass); loaded { + return ErrDup + } + + return nil +} + +func (r *bypassRegistry) Unregister(name string) { + r.m.Delete(name) +} + +func (r *bypassRegistry) Get(name string) bypass.Bypass { + if _, ok := r.m.Load(name); !ok { + return nil + } + return &bypassWrapper{name: name} +} + +type bypassWrapper struct { + name string +} + +func (w *bypassWrapper) Contains(addr string) bool { + bp := bypassReg.Get(w.name) + if bp == nil { + return false + } + return bp.Contains(addr) +} diff --git a/pkg/registry/chain.go b/pkg/registry/chain.go new file mode 100644 index 0000000..77ad3d8 --- /dev/null +++ b/pkg/registry/chain.go @@ -0,0 +1,50 @@ +package registry + +import ( + "sync" + + "github.com/go-gost/gost/pkg/chain" +) + +var ( + chainReg = &chainRegistry{} +) + +func Chain() *chainRegistry { + return chainReg +} + +type chainRegistry struct { + m sync.Map +} + +func (r *chainRegistry) Register(name string, chain chain.Chainer) error { + if _, loaded := r.m.LoadOrStore(name, chain); loaded { + return ErrDup + } + + return nil +} + +func (r *chainRegistry) Unregister(name string) { + r.m.Delete(name) +} + +func (r *chainRegistry) Get(name string) chain.Chainer { + if _, ok := r.m.Load(name); !ok { + return nil + } + return &chainWrapper{name: name} +} + +type chainWrapper struct { + name string +} + +func (w *chainWrapper) Route(network, address string) *chain.Route { + v := Chain().Get(w.name) + if v == nil { + return nil + } + return v.Route(network, address) +} diff --git a/pkg/registry/hosts.go b/pkg/registry/hosts.go new file mode 100644 index 0000000..b4685f8 --- /dev/null +++ b/pkg/registry/hosts.go @@ -0,0 +1,51 @@ +package registry + +import ( + "net" + "sync" + + "github.com/go-gost/gost/pkg/hosts" +) + +var ( + hostsReg = &hostsRegistry{} +) + +func Hosts() *hostsRegistry { + return hostsReg +} + +type hostsRegistry struct { + m sync.Map +} + +func (r *hostsRegistry) Register(name string, hosts hosts.HostMapper) error { + if _, loaded := r.m.LoadOrStore(name, hosts); loaded { + return ErrDup + } + + return nil +} + +func (r *hostsRegistry) Unregister(name string) { + r.m.Delete(name) +} + +func (r *hostsRegistry) Get(name string) hosts.HostMapper { + if _, ok := r.m.Load(name); !ok { + return nil + } + return &hostsWrapper{name: name} +} + +type hostsWrapper struct { + name string +} + +func (w *hostsWrapper) Lookup(network, host string) ([]net.IP, bool) { + v := Hosts().Get(w.name) + if v == nil { + return nil, false + } + return v.Lookup(network, host) +} diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go index b1cb609..b9aa7b0 100644 --- a/pkg/registry/registry.go +++ b/pkg/registry/registry.go @@ -1,6 +1,8 @@ package registry import ( + "errors" + "github.com/go-gost/gost/pkg/connector" "github.com/go-gost/gost/pkg/dialer" "github.com/go-gost/gost/pkg/handler" @@ -8,6 +10,11 @@ import ( "github.com/go-gost/gost/pkg/logger" ) +var ( + ErrDup = errors.New("registry: duplicate instance") + ErrNotFound = errors.New("registry: instance not found") +) + type NewListener func(opts ...listener.Option) listener.Listener type NewHandler func(opts ...handler.Option) handler.Handler type NewDialer func(opts ...dialer.Option) dialer.Dialer diff --git a/pkg/registry/resolver.go b/pkg/registry/resolver.go new file mode 100644 index 0000000..71af7fe --- /dev/null +++ b/pkg/registry/resolver.go @@ -0,0 +1,52 @@ +package registry + +import ( + "context" + "net" + "sync" + + "github.com/go-gost/gost/pkg/resolver" +) + +var ( + resolverReg = &resolverRegistry{} +) + +func Resolver() *resolverRegistry { + return resolverReg +} + +type resolverRegistry struct { + m sync.Map +} + +func (r *resolverRegistry) Register(name string, resolver resolver.Resolver) error { + if _, loaded := r.m.LoadOrStore(name, resolver); loaded { + return ErrDup + } + + return nil +} + +func (r *resolverRegistry) Unregister(name string) { + r.m.Delete(name) +} + +func (r *resolverRegistry) Get(name string) resolver.Resolver { + if _, ok := r.m.Load(name); !ok { + return nil + } + return &resolverWrapper{name: name} +} + +type resolverWrapper struct { + name string +} + +func (w *resolverWrapper) Resolve(ctx context.Context, network, host string) ([]net.IP, error) { + r := Resolver().Get(w.name) + if r == nil { + return nil, ErrNotFound + } + return r.Resolve(ctx, network, host) +} diff --git a/pkg/registry/service.go b/pkg/registry/service.go new file mode 100644 index 0000000..a0a2e2e --- /dev/null +++ b/pkg/registry/service.go @@ -0,0 +1,39 @@ +package registry + +import ( + "sync" + + "github.com/go-gost/gost/pkg/service" +) + +var ( + svcReg = &serviceRegistry{} +) + +func Service() *serviceRegistry { + return svcReg +} + +type serviceRegistry struct { + m sync.Map +} + +func (r *serviceRegistry) Register(name string, svc *service.Service) error { + if _, loaded := r.m.LoadOrStore(name, svc); loaded { + return ErrDup + } + + return nil +} + +func (r *serviceRegistry) Unregister(name string) { + r.m.Delete(name) +} + +func (r *serviceRegistry) Get(name string) *service.Service { + v, ok := r.m.Load(name) + if !ok { + return nil + } + return v.(*service.Service) +} diff --git a/pkg/resolver/impl/resolver.go b/pkg/resolver/impl/resolver.go index 74ef118..7e8096c 100644 --- a/pkg/resolver/impl/resolver.go +++ b/pkg/resolver/impl/resolver.go @@ -16,7 +16,7 @@ import ( type NameServer struct { Addr string - Chain *chain.Chain + Chain chain.Chainer TTL time.Duration Timeout time.Duration ClientIP net.IP @@ -89,7 +89,7 @@ func NewResolver(nameservers []NameServer, opts ...ResolverOption) (resolverpkg. }, nil } -func (r *resolver) Resolve(ctx context.Context, host string) (ips []net.IP, err error) { +func (r *resolver) Resolve(ctx context.Context, network, host string) (ips []net.IP, err error) { if ip := net.ParseIP(host); ip != nil { return []net.IP{ip}, nil } diff --git a/pkg/resolver/resolver.go b/pkg/resolver/resolver.go index 6b6108f..df2d0c6 100644 --- a/pkg/resolver/resolver.go +++ b/pkg/resolver/resolver.go @@ -7,5 +7,6 @@ import ( type Resolver interface { // Resolve returns a slice of the host's IPv4 and IPv6 addresses. - Resolve(ctx context.Context, host string) ([]net.IP, error) + // The network should be 'ip', 'ip4' or 'ip6', default network is 'ip'. + Resolve(ctx context.Context, network, host string) ([]net.IP, error) }