From f847fa533e984ed78e30f0849d4046c08f45f25a Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sat, 16 Dec 2023 14:28:58 +0800 Subject: [PATCH] fix auth for file handler --- api/api.go | 164 ++++++++++++++++++++++++++++++++ api/config_chain.go | 5 +- api/config_hop.go | 5 +- api/service.go | 157 ------------------------------ config/parsing/chain/parse.go | 6 +- config/parsing/hop/parse.go | 6 +- config/parsing/node/parse.go | 4 +- config/parsing/service/parse.go | 2 +- handler/file/handler.go | 1 + handler/tunnel/bind.go | 19 +++- hop/hop.go | 2 +- hop/plugin.go | 4 +- listener/rtcp/listener.go | 14 ++- listener/rudp/listener.go | 35 +++++-- metadata/metadata.go | 3 + metrics/service/service.go | 17 +++- service/service.go | 20 ++-- 17 files changed, 263 insertions(+), 201 deletions(-) delete mode 100644 api/service.go diff --git a/api/api.go b/api/api.go index 5b7fcce..b132272 100644 --- a/api/api.go +++ b/api/api.go @@ -2,6 +2,13 @@ package api import ( "embed" + "net" + "net/http" + + "github.com/gin-contrib/cors" + "github.com/gin-gonic/gin" + "github.com/go-gost/core/auth" + "github.com/go-gost/core/service" ) var ( @@ -13,3 +20,160 @@ type Response struct { Code int `json:"code,omitempty"` Msg string `json:"msg,omitempty"` } + +type options struct { + accessLog bool + pathPrefix string + auther auth.Authenticator +} + +type Option func(*options) + +func PathPrefixOption(pathPrefix string) Option { + return func(o *options) { + o.pathPrefix = pathPrefix + } +} + +func AccessLogOption(enable bool) Option { + return func(o *options) { + o.accessLog = enable + } +} + +func AutherOption(auther auth.Authenticator) Option { + return func(o *options) { + o.auther = auther + } +} + +type server struct { + s *http.Server + ln net.Listener + cclose chan struct{} +} + +func NewService(addr string, opts ...Option) (service.Service, error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + + var options options + for _, opt := range opts { + opt(&options) + } + + gin.SetMode(gin.ReleaseMode) + + r := gin.New() + r.Use( + cors.New((cors.Config{ + AllowAllOrigins: true, + AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + AllowHeaders: []string{"*"}, + AllowPrivateNetwork: true, + })), + gin.Recovery(), + ) + if options.accessLog { + r.Use(mwLogger()) + } + + router := r.Group("") + if options.pathPrefix != "" { + router = router.Group(options.pathPrefix) + } + + router.StaticFS("/docs", http.FS(swaggerDoc)) + + config := router.Group("/config") + config.Use(mwBasicAuth(options.auther)) + registerConfig(config) + + return &server{ + s: &http.Server{ + Handler: r, + }, + ln: ln, + cclose: make(chan struct{}), + }, nil +} + +func (s *server) Serve() error { + return s.s.Serve(s.ln) +} + +func (s *server) Addr() net.Addr { + return s.ln.Addr() +} + +func (s *server) Close() error { + return s.s.Close() +} + +func (s *server) IsClosed() bool { + select { + case <-s.cclose: + return true + default: + return false + } +} + +func registerConfig(config *gin.RouterGroup) { + config.GET("", getConfig) + config.POST("", saveConfig) + + config.POST("/services", createService) + config.PUT("/services/:service", updateService) + config.DELETE("/services/:service", deleteService) + + config.POST("/chains", createChain) + config.PUT("/chains/:chain", updateChain) + config.DELETE("/chains/:chain", deleteChain) + + config.POST("/hops", createHop) + config.PUT("/hops/:hop", updateHop) + config.DELETE("/hops/:hop", deleteHop) + + config.POST("/authers", createAuther) + config.PUT("/authers/:auther", updateAuther) + config.DELETE("/authers/:auther", deleteAuther) + + config.POST("/admissions", createAdmission) + config.PUT("/admissions/:admission", updateAdmission) + config.DELETE("/admissions/:admission", deleteAdmission) + + config.POST("/bypasses", createBypass) + config.PUT("/bypasses/:bypass", updateBypass) + config.DELETE("/bypasses/:bypass", deleteBypass) + + config.POST("/resolvers", createResolver) + config.PUT("/resolvers/:resolver", updateResolver) + config.DELETE("/resolvers/:resolver", deleteResolver) + + config.POST("/hosts", createHosts) + config.PUT("/hosts/:hosts", updateHosts) + config.DELETE("/hosts/:hosts", deleteHosts) + + config.POST("/ingresses", createIngress) + config.PUT("/ingresses/:ingress", updateIngress) + config.DELETE("/ingresses/:ingress", deleteIngress) + + config.POST("/routers", createRouter) + config.PUT("/routers/:router", updateRouter) + config.DELETE("/routers/:router", deleteRouter) + + config.POST("/limiters", createLimiter) + config.PUT("/limiters/:limiter", updateLimiter) + config.DELETE("/limiters/:limiter", deleteLimiter) + + config.POST("/climiters", createConnLimiter) + config.PUT("/climiters/:limiter", updateConnLimiter) + config.DELETE("/climiters/:limiter", deleteConnLimiter) + + config.POST("/rlimiters", createRateLimiter) + config.PUT("/rlimiters/:limiter", updateRateLimiter) + config.DELETE("/rlimiters/:limiter", deleteRateLimiter) +} diff --git a/api/config_chain.go b/api/config_chain.go index eba9acb..d5eb33a 100644 --- a/api/config_chain.go +++ b/api/config_chain.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/go-gost/core/logger" "github.com/go-gost/x/config" parser "github.com/go-gost/x/config/parsing/chain" "github.com/go-gost/x/registry" @@ -40,7 +41,7 @@ func createChain(ctx *gin.Context) { return } - v, err := parser.ParseChain(&req.Data) + v, err := parser.ParseChain(&req.Data, logger.Default()) if err != nil { writeError(ctx, ErrCreate) return @@ -99,7 +100,7 @@ func updateChain(ctx *gin.Context) { req.Data.Name = req.Chain - v, err := parser.ParseChain(&req.Data) + v, err := parser.ParseChain(&req.Data, logger.Default()) if err != nil { writeError(ctx, ErrCreate) return diff --git a/api/config_hop.go b/api/config_hop.go index 0021c45..01f45f7 100644 --- a/api/config_hop.go +++ b/api/config_hop.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/go-gost/core/logger" "github.com/go-gost/x/config" parser "github.com/go-gost/x/config/parsing/hop" "github.com/go-gost/x/registry" @@ -40,7 +41,7 @@ func createHop(ctx *gin.Context) { return } - v, err := parser.ParseHop(&req.Data) + v, err := parser.ParseHop(&req.Data, logger.Default()) if err != nil { writeError(ctx, ErrCreate) return @@ -99,7 +100,7 @@ func updateHop(ctx *gin.Context) { req.Data.Name = req.Hop - v, err := parser.ParseHop(&req.Data) + v, err := parser.ParseHop(&req.Data, logger.Default()) if err != nil { writeError(ctx, ErrCreate) return diff --git a/api/service.go b/api/service.go deleted file mode 100644 index 36a44ec..0000000 --- a/api/service.go +++ /dev/null @@ -1,157 +0,0 @@ -package api - -import ( - "net" - "net/http" - - "github.com/gin-contrib/cors" - "github.com/gin-gonic/gin" - "github.com/go-gost/core/auth" - "github.com/go-gost/core/service" -) - -type options struct { - accessLog bool - pathPrefix string - auther auth.Authenticator -} - -type Option func(*options) - -func PathPrefixOption(pathPrefix string) Option { - return func(o *options) { - o.pathPrefix = pathPrefix - } -} - -func AccessLogOption(enable bool) Option { - return func(o *options) { - o.accessLog = enable - } -} - -func AutherOption(auther auth.Authenticator) Option { - return func(o *options) { - o.auther = auther - } -} - -type server struct { - s *http.Server - ln net.Listener -} - -func NewService(addr string, opts ...Option) (service.Service, error) { - ln, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - - var options options - for _, opt := range opts { - opt(&options) - } - - gin.SetMode(gin.ReleaseMode) - - r := gin.New() - r.Use( - cors.New((cors.Config{ - AllowAllOrigins: true, - AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, - AllowHeaders: []string{"*"}, - AllowPrivateNetwork: true, - })), - gin.Recovery(), - ) - if options.accessLog { - r.Use(mwLogger()) - } - - router := r.Group("") - if options.pathPrefix != "" { - router = router.Group(options.pathPrefix) - } - - router.StaticFS("/docs", http.FS(swaggerDoc)) - - config := router.Group("/config") - config.Use(mwBasicAuth(options.auther)) - registerConfig(config) - - return &server{ - s: &http.Server{ - Handler: r, - }, - ln: ln, - }, nil -} - -func (s *server) Serve() error { - return s.s.Serve(s.ln) -} - -func (s *server) Addr() net.Addr { - return s.ln.Addr() -} - -func (s *server) Close() error { - return s.s.Close() -} - -func registerConfig(config *gin.RouterGroup) { - config.GET("", getConfig) - config.POST("", saveConfig) - - config.POST("/services", createService) - config.PUT("/services/:service", updateService) - config.DELETE("/services/:service", deleteService) - - config.POST("/chains", createChain) - config.PUT("/chains/:chain", updateChain) - config.DELETE("/chains/:chain", deleteChain) - - config.POST("/hops", createHop) - config.PUT("/hops/:hop", updateHop) - config.DELETE("/hops/:hop", deleteHop) - - config.POST("/authers", createAuther) - config.PUT("/authers/:auther", updateAuther) - config.DELETE("/authers/:auther", deleteAuther) - - config.POST("/admissions", createAdmission) - config.PUT("/admissions/:admission", updateAdmission) - config.DELETE("/admissions/:admission", deleteAdmission) - - config.POST("/bypasses", createBypass) - config.PUT("/bypasses/:bypass", updateBypass) - config.DELETE("/bypasses/:bypass", deleteBypass) - - config.POST("/resolvers", createResolver) - config.PUT("/resolvers/:resolver", updateResolver) - config.DELETE("/resolvers/:resolver", deleteResolver) - - config.POST("/hosts", createHosts) - config.PUT("/hosts/:hosts", updateHosts) - config.DELETE("/hosts/:hosts", deleteHosts) - - config.POST("/ingresses", createIngress) - config.PUT("/ingresses/:ingress", updateIngress) - config.DELETE("/ingresses/:ingress", deleteIngress) - - config.POST("/routers", createRouter) - config.PUT("/routers/:router", updateRouter) - config.DELETE("/routers/:router", deleteRouter) - - config.POST("/limiters", createLimiter) - config.PUT("/limiters/:limiter", updateLimiter) - config.DELETE("/limiters/:limiter", deleteLimiter) - - config.POST("/climiters", createConnLimiter) - config.PUT("/climiters/:limiter", updateConnLimiter) - config.DELETE("/climiters/:limiter", deleteConnLimiter) - - config.POST("/rlimiters", createRateLimiter) - config.PUT("/rlimiters/:limiter", updateRateLimiter) - config.DELETE("/rlimiters/:limiter", deleteRateLimiter) -} diff --git a/config/parsing/chain/parse.go b/config/parsing/chain/parse.go index 8a03c12..b801188 100644 --- a/config/parsing/chain/parse.go +++ b/config/parsing/chain/parse.go @@ -12,12 +12,12 @@ import ( "github.com/go-gost/x/registry" ) -func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { +func ParseChain(cfg *config.ChainConfig, log logger.Logger) (chain.Chainer, error) { if cfg == nil { return nil, nil } - chainLogger := logger.Default().WithFields(map[string]any{ + chainLogger := log.WithFields(map[string]any{ "kind": "chain", "chain": cfg.Name, }) @@ -37,7 +37,7 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { var err error if ch.Nodes != nil || ch.Plugin != nil { - if hop, err = hop_parser.ParseHop(ch); err != nil { + if hop, err = hop_parser.ParseHop(ch, log); err != nil { return nil, err } } else { diff --git a/config/parsing/hop/parse.go b/config/parsing/hop/parse.go index b3b1147..420e330 100644 --- a/config/parsing/hop/parse.go +++ b/config/parsing/hop/parse.go @@ -17,7 +17,7 @@ import ( "github.com/go-gost/x/internal/plugin" ) -func ParseHop(cfg *config.HopConfig) (hop.Hop, error) { +func ParseHop(cfg *config.HopConfig, log logger.Logger) (hop.Hop, error) { if cfg == nil { return nil, nil } @@ -77,7 +77,7 @@ func ParseHop(cfg *config.HopConfig) (hop.Hop, error) { } } - node, err := node_parser.ParseNode(cfg.Name, v) + node, err := node_parser.ParseNode(cfg.Name, v, log) if err != nil { return nil, err } @@ -97,7 +97,7 @@ func ParseHop(cfg *config.HopConfig) (hop.Hop, error) { xhop.SelectorOption(sel), xhop.BypassOption(bypass.BypassGroup(bypass_parser.List(cfg.Bypass, cfg.Bypasses...)...)), xhop.ReloadPeriodOption(cfg.Reload), - xhop.LoggerOption(logger.Default().WithFields(map[string]any{ + xhop.LoggerOption(log.WithFields(map[string]any{ "kind": "hop", "hop": cfg.Name, })), diff --git a/config/parsing/node/parse.go b/config/parsing/node/parse.go index 1600f2d..d1cc6ef 100644 --- a/config/parsing/node/parse.go +++ b/config/parsing/node/parse.go @@ -23,7 +23,7 @@ import ( "github.com/go-gost/x/registry" ) -func ParseNode(hop string, cfg *config.NodeConfig) (*chain.Node, error) { +func ParseNode(hop string, cfg *config.NodeConfig, log logger.Logger) (*chain.Node, error) { if cfg == nil { return nil, nil } @@ -40,7 +40,7 @@ func ParseNode(hop string, cfg *config.NodeConfig) (*chain.Node, error) { } } - nodeLogger := logger.Default().WithFields(map[string]any{ + nodeLogger := log.WithFields(map[string]any{ "hop": hop, "kind": "node", "node": cfg.Name, diff --git a/config/parsing/service/parse.go b/config/parsing/service/parse.go index 5deb7eb..c319169 100644 --- a/config/parsing/service/parse.go +++ b/config/parsing/service/parse.go @@ -291,7 +291,7 @@ func parseForwarder(cfg *config.ForwarderConfig) (hop.Hop, error) { } } if len(hc.Nodes) > 0 { - return hop_parser.ParseHop(&hc) + return hop_parser.ParseHop(&hc, logger.Default()) } return registry.HopRegistry().Get(hc.Name), nil } diff --git a/handler/file/handler.go b/handler/file/handler.go index 2c69b3f..d28e19e 100644 --- a/handler/file/handler.go +++ b/handler/file/handler.go @@ -73,6 +73,7 @@ func (h *fileHandler) handleFunc(w http.ResponseWriter, r *http.Request) { if auther := h.options.Auther; auther != nil { u, p, _ := r.BasicAuth() if _, ok := auther.Authenticate(r.Context(), u, p); !ok { + w.Header().Set("WWW-Authenticate", "Basic") w.WriteHeader(http.StatusUnauthorized) return } diff --git a/handler/tunnel/bind.go b/handler/tunnel/bind.go index a7ddf89..188691b 100644 --- a/handler/tunnel/bind.go +++ b/handler/tunnel/bind.go @@ -31,12 +31,15 @@ func (h *tunnelHandler) handleBind(ctx context.Context, conn net.Conn, network, connectorID = relay.NewUDPConnectorID(uuid[:]) } + v := md5.Sum([]byte(tunnelID.String())) + endpoint := hex.EncodeToString(v[:8]) + addr := address - if host, port, _ := net.SplitHostPort(addr); host == "" { - v := md5.Sum([]byte(tunnelID.String())) - host = hex.EncodeToString(v[:8]) - addr = net.JoinHostPort(host, port) + host, port, _ := net.SplitHostPort(addr) + if host == "" { + addr = net.JoinHostPort(endpoint, port) } + af := &relay.AddrFeature{} err = af.ParseFrom(addr) if err != nil { @@ -58,9 +61,15 @@ func (h *tunnelHandler) handleBind(ctx context.Context, conn net.Conn, network, h.pool.Add(tunnelID, NewConnector(connectorID, tunnelID, h.id, session, h.md.sd), h.md.tunnelTTL) if h.md.ingress != nil { h.md.ingress.SetRule(ctx, &ingress.Rule{ - Hostname: addr, + Hostname: endpoint, Endpoint: tunnelID.String(), }) + if host != "" { + h.md.ingress.SetRule(ctx, &ingress.Rule{ + Hostname: host, + Endpoint: tunnelID.String(), + }) + } } if h.md.sd != nil { err := h.md.sd.Register(ctx, &sd.Service{ diff --git a/hop/hop.go b/hop/hop.go index c228d2b..dec6392 100644 --- a/hop/hop.go +++ b/hop/hop.go @@ -308,7 +308,7 @@ func (p *chainHop) parseNode(r io.Reader) ([]*chain.Node, error) { continue } - node, err := node_parser.ParseNode(p.options.name, nc) + node, err := node_parser.ParseNode(p.options.name, nc, logger.Default()) if err != nil { return nodes, err } diff --git a/hop/plugin.go b/hop/plugin.go index 519d5f5..9df2f48 100644 --- a/hop/plugin.go +++ b/hop/plugin.go @@ -86,7 +86,7 @@ func (p *grpcPlugin) Select(ctx context.Context, opts ...hop.SelectOption) *chai return nil } - node, err := node_parser.ParseNode(p.name, &cfg) + node, err := node_parser.ParseNode(p.name, &cfg, logger.Default()) if err != nil { p.log.Error(err) return nil @@ -203,7 +203,7 @@ func (p *httpPlugin) Select(ctx context.Context, opts ...hop.SelectOption) *chai return nil } - node, err := node_parser.ParseNode(p.name, &cfg) + node, err := node_parser.ParseNode(p.name, &cfg, logger.Default()) if err != nil { p.log.Error(err) return nil diff --git a/listener/rtcp/listener.go b/listener/rtcp/listener.go index 95ba00a..ff51da6 100644 --- a/listener/rtcp/listener.go +++ b/listener/rtcp/listener.go @@ -89,7 +89,15 @@ func (l *rtcpListener) Accept() (conn net.Conn, err error) { ln = climiter.WrapListener(l.options.ConnLimiter, ln) l.setListener(ln) } - conn, err = l.ln.Accept() + + select { + case <-l.closed: + ln.Close() + return nil, net.ErrClosed + default: + } + + conn, err = ln.Accept() if err != nil { ln.Close() l.setListener(nil) @@ -107,10 +115,8 @@ func (l *rtcpListener) Close() error { case <-l.closed: default: close(l.closed) - ln := l.getListener() - if ln != nil { + if ln := l.getListener(); ln != nil { ln.Close() - // l.ln = nil } } diff --git a/listener/rudp/listener.go b/listener/rudp/listener.go index 7125212..bbd7848 100644 --- a/listener/rudp/listener.go +++ b/listener/rudp/listener.go @@ -3,6 +3,7 @@ package rudp import ( "context" "net" + "sync" "github.com/go-gost/core/chain" "github.com/go-gost/core/listener" @@ -27,6 +28,7 @@ type rudpListener struct { logger logger.Logger md metadata options listener.Options + mu sync.Mutex } func NewListener(opts ...listener.Option) listener.Listener { @@ -72,8 +74,9 @@ func (l *rudpListener) Accept() (conn net.Conn, err error) { default: } - if l.ln == nil { - l.ln, err = l.router.Bind( + ln := l.getListener() + if ln == nil { + ln, err = l.router.Bind( context.Background(), "udp", l.laddr.String(), chain.BacklogBindOption(l.md.backlog), chain.UDPConnTTLBindOption(l.md.ttl), @@ -83,11 +86,20 @@ func (l *rudpListener) Accept() (conn net.Conn, err error) { if err != nil { return nil, listener.NewAcceptError(err) } + l.setListener(ln) } + + select { + case <-l.closed: + ln.Close() + return nil, net.ErrClosed + default: + } + conn, err = l.ln.Accept() if err != nil { l.ln.Close() - l.ln = nil + l.setListener(nil) return nil, listener.NewAcceptError(err) } @@ -109,15 +121,26 @@ func (l *rudpListener) Close() error { case <-l.closed: default: close(l.closed) - if l.ln != nil { - l.ln.Close() - // l.ln = nil + if ln := l.getListener(); ln != nil { + ln.Close() } } return nil } +func (l *rudpListener) setListener(ln net.Listener) { + l.mu.Lock() + defer l.mu.Unlock() + l.ln = ln +} + +func (l *rudpListener) getListener() net.Listener { + l.mu.Lock() + defer l.mu.Unlock() + return l.ln +} + type bindAddr struct { addr string } diff --git a/metadata/metadata.go b/metadata/metadata.go index 599ee0f..0ae3b67 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -9,6 +9,9 @@ import ( type mapMetadata map[string]any func NewMetadata(m map[string]any) metadata.Metadata { + if len(m) == 0 { + return nil + } md := make(map[string]any) for k, v := range m { md[strings.ToLower(k)] = v diff --git a/metrics/service/service.go b/metrics/service/service.go index eeaa237..12e0434 100644 --- a/metrics/service/service.go +++ b/metrics/service/service.go @@ -33,8 +33,9 @@ func AutherOption(auther auth.Authenticator) Option { } type metricService struct { - s *http.Server - ln net.Listener + s *http.Server + ln net.Listener + cclose chan struct{} } func NewService(addr string, opts ...Option) (service.Service, error) { @@ -66,7 +67,8 @@ func NewService(addr string, opts ...Option) (service.Service, error) { s: &http.Server{ Handler: mux, }, - ln: ln, + ln: ln, + cclose: make(chan struct{}), }, nil } @@ -81,3 +83,12 @@ func (s *metricService) Addr() net.Addr { func (s *metricService) Close() error { return s.s.Close() } + +func (s *metricService) IsClosed() bool { + select { + case <-s.cclose: + return true + default: + return false + } +} diff --git a/service/service.go b/service/service.go index 38724a7..86bcfc5 100644 --- a/service/service.go +++ b/service/service.go @@ -102,16 +102,6 @@ func (s *defaultService) Addr() net.Addr { return s.listener.Addr() } -func (s *defaultService) Close() error { - s.execCmds("pre-down", s.options.preDown) - defer s.execCmds("post-down", s.options.postDown) - - if closer, ok := s.handler.(io.Closer); ok { - closer.Close() - } - return s.listener.Close() -} - func (s *defaultService) Serve() error { s.execCmds("post-up", s.options.postUp) @@ -201,6 +191,16 @@ func (s *defaultService) Serve() error { } } +func (s *defaultService) Close() error { + s.execCmds("pre-down", s.options.preDown) + defer s.execCmds("post-down", s.options.postDown) + + if closer, ok := s.handler.(io.Closer); ok { + closer.Close() + } + return s.listener.Close() +} + func (s *defaultService) execCmds(phase string, cmds []string) { for _, cmd := range cmds { cmd := strings.TrimSpace(cmd)