add grpc tunnel
This commit is contained in:
@ -36,7 +36,9 @@ func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) {
|
||||
}
|
||||
|
||||
func (tr *Transport) dialOptions() []dialer.DialOption {
|
||||
var opts []dialer.DialOption
|
||||
opts := []dialer.DialOption{
|
||||
dialer.HostDialOption(tr.addr),
|
||||
}
|
||||
if !tr.route.IsEmpty() {
|
||||
opts = append(opts,
|
||||
dialer.DialFuncDialOption(
|
||||
|
148
pkg/common/util/grpc/proto/gost.pb.go
Normal file
148
pkg/common/util/grpc/proto/gost.pb.go
Normal file
@ -0,0 +1,148 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.12.4
|
||||
// source: gost.proto
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type Chunk struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Chunk) Reset() {
|
||||
*x = Chunk{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_gost_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *Chunk) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Chunk) ProtoMessage() {}
|
||||
|
||||
func (x *Chunk) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_gost_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Chunk.ProtoReflect.Descriptor instead.
|
||||
func (*Chunk) Descriptor() ([]byte, []int) {
|
||||
return file_gost_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *Chunk) GetData() []byte {
|
||||
if x != nil {
|
||||
return x.Data
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var File_gost_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_gost_proto_rawDesc = []byte{
|
||||
0x0a, 0x0a, 0x67, 0x6f, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x1b, 0x0a, 0x05,
|
||||
0x43, 0x68, 0x75, 0x6e, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x32, 0x29, 0x0a, 0x09, 0x47, 0x6f, 0x73,
|
||||
0x74, 0x54, 0x75, 0x6e, 0x65, 0x6c, 0x12, 0x1c, 0x0a, 0x06, 0x54, 0x75, 0x6e, 0x6e, 0x65, 0x6c,
|
||||
0x12, 0x06, 0x2e, 0x43, 0x68, 0x75, 0x6e, 0x6b, 0x1a, 0x06, 0x2e, 0x43, 0x68, 0x75, 0x6e, 0x6b,
|
||||
0x28, 0x01, 0x30, 0x01, 0x42, 0x34, 0x5a, 0x32, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63,
|
||||
0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x67, 0x6f, 0x73, 0x74, 0x2f, 0x67, 0x6f, 0x73, 0x74, 0x2f,
|
||||
0x70, 0x6b, 0x67, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x75, 0x74, 0x69, 0x6c, 0x2f,
|
||||
0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_gost_proto_rawDescOnce sync.Once
|
||||
file_gost_proto_rawDescData = file_gost_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_gost_proto_rawDescGZIP() []byte {
|
||||
file_gost_proto_rawDescOnce.Do(func() {
|
||||
file_gost_proto_rawDescData = protoimpl.X.CompressGZIP(file_gost_proto_rawDescData)
|
||||
})
|
||||
return file_gost_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_gost_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
|
||||
var file_gost_proto_goTypes = []interface{}{
|
||||
(*Chunk)(nil), // 0: Chunk
|
||||
}
|
||||
var file_gost_proto_depIdxs = []int32{
|
||||
0, // 0: GostTunel.Tunnel:input_type -> Chunk
|
||||
0, // 1: GostTunel.Tunnel:output_type -> Chunk
|
||||
1, // [1:2] is the sub-list for method output_type
|
||||
0, // [0:1] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_gost_proto_init() }
|
||||
func file_gost_proto_init() {
|
||||
if File_gost_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_gost_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*Chunk); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_gost_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 1,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_gost_proto_goTypes,
|
||||
DependencyIndexes: file_gost_proto_depIdxs,
|
||||
MessageInfos: file_gost_proto_msgTypes,
|
||||
}.Build()
|
||||
File_gost_proto = out.File
|
||||
file_gost_proto_rawDesc = nil
|
||||
file_gost_proto_goTypes = nil
|
||||
file_gost_proto_depIdxs = nil
|
||||
}
|
10
pkg/common/util/grpc/proto/gost.proto
Normal file
10
pkg/common/util/grpc/proto/gost.proto
Normal file
@ -0,0 +1,10 @@
|
||||
syntax = "proto3";
|
||||
option go_package = "github.com/go-gost/gost/pkg/common/util/grpc/proto";
|
||||
|
||||
message Chunk {
|
||||
bytes data = 1;
|
||||
}
|
||||
|
||||
service GostTunel {
|
||||
rpc Tunnel (stream Chunk) returns (stream Chunk);
|
||||
}
|
133
pkg/common/util/grpc/proto/gost_grpc.pb.go
Normal file
133
pkg/common/util/grpc/proto/gost_grpc.pb.go
Normal file
@ -0,0 +1,133 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
|
||||
// GostTunelClient is the client API for GostTunel service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type GostTunelClient interface {
|
||||
Tunnel(ctx context.Context, opts ...grpc.CallOption) (GostTunel_TunnelClient, error)
|
||||
}
|
||||
|
||||
type gostTunelClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewGostTunelClient(cc grpc.ClientConnInterface) GostTunelClient {
|
||||
return &gostTunelClient{cc}
|
||||
}
|
||||
|
||||
func (c *gostTunelClient) Tunnel(ctx context.Context, opts ...grpc.CallOption) (GostTunel_TunnelClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &GostTunel_ServiceDesc.Streams[0], "/GostTunel/Tunnel", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &gostTunelTunnelClient{stream}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type GostTunel_TunnelClient interface {
|
||||
Send(*Chunk) error
|
||||
Recv() (*Chunk, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type gostTunelTunnelClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *gostTunelTunnelClient) Send(m *Chunk) error {
|
||||
return x.ClientStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *gostTunelTunnelClient) Recv() (*Chunk, error) {
|
||||
m := new(Chunk)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// GostTunelServer is the server API for GostTunel service.
|
||||
// All implementations must embed UnimplementedGostTunelServer
|
||||
// for forward compatibility
|
||||
type GostTunelServer interface {
|
||||
Tunnel(GostTunel_TunnelServer) error
|
||||
mustEmbedUnimplementedGostTunelServer()
|
||||
}
|
||||
|
||||
// UnimplementedGostTunelServer must be embedded to have forward compatible implementations.
|
||||
type UnimplementedGostTunelServer struct {
|
||||
}
|
||||
|
||||
func (UnimplementedGostTunelServer) Tunnel(GostTunel_TunnelServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Tunnel not implemented")
|
||||
}
|
||||
func (UnimplementedGostTunelServer) mustEmbedUnimplementedGostTunelServer() {}
|
||||
|
||||
// UnsafeGostTunelServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to GostTunelServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeGostTunelServer interface {
|
||||
mustEmbedUnimplementedGostTunelServer()
|
||||
}
|
||||
|
||||
func RegisterGostTunelServer(s grpc.ServiceRegistrar, srv GostTunelServer) {
|
||||
s.RegisterService(&GostTunel_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _GostTunel_Tunnel_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
return srv.(GostTunelServer).Tunnel(&gostTunelTunnelServer{stream})
|
||||
}
|
||||
|
||||
type GostTunel_TunnelServer interface {
|
||||
Send(*Chunk) error
|
||||
Recv() (*Chunk, error)
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type gostTunelTunnelServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *gostTunelTunnelServer) Send(m *Chunk) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *gostTunelTunnelServer) Recv() (*Chunk, error) {
|
||||
m := new(Chunk)
|
||||
if err := x.ServerStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// GostTunel_ServiceDesc is the grpc.ServiceDesc for GostTunel service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var GostTunel_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "GostTunel",
|
||||
HandlerType: (*GostTunelServer)(nil),
|
||||
Methods: []grpc.MethodDesc{},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "Tunnel",
|
||||
Handler: _GostTunel_Tunnel_Handler,
|
||||
ServerStreams: true,
|
||||
ClientStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "gost.proto",
|
||||
}
|
3
pkg/common/util/grpc/proto/protoc.sh
Executable file
3
pkg/common/util/grpc/proto/protoc.sh
Executable file
@ -0,0 +1,3 @@
|
||||
protoc --go_out=. --go_opt=paths=source_relative \
|
||||
--go-grpc_out=. --go-grpc_opt=paths=source_relative \
|
||||
gost.proto
|
92
pkg/dialer/grpc/conn.go
Normal file
92
pkg/dialer/grpc/conn.go
Normal file
@ -0,0 +1,92 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
pb "github.com/go-gost/gost/pkg/common/util/grpc/proto"
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
c pb.GostTunel_TunnelClient
|
||||
rb []byte
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func (c *conn) Read(b []byte) (n int, err error) {
|
||||
select {
|
||||
case <-c.c.Context().Done():
|
||||
err = c.c.Context().Err()
|
||||
return
|
||||
case <-c.closed:
|
||||
err = io.ErrClosedPipe
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if len(c.rb) == 0 {
|
||||
chunk, err := c.c.Recv()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.rb = chunk.Data
|
||||
}
|
||||
|
||||
n = copy(b, c.rb)
|
||||
c.rb = c.rb[n:]
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) Write(b []byte) (n int, err error) {
|
||||
select {
|
||||
case <-c.c.Context().Done():
|
||||
err = c.c.Context().Err()
|
||||
return
|
||||
case <-c.closed:
|
||||
err = io.ErrClosedPipe
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if err = c.c.Send(&pb.Chunk{
|
||||
Data: b,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
n = len(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
select {
|
||||
case <-c.closed:
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *conn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "grpc", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "grpc", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "grpc", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
140
pkg/dialer/grpc/dialer.go
Normal file
140
pkg/dialer/grpc/dialer.go
Normal file
@ -0,0 +1,140 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
pb "github.com/go-gost/gost/pkg/common/util/grpc/proto"
|
||||
"github.com/go-gost/gost/pkg/dialer"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/backoff"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.RegisterDialer("grpc", NewDialer)
|
||||
}
|
||||
|
||||
type grpcDialer struct {
|
||||
clients map[string]pb.GostTunelClient
|
||||
clientMutex sync.Mutex
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options dialer.Options
|
||||
}
|
||||
|
||||
func NewDialer(opts ...dialer.Option) dialer.Dialer {
|
||||
options := dialer.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
return &grpcDialer{
|
||||
clients: make(map[string]pb.GostTunelClient),
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *grpcDialer) Init(md md.Metadata) (err error) {
|
||||
return d.parseMetadata(md)
|
||||
}
|
||||
|
||||
// Multiplex implements dialer.Multiplexer interface.
|
||||
func (d *grpcDialer) Multiplex() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) {
|
||||
remoteAddr, err := net.ResolveTCPAddr("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d.clientMutex.Lock()
|
||||
defer d.clientMutex.Unlock()
|
||||
|
||||
client, ok := d.clients[addr]
|
||||
if !ok {
|
||||
var options dialer.DialOptions
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
host := d.md.host
|
||||
if host == "" {
|
||||
host = options.Host
|
||||
}
|
||||
|
||||
grpcOpts := []grpc.DialOption{
|
||||
grpc.WithBlock(),
|
||||
grpc.WithContextDialer(func(c context.Context, s string) (net.Conn, error) {
|
||||
return d.dial(ctx, "tcp", s, &options)
|
||||
}),
|
||||
grpc.WithAuthority(host),
|
||||
grpc.WithConnectParams(grpc.ConnectParams{
|
||||
Backoff: backoff.DefaultConfig,
|
||||
MinConnectTimeout: 10 * time.Second,
|
||||
}),
|
||||
}
|
||||
if !d.md.insecure {
|
||||
grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(credentials.NewTLS(d.options.TLSConfig)))
|
||||
} else {
|
||||
grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
}
|
||||
|
||||
cc, err := grpc.DialContext(ctx, addr, grpcOpts...)
|
||||
if err != nil {
|
||||
d.logger.Error(err)
|
||||
return nil, err
|
||||
}
|
||||
client = pb.NewGostTunelClient(cc)
|
||||
d.clients[addr] = client
|
||||
}
|
||||
|
||||
cli, err := client.Tunnel(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &conn{
|
||||
c: cli,
|
||||
localAddr: &net.TCPAddr{},
|
||||
remoteAddr: remoteAddr,
|
||||
closed: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *grpcDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) {
|
||||
dial := opts.DialFunc
|
||||
if dial != nil {
|
||||
conn, err := dial(ctx, addr)
|
||||
if err != nil {
|
||||
d.logger.Error(err)
|
||||
} else {
|
||||
d.logger.WithFields(map[string]interface{}{
|
||||
"src": conn.LocalAddr().String(),
|
||||
"dst": addr,
|
||||
}).Debug("dial with dial func")
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
var netd net.Dialer
|
||||
conn, err := netd.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
d.logger.Error(err)
|
||||
} else {
|
||||
d.logger.WithFields(map[string]interface{}{
|
||||
"src": conn.LocalAddr().String(),
|
||||
"dst": addr,
|
||||
}).Debugf("dial direct %s/%s", addr, network)
|
||||
}
|
||||
return conn, err
|
||||
}
|
22
pkg/dialer/grpc/metadata.go
Normal file
22
pkg/dialer/grpc/metadata.go
Normal file
@ -0,0 +1,22 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
mdata "github.com/go-gost/gost/pkg/metadata"
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
insecure bool
|
||||
host string
|
||||
}
|
||||
|
||||
func (d *grpcDialer) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
insecure = "grpcInsecure"
|
||||
host = "host"
|
||||
)
|
||||
|
||||
d.md.insecure = mdata.GetBool(md, insecure)
|
||||
d.md.host = mdata.GetString(md, host)
|
||||
|
||||
return
|
||||
}
|
@ -53,11 +53,6 @@ func (d *http2Dialer) Multiplex() bool {
|
||||
}
|
||||
|
||||
func (d *http2Dialer) Dial(ctx context.Context, address string, opts ...dialer.DialOption) (net.Conn, error) {
|
||||
options := &dialer.DialOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
raddr, err := net.ResolveTCPAddr("tcp", address)
|
||||
if err != nil {
|
||||
d.logger.Error(err)
|
||||
@ -69,11 +64,16 @@ func (d *http2Dialer) Dial(ctx context.Context, address string, opts ...dialer.D
|
||||
|
||||
client, ok := d.clients[address]
|
||||
if !ok {
|
||||
options := dialer.DialOptions{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
|
||||
client = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: d.options.TLSConfig,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return d.dial(ctx, network, addr, options)
|
||||
return d.dial(ctx, network, addr, &options)
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
|
@ -36,11 +36,18 @@ func LoggerOption(logger logger.Logger) Option {
|
||||
}
|
||||
|
||||
type DialOptions struct {
|
||||
Host string
|
||||
DialFunc func(ctx context.Context, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
type DialOption func(opts *DialOptions)
|
||||
|
||||
func HostDialOption(host string) DialOption {
|
||||
return func(opts *DialOptions) {
|
||||
opts.Host = host
|
||||
}
|
||||
}
|
||||
|
||||
func DialFuncDialOption(dialf func(ctx context.Context, addr string) (net.Conn, error)) DialOption {
|
||||
return func(opts *DialOptions) {
|
||||
opts.DialFunc = dialf
|
||||
|
101
pkg/listener/grpc/listener.go
Normal file
101
pkg/listener/grpc/listener.go
Normal file
@ -0,0 +1,101 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
pb "github.com/go-gost/gost/pkg/common/util/grpc/proto"
|
||||
"github.com/go-gost/gost/pkg/listener"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
md "github.com/go-gost/gost/pkg/metadata"
|
||||
"github.com/go-gost/gost/pkg/registry"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.RegisterListener("grpc", NewListener)
|
||||
}
|
||||
|
||||
type grpcListener struct {
|
||||
addr net.Addr
|
||||
server *grpc.Server
|
||||
cqueue chan net.Conn
|
||||
errChan chan error
|
||||
md metadata
|
||||
logger logger.Logger
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return &grpcListener{
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *grpcListener) Init(md md.Metadata) (err error) {
|
||||
if err = l.parseMetadata(md); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
laddr, err := net.ResolveTCPAddr("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln, err := net.ListenTCP("tcp", laddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var opts []grpc.ServerOption
|
||||
if !l.md.insecure {
|
||||
opts = append(opts, grpc.Creds(credentials.NewTLS(l.options.TLSConfig)))
|
||||
}
|
||||
|
||||
l.server = grpc.NewServer(opts...)
|
||||
l.addr = ln.Addr()
|
||||
l.cqueue = make(chan net.Conn, l.md.backlog)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
pb.RegisterGostTunelServer(l.server, &server{
|
||||
cqueue: l.cqueue,
|
||||
localAddr: l.addr,
|
||||
logger: l.options.Logger,
|
||||
})
|
||||
|
||||
go func() {
|
||||
err := l.server.Serve(ln)
|
||||
if err != nil {
|
||||
l.errChan <- err
|
||||
}
|
||||
close(l.errChan)
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *grpcListener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.cqueue:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *grpcListener) Close() error {
|
||||
l.server.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *grpcListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
29
pkg/listener/grpc/metadata.go
Normal file
29
pkg/listener/grpc/metadata.go
Normal file
@ -0,0 +1,29 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
mdata "github.com/go-gost/gost/pkg/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBacklog = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
backlog int
|
||||
insecure bool
|
||||
}
|
||||
|
||||
func (l *grpcListener) parseMetadata(md mdata.Metadata) (err error) {
|
||||
const (
|
||||
backlog = "backlog"
|
||||
insecure = "grpcInsecure"
|
||||
)
|
||||
|
||||
l.md.backlog = mdata.GetInt(md, backlog)
|
||||
if l.md.backlog <= 0 {
|
||||
l.md.backlog = defaultBacklog
|
||||
}
|
||||
|
||||
l.md.insecure = mdata.GetBool(md, insecure)
|
||||
return
|
||||
}
|
120
pkg/listener/grpc/server.go
Normal file
120
pkg/listener/grpc/server.go
Normal file
@ -0,0 +1,120 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
pb "github.com/go-gost/gost/pkg/common/util/grpc/proto"
|
||||
"github.com/go-gost/gost/pkg/logger"
|
||||
)
|
||||
|
||||
type server struct {
|
||||
cqueue chan net.Conn
|
||||
localAddr net.Addr
|
||||
pb.UnimplementedGostTunelServer
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func (s *server) Tunnel(srv pb.GostTunel_TunnelServer) error {
|
||||
c := &conn{
|
||||
s: srv,
|
||||
localAddr: s.localAddr,
|
||||
remoteAddr: &net.TCPAddr{},
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
|
||||
select {
|
||||
case s.cqueue <- c:
|
||||
default:
|
||||
c.Close()
|
||||
s.logger.Warnf("connection queue is full, client discarded")
|
||||
}
|
||||
|
||||
<-c.closed
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
s pb.GostTunel_TunnelServer
|
||||
rb []byte
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func (c *conn) Read(b []byte) (n int, err error) {
|
||||
select {
|
||||
case <-c.s.Context().Done():
|
||||
err = c.s.Context().Err()
|
||||
return
|
||||
case <-c.closed:
|
||||
err = io.ErrClosedPipe
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if len(c.rb) == 0 {
|
||||
chunk, err := c.s.Recv()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.rb = chunk.Data
|
||||
}
|
||||
|
||||
n = copy(b, c.rb)
|
||||
c.rb = c.rb[n:]
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) Write(b []byte) (n int, err error) {
|
||||
select {
|
||||
case <-c.s.Context().Done():
|
||||
err = c.s.Context().Err()
|
||||
return
|
||||
case <-c.closed:
|
||||
err = io.ErrClosedPipe
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if err = c.s.Send(&pb.Chunk{
|
||||
Data: b,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
n = len(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
select {
|
||||
case <-c.closed:
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *conn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "grpc", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "grpc", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "grpc", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
@ -14,20 +14,20 @@ func init() {
|
||||
}
|
||||
|
||||
type tcpListener struct {
|
||||
addr string
|
||||
md metadata
|
||||
net.Listener
|
||||
logger logger.Logger
|
||||
logger logger.Logger
|
||||
md metadata
|
||||
options listener.Options
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) listener.Listener {
|
||||
options := &listener.Options{}
|
||||
options := listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
opt(&options)
|
||||
}
|
||||
return &tcpListener{
|
||||
addr: options.Addr,
|
||||
logger: options.Logger,
|
||||
logger: options.Logger,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
@ -36,7 +36,7 @@ func (l *tcpListener) Init(md md.Metadata) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
laddr, err := net.ResolveTCPAddr("tcp", l.addr)
|
||||
laddr, err := net.ResolveTCPAddr("tcp", l.options.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user