package grpc import ( "context" "net" "sync" "time" pb "github.com/go-gost/gost/pkg/common/util/grpc/proto" "github.com/go-gost/gost/pkg/dialer" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" "google.golang.org/grpc" "google.golang.org/grpc/backoff" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" ) func init() { registry.RegisterDialer("grpc", NewDialer) } type grpcDialer struct { clients map[string]pb.GostTunelClient clientMutex sync.Mutex logger logger.Logger md metadata options dialer.Options } func NewDialer(opts ...dialer.Option) dialer.Dialer { options := dialer.Options{} for _, opt := range opts { opt(&options) } return &grpcDialer{ clients: make(map[string]pb.GostTunelClient), logger: options.Logger, options: options, } } func (d *grpcDialer) Init(md md.Metadata) (err error) { return d.parseMetadata(md) } // Multiplex implements dialer.Multiplexer interface. func (d *grpcDialer) Multiplex() bool { return true } func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { remoteAddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { return nil, err } d.clientMutex.Lock() defer d.clientMutex.Unlock() client, ok := d.clients[addr] if !ok { var options dialer.DialOptions for _, opt := range opts { opt(&options) } host := d.md.host if host == "" { host = options.Host } grpcOpts := []grpc.DialOption{ grpc.WithBlock(), grpc.WithContextDialer(func(c context.Context, s string) (net.Conn, error) { return d.dial(ctx, "tcp", s, &options) }), grpc.WithAuthority(host), grpc.WithConnectParams(grpc.ConnectParams{ Backoff: backoff.DefaultConfig, MinConnectTimeout: 10 * time.Second, }), } if !d.md.insecure { grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(credentials.NewTLS(d.options.TLSConfig))) } else { grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) } cc, err := grpc.DialContext(ctx, addr, grpcOpts...) if err != nil { d.logger.Error(err) return nil, err } client = pb.NewGostTunelClient(cc) d.clients[addr] = client } cli, err := client.Tunnel(ctx) if err != nil { return nil, err } return &conn{ c: cli, localAddr: &net.TCPAddr{}, remoteAddr: remoteAddr, closed: make(chan struct{}), }, nil } func (d *grpcDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) { dial := opts.DialFunc if dial != nil { conn, err := dial(ctx, addr) if err != nil { d.logger.Error(err) } else { d.logger.WithFields(map[string]interface{}{ "src": conn.LocalAddr().String(), "dst": addr, }).Debug("dial with dial func") } return conn, err } var netd net.Dialer conn, err := netd.DialContext(ctx, network, addr) if err != nil { d.logger.Error(err) } else { d.logger.WithFields(map[string]interface{}{ "src": conn.LocalAddr().String(), "dst": addr, }).Debugf("dial direct %s/%s", addr, network) } return conn, err }