add mux support for tls & ws listeners

This commit is contained in:
ginuerzh 2021-04-11 14:33:17 +08:00
parent 4e04e7ed86
commit f6bd34f7a3
9 changed files with 455 additions and 37 deletions

1
go.mod
View File

@ -9,4 +9,5 @@ require (
github.com/shadowsocks/go-shadowsocks2 v0.1.4 github.com/shadowsocks/go-shadowsocks2 v0.1.4
github.com/shadowsocks/shadowsocks-go v0.0.0-20200409064450-3e585ff90601 github.com/shadowsocks/shadowsocks-go v0.0.0-20200409064450-3e585ff90601
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/xtaci/smux v1.5.15
) )

2
go.sum
View File

@ -18,6 +18,8 @@ github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE
github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/xtaci/smux v1.5.15 h1:6hMiXswcleXj5oNfcJc+DXS8Vj36XX2LaX98udog6Kc=
github.com/xtaci/smux v1.5.15/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g= golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g=
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=

View File

@ -0,0 +1,140 @@
package mux
import (
"crypto/tls"
"errors"
"net"
"github.com/go-gost/gost/logger"
"github.com/go-gost/gost/server/listener"
"github.com/go-gost/gost/utils"
"github.com/xtaci/smux"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
net.Listener
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
l.Listener = tls.NewListener(ln, l.md.tlsConfig)
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan net.Conn, queueSize)
l.errChan = make(chan error, 1)
go l.listenLoop()
return
}
func (l *Listener) listenLoop() {
for {
conn, err := l.Listener.Accept()
if err != nil {
l.errChan <- err
close(l.errChan)
return
}
go l.mux(conn)
}
}
func (l *Listener) mux(conn net.Conn) {
smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
if l.md.muxKeepAlivePeriod > 0 {
smuxConfig.KeepAliveInterval = l.md.muxKeepAlivePeriod
}
if l.md.muxKeepAliveTimeout > 0 {
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
}
if l.md.muxMaxFrameSize > 0 {
smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize
}
if l.md.muxMaxReceiveBuffer > 0 {
smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer
}
if l.md.muxMaxStreamBuffer > 0 {
smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer
}
session, err := smux.Server(conn, smuxConfig)
if err != nil {
l.logger.Error(err)
return
}
defer session.Close()
for {
stream, err := session.AcceptStream()
if err != nil {
l.logger.Error("accept stream:", err)
return
}
select {
case l.connChan <- stream:
case <-stream.GetDieCh():
default:
stream.Close()
l.logger.Error("connection queue is full")
}
}
}
func (l *Listener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case err, ok = <-l.errChan:
if !ok {
err = errors.New("accpet on closed listener")
}
}
return
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}

View File

@ -0,0 +1,38 @@
package mux
import (
"crypto/tls"
"time"
)
const (
addr = "addr"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAlivePeriod = "muxKeepAlivePeriod"
muxKeepAliveTimeout = "muxKeepAliveTimeout"
muxMaxFrameSize = "muxMaxFrameSize"
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
const (
defaultQueueSize = 128
)
type metadata struct {
addr string
tlsConfig *tls.Config
muxKeepAliveDisabled bool
muxKeepAlivePeriod time.Duration
muxKeepAliveTimeout time.Duration
muxMaxFrameSize int
muxMaxReceiveBuffer int
muxMaxStreamBuffer int
connQueueSize int
}

View File

@ -1,11 +1,10 @@
package tcp package ws
import ( import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"net" "net"
"net/http" "net/http"
"time"
"github.com/go-gost/gost/logger" "github.com/go-gost/gost/logger"
"github.com/go-gost/gost/server/listener" "github.com/go-gost/gost/server/listener"
@ -89,12 +88,6 @@ func (l *Listener) Init(md listener.Metadata) (err error) {
close(l.errChan) close(l.errChan)
}() }()
select {
case err = <-l.errChan:
return
case <-time.After(100 * time.Millisecond):
}
return return
} }
@ -138,36 +131,9 @@ func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) {
} }
select { select {
case l.connChan <- &websocketConn{Conn: conn}: case l.connChan <- utils.WebsocketServerConn(conn):
default: default:
conn.Close() conn.Close()
l.logger.Warn("connection queue is full") l.logger.Warn("connection queue is full")
} }
} }
type websocketConn struct {
*websocket.Conn
rb []byte
}
func (c *websocketConn) Read(b []byte) (n int, err error) {
if len(c.rb) == 0 {
_, c.rb, err = c.ReadMessage()
}
n = copy(b, c.rb)
c.rb = c.rb[n:]
return
}
func (c *websocketConn) Write(b []byte) (n int, err error) {
err = c.WriteMessage(websocket.BinaryMessage, b)
n = len(b)
return
}
func (c *websocketConn) SetDeadline(t time.Time) error {
if err := c.SetReadDeadline(t); err != nil {
return err
}
return c.SetWriteDeadline(t)
}

View File

@ -1,4 +1,4 @@
package tcp package ws
import ( import (
"crypto/tls" "crypto/tls"

View File

@ -0,0 +1,177 @@
package mux
import (
"crypto/tls"
"errors"
"net"
"net/http"
"github.com/go-gost/gost/logger"
"github.com/go-gost/gost/server/listener"
"github.com/go-gost/gost/utils"
"github.com/gorilla/websocket"
"github.com/xtaci/smux"
)
var (
_ listener.Listener = (*Listener)(nil)
)
type Listener struct {
md metadata
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
connChan chan net.Conn
errChan chan error
logger logger.Logger
}
func NewListener(opts ...listener.Option) *Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &Listener{
logger: options.Logger,
}
}
func (l *Listener) Init(md listener.Metadata) (err error) {
l.md, err = l.parseMetadata(md)
if err != nil {
return
}
l.upgrader = &websocket.Upgrader{
HandshakeTimeout: l.md.handshakeTimeout,
ReadBufferSize: l.md.readBufferSize,
WriteBufferSize: l.md.writeBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: l.md.enableCompression,
}
path := l.md.path
if path == "" {
path = defaultPath
}
mux := http.NewServeMux()
mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: l.md.addr,
TLSConfig: l.md.tlsConfig,
Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout,
}
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan net.Conn, queueSize)
l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.md.addr)
if err != nil {
return
}
if l.md.tlsConfig != nil {
ln = tls.NewListener(ln, l.md.tlsConfig)
}
l.addr = ln.Addr()
go func() {
err := l.srv.Serve(ln)
if err != nil {
l.errChan <- err
}
close(l.errChan)
}()
return
}
func (l *Listener) Accept() (conn net.Conn, err error) {
select {
case conn = <-l.connChan:
case err = <-l.errChan:
}
return
}
func (l *Listener) Close() error {
return l.srv.Close()
}
func (l *Listener) Addr() net.Addr {
return l.addr
}
func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) {
if val, ok := md[addr]; ok {
m.addr = val
} else {
err = errors.New("missing address")
return
}
m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile])
if err != nil {
return
}
return
}
func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) {
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
if err != nil {
l.logger.Error(err)
return
}
l.mux(utils.WebsocketServerConn(conn))
}
func (l *Listener) mux(conn net.Conn) {
smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
if l.md.muxKeepAlivePeriod > 0 {
smuxConfig.KeepAliveInterval = l.md.muxKeepAlivePeriod
}
if l.md.muxKeepAliveTimeout > 0 {
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
}
if l.md.muxMaxFrameSize > 0 {
smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize
}
if l.md.muxMaxReceiveBuffer > 0 {
smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer
}
if l.md.muxMaxStreamBuffer > 0 {
smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer
}
session, err := smux.Server(conn, smuxConfig)
if err != nil {
l.logger.Error(err)
return
}
defer session.Close()
for {
stream, err := session.AcceptStream()
if err != nil {
l.logger.Error("accept stream:", err)
return
}
select {
case l.connChan <- stream:
case <-stream.GetDieCh():
default:
stream.Close()
l.logger.Error("connection queue is full")
}
}
}

View File

@ -0,0 +1,54 @@
package mux
import (
"crypto/tls"
"net/http"
"time"
)
const (
addr = "addr"
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
responseHeader = "responseHeader"
connQueueSize = "connQueueSize"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAlivePeriod = "muxKeepAlivePeriod"
muxKeepAliveTimeout = "muxKeepAliveTimeout"
muxMaxFrameSize = "muxMaxFrameSize"
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
const (
defaultPath = "/ws"
defaultQueueSize = 128
)
type metadata struct {
addr string
path string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
muxKeepAliveDisabled bool
muxKeepAlivePeriod time.Duration
muxKeepAliveTimeout time.Duration
muxMaxFrameSize int
muxMaxReceiveBuffer int
muxMaxStreamBuffer int
connQueueSize int
}

View File

@ -1 +1,41 @@
package utils package utils
import (
"net"
"time"
"github.com/gorilla/websocket"
)
type websocketConn struct {
*websocket.Conn
rb []byte
}
func WebsocketServerConn(conn *websocket.Conn) net.Conn {
return &websocketConn{
Conn: conn,
}
}
func (c *websocketConn) Read(b []byte) (n int, err error) {
if len(c.rb) == 0 {
_, c.rb, err = c.ReadMessage()
}
n = copy(b, c.rb)
c.rb = c.rb[n:]
return
}
func (c *websocketConn) Write(b []byte) (n int, err error) {
err = c.WriteMessage(websocket.BinaryMessage, b)
n = len(b)
return
}
func (c *websocketConn) SetDeadline(t time.Time) error {
if err := c.SetReadDeadline(t); err != nil {
return err
}
return c.SetWriteDeadline(t)
}