diff --git a/admission/plugin.go b/admission/plugin.go index c9ce744..c2a246d 100644 --- a/admission/plugin.go +++ b/admission/plugin.go @@ -10,18 +10,18 @@ import ( "github.com/go-gost/core/admission" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/admission/proto" - "github.com/go-gost/x/internal/util/plugin" + "github.com/go-gost/x/internal/plugin" "google.golang.org/grpc" ) -type grpcPluginAdmission struct { +type grpcPlugin struct { conn grpc.ClientConnInterface client proto.AdmissionClient log logger.Logger } -// NewGRPCPluginAdmission creates an Admission plugin based on gRPC. -func NewGRPCPluginAdmission(name string, addr string, opts ...plugin.Option) admission.Admission { +// NewGRPCPlugin creates an Admission plugin based on gRPC. +func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) admission.Admission { var options plugin.Options for _, opt := range opts { opt(&options) @@ -36,7 +36,7 @@ func NewGRPCPluginAdmission(name string, addr string, opts ...plugin.Option) adm log.Error(err) } - p := &grpcPluginAdmission{ + p := &grpcPlugin{ conn: conn, log: log, } @@ -46,7 +46,7 @@ func NewGRPCPluginAdmission(name string, addr string, opts ...plugin.Option) adm return p } -func (p *grpcPluginAdmission) Admit(ctx context.Context, addr string) bool { +func (p *grpcPlugin) Admit(ctx context.Context, addr string) bool { if p.client == nil { return false } @@ -62,36 +62,36 @@ func (p *grpcPluginAdmission) Admit(ctx context.Context, addr string) bool { return r.Ok } -func (p *grpcPluginAdmission) Close() error { +func (p *grpcPlugin) Close() error { if closer, ok := p.conn.(io.Closer); ok { return closer.Close() } return nil } -type httpAdmissionRequest struct { +type httpPluginRequest struct { Addr string `json:"addr"` } -type httpAdmissionResponse struct { +type httpPluginResponse struct { OK bool `json:"ok"` } -type httpPluginAdmission struct { +type httpPlugin struct { url string client *http.Client header http.Header log logger.Logger } -// NewHTTPPluginAdmission creates an Admission plugin based on HTTP. -func NewHTTPPluginAdmission(name string, url string, opts ...plugin.Option) admission.Admission { +// NewHTTPPlugin creates an Admission plugin based on HTTP. +func NewHTTPPlugin(name string, url string, opts ...plugin.Option) admission.Admission { var options plugin.Options for _, opt := range opts { opt(&options) } - return &httpPluginAdmission{ + return &httpPlugin{ url: url, client: plugin.NewHTTPClient(&options), header: options.Header, @@ -102,12 +102,12 @@ func NewHTTPPluginAdmission(name string, url string, opts ...plugin.Option) admi } } -func (p *httpPluginAdmission) Admit(ctx context.Context, addr string) (ok bool) { +func (p *httpPlugin) Admit(ctx context.Context, addr string) (ok bool) { if p.client == nil { return } - rb := httpAdmissionRequest{ + rb := httpPluginRequest{ Addr: addr, } v, err := json.Marshal(&rb) @@ -134,7 +134,7 @@ func (p *httpPluginAdmission) Admit(ctx context.Context, addr string) (ok bool) return } - res := httpAdmissionResponse{} + res := httpPluginResponse{} if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { return } diff --git a/api/config_admission.go b/api/config_admission.go index a1254e2..0be8c81 100644 --- a/api/config_admission.go +++ b/api/config_admission.go @@ -5,7 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-gost/x/config" - "github.com/go-gost/x/config/parsing" + parser "github.com/go-gost/x/config/parsing/admission" "github.com/go-gost/x/registry" ) @@ -40,7 +40,7 @@ func createAdmission(ctx *gin.Context) { return } - v := parsing.ParseAdmission(&req.Data) + v := parser.ParseAdmission(&req.Data) if err := registry.AdmissionRegistry().Register(req.Data.Name, v); err != nil { writeError(ctx, ErrDup) @@ -94,7 +94,7 @@ func updateAdmission(ctx *gin.Context) { req.Data.Name = req.Admission - v := parsing.ParseAdmission(&req.Data) + v := parser.ParseAdmission(&req.Data) registry.AdmissionRegistry().Unregister(req.Admission) diff --git a/api/config_auther.go b/api/config_auther.go index cfaab3c..68264c4 100644 --- a/api/config_auther.go +++ b/api/config_auther.go @@ -5,7 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-gost/x/config" - "github.com/go-gost/x/config/parsing" + parser "github.com/go-gost/x/config/parsing/auth" "github.com/go-gost/x/registry" ) @@ -40,7 +40,7 @@ func createAuther(ctx *gin.Context) { return } - v := parsing.ParseAuther(&req.Data) + v := parser.ParseAuther(&req.Data) if err := registry.AutherRegistry().Register(req.Data.Name, v); err != nil { writeError(ctx, ErrDup) return @@ -93,7 +93,7 @@ func updateAuther(ctx *gin.Context) { req.Data.Name = req.Auther - v := parsing.ParseAuther(&req.Data) + v := parser.ParseAuther(&req.Data) registry.AutherRegistry().Unregister(req.Auther) if err := registry.AutherRegistry().Register(req.Auther, v); err != nil { diff --git a/api/config_bypass.go b/api/config_bypass.go index 925406f..03aabbc 100644 --- a/api/config_bypass.go +++ b/api/config_bypass.go @@ -5,7 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-gost/x/config" - "github.com/go-gost/x/config/parsing" + parser "github.com/go-gost/x/config/parsing/bypass" "github.com/go-gost/x/registry" ) @@ -40,7 +40,7 @@ func createBypass(ctx *gin.Context) { return } - v := parsing.ParseBypass(&req.Data) + v := parser.ParseBypass(&req.Data) if err := registry.BypassRegistry().Register(req.Data.Name, v); err != nil { writeError(ctx, ErrDup) @@ -94,7 +94,7 @@ func updateBypass(ctx *gin.Context) { req.Data.Name = req.Bypass - v := parsing.ParseBypass(&req.Data) + v := parser.ParseBypass(&req.Data) registry.BypassRegistry().Unregister(req.Bypass) diff --git a/api/config_chain.go b/api/config_chain.go index 270c8c6..eba9acb 100644 --- a/api/config_chain.go +++ b/api/config_chain.go @@ -5,7 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-gost/x/config" - "github.com/go-gost/x/config/parsing" + parser "github.com/go-gost/x/config/parsing/chain" "github.com/go-gost/x/registry" ) @@ -40,7 +40,7 @@ func createChain(ctx *gin.Context) { return } - v, err := parsing.ParseChain(&req.Data) + v, err := parser.ParseChain(&req.Data) if err != nil { writeError(ctx, ErrCreate) return @@ -99,7 +99,7 @@ func updateChain(ctx *gin.Context) { req.Data.Name = req.Chain - v, err := parsing.ParseChain(&req.Data) + v, err := parser.ParseChain(&req.Data) if err != nil { writeError(ctx, ErrCreate) return diff --git a/api/config_conn_limiter.go b/api/config_conn_limiter.go index 05ac61a..62e916a 100644 --- a/api/config_conn_limiter.go +++ b/api/config_conn_limiter.go @@ -5,7 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-gost/x/config" - "github.com/go-gost/x/config/parsing" + parser "github.com/go-gost/x/config/parsing/limiter" "github.com/go-gost/x/registry" ) @@ -40,7 +40,7 @@ func createConnLimiter(ctx *gin.Context) { return } - v := parsing.ParseConnLimiter(&req.Data) + v := parser.ParseConnLimiter(&req.Data) if err := registry.ConnLimiterRegistry().Register(req.Data.Name, v); err != nil { writeError(ctx, ErrDup) @@ -94,7 +94,7 @@ func updateConnLimiter(ctx *gin.Context) { req.Data.Name = req.Limiter - v := parsing.ParseConnLimiter(&req.Data) + v := parser.ParseConnLimiter(&req.Data) registry.ConnLimiterRegistry().Unregister(req.Limiter) diff --git a/api/config_hop.go b/api/config_hop.go index 6307737..0021c45 100644 --- a/api/config_hop.go +++ b/api/config_hop.go @@ -5,7 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-gost/x/config" - "github.com/go-gost/x/config/parsing" + parser "github.com/go-gost/x/config/parsing/hop" "github.com/go-gost/x/registry" ) @@ -40,7 +40,7 @@ func createHop(ctx *gin.Context) { return } - v, err := parsing.ParseHop(&req.Data) + v, err := parser.ParseHop(&req.Data) if err != nil { writeError(ctx, ErrCreate) return @@ -99,7 +99,7 @@ func updateHop(ctx *gin.Context) { req.Data.Name = req.Hop - v, err := parsing.ParseHop(&req.Data) + v, err := parser.ParseHop(&req.Data) if err != nil { writeError(ctx, ErrCreate) return diff --git a/api/config_hosts.go b/api/config_hosts.go index 5ae51d0..7d65dbb 100644 --- a/api/config_hosts.go +++ b/api/config_hosts.go @@ -5,7 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-gost/x/config" - "github.com/go-gost/x/config/parsing" + parser "github.com/go-gost/x/config/parsing/hosts" "github.com/go-gost/x/registry" ) @@ -40,7 +40,7 @@ func createHosts(ctx *gin.Context) { return } - v := parsing.ParseHosts(&req.Data) + v := parser.ParseHostMapper(&req.Data) if err := registry.HostsRegistry().Register(req.Data.Name, v); err != nil { writeError(ctx, ErrDup) @@ -94,7 +94,7 @@ func updateHosts(ctx *gin.Context) { req.Data.Name = req.Hosts - v := parsing.ParseHosts(&req.Data) + v := parser.ParseHostMapper(&req.Data) registry.HostsRegistry().Unregister(req.Hosts) diff --git a/api/config_ingress.go b/api/config_ingress.go index cf50f39..5258ed5 100644 --- a/api/config_ingress.go +++ b/api/config_ingress.go @@ -5,7 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-gost/x/config" - "github.com/go-gost/x/config/parsing" + parser "github.com/go-gost/x/config/parsing/ingress" "github.com/go-gost/x/registry" ) @@ -40,7 +40,7 @@ func createIngress(ctx *gin.Context) { return } - v := parsing.ParseIngress(&req.Data) + v := parser.ParseIngress(&req.Data) if err := registry.IngressRegistry().Register(req.Data.Name, v); err != nil { writeError(ctx, ErrDup) @@ -94,7 +94,7 @@ func updateIngress(ctx *gin.Context) { req.Data.Name = req.Ingress - v := parsing.ParseIngress(&req.Data) + v := parser.ParseIngress(&req.Data) registry.IngressRegistry().Unregister(req.Ingress) diff --git a/api/config_limiter.go b/api/config_limiter.go index 8cbcc27..73ecaa7 100644 --- a/api/config_limiter.go +++ b/api/config_limiter.go @@ -5,7 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-gost/x/config" - "github.com/go-gost/x/config/parsing" + parser "github.com/go-gost/x/config/parsing/limiter" "github.com/go-gost/x/registry" ) @@ -40,7 +40,7 @@ func createLimiter(ctx *gin.Context) { return } - v := parsing.ParseTrafficLimiter(&req.Data) + v := parser.ParseTrafficLimiter(&req.Data) if err := registry.TrafficLimiterRegistry().Register(req.Data.Name, v); err != nil { writeError(ctx, ErrDup) @@ -94,7 +94,7 @@ func updateLimiter(ctx *gin.Context) { req.Data.Name = req.Limiter - v := parsing.ParseTrafficLimiter(&req.Data) + v := parser.ParseTrafficLimiter(&req.Data) registry.TrafficLimiterRegistry().Unregister(req.Limiter) diff --git a/api/config_rate_limiter.go b/api/config_rate_limiter.go index 984adae..e491e6e 100644 --- a/api/config_rate_limiter.go +++ b/api/config_rate_limiter.go @@ -5,7 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-gost/x/config" - "github.com/go-gost/x/config/parsing" + parser "github.com/go-gost/x/config/parsing/limiter" "github.com/go-gost/x/registry" ) @@ -40,7 +40,7 @@ func createRateLimiter(ctx *gin.Context) { return } - v := parsing.ParseRateLimiter(&req.Data) + v := parser.ParseRateLimiter(&req.Data) if err := registry.RateLimiterRegistry().Register(req.Data.Name, v); err != nil { writeError(ctx, ErrDup) @@ -94,7 +94,7 @@ func updateRateLimiter(ctx *gin.Context) { req.Data.Name = req.Limiter - v := parsing.ParseRateLimiter(&req.Data) + v := parser.ParseRateLimiter(&req.Data) registry.RateLimiterRegistry().Unregister(req.Limiter) diff --git a/api/config_resolver.go b/api/config_resolver.go index 1bde112..485ce49 100644 --- a/api/config_resolver.go +++ b/api/config_resolver.go @@ -5,7 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-gost/x/config" - "github.com/go-gost/x/config/parsing" + parser "github.com/go-gost/x/config/parsing/resolver" "github.com/go-gost/x/registry" ) @@ -40,7 +40,7 @@ func createResolver(ctx *gin.Context) { return } - v, err := parsing.ParseResolver(&req.Data) + v, err := parser.ParseResolver(&req.Data) if err != nil { writeError(ctx, ErrCreate) return @@ -98,7 +98,7 @@ func updateResolver(ctx *gin.Context) { req.Data.Name = req.Resolver - v, err := parsing.ParseResolver(&req.Data) + v, err := parser.ParseResolver(&req.Data) if err != nil { writeError(ctx, ErrCreate) return diff --git a/api/config_service.go b/api/config_service.go index ae385bb..f174042 100644 --- a/api/config_service.go +++ b/api/config_service.go @@ -5,7 +5,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-gost/x/config" - "github.com/go-gost/x/config/parsing" + parser "github.com/go-gost/x/config/parsing/service" "github.com/go-gost/x/registry" ) @@ -45,7 +45,7 @@ func createService(ctx *gin.Context) { return } - svc, err := parsing.ParseService(&req.Data) + svc, err := parser.ParseService(&req.Data) if err != nil { writeError(ctx, ErrCreate) return @@ -108,7 +108,7 @@ func updateService(ctx *gin.Context) { req.Data.Name = req.Service - svc, err := parsing.ParseService(&req.Data) + svc, err := parser.ParseService(&req.Data) if err != nil { writeError(ctx, ErrCreate) return diff --git a/auth/plugin.go b/auth/plugin.go index a8ec582..b53885e 100644 --- a/auth/plugin.go +++ b/auth/plugin.go @@ -10,19 +10,19 @@ import ( "github.com/go-gost/core/auth" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/auth/proto" + "github.com/go-gost/x/internal/plugin" auth_util "github.com/go-gost/x/internal/util/auth" - "github.com/go-gost/x/internal/util/plugin" "google.golang.org/grpc" ) -type grpcPluginAuthenticator struct { +type grpcPlugin struct { conn grpc.ClientConnInterface client proto.AuthenticatorClient log logger.Logger } -// NewGRPCPluginAuthenticator creates an Authenticator plugin based on gRPC. -func NewGRPCPluginAuthenticator(name string, addr string, opts ...plugin.Option) auth.Authenticator { +// NewGRPCPlugin creates an Authenticator plugin based on gRPC. +func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) auth.Authenticator { var options plugin.Options for _, opt := range opts { opt(&options) @@ -37,7 +37,7 @@ func NewGRPCPluginAuthenticator(name string, addr string, opts ...plugin.Option) log.Error(err) } - p := &grpcPluginAuthenticator{ + p := &grpcPlugin{ conn: conn, log: log, } @@ -49,7 +49,7 @@ func NewGRPCPluginAuthenticator(name string, addr string, opts ...plugin.Option) } // Authenticate checks the validity of the provided user-password pair. -func (p *grpcPluginAuthenticator) Authenticate(ctx context.Context, user, password string) (string, bool) { +func (p *grpcPlugin) Authenticate(ctx context.Context, user, password string) (string, bool) { if p.client == nil { return "", false } @@ -67,39 +67,39 @@ func (p *grpcPluginAuthenticator) Authenticate(ctx context.Context, user, passwo return r.Id, r.Ok } -func (p *grpcPluginAuthenticator) Close() error { +func (p *grpcPlugin) Close() error { if closer, ok := p.conn.(io.Closer); ok { return closer.Close() } return nil } -type httpAutherRequest struct { +type httpPluginRequest struct { Username string `json:"username"` Password string `json:"password"` Client string `json:"client"` } -type httpAutherResponse struct { +type httpPluginResponse struct { OK bool `json:"ok"` ID string `json:"id"` } -type httpPluginAuther struct { +type httpPlugin struct { url string client *http.Client header http.Header log logger.Logger } -// NewHTTPPluginAuthenticator creates an Authenticator plugin based on HTTP. -func NewHTTPPluginAuthenticator(name string, url string, opts ...plugin.Option) auth.Authenticator { +// NewHTTPPlugin creates an Authenticator plugin based on HTTP. +func NewHTTPPlugin(name string, url string, opts ...plugin.Option) auth.Authenticator { var options plugin.Options for _, opt := range opts { opt(&options) } - return &httpPluginAuther{ + return &httpPlugin{ url: url, client: plugin.NewHTTPClient(&options), header: options.Header, @@ -110,12 +110,12 @@ func NewHTTPPluginAuthenticator(name string, url string, opts ...plugin.Option) } } -func (p *httpPluginAuther) Authenticate(ctx context.Context, user, password string) (id string, ok bool) { +func (p *httpPlugin) Authenticate(ctx context.Context, user, password string) (id string, ok bool) { if p.client == nil { return } - rb := httpAutherRequest{ + rb := httpPluginRequest{ Username: user, Password: password, Client: string(auth_util.ClientAddrFromContext(ctx)), @@ -144,7 +144,7 @@ func (p *httpPluginAuther) Authenticate(ctx context.Context, user, password stri return } - res := httpAutherResponse{} + res := httpPluginResponse{} if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { return } diff --git a/bypass/bypass.go b/bypass/bypass.go index 6d300b8..0d53e93 100644 --- a/bypass/bypass.go +++ b/bypass/bypass.go @@ -13,7 +13,6 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/x/internal/loader" "github.com/go-gost/x/internal/matcher" - "google.golang.org/grpc" ) type options struct { @@ -22,7 +21,6 @@ type options struct { fileLoader loader.Loader redisLoader loader.Loader httpLoader loader.Loader - client *grpc.ClientConn period time.Duration logger logger.Logger } @@ -65,12 +63,6 @@ func HTTPLoaderOption(httpLoader loader.Loader) Option { } } -func PluginConnOption(c *grpc.ClientConn) Option { - return func(opts *options) { - opts.client = c - } -} - func LoggerOption(logger logger.Logger) Option { return func(opts *options) { opts.logger = logger diff --git a/bypass/plugin.go b/bypass/plugin.go index 88dfafe..80cd3c7 100644 --- a/bypass/plugin.go +++ b/bypass/plugin.go @@ -11,18 +11,18 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/plugin/bypass/proto" auth_util "github.com/go-gost/x/internal/util/auth" - "github.com/go-gost/x/internal/util/plugin" + "github.com/go-gost/x/internal/plugin" "google.golang.org/grpc" ) -type grpcPluginBypass struct { +type grpcPlugin struct { conn grpc.ClientConnInterface client proto.BypassClient log logger.Logger } -// NewGRPCPluginBypass creates a Bypass plugin based on gRPC. -func NewGRPCPluginBypass(name string, addr string, opts ...plugin.Option) bypass.Bypass { +// NewGRPCPlugin creates a Bypass plugin based on gRPC. +func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) bypass.Bypass { var options plugin.Options for _, opt := range opts { opt(&options) @@ -37,7 +37,7 @@ func NewGRPCPluginBypass(name string, addr string, opts ...plugin.Option) bypass log.Error(err) } - p := &grpcPluginBypass{ + p := &grpcPlugin{ conn: conn, log: log, } @@ -47,7 +47,7 @@ func NewGRPCPluginBypass(name string, addr string, opts ...plugin.Option) bypass return p } -func (p *grpcPluginBypass) Contains(ctx context.Context, addr string) bool { +func (p *grpcPlugin) Contains(ctx context.Context, addr string) bool { if p.client == nil { return true } @@ -64,37 +64,37 @@ func (p *grpcPluginBypass) Contains(ctx context.Context, addr string) bool { return r.Ok } -func (p *grpcPluginBypass) Close() error { +func (p *grpcPlugin) Close() error { if closer, ok := p.conn.(io.Closer); ok { return closer.Close() } return nil } -type httpBypassRequest struct { +type httpPluginRequest struct { Addr string `json:"addr"` Client string `json:"client"` } -type httpBypassResponse struct { +type httpPluginResponse struct { OK bool `json:"ok"` } -type httpPluginBypass struct { +type httpPlugin struct { url string client *http.Client header http.Header log logger.Logger } -// NewHTTPPluginBypass creates an Bypass plugin based on HTTP. -func NewHTTPPluginBypass(name string, url string, opts ...plugin.Option) bypass.Bypass { +// NewHTTPPlugin creates an Bypass plugin based on HTTP. +func NewHTTPPlugin(name string, url string, opts ...plugin.Option) bypass.Bypass { var options plugin.Options for _, opt := range opts { opt(&options) } - return &httpPluginBypass{ + return &httpPlugin{ url: url, client: plugin.NewHTTPClient(&options), header: options.Header, @@ -105,12 +105,12 @@ func NewHTTPPluginBypass(name string, url string, opts ...plugin.Option) bypass. } } -func (p *httpPluginBypass) Contains(ctx context.Context, addr string) (ok bool) { +func (p *httpPlugin) Contains(ctx context.Context, addr string) (ok bool) { if p.client == nil { return } - rb := httpBypassRequest{ + rb := httpPluginRequest{ Addr: addr, Client: string(auth_util.IDFromContext(ctx)), } @@ -138,7 +138,7 @@ func (p *httpPluginBypass) Contains(ctx context.Context, addr string) (ok bool) return } - res := httpBypassResponse{} + res := httpPluginResponse{} if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { return } diff --git a/chain/chain.go b/chain/chain.go index 0ff02dc..f8ac3a1 100644 --- a/chain/chain.go +++ b/chain/chain.go @@ -4,6 +4,7 @@ import ( "context" "github.com/go-gost/core/chain" + "github.com/go-gost/core/hop" "github.com/go-gost/core/logger" "github.com/go-gost/core/metadata" "github.com/go-gost/core/selector" @@ -38,7 +39,7 @@ type chainNamer interface { type Chain struct { name string - hops []chain.Hop + hops []hop.Hop marker selector.Marker metadata metadata.Metadata logger logger.Logger @@ -60,7 +61,7 @@ func NewChain(name string, opts ...ChainOption) *Chain { } } -func (c *Chain) AddHop(hop chain.Hop) { +func (c *Chain) AddHop(hop hop.Hop) { c.hops = append(c.hops, hop) } @@ -84,8 +85,8 @@ func (c *Chain) Route(ctx context.Context, network, address string) chain.Route } rt := NewRoute(ChainRouteOption(c)) - for _, hop := range c.hops { - node := hop.Select(ctx, chain.AddrSelectOption(address)) + for _, h := range c.hops { + node := h.Select(ctx, hop.AddrSelectOption(address)) if node == nil { return rt } diff --git a/chain/hop.go b/chain/hop.go deleted file mode 100644 index 4fe5185..0000000 --- a/chain/hop.go +++ /dev/null @@ -1,138 +0,0 @@ -package chain - -import ( - "context" - "net" - "strings" - - "github.com/go-gost/core/bypass" - "github.com/go-gost/core/chain" - "github.com/go-gost/core/logger" - "github.com/go-gost/core/selector" -) - -type HopOptions struct { - bypass bypass.Bypass - selector selector.Selector[*chain.Node] - logger logger.Logger -} - -type HopOption func(*HopOptions) - -func BypassHopOption(bp bypass.Bypass) HopOption { - return func(o *HopOptions) { - o.bypass = bp - } -} - -func SelectorHopOption(s selector.Selector[*chain.Node]) HopOption { - return func(o *HopOptions) { - o.selector = s - } -} - -func LoggerHopOption(logger logger.Logger) HopOption { - return func(opts *HopOptions) { - opts.logger = logger - } -} - -type chainHop struct { - nodes []*chain.Node - options HopOptions -} - -func NewChainHop(nodes []*chain.Node, opts ...HopOption) chain.Hop { - var options HopOptions - for _, opt := range opts { - if opt != nil { - opt(&options) - } - } - - hop := &chainHop{ - nodes: nodes, - options: options, - } - - return hop -} - -func (p *chainHop) Nodes() []*chain.Node { - return p.nodes -} - -func (p *chainHop) Select(ctx context.Context, opts ...chain.SelectOption) *chain.Node { - var options chain.SelectOptions - for _, opt := range opts { - opt(&options) - } - - if p == nil || len(p.nodes) == 0 { - return nil - } - - // hop level bypass - if p.options.bypass != nil && - p.options.bypass.Contains(ctx, options.Addr) { - return nil - } - - filters := p.nodes - if host := options.Host; host != "" { - filters = nil - if v, _, _ := net.SplitHostPort(host); v != "" { - host = v - } - var nodes []*chain.Node - for _, node := range p.nodes { - if node == nil { - continue - } - vhost := node.Options().Host - if vhost == "" { - nodes = append(nodes, node) - continue - } - if vhost == host || - vhost[0] == '.' && strings.HasSuffix(host, vhost[1:]) { - filters = append(filters, node) - } - } - if len(filters) == 0 { - filters = nodes - } - } else if protocol := options.Protocol; protocol != "" { - filters = nil - for _, node := range p.nodes { - if node == nil { - continue - } - if node.Options().Protocol == protocol { - filters = append(filters, node) - } - } - } - - var nodes []*chain.Node - for _, node := range filters { - if node == nil { - continue - } - // node level bypass - if node.Options().Bypass != nil && - node.Options().Bypass.Contains(ctx, options.Addr) { - continue - } - - nodes = append(nodes, node) - } - if len(nodes) == 0 { - return nil - } - - if s := p.options.selector; s != nil { - return s.Select(ctx, nodes...) - } - return nodes[0] -} diff --git a/config/config.go b/config/config.go index a34b566..1a21e56 100644 --- a/config/config.go +++ b/config/config.go @@ -371,9 +371,7 @@ type ServiceConfig struct { } type ChainConfig struct { - Name string `json:"name"` - // REMOVED since beta.6 - // Selector *SelectorConfig `yaml:",omitempty" json:"selector,omitempty"` + Name string `json:"name"` Hops []*HopConfig `json:"hops"` Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"` } @@ -393,6 +391,11 @@ type HopConfig struct { Resolver string `yaml:",omitempty" json:"resolver,omitempty"` Hosts string `yaml:",omitempty" json:"hosts,omitempty"` Nodes []*NodeConfig `yaml:",omitempty" json:"nodes,omitempty"` + Reload time.Duration `yaml:",omitempty" json:"reload,omitempty"` + File *FileLoader `yaml:",omitempty" json:"file,omitempty"` + Redis *RedisLoader `yaml:",omitempty" json:"redis,omitempty"` + HTTP *HTTPLoader `yaml:"http,omitempty" json:"http,omitempty"` + Plugin *PluginConfig `yaml:",omitempty" json:"plugin,omitempty"` } type NodeConfig struct { diff --git a/config/parsing/admission/parse.go b/config/parsing/admission/parse.go new file mode 100644 index 0000000..4de13b7 --- /dev/null +++ b/config/parsing/admission/parse.go @@ -0,0 +1,87 @@ +package admission + +import ( + "crypto/tls" + "strings" + + "github.com/go-gost/core/admission" + "github.com/go-gost/core/logger" + xadmission "github.com/go-gost/x/admission" + "github.com/go-gost/x/config" + "github.com/go-gost/x/internal/loader" + "github.com/go-gost/x/internal/plugin" + "github.com/go-gost/x/registry" +) + +func ParseAdmission(cfg *config.AdmissionConfig) admission.Admission { + if cfg == nil { + return nil + } + + if cfg.Plugin != nil { + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return xadmission.NewHTTPPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return xadmission.NewGRPCPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) + } + } + + opts := []xadmission.Option{ + xadmission.MatchersOption(cfg.Matchers), + xadmission.WhitelistOption(cfg.Reverse || cfg.Whitelist), + xadmission.ReloadPeriodOption(cfg.Reload), + xadmission.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "admission", + "admission": cfg.Name, + })), + } + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, xadmission.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + opts = append(opts, xadmission.RedisLoaderOption(loader.RedisSetLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + if cfg.HTTP != nil && cfg.HTTP.URL != "" { + opts = append(opts, xadmission.HTTPLoaderOption(loader.HTTPLoader( + cfg.HTTP.URL, + loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), + ))) + } + + return xadmission.NewAdmission(opts...) +} + +func List(name string, names ...string) []admission.Admission { + var admissions []admission.Admission + if adm := registry.AdmissionRegistry().Get(name); adm != nil { + admissions = append(admissions, adm) + } + for _, s := range names { + if adm := registry.AdmissionRegistry().Get(s); adm != nil { + admissions = append(admissions, adm) + } + } + + return admissions +} diff --git a/config/parsing/auth/parse.go b/config/parsing/auth/parse.go new file mode 100644 index 0000000..651dff0 --- /dev/null +++ b/config/parsing/auth/parse.go @@ -0,0 +1,120 @@ +package auth + +import ( + "crypto/tls" + "net/url" + + "github.com/go-gost/core/auth" + "github.com/go-gost/core/logger" + xauth "github.com/go-gost/x/auth" + "github.com/go-gost/x/config" + "github.com/go-gost/x/internal/loader" + "github.com/go-gost/x/internal/plugin" + "github.com/go-gost/x/registry" +) + +func ParseAuther(cfg *config.AutherConfig) auth.Authenticator { + if cfg == nil { + return nil + } + + if cfg.Plugin != nil { + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch cfg.Plugin.Type { + case "http": + return xauth.NewHTTPPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return xauth.NewGRPCPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) + } + } + + m := make(map[string]string) + + for _, user := range cfg.Auths { + if user.Username == "" { + continue + } + m[user.Username] = user.Password + } + + opts := []xauth.Option{ + xauth.AuthsOption(m), + xauth.ReloadPeriodOption(cfg.Reload), + xauth.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "auther", + "auther": cfg.Name, + })), + } + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, xauth.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + opts = append(opts, xauth.RedisLoaderOption(loader.RedisHashLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + if cfg.HTTP != nil && cfg.HTTP.URL != "" { + opts = append(opts, xauth.HTTPLoaderOption(loader.HTTPLoader( + cfg.HTTP.URL, + loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), + ))) + } + return xauth.NewAuthenticator(opts...) +} + +func ParseAutherFromAuth(au *config.AuthConfig) auth.Authenticator { + if au == nil || au.Username == "" { + return nil + } + return xauth.NewAuthenticator( + xauth.AuthsOption( + map[string]string{ + au.Username: au.Password, + }, + ), + xauth.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "auther", + })), + ) +} + +func Info(cfg *config.AuthConfig) *url.Userinfo { + if cfg == nil || cfg.Username == "" { + return nil + } + + if cfg.Password == "" { + return url.User(cfg.Username) + } + return url.UserPassword(cfg.Username, cfg.Password) +} + +func List(name string, names ...string) []auth.Authenticator { + var authers []auth.Authenticator + if auther := registry.AutherRegistry().Get(name); auther != nil { + authers = append(authers, auther) + } + for _, s := range names { + if auther := registry.AutherRegistry().Get(s); auther != nil { + authers = append(authers, auther) + } + } + return authers +} diff --git a/config/parsing/bypass/parse.go b/config/parsing/bypass/parse.go new file mode 100644 index 0000000..f2b3419 --- /dev/null +++ b/config/parsing/bypass/parse.go @@ -0,0 +1,86 @@ +package bypass + +import ( + "crypto/tls" + "strings" + + "github.com/go-gost/core/bypass" + "github.com/go-gost/core/logger" + xbypass "github.com/go-gost/x/bypass" + "github.com/go-gost/x/config" + "github.com/go-gost/x/internal/loader" + "github.com/go-gost/x/internal/plugin" + "github.com/go-gost/x/registry" +) + +func ParseBypass(cfg *config.BypassConfig) bypass.Bypass { + if cfg == nil { + return nil + } + + if cfg.Plugin != nil { + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return xbypass.NewHTTPPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return xbypass.NewGRPCPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) + } + } + + opts := []xbypass.Option{ + xbypass.MatchersOption(cfg.Matchers), + xbypass.WhitelistOption(cfg.Reverse || cfg.Whitelist), + xbypass.ReloadPeriodOption(cfg.Reload), + xbypass.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "bypass", + "bypass": cfg.Name, + })), + } + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, xbypass.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + opts = append(opts, xbypass.RedisLoaderOption(loader.RedisSetLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + if cfg.HTTP != nil && cfg.HTTP.URL != "" { + opts = append(opts, xbypass.HTTPLoaderOption(loader.HTTPLoader( + cfg.HTTP.URL, + loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), + ))) + } + + return xbypass.NewBypass(opts...) +} + +func List(name string, names ...string) []bypass.Bypass { + var bypasses []bypass.Bypass + if bp := registry.BypassRegistry().Get(name); bp != nil { + bypasses = append(bypasses, bp) + } + for _, s := range names { + if bp := registry.BypassRegistry().Get(s); bp != nil { + bypasses = append(bypasses, bp) + } + } + return bypasses +} diff --git a/config/parsing/chain.go b/config/parsing/chain.go deleted file mode 100644 index e45864c..0000000 --- a/config/parsing/chain.go +++ /dev/null @@ -1,270 +0,0 @@ -package parsing - -import ( - "fmt" - "net" - "strings" - "time" - - "github.com/go-gost/core/bypass" - "github.com/go-gost/core/chain" - "github.com/go-gost/core/connector" - "github.com/go-gost/core/dialer" - "github.com/go-gost/core/logger" - "github.com/go-gost/core/metadata" - mdutil "github.com/go-gost/core/metadata/util" - auther "github.com/go-gost/x/auth" - xchain "github.com/go-gost/x/chain" - "github.com/go-gost/x/config" - tls_util "github.com/go-gost/x/internal/util/tls" - mdx "github.com/go-gost/x/metadata" - "github.com/go-gost/x/registry" -) - -func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { - if cfg == nil { - return nil, nil - } - - chainLogger := logger.Default().WithFields(map[string]any{ - "kind": "chain", - "chain": cfg.Name, - }) - - var md metadata.Metadata - if cfg.Metadata != nil { - md = mdx.NewMetadata(cfg.Metadata) - } - - c := xchain.NewChain(cfg.Name, - xchain.MetadataChainOption(md), - xchain.LoggerChainOption(chainLogger), - ) - - for _, ch := range cfg.Hops { - var hop chain.Hop - var err error - - if len(ch.Nodes) > 0 { - if hop, err = ParseHop(ch); err != nil { - return nil, err - } - } else { - hop = registry.HopRegistry().Get(ch.Name) - } - if hop != nil { - c.AddHop(hop) - } - } - - return c, nil -} - -func ParseHop(cfg *config.HopConfig) (chain.Hop, error) { - if cfg == nil { - return nil, nil - } - - hopLogger := logger.Default().WithFields(map[string]any{ - "kind": "hop", - "hop": cfg.Name, - }) - - var nodes []*chain.Node - for _, v := range cfg.Nodes { - if v == nil { - continue - } - - if v.Connector == nil { - v.Connector = &config.ConnectorConfig{ - Type: "http", - } - } - - if v.Dialer == nil { - v.Dialer = &config.DialerConfig{ - Type: "tcp", - } - } - - nodeLogger := hopLogger.WithFields(map[string]any{ - "kind": "node", - "node": v.Name, - "connector": v.Connector.Type, - "dialer": v.Dialer.Type, - }) - - serverName, _, _ := net.SplitHostPort(v.Addr) - - tlsCfg := v.Connector.TLS - if tlsCfg == nil { - tlsCfg = &config.TLSConfig{} - } - if tlsCfg.ServerName == "" { - tlsCfg.ServerName = serverName - } - tlsConfig, err := tls_util.LoadClientConfig( - tlsCfg.CertFile, tlsCfg.KeyFile, tlsCfg.CAFile, - tlsCfg.Secure, tlsCfg.ServerName) - if err != nil { - hopLogger.Error(err) - return nil, err - } - - var nm metadata.Metadata - if v.Metadata != nil { - nm = mdx.NewMetadata(v.Metadata) - } - - connectorLogger := nodeLogger.WithFields(map[string]any{ - "kind": "connector", - }) - var cr connector.Connector - if rf := registry.ConnectorRegistry().Get(v.Connector.Type); rf != nil { - cr = rf( - connector.AuthOption(parseAuth(v.Connector.Auth)), - connector.TLSConfigOption(tlsConfig), - connector.LoggerOption(connectorLogger), - ) - } else { - return nil, fmt.Errorf("unregistered connector: %s", v.Connector.Type) - } - - if v.Connector.Metadata == nil { - v.Connector.Metadata = make(map[string]any) - } - if err := cr.Init(mdx.NewMetadata(v.Connector.Metadata)); err != nil { - connectorLogger.Error("init: ", err) - return nil, err - } - - tlsCfg = v.Dialer.TLS - if tlsCfg == nil { - tlsCfg = &config.TLSConfig{} - } - if tlsCfg.ServerName == "" { - tlsCfg.ServerName = serverName - } - tlsConfig, err = tls_util.LoadClientConfig( - tlsCfg.CertFile, tlsCfg.KeyFile, tlsCfg.CAFile, - tlsCfg.Secure, tlsCfg.ServerName) - if err != nil { - hopLogger.Error(err) - return nil, err - } - - var ppv int - if nm != nil { - ppv = mdutil.GetInt(nm, mdKeyProxyProtocol) - } - - dialerLogger := nodeLogger.WithFields(map[string]any{ - "kind": "dialer", - }) - - var d dialer.Dialer - if rf := registry.DialerRegistry().Get(v.Dialer.Type); rf != nil { - d = rf( - dialer.AuthOption(parseAuth(v.Dialer.Auth)), - dialer.TLSConfigOption(tlsConfig), - dialer.LoggerOption(dialerLogger), - dialer.ProxyProtocolOption(ppv), - ) - } else { - return nil, fmt.Errorf("unregistered dialer: %s", v.Dialer.Type) - } - - if v.Dialer.Metadata == nil { - v.Dialer.Metadata = make(map[string]any) - } - if err := d.Init(mdx.NewMetadata(v.Dialer.Metadata)); err != nil { - dialerLogger.Error("init: ", err) - return nil, err - } - - if v.Resolver == "" { - v.Resolver = cfg.Resolver - } - if v.Hosts == "" { - v.Hosts = cfg.Hosts - } - if v.Interface == "" { - v.Interface = cfg.Interface - } - if v.SockOpts == nil { - v.SockOpts = cfg.SockOpts - } - - var sockOpts *chain.SockOpts - if v.SockOpts != nil { - sockOpts = &chain.SockOpts{ - Mark: v.SockOpts.Mark, - } - } - - tr := chain.NewTransport(d, cr, - chain.AddrTransportOption(v.Addr), - chain.InterfaceTransportOption(v.Interface), - chain.SockOptsTransportOption(sockOpts), - chain.TimeoutTransportOption(10*time.Second), - ) - - // convert *.example.com to .example.com - // convert *example.com to example.com - host := v.Host - if strings.HasPrefix(host, "*") { - host = host[1:] - if !strings.HasPrefix(host, ".") { - host = "." + host - } - } - - opts := []chain.NodeOption{ - chain.TransportNodeOption(tr), - chain.BypassNodeOption(bypass.BypassGroup(bypassList(v.Bypass, v.Bypasses...)...)), - chain.ResoloverNodeOption(registry.ResolverRegistry().Get(v.Resolver)), - chain.HostMapperNodeOption(registry.HostsRegistry().Get(v.Hosts)), - chain.MetadataNodeOption(nm), - chain.HostNodeOption(host), - chain.ProtocolNodeOption(v.Protocol), - } - if v.HTTP != nil { - opts = append(opts, chain.HTTPNodeOption(&chain.HTTPNodeSettings{ - Host: v.HTTP.Host, - Header: v.HTTP.Header, - })) - } - if v.TLS != nil { - opts = append(opts, chain.TLSNodeOption(&chain.TLSNodeSettings{ - ServerName: v.TLS.ServerName, - Secure: v.TLS.Secure, - })) - } - if v.Auth != nil { - opts = append(opts, chain.AutherNodeOption( - auther.NewAuthenticator( - auther.AuthsOption(map[string]string{v.Auth.Username: v.Auth.Password}), - auther.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "node", - "node": v.Name, - "addr": v.Addr, - "host": v.Host, - "protocol": v.Protocol, - })), - ))) - } - node := chain.NewNode(v.Name, v.Addr, opts...) - nodes = append(nodes, node) - } - - sel := parseNodeSelector(cfg.Selector) - if sel == nil { - sel = defaultNodeSelector() - } - return xchain.NewChainHop(nodes, - xchain.SelectorHopOption(sel), - xchain.BypassHopOption(bypass.BypassGroup(bypassList(cfg.Bypass, cfg.Bypasses...)...)), - xchain.LoggerHopOption(hopLogger), - ), nil -} diff --git a/config/parsing/chain/parse.go b/config/parsing/chain/parse.go new file mode 100644 index 0000000..8a03c12 --- /dev/null +++ b/config/parsing/chain/parse.go @@ -0,0 +1,52 @@ +package chain + +import ( + "github.com/go-gost/core/chain" + "github.com/go-gost/core/hop" + "github.com/go-gost/core/logger" + "github.com/go-gost/core/metadata" + xchain "github.com/go-gost/x/chain" + "github.com/go-gost/x/config" + hop_parser "github.com/go-gost/x/config/parsing/hop" + mdx "github.com/go-gost/x/metadata" + "github.com/go-gost/x/registry" +) + +func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { + if cfg == nil { + return nil, nil + } + + chainLogger := logger.Default().WithFields(map[string]any{ + "kind": "chain", + "chain": cfg.Name, + }) + + var md metadata.Metadata + if cfg.Metadata != nil { + md = mdx.NewMetadata(cfg.Metadata) + } + + c := xchain.NewChain(cfg.Name, + xchain.MetadataChainOption(md), + xchain.LoggerChainOption(chainLogger), + ) + + for _, ch := range cfg.Hops { + var hop hop.Hop + var err error + + if ch.Nodes != nil || ch.Plugin != nil { + if hop, err = hop_parser.ParseHop(ch); err != nil { + return nil, err + } + } else { + hop = registry.HopRegistry().Get(ch.Name) + } + if hop != nil { + c.AddHop(hop) + } + } + + return c, nil +} diff --git a/config/parsing/hop/parse.go b/config/parsing/hop/parse.go new file mode 100644 index 0000000..b3b1147 --- /dev/null +++ b/config/parsing/hop/parse.go @@ -0,0 +1,124 @@ +package hop + +import ( + "crypto/tls" + "strings" + + "github.com/go-gost/core/bypass" + "github.com/go-gost/core/chain" + "github.com/go-gost/core/hop" + "github.com/go-gost/core/logger" + "github.com/go-gost/x/config" + bypass_parser "github.com/go-gost/x/config/parsing/bypass" + node_parser "github.com/go-gost/x/config/parsing/node" + selector_parser "github.com/go-gost/x/config/parsing/selector" + xhop "github.com/go-gost/x/hop" + "github.com/go-gost/x/internal/loader" + "github.com/go-gost/x/internal/plugin" +) + +func ParseHop(cfg *config.HopConfig) (hop.Hop, error) { + if cfg == nil { + return nil, nil + } + + if cfg.Plugin != nil { + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return xhop.NewHTTPPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ), nil + default: + return xhop.NewGRPCPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ), nil + } + } + + var nodes []*chain.Node + for _, v := range cfg.Nodes { + if v == nil { + continue + } + + if v.Resolver == "" { + v.Resolver = cfg.Resolver + } + if v.Hosts == "" { + v.Hosts = cfg.Hosts + } + if v.Interface == "" { + v.Interface = cfg.Interface + } + if v.SockOpts == nil { + v.SockOpts = cfg.SockOpts + } + + if v.Connector == nil { + v.Connector = &config.ConnectorConfig{ + Type: "http", + } + } + + if v.Dialer == nil { + v.Dialer = &config.DialerConfig{ + Type: "tcp", + } + } + + node, err := node_parser.ParseNode(cfg.Name, v) + if err != nil { + return nil, err + } + if node != nil { + nodes = append(nodes, node) + } + } + + sel := selector_parser.ParseNodeSelector(cfg.Selector) + if sel == nil { + sel = selector_parser.DefaultNodeSelector() + } + + opts := []xhop.Option{ + xhop.NameOption(cfg.Name), + xhop.NodeOption(nodes...), + xhop.SelectorOption(sel), + xhop.BypassOption(bypass.BypassGroup(bypass_parser.List(cfg.Bypass, cfg.Bypasses...)...)), + xhop.ReloadPeriodOption(cfg.Reload), + xhop.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "hop", + "hop": cfg.Name, + })), + } + + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, xhop.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + opts = append(opts, xhop.RedisLoaderOption(loader.RedisStringLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + if cfg.HTTP != nil && cfg.HTTP.URL != "" { + opts = append(opts, xhop.HTTPLoaderOption(loader.HTTPLoader( + cfg.HTTP.URL, + loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), + ))) + } + return xhop.NewHop(opts...), nil +} diff --git a/config/parsing/hosts/parse.go b/config/parsing/hosts/parse.go new file mode 100644 index 0000000..7bcac0b --- /dev/null +++ b/config/parsing/hosts/parse.go @@ -0,0 +1,96 @@ +package hosts + +import ( + "crypto/tls" + "net" + "strings" + + "github.com/go-gost/core/hosts" + "github.com/go-gost/core/logger" + "github.com/go-gost/x/config" + xhosts "github.com/go-gost/x/hosts" + "github.com/go-gost/x/internal/loader" + "github.com/go-gost/x/internal/plugin" +) + +func ParseHostMapper(cfg *config.HostsConfig) hosts.HostMapper { + if cfg == nil { + return nil + } + + if cfg.Plugin != nil { + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return xhosts.NewHTTPPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return xhosts.NewGRPCPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) + } + } + + var mappings []xhosts.Mapping + for _, mapping := range cfg.Mappings { + if mapping.IP == "" || mapping.Hostname == "" { + continue + } + + ip := net.ParseIP(mapping.IP) + if ip == nil { + continue + } + mappings = append(mappings, xhosts.Mapping{ + Hostname: mapping.Hostname, + IP: ip, + }) + } + opts := []xhosts.Option{ + xhosts.MappingsOption(mappings), + xhosts.ReloadPeriodOption(cfg.Reload), + xhosts.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "hosts", + "hosts": cfg.Name, + })), + } + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, xhosts.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + switch cfg.Redis.Type { + case "list": // redis list + opts = append(opts, xhosts.RedisLoaderOption(loader.RedisListLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + default: // redis set + opts = append(opts, xhosts.RedisLoaderOption(loader.RedisSetLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + } + if cfg.HTTP != nil && cfg.HTTP.URL != "" { + opts = append(opts, xhosts.HTTPLoaderOption(loader.HTTPLoader( + cfg.HTTP.URL, + loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), + ))) + } + return xhosts.NewHostMapper(opts...) +} diff --git a/config/parsing/ingress/parse.go b/config/parsing/ingress/parse.go new file mode 100644 index 0000000..5c95303 --- /dev/null +++ b/config/parsing/ingress/parse.go @@ -0,0 +1,91 @@ +package ingress + +import ( + "crypto/tls" + "strings" + + "github.com/go-gost/core/ingress" + "github.com/go-gost/core/logger" + "github.com/go-gost/x/config" + xingress "github.com/go-gost/x/ingress" + "github.com/go-gost/x/internal/loader" + "github.com/go-gost/x/internal/plugin" +) + +func ParseIngress(cfg *config.IngressConfig) ingress.Ingress { + if cfg == nil { + return nil + } + + if cfg.Plugin != nil { + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return xingress.NewHTTPPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return xingress.NewGRPCPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) + } + } + + var rules []xingress.Rule + for _, rule := range cfg.Rules { + if rule.Hostname == "" || rule.Endpoint == "" { + continue + } + + rules = append(rules, xingress.Rule{ + Hostname: rule.Hostname, + Endpoint: rule.Endpoint, + }) + } + opts := []xingress.Option{ + xingress.RulesOption(rules), + xingress.ReloadPeriodOption(cfg.Reload), + xingress.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "ingress", + "ingress": cfg.Name, + })), + } + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, xingress.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + switch cfg.Redis.Type { + case "set": // redis set + opts = append(opts, xingress.RedisLoaderOption(loader.RedisSetLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + default: // redis hash + opts = append(opts, xingress.RedisLoaderOption(loader.RedisHashLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + } + if cfg.HTTP != nil && cfg.HTTP.URL != "" { + opts = append(opts, xingress.HTTPLoaderOption(loader.HTTPLoader( + cfg.HTTP.URL, + loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), + ))) + } + return xingress.NewIngress(opts...) +} diff --git a/config/parsing/limiter/parse.go b/config/parsing/limiter/parse.go new file mode 100644 index 0000000..b26b5fd --- /dev/null +++ b/config/parsing/limiter/parse.go @@ -0,0 +1,151 @@ +package limiter + +import ( + "github.com/go-gost/core/limiter/conn" + "github.com/go-gost/core/limiter/rate" + "github.com/go-gost/core/limiter/traffic" + "github.com/go-gost/core/logger" + "github.com/go-gost/x/config" + "github.com/go-gost/x/internal/loader" + xconn "github.com/go-gost/x/limiter/conn" + xrate "github.com/go-gost/x/limiter/rate" + xtraffic "github.com/go-gost/x/limiter/traffic" +) + +func ParseTrafficLimiter(cfg *config.LimiterConfig) (lim traffic.TrafficLimiter) { + if cfg == nil { + return nil + } + + var opts []xtraffic.Option + + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, xtraffic.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + switch cfg.Redis.Type { + case "list": // redis list + opts = append(opts, xtraffic.RedisLoaderOption(loader.RedisListLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + default: // redis set + opts = append(opts, xtraffic.RedisLoaderOption(loader.RedisSetLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + } + if cfg.HTTP != nil && cfg.HTTP.URL != "" { + opts = append(opts, xtraffic.HTTPLoaderOption(loader.HTTPLoader( + cfg.HTTP.URL, + loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), + ))) + } + opts = append(opts, + xtraffic.LimitsOption(cfg.Limits...), + xtraffic.ReloadPeriodOption(cfg.Reload), + xtraffic.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "limiter", + "limiter": cfg.Name, + })), + ) + + return xtraffic.NewTrafficLimiter(opts...) +} + +func ParseConnLimiter(cfg *config.LimiterConfig) (lim conn.ConnLimiter) { + if cfg == nil { + return nil + } + + var opts []xconn.Option + + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, xconn.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + switch cfg.Redis.Type { + case "list": // redis list + opts = append(opts, xconn.RedisLoaderOption(loader.RedisListLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + default: // redis set + opts = append(opts, xconn.RedisLoaderOption(loader.RedisSetLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + } + if cfg.HTTP != nil && cfg.HTTP.URL != "" { + opts = append(opts, xconn.HTTPLoaderOption(loader.HTTPLoader( + cfg.HTTP.URL, + loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), + ))) + } + opts = append(opts, + xconn.LimitsOption(cfg.Limits...), + xconn.ReloadPeriodOption(cfg.Reload), + xconn.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "limiter", + "limiter": cfg.Name, + })), + ) + + return xconn.NewConnLimiter(opts...) +} + +func ParseRateLimiter(cfg *config.LimiterConfig) (lim rate.RateLimiter) { + if cfg == nil { + return nil + } + + var opts []xrate.Option + + if cfg.File != nil && cfg.File.Path != "" { + opts = append(opts, xrate.FileLoaderOption(loader.FileLoader(cfg.File.Path))) + } + if cfg.Redis != nil && cfg.Redis.Addr != "" { + switch cfg.Redis.Type { + case "list": // redis list + opts = append(opts, xrate.RedisLoaderOption(loader.RedisListLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + default: // redis set + opts = append(opts, xrate.RedisLoaderOption(loader.RedisSetLoader( + cfg.Redis.Addr, + loader.DBRedisLoaderOption(cfg.Redis.DB), + loader.PasswordRedisLoaderOption(cfg.Redis.Password), + loader.KeyRedisLoaderOption(cfg.Redis.Key), + ))) + } + } + if cfg.HTTP != nil && cfg.HTTP.URL != "" { + opts = append(opts, xrate.HTTPLoaderOption(loader.HTTPLoader( + cfg.HTTP.URL, + loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), + ))) + } + opts = append(opts, + xrate.LimitsOption(cfg.Limits...), + xrate.ReloadPeriodOption(cfg.Reload), + xrate.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "limiter", + "limiter": cfg.Name, + })), + ) + + return xrate.NewRateLimiter(opts...) +} diff --git a/config/parsing/node/parse.go b/config/parsing/node/parse.go new file mode 100644 index 0000000..15a0cb3 --- /dev/null +++ b/config/parsing/node/parse.go @@ -0,0 +1,198 @@ +package node + +import ( + "fmt" + "net" + "strings" + "time" + + "github.com/go-gost/core/bypass" + "github.com/go-gost/core/chain" + "github.com/go-gost/core/connector" + "github.com/go-gost/core/dialer" + "github.com/go-gost/core/logger" + "github.com/go-gost/core/metadata" + mdutil "github.com/go-gost/core/metadata/util" + xauth "github.com/go-gost/x/auth" + "github.com/go-gost/x/config" + "github.com/go-gost/x/config/parsing" + auth_parser "github.com/go-gost/x/config/parsing/auth" + bypass_parser "github.com/go-gost/x/config/parsing/bypass" + tls_util "github.com/go-gost/x/internal/util/tls" + mdx "github.com/go-gost/x/metadata" + "github.com/go-gost/x/registry" +) + +func ParseNode(hop string, cfg *config.NodeConfig) (*chain.Node, error) { + if cfg == nil { + return nil, nil + } + + if cfg.Connector == nil { + cfg.Connector = &config.ConnectorConfig{ + Type: "http", + } + } + + if cfg.Dialer == nil { + cfg.Dialer = &config.DialerConfig{ + Type: "tcp", + } + } + + nodeLogger := logger.Default().WithFields(map[string]any{ + "hop": hop, + "kind": "node", + "node": cfg.Name, + "connector": cfg.Connector.Type, + "dialer": cfg.Dialer.Type, + }) + + serverName, _, _ := net.SplitHostPort(cfg.Addr) + + tlsCfg := cfg.Connector.TLS + if tlsCfg == nil { + tlsCfg = &config.TLSConfig{} + } + if tlsCfg.ServerName == "" { + tlsCfg.ServerName = serverName + } + tlsConfig, err := tls_util.LoadClientConfig( + tlsCfg.CertFile, tlsCfg.KeyFile, tlsCfg.CAFile, + tlsCfg.Secure, tlsCfg.ServerName) + if err != nil { + nodeLogger.Error(err) + return nil, err + } + + var nm metadata.Metadata + if cfg.Metadata != nil { + nm = mdx.NewMetadata(cfg.Metadata) + } + + connectorLogger := nodeLogger.WithFields(map[string]any{ + "kind": "connector", + }) + var cr connector.Connector + if rf := registry.ConnectorRegistry().Get(cfg.Connector.Type); rf != nil { + cr = rf( + connector.AuthOption(auth_parser.Info(cfg.Connector.Auth)), + connector.TLSConfigOption(tlsConfig), + connector.LoggerOption(connectorLogger), + ) + } else { + return nil, fmt.Errorf("unregistered connector: %s", cfg.Connector.Type) + } + + if cfg.Connector.Metadata == nil { + cfg.Connector.Metadata = make(map[string]any) + } + if err := cr.Init(mdx.NewMetadata(cfg.Connector.Metadata)); err != nil { + connectorLogger.Error("init: ", err) + return nil, err + } + + tlsCfg = cfg.Dialer.TLS + if tlsCfg == nil { + tlsCfg = &config.TLSConfig{} + } + if tlsCfg.ServerName == "" { + tlsCfg.ServerName = serverName + } + tlsConfig, err = tls_util.LoadClientConfig( + tlsCfg.CertFile, tlsCfg.KeyFile, tlsCfg.CAFile, + tlsCfg.Secure, tlsCfg.ServerName) + if err != nil { + nodeLogger.Error(err) + return nil, err + } + + var ppv int + if nm != nil { + ppv = mdutil.GetInt(nm, parsing.MDKeyProxyProtocol) + } + + dialerLogger := nodeLogger.WithFields(map[string]any{ + "kind": "dialer", + }) + + var d dialer.Dialer + if rf := registry.DialerRegistry().Get(cfg.Dialer.Type); rf != nil { + d = rf( + dialer.AuthOption(auth_parser.Info(cfg.Dialer.Auth)), + dialer.TLSConfigOption(tlsConfig), + dialer.LoggerOption(dialerLogger), + dialer.ProxyProtocolOption(ppv), + ) + } else { + return nil, fmt.Errorf("unregistered dialer: %s", cfg.Dialer.Type) + } + + if cfg.Dialer.Metadata == nil { + cfg.Dialer.Metadata = make(map[string]any) + } + if err := d.Init(mdx.NewMetadata(cfg.Dialer.Metadata)); err != nil { + dialerLogger.Error("init: ", err) + return nil, err + } + + var sockOpts *chain.SockOpts + if cfg.SockOpts != nil { + sockOpts = &chain.SockOpts{ + Mark: cfg.SockOpts.Mark, + } + } + + tr := chain.NewTransport(d, cr, + chain.AddrTransportOption(cfg.Addr), + chain.InterfaceTransportOption(cfg.Interface), + chain.SockOptsTransportOption(sockOpts), + chain.TimeoutTransportOption(10*time.Second), + ) + + // convert *.example.com to .example.com + // convert *example.com to example.com + host := cfg.Host + if strings.HasPrefix(host, "*") { + host = host[1:] + if !strings.HasPrefix(host, ".") { + host = "." + host + } + } + + opts := []chain.NodeOption{ + chain.TransportNodeOption(tr), + chain.BypassNodeOption(bypass.BypassGroup(bypass_parser.List(cfg.Bypass, cfg.Bypasses...)...)), + chain.ResoloverNodeOption(registry.ResolverRegistry().Get(cfg.Resolver)), + chain.HostMapperNodeOption(registry.HostsRegistry().Get(cfg.Hosts)), + chain.MetadataNodeOption(nm), + chain.HostNodeOption(host), + chain.ProtocolNodeOption(cfg.Protocol), + } + if cfg.HTTP != nil { + opts = append(opts, chain.HTTPNodeOption(&chain.HTTPNodeSettings{ + Host: cfg.HTTP.Host, + Header: cfg.HTTP.Header, + })) + } + if cfg.TLS != nil { + opts = append(opts, chain.TLSNodeOption(&chain.TLSNodeSettings{ + ServerName: cfg.TLS.ServerName, + Secure: cfg.TLS.Secure, + })) + } + if cfg.Auth != nil { + opts = append(opts, chain.AutherNodeOption( + xauth.NewAuthenticator( + xauth.AuthsOption(map[string]string{cfg.Auth.Username: cfg.Auth.Password}), + xauth.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "node", + "node": cfg.Name, + "addr": cfg.Addr, + "host": cfg.Host, + "protocol": cfg.Protocol, + })), + ))) + } + return chain.NewNode(cfg.Name, cfg.Addr, opts...), nil +} diff --git a/config/parsing/parse.go b/config/parsing/parse.go index 77a4ef5..624426e 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -1,820 +1,17 @@ package parsing -import ( - "context" - "crypto/tls" - "net" - "net/http" - "net/url" - "strings" - - "github.com/go-gost/core/admission" - "github.com/go-gost/core/auth" - "github.com/go-gost/core/bypass" - "github.com/go-gost/core/chain" - "github.com/go-gost/core/hosts" - "github.com/go-gost/core/ingress" - "github.com/go-gost/core/limiter/conn" - "github.com/go-gost/core/limiter/rate" - "github.com/go-gost/core/limiter/traffic" - "github.com/go-gost/core/logger" - "github.com/go-gost/core/recorder" - "github.com/go-gost/core/resolver" - "github.com/go-gost/core/selector" - admission_impl "github.com/go-gost/x/admission" - auth_impl "github.com/go-gost/x/auth" - bypass_impl "github.com/go-gost/x/bypass" - "github.com/go-gost/x/config" - xhosts "github.com/go-gost/x/hosts" - xingress "github.com/go-gost/x/ingress" - "github.com/go-gost/x/internal/loader" - "github.com/go-gost/x/internal/util/plugin" - xconn "github.com/go-gost/x/limiter/conn" - xrate "github.com/go-gost/x/limiter/rate" - xtraffic "github.com/go-gost/x/limiter/traffic" - xrecorder "github.com/go-gost/x/recorder" - "github.com/go-gost/x/registry" - resolver_impl "github.com/go-gost/x/resolver" - xs "github.com/go-gost/x/selector" - "google.golang.org/grpc" - "google.golang.org/grpc/backoff" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" -) - const ( - mdKeyProxyProtocol = "proxyProtocol" - mdKeyInterface = "interface" - mdKeySoMark = "so_mark" - mdKeyHash = "hash" - mdKeyPreUp = "preUp" - mdKeyPreDown = "preDown" - mdKeyPostUp = "postUp" - mdKeyPostDown = "postDown" - mdKeyIgnoreChain = "ignoreChain" + MDKeyProxyProtocol = "proxyProtocol" + MDKeyInterface = "interface" + MDKeySoMark = "so_mark" + MDKeyHash = "hash" + MDKeyPreUp = "preUp" + MDKeyPreDown = "preDown" + MDKeyPostUp = "postUp" + MDKeyPostDown = "postDown" + MDKeyIgnoreChain = "ignoreChain" - mdKeyRecorderDirection = "direction" - mdKeyRecorderTimestampFormat = "timeStampFormat" - mdKeyRecorderHexdump = "hexdump" + MDKeyRecorderDirection = "direction" + MDKeyRecorderTimestampFormat = "timeStampFormat" + MDKeyRecorderHexdump = "hexdump" ) - -func ParseAuther(cfg *config.AutherConfig) auth.Authenticator { - if cfg == nil { - return nil - } - - if cfg.Plugin != nil { - var tlsCfg *tls.Config - if cfg.Plugin.TLS != nil { - tlsCfg = &tls.Config{ - ServerName: cfg.Plugin.TLS.ServerName, - InsecureSkipVerify: !cfg.Plugin.TLS.Secure, - } - } - switch cfg.Plugin.Type { - case "http": - return auth_impl.NewHTTPPluginAuthenticator( - cfg.Name, cfg.Plugin.Addr, - plugin.TLSConfigOption(tlsCfg), - plugin.TimeoutOption(cfg.Plugin.Timeout), - ) - default: - return auth_impl.NewGRPCPluginAuthenticator( - cfg.Name, cfg.Plugin.Addr, - plugin.TokenOption(cfg.Plugin.Token), - plugin.TLSConfigOption(tlsCfg), - ) - } - } - - m := make(map[string]string) - - for _, user := range cfg.Auths { - if user.Username == "" { - continue - } - m[user.Username] = user.Password - } - - opts := []auth_impl.Option{ - auth_impl.AuthsOption(m), - auth_impl.ReloadPeriodOption(cfg.Reload), - auth_impl.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "auther", - "auther": cfg.Name, - })), - } - if cfg.File != nil && cfg.File.Path != "" { - opts = append(opts, auth_impl.FileLoaderOption(loader.FileLoader(cfg.File.Path))) - } - if cfg.Redis != nil && cfg.Redis.Addr != "" { - opts = append(opts, auth_impl.RedisLoaderOption(loader.RedisHashLoader( - cfg.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Redis.Key), - ))) - } - if cfg.HTTP != nil && cfg.HTTP.URL != "" { - opts = append(opts, auth_impl.HTTPLoaderOption(loader.HTTPLoader( - cfg.HTTP.URL, - loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), - ))) - } - return auth_impl.NewAuthenticator(opts...) -} - -func ParseAutherFromAuth(au *config.AuthConfig) auth.Authenticator { - if au == nil || au.Username == "" { - return nil - } - return auth_impl.NewAuthenticator( - auth_impl.AuthsOption( - map[string]string{ - au.Username: au.Password, - }, - ), - auth_impl.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "auther", - })), - ) -} - -func parseAuth(cfg *config.AuthConfig) *url.Userinfo { - if cfg == nil || cfg.Username == "" { - return nil - } - - if cfg.Password == "" { - return url.User(cfg.Username) - } - return url.UserPassword(cfg.Username, cfg.Password) -} - -func parseChainSelector(cfg *config.SelectorConfig) selector.Selector[chain.Chainer] { - if cfg == nil { - return nil - } - - var strategy selector.Strategy[chain.Chainer] - switch cfg.Strategy { - case "round", "rr": - strategy = xs.RoundRobinStrategy[chain.Chainer]() - case "random", "rand": - strategy = xs.RandomStrategy[chain.Chainer]() - case "fifo", "ha": - strategy = xs.FIFOStrategy[chain.Chainer]() - case "hash": - strategy = xs.HashStrategy[chain.Chainer]() - default: - strategy = xs.RoundRobinStrategy[chain.Chainer]() - } - return xs.NewSelector( - strategy, - xs.FailFilter[chain.Chainer](cfg.MaxFails, cfg.FailTimeout), - xs.BackupFilter[chain.Chainer](), - ) -} - -func parseNodeSelector(cfg *config.SelectorConfig) selector.Selector[*chain.Node] { - if cfg == nil { - return nil - } - - var strategy selector.Strategy[*chain.Node] - switch cfg.Strategy { - case "round", "rr": - strategy = xs.RoundRobinStrategy[*chain.Node]() - case "random", "rand": - strategy = xs.RandomStrategy[*chain.Node]() - case "fifo", "ha": - strategy = xs.FIFOStrategy[*chain.Node]() - case "hash": - strategy = xs.HashStrategy[*chain.Node]() - default: - strategy = xs.RoundRobinStrategy[*chain.Node]() - } - - return xs.NewSelector( - strategy, - xs.FailFilter[*chain.Node](cfg.MaxFails, cfg.FailTimeout), - xs.BackupFilter[*chain.Node](), - ) -} - -func ParseAdmission(cfg *config.AdmissionConfig) admission.Admission { - if cfg == nil { - return nil - } - - if cfg.Plugin != nil { - var tlsCfg *tls.Config - if cfg.Plugin.TLS != nil { - tlsCfg = &tls.Config{ - ServerName: cfg.Plugin.TLS.ServerName, - InsecureSkipVerify: !cfg.Plugin.TLS.Secure, - } - } - switch strings.ToLower(cfg.Plugin.Type) { - case "http": - return admission_impl.NewHTTPPluginAdmission( - cfg.Name, cfg.Plugin.Addr, - plugin.TLSConfigOption(tlsCfg), - plugin.TimeoutOption(cfg.Plugin.Timeout), - ) - default: - return admission_impl.NewGRPCPluginAdmission( - cfg.Name, cfg.Plugin.Addr, - plugin.TokenOption(cfg.Plugin.Token), - plugin.TLSConfigOption(tlsCfg), - ) - } - } - - opts := []admission_impl.Option{ - admission_impl.MatchersOption(cfg.Matchers), - admission_impl.WhitelistOption(cfg.Reverse || cfg.Whitelist), - admission_impl.ReloadPeriodOption(cfg.Reload), - admission_impl.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "admission", - "admission": cfg.Name, - })), - } - if cfg.File != nil && cfg.File.Path != "" { - opts = append(opts, admission_impl.FileLoaderOption(loader.FileLoader(cfg.File.Path))) - } - if cfg.Redis != nil && cfg.Redis.Addr != "" { - opts = append(opts, admission_impl.RedisLoaderOption(loader.RedisSetLoader( - cfg.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Redis.Key), - ))) - } - if cfg.HTTP != nil && cfg.HTTP.URL != "" { - opts = append(opts, admission_impl.HTTPLoaderOption(loader.HTTPLoader( - cfg.HTTP.URL, - loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), - ))) - } - - return admission_impl.NewAdmission(opts...) -} - -func ParseBypass(cfg *config.BypassConfig) bypass.Bypass { - if cfg == nil { - return nil - } - - if cfg.Plugin != nil { - var tlsCfg *tls.Config - if cfg.Plugin.TLS != nil { - tlsCfg = &tls.Config{ - ServerName: cfg.Plugin.TLS.ServerName, - InsecureSkipVerify: !cfg.Plugin.TLS.Secure, - } - } - switch strings.ToLower(cfg.Plugin.Type) { - case "http": - return bypass_impl.NewHTTPPluginBypass( - cfg.Name, cfg.Plugin.Addr, - plugin.TLSConfigOption(tlsCfg), - plugin.TimeoutOption(cfg.Plugin.Timeout), - ) - default: - return bypass_impl.NewGRPCPluginBypass( - cfg.Name, cfg.Plugin.Addr, - plugin.TokenOption(cfg.Plugin.Token), - plugin.TLSConfigOption(tlsCfg), - ) - } - } - - opts := []bypass_impl.Option{ - bypass_impl.MatchersOption(cfg.Matchers), - bypass_impl.WhitelistOption(cfg.Reverse || cfg.Whitelist), - bypass_impl.ReloadPeriodOption(cfg.Reload), - bypass_impl.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "bypass", - "bypass": cfg.Name, - })), - } - if cfg.File != nil && cfg.File.Path != "" { - opts = append(opts, bypass_impl.FileLoaderOption(loader.FileLoader(cfg.File.Path))) - } - if cfg.Redis != nil && cfg.Redis.Addr != "" { - opts = append(opts, bypass_impl.RedisLoaderOption(loader.RedisSetLoader( - cfg.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Redis.Key), - ))) - } - if cfg.HTTP != nil && cfg.HTTP.URL != "" { - opts = append(opts, bypass_impl.HTTPLoaderOption(loader.HTTPLoader( - cfg.HTTP.URL, - loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), - ))) - } - - return bypass_impl.NewBypass(opts...) -} - -func ParseResolver(cfg *config.ResolverConfig) (resolver.Resolver, error) { - if cfg == nil { - return nil, nil - } - - if cfg.Plugin != nil { - var tlsCfg *tls.Config - if cfg.Plugin.TLS != nil { - tlsCfg = &tls.Config{ - ServerName: cfg.Plugin.TLS.ServerName, - InsecureSkipVerify: !cfg.Plugin.TLS.Secure, - } - } - switch strings.ToLower(cfg.Plugin.Type) { - case "http": - return resolver_impl.NewHTTPPluginResolver( - cfg.Name, cfg.Plugin.Addr, - plugin.TLSConfigOption(tlsCfg), - plugin.TimeoutOption(cfg.Plugin.Timeout), - ), nil - default: - return resolver_impl.NewGRPCPluginResolver( - cfg.Name, cfg.Plugin.Addr, - plugin.TokenOption(cfg.Plugin.Token), - plugin.TLSConfigOption(tlsCfg), - ) - } - } - - var nameservers []resolver_impl.NameServer - for _, server := range cfg.Nameservers { - nameservers = append(nameservers, resolver_impl.NameServer{ - Addr: server.Addr, - Chain: registry.ChainRegistry().Get(server.Chain), - TTL: server.TTL, - Timeout: server.Timeout, - ClientIP: net.ParseIP(server.ClientIP), - Prefer: server.Prefer, - Hostname: server.Hostname, - }) - } - - return resolver_impl.NewResolver( - nameservers, - resolver_impl.LoggerOption( - logger.Default().WithFields(map[string]any{ - "kind": "resolver", - "resolver": cfg.Name, - }), - ), - ) -} - -func ParseHosts(cfg *config.HostsConfig) hosts.HostMapper { - if cfg == nil { - return nil - } - - if cfg.Plugin != nil { - var tlsCfg *tls.Config - if cfg.Plugin.TLS != nil { - tlsCfg = &tls.Config{ - ServerName: cfg.Plugin.TLS.ServerName, - InsecureSkipVerify: !cfg.Plugin.TLS.Secure, - } - } - switch strings.ToLower(cfg.Plugin.Type) { - case "http": - return xhosts.NewHTTPPluginHostMapper( - cfg.Name, cfg.Plugin.Addr, - plugin.TLSConfigOption(tlsCfg), - plugin.TimeoutOption(cfg.Plugin.Timeout), - ) - default: - return xhosts.NewGRPCPluginHostMapper( - cfg.Name, cfg.Plugin.Addr, - plugin.TokenOption(cfg.Plugin.Token), - plugin.TLSConfigOption(tlsCfg), - ) - } - } - - var mappings []xhosts.Mapping - for _, mapping := range cfg.Mappings { - if mapping.IP == "" || mapping.Hostname == "" { - continue - } - - ip := net.ParseIP(mapping.IP) - if ip == nil { - continue - } - mappings = append(mappings, xhosts.Mapping{ - Hostname: mapping.Hostname, - IP: ip, - }) - } - opts := []xhosts.Option{ - xhosts.MappingsOption(mappings), - xhosts.ReloadPeriodOption(cfg.Reload), - xhosts.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "hosts", - "hosts": cfg.Name, - })), - } - if cfg.File != nil && cfg.File.Path != "" { - opts = append(opts, xhosts.FileLoaderOption(loader.FileLoader(cfg.File.Path))) - } - if cfg.Redis != nil && cfg.Redis.Addr != "" { - switch cfg.Redis.Type { - case "list": // redis list - opts = append(opts, xhosts.RedisLoaderOption(loader.RedisListLoader( - cfg.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Redis.Key), - ))) - default: // redis set - opts = append(opts, xhosts.RedisLoaderOption(loader.RedisSetLoader( - cfg.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Redis.Key), - ))) - } - } - if cfg.HTTP != nil && cfg.HTTP.URL != "" { - opts = append(opts, xhosts.HTTPLoaderOption(loader.HTTPLoader( - cfg.HTTP.URL, - loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), - ))) - } - return xhosts.NewHostMapper(opts...) -} - -func ParseIngress(cfg *config.IngressConfig) ingress.Ingress { - if cfg == nil { - return nil - } - - if cfg.Plugin != nil { - var tlsCfg *tls.Config - if cfg.Plugin.TLS != nil { - tlsCfg = &tls.Config{ - ServerName: cfg.Plugin.TLS.ServerName, - InsecureSkipVerify: !cfg.Plugin.TLS.Secure, - } - } - switch strings.ToLower(cfg.Plugin.Type) { - case "http": - return xingress.NewHTTPPluginIngress( - cfg.Name, cfg.Plugin.Addr, - plugin.TLSConfigOption(tlsCfg), - plugin.TimeoutOption(cfg.Plugin.Timeout), - ) - default: - return xingress.NewGRPCPluginIngress( - cfg.Name, cfg.Plugin.Addr, - plugin.TokenOption(cfg.Plugin.Token), - plugin.TLSConfigOption(tlsCfg), - ) - } - } - - var rules []xingress.Rule - for _, rule := range cfg.Rules { - if rule.Hostname == "" || rule.Endpoint == "" { - continue - } - - rules = append(rules, xingress.Rule{ - Hostname: rule.Hostname, - Endpoint: rule.Endpoint, - }) - } - opts := []xingress.Option{ - xingress.RulesOption(rules), - xingress.ReloadPeriodOption(cfg.Reload), - xingress.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "ingress", - "ingress": cfg.Name, - })), - } - if cfg.File != nil && cfg.File.Path != "" { - opts = append(opts, xingress.FileLoaderOption(loader.FileLoader(cfg.File.Path))) - } - if cfg.Redis != nil && cfg.Redis.Addr != "" { - switch cfg.Redis.Type { - case "set": // redis set - opts = append(opts, xingress.RedisLoaderOption(loader.RedisSetLoader( - cfg.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Redis.Key), - ))) - default: // redis hash - opts = append(opts, xingress.RedisLoaderOption(loader.RedisHashLoader( - cfg.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Redis.Key), - ))) - } - } - if cfg.HTTP != nil && cfg.HTTP.URL != "" { - opts = append(opts, xingress.HTTPLoaderOption(loader.HTTPLoader( - cfg.HTTP.URL, - loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), - ))) - } - return xingress.NewIngress(opts...) -} - -func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) { - if cfg == nil { - return nil - } - - if cfg.Plugin != nil { - var tlsCfg *tls.Config - if cfg.Plugin.TLS != nil { - tlsCfg = &tls.Config{ - ServerName: cfg.Plugin.TLS.ServerName, - InsecureSkipVerify: !cfg.Plugin.TLS.Secure, - } - } - switch strings.ToLower(cfg.Plugin.Type) { - case "http": - return xrecorder.NewHTTPPluginRecorder( - cfg.Name, cfg.Plugin.Addr, - plugin.TLSConfigOption(tlsCfg), - plugin.TimeoutOption(cfg.Plugin.Timeout), - ) - default: - return xrecorder.NewGRPCPluginRecorder( - cfg.Name, cfg.Plugin.Addr, - plugin.TokenOption(cfg.Plugin.Token), - plugin.TLSConfigOption(tlsCfg), - ) - } - } - - if cfg.File != nil && cfg.File.Path != "" { - return xrecorder.FileRecorder(cfg.File.Path, - xrecorder.SepRecorderOption(cfg.File.Sep), - ) - } - - if cfg.TCP != nil && cfg.TCP.Addr != "" { - return xrecorder.TCPRecorder(cfg.TCP.Addr, xrecorder.TimeoutTCPRecorderOption(cfg.TCP.Timeout)) - } - - if cfg.HTTP != nil && cfg.HTTP.URL != "" { - return xrecorder.HTTPRecorder(cfg.HTTP.URL, xrecorder.TimeoutHTTPRecorderOption(cfg.HTTP.Timeout)) - } - - if cfg.Redis != nil && - cfg.Redis.Addr != "" && - cfg.Redis.Key != "" { - switch cfg.Redis.Type { - case "list": // redis list - return xrecorder.RedisListRecorder(cfg.Redis.Addr, - xrecorder.DBRedisRecorderOption(cfg.Redis.DB), - xrecorder.KeyRedisRecorderOption(cfg.Redis.Key), - xrecorder.PasswordRedisRecorderOption(cfg.Redis.Password), - ) - case "sset": // sorted set - return xrecorder.RedisSortedSetRecorder(cfg.Redis.Addr, - xrecorder.DBRedisRecorderOption(cfg.Redis.DB), - xrecorder.KeyRedisRecorderOption(cfg.Redis.Key), - xrecorder.PasswordRedisRecorderOption(cfg.Redis.Password), - ) - default: // redis set - return xrecorder.RedisSetRecorder(cfg.Redis.Addr, - xrecorder.DBRedisRecorderOption(cfg.Redis.DB), - xrecorder.KeyRedisRecorderOption(cfg.Redis.Key), - xrecorder.PasswordRedisRecorderOption(cfg.Redis.Password), - ) - } - } - - return -} - -func defaultNodeSelector() selector.Selector[*chain.Node] { - return xs.NewSelector( - xs.RoundRobinStrategy[*chain.Node](), - xs.FailFilter[*chain.Node](xs.DefaultMaxFails, xs.DefaultFailTimeout), - xs.BackupFilter[*chain.Node](), - ) -} - -func defaultChainSelector() selector.Selector[chain.Chainer] { - return xs.NewSelector( - xs.RoundRobinStrategy[chain.Chainer](), - xs.FailFilter[chain.Chainer](xs.DefaultMaxFails, xs.DefaultFailTimeout), - xs.BackupFilter[chain.Chainer](), - ) -} - -func ParseTrafficLimiter(cfg *config.LimiterConfig) (lim traffic.TrafficLimiter) { - if cfg == nil { - return nil - } - - var opts []xtraffic.Option - - if cfg.File != nil && cfg.File.Path != "" { - opts = append(opts, xtraffic.FileLoaderOption(loader.FileLoader(cfg.File.Path))) - } - if cfg.Redis != nil && cfg.Redis.Addr != "" { - switch cfg.Redis.Type { - case "list": // redis list - opts = append(opts, xtraffic.RedisLoaderOption(loader.RedisListLoader( - cfg.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Redis.Key), - ))) - default: // redis set - opts = append(opts, xtraffic.RedisLoaderOption(loader.RedisSetLoader( - cfg.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Redis.Key), - ))) - } - } - if cfg.HTTP != nil && cfg.HTTP.URL != "" { - opts = append(opts, xtraffic.HTTPLoaderOption(loader.HTTPLoader( - cfg.HTTP.URL, - loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), - ))) - } - opts = append(opts, - xtraffic.LimitsOption(cfg.Limits...), - xtraffic.ReloadPeriodOption(cfg.Reload), - xtraffic.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "limiter", - "limiter": cfg.Name, - })), - ) - - return xtraffic.NewTrafficLimiter(opts...) -} - -func ParseConnLimiter(cfg *config.LimiterConfig) (lim conn.ConnLimiter) { - if cfg == nil { - return nil - } - - var opts []xconn.Option - - if cfg.File != nil && cfg.File.Path != "" { - opts = append(opts, xconn.FileLoaderOption(loader.FileLoader(cfg.File.Path))) - } - if cfg.Redis != nil && cfg.Redis.Addr != "" { - switch cfg.Redis.Type { - case "list": // redis list - opts = append(opts, xconn.RedisLoaderOption(loader.RedisListLoader( - cfg.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Redis.Key), - ))) - default: // redis set - opts = append(opts, xconn.RedisLoaderOption(loader.RedisSetLoader( - cfg.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Redis.Key), - ))) - } - } - if cfg.HTTP != nil && cfg.HTTP.URL != "" { - opts = append(opts, xconn.HTTPLoaderOption(loader.HTTPLoader( - cfg.HTTP.URL, - loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), - ))) - } - opts = append(opts, - xconn.LimitsOption(cfg.Limits...), - xconn.ReloadPeriodOption(cfg.Reload), - xconn.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "limiter", - "limiter": cfg.Name, - })), - ) - - return xconn.NewConnLimiter(opts...) -} - -func ParseRateLimiter(cfg *config.LimiterConfig) (lim rate.RateLimiter) { - if cfg == nil { - return nil - } - - var opts []xrate.Option - - if cfg.File != nil && cfg.File.Path != "" { - opts = append(opts, xrate.FileLoaderOption(loader.FileLoader(cfg.File.Path))) - } - if cfg.Redis != nil && cfg.Redis.Addr != "" { - switch cfg.Redis.Type { - case "list": // redis list - opts = append(opts, xrate.RedisLoaderOption(loader.RedisListLoader( - cfg.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Redis.Key), - ))) - default: // redis set - opts = append(opts, xrate.RedisLoaderOption(loader.RedisSetLoader( - cfg.Redis.Addr, - loader.DBRedisLoaderOption(cfg.Redis.DB), - loader.PasswordRedisLoaderOption(cfg.Redis.Password), - loader.KeyRedisLoaderOption(cfg.Redis.Key), - ))) - } - } - if cfg.HTTP != nil && cfg.HTTP.URL != "" { - opts = append(opts, xrate.HTTPLoaderOption(loader.HTTPLoader( - cfg.HTTP.URL, - loader.TimeoutHTTPLoaderOption(cfg.HTTP.Timeout), - ))) - } - opts = append(opts, - xrate.LimitsOption(cfg.Limits...), - xrate.ReloadPeriodOption(cfg.Reload), - xrate.LoggerOption(logger.Default().WithFields(map[string]any{ - "kind": "limiter", - "limiter": cfg.Name, - })), - ) - - return xrate.NewRateLimiter(opts...) -} - -func newGRPCPluginConn(cfg *config.PluginConfig) (*grpc.ClientConn, error) { - grpcOpts := []grpc.DialOption{ - // grpc.WithBlock(), - grpc.WithConnectParams(grpc.ConnectParams{ - Backoff: backoff.DefaultConfig, - }), - grpc.FailOnNonTempDialError(true), - } - if tlsCfg := cfg.TLS; tlsCfg != nil && tlsCfg.Secure { - grpcOpts = append(grpcOpts, - grpc.WithAuthority(tlsCfg.ServerName), - grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ - ServerName: tlsCfg.ServerName, - InsecureSkipVerify: !tlsCfg.Secure, - }))) - } else { - grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) - } - if cfg.Token != "" { - grpcOpts = append(grpcOpts, grpc.WithPerRPCCredentials(&rpcCredentials{token: cfg.Token})) - } - return grpc.Dial(cfg.Addr, grpcOpts...) -} - -type rpcCredentials struct { - token string -} - -func (c *rpcCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { - return map[string]string{ - "token": c.token, - }, nil -} - -func (c *rpcCredentials) RequireTransportSecurity() bool { - return false -} - -func newHTTPPluginClient(cfg *config.PluginConfig) *http.Client { - if cfg == nil { - return nil - } - - tr := &http.Transport{} - if cfg.TLS != nil { - if cfg.TLS.Secure { - tr.TLSClientConfig = &tls.Config{ - ServerName: cfg.TLS.ServerName, - } - } else { - tr.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, - } - } - } - return &http.Client{ - Timeout: cfg.Timeout, - Transport: tr, - } -} diff --git a/config/parsing/recorder/parse.go b/config/parsing/recorder/parse.go new file mode 100644 index 0000000..de6d049 --- /dev/null +++ b/config/parsing/recorder/parse.go @@ -0,0 +1,82 @@ +package recorder + +import ( + "crypto/tls" + "strings" + + "github.com/go-gost/core/recorder" + "github.com/go-gost/x/config" + "github.com/go-gost/x/internal/plugin" + xrecorder "github.com/go-gost/x/recorder" +) + +func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) { + if cfg == nil { + return nil + } + + if cfg.Plugin != nil { + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return xrecorder.NewHTTPPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return xrecorder.NewGRPCPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) + } + } + + if cfg.File != nil && cfg.File.Path != "" { + return xrecorder.FileRecorder(cfg.File.Path, + xrecorder.SepRecorderOption(cfg.File.Sep), + ) + } + + if cfg.TCP != nil && cfg.TCP.Addr != "" { + return xrecorder.TCPRecorder(cfg.TCP.Addr, xrecorder.TimeoutTCPRecorderOption(cfg.TCP.Timeout)) + } + + if cfg.HTTP != nil && cfg.HTTP.URL != "" { + return xrecorder.HTTPRecorder(cfg.HTTP.URL, xrecorder.TimeoutHTTPRecorderOption(cfg.HTTP.Timeout)) + } + + if cfg.Redis != nil && + cfg.Redis.Addr != "" && + cfg.Redis.Key != "" { + switch cfg.Redis.Type { + case "list": // redis list + return xrecorder.RedisListRecorder(cfg.Redis.Addr, + xrecorder.DBRedisRecorderOption(cfg.Redis.DB), + xrecorder.KeyRedisRecorderOption(cfg.Redis.Key), + xrecorder.PasswordRedisRecorderOption(cfg.Redis.Password), + ) + case "sset": // sorted set + return xrecorder.RedisSortedSetRecorder(cfg.Redis.Addr, + xrecorder.DBRedisRecorderOption(cfg.Redis.DB), + xrecorder.KeyRedisRecorderOption(cfg.Redis.Key), + xrecorder.PasswordRedisRecorderOption(cfg.Redis.Password), + ) + default: // redis set + return xrecorder.RedisSetRecorder(cfg.Redis.Addr, + xrecorder.DBRedisRecorderOption(cfg.Redis.DB), + xrecorder.KeyRedisRecorderOption(cfg.Redis.Key), + xrecorder.PasswordRedisRecorderOption(cfg.Redis.Password), + ) + } + } + + return +} diff --git a/config/parsing/resolver/parse.go b/config/parsing/resolver/parse.go new file mode 100644 index 0000000..3ed3b53 --- /dev/null +++ b/config/parsing/resolver/parse.go @@ -0,0 +1,67 @@ +package resolver + +import ( + "crypto/tls" + "net" + "strings" + + "github.com/go-gost/core/logger" + "github.com/go-gost/core/resolver" + "github.com/go-gost/x/config" + "github.com/go-gost/x/internal/plugin" + "github.com/go-gost/x/registry" + xresolver "github.com/go-gost/x/resolver" +) + +func ParseResolver(cfg *config.ResolverConfig) (resolver.Resolver, error) { + if cfg == nil { + return nil, nil + } + + if cfg.Plugin != nil { + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return xresolver.NewHTTPPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ), nil + default: + return xresolver.NewGRPCPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) + } + } + + var nameservers []xresolver.NameServer + for _, server := range cfg.Nameservers { + nameservers = append(nameservers, xresolver.NameServer{ + Addr: server.Addr, + Chain: registry.ChainRegistry().Get(server.Chain), + TTL: server.TTL, + Timeout: server.Timeout, + ClientIP: net.ParseIP(server.ClientIP), + Prefer: server.Prefer, + Hostname: server.Hostname, + }) + } + + return xresolver.NewResolver( + nameservers, + xresolver.LoggerOption( + logger.Default().WithFields(map[string]any{ + "kind": "resolver", + "resolver": cfg.Name, + }), + ), + ) +} diff --git a/config/parsing/selector/parse.go b/config/parsing/selector/parse.go new file mode 100644 index 0000000..b017d6f --- /dev/null +++ b/config/parsing/selector/parse.go @@ -0,0 +1,75 @@ +package selector + +import ( + "github.com/go-gost/core/chain" + "github.com/go-gost/core/selector" + "github.com/go-gost/x/config" + xs "github.com/go-gost/x/selector" +) + +func ParseChainSelector(cfg *config.SelectorConfig) selector.Selector[chain.Chainer] { + if cfg == nil { + return nil + } + + var strategy selector.Strategy[chain.Chainer] + switch cfg.Strategy { + case "round", "rr": + strategy = xs.RoundRobinStrategy[chain.Chainer]() + case "random", "rand": + strategy = xs.RandomStrategy[chain.Chainer]() + case "fifo", "ha": + strategy = xs.FIFOStrategy[chain.Chainer]() + case "hash": + strategy = xs.HashStrategy[chain.Chainer]() + default: + strategy = xs.RoundRobinStrategy[chain.Chainer]() + } + return xs.NewSelector( + strategy, + xs.FailFilter[chain.Chainer](cfg.MaxFails, cfg.FailTimeout), + xs.BackupFilter[chain.Chainer](), + ) +} + +func ParseNodeSelector(cfg *config.SelectorConfig) selector.Selector[*chain.Node] { + if cfg == nil { + return nil + } + + var strategy selector.Strategy[*chain.Node] + switch cfg.Strategy { + case "round", "rr": + strategy = xs.RoundRobinStrategy[*chain.Node]() + case "random", "rand": + strategy = xs.RandomStrategy[*chain.Node]() + case "fifo", "ha": + strategy = xs.FIFOStrategy[*chain.Node]() + case "hash": + strategy = xs.HashStrategy[*chain.Node]() + default: + strategy = xs.RoundRobinStrategy[*chain.Node]() + } + + return xs.NewSelector( + strategy, + xs.FailFilter[*chain.Node](cfg.MaxFails, cfg.FailTimeout), + xs.BackupFilter[*chain.Node](), + ) +} + +func DefaultNodeSelector() selector.Selector[*chain.Node] { + return xs.NewSelector( + xs.RoundRobinStrategy[*chain.Node](), + xs.FailFilter[*chain.Node](xs.DefaultMaxFails, xs.DefaultFailTimeout), + xs.BackupFilter[*chain.Node](), + ) +} + +func DefaultChainSelector() selector.Selector[chain.Chainer] { + return xs.NewSelector( + xs.RoundRobinStrategy[chain.Chainer](), + xs.FailFilter[chain.Chainer](xs.DefaultMaxFails, xs.DefaultFailTimeout), + xs.BackupFilter[chain.Chainer](), + ) +} diff --git a/config/parsing/service.go b/config/parsing/service/parse.go similarity index 74% rename from config/parsing/service.go rename to config/parsing/service/parse.go index 18e9ced..ceb4d15 100644 --- a/config/parsing/service.go +++ b/config/parsing/service/parse.go @@ -1,4 +1,5 @@ -package parsing +package service + import ( "fmt" @@ -8,6 +9,7 @@ import ( "github.com/go-gost/core/auth" "github.com/go-gost/core/bypass" "github.com/go-gost/core/chain" + "github.com/go-gost/core/hop" "github.com/go-gost/core/handler" "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" @@ -17,6 +19,12 @@ import ( "github.com/go-gost/core/service" xchain "github.com/go-gost/x/chain" "github.com/go-gost/x/config" + "github.com/go-gost/x/config/parsing" + auth_parser "github.com/go-gost/x/config/parsing/auth" + hop_parser "github.com/go-gost/x/config/parsing/hop" + bypass_parser "github.com/go-gost/x/config/parsing/bypass" + selector_parser "github.com/go-gost/x/config/parsing/selector" + admission_parser "github.com/go-gost/x/config/parsing/admission" tls_util "github.com/go-gost/x/internal/util/tls" "github.com/go-gost/x/metadata" "github.com/go-gost/x/registry" @@ -56,12 +64,12 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { return nil, err } if tlsConfig == nil { - tlsConfig = defaultTLSConfig.Clone() + tlsConfig = parsing.DefaultTLSConfig().Clone() } - authers := autherList(cfg.Listener.Auther, cfg.Listener.Authers...) + authers := auth_parser.List(cfg.Listener.Auther, cfg.Listener.Authers...) if len(authers) == 0 { - if auther := ParseAutherFromAuth(cfg.Listener.Auth); auther != nil { + if auther := auth_parser.ParseAutherFromAuth(cfg.Listener.Auth); auther != nil { authers = append(authers, auther) } } @@ -70,7 +78,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { auther = auth.AuthenticatorGroup(authers...) } - admissions := admissionList(cfg.Admission, cfg.Admissions...) + admissions := admission_parser.List(cfg.Admission, cfg.Admissions...) var sockOpts *chain.SockOpts if cfg.SockOpts != nil { @@ -85,26 +93,26 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { var ignoreChain bool if cfg.Metadata != nil { md := metadata.NewMetadata(cfg.Metadata) - ppv = mdutil.GetInt(md, mdKeyProxyProtocol) - if v := mdutil.GetString(md, mdKeyInterface); v != "" { + ppv = mdutil.GetInt(md, parsing.MDKeyProxyProtocol) + if v := mdutil.GetString(md, parsing.MDKeyInterface); v != "" { ifce = v } - if v := mdutil.GetInt(md, mdKeySoMark); v > 0 { + if v := mdutil.GetInt(md, parsing.MDKeySoMark); v > 0 { sockOpts = &chain.SockOpts{ Mark: v, } } - preUp = mdutil.GetStrings(md, mdKeyPreUp) - preDown = mdutil.GetStrings(md, mdKeyPreDown) - postUp = mdutil.GetStrings(md, mdKeyPostUp) - postDown = mdutil.GetStrings(md, mdKeyPostDown) - ignoreChain = mdutil.GetBool(md, mdKeyIgnoreChain) + preUp = mdutil.GetStrings(md, parsing.MDKeyPreUp) + preDown = mdutil.GetStrings(md, parsing.MDKeyPreDown) + postUp = mdutil.GetStrings(md, parsing.MDKeyPostUp) + postDown = mdutil.GetStrings(md, parsing.MDKeyPostDown) + ignoreChain = mdutil.GetBool(md, parsing.MDKeyIgnoreChain) } listenOpts := []listener.Option{ listener.AddrOption(cfg.Addr), listener.AutherOption(auther), - listener.AuthOption(parseAuth(cfg.Listener.Auth)), + listener.AuthOption(auth_parser.Info(cfg.Listener.Auth)), listener.TLSConfigOption(tlsConfig), listener.AdmissionOption(admission.AdmissionGroup(admissions...)), listener.TrafficLimiterOption(registry.TrafficLimiterRegistry().Get(cfg.Limiter)), @@ -150,12 +158,12 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { return nil, err } if tlsConfig == nil { - tlsConfig = defaultTLSConfig.Clone() + tlsConfig = parsing.DefaultTLSConfig().Clone() } - authers = autherList(cfg.Handler.Auther, cfg.Handler.Authers...) + authers = auth_parser.List(cfg.Handler.Auther, cfg.Handler.Authers...) if len(authers) == 0 { - if auther := ParseAutherFromAuth(cfg.Handler.Auth); auther != nil { + if auther := auth_parser.ParseAutherFromAuth(cfg.Handler.Auth); auther != nil { authers = append(authers, auther) } } @@ -172,9 +180,9 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { Recorder: registry.RecorderRegistry().Get(r.Name), Record: r.Record, Options: &recorder.Options{ - Direction: mdutil.GetBool(md, mdKeyRecorderDirection), - TimestampFormat: mdutil.GetString(md, mdKeyRecorderTimestampFormat), - Hexdump: mdutil.GetBool(md, mdKeyRecorderHexdump), + Direction: mdutil.GetBool(md, parsing.MDKeyRecorderDirection), + TimestampFormat: mdutil.GetString(md, parsing.MDKeyRecorderTimestampFormat), + Hexdump: mdutil.GetBool(md, parsing.MDKeyRecorderHexdump), }, }) } @@ -201,8 +209,8 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { h = rf( handler.RouterOption(router), handler.AutherOption(auther), - handler.AuthOption(parseAuth(cfg.Handler.Auth)), - handler.BypassOption(bypass.BypassGroup(bypassList(cfg.Bypass, cfg.Bypasses...)...)), + handler.AuthOption(auth_parser.Info(cfg.Handler.Auth)), + handler.BypassOption(bypass.BypassGroup(bypass_parser.List(cfg.Bypass, cfg.Bypasses...)...)), handler.TLSConfigOption(tlsConfig), handler.RateLimiterOption(registry.RateLimiterRegistry().Get(cfg.RLimiter)), handler.LoggerOption(handlerLogger), @@ -243,7 +251,7 @@ func ParseService(cfg *config.ServiceConfig) (service.Service, error) { return s, nil } -func parseForwarder(cfg *config.ForwarderConfig) (chain.Hop, error) { +func parseForwarder(cfg *config.ForwarderConfig) (hop.Hop, error) { if cfg == nil { return nil, nil } @@ -284,51 +292,11 @@ func parseForwarder(cfg *config.ForwarderConfig) (chain.Hop, error) { } if len(hc.Nodes) > 0 { - return ParseHop(&hc) + return hop_parser.ParseHop(&hc) } return registry.HopRegistry().Get(hc.Name), nil } -func bypassList(name string, names ...string) []bypass.Bypass { - var bypasses []bypass.Bypass - if bp := registry.BypassRegistry().Get(name); bp != nil { - bypasses = append(bypasses, bp) - } - for _, s := range names { - if bp := registry.BypassRegistry().Get(s); bp != nil { - bypasses = append(bypasses, bp) - } - } - return bypasses -} - -func autherList(name string, names ...string) []auth.Authenticator { - var authers []auth.Authenticator - if auther := registry.AutherRegistry().Get(name); auther != nil { - authers = append(authers, auther) - } - for _, s := range names { - if auther := registry.AutherRegistry().Get(s); auther != nil { - authers = append(authers, auther) - } - } - return authers -} - -func admissionList(name string, names ...string) []admission.Admission { - var admissions []admission.Admission - if adm := registry.AdmissionRegistry().Get(name); adm != nil { - admissions = append(admissions, adm) - } - for _, s := range names { - if adm := registry.AdmissionRegistry().Get(s); adm != nil { - admissions = append(admissions, adm) - } - } - - return admissions -} - func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer { var chains []chain.Chainer var sel selector.Selector[chain.Chainer] @@ -342,14 +310,14 @@ func chainGroup(name string, group *config.ChainGroupConfig) chain.Chainer { chains = append(chains, c) } } - sel = parseChainSelector(group.Selector) + sel = selector_parser.ParseChainSelector(group.Selector) } if len(chains) == 0 { return nil } if sel == nil { - sel = defaultChainSelector() + sel = selector_parser.DefaultChainSelector() } return xchain.NewChainGroup(chains...). diff --git a/config/parsing/tls.go b/config/parsing/tls.go index b78f718..ac6d474 100644 --- a/config/parsing/tls.go +++ b/config/parsing/tls.go @@ -23,6 +23,10 @@ var ( defaultTLSConfig *tls.Config ) +func DefaultTLSConfig() *tls.Config { + return defaultTLSConfig +} + func BuildDefaultTLSConfig(cfg *config.TLSConfig) { log := logger.Default() diff --git a/go.mod b/go.mod index ad63c03..02db719 100644 --- a/go.mod +++ b/go.mod @@ -7,10 +7,10 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.9.1 - github.com/go-gost/core v0.0.0-20230920145336-6d0e88635be9 + github.com/go-gost/core v0.0.0-20230928130125-b0bd45c1b862 github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.4.0 - github.com/go-gost/plugin v0.0.0-20230921115816-47001719099f + github.com/go-gost/plugin v0.0.0-20230928130211-8bc0679b5c15 github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 github.com/go-redis/redis/v8 v8.11.5 diff --git a/go.sum b/go.sum index db9d9e6..3e7f574 100644 --- a/go.sum +++ b/go.sum @@ -100,14 +100,14 @@ github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SU github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gost/core v0.0.0-20230920145336-6d0e88635be9 h1:VHka8LcdBJmM7Yv2bjQO5kctF0T9O4E/PVzgkdk0Vdo= -github.com/go-gost/core v0.0.0-20230920145336-6d0e88635be9/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= +github.com/go-gost/core v0.0.0-20230928130125-b0bd45c1b862 h1:hbCHyfYE96WZefTBitiL35FCYxHCgEWpS+W/5oCyEXk= +github.com/go-gost/core v0.0.0-20230928130125-b0bd45c1b862/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.4.0 h1:EIrOEkpJez4gwHrMa33frA+hHXJyevjp47thpMQsJzI= github.com/go-gost/gosocks5 v0.4.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= -github.com/go-gost/plugin v0.0.0-20230921115816-47001719099f h1:Z6k8xfvQv8PmrC++wV4BlzVv85iuGtHhL6QSzrF6m5Q= -github.com/go-gost/plugin v0.0.0-20230921115816-47001719099f/go.mod h1:mM/RLNsVy2nz5PiOijuqLYR3LhMzyQ9Kh/p0rXybJoo= +github.com/go-gost/plugin v0.0.0-20230928130211-8bc0679b5c15 h1:SKPbGuJUBKhh4qE2G5juT4PNMrzYH86itiY3TGwvYcs= +github.com/go-gost/plugin v0.0.0-20230928130211-8bc0679b5c15/go.mod h1:mM/RLNsVy2nz5PiOijuqLYR3LhMzyQ9Kh/p0rXybJoo= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 h1:qAG1OyjvdA5h221CfFSS3J359V3d2E7dJWyP29QoDSI= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 h1:xj8gUZGYO3nb5+6Bjw9+tsFkA9sYynrOvDvvC4uDV2I= diff --git a/handler/dns/handler.go b/handler/dns/handler.go index a9189bf..e44ef9a 100644 --- a/handler/dns/handler.go +++ b/handler/dns/handler.go @@ -11,10 +11,11 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/handler" + "github.com/go-gost/core/hop" "github.com/go-gost/core/hosts" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" - xchain "github.com/go-gost/x/chain" + xhop "github.com/go-gost/x/hop" resolver_util "github.com/go-gost/x/internal/util/resolver" "github.com/go-gost/x/registry" "github.com/go-gost/x/resolver/exchanger" @@ -30,7 +31,7 @@ func init() { } type dnsHandler struct { - hop chain.Hop + hop hop.Hop exchangers map[string]exchanger.Exchanger cache *resolver_util.Cache router *chain.Router @@ -70,10 +71,14 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { for i, addr := range h.md.dns { nodes = append(nodes, chain.NewNode(fmt.Sprintf("target-%d", i), addr)) } - h.hop = xchain.NewChainHop(nodes) + h.hop = xhop.NewHop(xhop.NodeOption(nodes...)) } - for _, node := range h.hop.Nodes() { + var nodes []*chain.Node + if nl, ok := h.hop.(hop.NodeList); ok { + nodes = nl.Nodes() + } + for _, node := range nodes { addr := strings.TrimSpace(node.Addr) if addr == "" { continue @@ -109,7 +114,7 @@ func (h *dnsHandler) Init(md md.Metadata) (err error) { } // Forward implements handler.Forwarder. -func (h *dnsHandler) Forward(hop chain.Hop) { +func (h *dnsHandler) Forward(hop hop.Hop) { h.hop = hop } @@ -325,7 +330,7 @@ func (h *dnsHandler) selectExchanger(ctx context.Context, addr string) exchanger if h.hop == nil { return nil } - node := h.hop.Select(ctx, chain.AddrSelectOption(addr)) + node := h.hop.Select(ctx, hop.AddrSelectOption(addr)) if node == nil { return nil } diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index 82fad67..5846eee 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -15,6 +15,7 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/handler" + "github.com/go-gost/core/hop" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" xnet "github.com/go-gost/x/internal/net" @@ -30,7 +31,7 @@ func init() { } type forwardHandler struct { - hop chain.Hop + hop hop.Hop router *chain.Router md metadata options handler.Options @@ -61,7 +62,7 @@ func (h *forwardHandler) Init(md md.Metadata) (err error) { } // Forward implements handler.Forwarder. -func (h *forwardHandler) Forward(hop chain.Hop) { +func (h *forwardHandler) Forward(hop hop.Hop) { h.hop = hop } @@ -123,8 +124,8 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand } if h.hop != nil { target = h.hop.Select(ctx, - chain.HostSelectOption(host), - chain.ProtocolSelectOption(protocol), + hop.HostSelectOption(host), + hop.ProtocolSelectOption(protocol), ) } if target == nil { @@ -192,8 +193,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l } if h.hop != nil { target = h.hop.Select(ctx, - chain.HostSelectOption(req.Host), - chain.ProtocolSelectOption(forward.ProtoHTTP), + hop.HostSelectOption(req.Host), + hop.ProtocolSelectOption(forward.ProtoHTTP), ) } if target == nil { diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index f1ffcec..0a85fda 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -15,6 +15,7 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/handler" + "github.com/go-gost/core/hop" "github.com/go-gost/core/logger" mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" @@ -30,7 +31,7 @@ func init() { } type forwardHandler struct { - hop chain.Hop + hop hop.Hop router *chain.Router md metadata options handler.Options @@ -61,7 +62,7 @@ func (h *forwardHandler) Init(md mdata.Metadata) (err error) { } // Forward implements handler.Forwarder. -func (h *forwardHandler) Forward(hop chain.Hop) { +func (h *forwardHandler) Forward(hop hop.Hop) { h.hop = hop } @@ -123,8 +124,8 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand } if h.hop != nil { target = h.hop.Select(ctx, - chain.HostSelectOption(host), - chain.ProtocolSelectOption(protocol), + hop.HostSelectOption(host), + hop.ProtocolSelectOption(protocol), ) } if target == nil { @@ -189,8 +190,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l } if h.hop != nil { target = h.hop.Select(ctx, - chain.HostSelectOption(req.Host), - chain.ProtocolSelectOption(forward.ProtoHTTP), + hop.HostSelectOption(req.Host), + hop.ProtocolSelectOption(forward.ProtoHTTP), ) } if target == nil { diff --git a/handler/http3/handler.go b/handler/http3/handler.go index 90a11a5..e8033ff 100644 --- a/handler/http3/handler.go +++ b/handler/http3/handler.go @@ -10,6 +10,7 @@ import ( "time" "github.com/go-gost/core/chain" + "github.com/go-gost/core/hop" "github.com/go-gost/core/handler" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" @@ -22,7 +23,7 @@ func init() { } type http3Handler struct { - hop chain.Hop + hop hop.Hop router *chain.Router md metadata options handler.Options @@ -53,7 +54,7 @@ func (h *http3Handler) Init(md md.Metadata) error { } // Forward implements handler.Forwarder. -func (h *http3Handler) Forward(hop chain.Hop) { +func (h *http3Handler) Forward(hop hop.Hop) { h.hop = hop } @@ -118,7 +119,7 @@ func (h *http3Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req var target *chain.Node if h.hop != nil { - target = h.hop.Select(ctx, chain.HostSelectOption(addr)) + target = h.hop.Select(ctx, hop.HostSelectOption(addr)) } if target == nil { err := errors.New("target not available") diff --git a/handler/relay/handler.go b/handler/relay/handler.go index fdcf218..fed36df 100644 --- a/handler/relay/handler.go +++ b/handler/relay/handler.go @@ -10,6 +10,7 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/handler" + "github.com/go-gost/core/hop" "github.com/go-gost/core/listener" md "github.com/go-gost/core/metadata" "github.com/go-gost/core/service" @@ -32,7 +33,7 @@ func init() { } type relayHandler struct { - hop chain.Hop + hop hop.Hop router *chain.Router md metadata options handler.Options @@ -124,7 +125,7 @@ func (h *relayHandler) initEntryPoint() (err error) { } // Forward implements handler.Forwarder. -func (h *relayHandler) Forward(hop chain.Hop) { +func (h *relayHandler) Forward(hop hop.Hop) { h.hop = hop } diff --git a/handler/serial/handler.go b/handler/serial/handler.go index 895ccdd..5d60fb1 100644 --- a/handler/serial/handler.go +++ b/handler/serial/handler.go @@ -9,6 +9,7 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/handler" + "github.com/go-gost/core/hop" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" "github.com/go-gost/core/recorder" @@ -24,7 +25,7 @@ func init() { } type serialHandler struct { - hop chain.Hop + hop hop.Hop router *chain.Router md metadata options handler.Options @@ -64,7 +65,7 @@ func (h *serialHandler) Init(md md.Metadata) (err error) { } // Forward implements handler.Forwarder. -func (h *serialHandler) Forward(hop chain.Hop) { +func (h *serialHandler) Forward(hop hop.Hop) { h.hop = hop } diff --git a/handler/tap/handler.go b/handler/tap/handler.go index ea1dac3..e73af6d 100644 --- a/handler/tap/handler.go +++ b/handler/tap/handler.go @@ -13,6 +13,7 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/common/bufpool" "github.com/go-gost/core/handler" + "github.com/go-gost/core/hop" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" "github.com/go-gost/x/internal/util/ss" @@ -28,7 +29,7 @@ func init() { } type tapHandler struct { - hop chain.Hop + hop hop.Hop routes sync.Map exit chan struct{} cipher core.Cipher @@ -72,7 +73,7 @@ func (h *tapHandler) Init(md md.Metadata) (err error) { } // Forward implements handler.Forwarder. -func (h *tapHandler) Forward(hop chain.Hop) { +func (h *tapHandler) Forward(hop hop.Hop) { h.hop = hop } diff --git a/handler/tun/handler.go b/handler/tun/handler.go index cc67b09..87a14b3 100644 --- a/handler/tun/handler.go +++ b/handler/tun/handler.go @@ -9,6 +9,7 @@ import ( "time" "github.com/go-gost/core/chain" + "github.com/go-gost/core/hop" "github.com/go-gost/core/handler" md "github.com/go-gost/core/metadata" tun_util "github.com/go-gost/x/internal/util/tun" @@ -26,7 +27,7 @@ func init() { } type tunHandler struct { - hop chain.Hop + hop hop.Hop routes sync.Map router *chain.Router md metadata @@ -58,7 +59,7 @@ func (h *tunHandler) Init(md md.Metadata) (err error) { } // Forward implements handler.Forwarder. -func (h *tunHandler) Forward(hop chain.Hop) { +func (h *tunHandler) Forward(hop hop.Hop) { h.hop = hop } diff --git a/handler/unix/handler.go b/handler/unix/handler.go index 4f1bc0e..d03cf56 100644 --- a/handler/unix/handler.go +++ b/handler/unix/handler.go @@ -8,6 +8,7 @@ import ( "time" "github.com/go-gost/core/chain" + "github.com/go-gost/core/hop" "github.com/go-gost/core/handler" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" @@ -20,7 +21,7 @@ func init() { } type unixHandler struct { - hop chain.Hop + hop hop.Hop router *chain.Router md metadata options handler.Options @@ -51,7 +52,7 @@ func (h *unixHandler) Init(md md.Metadata) (err error) { } // Forward implements handler.Forwarder. -func (h *unixHandler) Forward(hop chain.Hop) { +func (h *unixHandler) Forward(hop hop.Hop) { h.hop = hop } diff --git a/hop/hop.go b/hop/hop.go new file mode 100644 index 0000000..abf44b7 --- /dev/null +++ b/hop/hop.go @@ -0,0 +1,307 @@ +package hop + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net" + "strings" + "sync" + "time" + + "github.com/go-gost/core/bypass" + "github.com/go-gost/core/chain" + "github.com/go-gost/core/hop" + "github.com/go-gost/core/logger" + "github.com/go-gost/core/selector" + "github.com/go-gost/x/config" + node_parser "github.com/go-gost/x/config/parsing/node" + "github.com/go-gost/x/internal/loader" +) + +type options struct { + name string + nodes []*chain.Node + bypass bypass.Bypass + selector selector.Selector[*chain.Node] + fileLoader loader.Loader + redisLoader loader.Loader + httpLoader loader.Loader + period time.Duration + logger logger.Logger +} + +type Option func(*options) + +func NameOption(name string) Option { + return func(o *options) { + o.name = name + } +} + +func NodeOption(nodes ...*chain.Node) Option { + return func(o *options) { + o.nodes = nodes + } +} +func BypassOption(bp bypass.Bypass) Option { + return func(o *options) { + o.bypass = bp + } +} + +func SelectorOption(s selector.Selector[*chain.Node]) Option { + return func(o *options) { + o.selector = s + } +} + +func ReloadPeriodOption(period time.Duration) Option { + return func(opts *options) { + opts.period = period + } +} + +func FileLoaderOption(fileLoader loader.Loader) Option { + return func(opts *options) { + opts.fileLoader = fileLoader + } +} + +func RedisLoaderOption(redisLoader loader.Loader) Option { + return func(opts *options) { + opts.redisLoader = redisLoader + } +} + +func HTTPLoaderOption(httpLoader loader.Loader) Option { + return func(opts *options) { + opts.httpLoader = httpLoader + } +} +func LoggerOption(logger logger.Logger) Option { + return func(opts *options) { + opts.logger = logger + } +} + +type chainHop struct { + nodes []*chain.Node + mu sync.RWMutex + cancelFunc context.CancelFunc + options options +} + +func NewHop(opts ...Option) hop.Hop { + var options options + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + + ctx, cancel := context.WithCancel(context.TODO()) + p := &chainHop{ + cancelFunc: cancel, + options: options, + } + + if err := p.reload(ctx); err != nil { + options.logger.Warnf("reload: %v", err) + } + if p.options.period > 0 { + go p.periodReload(ctx) + } + + return p +} + +func (p *chainHop) Nodes() []*chain.Node { + if p == nil { + return nil + } + p.mu.RLock() + defer p.mu.RUnlock() + return p.nodes +} + +func (p *chainHop) Select(ctx context.Context, opts ...hop.SelectOption) *chain.Node { + var options hop.SelectOptions + for _, opt := range opts { + opt(&options) + } + + ns := p.Nodes() + if len(ns) == 0 { + return nil + } + + // hop level bypass + if p.options.bypass != nil && + p.options.bypass.Contains(ctx, options.Addr) { + return nil + } + + filters := ns + if host := options.Host; host != "" { + filters = nil + if v, _, _ := net.SplitHostPort(host); v != "" { + host = v + } + var nodes []*chain.Node + for _, node := range ns { + if node == nil { + continue + } + vhost := node.Options().Host + if vhost == "" { + nodes = append(nodes, node) + continue + } + if vhost == host || + vhost[0] == '.' && strings.HasSuffix(host, vhost[1:]) { + filters = append(filters, node) + } + } + if len(filters) == 0 { + filters = nodes + } + } else if protocol := options.Protocol; protocol != "" { + filters = nil + for _, node := range ns { + if node == nil { + continue + } + if node.Options().Protocol == protocol { + filters = append(filters, node) + } + } + } + + var nodes []*chain.Node + for _, node := range filters { + if node == nil { + continue + } + // node level bypass + if node.Options().Bypass != nil && + node.Options().Bypass.Contains(ctx, options.Addr) { + continue + } + + nodes = append(nodes, node) + } + if len(nodes) == 0 { + return nil + } + + if s := p.options.selector; s != nil { + return s.Select(ctx, nodes...) + } + return nodes[0] +} + +func (p *chainHop) periodReload(ctx context.Context) error { + period := p.options.period + if period < time.Second { + period = time.Second + } + ticker := time.NewTicker(period) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := p.reload(ctx); err != nil { + p.options.logger.Warnf("reload: %v", err) + // return err + } + p.options.logger.Debug("hop reload done") + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (p *chainHop) reload(ctx context.Context) (err error) { + nodes := p.options.nodes + + nl, err := p.load(ctx) + + nodes = append(nodes, nl...) + + p.mu.Lock() + defer p.mu.Unlock() + + p.nodes = nodes + + return +} + +func (p *chainHop) load(ctx context.Context) (nodes []*chain.Node, err error) { + if p.options.fileLoader != nil { + r, er := p.options.fileLoader.Load(ctx) + if er != nil { + p.options.logger.Warnf("file loader: %v", er) + } + nodes, _ = p.parseNode(r) + } + + if p.options.redisLoader != nil { + if lister, ok := p.options.redisLoader.(loader.Lister); ok { + list, er := lister.List(ctx) + if er != nil { + p.options.logger.Warnf("redis loader: %v", er) + } + for _, s := range list { + nl, _ := p.parseNode(bytes.NewReader([]byte(s))) + nodes = append(nodes, nl...) + } + } + } + if p.options.httpLoader != nil { + r, er := p.options.httpLoader.Load(ctx) + if er != nil { + p.options.logger.Warnf("http loader: %v", er) + } + if node, _ := p.parseNode(r); node != nil { + nodes = append(nodes, node...) + } + } + + p.options.logger.Debugf("load items %d", len(nodes)) + return +} + +func (p *chainHop) parseNode(r io.Reader) ([]*chain.Node, error) { + var ncs []*config.NodeConfig + if err := json.NewDecoder(r).Decode(&ncs); err != nil { + return nil, err + } + + var nodes []*chain.Node + for _, nc := range ncs { + if nc == nil { + continue + } + + node, err := node_parser.ParseNode(p.options.name, nc) + if err != nil { + return nodes, err + } + nodes = append(nodes, node) + } + return nodes, nil +} + +func (p *chainHop) Close() error { + p.cancelFunc() + if p.options.fileLoader != nil { + p.options.fileLoader.Close() + } + if p.options.redisLoader != nil { + p.options.redisLoader.Close() + } + return nil +} diff --git a/hop/plugin.go b/hop/plugin.go new file mode 100644 index 0000000..1e782f7 --- /dev/null +++ b/hop/plugin.go @@ -0,0 +1,204 @@ +package hop + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + + "github.com/go-gost/core/chain" + "github.com/go-gost/core/hop" + "github.com/go-gost/core/logger" + "github.com/go-gost/plugin/hop/proto" + "github.com/go-gost/x/config" + node_parser "github.com/go-gost/x/config/parsing/node" + "github.com/go-gost/x/internal/plugin" + auth_util "github.com/go-gost/x/internal/util/auth" + "google.golang.org/grpc" +) + +type grpcPlugin struct { + name string + conn grpc.ClientConnInterface + client proto.HopClient + log logger.Logger +} + +// NewGRPCPlugin creates a Hop plugin based on gRPC. +func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) hop.Hop{ + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + log := logger.Default().WithFields(map[string]any{ + "kind": "hop", + "hop": name, + }) + conn, err := plugin.NewGRPCConn(addr, &options) + if err != nil { + log.Error(err) + } + + p := &grpcPlugin{ + name: name, + conn: conn, + log: log, + } + if conn != nil { + p.client = proto.NewHopClient(conn) + } + return p +} + +func (p *grpcPlugin) Select(ctx context.Context, opts ...hop.SelectOption) *chain.Node { + if p.client == nil { + return nil + } + + var options hop.SelectOptions + for _, opt := range opts { + opt(&options) + } + + r, err := p.client.Select(ctx, + &proto.SelectRequest{ + Addr: options.Addr, + Host: options.Host, + Client: string(auth_util.IDFromContext(ctx)), + }) + if err != nil { + p.log.Error(err) + return nil + } + + if r.Node == nil { + return nil + } + + var cfg config.NodeConfig + if err := json.NewDecoder(bytes.NewReader(r.Node)).Decode(&cfg); err != nil { + p.log.Error(err) + return nil + } + + node, err := node_parser.ParseNode(p.name, &cfg) + if err != nil { + p.log.Error(err) + return nil + } + return node +} + +func (p *grpcPlugin) Close() error { + if closer, ok := p.conn.(io.Closer); ok { + return closer.Close() + } + return nil +} + +type httpPluginRequest struct { + Network string `json:"network"` + Addr string `json:"addr"` + Host string `json:"host"` + Client string `json:"client"` +} + +type httpPluginResponse struct { + Node string `json:"node"` +} + +type httpPlugin struct { + name string + url string + client *http.Client + header http.Header + log logger.Logger +} + +// NewHTTPPlugin creates an Hop plugin based on HTTP. +func NewHTTPPlugin(name string, url string, opts ...plugin.Option) hop.Hop{ + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + return &httpPlugin{ + name: name, + url: url, + client: plugin.NewHTTPClient(&options), + header: options.Header, + log: logger.Default().WithFields(map[string]any{ + "kind": "hop", + "hop": name, + }), + } +} + +func (p *httpPlugin) Select(ctx context.Context, opts ...hop.SelectOption) *chain.Node { + if p.client == nil { + return nil + } + + var options hop.SelectOptions + for _, opt := range opts { + opt(&options) + } + + rb := httpPluginRequest{ + Addr: options.Addr, + Host: options.Host, + Client: string(auth_util.IDFromContext(ctx)), + } + v, err := json.Marshal(&rb) + if err != nil { + p.log.Error(err) + return nil + } + + req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + if err != nil { + p.log.Error(err) + return nil + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + p.log.Error(err) + return nil + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + p.log.Error(resp.Status) + return nil + } + + res := httpPluginResponse{} + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + p.log.Error(resp.Status) + return nil + } + + if res.Node == "" { + return nil + } + + var cfg config.NodeConfig + if err := json.NewDecoder(bytes.NewReader([]byte(res.Node))).Decode(&cfg); err != nil { + p.log.Error(err) + return nil + } + + node, err := node_parser.ParseNode(p.name, &cfg) + if err != nil { + p.log.Error(err) + return nil + } + return node +} diff --git a/hosts/hosts.go b/hosts/hosts.go index 66a648e..b06f789 100644 --- a/hosts/hosts.go +++ b/hosts/hosts.go @@ -12,7 +12,6 @@ import ( "github.com/go-gost/core/hosts" "github.com/go-gost/core/logger" "github.com/go-gost/x/internal/loader" - "google.golang.org/grpc" ) type Mapping struct { @@ -25,7 +24,6 @@ type options struct { fileLoader loader.Loader redisLoader loader.Loader httpLoader loader.Loader - client *grpc.ClientConn period time.Duration logger logger.Logger } @@ -62,12 +60,6 @@ func HTTPLoaderOption(httpLoader loader.Loader) Option { } } -func PluginConnOption(c *grpc.ClientConn) Option { - return func(opts *options) { - opts.client = c - } -} - func LoggerOption(logger logger.Logger) Option { return func(opts *options) { opts.logger = logger diff --git a/hosts/plugin.go b/hosts/plugin.go index fcc09c5..e1d56e1 100644 --- a/hosts/plugin.go +++ b/hosts/plugin.go @@ -12,18 +12,18 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/plugin/hosts/proto" auth_util "github.com/go-gost/x/internal/util/auth" - "github.com/go-gost/x/internal/util/plugin" + "github.com/go-gost/x/internal/plugin" "google.golang.org/grpc" ) -type grpcPluginHostMapper struct { +type grpcPlugin struct { conn grpc.ClientConnInterface client proto.HostMapperClient log logger.Logger } -// NewGRPCPluginHostMapper creates a HostMapper plugin based on gRPC. -func NewGRPCPluginHostMapper(name string, addr string, opts ...plugin.Option) hosts.HostMapper { +// NewGRPCPlugin creates a HostMapper plugin based on gRPC. +func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) hosts.HostMapper { var options plugin.Options for _, opt := range opts { opt(&options) @@ -37,7 +37,7 @@ func NewGRPCPluginHostMapper(name string, addr string, opts ...plugin.Option) ho if err != nil { log.Error(err) } - p := &grpcPluginHostMapper{ + p := &grpcPlugin{ conn: conn, log: log, } @@ -47,7 +47,7 @@ func NewGRPCPluginHostMapper(name string, addr string, opts ...plugin.Option) ho return p } -func (p *grpcPluginHostMapper) Lookup(ctx context.Context, network, host string) (ips []net.IP, ok bool) { +func (p *grpcPlugin) Lookup(ctx context.Context, network, host string) (ips []net.IP, ok bool) { p.log.Debugf("lookup %s/%s", host, network) if p.client == nil { @@ -73,39 +73,39 @@ func (p *grpcPluginHostMapper) Lookup(ctx context.Context, network, host string) return } -func (p *grpcPluginHostMapper) Close() error { +func (p *grpcPlugin) Close() error { if closer, ok := p.conn.(io.Closer); ok { return closer.Close() } return nil } -type httpHostMapperRequest struct { +type httpPluginRequest struct { Network string `json:"network"` Host string `json:"host"` Client string `json:"client"` } -type httpHostMapperResponse struct { +type httpPluginResponse struct { IPs []string `json:"ips"` OK bool `json:"ok"` } -type httpPluginHostMapper struct { +type httpPlugin struct { url string client *http.Client header http.Header log logger.Logger } -// NewHTTPPluginHostMapper creates an HostMapper plugin based on HTTP. -func NewHTTPPluginHostMapper(name string, url string, opts ...plugin.Option) hosts.HostMapper { +// NewHTTPPlugin creates an HostMapper plugin based on HTTP. +func NewHTTPPlugin(name string, url string, opts ...plugin.Option) hosts.HostMapper { var options plugin.Options for _, opt := range opts { opt(&options) } - return &httpPluginHostMapper{ + return &httpPlugin{ url: url, client: plugin.NewHTTPClient(&options), header: options.Header, @@ -116,14 +116,14 @@ func NewHTTPPluginHostMapper(name string, url string, opts ...plugin.Option) hos } } -func (p *httpPluginHostMapper) Lookup(ctx context.Context, network, host string) (ips []net.IP, ok bool) { +func (p *httpPlugin) Lookup(ctx context.Context, network, host string) (ips []net.IP, ok bool) { p.log.Debugf("lookup %s/%s", host, network) if p.client == nil { return } - rb := httpHostMapperRequest{ + rb := httpPluginRequest{ Network: network, Host: host, Client: string(auth_util.IDFromContext(ctx)), @@ -152,7 +152,7 @@ func (p *httpPluginHostMapper) Lookup(ctx context.Context, network, host string) return } - res := httpHostMapperResponse{} + res := httpPluginResponse{} if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { return } diff --git a/ingress/ingress.go b/ingress/ingress.go index 64170d6..49f4903 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -12,7 +12,6 @@ import ( "github.com/go-gost/core/ingress" "github.com/go-gost/core/logger" "github.com/go-gost/x/internal/loader" - "google.golang.org/grpc" ) type Rule struct { @@ -25,7 +24,6 @@ type options struct { fileLoader loader.Loader redisLoader loader.Loader httpLoader loader.Loader - client *grpc.ClientConn period time.Duration logger logger.Logger } @@ -62,12 +60,6 @@ func HTTPLoaderOption(httpLoader loader.Loader) Option { } } -func PluginConnOption(c *grpc.ClientConn) Option { - return func(opts *options) { - opts.client = c - } -} - func LoggerOption(logger logger.Logger) Option { return func(opts *options) { opts.logger = logger diff --git a/ingress/plugin.go b/ingress/plugin.go index 7a05a83..6d0b6d5 100644 --- a/ingress/plugin.go +++ b/ingress/plugin.go @@ -10,18 +10,18 @@ import ( "github.com/go-gost/core/ingress" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/ingress/proto" - "github.com/go-gost/x/internal/util/plugin" + "github.com/go-gost/x/internal/plugin" "google.golang.org/grpc" ) -type grpcPluginIngress struct { +type grpcPlugin struct { conn grpc.ClientConnInterface client proto.IngressClient log logger.Logger } -// NewGRPCPluginIngress creates an Ingress plugin based on gRPC. -func NewGRPCPluginIngress(name string, addr string, opts ...plugin.Option) ingress.Ingress { +// NewGRPCPlugin creates an Ingress plugin based on gRPC. +func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) ingress.Ingress { var options plugin.Options for _, opt := range opts { opt(&options) @@ -36,7 +36,7 @@ func NewGRPCPluginIngress(name string, addr string, opts ...plugin.Option) ingre log.Error(err) } - p := &grpcPluginIngress{ + p := &grpcPlugin{ conn: conn, log: log, } @@ -46,7 +46,7 @@ func NewGRPCPluginIngress(name string, addr string, opts ...plugin.Option) ingre return p } -func (p *grpcPluginIngress) Get(ctx context.Context, host string) string { +func (p *grpcPlugin) Get(ctx context.Context, host string) string { if p.client == nil { return "" } @@ -62,36 +62,36 @@ func (p *grpcPluginIngress) Get(ctx context.Context, host string) string { return r.GetEndpoint() } -func (p *grpcPluginIngress) Close() error { +func (p *grpcPlugin) Close() error { if closer, ok := p.conn.(io.Closer); ok { return closer.Close() } return nil } -type httpIngressRequest struct { +type httpPluginRequest struct { Host string `json:"host"` } -type httpIngressResponse struct { +type httpPluginResponse struct { Endpoint string `json:"endpoint"` } -type httpPluginIngress struct { +type httpPlugin struct { url string client *http.Client header http.Header log logger.Logger } -// NewHTTPPluginIngress creates an Ingress plugin based on HTTP. -func NewHTTPPluginIngress(name string, url string, opts ...plugin.Option) ingress.Ingress { +// NewHTTPPlugin creates an Ingress plugin based on HTTP. +func NewHTTPPlugin(name string, url string, opts ...plugin.Option) ingress.Ingress { var options plugin.Options for _, opt := range opts { opt(&options) } - return &httpPluginIngress{ + return &httpPlugin{ url: url, client: plugin.NewHTTPClient(&options), header: options.Header, @@ -102,12 +102,12 @@ func NewHTTPPluginIngress(name string, url string, opts ...plugin.Option) ingres } } -func (p *httpPluginIngress) Get(ctx context.Context, host string) (endpoint string) { +func (p *httpPlugin) Get(ctx context.Context, host string) (endpoint string) { if p.client == nil { return } - rb := httpIngressRequest{ + rb := httpPluginRequest{ Host: host, } v, err := json.Marshal(&rb) @@ -134,7 +134,7 @@ func (p *httpPluginIngress) Get(ctx context.Context, host string) (endpoint stri return } - res := httpIngressResponse{} + res := httpPluginResponse{} if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { return } diff --git a/internal/loader/redis.go b/internal/loader/redis.go index c19331d..5265a84 100644 --- a/internal/loader/redis.go +++ b/internal/loader/redis.go @@ -40,6 +40,47 @@ func KeyRedisLoaderOption(key string) RedisLoaderOption { } } +type redisStringLoader struct { + client *redis.Client + key string +} + +// RedisStringLoader loads data from redis string. +func RedisStringLoader(addr string, opts ...RedisLoaderOption) Loader { + var options redisLoaderOptions + for _, opt := range opts { + if opt != nil { + opt(&options) + } + } + + key := options.key + if key == "" { + key = DefaultRedisKey + } + + return &redisStringLoader{ + client: redis.NewClient(&redis.Options{ + Addr: addr, + Password: options.password, + DB: options.db, + }), + key: key, + } +} + +func (p *redisStringLoader) Load(ctx context.Context) (io.Reader, error) { + v, err := p.client.Get(ctx, p.key).Bytes() + if err != nil { + return nil, err + } + return bytes.NewReader(v), nil +} + +func (p *redisStringLoader) Close() error { + return p.client.Close() +} + type redisSetLoader struct { client *redis.Client key string diff --git a/internal/util/plugin/plugin.go b/internal/plugin/plugin.go similarity index 100% rename from internal/util/plugin/plugin.go rename to internal/plugin/plugin.go diff --git a/internal/util/pht/server.go b/internal/util/pht/server.go index b6d0d8c..1c28038 100644 --- a/internal/util/pht/server.go +++ b/internal/util/pht/server.go @@ -2,6 +2,7 @@ package pht import ( "bufio" + "context" "crypto/tls" "encoding/base64" "errors" @@ -38,6 +39,7 @@ type serverOptions struct { tlsConfig *tls.Config readBufferSize int readTimeout time.Duration + mptcp bool logger logger.Logger } @@ -81,6 +83,12 @@ func ReadTimeoutServerOption(timeout time.Duration) ServerOption { } } +func MPTCPServerOption(mptcp bool) ServerOption { + return func(opts *serverOptions) { + opts.mptcp = mptcp + } +} + func LoggerServerOption(logger logger.Logger) ServerOption { return func(opts *serverOptions) { opts.logger = logger @@ -187,7 +195,13 @@ func (s *Server) ListenAndServe() error { if xnet.IsIPv4(s.httpServer.Addr) { network = "tcp4" } - ln, err := net.Listen(network, s.httpServer.Addr) + + lc := net.ListenConfig{} + if s.options.mptcp { + lc.SetMultipathTCP(true) + s.options.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + ln, err := lc.Listen(context.Background(), network, s.httpServer.Addr) if err != nil { s.options.logger.Error(err) return err diff --git a/listener/grpc/listener.go b/listener/grpc/listener.go index fb3f862..4c1b744 100644 --- a/listener/grpc/listener.go +++ b/listener/grpc/listener.go @@ -1,6 +1,7 @@ package grpc import ( + "context" "net" "time" @@ -54,7 +55,12 @@ func (l *grpcListener) Init(md md.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "tcp4" } - ln, err := net.Listen(network, l.options.Addr) + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + ln, err := lc.Listen(context.Background(), network, l.options.Addr) if err != nil { return } diff --git a/listener/grpc/metadata.go b/listener/grpc/metadata.go index 662e1ce..dbf9fc4 100644 --- a/listener/grpc/metadata.go +++ b/listener/grpc/metadata.go @@ -21,6 +21,7 @@ type metadata struct { keepaliveTimeout time.Duration keepalivePermitWithoutStream bool keepaliveMaxConnectionIdle time.Duration + mptcp bool } func (l *grpcListener) parseMetadata(md mdata.Metadata) (err error) { @@ -49,6 +50,7 @@ func (l *grpcListener) parseMetadata(md mdata.Metadata) (err error) { l.md.keepalivePermitWithoutStream = mdutil.GetBool(md, "grpc.keepalive.permitWithoutStream", "keepalive.permitWithoutStream") l.md.keepaliveMaxConnectionIdle = mdutil.GetDuration(md, "grpc.keepalive.maxConnectionIdle", "keepalive.maxConnectionIdle") + l.md.mptcp = mdutil.GetBool(md, "mptcp") } return diff --git a/listener/http2/h2/listener.go b/listener/http2/h2/listener.go index f8cea7d..5e4ab66 100644 --- a/listener/http2/h2/listener.go +++ b/listener/http2/h2/listener.go @@ -1,6 +1,7 @@ package h2 import ( + "context" "crypto/tls" "errors" "net" @@ -74,7 +75,12 @@ func (l *h2Listener) Init(md md.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "tcp4" } - ln, err := net.Listen(network, l.options.Addr) + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + ln, err := lc.Listen(context.Background(), network, l.options.Addr) if err != nil { return err } diff --git a/listener/http2/h2/metadata.go b/listener/http2/h2/metadata.go index 0731550..3b0282a 100644 --- a/listener/http2/h2/metadata.go +++ b/listener/http2/h2/metadata.go @@ -12,6 +12,7 @@ const ( type metadata struct { path string backlog int + mptcp bool } func (l *h2Listener) parseMetadata(md mdata.Metadata) (err error) { @@ -26,5 +27,7 @@ func (l *h2Listener) parseMetadata(md mdata.Metadata) (err error) { } l.md.path = mdutil.GetString(md, path) + l.md.mptcp = mdutil.GetBool(md, "mptcp") + return } diff --git a/listener/http2/listener.go b/listener/http2/listener.go index 5870492..abee6fb 100644 --- a/listener/http2/listener.go +++ b/listener/http2/listener.go @@ -1,6 +1,7 @@ package http2 import ( + "context" "crypto/tls" "net" "net/http" @@ -63,7 +64,12 @@ func (l *http2Listener) Init(md md.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "tcp4" } - ln, err := net.Listen(network, l.options.Addr) + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + ln, err := lc.Listen(context.Background(), network, l.options.Addr) if err != nil { return err } diff --git a/listener/http2/metadata.go b/listener/http2/metadata.go index 1e1aa28..7184c84 100644 --- a/listener/http2/metadata.go +++ b/listener/http2/metadata.go @@ -11,6 +11,7 @@ const ( type metadata struct { backlog int + mptcp bool } func (l *http2Listener) parseMetadata(md mdata.Metadata) (err error) { @@ -22,5 +23,7 @@ func (l *http2Listener) parseMetadata(md mdata.Metadata) (err error) { if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } + l.md.mptcp = mdutil.GetBool(md, "mptcp") + return } diff --git a/listener/mtls/listener.go b/listener/mtls/listener.go index 9cfe92b..6cdd920 100644 --- a/listener/mtls/listener.go +++ b/listener/mtls/listener.go @@ -1,6 +1,7 @@ package mtls import ( + "context" "crypto/tls" "net" "time" @@ -51,7 +52,13 @@ func (l *mtlsListener) Init(md md.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "tcp4" } - ln, err := net.Listen(network, l.options.Addr) + + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + ln, err := lc.Listen(context.Background(), network, l.options.Addr) if err != nil { return } diff --git a/listener/mtls/metadata.go b/listener/mtls/metadata.go index 17a6171..e3b7a7a 100644 --- a/listener/mtls/metadata.go +++ b/listener/mtls/metadata.go @@ -20,6 +20,7 @@ type metadata struct { muxMaxStreamBuffer int backlog int + mptcp bool } func (l *mtlsListener) parseMetadata(md mdata.Metadata) (err error) { @@ -46,5 +47,7 @@ func (l *mtlsListener) parseMetadata(md mdata.Metadata) (err error) { l.md.muxMaxReceiveBuffer = mdutil.GetInt(md, muxMaxReceiveBuffer) l.md.muxMaxStreamBuffer = mdutil.GetInt(md, muxMaxStreamBuffer) + l.md.mptcp = mdutil.GetBool(md, "mptcp") + return } diff --git a/listener/mws/listener.go b/listener/mws/listener.go index 9884ebd..14b2780 100644 --- a/listener/mws/listener.go +++ b/listener/mws/listener.go @@ -1,6 +1,7 @@ package mws import ( + "context" "crypto/tls" "net" "net/http" @@ -94,7 +95,13 @@ func (l *mwsListener) Init(md md.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "tcp4" } - ln, err := net.Listen(network, l.options.Addr) + + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + ln, err := lc.Listen(context.Background(), network, l.options.Addr) if err != nil { return } diff --git a/listener/mws/metadata.go b/listener/mws/metadata.go index 69d0abe..402941c 100644 --- a/listener/mws/metadata.go +++ b/listener/mws/metadata.go @@ -30,6 +30,8 @@ type metadata struct { muxMaxFrameSize int muxMaxReceiveBuffer int muxMaxStreamBuffer int + + mptcp bool } func (l *mwsListener) parseMetadata(md mdata.Metadata) (err error) { @@ -82,5 +84,8 @@ func (l *mwsListener) parseMetadata(md mdata.Metadata) (err error) { } l.md.header = hd } + + l.md.mptcp = mdutil.GetBool(md, "mptcp") + return } diff --git a/listener/obfs/http/listener.go b/listener/obfs/http/listener.go index 8872ac7..3111545 100644 --- a/listener/obfs/http/listener.go +++ b/listener/obfs/http/listener.go @@ -1,6 +1,7 @@ package http import ( + "context" "net" "time" @@ -48,7 +49,13 @@ func (l *obfsListener) Init(md md.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "tcp4" } - ln, err := net.Listen(network, l.options.Addr) + + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + ln, err := lc.Listen(context.Background(), network, l.options.Addr) if err != nil { return } diff --git a/listener/obfs/http/metadata.go b/listener/obfs/http/metadata.go index 1dba1b8..1c46d25 100644 --- a/listener/obfs/http/metadata.go +++ b/listener/obfs/http/metadata.go @@ -9,6 +9,7 @@ import ( type metadata struct { header http.Header + mptcp bool } func (l *obfsListener) parseMetadata(md mdata.Metadata) (err error) { @@ -23,5 +24,7 @@ func (l *obfsListener) parseMetadata(md mdata.Metadata) (err error) { } l.md.header = hd } + + l.md.mptcp = mdutil.GetBool(md, "mptcp") return } diff --git a/listener/obfs/tls/listener.go b/listener/obfs/tls/listener.go index ec7758a..ff27bb6 100644 --- a/listener/obfs/tls/listener.go +++ b/listener/obfs/tls/listener.go @@ -1,6 +1,7 @@ package tls import ( + "context" "net" "time" @@ -47,7 +48,13 @@ func (l *obfsListener) Init(md md.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "tcp4" } - ln, err := net.Listen(network, l.options.Addr) + + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + ln, err := lc.Listen(context.Background(), network, l.options.Addr) if err != nil { return } diff --git a/listener/obfs/tls/metadata.go b/listener/obfs/tls/metadata.go index db6d115..9c8535d 100644 --- a/listener/obfs/tls/metadata.go +++ b/listener/obfs/tls/metadata.go @@ -2,11 +2,14 @@ package tls import ( md "github.com/go-gost/core/metadata" + mdutil "github.com/go-gost/core/metadata/util" ) type metadata struct { + mptcp bool } func (l *obfsListener) parseMetadata(md md.Metadata) (err error) { + l.md.mptcp = mdutil.GetBool(md, "mptcp") return } diff --git a/listener/pht/listener.go b/listener/pht/listener.go index 06ae383..df30bd2 100644 --- a/listener/pht/listener.go +++ b/listener/pht/listener.go @@ -74,6 +74,7 @@ func (l *phtListener) Init(md md.Metadata) (err error) { pht_util.BacklogServerOption(l.md.backlog), pht_util.PathServerOption(l.md.authorizePath, l.md.pushPath, l.md.pullPath), pht_util.LoggerServerOption(l.options.Logger), + pht_util.MPTCPServerOption(l.md.mptcp), ) go func() { diff --git a/listener/pht/metadata.go b/listener/pht/metadata.go index 461f6cd..963feee 100644 --- a/listener/pht/metadata.go +++ b/listener/pht/metadata.go @@ -19,6 +19,7 @@ type metadata struct { pushPath string pullPath string backlog int + mptcp bool } func (l *phtListener) parseMetadata(md mdata.Metadata) (err error) { @@ -48,5 +49,6 @@ func (l *phtListener) parseMetadata(md mdata.Metadata) (err error) { l.md.backlog = defaultBacklog } + l.md.mptcp = mdutil.GetBool(md, "mptcp") return } diff --git a/listener/redirect/tcp/listener.go b/listener/redirect/tcp/listener.go index 47df75e..00de640 100644 --- a/listener/redirect/tcp/listener.go +++ b/listener/redirect/tcp/listener.go @@ -46,13 +46,17 @@ func (l *redirectListener) Init(md md.Metadata) (err error) { return } + network := "tcp" + if xnet.IsIPv4(l.options.Addr) { + network = "tcp4" + } lc := net.ListenConfig{} if l.md.tproxy { lc.Control = l.control } - network := "tcp" - if xnet.IsIPv4(l.options.Addr) { - network = "tcp4" + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) } ln, err := lc.Listen(context.Background(), network, l.options.Addr) if err != nil { diff --git a/listener/redirect/tcp/metadata.go b/listener/redirect/tcp/metadata.go index 893988f..d06d41f 100644 --- a/listener/redirect/tcp/metadata.go +++ b/listener/redirect/tcp/metadata.go @@ -7,6 +7,7 @@ import ( type metadata struct { tproxy bool + mptcp bool } func (l *redirectListener) parseMetadata(md mdata.Metadata) (err error) { @@ -14,5 +15,6 @@ func (l *redirectListener) parseMetadata(md mdata.Metadata) (err error) { tproxy = "tproxy" ) l.md.tproxy = mdutil.GetBool(md, tproxy) + l.md.mptcp = mdutil.GetBool(md, "mptcp") return } diff --git a/listener/ssh/listener.go b/listener/ssh/listener.go index dafecfe..b5b7879 100644 --- a/listener/ssh/listener.go +++ b/listener/ssh/listener.go @@ -1,6 +1,7 @@ package ssh import ( + "context" "fmt" "net" "time" @@ -53,7 +54,13 @@ func (l *sshListener) Init(md md.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "tcp4" } - ln, err := net.Listen(network, l.options.Addr) + + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + ln, err := lc.Listen(context.Background(), network, l.options.Addr) if err != nil { return err } diff --git a/listener/ssh/metadata.go b/listener/ssh/metadata.go index 61044f9..cb9c32e 100644 --- a/listener/ssh/metadata.go +++ b/listener/ssh/metadata.go @@ -17,6 +17,7 @@ type metadata struct { signer ssh.Signer authorizedKeys map[string]bool backlog int + mptcp bool } func (l *sshListener) parseMetadata(md mdata.Metadata) (err error) { @@ -64,5 +65,6 @@ func (l *sshListener) parseMetadata(md mdata.Metadata) (err error) { l.md.backlog = defaultBacklog } + l.md.mptcp = mdutil.GetBool(md, "mptcp") return } diff --git a/listener/sshd/listener.go b/listener/sshd/listener.go index c70b040..ef19a22 100644 --- a/listener/sshd/listener.go +++ b/listener/sshd/listener.go @@ -62,7 +62,13 @@ func (l *sshdListener) Init(md md.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "tcp4" } - ln, err := net.Listen(network, l.options.Addr) + + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + ln, err := lc.Listen(context.Background(), network, l.options.Addr) if err != nil { return err } diff --git a/listener/sshd/metadata.go b/listener/sshd/metadata.go index d9644e2..f4d3497 100644 --- a/listener/sshd/metadata.go +++ b/listener/sshd/metadata.go @@ -17,6 +17,7 @@ type metadata struct { signer ssh.Signer authorizedKeys map[string]bool backlog int + mptcp bool } func (l *sshdListener) parseMetadata(md mdata.Metadata) (err error) { @@ -64,5 +65,6 @@ func (l *sshdListener) parseMetadata(md mdata.Metadata) (err error) { l.md.backlog = defaultBacklog } + l.md.mptcp = mdutil.GetBool(md, "mptcp") return } diff --git a/listener/tcp/listener.go b/listener/tcp/listener.go index 451afd0..1e34e00 100644 --- a/listener/tcp/listener.go +++ b/listener/tcp/listener.go @@ -1,6 +1,7 @@ package tcp import ( + "context" "net" "time" @@ -47,7 +48,13 @@ func (l *tcpListener) Init(md md.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "tcp4" } - ln, err := net.Listen(network, l.options.Addr) + + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + ln, err := lc.Listen(context.Background(), network, l.options.Addr) if err != nil { return } diff --git a/listener/tcp/metadata.go b/listener/tcp/metadata.go index e1d286a..70df7c3 100644 --- a/listener/tcp/metadata.go +++ b/listener/tcp/metadata.go @@ -2,11 +2,14 @@ package tcp import ( md "github.com/go-gost/core/metadata" + mdutil "github.com/go-gost/core/metadata/util" ) type metadata struct { + mptcp bool } func (l *tcpListener) parseMetadata(md md.Metadata) (err error) { + l.md.mptcp = mdutil.GetBool(md, "mptcp") return } diff --git a/listener/tls/listener.go b/listener/tls/listener.go index 6f1204f..5460e6f 100644 --- a/listener/tls/listener.go +++ b/listener/tls/listener.go @@ -1,6 +1,7 @@ package tls import ( + "context" "crypto/tls" "net" "time" @@ -48,7 +49,13 @@ func (l *tlsListener) Init(md md.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "tcp4" } - ln, err := net.Listen(network, l.options.Addr) + + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + ln, err := lc.Listen(context.Background(), network, l.options.Addr) if err != nil { return } diff --git a/listener/tls/metadata.go b/listener/tls/metadata.go index d515844..aff8b01 100644 --- a/listener/tls/metadata.go +++ b/listener/tls/metadata.go @@ -2,11 +2,14 @@ package tls import ( mdata "github.com/go-gost/core/metadata" + mdutil "github.com/go-gost/core/metadata/util" ) type metadata struct { + mptcp bool } func (l *tlsListener) parseMetadata(md mdata.Metadata) (err error) { + l.md.mptcp = mdutil.GetBool(md, "mptcp") return } diff --git a/listener/ws/listener.go b/listener/ws/listener.go index 05f8456..6958251 100644 --- a/listener/ws/listener.go +++ b/listener/ws/listener.go @@ -1,6 +1,7 @@ package ws import ( + "context" "crypto/tls" "net" "net/http" @@ -89,7 +90,13 @@ func (l *wsListener) Init(md md.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "tcp4" } - ln, err := net.Listen(network, l.options.Addr) + + lc := net.ListenConfig{} + if l.md.mptcp { + lc.SetMultipathTCP(true) + l.logger.Debugf("mptcp enabled: %v", lc.MultipathTCP()) + } + ln, err := lc.Listen(context.Background(), network, l.options.Addr) if err != nil { return } diff --git a/listener/ws/metadata.go b/listener/ws/metadata.go index 02cfb76..6017606 100644 --- a/listener/ws/metadata.go +++ b/listener/ws/metadata.go @@ -22,8 +22,9 @@ type metadata struct { readBufferSize int writeBufferSize int enableCompression bool + header http.Header - header http.Header + mptcp bool } func (l *wsListener) parseMetadata(md mdata.Metadata) (err error) { @@ -63,5 +64,7 @@ func (l *wsListener) parseMetadata(md mdata.Metadata) (err error) { } l.md.header = hd } + + l.md.mptcp = mdutil.GetBool(md, "mptcp") return } diff --git a/recorder/plugin.go b/recorder/plugin.go index 7004807..456d1d2 100644 --- a/recorder/plugin.go +++ b/recorder/plugin.go @@ -12,18 +12,18 @@ import ( "github.com/go-gost/core/logger" "github.com/go-gost/core/recorder" "github.com/go-gost/plugin/recorder/proto" - "github.com/go-gost/x/internal/util/plugin" + "github.com/go-gost/x/internal/plugin" "google.golang.org/grpc" ) -type grpcPluginRecorder struct { +type grpcPlugin struct { conn grpc.ClientConnInterface client proto.RecorderClient log logger.Logger } -// NewGRPCPluginRecorder creates a Recorder plugin based on gRPC. -func NewGRPCPluginRecorder(name string, addr string, opts ...plugin.Option) recorder.Recorder { +// NewGRPCPlugin creates a Recorder plugin based on gRPC. +func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) recorder.Recorder { var options plugin.Options for _, opt := range opts { opt(&options) @@ -38,7 +38,7 @@ func NewGRPCPluginRecorder(name string, addr string, opts ...plugin.Option) reco log.Error(err) } - p := &grpcPluginRecorder{ + p := &grpcPlugin{ conn: conn, log: log, } @@ -48,7 +48,7 @@ func NewGRPCPluginRecorder(name string, addr string, opts ...plugin.Option) reco return p } -func (p *grpcPluginRecorder) Record(ctx context.Context, b []byte) error { +func (p *grpcPlugin) Record(ctx context.Context, b []byte) error { if p.client == nil { return nil } @@ -64,36 +64,36 @@ func (p *grpcPluginRecorder) Record(ctx context.Context, b []byte) error { return nil } -func (p *grpcPluginRecorder) Close() error { +func (p *grpcPlugin) Close() error { if closer, ok := p.conn.(io.Closer); ok { return closer.Close() } return nil } -type httpRecorderRequest struct { +type httpPluginRequest struct { Data []byte `json:"data"` } -type httpRecorderResponse struct { +type httpPluginResponse struct { OK bool `json:"ok"` } -type httpPluginRecorder struct { +type httpPlugin struct { url string client *http.Client header http.Header log logger.Logger } -// NewHTTPPluginRecorder creates an Recorder plugin based on HTTP. -func NewHTTPPluginRecorder(name string, url string, opts ...plugin.Option) recorder.Recorder { +// NewHTTPPlugin creates an Recorder plugin based on HTTP. +func NewHTTPPlugin(name string, url string, opts ...plugin.Option) recorder.Recorder { var options plugin.Options for _, opt := range opts { opt(&options) } - return &httpPluginRecorder{ + return &httpPlugin{ url: url, client: plugin.NewHTTPClient(&options), header: options.Header, @@ -104,12 +104,12 @@ func NewHTTPPluginRecorder(name string, url string, opts ...plugin.Option) recor } } -func (p *httpPluginRecorder) Record(ctx context.Context, b []byte) error { +func (p *httpPlugin) Record(ctx context.Context, b []byte) error { if len(b) == 0 || p.client == nil { return nil } - rb := httpRecorderRequest{ + rb := httpPluginRequest{ Data: b, } v, err := json.Marshal(&rb) @@ -136,7 +136,7 @@ func (p *httpPluginRecorder) Record(ctx context.Context, b []byte) error { return fmt.Errorf("%s", resp.Status) } - res := httpRecorderResponse{} + res := httpPluginResponse{} if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { return err } diff --git a/registry/hop.go b/registry/hop.go index 14e4bb4..2ae5415 100644 --- a/registry/hop.go +++ b/registry/hop.go @@ -4,24 +4,25 @@ import ( "context" "github.com/go-gost/core/chain" + "github.com/go-gost/core/hop" ) type hopRegistry struct { - registry[chain.Hop] + registry[hop.Hop] } -func (r *hopRegistry) Register(name string, v chain.Hop) error { +func (r *hopRegistry) Register(name string, v hop.Hop) error { return r.registry.Register(name, v) } -func (r *hopRegistry) Get(name string) chain.Hop { +func (r *hopRegistry) Get(name string) hop.Hop { if name != "" { return &hopWrapper{name: name, r: r} } return nil } -func (r *hopRegistry) get(name string) chain.Hop { +func (r *hopRegistry) get(name string) hop.Hop { return r.registry.Get(name) } @@ -35,10 +36,13 @@ func (w *hopWrapper) Nodes() []*chain.Node { if v == nil { return nil } - return v.Nodes() + if nl, ok := v.(hop.NodeList); ok { + return nl.Nodes() + } + return nil } -func (w *hopWrapper) Select(ctx context.Context, opts ...chain.SelectOption) *chain.Node { +func (w *hopWrapper) Select(ctx context.Context, opts ...hop.SelectOption) *chain.Node { v := w.r.get(w.name) if v == nil { return nil diff --git a/registry/registry.go b/registry/registry.go index 2d9e211..ba87bc0 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -9,6 +9,7 @@ import ( "github.com/go-gost/core/auth" "github.com/go-gost/core/bypass" "github.com/go-gost/core/chain" + "github.com/go-gost/core/hop" "github.com/go-gost/core/hosts" "github.com/go-gost/core/ingress" "github.com/go-gost/core/limiter/conn" @@ -31,7 +32,7 @@ var ( connectorReg reg.Registry[NewConnector] = new(connectorRegistry) serviceReg reg.Registry[service.Service] = new(serviceRegistry) chainReg reg.Registry[chain.Chainer] = new(chainRegistry) - hopReg reg.Registry[chain.Hop] = new(hopRegistry) + hopReg reg.Registry[hop.Hop] = new(hopRegistry) autherReg reg.Registry[auth.Authenticator] = new(autherRegistry) admissionReg reg.Registry[admission.Admission] = new(admissionRegistry) bypassReg reg.Registry[bypass.Bypass] = new(bypassRegistry) @@ -119,7 +120,7 @@ func ChainRegistry() reg.Registry[chain.Chainer] { return chainReg } -func HopRegistry() reg.Registry[chain.Hop] { +func HopRegistry() reg.Registry[hop.Hop] { return hopReg } diff --git a/resolver/plugin.go b/resolver/plugin.go index 9c9bc90..60c02be 100644 --- a/resolver/plugin.go +++ b/resolver/plugin.go @@ -14,18 +14,18 @@ import ( "github.com/go-gost/core/resolver" "github.com/go-gost/plugin/resolver/proto" auth_util "github.com/go-gost/x/internal/util/auth" - "github.com/go-gost/x/internal/util/plugin" + "github.com/go-gost/x/internal/plugin" "google.golang.org/grpc" ) -type grpcPluginResolver struct { +type grpcPlugin struct { conn grpc.ClientConnInterface client proto.ResolverClient log logger.Logger } -// NewGRPCPluginResolver creates a Resolver plugin based on gRPC. -func NewGRPCPluginResolver(name string, addr string, opts ...plugin.Option) (resolver.Resolver, error) { +// NewGRPCPlugin creates a Resolver plugin based on gRPC. +func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) (resolver.Resolver, error) { var options plugin.Options for _, opt := range opts { opt(&options) @@ -39,7 +39,7 @@ func NewGRPCPluginResolver(name string, addr string, opts ...plugin.Option) (res if err != nil { log.Error(err) } - p := &grpcPluginResolver{ + p := &grpcPlugin{ conn: conn, log: log, } @@ -49,7 +49,7 @@ func NewGRPCPluginResolver(name string, addr string, opts ...plugin.Option) (res return p, nil } -func (p *grpcPluginResolver) Resolve(ctx context.Context, network, host string) (ips []net.IP, err error) { +func (p *grpcPlugin) Resolve(ctx context.Context, network, host string) (ips []net.IP, err error) { p.log.Debugf("resolve %s/%s", host, network) if p.client == nil { @@ -74,39 +74,39 @@ func (p *grpcPluginResolver) Resolve(ctx context.Context, network, host string) return } -func (p *grpcPluginResolver) Close() error { +func (p *grpcPlugin) Close() error { if closer, ok := p.conn.(io.Closer); ok { return closer.Close() } return nil } -type httpResolverRequest struct { +type httpPluginRequest struct { Network string `json:"network"` Host string `json:"host"` Client string `json:"client"` } -type httpResolverResponse struct { +type httpPluginResponse struct { IPs []string `json:"ips"` OK bool `json:"ok"` } -type httpPluginResolver struct { +type httpPlugin struct { url string client *http.Client header http.Header log logger.Logger } -// NewHTTPPluginResolver creates an Resolver plugin based on HTTP. -func NewHTTPPluginResolver(name string, url string, opts ...plugin.Option) resolver.Resolver { +// NewHTTPPlugin creates an Resolver plugin based on HTTP. +func NewHTTPPlugin(name string, url string, opts ...plugin.Option) resolver.Resolver { var options plugin.Options for _, opt := range opts { opt(&options) } - return &httpPluginResolver{ + return &httpPlugin{ url: url, client: plugin.NewHTTPClient(&options), header: options.Header, @@ -117,14 +117,14 @@ func NewHTTPPluginResolver(name string, url string, opts ...plugin.Option) resol } } -func (p *httpPluginResolver) Resolve(ctx context.Context, network, host string) (ips []net.IP, err error) { +func (p *httpPlugin) Resolve(ctx context.Context, network, host string) (ips []net.IP, err error) { p.log.Debugf("resolve %s/%s", host, network) if p.client == nil { return } - rb := httpResolverRequest{ + rb := httpPluginRequest{ Network: network, Host: host, Client: string(auth_util.IDFromContext(ctx)), @@ -154,7 +154,7 @@ func (p *httpPluginResolver) Resolve(ctx context.Context, network, host string) return } - res := httpResolverResponse{} + res := httpPluginResponse{} if err = json.NewDecoder(resp.Body).Decode(&res); err != nil { return } diff --git a/resolver/resolver.go b/resolver/resolver.go index 0ea44e0..b6e3aca 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -12,7 +12,6 @@ import ( resolver_util "github.com/go-gost/x/internal/util/resolver" "github.com/go-gost/x/resolver/exchanger" "github.com/miekg/dns" - "google.golang.org/grpc" ) type NameServer struct { @@ -28,7 +27,6 @@ type NameServer struct { type options struct { domain string - client *grpc.ClientConn logger logger.Logger } @@ -40,12 +38,6 @@ func DomainOption(domain string) Option { } } -func PluginConnOption(c *grpc.ClientConn) Option { - return func(opts *options) { - opts.client = c - } -} - func LoggerOption(logger logger.Logger) Option { return func(opts *options) { opts.logger = logger