diff --git a/cmd/gost/api.yaml b/cmd/gost/api.yaml new file mode 100644 index 0000000..986ad43 --- /dev/null +++ b/cmd/gost/api.yaml @@ -0,0 +1,13 @@ +api: + addr: :18080 + accesslog: true + pathPrefix: /api + auth: + username: gost + password: gost + auther: auther-0 +authers: +- name: auther-0 + auths: + - username: gost1 + password: gost1 diff --git a/cmd/gost/config.go b/cmd/gost/config.go index fad9956..0b4ca5c 100644 --- a/cmd/gost/config.go +++ b/cmd/gost/config.go @@ -4,6 +4,7 @@ import ( "io" "os" + "github.com/go-gost/gost/pkg/api" "github.com/go-gost/gost/pkg/config" "github.com/go-gost/gost/pkg/config/parsing" "github.com/go-gost/gost/pkg/logger" @@ -11,8 +12,8 @@ import ( "github.com/go-gost/gost/pkg/service" ) -func buildService(cfg *config.Config) (services []*service.Service) { - if cfg == nil || len(cfg.Services) == 0 { +func buildService(cfg *config.Config) (services []service.Servicer) { + if cfg == nil { return } @@ -109,3 +110,16 @@ func logFromConfig(cfg *config.LogConfig) logger.Logger { return logger.NewLogger(opts...) } + +func buildAPIServer(cfg *config.APIConfig) (*api.Server, error) { + auther := parsing.ParseAutherFromAuth(cfg.Auth) + if cfg.Auther != "" { + auther = registry.Auther().Get(cfg.Auther) + } + return api.NewServer( + cfg.Addr, + api.PathPrefixOption(cfg.PathPrefix), + api.AccessLogOption(cfg.AccessLog), + api.AutherOption(auther), + ) +} diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 2bc9c96..2573d32 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -3,13 +3,11 @@ package main import ( "flag" "fmt" - "net" "net/http" _ "net/http/pprof" "os" "runtime" - "github.com/go-gost/gost/pkg/api" "github.com/go-gost/gost/pkg/config" "github.com/go-gost/gost/pkg/logger" ) @@ -93,17 +91,16 @@ func main() { }() } - if cfg.API != nil && cfg.API.Addr != "" { - api.Init(cfg.API) - ln, err := net.Listen("tcp", cfg.API.Addr) + if cfg.API != nil { + s, err := buildAPIServer(cfg.API) if err != nil { log.Fatal(err) } - defer ln.Close() + defer s.Close() go func() { - log.Info("api server on ", ln.Addr()) - log.Fatal(api.Run(ln)) + log.Info("api server on ", s.Addr()) + log.Fatal(s.Serve()) }() } @@ -111,7 +108,7 @@ func main() { services := buildService(cfg) for _, svc := range services { - go svc.Run() + go svc.Serve() } config.SetGlobal(cfg) diff --git a/pkg/api/api.go b/pkg/api/api.go index 4d0de43..9222186 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -2,51 +2,20 @@ package api import ( "embed" - "net" "net/http" - "time" - "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" - "github.com/go-gost/gost/pkg/config" - "github.com/go-gost/gost/pkg/logger" ) var ( - apiServer = &http.Server{} - //go:embed swagger.yaml swaggerDoc embed.FS ) -func Init(cfg *config.APIConfig) { - gin.SetMode(gin.ReleaseMode) +func register(r *gin.RouterGroup) { + r.StaticFS("/docs", http.FS(swaggerDoc)) - if cfg == nil { - cfg = &config.APIConfig{} - } - - r := gin.New() - r.Use( - cors.New((cors.Config{ - AllowAllOrigins: true, - AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, - AllowHeaders: []string{"*"}, - })), - gin.Recovery(), - ) - if cfg.AccessLog { - r.Use(loggerHandler) - } - - router := r.Group("") - if cfg.PathPrefix != "" { - router = router.Group(cfg.PathPrefix) - } - - router.StaticFS("/docs", http.FS(swaggerDoc)) - - config := router.Group("/config") + config := r.Group("/config") { config.GET("", getConfig) @@ -74,30 +43,6 @@ func Init(cfg *config.APIConfig) { config.PUT("/hosts/:hosts", updateHosts) config.DELETE("/hosts/:hosts", deleteHosts) } - - apiServer.Handler = r -} - -func Run(ln net.Listener) error { - return apiServer.Serve(ln) -} - -func loggerHandler(ctx *gin.Context) { - // start time - startTime := time.Now() - // Processing request - ctx.Next() - duration := time.Since(startTime) - - logger.Default().WithFields(map[string]interface{}{ - "kind": "api", - "method": ctx.Request.Method, - "uri": ctx.Request.RequestURI, - "code": ctx.Writer.Status(), - "client": ctx.ClientIP(), - "duration": duration, - }).Infof("| %3d | %13v | %15s | %-7s %s", - ctx.Writer.Status(), duration, ctx.ClientIP(), ctx.Request.Method, ctx.Request.RequestURI) } type Response struct { diff --git a/pkg/api/config_service.go b/pkg/api/config_service.go index 7a37847..80f6870 100644 --- a/pkg/api/config_service.go +++ b/pkg/api/config_service.go @@ -54,7 +54,7 @@ func createService(ctx *gin.Context) { return } - go svc.Run() + go svc.Serve() cfg := config.Global() cfg.Services = append(cfg.Services, &req.Data) @@ -115,7 +115,7 @@ func updateService(ctx *gin.Context) { return } - go svc.Run() + go svc.Serve() cfg := config.Global() for i := range cfg.Services { diff --git a/pkg/api/middleware.go b/pkg/api/middleware.go new file mode 100644 index 0000000..d9c71f6 --- /dev/null +++ b/pkg/api/middleware.go @@ -0,0 +1,42 @@ +package api + +import ( + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/go-gost/gost/pkg/auth" + "github.com/go-gost/gost/pkg/logger" +) + +func mwLogger() gin.HandlerFunc { + return func(ctx *gin.Context) { + // start time + startTime := time.Now() + // Processing request + ctx.Next() + duration := time.Since(startTime) + + logger.Default().WithFields(map[string]interface{}{ + "kind": "api", + "method": ctx.Request.Method, + "uri": ctx.Request.RequestURI, + "code": ctx.Writer.Status(), + "client": ctx.ClientIP(), + "duration": duration, + }).Infof("| %3d | %13v | %15s | %-7s %s", + ctx.Writer.Status(), duration, ctx.ClientIP(), ctx.Request.Method, ctx.Request.RequestURI) + } +} + +func mwBasicAuth(auther auth.Authenticator) gin.HandlerFunc { + return func(c *gin.Context) { + if auther == nil { + return + } + u, p, _ := c.Request.BasicAuth() + if !auther.Authenticate(u, p) { + c.AbortWithStatus(http.StatusUnauthorized) + } + } +} diff --git a/pkg/api/server.go b/pkg/api/server.go new file mode 100644 index 0000000..9cec619 --- /dev/null +++ b/pkg/api/server.go @@ -0,0 +1,96 @@ +package api + +import ( + "net" + "net/http" + + "github.com/gin-contrib/cors" + "github.com/gin-gonic/gin" + "github.com/go-gost/gost/pkg/auth" +) + +type options struct { + accessLog bool + pathPrefix string + auther auth.Authenticator +} + +type Option func(*options) + +func PathPrefixOption(pathPrefix string) Option { + return func(o *options) { + o.pathPrefix = pathPrefix + } +} + +func AccessLogOption(enable bool) Option { + return func(o *options) { + o.accessLog = enable + } +} + +func AutherOption(auther auth.Authenticator) Option { + return func(o *options) { + o.auther = auther + } +} + +type Server struct { + s *http.Server + ln net.Listener +} + +func NewServer(addr string, opts ...Option) (*Server, error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + + var options options + for _, opt := range opts { + opt(&options) + } + + gin.SetMode(gin.ReleaseMode) + + r := gin.New() + r.Use( + cors.New((cors.Config{ + AllowAllOrigins: true, + AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + AllowHeaders: []string{"*"}, + })), + gin.Recovery(), + ) + if options.accessLog { + r.Use(mwLogger()) + } + if options.auther != nil { + r.Use(mwBasicAuth(options.auther)) + } + + router := r.Group("") + if options.pathPrefix != "" { + router = router.Group(options.pathPrefix) + } + register(router) + + return &Server{ + s: &http.Server{ + Handler: r, + }, + ln: ln, + }, nil +} + +func (s *Server) Serve() error { + return s.s.Serve(s.ln) +} + +func (s *Server) Addr() net.Addr { + return s.ln.Addr() +} + +func (s *Server) Close() error { + return s.s.Close() +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 21fdc9a..41f4f32 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -54,9 +54,11 @@ type ProfilingConfig struct { } type APIConfig struct { - Addr string `json:"addr"` - PathPrefix string `yaml:"pathPrefix,omitempty" json:"pathPrefix,omitempty"` - AccessLog bool `yaml:"accesslog,omitempty" json:"accesslog,omitemtpy"` + Addr string `json:"addr"` + PathPrefix string `yaml:"pathPrefix,omitempty" json:"pathPrefix,omitempty"` + AccessLog bool `yaml:"accesslog,omitempty" json:"accesslog,omitempty"` + Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"` + Auther string `yaml:",omitempty" json:"auther,omitempty"` } type TLSConfig struct { diff --git a/pkg/config/parsing/parse.go b/pkg/config/parsing/parse.go index 9c6b919..1235a31 100644 --- a/pkg/config/parsing/parse.go +++ b/pkg/config/parsing/parse.go @@ -35,7 +35,7 @@ func ParseAuther(cfg *config.AutherConfig) auth.Authenticator { return auth.NewMapAuthenticator(m) } -func autherFromAuth(au *config.AuthConfig) auth.Authenticator { +func ParseAutherFromAuth(au *config.AuthConfig) auth.Authenticator { if au == nil || au.Username == "" { return nil } diff --git a/pkg/config/parsing/service.go b/pkg/config/parsing/service.go index 71664ae..5929512 100644 --- a/pkg/config/parsing/service.go +++ b/pkg/config/parsing/service.go @@ -14,7 +14,7 @@ import ( "github.com/go-gost/gost/pkg/service" ) -func ParseService(cfg *config.ServiceConfig) (*service.Service, error) { +func ParseService(cfg *config.ServiceConfig) (service.Servicer, error) { if cfg.Listener == nil { cfg.Listener = &config.ListenerConfig{ Type: "tcp", @@ -47,7 +47,7 @@ func ParseService(cfg *config.ServiceConfig) (*service.Service, error) { return nil, err } - auther := autherFromAuth(cfg.Listener.Auth) + auther := ParseAutherFromAuth(cfg.Listener.Auth) if cfg.Listener.Auther != "" { auther = registry.Auther().Get(cfg.Listener.Auther) } @@ -84,7 +84,7 @@ func ParseService(cfg *config.ServiceConfig) (*service.Service, error) { return nil, err } - auther = autherFromAuth(cfg.Handler.Auth) + auther = ParseAutherFromAuth(cfg.Handler.Auth) if cfg.Handler.Auther != "" { auther = registry.Auther().Get(cfg.Handler.Auther) } diff --git a/pkg/registry/service.go b/pkg/registry/service.go index c4ee266..2889725 100644 --- a/pkg/registry/service.go +++ b/pkg/registry/service.go @@ -18,7 +18,7 @@ type serviceRegistry struct { m sync.Map } -func (r *serviceRegistry) Register(name string, svc *service.Service) error { +func (r *serviceRegistry) Register(name string, svc service.Servicer) error { if name == "" || svc == nil { return nil } @@ -38,7 +38,7 @@ func (r *serviceRegistry) IsRegistered(name string) bool { return ok } -func (r *serviceRegistry) Get(name string) *service.Service { +func (r *serviceRegistry) Get(name string) service.Servicer { if name == "" { return nil } @@ -46,5 +46,5 @@ func (r *serviceRegistry) Get(name string) *service.Service { if !ok { return nil } - return v.(*service.Service) + return v.(service.Servicer) } diff --git a/pkg/service/service.go b/pkg/service/service.go index 3290e8a..d9b7da9 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -10,6 +10,12 @@ import ( "github.com/go-gost/gost/pkg/logger" ) +type Servicer interface { + Serve() error + Addr() net.Addr + Close() error +} + type Service struct { listener listener.Listener handler handler.Handler @@ -35,15 +41,11 @@ func (s *Service) Addr() net.Addr { return s.listener.Addr() } -func (s *Service) Run() error { - return s.serve() -} - func (s *Service) Close() error { return s.listener.Close() } -func (s *Service) serve() error { +func (s *Service) Serve() error { var tempDelay time.Duration for { conn, e := s.listener.Accept()