add basic auth for webapi

This commit is contained in:
ginuerzh 2022-02-12 21:05:39 +08:00
parent fdd67a6086
commit f2d806886a
12 changed files with 197 additions and 86 deletions

13
cmd/gost/api.yaml Normal file
View File

@ -0,0 +1,13 @@
api:
addr: :18080
accesslog: true
pathPrefix: /api
auth:
username: gost
password: gost
auther: auther-0
authers:
- name: auther-0
auths:
- username: gost1
password: gost1

View File

@ -4,6 +4,7 @@ import (
"io" "io"
"os" "os"
"github.com/go-gost/gost/pkg/api"
"github.com/go-gost/gost/pkg/config" "github.com/go-gost/gost/pkg/config"
"github.com/go-gost/gost/pkg/config/parsing" "github.com/go-gost/gost/pkg/config/parsing"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
@ -11,8 +12,8 @@ import (
"github.com/go-gost/gost/pkg/service" "github.com/go-gost/gost/pkg/service"
) )
func buildService(cfg *config.Config) (services []*service.Service) { func buildService(cfg *config.Config) (services []service.Servicer) {
if cfg == nil || len(cfg.Services) == 0 { if cfg == nil {
return return
} }
@ -109,3 +110,16 @@ func logFromConfig(cfg *config.LogConfig) logger.Logger {
return logger.NewLogger(opts...) return logger.NewLogger(opts...)
} }
func buildAPIServer(cfg *config.APIConfig) (*api.Server, error) {
auther := parsing.ParseAutherFromAuth(cfg.Auth)
if cfg.Auther != "" {
auther = registry.Auther().Get(cfg.Auther)
}
return api.NewServer(
cfg.Addr,
api.PathPrefixOption(cfg.PathPrefix),
api.AccessLogOption(cfg.AccessLog),
api.AutherOption(auther),
)
}

View File

@ -3,13 +3,11 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"net"
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"runtime" "runtime"
"github.com/go-gost/gost/pkg/api"
"github.com/go-gost/gost/pkg/config" "github.com/go-gost/gost/pkg/config"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
) )
@ -93,17 +91,16 @@ func main() {
}() }()
} }
if cfg.API != nil && cfg.API.Addr != "" { if cfg.API != nil {
api.Init(cfg.API) s, err := buildAPIServer(cfg.API)
ln, err := net.Listen("tcp", cfg.API.Addr)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer ln.Close() defer s.Close()
go func() { go func() {
log.Info("api server on ", ln.Addr()) log.Info("api server on ", s.Addr())
log.Fatal(api.Run(ln)) log.Fatal(s.Serve())
}() }()
} }
@ -111,7 +108,7 @@ func main() {
services := buildService(cfg) services := buildService(cfg)
for _, svc := range services { for _, svc := range services {
go svc.Run() go svc.Serve()
} }
config.SetGlobal(cfg) config.SetGlobal(cfg)

View File

@ -2,51 +2,20 @@ package api
import ( import (
"embed" "embed"
"net"
"net/http" "net/http"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-gost/gost/pkg/config"
"github.com/go-gost/gost/pkg/logger"
) )
var ( var (
apiServer = &http.Server{}
//go:embed swagger.yaml //go:embed swagger.yaml
swaggerDoc embed.FS swaggerDoc embed.FS
) )
func Init(cfg *config.APIConfig) { func register(r *gin.RouterGroup) {
gin.SetMode(gin.ReleaseMode) r.StaticFS("/docs", http.FS(swaggerDoc))
if cfg == nil { config := r.Group("/config")
cfg = &config.APIConfig{}
}
r := gin.New()
r.Use(
cors.New((cors.Config{
AllowAllOrigins: true,
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowHeaders: []string{"*"},
})),
gin.Recovery(),
)
if cfg.AccessLog {
r.Use(loggerHandler)
}
router := r.Group("")
if cfg.PathPrefix != "" {
router = router.Group(cfg.PathPrefix)
}
router.StaticFS("/docs", http.FS(swaggerDoc))
config := router.Group("/config")
{ {
config.GET("", getConfig) config.GET("", getConfig)
@ -74,30 +43,6 @@ func Init(cfg *config.APIConfig) {
config.PUT("/hosts/:hosts", updateHosts) config.PUT("/hosts/:hosts", updateHosts)
config.DELETE("/hosts/:hosts", deleteHosts) config.DELETE("/hosts/:hosts", deleteHosts)
} }
apiServer.Handler = r
}
func Run(ln net.Listener) error {
return apiServer.Serve(ln)
}
func loggerHandler(ctx *gin.Context) {
// start time
startTime := time.Now()
// Processing request
ctx.Next()
duration := time.Since(startTime)
logger.Default().WithFields(map[string]interface{}{
"kind": "api",
"method": ctx.Request.Method,
"uri": ctx.Request.RequestURI,
"code": ctx.Writer.Status(),
"client": ctx.ClientIP(),
"duration": duration,
}).Infof("| %3d | %13v | %15s | %-7s %s",
ctx.Writer.Status(), duration, ctx.ClientIP(), ctx.Request.Method, ctx.Request.RequestURI)
} }
type Response struct { type Response struct {

View File

@ -54,7 +54,7 @@ func createService(ctx *gin.Context) {
return return
} }
go svc.Run() go svc.Serve()
cfg := config.Global() cfg := config.Global()
cfg.Services = append(cfg.Services, &req.Data) cfg.Services = append(cfg.Services, &req.Data)
@ -115,7 +115,7 @@ func updateService(ctx *gin.Context) {
return return
} }
go svc.Run() go svc.Serve()
cfg := config.Global() cfg := config.Global()
for i := range cfg.Services { for i := range cfg.Services {

42
pkg/api/middleware.go Normal file
View File

@ -0,0 +1,42 @@
package api
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/go-gost/gost/pkg/auth"
"github.com/go-gost/gost/pkg/logger"
)
func mwLogger() gin.HandlerFunc {
return func(ctx *gin.Context) {
// start time
startTime := time.Now()
// Processing request
ctx.Next()
duration := time.Since(startTime)
logger.Default().WithFields(map[string]interface{}{
"kind": "api",
"method": ctx.Request.Method,
"uri": ctx.Request.RequestURI,
"code": ctx.Writer.Status(),
"client": ctx.ClientIP(),
"duration": duration,
}).Infof("| %3d | %13v | %15s | %-7s %s",
ctx.Writer.Status(), duration, ctx.ClientIP(), ctx.Request.Method, ctx.Request.RequestURI)
}
}
func mwBasicAuth(auther auth.Authenticator) gin.HandlerFunc {
return func(c *gin.Context) {
if auther == nil {
return
}
u, p, _ := c.Request.BasicAuth()
if !auther.Authenticate(u, p) {
c.AbortWithStatus(http.StatusUnauthorized)
}
}
}

96
pkg/api/server.go Normal file
View File

@ -0,0 +1,96 @@
package api
import (
"net"
"net/http"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/go-gost/gost/pkg/auth"
)
type options struct {
accessLog bool
pathPrefix string
auther auth.Authenticator
}
type Option func(*options)
func PathPrefixOption(pathPrefix string) Option {
return func(o *options) {
o.pathPrefix = pathPrefix
}
}
func AccessLogOption(enable bool) Option {
return func(o *options) {
o.accessLog = enable
}
}
func AutherOption(auther auth.Authenticator) Option {
return func(o *options) {
o.auther = auther
}
}
type Server struct {
s *http.Server
ln net.Listener
}
func NewServer(addr string, opts ...Option) (*Server, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
var options options
for _, opt := range opts {
opt(&options)
}
gin.SetMode(gin.ReleaseMode)
r := gin.New()
r.Use(
cors.New((cors.Config{
AllowAllOrigins: true,
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowHeaders: []string{"*"},
})),
gin.Recovery(),
)
if options.accessLog {
r.Use(mwLogger())
}
if options.auther != nil {
r.Use(mwBasicAuth(options.auther))
}
router := r.Group("")
if options.pathPrefix != "" {
router = router.Group(options.pathPrefix)
}
register(router)
return &Server{
s: &http.Server{
Handler: r,
},
ln: ln,
}, nil
}
func (s *Server) Serve() error {
return s.s.Serve(s.ln)
}
func (s *Server) Addr() net.Addr {
return s.ln.Addr()
}
func (s *Server) Close() error {
return s.s.Close()
}

View File

@ -56,7 +56,9 @@ type ProfilingConfig struct {
type APIConfig struct { type APIConfig struct {
Addr string `json:"addr"` Addr string `json:"addr"`
PathPrefix string `yaml:"pathPrefix,omitempty" json:"pathPrefix,omitempty"` PathPrefix string `yaml:"pathPrefix,omitempty" json:"pathPrefix,omitempty"`
AccessLog bool `yaml:"accesslog,omitempty" json:"accesslog,omitemtpy"` AccessLog bool `yaml:"accesslog,omitempty" json:"accesslog,omitempty"`
Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"`
Auther string `yaml:",omitempty" json:"auther,omitempty"`
} }
type TLSConfig struct { type TLSConfig struct {

View File

@ -35,7 +35,7 @@ func ParseAuther(cfg *config.AutherConfig) auth.Authenticator {
return auth.NewMapAuthenticator(m) return auth.NewMapAuthenticator(m)
} }
func autherFromAuth(au *config.AuthConfig) auth.Authenticator { func ParseAutherFromAuth(au *config.AuthConfig) auth.Authenticator {
if au == nil || au.Username == "" { if au == nil || au.Username == "" {
return nil return nil
} }

View File

@ -14,7 +14,7 @@ import (
"github.com/go-gost/gost/pkg/service" "github.com/go-gost/gost/pkg/service"
) )
func ParseService(cfg *config.ServiceConfig) (*service.Service, error) { func ParseService(cfg *config.ServiceConfig) (service.Servicer, error) {
if cfg.Listener == nil { if cfg.Listener == nil {
cfg.Listener = &config.ListenerConfig{ cfg.Listener = &config.ListenerConfig{
Type: "tcp", Type: "tcp",
@ -47,7 +47,7 @@ func ParseService(cfg *config.ServiceConfig) (*service.Service, error) {
return nil, err return nil, err
} }
auther := autherFromAuth(cfg.Listener.Auth) auther := ParseAutherFromAuth(cfg.Listener.Auth)
if cfg.Listener.Auther != "" { if cfg.Listener.Auther != "" {
auther = registry.Auther().Get(cfg.Listener.Auther) auther = registry.Auther().Get(cfg.Listener.Auther)
} }
@ -84,7 +84,7 @@ func ParseService(cfg *config.ServiceConfig) (*service.Service, error) {
return nil, err return nil, err
} }
auther = autherFromAuth(cfg.Handler.Auth) auther = ParseAutherFromAuth(cfg.Handler.Auth)
if cfg.Handler.Auther != "" { if cfg.Handler.Auther != "" {
auther = registry.Auther().Get(cfg.Handler.Auther) auther = registry.Auther().Get(cfg.Handler.Auther)
} }

View File

@ -18,7 +18,7 @@ type serviceRegistry struct {
m sync.Map m sync.Map
} }
func (r *serviceRegistry) Register(name string, svc *service.Service) error { func (r *serviceRegistry) Register(name string, svc service.Servicer) error {
if name == "" || svc == nil { if name == "" || svc == nil {
return nil return nil
} }
@ -38,7 +38,7 @@ func (r *serviceRegistry) IsRegistered(name string) bool {
return ok return ok
} }
func (r *serviceRegistry) Get(name string) *service.Service { func (r *serviceRegistry) Get(name string) service.Servicer {
if name == "" { if name == "" {
return nil return nil
} }
@ -46,5 +46,5 @@ func (r *serviceRegistry) Get(name string) *service.Service {
if !ok { if !ok {
return nil return nil
} }
return v.(*service.Service) return v.(service.Servicer)
} }

View File

@ -10,6 +10,12 @@ import (
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
) )
type Servicer interface {
Serve() error
Addr() net.Addr
Close() error
}
type Service struct { type Service struct {
listener listener.Listener listener listener.Listener
handler handler.Handler handler handler.Handler
@ -35,15 +41,11 @@ func (s *Service) Addr() net.Addr {
return s.listener.Addr() return s.listener.Addr()
} }
func (s *Service) Run() error {
return s.serve()
}
func (s *Service) Close() error { func (s *Service) Close() error {
return s.listener.Close() return s.listener.Close()
} }
func (s *Service) serve() error { func (s *Service) Serve() error {
var tempDelay time.Duration var tempDelay time.Duration
for { for {
conn, e := s.listener.Accept() conn, e := s.listener.Accept()