add mws dialer

This commit is contained in:
ginuerzh
2021-12-17 14:05:59 +08:00
parent bfe5eae172
commit 8e31e532e4
23 changed files with 795 additions and 137 deletions

View File

@ -4,8 +4,9 @@ import (
"crypto/tls"
"net"
"net/http"
"net/http/httputil"
ws_util "github.com/go-gost/gost/pkg/common/util/ws"
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
"github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
@ -20,14 +21,14 @@ func init() {
type wsListener struct {
saddr string
md metadata
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
tlsEnabled bool
connChan chan net.Conn
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
}
func NewListener(opts ...listener.Option) listener.Listener {
@ -48,8 +49,8 @@ func NewTLSListener(opts ...listener.Option) listener.Listener {
}
return &wsListener{
saddr: options.Addr,
tlsEnabled: true,
logger: options.Logger,
tlsEnabled: true,
}
}
@ -62,35 +63,26 @@ func (l *wsListener) Init(md md.Metadata) (err error) {
HandshakeTimeout: l.md.handshakeTimeout,
ReadBufferSize: l.md.readBufferSize,
WriteBufferSize: l.md.writeBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: l.md.enableCompression,
CheckOrigin: func(r *http.Request) bool { return true },
}
path := l.md.path
if path == "" {
path = defaultPath
}
mux := http.NewServeMux()
mux.Handle(path, http.HandlerFunc(l.upgrade))
mux.Handle(l.md.path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: l.saddr,
TLSConfig: l.md.tlsConfig,
Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout,
}
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan net.Conn, queueSize)
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.saddr)
if err != nil {
return
}
if l.md.tlsConfig != nil {
if l.tlsEnabled {
ln = tls.NewListener(ln, l.md.tlsConfig)
}
@ -110,7 +102,7 @@ func (l *wsListener) Init(md md.Metadata) (err error) {
func (l *wsListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
@ -128,16 +120,25 @@ func (l *wsListener) Addr() net.Addr {
}
func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) {
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
if l.logger.IsLevelEnabled(logger.DebugLevel) {
log := l.logger.WithFields(map[string]interface{}{
"local": l.addr.String(),
"remote": r.RemoteAddr,
})
dump, _ := httputil.DumpRequest(r, false)
log.Debug(string(dump))
}
conn, err := l.upgrader.Upgrade(w, r, l.md.header)
if err != nil {
l.logger.Error(err)
return
}
select {
case l.connChan <- ws_util.WebsocketServerConn(conn):
case l.cqueue <- ws_util.Conn(conn):
default:
conn.Close()
l.logger.Warn("connection queue is full")
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
}
}

View File

@ -2,6 +2,7 @@ package ws
import (
"crypto/tls"
"fmt"
"net/http"
"time"
@ -10,35 +11,40 @@ import (
)
const (
defaultPath = "/ws"
defaultQueueSize = 128
defaultPath = "/ws"
defaultBacklog = 128
)
type metadata struct {
path string
tlsConfig *tls.Config
path string
backlog int
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
connQueueSize int
header http.Header
}
func (l *wsListener) parseMetadata(md md.Metadata) (err error) {
const (
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
path = "path"
backlog = "backlog"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
responseHeader = "responseHeader"
connQueueSize = "connQueueSize"
header = "header"
)
l.md.tlsConfig, err = tls_util.LoadServerConfig(
@ -51,15 +57,28 @@ func (l *wsListener) parseMetadata(md md.Metadata) (err error) {
}
l.md.path = md.GetString(path)
l.md.connQueueSize = md.GetInt(connQueueSize)
if l.md.connQueueSize <= 0 {
l.md.connQueueSize = defaultQueueSize
if l.md.path == "" {
l.md.path = defaultPath
}
l.md.enableCompression = md.GetBool(enableCompression)
l.md.readBufferSize = md.GetInt(readBufferSize)
l.md.writeBufferSize = md.GetInt(writeBufferSize)
l.md.backlog = md.GetInt(backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
l.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout)
l.md.readBufferSize = md.GetInt(readBufferSize)
l.md.writeBufferSize = md.GetInt(writeBufferSize)
l.md.enableCompression = md.GetBool(enableCompression)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
h := http.Header{}
for k, v := range mm {
h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
l.md.header = h
}
return
}

View File

@ -4,8 +4,9 @@ import (
"crypto/tls"
"net"
"net/http"
"net/http/httputil"
ws_util "github.com/go-gost/gost/pkg/common/util/ws"
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
"github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
@ -16,18 +17,19 @@ import (
func init() {
registry.RegisterListener("mws", NewListener)
registry.RegisterListener("mwss", NewListener)
registry.RegisterListener("mwss", NewTLSListener)
}
type mwsListener struct {
saddr string
md metadata
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
connChan chan net.Conn
errChan chan error
logger logger.Logger
saddr string
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
tlsEnabled bool
}
func NewListener(opts ...listener.Option) listener.Listener {
@ -36,10 +38,23 @@ func NewListener(opts ...listener.Option) listener.Listener {
opt(options)
}
return &mwsListener{
saddr: options.Addr,
logger: options.Logger,
}
}
func NewTLSListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &mwsListener{
saddr: options.Addr,
logger: options.Logger,
tlsEnabled: true,
}
}
func (l *mwsListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
@ -49,8 +64,8 @@ func (l *mwsListener) Init(md md.Metadata) (err error) {
HandshakeTimeout: l.md.handshakeTimeout,
ReadBufferSize: l.md.readBufferSize,
WriteBufferSize: l.md.writeBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: l.md.enableCompression,
CheckOrigin: func(r *http.Request) bool { return true },
}
path := l.md.path
@ -61,19 +76,18 @@ func (l *mwsListener) Init(md md.Metadata) (err error) {
mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: l.saddr,
TLSConfig: l.md.tlsConfig,
Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout,
}
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.saddr)
if err != nil {
return
}
if l.md.tlsConfig != nil {
if l.tlsEnabled {
ln = tls.NewListener(ln, l.md.tlsConfig)
}
@ -93,7 +107,7 @@ func (l *mwsListener) Init(md md.Metadata) (err error) {
func (l *mwsListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
@ -111,20 +125,31 @@ func (l *mwsListener) Addr() net.Addr {
}
func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) {
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
if l.logger.IsLevelEnabled(logger.DebugLevel) {
log := l.logger.WithFields(map[string]interface{}{
"local": l.addr.String(),
"remote": r.RemoteAddr,
})
dump, _ := httputil.DumpRequest(r, false)
log.Debug(string(dump))
}
conn, err := l.upgrader.Upgrade(w, r, l.md.header)
if err != nil {
l.logger.Error(err)
return
}
l.mux(ws_util.WebsocketServerConn(conn))
l.mux(ws_util.Conn(conn))
}
func (l *mwsListener) mux(conn net.Conn) {
defer conn.Close()
smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
if l.md.muxKeepAlivePeriod > 0 {
smuxConfig.KeepAliveInterval = l.md.muxKeepAlivePeriod
if l.md.muxKeepAliveInterval > 0 {
smuxConfig.KeepAliveInterval = l.md.muxKeepAliveInterval
}
if l.md.muxKeepAliveTimeout > 0 {
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
@ -148,17 +173,17 @@ func (l *mwsListener) mux(conn net.Conn) {
for {
stream, err := session.AcceptStream()
if err != nil {
l.logger.Error("accept stream:", err)
l.logger.Error("accept stream: ", err)
return
}
select {
case l.connChan <- stream:
case l.cqueue <- stream:
case <-stream.GetDieCh():
stream.Close()
default:
stream.Close()
l.logger.Error("connection queue is full")
l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
}
}
}

View File

@ -2,6 +2,7 @@ package mux
import (
"crypto/tls"
"fmt"
"net/http"
"time"
@ -10,45 +11,48 @@ import (
)
const (
defaultPath = "/ws"
defaultQueueSize = 128
defaultPath = "/ws"
defaultBacklog = 128
)
type metadata struct {
path string
tlsConfig *tls.Config
path string
backlog int
tlsConfig *tls.Config
header http.Header
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
muxKeepAliveDisabled bool
muxKeepAlivePeriod time.Duration
muxKeepAliveInterval time.Duration
muxKeepAliveTimeout time.Duration
muxMaxFrameSize int
muxMaxReceiveBuffer int
muxMaxStreamBuffer int
connQueueSize int
}
func (l *mwsListener) parseMetadata(md md.Metadata) (err error) {
const (
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
path = "path"
backlog = "backlog"
header = "header"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
responseHeader = "responseHeader"
connQueueSize = "connQueueSize"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAlivePeriod = "muxKeepAlivePeriod"
muxKeepAliveInterval = "muxKeepAliveInterval"
muxKeepAliveTimeout = "muxKeepAliveTimeout"
muxMaxFrameSize = "muxMaxFrameSize"
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
@ -64,5 +68,35 @@ func (l *mwsListener) parseMetadata(md md.Metadata) (err error) {
return
}
l.md.path = md.GetString(path)
if l.md.path == "" {
l.md.path = defaultPath
}
l.md.backlog = md.GetInt(backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
l.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout)
l.md.readBufferSize = md.GetInt(readBufferSize)
l.md.writeBufferSize = md.GetInt(writeBufferSize)
l.md.enableCompression = md.GetBool(enableCompression)
l.md.muxKeepAliveDisabled = md.GetBool(muxKeepAliveDisabled)
l.md.muxKeepAliveInterval = md.GetDuration(muxKeepAliveInterval)
l.md.muxKeepAliveTimeout = md.GetDuration(muxKeepAliveTimeout)
l.md.muxMaxFrameSize = md.GetInt(muxMaxFrameSize)
l.md.muxMaxReceiveBuffer = md.GetInt(muxMaxReceiveBuffer)
l.md.muxMaxStreamBuffer = md.GetInt(muxMaxStreamBuffer)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
h := http.Header{}
for k, v := range mm {
h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
l.md.header = h
}
return
}