add http2 tunnel

This commit is contained in:
ginuerzh
2021-12-15 15:19:19 +08:00
parent c651743ea2
commit 5bd3c25c65
12 changed files with 392 additions and 118 deletions

View File

@ -5,28 +5,30 @@ import (
"errors"
"net"
"net/http"
"time"
"net/http/httputil"
"github.com/go-gost/gost/pkg/common/util"
"github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
func init() {
registry.RegisterListener("h2", NewListener)
registry.RegisterListener("h2c", NewListener)
registry.RegisterListener("h2", NewTLSListener)
}
type h2Listener struct {
addr string
net.Listener
md metadata
server *http2.Server
connChan chan net.Conn
errChan chan error
logger logger.Logger
server *http.Server
saddr string
addr net.Addr
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
h2c bool
}
func NewListener(opts ...listener.Option) listener.Listener {
@ -35,7 +37,19 @@ func NewListener(opts ...listener.Option) listener.Listener {
opt(options)
}
return &h2Listener{
addr: options.Addr,
saddr: options.Addr,
logger: options.Logger,
h2c: true,
}
}
func NewTLSListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &h2Listener{
saddr: options.Addr,
logger: options.Logger,
}
}
@ -45,36 +59,45 @@ func (l *h2Listener) Init(md md.Metadata) (err error) {
return
}
ln, err := net.Listen("tcp", l.addr)
if err != nil {
return
}
l.Listener = &util.TCPKeepAliveListener{
TCPListener: ln.(*net.TCPListener),
KeepAlivePeriod: l.md.keepAlivePeriod,
}
// TODO: tune http2 server config
l.server = &http2.Server{
// MaxConcurrentStreams: 1000,
PermitProhibitedCipherSuites: true,
IdleTimeout: 5 * time.Minute,
l.server = &http.Server{
Addr: l.saddr,
}
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
ln, err := net.Listen("tcp", l.saddr)
if err != nil {
return err
}
l.connChan = make(chan net.Conn, queueSize)
l.addr = ln.Addr()
if l.h2c {
l.server.Handler = h2c.NewHandler(
http.HandlerFunc(l.handleFunc), &http2.Server{})
} else {
l.server.Handler = http.HandlerFunc(l.handleFunc)
l.server.TLSConfig = l.md.tlsConfig
if err := http2.ConfigureServer(l.server, nil); err != nil {
ln.Close()
return err
}
ln = tls.NewListener(ln, l.md.tlsConfig)
}
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
go l.listenLoop()
go func() {
if err := l.server.Serve(ln); err != nil {
l.logger.Error(err)
}
}()
return
}
func (l *h2Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
@ -83,58 +106,36 @@ func (l *h2Listener) Accept() (conn net.Conn, err error) {
return
}
func (l *h2Listener) listenLoop() {
for {
conn, err := l.Listener.Accept()
if err != nil {
// log.Log("[http2] accept:", err)
l.errChan <- err
close(l.errChan)
return
}
go l.handleLoop(conn)
}
func (l *h2Listener) Addr() net.Addr {
return l.addr
}
func (l *h2Listener) handleLoop(conn net.Conn) {
if l.md.tlsConfig != nil {
tlsConn := tls.Server(conn, l.md.tlsConfig)
// NOTE: HTTP2 server will check the TLS version,
// so we must ensure that the TLS connection is handshake completed.
if err := tlsConn.Handshake(); err != nil {
// log.Logf("[http2] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return
}
conn = tlsConn
func (l *h2Listener) Close() (err error) {
select {
case <-l.errChan:
default:
err = l.server.Close()
l.errChan <- err
close(l.errChan)
}
opt := http2.ServeConnOpts{
Handler: http.HandlerFunc(l.handleFunc),
}
l.server.ServeConn(conn, &opt)
return nil
}
func (l *h2Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
/*
log.Logf("[http2] %s -> %s %s %s %s",
r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto)
if Debug {
dump, _ := httputil.DumpRequest(r, false)
log.Log("[http2]", string(dump))
}
*/
// w.Header().Set("Proxy-Agent", "gost/"+Version)
if l.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(r, false)
l.logger.Debug(string(dump))
}
conn, err := l.upgrade(w, r)
if err != nil {
// log.Logf("[http2] %s - %s %s %s %s: %s",
// r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto, err)
l.logger.Error(err)
return
}
select {
case l.connChan <- conn:
case l.cqueue <- conn:
default:
conn.Close()
// log.Logf("[http2] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
l.logger.Warnf("connection queue is full, client %s discarded", r.RemoteAddr)
}
<-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed
@ -166,7 +167,7 @@ func (l *h2Listener) upgrade(w http.ResponseWriter, r *http.Request) (*conn, err
return &conn{
r: r.Body,
w: flushWriter{w},
localAddr: l.Listener.Addr(),
localAddr: l.addr,
remoteAddr: remoteAddr,
closed: make(chan struct{}),
}, nil

View File

@ -2,41 +2,28 @@ package h2
import (
"crypto/tls"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
md "github.com/go-gost/gost/pkg/metadata"
)
const (
defaultQueueSize = 128
defaultBacklog = 128
)
type metadata struct {
path string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
connQueueSize int
keepAlivePeriod time.Duration
path string
tlsConfig *tls.Config
backlog int
}
func (l *h2Listener) parseMetadata(md md.Metadata) (err error) {
const (
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
connQueueSize = "connQueueSize"
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
backlog = "backlog"
)
l.md.tlsConfig, err = tls_util.LoadServerConfig(
@ -48,5 +35,11 @@ func (l *h2Listener) parseMetadata(md md.Metadata) (err error) {
return
}
l.md.backlog = md.GetInt(backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.path = md.GetString(path)
return
}

View File

@ -19,13 +19,13 @@ func init() {
}
type http2Listener struct {
saddr string
md metadata
server *http.Server
saddr string
addr net.Addr
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
}
func NewListener(opts ...listener.Option) listener.Listener {