x/config/parsing/service/parse.go
2023-10-09 21:27:35 +08:00

311 lines
8.8 KiB
Go

package service
import (
"fmt"
"github.com/go-gost/core/admission"
"github.com/go-gost/core/auth"
"github.com/go-gost/core/bypass"
"github.com/go-gost/core/chain"
"github.com/go-gost/core/handler"
"github.com/go-gost/core/hop"
"github.com/go-gost/core/listener"
"github.com/go-gost/core/logger"
mdutil "github.com/go-gost/core/metadata/util"
"github.com/go-gost/core/recorder"
"github.com/go-gost/core/selector"
"github.com/go-gost/core/service"
xchain "github.com/go-gost/x/chain"
"github.com/go-gost/x/config"
"github.com/go-gost/x/config/parsing"
admission_parser "github.com/go-gost/x/config/parsing/admission"
auth_parser "github.com/go-gost/x/config/parsing/auth"
bypass_parser "github.com/go-gost/x/config/parsing/bypass"
hop_parser "github.com/go-gost/x/config/parsing/hop"
selector_parser "github.com/go-gost/x/config/parsing/selector"
tls_util "github.com/go-gost/x/internal/util/tls"
"github.com/go-gost/x/metadata"
"github.com/go-gost/x/registry"
xservice "github.com/go-gost/x/service"
)
func ParseService(cfg *config.ServiceConfig) (service.Service, error) {
if cfg.Listener == nil {
cfg.Listener = &config.ListenerConfig{
Type: "tcp",
}
}
if cfg.Handler == nil {
cfg.Handler = &config.HandlerConfig{
Type: "auto",
}
}
serviceLogger := logger.Default().WithFields(map[string]any{
"kind": "service",
"service": cfg.Name,
"listener": cfg.Listener.Type,
"handler": cfg.Handler.Type,
})
listenerLogger := serviceLogger.WithFields(map[string]any{
"kind": "listener",
})
tlsCfg := cfg.Listener.TLS
if tlsCfg == nil {
tlsCfg = &config.TLSConfig{}
}
tlsConfig, err := tls_util.LoadServerConfig(
tlsCfg.CertFile, tlsCfg.KeyFile, tlsCfg.CAFile)
if err != nil {
listenerLogger.Error(err)
return nil, err
}
if tlsConfig == nil {
tlsConfig = parsing.DefaultTLSConfig().Clone()
}
authers := auth_parser.List(cfg.Listener.Auther, cfg.Listener.Authers...)
if len(authers) == 0 {
if auther := auth_parser.ParseAutherFromAuth(cfg.Listener.Auth); auther != nil {
authers = append(authers, auther)
}
}
var auther auth.Authenticator
if len(authers) > 0 {
auther = auth.AuthenticatorGroup(authers...)
}
admissions := admission_parser.List(cfg.Admission, cfg.Admissions...)
var sockOpts *chain.SockOpts
if cfg.SockOpts != nil {
sockOpts = &chain.SockOpts{
Mark: cfg.SockOpts.Mark,
}
}
var ppv int
ifce := cfg.Interface
var preUp, preDown, postUp, postDown []string
var ignoreChain bool
if cfg.Metadata != nil {
md := metadata.NewMetadata(cfg.Metadata)
ppv = mdutil.GetInt(md, parsing.MDKeyProxyProtocol)
if v := mdutil.GetString(md, parsing.MDKeyInterface); v != "" {
ifce = v
}
if v := mdutil.GetInt(md, parsing.MDKeySoMark); v > 0 {
sockOpts = &chain.SockOpts{
Mark: v,
}
}
preUp = mdutil.GetStrings(md, parsing.MDKeyPreUp)
preDown = mdutil.GetStrings(md, parsing.MDKeyPreDown)
postUp = mdutil.GetStrings(md, parsing.MDKeyPostUp)
postDown = mdutil.GetStrings(md, parsing.MDKeyPostDown)
ignoreChain = mdutil.GetBool(md, parsing.MDKeyIgnoreChain)
}
listenOpts := []listener.Option{
listener.AddrOption(cfg.Addr),
listener.AutherOption(auther),
listener.AuthOption(auth_parser.Info(cfg.Listener.Auth)),
listener.TLSConfigOption(tlsConfig),
listener.AdmissionOption(admission.AdmissionGroup(admissions...)),
listener.TrafficLimiterOption(registry.TrafficLimiterRegistry().Get(cfg.Limiter)),
listener.ConnLimiterOption(registry.ConnLimiterRegistry().Get(cfg.CLimiter)),
listener.LoggerOption(listenerLogger),
listener.ServiceOption(cfg.Name),
listener.ProxyProtocolOption(ppv),
}
if !ignoreChain {
listenOpts = append(listenOpts,
listener.ChainOption(chainGroup(cfg.Listener.Chain, cfg.Listener.ChainGroup)),
)
}
var ln listener.Listener
if rf := registry.ListenerRegistry().Get(cfg.Listener.Type); rf != nil {
ln = rf(listenOpts...)
} else {
return nil, fmt.Errorf("unregistered listener: %s", cfg.Listener.Type)
}
if cfg.Listener.Metadata == nil {
cfg.Listener.Metadata = make(map[string]any)
}
listenerLogger.Debugf("metadata: %v", cfg.Listener.Metadata)
if err := ln.Init(metadata.NewMetadata(cfg.Listener.Metadata)); err != nil {
listenerLogger.Error("init: ", err)
return nil, err
}
handlerLogger := serviceLogger.WithFields(map[string]any{
"kind": "handler",
})
tlsCfg = cfg.Handler.TLS
if tlsCfg == nil {
tlsCfg = &config.TLSConfig{}
}
tlsConfig, err = tls_util.LoadServerConfig(
tlsCfg.CertFile, tlsCfg.KeyFile, tlsCfg.CAFile)
if err != nil {
handlerLogger.Error(err)
return nil, err
}
if tlsConfig == nil {
tlsConfig = parsing.DefaultTLSConfig().Clone()
}
authers = auth_parser.List(cfg.Handler.Auther, cfg.Handler.Authers...)
if len(authers) == 0 {
if auther := auth_parser.ParseAutherFromAuth(cfg.Handler.Auth); auther != nil {
authers = append(authers, auther)
}
}
auther = nil
if len(authers) > 0 {
auther = auth.AuthenticatorGroup(authers...)
}
var recorders []recorder.RecorderObject
for _, r := range cfg.Recorders {
md := metadata.NewMetadata(r.Metadata)
recorders = append(recorders, recorder.RecorderObject{
Recorder: registry.RecorderRegistry().Get(r.Name),
Record: r.Record,
Options: &recorder.Options{
Direction: mdutil.GetBool(md, parsing.MDKeyRecorderDirection),
TimestampFormat: mdutil.GetString(md, parsing.MDKeyRecorderTimestampFormat),
Hexdump: mdutil.GetBool(md, parsing.MDKeyRecorderHexdump),
},
})
}
routerOpts := []chain.RouterOption{
chain.RetriesRouterOption(cfg.Handler.Retries),
// chain.TimeoutRouterOption(10*time.Second),
chain.InterfaceRouterOption(ifce),
chain.SockOptsRouterOption(sockOpts),
chain.ResolverRouterOption(registry.ResolverRegistry().Get(cfg.Resolver)),
chain.HostMapperRouterOption(registry.HostsRegistry().Get(cfg.Hosts)),
chain.RecordersRouterOption(recorders...),
chain.LoggerRouterOption(handlerLogger),
}
if !ignoreChain {
routerOpts = append(routerOpts,
chain.ChainRouterOption(chainGroup(cfg.Handler.Chain, cfg.Handler.ChainGroup)),
)
}
router := chain.NewRouter(routerOpts...)
var h handler.Handler
if rf := registry.HandlerRegistry().Get(cfg.Handler.Type); rf != nil {
h = rf(
handler.RouterOption(router),
handler.AutherOption(auther),
handler.AuthOption(auth_parser.Info(cfg.Handler.Auth)),
handler.BypassOption(bypass.BypassGroup(bypass_parser.List(cfg.Bypass, cfg.Bypasses...)...)),
handler.TLSConfigOption(tlsConfig),
handler.RateLimiterOption(registry.RateLimiterRegistry().Get(cfg.RLimiter)),
handler.LoggerOption(handlerLogger),
handler.ServiceOption(cfg.Name),
)
} else {
return nil, fmt.Errorf("unregistered handler: %s", cfg.Handler.Type)
}
if forwarder, ok := h.(handler.Forwarder); ok {
hop, err := parseForwarder(cfg.Forwarder)
if err != nil {
return nil, err
}
forwarder.Forward(hop)
}
if cfg.Handler.Metadata == nil {
cfg.Handler.Metadata = make(map[string]any)
}
handlerLogger.Debugf("metadata: %v", cfg.Handler.Metadata)
if err := h.Init(metadata.NewMetadata(cfg.Handler.Metadata)); err != nil {
handlerLogger.Error("init: ", err)
return nil, err
}
s := xservice.NewService(cfg.Name, ln, h,
xservice.AdmissionOption(admission.AdmissionGroup(admissions...)),
xservice.PreUpOption(preUp),
xservice.PreDownOption(preDown),
xservice.PostUpOption(postUp),
xservice.PostDownOption(postDown),
xservice.RecordersOption(recorders...),
xservice.LoggerOption(serviceLogger),
)
serviceLogger.Infof("listening on %s/%s", s.Addr().String(), s.Addr().Network())
return s, nil
}
func parseForwarder(cfg *config.ForwarderConfig) (hop.Hop, error) {
if cfg == nil {
return nil, nil
}
hc := config.HopConfig{
Name: cfg.Name,
Selector: cfg.Selector,
}
for _, node := range cfg.Nodes {
if node != nil {
hc.Nodes = append(hc.Nodes,
&config.NodeConfig{
Name: node.Name,
Addr: node.Addr,
Host: node.Host,
Network: node.Network,
Protocol: node.Protocol,
Bypass: node.Bypass,
Bypasses: node.Bypasses,
HTTP: node.HTTP,
TLS: node.TLS,
Auth: node.Auth,
},
)
}
}
if len(hc.Nodes) > 0 {
return hop_parser.ParseHop(&hc)
}
return registry.HopRegistry().Get(hc.Name), nil
}
func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer {
var chains []chain.Chainer
var sel selector.Selector[chain.Chainer]
if c := registry.ChainRegistry().Get(name); c != nil {
chains = append(chains, c)
}
if group != nil {
for _, s := range group.Chains {
if c := registry.ChainRegistry().Get(s); c != nil {
chains = append(chains, c)
}
}
sel = selector_parser.ParseChainSelector(group.Selector)
}
if len(chains) == 0 {
return nil
}
if sel == nil {
sel = selector_parser.DefaultChainSelector()
}
return xchain.NewChainGroup(chains...).
WithSelector(sel)
}