add auth support for forwarder node

This commit is contained in:
ginuerzh
2023-01-31 14:04:28 +08:00
parent 3e35a7b761
commit ebdb77d71f
8 changed files with 75 additions and 25 deletions

View File

@ -11,6 +11,7 @@ import (
"github.com/go-gost/core/auth" "github.com/go-gost/core/auth"
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
"github.com/go-gost/x/internal/loader" "github.com/go-gost/x/internal/loader"
xlogger "github.com/go-gost/x/logger"
) )
type options struct { type options struct {
@ -74,6 +75,9 @@ func NewAuthenticator(opts ...Option) auth.Authenticator {
for _, opt := range opts { for _, opt := range opts {
opt(&options) opt(&options)
} }
if options.logger == nil {
options.logger = xlogger.Nop()
}
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
p := &authenticator{ p := &authenticator{

View File

@ -288,6 +288,7 @@ type ForwardNodeConfig struct {
Bypasses []string `yaml:",omitempty" json:"bypasses,omitempty"` Bypasses []string `yaml:",omitempty" json:"bypasses,omitempty"`
HTTP *HTTPNodeConfig `yaml:",omitempty" json:"http,omitempty"` HTTP *HTTPNodeConfig `yaml:",omitempty" json:"http,omitempty"`
TLS *TLSNodeConfig `yaml:",omitempty" json:"tls,omitempty"` TLS *TLSNodeConfig `yaml:",omitempty" json:"tls,omitempty"`
Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"`
} }
type HTTPNodeConfig struct { type HTTPNodeConfig struct {
@ -382,6 +383,7 @@ type NodeConfig struct {
Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"` Metadata map[string]any `yaml:",omitempty" json:"metadata,omitempty"`
HTTP *HTTPNodeConfig `yaml:",omitempty" json:"http,omitempty"` HTTP *HTTPNodeConfig `yaml:",omitempty" json:"http,omitempty"`
TLS *TLSNodeConfig `yaml:",omitempty" json:"tls,omitempty"` TLS *TLSNodeConfig `yaml:",omitempty" json:"tls,omitempty"`
Auth *AuthConfig `yaml:",omitempty" json:"auth,omitempty"`
} }
type Config struct { type Config struct {

View File

@ -12,6 +12,7 @@ import (
"github.com/go-gost/core/logger" "github.com/go-gost/core/logger"
"github.com/go-gost/core/metadata" "github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util" mdutil "github.com/go-gost/core/metadata/util"
auther "github.com/go-gost/x/auth"
xchain "github.com/go-gost/x/chain" xchain "github.com/go-gost/x/chain"
"github.com/go-gost/x/config" "github.com/go-gost/x/config"
tls_util "github.com/go-gost/x/internal/util/tls" tls_util "github.com/go-gost/x/internal/util/tls"
@ -231,6 +232,19 @@ func ParseHop(cfg *config.HopConfig) (chain.Hop, error) {
Secure: v.TLS.Secure, Secure: v.TLS.Secure,
})) }))
} }
if v.Auth != nil {
opts = append(opts, chain.AutherNodeOption(
auther.NewAuthenticator(
auther.AuthsOption(map[string]string{v.Auth.Username: v.Auth.Password}),
auther.LoggerOption(logger.Default().WithFields(map[string]any{
"kind": "node",
"node": v.Name,
"addr": v.Addr,
"host": v.Host,
"protocol": v.Protocol,
})),
)))
}
node := chain.NewNode(v.Name, v.Addr, opts...) node := chain.NewNode(v.Name, v.Addr, opts...)
nodes = append(nodes, node) nodes = append(nodes, node)
} }

View File

@ -259,6 +259,7 @@ func parseForwarder(cfg *config.ForwarderConfig) (chain.Hop, error) {
Bypasses: node.Bypasses, Bypasses: node.Bypasses,
HTTP: node.HTTP, HTTP: node.HTTP,
TLS: node.TLS, TLS: node.TLS,
Auth: node.Auth,
}, },
) )
} }

2
go.mod
View File

@ -7,7 +7,7 @@ require (
github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d
github.com/gin-contrib/cors v1.3.1 github.com/gin-contrib/cors v1.3.1
github.com/gin-gonic/gin v1.8.2 github.com/gin-gonic/gin v1.8.2
github.com/go-gost/core v0.0.0-20230129124513-1a643651c025 github.com/go-gost/core v0.0.0-20230131053758-ff3b77ac2899
github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks4 v0.0.1
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09
github.com/go-gost/relay v0.3.1 github.com/go-gost/relay v0.3.1

4
go.sum
View File

@ -91,8 +91,8 @@ github.com/gin-gonic/gin v1.8.2/go.mod h1:qw5AYuDrzRTnhvusDsrov+fDIxp9Dleuu12h8n
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gost/core v0.0.0-20230129124513-1a643651c025 h1:3UdtINVrlGIulMyTrygas1uYMc/az6VivtzosP8FTmg= github.com/go-gost/core v0.0.0-20230131053758-ff3b77ac2899 h1:Ofa8D9NViX+biyS3oUPgIzAjYD6i71abRGalX/0Den4=
github.com/go-gost/core v0.0.0-20230129124513-1a643651c025/go.mod h1:R08B7BVdhWsYHX8s7wkEBpeKqc4+YFP6bLLFoao0J/A= github.com/go-gost/core v0.0.0-20230131053758-ff3b77ac2899/go.mod h1:R08B7BVdhWsYHX8s7wkEBpeKqc4+YFP6bLLFoao0J/A=
github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s=
github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc=
github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04=

View File

@ -123,7 +123,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", target.Addr, network), "host": host,
"node": target.Name,
"dst": fmt.Sprintf("%s/%s", target.Addr, network),
}) })
log.Debugf("%s >> %s", conn.RemoteAddr(), target.Addr) log.Debugf("%s >> %s", conn.RemoteAddr(), target.Addr)
@ -157,13 +159,13 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
br := bufio.NewReader(rw) br := bufio.NewReader(rw)
var connPool sync.Map var connPool sync.Map
resp := &http.Response{
ProtoMajor: 1,
ProtoMinor: 1,
StatusCode: http.StatusServiceUnavailable,
}
for { for {
resp := &http.Response{
ProtoMajor: 1,
ProtoMinor: 1,
StatusCode: http.StatusServiceUnavailable,
}
err = func() error { err = func() error {
req, err := http.ReadRequest(br) req, err := http.ReadRequest(br)
if err != nil { if err != nil {
@ -183,11 +185,21 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"dst": target.Addr, "host": req.Host,
"node": target.Name,
"dst": target.Addr,
}) })
log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr) log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr)
// log.Debugf("%s >> %s", conn.RemoteAddr(), target.Addr) if auther := target.Options().Auther; auther != nil {
username, password, _ := req.BasicAuth()
if !auther.Authenticate(username, password) {
resp.StatusCode = http.StatusUnauthorized
resp.Header.Set("WWW-Authenticate", "Basic")
log.Warnf("node %s(%s) 401 unauthorized", target.Name, target.Addr)
return resp.Write(rw)
}
}
var cc net.Conn var cc net.Conn
if v, ok := connPool.Load(target); ok { if v, ok := connPool.Load(target); ok {
@ -202,7 +214,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
if marker := target.Marker(); marker != nil { if marker := target.Marker(); marker != nil {
marker.Mark() marker.Mark()
} }
log.Warnf("connect to node %s(%s) failed", target.Name, target.Addr) log.Warnf("connect to node %s(%s) failed: %v", target.Name, target.Addr, err)
return resp.Write(rw) return resp.Write(rw)
} }
if marker := target.Marker(); marker != nil { if marker := target.Marker(); marker != nil {
@ -221,7 +233,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
go func() { go func() {
defer cc.Close() defer cc.Close()
xnet.CopyBuffer(rw, cc, 8192) err := xnet.CopyBuffer(rw, cc, 8192)
log.Debugf("close connection to node %s(%s), reason: %v", target.Name, target.Addr, err)
connPool.Delete(target) connPool.Delete(target)
}() }()
} }
@ -240,6 +253,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
log.Trace(string(dump)) log.Trace(string(dump))
} }
if err := req.Write(cc); err != nil { if err := req.Write(cc); err != nil {
log.Warnf("send request to node %s(%s) failed: %v", target.Name, target.Addr, err)
return resp.Write(rw) return resp.Write(rw)
} }

View File

@ -122,7 +122,9 @@ func (h *forwardHandler) Handle(ctx context.Context, conn net.Conn, opts ...hand
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"dst": fmt.Sprintf("%s/%s", target.Addr, network), "host": host,
"node": target.Name,
"dst": fmt.Sprintf("%s/%s", target.Addr, network),
}) })
log.Debugf("%s >> %s", conn.RemoteAddr(), target.Addr) log.Debugf("%s >> %s", conn.RemoteAddr(), target.Addr)
@ -156,13 +158,14 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
br := bufio.NewReader(rw) br := bufio.NewReader(rw)
var connPool sync.Map var connPool sync.Map
resp := &http.Response{
ProtoMajor: 1,
ProtoMinor: 1,
StatusCode: http.StatusServiceUnavailable,
}
for { for {
resp := &http.Response{
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{},
StatusCode: http.StatusServiceUnavailable,
}
err = func() error { err = func() error {
req, err := http.ReadRequest(br) req, err := http.ReadRequest(br)
if err != nil { if err != nil {
@ -182,11 +185,21 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
} }
log = log.WithFields(map[string]any{ log = log.WithFields(map[string]any{
"dst": target.Addr, "host": req.Host,
"node": target.Name,
"dst": target.Addr,
}) })
log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr) log.Debugf("find node for host %s -> %s(%s)", req.Host, target.Name, target.Addr)
// log.Debugf("%s >> %s", conn.RemoteAddr(), target.Addr) if auther := target.Options().Auther; auther != nil {
username, password, _ := req.BasicAuth()
if !auther.Authenticate(username, password) {
resp.StatusCode = http.StatusUnauthorized
resp.Header.Set("WWW-Authenticate", "Basic")
log.Warnf("node %s(%s) 401 unauthorized", target.Name, target.Addr)
return resp.Write(rw)
}
}
var cc net.Conn var cc net.Conn
if v, ok := connPool.Load(target); ok { if v, ok := connPool.Load(target); ok {
@ -201,7 +214,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
if marker := target.Marker(); marker != nil { if marker := target.Marker(); marker != nil {
marker.Mark() marker.Mark()
} }
log.Warnf("connect to node %s(%s) failed", target.Name, target.Addr) log.Warnf("connect to node %s(%s) failed: %v", target.Name, target.Addr, err)
return resp.Write(rw) return resp.Write(rw)
} }
if marker := target.Marker(); marker != nil { if marker := target.Marker(); marker != nil {
@ -220,7 +233,8 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
go func() { go func() {
defer cc.Close() defer cc.Close()
xnet.CopyBuffer(rw, cc, 8192) err := xnet.CopyBuffer(rw, cc, 8192)
log.Debugf("close connection to node %s(%s), reason: %v", target.Name, target.Addr, err)
connPool.Delete(target) connPool.Delete(target)
}() }()
} }
@ -239,6 +253,7 @@ func (h *forwardHandler) handleHTTP(ctx context.Context, rw io.ReadWriter, log l
log.Trace(string(dump)) log.Trace(string(dump))
} }
if err := req.Write(cc); err != nil { if err := req.Write(cc); err != nil {
log.Warnf("send request to node %s(%s) failed: %v", target.Name, target.Addr, err)
return resp.Write(rw) return resp.Write(rw)
} }