add tunnel feature for relay

This commit is contained in:
ginuerzh
2023-01-14 13:15:15 +08:00
parent 9b128534a0
commit 82cd924c86
20 changed files with 1000 additions and 45 deletions

View File

@ -3,20 +3,26 @@ package relay
import (
"context"
"errors"
"fmt"
"net"
"strconv"
"time"
"github.com/go-gost/core/chain"
"github.com/go-gost/core/handler"
"github.com/go-gost/core/listener"
md "github.com/go-gost/core/metadata"
"github.com/go-gost/core/service"
"github.com/go-gost/relay"
"github.com/go-gost/x/registry"
xservice "github.com/go-gost/x/service"
)
var (
ErrBadVersion = errors.New("relay: bad version")
ErrUnknownCmd = errors.New("relay: unknown command")
ErrBadVersion = errors.New("relay: bad version")
ErrUnknownCmd = errors.New("relay: unknown command")
ErrUnauthorized = errors.New("relay: unauthorized")
ErrRateLimit = errors.New("relay: rate limiting exceeded")
)
func init() {
@ -28,6 +34,8 @@ type relayHandler struct {
router *chain.Router
md metadata
options handler.Options
ep service.Service
pool *ConnectorPool
}
func NewHandler(opts ...handler.Option) handler.Handler {
@ -38,6 +46,7 @@ func NewHandler(opts ...handler.Option) handler.Handler {
return &relayHandler{
options: options,
pool: NewConnectorPool(),
}
}
@ -51,17 +60,61 @@ func (h *relayHandler) Init(md md.Metadata) (err error) {
h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger))
}
if err = h.initEntryPoint(); err != nil {
return
}
return nil
}
func (h *relayHandler) initEntryPoint() (err error) {
if h.md.entryPoint == "" {
return
}
serviceName := fmt.Sprintf("%s-ep", h.options.Service)
log := h.options.Logger.WithFields(map[string]any{
"service": serviceName,
"listener": "tunnel",
"handler": "tunnel",
})
epListener := NewEntryPointListener(
listener.AddrOption(h.md.entryPoint),
listener.ServiceOption(serviceName),
listener.LoggerOption(log.WithFields(map[string]any{
"kind": "listener",
})),
)
if err = epListener.Init(nil); err != nil {
return
}
epHandler := NewEntryPointHandler(
h.pool,
h.md.ingress,
handler.ServiceOption(serviceName),
handler.LoggerOption(log.WithFields(map[string]any{
"kind": "handler",
})),
)
if err = epHandler.Init(nil); err != nil {
return
}
h.ep = xservice.NewService(
serviceName, epListener, epHandler,
xservice.LoggerOption(log),
)
go h.ep.Serve()
log.Infof("entrypoint: %s", h.ep.Addr())
return
}
// Forward implements handler.Forwarder.
func (h *relayHandler) Forward(hop chain.Hop) {
h.hop = hop
}
func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) (err error) {
start := time.Now()
log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
@ -69,14 +122,19 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
var tunnelID relay.TunnelID
defer func() {
if tunnelID.IsZero() || err != nil {
conn.Close()
}
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
return ErrRateLimit
}
if h.md.readTimeout > 0 {
@ -85,28 +143,38 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
req := relay.Request{}
if _, err := req.ReadFrom(conn); err != nil {
log.Error(err)
return err
}
conn.SetReadDeadline(time.Time{})
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
}
if req.Version != relay.Version1 {
err := ErrBadVersion
log.Error(err)
return err
resp.Status = relay.StatusBadRequest
resp.WriteTo(conn)
return ErrBadVersion
}
var user, pass string
var address string
for _, f := range req.Features {
if f.Type() == relay.FeatureUserAuth {
feature := f.(*relay.UserAuthFeature)
user, pass = feature.Username, feature.Password
}
if f.Type() == relay.FeatureAddr {
feature := f.(*relay.AddrFeature)
address = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port)))
switch f.Type() {
case relay.FeatureUserAuth:
if feature, _ := f.(*relay.UserAuthFeature); feature != nil {
user, pass = feature.Username, feature.Password
}
case relay.FeatureAddr:
if feature, _ := f.(*relay.AddrFeature); feature != nil {
address = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port)))
}
case relay.FeatureTunnel:
if feature, _ := f.(*relay.TunnelFeature); feature != nil {
tunnelID = feature.ID
}
}
}
@ -114,19 +182,15 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
log = log.WithFields(map[string]any{"user": user})
}
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
}
if h.options.Auther != nil && !h.options.Auther.Authenticate(user, pass) {
if h.options.Auther != nil &&
!h.options.Auther.Authenticate(user, pass) {
resp.Status = relay.StatusUnauthorized
log.Error("unauthorized")
_, err := resp.WriteTo(conn)
return err
resp.WriteTo(conn)
return ErrUnauthorized
}
network := "tcp"
if (req.Flags & relay.FUDP) == relay.FUDP {
if (req.Cmd & relay.FUDP) == relay.FUDP {
network = "udp"
}
@ -141,13 +205,19 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
return h.handleForward(ctx, conn, network, log)
}
switch req.Flags & relay.CmdMask {
case 0, relay.CONNECT:
switch req.Cmd & relay.CmdMask {
case 0, relay.CmdConnect:
return h.handleConnect(ctx, conn, network, address, log)
case relay.BIND:
case relay.CmdBind:
if !tunnelID.IsZero() {
return h.handleTunnel(ctx, conn, tunnelID, log)
}
return h.handleBind(ctx, conn, network, address, log)
default:
resp.Status = relay.StatusBadRequest
resp.WriteTo(conn)
return ErrUnknownCmd
}
return ErrUnknownCmd
}
func (h *relayHandler) checkRateLimit(addr net.Addr) bool {