diff --git a/server/listener/http2/conn.go b/server/listener/http2/conn.go new file mode 100644 index 0000000..c145860 --- /dev/null +++ b/server/listener/http2/conn.go @@ -0,0 +1,54 @@ +package http2 + +import ( + "errors" + "net" + "net/http" + "time" +) + +// a dummy HTTP2 server conn used by HTTP2 handler +type conn struct { + r *http.Request + w http.ResponseWriter + closed chan struct{} +} + +func (c *conn) Read(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "read", Net: "http2", Source: nil, Addr: nil, Err: errors.New("read not supported")} +} + +func (c *conn) Write(b []byte) (n int, err error) { + return 0, &net.OpError{Op: "write", Net: "http2", Source: nil, Addr: nil, Err: errors.New("write not supported")} +} + +func (c *conn) Close() error { + select { + case <-c.closed: + default: + close(c.closed) + } + return nil +} + +func (c *conn) LocalAddr() net.Addr { + addr, _ := net.ResolveTCPAddr("tcp", c.r.Host) + return addr +} + +func (c *conn) RemoteAddr() net.Addr { + addr, _ := net.ResolveTCPAddr("tcp", c.r.RemoteAddr) + return addr +} + +func (c *conn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *conn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *conn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} diff --git a/server/listener/http2/h2/conn.go b/server/listener/http2/h2/conn.go new file mode 100644 index 0000000..10f8d4c --- /dev/null +++ b/server/listener/http2/h2/conn.go @@ -0,0 +1,89 @@ +package h2 + +import ( + "errors" + "io" + "net" + "net/http" + "time" +) + +// HTTP2 connection, wrapped up just like a net.Conn +type conn struct { + r io.Reader + w io.Writer + remoteAddr net.Addr + localAddr net.Addr + closed chan struct{} +} + +func (c *conn) Read(b []byte) (n int, err error) { + return c.r.Read(b) +} + +func (c *conn) Write(b []byte) (n int, err error) { + return c.w.Write(b) +} + +func (c *conn) Close() (err error) { + select { + case <-c.closed: + return + default: + close(c.closed) + } + if rc, ok := c.r.(io.Closer); ok { + err = rc.Close() + } + if w, ok := c.w.(io.Closer); ok { + err = w.Close() + } + return +} + +func (c *conn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *conn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *conn) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *conn) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (c *conn) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +type flushWriter struct { + w io.Writer +} + +func (fw flushWriter) Write(p []byte) (n int, err error) { + defer func() { + if r := recover(); r != nil { + if s, ok := r.(string); ok { + err = errors.New(s) + // log.Log("[http2]", err) + return + } + err = r.(error) + } + }() + + n, err = fw.w.Write(p) + if err != nil { + // log.Log("flush writer:", err) + return + } + if f, ok := fw.w.(http.Flusher); ok { + f.Flush() + } + return +} diff --git a/server/listener/http2/h2/listener.go b/server/listener/http2/h2/listener.go new file mode 100644 index 0000000..b736089 --- /dev/null +++ b/server/listener/http2/h2/listener.go @@ -0,0 +1,186 @@ +package h2 + +import ( + "crypto/tls" + "errors" + "net" + "net/http" + "time" + + "github.com/go-gost/gost/logger" + "github.com/go-gost/gost/server/listener" + "github.com/go-gost/gost/utils" + "golang.org/x/net/http2" +) + +var ( + _ listener.Listener = (*Listener)(nil) +) + +type Listener struct { + net.Listener + md metadata + server *http2.Server + connChan chan net.Conn + errChan chan error + logger logger.Logger +} + +func NewListener(opts ...listener.Option) *Listener { + options := &listener.Options{} + for _, opt := range opts { + opt(options) + } + return &Listener{ + logger: options.Logger, + } +} + +func (l *Listener) Init(md listener.Metadata) (err error) { + l.md, err = l.parseMetadata(md) + if err != nil { + return + } + + ln, err := net.Listen("tcp", l.md.addr) + if err != nil { + return + } + l.Listener = &utils.TCPKeepAliveListener{ + TCPListener: ln.(*net.TCPListener), + KeepAlivePeriod: l.md.keepAlivePeriod, + } + // TODO: tune http2 server config + l.server = &http2.Server{ + // MaxConcurrentStreams: 1000, + PermitProhibitedCipherSuites: true, + IdleTimeout: 5 * time.Minute, + } + + queueSize := l.md.connQueueSize + if queueSize <= 0 { + queueSize = defaultQueueSize + } + l.connChan = make(chan net.Conn, queueSize) + l.errChan = make(chan error, 1) + + go l.listenLoop() + return +} + +func (l *Listener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = listener.ErrClosed + } + } + return +} + +func (l *Listener) listenLoop() { + for { + conn, err := l.Listener.Accept() + if err != nil { + // log.Log("[http2] accept:", err) + l.errChan <- err + close(l.errChan) + return + } + go l.handleLoop(conn) + } +} + +func (l *Listener) handleLoop(conn net.Conn) { + if l.md.tlsConfig != nil { + tlsConn := tls.Server(conn, l.md.tlsConfig) + // NOTE: HTTP2 server will check the TLS version, + // so we must ensure that the TLS connection is handshake completed. + if err := tlsConn.Handshake(); err != nil { + // log.Logf("[http2] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) + return + } + conn = tlsConn + } + + opt := http2.ServeConnOpts{ + Handler: http.HandlerFunc(l.handleFunc), + } + l.server.ServeConn(conn, &opt) +} + +func (l *Listener) handleFunc(w http.ResponseWriter, r *http.Request) { + /* + log.Logf("[http2] %s -> %s %s %s %s", + r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto) + if Debug { + dump, _ := httputil.DumpRequest(r, false) + log.Log("[http2]", string(dump)) + } + */ + // w.Header().Set("Proxy-Agent", "gost/"+Version) + conn, err := l.upgrade(w, r) + if err != nil { + // log.Logf("[http2] %s - %s %s %s %s: %s", + // r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto, err) + return + } + select { + case l.connChan <- conn: + default: + conn.Close() + // log.Logf("[http2] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) + } + + <-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed +} + +func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) (*conn, error) { + if l.md.path == "" && r.Method != http.MethodConnect { + w.WriteHeader(http.StatusMethodNotAllowed) + return nil, errors.New("method not allowed") + } + + if l.md.path != "" && r.RequestURI != l.md.path { + w.WriteHeader(http.StatusBadRequest) + return nil, errors.New("bad request") + } + + w.WriteHeader(http.StatusOK) + if fw, ok := w.(http.Flusher); ok { + fw.Flush() // write header to client + } + + remoteAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) + if remoteAddr == nil { + remoteAddr = &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + } + return &conn{ + r: r.Body, + w: flushWriter{w}, + localAddr: l.Listener.Addr(), + remoteAddr: remoteAddr, + closed: make(chan struct{}), + }, nil +} + +func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { + if val, ok := md[addr]; ok { + m.addr = val + } else { + err = errors.New("missing address") + return + } + + m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile]) + if err != nil { + return + } + + return +} diff --git a/server/listener/http2/h2/metadata.go b/server/listener/http2/h2/metadata.go new file mode 100644 index 0000000..30d90f9 --- /dev/null +++ b/server/listener/http2/h2/metadata.go @@ -0,0 +1,38 @@ +package h2 + +import ( + "crypto/tls" + "net/http" + "time" +) + +const ( + addr = "addr" + path = "path" + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + handshakeTimeout = "handshakeTimeout" + readHeaderTimeout = "readHeaderTimeout" + readBufferSize = "readBufferSize" + writeBufferSize = "writeBufferSize" + connQueueSize = "connQueueSize" +) + +const ( + defaultQueueSize = 128 +) + +type metadata struct { + addr string + path string + tlsConfig *tls.Config + handshakeTimeout time.Duration + readHeaderTimeout time.Duration + readBufferSize int + writeBufferSize int + enableCompression bool + responseHeader http.Header + connQueueSize int + keepAlivePeriod time.Duration +} diff --git a/server/listener/http2/listener.go b/server/listener/http2/listener.go new file mode 100644 index 0000000..be6cdf6 --- /dev/null +++ b/server/listener/http2/listener.go @@ -0,0 +1,140 @@ +package http2 + +import ( + "crypto/tls" + "errors" + "net" + "net/http" + + "github.com/go-gost/gost/logger" + "github.com/go-gost/gost/server/listener" + "github.com/go-gost/gost/utils" + "golang.org/x/net/http2" +) + +var ( + _ listener.Listener = (*Listener)(nil) +) + +type Listener struct { + md metadata + server *http.Server + addr net.Addr + connChan chan *conn + errChan chan error + logger logger.Logger +} + +func NewListener(opts ...listener.Option) *Listener { + options := &listener.Options{} + for _, opt := range opts { + opt(options) + } + return &Listener{ + logger: options.Logger, + } +} + +func (l *Listener) Init(md listener.Metadata) (err error) { + l.md, err = l.parseMetadata(md) + if err != nil { + return + } + + l.server = &http.Server{ + Addr: l.md.addr, + Handler: http.HandlerFunc(l.handleFunc), + TLSConfig: l.md.tlsConfig, + } + if err := http2.ConfigureServer(l.server, nil); err != nil { + return err + } + + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + l.addr = ln.Addr() + + ln = tls.NewListener( + &utils.TCPKeepAliveListener{ + TCPListener: ln.(*net.TCPListener), + KeepAlivePeriod: l.md.keepAlivePeriod, + }, + l.md.tlsConfig, + ) + + queueSize := l.md.connQueueSize + if queueSize <= 0 { + queueSize = defaultQueueSize + } + l.connChan = make(chan *conn, queueSize) + l.errChan = make(chan error, 1) + + go func() { + if err := l.server.Serve(ln); err != nil { + // log.Log("[http2]", err) + } + }() + + return +} + +func (l *Listener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = listener.ErrClosed + } + } + return +} + +func (l *Listener) Addr() net.Addr { + return l.addr +} + +func (l *Listener) Close() (err error) { + select { + case <-l.errChan: + default: + err = l.server.Close() + l.errChan <- err + close(l.errChan) + } + return nil +} + +func (l *Listener) handleFunc(w http.ResponseWriter, r *http.Request) { + conn := &conn{ + r: r, + w: w, + closed: make(chan struct{}), + } + select { + case l.connChan <- conn: + default: + // log.Logf("[http2] %s - %s: connection queue is full", r.RemoteAddr, l.server.Addr) + return + } + + <-conn.closed +} + +func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { + if val, ok := md[addr]; ok { + m.addr = val + } else { + err = errors.New("missing address") + return + } + + m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile]) + if err != nil { + return + } + + return +} diff --git a/server/listener/http2/metadata.go b/server/listener/http2/metadata.go new file mode 100644 index 0000000..88d8fd2 --- /dev/null +++ b/server/listener/http2/metadata.go @@ -0,0 +1,38 @@ +package http2 + +import ( + "crypto/tls" + "net/http" + "time" +) + +const ( + addr = "addr" + path = "path" + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + handshakeTimeout = "handshakeTimeout" + readHeaderTimeout = "readHeaderTimeout" + readBufferSize = "readBufferSize" + writeBufferSize = "writeBufferSize" + connQueueSize = "connQueueSize" +) + +const ( + defaultQueueSize = 128 +) + +type metadata struct { + addr string + path string + tlsConfig *tls.Config + handshakeTimeout time.Duration + readHeaderTimeout time.Duration + readBufferSize int + writeBufferSize int + enableCompression bool + responseHeader http.Header + connQueueSize int + keepAlivePeriod time.Duration +}