From 68edeb2d59e8b3a5a8e83830abcb4f676ef4f7f2 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Thu, 2 Nov 2023 20:52:33 +0800 Subject: [PATCH] update sd --- dialer/mws/metadata.go | 32 ++++-------- dialer/ws/metadata.go | 31 ++++------- go.mod | 4 +- go.sum | 19 ++----- handler/tunnel/bind.go | 12 +++-- handler/tunnel/tunnel.go | 13 +++-- internal/util/ws/ws.go | 9 ++++ listener/mws/metadata.go | 28 +++------- listener/ws/metadata.go | 29 +++-------- registry/sd.go | 12 ++--- sd/plugin.go | 110 ++++++++++++++++++--------------------- 11 files changed, 126 insertions(+), 173 deletions(-) diff --git a/dialer/mws/metadata.go b/dialer/mws/metadata.go index 7dc2c29..aaea799 100644 --- a/dialer/mws/metadata.go +++ b/dialer/mws/metadata.go @@ -30,22 +30,8 @@ type metadata struct { } func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) { - const ( - host = "host" - path = "path" - - handshakeTimeout = "handshakeTimeout" - readHeaderTimeout = "readHeaderTimeout" - readBufferSize = "readBufferSize" - writeBufferSize = "writeBufferSize" - enableCompression = "enableCompression" - - header = "header" - ) - - d.md.host = mdutil.GetString(md, host) - - d.md.path = mdutil.GetString(md, path) + d.md.host = mdutil.GetString(md, "ws.host", "host") + d.md.path = mdutil.GetString(md, "ws.path", "path") if d.md.path == "" { d.md.path = defaultPath } @@ -60,13 +46,13 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) { MaxStreamBuffer: mdutil.GetInt(md, "mux.maxStreamBuffer"), } - d.md.handshakeTimeout = mdutil.GetDuration(md, handshakeTimeout) - d.md.readHeaderTimeout = mdutil.GetDuration(md, readHeaderTimeout) - d.md.readBufferSize = mdutil.GetInt(md, readBufferSize) - d.md.writeBufferSize = mdutil.GetInt(md, writeBufferSize) - d.md.enableCompression = mdutil.GetBool(md, enableCompression) + d.md.handshakeTimeout = mdutil.GetDuration(md, "ws.handshakeTimeout", "handshakeTimeout") + d.md.readHeaderTimeout = mdutil.GetDuration(md, "ws.readHeaderTimeout", "readHeaderTimeout") + d.md.readBufferSize = mdutil.GetInt(md, "ws.readBufferSize", "readBufferSize") + d.md.writeBufferSize = mdutil.GetInt(md, "ws.writeBufferSize", "writeBufferSize") + d.md.enableCompression = mdutil.GetBool(md, "ws.enableCompression", "enableCompression") - if m := mdutil.GetStringMapString(md, header); len(m) > 0 { + if m := mdutil.GetStringMapString(md, "ws.header", "header"); len(m) > 0 { h := http.Header{} for k, v := range m { h.Add(k, v) @@ -74,7 +60,7 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) { d.md.header = h } - if mdutil.GetBool(md, "keepalive") { + if mdutil.GetBool(md, "ws.keepalive", "keepalive") { d.md.keepaliveInterval = mdutil.GetDuration(md, "ttl", "keepalive.interval") if d.md.keepaliveInterval <= 0 { d.md.keepaliveInterval = defaultKeepalivePeriod diff --git a/dialer/ws/metadata.go b/dialer/ws/metadata.go index 2f23eac..fac5995 100644 --- a/dialer/ws/metadata.go +++ b/dialer/ws/metadata.go @@ -28,33 +28,20 @@ type metadata struct { } func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) { - const ( - host = "host" - path = "path" + d.md.host = mdutil.GetString(md, "ws.host", "host") - handshakeTimeout = "handshakeTimeout" - readHeaderTimeout = "readHeaderTimeout" - readBufferSize = "readBufferSize" - writeBufferSize = "writeBufferSize" - enableCompression = "enableCompression" - - header = "header" - ) - - d.md.host = mdutil.GetString(md, host) - - d.md.path = mdutil.GetString(md, path) + d.md.path = mdutil.GetString(md, "ws.path", "path") if d.md.path == "" { d.md.path = defaultPath } - d.md.handshakeTimeout = mdutil.GetDuration(md, handshakeTimeout) - d.md.readHeaderTimeout = mdutil.GetDuration(md, readHeaderTimeout) - d.md.readBufferSize = mdutil.GetInt(md, readBufferSize) - d.md.writeBufferSize = mdutil.GetInt(md, writeBufferSize) - d.md.enableCompression = mdutil.GetBool(md, enableCompression) + d.md.handshakeTimeout = mdutil.GetDuration(md, "ws.handshakeTimeout", "handshakeTimeout") + d.md.readHeaderTimeout = mdutil.GetDuration(md, "ws.readHeaderTimeout", "readHeaderTimeout") + d.md.readBufferSize = mdutil.GetInt(md, "ws.readBufferSize", "readBufferSize") + d.md.writeBufferSize = mdutil.GetInt(md, "ws.writeBufferSize", "writeBufferSize") + d.md.enableCompression = mdutil.GetBool(md, "ws.enableCompression", "enableCompression") - if m := mdutil.GetStringMapString(md, header); len(m) > 0 { + if m := mdutil.GetStringMapString(md, "ws.header", "header"); len(m) > 0 { h := http.Header{} for k, v := range m { h.Add(k, v) @@ -62,7 +49,7 @@ func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) { d.md.header = h } - if mdutil.GetBool(md, "keepalive") { + if mdutil.GetBool(md, "ws.keepalive", "keepalive") { d.md.keepaliveInterval = mdutil.GetDuration(md, "ttl", "keepalive.interval") if d.md.keepaliveInterval <= 0 { d.md.keepaliveInterval = defaultKeepalivePeriod diff --git a/go.mod b/go.mod index b94d6bd..870dc5b 100644 --- a/go.mod +++ b/go.mod @@ -7,10 +7,10 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.9.1 - github.com/go-gost/core v0.0.0-20231031145651-8835e0e647f9 + github.com/go-gost/core v0.0.0-20231102125025-55d7b2e3129e github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.4.0 - github.com/go-gost/plugin v0.0.0-20231031145754-4c25027b8b97 + github.com/go-gost/plugin v0.0.0-20231102125124-a1cc7a13e066 github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 github.com/go-redis/redis/v8 v8.11.5 diff --git a/go.sum b/go.sum index ca17d66..5a6756f 100644 --- a/go.sum +++ b/go.sum @@ -91,18 +91,14 @@ github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SU github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-gost/core v0.0.0-20231027140845-d975ec3c7477 h1:a49XfrB4mgbw7z7oN/WTovx0X7SbxdfoANsEDTy9CqI= -github.com/go-gost/core v0.0.0-20231027140845-d975ec3c7477/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= -github.com/go-gost/core v0.0.0-20231031145651-8835e0e647f9 h1:Zab4WYWl/GyhfjkoZ2JqauQlRwLGzsxs8/tHxctYlv4= -github.com/go-gost/core v0.0.0-20231031145651-8835e0e647f9/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= +github.com/go-gost/core v0.0.0-20231102125025-55d7b2e3129e h1:rOlfeBOv+1vDMFuS6hgWoD9qpQeDzhdsoiA9v5GEw6c= +github.com/go-gost/core v0.0.0-20231102125025-55d7b2e3129e/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.4.0 h1:EIrOEkpJez4gwHrMa33frA+hHXJyevjp47thpMQsJzI= github.com/go-gost/gosocks5 v0.4.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= -github.com/go-gost/plugin v0.0.0-20231027141311-2cfb0a14b451 h1:sgg7LyK4ZAuQkBfaQxyFpH+xyAfrczDFDtkdRAcUxCE= -github.com/go-gost/plugin v0.0.0-20231027141311-2cfb0a14b451/go.mod h1:mM/RLNsVy2nz5PiOijuqLYR3LhMzyQ9Kh/p0rXybJoo= -github.com/go-gost/plugin v0.0.0-20231031145754-4c25027b8b97 h1:p9dmeWsNwKcbIwwUUumD5a7HlZFODBwnMItBGuJ+P5M= -github.com/go-gost/plugin v0.0.0-20231031145754-4c25027b8b97/go.mod h1:qXr2Zm9Ex2ATqnWuNUzVZqySPMnuIihvblYZt4MlZLw= +github.com/go-gost/plugin v0.0.0-20231102125124-a1cc7a13e066 h1:/pDM9JP9ESSRuAr237yAXB6WiDdjEeulDkaLa9Gw0ss= +github.com/go-gost/plugin v0.0.0-20231102125124-a1cc7a13e066/go.mod h1:qXr2Zm9Ex2ATqnWuNUzVZqySPMnuIihvblYZt4MlZLw= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 h1:qAG1OyjvdA5h221CfFSS3J359V3d2E7dJWyP29QoDSI= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 h1:xj8gUZGYO3nb5+6Bjw9+tsFkA9sYynrOvDvvC4uDV2I= @@ -193,9 +189,6 @@ github.com/google/pprof v0.0.0-20230912144702-c363fe2c2ed8 h1:gpptm606MZYGaMHMsB github.com/google/pprof v0.0.0-20230912144702-c363fe2c2ed8/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= @@ -690,8 +683,6 @@ google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20201214200347-8c77b98c765d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210108203827-ffc7fda8c3d7/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210226172003-ab064af71705/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM= google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b h1:ZlWIi1wSK56/8hn4QcBp/j9M7Gt3U/3hZw3mC7vDICo= google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b/go.mod h1:swOH3j0KzcDDgGUWr+SNpyTen5YrXjS3eyPzFYKc6lc= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= @@ -710,8 +701,6 @@ google.golang.org/grpc v1.31.1/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA51WJ8= google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc v1.58.3 h1:BjnpXut1btbtgN/6sp+brB2Kbm2LjNXnidYujAVbSoQ= -google.golang.org/grpc v1.58.3/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0= google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= diff --git a/handler/tunnel/bind.go b/handler/tunnel/bind.go index f04eda7..d0ac0a8 100644 --- a/handler/tunnel/bind.go +++ b/handler/tunnel/bind.go @@ -4,10 +4,10 @@ import ( "context" "crypto/md5" "encoding/hex" - "fmt" "net" "github.com/go-gost/core/logger" + "github.com/go-gost/core/sd" "github.com/go-gost/relay" "github.com/go-gost/x/internal/util/mux" "github.com/google/uuid" @@ -58,8 +58,14 @@ func (h *tunnelHandler) handleBind(ctx context.Context, conn net.Conn, network, if h.md.ingress != nil { h.md.ingress.Set(ctx, addr, tunnelID.String()) } - if sd := h.md.sd; sd != nil { - err := sd.Register(ctx, fmt.Sprintf("%s:%s:%s", h.id, tunnelID, connectorID), network, h.md.entryPoint) + if h.md.sd != nil { + err := h.md.sd.Register(ctx, &sd.Service{ + ID: connectorID.String(), + Name: tunnelID.String(), + Node: h.id, + Network: network, + Address: h.md.entryPoint, + }) if err != nil { h.log.Error(err) } diff --git a/handler/tunnel/tunnel.go b/handler/tunnel/tunnel.go index 6103622..e12d1fd 100644 --- a/handler/tunnel/tunnel.go +++ b/handler/tunnel/tunnel.go @@ -2,7 +2,6 @@ package tunnel import ( "context" - "fmt" "sync" "sync/atomic" "time" @@ -144,14 +143,22 @@ func (t *Tunnel) clean() { if c.Session().IsClosed() { logger.Default().Debugf("remove tunnel: %s, connector: %s", t.id, c.id) if t.sd != nil { - t.sd.Deregister(context.Background(), fmt.Sprintf("%s:%s:%s", t.node, t.id, c.id)) + t.sd.Deregister(context.Background(), &sd.Service{ + ID: c.id.String(), + Name: t.id.String(), + Node: t.node, + }) } continue } connectors = append(connectors, c) if t.sd != nil { - t.sd.Renew(context.Background(), fmt.Sprintf("%s:%s:%s", t.node, t.id, c.id)) + t.sd.Renew(context.Background(), &sd.Service{ + ID: c.id.String(), + Name: t.id.String(), + Node: t.node, + }) } } if len(connectors) != len(t.connectors) { diff --git a/internal/util/ws/ws.go b/internal/util/ws/ws.go index da4beaa..0dd6612 100644 --- a/internal/util/ws/ws.go +++ b/internal/util/ws/ws.go @@ -49,8 +49,17 @@ func (c *websocketConn) WriteMessage(messageType int, data []byte) error { } func (c *websocketConn) SetDeadline(t time.Time) error { + c.mux.Lock() + defer c.mux.Unlock() + if err := c.SetReadDeadline(t); err != nil { return err } return c.SetWriteDeadline(t) } + +func (c *websocketConn) SetWriteDeadline(t time.Time) error { + c.mux.Lock() + defer c.mux.Unlock() + return c.Conn.SetWriteDeadline(t) +} diff --git a/listener/mws/metadata.go b/listener/mws/metadata.go index c5d9c57..1761991 100644 --- a/listener/mws/metadata.go +++ b/listener/mws/metadata.go @@ -31,33 +31,21 @@ type metadata struct { } func (l *mwsListener) parseMetadata(md mdata.Metadata) (err error) { - const ( - path = "path" - backlog = "backlog" - header = "header" - - handshakeTimeout = "handshakeTimeout" - readHeaderTimeout = "readHeaderTimeout" - readBufferSize = "readBufferSize" - writeBufferSize = "writeBufferSize" - enableCompression = "enableCompression" - ) - - l.md.path = mdutil.GetString(md, path) + l.md.path = mdutil.GetString(md, "ws.path", "path") if l.md.path == "" { l.md.path = defaultPath } - l.md.backlog = mdutil.GetInt(md, backlog) + l.md.backlog = mdutil.GetInt(md, "ws.backlog", "backlog") if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - l.md.handshakeTimeout = mdutil.GetDuration(md, handshakeTimeout) - l.md.readHeaderTimeout = mdutil.GetDuration(md, readHeaderTimeout) - l.md.readBufferSize = mdutil.GetInt(md, readBufferSize) - l.md.writeBufferSize = mdutil.GetInt(md, writeBufferSize) - l.md.enableCompression = mdutil.GetBool(md, enableCompression) + l.md.handshakeTimeout = mdutil.GetDuration(md, "ws.handshakeTimeout", "handshakeTimeout") + l.md.readHeaderTimeout = mdutil.GetDuration(md, "ws.readHeaderTimeout", "readHeaderTimeout") + l.md.readBufferSize = mdutil.GetInt(md, "ws.readBufferSize", "readBufferSize") + l.md.writeBufferSize = mdutil.GetInt(md, "ws.writeBufferSize", "writeBufferSize") + l.md.enableCompression = mdutil.GetBool(md, "ws.enableCompression", "enableCompression") l.md.muxCfg = &mux.Config{ Version: mdutil.GetInt(md, "mux.version"), @@ -69,7 +57,7 @@ func (l *mwsListener) parseMetadata(md mdata.Metadata) (err error) { MaxStreamBuffer: mdutil.GetInt(md, "mux.maxStreamBuffer"), } - if mm := mdutil.GetStringMapString(md, header); len(mm) > 0 { + if mm := mdutil.GetStringMapString(md, "ws.header", "header"); len(mm) > 0 { hd := http.Header{} for k, v := range mm { hd.Add(k, v) diff --git a/listener/ws/metadata.go b/listener/ws/metadata.go index 6017606..f10d985 100644 --- a/listener/ws/metadata.go +++ b/listener/ws/metadata.go @@ -28,36 +28,23 @@ type metadata struct { } func (l *wsListener) parseMetadata(md mdata.Metadata) (err error) { - const ( - path = "path" - backlog = "backlog" - - handshakeTimeout = "handshakeTimeout" - readHeaderTimeout = "readHeaderTimeout" - readBufferSize = "readBufferSize" - writeBufferSize = "writeBufferSize" - enableCompression = "enableCompression" - - header = "header" - ) - - l.md.path = mdutil.GetString(md, path) + l.md.path = mdutil.GetString(md, "ws.path", "path") if l.md.path == "" { l.md.path = defaultPath } - l.md.backlog = mdutil.GetInt(md, backlog) + l.md.backlog = mdutil.GetInt(md, "ws.backlog", "backlog") if l.md.backlog <= 0 { l.md.backlog = defaultBacklog } - l.md.handshakeTimeout = mdutil.GetDuration(md, handshakeTimeout) - l.md.readHeaderTimeout = mdutil.GetDuration(md, readHeaderTimeout) - l.md.readBufferSize = mdutil.GetInt(md, readBufferSize) - l.md.writeBufferSize = mdutil.GetInt(md, writeBufferSize) - l.md.enableCompression = mdutil.GetBool(md, enableCompression) + l.md.handshakeTimeout = mdutil.GetDuration(md, "ws.handshakeTimeout", "handshakeTimeout") + l.md.readHeaderTimeout = mdutil.GetDuration(md, "ws.readHeaderTimeout", "readHeaderTimeout") + l.md.readBufferSize = mdutil.GetInt(md, "ws.readBufferSize", "readBufferSize") + l.md.writeBufferSize = mdutil.GetInt(md, "ws.writeBufferSize", "writeBufferSize") + l.md.enableCompression = mdutil.GetBool(md, "ws.enableCompression", "enableCompression") - if mm := mdutil.GetStringMapString(md, header); len(mm) > 0 { + if mm := mdutil.GetStringMapString(md, "ws.header", "header"); len(mm) > 0 { hd := http.Header{} for k, v := range mm { hd.Add(k, v) diff --git a/registry/sd.go b/registry/sd.go index fa73017..775a1e0 100644 --- a/registry/sd.go +++ b/registry/sd.go @@ -30,30 +30,30 @@ type sdWrapper struct { r *sdRegistry } -func (w *sdWrapper) Register(ctx context.Context, name string, network, address string, opts ...sd.Option) error { +func (w *sdWrapper) Register(ctx context.Context, service *sd.Service, opts ...sd.Option) error { v := w.r.get(w.name) if v == nil { return nil } - return v.Register(ctx, name, network, address, opts...) + return v.Register(ctx, service, opts...) } -func (w *sdWrapper) Deregister(ctx context.Context, name string) error { +func (w *sdWrapper) Deregister(ctx context.Context, service *sd.Service) error { v := w.r.get(w.name) if v == nil { return nil } - return v.Deregister(ctx, name) + return v.Deregister(ctx, service) } -func (w *sdWrapper) Renew(ctx context.Context, name string) error { +func (w *sdWrapper) Renew(ctx context.Context, service *sd.Service) error { v := w.r.get(w.name) if v == nil { return nil } - return v.Renew(ctx, name) + return v.Renew(ctx, service) } func (w *sdWrapper) Get(ctx context.Context, name string) ([]*sd.Service, error) { diff --git a/sd/plugin.go b/sd/plugin.go index e66c941..7795c4d 100644 --- a/sd/plugin.go +++ b/sd/plugin.go @@ -47,16 +47,20 @@ func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) sd.SD { return p } -func (p *grpcPlugin) Register(ctx context.Context, name string, network, address string, opts ...sd.Option) error { +func (p *grpcPlugin) Register(ctx context.Context, service *sd.Service, opts ...sd.Option) error { if p.client == nil { return nil } _, err := p.client.Register(ctx, &proto.RegisterRequest{ - Name: name, - Network: network, - Address: address, + Service: &proto.Service{ + Id: service.ID, + Name: service.Name, + Node: service.Node, + Network: service.Network, + Address: service.Address, + }, }) if err != nil { p.log.Error(err) @@ -65,24 +69,36 @@ func (p *grpcPlugin) Register(ctx context.Context, name string, network, address return nil } -func (p *grpcPlugin) Deregister(ctx context.Context, name string) error { +func (p *grpcPlugin) Deregister(ctx context.Context, service *sd.Service) error { if p.client == nil { return nil } _, err := p.client.Deregister(ctx, &proto.DeregisterRequest{ - Name: name, + Service: &proto.Service{ + Id: service.ID, + Name: service.Name, + Node: service.Node, + Network: service.Network, + Address: service.Address, + }, }) return err } -func (p *grpcPlugin) Renew(ctx context.Context, name string) error { +func (p *grpcPlugin) Renew(ctx context.Context, service *sd.Service) error { if p.client == nil { return nil } _, err := p.client.Renew(ctx, &proto.RenewRequest{ - Name: name, + Service: &proto.Service{ + Id: service.ID, + Name: service.Name, + Node: service.Node, + Network: service.Network, + Address: service.Address, + }, }) return err } @@ -121,39 +137,10 @@ func (p *grpcPlugin) Close() error { return nil } -type httpRegisterRequest struct { - Name string `json:"name"` - Network string `json:"network"` - Address string `json:"address"` -} - -type httpRegisterResponse struct { - Ok bool `json:"ok"` -} - -type httpDeregisterRequest struct { - Name string `json:"name"` -} - -type httpDeregisterResponse struct { - Ok bool `json:"ok"` -} - -type httpRenewRequest struct { - Name string `json:"name"` -} - -type httpRenewResponse struct { - Ok bool `json:"ok"` -} - -type httpGetRequest struct { - Name string `json:"name"` -} - type sdService struct { - Node string `json:"node"` + ID string `json:"id"` Name string `json:"name"` + Node string `json:"node"` Network string `json:"network"` Address string `json:"address"` } @@ -187,17 +174,18 @@ func NewHTTPPlugin(name string, url string, opts ...plugin.Option) sd.SD { } } -func (p *httpPlugin) Register(ctx context.Context, name string, network, address string, opts ...sd.Option) error { - if p.client == nil { +func (p *httpPlugin) Register(ctx context.Context, service *sd.Service, opts ...sd.Option) error { + if p.client == nil || service == nil { return nil } - rb := httpRegisterRequest{ - Name: name, - Network: network, - Address: address, - } - v, err := json.Marshal(&rb) + v, err := json.Marshal(sdService{ + ID: service.ID, + Name: service.Name, + Node: service.Node, + Network: service.Network, + Address: service.Address, + }) if err != nil { return err } @@ -224,15 +212,18 @@ func (p *httpPlugin) Register(ctx context.Context, name string, network, address return nil } -func (p *httpPlugin) Deregister(ctx context.Context, name string) error { - if p.client == nil { +func (p *httpPlugin) Deregister(ctx context.Context, service *sd.Service) error { + if p.client == nil || service == nil { return nil } - rb := httpDeregisterRequest{ - Name: name, - } - v, err := json.Marshal(&rb) + v, err := json.Marshal(sdService{ + ID: service.ID, + Name: service.Name, + Node: service.Node, + Network: service.Network, + Address: service.Address, + }) if err != nil { return err } @@ -259,15 +250,18 @@ func (p *httpPlugin) Deregister(ctx context.Context, name string) error { return nil } -func (p *httpPlugin) Renew(ctx context.Context, name string) error { +func (p *httpPlugin) Renew(ctx context.Context, service *sd.Service) error { if p.client == nil { return nil } - rb := httpRenewRequest{ - Name: name, - } - v, err := json.Marshal(&rb) + v, err := json.Marshal(sdService{ + ID: service.ID, + Name: service.Name, + Node: service.Node, + Network: service.Network, + Address: service.Address, + }) if err != nil { return err }