diff --git a/admission/admission.go b/admission/admission.go index 35d5a80..5c9c546 100644 --- a/admission/admission.go +++ b/admission/admission.go @@ -9,7 +9,7 @@ import ( "sync" "time" - admission_pkg "github.com/go-gost/core/admission" + "github.com/go-gost/core/admission" "github.com/go-gost/core/logger" "github.com/go-gost/x/internal/loader" "github.com/go-gost/x/internal/matcher" @@ -69,7 +69,7 @@ func LoggerOption(logger logger.Logger) Option { } } -type admission struct { +type localAdmission struct { ipMatcher matcher.Matcher cidrMatcher matcher.Matcher mu sync.RWMutex @@ -79,14 +79,14 @@ type admission struct { // NewAdmission creates and initializes a new Admission using matcher patterns as its match rules. // The rules will be reversed if the reverse is true. -func NewAdmission(opts ...Option) admission_pkg.Admission { +func NewAdmission(opts ...Option) admission.Admission { var options options for _, opt := range opts { opt(&options) } ctx, cancel := context.WithCancel(context.TODO()) - p := &admission{ + p := &localAdmission{ cancelFunc: cancel, options: options, } @@ -101,7 +101,7 @@ func NewAdmission(opts ...Option) admission_pkg.Admission { return p } -func (p *admission) Admit(ctx context.Context, addr string) bool { +func (p *localAdmission) Admit(ctx context.Context, addr string) bool { if addr == "" || p == nil { return true } @@ -117,7 +117,7 @@ func (p *admission) Admit(ctx context.Context, addr string) bool { p.options.whitelist && matched } -func (p *admission) periodReload(ctx context.Context) error { +func (p *localAdmission) periodReload(ctx context.Context) error { period := p.options.period if period < time.Second { period = time.Second @@ -138,7 +138,7 @@ func (p *admission) periodReload(ctx context.Context) error { } } -func (p *admission) reload(ctx context.Context) error { +func (p *localAdmission) reload(ctx context.Context) error { v, err := p.load(ctx) if err != nil { return err @@ -167,7 +167,7 @@ func (p *admission) reload(ctx context.Context) error { return nil } -func (p *admission) load(ctx context.Context) (patterns []string, err error) { +func (p *localAdmission) load(ctx context.Context) (patterns []string, err error) { if p.options.fileLoader != nil { if lister, ok := p.options.fileLoader.(loader.Lister); ok { list, er := lister.List(ctx) @@ -221,7 +221,7 @@ func (p *admission) load(ctx context.Context) (patterns []string, err error) { return } -func (p *admission) parsePatterns(r io.Reader) (patterns []string, err error) { +func (p *localAdmission) parsePatterns(r io.Reader) (patterns []string, err error) { if r == nil { return } @@ -237,14 +237,14 @@ func (p *admission) parsePatterns(r io.Reader) (patterns []string, err error) { return } -func (p *admission) parseLine(s string) string { +func (p *localAdmission) parseLine(s string) string { if n := strings.IndexByte(s, '#'); n >= 0 { s = s[:n] } return strings.TrimSpace(s) } -func (p *admission) matched(addr string) bool { +func (p *localAdmission) matched(addr string) bool { p.mu.RLock() defer p.mu.RUnlock() @@ -252,7 +252,7 @@ func (p *admission) matched(addr string) bool { p.cidrMatcher.Match(addr) } -func (p *admission) Close() error { +func (p *localAdmission) Close() error { p.cancelFunc() if p.options.fileLoader != nil { p.options.fileLoader.Close() diff --git a/admission/plugin.go b/admission/plugin.go index 81540ee..c9ce744 100644 --- a/admission/plugin.go +++ b/admission/plugin.go @@ -1,12 +1,16 @@ package admission import ( + "bytes" "context" + "encoding/json" "io" + "net/http" - admission_pkg "github.com/go-gost/core/admission" + "github.com/go-gost/core/admission" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/admission/proto" + "github.com/go-gost/x/internal/util/plugin" "google.golang.org/grpc" ) @@ -17,13 +21,24 @@ type grpcPluginAdmission struct { } // NewGRPCPluginAdmission creates an Admission plugin based on gRPC. -func NewGRPCPluginAdmission(name string, conn grpc.ClientConnInterface) admission_pkg.Admission { +func NewGRPCPluginAdmission(name string, addr string, opts ...plugin.Option) admission.Admission { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + log := logger.Default().WithFields(map[string]any{ + "kind": "admission", + "admission": name, + }) + conn, err := plugin.NewGRPCConn(addr, &options) + if err != nil { + log.Error(err) + } + p := &grpcPluginAdmission{ conn: conn, - log: logger.Default().WithFields(map[string]any{ - "kind": "admission", - "admission": name, - }), + log: log, } if conn != nil { p.client = proto.NewAdmissionClient(conn) @@ -53,3 +68,75 @@ func (p *grpcPluginAdmission) Close() error { } return nil } + +type httpAdmissionRequest struct { + Addr string `json:"addr"` +} + +type httpAdmissionResponse struct { + OK bool `json:"ok"` +} + +type httpPluginAdmission struct { + url string + client *http.Client + header http.Header + log logger.Logger +} + +// NewHTTPPluginAdmission creates an Admission plugin based on HTTP. +func NewHTTPPluginAdmission(name string, url string, opts ...plugin.Option) admission.Admission { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + return &httpPluginAdmission{ + url: url, + client: plugin.NewHTTPClient(&options), + header: options.Header, + log: logger.Default().WithFields(map[string]any{ + "kind": "admission", + "admission": name, + }), + } +} + +func (p *httpPluginAdmission) Admit(ctx context.Context, addr string) (ok bool) { + if p.client == nil { + return + } + + rb := httpAdmissionRequest{ + Addr: addr, + } + v, err := json.Marshal(&rb) + if err != nil { + return + } + + req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + if err != nil { + return + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return + } + + res := httpAdmissionResponse{} + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return + } + return res.OK +} diff --git a/admission/wrapper/listener.go b/admission/wrapper/listener.go index 1c2d7b1..55d4e40 100644 --- a/admission/wrapper/listener.go +++ b/admission/wrapper/listener.go @@ -5,11 +5,13 @@ import ( "net" "github.com/go-gost/core/admission" + "github.com/go-gost/core/logger" ) type listener struct { net.Listener admission admission.Admission + log logger.Logger } func WrapListener(admission admission.Admission, ln net.Listener) net.Listener { diff --git a/api/middleware.go b/api/middleware.go index 5b9006a..c350b1c 100644 --- a/api/middleware.go +++ b/api/middleware.go @@ -35,7 +35,7 @@ func mwBasicAuth(auther auth.Authenticator) gin.HandlerFunc { return } u, p, _ := c.Request.BasicAuth() - if ok, _ := auther.Authenticate(c, u, p); !ok { + if _, ok := auther.Authenticate(c, u, p); !ok { c.AbortWithStatus(http.StatusUnauthorized) } } diff --git a/auth/auth.go b/auth/auth.go index cfc9e4b..49f7be9 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -97,20 +97,20 @@ func NewAuthenticator(opts ...Option) auth.Authenticator { } // Authenticate checks the validity of the provided user-password pair. -func (p *authenticator) Authenticate(ctx context.Context, user, password string) (bool, string) { +func (p *authenticator) Authenticate(ctx context.Context, user, password string) (string, bool) { if p == nil { - return true, "" + return "", true } p.mu.RLock() defer p.mu.RUnlock() if len(p.kvs) == 0 { - return false, "" + return "", false } v, ok := p.kvs[user] - return ok && (v == "" || password == v), "" + return "", ok && (v == "" || password == v) } func (p *authenticator) periodReload(ctx context.Context) error { diff --git a/auth/plugin.go b/auth/plugin.go index 66d0f52..c7085fd 100644 --- a/auth/plugin.go +++ b/auth/plugin.go @@ -1,12 +1,16 @@ package auth import ( + "bytes" "context" + "encoding/json" "io" + "net/http" "github.com/go-gost/core/auth" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/auth/proto" + "github.com/go-gost/x/internal/util/plugin" "google.golang.org/grpc" ) @@ -17,13 +21,24 @@ type grpcPluginAuthenticator struct { } // NewGRPCPluginAuthenticator creates an Authenticator plugin based on gRPC. -func NewGRPCPluginAuthenticator(name string, conn grpc.ClientConnInterface) auth.Authenticator { +func NewGRPCPluginAuthenticator(name string, addr string, opts ...plugin.Option) auth.Authenticator { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + log := logger.Default().WithFields(map[string]any{ + "kind": "auther", + "auther": name, + }) + conn, err := plugin.NewGRPCConn(addr, &options) + if err != nil { + log.Error(err) + } + p := &grpcPluginAuthenticator{ conn: conn, - log: logger.Default().WithFields(map[string]any{ - "kind": "auther", - "auther": name, - }), + log: log, } if conn != nil { @@ -33,9 +48,9 @@ func NewGRPCPluginAuthenticator(name string, conn grpc.ClientConnInterface) auth } // Authenticate checks the validity of the provided user-password pair. -func (p *grpcPluginAuthenticator) Authenticate(ctx context.Context, user, password string) (bool, string) { +func (p *grpcPluginAuthenticator) Authenticate(ctx context.Context, user, password string) (string, bool) { if p.client == nil { - return false, "" + return "", false } r, err := p.client.Authenticate(ctx, @@ -45,9 +60,9 @@ func (p *grpcPluginAuthenticator) Authenticate(ctx context.Context, user, passwo }) if err != nil { p.log.Error(err) - return false, "" + return "", false } - return r.Ok, r.Id + return r.Id, r.Ok } func (p *grpcPluginAuthenticator) Close() error { @@ -56,3 +71,78 @@ func (p *grpcPluginAuthenticator) Close() error { } return nil } + +type httpAutherRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +type httpAutherResponse struct { + OK bool `json:"ok"` + ID string `json:"id"` +} + +type httpPluginAuther struct { + url string + client *http.Client + header http.Header + log logger.Logger +} + +// NewHTTPPluginAuthenticator creates an Authenticator plugin based on HTTP. +func NewHTTPPluginAuthenticator(name string, url string, opts ...plugin.Option) auth.Authenticator { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + return &httpPluginAuther{ + url: url, + client: plugin.NewHTTPClient(&options), + header: options.Header, + log: logger.Default().WithFields(map[string]any{ + "kind": "auther", + "auther": name, + }), + } +} + +func (p *httpPluginAuther) Authenticate(ctx context.Context, user, password string) (id string, ok bool) { + if p.client == nil { + return + } + + rb := httpAutherRequest{ + Username: user, + Password: password, + } + v, err := json.Marshal(&rb) + if err != nil { + return + } + + req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + if err != nil { + return + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return + } + + res := httpAutherResponse{} + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return + } + return res.ID, res.OK +} diff --git a/bypass/bypass.go b/bypass/bypass.go index fe7d06d..6d300b8 100644 --- a/bypass/bypass.go +++ b/bypass/bypass.go @@ -9,7 +9,7 @@ import ( "sync" "time" - bypass_pkg "github.com/go-gost/core/bypass" + "github.com/go-gost/core/bypass" "github.com/go-gost/core/logger" "github.com/go-gost/x/internal/loader" "github.com/go-gost/x/internal/matcher" @@ -77,7 +77,7 @@ func LoggerOption(logger logger.Logger) Option { } } -type bypass struct { +type localBypass struct { ipMatcher matcher.Matcher cidrMatcher matcher.Matcher domainMatcher matcher.Matcher @@ -89,7 +89,7 @@ type bypass struct { // NewBypass creates and initializes a new Bypass. // The rules will be reversed if the reverse option is true. -func NewBypass(opts ...Option) bypass_pkg.Bypass { +func NewBypass(opts ...Option) bypass.Bypass { var options options for _, opt := range opts { opt(&options) @@ -97,7 +97,7 @@ func NewBypass(opts ...Option) bypass_pkg.Bypass { ctx, cancel := context.WithCancel(context.TODO()) - bp := &bypass{ + bp := &localBypass{ cancelFunc: cancel, options: options, } @@ -112,7 +112,7 @@ func NewBypass(opts ...Option) bypass_pkg.Bypass { return bp } -func (bp *bypass) periodReload(ctx context.Context) error { +func (bp *localBypass) periodReload(ctx context.Context) error { period := bp.options.period if period < time.Second { period = time.Second @@ -133,7 +133,7 @@ func (bp *bypass) periodReload(ctx context.Context) error { } } -func (bp *bypass) reload(ctx context.Context) error { +func (bp *localBypass) reload(ctx context.Context) error { v, err := bp.load(ctx) if err != nil { return err @@ -171,7 +171,7 @@ func (bp *bypass) reload(ctx context.Context) error { return nil } -func (bp *bypass) load(ctx context.Context) (patterns []string, err error) { +func (bp *localBypass) load(ctx context.Context) (patterns []string, err error) { if bp.options.fileLoader != nil { if lister, ok := bp.options.fileLoader.(loader.Lister); ok { list, er := lister.List(ctx) @@ -224,7 +224,7 @@ func (bp *bypass) load(ctx context.Context) (patterns []string, err error) { return } -func (bp *bypass) parsePatterns(r io.Reader) (patterns []string, err error) { +func (bp *localBypass) parsePatterns(r io.Reader) (patterns []string, err error) { if r == nil { return } @@ -240,7 +240,7 @@ func (bp *bypass) parsePatterns(r io.Reader) (patterns []string, err error) { return } -func (bp *bypass) Contains(ctx context.Context, addr string) bool { +func (bp *localBypass) Contains(ctx context.Context, addr string) bool { if addr == "" || bp == nil { return false } @@ -260,14 +260,14 @@ func (bp *bypass) Contains(ctx context.Context, addr string) bool { return b } -func (bp *bypass) parseLine(s string) string { +func (bp *localBypass) parseLine(s string) string { if n := strings.IndexByte(s, '#'); n >= 0 { s = s[:n] } return strings.TrimSpace(s) } -func (bp *bypass) matched(addr string) bool { +func (bp *localBypass) matched(addr string) bool { bp.mu.RLock() defer bp.mu.RUnlock() @@ -280,7 +280,7 @@ func (bp *bypass) matched(addr string) bool { bp.wildcardMatcher.Match(addr) } -func (bp *bypass) Close() error { +func (bp *localBypass) Close() error { bp.cancelFunc() if bp.options.fileLoader != nil { bp.options.fileLoader.Close() diff --git a/bypass/plugin.go b/bypass/plugin.go index dc80934..88dfafe 100644 --- a/bypass/plugin.go +++ b/bypass/plugin.go @@ -1,13 +1,17 @@ package bypass import ( + "bytes" "context" + "encoding/json" "io" + "net/http" - bypass_pkg "github.com/go-gost/core/bypass" + "github.com/go-gost/core/bypass" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/bypass/proto" auth_util "github.com/go-gost/x/internal/util/auth" + "github.com/go-gost/x/internal/util/plugin" "google.golang.org/grpc" ) @@ -18,13 +22,24 @@ type grpcPluginBypass struct { } // NewGRPCPluginBypass creates a Bypass plugin based on gRPC. -func NewGRPCPluginBypass(name string, conn grpc.ClientConnInterface) bypass_pkg.Bypass { +func NewGRPCPluginBypass(name string, addr string, opts ...plugin.Option) bypass.Bypass { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + log := logger.Default().WithFields(map[string]any{ + "kind": "bypass", + "bypass": name, + }) + conn, err := plugin.NewGRPCConn(addr, &options) + if err != nil { + log.Error(err) + } + p := &grpcPluginBypass{ conn: conn, - log: logger.Default().WithFields(map[string]any{ - "kind": "bypass", - "bypass": name, - }), + log: log, } if conn != nil { p.client = proto.NewBypassClient(conn) @@ -55,3 +70,77 @@ func (p *grpcPluginBypass) Close() error { } return nil } + +type httpBypassRequest struct { + Addr string `json:"addr"` + Client string `json:"client"` +} + +type httpBypassResponse struct { + OK bool `json:"ok"` +} + +type httpPluginBypass struct { + url string + client *http.Client + header http.Header + log logger.Logger +} + +// NewHTTPPluginBypass creates an Bypass plugin based on HTTP. +func NewHTTPPluginBypass(name string, url string, opts ...plugin.Option) bypass.Bypass { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + return &httpPluginBypass{ + url: url, + client: plugin.NewHTTPClient(&options), + header: options.Header, + log: logger.Default().WithFields(map[string]any{ + "kind": "bypass", + "bypass": name, + }), + } +} + +func (p *httpPluginBypass) Contains(ctx context.Context, addr string) (ok bool) { + if p.client == nil { + return + } + + rb := httpBypassRequest{ + Addr: addr, + Client: string(auth_util.IDFromContext(ctx)), + } + v, err := json.Marshal(&rb) + if err != nil { + return + } + + req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + if err != nil { + return + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return + } + + res := httpBypassResponse{} + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return + } + return res.OK +} diff --git a/config/config.go b/config/config.go index d7d2f86..a34b566 100644 --- a/config/config.go +++ b/config/config.go @@ -110,9 +110,11 @@ type TLSConfig struct { } type PluginConfig struct { - Addr string `json:"addr"` - TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` - Token string `yaml:",omitempty" json:"token,omitempty"` + Type string `json:"type"` + Addr string `json:"addr"` + TLS *TLSConfig `yaml:",omitempty" json:"tls,omitempty"` + Timeout time.Duration `yaml:",omitempty" json:"timeout,omitempty"` + Token string `yaml:",omitempty" json:"token,omitempty"` } type AutherConfig struct { diff --git a/config/parsing/parse.go b/config/parsing/parse.go index ab09845..77a4ef5 100644 --- a/config/parsing/parse.go +++ b/config/parsing/parse.go @@ -4,7 +4,9 @@ import ( "context" "crypto/tls" "net" + "net/http" "net/url" + "strings" "github.com/go-gost/core/admission" "github.com/go-gost/core/auth" @@ -26,6 +28,7 @@ import ( xhosts "github.com/go-gost/x/hosts" xingress "github.com/go-gost/x/ingress" "github.com/go-gost/x/internal/loader" + "github.com/go-gost/x/internal/util/plugin" xconn "github.com/go-gost/x/limiter/conn" xrate "github.com/go-gost/x/limiter/rate" xtraffic "github.com/go-gost/x/limiter/traffic" @@ -61,11 +64,27 @@ func ParseAuther(cfg *config.AutherConfig) auth.Authenticator { } if cfg.Plugin != nil { - c, err := newPluginConn(cfg.Plugin) - if err != nil { - logger.Default().Error(err) + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch cfg.Plugin.Type { + case "http": + return auth_impl.NewHTTPPluginAuthenticator( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return auth_impl.NewGRPCPluginAuthenticator( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) } - return auth_impl.NewGRPCPluginAuthenticator(cfg.Name, c) } m := make(map[string]string) @@ -189,11 +208,27 @@ func ParseAdmission(cfg *config.AdmissionConfig) admission.Admission { } if cfg.Plugin != nil { - c, err := newPluginConn(cfg.Plugin) - if err != nil { - logger.Default().Error(err) + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return admission_impl.NewHTTPPluginAdmission( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return admission_impl.NewGRPCPluginAdmission( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) } - return admission_impl.NewGRPCPluginAdmission(cfg.Name, c) } opts := []admission_impl.Option{ @@ -232,11 +267,27 @@ func ParseBypass(cfg *config.BypassConfig) bypass.Bypass { } if cfg.Plugin != nil { - c, err := newPluginConn(cfg.Plugin) - if err != nil { - logger.Default().Error(err) + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return bypass_impl.NewHTTPPluginBypass( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return bypass_impl.NewGRPCPluginBypass( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) } - return bypass_impl.NewGRPCPluginBypass(cfg.Name, c) } opts := []bypass_impl.Option{ @@ -275,12 +326,27 @@ func ParseResolver(cfg *config.ResolverConfig) (resolver.Resolver, error) { } if cfg.Plugin != nil { - c, err := newPluginConn(cfg.Plugin) - if err != nil { - logger.Default().Error(err) - return nil, err + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return resolver_impl.NewHTTPPluginResolver( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ), nil + default: + return resolver_impl.NewGRPCPluginResolver( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) } - return resolver_impl.NewGRPCPluginResolver(cfg.Name, c) } var nameservers []resolver_impl.NameServer @@ -313,11 +379,27 @@ func ParseHosts(cfg *config.HostsConfig) hosts.HostMapper { } if cfg.Plugin != nil { - c, err := newPluginConn(cfg.Plugin) - if err != nil { - logger.Default().Error(err) + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return xhosts.NewHTTPPluginHostMapper( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return xhosts.NewGRPCPluginHostMapper( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) } - return xhosts.NewGRPCPluginHostMapper(cfg.Name, c) } var mappings []xhosts.Mapping @@ -379,11 +461,27 @@ func ParseIngress(cfg *config.IngressConfig) ingress.Ingress { } if cfg.Plugin != nil { - c, err := newPluginConn(cfg.Plugin) - if err != nil { - logger.Default().Error(err) + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return xingress.NewHTTPPluginIngress( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return xingress.NewGRPCPluginIngress( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) } - return xingress.NewGRPCPluginIngress(cfg.Name, c) } var rules []xingress.Rule @@ -441,11 +539,27 @@ func ParseRecorder(cfg *config.RecorderConfig) (r recorder.Recorder) { } if cfg.Plugin != nil { - c, err := newPluginConn(cfg.Plugin) - if err != nil { - logger.Default().Error(err) + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return xrecorder.NewHTTPPluginRecorder( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return xrecorder.NewGRPCPluginRecorder( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) } - return xrecorder.NewGRPCPluginRecorder(cfg.Name, c) } if cfg.File != nil && cfg.File.Path != "" { @@ -644,7 +758,7 @@ func ParseRateLimiter(cfg *config.LimiterConfig) (lim rate.RateLimiter) { return xrate.NewRateLimiter(opts...) } -func newPluginConn(cfg *config.PluginConfig) (*grpc.ClientConn, error) { +func newGRPCPluginConn(cfg *config.PluginConfig) (*grpc.ClientConn, error) { grpcOpts := []grpc.DialOption{ // grpc.WithBlock(), grpc.WithConnectParams(grpc.ConnectParams{ @@ -681,3 +795,26 @@ func (c *rpcCredentials) GetRequestMetadata(ctx context.Context, uri ...string) func (c *rpcCredentials) RequireTransportSecurity() bool { return false } + +func newHTTPPluginClient(cfg *config.PluginConfig) *http.Client { + if cfg == nil { + return nil + } + + tr := &http.Transport{} + if cfg.TLS != nil { + if cfg.TLS.Secure { + tr.TLSClientConfig = &tls.Config{ + ServerName: cfg.TLS.ServerName, + } + } else { + tr.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + } + return &http.Client{ + Timeout: cfg.Timeout, + Transport: tr, + } +} diff --git a/go.mod b/go.mod index bee5063..ebc0b91 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.9.1 - github.com/go-gost/core v0.0.0-20230919141921-a1419ec2f4d1 + github.com/go-gost/core v0.0.0-20230920145336-6d0e88635be9 github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.4.0 github.com/go-gost/plugin v0.0.0-20230919143240-0e42c7c67eaa diff --git a/go.sum b/go.sum index 00e8301..2599a42 100644 --- a/go.sum +++ b/go.sum @@ -103,6 +103,8 @@ github.com/go-gost/core v0.0.0-20230918131208-c258a114c40b h1:kqALaNXbbYyKFlcLj3 github.com/go-gost/core v0.0.0-20230918131208-c258a114c40b/go.mod h1:db6lBY+DkC3ct4OJfclsKnQwQmcv1B9NnMnpI2MNUwY= github.com/go-gost/core v0.0.0-20230919141921-a1419ec2f4d1 h1:tV5Ra3bBU5R9Mcg9lGzMQeVeLcnCFEEiE6UNnx6F46k= github.com/go-gost/core v0.0.0-20230919141921-a1419ec2f4d1/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= +github.com/go-gost/core v0.0.0-20230920145336-6d0e88635be9 h1:VHka8LcdBJmM7Yv2bjQO5kctF0T9O4E/PVzgkdk0Vdo= +github.com/go-gost/core v0.0.0-20230920145336-6d0e88635be9/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.4.0 h1:EIrOEkpJez4gwHrMa33frA+hHXJyevjp47thpMQsJzI= diff --git a/handler/forward/local/handler.go b/handler/forward/local/handler.go index 445c129..18961f4 100644 --- a/handler/forward/local/handler.go +++ b/handler/forward/local/handler.go @@ -209,7 +209,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l if auther := target.Options().Auther; auther != nil { username, password, _ := req.BasicAuth() - ok, id := auther.Authenticate(ctx, username, password) + id, ok := auther.Authenticate(ctx, username, password) if !ok { resp.StatusCode = http.StatusUnauthorized resp.Header.Set("WWW-Authenticate", "Basic") diff --git a/handler/forward/remote/handler.go b/handler/forward/remote/handler.go index a6bce36..1ae9905 100644 --- a/handler/forward/remote/handler.go +++ b/handler/forward/remote/handler.go @@ -206,7 +206,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l if auther := target.Options().Auther; auther != nil { username, password, _ := req.BasicAuth() - ok, id := auther.Authenticate(ctx, username, password) + id, ok := auther.Authenticate(ctx, username, password) if !ok { resp.StatusCode = http.StatusUnauthorized resp.Header.Set("WWW-Authenticate", "Basic") diff --git a/handler/http/handler.go b/handler/http/handler.go index 28a4ab7..7187634 100644 --- a/handler/http/handler.go +++ b/handler/http/handler.go @@ -146,7 +146,7 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt resp.Header = http.Header{} } - ok, id := h.authenticate(ctx, conn, req, resp, log) + id, ok := h.authenticate(ctx, conn, req, resp, log) if !ok { return nil } @@ -269,12 +269,12 @@ func (h *httpHandler) basicProxyAuth(proxyAuth string, log logger.Logger) (usern return cs[:s], cs[s+1:], true } -func (h *httpHandler) authenticate(ctx context.Context, conn net.Conn, req *http.Request, resp *http.Response, log logger.Logger) (ok bool, token string) { +func (h *httpHandler) authenticate(ctx context.Context, conn net.Conn, req *http.Request, resp *http.Response, log logger.Logger) (id string, ok bool) { u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization"), log) if h.options.Auther == nil { - return true, "" + return "", true } - if ok, token = h.options.Auther.Authenticate(ctx, u, p); ok { + if id, ok = h.options.Auther.Authenticate(ctx, u, p); ok { return } diff --git a/handler/http2/handler.go b/handler/http2/handler.go index 9aaf4bf..5e1ce31 100644 --- a/handler/http2/handler.go +++ b/handler/http2/handler.go @@ -146,7 +146,7 @@ func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req Body: io.NopCloser(bytes.NewReader([]byte{})), } - ok, id := h.authenticate(ctx, w, req, resp, log) + id, ok := h.authenticate(ctx, w, req, resp, log) if !ok { return nil } @@ -254,12 +254,12 @@ func (h *http2Handler) basicProxyAuth(proxyAuth string) (username, password stri return cs[:s], cs[s+1:], true } -func (h *http2Handler) authenticate(ctx context.Context, w http.ResponseWriter, r *http.Request, resp *http.Response, log logger.Logger) (ok bool, token string) { +func (h *http2Handler) authenticate(ctx context.Context, w http.ResponseWriter, r *http.Request, resp *http.Response, log logger.Logger) (id string, ok bool) { u, p, _ := h.basicProxyAuth(r.Header.Get("Proxy-Authorization")) if h.options.Auther == nil { - return true, "" + return "", true } - if ok, token = h.options.Auther.Authenticate(ctx, u, p); ok { + if id, ok = h.options.Auther.Authenticate(ctx, u, p); ok { return } diff --git a/handler/relay/handler.go b/handler/relay/handler.go index c4946e7..d89fd88 100644 --- a/handler/relay/handler.go +++ b/handler/relay/handler.go @@ -202,7 +202,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle } if h.options.Auther != nil { - ok, id := h.options.Auther.Authenticate(ctx, user, pass) + id, ok := h.options.Auther.Authenticate(ctx, user, pass) if !ok { resp.Status = relay.StatusUnauthorized resp.WriteTo(conn) diff --git a/handler/socks/v4/handler.go b/handler/socks/v4/handler.go index 0163920..f7a716c 100644 --- a/handler/socks/v4/handler.go +++ b/handler/socks/v4/handler.go @@ -92,7 +92,7 @@ func (h *socks4Handler) Handle(ctx context.Context, conn net.Conn, opts ...handl conn.SetReadDeadline(time.Time{}) if h.options.Auther != nil { - ok, id := h.options.Auther.Authenticate(ctx, string(req.Userid), "") + id, ok := h.options.Auther.Authenticate(ctx, string(req.Userid), "") if !ok { resp := gosocks4.NewReply(gosocks4.RejectedUserid, nil) log.Trace(resp) diff --git a/handler/socks/v5/selector.go b/handler/socks/v5/selector.go index 8f6b271..0a41cf4 100644 --- a/handler/socks/v5/selector.go +++ b/handler/socks/v5/selector.go @@ -68,7 +68,7 @@ func (s *serverSelector) OnSelected(method uint8, conn net.Conn) (string, net.Co var id string if s.Authenticator != nil { var ok bool - ok, id = s.Authenticator.Authenticate(context.Background(), req.Username, req.Password) + id, ok = s.Authenticator.Authenticate(context.Background(), req.Username, req.Password) if !ok { resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Failure) if err := resp.Write(conn); err != nil { diff --git a/handler/tun/server.go b/handler/tun/server.go index caa22b4..cc30a5b 100644 --- a/handler/tun/server.go +++ b/handler/tun/server.go @@ -135,7 +135,7 @@ func (h *tunHandler) transportServer(ctx context.Context, tun io.ReadWriter, con ok := true key := bytes.TrimRight((*b)[4:20], "\x00") for _, ip := range peerIPs { - if ok, _ = auther.Authenticate(ctx, ip.String(), string(key)); !ok { + if _, ok = auther.Authenticate(ctx, ip.String(), string(key)); !ok { break } } diff --git a/hosts/plugin.go b/hosts/plugin.go index e414652..fcc09c5 100644 --- a/hosts/plugin.go +++ b/hosts/plugin.go @@ -1,14 +1,18 @@ package hosts import ( + "bytes" "context" + "encoding/json" "io" "net" + "net/http" "github.com/go-gost/core/hosts" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/hosts/proto" auth_util "github.com/go-gost/x/internal/util/auth" + "github.com/go-gost/x/internal/util/plugin" "google.golang.org/grpc" ) @@ -19,13 +23,23 @@ type grpcPluginHostMapper struct { } // NewGRPCPluginHostMapper creates a HostMapper plugin based on gRPC. -func NewGRPCPluginHostMapper(name string, conn grpc.ClientConnInterface) hosts.HostMapper { +func NewGRPCPluginHostMapper(name string, addr string, opts ...plugin.Option) hosts.HostMapper { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + log := logger.Default().WithFields(map[string]any{ + "kind": "hosts", + "hosts": name, + }) + conn, err := plugin.NewGRPCConn(addr, &options) + if err != nil { + log.Error(err) + } p := &grpcPluginHostMapper{ conn: conn, - log: logger.Default().WithFields(map[string]any{ - "kind": "hosts", - "hosts": name, - }), + log: log, } if conn != nil { p.client = proto.NewHostMapperClient(conn) @@ -65,3 +79,88 @@ func (p *grpcPluginHostMapper) Close() error { } return nil } + +type httpHostMapperRequest struct { + Network string `json:"network"` + Host string `json:"host"` + Client string `json:"client"` +} + +type httpHostMapperResponse struct { + IPs []string `json:"ips"` + OK bool `json:"ok"` +} + +type httpPluginHostMapper struct { + url string + client *http.Client + header http.Header + log logger.Logger +} + +// NewHTTPPluginHostMapper creates an HostMapper plugin based on HTTP. +func NewHTTPPluginHostMapper(name string, url string, opts ...plugin.Option) hosts.HostMapper { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + return &httpPluginHostMapper{ + url: url, + client: plugin.NewHTTPClient(&options), + header: options.Header, + log: logger.Default().WithFields(map[string]any{ + "kind": "hosts", + "hosts": name, + }), + } +} + +func (p *httpPluginHostMapper) Lookup(ctx context.Context, network, host string) (ips []net.IP, ok bool) { + p.log.Debugf("lookup %s/%s", host, network) + + if p.client == nil { + return + } + + rb := httpHostMapperRequest{ + Network: network, + Host: host, + Client: string(auth_util.IDFromContext(ctx)), + } + v, err := json.Marshal(&rb) + if err != nil { + return + } + + req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + if err != nil { + return + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return + } + + res := httpHostMapperResponse{} + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return + } + + for _, s := range res.IPs { + if ip := net.ParseIP(s); ip != nil { + ips = append(ips, ip) + } + } + return ips, res.OK +} diff --git a/ingress/ingress.go b/ingress/ingress.go index 4d19759..64170d6 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -9,7 +9,7 @@ import ( "sync" "time" - ingress_pkg "github.com/go-gost/core/ingress" + "github.com/go-gost/core/ingress" "github.com/go-gost/core/logger" "github.com/go-gost/x/internal/loader" "google.golang.org/grpc" @@ -74,7 +74,7 @@ func LoggerOption(logger logger.Logger) Option { } } -type ingress struct { +type localIngress struct { rules map[string]Rule cancelFunc context.CancelFunc options options @@ -82,7 +82,7 @@ type ingress struct { } // NewIngress creates and initializes a new Ingress. -func NewIngress(opts ...Option) ingress_pkg.Ingress { +func NewIngress(opts ...Option) ingress.Ingress { var options options for _, opt := range opts { opt(&options) @@ -90,7 +90,7 @@ func NewIngress(opts ...Option) ingress_pkg.Ingress { ctx, cancel := context.WithCancel(context.TODO()) - ing := &ingress{ + ing := &localIngress{ cancelFunc: cancel, options: options, } @@ -105,7 +105,7 @@ func NewIngress(opts ...Option) ingress_pkg.Ingress { return ing } -func (ing *ingress) periodReload(ctx context.Context) error { +func (ing *localIngress) periodReload(ctx context.Context) error { period := ing.options.period if period < time.Second { period = time.Second @@ -126,7 +126,7 @@ func (ing *ingress) periodReload(ctx context.Context) error { } } -func (ing *ingress) reload(ctx context.Context) error { +func (ing *localIngress) reload(ctx context.Context) error { rules := make(map[string]Rule) fn := func(rule Rule) { @@ -160,7 +160,7 @@ func (ing *ingress) reload(ctx context.Context) error { return nil } -func (ing *ingress) load(ctx context.Context) (rules []Rule, err error) { +func (ing *localIngress) load(ctx context.Context) (rules []Rule, err error) { if ing.options.fileLoader != nil { if lister, ok := ing.options.fileLoader.(loader.Lister); ok { list, er := lister.List(ctx) @@ -211,7 +211,7 @@ func (ing *ingress) load(ctx context.Context) (rules []Rule, err error) { return } -func (ing *ingress) parseRules(r io.Reader) (rules []Rule, err error) { +func (ing *localIngress) parseRules(r io.Reader) (rules []Rule, err error) { if r == nil { return } @@ -227,7 +227,7 @@ func (ing *ingress) parseRules(r io.Reader) (rules []Rule, err error) { return } -func (ing *ingress) Get(ctx context.Context, host string) string { +func (ing *localIngress) Get(ctx context.Context, host string) string { if host == "" || ing == nil { return "" } @@ -267,7 +267,7 @@ func (ing *ingress) Get(ctx context.Context, host string) string { return ep } -func (ing *ingress) lookup(host string) string { +func (ing *localIngress) lookup(host string) string { if ing == nil || len(ing.rules) == 0 { return "" } @@ -278,7 +278,7 @@ func (ing *ingress) lookup(host string) string { return ing.rules[host].Endpoint } -func (ing *ingress) parseLine(s string) (rule Rule) { +func (ing *localIngress) parseLine(s string) (rule Rule) { line := strings.Replace(s, "\t", " ", -1) line = strings.TrimSpace(line) if n := strings.IndexByte(line, '#'); n >= 0 { @@ -300,7 +300,7 @@ func (ing *ingress) parseLine(s string) (rule Rule) { } } -func (ing *ingress) Close() error { +func (ing *localIngress) Close() error { ing.cancelFunc() if ing.options.fileLoader != nil { ing.options.fileLoader.Close() diff --git a/ingress/plugin.go b/ingress/plugin.go index 6c1d125..7a05a83 100644 --- a/ingress/plugin.go +++ b/ingress/plugin.go @@ -1,12 +1,16 @@ package ingress import ( + "bytes" "context" + "encoding/json" "io" + "net/http" - ingress_pkg "github.com/go-gost/core/ingress" + "github.com/go-gost/core/ingress" "github.com/go-gost/core/logger" "github.com/go-gost/plugin/ingress/proto" + "github.com/go-gost/x/internal/util/plugin" "google.golang.org/grpc" ) @@ -16,14 +20,25 @@ type grpcPluginIngress struct { log logger.Logger } -// NewGRPCPluginIngress creates a ingress plugin based on gRPC. -func NewGRPCPluginIngress(name string, conn grpc.ClientConnInterface) ingress_pkg.Ingress { +// NewGRPCPluginIngress creates an Ingress plugin based on gRPC. +func NewGRPCPluginIngress(name string, addr string, opts ...plugin.Option) ingress.Ingress { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + log := logger.Default().WithFields(map[string]any{ + "kind": "ingress", + "ingress": name, + }) + conn, err := plugin.NewGRPCConn(addr, &options) + if err != nil { + log.Error(err) + } + p := &grpcPluginIngress{ conn: conn, - log: logger.Default().WithFields(map[string]any{ - "kind": "ingress", - "ingress": name, - }), + log: log, } if conn != nil { p.client = proto.NewIngressClient(conn) @@ -53,3 +68,75 @@ func (p *grpcPluginIngress) Close() error { } return nil } + +type httpIngressRequest struct { + Host string `json:"host"` +} + +type httpIngressResponse struct { + Endpoint string `json:"endpoint"` +} + +type httpPluginIngress struct { + url string + client *http.Client + header http.Header + log logger.Logger +} + +// NewHTTPPluginIngress creates an Ingress plugin based on HTTP. +func NewHTTPPluginIngress(name string, url string, opts ...plugin.Option) ingress.Ingress { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + return &httpPluginIngress{ + url: url, + client: plugin.NewHTTPClient(&options), + header: options.Header, + log: logger.Default().WithFields(map[string]any{ + "kind": "ingress", + "ingress": name, + }), + } +} + +func (p *httpPluginIngress) Get(ctx context.Context, host string) (endpoint string) { + if p.client == nil { + return + } + + rb := httpIngressRequest{ + Host: host, + } + v, err := json.Marshal(&rb) + if err != nil { + return + } + + req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + if err != nil { + return + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return + } + + res := httpIngressResponse{} + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return + } + return res.Endpoint +} diff --git a/internal/util/plugin/plugin.go b/internal/util/plugin/plugin.go new file mode 100644 index 0000000..56a7428 --- /dev/null +++ b/internal/util/plugin/plugin.go @@ -0,0 +1,91 @@ +package plugin + +import ( + "context" + "crypto/tls" + "net/http" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/backoff" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" +) + +type Options struct { + Token string + TLSConfig *tls.Config + Header http.Header + Timeout time.Duration +} + +type Option func(opts *Options) + +func TokenOption(token string) Option { + return func(opts *Options) { + opts.Token = token + } +} + +func TLSConfigOption(cfg *tls.Config) Option { + return func(opts *Options) { + opts.TLSConfig = cfg + } +} + +func HeaderOption(header http.Header) Option { + return func(opts *Options) { + opts.Header = header + } +} + +func TimeoutOption(timeout time.Duration) Option { + return func(opts *Options) { + opts.Timeout = timeout + } +} + +func NewGRPCConn(addr string, opts *Options) (*grpc.ClientConn, error) { + grpcOpts := []grpc.DialOption{ + // grpc.WithBlock(), + grpc.WithConnectParams(grpc.ConnectParams{ + Backoff: backoff.DefaultConfig, + }), + grpc.FailOnNonTempDialError(true), + } + if opts.TLSConfig != nil { + grpcOpts = append(grpcOpts, + grpc.WithAuthority(opts.TLSConfig.ServerName), + grpc.WithTransportCredentials(credentials.NewTLS(opts.TLSConfig)), + ) + } else { + grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + if opts.Token != "" { + grpcOpts = append(grpcOpts, grpc.WithPerRPCCredentials(&rpcCredentials{token: opts.Token})) + } + return grpc.Dial(addr, grpcOpts...) +} + +type rpcCredentials struct { + token string +} + +func (c *rpcCredentials) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + return map[string]string{ + "token": c.token, + }, nil +} + +func (c *rpcCredentials) RequireTransportSecurity() bool { + return false +} + +func NewHTTPClient(opts *Options) *http.Client { + return &http.Client{ + Timeout: opts.Timeout, + Transport: &http.Transport{ + TLSClientConfig: opts.TLSConfig, + }, + } +} diff --git a/internal/util/ssh/ssh.go b/internal/util/ssh/ssh.go index 55af884..33e9da3 100644 --- a/internal/util/ssh/ssh.go +++ b/internal/util/ssh/ssh.go @@ -27,7 +27,7 @@ func PasswordCallback(au auth.Authenticator) PasswordCallbackFunc { return nil } return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { - if ok, _ := au.Authenticate(context.Background(), conn.User(), string(password)); ok { + if _, ok := au.Authenticate(context.Background(), conn.User(), string(password)); ok { return nil, nil } return nil, fmt.Errorf("password rejected for %s", conn.User()) diff --git a/recorder/plugin.go b/recorder/plugin.go index a473d00..7004807 100644 --- a/recorder/plugin.go +++ b/recorder/plugin.go @@ -1,12 +1,18 @@ package recorder import ( + "bytes" "context" + "encoding/json" + "errors" + "fmt" "io" + "net/http" "github.com/go-gost/core/logger" "github.com/go-gost/core/recorder" "github.com/go-gost/plugin/recorder/proto" + "github.com/go-gost/x/internal/util/plugin" "google.golang.org/grpc" ) @@ -16,14 +22,25 @@ type grpcPluginRecorder struct { log logger.Logger } -// NewGRPCPluginRecorder creates a plugin recorder. -func NewGRPCPluginRecorder(name string, conn grpc.ClientConnInterface) recorder.Recorder { +// NewGRPCPluginRecorder creates a Recorder plugin based on gRPC. +func NewGRPCPluginRecorder(name string, addr string, opts ...plugin.Option) recorder.Recorder { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + log := logger.Default().WithFields(map[string]any{ + "kind": "recorder", + "recorder": name, + }) + conn, err := plugin.NewGRPCConn(addr, &options) + if err != nil { + log.Error(err) + } + p := &grpcPluginRecorder{ conn: conn, - log: logger.Default().WithFields(map[string]any{ - "kind": "recorder", - "recorder": name, - }), + log: log, } if conn != nil { p.client = proto.NewRecorderClient(conn) @@ -53,3 +70,79 @@ func (p *grpcPluginRecorder) Close() error { } return nil } + +type httpRecorderRequest struct { + Data []byte `json:"data"` +} + +type httpRecorderResponse struct { + OK bool `json:"ok"` +} + +type httpPluginRecorder struct { + url string + client *http.Client + header http.Header + log logger.Logger +} + +// NewHTTPPluginRecorder creates an Recorder plugin based on HTTP. +func NewHTTPPluginRecorder(name string, url string, opts ...plugin.Option) recorder.Recorder { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + return &httpPluginRecorder{ + url: url, + client: plugin.NewHTTPClient(&options), + header: options.Header, + log: logger.Default().WithFields(map[string]any{ + "kind": "recorder", + "recorder": name, + }), + } +} + +func (p *httpPluginRecorder) Record(ctx context.Context, b []byte) error { + if len(b) == 0 || p.client == nil { + return nil + } + + rb := httpRecorderRequest{ + Data: b, + } + v, err := json.Marshal(&rb) + if err != nil { + return err + } + + req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + if err != nil { + return err + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("%s", resp.Status) + } + + res := httpRecorderResponse{} + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return err + } + + if !res.OK { + return errors.New("record failed") + } + return nil +} diff --git a/registry/auther.go b/registry/auther.go index d1d5918..ae76aff 100644 --- a/registry/auther.go +++ b/registry/auther.go @@ -30,10 +30,10 @@ type autherWrapper struct { r *autherRegistry } -func (w *autherWrapper) Authenticate(ctx context.Context, user, password string) (bool, string) { +func (w *autherWrapper) Authenticate(ctx context.Context, user, password string) (string, bool) { v := w.r.get(w.name) if v == nil { - return true, "" + return "", true } return v.Authenticate(ctx, user, password) } diff --git a/resolver/plugin.go b/resolver/plugin.go index 4ecae23..9c9bc90 100644 --- a/resolver/plugin.go +++ b/resolver/plugin.go @@ -1,14 +1,20 @@ package resolver import ( + "bytes" "context" + "encoding/json" + "errors" + "fmt" "io" "net" + "net/http" "github.com/go-gost/core/logger" - resolver_pkg "github.com/go-gost/core/resolver" + "github.com/go-gost/core/resolver" "github.com/go-gost/plugin/resolver/proto" auth_util "github.com/go-gost/x/internal/util/auth" + "github.com/go-gost/x/internal/util/plugin" "google.golang.org/grpc" ) @@ -19,13 +25,23 @@ type grpcPluginResolver struct { } // NewGRPCPluginResolver creates a Resolver plugin based on gRPC. -func NewGRPCPluginResolver(name string, conn grpc.ClientConnInterface) (resolver_pkg.Resolver, error) { +func NewGRPCPluginResolver(name string, addr string, opts ...plugin.Option) (resolver.Resolver, error) { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + log := logger.Default().WithFields(map[string]any{ + "kind": "resolver", + "resolover": name, + }) + conn, err := plugin.NewGRPCConn(addr, &options) + if err != nil { + log.Error(err) + } p := &grpcPluginResolver{ conn: conn, - log: logger.Default().WithFields(map[string]any{ - "kind": "resolver", - "resolver": name, - }), + log: log, } if conn != nil { p.client = proto.NewResolverClient(conn) @@ -64,3 +80,93 @@ func (p *grpcPluginResolver) Close() error { } return nil } + +type httpResolverRequest struct { + Network string `json:"network"` + Host string `json:"host"` + Client string `json:"client"` +} + +type httpResolverResponse struct { + IPs []string `json:"ips"` + OK bool `json:"ok"` +} + +type httpPluginResolver struct { + url string + client *http.Client + header http.Header + log logger.Logger +} + +// NewHTTPPluginResolver creates an Resolver plugin based on HTTP. +func NewHTTPPluginResolver(name string, url string, opts ...plugin.Option) resolver.Resolver { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + return &httpPluginResolver{ + url: url, + client: plugin.NewHTTPClient(&options), + header: options.Header, + log: logger.Default().WithFields(map[string]any{ + "kind": "resolver", + "resolver": name, + }), + } +} + +func (p *httpPluginResolver) Resolve(ctx context.Context, network, host string) (ips []net.IP, err error) { + p.log.Debugf("resolve %s/%s", host, network) + + if p.client == nil { + return + } + + rb := httpResolverRequest{ + Network: network, + Host: host, + Client: string(auth_util.IDFromContext(ctx)), + } + v, err := json.Marshal(&rb) + if err != nil { + return + } + + req, err := http.NewRequest(http.MethodPost, p.url, bytes.NewReader(v)) + if err != nil { + return + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + err = fmt.Errorf("%s", resp.Status) + return + } + + res := httpResolverResponse{} + if err = json.NewDecoder(resp.Body).Decode(&res); err != nil { + return + } + + if !res.OK { + return nil, errors.New("resolve failed") + } + + for _, s := range res.IPs { + if ip := net.ParseIP(s); ip != nil { + ips = append(ips, ip) + } + } + return ips, nil +} diff --git a/resolver/resolver.go b/resolver/resolver.go index 548a1c0..0ea44e0 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -8,7 +8,7 @@ import ( "github.com/go-gost/core/chain" "github.com/go-gost/core/logger" - resolverpkg "github.com/go-gost/core/resolver" + "github.com/go-gost/core/resolver" resolver_util "github.com/go-gost/x/internal/util/resolver" "github.com/go-gost/x/resolver/exchanger" "github.com/miekg/dns" @@ -52,13 +52,13 @@ func LoggerOption(logger logger.Logger) Option { } } -type resolver struct { +type localResolver struct { servers []NameServer cache *resolver_util.Cache options options } -func NewResolver(nameservers []NameServer, opts ...Option) (resolverpkg.Resolver, error) { +func NewResolver(nameservers []NameServer, opts ...Option) (resolver.Resolver, error) { options := options{} for _, opt := range opts { opt(&options) @@ -92,14 +92,14 @@ func NewResolver(nameservers []NameServer, opts ...Option) (resolverpkg.Resolver cache := resolver_util.NewCache(). WithLogger(options.logger) - return &resolver{ + return &localResolver{ servers: servers, cache: cache, options: options, }, nil } -func (r *resolver) Resolve(ctx context.Context, network, host string) (ips []net.IP, err error) { +func (r *localResolver) Resolve(ctx context.Context, network, host string) (ips []net.IP, err error) { if ip := net.ParseIP(host); ip != nil { return []net.IP{ip}, nil } @@ -126,7 +126,7 @@ func (r *resolver) Resolve(ctx context.Context, network, host string) (ips []net return } -func (r *resolver) resolve(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) { +func (r *localResolver) resolve(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) { if server == nil { return } @@ -144,19 +144,19 @@ func (r *resolver) resolve(ctx context.Context, server *NameServer, host string) return r.resolve6(ctx, server, host) } -func (r *resolver) resolve4(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) { +func (r *localResolver) resolve4(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) { mq := dns.Msg{} mq.SetQuestion(dns.Fqdn(host), dns.TypeA) return r.resolveIPs(ctx, server, &mq) } -func (r *resolver) resolve6(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) { +func (r *localResolver) resolve6(ctx context.Context, server *NameServer, host string) (ips []net.IP, err error) { mq := dns.Msg{} mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) return r.resolveIPs(ctx, server, &mq) } -func (r *resolver) resolveIPs(ctx context.Context, server *NameServer, mq *dns.Msg) (ips []net.IP, err error) { +func (r *localResolver) resolveIPs(ctx context.Context, server *NameServer, mq *dns.Msg) (ips []net.IP, err error) { key := resolver_util.NewCacheKey(&mq.Question[0]) mr, ttl := r.cache.Load(key) if ttl <= 0 { @@ -180,7 +180,7 @@ func (r *resolver) resolveIPs(ctx context.Context, server *NameServer, mq *dns.M return } -func (r *resolver) exchange(ctx context.Context, ex exchanger.Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { +func (r *localResolver) exchange(ctx context.Context, ex exchanger.Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { query, err := mq.Pack() if err != nil { return