add http2 listener

This commit is contained in:
ginuerzh 2021-07-06 09:53:23 +08:00
parent 734419b225
commit 9111b42c04
6 changed files with 545 additions and 0 deletions

View File

@ -0,0 +1,54 @@
package http2
import (
"errors"
"net"
"net/http"
"time"
)
// a dummy HTTP2 server conn used by HTTP2 handler
type conn struct {
r *http.Request
w http.ResponseWriter
closed chan struct{}
}
func (c *conn) Read(b []byte) (n int, err error) {
return 0, &net.OpError{Op: "read", Net: "http2", Source: nil, Addr: nil, Err: errors.New("read not supported")}
}
func (c *conn) Write(b []byte) (n int, err error) {
return 0, &net.OpError{Op: "write", Net: "http2", Source: nil, Addr: nil, Err: errors.New("write not supported")}
}
func (c *conn) Close() error {
select {
case <-c.closed:
default:
close(c.closed)
}
return nil
}
func (c *conn) LocalAddr() net.Addr {
addr, _ := net.ResolveTCPAddr("tcp", c.r.Host)
return addr
}
func (c *conn) RemoteAddr() net.Addr {
addr, _ := net.ResolveTCPAddr("tcp", c.r.RemoteAddr)
return addr
}
func (c *conn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}

View File

@ -0,0 +1,89 @@
package h2
import (
"errors"
"io"
"net"
"net/http"
"time"
)
// HTTP2 connection, wrapped up just like a net.Conn
type conn struct {
r io.Reader
w io.Writer
remoteAddr net.Addr
localAddr net.Addr
closed chan struct{}
}
func (c *conn) Read(b []byte) (n int, err error) {
return c.r.Read(b)
}
func (c *conn) Write(b []byte) (n int, err error) {
return c.w.Write(b)
}
func (c *conn) Close() (err error) {
select {
case <-c.closed:
return
default:
close(c.closed)
}
if rc, ok := c.r.(io.Closer); ok {
err = rc.Close()
}
if w, ok := c.w.(io.Closer); ok {
err = w.Close()
}
return
}
func (c *conn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *conn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *conn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
type flushWriter struct {
w io.Writer
}
func (fw flushWriter) Write(p []byte) (n int, err error) {
defer func() {
if r := recover(); r != nil {
if s, ok := r.(string); ok {
err = errors.New(s)
// log.Log("[http2]", err)
return
}
err = r.(error)
}
}()
n, err = fw.w.Write(p)
if err != nil {
// log.Log("flush writer:", err)
return
}
if f, ok := fw.w.(http.Flusher); ok {
f.Flush()
}
return
}

View File

@ -0,0 +1,186 @@
package h2
import (
"crypto/tls"
"errors"
"net"
"net/http"
"time"
"github.com/go-gost/gost/logger"
"github.com/go-gost/gost/server/listener"
"github.com/go-gost/gost/utils"
"golang.org/x/net/http2"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
net.Listener
md metadata
server *http2.Server
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
l.Listener = &utils.TCPKeepAliveListener{
TCPListener: ln.(*net.TCPListener),
KeepAlivePeriod: l.md.keepAlivePeriod,
}
// TODO: tune http2 server config
l.server = &http2.Server{
// MaxConcurrentStreams: 1000,
PermitProhibitedCipherSuites: true,
IdleTimeout: 5 * time.Minute,
}
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan net.Conn, queueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) listenLoop() {
for {
conn, err := l.Listener.Accept()
if err != nil {
// log.Log("[http2] accept:", err)
l.errChan <- err
close(l.errChan)
return
}
go l.handleLoop(conn)
}
}
func (l *Listener) handleLoop(conn net.Conn) {
if l.md.tlsConfig != nil {
tlsConn := tls.Server(conn, l.md.tlsConfig)
// NOTE: HTTP2 server will check the TLS version,
// so we must ensure that the TLS connection is handshake completed.
if err := tlsConn.Handshake(); err != nil {
// log.Logf("[http2] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return
}
conn = tlsConn
}
opt := http2.ServeConnOpts{
Handler: http.HandlerFunc(l.handleFunc),
}
l.server.ServeConn(conn, &opt)
}
func (l *Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
/*
log.Logf("[http2] %s -> %s %s %s %s",
r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto)
if Debug {
dump, _ := httputil.DumpRequest(r, false)
log.Log("[http2]", string(dump))
}
*/
// w.Header().Set("Proxy-Agent", "gost/"+Version)
conn, err := l.upgrade(w, r)
if err != nil {
// log.Logf("[http2] %s - %s %s %s %s: %s",
// r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto, err)
return
}
select {
case l.connChan <- conn:
default:
conn.Close()
// log.Logf("[http2] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
}
<-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed
}
func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) (*conn, error) {
if l.md.path == "" && r.Method != http.MethodConnect {
w.WriteHeader(http.StatusMethodNotAllowed)
return nil, errors.New("method not allowed")
}
if l.md.path != "" && r.RequestURI != l.md.path {
w.WriteHeader(http.StatusBadRequest)
return nil, errors.New("bad request")
}
w.WriteHeader(http.StatusOK)
if fw, ok := w.(http.Flusher); ok {
fw.Flush() // write header to client
}
remoteAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
if remoteAddr == nil {
remoteAddr = &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
}
return &conn{
r: r.Body,
w: flushWriter{w},
localAddr: l.Listener.Addr(),
remoteAddr: remoteAddr,
closed: make(chan struct{}),
}, nil
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}

View File

@ -0,0 +1,38 @@
package h2
import (
"crypto/tls"
"net/http"
"time"
)
const (
addr = "addr"
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
connQueueSize = "connQueueSize"
)
const (
defaultQueueSize = 128
)
type metadata struct {
addr string
path string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
connQueueSize int
keepAlivePeriod time.Duration
}

View File

@ -0,0 +1,140 @@
package http2
import (
"crypto/tls"
"errors"
"net"
"net/http"
"github.com/go-gost/gost/logger"
"github.com/go-gost/gost/server/listener"
"github.com/go-gost/gost/utils"
"golang.org/x/net/http2"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
server *http.Server
addr net.Addr
connChan chan *conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
l.server = &http.Server{
Addr: l.md.addr,
Handler: http.HandlerFunc(l.handleFunc),
TLSConfig: l.md.tlsConfig,
}
if err := http2.ConfigureServer(l.server, nil); err != nil {
return err
}
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
l.addr = ln.Addr()
ln = tls.NewListener(
&utils.TCPKeepAliveListener{
TCPListener: ln.(*net.TCPListener),
KeepAlivePeriod: l.md.keepAlivePeriod,
},
l.md.tlsConfig,
)
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan *conn, queueSize)
l.errChan = make(chan error, 1)
go func() {
if err := l.server.Serve(ln); err != nil {
// log.Log("[http2]", err)
}
}()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *Listener) Addr() net.Addr {
return l.addr
}
func (l *Listener) Close() (err error) {
select {
case <-l.errChan:
default:
err = l.server.Close()
l.errChan <- err
close(l.errChan)
}
return nil
}
func (l *Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
conn := &conn{
r: r,
w: w,
closed: make(chan struct{}),
}
select {
case l.connChan <- conn:
default:
// log.Logf("[http2] %s - %s: connection queue is full", r.RemoteAddr, l.server.Addr)
return
}
<-conn.closed
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}

View File

@ -0,0 +1,38 @@
package http2
import (
"crypto/tls"
"net/http"
"time"
)
const (
addr = "addr"
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
connQueueSize = "connQueueSize"
)
const (
defaultQueueSize = 128
)
type metadata struct {
addr string
path string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
connQueueSize int
keepAlivePeriod time.Duration
}