diff --git a/cmd/gost/main.go b/cmd/gost/main.go index f863aed..d6d9464 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -1,15 +1,14 @@ package main import ( - "crypto/tls" "flag" "fmt" + "io" "net/http" _ "net/http/pprof" "os" "runtime" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" "github.com/go-gost/gost/pkg/config" "github.com/go-gost/gost/pkg/logger" ) @@ -69,7 +68,18 @@ func main() { log = logFromConfig(cfg.Log) if outputCfgFile != "" { - if err := cfg.WriteFile(outputCfgFile); err != nil { + var w io.Writer + if outputCfgFile == "-" { + w = os.Stdout + } else { + f, err := os.Create(outputCfgFile) + if err != nil { + log.Fatal(err) + } + defer f.Close() + w = f + } + if err := cfg.Write(w); err != nil { log.Fatal(err) } os.Exit(0) @@ -86,29 +96,7 @@ func main() { }() } - tlsCfg := cfg.TLS - if tlsCfg == nil { - tlsCfg = &config.TLSConfig{ - Cert: "cert.pem", - Key: "key.pem", - CA: "ca.crt", - } - } - tlsConfig, err := tls_util.LoadTLSConfig(tlsCfg.Cert, tlsCfg.Key, tlsCfg.CA) - if err != nil { - // generate random self-signed certificate. - cert, err := tls_util.GenCertificate() - if err != nil { - log.Fatal(err) - } - tlsConfig = &tls.Config{ - Certificates: []tls.Certificate{cert}, - } - log.Warn("load TLS certificate files failed, use random generated certificate") - } else { - log.Debug("load TLS certificate files OK") - } - tls_util.DefaultConfig = tlsConfig + buildDefaultTLSConfig(cfg.TLS) services := buildService(cfg) for _, svc := range services { diff --git a/cmd/gost/norm.go b/cmd/gost/norm.go index 1783508..3291ed3 100644 --- a/cmd/gost/norm.go +++ b/cmd/gost/norm.go @@ -111,6 +111,7 @@ func normChain(chain *config.ChainConfig) { if u.User != nil { md["user"] = u.User.String() } + md["serverName"] = u.Host node.Addr = u.Host diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 9526f6f..3fa717e 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -4,6 +4,7 @@ import ( // Register connectors _ "github.com/go-gost/gost/pkg/connector/forward" _ "github.com/go-gost/gost/pkg/connector/http" + _ "github.com/go-gost/gost/pkg/connector/http2" _ "github.com/go-gost/gost/pkg/connector/relay" _ "github.com/go-gost/gost/pkg/connector/sni" _ "github.com/go-gost/gost/pkg/connector/socks/v4" @@ -12,6 +13,8 @@ import ( _ "github.com/go-gost/gost/pkg/connector/ss/udp" // Register dialers + _ "github.com/go-gost/gost/pkg/dialer/ftcp" + _ "github.com/go-gost/gost/pkg/dialer/http2" _ "github.com/go-gost/gost/pkg/dialer/tcp" _ "github.com/go-gost/gost/pkg/dialer/udp" @@ -20,6 +23,7 @@ import ( _ "github.com/go-gost/gost/pkg/handler/forward/local" _ "github.com/go-gost/gost/pkg/handler/forward/remote" _ "github.com/go-gost/gost/pkg/handler/http" + _ "github.com/go-gost/gost/pkg/handler/http2" _ "github.com/go-gost/gost/pkg/handler/redirect" _ "github.com/go-gost/gost/pkg/handler/relay" _ "github.com/go-gost/gost/pkg/handler/sni" diff --git a/cmd/gost/tls.go b/cmd/gost/tls.go new file mode 100644 index 0000000..239d4e0 --- /dev/null +++ b/cmd/gost/tls.go @@ -0,0 +1,98 @@ +package main + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "time" + + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + "github.com/go-gost/gost/pkg/config" +) + +func buildDefaultTLSConfig(cfg *config.TLSConfig) { + if cfg == nil { + cfg = &config.TLSConfig{ + Cert: "cert.pem", + Key: "key.pem", + } + } + + tlsConfig, err := loadConfig(cfg.Cert, cfg.Key) + if err != nil { + // generate random self-signed certificate. + cert, err := genCertificate() + if err != nil { + log.Fatal(err) + } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + log.Warn("load TLS certificate files failed, use random generated certificate") + } else { + log.Debug("load TLS certificate files OK") + } + tls_util.DefaultConfig = tlsConfig +} + +func loadConfig(certFile, keyFile string) (*tls.Config, error) { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + return cfg, nil +} + +func genCertificate() (cert tls.Certificate, err error) { + rawCert, rawKey, err := generateKeyPair() + if err != nil { + return + } + return tls.X509KeyPair(rawCert, rawKey) +} + +func generateKeyPair() (rawCert, rawKey []byte, err error) { + // Create private key and self-signed certificate + // Adapted from https://golang.org/src/crypto/tls/generate_cert.go + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return + } + validFor := time.Hour * 24 * 365 * 10 // ten years + notBefore := time.Now() + notAfter := notBefore.Add(validFor) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit) + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"gost"}, + CommonName: "gost.run", + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return + } + + rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + rawKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + return +} diff --git a/go.mod b/go.mod index ff59bc1..1b838cb 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,7 @@ require ( github.com/xtaci/smux v1.5.15 github.com/xtaci/tcpraw v1.2.25 golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 - golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420 + golang.org/x/net v0.0.0-20211209124913-491a49abca63 gopkg.in/yaml.v2 v2.4.0 ) diff --git a/go.sum b/go.sum index 763385d..868264b 100644 --- a/go.sum +++ b/go.sum @@ -523,6 +523,8 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420 h1:a8jGStKg0XqKDlKqjLrXn0ioF5MH36pT7Z0BRTqLhbk= golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211209124913-491a49abca63 h1:iocB37TsdFuN6IBRZ+ry36wrkoV51/tl5vOWqkcPGvY= +golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= diff --git a/pkg/chain/transport.go b/pkg/chain/transport.go index c71e07e..220847b 100644 --- a/pkg/chain/transport.go +++ b/pkg/chain/transport.go @@ -36,7 +36,7 @@ func (tr *Transport) Dial(ctx context.Context, addr string) (net.Conn, error) { func (tr *Transport) dialOptions() []dialer.DialOption { var opts []dialer.DialOption - if tr.route != nil { + if !tr.route.IsEmpty() { opts = append(opts, dialer.DialFuncDialOption( func(ctx context.Context, addr string) (net.Conn, error) { diff --git a/pkg/common/util/tls/tls.go b/pkg/common/util/tls/tls.go index 3272958..c1e6c14 100644 --- a/pkg/common/util/tls/tls.go +++ b/pkg/common/util/tls/tls.go @@ -1,15 +1,11 @@ package tls import ( - "crypto/rand" - "crypto/rsa" "crypto/tls" "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" "errors" "io/ioutil" - "math/big" + "net" "time" ) @@ -18,8 +14,12 @@ var ( DefaultConfig *tls.Config ) -// LoadTLSConfig loads the certificate from cert & key files and optional client CA file. -func LoadTLSConfig(certFile, keyFile, caFile string) (*tls.Config, error) { +// LoadServerConfig loads the certificate from cert & key files and optional client CA file. +func LoadServerConfig(certFile, keyFile, caFile string) (*tls.Config, error) { + if certFile == "" && keyFile == "" { + return DefaultConfig, nil + } + cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, err @@ -27,7 +27,11 @@ func LoadTLSConfig(certFile, keyFile, caFile string) (*tls.Config, error) { cfg := &tls.Config{Certificates: []tls.Certificate{cert}} - if pool, _ := loadCA(caFile); pool != nil { + pool, err := loadCA(caFile) + if err != nil { + return nil, err + } + if pool != nil { cfg.ClientCAs = pool cfg.ClientAuth = tls.RequireAndVerifyClientCert } @@ -35,6 +39,58 @@ func LoadTLSConfig(certFile, keyFile, caFile string) (*tls.Config, error) { return cfg, nil } +// LoadClientConfig loads the certificate from cert & key files and optional CA file. +func LoadClientConfig(certFile, keyFile, caFile string, verify bool, serverName string) (*tls.Config, error) { + var cfg *tls.Config + + if certFile == "" && keyFile == "" { + cfg = &tls.Config{} + } else { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + cfg = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + } + + rootCAs, err := loadCA(caFile) + if err != nil { + return nil, err + } + + cfg.RootCAs = rootCAs + cfg.ServerName = serverName + cfg.InsecureSkipVerify = !verify + + // If the root ca is given, but skip verify, we verify the certificate manually. + if cfg.RootCAs != nil && !verify { + cfg.VerifyConnection = func(state tls.ConnectionState) error { + opts := x509.VerifyOptions{ + Roots: cfg.RootCAs, + CurrentTime: time.Now(), + DNSName: "", + Intermediates: x509.NewCertPool(), + } + + certs := state.PeerCertificates + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + + _, err := certs[0].Verify(opts) + return err + } + } + + return cfg, nil +} + func loadCA(caFile string) (cp *x509.CertPool, err error) { if caFile == "" { return @@ -50,47 +106,73 @@ func loadCA(caFile string) (cp *x509.CertPool, err error) { return } -func GenCertificate() (cert tls.Certificate, err error) { - rawCert, rawKey, err := generateKeyPair() - if err != nil { - return +// Wrap a net.Conn into a client tls connection, performing any +// additional verification as needed. +// +// As of go 1.3, crypto/tls only supports either doing no certificate +// verification, or doing full verification including of the peer's +// DNS name. For consul, we want to validate that the certificate is +// signed by a known CA, but because consul doesn't use DNS names for +// node names, we don't verify the certificate DNS names. Since go 1.3 +// no longer supports this mode of operation, we have to do it +// manually. +// +// This code is taken from consul: +// https://github.com/hashicorp/consul/blob/master/tlsutil/config.go +func WrapTLSClient(conn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) { + var err error + var tlsConn *tls.Conn + + if timeout > 0 { + conn.SetDeadline(time.Now().Add(timeout)) + defer conn.SetDeadline(time.Time{}) } - return tls.X509KeyPair(rawCert, rawKey) -} - -func generateKeyPair() (rawCert, rawKey []byte, err error) { - // Create private key and self-signed certificate - // Adapted from https://golang.org/src/crypto/tls/generate_cert.go - - priv, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return - } - validFor := time.Hour * 24 * 365 * 10 // ten years - notBefore := time.Now() - notAfter := notBefore.Add(validFor) - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit) - template := x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - Organization: []string{"gost"}, - CommonName: "gost.run", - }, - NotBefore: notBefore, - NotAfter: notAfter, - - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) - if err != nil { - return - } - - rawCert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - rawKey = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) - - return + + tlsConn = tls.Client(conn, tlsConfig) + + // Otherwise perform handshake, but don't verify the domain + // + // The following is lightly-modified from the doFullHandshake + // method in https://golang.org/src/crypto/tls/handshake_client.go + if err = tlsConn.Handshake(); err != nil { + tlsConn.Close() + return nil, err + } + + // We can do this in `tls.Config.VerifyConnection`, which effective for + // other TLS protocols such as WebSocket. See `route.go:parseChainNode` + /* + // If crypto/tls is doing verification, there's no need to do our own. + if tlsConfig.InsecureSkipVerify == false { + return tlsConn, nil + } + + // Similarly if we use host's CA, we can do full handshake + if tlsConfig.RootCAs == nil { + return tlsConn, nil + } + + opts := x509.VerifyOptions{ + Roots: tlsConfig.RootCAs, + CurrentTime: time.Now(), + DNSName: "", + Intermediates: x509.NewCertPool(), + } + + certs := tlsConn.ConnectionState().PeerCertificates + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + + _, err = certs[0].Verify(opts) + if err != nil { + tlsConn.Close() + return nil, err + } + */ + + return tlsConn, err } diff --git a/pkg/config/config.go b/pkg/config/config.go index d86c60f..694bd8a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -2,7 +2,6 @@ package config import ( "io" - "os" "time" "github.com/spf13/viper" @@ -138,14 +137,8 @@ func (c *Config) ReadFile(file string) error { return v.Unmarshal(c) } -func (c *Config) WriteFile(file string) error { - f, err := os.Create(file) - if err != nil { - return err - } - defer f.Close() - - enc := yaml.NewEncoder(f) +func (c *Config) Write(w io.Writer) error { + enc := yaml.NewEncoder(w) defer enc.Close() return enc.Encode(c) diff --git a/pkg/connector/http2/conn.go b/pkg/connector/http2/conn.go new file mode 100644 index 0000000..186e775 --- /dev/null +++ b/pkg/connector/http2/conn.go @@ -0,0 +1,54 @@ +package http2 + +import ( + "errors" + "io" + "net" + "time" +) + +// HTTP2 connection, wrapped up just like a net.Conn. +type http2Conn struct { + r io.Reader + w io.Writer + remoteAddr net.Addr + localAddr net.Addr +} + +func (c *http2Conn) Read(b []byte) (n int, err error) { + return c.r.Read(b) +} + +func (c *http2Conn) Write(b []byte) (n int, err error) { + return c.w.Write(b) +} + +func (c *http2Conn) Close() (err error) { + if r, ok := c.r.(io.Closer); ok { + err = r.Close() + } + if w, ok := c.w.(io.Closer); ok { + err = w.Close() + } + return +} + +func (c *http2Conn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *http2Conn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *http2Conn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *http2Conn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *http2Conn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} diff --git a/pkg/connector/http2/connector.go b/pkg/connector/http2/connector.go new file mode 100644 index 0000000..49ae623 --- /dev/null +++ b/pkg/connector/http2/connector.go @@ -0,0 +1,122 @@ +package http2 + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" + "time" + + "github.com/go-gost/gost/pkg/connector" + http2_util "github.com/go-gost/gost/pkg/internal/http2" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegiserConnector("http2", NewConnector) +} + +type http2Connector struct { + md metadata + logger logger.Logger +} + +func NewConnector(opts ...connector.Option) connector.Connector { + options := &connector.Options{} + for _, opt := range opts { + opt(options) + } + + return &http2Connector{ + logger: options.Logger, + } +} + +func (c *http2Connector) Init(md md.Metadata) (err error) { + return c.parseMetadata(md) +} + +func (c *http2Connector) Connect(ctx context.Context, conn net.Conn, network, address string, opts ...connector.ConnectOption) (net.Conn, error) { + c.logger = c.logger.WithFields(map[string]interface{}{ + "local": conn.LocalAddr().String(), + "remote": conn.RemoteAddr().String(), + "network": network, + "address": address, + }) + c.logger.Infof("connect %s/%s", address, network) + + cc, ok := conn.(*http2_util.ClientConn) + if !ok { + err := errors.New("wrong connection type") + c.logger.Error(err) + return nil, err + } + + pr, pw := io.Pipe() + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Scheme: "https", Host: conn.RemoteAddr().String()}, + Host: address, + ProtoMajor: 2, + ProtoMinor: 0, + Proto: "HTTP/2.0", + Header: make(http.Header), + Body: pr, + ContentLength: -1, + } + if c.md.UserAgent != "" { + req.Header.Set("User-Agent", c.md.UserAgent) + } + + if user := c.md.User; user != nil { + u := user.Username() + p, _ := user.Password() + req.Header.Set("Proxy-Authorization", + "Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p))) + } + + if c.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(req, false) + c.logger.Debug(string(dump)) + } + + if c.md.connectTimeout > 0 { + conn.SetDeadline(time.Now().Add(c.md.connectTimeout)) + defer conn.SetDeadline(time.Time{}) + } + + resp, err := cc.Client().Do(req.WithContext(ctx)) + if err != nil { + c.logger.Error(err) + cc.Close() + return nil, err + } + + if c.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + c.logger.Debug(string(dump)) + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + err = fmt.Errorf("%s", resp.Status) + c.logger.Error(err) + return nil, err + } + + hc := &http2Conn{ + r: resp.Body, + w: pw, + localAddr: conn.RemoteAddr(), + } + + hc.remoteAddr, _ = net.ResolveTCPAddr(network, address) + + return hc, nil +} diff --git a/pkg/connector/http2/metadata.go b/pkg/connector/http2/metadata.go new file mode 100644 index 0000000..c675eee --- /dev/null +++ b/pkg/connector/http2/metadata.go @@ -0,0 +1,44 @@ +package http2 + +import ( + "net/url" + "strings" + "time" + + md "github.com/go-gost/gost/pkg/metadata" +) + +const ( + defaultUserAgent = "Chrome/78.0.3904.106" +) + +type metadata struct { + connectTimeout time.Duration + UserAgent string + User *url.Userinfo +} + +func (c *http2Connector) parseMetadata(md md.Metadata) (err error) { + const ( + connectTimeout = "timeout" + userAgent = "userAgent" + user = "user" + ) + + c.md.connectTimeout = md.GetDuration(connectTimeout) + c.md.UserAgent, _ = md.Get(userAgent).(string) + if c.md.UserAgent == "" { + c.md.UserAgent = defaultUserAgent + } + + if v := md.GetString(user); v != "" { + ss := strings.SplitN(v, ":", 2) + if len(ss) == 1 { + c.md.User = url.User(ss[0]) + } else { + c.md.User = url.UserPassword(ss[0], ss[1]) + } + } + + return +} diff --git a/pkg/dialer/ftcp/conn.go b/pkg/dialer/ftcp/conn.go new file mode 100644 index 0000000..b0d7e26 --- /dev/null +++ b/pkg/dialer/ftcp/conn.go @@ -0,0 +1,21 @@ +package ftcp + +import "net" + +type fakeTCPConn struct { + raddr net.Addr + net.PacketConn +} + +func (c *fakeTCPConn) Read(b []byte) (n int, err error) { + n, _, err = c.ReadFrom(b) + return +} + +func (c *fakeTCPConn) Write(b []byte) (n int, err error) { + return c.WriteTo(b, c.raddr) +} + +func (c *fakeTCPConn) RemoteAddr() net.Addr { + return c.raddr +} diff --git a/pkg/dialer/ftcp/dialer.go b/pkg/dialer/ftcp/dialer.go new file mode 100644 index 0000000..5db9032 --- /dev/null +++ b/pkg/dialer/ftcp/dialer.go @@ -0,0 +1,51 @@ +package ftcp + +import ( + "context" + "net" + + "github.com/go-gost/gost/pkg/dialer" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "github.com/xtaci/tcpraw" +) + +func init() { + registry.RegisterDialer("ftcp", NewDialer) +} + +type ftcpDialer struct { + md metadata + logger logger.Logger +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := &dialer.Options{} + for _, opt := range opts { + opt(options) + } + + return &ftcpDialer{ + logger: options.Logger, + } +} + +func (d *ftcpDialer) Init(md md.Metadata) (err error) { + return d.parseMetadata(md) +} + +func (d *ftcpDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) { + raddr, er := net.ResolveTCPAddr("tcp", addr) + if er != nil { + return nil, er + } + c, err := tcpraw.Dial("tcp", addr) + if err != nil { + return + } + return &fakeTCPConn{ + raddr: raddr, + PacketConn: c, + }, nil +} diff --git a/pkg/dialer/ftcp/metadata.go b/pkg/dialer/ftcp/metadata.go new file mode 100644 index 0000000..3839b4b --- /dev/null +++ b/pkg/dialer/ftcp/metadata.go @@ -0,0 +1,23 @@ +package ftcp + +import ( + "time" + + md "github.com/go-gost/gost/pkg/metadata" +) + +const ( + dialTimeout = "dialTimeout" +) + +const ( + defaultDialTimeout = 5 * time.Second +) + +type metadata struct { + dialTimeout time.Duration +} + +func (d *ftcpDialer) parseMetadata(md md.Metadata) (err error) { + return +} diff --git a/pkg/dialer/http2/dialer.go b/pkg/dialer/http2/dialer.go new file mode 100644 index 0000000..844e5e9 --- /dev/null +++ b/pkg/dialer/http2/dialer.go @@ -0,0 +1,136 @@ +package http2 + +import ( + "context" + "net" + "net/http" + "sync" + "time" + + "github.com/go-gost/gost/pkg/dialer" + http2_util "github.com/go-gost/gost/pkg/internal/http2" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegisterDialer("http2", NewDialer) +} + +type http2Dialer struct { + md metadata + clients map[string]*http.Client + clientMutex sync.Mutex + logger logger.Logger +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := &dialer.Options{} + for _, opt := range opts { + opt(options) + } + + return &http2Dialer{ + clients: make(map[string]*http.Client), + logger: options.Logger, + } +} + +func (d *http2Dialer) Init(md md.Metadata) (err error) { + if err = d.parseMetadata(md); err != nil { + return + } + + return nil +} + +// IsMultiplex implements dialer.Multiplexer interface. +func (d *http2Dialer) IsMultiplex() bool { + return true +} + +func (d *http2Dialer) Dial(ctx context.Context, address string, opts ...dialer.DialOption) (net.Conn, error) { + options := &dialer.DialOptions{} + for _, opt := range opts { + opt(options) + } + + raddr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + d.logger.Error(err) + return nil, err + } + + d.clientMutex.Lock() + defer d.clientMutex.Unlock() + + client, ok := d.clients[address] + if !ok { + client = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: d.md.tlsConfig, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return d.dial(ctx, network, addr, options) + }, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + } + /* + client = &http.Client{ + Transport: &http2.Transport{ + TLSClientConfig: d.md.tlsConfig, + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + conn, err := d.dial(ctx, network, addr, options) + if err != nil { + return nil, err + } + return tls_util.WrapTLSClient(conn, cfg, time.Duration(0)) + }, + }, + } + */ + d.clients[address] = client + } + + return http2_util.NewClientConn( + &net.TCPAddr{}, raddr, + client, + func() { + d.clientMutex.Lock() + defer d.clientMutex.Unlock() + delete(d.clients, address) + }), nil +} + +func (d *http2Dialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) { + dial := opts.DialFunc + if dial != nil { + conn, err := dial(ctx, addr) + if err != nil { + d.logger.Error(err) + } else { + d.logger.WithFields(map[string]interface{}{ + "src": conn.LocalAddr().String(), + "dst": addr, + }).Debug("dial with dial func") + } + return conn, err + } + + var netd net.Dialer + conn, err := netd.DialContext(ctx, network, addr) + if err != nil { + d.logger.Error(err) + } else { + d.logger.WithFields(map[string]interface{}{ + "src": conn.LocalAddr().String(), + "dst": addr, + }).Debugf("dial direct %s/%s", addr, network) + } + return conn, err +} diff --git a/pkg/dialer/http2/metadata.go b/pkg/dialer/http2/metadata.go new file mode 100644 index 0000000..4befe1d --- /dev/null +++ b/pkg/dialer/http2/metadata.go @@ -0,0 +1,37 @@ +package http2 + +import ( + "crypto/tls" + "net" + + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + md "github.com/go-gost/gost/pkg/metadata" +) + +type metadata struct { + tlsConfig *tls.Config +} + +func (d *http2Dialer) parseMetadata(md md.Metadata) (err error) { + const ( + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + secure = "secure" + serverName = "serverName" + ) + + sn, _, _ := net.SplitHostPort(md.GetString(serverName)) + if sn == "" { + sn = "localhost" + } + d.md.tlsConfig, err = tls_util.LoadClientConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + md.GetBool(secure), + sn, + ) + + return +} diff --git a/pkg/handler/auto/handler.go b/pkg/handler/auto/handler.go index 5b5b6a0..0280ce4 100644 --- a/pkg/handler/auto/handler.go +++ b/pkg/handler/auto/handler.go @@ -8,6 +8,7 @@ import ( "github.com/go-gost/gosocks4" "github.com/go-gost/gosocks5" + "github.com/go-gost/gost/pkg/chain" "github.com/go-gost/gost/pkg/handler" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -20,6 +21,7 @@ func init() { } type autoHandler struct { + chain *chain.Chain httpHandler handler.Handler socks4Handler handler.Handler socks5Handler handler.Handler @@ -66,23 +68,43 @@ func NewHandler(opts ...handler.Option) handler.Handler { return h } +func (h *autoHandler) WithChain(chain *chain.Chain) { + h.chain = chain +} + func (h *autoHandler) Init(md md.Metadata) error { if h.httpHandler != nil { + if chainable, ok := h.httpHandler.(chain.Chainable); ok { + chainable.WithChain(h.chain) + } + if err := h.httpHandler.Init(md); err != nil { return err } } if h.socks4Handler != nil { + if chainable, ok := h.socks4Handler.(chain.Chainable); ok { + chainable.WithChain(h.chain) + } + if err := h.socks4Handler.Init(md); err != nil { return err } } if h.socks5Handler != nil { + if chainable, ok := h.socks5Handler.(chain.Chainable); ok { + chainable.WithChain(h.chain) + } + if err := h.socks5Handler.Init(md); err != nil { return err } } if h.relayHandler != nil { + if chainable, ok := h.relayHandler.(chain.Chainable); ok { + chainable.WithChain(h.chain) + } + if err := h.relayHandler.Init(md); err != nil { return err } diff --git a/pkg/handler/http2/conn.go b/pkg/handler/http2/conn.go new file mode 100644 index 0000000..5454565 --- /dev/null +++ b/pkg/handler/http2/conn.go @@ -0,0 +1,46 @@ +package http2 + +import ( + "errors" + "io" + "net/http" +) + +type readWriter struct { + r io.Reader + w io.Writer +} + +func (rw *readWriter) Read(p []byte) (n int, err error) { + return rw.r.Read(p) +} + +func (rw *readWriter) Write(p []byte) (n int, err error) { + return rw.w.Write(p) +} + +type flushWriter struct { + w io.Writer +} + +func (fw flushWriter) Write(p []byte) (n int, err error) { + defer func() { + if r := recover(); r != nil { + if s, ok := r.(string); ok { + err = errors.New(s) + return + } + err = r.(error) + } + }() + + n, err = fw.w.Write(p) + if err != nil { + // log.Log("flush writer:", err) + return + } + if f, ok := fw.w.(http.Flusher); ok { + f.Flush() + } + return +} diff --git a/pkg/handler/http2/handler.go b/pkg/handler/http2/handler.go new file mode 100644 index 0000000..27bdfc5 --- /dev/null +++ b/pkg/handler/http2/handler.go @@ -0,0 +1,473 @@ +package http2 + +import ( + "context" + "encoding/base64" + "encoding/binary" + "errors" + "hash/crc32" + "net" + "net/http" + "net/http/httputil" + "os" + "strconv" + "strings" + "time" + + "github.com/asaskevich/govalidator" + "github.com/go-gost/gost/pkg/bypass" + "github.com/go-gost/gost/pkg/chain" + "github.com/go-gost/gost/pkg/handler" + http2_util "github.com/go-gost/gost/pkg/internal/http2" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" +) + +func init() { + registry.RegisterHandler("http2", NewHandler) +} + +type http2Handler struct { + chain *chain.Chain + bypass bypass.Bypass + logger logger.Logger + md metadata +} + +func NewHandler(opts ...handler.Option) handler.Handler { + options := &handler.Options{} + for _, opt := range opts { + opt(options) + } + + return &http2Handler{ + bypass: options.Bypass, + logger: options.Logger, + } +} + +func (h *http2Handler) Init(md md.Metadata) error { + return h.parseMetadata(md) +} + +// implements chain.Chainable interface +func (h *http2Handler) WithChain(chain *chain.Chain) { + h.chain = chain +} + +func (h *http2Handler) Handle(ctx context.Context, conn net.Conn) { + defer conn.Close() + + start := time.Now() + h.logger = h.logger.WithFields(map[string]interface{}{ + "remote": conn.RemoteAddr().String(), + "local": conn.LocalAddr().String(), + }) + h.logger.Infof("%s <> %s", conn.RemoteAddr(), conn.LocalAddr()) + defer func() { + h.logger.WithFields(map[string]interface{}{ + "duration": time.Since(start), + }).Infof("%s >< %s", conn.RemoteAddr(), conn.LocalAddr()) + }() + + cc, ok := conn.(*http2_util.ServerConn) + if !ok { + h.logger.Error("wrong connection type") + return + } + h.roundTrip(ctx, cc.Writer(), cc.Request()) +} + +// NOTE: there is an issue (golang/go#43989) will cause the client hangs +// when server returns an non-200 status code, +// May be fixed in go1.18. +func (h *http2Handler) roundTrip(ctx context.Context, w http.ResponseWriter, req *http.Request) { + // Try to get the actual host. + // Compatible with GOST 2.x. + if v := req.Header.Get("Gost-Target"); v != "" { + if h, err := h.decodeServerName(v); err == nil { + req.Host = h + } + } + req.Header.Del("Gost-Target") + + if v := req.Header.Get("X-Gost-Target"); v != "" { + if h, err := h.decodeServerName(v); err == nil { + req.Host = h + } + } + req.Header.Del("X-Gost-Target") + + addr := req.Host + if _, port, _ := net.SplitHostPort(addr); port == "" { + addr = net.JoinHostPort(addr, "80") + } + + fields := map[string]interface{}{ + "dst": addr, + } + if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")); u != "" { + fields["user"] = u + } + h.logger = h.logger.WithFields(fields) + + if h.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(req, false) + h.logger.Debug(string(dump)) + } + h.logger.Infof("%s >> %s", req.RemoteAddr, addr) + + if h.md.proxyAgent != "" { + w.Header().Set("Proxy-Agent", h.md.proxyAgent) + } + + /* + if !Can("tcp", host, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[http2] %s - %s : Unauthorized to tcp connect to %s", + r.RemoteAddr, laddr, host) + w.WriteHeader(http.StatusForbidden) + return + } + */ + + if h.bypass != nil && h.bypass.Contains(addr) { + w.WriteHeader(http.StatusForbidden) + h.logger.Info("bypass: ", addr) + return + } + + /* + resp := &http.Response{ + ProtoMajor: 2, + ProtoMinor: 0, + Header: http.Header{}, + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + } + + if !h.authenticate(w, r, resp) { + return + } + */ + + // delete the proxy related headers. + req.Header.Del("Proxy-Authorization") + req.Header.Del("Proxy-Connection") + + r := (&chain.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + cc, err := r.Dial(ctx, "tcp", addr) + if err != nil { + h.logger.Error(err) + w.WriteHeader(http.StatusServiceUnavailable) + return + } + defer cc.Close() + + if req.Method == http.MethodConnect { + w.WriteHeader(http.StatusOK) + if fw, ok := w.(http.Flusher); ok { + fw.Flush() + } + + // compatible with HTTP1.x + if hj, ok := w.(http.Hijacker); ok && req.ProtoMajor == 1 { + // we take over the underly connection + conn, _, err := hj.Hijack() + if err != nil { + h.logger.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + defer conn.Close() + + start := time.Now() + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) + handler.Transport(conn, cc) + h.logger. + WithFields(map[string]interface{}{ + "duration": time.Since(start), + }). + Infof("%s >-< %s", conn.RemoteAddr(), addr) + } + + start := time.Now() + h.logger.Infof("%s <-> %s", req.RemoteAddr, addr) + handler.Transport(&readWriter{r: req.Body, w: flushWriter{w}}, cc) + h.logger. + WithFields(map[string]interface{}{ + "duration": time.Since(start), + }). + Infof("%s >-< %s", req.RemoteAddr, addr) + return + } +} + +func (h *http2Handler) handleRequest(ctx context.Context, conn net.Conn, req *http.Request) { + if req == nil { + return + } + + if h.md.sni && !req.URL.IsAbs() && govalidator.IsDNSName(req.Host) { + req.URL.Scheme = "http" + } + + network := req.Header.Get("X-Gost-Protocol") + if network != "udp" { + network = "tcp" + } + + // Try to get the actual host. + // Compatible with GOST 2.x. + if v := req.Header.Get("Gost-Target"); v != "" { + if h, err := h.decodeServerName(v); err == nil { + req.Host = h + } + } + req.Header.Del("Gost-Target") + + if v := req.Header.Get("X-Gost-Target"); v != "" { + if h, err := h.decodeServerName(v); err == nil { + req.Host = h + } + } + req.Header.Del("X-Gost-Target") + + addr := req.Host + if _, port, _ := net.SplitHostPort(addr); port == "" { + addr = net.JoinHostPort(addr, "80") + } + + fields := map[string]interface{}{ + "dst": addr, + } + if u, _, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")); u != "" { + fields["user"] = u + } + h.logger = h.logger.WithFields(fields) + + if h.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(req, false) + h.logger.Debug(string(dump)) + } + h.logger.Infof("%s >> %s", conn.RemoteAddr(), addr) + + resp := &http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{}, + } + + if h.md.proxyAgent != "" { + resp.Header.Add("Proxy-Agent", h.md.proxyAgent) + } + + /* + if !Can("tcp", host, h.options.Whitelist, h.options.Blacklist) { + log.Logf("[http] %s - %s : Unauthorized to tcp connect to %s", + conn.RemoteAddr(), conn.LocalAddr(), host) + resp.StatusCode = http.StatusForbidden + + if Debug { + dump, _ := httputil.DumpResponse(resp, false) + log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump)) + } + + resp.Write(conn) + return + } + */ + + if h.bypass != nil && h.bypass.Contains(addr) { + resp.StatusCode = http.StatusForbidden + + if h.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + h.logger.Debug(string(dump)) + } + h.logger.Info("bypass: ", addr) + + resp.Write(conn) + return + } + + if !h.authenticate(conn, req, resp) { + return + } + + if req.Method == "PRI" || + (req.Method != http.MethodConnect && req.URL.Scheme != "http") { + resp.StatusCode = http.StatusBadRequest + resp.Write(conn) + + if h.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + h.logger.Debug(string(dump)) + } + + return + } + + req.Header.Del("Proxy-Authorization") + + r := (&chain.Router{}). + WithChain(h.chain). + WithRetry(h.md.retryCount). + WithLogger(h.logger) + cc, err := r.Dial(ctx, network, addr) + if err != nil { + resp.StatusCode = http.StatusServiceUnavailable + resp.Write(conn) + + if h.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + h.logger.Debug(string(dump)) + } + return + } + defer cc.Close() + + if req.Method == http.MethodConnect { + resp.StatusCode = http.StatusOK + resp.Status = "200 Connection established" + + if h.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + h.logger.Debug(string(dump)) + } + if err = resp.Write(conn); err != nil { + h.logger.Error(err) + return + } + } else { + req.Header.Del("Proxy-Connection") + if err = req.Write(cc); err != nil { + h.logger.Error(err) + return + } + } + + start := time.Now() + h.logger.Infof("%s <-> %s", conn.RemoteAddr(), addr) + handler.Transport(conn, cc) + h.logger. + WithFields(map[string]interface{}{ + "duration": time.Since(start), + }). + Infof("%s >-< %s", conn.RemoteAddr(), addr) +} + +func (h *http2Handler) decodeServerName(s string) (string, error) { + b, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return "", err + } + if len(b) < 4 { + return "", errors.New("invalid name") + } + v, err := base64.RawURLEncoding.DecodeString(string(b[4:])) + if err != nil { + return "", err + } + if crc32.ChecksumIEEE(v) != binary.BigEndian.Uint32(b[:4]) { + return "", errors.New("invalid name") + } + return string(v), nil +} + +func (h *http2Handler) basicProxyAuth(proxyAuth string) (username, password string, ok bool) { + if proxyAuth == "" { + return + } + + if !strings.HasPrefix(proxyAuth, "Basic ") { + return + } + c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(proxyAuth, "Basic ")) + if err != nil { + return + } + cs := string(c) + s := strings.IndexByte(cs, ':') + if s < 0 { + return + } + + return cs[:s], cs[s+1:], true +} + +func (h *http2Handler) authenticate(conn net.Conn, req *http.Request, resp *http.Response) (ok bool) { + u, p, _ := h.basicProxyAuth(req.Header.Get("Proxy-Authorization")) + if h.md.authenticator == nil || h.md.authenticator.Authenticate(u, p) { + return true + } + + pr := h.md.probeResist + // probing resistance is enabled, and knocking host is mismatch. + if pr != nil && (pr.Knock == "" || !strings.EqualFold(req.URL.Hostname(), pr.Knock)) { + resp.StatusCode = http.StatusServiceUnavailable // default status code + + switch pr.Type { + case "code": + resp.StatusCode, _ = strconv.Atoi(pr.Value) + case "web": + url := pr.Value + if !strings.HasPrefix(url, "http") { + url = "http://" + url + } + if r, err := http.Get(url); err == nil { + resp = r + defer r.Body.Close() + } + case "host": + cc, err := net.Dial("tcp", pr.Value) + if err == nil { + defer cc.Close() + + req.Write(cc) + handler.Transport(conn, cc) + return + } + case "file": + f, _ := os.Open(pr.Value) + if f != nil { + resp.StatusCode = http.StatusOK + if finfo, _ := f.Stat(); finfo != nil { + resp.ContentLength = finfo.Size() + } + resp.Header.Set("Content-Type", "text/html") + resp.Body = f + } + } + } + + if resp.StatusCode == 0 { + resp.StatusCode = http.StatusProxyAuthRequired + resp.Header.Add("Proxy-Authenticate", "Basic realm=\"gost\"") + if strings.ToLower(req.Header.Get("Proxy-Connection")) == "keep-alive" { + // XXX libcurl will keep sending auth request in same conn + // which we don't supported yet. + resp.Header.Add("Connection", "close") + resp.Header.Add("Proxy-Connection", "close") + } + + h.logger.Info("proxy authentication required") + } else { + resp.Header = http.Header{} + resp.Header.Set("Server", "nginx/1.20.1") + resp.Header.Set("Date", time.Now().Format(http.TimeFormat)) + if resp.StatusCode == http.StatusOK { + resp.Header.Set("Connection", "keep-alive") + } + } + + if h.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + h.logger.Debug(string(dump)) + } + + resp.Write(conn) + return +} diff --git a/pkg/handler/http2/metadata.go b/pkg/handler/http2/metadata.go new file mode 100644 index 0000000..b2de45d --- /dev/null +++ b/pkg/handler/http2/metadata.go @@ -0,0 +1,67 @@ +package http2 + +import ( + "strings" + + "github.com/go-gost/gost/pkg/auth" + md "github.com/go-gost/gost/pkg/metadata" +) + +type metadata struct { + authenticator auth.Authenticator + proxyAgent string + retryCount int + probeResist *probeResist + sni bool + enableUDP bool +} + +func (h *http2Handler) parseMetadata(md md.Metadata) error { + const ( + proxyAgent = "proxyAgent" + users = "users" + probeResistKey = "probeResist" + knock = "knock" + retryCount = "retry" + sni = "sni" + enableUDP = "udp" + ) + + h.md.proxyAgent = md.GetString(proxyAgent) + + if v, _ := md.Get(users).([]interface{}); len(v) > 0 { + authenticator := auth.NewLocalAuthenticator(nil) + for _, auth := range v { + if s, _ := auth.(string); s != "" { + ss := strings.SplitN(s, ":", 2) + if len(ss) == 1 { + authenticator.Add(ss[0], "") + } else { + authenticator.Add(ss[0], ss[1]) + } + } + } + h.md.authenticator = authenticator + } + + if v := md.GetString(probeResistKey); v != "" { + if ss := strings.SplitN(v, ":", 2); len(ss) == 2 { + h.md.probeResist = &probeResist{ + Type: ss[0], + Value: ss[1], + Knock: md.GetString(knock), + } + } + } + h.md.retryCount = md.GetInt(retryCount) + h.md.sni = md.GetBool(sni) + h.md.enableUDP = md.GetBool(enableUDP) + + return nil +} + +type probeResist struct { + Type string + Value string + Knock string +} diff --git a/pkg/handler/socks/v5/metadata.go b/pkg/handler/socks/v5/metadata.go index 6d10a7a..13aa447 100644 --- a/pkg/handler/socks/v5/metadata.go +++ b/pkg/handler/socks/v5/metadata.go @@ -39,19 +39,13 @@ func (h *socks5Handler) parseMetadata(md md.Metadata) (err error) { compatibilityMode = "comp" ) - if md.GetString(certFile) != "" || - md.GetString(keyFile) != "" || - md.GetString(caFile) != "" { - h.md.tlsConfig, err = tls_util.LoadTLSConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - ) - if err != nil { - return - } - } else { - h.md.tlsConfig = tls_util.DefaultConfig + h.md.tlsConfig, err = tls_util.LoadServerConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) + if err != nil { + return } if v, _ := md.Get(users).([]interface{}); len(v) > 0 { diff --git a/pkg/internal/http2/conn.go b/pkg/internal/http2/conn.go new file mode 100644 index 0000000..8751aa4 --- /dev/null +++ b/pkg/internal/http2/conn.go @@ -0,0 +1,134 @@ +package http2 + +import ( + "context" + "errors" + "net" + "net/http" + "time" +) + +// a dummy HTTP2 client conn used by HTTP2 client connector +type ClientConn struct { + localAddr net.Addr + remoteAddr net.Addr + client *http.Client + onClose func() +} + +func NewClientConn(localAddr, remoteAddr net.Addr, client *http.Client, onClose func()) net.Conn { + return &ClientConn{ + localAddr: localAddr, + remoteAddr: remoteAddr, + client: client, + onClose: onClose, + } +} + +func (c *ClientConn) Client() *http.Client { + return c.client +} + +func (c *ClientConn) Close() error { + if c.onClose != nil { + c.onClose() + } + return nil +} + +func (c *ClientConn) Read(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "read", Net: "nop", Source: nil, Addr: nil, Err: errors.New("read not supported")} +} + +func (c *ClientConn) Write(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "write", Net: "nop", Source: nil, Addr: nil, Err: errors.New("write not supported")} +} + +func (c *ClientConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *ClientConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *ClientConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *ClientConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *ClientConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "nop", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +// a dummy HTTP2 server conn used by HTTP2 handler +type ServerConn struct { + r *http.Request + w http.ResponseWriter + cancel context.CancelFunc +} + +func NewServerConn(w http.ResponseWriter, r *http.Request) *ServerConn { + ctx, cancel := context.WithCancel(r.Context()) + + return &ServerConn{ + r: r.Clone(ctx), + w: w, + cancel: cancel, + } +} + +func (c *ServerConn) Done() <-chan struct{} { + return c.r.Context().Done() +} + +func (c *ServerConn) Request() *http.Request { + return c.r +} + +func (c *ServerConn) Writer() http.ResponseWriter { + return c.w +} + +func (c *ServerConn) Read(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "read", Net: "http2", Source: nil, Addr: nil, Err: errors.New("read not supported")} +} + +func (c *ServerConn) Write(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "write", Net: "http2", Source: nil, Addr: nil, Err: errors.New("write not supported")} +} + +func (c *ServerConn) Close() error { + c.cancel() + + select { + case <-c.r.Context().Done(): + default: + } + return nil +} + +func (c *ServerConn) LocalAddr() net.Addr { + addr, _ := net.ResolveTCPAddr("tcp", c.r.Host) + return addr +} + +func (c *ServerConn) RemoteAddr() net.Addr { + addr, _ := net.ResolveTCPAddr("tcp", c.r.RemoteAddr) + return addr +} + +func (c *ServerConn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *ServerConn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *ServerConn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} diff --git a/pkg/listener/http2/h2/listener.go b/pkg/listener/http2/h2/listener.go index 4e11550..204b4c6 100644 --- a/pkg/listener/http2/h2/listener.go +++ b/pkg/listener/http2/h2/listener.go @@ -8,7 +8,6 @@ import ( "time" "github.com/go-gost/gost/pkg/common/util" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -172,16 +171,3 @@ func (l *h2Listener) upgrade(w http.ResponseWriter, r *http.Request) (*conn, err closed: make(chan struct{}), }, nil } - -func (l *h2Listener) parseMetadata(md md.Metadata) (err error) { - l.md.tlsConfig, err = tls_util.LoadTLSConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - ) - if err != nil { - return - } - - return -} diff --git a/pkg/listener/http2/h2/metadata.go b/pkg/listener/http2/h2/metadata.go index 4fa94fc..999475b 100644 --- a/pkg/listener/http2/h2/metadata.go +++ b/pkg/listener/http2/h2/metadata.go @@ -4,18 +4,9 @@ import ( "crypto/tls" "net/http" "time" -) -const ( - path = "path" - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - handshakeTimeout = "handshakeTimeout" - readHeaderTimeout = "readHeaderTimeout" - readBufferSize = "readBufferSize" - writeBufferSize = "writeBufferSize" - connQueueSize = "connQueueSize" + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + md "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -34,3 +25,28 @@ type metadata struct { connQueueSize int keepAlivePeriod time.Duration } + +func (l *h2Listener) parseMetadata(md md.Metadata) (err error) { + const ( + path = "path" + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + handshakeTimeout = "handshakeTimeout" + readHeaderTimeout = "readHeaderTimeout" + readBufferSize = "readBufferSize" + writeBufferSize = "writeBufferSize" + connQueueSize = "connQueueSize" + ) + + l.md.tlsConfig, err = tls_util.LoadServerConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) + if err != nil { + return + } + + return +} diff --git a/pkg/listener/http2/listener.go b/pkg/listener/http2/listener.go index b4fa135..edae513 100644 --- a/pkg/listener/http2/listener.go +++ b/pkg/listener/http2/listener.go @@ -6,7 +6,7 @@ import ( "net/http" "github.com/go-gost/gost/pkg/common/util" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" + http2_util "github.com/go-gost/gost/pkg/internal/http2" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -19,13 +19,13 @@ func init() { } type http2Listener struct { - saddr string - md metadata - server *http.Server - addr net.Addr - connChan chan *conn - errChan chan error - logger logger.Logger + saddr string + md metadata + server *http.Server + addr net.Addr + cqueue chan net.Conn + errChan chan error + logger logger.Logger } func NewListener(opts ...listener.Option) listener.Listener { @@ -61,22 +61,17 @@ func (l *http2Listener) Init(md md.Metadata) (err error) { ln = tls.NewListener( &util.TCPKeepAliveListener{ - TCPListener: ln.(*net.TCPListener), - KeepAlivePeriod: l.md.keepAlivePeriod, + TCPListener: ln.(*net.TCPListener), }, l.md.tlsConfig, ) - queueSize := l.md.connQueueSize - if queueSize <= 0 { - queueSize = defaultQueueSize - } - l.connChan = make(chan *conn, queueSize) + l.cqueue = make(chan net.Conn, l.md.backlog) l.errChan = make(chan error, 1) go func() { if err := l.server.Serve(ln); err != nil { - // log.Log("[http2]", err) + l.logger.Error(err) } }() @@ -86,7 +81,7 @@ func (l *http2Listener) Init(md md.Metadata) (err error) { func (l *http2Listener) Accept() (conn net.Conn, err error) { var ok bool select { - case conn = <-l.connChan: + case conn = <-l.cqueue: case err, ok = <-l.errChan: if !ok { err = listener.ErrClosed @@ -111,30 +106,13 @@ func (l *http2Listener) Close() (err error) { } func (l *http2Listener) handleFunc(w http.ResponseWriter, r *http.Request) { - conn := &conn{ - r: r, - w: w, - closed: make(chan struct{}), - } + conn := http2_util.NewServerConn(w, r) select { - case l.connChan <- conn: + case l.cqueue <- conn: default: - // log.Logf("[http2] %s - %s: connection queue is full", r.RemoteAddr, l.server.Addr) + l.logger.Warnf("connection queue is full, client %s discarded", r.RemoteAddr) return } - <-conn.closed -} - -func (l *http2Listener) parseMetadata(md md.Metadata) (err error) { - l.md.tlsConfig, err = tls_util.LoadTLSConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - ) - if err != nil { - return - } - - return + <-conn.Done() } diff --git a/pkg/listener/http2/metadata.go b/pkg/listener/http2/metadata.go index d53f11b..50fca3c 100644 --- a/pkg/listener/http2/metadata.go +++ b/pkg/listener/http2/metadata.go @@ -4,22 +4,13 @@ import ( "crypto/tls" "net/http" "time" + + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + md "github.com/go-gost/gost/pkg/metadata" ) const ( - path = "path" - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - handshakeTimeout = "handshakeTimeout" - readHeaderTimeout = "readHeaderTimeout" - readBufferSize = "readBufferSize" - writeBufferSize = "writeBufferSize" - connQueueSize = "connQueueSize" -) - -const ( - defaultQueueSize = 128 + defaultBacklog = 128 ) type metadata struct { @@ -31,6 +22,34 @@ type metadata struct { writeBufferSize int enableCompression bool responseHeader http.Header - connQueueSize int - keepAlivePeriod time.Duration + backlog int +} + +func (l *http2Listener) parseMetadata(md md.Metadata) (err error) { + const ( + path = "path" + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + handshakeTimeout = "handshakeTimeout" + readHeaderTimeout = "readHeaderTimeout" + readBufferSize = "readBufferSize" + writeBufferSize = "writeBufferSize" + backlog = "backlog" + ) + + l.md.tlsConfig, err = tls_util.LoadServerConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) + if err != nil { + return + } + + l.md.backlog = md.GetInt(backlog) + if l.md.backlog <= 0 { + l.md.backlog = defaultBacklog + } + return } diff --git a/pkg/listener/tls/metadata.go b/pkg/listener/tls/metadata.go index bf6f2c6..bb47f52 100644 --- a/pkg/listener/tls/metadata.go +++ b/pkg/listener/tls/metadata.go @@ -21,19 +21,13 @@ func (l *tlsListener) parseMetadata(md md.Metadata) (err error) { keepAlivePeriod = "keepAlivePeriod" ) - if md.GetString(certFile) != "" || - md.GetString(keyFile) != "" || - md.GetString(caFile) != "" { - l.md.tlsConfig, err = tls_util.LoadTLSConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - ) - if err != nil { - return - } - } else { - l.md.tlsConfig = tls_util.DefaultConfig + l.md.tlsConfig, err = tls_util.LoadServerConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) + if err != nil { + return } l.md.keepAlivePeriod = md.GetDuration(keepAlivePeriod) diff --git a/pkg/listener/tls/mux/listener.go b/pkg/listener/tls/mux/listener.go index d7e17f6..799006c 100644 --- a/pkg/listener/tls/mux/listener.go +++ b/pkg/listener/tls/mux/listener.go @@ -4,7 +4,6 @@ import ( "crypto/tls" "net" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" @@ -125,16 +124,3 @@ func (l *mtlsListener) Accept() (conn net.Conn, err error) { } return } - -func (l *mtlsListener) parseMetadata(md md.Metadata) (err error) { - l.md.tlsConfig, err = tls_util.LoadTLSConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - ) - if err != nil { - return - } - - return -} diff --git a/pkg/listener/tls/mux/metadata.go b/pkg/listener/tls/mux/metadata.go index a5d0cec..fc9e922 100644 --- a/pkg/listener/tls/mux/metadata.go +++ b/pkg/listener/tls/mux/metadata.go @@ -3,19 +3,9 @@ package mux import ( "crypto/tls" "time" -) -const ( - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - - muxKeepAliveDisabled = "muxKeepAliveDisabled" - muxKeepAlivePeriod = "muxKeepAlivePeriod" - muxKeepAliveTimeout = "muxKeepAliveTimeout" - muxMaxFrameSize = "muxMaxFrameSize" - muxMaxReceiveBuffer = "muxMaxReceiveBuffer" - muxMaxStreamBuffer = "muxMaxStreamBuffer" + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + md "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -34,3 +24,29 @@ type metadata struct { connQueueSize int } + +func (l *mtlsListener) parseMetadata(md md.Metadata) (err error) { + const ( + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + + muxKeepAliveDisabled = "muxKeepAliveDisabled" + muxKeepAlivePeriod = "muxKeepAlivePeriod" + muxKeepAliveTimeout = "muxKeepAliveTimeout" + muxMaxFrameSize = "muxMaxFrameSize" + muxMaxReceiveBuffer = "muxMaxReceiveBuffer" + muxMaxStreamBuffer = "muxMaxStreamBuffer" + ) + + l.md.tlsConfig, err = tls_util.LoadServerConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) + if err != nil { + return + } + + return +} diff --git a/pkg/listener/ws/metadata.go b/pkg/listener/ws/metadata.go index a2165cb..7e05664 100644 --- a/pkg/listener/ws/metadata.go +++ b/pkg/listener/ws/metadata.go @@ -41,21 +41,13 @@ func (l *wsListener) parseMetadata(md md.Metadata) (err error) { connQueueSize = "connQueueSize" ) - if l.tlsEnabled { - if md.GetString(certFile) != "" || - md.GetString(keyFile) != "" || - md.GetString(caFile) != "" { - l.md.tlsConfig, err = tls_util.LoadTLSConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - ) - if err != nil { - return - } - } else { - l.md.tlsConfig = tls_util.DefaultConfig - } + l.md.tlsConfig, err = tls_util.LoadServerConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) + if err != nil { + return } l.md.path = md.GetString(path) diff --git a/pkg/listener/ws/mux/listener.go b/pkg/listener/ws/mux/listener.go index 822c116..bef8254 100644 --- a/pkg/listener/ws/mux/listener.go +++ b/pkg/listener/ws/mux/listener.go @@ -5,7 +5,6 @@ import ( "net" "net/http" - tls_util "github.com/go-gost/gost/pkg/common/util/tls" ws_util "github.com/go-gost/gost/pkg/common/util/ws" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" @@ -111,19 +110,6 @@ func (l *mwsListener) Addr() net.Addr { return l.addr } -func (l *mwsListener) parseMetadata(md md.Metadata) (err error) { - l.md.tlsConfig, err = tls_util.LoadTLSConfig( - md.GetString(certFile), - md.GetString(keyFile), - md.GetString(caFile), - ) - if err != nil { - return - } - - return -} - func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) { conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader) if err != nil { diff --git a/pkg/listener/ws/mux/metadata.go b/pkg/listener/ws/mux/metadata.go index 8233ec3..b8d48e3 100644 --- a/pkg/listener/ws/mux/metadata.go +++ b/pkg/listener/ws/mux/metadata.go @@ -4,27 +4,9 @@ import ( "crypto/tls" "net/http" "time" -) -const ( - path = "path" - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - handshakeTimeout = "handshakeTimeout" - readHeaderTimeout = "readHeaderTimeout" - readBufferSize = "readBufferSize" - writeBufferSize = "writeBufferSize" - enableCompression = "enableCompression" - responseHeader = "responseHeader" - connQueueSize = "connQueueSize" - - muxKeepAliveDisabled = "muxKeepAliveDisabled" - muxKeepAlivePeriod = "muxKeepAlivePeriod" - muxKeepAliveTimeout = "muxKeepAliveTimeout" - muxMaxFrameSize = "muxMaxFrameSize" - muxMaxReceiveBuffer = "muxMaxReceiveBuffer" - muxMaxStreamBuffer = "muxMaxStreamBuffer" + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + md "github.com/go-gost/gost/pkg/metadata" ) const ( @@ -50,3 +32,37 @@ type metadata struct { muxMaxStreamBuffer int connQueueSize int } + +func (l *mwsListener) parseMetadata(md md.Metadata) (err error) { + const ( + path = "path" + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + handshakeTimeout = "handshakeTimeout" + readHeaderTimeout = "readHeaderTimeout" + readBufferSize = "readBufferSize" + writeBufferSize = "writeBufferSize" + enableCompression = "enableCompression" + responseHeader = "responseHeader" + connQueueSize = "connQueueSize" + + muxKeepAliveDisabled = "muxKeepAliveDisabled" + muxKeepAlivePeriod = "muxKeepAlivePeriod" + muxKeepAliveTimeout = "muxKeepAliveTimeout" + muxMaxFrameSize = "muxMaxFrameSize" + muxMaxReceiveBuffer = "muxMaxReceiveBuffer" + muxMaxStreamBuffer = "muxMaxStreamBuffer" + ) + + l.md.tlsConfig, err = tls_util.LoadServerConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + ) + if err != nil { + return + } + + return +}