package ws import ( "context" tls "github.com/refraction-networking/utls" "net" "net/url" "time" "github.com/go-gost/core/dialer" md "github.com/go-gost/core/metadata" ws_util "github.com/go-gost/x/internal/util/ws" "github.com/go-gost/x/registry" "github.com/gorilla/websocket" ) func init() { registry.DialerRegistry().Register("ws", NewDialer) registry.DialerRegistry().Register("wss", NewTLSDialer) } type wsDialer struct { tlsEnabled bool md metadata options dialer.Options } func NewDialer(opts ...dialer.Option) dialer.Dialer { options := dialer.Options{} for _, opt := range opts { opt(&options) } return &wsDialer{ options: options, } } func NewTLSDialer(opts ...dialer.Option) dialer.Dialer { options := dialer.Options{} for _, opt := range opts { opt(&options) } return &wsDialer{ tlsEnabled: true, options: options, } } func (d *wsDialer) Init(md md.Metadata) (err error) { return d.parseMetadata(md) } func (d *wsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { var options dialer.DialOptions for _, opt := range opts { opt(&options) } conn, err := options.NetDialer.Dial(ctx, "tcp", addr) if err != nil { d.options.Logger.Error(err) } return conn, err } // Handshake implements dialer.Handshaker func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) { opts := &dialer.HandshakeOptions{} for _, option := range options { option(opts) } if d.md.handshakeTimeout > 0 { conn.SetReadDeadline(time.Now().Add(d.md.handshakeTimeout)) defer conn.SetReadDeadline(time.Time{}) } host := d.md.host if host == "" { host = opts.Addr } dialer := websocket.Dialer{ HandshakeTimeout: d.md.handshakeTimeout, ReadBufferSize: d.md.readBufferSize, WriteBufferSize: d.md.writeBufferSize, EnableCompression: d.md.enableCompression, NetDial: func(net, addr string) (net.Conn, error) { return conn, nil }, } urlObj := url.URL{Scheme: "ws", Host: host, Path: d.md.path} if d.tlsEnabled { urlObj.Scheme = "wss" dialer.TLSClientConfig = d.options.TLSConfig tlsConfig := d.options.TLSConfig dialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { utlsConf := &tls.Config{InsecureSkipVerify: tlsConfig.InsecureSkipVerify, ServerName: tlsConfig.ServerName, ClientAuth: tls.ClientAuthType(tlsConfig.ClientAuth), ClientCAs: tlsConfig.ClientCAs, RootCAs: tlsConfig.RootCAs} if len(tlsConfig.Certificates) > 0 { for _, certificate := range tlsConfig.Certificates { utlsConf.Certificates = append(utlsConf.Certificates, tls.Certificate{ Certificate: certificate.Certificate, PrivateKey: certificate.PrivateKey, OCSPStaple: certificate.OCSPStaple, SignedCertificateTimestamps: certificate.SignedCertificateTimestamps, Leaf: certificate.Leaf, }) } } client := tls.UClient(conn, utlsConf, tls.HelloCustom) client.ApplyPreset(newWsSpec()) err := client.Handshake() if err != nil { return nil, err } return client, nil } } urlStr, errUnescape := url.QueryUnescape(urlObj.String()) if errUnescape != nil { d.options.Logger.Debugf("[ws] URL QueryUnescape Error URL.String() -> %s", urlObj.String()) } c, resp, err := dialer.DialContext(ctx, urlStr, d.md.header) if err != nil { return nil, err } resp.Body.Close() cc := ws_util.Conn(c) if d.md.keepaliveInterval > 0 { d.options.Logger.Debugf("keepalive is enabled, ttl: %v", d.md.keepaliveInterval) c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2)) c.SetPongHandler(func(string) error { c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2)) d.options.Logger.Debugf("pong: set read deadline: %v", d.md.keepaliveInterval*2) return nil }) go d.keepalive(cc) } return cc, nil } func (d *wsDialer) keepalive(conn ws_util.WebsocketConn) { ticker := time.NewTicker(d.md.keepaliveInterval) defer ticker.Stop() for range ticker.C { d.options.Logger.Debug("send ping") conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { // d.options.Logger.Error(err) return } conn.SetWriteDeadline(time.Time{}) } } func newWsSpec() *tls.ClientHelloSpec { return &tls.ClientHelloSpec{ CipherSuites: []uint16{ tls.GREASE_PLACEHOLDER, tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384, tls.TLS_CHACHA20_POLY1305_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, tls.TLS_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_RSA_WITH_AES_128_CBC_SHA, tls.TLS_RSA_WITH_AES_256_CBC_SHA, }, CompressionMethods: []byte{ 0x00, // compressionNone }, Extensions: []tls.TLSExtension{ &tls.UtlsGREASEExtension{}, &tls.SNIExtension{}, &tls.ExtendedMasterSecretExtension{}, &tls.RenegotiationInfoExtension{Renegotiation: tls.RenegotiateOnceAsClient}, &tls.SupportedCurvesExtension{[]tls.CurveID{ tls.GREASE_PLACEHOLDER, tls.X25519, tls.CurveP256, tls.CurveP384, }}, &tls.SupportedPointsExtension{SupportedPoints: []byte{ 0x00, // pointFormatUncompressed }}, &tls.SessionTicketExtension{}, &tls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}}, &tls.StatusRequestExtension{}, &tls.SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []tls.SignatureScheme{ tls.ECDSAWithP256AndSHA256, tls.PSSWithSHA256, tls.PKCS1WithSHA256, tls.ECDSAWithP384AndSHA384, tls.PSSWithSHA384, tls.PKCS1WithSHA384, tls.PSSWithSHA512, tls.PKCS1WithSHA512, }}, &tls.SCTExtension{}, &tls.KeyShareExtension{[]tls.KeyShare{ {Group: tls.CurveID(tls.GREASE_PLACEHOLDER), Data: []byte{0}}, {Group: tls.X25519}, }}, &tls.PSKKeyExchangeModesExtension{[]uint8{ tls.PskModeDHE, }}, &tls.SupportedVersionsExtension{[]uint16{ tls.GREASE_PLACEHOLDER, tls.VersionTLS13, tls.VersionTLS12, }}, &tls.UtlsCompressCertExtension{[]tls.CertCompressionAlgo{ tls.CertCompressionBrotli, }}, &tls.ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}}, &tls.UtlsGREASEExtension{}, &tls.UtlsPaddingExtension{GetPaddingLen: tls.BoringPaddingStyle}, }, } }