From 818463260f78457f2eea1307020943f077eab8c9 Mon Sep 17 00:00:00 2001 From: ginuerzh Date: Wed, 27 Jul 2022 18:01:22 +0800 Subject: [PATCH] add path option support for grpc --- dialer/grpc/dialer.go | 8 +-- dialer/grpc/metadata.go | 3 ++ internal/util/grpc/proto/gost_grpcx.go | 72 ++++++++++++++++++++++++++ listener/grpc/listener.go | 4 +- listener/grpc/metadata.go | 3 ++ 5 files changed, 84 insertions(+), 6 deletions(-) create mode 100644 internal/util/grpc/proto/gost_grpcx.go diff --git a/dialer/grpc/dialer.go b/dialer/grpc/dialer.go index 53ffcc3..e4f7c2e 100644 --- a/dialer/grpc/dialer.go +++ b/dialer/grpc/dialer.go @@ -21,7 +21,7 @@ func init() { } type grpcDialer struct { - clients map[string]pb.GostTunelClient + clients map[string]pb.GostTunelClientX clientMutex sync.Mutex md metadata options dialer.Options @@ -34,7 +34,7 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer { } return &grpcDialer{ - clients: make(map[string]pb.GostTunelClient), + clients: make(map[string]pb.GostTunelClientX), options: options, } } @@ -95,11 +95,11 @@ func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO d.options.Logger.Error(err) return nil, err } - client = pb.NewGostTunelClient(cc) + client = pb.NewGostTunelClientX(cc) d.clients[addr] = client } - cli, err := client.Tunnel(ctx) + cli, err := client.TunnelX(ctx, d.md.path) if err != nil { return nil, err } diff --git a/dialer/grpc/metadata.go b/dialer/grpc/metadata.go index 25ad5bc..df7cd60 100644 --- a/dialer/grpc/metadata.go +++ b/dialer/grpc/metadata.go @@ -8,16 +8,19 @@ import ( type metadata struct { insecure bool host string + path string } func (d *grpcDialer) parseMetadata(md mdata.Metadata) (err error) { const ( insecure = "grpcInsecure" host = "host" + path = "path" ) d.md.insecure = mdx.GetBool(md, insecure) d.md.host = mdx.GetString(md, host) + d.md.path = mdx.GetString(md, path) return } diff --git a/internal/util/grpc/proto/gost_grpcx.go b/internal/util/grpc/proto/gost_grpcx.go new file mode 100644 index 0000000..60eb066 --- /dev/null +++ b/internal/util/grpc/proto/gost_grpcx.go @@ -0,0 +1,72 @@ +package proto + +import ( + context "context" + "strings" + + grpc "google.golang.org/grpc" +) + +type GostTunelClientX interface { + TunnelX(ctx context.Context, method string, opts ...grpc.CallOption) (GostTunel_TunnelClient, error) +} +type gostTunelClientX struct { + cc grpc.ClientConnInterface +} + +func NewGostTunelClientX(cc grpc.ClientConnInterface) GostTunelClientX { + return &gostTunelClientX{ + cc: cc, + } +} + +func (c *gostTunelClientX) TunnelX(ctx context.Context, method string, opts ...grpc.CallOption) (GostTunel_TunnelClient, error) { + sd := ServerDesc(method) + method = "/" + sd.ServiceName + "/" + sd.Streams[0].StreamName + stream, err := c.cc.NewStream(ctx, &sd.Streams[0], method, opts...) + if err != nil { + return nil, err + } + x := &gostTunelTunnelClient{stream} + return x, nil +} + +func RegisterGostTunelServerX(s grpc.ServiceRegistrar, srv GostTunelServer, method string) { + sd := ServerDesc(method) + s.RegisterService(&sd, srv) +} + +func ServerDesc(method string) grpc.ServiceDesc { + serviceName, streamName := parsingMethod(method) + + return grpc.ServiceDesc{ + ServiceName: serviceName, + HandlerType: GostTunel_ServiceDesc.HandlerType, + Methods: GostTunel_ServiceDesc.Methods, + Streams: []grpc.StreamDesc{ + { + StreamName: streamName, + Handler: GostTunel_ServiceDesc.Streams[0].Handler, + ServerStreams: GostTunel_ServiceDesc.Streams[0].ServerStreams, + ClientStreams: GostTunel_ServiceDesc.Streams[0].ClientStreams, + }, + }, + Metadata: GostTunel_ServiceDesc.Metadata, + } + +} + +func parsingMethod(method string) (string, string) { + serviceName := GostTunel_ServiceDesc.ServiceName + streamName := GostTunel_ServiceDesc.Streams[0].StreamName + v := strings.SplitN(strings.Trim(method, "/"), "/", 2) + if len(v) == 1 && v[0] != "" { + serviceName = v[0] + } + if len(v) == 2 { + serviceName = v[0] + streamName = strings.Replace(v[1], "/", "-", -1) + } + + return serviceName, streamName +} diff --git a/listener/grpc/listener.go b/listener/grpc/listener.go index b21fadc..5ee75c9 100644 --- a/listener/grpc/listener.go +++ b/listener/grpc/listener.go @@ -66,11 +66,11 @@ func (l *grpcListener) Init(md md.Metadata) (err error) { l.cqueue = make(chan net.Conn, l.md.backlog) l.errChan = make(chan error, 1) - pb.RegisterGostTunelServer(l.server, &server{ + pb.RegisterGostTunelServerX(l.server, &server{ cqueue: l.cqueue, localAddr: l.addr, logger: l.options.Logger, - }) + }, l.md.path) go func() { err := l.server.Serve(ln) diff --git a/listener/grpc/metadata.go b/listener/grpc/metadata.go index 6b9e796..c11d7ab 100644 --- a/listener/grpc/metadata.go +++ b/listener/grpc/metadata.go @@ -12,12 +12,14 @@ const ( type metadata struct { backlog int insecure bool + path string } func (l *grpcListener) parseMetadata(md mdata.Metadata) (err error) { const ( backlog = "backlog" insecure = "grpcInsecure" + path = "path" ) l.md.backlog = mdx.GetInt(md, backlog) @@ -26,5 +28,6 @@ func (l *grpcListener) parseMetadata(md mdata.Metadata) (err error) { } l.md.insecure = mdx.GetBool(md, insecure) + l.md.path = mdx.GetString(md, path) return }