diff --git a/cmd/gost/cmd.go b/cmd/gost/cmd.go index 901c933..f8b80bf 100644 --- a/cmd/gost/cmd.go +++ b/cmd/gost/cmd.go @@ -34,8 +34,15 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) { if v := os.Getenv("GOST_PROFILING"); v != "" { cfg.Profiling = &config.ProfilingConfig{ - Addr: v, - Enabled: true, + Addr: v, + Enable: true, + } + } + if v := os.Getenv("GOST_METRICS"); v != "" { + cfg.Metrics = &config.MetricsConfig{ + Addr: v, + Path: "/metrics", + Enable: true, } } diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 2573d32..e7b56b6 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -80,7 +80,7 @@ func main() { os.Exit(0) } - if cfg.Profiling != nil && cfg.Profiling.Enabled { + if cfg.Profiling != nil && cfg.Profiling.Enable { go func() { addr := cfg.Profiling.Addr if addr == "" { diff --git a/cmd/gost/tls.go b/cmd/gost/tls.go index 3f83cb8..55721df 100644 --- a/cmd/gost/tls.go +++ b/cmd/gost/tls.go @@ -72,12 +72,15 @@ func generateKeyPair() (rawCert, rawKey []byte, err error) { notBefore := time.Now() notAfter := notBefore.Add(validFor) serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return + } + template := x509.Certificate{ SerialNumber: serialNumber, Subject: pkix.Name{ Organization: []string{"gost"}, - CommonName: "gost.run", }, NotBefore: notBefore, NotAfter: notAfter, @@ -86,13 +89,19 @@ func generateKeyPair() (rawCert, rawKey []byte, err error) { ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, } + + template.DNSNames = append(template.DNSNames, "gost.run") derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { return } rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - rawKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return + } + rawKey = pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}) return } diff --git a/pkg/config/config.go b/pkg/config/config.go index 41f4f32..bb83855 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -49,8 +49,8 @@ type LogConfig struct { } type ProfilingConfig struct { - Addr string `json:"addr"` - Enabled bool `json:"enabled"` + Addr string `json:"addr"` + Enable bool `json:"enable"` } type APIConfig struct { @@ -61,6 +61,12 @@ type APIConfig struct { Auther string `yaml:",omitempty" json:"auther,omitempty"` } +type MetricsConfig struct { + Enable bool `json:"enable"` + Addr string `json:"addr"` + Path string `json:"path"` +} + type TLSConfig struct { CertFile string `yaml:"certFile,omitempty" json:"certFile,omitempty"` KeyFile string `yaml:"keyFile,omitempty" json:"keyFile,omitempty"` @@ -211,6 +217,7 @@ type Config struct { Log *LogConfig `yaml:",omitempty" json:"log,omitempty"` Profiling *ProfilingConfig `yaml:",omitempty" json:"profiling,omitempty"` API *APIConfig `yaml:",omitempty" json:"api,omitempty"` + Metrics *MetricsConfig `yaml:",omitempty" json:"metrics,omitempty"` } func (c *Config) Load() error { diff --git a/pkg/dialer/grpc/dialer.go b/pkg/dialer/grpc/dialer.go index 4b46583..c8bdd50 100644 --- a/pkg/dialer/grpc/dialer.go +++ b/pkg/dialer/grpc/dialer.go @@ -8,7 +8,6 @@ import ( pb "github.com/go-gost/gost/pkg/common/util/grpc/proto" "github.com/go-gost/gost/pkg/dialer" - "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" "google.golang.org/grpc" @@ -24,7 +23,6 @@ func init() { type grpcDialer struct { clients map[string]pb.GostTunelClient clientMutex sync.Mutex - logger logger.Logger md metadata options dialer.Options } @@ -37,7 +35,6 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer { return &grpcDialer{ clients: make(map[string]pb.GostTunelClient), - logger: options.Logger, options: options, } } @@ -71,9 +68,12 @@ func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO if host == "" { host = options.Host } + if h, _, _ := net.SplitHostPort(host); h != "" { + host = h + } grpcOpts := []grpc.DialOption{ - grpc.WithBlock(), + // grpc.WithBlock(), grpc.WithContextDialer(func(c context.Context, s string) (net.Conn, error) { return d.dial(ctx, "tcp", s, &options) }), @@ -82,6 +82,7 @@ func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO Backoff: backoff.DefaultConfig, MinConnectTimeout: 10 * time.Second, }), + grpc.FailOnNonTempDialError(true), } if !d.md.insecure { grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(credentials.NewTLS(d.options.TLSConfig))) @@ -91,7 +92,7 @@ func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO cc, err := grpc.DialContext(ctx, addr, grpcOpts...) if err != nil { - d.logger.Error(err) + d.options.Logger.Error(err) return nil, err } client = pb.NewGostTunelClient(cc) @@ -116,9 +117,9 @@ func (d *grpcDialer) dial(ctx context.Context, network, addr string, opts *diale if dial != nil { conn, err := dial(ctx, addr) if err != nil { - d.logger.Error(err) + d.options.Logger.Error(err) } else { - d.logger.WithFields(map[string]interface{}{ + d.options.Logger.WithFields(map[string]interface{}{ "src": conn.LocalAddr().String(), "dst": addr, }).Debug("dial with dial func") @@ -129,9 +130,9 @@ func (d *grpcDialer) dial(ctx context.Context, network, addr string, opts *diale var netd net.Dialer conn, err := netd.DialContext(ctx, network, addr) if err != nil { - d.logger.Error(err) + d.options.Logger.Error(err) } else { - d.logger.WithFields(map[string]interface{}{ + d.options.Logger.WithFields(map[string]interface{}{ "src": conn.LocalAddr().String(), "dst": addr, }).Debugf("dial direct %s/%s", addr, network)