add http2 listener
This commit is contained in:
parent
734419b225
commit
9111b42c04
54
server/listener/http2/conn.go
Normal file
54
server/listener/http2/conn.go
Normal file
@ -0,0 +1,54 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// a dummy HTTP2 server conn used by HTTP2 handler
|
||||
type conn struct {
|
||||
r *http.Request
|
||||
w http.ResponseWriter
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func (c *conn) Read(b []byte) (n int, err error) {
|
||||
return 0, &net.OpError{Op: "read", Net: "http2", Source: nil, Addr: nil, Err: errors.New("read not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) Write(b []byte) (n int, err error) {
|
||||
return 0, &net.OpError{Op: "write", Net: "http2", Source: nil, Addr: nil, Err: errors.New("write not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
select {
|
||||
case <-c.closed:
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) LocalAddr() net.Addr {
|
||||
addr, _ := net.ResolveTCPAddr("tcp", c.r.Host)
|
||||
return addr
|
||||
}
|
||||
|
||||
func (c *conn) RemoteAddr() net.Addr {
|
||||
addr, _ := net.ResolveTCPAddr("tcp", c.r.RemoteAddr)
|
||||
return addr
|
||||
}
|
||||
|
||||
func (c *conn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
89
server/listener/http2/h2/conn.go
Normal file
89
server/listener/http2/h2/conn.go
Normal file
@ -0,0 +1,89 @@
|
||||
package h2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HTTP2 connection, wrapped up just like a net.Conn
|
||||
type conn struct {
|
||||
r io.Reader
|
||||
w io.Writer
|
||||
remoteAddr net.Addr
|
||||
localAddr net.Addr
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func (c *conn) Read(b []byte) (n int, err error) {
|
||||
return c.r.Read(b)
|
||||
}
|
||||
|
||||
func (c *conn) Write(b []byte) (n int, err error) {
|
||||
return c.w.Write(b)
|
||||
}
|
||||
|
||||
func (c *conn) Close() (err error) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
default:
|
||||
close(c.closed)
|
||||
}
|
||||
if rc, ok := c.r.(io.Closer); ok {
|
||||
err = rc.Close()
|
||||
}
|
||||
if w, ok := c.w.(io.Closer); ok {
|
||||
err = w.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *conn) SetDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetReadDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
func (c *conn) SetWriteDeadline(t time.Time) error {
|
||||
return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
|
||||
}
|
||||
|
||||
type flushWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (fw flushWriter) Write(p []byte) (n int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
if s, ok := r.(string); ok {
|
||||
err = errors.New(s)
|
||||
// log.Log("[http2]", err)
|
||||
return
|
||||
}
|
||||
err = r.(error)
|
||||
}
|
||||
}()
|
||||
|
||||
n, err = fw.w.Write(p)
|
||||
if err != nil {
|
||||
// log.Log("flush writer:", err)
|
||||
return
|
||||
}
|
||||
if f, ok := fw.w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
return
|
||||
}
|
186
server/listener/http2/h2/listener.go
Normal file
186
server/listener/http2/h2/listener.go
Normal file
@ -0,0 +1,186 @@
|
||||
package h2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/go-gost/gost/logger"
|
||||
"github.com/go-gost/gost/server/listener"
|
||||
"github.com/go-gost/gost/utils"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
net.Listener
|
||||
md metadata
|
||||
server *http2.Server
|
||||
connChan chan net.Conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", l.md.addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
l.Listener = &utils.TCPKeepAliveListener{
|
||||
TCPListener: ln.(*net.TCPListener),
|
||||
KeepAlivePeriod: l.md.keepAlivePeriod,
|
||||
}
|
||||
// TODO: tune http2 server config
|
||||
l.server = &http2.Server{
|
||||
// MaxConcurrentStreams: 1000,
|
||||
PermitProhibitedCipherSuites: true,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
}
|
||||
|
||||
queueSize := l.md.connQueueSize
|
||||
if queueSize <= 0 {
|
||||
queueSize = defaultQueueSize
|
||||
}
|
||||
l.connChan = make(chan net.Conn, queueSize)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go l.listenLoop()
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.connChan:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) listenLoop() {
|
||||
for {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
// log.Log("[http2] accept:", err)
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
return
|
||||
}
|
||||
go l.handleLoop(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) handleLoop(conn net.Conn) {
|
||||
if l.md.tlsConfig != nil {
|
||||
tlsConn := tls.Server(conn, l.md.tlsConfig)
|
||||
// NOTE: HTTP2 server will check the TLS version,
|
||||
// so we must ensure that the TLS connection is handshake completed.
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
// log.Logf("[http2] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
|
||||
return
|
||||
}
|
||||
conn = tlsConn
|
||||
}
|
||||
|
||||
opt := http2.ServeConnOpts{
|
||||
Handler: http.HandlerFunc(l.handleFunc),
|
||||
}
|
||||
l.server.ServeConn(conn, &opt)
|
||||
}
|
||||
|
||||
func (l *Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
|
||||
/*
|
||||
log.Logf("[http2] %s -> %s %s %s %s",
|
||||
r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto)
|
||||
if Debug {
|
||||
dump, _ := httputil.DumpRequest(r, false)
|
||||
log.Log("[http2]", string(dump))
|
||||
}
|
||||
*/
|
||||
// w.Header().Set("Proxy-Agent", "gost/"+Version)
|
||||
conn, err := l.upgrade(w, r)
|
||||
if err != nil {
|
||||
// log.Logf("[http2] %s - %s %s %s %s: %s",
|
||||
// r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto, err)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case l.connChan <- conn:
|
||||
default:
|
||||
conn.Close()
|
||||
// log.Logf("[http2] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}
|
||||
|
||||
<-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed
|
||||
}
|
||||
|
||||
func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) (*conn, error) {
|
||||
if l.md.path == "" && r.Method != http.MethodConnect {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return nil, errors.New("method not allowed")
|
||||
}
|
||||
|
||||
if l.md.path != "" && r.RequestURI != l.md.path {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return nil, errors.New("bad request")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fw, ok := w.(http.Flusher); ok {
|
||||
fw.Flush() // write header to client
|
||||
}
|
||||
|
||||
remoteAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr)
|
||||
if remoteAddr == nil {
|
||||
remoteAddr = &net.TCPAddr{
|
||||
IP: net.IPv4zero,
|
||||
Port: 0,
|
||||
}
|
||||
}
|
||||
return &conn{
|
||||
r: r.Body,
|
||||
w: flushWriter{w},
|
||||
localAddr: l.Listener.Addr(),
|
||||
remoteAddr: remoteAddr,
|
||||
closed: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
38
server/listener/http2/h2/metadata.go
Normal file
38
server/listener/http2/h2/metadata.go
Normal file
@ -0,0 +1,38 @@
|
||||
package h2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
path = "path"
|
||||
certFile = "certFile"
|
||||
keyFile = "keyFile"
|
||||
caFile = "caFile"
|
||||
handshakeTimeout = "handshakeTimeout"
|
||||
readHeaderTimeout = "readHeaderTimeout"
|
||||
readBufferSize = "readBufferSize"
|
||||
writeBufferSize = "writeBufferSize"
|
||||
connQueueSize = "connQueueSize"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultQueueSize = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
path string
|
||||
tlsConfig *tls.Config
|
||||
handshakeTimeout time.Duration
|
||||
readHeaderTimeout time.Duration
|
||||
readBufferSize int
|
||||
writeBufferSize int
|
||||
enableCompression bool
|
||||
responseHeader http.Header
|
||||
connQueueSize int
|
||||
keepAlivePeriod time.Duration
|
||||
}
|
140
server/listener/http2/listener.go
Normal file
140
server/listener/http2/listener.go
Normal file
@ -0,0 +1,140 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-gost/gost/logger"
|
||||
"github.com/go-gost/gost/server/listener"
|
||||
"github.com/go-gost/gost/utils"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var (
|
||||
_ listener.Listener = (*Listener)(nil)
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
md metadata
|
||||
server *http.Server
|
||||
addr net.Addr
|
||||
connChan chan *conn
|
||||
errChan chan error
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewListener(opts ...listener.Option) *Listener {
|
||||
options := &listener.Options{}
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
return &Listener{
|
||||
logger: options.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) Init(md listener.Metadata) (err error) {
|
||||
l.md, err = l.parseMetadata(md)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
l.server = &http.Server{
|
||||
Addr: l.md.addr,
|
||||
Handler: http.HandlerFunc(l.handleFunc),
|
||||
TLSConfig: l.md.tlsConfig,
|
||||
}
|
||||
if err := http2.ConfigureServer(l.server, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.addr = ln.Addr()
|
||||
|
||||
ln = tls.NewListener(
|
||||
&utils.TCPKeepAliveListener{
|
||||
TCPListener: ln.(*net.TCPListener),
|
||||
KeepAlivePeriod: l.md.keepAlivePeriod,
|
||||
},
|
||||
l.md.tlsConfig,
|
||||
)
|
||||
|
||||
queueSize := l.md.connQueueSize
|
||||
if queueSize <= 0 {
|
||||
queueSize = defaultQueueSize
|
||||
}
|
||||
l.connChan = make(chan *conn, queueSize)
|
||||
l.errChan = make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
if err := l.server.Serve(ln); err != nil {
|
||||
// log.Log("[http2]", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (conn net.Conn, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case conn = <-l.connChan:
|
||||
case err, ok = <-l.errChan:
|
||||
if !ok {
|
||||
err = listener.ErrClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (l *Listener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
func (l *Listener) Close() (err error) {
|
||||
select {
|
||||
case <-l.errChan:
|
||||
default:
|
||||
err = l.server.Close()
|
||||
l.errChan <- err
|
||||
close(l.errChan)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
|
||||
conn := &conn{
|
||||
r: r,
|
||||
w: w,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
select {
|
||||
case l.connChan <- conn:
|
||||
default:
|
||||
// log.Logf("[http2] %s - %s: connection queue is full", r.RemoteAddr, l.server.Addr)
|
||||
return
|
||||
}
|
||||
|
||||
<-conn.closed
|
||||
}
|
||||
|
||||
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
|
||||
if val, ok := md[addr]; ok {
|
||||
m.addr = val
|
||||
} else {
|
||||
err = errors.New("missing address")
|
||||
return
|
||||
}
|
||||
|
||||
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
38
server/listener/http2/metadata.go
Normal file
38
server/listener/http2/metadata.go
Normal file
@ -0,0 +1,38 @@
|
||||
package http2
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
addr = "addr"
|
||||
path = "path"
|
||||
certFile = "certFile"
|
||||
keyFile = "keyFile"
|
||||
caFile = "caFile"
|
||||
handshakeTimeout = "handshakeTimeout"
|
||||
readHeaderTimeout = "readHeaderTimeout"
|
||||
readBufferSize = "readBufferSize"
|
||||
writeBufferSize = "writeBufferSize"
|
||||
connQueueSize = "connQueueSize"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultQueueSize = 128
|
||||
)
|
||||
|
||||
type metadata struct {
|
||||
addr string
|
||||
path string
|
||||
tlsConfig *tls.Config
|
||||
handshakeTimeout time.Duration
|
||||
readHeaderTimeout time.Duration
|
||||
readBufferSize int
|
||||
writeBufferSize int
|
||||
enableCompression bool
|
||||
responseHeader http.Header
|
||||
connQueueSize int
|
||||
keepAlivePeriod time.Duration
|
||||
}
|
Loading…
Reference in New Issue
Block a user