package ws import ( "context" "github.com/go-gost/x/util" tls "github.com/refraction-networking/utls" "net" "net/url" "time" "github.com/go-gost/core/dialer" md "github.com/go-gost/core/metadata" ws_util "github.com/go-gost/x/internal/util/ws" "github.com/go-gost/x/registry" "github.com/gorilla/websocket" ) func init() { registry.DialerRegistry().Register("ws", NewDialer) registry.DialerRegistry().Register("wss", NewTLSDialer) } type wsDialer struct { tlsEnabled bool md metadata options dialer.Options } func NewDialer(opts ...dialer.Option) dialer.Dialer { options := dialer.Options{} for _, opt := range opts { opt(&options) } return &wsDialer{ options: options, } } func NewTLSDialer(opts ...dialer.Option) dialer.Dialer { options := dialer.Options{} for _, opt := range opts { opt(&options) } return &wsDialer{ tlsEnabled: true, options: options, } } func (d *wsDialer) Init(md md.Metadata) (err error) { return d.parseMetadata(md) } func (d *wsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) { var options dialer.DialOptions for _, opt := range opts { opt(&options) } conn, err := options.NetDialer.Dial(ctx, "tcp", addr) if err != nil { d.options.Logger.Error(err) } return conn, err } // Handshake implements dialer.Handshaker func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) { opts := &dialer.HandshakeOptions{} for _, option := range options { option(opts) } if d.md.handshakeTimeout > 0 { conn.SetReadDeadline(time.Now().Add(d.md.handshakeTimeout)) defer conn.SetReadDeadline(time.Time{}) } host := d.md.host if host == "" { host = opts.Addr } dialer := websocket.Dialer{ HandshakeTimeout: d.md.handshakeTimeout, ReadBufferSize: d.md.readBufferSize, WriteBufferSize: d.md.writeBufferSize, EnableCompression: d.md.enableCompression, NetDial: func(net, addr string) (net.Conn, error) { return conn, nil }, } urlObj := url.URL{Scheme: "ws", Host: host, Path: d.md.path} if d.tlsEnabled { urlObj.Scheme = "wss" dialer.TLSClientConfig = d.options.TLSConfig tlsConfig := d.options.TLSConfig dialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { utlsConf := &tls.Config{InsecureSkipVerify: tlsConfig.InsecureSkipVerify, ServerName: tlsConfig.ServerName, ClientAuth: tls.ClientAuthType(tlsConfig.ClientAuth), ClientCAs: tlsConfig.ClientCAs, RootCAs: tlsConfig.RootCAs} if len(tlsConfig.Certificates) > 0 { for _, certificate := range tlsConfig.Certificates { utlsConf.Certificates = append(utlsConf.Certificates, tls.Certificate{ Certificate: certificate.Certificate, PrivateKey: certificate.PrivateKey, OCSPStaple: certificate.OCSPStaple, SignedCertificateTimestamps: certificate.SignedCertificateTimestamps, Leaf: certificate.Leaf, }) } } var client *tls.UConn if d.md.useH2 { client = tls.UClient(conn, utlsConf, tls.HelloChrome_Auto) } else { client = tls.UClient(conn, utlsConf, tls.HelloCustom) client.ApplyPreset(util.NewWsSpec()) } err := client.Handshake() if err != nil { return nil, err } return client, nil } } urlStr, errUnescape := url.QueryUnescape(urlObj.String()) if errUnescape != nil { d.options.Logger.Debugf("[ws] URL QueryUnescape Error URL.String() -> %s", urlObj.String()) } c, resp, err := dialer.DialContext(ctx, urlStr, d.md.header) if err != nil { return nil, err } resp.Body.Close() cc := ws_util.Conn(c) if d.md.keepaliveInterval > 0 { d.options.Logger.Debugf("keepalive is enabled, ttl: %v", d.md.keepaliveInterval) c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2)) c.SetPongHandler(func(string) error { c.SetReadDeadline(time.Now().Add(d.md.keepaliveInterval * 2)) d.options.Logger.Debugf("pong: set read deadline: %v", d.md.keepaliveInterval*2) return nil }) go d.keepalive(cc) } return cc, nil } func (d *wsDialer) keepalive(conn ws_util.WebsocketConn) { ticker := time.NewTicker(d.md.keepaliveInterval) defer ticker.Stop() for range ticker.C { d.options.Logger.Debug("send ping") conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { // d.options.Logger.Error(err) return } conn.SetWriteDeadline(time.Time{}) } }