add http2 tunnel

This commit is contained in:
ginuerzh 2021-12-15 15:19:19 +08:00
parent c651743ea2
commit 5bd3c25c65
12 changed files with 392 additions and 118 deletions

2
.gitignore vendored
View File

@ -33,3 +33,5 @@ _testmain.go
cmd/gost/gost cmd/gost/gost
snap snap
*.pem

View File

@ -15,6 +15,7 @@ import (
// Register dialers // Register dialers
_ "github.com/go-gost/gost/pkg/dialer/ftcp" _ "github.com/go-gost/gost/pkg/dialer/ftcp"
_ "github.com/go-gost/gost/pkg/dialer/http2" _ "github.com/go-gost/gost/pkg/dialer/http2"
_ "github.com/go-gost/gost/pkg/dialer/http2/h2"
_ "github.com/go-gost/gost/pkg/dialer/tcp" _ "github.com/go-gost/gost/pkg/dialer/tcp"
_ "github.com/go-gost/gost/pkg/dialer/udp" _ "github.com/go-gost/gost/pkg/dialer/udp"

View File

@ -34,7 +34,7 @@ func buildDefaultTLSConfig(cfg *config.TLSConfig) {
} }
log.Warn("load TLS certificate files failed, use random generated certificate") log.Warn("load TLS certificate files failed, use random generated certificate")
} else { } else {
log.Debug("load TLS certificate files OK") log.Info("load TLS certificate files OK")
} }
tls_util.DefaultConfig = tlsConfig tls_util.DefaultConfig = tlsConfig
} }

View File

@ -7,6 +7,11 @@ profiling:
addr: ":6060" addr: ":6060"
enabled: true enabled: true
# tls:
# cert: "cert.pem"
# key: "key.pem"
# ca: "root.ca"
services: services:
- name: http+tcp - name: http+tcp
url: "http://gost:gost@:8000" url: "http://gost:gost@:8000"

View File

@ -66,10 +66,9 @@ func (c *http2Connector) Connect(ctx context.Context, conn net.Conn, network, ad
Host: address, Host: address,
ProtoMajor: 2, ProtoMajor: 2,
ProtoMinor: 0, ProtoMinor: 0,
Proto: "HTTP/2.0",
Header: make(http.Header), Header: make(http.Header),
Body: pr, Body: pr,
ContentLength: -1, // ContentLength: -1,
} }
if c.md.UserAgent != "" { if c.md.UserAgent != "" {
req.Header.Set("User-Agent", c.md.UserAgent) req.Header.Set("User-Agent", c.md.UserAgent)

View File

@ -80,20 +80,6 @@ func (d *http2Dialer) Dial(ctx context.Context, address string, opts ...dialer.D
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
}, },
} }
/*
client = &http.Client{
Transport: &http2.Transport{
TLSClientConfig: d.md.tlsConfig,
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
conn, err := d.dial(ctx, network, addr, options)
if err != nil {
return nil, err
}
return tls_util.WrapTLSClient(conn, cfg, time.Duration(0))
},
},
}
*/
d.clients[address] = client d.clients[address] = client
} }

View File

@ -0,0 +1,54 @@
package h2
import (
"errors"
"io"
"net"
"time"
)
// HTTP2 connection, wrapped up just like a net.Conn.
type http2Conn struct {
r io.Reader
w io.Writer
remoteAddr net.Addr
localAddr net.Addr
}
func (c *http2Conn) Read(b []byte) (n int, err error) {
return c.r.Read(b)
}
func (c *http2Conn) Write(b []byte) (n int, err error) {
return c.w.Write(b)
}
func (c *http2Conn) Close() (err error) {
if r, ok := c.r.(io.Closer); ok {
err = r.Close()
}
if w, ok := c.w.(io.Closer); ok {
err = w.Close()
}
return
}
func (c *http2Conn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *http2Conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *http2Conn) SetDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "h2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *http2Conn) SetReadDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "h2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}
func (c *http2Conn) SetWriteDeadline(t time.Time) error {
return &net.OpError{Op: "set", Net: "h2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")}
}

View File

@ -0,0 +1,190 @@
package h2
import (
"context"
"crypto/tls"
"errors"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"sync"
"time"
"github.com/go-gost/gost/pkg/dialer"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"golang.org/x/net/http2"
)
func init() {
registry.RegisterDialer("h2", NewTLSDialer)
registry.RegisterDialer("h2c", NewDialer)
}
type h2Dialer struct {
clients map[string]*http.Client
clientMutex sync.Mutex
logger logger.Logger
md metadata
h2c bool
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &h2Dialer{
clients: make(map[string]*http.Client),
logger: options.Logger,
h2c: true,
}
}
func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &h2Dialer{
clients: make(map[string]*http.Client),
logger: options.Logger,
}
}
func (d *h2Dialer) Init(md md.Metadata) (err error) {
if err = d.parseMetadata(md); err != nil {
return
}
return nil
}
// IsMultiplex implements dialer.Multiplexer interface.
func (d *h2Dialer) IsMultiplex() bool {
return true
}
func (d *h2Dialer) Dial(ctx context.Context, address string, opts ...dialer.DialOption) (net.Conn, error) {
raddr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
d.logger.Error(err)
return nil, err
}
d.clientMutex.Lock()
client, ok := d.clients[address]
if !ok {
options := &dialer.DialOptions{}
for _, opt := range opts {
opt(options)
}
client = &http.Client{}
if d.h2c {
client.Transport = &http2.Transport{
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
return d.dial(ctx, network, addr, options)
},
}
} else {
client.Transport = &http.Transport{
TLSClientConfig: d.md.tlsConfig,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return d.dial(ctx, network, addr, options)
},
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
d.clients[address] = client
}
d.clientMutex.Unlock()
host := d.md.host
if host == "" {
host = address
}
pr, pw := io.Pipe()
req := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Scheme: "https", Host: host},
Header: make(http.Header),
ProtoMajor: 2,
ProtoMinor: 0,
Body: pr,
Host: host,
// ContentLength: -1,
}
if d.md.path != "" {
req.Method = http.MethodGet
req.URL.Path = d.md.path
}
if d.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpRequest(req, false)
d.logger.Debug(string(dump))
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if d.logger.IsLevelEnabled(logger.DebugLevel) {
dump, _ := httputil.DumpResponse(resp, false)
d.logger.Debug(string(dump))
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return nil, errors.New(resp.Status)
}
conn := &http2Conn{
r: resp.Body,
w: pw,
remoteAddr: raddr,
localAddr: &net.TCPAddr{IP: net.IPv4zero, Port: 0},
}
return conn, nil
}
func (d *h2Dialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) {
dial := opts.DialFunc
if dial != nil {
conn, err := dial(ctx, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial with dial func")
}
return conn, err
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, network, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debugf("dial direct %s/%s", addr, network)
}
return conn, err
}

View File

@ -0,0 +1,43 @@
package h2
import (
"crypto/tls"
"net"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
md "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
path string
host string
tlsConfig *tls.Config
}
func (d *h2Dialer) parseMetadata(md md.Metadata) (err error) {
const (
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
path = "path"
)
d.md.host = md.GetString(serverName)
sn, _, _ := net.SplitHostPort(d.md.host)
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
md.GetString(certFile),
md.GetString(keyFile),
md.GetString(caFile),
md.GetBool(secure),
sn,
)
d.md.path = md.GetString(path)
return
}

View File

@ -5,28 +5,30 @@ import (
"errors" "errors"
"net" "net"
"net/http" "net/http"
"time" "net/http/httputil"
"github.com/go-gost/gost/pkg/common/util"
"github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry" "github.com/go-gost/gost/pkg/registry"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
) )
func init() { func init() {
registry.RegisterListener("h2", NewListener) registry.RegisterListener("h2c", NewListener)
registry.RegisterListener("h2", NewTLSListener)
} }
type h2Listener struct { type h2Listener struct {
addr string server *http.Server
net.Listener saddr string
md metadata addr net.Addr
server *http2.Server cqueue chan net.Conn
connChan chan net.Conn
errChan chan error errChan chan error
logger logger.Logger logger logger.Logger
md metadata
h2c bool
} }
func NewListener(opts ...listener.Option) listener.Listener { func NewListener(opts ...listener.Option) listener.Listener {
@ -35,7 +37,19 @@ func NewListener(opts ...listener.Option) listener.Listener {
opt(options) opt(options)
} }
return &h2Listener{ return &h2Listener{
addr: options.Addr, saddr: options.Addr,
logger: options.Logger,
h2c: true,
}
}
func NewTLSListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &h2Listener{
saddr: options.Addr,
logger: options.Logger, logger: options.Logger,
} }
} }
@ -45,36 +59,45 @@ func (l *h2Listener) Init(md md.Metadata) (err error) {
return return
} }
ln, err := net.Listen("tcp", l.addr) l.server = &http.Server{
if err != nil { Addr: l.saddr,
return
}
l.Listener = &util.TCPKeepAliveListener{
TCPListener: ln.(*net.TCPListener),
KeepAlivePeriod: l.md.keepAlivePeriod,
}
// TODO: tune http2 server config
l.server = &http2.Server{
// MaxConcurrentStreams: 1000,
PermitProhibitedCipherSuites: true,
IdleTimeout: 5 * time.Minute,
} }
queueSize := l.md.connQueueSize ln, err := net.Listen("tcp", l.saddr)
if queueSize <= 0 { if err != nil {
queueSize = defaultQueueSize return err
} }
l.connChan = make(chan net.Conn, queueSize) l.addr = ln.Addr()
if l.h2c {
l.server.Handler = h2c.NewHandler(
http.HandlerFunc(l.handleFunc), &http2.Server{})
} else {
l.server.Handler = http.HandlerFunc(l.handleFunc)
l.server.TLSConfig = l.md.tlsConfig
if err := http2.ConfigureServer(l.server, nil); err != nil {
ln.Close()
return err
}
ln = tls.NewListener(ln, l.md.tlsConfig)
}
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1) l.errChan = make(chan error, 1)
go l.listenLoop() go func() {
if err := l.server.Serve(ln); err != nil {
l.logger.Error(err)
}
}()
return return
} }
func (l *h2Listener) Accept() (conn net.Conn, err error) { func (l *h2Listener) Accept() (conn net.Conn, err error) {
var ok bool var ok bool
select { select {
case conn = <-l.connChan: case conn = <-l.cqueue:
case err, ok = <-l.errChan: case err, ok = <-l.errChan:
if !ok { if !ok {
err = listener.ErrClosed err = listener.ErrClosed
@ -83,58 +106,36 @@ func (l *h2Listener) Accept() (conn net.Conn, err error) {
return return
} }
func (l *h2Listener) listenLoop() { func (l *h2Listener) Addr() net.Addr {
for { return l.addr
conn, err := l.Listener.Accept() }
if err != nil {
// log.Log("[http2] accept:", err) func (l *h2Listener) Close() (err error) {
select {
case <-l.errChan:
default:
err = l.server.Close()
l.errChan <- err l.errChan <- err
close(l.errChan) close(l.errChan)
return
} }
go l.handleLoop(conn) return nil
}
}
func (l *h2Listener) handleLoop(conn net.Conn) {
if l.md.tlsConfig != nil {
tlsConn := tls.Server(conn, l.md.tlsConfig)
// NOTE: HTTP2 server will check the TLS version,
// so we must ensure that the TLS connection is handshake completed.
if err := tlsConn.Handshake(); err != nil {
// log.Logf("[http2] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err)
return
}
conn = tlsConn
}
opt := http2.ServeConnOpts{
Handler: http.HandlerFunc(l.handleFunc),
}
l.server.ServeConn(conn, &opt)
} }
func (l *h2Listener) handleFunc(w http.ResponseWriter, r *http.Request) { func (l *h2Listener) handleFunc(w http.ResponseWriter, r *http.Request) {
/* if l.logger.IsLevelEnabled(logger.DebugLevel) {
log.Logf("[http2] %s -> %s %s %s %s",
r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto)
if Debug {
dump, _ := httputil.DumpRequest(r, false) dump, _ := httputil.DumpRequest(r, false)
log.Log("[http2]", string(dump)) l.logger.Debug(string(dump))
} }
*/
// w.Header().Set("Proxy-Agent", "gost/"+Version)
conn, err := l.upgrade(w, r) conn, err := l.upgrade(w, r)
if err != nil { if err != nil {
// log.Logf("[http2] %s - %s %s %s %s: %s", l.logger.Error(err)
// r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto, err)
return return
} }
select { select {
case l.connChan <- conn: case l.cqueue <- conn:
default: default:
conn.Close() conn.Close()
// log.Logf("[http2] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) l.logger.Warnf("connection queue is full, client %s discarded", r.RemoteAddr)
} }
<-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed <-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed
@ -166,7 +167,7 @@ func (l *h2Listener) upgrade(w http.ResponseWriter, r *http.Request) (*conn, err
return &conn{ return &conn{
r: r.Body, r: r.Body,
w: flushWriter{w}, w: flushWriter{w},
localAddr: l.Listener.Addr(), localAddr: l.addr,
remoteAddr: remoteAddr, remoteAddr: remoteAddr,
closed: make(chan struct{}), closed: make(chan struct{}),
}, nil }, nil

View File

@ -2,28 +2,19 @@ package h2
import ( import (
"crypto/tls" "crypto/tls"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls" tls_util "github.com/go-gost/gost/pkg/common/util/tls"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
) )
const ( const (
defaultQueueSize = 128 defaultBacklog = 128
) )
type metadata struct { type metadata struct {
path string path string
tlsConfig *tls.Config tlsConfig *tls.Config
handshakeTimeout time.Duration backlog int
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
connQueueSize int
keepAlivePeriod time.Duration
} }
func (l *h2Listener) parseMetadata(md md.Metadata) (err error) { func (l *h2Listener) parseMetadata(md md.Metadata) (err error) {
@ -32,11 +23,7 @@ func (l *h2Listener) parseMetadata(md md.Metadata) (err error) {
certFile = "certFile" certFile = "certFile"
keyFile = "keyFile" keyFile = "keyFile"
caFile = "caFile" caFile = "caFile"
handshakeTimeout = "handshakeTimeout" backlog = "backlog"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
connQueueSize = "connQueueSize"
) )
l.md.tlsConfig, err = tls_util.LoadServerConfig( l.md.tlsConfig, err = tls_util.LoadServerConfig(
@ -48,5 +35,11 @@ func (l *h2Listener) parseMetadata(md md.Metadata) (err error) {
return return
} }
l.md.backlog = md.GetInt(backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.path = md.GetString(path)
return return
} }

View File

@ -19,13 +19,13 @@ func init() {
} }
type http2Listener struct { type http2Listener struct {
saddr string
md metadata
server *http.Server server *http.Server
saddr string
addr net.Addr addr net.Addr
cqueue chan net.Conn cqueue chan net.Conn
errChan chan error errChan chan error
logger logger.Logger logger logger.Logger
md metadata
} }
func NewListener(opts ...listener.Option) listener.Listener { func NewListener(opts ...listener.Option) listener.Listener {