fix issue#10

This commit is contained in:
ginuerzh
2022-02-13 20:40:37 +08:00
parent a8804ea02d
commit edca3e0a55
5 changed files with 41 additions and 17 deletions

View File

@ -34,8 +34,15 @@ func buildConfigFromCmd(services, nodes stringList) (*config.Config, error) {
if v := os.Getenv("GOST_PROFILING"); v != "" { if v := os.Getenv("GOST_PROFILING"); v != "" {
cfg.Profiling = &config.ProfilingConfig{ cfg.Profiling = &config.ProfilingConfig{
Addr: v, Addr: v,
Enabled: true, Enable: true,
}
}
if v := os.Getenv("GOST_METRICS"); v != "" {
cfg.Metrics = &config.MetricsConfig{
Addr: v,
Path: "/metrics",
Enable: true,
} }
} }

View File

@ -80,7 +80,7 @@ func main() {
os.Exit(0) os.Exit(0)
} }
if cfg.Profiling != nil && cfg.Profiling.Enabled { if cfg.Profiling != nil && cfg.Profiling.Enable {
go func() { go func() {
addr := cfg.Profiling.Addr addr := cfg.Profiling.Addr
if addr == "" { if addr == "" {

View File

@ -72,12 +72,15 @@ func generateKeyPair() (rawCert, rawKey []byte, err error) {
notBefore := time.Now() notBefore := time.Now()
notAfter := notBefore.Add(validFor) notAfter := notBefore.Add(validFor)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return
}
template := x509.Certificate{ template := x509.Certificate{
SerialNumber: serialNumber, SerialNumber: serialNumber,
Subject: pkix.Name{ Subject: pkix.Name{
Organization: []string{"gost"}, Organization: []string{"gost"},
CommonName: "gost.run",
}, },
NotBefore: notBefore, NotBefore: notBefore,
NotAfter: notAfter, NotAfter: notAfter,
@ -86,13 +89,19 @@ func generateKeyPair() (rawCert, rawKey []byte, err error) {
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true, BasicConstraintsValid: true,
} }
template.DNSNames = append(template.DNSNames, "gost.run")
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil { if err != nil {
return return
} }
rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
rawKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return
}
rawKey = pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privBytes})
return return
} }

View File

@ -49,8 +49,8 @@ type LogConfig struct {
} }
type ProfilingConfig struct { type ProfilingConfig struct {
Addr string `json:"addr"` Addr string `json:"addr"`
Enabled bool `json:"enabled"` Enable bool `json:"enable"`
} }
type APIConfig struct { type APIConfig struct {
@ -61,6 +61,12 @@ type APIConfig struct {
Auther string `yaml:",omitempty" json:"auther,omitempty"` Auther string `yaml:",omitempty" json:"auther,omitempty"`
} }
type MetricsConfig struct {
Enable bool `json:"enable"`
Addr string `json:"addr"`
Path string `json:"path"`
}
type TLSConfig struct { type TLSConfig struct {
CertFile string `yaml:"certFile,omitempty" json:"certFile,omitempty"` CertFile string `yaml:"certFile,omitempty" json:"certFile,omitempty"`
KeyFile string `yaml:"keyFile,omitempty" json:"keyFile,omitempty"` KeyFile string `yaml:"keyFile,omitempty" json:"keyFile,omitempty"`
@ -211,6 +217,7 @@ type Config struct {
Log *LogConfig `yaml:",omitempty" json:"log,omitempty"` Log *LogConfig `yaml:",omitempty" json:"log,omitempty"`
Profiling *ProfilingConfig `yaml:",omitempty" json:"profiling,omitempty"` Profiling *ProfilingConfig `yaml:",omitempty" json:"profiling,omitempty"`
API *APIConfig `yaml:",omitempty" json:"api,omitempty"` API *APIConfig `yaml:",omitempty" json:"api,omitempty"`
Metrics *MetricsConfig `yaml:",omitempty" json:"metrics,omitempty"`
} }
func (c *Config) Load() error { func (c *Config) Load() error {

View File

@ -8,7 +8,6 @@ import (
pb "github.com/go-gost/gost/pkg/common/util/grpc/proto" pb "github.com/go-gost/gost/pkg/common/util/grpc/proto"
"github.com/go-gost/gost/pkg/dialer" "github.com/go-gost/gost/pkg/dialer"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -24,7 +23,6 @@ func init() {
type grpcDialer struct { type grpcDialer struct {
clients map[string]pb.GostTunelClient clients map[string]pb.GostTunelClient
clientMutex sync.Mutex clientMutex sync.Mutex
logger logger.Logger
md metadata md metadata
options dialer.Options options dialer.Options
} }
@ -37,7 +35,6 @@ func NewDialer(opts ...dialer.Option) dialer.Dialer {
return &grpcDialer{ return &grpcDialer{
clients: make(map[string]pb.GostTunelClient), clients: make(map[string]pb.GostTunelClient),
logger: options.Logger,
options: options, options: options,
} }
} }
@ -71,9 +68,12 @@ func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO
if host == "" { if host == "" {
host = options.Host host = options.Host
} }
if h, _, _ := net.SplitHostPort(host); h != "" {
host = h
}
grpcOpts := []grpc.DialOption{ grpcOpts := []grpc.DialOption{
grpc.WithBlock(), // grpc.WithBlock(),
grpc.WithContextDialer(func(c context.Context, s string) (net.Conn, error) { grpc.WithContextDialer(func(c context.Context, s string) (net.Conn, error) {
return d.dial(ctx, "tcp", s, &options) return d.dial(ctx, "tcp", s, &options)
}), }),
@ -82,6 +82,7 @@ func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO
Backoff: backoff.DefaultConfig, Backoff: backoff.DefaultConfig,
MinConnectTimeout: 10 * time.Second, MinConnectTimeout: 10 * time.Second,
}), }),
grpc.FailOnNonTempDialError(true),
} }
if !d.md.insecure { if !d.md.insecure {
grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(credentials.NewTLS(d.options.TLSConfig))) grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(credentials.NewTLS(d.options.TLSConfig)))
@ -91,7 +92,7 @@ func (d *grpcDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialO
cc, err := grpc.DialContext(ctx, addr, grpcOpts...) cc, err := grpc.DialContext(ctx, addr, grpcOpts...)
if err != nil { if err != nil {
d.logger.Error(err) d.options.Logger.Error(err)
return nil, err return nil, err
} }
client = pb.NewGostTunelClient(cc) client = pb.NewGostTunelClient(cc)
@ -116,9 +117,9 @@ func (d *grpcDialer) dial(ctx context.Context, network, addr string, opts *diale
if dial != nil { if dial != nil {
conn, err := dial(ctx, addr) conn, err := dial(ctx, addr)
if err != nil { if err != nil {
d.logger.Error(err) d.options.Logger.Error(err)
} else { } else {
d.logger.WithFields(map[string]interface{}{ d.options.Logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(), "src": conn.LocalAddr().String(),
"dst": addr, "dst": addr,
}).Debug("dial with dial func") }).Debug("dial with dial func")
@ -129,9 +130,9 @@ func (d *grpcDialer) dial(ctx context.Context, network, addr string, opts *diale
var netd net.Dialer var netd net.Dialer
conn, err := netd.DialContext(ctx, network, addr) conn, err := netd.DialContext(ctx, network, addr)
if err != nil { if err != nil {
d.logger.Error(err) d.options.Logger.Error(err)
} else { } else {
d.logger.WithFields(map[string]interface{}{ d.options.Logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(), "src": conn.LocalAddr().String(),
"dst": addr, "dst": addr,
}).Debugf("dial direct %s/%s", addr, network) }).Debugf("dial direct %s/%s", addr, network)