x/handler/tunnel/handler.go
2024-08-01 20:52:08 +08:00

285 lines
6.3 KiB
Go

package tunnel
import (
"context"
"errors"
"fmt"
"net"
"strconv"
"time"
"github.com/go-gost/core/handler"
"github.com/go-gost/core/listener"
"github.com/go-gost/core/logger"
md "github.com/go-gost/core/metadata"
"github.com/go-gost/core/recorder"
"github.com/go-gost/core/service"
"github.com/go-gost/relay"
ctxvalue "github.com/go-gost/x/ctx"
xnet "github.com/go-gost/x/internal/net"
xrecorder "github.com/go-gost/x/recorder"
"github.com/go-gost/x/registry"
xservice "github.com/go-gost/x/service"
"github.com/google/uuid"
)
var (
ErrBadVersion = errors.New("bad version")
ErrUnknownCmd = errors.New("unknown command")
ErrTunnelID = errors.New("invalid tunnel ID")
ErrTunnelNotAvailable = errors.New("tunnel not available")
ErrUnauthorized = errors.New("unauthorized")
ErrRateLimit = errors.New("rate limiting exceeded")
)
func init() {
registry.HandlerRegistry().Register("tunnel", NewHandler)
}
type tunnelHandler struct {
id string
options handler.Options
pool *ConnectorPool
recorder recorder.Recorder
epSvc service.Service
ep *entrypoint
md metadata
log logger.Logger
}
func NewHandler(opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
return &tunnelHandler{
options: options,
}
}
func (h *tunnelHandler) Init(md md.Metadata) (err error) {
if err := h.parseMetadata(md); err != nil {
return err
}
uuid, err := uuid.NewRandom()
if err != nil {
return err
}
h.id = uuid.String()
h.log = h.options.Logger.WithFields(map[string]any{
"node": h.id,
})
if opts := h.options.Router.Options(); opts != nil {
for _, ro := range opts.Recorders {
if ro.Record == xrecorder.RecorderServiceHandlerTunnel {
h.recorder = ro.Recorder
break
}
}
}
h.pool = NewConnectorPool(h.id, h.md.sd)
h.ep = &entrypoint{
node: h.id,
pool: h.pool,
ingress: h.md.ingress,
sd: h.md.sd,
log: h.log.WithFields(map[string]any{
"kind": "entrypoint",
}),
}
if err = h.initEntrypoint(); err != nil {
return
}
return nil
}
func (h *tunnelHandler) initEntrypoint() (err error) {
if h.md.entryPoint == "" {
return
}
network := "tcp"
if xnet.IsIPv4(h.md.entryPoint) {
network = "tcp4"
}
ln, err := net.Listen(network, h.md.entryPoint)
if err != nil {
h.log.Error(err)
return
}
serviceName := fmt.Sprintf("%s-ep-%s", h.options.Service, ln.Addr())
log := h.log.WithFields(map[string]any{
"service": serviceName,
"listener": "tcp",
"handler": "tunnel-ep",
"kind": "service",
})
epListener := newTCPListener(ln,
listener.AddrOption(h.md.entryPoint),
listener.ServiceOption(serviceName),
listener.ProxyProtocolOption(h.md.entryPointProxyProtocol),
listener.LoggerOption(log.WithFields(map[string]any{
"kind": "listener",
})),
)
if err = epListener.Init(nil); err != nil {
return
}
epHandler := &entrypointHandler{
ep: h.ep,
}
if err = epHandler.Init(nil); err != nil {
return
}
h.epSvc = xservice.NewService(
serviceName, epListener, epHandler,
xservice.LoggerOption(log),
)
go h.epSvc.Serve()
log.Infof("entrypoint: %s", h.epSvc.Addr())
return
}
func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) (err error) {
start := time.Now()
log := h.log.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
if err != nil {
conn.Close()
}
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return ErrRateLimit
}
if h.md.readTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(h.md.readTimeout))
}
req := relay.Request{}
if _, err := req.ReadFrom(conn); err != nil {
return err
}
conn.SetReadDeadline(time.Time{})
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
}
if req.Version != relay.Version1 {
resp.Status = relay.StatusBadRequest
resp.WriteTo(conn)
return ErrBadVersion
}
var user, pass string
var srcAddr, dstAddr string
network := "tcp"
var tunnelID relay.TunnelID
for _, f := range req.Features {
switch f.Type() {
case relay.FeatureUserAuth:
if feature, _ := f.(*relay.UserAuthFeature); feature != nil {
user, pass = feature.Username, feature.Password
}
case relay.FeatureAddr:
if feature, _ := f.(*relay.AddrFeature); feature != nil {
v := net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port)))
if srcAddr != "" {
dstAddr = v
} else {
srcAddr = v
}
}
case relay.FeatureTunnel:
if feature, _ := f.(*relay.TunnelFeature); feature != nil {
tunnelID = feature.ID
}
case relay.FeatureNetwork:
if feature, _ := f.(*relay.NetworkFeature); feature != nil {
network = feature.Network.String()
}
}
}
if tunnelID.IsZero() {
resp.Status = relay.StatusBadRequest
resp.WriteTo(conn)
return ErrTunnelID
}
if user != "" {
log = log.WithFields(map[string]any{"user": user})
}
if h.options.Auther != nil {
clientID, ok := h.options.Auther.Authenticate(ctx, user, pass)
if !ok {
resp.Status = relay.StatusUnauthorized
resp.WriteTo(conn)
return ErrUnauthorized
}
ctx = ctxvalue.ContextWithClientID(ctx, ctxvalue.ClientID(clientID))
}
switch req.Cmd & relay.CmdMask {
case relay.CmdConnect:
defer conn.Close()
log.Debugf("connect: %s >> %s/%s", srcAddr, dstAddr, network)
return h.handleConnect(ctx, &req, conn, network, srcAddr, dstAddr, tunnelID, log)
case relay.CmdBind:
log.Debugf("bind: %s >> %s/%s", srcAddr, dstAddr, network)
return h.handleBind(ctx, conn, network, dstAddr, tunnelID, log)
default:
resp.Status = relay.StatusBadRequest
resp.WriteTo(conn)
return ErrUnknownCmd
}
}
// Close implements io.Closer interface.
func (h *tunnelHandler) Close() error {
if h.epSvc != nil {
h.epSvc.Close()
}
h.pool.Close()
return nil
}
func (h *tunnelHandler) checkRateLimit(addr net.Addr) bool {
if h.options.RateLimiter == nil {
return true
}
host, _, _ := net.SplitHostPort(addr.String())
if limiter := h.options.RateLimiter.Limiter(host); limiter != nil {
return limiter.Allow(1)
}
return true
}