diff --git a/cmd/gost/main.go b/cmd/gost/main.go index b092a25..f863aed 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -1,6 +1,7 @@ package main import ( + "crypto/tls" "flag" "fmt" "net/http" @@ -8,6 +9,7 @@ import ( "os" "runtime" + tls_util "github.com/go-gost/gost/pkg/common/util/tls" "github.com/go-gost/gost/pkg/config" "github.com/go-gost/gost/pkg/logger" ) @@ -64,6 +66,8 @@ func main() { normConfig(cfg) + log = logFromConfig(cfg.Log) + if outputCfgFile != "" { if err := cfg.WriteFile(outputCfgFile); err != nil { log.Fatal(err) @@ -71,8 +75,6 @@ func main() { os.Exit(0) } - log = logFromConfig(cfg.Log) - if cfg.Profiling != nil && cfg.Profiling.Enabled { go func() { addr := cfg.Profiling.Addr @@ -83,6 +85,31 @@ func main() { log.Fatal(http.ListenAndServe(addr, nil)) }() } + + tlsCfg := cfg.TLS + if tlsCfg == nil { + tlsCfg = &config.TLSConfig{ + Cert: "cert.pem", + Key: "key.pem", + CA: "ca.crt", + } + } + tlsConfig, err := tls_util.LoadTLSConfig(tlsCfg.Cert, tlsCfg.Key, tlsCfg.CA) + if err != nil { + // generate random self-signed certificate. + cert, err := tls_util.GenCertificate() + if err != nil { + log.Fatal(err) + } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + log.Warn("load TLS certificate files failed, use random generated certificate") + } else { + log.Debug("load TLS certificate files OK") + } + tls_util.DefaultConfig = tlsConfig + services := buildService(cfg) for _, svc := range services { go svc.Run() diff --git a/cmd/gost/out.yml b/cmd/gost/out.yml index bda9cb0..2a46c88 100644 --- a/cmd/gost/out.yml +++ b/cmd/gost/out.yml @@ -1,40 +1,27 @@ +log: + level: debug services: - name: service-0 - url: ss://abc:123@:18338/:18338 - addr: :18338 + url: udp://:10053/192.168.8.8:53,192.168.8.1:53 + addr: :10053 chain: chain-0 listener: - type: tcp - metadata: - users: - - abc:123 + type: udp handler: - type: tcp - metadata: - users: - - abc:123 + type: udp forwarder: targets: - - :18338 + - 192.168.8.8:53 + - 192.168.8.1:53 chains: - name: chain-0 hops: - name: hop-0 nodes: - name: node-0 - url: socks://abc:123@:11080?type=abc&key=value - addr: :11080 + url: relay://:8420 + addr: :8420 dialer: type: tcp - metadata: - key: value - type: abc - user: - - abc:123 connector: - type: socks - metadata: - key: value - type: abc - user: - - abc:123 + type: relay diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 82f2022..a35a487 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -20,6 +20,7 @@ import ( _ "github.com/go-gost/gost/pkg/handler/forward/remote" _ "github.com/go-gost/gost/pkg/handler/http" _ "github.com/go-gost/gost/pkg/handler/relay" + _ "github.com/go-gost/gost/pkg/handler/sni" _ "github.com/go-gost/gost/pkg/handler/socks/v4" _ "github.com/go-gost/gost/pkg/handler/socks/v5" _ "github.com/go-gost/gost/pkg/handler/ss" diff --git a/go.mod b/go.mod index 956ef36..9425e0f 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/go-gost/gosocks4 v0.0.1 github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 + github.com/go-gost/tls-dissector v0.0.2-0.20211125135007-2b5d5bd9c07e github.com/gobwas/glob v0.2.3 github.com/golang/snappy v0.0.3 github.com/google/gopacket v1.1.19 // indirect diff --git a/go.sum b/go.sum index c3b2c9f..08621e8 100644 --- a/go.sum +++ b/go.sum @@ -113,10 +113,12 @@ github.com/go-gost/gosocks4 v0.0.1 h1:+k1sec8HlELuQV7rWftIkmy8UijzUt2I6t+iMPlGB2 github.com/go-gost/gosocks4 v0.0.1/go.mod h1:3B6L47HbU/qugDg4JnoFPHgJXE43Inz8Bah1QaN9qCc= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09 h1:A95M6UWcfZgOuJkQ7QLfG0Hs5peWIUSysCDNz4pfe04= github.com/go-gost/gosocks5 v0.3.1-0.20211109033403-d894d75b7f09/go.mod h1:1G6I7HP7VFVxveGkoK8mnprnJqSqJjdcASKsdUn4Pp4= -github.com/go-gost/relay v0.1.1-0.20211122150329-54ee406ea49d h1:rzGVzkSvxuDZg8PoYmOR+tvcAg9Dr8whgV19kzuO4YA= -github.com/go-gost/relay v0.1.1-0.20211122150329-54ee406ea49d/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7 h1:itaaJhQJ19kUXEB4Igb0EbY8m+1Py2AaNNSBds/9gk4= github.com/go-gost/relay v0.1.1-0.20211123134818-8ef7fd81ffd7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= +github.com/go-gost/tls-dissector v0.0.1 h1:cySZTSa7o5aOg/bqZXVFzi3NMudsiLwkzArFoxjUWCY= +github.com/go-gost/tls-dissector v0.0.1/go.mod h1:8CmRTbp7v4Ebd/lewu/Y/4dEJOP9ke6nwumyJ9WlOec= +github.com/go-gost/tls-dissector v0.0.2-0.20211125135007-2b5d5bd9c07e h1:73NGqAs22ey3wJkIYVD/ACEoovuIuOlEzQTEoqrO5+U= +github.com/go-gost/tls-dissector v0.0.2-0.20211125135007-2b5d5bd9c07e/go.mod h1:/9QfdewqmHdaE362Hv5nDaSWLx3pCmtD870d6GaquXs= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 h1:p104kn46Q8WdvHunIJ9dAyjPVtrBPhSr3KT2yUst43I= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= diff --git a/pkg/common/util/tls/tls.go b/pkg/common/util/tls/tls.go index bbce7bd..3272958 100644 --- a/pkg/common/util/tls/tls.go +++ b/pkg/common/util/tls/tls.go @@ -1,10 +1,21 @@ package tls import ( + "crypto/rand" + "crypto/rsa" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" "io/ioutil" + "math/big" + "time" +) + +var ( + // DefaultConfig is a default TLS config for global use. + DefaultConfig *tls.Config ) // LoadTLSConfig loads the certificate from cert & key files and optional client CA file. @@ -38,3 +49,48 @@ func loadCA(caFile string) (cp *x509.CertPool, err error) { } return } + +func GenCertificate() (cert tls.Certificate, err error) { + rawCert, rawKey, err := generateKeyPair() + if err != nil { + return + } + return tls.X509KeyPair(rawCert, rawKey) +} + +func generateKeyPair() (rawCert, rawKey []byte, err error) { + // Create private key and self-signed certificate + // Adapted from https://golang.org/src/crypto/tls/generate_cert.go + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return + } + validFor := time.Hour * 24 * 365 * 10 // ten years + notBefore := time.Now() + notAfter := notBefore.Add(validFor) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit) + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"gost"}, + CommonName: "gost.run", + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return + } + + rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + rawKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + return +} diff --git a/pkg/config/config.go b/pkg/config/config.go index cd9a54d..d86c60f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -31,6 +31,12 @@ type ProfilingConfig struct { Enabled bool } +type TLSConfig struct { + Cert string + Key string + CA string +} + type SelectorConfig struct { Strategy string MaxFails int @@ -59,12 +65,12 @@ type ForwarderConfig struct { type DialerConfig struct { Type string - Metadata map[string]interface{} + Metadata map[string]interface{} `yaml:",omitempty"` } type ConnectorConfig struct { Type string - Metadata map[string]interface{} + Metadata map[string]interface{} `yaml:",omitempty"` } type ServiceConfig struct { @@ -102,6 +108,7 @@ type NodeConfig struct { type Config struct { Log *LogConfig `yaml:",omitempty"` Profiling *ProfilingConfig `yaml:",omitempty"` + TLS *TLSConfig `yaml:",omitempty"` Services []*ServiceConfig Chains []*ChainConfig `yaml:",omitempty"` Bypasses []*BypassConfig `yaml:",omitempty"` diff --git a/pkg/handler/auto/handler.go b/pkg/handler/auto/handler.go index 731de23..e1de705 100644 --- a/pkg/handler/auto/handler.go +++ b/pkg/handler/auto/handler.go @@ -102,16 +102,16 @@ func (h *autoHandler) Handle(ctx context.Context, conn net.Conn) { return } - cc := handler.NewBufferReaderConn(conn, br) + conn = handler.NewBufferReaderConn(conn, br) switch b[0] { case gosocks4.Ver4: // socks4 - h.socks4Handler.Handle(ctx, cc) + h.socks4Handler.Handle(ctx, conn) case gosocks5.Ver5: // socks5 - h.socks5Handler.Handle(ctx, cc) + h.socks5Handler.Handle(ctx, conn) case relay.Version1: // relay - h.relayHandler.Handle(ctx, cc) + h.relayHandler.Handle(ctx, conn) default: // http - h.httpHandler.Handle(ctx, cc) + h.httpHandler.Handle(ctx, conn) } } diff --git a/pkg/handler/sni/handler.go b/pkg/handler/sni/handler.go new file mode 100644 index 0000000..f03c13a --- /dev/null +++ b/pkg/handler/sni/handler.go @@ -0,0 +1,183 @@ +package sni + +import ( + "bufio" + "context" + "encoding/base64" + "encoding/binary" + "errors" + "hash/crc32" + "io" + "net" + "time" + + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/handler" + http_handler "github.com/go-gost/gost/pkg/handler/http" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + dissector "github.com/go-gost/tls-dissector" +) + +func init() { + registry.RegisterHandler("sni", NewHandler) +} + +type sniHandler struct { + httpHandler handler.Handler + chain *chain.Chain + bypass bypass.Bypass + logger logger.Logger + md metadata +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := &handler.Options{} + for _, opt := range opts { + opt(options) + } + + log := options.Logger + if log == nil { + log = logger.Default() + } + + h := &sniHandler{ + bypass: options.Bypass, + logger: log, + } + + v := append(opts, + handler.LoggerOption(log.WithFields(map[string]interface{}{"type": "http"}))) + h.httpHandler = http_handler.NewHandler(v...) + + return h +} + +func (h *sniHandler) Init(md md.Metadata) (err error) { + if err = h.parseMetadata(md); err != nil { + return + } + if err = h.httpHandler.Init(md); err != nil { + return + } + + return nil +} + +// WithChain implements chain.Chainable interface +func (h *sniHandler) WithChain(chain *chain.Chain) { + h.chain = chain +} + +func (h *sniHandler) Handle(ctx context.Context, conn net.Conn) { + defer conn.Close() + + start := time.Now() + h.logger = h.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + + h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + br := bufio.NewReader(conn) + hdr, err := br.Peek(dissector.RecordHeaderLen) + if err != nil { + h.logger.Error(err) + return + } + + conn = handler.NewBufferReaderConn(conn, br) + + if hdr[0] != dissector.Handshake { + // We assume it is an HTTP request + h.httpHandler.Handle(ctx, conn) + return + } + + host, err := h.decodeHost(conn) + if err != nil { + h.logger.Error(err) + return + } + target := net.JoinHostPort(host, "443") + + h.logger = h.logger.WithFields(map[string]interface{}{ + "dst": target, + }) + h.logger.Infof("%s >> %s", conn.RemoteAddr(), target) + + if h.bypass != nil && h.bypass.Contains(target) { + h.logger.Info("bypass: ", target) + return + } + + r := (&chain.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + cc, err := r.Dial(ctx, "tcp", target) + if err != nil { + return + } + defer cc.Close() + + t := time.Now() + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), target) + handler.Transport(conn, cc) + h.logger. + WithFields(map[string]interface{}{ + "duration": time.Since(t), + }). + Infof("%s >-< %s", conn.RemoteAddr(), target) +} + +func (h *sniHandler) decodeHost(r io.Reader) (host string, err error) { + record, err := dissector.ReadRecord(r) + if err != nil { + return + } + clientHello := &dissector.ClientHelloMsg{} + if err = clientHello.Decode(record.Opaque); err != nil { + return + } + + for _, ext := range clientHello.Extensions { + if ext.Type() == 0xFFFE { + b, _ := ext.Encode() + return h.decodeServerName(string(b)) + } + + if ext.Type() == dissector.ExtServerName { + snExtension := ext.(*dissector.ServerNameExtension) + host = snExtension.Name + } + } + return +} + +func (h *sniHandler) decodeServerName(s string) (string, error) { + b, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return "", err + } + if len(b) < 4 { + return "", errors.New("invalid name") + } + v, err := base64.RawURLEncoding.DecodeString(string(b[4:])) + if err != nil { + return "", err + } + if crc32.ChecksumIEEE(v) != binary.BigEndian.Uint32(b[:4]) { + return "", errors.New("invalid name") + } + return string(v), nil +} diff --git a/pkg/handler/sni/metadata.go b/pkg/handler/sni/metadata.go new file mode 100644 index 0000000..c38f0a5 --- /dev/null +++ b/pkg/handler/sni/metadata.go @@ -0,0 +1,23 @@ +package sni + +import ( + "time" + + md "github.com/go-gost/gost/pkg/metadata" +) + +type metadata struct { + readTimeout time.Duration + retryCount int +} + +func (h *sniHandler) parseMetadata(md md.Metadata) (err error) { + const ( + readTimeout = "readTimeout" + retryCount = "retry" + ) + + h.md.readTimeout = md.GetDuration(readTimeout) + h.md.retryCount = md.GetInt(retryCount) + return +} diff --git a/pkg/listener/tls/listener.go b/pkg/listener/tls/listener.go index ac9465b..614bed1 100644 --- a/pkg/listener/tls/listener.go +++ b/pkg/listener/tls/listener.go @@ -5,7 +5,6 @@ import ( "net" "github.com/go-gost/gost/pkg/common/util" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -55,17 +54,3 @@ func (l *tlsListener) Init(md md.Metadata) (err error) { l.Listener = ln return } - -func (l *tlsListener) parseMetadata(md md.Metadata) (err error) { - l.md.tlsConfig, err = tls_util.LoadTLSConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - ) - if err != nil { - return - } - - l.md.keepAlivePeriod = md.GetDuration(keepAlivePeriod) - return -} diff --git a/pkg/listener/tls/metadata.go b/pkg/listener/tls/metadata.go index d544654..bf6f2c6 100644 --- a/pkg/listener/tls/metadata.go +++ b/pkg/listener/tls/metadata.go @@ -3,16 +3,39 @@ package tls import ( "crypto/tls" "time" -) -const ( - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - keepAlivePeriod = "keepAlivePeriod" + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + md "github.com/go-gost/gost/pkg/metadata" ) type metadata struct { tlsConfig *tls.Config keepAlivePeriod time.Duration } + +func (l *tlsListener) parseMetadata(md md.Metadata) (err error) { + const ( + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + keepAlivePeriod = "keepAlivePeriod" + ) + + if md.GetString(certFile) != "" || + md.GetString(keyFile) != "" || + md.GetString(caFile) != "" { + l.md.tlsConfig, err = tls_util.LoadTLSConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) + if err != nil { + return + } + } else { + l.md.tlsConfig = tls_util.DefaultConfig + } + + l.md.keepAlivePeriod = md.GetDuration(keepAlivePeriod) + return +}