diff --git a/config/config.go b/config/config.go index 7d7055c..4c8258e 100644 --- a/config/config.go +++ b/config/config.go @@ -237,6 +237,11 @@ type IngressConfig struct { Plugin *PluginConfig `yaml:",omitempty" json:"plugin,omitempty"` } +type SDConfig struct { + Name string `json:"name"` + Plugin *PluginConfig `yaml:",omitempty" json:"plugin,omitempty"` +} + type RecorderConfig struct { Name string `json:"name"` File *FileRecorder `yaml:",omitempty" json:"file,omitempty"` @@ -436,6 +441,7 @@ type Config struct { Resolvers []*ResolverConfig `yaml:",omitempty" json:"resolvers,omitempty"` Hosts []*HostsConfig `yaml:",omitempty" json:"hosts,omitempty"` Ingresses []*IngressConfig `yaml:",omitempty" json:"ingresses,omitempty"` + SDs []*SDConfig `yaml:"sds,omitempty" json:"sds,omitempty"` Recorders []*RecorderConfig `yaml:",omitempty" json:"recorders,omitempty"` Limiters []*LimiterConfig `yaml:",omitempty" json:"limiters,omitempty"` CLimiters []*LimiterConfig `yaml:"climiters,omitempty" json:"climiters,omitempty"` diff --git a/config/parsing/sd/parse.go b/config/parsing/sd/parse.go new file mode 100644 index 0000000..dba94f2 --- /dev/null +++ b/config/parsing/sd/parse.go @@ -0,0 +1,39 @@ +package sd + +import ( + "crypto/tls" + "strings" + + "github.com/go-gost/core/sd" + "github.com/go-gost/x/config" + "github.com/go-gost/x/internal/plugin" + xsd "github.com/go-gost/x/sd" +) + +func ParseSD(cfg *config.SDConfig) sd.SD { + if cfg == nil || cfg.Plugin == nil { + return nil + } + + var tlsCfg *tls.Config + if cfg.Plugin.TLS != nil { + tlsCfg = &tls.Config{ + ServerName: cfg.Plugin.TLS.ServerName, + InsecureSkipVerify: !cfg.Plugin.TLS.Secure, + } + } + switch strings.ToLower(cfg.Plugin.Type) { + case "http": + return xsd.NewHTTPPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TLSConfigOption(tlsCfg), + plugin.TimeoutOption(cfg.Plugin.Timeout), + ) + default: + return xsd.NewGRPCPlugin( + cfg.Name, cfg.Plugin.Addr, + plugin.TokenOption(cfg.Plugin.Token), + plugin.TLSConfigOption(tlsCfg), + ) + } +} diff --git a/connector/tunnel/bind.go b/connector/tunnel/bind.go index 5a3fad4..9370184 100644 --- a/connector/tunnel/bind.go +++ b/connector/tunnel/bind.go @@ -50,6 +50,9 @@ func (c *tunnelConnector) initTunnel(conn net.Conn, network, address string) (ad if network == "udp" { req.Cmd |= relay.FUDP + req.Features = append(req.Features, &relay.NetworkFeature{ + Network: relay.NetworkUDP, + }) } if c.options.Auth != nil { diff --git a/go.mod b/go.mod index a553961..b94d6bd 100644 --- a/go.mod +++ b/go.mod @@ -7,16 +7,16 @@ require ( github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d github.com/gin-contrib/cors v1.3.1 github.com/gin-gonic/gin v1.9.1 - github.com/go-gost/core v0.0.0-20231027140845-d975ec3c7477 + github.com/go-gost/core v0.0.0-20231031145651-8835e0e647f9 github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.4.0 - github.com/go-gost/plugin v0.0.0-20231027141311-2cfb0a14b451 + github.com/go-gost/plugin v0.0.0-20231031145754-4c25027b8b97 github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 github.com/go-redis/redis/v8 v8.11.5 github.com/gobwas/glob v0.2.3 github.com/golang/snappy v0.0.4 - github.com/google/uuid v1.3.0 + github.com/google/uuid v1.4.0 github.com/gorilla/websocket v1.5.0 github.com/miekg/dns v1.1.56 github.com/patrickmn/go-cache v2.1.0+incompatible @@ -40,7 +40,7 @@ require ( golang.org/x/sys v0.13.0 golang.org/x/time v0.3.0 golang.zx2c4.com/wireguard v0.0.0-20220703234212-c31a7b1ab478 - google.golang.org/grpc v1.58.3 + google.golang.org/grpc v1.59.0 google.golang.org/protobuf v1.31.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -108,7 +108,7 @@ require ( golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.13.0 // indirect golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index 7aa2e2a..ca17d66 100644 --- a/go.sum +++ b/go.sum @@ -93,12 +93,16 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gost/core v0.0.0-20231027140845-d975ec3c7477 h1:a49XfrB4mgbw7z7oN/WTovx0X7SbxdfoANsEDTy9CqI= github.com/go-gost/core v0.0.0-20231027140845-d975ec3c7477/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= +github.com/go-gost/core v0.0.0-20231031145651-8835e0e647f9 h1:Zab4WYWl/GyhfjkoZ2JqauQlRwLGzsxs8/tHxctYlv4= +github.com/go-gost/core v0.0.0-20231031145651-8835e0e647f9/go.mod h1:ndkgWVYRLwupVaFFWv8ML1Nr8tD3xhHK245PLpUDg4E= github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2s= github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.4.0 h1:EIrOEkpJez4gwHrMa33frA+hHXJyevjp47thpMQsJzI= github.com/go-gost/gosocks5 v0.4.0/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= github.com/go-gost/plugin v0.0.0-20231027141311-2cfb0a14b451 h1:sgg7LyK4ZAuQkBfaQxyFpH+xyAfrczDFDtkdRAcUxCE= github.com/go-gost/plugin v0.0.0-20231027141311-2cfb0a14b451/go.mod h1:mM/RLNsVy2nz5PiOijuqLYR3LhMzyQ9Kh/p0rXybJoo= +github.com/go-gost/plugin v0.0.0-20231031145754-4c25027b8b97 h1:p9dmeWsNwKcbIwwUUumD5a7HlZFODBwnMItBGuJ+P5M= +github.com/go-gost/plugin v0.0.0-20231031145754-4c25027b8b97/go.mod h1:qXr2Zm9Ex2ATqnWuNUzVZqySPMnuIihvblYZt4MlZLw= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 h1:qAG1OyjvdA5h221CfFSS3J359V3d2E7dJWyP29QoDSI= github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= github.com/go-gost/tls-dissector v0.0.2-0.20220408131628-aac992c27451 h1:xj8gUZGYO3nb5+6Bjw9+tsFkA9sYynrOvDvvC4uDV2I= @@ -191,6 +195,9 @@ github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm4 github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= +github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= @@ -685,6 +692,8 @@ google.golang.org/genproto v0.0.0-20210108203827-ffc7fda8c3d7/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20210226172003-ab064af71705/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U= google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM= +google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b h1:ZlWIi1wSK56/8hn4QcBp/j9M7Gt3U/3hZw3mC7vDICo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20231016165738-49dd2c1f3d0b/go.mod h1:swOH3j0KzcDDgGUWr+SNpyTen5YrXjS3eyPzFYKc6lc= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -703,6 +712,8 @@ google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA5 google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.58.3 h1:BjnpXut1btbtgN/6sp+brB2Kbm2LjNXnidYujAVbSoQ= google.golang.org/grpc v1.58.3/go.mod h1:tgX3ZQDlNJGU96V6yHh1T/JeoBQ2TXdr43YbYSsCJk0= +google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= +google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/handler/tunnel/bind.go b/handler/tunnel/bind.go index 6b0eb01..f04eda7 100644 --- a/handler/tunnel/bind.go +++ b/handler/tunnel/bind.go @@ -8,7 +8,6 @@ import ( "net" "github.com/go-gost/core/logger" - "github.com/go-gost/core/recorder" "github.com/go-gost/relay" "github.com/go-gost/x/internal/util/mux" "github.com/google/uuid" @@ -59,15 +58,11 @@ func (h *tunnelHandler) handleBind(ctx context.Context, conn net.Conn, network, if h.md.ingress != nil { h.md.ingress.Set(ctx, addr, tunnelID.String()) } - if h.recorder != nil { - h.recorder.Record(ctx, - []byte(fmt.Sprintf("%s:%s", tunnelID, connectorID)), - recorder.MetadataReocrdOption(connectorMetadata{ - Op: "add", - Network: network, - Server: conn.LocalAddr().String(), - }), - ) + if sd := h.md.sd; sd != nil { + err := sd.Register(ctx, fmt.Sprintf("%s:%s:%s", h.id, tunnelID, connectorID), network, h.md.entryPoint) + if err != nil { + h.log.Error(err) + } } log.Debugf("%s/%s: tunnel=%s, connector=%s established", addr, network, tunnelID, connectorID) diff --git a/handler/tunnel/connect.go b/handler/tunnel/connect.go index 78ecc39..4b5f020 100644 --- a/handler/tunnel/connect.go +++ b/handler/tunnel/connect.go @@ -11,7 +11,7 @@ import ( xnet "github.com/go-gost/x/internal/net" ) -func (h *tunnelHandler) handleConnect(ctx context.Context, conn net.Conn, network, srcAddr string, dstAddr string, tunnelID relay.TunnelID, log logger.Logger) error { +func (h *tunnelHandler) handleConnect(ctx context.Context, req *relay.Request, conn net.Conn, network, srcAddr string, dstAddr string, tunnelID relay.TunnelID, log logger.Logger) error { log = log.WithFields(map[string]any{ "dst": fmt.Sprintf("%s/%s", dstAddr, network), "cmd": "connect", @@ -33,58 +33,68 @@ func (h *tunnelHandler) handleConnect(ctx context.Context, conn net.Conn, networ host, _, _ := net.SplitHostPort(dstAddr) // client is a public entrypoint. - if tunnelID.Equal(h.md.entryPointID) && !h.md.entryPointID.IsZero() { + if tunnelID.Equal(h.md.entryPointID) { resp.WriteTo(conn) return h.ep.handle(ctx, conn) } - var tid relay.TunnelID - if ingress := h.md.ingress; ingress != nil && host != "" { - tid = parseTunnelID(ingress.Get(ctx, host)) + if !h.md.directTunnel { + var tid relay.TunnelID + if ingress := h.md.ingress; ingress != nil && host != "" { + tid = parseTunnelID(ingress.Get(ctx, host)) + } + if !tid.Equal(tunnelID) { + resp.Status = relay.StatusHostUnreachable + resp.WriteTo(conn) + err := fmt.Errorf("no route to host %s", host) + log.Error(err) + return err + } } - // direct routing - if h.md.directTunnel { - tid = tunnelID - } else if !tid.Equal(tunnelID) { - resp.Status = relay.StatusHostUnreachable - resp.WriteTo(conn) - err := fmt.Errorf("no route to host %s", host) - log.Error(err) - return err + d := Dialer{ + node: h.id, + pool: h.pool, + sd: h.md.sd, + retry: 3, + timeout: 15 * time.Second, + log: log, } - - cc, _, err := getTunnelConn(network, h.pool, tid, 3, log) + cc, node, cid, err := d.Dial(ctx, network, tunnelID.String()) if err != nil { + log.Error(err) resp.Status = relay.StatusServiceUnavailable resp.WriteTo(conn) - log.Error(err) return err } defer cc.Close() - log.Debugf("%s >> %s", conn.RemoteAddr(), cc.RemoteAddr()) + log.Debugf("new connection to tunnel: %s, connector: %s", tunnelID, cid) - if _, err := resp.WriteTo(conn); err != nil { - log.Error(err) - return err + if node == h.id { + if _, err := resp.WriteTo(conn); err != nil { + log.Error(err) + return err + } + + resp = relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + } + + af := &relay.AddrFeature{} + af.ParseFrom(srcAddr) + resp.Features = append(resp.Features, af) // src address + + af = &relay.AddrFeature{} + af.ParseFrom(dstAddr) + resp.Features = append(resp.Features, af) // dst address + + resp.WriteTo(cc) + } else { + req.WriteTo(cc) } - resp = relay.Response{ - Version: relay.Version1, - Status: relay.StatusOK, - } - - af := &relay.AddrFeature{} - af.ParseFrom(srcAddr) - resp.Features = append(resp.Features, af) // src address - - af = &relay.AddrFeature{} - af.ParseFrom(dstAddr) - resp.Features = append(resp.Features, af) // dst address - - resp.WriteTo(cc) - t := time.Now() log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) xnet.Transport(conn, cc) diff --git a/handler/tunnel/dialer.go b/handler/tunnel/dialer.go new file mode 100644 index 0000000..b8e14b3 --- /dev/null +++ b/handler/tunnel/dialer.go @@ -0,0 +1,76 @@ +package tunnel + +import ( + "context" + "net" + "time" + + "github.com/go-gost/core/logger" + "github.com/go-gost/core/sd" +) + +type Dialer struct { + node string + pool *ConnectorPool + sd sd.SD + retry int + timeout time.Duration + log logger.Logger +} + +func (d *Dialer) Dial(ctx context.Context, network string, tid string) (conn net.Conn, node string, cid string, err error) { + retry := d.retry + retry = 1 + + for i := 0; i < retry; i++ { + c := d.pool.Get(network, tid) + if c == nil { + break + } + + conn, err = c.Session().GetConn() + if err != nil { + d.log.Error(err) + continue + } + node = d.node + cid = c.id.String() + + break + } + if conn != nil || err != nil { + return + } + + if d.sd == nil { + err = ErrTunnelNotAvailable + return + } + + ss, err := d.sd.Get(ctx, tid) + if err != nil { + return + } + + var service *sd.Service + for _, s := range ss { + d.log.Debugf("%+v", s) + if s.Name != d.node && s.Network == network { + service = s + break + } + } + if service == nil || service.Address == "" { + err = ErrTunnelNotAvailable + return + } + + node = service.Node + cid = service.Name + + dialer := net.Dialer{ + Timeout: d.timeout, + } + conn, err = dialer.DialContext(ctx, network, service.Address) + return +} diff --git a/handler/tunnel/entrypoint.go b/handler/tunnel/entrypoint.go index a6e3f63..d454058 100644 --- a/handler/tunnel/entrypoint.go +++ b/handler/tunnel/entrypoint.go @@ -17,6 +17,7 @@ import ( "github.com/go-gost/core/listener" "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" + "github.com/go-gost/core/sd" "github.com/go-gost/relay" admission "github.com/go-gost/x/admission/wrapper" xio "github.com/go-gost/x/internal/io" @@ -28,8 +29,10 @@ import ( ) type entrypoint struct { + node string pool *ConnectorPool ingress ingress.Ingress + sd sd.SD log logger.Logger } @@ -51,6 +54,14 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error { br := bufio.NewReader(conn) + v, err := br.Peek(1) + if err != nil { + return err + } + if v[0] == relay.Version1 { + return ep.handleConnect(ctx, xnet.NewBufferReaderConn(conn, br), log) + } + var cc net.Conn for { resp := &http.Response{ @@ -102,32 +113,42 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error { remoteAddr = addr } - cc, cid, err := getTunnelConn("tcp", ep.pool, tunnelID, 3, log) + d := &Dialer{ + node: ep.node, + pool: ep.pool, + sd: ep.sd, + retry: 3, + timeout: 15 * time.Second, + log: log, + } + cc, node, cid, err := d.Dial(ctx, "tcp", tunnelID.String()) if err != nil { log.Error(err) return resp.Write(conn) } - log.Debugf("new connection to tunnel: %s, connector: %s", tunnelID, cid) - var features []relay.Feature - af := &relay.AddrFeature{} - af.ParseFrom(remoteAddr.String()) - features = append(features, af) // src address - host := req.Host if h, _, _ := net.SplitHostPort(host); h == "" { host = net.JoinHostPort(host, "80") } - af = &relay.AddrFeature{} - af.ParseFrom(host) - features = append(features, af) // dst address - (&relay.Response{ - Version: relay.Version1, - Status: relay.StatusOK, - Features: features, - }).WriteTo(cc) + if node == ep.node { + var features []relay.Feature + af := &relay.AddrFeature{} + af.ParseFrom(remoteAddr.String()) + features = append(features, af) // src address + + af = &relay.AddrFeature{} + af.ParseFrom(host) + features = append(features, af) // dst address + + (&relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + Features: features, + }).WriteTo(cc) + } if err := req.Write(cc); err != nil { cc.Close() @@ -186,6 +207,90 @@ func (ep *entrypoint) handle(ctx context.Context, conn net.Conn) error { return nil } +func (ep *entrypoint) handleConnect(ctx context.Context, conn net.Conn, log logger.Logger) error { + req := relay.Request{} + if _, err := req.ReadFrom(conn); err != nil { + return err + } + + resp := relay.Response{ + Version: relay.Version1, + Status: relay.StatusOK, + } + + var srcAddr, dstAddr string + network := "tcp" + var tunnelID relay.TunnelID + for _, f := range req.Features { + switch f.Type() { + case relay.FeatureAddr: + if feature, _ := f.(*relay.AddrFeature); feature != nil { + v := net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port))) + if srcAddr != "" { + dstAddr = v + } else { + srcAddr = v + } + } + case relay.FeatureTunnel: + if feature, _ := f.(*relay.TunnelFeature); feature != nil { + tunnelID = relay.NewTunnelID(feature.ID[:]) + } + case relay.FeatureNetwork: + if feature, _ := f.(*relay.NetworkFeature); feature != nil { + network = feature.Network.String() + } + } + } + + if tunnelID.IsZero() { + resp.Status = relay.StatusBadRequest + resp.WriteTo(conn) + return ErrTunnelID + } + + d := Dialer{ + pool: ep.pool, + retry: 3, + timeout: 15 * time.Second, + log: log, + } + cc, _, cid, err := d.Dial(ctx, network, tunnelID.String()) + if err != nil { + log.Error(err) + resp.Status = relay.StatusServiceUnavailable + resp.WriteTo(conn) + return err + } + defer cc.Close() + + log.Debugf("new connection to tunnel: %s, connector: %s", tunnelID, cid) + + if _, err := resp.WriteTo(conn); err != nil { + log.Error(err) + return err + } + + af := &relay.AddrFeature{} + af.ParseFrom(srcAddr) + resp.Features = append(resp.Features, af) // src address + + af = &relay.AddrFeature{} + af.ParseFrom(dstAddr) + resp.Features = append(resp.Features, af) // dst address + + resp.WriteTo(cc) + + t := time.Now() + log.Debugf("%s <-> %s", conn.RemoteAddr(), cc.RemoteAddr()) + xnet.Transport(conn, cc) + log.WithFields(map[string]any{ + "duration": time.Since(t), + }).Debugf("%s >-< %s", conn.RemoteAddr(), cc.RemoteAddr()) + + return nil +} + func (ep *entrypoint) getRealClientAddr(req *http.Request, raddr net.Addr) net.Addr { if req == nil { return nil diff --git a/handler/tunnel/handler.go b/handler/tunnel/handler.go index ad8d491..e715264 100644 --- a/handler/tunnel/handler.go +++ b/handler/tunnel/handler.go @@ -8,9 +8,9 @@ import ( "strconv" "time" - "github.com/go-gost/core/chain" "github.com/go-gost/core/handler" "github.com/go-gost/core/listener" + "github.com/go-gost/core/logger" md "github.com/go-gost/core/metadata" "github.com/go-gost/core/recorder" "github.com/go-gost/core/service" @@ -20,14 +20,16 @@ import ( xrecorder "github.com/go-gost/x/recorder" "github.com/go-gost/x/registry" xservice "github.com/go-gost/x/service" + "github.com/google/uuid" ) var ( - ErrBadVersion = errors.New("relay: bad version") - ErrUnknownCmd = errors.New("relay: unknown command") - ErrTunnelID = errors.New("tunnel: invalid tunnel ID") - ErrUnauthorized = errors.New("relay: unauthorized") - ErrRateLimit = errors.New("relay: rate limiting exceeded") + ErrBadVersion = errors.New("bad version") + ErrUnknownCmd = errors.New("unknown command") + ErrTunnelID = errors.New("invalid tunnel ID") + ErrTunnelNotAvailable = errors.New("tunnel not available") + ErrUnauthorized = errors.New("unauthorized") + ErrRateLimit = errors.New("rate limiting exceeded") ) func init() { @@ -35,13 +37,14 @@ func init() { } type tunnelHandler struct { - router *chain.Router - md metadata + id string options handler.Options pool *ConnectorPool recorder recorder.Recorder epSvc service.Service ep *entrypoint + md metadata + log logger.Logger } func NewHandler(opts ...handler.Option) handler.Handler { @@ -60,26 +63,33 @@ func (h *tunnelHandler) Init(md md.Metadata) (err error) { return err } - h.router = h.options.Router - if h.router == nil { - h.router = chain.NewRouter(chain.LoggerRouterOption(h.options.Logger)) + uuid, err := uuid.NewRandom() + if err != nil { + return err } + h.id = uuid.String() - if opts := h.router.Options(); opts != nil { + h.log = h.options.Logger.WithFields(map[string]any{ + "node": h.id, + }) + + if opts := h.options.Router.Options(); opts != nil { for _, ro := range opts.Recorders { - if ro.Record == xrecorder.RecorderServiceHandlerTunnelConnector { + if ro.Record == xrecorder.RecorderServiceHandlerTunnel { h.recorder = ro.Recorder break } } } - h.pool = NewConnectorPool() - h.pool.WithRecorder(h.recorder) + + h.pool = NewConnectorPool(h.id, h.md.sd) h.ep = &entrypoint{ + node: h.id, pool: h.pool, ingress: h.md.ingress, - log: h.options.Logger.WithFields(map[string]any{ + sd: h.md.sd, + log: h.log.WithFields(map[string]any{ "kind": "entrypoint", }), } @@ -102,12 +112,12 @@ func (h *tunnelHandler) initEntrypoint() (err error) { ln, err := net.Listen(network, h.md.entryPoint) if err != nil { - h.options.Logger.Error(err) + h.log.Error(err) return } serviceName := fmt.Sprintf("%s-ep-%s", h.options.Service, ln.Addr()) - log := h.options.Logger.WithFields(map[string]any{ + log := h.log.WithFields(map[string]any{ "service": serviceName, "listener": "tcp", "handler": "tunnel-ep", @@ -143,7 +153,7 @@ func (h *tunnelHandler) initEntrypoint() (err error) { func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handler.HandleOption) (err error) { start := time.Now() - log := h.options.Logger.WithFields(map[string]any{ + log := h.log.WithFields(map[string]any{ "remote": conn.RemoteAddr().String(), "local": conn.LocalAddr().String(), }) @@ -189,7 +199,7 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl var user, pass string var srcAddr, dstAddr string - var networkID relay.NetworkID + network := "tcp" var tunnelID relay.TunnelID for _, f := range req.Features { switch f.Type() { @@ -212,7 +222,7 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl } case relay.FeatureNetwork: if feature, _ := f.(*relay.NetworkFeature); feature != nil { - networkID = feature.Network + network = feature.Network.String() } } } @@ -237,17 +247,13 @@ func (h *tunnelHandler) Handle(ctx context.Context, conn net.Conn, opts ...handl ctx = auth_util.ContextWithID(ctx, auth_util.ID(id)) } - network := networkID.String() - if (req.Cmd & relay.FUDP) == relay.FUDP { - network = "udp" - } - switch req.Cmd & relay.CmdMask { case relay.CmdConnect: defer conn.Close() log.Debugf("connect: %s >> %s/%s", srcAddr, dstAddr, network) - return h.handleConnect(ctx, conn, network, srcAddr, dstAddr, tunnelID, log) + return h.handleConnect(ctx, &req, conn, network, srcAddr, dstAddr, tunnelID, log) + case relay.CmdBind: log.Debugf("bind: %s >> %s/%s", srcAddr, dstAddr, network) return h.handleBind(ctx, conn, network, dstAddr, tunnelID, log) diff --git a/handler/tunnel/metadata.go b/handler/tunnel/metadata.go index 4d5868e..67d3040 100644 --- a/handler/tunnel/metadata.go +++ b/handler/tunnel/metadata.go @@ -8,6 +8,7 @@ import ( "github.com/go-gost/core/logger" mdata "github.com/go-gost/core/metadata" mdutil "github.com/go-gost/core/metadata/util" + "github.com/go-gost/core/sd" "github.com/go-gost/relay" xingress "github.com/go-gost/x/ingress" "github.com/go-gost/x/internal/util/mux" @@ -21,6 +22,7 @@ type metadata struct { entryPointProxyProtocol int directTunnel bool ingress ingress.Ingress + sd sd.SD muxCfg *mux.Config } @@ -54,6 +56,7 @@ func (h *tunnelHandler) parseMetadata(md mdata.Metadata) (err error) { ) } } + h.md.sd = registry.SDRegistry().Get(mdutil.GetString(md, "sd")) h.md.muxCfg = &mux.Config{ Version: mdutil.GetInt(md, "mux.version"), diff --git a/handler/tunnel/tunnel.go b/handler/tunnel/tunnel.go index 134db6d..6103622 100644 --- a/handler/tunnel/tunnel.go +++ b/handler/tunnel/tunnel.go @@ -3,24 +3,17 @@ package tunnel import ( "context" "fmt" - "net" "sync" "sync/atomic" "time" "github.com/go-gost/core/logger" - "github.com/go-gost/core/recorder" + "github.com/go-gost/core/sd" "github.com/go-gost/relay" "github.com/go-gost/x/internal/util/mux" "github.com/google/uuid" ) -type connectorMetadata struct { - Op string - Network string - Server string -} - type Connector struct { id relay.ConnectorID t time.Time @@ -58,18 +51,20 @@ func (c *Connector) Session() *mux.Session { } type Tunnel struct { + node string id relay.TunnelID connectors []*Connector t time.Time n uint64 close chan struct{} mu sync.RWMutex - recorder recorder.Recorder + sd sd.SD } -func NewTunnel(id relay.TunnelID) *Tunnel { +func NewTunnel(node string, tid relay.TunnelID) *Tunnel { t := &Tunnel{ - id: id, + node: node, + id: tid, t: time.Now(), close: make(chan struct{}), } @@ -77,8 +72,8 @@ func NewTunnel(id relay.TunnelID) *Tunnel { return t } -func (t *Tunnel) WithRecorder(recorder recorder.Recorder) { - t.recorder = recorder +func (t *Tunnel) WithSD(sd sd.SD) { + t.sd = sd } func (t *Tunnel) ID() relay.TunnelID { @@ -142,30 +137,21 @@ func (t *Tunnel) clean() { t.mu.Lock() if len(t.connectors) == 0 { t.mu.Unlock() + break } var connectors []*Connector for _, c := range t.connectors { if c.Session().IsClosed() { logger.Default().Debugf("remove tunnel: %s, connector: %s", t.id, c.id) - if t.recorder != nil { - t.recorder.Record(context.Background(), - []byte(fmt.Sprintf("%s:%s", t.id, c.id)), - recorder.MetadataReocrdOption(connectorMetadata{ - Op: "del", - }), - ) + if t.sd != nil { + t.sd.Deregister(context.Background(), fmt.Sprintf("%s:%s:%s", t.node, t.id, c.id)) } continue } connectors = append(connectors, c) - if t.recorder != nil { - t.recorder.Record(context.Background(), - []byte(fmt.Sprintf("%s:%s", t.id, c.id)), - recorder.MetadataReocrdOption(connectorMetadata{ - Op: "set", - }), - ) + if t.sd != nil { + t.sd.Renew(context.Background(), fmt.Sprintf("%s:%s:%s", t.node, t.id, c.id)) } } if len(connectors) != len(t.connectors) { @@ -179,23 +165,22 @@ func (t *Tunnel) clean() { } type ConnectorPool struct { - tunnels map[string]*Tunnel - mu sync.RWMutex - recorder recorder.Recorder + node string + sd sd.SD + tunnels map[string]*Tunnel + mu sync.RWMutex } -func NewConnectorPool() *ConnectorPool { +func NewConnectorPool(node string, sd sd.SD) *ConnectorPool { p := &ConnectorPool{ + node: node, + sd: sd, tunnels: make(map[string]*Tunnel), } go p.closeIdles() return p } -func (p *ConnectorPool) WithRecorder(recorder recorder.Recorder) { - p.recorder = recorder -} - func (p *ConnectorPool) Add(tid relay.TunnelID, c *Connector) { p.mu.Lock() defer p.mu.Unlock() @@ -204,15 +189,15 @@ func (p *ConnectorPool) Add(tid relay.TunnelID, c *Connector) { t := p.tunnels[s] if t == nil { - t = NewTunnel(tid) - t.WithRecorder(p.recorder) + t = NewTunnel(p.node, tid) + t.WithSD(p.sd) p.tunnels[s] = t } t.AddConnector(c) } -func (p *ConnectorPool) Get(network string, tid relay.TunnelID) *Connector { +func (p *ConnectorPool) Get(network string, tid string) *Connector { if p == nil { return nil } @@ -220,7 +205,7 @@ func (p *ConnectorPool) Get(network string, tid relay.TunnelID) *Connector { p.mu.RLock() defer p.mu.RUnlock() - t := p.tunnels[tid.String()] + t := p.tunnels[tid] if t == nil { return nil } @@ -260,31 +245,3 @@ func parseTunnelID(s string) (tid relay.TunnelID) { } return relay.NewTunnelID(uuid[:]) } - -func getTunnelConn(network string, pool *ConnectorPool, tid relay.TunnelID, retry int, log logger.Logger) (conn net.Conn, cid relay.ConnectorID, err error) { - if tid.IsZero() { - err = ErrTunnelID - return - } - - if retry <= 0 { - retry = 1 - } - for i := 0; i < retry; i++ { - c := pool.Get(network, tid) - if c == nil { - err = fmt.Errorf("tunnel %s not available", tid.String()) - break - } - - conn, err = c.Session().GetConn() - if err != nil { - log.Error(err) - continue - } - cid = c.id - break - } - - return -} diff --git a/ingress/plugin.go b/ingress/plugin.go index 7097a23..9e7f504 100644 --- a/ingress/plugin.go +++ b/ingress/plugin.go @@ -126,23 +126,19 @@ func (p *httpPlugin) Get(ctx context.Context, host string, opts ...ingress.GetOp return } - rb := httpPluginGetRequest{ - Host: host, - } - v, err := json.Marshal(&rb) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.url, nil) if err != nil { return } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) - if err != nil { - return - } - if p.header != nil { req.Header = p.header.Clone() } req.Header.Set("Content-Type", "application/json") + + q := req.URL.Query() + q.Set("host", host) + req.URL.RawQuery = q.Encode() + resp, err := p.client.Do(req) if err != nil { return diff --git a/recorder/recorder.go b/recorder/recorder.go index 9d1eb8b..303b915 100644 --- a/recorder/recorder.go +++ b/recorder/recorder.go @@ -1,6 +1,6 @@ package recorder const ( - RecorderServiceHandlerSerial = "recorder.service.handler.serial" - RecorderServiceHandlerTunnelConnector = "recorder.service.handler.tunnel.connector" + RecorderServiceHandlerSerial = "recorder.service.handler.serial" + RecorderServiceHandlerTunnel = "recorder.service.handler.tunnel" ) diff --git a/registry/registry.go b/registry/registry.go index ba87bc0..9a3731d 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -18,6 +18,7 @@ import ( "github.com/go-gost/core/recorder" reg "github.com/go-gost/core/registry" "github.com/go-gost/core/resolver" + "github.com/go-gost/core/sd" "github.com/go-gost/core/service" ) @@ -32,7 +33,7 @@ var ( connectorReg reg.Registry[NewConnector] = new(connectorRegistry) serviceReg reg.Registry[service.Service] = new(serviceRegistry) chainReg reg.Registry[chain.Chainer] = new(chainRegistry) - hopReg reg.Registry[hop.Hop] = new(hopRegistry) + hopReg reg.Registry[hop.Hop] = new(hopRegistry) autherReg reg.Registry[auth.Authenticator] = new(autherRegistry) admissionReg reg.Registry[admission.Admission] = new(admissionRegistry) bypassReg reg.Registry[bypass.Bypass] = new(bypassRegistry) @@ -45,6 +46,7 @@ var ( rateLimiterReg reg.Registry[rate.RateLimiter] = new(rateLimiterRegistry) ingressReg reg.Registry[ingress.Ingress] = new(ingressRegistry) + sdReg reg.Registry[sd.SD] = new(sdRegistry) ) type registry[T any] struct { @@ -163,3 +165,7 @@ func RateLimiterRegistry() reg.Registry[rate.RateLimiter] { func IngressRegistry() reg.Registry[ingress.Ingress] { return ingressReg } + +func SDRegistry() reg.Registry[sd.SD] { + return sdReg +} diff --git a/registry/sd.go b/registry/sd.go new file mode 100644 index 0000000..fa73017 --- /dev/null +++ b/registry/sd.go @@ -0,0 +1,66 @@ +package registry + +import ( + "context" + + "github.com/go-gost/core/sd" +) + +type sdRegistry struct { + registry[sd.SD] +} + +func (r *sdRegistry) Register(name string, v sd.SD) error { + return r.registry.Register(name, v) +} + +func (r *sdRegistry) Get(name string) sd.SD { + if name != "" { + return &sdWrapper{name: name, r: r} + } + return nil +} + +func (r *sdRegistry) get(name string) sd.SD { + return r.registry.Get(name) +} + +type sdWrapper struct { + name string + r *sdRegistry +} + +func (w *sdWrapper) Register(ctx context.Context, name string, network, address string, opts ...sd.Option) error { + v := w.r.get(w.name) + if v == nil { + return nil + } + return v.Register(ctx, name, network, address, opts...) +} + +func (w *sdWrapper) Deregister(ctx context.Context, name string) error { + v := w.r.get(w.name) + if v == nil { + return nil + } + + return v.Deregister(ctx, name) +} + +func (w *sdWrapper) Renew(ctx context.Context, name string) error { + v := w.r.get(w.name) + if v == nil { + return nil + } + + return v.Renew(ctx, name) +} + +func (w *sdWrapper) Get(ctx context.Context, name string) ([]*sd.Service, error) { + v := w.r.get(w.name) + if v == nil { + return nil, nil + } + + return v.Get(ctx, name) +} diff --git a/sd/plugin.go b/sd/plugin.go new file mode 100644 index 0000000..e66c941 --- /dev/null +++ b/sd/plugin.go @@ -0,0 +1,343 @@ +package ingress + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/go-gost/core/logger" + "github.com/go-gost/core/sd" + "github.com/go-gost/plugin/sd/proto" + "github.com/go-gost/x/internal/plugin" + "google.golang.org/grpc" +) + +type grpcPlugin struct { + conn grpc.ClientConnInterface + client proto.SDClient + log logger.Logger +} + +// NewGRPCPlugin creates an SD plugin based on gRPC. +func NewGRPCPlugin(name string, addr string, opts ...plugin.Option) sd.SD { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + log := logger.Default().WithFields(map[string]any{ + "kind": "sd", + "sd": name, + }) + conn, err := plugin.NewGRPCConn(addr, &options) + if err != nil { + log.Error(err) + } + + p := &grpcPlugin{ + conn: conn, + log: log, + } + if conn != nil { + p.client = proto.NewSDClient(conn) + } + return p +} + +func (p *grpcPlugin) Register(ctx context.Context, name string, network, address string, opts ...sd.Option) error { + if p.client == nil { + return nil + } + + _, err := p.client.Register(ctx, + &proto.RegisterRequest{ + Name: name, + Network: network, + Address: address, + }) + if err != nil { + p.log.Error(err) + return err + } + return nil +} + +func (p *grpcPlugin) Deregister(ctx context.Context, name string) error { + if p.client == nil { + return nil + } + + _, err := p.client.Deregister(ctx, &proto.DeregisterRequest{ + Name: name, + }) + return err +} + +func (p *grpcPlugin) Renew(ctx context.Context, name string) error { + if p.client == nil { + return nil + } + + _, err := p.client.Renew(ctx, &proto.RenewRequest{ + Name: name, + }) + return err +} + +func (p *grpcPlugin) Get(ctx context.Context, name string) ([]*sd.Service, error) { + if p.client == nil { + return nil, nil + } + + r, err := p.client.Get(ctx, &proto.GetServiceRequest{ + Name: name, + }) + if err != nil { + return nil, err + } + + var services []*sd.Service + for _, v := range r.Services { + if v == nil { + continue + } + services = append(services, &sd.Service{ + Node: v.Node, + Name: v.Name, + Network: v.Network, + Address: v.Address, + }) + } + return services, nil +} + +func (p *grpcPlugin) Close() error { + if closer, ok := p.conn.(io.Closer); ok { + return closer.Close() + } + return nil +} + +type httpRegisterRequest struct { + Name string `json:"name"` + Network string `json:"network"` + Address string `json:"address"` +} + +type httpRegisterResponse struct { + Ok bool `json:"ok"` +} + +type httpDeregisterRequest struct { + Name string `json:"name"` +} + +type httpDeregisterResponse struct { + Ok bool `json:"ok"` +} + +type httpRenewRequest struct { + Name string `json:"name"` +} + +type httpRenewResponse struct { + Ok bool `json:"ok"` +} + +type httpGetRequest struct { + Name string `json:"name"` +} + +type sdService struct { + Node string `json:"node"` + Name string `json:"name"` + Network string `json:"network"` + Address string `json:"address"` +} + +type httpGetResponse struct { + Services []*sdService +} + +type httpPlugin struct { + url string + client *http.Client + header http.Header + log logger.Logger +} + +// NewHTTPPlugin creates an SD plugin based on HTTP. +func NewHTTPPlugin(name string, url string, opts ...plugin.Option) sd.SD { + var options plugin.Options + for _, opt := range opts { + opt(&options) + } + + return &httpPlugin{ + url: url, + client: plugin.NewHTTPClient(&options), + header: options.Header, + log: logger.Default().WithFields(map[string]any{ + "kind": "sd", + "sd": name, + }), + } +} + +func (p *httpPlugin) Register(ctx context.Context, name string, network, address string, opts ...sd.Option) error { + if p.client == nil { + return nil + } + + rb := httpRegisterRequest{ + Name: name, + Network: network, + Address: address, + } + v, err := json.Marshal(&rb) + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url, bytes.NewReader(v)) + if err != nil { + return err + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf(resp.Status) + } + + return nil +} + +func (p *httpPlugin) Deregister(ctx context.Context, name string) error { + if p.client == nil { + return nil + } + + rb := httpDeregisterRequest{ + Name: name, + } + v, err := json.Marshal(&rb) + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, p.url, bytes.NewReader(v)) + if err != nil { + return err + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf(resp.Status) + } + + return nil +} + +func (p *httpPlugin) Renew(ctx context.Context, name string) error { + if p.client == nil { + return nil + } + + rb := httpRenewRequest{ + Name: name, + } + v, err := json.Marshal(&rb) + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPut, p.url, bytes.NewReader(v)) + if err != nil { + return err + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + resp, err := p.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf(resp.Status) + } + + return nil +} + +func (p *httpPlugin) Get(ctx context.Context, name string) (services []*sd.Service, err error) { + if p.client == nil { + return + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.url, nil) + if err != nil { + return + } + + if p.header != nil { + req.Header = p.header.Clone() + } + req.Header.Set("Content-Type", "application/json") + + q := req.URL.Query() + q.Set("name", name) + req.URL.RawQuery = q.Encode() + + resp, err := p.client.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf(resp.Status) + } + + res := &httpGetResponse{} + if err := json.NewDecoder(resp.Body).Decode(&res); err != nil { + return nil, err + } + + for _, v := range res.Services { + if v == nil { + continue + } + services = append(services, &sd.Service{ + Node: v.Node, + Name: v.Name, + Network: v.Network, + Address: v.Address, + }) + } + return +}