update config parsing
This commit is contained in:
@ -1,308 +1,76 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/go-gost/gost/pkg/bypass"
|
||||
"github.com/go-gost/gost/pkg/chain"
|
||||
"github.com/go-gost/gost/pkg/config"
|
||||
"github.com/go-gost/gost/pkg/connector"
|
||||
"github.com/go-gost/gost/pkg/dialer"
|
||||
"github.com/go-gost/gost/pkg/handler"
|
||||
hostspkg "github.com/go-gost/gost/pkg/hosts"
|
||||
"github.com/go-gost/gost/pkg/listener"
|
||||
"github.com/go-gost/gost/pkg/config/parsing"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
"github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
"github.com/go-gost/gost/pkg/resolver"
|
||||
resolver_impl "github.com/go-gost/gost/pkg/resolver/impl"
|
||||
"github.com/go-gost/gost/pkg/service"
|
||||
)
|
||||
|
||||
var (
|
||||
chains = make(map[string]*chain.Chain)
|
||||
bypasses = make(map[string]bypass.Bypass)
|
||||
resolvers = make(map[string]resolver.Resolver)
|
||||
hosts = make(map[string]hostspkg.HostMapper)
|
||||
)
|
||||
|
||||
func buildService(cfg *config.Config) (services []*service.Service) {
|
||||
if cfg == nil || len(cfg.Services) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, bypassCfg := range cfg.Bypasses {
|
||||
bypasses[bypassCfg.Name] = bypassFromConfig(bypassCfg)
|
||||
if bp := parsing.ParseBypass(bypassCfg); bp != nil {
|
||||
if err := registry.Bypass().Register(bypassCfg.Name, bp); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, resolverCfg := range cfg.Resolvers {
|
||||
r, err := resolverFromConfig(resolverCfg)
|
||||
r, err := parsing.ParseResolver(resolverCfg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
resolvers[resolverCfg.Name] = r
|
||||
if r != nil {
|
||||
if err := registry.Resolver().Register(resolverCfg.Name, r); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, hostsCfg := range cfg.Hosts {
|
||||
hosts[hostsCfg.Name] = hostsFromConfig(hostsCfg)
|
||||
if h := parsing.ParseHosts(hostsCfg); h != nil {
|
||||
if err := registry.Hosts().Register(hostsCfg.Name, h); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, chainCfg := range cfg.Chains {
|
||||
chains[chainCfg.Name] = chainFromConfig(chainCfg)
|
||||
c, err := parsing.ParseChain(chainCfg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if c != nil {
|
||||
if err := registry.Chain().Register(chainCfg.Name, c); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, svc := range cfg.Services {
|
||||
if svc.Listener == nil {
|
||||
svc.Listener = &config.ListenerConfig{
|
||||
Type: "tcp",
|
||||
}
|
||||
}
|
||||
if svc.Handler == nil {
|
||||
svc.Handler = &config.HandlerConfig{
|
||||
Type: "auto",
|
||||
}
|
||||
}
|
||||
serviceLogger := log.WithFields(map[string]interface{}{
|
||||
"kind": "service",
|
||||
"service": svc.Name,
|
||||
"listener": svc.Listener.Type,
|
||||
"handler": svc.Handler.Type,
|
||||
})
|
||||
|
||||
listenerLogger := serviceLogger.WithFields(map[string]interface{}{
|
||||
"kind": "listener",
|
||||
})
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
var err error
|
||||
|
||||
tlsCfg := svc.Listener.TLS
|
||||
if tlsCfg == nil {
|
||||
tlsCfg = &config.TLSConfig{}
|
||||
}
|
||||
tlsConfig, err = loadServerTLSConfig(tlsCfg)
|
||||
for _, svcCfg := range cfg.Services {
|
||||
svc, err := parsing.ParseService(svcCfg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ln := registry.GetListener(svc.Listener.Type)(
|
||||
listener.AddrOption(svc.Addr),
|
||||
listener.ChainOption(chains[svc.Listener.Chain]),
|
||||
listener.AuthsOption(authsFromConfig(svc.Listener.Auths...)...),
|
||||
listener.TLSConfigOption(tlsConfig),
|
||||
listener.LoggerOption(listenerLogger),
|
||||
)
|
||||
|
||||
if svc.Listener.Metadata == nil {
|
||||
svc.Listener.Metadata = make(map[string]interface{})
|
||||
if svc != nil {
|
||||
if err := registry.Service().Register(svcCfg.Name, svc); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
if err := ln.Init(metadata.MapMetadata(svc.Listener.Metadata)); err != nil {
|
||||
listenerLogger.Fatal("init: ", err)
|
||||
}
|
||||
|
||||
handlerLogger := serviceLogger.WithFields(map[string]interface{}{
|
||||
"kind": "handler",
|
||||
})
|
||||
|
||||
tlsConfig = nil
|
||||
tlsCfg = svc.Handler.TLS
|
||||
if tlsCfg == nil {
|
||||
tlsCfg = &config.TLSConfig{}
|
||||
}
|
||||
tlsConfig, err = loadServerTLSConfig(tlsCfg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
h := registry.GetHandler(svc.Handler.Type)(
|
||||
handler.AuthsOption(authsFromConfig(svc.Handler.Auths...)...),
|
||||
handler.RetriesOption(svc.Handler.Retries),
|
||||
handler.ChainOption(chains[svc.Handler.Chain]),
|
||||
handler.BypassOption(bypasses[svc.Bypass]),
|
||||
handler.ResolverOption(resolvers[svc.Resolver]),
|
||||
handler.HostsOption(hosts[svc.Hosts]),
|
||||
handler.TLSConfigOption(tlsConfig),
|
||||
handler.LoggerOption(handlerLogger),
|
||||
)
|
||||
|
||||
if forwarder, ok := h.(handler.Forwarder); ok {
|
||||
forwarder.Forward(forwarderFromConfig(svc.Forwarder))
|
||||
}
|
||||
|
||||
if svc.Handler.Metadata == nil {
|
||||
svc.Handler.Metadata = make(map[string]interface{})
|
||||
}
|
||||
if err := h.Init(metadata.MapMetadata(svc.Handler.Metadata)); err != nil {
|
||||
handlerLogger.Fatal("init: ", err)
|
||||
}
|
||||
|
||||
s := (&service.Service{}).
|
||||
WithListener(ln).
|
||||
WithHandler(h).
|
||||
WithLogger(serviceLogger)
|
||||
services = append(services, s)
|
||||
|
||||
serviceLogger.Infof("listening on %s/%s", s.Addr().String(), s.Addr().Network())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func chainFromConfig(cfg *config.ChainConfig) *chain.Chain {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
chainLogger := log.WithFields(map[string]interface{}{
|
||||
"kind": "chain",
|
||||
"chain": cfg.Name,
|
||||
})
|
||||
|
||||
c := &chain.Chain{}
|
||||
selector := selectorFromConfig(cfg.Selector)
|
||||
for _, hop := range cfg.Hops {
|
||||
group := &chain.NodeGroup{}
|
||||
for _, v := range hop.Nodes {
|
||||
nodeLogger := chainLogger.WithFields(map[string]interface{}{
|
||||
"kind": "node",
|
||||
"connector": v.Connector.Type,
|
||||
"dialer": v.Dialer.Type,
|
||||
"hop": hop.Name,
|
||||
"node": v.Name,
|
||||
})
|
||||
connectorLogger := nodeLogger.WithFields(map[string]interface{}{
|
||||
"kind": "connector",
|
||||
})
|
||||
|
||||
var user *url.Userinfo
|
||||
if auth := v.Connector.Auth; auth != nil && auth.Username != "" {
|
||||
if auth.Password == "" {
|
||||
user = url.User(auth.Username)
|
||||
} else {
|
||||
user = url.UserPassword(auth.Username, auth.Password)
|
||||
}
|
||||
}
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
var err error
|
||||
tlsCfg := v.Connector.TLS
|
||||
if tlsCfg == nil {
|
||||
tlsCfg = &config.TLSConfig{}
|
||||
}
|
||||
tlsConfig, err = loadClientTLSConfig(tlsCfg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
cr := registry.GetConnector(v.Connector.Type)(
|
||||
connector.UserOption(user),
|
||||
connector.TLSConfigOption(tlsConfig),
|
||||
connector.LoggerOption(connectorLogger),
|
||||
)
|
||||
|
||||
if v.Connector.Metadata == nil {
|
||||
v.Connector.Metadata = make(map[string]interface{})
|
||||
}
|
||||
if err := cr.Init(metadata.MapMetadata(v.Connector.Metadata)); err != nil {
|
||||
connectorLogger.Fatal("init: ", err)
|
||||
}
|
||||
|
||||
dialerLogger := nodeLogger.WithFields(map[string]interface{}{
|
||||
"kind": "dialer",
|
||||
})
|
||||
|
||||
user = nil
|
||||
if auth := v.Dialer.Auth; auth != nil && auth.Username != "" {
|
||||
if auth.Password == "" {
|
||||
user = url.User(auth.Username)
|
||||
} else {
|
||||
user = url.UserPassword(auth.Username, auth.Password)
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig = nil
|
||||
tlsCfg = v.Dialer.TLS
|
||||
if tlsCfg == nil {
|
||||
tlsCfg = &config.TLSConfig{}
|
||||
}
|
||||
tlsConfig, err = loadClientTLSConfig(tlsCfg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
d := registry.GetDialer(v.Dialer.Type)(
|
||||
dialer.UserOption(user),
|
||||
dialer.TLSConfigOption(tlsConfig),
|
||||
dialer.LoggerOption(dialerLogger),
|
||||
)
|
||||
|
||||
if v.Dialer.Metadata == nil {
|
||||
v.Dialer.Metadata = make(map[string]interface{})
|
||||
}
|
||||
if err := d.Init(metadata.MapMetadata(v.Dialer.Metadata)); err != nil {
|
||||
dialerLogger.Fatal("init: ", err)
|
||||
}
|
||||
|
||||
tr := (&chain.Transport{}).
|
||||
WithConnector(cr).
|
||||
WithDialer(d).
|
||||
WithAddr(v.Addr)
|
||||
|
||||
if v.Bypass == "" {
|
||||
v.Bypass = hop.Bypass
|
||||
}
|
||||
if v.Resolver == "" {
|
||||
v.Resolver = hop.Resolver
|
||||
}
|
||||
if v.Hosts == "" {
|
||||
v.Hosts = hop.Hosts
|
||||
}
|
||||
|
||||
node := &chain.Node{
|
||||
Name: v.Name,
|
||||
Addr: v.Addr,
|
||||
Transport: tr,
|
||||
Bypass: bypasses[v.Bypass],
|
||||
Resolver: resolvers[v.Resolver],
|
||||
Hosts: hosts[v.Hosts],
|
||||
Marker: &chain.FailMarker{},
|
||||
}
|
||||
group.AddNode(node)
|
||||
}
|
||||
|
||||
sel := selector
|
||||
if s := selectorFromConfig(hop.Selector); s != nil {
|
||||
sel = s
|
||||
}
|
||||
group.WithSelector(sel)
|
||||
c.AddNodeGroup(group)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func forwarderFromConfig(cfg *config.ForwarderConfig) *chain.NodeGroup {
|
||||
if cfg == nil || len(cfg.Targets) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
group := &chain.NodeGroup{}
|
||||
for _, target := range cfg.Targets {
|
||||
if v := strings.TrimSpace(target); v != "" {
|
||||
group.AddNode(&chain.Node{
|
||||
Name: target,
|
||||
Addr: target,
|
||||
Marker: &chain.FailMarker{},
|
||||
})
|
||||
}
|
||||
}
|
||||
return group.WithSelector(selectorFromConfig(cfg.Selector))
|
||||
}
|
||||
|
||||
func logFromConfig(cfg *config.LogConfig) logger.Logger {
|
||||
if cfg == nil {
|
||||
cfg = &config.LogConfig{}
|
||||
@ -314,7 +82,7 @@ func logFromConfig(cfg *config.LogConfig) logger.Logger {
|
||||
|
||||
var out io.Writer = os.Stderr
|
||||
switch cfg.Output {
|
||||
case "none":
|
||||
case "none", "null":
|
||||
return logger.Nop()
|
||||
case "stdout":
|
||||
out = os.Stdout
|
||||
@ -332,105 +100,3 @@ func logFromConfig(cfg *config.LogConfig) logger.Logger {
|
||||
|
||||
return logger.NewLogger(opts...)
|
||||
}
|
||||
|
||||
func selectorFromConfig(cfg *config.SelectorConfig) chain.Selector {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var strategy chain.Strategy
|
||||
switch cfg.Strategy {
|
||||
case "round", "rr":
|
||||
strategy = chain.RoundRobinStrategy()
|
||||
case "random", "rand":
|
||||
strategy = chain.RandomStrategy()
|
||||
case "fifo", "ha":
|
||||
strategy = chain.FIFOStrategy()
|
||||
default:
|
||||
strategy = chain.RoundRobinStrategy()
|
||||
}
|
||||
|
||||
return chain.NewSelector(
|
||||
strategy,
|
||||
chain.InvalidFilter(),
|
||||
chain.FailFilter(cfg.MaxFails, cfg.FailTimeout),
|
||||
)
|
||||
}
|
||||
|
||||
func bypassFromConfig(cfg *config.BypassConfig) bypass.Bypass {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return bypass.NewBypassPatterns(
|
||||
cfg.Reverse,
|
||||
cfg.Matchers,
|
||||
bypass.LoggerBypassOption(log.WithFields(map[string]interface{}{
|
||||
"kind": "bypass",
|
||||
"bypass": cfg.Name,
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
func resolverFromConfig(cfg *config.ResolverConfig) (resolver.Resolver, error) {
|
||||
if cfg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
var nameservers []resolver_impl.NameServer
|
||||
for _, server := range cfg.Nameservers {
|
||||
nameservers = append(nameservers, resolver_impl.NameServer{
|
||||
Addr: server.Addr,
|
||||
Chain: chains[server.Chain],
|
||||
TTL: server.TTL,
|
||||
Timeout: server.Timeout,
|
||||
ClientIP: net.ParseIP(server.ClientIP),
|
||||
Prefer: server.Prefer,
|
||||
Hostname: server.Hostname,
|
||||
})
|
||||
}
|
||||
|
||||
logger := log.WithFields(map[string]interface{}{
|
||||
"kind": "resolver",
|
||||
"resolver": cfg.Name,
|
||||
})
|
||||
return resolver_impl.NewResolver(
|
||||
nameservers,
|
||||
resolver_impl.LoggerResolverOption(logger),
|
||||
)
|
||||
}
|
||||
|
||||
func hostsFromConfig(cfg *config.HostsConfig) hostspkg.HostMapper {
|
||||
if cfg == nil || len(cfg.Mappings) == 0 {
|
||||
return nil
|
||||
}
|
||||
hosts := hostspkg.NewHosts()
|
||||
hosts.Logger = log.WithFields(map[string]interface{}{
|
||||
"kind": "hosts",
|
||||
"hosts": cfg.Name,
|
||||
})
|
||||
|
||||
for _, host := range cfg.Mappings {
|
||||
if host.IP == "" || host.Hostname == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
ip := net.ParseIP(host.IP)
|
||||
if ip == nil {
|
||||
continue
|
||||
}
|
||||
hosts.Map(ip, host.Hostname, host.Aliases...)
|
||||
}
|
||||
return hosts
|
||||
}
|
||||
|
||||
func authsFromConfig(cfgs ...*config.AuthConfig) []*url.Userinfo {
|
||||
var auths []*url.Userinfo
|
||||
|
||||
for _, cfg := range cfgs {
|
||||
if cfg == nil || cfg.Username == "" {
|
||||
continue
|
||||
}
|
||||
auths = append(auths, url.UserPassword(cfg.Username, cfg.Password))
|
||||
}
|
||||
|
||||
return auths
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
@ -16,11 +15,11 @@ import (
|
||||
var (
|
||||
log = logger.Default()
|
||||
|
||||
cfgFile string
|
||||
outputCfgFile string
|
||||
services stringList
|
||||
nodes stringList
|
||||
debug bool
|
||||
cfgFile string
|
||||
outputFormat string
|
||||
services stringList
|
||||
nodes stringList
|
||||
debug bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -31,7 +30,7 @@ func init() {
|
||||
flag.StringVar(&cfgFile, "C", "", "configure file")
|
||||
flag.BoolVar(&printVersion, "V", false, "print version")
|
||||
flag.BoolVar(&debug, "D", false, "debug mode")
|
||||
flag.StringVar(&outputCfgFile, "O", "", "write config to FILE")
|
||||
flag.StringVar(&outputFormat, "O", "", "output format, one of yaml|json format")
|
||||
flag.Parse()
|
||||
|
||||
if printVersion {
|
||||
@ -65,19 +64,8 @@ func main() {
|
||||
|
||||
log = logFromConfig(cfg.Log)
|
||||
|
||||
if outputCfgFile != "" {
|
||||
var w io.Writer
|
||||
if outputCfgFile == "-" {
|
||||
w = os.Stdout
|
||||
} else {
|
||||
f, err := os.Create(outputCfgFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
w = f
|
||||
}
|
||||
if err := cfg.Write(w); err != nil {
|
||||
if outputFormat != "" {
|
||||
if err := cfg.Write(os.Stdout, outputFormat); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
os.Exit(0)
|
||||
|
@ -14,14 +14,6 @@ import (
|
||||
"github.com/go-gost/gost/pkg/config"
|
||||
)
|
||||
|
||||
func loadServerTLSConfig(cfg *config.TLSConfig) (*tls.Config, error) {
|
||||
return tls_util.LoadServerConfig(cfg.CertFile, cfg.KeyFile, cfg.CAFile)
|
||||
}
|
||||
|
||||
func loadClientTLSConfig(cfg *config.TLSConfig) (*tls.Config, error) {
|
||||
return tls_util.LoadClientConfig(cfg.CertFile, cfg.KeyFile, cfg.CAFile, cfg.Secure, cfg.ServerName)
|
||||
}
|
||||
|
||||
func buildDefaultTLSConfig(cfg *config.TLSConfig) {
|
||||
if cfg == nil {
|
||||
cfg = &config.TLSConfig{
|
||||
|
Reference in New Issue
Block a user