diff --git a/.gitignore b/.gitignore index 2072667..07cb34b 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,5 @@ _testmain.go cmd/gost/gost snap + +*.pem \ No newline at end of file diff --git a/cmd/gost/register.go b/cmd/gost/register.go index 3fa717e..b1a531e 100644 --- a/cmd/gost/register.go +++ b/cmd/gost/register.go @@ -15,6 +15,7 @@ import ( // Register dialers _ "github.com/go-gost/gost/pkg/dialer/ftcp" _ "github.com/go-gost/gost/pkg/dialer/http2" + _ "github.com/go-gost/gost/pkg/dialer/http2/h2" _ "github.com/go-gost/gost/pkg/dialer/tcp" _ "github.com/go-gost/gost/pkg/dialer/udp" diff --git a/cmd/gost/tls.go b/cmd/gost/tls.go index 239d4e0..53dcc57 100644 --- a/cmd/gost/tls.go +++ b/cmd/gost/tls.go @@ -34,7 +34,7 @@ func buildDefaultTLSConfig(cfg *config.TLSConfig) { } log.Warn("load TLS certificate files failed, use random generated certificate") } else { - log.Debug("load TLS certificate files OK") + log.Info("load TLS certificate files OK") } tls_util.DefaultConfig = tlsConfig } diff --git a/gost.yml b/gost.yml index ff6451a..bf725ba 100644 --- a/gost.yml +++ b/gost.yml @@ -7,6 +7,11 @@ profiling: addr: ":6060" enabled: true +# tls: +# cert: "cert.pem" +# key: "key.pem" +# ca: "root.ca" + services: - name: http+tcp url: "http://gost:gost@:8000" diff --git a/pkg/connector/http2/connector.go b/pkg/connector/http2/connector.go index 49ae623..a72461e 100644 --- a/pkg/connector/http2/connector.go +++ b/pkg/connector/http2/connector.go @@ -61,15 +61,14 @@ func (c *http2Connector) Connect(ctx context.Context, conn net.Conn, network, ad pr, pw := io.Pipe() req := &http.Request{ - Method: http.MethodConnect, - URL: &url.URL{Scheme: "https", Host: conn.RemoteAddr().String()}, - Host: address, - ProtoMajor: 2, - ProtoMinor: 0, - Proto: "HTTP/2.0", - Header: make(http.Header), - Body: pr, - ContentLength: -1, + Method: http.MethodConnect, + URL: &url.URL{Scheme: "https", Host: conn.RemoteAddr().String()}, + Host: address, + ProtoMajor: 2, + ProtoMinor: 0, + Header: make(http.Header), + Body: pr, + // ContentLength: -1, } if c.md.UserAgent != "" { req.Header.Set("User-Agent", c.md.UserAgent) diff --git a/pkg/dialer/http2/dialer.go b/pkg/dialer/http2/dialer.go index 844e5e9..b7398e0 100644 --- a/pkg/dialer/http2/dialer.go +++ b/pkg/dialer/http2/dialer.go @@ -80,20 +80,6 @@ func (d *http2Dialer) Dial(ctx context.Context, address string, opts ...dialer.D ExpectContinueTimeout: 1 * time.Second, }, } - /* - client = &http.Client{ - Transport: &http2.Transport{ - TLSClientConfig: d.md.tlsConfig, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - conn, err := d.dial(ctx, network, addr, options) - if err != nil { - return nil, err - } - return tls_util.WrapTLSClient(conn, cfg, time.Duration(0)) - }, - }, - } - */ d.clients[address] = client } diff --git a/pkg/dialer/http2/h2/conn.go b/pkg/dialer/http2/h2/conn.go new file mode 100644 index 0000000..6b786b4 --- /dev/null +++ b/pkg/dialer/http2/h2/conn.go @@ -0,0 +1,54 @@ +package h2 + +import ( + "errors" + "io" + "net" + "time" +) + +// HTTP2 connection, wrapped up just like a net.Conn. +type http2Conn struct { + r io.Reader + w io.Writer + remoteAddr net.Addr + localAddr net.Addr +} + +func (c *http2Conn) Read(b []byte) (n int, err error) { + return c.r.Read(b) +} + +func (c *http2Conn) Write(b []byte) (n int, err error) { + return c.w.Write(b) +} + +func (c *http2Conn) Close() (err error) { + if r, ok := c.r.(io.Closer); ok { + err = r.Close() + } + if w, ok := c.w.(io.Closer); ok { + err = w.Close() + } + return +} + +func (c *http2Conn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *http2Conn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *http2Conn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "h2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *http2Conn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "h2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *http2Conn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "h2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} diff --git a/pkg/dialer/http2/h2/dialer.go b/pkg/dialer/http2/h2/dialer.go new file mode 100644 index 0000000..1b6e766 --- /dev/null +++ b/pkg/dialer/http2/h2/dialer.go @@ -0,0 +1,190 @@ +package h2 + +import ( + "context" + "crypto/tls" + "errors" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" + "sync" + "time" + + "github.com/go-gost/gost/pkg/dialer" + "github.com/go-gost/gost/pkg/logger" + md "github.com/go-gost/gost/pkg/metadata" + "github.com/go-gost/gost/pkg/registry" + "golang.org/x/net/http2" +) + +func init() { + registry.RegisterDialer("h2", NewTLSDialer) + registry.RegisterDialer("h2c", NewDialer) +} + +type h2Dialer struct { + clients map[string]*http.Client + clientMutex sync.Mutex + logger logger.Logger + md metadata + h2c bool +} + +func NewDialer(opts ...dialer.Option) dialer.Dialer { + options := &dialer.Options{} + for _, opt := range opts { + opt(options) + } + + return &h2Dialer{ + clients: make(map[string]*http.Client), + logger: options.Logger, + h2c: true, + } +} + +func NewTLSDialer(opts ...dialer.Option) dialer.Dialer { + options := &dialer.Options{} + for _, opt := range opts { + opt(options) + } + + return &h2Dialer{ + clients: make(map[string]*http.Client), + logger: options.Logger, + } +} + +func (d *h2Dialer) Init(md md.Metadata) (err error) { + if err = d.parseMetadata(md); err != nil { + return + } + + return nil +} + +// IsMultiplex implements dialer.Multiplexer interface. +func (d *h2Dialer) IsMultiplex() bool { + return true +} + +func (d *h2Dialer) Dial(ctx context.Context, address string, opts ...dialer.DialOption) (net.Conn, error) { + raddr, err := net.ResolveTCPAddr("tcp", address) + if err != nil { + d.logger.Error(err) + return nil, err + } + + d.clientMutex.Lock() + + client, ok := d.clients[address] + if !ok { + options := &dialer.DialOptions{} + for _, opt := range opts { + opt(options) + } + + client = &http.Client{} + if d.h2c { + client.Transport = &http2.Transport{ + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + return d.dial(ctx, network, addr, options) + }, + } + } else { + client.Transport = &http.Transport{ + TLSClientConfig: d.md.tlsConfig, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return d.dial(ctx, network, addr, options) + }, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + } + + d.clients[address] = client + } + d.clientMutex.Unlock() + + host := d.md.host + if host == "" { + host = address + } + + pr, pw := io.Pipe() + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Scheme: "https", Host: host}, + Header: make(http.Header), + ProtoMajor: 2, + ProtoMinor: 0, + Body: pr, + Host: host, + // ContentLength: -1, + } + if d.md.path != "" { + req.Method = http.MethodGet + req.URL.Path = d.md.path + } + + if d.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(req, false) + d.logger.Debug(string(dump)) + } + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + + if d.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpResponse(resp, false) + d.logger.Debug(string(dump)) + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, errors.New(resp.Status) + } + + conn := &http2Conn{ + r: resp.Body, + w: pw, + remoteAddr: raddr, + localAddr: &net.TCPAddr{IP: net.IPv4zero, Port: 0}, + } + return conn, nil +} + +func (d *h2Dialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) { + dial := opts.DialFunc + if dial != nil { + conn, err := dial(ctx, addr) + if err != nil { + d.logger.Error(err) + } else { + d.logger.WithFields(map[string]interface{}{ + "src": conn.LocalAddr().String(), + "dst": addr, + }).Debug("dial with dial func") + } + return conn, err + } + + var netd net.Dialer + conn, err := netd.DialContext(ctx, network, addr) + if err != nil { + d.logger.Error(err) + } else { + d.logger.WithFields(map[string]interface{}{ + "src": conn.LocalAddr().String(), + "dst": addr, + }).Debugf("dial direct %s/%s", addr, network) + } + return conn, err +} diff --git a/pkg/dialer/http2/h2/metadata.go b/pkg/dialer/http2/h2/metadata.go new file mode 100644 index 0000000..bb36c71 --- /dev/null +++ b/pkg/dialer/http2/h2/metadata.go @@ -0,0 +1,43 @@ +package h2 + +import ( + "crypto/tls" + "net" + + tls_util "github.com/go-gost/gost/pkg/common/util/tls" + md "github.com/go-gost/gost/pkg/metadata" +) + +type metadata struct { + path string + host string + tlsConfig *tls.Config +} + +func (d *h2Dialer) parseMetadata(md md.Metadata) (err error) { + const ( + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + secure = "secure" + serverName = "serverName" + path = "path" + ) + + d.md.host = md.GetString(serverName) + sn, _, _ := net.SplitHostPort(d.md.host) + if sn == "" { + sn = "localhost" + } + d.md.tlsConfig, err = tls_util.LoadClientConfig( + md.GetString(certFile), + md.GetString(keyFile), + md.GetString(caFile), + md.GetBool(secure), + sn, + ) + + d.md.path = md.GetString(path) + + return +} diff --git a/pkg/listener/http2/h2/listener.go b/pkg/listener/http2/h2/listener.go index 204b4c6..92bd4eb 100644 --- a/pkg/listener/http2/h2/listener.go +++ b/pkg/listener/http2/h2/listener.go @@ -5,28 +5,30 @@ import ( "errors" "net" "net/http" - "time" + "net/http/httputil" - "github.com/go-gost/gost/pkg/common/util" "github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/logger" md "github.com/go-gost/gost/pkg/metadata" "github.com/go-gost/gost/pkg/registry" "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) func init() { - registry.RegisterListener("h2", NewListener) + registry.RegisterListener("h2c", NewListener) + registry.RegisterListener("h2", NewTLSListener) } type h2Listener struct { - addr string - net.Listener - md metadata - server *http2.Server - connChan chan net.Conn - errChan chan error - logger logger.Logger + server *http.Server + saddr string + addr net.Addr + cqueue chan net.Conn + errChan chan error + logger logger.Logger + md metadata + h2c bool } func NewListener(opts ...listener.Option) listener.Listener { @@ -35,7 +37,19 @@ func NewListener(opts ...listener.Option) listener.Listener { opt(options) } return &h2Listener{ - addr: options.Addr, + saddr: options.Addr, + logger: options.Logger, + h2c: true, + } +} + +func NewTLSListener(opts ...listener.Option) listener.Listener { + options := &listener.Options{} + for _, opt := range opts { + opt(options) + } + return &h2Listener{ + saddr: options.Addr, logger: options.Logger, } } @@ -45,36 +59,45 @@ func (l *h2Listener) Init(md md.Metadata) (err error) { return } - ln, err := net.Listen("tcp", l.addr) - if err != nil { - return - } - l.Listener = &util.TCPKeepAliveListener{ - TCPListener: ln.(*net.TCPListener), - KeepAlivePeriod: l.md.keepAlivePeriod, - } - // TODO: tune http2 server config - l.server = &http2.Server{ - // MaxConcurrentStreams: 1000, - PermitProhibitedCipherSuites: true, - IdleTimeout: 5 * time.Minute, + l.server = &http.Server{ + Addr: l.saddr, } - queueSize := l.md.connQueueSize - if queueSize <= 0 { - queueSize = defaultQueueSize + ln, err := net.Listen("tcp", l.saddr) + if err != nil { + return err } - l.connChan = make(chan net.Conn, queueSize) + l.addr = ln.Addr() + + if l.h2c { + l.server.Handler = h2c.NewHandler( + http.HandlerFunc(l.handleFunc), &http2.Server{}) + } else { + l.server.Handler = http.HandlerFunc(l.handleFunc) + l.server.TLSConfig = l.md.tlsConfig + if err := http2.ConfigureServer(l.server, nil); err != nil { + ln.Close() + return err + } + ln = tls.NewListener(ln, l.md.tlsConfig) + } + + l.cqueue = make(chan net.Conn, l.md.backlog) l.errChan = make(chan error, 1) - go l.listenLoop() + go func() { + if err := l.server.Serve(ln); err != nil { + l.logger.Error(err) + } + }() + return } func (l *h2Listener) Accept() (conn net.Conn, err error) { var ok bool select { - case conn = <-l.connChan: + case conn = <-l.cqueue: case err, ok = <-l.errChan: if !ok { err = listener.ErrClosed @@ -83,58 +106,36 @@ func (l *h2Listener) Accept() (conn net.Conn, err error) { return } -func (l *h2Listener) listenLoop() { - for { - conn, err := l.Listener.Accept() - if err != nil { - // log.Log("[http2] accept:", err) - l.errChan <- err - close(l.errChan) - return - } - go l.handleLoop(conn) - } +func (l *h2Listener) Addr() net.Addr { + return l.addr } -func (l *h2Listener) handleLoop(conn net.Conn) { - if l.md.tlsConfig != nil { - tlsConn := tls.Server(conn, l.md.tlsConfig) - // NOTE: HTTP2 server will check the TLS version, - // so we must ensure that the TLS connection is handshake completed. - if err := tlsConn.Handshake(); err != nil { - // log.Logf("[http2] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - conn = tlsConn +func (l *h2Listener) Close() (err error) { + select { + case <-l.errChan: + default: + err = l.server.Close() + l.errChan <- err + close(l.errChan) } - - opt := http2.ServeConnOpts{ - Handler: http.HandlerFunc(l.handleFunc), - } - l.server.ServeConn(conn, &opt) + return nil } func (l *h2Listener) handleFunc(w http.ResponseWriter, r *http.Request) { - /* - log.Logf("[http2] %s -> %s %s %s %s", - r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto) - if Debug { - dump, _ := httputil.DumpRequest(r, false) - log.Log("[http2]", string(dump)) - } - */ - // w.Header().Set("Proxy-Agent", "gost/"+Version) + if l.logger.IsLevelEnabled(logger.DebugLevel) { + dump, _ := httputil.DumpRequest(r, false) + l.logger.Debug(string(dump)) + } conn, err := l.upgrade(w, r) if err != nil { - // log.Logf("[http2] %s - %s %s %s %s: %s", - // r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto, err) + l.logger.Error(err) return } select { - case l.connChan <- conn: + case l.cqueue <- conn: default: conn.Close() - // log.Logf("[http2] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) + l.logger.Warnf("connection queue is full, client %s discarded", r.RemoteAddr) } <-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed @@ -166,7 +167,7 @@ func (l *h2Listener) upgrade(w http.ResponseWriter, r *http.Request) (*conn, err return &conn{ r: r.Body, w: flushWriter{w}, - localAddr: l.Listener.Addr(), + localAddr: l.addr, remoteAddr: remoteAddr, closed: make(chan struct{}), }, nil diff --git a/pkg/listener/http2/h2/metadata.go b/pkg/listener/http2/h2/metadata.go index 999475b..80ffc2c 100644 --- a/pkg/listener/http2/h2/metadata.go +++ b/pkg/listener/http2/h2/metadata.go @@ -2,41 +2,28 @@ package h2 import ( "crypto/tls" - "net/http" - "time" tls_util "github.com/go-gost/gost/pkg/common/util/tls" md "github.com/go-gost/gost/pkg/metadata" ) const ( - defaultQueueSize = 128 + defaultBacklog = 128 ) type metadata struct { - path string - tlsConfig *tls.Config - handshakeTimeout time.Duration - readHeaderTimeout time.Duration - readBufferSize int - writeBufferSize int - enableCompression bool - responseHeader http.Header - connQueueSize int - keepAlivePeriod time.Duration + path string + tlsConfig *tls.Config + backlog int } func (l *h2Listener) parseMetadata(md md.Metadata) (err error) { const ( - path = "path" - certFile = "certFile" - keyFile = "keyFile" - caFile = "caFile" - handshakeTimeout = "handshakeTimeout" - readHeaderTimeout = "readHeaderTimeout" - readBufferSize = "readBufferSize" - writeBufferSize = "writeBufferSize" - connQueueSize = "connQueueSize" + path = "path" + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + backlog = "backlog" ) l.md.tlsConfig, err = tls_util.LoadServerConfig( @@ -48,5 +35,11 @@ func (l *h2Listener) parseMetadata(md md.Metadata) (err error) { return } + l.md.backlog = md.GetInt(backlog) + if l.md.backlog <= 0 { + l.md.backlog = defaultBacklog + } + + l.md.path = md.GetString(path) return } diff --git a/pkg/listener/http2/listener.go b/pkg/listener/http2/listener.go index edae513..0804f76 100644 --- a/pkg/listener/http2/listener.go +++ b/pkg/listener/http2/listener.go @@ -19,13 +19,13 @@ func init() { } type http2Listener struct { - saddr string - md metadata server *http.Server + saddr string addr net.Addr cqueue chan net.Conn errChan chan error logger logger.Logger + md metadata } func NewListener(opts ...listener.Option) listener.Listener {