fix auth for file handler

This commit is contained in:
ginuerzh 2023-12-16 14:28:58 +08:00
parent b1390dda1c
commit f847fa533e
17 changed files with 263 additions and 201 deletions

View File

@ -2,6 +2,13 @@ package api
import ( import (
"embed" "embed"
"net"
"net/http"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/go-gost/core/auth"
"github.com/go-gost/core/service"
) )
var ( var (
@ -13,3 +20,160 @@ type Response struct {
Code int `json:"code,omitempty"` Code int `json:"code,omitempty"`
Msg string `json:"msg,omitempty"` Msg string `json:"msg,omitempty"`
} }
type options struct {
accessLog bool
pathPrefix string
auther auth.Authenticator
}
type Option func(*options)
func PathPrefixOption(pathPrefix string) Option {
return func(o *options) {
o.pathPrefix = pathPrefix
}
}
func AccessLogOption(enable bool) Option {
return func(o *options) {
o.accessLog = enable
}
}
func AutherOption(auther auth.Authenticator) Option {
return func(o *options) {
o.auther = auther
}
}
type server struct {
s *http.Server
ln net.Listener
cclose chan struct{}
}
func NewService(addr string, opts ...Option) (service.Service, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
var options options
for _, opt := range opts {
opt(&options)
}
gin.SetMode(gin.ReleaseMode)
r := gin.New()
r.Use(
cors.New((cors.Config{
AllowAllOrigins: true,
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowHeaders: []string{"*"},
AllowPrivateNetwork: true,
})),
gin.Recovery(),
)
if options.accessLog {
r.Use(mwLogger())
}
router := r.Group("")
if options.pathPrefix != "" {
router = router.Group(options.pathPrefix)
}
router.StaticFS("/docs", http.FS(swaggerDoc))
config := router.Group("/config")
config.Use(mwBasicAuth(options.auther))
registerConfig(config)
return &server{
s: &http.Server{
Handler: r,
},
ln: ln,
cclose: make(chan struct{}),
}, nil
}
func (s *server) Serve() error {
return s.s.Serve(s.ln)
}
func (s *server) Addr() net.Addr {
return s.ln.Addr()
}
func (s *server) Close() error {
return s.s.Close()
}
func (s *server) IsClosed() bool {
select {
case <-s.cclose:
return true
default:
return false
}
}
func registerConfig(config *gin.RouterGroup) {
config.GET("", getConfig)
config.POST("", saveConfig)
config.POST("/services", createService)
config.PUT("/services/:service", updateService)
config.DELETE("/services/:service", deleteService)
config.POST("/chains", createChain)
config.PUT("/chains/:chain", updateChain)
config.DELETE("/chains/:chain", deleteChain)
config.POST("/hops", createHop)
config.PUT("/hops/:hop", updateHop)
config.DELETE("/hops/:hop", deleteHop)
config.POST("/authers", createAuther)
config.PUT("/authers/:auther", updateAuther)
config.DELETE("/authers/:auther", deleteAuther)
config.POST("/admissions", createAdmission)
config.PUT("/admissions/:admission", updateAdmission)
config.DELETE("/admissions/:admission", deleteAdmission)
config.POST("/bypasses", createBypass)
config.PUT("/bypasses/:bypass", updateBypass)
config.DELETE("/bypasses/:bypass", deleteBypass)
config.POST("/resolvers", createResolver)
config.PUT("/resolvers/:resolver", updateResolver)
config.DELETE("/resolvers/:resolver", deleteResolver)
config.POST("/hosts", createHosts)
config.PUT("/hosts/:hosts", updateHosts)
config.DELETE("/hosts/:hosts", deleteHosts)
config.POST("/ingresses", createIngress)
config.PUT("/ingresses/:ingress", updateIngress)
config.DELETE("/ingresses/:ingress", deleteIngress)
config.POST("/routers", createRouter)
config.PUT("/routers/:router", updateRouter)
config.DELETE("/routers/:router", deleteRouter)
config.POST("/limiters", createLimiter)
config.PUT("/limiters/:limiter", updateLimiter)
config.DELETE("/limiters/:limiter", deleteLimiter)
config.POST("/climiters", createConnLimiter)
config.PUT("/climiters/:limiter", updateConnLimiter)
config.DELETE("/climiters/:limiter", deleteConnLimiter)
config.POST("/rlimiters", createRateLimiter)
config.PUT("/rlimiters/:limiter", updateRateLimiter)
config.DELETE("/rlimiters/:limiter", deleteRateLimiter)
}

View File

@ -4,6 +4,7 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-gost/core/logger"
"github.com/go-gost/x/config" "github.com/go-gost/x/config"
parser "github.com/go-gost/x/config/parsing/chain" parser "github.com/go-gost/x/config/parsing/chain"
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
@ -40,7 +41,7 @@ func createChain(ctx *gin.Context) {
return return
} }
v, err := parser.ParseChain(&req.Data) v, err := parser.ParseChain(&req.Data, logger.Default())
if err != nil { if err != nil {
writeError(ctx, ErrCreate) writeError(ctx, ErrCreate)
return return
@ -99,7 +100,7 @@ func updateChain(ctx *gin.Context) {
req.Data.Name = req.Chain req.Data.Name = req.Chain
v, err := parser.ParseChain(&req.Data) v, err := parser.ParseChain(&req.Data, logger.Default())
if err != nil { if err != nil {
writeError(ctx, ErrCreate) writeError(ctx, ErrCreate)
return return

View File

@ -4,6 +4,7 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-gost/core/logger"
"github.com/go-gost/x/config" "github.com/go-gost/x/config"
parser "github.com/go-gost/x/config/parsing/hop" parser "github.com/go-gost/x/config/parsing/hop"
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
@ -40,7 +41,7 @@ func createHop(ctx *gin.Context) {
return return
} }
v, err := parser.ParseHop(&req.Data) v, err := parser.ParseHop(&req.Data, logger.Default())
if err != nil { if err != nil {
writeError(ctx, ErrCreate) writeError(ctx, ErrCreate)
return return
@ -99,7 +100,7 @@ func updateHop(ctx *gin.Context) {
req.Data.Name = req.Hop req.Data.Name = req.Hop
v, err := parser.ParseHop(&req.Data) v, err := parser.ParseHop(&req.Data, logger.Default())
if err != nil { if err != nil {
writeError(ctx, ErrCreate) writeError(ctx, ErrCreate)
return return

View File

@ -1,157 +0,0 @@
package api
import (
"net"
"net/http"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/go-gost/core/auth"
"github.com/go-gost/core/service"
)
type options struct {
accessLog bool
pathPrefix string
auther auth.Authenticator
}
type Option func(*options)
func PathPrefixOption(pathPrefix string) Option {
return func(o *options) {
o.pathPrefix = pathPrefix
}
}
func AccessLogOption(enable bool) Option {
return func(o *options) {
o.accessLog = enable
}
}
func AutherOption(auther auth.Authenticator) Option {
return func(o *options) {
o.auther = auther
}
}
type server struct {
s *http.Server
ln net.Listener
}
func NewService(addr string, opts ...Option) (service.Service, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
var options options
for _, opt := range opts {
opt(&options)
}
gin.SetMode(gin.ReleaseMode)
r := gin.New()
r.Use(
cors.New((cors.Config{
AllowAllOrigins: true,
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowHeaders: []string{"*"},
AllowPrivateNetwork: true,
})),
gin.Recovery(),
)
if options.accessLog {
r.Use(mwLogger())
}
router := r.Group("")
if options.pathPrefix != "" {
router = router.Group(options.pathPrefix)
}
router.StaticFS("/docs", http.FS(swaggerDoc))
config := router.Group("/config")
config.Use(mwBasicAuth(options.auther))
registerConfig(config)
return &server{
s: &http.Server{
Handler: r,
},
ln: ln,
}, nil
}
func (s *server) Serve() error {
return s.s.Serve(s.ln)
}
func (s *server) Addr() net.Addr {
return s.ln.Addr()
}
func (s *server) Close() error {
return s.s.Close()
}
func registerConfig(config *gin.RouterGroup) {
config.GET("", getConfig)
config.POST("", saveConfig)
config.POST("/services", createService)
config.PUT("/services/:service", updateService)
config.DELETE("/services/:service", deleteService)
config.POST("/chains", createChain)
config.PUT("/chains/:chain", updateChain)
config.DELETE("/chains/:chain", deleteChain)
config.POST("/hops", createHop)
config.PUT("/hops/:hop", updateHop)
config.DELETE("/hops/:hop", deleteHop)
config.POST("/authers", createAuther)
config.PUT("/authers/:auther", updateAuther)
config.DELETE("/authers/:auther", deleteAuther)
config.POST("/admissions", createAdmission)
config.PUT("/admissions/:admission", updateAdmission)
config.DELETE("/admissions/:admission", deleteAdmission)
config.POST("/bypasses", createBypass)
config.PUT("/bypasses/:bypass", updateBypass)
config.DELETE("/bypasses/:bypass", deleteBypass)
config.POST("/resolvers", createResolver)
config.PUT("/resolvers/:resolver", updateResolver)
config.DELETE("/resolvers/:resolver", deleteResolver)
config.POST("/hosts", createHosts)
config.PUT("/hosts/:hosts", updateHosts)
config.DELETE("/hosts/:hosts", deleteHosts)
config.POST("/ingresses", createIngress)
config.PUT("/ingresses/:ingress", updateIngress)
config.DELETE("/ingresses/:ingress", deleteIngress)
config.POST("/routers", createRouter)
config.PUT("/routers/:router", updateRouter)
config.DELETE("/routers/:router", deleteRouter)
config.POST("/limiters", createLimiter)
config.PUT("/limiters/:limiter", updateLimiter)
config.DELETE("/limiters/:limiter", deleteLimiter)
config.POST("/climiters", createConnLimiter)
config.PUT("/climiters/:limiter", updateConnLimiter)
config.DELETE("/climiters/:limiter", deleteConnLimiter)
config.POST("/rlimiters", createRateLimiter)
config.PUT("/rlimiters/:limiter", updateRateLimiter)
config.DELETE("/rlimiters/:limiter", deleteRateLimiter)
}

View File

@ -12,12 +12,12 @@ import (
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
) )
func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) { func ParseChain(cfg *config.ChainConfig, log logger.Logger) (chain.Chainer, error) {
if cfg == nil { if cfg == nil {
return nil, nil return nil, nil
} }
chainLogger := logger.Default().WithFields(map[string]any{ chainLogger := log.WithFields(map[string]any{
"kind": "chain", "kind": "chain",
"chain": cfg.Name, "chain": cfg.Name,
}) })
@ -37,7 +37,7 @@ func ParseChain(cfg *config.ChainConfig) (chain.Chainer, error) {
var err error var err error
if ch.Nodes != nil || ch.Plugin != nil { if ch.Nodes != nil || ch.Plugin != nil {
if hop, err = hop_parser.ParseHop(ch); err != nil { if hop, err = hop_parser.ParseHop(ch, log); err != nil {
return nil, err return nil, err
} }
} else { } else {

View File

@ -17,7 +17,7 @@ import (
"github.com/go-gost/x/internal/plugin" "github.com/go-gost/x/internal/plugin"
) )
func ParseHop(cfg *config.HopConfig) (hop.Hop, error) { func ParseHop(cfg *config.HopConfig, log logger.Logger) (hop.Hop, error) {
if cfg == nil { if cfg == nil {
return nil, nil return nil, nil
} }
@ -77,7 +77,7 @@ func ParseHop(cfg *config.HopConfig) (hop.Hop, error) {
} }
} }
node, err := node_parser.ParseNode(cfg.Name, v) node, err := node_parser.ParseNode(cfg.Name, v, log)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -97,7 +97,7 @@ func ParseHop(cfg *config.HopConfig) (hop.Hop, error) {
xhop.SelectorOption(sel), xhop.SelectorOption(sel),
xhop.BypassOption(bypass.BypassGroup(bypass_parser.List(cfg.Bypass, cfg.Bypasses...)...)), xhop.BypassOption(bypass.BypassGroup(bypass_parser.List(cfg.Bypass, cfg.Bypasses...)...)),
xhop.ReloadPeriodOption(cfg.Reload), xhop.ReloadPeriodOption(cfg.Reload),
xhop.LoggerOption(logger.Default().WithFields(map[string]any{ xhop.LoggerOption(log.WithFields(map[string]any{
"kind": "hop", "kind": "hop",
"hop": cfg.Name, "hop": cfg.Name,
})), })),

View File

@ -23,7 +23,7 @@ import (
"github.com/go-gost/x/registry" "github.com/go-gost/x/registry"
) )
func ParseNode(hop string, cfg *config.NodeConfig) (*chain.Node, error) { func ParseNode(hop string, cfg *config.NodeConfig, log logger.Logger) (*chain.Node, error) {
if cfg == nil { if cfg == nil {
return nil, nil return nil, nil
} }
@ -40,7 +40,7 @@ func ParseNode(hop string, cfg *config.NodeConfig) (*chain.Node, error) {
} }
} }
nodeLogger := logger.Default().WithFields(map[string]any{ nodeLogger := log.WithFields(map[string]any{
"hop": hop, "hop": hop,
"kind": "node", "kind": "node",
"node": cfg.Name, "node": cfg.Name,

View File

@ -291,7 +291,7 @@ func parseForwarder(cfg *config.ForwarderConfig) (hop.Hop, error) {
} }
} }
if len(hc.Nodes) > 0 { if len(hc.Nodes) > 0 {
return hop_parser.ParseHop(&hc) return hop_parser.ParseHop(&hc, logger.Default())
} }
return registry.HopRegistry().Get(hc.Name), nil return registry.HopRegistry().Get(hc.Name), nil
} }

View File

@ -73,6 +73,7 @@ func (h *fileHandler) handleFunc(w http.ResponseWriter, r *http.Request) {
if auther := h.options.Auther; auther != nil { if auther := h.options.Auther; auther != nil {
u, p, _ := r.BasicAuth() u, p, _ := r.BasicAuth()
if _, ok := auther.Authenticate(r.Context(), u, p); !ok { if _, ok := auther.Authenticate(r.Context(), u, p); !ok {
w.Header().Set("WWW-Authenticate", "Basic")
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return return
} }

View File

@ -31,12 +31,15 @@ func (h *tunnelHandler) handleBind(ctx context.Context, conn net.Conn, network,
connectorID = relay.NewUDPConnectorID(uuid[:]) connectorID = relay.NewUDPConnectorID(uuid[:])
} }
addr := address
if host, port, _ := net.SplitHostPort(addr); host == "" {
v := md5.Sum([]byte(tunnelID.String())) v := md5.Sum([]byte(tunnelID.String()))
host = hex.EncodeToString(v[:8]) endpoint := hex.EncodeToString(v[:8])
addr = net.JoinHostPort(host, port)
addr := address
host, port, _ := net.SplitHostPort(addr)
if host == "" {
addr = net.JoinHostPort(endpoint, port)
} }
af := &relay.AddrFeature{} af := &relay.AddrFeature{}
err = af.ParseFrom(addr) err = af.ParseFrom(addr)
if err != nil { if err != nil {
@ -58,9 +61,15 @@ func (h *tunnelHandler) handleBind(ctx context.Context, conn net.Conn, network,
h.pool.Add(tunnelID, NewConnector(connectorID, tunnelID, h.id, session, h.md.sd), h.md.tunnelTTL) h.pool.Add(tunnelID, NewConnector(connectorID, tunnelID, h.id, session, h.md.sd), h.md.tunnelTTL)
if h.md.ingress != nil { if h.md.ingress != nil {
h.md.ingress.SetRule(ctx, &ingress.Rule{ h.md.ingress.SetRule(ctx, &ingress.Rule{
Hostname: addr, Hostname: endpoint,
Endpoint: tunnelID.String(), Endpoint: tunnelID.String(),
}) })
if host != "" {
h.md.ingress.SetRule(ctx, &ingress.Rule{
Hostname: host,
Endpoint: tunnelID.String(),
})
}
} }
if h.md.sd != nil { if h.md.sd != nil {
err := h.md.sd.Register(ctx, &sd.Service{ err := h.md.sd.Register(ctx, &sd.Service{

View File

@ -308,7 +308,7 @@ func (p *chainHop) parseNode(r io.Reader) ([]*chain.Node, error) {
continue continue
} }
node, err := node_parser.ParseNode(p.options.name, nc) node, err := node_parser.ParseNode(p.options.name, nc, logger.Default())
if err != nil { if err != nil {
return nodes, err return nodes, err
} }

View File

@ -86,7 +86,7 @@ func (p *grpcPlugin) Select(ctx context.Context, opts ...hop.SelectOption) *chai
return nil return nil
} }
node, err := node_parser.ParseNode(p.name, &cfg) node, err := node_parser.ParseNode(p.name, &cfg, logger.Default())
if err != nil { if err != nil {
p.log.Error(err) p.log.Error(err)
return nil return nil
@ -203,7 +203,7 @@ func (p *httpPlugin) Select(ctx context.Context, opts ...hop.SelectOption) *chai
return nil return nil
} }
node, err := node_parser.ParseNode(p.name, &cfg) node, err := node_parser.ParseNode(p.name, &cfg, logger.Default())
if err != nil { if err != nil {
p.log.Error(err) p.log.Error(err)
return nil return nil

View File

@ -89,7 +89,15 @@ func (l *rtcpListener) Accept() (conn net.Conn, err error) {
ln = climiter.WrapListener(l.options.ConnLimiter, ln) ln = climiter.WrapListener(l.options.ConnLimiter, ln)
l.setListener(ln) l.setListener(ln)
} }
conn, err = l.ln.Accept()
select {
case <-l.closed:
ln.Close()
return nil, net.ErrClosed
default:
}
conn, err = ln.Accept()
if err != nil { if err != nil {
ln.Close() ln.Close()
l.setListener(nil) l.setListener(nil)
@ -107,10 +115,8 @@ func (l *rtcpListener) Close() error {
case <-l.closed: case <-l.closed:
default: default:
close(l.closed) close(l.closed)
ln := l.getListener() if ln := l.getListener(); ln != nil {
if ln != nil {
ln.Close() ln.Close()
// l.ln = nil
} }
} }

View File

@ -3,6 +3,7 @@ package rudp
import ( import (
"context" "context"
"net" "net"
"sync"
"github.com/go-gost/core/chain" "github.com/go-gost/core/chain"
"github.com/go-gost/core/listener" "github.com/go-gost/core/listener"
@ -27,6 +28,7 @@ type rudpListener struct {
logger logger.Logger logger logger.Logger
md metadata md metadata
options listener.Options options listener.Options
mu sync.Mutex
} }
func NewListener(opts ...listener.Option) listener.Listener { func NewListener(opts ...listener.Option) listener.Listener {
@ -72,8 +74,9 @@ func (l *rudpListener) Accept() (conn net.Conn, err error) {
default: default:
} }
if l.ln == nil { ln := l.getListener()
l.ln, err = l.router.Bind( if ln == nil {
ln, err = l.router.Bind(
context.Background(), "udp", l.laddr.String(), context.Background(), "udp", l.laddr.String(),
chain.BacklogBindOption(l.md.backlog), chain.BacklogBindOption(l.md.backlog),
chain.UDPConnTTLBindOption(l.md.ttl), chain.UDPConnTTLBindOption(l.md.ttl),
@ -83,11 +86,20 @@ func (l *rudpListener) Accept() (conn net.Conn, err error) {
if err != nil { if err != nil {
return nil, listener.NewAcceptError(err) return nil, listener.NewAcceptError(err)
} }
l.setListener(ln)
} }
select {
case <-l.closed:
ln.Close()
return nil, net.ErrClosed
default:
}
conn, err = l.ln.Accept() conn, err = l.ln.Accept()
if err != nil { if err != nil {
l.ln.Close() l.ln.Close()
l.ln = nil l.setListener(nil)
return nil, listener.NewAcceptError(err) return nil, listener.NewAcceptError(err)
} }
@ -109,15 +121,26 @@ func (l *rudpListener) Close() error {
case <-l.closed: case <-l.closed:
default: default:
close(l.closed) close(l.closed)
if l.ln != nil { if ln := l.getListener(); ln != nil {
l.ln.Close() ln.Close()
// l.ln = nil
} }
} }
return nil return nil
} }
func (l *rudpListener) setListener(ln net.Listener) {
l.mu.Lock()
defer l.mu.Unlock()
l.ln = ln
}
func (l *rudpListener) getListener() net.Listener {
l.mu.Lock()
defer l.mu.Unlock()
return l.ln
}
type bindAddr struct { type bindAddr struct {
addr string addr string
} }

View File

@ -9,6 +9,9 @@ import (
type mapMetadata map[string]any type mapMetadata map[string]any
func NewMetadata(m map[string]any) metadata.Metadata { func NewMetadata(m map[string]any) metadata.Metadata {
if len(m) == 0 {
return nil
}
md := make(map[string]any) md := make(map[string]any)
for k, v := range m { for k, v := range m {
md[strings.ToLower(k)] = v md[strings.ToLower(k)] = v

View File

@ -35,6 +35,7 @@ func AutherOption(auther auth.Authenticator) Option {
type metricService struct { type metricService struct {
s *http.Server s *http.Server
ln net.Listener ln net.Listener
cclose chan struct{}
} }
func NewService(addr string, opts ...Option) (service.Service, error) { func NewService(addr string, opts ...Option) (service.Service, error) {
@ -67,6 +68,7 @@ func NewService(addr string, opts ...Option) (service.Service, error) {
Handler: mux, Handler: mux,
}, },
ln: ln, ln: ln,
cclose: make(chan struct{}),
}, nil }, nil
} }
@ -81,3 +83,12 @@ func (s *metricService) Addr() net.Addr {
func (s *metricService) Close() error { func (s *metricService) Close() error {
return s.s.Close() return s.s.Close()
} }
func (s *metricService) IsClosed() bool {
select {
case <-s.cclose:
return true
default:
return false
}
}

View File

@ -102,16 +102,6 @@ func (s *defaultService) Addr() net.Addr {
return s.listener.Addr() return s.listener.Addr()
} }
func (s *defaultService) Close() error {
s.execCmds("pre-down", s.options.preDown)
defer s.execCmds("post-down", s.options.postDown)
if closer, ok := s.handler.(io.Closer); ok {
closer.Close()
}
return s.listener.Close()
}
func (s *defaultService) Serve() error { func (s *defaultService) Serve() error {
s.execCmds("post-up", s.options.postUp) s.execCmds("post-up", s.options.postUp)
@ -201,6 +191,16 @@ func (s *defaultService) Serve() error {
} }
} }
func (s *defaultService) Close() error {
s.execCmds("pre-down", s.options.preDown)
defer s.execCmds("post-down", s.options.postDown)
if closer, ok := s.handler.(io.Closer); ok {
closer.Close()
}
return s.listener.Close()
}
func (s *defaultService) execCmds(phase string, cmds []string) { func (s *defaultService) execCmds(phase string, cmds []string) {
for _, cmd := range cmds { for _, cmd := range cmds {
cmd := strings.TrimSpace(cmd) cmd := strings.TrimSpace(cmd)