From 497915f465084457c584823f858da039e1910f07 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Sun, 15 Oct 2023 15:39:25 +0800 Subject: [PATCH] add tunnel handler and connector --- connector/relay/bind.go | 26 ++-- connector/relay/conn.go | 13 ++ connector/tunnel/bind.go | 95 ++++++++++++++ connector/tunnel/conn.go | 224 ++++++++++++++++++++++++++++++++++ connector/tunnel/connector.go | 102 ++++++++++++++++ connector/tunnel/listener.go | 99 +++++++++++++++ connector/tunnel/metadata.go | 42 +++++++ go.mod | 4 +- go.sum | 8 +- handler/relay/bind.go | 19 +-- handler/relay/connect.go | 15 +-- handler/relay/handler.go | 26 ++-- handler/relay/metadata.go | 3 - handler/tunnel/bind.go | 67 ++++++++++ handler/tunnel/conn.go | 26 ++++ handler/tunnel/connect.go | 105 ++++++++++++++++ handler/tunnel/handler.go | 198 ++++++++++++++++++++++++++++++ handler/tunnel/metadata.go | 62 ++++++++++ handler/tunnel/tunnel.go | 200 ++++++++++++++++++++++++++++++ ingress/ingress.go | 4 + ingress/plugin.go | 66 +++++++++- listener/rtcp/listener.go | 21 +++- recorder/recorder.go | 4 +- registry/ingress.go | 9 ++ 24 files changed, 1375 insertions(+), 63 deletions(-) create mode 100644 connector/tunnel/bind.go create mode 100644 connector/tunnel/conn.go create mode 100644 connector/tunnel/connector.go create mode 100644 connector/tunnel/listener.go create mode 100644 connector/tunnel/metadata.go create mode 100644 handler/tunnel/bind.go create mode 100644 handler/tunnel/conn.go create mode 100644 handler/tunnel/connect.go create mode 100644 handler/tunnel/handler.go create mode 100644 handler/tunnel/metadata.go create mode 100644 handler/tunnel/tunnel.go diff --git a/connector/relay/bind.go b/connector/relay/bind.go index 2cc5751..e703378 100644 --- a/connector/relay/bind.go +++ b/connector/relay/bind.go @@ -17,7 +17,7 @@ import ( // Bind implements connector.Binder. func (c *relayConnector) Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error) { if !c.md.tunnelID.IsZero() { - return c.bindTunnel(ctx, conn, network, c.options.Logger) + return c.bindTunnel(ctx, conn, network, address, c.options.Logger) } log := c.options.Logger.WithFields(map[string]any{ @@ -43,12 +43,12 @@ func (c *relayConnector) Bind(ctx context.Context, conn net.Conn, network, addre } } -func (c *relayConnector) bindTunnel(ctx context.Context, conn net.Conn, network string, log logger.Logger) (net.Listener, error) { - addr, cid, err := c.initTunnel(conn, network) +func (c *relayConnector) bindTunnel(ctx context.Context, conn net.Conn, network, address string, log logger.Logger) (net.Listener, error) { + addr, cid, err := c.initTunnel(conn, network, address) if err != nil { return nil, err } - log.Debugf("create tunnel %s connector %s/%s OK", c.md.tunnelID.String(), cid, network) + log.Infof("create tunnel on %s/%s OK, tunnel=%s, connector=%s", addr, network, c.md.tunnelID.String(), cid) session, err := mux.ServerSession(conn) if err != nil { @@ -63,7 +63,7 @@ func (c *relayConnector) bindTunnel(ctx context.Context, conn net.Conn, network }, nil } -func (c *relayConnector) initTunnel(conn net.Conn, network string) (addr net.Addr, cid relay.ConnectorID, err error) { +func (c *relayConnector) initTunnel(conn net.Conn, network, address string) (addr net.Addr, cid relay.ConnectorID, err error) { req := relay.Request{ Version: relay.Version1, Cmd: relay.CmdBind, @@ -81,9 +81,14 @@ func (c *relayConnector) initTunnel(conn net.Conn, network string) (addr net.Add }) } - req.Features = append(req.Features, &relay.TunnelFeature{ - ID: c.md.tunnelID.ID(), - }) + af := &relay.AddrFeature{} + af.ParseFrom(address) + + req.Features = append(req.Features, af, + &relay.TunnelFeature{ + ID: c.md.tunnelID.ID(), + }, + ) if _, err = req.WriteTo(conn); err != nil { return } @@ -103,7 +108,10 @@ func (c *relayConnector) initTunnel(conn net.Conn, network string) (addr net.Add switch f.Type() { case relay.FeatureAddr: if feature, _ := f.(*relay.AddrFeature); feature != nil { - addr, err = net.ResolveTCPAddr("tcp", net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port)))) + addr = &bindAddr{ + network: network, + addr: net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port))), + } } case relay.FeatureTunnel: if feature, _ := f.(*relay.TunnelFeature); feature != nil { diff --git a/connector/relay/conn.go b/connector/relay/conn.go index 7a42e37..384aaa0 100644 --- a/connector/relay/conn.go +++ b/connector/relay/conn.go @@ -209,3 +209,16 @@ func (c *bindUDPConn) RemoteAddr() net.Addr { func (c *bindUDPConn) Metadata() mdata.Metadata { return c.md } + +type bindAddr struct { + network string + addr string +} + +func (p *bindAddr) Network() string { + return p.network +} + +func (p *bindAddr) String() string { + return p.addr +} diff --git a/connector/tunnel/bind.go b/connector/tunnel/bind.go new file mode 100644 index 0000000..58c2168 --- /dev/null +++ b/connector/tunnel/bind.go @@ -0,0 +1,95 @@ +package tunnel + +import ( + "context" + "fmt" + "net" + "strconv" + + "github.com/go-gost/core/connector" + "github.com/go-gost/relay" + "github.com/go-gost/x/internal/util/mux" +) + +// Bind implements connector.Binder. +func (c *tunnelConnector) Bind(ctx context.Context, conn net.Conn, network, address string, opts ...connector.BindOption) (net.Listener, error) { + log := c.options.Logger + + addr, cid, err := c.initTunnel(conn, network, address) + if err != nil { + return nil, err + } + log.Infof("create tunnel on %s/%s OK, tunnel=%s, connector=%s", addr, network, c.md.tunnelID.String(), cid) + + session, err := mux.ServerSession(conn) + if err != nil { + return nil, err + } + + return &bindListener{ + network: network, + addr: addr, + session: session, + logger: log, + }, nil +} + +func (c *tunnelConnector) initTunnel(conn net.Conn, network, address string) (addr net.Addr, cid relay.ConnectorID, err error) { + req := relay.Request{ + Version: relay.Version1, + Cmd: relay.CmdBind, + } + + if network == "udp" { + req.Cmd |= relay.FUDP + } + + if c.options.Auth != nil { + pwd, _ := c.options.Auth.Password() + req.Features = append(req.Features, &relay.UserAuthFeature{ + Username: c.options.Auth.Username(), + Password: pwd, + }) + } + + af := &relay.AddrFeature{} + af.ParseFrom(address) + + req.Features = append(req.Features, af, + &relay.TunnelFeature{ + ID: c.md.tunnelID.ID(), + }, + ) + if _, err = req.WriteTo(conn); err != nil { + return + } + + // first reply, bind status + resp := relay.Response{} + if _, err = resp.ReadFrom(conn); err != nil { + return + } + + if resp.Status != relay.StatusOK { + err = fmt.Errorf("%d: create tunnel %s failed", resp.Status, c.md.tunnelID.String()) + return + } + + for _, f := range resp.Features { + switch f.Type() { + case relay.FeatureAddr: + if feature, _ := f.(*relay.AddrFeature); feature != nil { + addr = &bindAddr{ + network: network, + addr: net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port))), + } + } + case relay.FeatureTunnel: + if feature, _ := f.(*relay.TunnelFeature); feature != nil { + cid = relay.NewConnectorID(feature.ID[:]) + } + } + } + + return +} diff --git a/connector/tunnel/conn.go b/connector/tunnel/conn.go new file mode 100644 index 0000000..0929981 --- /dev/null +++ b/connector/tunnel/conn.go @@ -0,0 +1,224 @@ +package tunnel + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "net" + "sync" + + "github.com/go-gost/core/common/bufpool" + mdata "github.com/go-gost/core/metadata" + "github.com/go-gost/relay" +) + +type tcpConn struct { + net.Conn + wbuf bytes.Buffer + once sync.Once +} + +func (c *tcpConn) Read(b []byte) (n int, err error) { + c.once.Do(func() { + err = readResponse(c.Conn) + }) + + if err != nil { + return + } + return c.Conn.Read(b) +} + +func (c *tcpConn) Write(b []byte) (n int, err error) { + n = len(b) // force byte length consistent + if c.wbuf.Len() > 0 { + c.wbuf.Write(b) // append the data to the cached header + _, err = c.Conn.Write(c.wbuf.Bytes()) + c.wbuf.Reset() + return + } + _, err = c.Conn.Write(b) + return +} + +type udpConn struct { + net.Conn + wbuf bytes.Buffer + once sync.Once +} + +func (c *udpConn) Read(b []byte) (n int, err error) { + c.once.Do(func() { + err = readResponse(c.Conn) + }) + if err != nil { + return + } + + var bb [2]byte + _, err = io.ReadFull(c.Conn, bb[:]) + if err != nil { + return + } + + dlen := int(binary.BigEndian.Uint16(bb[:])) + if len(b) >= dlen { + return io.ReadFull(c.Conn, b[:dlen]) + } + + buf := bufpool.Get(dlen) + defer bufpool.Put(buf) + _, err = io.ReadFull(c.Conn, *buf) + n = copy(b, *buf) + + return +} + +func (c *udpConn) Write(b []byte) (n int, err error) { + if len(b) > math.MaxUint16 { + err = errors.New("write: data maximum exceeded") + return + } + + n = len(b) + if c.wbuf.Len() > 0 { + var bb [2]byte + binary.BigEndian.PutUint16(bb[:], uint16(len(b))) + c.wbuf.Write(bb[:]) + c.wbuf.Write(b) // append the data to the cached header + _, err = c.wbuf.WriteTo(c.Conn) + return + } + + var bb [2]byte + binary.BigEndian.PutUint16(bb[:], uint16(len(b))) + _, err = c.Conn.Write(bb[:]) + if err != nil { + return + } + return c.Conn.Write(b) +} + +func readResponse(r io.Reader) (err error) { + resp := relay.Response{} + _, err = resp.ReadFrom(r) + if err != nil { + return + } + + if resp.Version != relay.Version1 { + err = relay.ErrBadVersion + return + } + + if resp.Status != relay.StatusOK { + err = fmt.Errorf("status %d", resp.Status) + return + } + return nil +} + +type bindConn struct { + net.Conn + localAddr net.Addr + remoteAddr net.Addr + md mdata.Metadata +} + +func (c *bindConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *bindConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +// Metadata implements metadata.Metadatable interface. +func (c *bindConn) Metadata() mdata.Metadata { + return c.md +} + +type bindUDPConn struct { + net.Conn + localAddr net.Addr + remoteAddr net.Addr + md mdata.Metadata +} + +func (c *bindUDPConn) Read(b []byte) (n int, err error) { + // 2-byte data length header + var bh [2]byte + _, err = io.ReadFull(c.Conn, bh[:]) + if err != nil { + return + } + + dlen := int(binary.BigEndian.Uint16(bh[:])) + if len(b) >= dlen { + n, err = io.ReadFull(c.Conn, b[:dlen]) + return + } + + buf := bufpool.Get(dlen) + defer bufpool.Put(buf) + + _, err = io.ReadFull(c.Conn, *buf) + n = copy(b, *buf) + + return +} + +func (c *bindUDPConn) Write(b []byte) (n int, err error) { + if len(b) > math.MaxUint16 { + err = errors.New("write: data maximum exceeded") + return + } + + // 2-byte data length header + var bh [2]byte + binary.BigEndian.PutUint16(bh[:], uint16(len(b))) + _, err = c.Conn.Write(bh[:]) + if err != nil { + return + } + return c.Conn.Write(b) +} + +func (c *bindUDPConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + addr = c.remoteAddr + n, err = c.Read(b) + return +} + +func (c *bindUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + return c.Write(b) +} + +func (c *bindUDPConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *bindUDPConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +// Metadata implements metadata.Metadatable interface. +func (c *bindUDPConn) Metadata() mdata.Metadata { + return c.md +} + +type bindAddr struct { + network string + addr string +} + +func (p *bindAddr) Network() string { + return p.network +} + +func (p *bindAddr) String() string { + return p.addr +} diff --git a/connector/tunnel/connector.go b/connector/tunnel/connector.go new file mode 100644 index 0000000..960ad3f --- /dev/null +++ b/connector/tunnel/connector.go @@ -0,0 +1,102 @@ +package tunnel + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/go-gost/core/connector" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/relay" + "github.com/go-gost/x/registry" +) + +func init() { + registry.ConnectorRegistry().Register("tunnel", NewConnector) +} + +type tunnelConnector struct { + md metadata + options connector.Options +} + +func NewConnector(opts ...connector.Option) connector.Connector { + options := connector.Options{} + for _, opt := range opts { + opt(&options) + } + + return &tunnelConnector{ + options: options, + } +} + +func (c *tunnelConnector) Init(md md.Metadata) (err error) { + return c.parseMetadata(md) +} + +func (c *tunnelConnector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { + log := c.options.Logger.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + "network": network, + "address": address, + }) + log.Debugf("connect %s/%s", address, network) + + if c.md.connectTimeout > 0 { + conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + req := relay.Request{ + Version: relay.Version1, + Cmd: relay.CmdConnect, + } + + if c.options.Auth != nil { + pwd, _ := c.options.Auth.Password() + req.Features = append(req.Features, &relay.UserAuthFeature{ + Username: c.options.Auth.Username(), + Password: pwd, + }) + } + + if address != "" { + af := &relay.AddrFeature{} + if err := af.ParseFrom(address); err != nil { + return nil, err + } + req.Features = append(req.Features, af) + } + + req.Features = append(req.Features, &relay.TunnelFeature{ + ID: c.md.tunnelID.ID(), + }) + + switch network { + case "tcp", "tcp4", "tcp6", "unix", "serial": + cc := &tcpConn{ + Conn: conn, + } + if _, err := req.WriteTo(&cc.wbuf); err != nil { + return nil, err + } + conn = cc + case "udp", "udp4", "udp6": + cc := &udpConn{ + Conn: conn, + } + if _, err := req.WriteTo(&cc.wbuf); err != nil { + return nil, err + } + conn = cc + default: + err := fmt.Errorf("network %s is unsupported", network) + log.Error(err) + return nil, err + } + + return conn, nil +} diff --git a/connector/tunnel/listener.go b/connector/tunnel/listener.go new file mode 100644 index 0000000..5ffb981 --- /dev/null +++ b/connector/tunnel/listener.go @@ -0,0 +1,99 @@ +package tunnel + +import ( + "fmt" + "net" + "strconv" + + "github.com/go-gost/core/logger" + mdata "github.com/go-gost/core/metadata" + "github.com/go-gost/relay" + "github.com/go-gost/x/internal/util/mux" + mdx "github.com/go-gost/x/metadata" +) + +type bindListener struct { + network string + addr net.Addr + session *mux.Session + logger logger.Logger +} + +func (p *bindListener) Accept() (net.Conn, error) { + cc, err := p.session.Accept() + if err != nil { + return nil, err + } + + conn, err := p.getPeerConn(cc) + if err != nil { + cc.Close() + p.logger.Errorf("get peer failed: %s", err) + return nil, err + } + + return conn, nil +} + +func (p *bindListener) getPeerConn(conn net.Conn) (net.Conn, error) { + // second reply, peer connected + resp := relay.Response{} + if _, err := resp.ReadFrom(conn); err != nil { + return nil, err + } + + if resp.Status != relay.StatusOK { + err := fmt.Errorf("peer connect failed") + return nil, err + } + + var address, host string + // the first addr is the client address, the optional second addr is the target host address. + for _, f := range resp.Features { + if f.Type() == relay.FeatureAddr { + if fa, ok := f.(*relay.AddrFeature); ok { + v := net.JoinHostPort(fa.Host, strconv.Itoa(int(fa.Port))) + if address != "" { + host = v + } else { + address = v + } + } + } + } + + raddr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + return nil, err + } + + var md mdata.Metadata + if host != "" { + md = mdx.NewMetadata(map[string]any{"host": host}) + } + + if p.network == "udp" { + return &bindUDPConn{ + Conn: conn, + localAddr: p.addr, + remoteAddr: raddr, + md: md, + }, nil + } + + cn := &bindConn{ + Conn: conn, + localAddr: p.addr, + remoteAddr: raddr, + md: md, + } + return cn, nil +} + +func (p *bindListener) Addr() net.Addr { + return p.addr +} + +func (p *bindListener) Close() error { + return p.session.Close() +} diff --git a/connector/tunnel/metadata.go b/connector/tunnel/metadata.go new file mode 100644 index 0000000..cd879df --- /dev/null +++ b/connector/tunnel/metadata.go @@ -0,0 +1,42 @@ +package tunnel + +import ( + "errors" + "time" + + mdata "github.com/go-gost/core/metadata" + mdutil "github.com/go-gost/core/metadata/util" + "github.com/go-gost/relay" + "github.com/google/uuid" +) + +var ( + ErrInvalidTunnelID = errors.New("tunnel: invalid tunnel ID") +) + +type metadata struct { + connectTimeout time.Duration + tunnelID relay.TunnelID +} + +func (c *tunnelConnector) parseMetadata(md mdata.Metadata) (err error) { + const ( + connectTimeout = "connectTimeout" + noDelay = "nodelay" + ) + + c.md.connectTimeout = mdutil.GetDuration(md, connectTimeout) + + if s := mdutil.GetString(md, "tunnelID", "tunnel.id"); s != "" { + uuid, err := uuid.Parse(s) + if err != nil { + return err + } + c.md.tunnelID = relay.NewTunnelID(uuid[:]) + } + + if c.md.tunnelID.IsZero() { + return ErrInvalidTunnelID + } + return +} diff --git a/go.mod b/go.mod index 3bc9886..0d4b1b6 100644 --- a/go.mod +++ b/go.mod @@ -7,10 +7,10 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.9.1 - github.com/go-gost/core v0.0.0-20231009132641-4525630abb98 + github.com/go-gost/core v0.0.0-20231015073540-f08c81460234 github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.4.0 - github.com/go-gost/plugin v0.0.0-20230930094933-bc86458bf2fb + github.com/go-gost/plugin v0.0.0-20231015073745-fc558b9a3146 github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 github.com/go-redis/redis/v8 v8.11.5 diff --git a/go.sum b/go.sum index fdac77f..86744f3 100644 --- a/go.sum +++ b/go.sum @@ -91,14 +91,14 @@ github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SU github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gost/core v0.0.0-20231009132641-4525630abb98 h1:4LJ7HjwW9uOJxhBBdyTpZcy0+TSaj0lkyaqCIXEn1PY= -github.com/go-gost/core v0.0.0-20231009132641-4525630abb98/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= +github.com/go-gost/core v0.0.0-20231015073540-f08c81460234 h1:xTuwKXHTVStuO6RTXF1iGr4Y5XMckRjYbisLYO3Mg2Y= +github.com/go-gost/core v0.0.0-20231015073540-f08c81460234/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.4.0 h1:EIrOEkpJez4gwHrMa33frA+hHXJyevjp47thpMQsJzI= github.com/go-gost/gosocks5 v0.4.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= -github.com/go-gost/plugin v0.0.0-20230930094933-bc86458bf2fb h1:pJP1zrNLyKPsDQhL+ITyP2uCaS4Kax9T4ap2dZF3QaM= -github.com/go-gost/plugin v0.0.0-20230930094933-bc86458bf2fb/go.mod h1:mM/RLNsVy2nz5PiOijuqLYR3LhMzyQ9Kh/p0rXybJoo= +github.com/go-gost/plugin v0.0.0-20231015073745-fc558b9a3146 h1:ggHOq9ozazmWAgfUYaCXoNS4vbFNsRa2CSJUAwKdstg= +github.com/go-gost/plugin v0.0.0-20231015073745-fc558b9a3146/go.mod h1:mM/RLNsVy2nz5PiOijuqLYR3LhMzyQ9Kh/p0rXybJoo= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 h1:qAG1OyjvdA5h221CfFSS3J359V3d2E7dJWyP29QoDSI= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 h1:xj8gUZGYO3nb5+6Bjw9+tsFkA9sYynrOvDvvC4uDV2I= diff --git a/handler/relay/bind.go b/handler/relay/bind.go index 7d498a8..e9e4fb3 100644 --- a/handler/relay/bind.go +++ b/handler/relay/bind.go @@ -2,6 +2,8 @@ package relay import ( "context" + "crypto/md5" + "encoding/hex" "fmt" "net" "time" @@ -181,7 +183,7 @@ func (h *relayHandler) bindUDP(ctx context.Context, conn net.Conn, network, addr return nil } -func (h *relayHandler) handleBindTunnel(ctx context.Context, conn net.Conn, network string, tunnelID relay.TunnelID, log logger.Logger) (err error) { +func (h *relayHandler) handleBindTunnel(ctx context.Context, conn net.Conn, network, address string, tunnelID relay.TunnelID, log logger.Logger) (err error) { resp := relay.Response{ Version: relay.Version1, Status: relay.StatusOK, @@ -198,9 +200,11 @@ func (h *relayHandler) handleBindTunnel(ctx context.Context, conn net.Conn, netw connectorID = relay.NewUDPConnectorID(uuid[:]) } - addr := ":0" - if h.ep != nil { - addr = h.ep.Addr().String() + addr := address + if host, port, _ := net.SplitHostPort(addr); host == "" { + v := md5.Sum([]byte(tunnelID.String())) + host = hex.EncodeToString(v[:8]) + addr = net.JoinHostPort(host, port) } af := &relay.AddrFeature{} err = af.ParseFrom(addr) @@ -221,10 +225,11 @@ func (h *relayHandler) handleBindTunnel(ctx context.Context, conn net.Conn, netw } h.pool.Add(tunnelID, NewConnector(connectorID, session)) - log.Debugf("tunnel %s connector %s/%s established", tunnelID, connectorID, network) - if h.recorder.Recorder != nil { - h.recorder.Recorder.Record(ctx, tunnelID[:]) + if h.md.ingress != nil { + h.md.ingress.Set(ctx, addr, tunnelID.String()) } + log.Debugf("%s/%s: tunnel=%s, connector=%s established", address, network, tunnelID, connectorID) + return } diff --git a/handler/relay/connect.go b/handler/relay/connect.go index 66daf3f..0936dac 100644 --- a/handler/relay/connect.go +++ b/handler/relay/connect.go @@ -143,15 +143,12 @@ func (h *relayHandler) handleConnectTunnel(ctx context.Context, conn net.Conn, n tid = parseTunnelID(ingress.Get(ctx, host)) } - // client is not an public entrypoint. - if h.md.entryPointID.IsZero() || !tunnelID.Equal(h.md.entryPointID) { - if !tid.Equal(tunnelID) && !h.md.directTunnel { - resp.Status = relay.StatusHostUnreachable - resp.WriteTo(conn) - err := fmt.Errorf("no route to host %s", host) - log.Error(err) - return err - } + if !tid.Equal(tunnelID) && !h.md.directTunnel { + resp.Status = relay.StatusHostUnreachable + resp.WriteTo(conn) + err := fmt.Errorf("no route to host %s", host) + log.Error(err) + return err } cc, _, err := getTunnelConn(network, h.pool, tid, 3, log) diff --git a/handler/relay/handler.go b/handler/relay/handler.go index e30568b..c42e6e5 100644 --- a/handler/relay/handler.go +++ b/handler/relay/handler.go @@ -13,12 +13,10 @@ import ( "github.com/go-gost/core/hop" "github.com/go-gost/core/listener" md "github.com/go-gost/core/metadata" - "github.com/go-gost/core/recorder" "github.com/go-gost/core/service" "github.com/go-gost/relay" xnet "github.com/go-gost/x/internal/net" auth_util "github.com/go-gost/x/internal/util/auth" - xrecorder "github.com/go-gost/x/recorder" "github.com/go-gost/x/registry" xservice "github.com/go-gost/x/service" ) @@ -35,13 +33,12 @@ func init() { } type relayHandler struct { - hop hop.Hop - router *chain.Router - md metadata - options handler.Options - ep service.Service - pool *ConnectorPool - recorder recorder.RecorderObject + hop hop.Hop + router *chain.Router + md metadata + options handler.Options + ep service.Service + pool *ConnectorPool } func NewHandler(opts ...handler.Option) handler.Handler { @@ -66,15 +63,6 @@ func (h *relayHandler) Init(md md.Metadata) (err error) { h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) } - if opts := h.router.Options(); opts != nil { - for _, ro := range opts.Recorders { - if ro.Record == xrecorder.RecorderServiceHandlerRelayTunnelEndpoint { - h.recorder = ro - break - } - } - } - if err = h.initEntryPoint(); err != nil { return } @@ -248,7 +236,7 @@ func (h *relayHandler) Handle(ctx context.Context, conn net.Conn, opts ...handle return h.handleConnect(ctx, conn, network, address, log) case relay.CmdBind: if !tunnelID.IsZero() { - return h.handleBindTunnel(ctx, conn, network, tunnelID, log) + return h.handleBindTunnel(ctx, conn, network, address, tunnelID, log) } defer conn.Close() diff --git a/handler/relay/metadata.go b/handler/relay/metadata.go index 25b1817..0267b54 100644 --- a/handler/relay/metadata.go +++ b/handler/relay/metadata.go @@ -9,7 +9,6 @@ import ( "github.com/go-gost/core/logger" mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" - "github.com/go-gost/relay" xingress "github.com/go-gost/x/ingress" "github.com/go-gost/x/registry" ) @@ -22,7 +21,6 @@ type metadata struct { hash string directTunnel bool entryPoint string - entryPointID relay.TunnelID entryPointProxyProtocol int ingress ingress.Ingress } @@ -53,7 +51,6 @@ func (h *relayHandler) parseMetadata(md mdata.Metadata) (err error) { h.md.directTunnel = mdutil.GetBool(md, "tunnel.direct") h.md.entryPoint = mdutil.GetString(md, entryPoint) - h.md.entryPointID = parseTunnelID(mdutil.GetString(md, entryPointID)) h.md.entryPointProxyProtocol = mdutil.GetInt(md, entryPointProxyProtocol) h.md.ingress = registry.IngressRegistry().Get(mdutil.GetString(md, "ingress")) diff --git a/handler/tunnel/bind.go b/handler/tunnel/bind.go new file mode 100644 index 0000000..d9de83a --- /dev/null +++ b/handler/tunnel/bind.go @@ -0,0 +1,67 @@ +package tunnel + +import ( + "context" + "crypto/md5" + "encoding/hex" + "net" + + "github.com/go-gost/core/logger" + "github.com/go-gost/relay" + "github.com/go-gost/x/internal/util/mux" + "github.com/google/uuid" +) + +func (h *tunnelHandler) handleBind(ctx context.Context, conn net.Conn, network, address string, tunnelID relay.TunnelID, log logger.Logger) (err error) { + resp := relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + } + + uuid, err := uuid.NewRandom() + if err != nil { + resp.Status = relay.StatusInternalServerError + resp.WriteTo(conn) + return + } + connectorID := relay.NewConnectorID(uuid[:]) + if network == "udp" { + connectorID = relay.NewUDPConnectorID(uuid[:]) + } + + addr := address + if host, port, _ := net.SplitHostPort(addr); host == "" { + v := md5.Sum([]byte(tunnelID.String())) + host = hex.EncodeToString(v[:8]) + addr = net.JoinHostPort(host, port) + } + af := &relay.AddrFeature{} + err = af.ParseFrom(addr) + if err != nil { + log.Warn(err) + } + resp.Features = append(resp.Features, af, + &relay.TunnelFeature{ + ID: connectorID.ID(), + }, + ) + resp.WriteTo(conn) + + // Upgrade connection to multiplex session. + session, err := mux.ClientSession(conn) + if err != nil { + return + } + + h.pool.Add(tunnelID, NewConnector(connectorID, session)) + if h.md.ingress != nil { + h.md.ingress.Set(ctx, addr, tunnelID.String()) + } + if h.recorder.Recorder != nil { + h.recorder.Recorder.Record(ctx, tunnelID[:]) + } + + log.Debugf("%s/%s: tunnel=%s, connector=%s established", addr, network, tunnelID, connectorID) + + return +} diff --git a/handler/tunnel/conn.go b/handler/tunnel/conn.go new file mode 100644 index 0000000..91b7666 --- /dev/null +++ b/handler/tunnel/conn.go @@ -0,0 +1,26 @@ +package tunnel + +import ( + "bytes" + "net" +) + +type tcpConn struct { + net.Conn + wbuf bytes.Buffer +} + +func (c *tcpConn) Read(b []byte) (n int, err error) { + return c.Conn.Read(b) +} + +func (c *tcpConn) Write(b []byte) (n int, err error) { + n = len(b) // force byte length consistent + if c.wbuf.Len() > 0 { + c.wbuf.Write(b) // append the data to the cached header + _, err = c.wbuf.WriteTo(c.Conn) + return + } + _, err = c.Conn.Write(b) + return +} diff --git a/handler/tunnel/connect.go b/handler/tunnel/connect.go new file mode 100644 index 0000000..edcffe2 --- /dev/null +++ b/handler/tunnel/connect.go @@ -0,0 +1,105 @@ +package tunnel + +import ( + "context" + "fmt" + "net" + "strconv" + "time" + + "github.com/go-gost/core/logger" + "github.com/go-gost/relay" + xnet "github.com/go-gost/x/internal/net" +) + +func (h *tunnelHandler) handleConnect(ctx context.Context, conn net.Conn, network, address string, tunnelID relay.TunnelID, log logger.Logger) error { + log = log.WithFields(map[string]any{ + "dst": fmt.Sprintf("%s/%s", address, network), + "cmd": "connect", + "tunnel": tunnelID.String(), + }) + + log.Debugf("%s >> %s/%s", conn.RemoteAddr(), address, network) + + resp := relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + } + + host, sp, _ := net.SplitHostPort(address) + + if h.options.Bypass != nil && h.options.Bypass.Contains(ctx, network, address) { + log.Debug("bypass: ", address) + resp.Status = relay.StatusForbidden + _, err := resp.WriteTo(conn) + return err + } + + var tid relay.TunnelID + if ingress := h.md.ingress; ingress != nil { + tid = parseTunnelID(ingress.Get(ctx, host)) + } + + // client is not an public entrypoint. + if h.md.entryPointID.IsZero() || !tunnelID.Equal(h.md.entryPointID) { + if !tid.Equal(tunnelID) && !h.md.directTunnel { + resp.Status = relay.StatusHostUnreachable + resp.WriteTo(conn) + err := fmt.Errorf("no route to host %s", host) + log.Error(err) + return err + } + } + + cc, _, err := getTunnelConn(network, h.pool, tid, 3, log) + if err != nil { + resp.Status = relay.StatusServiceUnavailable + resp.WriteTo(conn) + log.Error(err) + return err + } + defer cc.Close() + + log.Debugf("%s >> %s", conn.RemoteAddr(), cc.RemoteAddr()) + + rc := &tcpConn{ + Conn: conn, + } + // cache the header + if _, err := resp.WriteTo(&rc.wbuf); err != nil { + return err + } + conn = rc + + var features []relay.Feature + af := &relay.AddrFeature{} // source/visitor address + af.ParseFrom(conn.RemoteAddr().String()) + features = append(features, af) + + if host != "" { + port, _ := strconv.Atoi(sp) + // target host + af = &relay.AddrFeature{ + AType: relay.AddrDomain, + Host: host, + Port: uint16(port), + } + features = append(features, af) + } + + resp = relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + Features: features, + } + resp.WriteTo(cc) + + t := time.Now() + log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) + xnet.Transport(conn, cc) + log.WithFields(map[string]any{ + "duration": time.Since(t), + }).Debugf("%s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) + + return nil +} diff --git a/handler/tunnel/handler.go b/handler/tunnel/handler.go new file mode 100644 index 0000000..a8681a9 --- /dev/null +++ b/handler/tunnel/handler.go @@ -0,0 +1,198 @@ +package tunnel + +import ( + "context" + "errors" + "net" + "strconv" + "time" + + "github.com/go-gost/core/chain" + "github.com/go-gost/core/handler" + md "github.com/go-gost/core/metadata" + "github.com/go-gost/core/recorder" + "github.com/go-gost/relay" + auth_util "github.com/go-gost/x/internal/util/auth" + xrecorder "github.com/go-gost/x/recorder" + "github.com/go-gost/x/registry" +) + +var ( + ErrBadVersion = errors.New("relay: bad version") + ErrUnknownCmd = errors.New("relay: unknown command") + ErrTunnelID = errors.New("tunnel: invalid tunnel ID") + ErrUnauthorized = errors.New("relay: unauthorized") + ErrRateLimit = errors.New("relay: rate limiting exceeded") +) + +func init() { + registry.HandlerRegistry().Register("tunnel", NewHandler) +} + +type tunnelHandler struct { + router *chain.Router + md metadata + options handler.Options + pool *ConnectorPool + recorder recorder.RecorderObject +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := handler.Options{} + for _, opt := range opts { + opt(&options) + } + + return &tunnelHandler{ + options: options, + pool: NewConnectorPool(), + } +} + +func (h *tunnelHandler) Init(md md.Metadata) (err error) { + if err := h.parseMetadata(md); err != nil { + return err + } + + h.router = h.options.Router + if h.router == nil { + h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) + } + + if opts := h.router.Options(); opts != nil { + for _, ro := range opts.Recorders { + if ro.Record == xrecorder.RecorderServiceHandlerTunnelEndpoint { + h.recorder = ro + break + } + } + } + + return nil +} + +func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) (err error) { + start := time.Now() + log := h.options.Logger.WithFields(map[string]any{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + log.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + + defer func() { + if err != nil { + conn.Close() + } + log.WithFields(map[string]any{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + ctx = auth_util.ContextWithClientAddr(ctx, auth_util.ClientAddr(conn.RemoteAddr().String())) + + if !h.checkRateLimit(conn.RemoteAddr()) { + return ErrRateLimit + } + + if h.md.readTimeout > 0 { + conn.SetReadDeadline(time.Now().Add(h.md.readTimeout)) + } + + req := relay.Request{} + if _, err := req.ReadFrom(conn); err != nil { + return err + } + + conn.SetReadDeadline(time.Time{}) + + resp := relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + } + + if req.Version != relay.Version1 { + resp.Status = relay.StatusBadRequest + resp.WriteTo(conn) + return ErrBadVersion + } + + var user, pass string + var address string + var networkID relay.NetworkID + var tunnelID relay.TunnelID + for _, f := range req.Features { + switch f.Type() { + case relay.FeatureUserAuth: + if feature, _ := f.(*relay.UserAuthFeature); feature != nil { + user, pass = feature.Username, feature.Password + } + case relay.FeatureAddr: + if feature, _ := f.(*relay.AddrFeature); feature != nil { + address = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port))) + } + case relay.FeatureTunnel: + if feature, _ := f.(*relay.TunnelFeature); feature != nil { + tunnelID = relay.NewTunnelID(feature.ID[:]) + } + case relay.FeatureNetwork: + if feature, _ := f.(*relay.NetworkFeature); feature != nil { + networkID = feature.Network + } + } + } + + if tunnelID.IsZero() { + resp.Status = relay.StatusBadRequest + resp.WriteTo(conn) + return ErrTunnelID + } + + if user != "" { + log = log.WithFields(map[string]any{"user": user}) + } + + if h.options.Auther != nil { + id, ok := h.options.Auther.Authenticate(ctx, user, pass) + if !ok { + resp.Status = relay.StatusUnauthorized + resp.WriteTo(conn) + return ErrUnauthorized + } + ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) + } + + network := networkID.String() + if (req.Cmd & relay.FUDP) == relay.FUDP { + network = "udp" + } + + switch req.Cmd & relay.CmdMask { + case relay.CmdConnect: + defer conn.Close() + return h.handleConnect(ctx, conn, network, address, tunnelID, log) + case relay.CmdBind: + return h.handleBind(ctx, conn, network, address, tunnelID, log) + default: + resp.Status = relay.StatusBadRequest + resp.WriteTo(conn) + return ErrUnknownCmd + } +} + +// Close implements io.Closer interface. +func (h *tunnelHandler) Close() error { + return nil +} + +func (h *tunnelHandler) checkRateLimit(addr net.Addr) bool { + if h.options.RateLimiter == nil { + return true + } + host, _, _ := net.SplitHostPort(addr.String()) + if limiter := h.options.RateLimiter.Limiter(host); limiter != nil { + return limiter.Allow(1) + } + + return true +} diff --git a/handler/tunnel/metadata.go b/handler/tunnel/metadata.go new file mode 100644 index 0000000..bd24bdd --- /dev/null +++ b/handler/tunnel/metadata.go @@ -0,0 +1,62 @@ +package tunnel + +import ( + "strings" + "time" + + "github.com/go-gost/core/ingress" + "github.com/go-gost/core/logger" + mdata "github.com/go-gost/core/metadata" + mdutil "github.com/go-gost/core/metadata/util" + "github.com/go-gost/relay" + xingress "github.com/go-gost/x/ingress" + "github.com/go-gost/x/registry" +) + +type metadata struct { + readTimeout time.Duration + hash string + directTunnel bool + entryPointID relay.TunnelID + ingress ingress.Ingress +} + +func (h *tunnelHandler) parseMetadata(md mdata.Metadata) (err error) { + const ( + readTimeout = "readTimeout" + entryPointID = "entrypoint.id" + hash = "hash" + ) + + h.md.readTimeout = mdutil.GetDuration(md, readTimeout) + + h.md.hash = mdutil.GetString(md, hash) + + h.md.directTunnel = mdutil.GetBool(md, "tunnel.direct") + h.md.entryPointID = parseTunnelID(mdutil.GetString(md, entryPointID)) + + h.md.ingress = registry.IngressRegistry().Get(mdutil.GetString(md, "ingress")) + if h.md.ingress == nil { + var rules []xingress.Rule + for _, s := range strings.Split(mdutil.GetString(md, "tunnel"), ",") { + ss := strings.SplitN(s, ":", 2) + if len(ss) != 2 { + continue + } + rules = append(rules, xingress.Rule{ + Hostname: ss[0], + Endpoint: ss[1], + }) + } + if len(rules) > 0 { + h.md.ingress = xingress.NewIngress( + xingress.RulesOption(rules), + xingress.LoggerOption(logger.Default().WithFields(map[string]any{ + "kind": "ingress", + })), + ) + } + } + + return +} diff --git a/handler/tunnel/tunnel.go b/handler/tunnel/tunnel.go new file mode 100644 index 0000000..e287efa --- /dev/null +++ b/handler/tunnel/tunnel.go @@ -0,0 +1,200 @@ +package tunnel + +import ( + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/go-gost/core/logger" + "github.com/go-gost/relay" + "github.com/go-gost/x/internal/util/mux" + "github.com/google/uuid" +) + +type Connector struct { + id relay.ConnectorID + t time.Time + s *mux.Session +} + +func NewConnector(id relay.ConnectorID, s *mux.Session) *Connector { + c := &Connector{ + id: id, + t: time.Now(), + s: s, + } + go c.accept() + return c +} + +func (c *Connector) accept() { + for { + conn, err := c.s.Accept() + if err != nil { + logger.Default().Errorf("connector %s: %v", c.id, err) + c.s.Close() + return + } + conn.Close() + } +} + +func (c *Connector) ID() relay.ConnectorID { + return c.id +} + +func (c *Connector) Session() *mux.Session { + return c.s +} + +type Tunnel struct { + id relay.TunnelID + connectors []*Connector + t time.Time + n uint64 + mu sync.RWMutex +} + +func NewTunnel(id relay.TunnelID) *Tunnel { + t := &Tunnel{ + id: id, + t: time.Now(), + } + go t.clean() + return t +} + +func (t *Tunnel) ID() relay.TunnelID { + return t.id +} + +func (t *Tunnel) AddConnector(c *Connector) { + if c == nil { + return + } + + t.mu.Lock() + defer t.mu.Unlock() + + t.connectors = append(t.connectors, c) +} + +func (t *Tunnel) GetConnector(network string) *Connector { + t.mu.RLock() + defer t.mu.RUnlock() + + var connectors []*Connector + for _, c := range t.connectors { + if network == "udp" && c.id.IsUDP() || + network != "udp" && !c.id.IsUDP() { + connectors = append(connectors, c) + } + } + if len(connectors) == 0 { + return nil + } + n := atomic.AddUint64(&t.n, 1) - 1 + return connectors[n%uint64(len(connectors))] +} + +func (t *Tunnel) clean() { + ticker := time.NewTicker(30 * time.Second) + for range ticker.C { + t.mu.Lock() + var connectors []*Connector + for _, c := range t.connectors { + if c.Session().IsClosed() { + logger.Default().Debugf("remove tunnel %s connector %s", t.id, c.id) + continue + } + connectors = append(connectors, c) + } + if len(connectors) != len(t.connectors) { + t.connectors = connectors + } + t.mu.Unlock() + } +} + +type ConnectorPool struct { + tunnels map[string]*Tunnel + mu sync.RWMutex +} + +func NewConnectorPool() *ConnectorPool { + return &ConnectorPool{ + tunnels: make(map[string]*Tunnel), + } +} + +func (p *ConnectorPool) Add(tid relay.TunnelID, c *Connector) { + p.mu.Lock() + defer p.mu.Unlock() + + s := tid.String() + + t := p.tunnels[s] + if t == nil { + t = NewTunnel(tid) + p.tunnels[s] = t + } + t.AddConnector(c) +} + +func (p *ConnectorPool) Get(network string, tid relay.TunnelID) *Connector { + if p == nil { + return nil + } + + p.mu.RLock() + defer p.mu.RUnlock() + + t := p.tunnels[tid.String()] + if t == nil { + return nil + } + + return t.GetConnector(network) +} + +func parseTunnelID(s string) (tid relay.TunnelID) { + if s == "" { + return + } + private := false + if s[0] == '$' { + private = true + s = s[1:] + } + uuid, _ := uuid.Parse(s) + + if private { + return relay.NewPrivateTunnelID(uuid[:]) + } + return relay.NewTunnelID(uuid[:]) +} + +func getTunnelConn(network string, pool *ConnectorPool, tid relay.TunnelID, retry int, log logger.Logger) (conn net.Conn, cid relay.ConnectorID, err error) { + if retry <= 0 { + retry = 1 + } + for i := 0; i < retry; i++ { + c := pool.Get(network, tid) + if c == nil { + err = fmt.Errorf("tunnel %s not available", tid.String()) + break + } + + conn, err = c.Session().GetConn() + if err != nil { + log.Error(err) + continue + } + cid = c.id + break + } + + return +} diff --git a/ingress/ingress.go b/ingress/ingress.go index 49f4903..8cb421e 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -259,6 +259,10 @@ func (ing *localIngress) Get(ctx context.Context, host string) string { return ep } +func (ing *localIngress) Set(ctx context.Context, host, endpoint string) { + +} + func (ing *localIngress) lookup(host string) string { if ing == nil || len(ing.rules) == 0 { return "" diff --git a/ingress/plugin.go b/ingress/plugin.go index 6d0b6d5..406b72c 100644 --- a/ingress/plugin.go +++ b/ingress/plugin.go @@ -62,6 +62,17 @@ func (p *grpcPlugin) Get(ctx context.Context, host string) string { return r.GetEndpoint() } +func (p *grpcPlugin) Set(ctx context.Context, host, endpoint string) { + if p.client == nil { + return + } + + p.client.Set(ctx, &proto.SetRequest{ + Host: host, + Endpoint: endpoint, + }) +} + func (p *grpcPlugin) Close() error { if closer, ok := p.conn.(io.Closer); ok { return closer.Close() @@ -69,14 +80,22 @@ func (p *grpcPlugin) Close() error { return nil } -type httpPluginRequest struct { +type httpPluginGetRequest struct { Host string `json:"host"` } -type httpPluginResponse struct { +type httpPluginGetResponse struct { Endpoint string `json:"endpoint"` } +type httpPluginSetRequest struct { + Host string `json:"host"` + Endpoint string `json:"endpoint"` +} + +type httpPluginSetResponse struct { +} + type httpPlugin struct { url string client *http.Client @@ -107,7 +126,7 @@ func (p *httpPlugin) Get(ctx context.Context, host string) (endpoint string) { return } - rb := httpPluginRequest{ + rb := httpPluginGetRequest{ Host: host, } v, err := json.Marshal(&rb) @@ -134,9 +153,48 @@ func (p *httpPlugin) Get(ctx context.Context, host string) (endpoint string) { return } - res := httpPluginResponse{} + res := httpPluginGetResponse{} if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { return } return res.Endpoint } + +func (p *httpPlugin) Set(ctx context.Context, host, endpoint string) { + if p.client == nil { + return + } + + rb := httpPluginSetRequest{ + Host: host, + Endpoint: endpoint, + } + v, err := json.Marshal(&rb) + if err != nil { + return + } + + req, err := http.NewRequest(http.MethodPut, p.url, bytes.NewReader(v)) + if err != nil { + return + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return + } + + res := httpPluginSetResponse{} + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return + } +} diff --git a/listener/rtcp/listener.go b/listener/rtcp/listener.go index 340e25e..bbbc7cf 100644 --- a/listener/rtcp/listener.go +++ b/listener/rtcp/listener.go @@ -50,12 +50,13 @@ func (l *rtcpListener) Init(md md.Metadata) (err error) { if xnet.IsIPv4(l.options.Addr) { network = "tcp4" } - laddr, err := net.ResolveTCPAddr(network, l.options.Addr) - if err != nil { - return + if laddr, _ := net.ResolveTCPAddr(network, l.options.Addr); laddr != nil { + l.laddr = laddr + } + if l.laddr == nil { + l.laddr = &bindAddr{addr: l.options.Addr} } - l.laddr = laddr l.router = chain.NewRouter( chain.ChainRouterOption(l.options.Chain), chain.LoggerRouterOption(l.logger), @@ -110,3 +111,15 @@ func (l *rtcpListener) Close() error { return nil } + +type bindAddr struct { + addr string +} + +func (p *bindAddr) Network() string { + return "tcp" +} + +func (p *bindAddr) String() string { + return p.addr +} diff --git a/recorder/recorder.go b/recorder/recorder.go index dc7348a..857a608 100644 --- a/recorder/recorder.go +++ b/recorder/recorder.go @@ -1,6 +1,6 @@ package recorder const ( - RecorderServiceHandlerSerial = "recorder.service.handler.serial" - RecorderServiceHandlerRelayTunnelEndpoint = "recorder.service.handler.relay.tunnel.endpoint" + RecorderServiceHandlerSerial = "recorder.service.handler.serial" + RecorderServiceHandlerTunnelEndpoint = "recorder.service.handler.tunnel.endpoint" ) diff --git a/registry/ingress.go b/registry/ingress.go index 358c579..72afd90 100644 --- a/registry/ingress.go +++ b/registry/ingress.go @@ -37,3 +37,12 @@ func (w *ingressWrapper) Get(ctx context.Context, host string) string { } return v.Get(ctx, host) } + +func (w *ingressWrapper) Set(ctx context.Context, host, endpoint string) { + v := w.r.get(w.name) + if v == nil { + return + } + + v.Set(ctx, host, endpoint) +}