add tunnel feature for relay

This commit is contained in:
ginuerzh
2023-01-14 13:15:15 +08:00
parent 9b128534a0
commit 82cd924c86
20 changed files with 1000 additions and 45 deletions

View File

@ -12,6 +12,7 @@ import (
"github.com/go-gost/x/internal/net/udp"
"github.com/go-gost/x/internal/util/mux"
relay_util "github.com/go-gost/x/internal/util/relay"
"github.com/google/uuid"
)
func (h *relayHandler) handleBind(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) error {
@ -191,3 +192,49 @@ func (h *relayHandler) serveTCPBind(ctx context.Context, conn net.Conn, ln net.L
}(rc)
}
}
func (h *relayHandler) handleTunnel(ctx context.Context, conn net.Conn, tunnelID relay.TunnelID, log logger.Logger) (err error) {
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
}
if h.ep == nil {
resp.Status = relay.StatusServiceUnavailable
resp.WriteTo(conn)
return
}
uuid, err := uuid.NewRandom()
if err != nil {
resp.Status = relay.StatusInternalServerError
resp.WriteTo(conn)
return
}
var connectorID relay.ConnectorID
copy(connectorID[:], uuid[:])
af := &relay.AddrFeature{}
err = af.ParseFrom(h.ep.Addr().String())
if err != nil {
log.Warn(err)
}
resp.Features = append(resp.Features, af,
&relay.TunnelFeature{
ID: connectorID,
},
)
resp.WriteTo(conn)
// Upgrade connection to multiplex session.
session, err := mux.ClientSession(conn)
if err != nil {
return
}
h.pool.Add(tunnelID, NewConnector(connectorID, session))
log.Debugf("tunnel %s connector %s established", tunnelID, connectorID)
return
}

170
handler/relay/entrypoint.go Normal file
View File

@ -0,0 +1,170 @@
package relay
import (
"context"
"fmt"
"io"
"net"
"time"
"github.com/go-gost/core/handler"
"github.com/go-gost/core/ingress"
"github.com/go-gost/core/listener"
md "github.com/go-gost/core/metadata"
"github.com/go-gost/relay"
admission "github.com/go-gost/x/admission/wrapper"
xnet "github.com/go-gost/x/internal/net"
"github.com/go-gost/x/internal/net/proxyproto"
"github.com/go-gost/x/internal/util/forward"
climiter "github.com/go-gost/x/limiter/conn/wrapper"
limiter "github.com/go-gost/x/limiter/traffic/wrapper"
metrics "github.com/go-gost/x/metrics/wrapper"
"github.com/google/uuid"
)
type epListener struct {
ln net.Listener
options listener.Options
}
func NewEntryPointListener(opts ...listener.Option) listener.Listener {
options := listener.Options{}
for _, opt := range opts {
opt(&options)
}
return &epListener{
options: options,
}
}
func (l *epListener) Init(md md.Metadata) (err error) {
network := "tcp"
if xnet.IsIPv4(l.options.Addr) {
network = "tcp4"
}
ln, err := net.Listen(network, l.options.Addr)
if err != nil {
return
}
// l.logger.Debugf("pp: %d", l.options.ProxyProtocol)
ln = metrics.WrapListener(l.options.Service, ln)
ln = proxyproto.WrapListener(l.options.ProxyProtocol, ln, 10*time.Second)
ln = admission.WrapListener(l.options.Admission, ln)
ln = limiter.WrapListener(l.options.TrafficLimiter, ln)
ln = climiter.WrapListener(l.options.ConnLimiter, ln)
l.ln = ln
return
}
func (l *epListener) Accept() (conn net.Conn, err error) {
return l.ln.Accept()
}
func (l *epListener) Addr() net.Addr {
return l.ln.Addr()
}
func (l *epListener) Close() error {
return l.ln.Close()
}
type epHandler struct {
pool *ConnectorPool
ingress ingress.Ingress
options handler.Options
}
func NewEntryPointHandler(pool *ConnectorPool, ingress ingress.Ingress, opts ...handler.Option) handler.Handler {
options := handler.Options{}
for _, opt := range opts {
opt(&options)
}
return &epHandler{
pool: pool,
ingress: ingress,
options: options,
}
}
func (h *epHandler) Init(md md.Metadata) (err error) {
return
}
func (h *epHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
start := time.Now()
log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
"local": conn.LocalAddr().String(),
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
defer func() {
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
var rw io.ReadWriter = conn
var host string
var protocol string
rw, host, protocol, _ = forward.Sniffing(ctx, conn)
h.options.Logger.Debugf("sniffing: host=%s, protocol=%s", host, protocol)
var tunnelID relay.TunnelID
if h.ingress != nil {
v := h.ingress.Get(host)
uuid, _ := uuid.Parse(v)
copy(tunnelID[:], uuid[:])
}
log = log.WithFields(map[string]any{
"tunnel": tunnelID.String(),
})
var cc net.Conn
var err error
for i := 0; i < 3; i++ {
c := h.pool.Get(tunnelID)
if c == nil {
err = fmt.Errorf("tunnel %s not available", tunnelID.String())
break
}
cc, err = c.Session().GetConn()
if err != nil {
log.Error(err)
continue
}
break
}
if err != nil {
log.Error(err)
return err
}
defer cc.Close()
log.Debugf("%s >> %s", conn.RemoteAddr(), cc.RemoteAddr())
af := &relay.AddrFeature{}
af.ParseFrom(conn.RemoteAddr().String())
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
Features: []relay.Feature{af},
}
resp.WriteTo(cc)
t := time.Now()
log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.RemoteAddr())
xnet.Transport(rw, cc)
log.WithFields(map[string]any{
"duration": time.Since(t),
}).Debugf("%s >-< %s", conn.RemoteAddr(), cc.RemoteAddr())
return nil
}

View File

@ -3,20 +3,26 @@ package relay
import (
"context"
"errors"
"fmt"
"net"
"strconv"
"time"
"github.com/go-gost/core/chain"
"github.com/go-gost/core/handler"
"github.com/go-gost/core/listener"
md "github.com/go-gost/core/metadata"
"github.com/go-gost/core/service"
"github.com/go-gost/relay"
"github.com/go-gost/x/registry"
xservice "github.com/go-gost/x/service"
)
var (
ErrBadVersion = errors.New("relay: bad version")
ErrUnknownCmd = errors.New("relay: unknown command")
ErrBadVersion = errors.New("relay: bad version")
ErrUnknownCmd = errors.New("relay: unknown command")
ErrUnauthorized = errors.New("relay: unauthorized")
ErrRateLimit = errors.New("relay: rate limiting exceeded")
)
func init() {
@ -28,6 +34,8 @@ type relayHandler struct {
router *chain.Router
md metadata
options handler.Options
ep service.Service
pool *ConnectorPool
}
func NewHandler(opts ...handler.Option) handler.Handler {
@ -38,6 +46,7 @@ func NewHandler(opts ...handler.Option) handler.Handler {
return &relayHandler{
options: options,
pool: NewConnectorPool(),
}
}
@ -51,17 +60,61 @@ func (h *relayHandler) Init(md md.Metadata) (err error) {
h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger))
}
if err = h.initEntryPoint(); err != nil {
return
}
return nil
}
func (h *relayHandler) initEntryPoint() (err error) {
if h.md.entryPoint == "" {
return
}
serviceName := fmt.Sprintf("%s-ep", h.options.Service)
log := h.options.Logger.WithFields(map[string]any{
"service": serviceName,
"listener": "tunnel",
"handler": "tunnel",
})
epListener := NewEntryPointListener(
listener.AddrOption(h.md.entryPoint),
listener.ServiceOption(serviceName),
listener.LoggerOption(log.WithFields(map[string]any{
"kind": "listener",
})),
)
if err = epListener.Init(nil); err != nil {
return
}
epHandler := NewEntryPointHandler(
h.pool,
h.md.ingress,
handler.ServiceOption(serviceName),
handler.LoggerOption(log.WithFields(map[string]any{
"kind": "handler",
})),
)
if err = epHandler.Init(nil); err != nil {
return
}
h.ep = xservice.NewService(
serviceName, epListener, epHandler,
xservice.LoggerOption(log),
)
go h.ep.Serve()
log.Infof("entrypoint: %s", h.ep.Addr())
return
}
// Forward implements handler.Forwarder.
func (h *relayHandler) Forward(hop chain.Hop) {
h.hop = hop
}
func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) error {
defer conn.Close()
func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) (err error) {
start := time.Now()
log := h.options.Logger.WithFields(map[string]any{
"remote": conn.RemoteAddr().String(),
@ -69,14 +122,19 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
})
log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr())
var tunnelID relay.TunnelID
defer func() {
if tunnelID.IsZero() || err != nil {
conn.Close()
}
log.WithFields(map[string]any{
"duration": time.Since(start),
}).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr())
}()
if !h.checkRateLimit(conn.RemoteAddr()) {
return nil
return ErrRateLimit
}
if h.md.readTimeout > 0 {
@ -85,28 +143,38 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
req := relay.Request{}
if _, err := req.ReadFrom(conn); err != nil {
log.Error(err)
return err
}
conn.SetReadDeadline(time.Time{})
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
}
if req.Version != relay.Version1 {
err := ErrBadVersion
log.Error(err)
return err
resp.Status = relay.StatusBadRequest
resp.WriteTo(conn)
return ErrBadVersion
}
var user, pass string
var address string
for _, f := range req.Features {
if f.Type() == relay.FeatureUserAuth {
feature := f.(*relay.UserAuthFeature)
user, pass = feature.Username, feature.Password
}
if f.Type() == relay.FeatureAddr {
feature := f.(*relay.AddrFeature)
address = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port)))
switch f.Type() {
case relay.FeatureUserAuth:
if feature, _ := f.(*relay.UserAuthFeature); feature != nil {
user, pass = feature.Username, feature.Password
}
case relay.FeatureAddr:
if feature, _ := f.(*relay.AddrFeature); feature != nil {
address = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port)))
}
case relay.FeatureTunnel:
if feature, _ := f.(*relay.TunnelFeature); feature != nil {
tunnelID = feature.ID
}
}
}
@ -114,19 +182,15 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
log = log.WithFields(map[string]any{"user": user})
}
resp := relay.Response{
Version: relay.Version1,
Status: relay.StatusOK,
}
if h.options.Auther != nil && !h.options.Auther.Authenticate(user, pass) {
if h.options.Auther != nil &&
!h.options.Auther.Authenticate(user, pass) {
resp.Status = relay.StatusUnauthorized
log.Error("unauthorized")
_, err := resp.WriteTo(conn)
return err
resp.WriteTo(conn)
return ErrUnauthorized
}
network := "tcp"
if (req.Flags & relay.FUDP) == relay.FUDP {
if (req.Cmd & relay.FUDP) == relay.FUDP {
network = "udp"
}
@ -141,13 +205,19 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle
return h.handleForward(ctx, conn, network, log)
}
switch req.Flags & relay.CmdMask {
case 0, relay.CONNECT:
switch req.Cmd & relay.CmdMask {
case 0, relay.CmdConnect:
return h.handleConnect(ctx, conn, network, address, log)
case relay.BIND:
case relay.CmdBind:
if !tunnelID.IsZero() {
return h.handleTunnel(ctx, conn, tunnelID, log)
}
return h.handleBind(ctx, conn, network, address, log)
default:
resp.Status = relay.StatusBadRequest
resp.WriteTo(conn)
return ErrUnknownCmd
}
return ErrUnknownCmd
}
func (h *relayHandler) checkRateLimit(addr net.Addr) bool {

View File

@ -4,8 +4,10 @@ import (
"math"
"time"
"github.com/go-gost/core/ingress"
mdata "github.com/go-gost/core/metadata"
mdutil "github.com/go-gost/core/metadata/util"
"github.com/go-gost/x/registry"
)
type metadata struct {
@ -14,6 +16,8 @@ type metadata struct {
udpBufferSize int
noDelay bool
hash string
entryPoint string
ingress ingress.Ingress
}
func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) {
@ -23,6 +27,8 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) {
udpBufferSize = "udpBufferSize"
noDelay = "nodelay"
hash = "hash"
entryPoint = "entryPoint"
ingress = "ingress"
)
h.md.readTimeout = mdutil.GetDuration(md, readTimeout)
@ -36,5 +42,9 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) {
}
h.md.hash = mdutil.GetString(md, hash)
h.md.entryPoint = mdutil.GetString(md, entryPoint)
h.md.ingress = registry.IngressRegistry().Get(mdutil.GetString(md, ingress))
return
}

149
handler/relay/tunnel.go Normal file
View File

@ -0,0 +1,149 @@
package relay
import (
"sync"
"sync/atomic"
"time"
"github.com/go-gost/core/logger"
"github.com/go-gost/relay"
"github.com/go-gost/x/internal/util/mux"
)
type Connector struct {
id relay.ConnectorID
t time.Time
s *mux.Session
}
func NewConnector(id relay.ConnectorID, s *mux.Session) *Connector {
c := &Connector{
id: id,
t: time.Now(),
s: s,
}
go c.accept()
return c
}
func (c *Connector) accept() {
for {
conn, err := c.s.Accept()
if err != nil {
logger.Default().Errorf("connector %s: %v", c.id, err)
c.s.Close()
return
}
conn.Close()
}
}
func (c *Connector) ID() relay.ConnectorID {
return c.id
}
func (c *Connector) Session() *mux.Session {
return c.s
}
type Tunnel struct {
id relay.TunnelID
connectors []*Connector
t time.Time
n uint64
mu sync.RWMutex
}
func NewTunnel(id relay.TunnelID) *Tunnel {
t := &Tunnel{
id: id,
t: time.Now(),
}
go t.clean()
return t
}
func (t *Tunnel) ID() relay.TunnelID {
return t.id
}
func (t *Tunnel) AddConnector(c *Connector) {
if c == nil {
return
}
t.mu.Lock()
defer t.mu.Unlock()
t.connectors = append(t.connectors, c)
}
func (t *Tunnel) GetConnector() *Connector {
t.mu.RLock()
defer t.mu.RUnlock()
if len(t.connectors) == 0 {
return nil
}
n := atomic.AddUint64(&t.n, 1) - 1
return t.connectors[n%uint64(len(t.connectors))]
}
func (t *Tunnel) clean() {
ticker := time.NewTicker(3 * time.Second)
for range ticker.C {
t.mu.Lock()
var connectors []*Connector
for _, c := range t.connectors {
if c.Session().IsClosed() {
logger.Default().Debugf("remove tunnel %s connector %s", t.id, c.id)
continue
}
connectors = append(connectors, c)
}
if len(connectors) != len(t.connectors) {
t.connectors = connectors
}
t.mu.Unlock()
}
}
type ConnectorPool struct {
tunnels map[relay.TunnelID]*Tunnel
mu sync.RWMutex
}
func NewConnectorPool() *ConnectorPool {
return &ConnectorPool{
tunnels: make(map[relay.TunnelID]*Tunnel),
}
}
func (p *ConnectorPool) Add(tid relay.TunnelID, c *Connector) {
p.mu.Lock()
defer p.mu.Unlock()
t := p.tunnels[tid]
if t == nil {
t = NewTunnel(tid)
p.tunnels[tid] = t
}
t.AddConnector(c)
}
func (p *ConnectorPool) Get(tid relay.TunnelID) *Connector {
if p == nil {
return nil
}
p.mu.RLock()
defer p.mu.RUnlock()
t := p.tunnels[tid]
if t == nil {
return nil
}
return t.GetConnector()
}