410 lines
12 KiB
Go
410 lines
12 KiB
Go
package service
|
|
|
|
import (
|
|
"fmt"
|
|
"runtime"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-gost/core/admission"
|
|
"github.com/go-gost/core/auth"
|
|
"github.com/go-gost/core/bypass"
|
|
"github.com/go-gost/core/chain"
|
|
"github.com/go-gost/core/handler"
|
|
"github.com/go-gost/core/hop"
|
|
"github.com/go-gost/core/listener"
|
|
"github.com/go-gost/core/logger"
|
|
mdutil "github.com/go-gost/core/metadata/util"
|
|
"github.com/go-gost/core/observer/stats"
|
|
"github.com/go-gost/core/recorder"
|
|
"github.com/go-gost/core/selector"
|
|
"github.com/go-gost/core/service"
|
|
xchain "github.com/go-gost/x/chain"
|
|
"github.com/go-gost/x/config"
|
|
"github.com/go-gost/x/config/parsing"
|
|
admission_parser "github.com/go-gost/x/config/parsing/admission"
|
|
auth_parser "github.com/go-gost/x/config/parsing/auth"
|
|
bypass_parser "github.com/go-gost/x/config/parsing/bypass"
|
|
hop_parser "github.com/go-gost/x/config/parsing/hop"
|
|
logger_parser "github.com/go-gost/x/config/parsing/logger"
|
|
selector_parser "github.com/go-gost/x/config/parsing/selector"
|
|
xnet "github.com/go-gost/x/internal/net"
|
|
tls_util "github.com/go-gost/x/internal/util/tls"
|
|
"github.com/go-gost/x/metadata"
|
|
"github.com/go-gost/x/registry"
|
|
xservice "github.com/go-gost/x/service"
|
|
"github.com/vishvananda/netns"
|
|
)
|
|
|
|
func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
|
|
if cfg.Listener == nil {
|
|
cfg.Listener = &config.ListenerConfig{}
|
|
}
|
|
if strings.TrimSpace(cfg.Listener.Type) == "" {
|
|
cfg.Listener.Type = "tcp"
|
|
}
|
|
|
|
if cfg.Handler == nil {
|
|
cfg.Handler = &config.HandlerConfig{}
|
|
}
|
|
if strings.TrimSpace(cfg.Handler.Type) == "" {
|
|
cfg.Handler.Type = "auto"
|
|
}
|
|
|
|
log := logger.Default()
|
|
if loggers := logger_parser.List(cfg.Logger, cfg.Loggers...); len(loggers) > 0 {
|
|
log = logger.LoggerGroup(loggers...)
|
|
}
|
|
|
|
serviceLogger := log.WithFields(map[string]any{
|
|
"kind": "service",
|
|
"service": cfg.Name,
|
|
"listener": cfg.Listener.Type,
|
|
"handler": cfg.Handler.Type,
|
|
})
|
|
|
|
tlsCfg := cfg.Listener.TLS
|
|
if tlsCfg == nil {
|
|
tlsCfg = &config.TLSConfig{}
|
|
}
|
|
tlsConfig, err := tls_util.LoadServerConfig(tlsCfg)
|
|
if err != nil {
|
|
serviceLogger.Error(err)
|
|
return nil, err
|
|
}
|
|
if tlsConfig == nil {
|
|
tlsConfig = parsing.DefaultTLSConfig().Clone()
|
|
}
|
|
|
|
authers := auth_parser.List(cfg.Listener.Auther, cfg.Listener.Authers...)
|
|
if len(authers) == 0 {
|
|
if auther := auth_parser.ParseAutherFromAuth(cfg.Listener.Auth); auther != nil {
|
|
authers = append(authers, auther)
|
|
}
|
|
}
|
|
var auther auth.Authenticator
|
|
if len(authers) > 0 {
|
|
auther = auth.AuthenticatorGroup(authers...)
|
|
}
|
|
|
|
admissions := admission_parser.List(cfg.Admission, cfg.Admissions...)
|
|
|
|
var sockOpts *chain.SockOpts
|
|
if cfg.SockOpts != nil {
|
|
sockOpts = &chain.SockOpts{
|
|
Mark: cfg.SockOpts.Mark,
|
|
}
|
|
}
|
|
|
|
var ppv int
|
|
ifce := cfg.Interface
|
|
var preUp, preDown, postUp, postDown []string
|
|
var ignoreChain bool
|
|
var pStats *stats.Stats
|
|
var observePeriod time.Duration
|
|
var netnsIn, netnsOut string
|
|
var dialTimeout time.Duration
|
|
if cfg.Metadata != nil {
|
|
md := metadata.NewMetadata(cfg.Metadata)
|
|
ppv = mdutil.GetInt(md, parsing.MDKeyProxyProtocol)
|
|
if v := mdutil.GetString(md, parsing.MDKeyInterface); v != "" {
|
|
ifce = v
|
|
}
|
|
if v := mdutil.GetInt(md, parsing.MDKeySoMark); v > 0 {
|
|
sockOpts = &chain.SockOpts{
|
|
Mark: v,
|
|
}
|
|
}
|
|
preUp = mdutil.GetStrings(md, parsing.MDKeyPreUp)
|
|
preDown = mdutil.GetStrings(md, parsing.MDKeyPreDown)
|
|
postUp = mdutil.GetStrings(md, parsing.MDKeyPostUp)
|
|
postDown = mdutil.GetStrings(md, parsing.MDKeyPostDown)
|
|
ignoreChain = mdutil.GetBool(md, parsing.MDKeyIgnoreChain)
|
|
|
|
if mdutil.GetBool(md, parsing.MDKeyEnableStats) {
|
|
pStats = &stats.Stats{}
|
|
}
|
|
observePeriod = mdutil.GetDuration(md, "observePeriod")
|
|
netnsIn = mdutil.GetString(md, "netns")
|
|
netnsOut = mdutil.GetString(md, "netns.out")
|
|
dialTimeout = mdutil.GetDuration(md, "dialTimeout")
|
|
}
|
|
|
|
listenerLogger := serviceLogger.WithFields(map[string]any{
|
|
"kind": "listener",
|
|
})
|
|
|
|
routerOpts := []chain.RouterOption{
|
|
chain.TimeoutRouterOption(dialTimeout),
|
|
chain.InterfaceRouterOption(ifce),
|
|
chain.NetnsRouterOption(netnsOut),
|
|
chain.SockOptsRouterOption(sockOpts),
|
|
chain.ResolverRouterOption(registry.ResolverRegistry().Get(cfg.Resolver)),
|
|
chain.HostMapperRouterOption(registry.HostsRegistry().Get(cfg.Hosts)),
|
|
chain.LoggerRouterOption(listenerLogger),
|
|
}
|
|
if !ignoreChain {
|
|
routerOpts = append(routerOpts,
|
|
chain.ChainRouterOption(chainGroup(cfg.Listener.Chain, cfg.Listener.ChainGroup)),
|
|
)
|
|
}
|
|
|
|
listenOpts := []listener.Option{
|
|
listener.AddrOption(cfg.Addr),
|
|
listener.RouterOption(chain.NewRouter(routerOpts...)),
|
|
listener.AutherOption(auther),
|
|
listener.AuthOption(auth_parser.Info(cfg.Listener.Auth)),
|
|
listener.TLSConfigOption(tlsConfig),
|
|
listener.AdmissionOption(admission.AdmissionGroup(admissions...)),
|
|
listener.TrafficLimiterOption(registry.TrafficLimiterRegistry().Get(cfg.Limiter)),
|
|
listener.ConnLimiterOption(registry.ConnLimiterRegistry().Get(cfg.CLimiter)),
|
|
listener.ServiceOption(cfg.Name),
|
|
listener.ProxyProtocolOption(ppv),
|
|
listener.StatsOption(pStats),
|
|
listener.NetnsOption(netnsIn),
|
|
listener.LoggerOption(listenerLogger),
|
|
}
|
|
|
|
if netnsIn != "" {
|
|
runtime.LockOSThread()
|
|
defer runtime.UnlockOSThread()
|
|
|
|
originNs, err := netns.Get()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("netns.Get(): %v", err)
|
|
}
|
|
defer netns.Set(originNs)
|
|
|
|
ns, err := netns.GetFromName(netnsIn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("netns.GetFromName(%s): %v", netnsIn, err)
|
|
}
|
|
defer ns.Close()
|
|
|
|
if err := netns.Set(ns); err != nil {
|
|
return nil, fmt.Errorf("netns.Set(%s): %v", netnsIn, err)
|
|
}
|
|
}
|
|
|
|
var ln listener.Listener
|
|
if rf := registry.ListenerRegistry().Get(cfg.Listener.Type); rf != nil {
|
|
ln = rf(listenOpts...)
|
|
} else {
|
|
return nil, fmt.Errorf("unknown listener: %s", cfg.Listener.Type)
|
|
}
|
|
|
|
if cfg.Listener.Metadata == nil {
|
|
cfg.Listener.Metadata = make(map[string]any)
|
|
}
|
|
listenerLogger.Debugf("metadata: %v", cfg.Listener.Metadata)
|
|
if err := ln.Init(metadata.NewMetadata(cfg.Listener.Metadata)); err != nil {
|
|
listenerLogger.Error("init: ", err)
|
|
return nil, err
|
|
}
|
|
|
|
handlerLogger := serviceLogger.WithFields(map[string]any{
|
|
"kind": "handler",
|
|
})
|
|
|
|
tlsCfg = cfg.Handler.TLS
|
|
if tlsCfg == nil {
|
|
tlsCfg = &config.TLSConfig{}
|
|
}
|
|
tlsConfig, err = tls_util.LoadServerConfig(tlsCfg)
|
|
if err != nil {
|
|
handlerLogger.Error(err)
|
|
return nil, err
|
|
}
|
|
if tlsConfig == nil {
|
|
tlsConfig = parsing.DefaultTLSConfig().Clone()
|
|
}
|
|
|
|
authers = auth_parser.List(cfg.Handler.Auther, cfg.Handler.Authers...)
|
|
if len(authers) == 0 {
|
|
if auther := auth_parser.ParseAutherFromAuth(cfg.Handler.Auth); auther != nil {
|
|
authers = append(authers, auther)
|
|
}
|
|
}
|
|
|
|
auther = nil
|
|
if len(authers) > 0 {
|
|
auther = auth.AuthenticatorGroup(authers...)
|
|
}
|
|
|
|
var recorders []recorder.RecorderObject
|
|
for _, r := range cfg.Recorders {
|
|
md := metadata.NewMetadata(r.Metadata)
|
|
recorders = append(recorders, recorder.RecorderObject{
|
|
Recorder: registry.RecorderRegistry().Get(r.Name),
|
|
Record: r.Record,
|
|
Options: &recorder.Options{
|
|
Direction: mdutil.GetBool(md, parsing.MDKeyRecorderDirection),
|
|
TimestampFormat: mdutil.GetString(md, parsing.MDKeyRecorderTimestampFormat),
|
|
Hexdump: mdutil.GetBool(md, parsing.MDKeyRecorderHexdump),
|
|
},
|
|
})
|
|
}
|
|
|
|
routerOpts = []chain.RouterOption{
|
|
chain.RetriesRouterOption(cfg.Handler.Retries),
|
|
chain.TimeoutRouterOption(dialTimeout),
|
|
chain.InterfaceRouterOption(ifce),
|
|
chain.NetnsRouterOption(netnsOut),
|
|
chain.SockOptsRouterOption(sockOpts),
|
|
chain.ResolverRouterOption(registry.ResolverRegistry().Get(cfg.Resolver)),
|
|
chain.HostMapperRouterOption(registry.HostsRegistry().Get(cfg.Hosts)),
|
|
chain.RecordersRouterOption(recorders...),
|
|
chain.LoggerRouterOption(handlerLogger),
|
|
}
|
|
if !ignoreChain {
|
|
routerOpts = append(routerOpts,
|
|
chain.ChainRouterOption(chainGroup(cfg.Handler.Chain, cfg.Handler.ChainGroup)),
|
|
)
|
|
}
|
|
|
|
var h handler.Handler
|
|
if rf := registry.HandlerRegistry().Get(cfg.Handler.Type); rf != nil {
|
|
h = rf(
|
|
handler.RouterOption(chain.NewRouter(routerOpts...)),
|
|
handler.AutherOption(auther),
|
|
handler.AuthOption(auth_parser.Info(cfg.Handler.Auth)),
|
|
handler.BypassOption(bypass.BypassGroup(bypass_parser.List(cfg.Bypass, cfg.Bypasses...)...)),
|
|
handler.TLSConfigOption(tlsConfig),
|
|
handler.RateLimiterOption(registry.RateLimiterRegistry().Get(cfg.RLimiter)),
|
|
handler.TrafficLimiterOption(registry.TrafficLimiterRegistry().Get(cfg.Handler.Limiter)),
|
|
handler.ObserverOption(registry.ObserverRegistry().Get(cfg.Handler.Observer)),
|
|
handler.LoggerOption(handlerLogger),
|
|
handler.ServiceOption(cfg.Name),
|
|
handler.NetnsOption(netnsIn),
|
|
)
|
|
} else {
|
|
return nil, fmt.Errorf("unknown handler: %s", cfg.Handler.Type)
|
|
}
|
|
|
|
if forwarder, ok := h.(handler.Forwarder); ok {
|
|
hop, err := parseForwarder(cfg.Forwarder, log)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
forwarder.Forward(hop)
|
|
}
|
|
|
|
if cfg.Handler.Metadata == nil {
|
|
cfg.Handler.Metadata = make(map[string]any)
|
|
}
|
|
handlerLogger.Debugf("metadata: %v", cfg.Handler.Metadata)
|
|
if err := h.Init(metadata.NewMetadata(cfg.Handler.Metadata)); err != nil {
|
|
handlerLogger.Error("init: ", err)
|
|
return nil, err
|
|
}
|
|
|
|
s := xservice.NewService(cfg.Name, ln, h,
|
|
xservice.AdmissionOption(admission.AdmissionGroup(admissions...)),
|
|
xservice.PreUpOption(preUp),
|
|
xservice.PreDownOption(preDown),
|
|
xservice.PostUpOption(postUp),
|
|
xservice.PostDownOption(postDown),
|
|
xservice.RecordersOption(recorders...),
|
|
xservice.StatsOption(pStats),
|
|
xservice.ObserverOption(registry.ObserverRegistry().Get(cfg.Observer)),
|
|
xservice.ObservePeriodOption(observePeriod),
|
|
xservice.LoggerOption(serviceLogger),
|
|
)
|
|
|
|
serviceLogger.Infof("listening on %s/%s", s.Addr().String(), s.Addr().Network())
|
|
return s, nil
|
|
}
|
|
|
|
func parseForwarder(cfg *config.ForwarderConfig, log logger.Logger) (hop.Hop, error) {
|
|
if cfg == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
hopName := cfg.Hop
|
|
if hopName == "" {
|
|
hopName = cfg.Name
|
|
}
|
|
if hopName != "" {
|
|
return registry.HopRegistry().Get(hopName), nil
|
|
}
|
|
|
|
hc := config.HopConfig{
|
|
Name: cfg.Name,
|
|
Selector: cfg.Selector,
|
|
}
|
|
for _, node := range cfg.Nodes {
|
|
if node != nil {
|
|
addrs := xnet.AddrPortRange(node.Addr).Addrs()
|
|
if len(addrs) == 0 {
|
|
addrs = append(addrs, node.Addr)
|
|
}
|
|
for i, addr := range addrs {
|
|
name := node.Name
|
|
if i > 0 {
|
|
name = fmt.Sprintf("%s-%d", node.Name, i)
|
|
}
|
|
|
|
filter := node.Filter
|
|
if filter == nil {
|
|
if node.Protocol != "" || node.Host != "" || node.Path != "" {
|
|
filter = &config.NodeFilterConfig{
|
|
Protocol: node.Protocol,
|
|
Host: node.Host,
|
|
Path: node.Path,
|
|
}
|
|
}
|
|
}
|
|
|
|
httpCfg := node.HTTP
|
|
if node.Auth != nil {
|
|
if httpCfg == nil {
|
|
httpCfg = &config.HTTPNodeConfig{}
|
|
}
|
|
if httpCfg.Auth == nil {
|
|
httpCfg.Auth = node.Auth
|
|
}
|
|
}
|
|
hc.Nodes = append(hc.Nodes, &config.NodeConfig{
|
|
Name: name,
|
|
Addr: addr,
|
|
Network: node.Network,
|
|
Bypass: node.Bypass,
|
|
Bypasses: node.Bypasses,
|
|
Filter: filter,
|
|
HTTP: httpCfg,
|
|
TLS: node.TLS,
|
|
Metadata: node.Metadata,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
return hop_parser.ParseHop(&hc, log)
|
|
}
|
|
|
|
func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer {
|
|
var chains []chain.Chainer
|
|
var sel selector.Selector[chain.Chainer]
|
|
|
|
if c := registry.ChainRegistry().Get(name); c != nil {
|
|
chains = append(chains, c)
|
|
}
|
|
if group != nil {
|
|
for _, s := range group.Chains {
|
|
if c := registry.ChainRegistry().Get(s); c != nil {
|
|
chains = append(chains, c)
|
|
}
|
|
}
|
|
sel = selector_parser.ParseChainSelector(group.Selector)
|
|
}
|
|
if len(chains) == 0 {
|
|
return nil
|
|
}
|
|
|
|
if sel == nil {
|
|
sel = selector_parser.DefaultChainSelector()
|
|
}
|
|
|
|
return xchain.NewChainGroup(chains...).
|
|
WithSelector(sel)
|
|
}
|