diff --git a/go.mod b/go.mod index ca2524e..7744f20 100644 --- a/go.mod +++ b/go.mod @@ -9,4 +9,5 @@ require ( github.com/shadowsocks/go-shadowsocks2 v0.1.4 github.com/shadowsocks/shadowsocks-go v0.0.0-20200409064450-3e585ff90601 github.com/sirupsen/logrus v1.8.1 + github.com/xtaci/smux v1.5.15 ) diff --git a/go.sum b/go.sum index dfe6401..6f69bf7 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,8 @@ github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/xtaci/smux v1.5.15 h1:6hMiXswcleXj5oNfcJc+DXS8Vj36XX2LaX98udog6Kc= +github.com/xtaci/smux v1.5.15/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 h1:/ZScEX8SfEmUGRHs0gxpqteO5nfNW6axyZbBdw9A12g= golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= diff --git a/server/listener/tls/mux/listener.go b/server/listener/tls/mux/listener.go new file mode 100644 index 0000000..ac33ec8 --- /dev/null +++ b/server/listener/tls/mux/listener.go @@ -0,0 +1,140 @@ +package mux + +import ( + "crypto/tls" + "errors" + "net" + + "github.com/go-gost/gost/logger" + "github.com/go-gost/gost/server/listener" + "github.com/go-gost/gost/utils" + "github.com/xtaci/smux" +) + +var ( + _ listener.Listener = (*Listener)(nil) +) + +type Listener struct { + md metadata + net.Listener + connChan chan net.Conn + errChan chan error + logger logger.Logger +} + +func NewListener(opts ...listener.Option) *Listener { + options := &listener.Options{} + for _, opt := range opts { + opt(options) + } + return &Listener{ + logger: options.Logger, + } +} + +func (l *Listener) Init(md listener.Metadata) (err error) { + l.md, err = l.parseMetadata(md) + if err != nil { + return + } + + ln, err := net.Listen("tcp", l.md.addr) + if err != nil { + return + } + l.Listener = tls.NewListener(ln, l.md.tlsConfig) + + queueSize := l.md.connQueueSize + if queueSize <= 0 { + queueSize = defaultQueueSize + } + l.connChan = make(chan net.Conn, queueSize) + l.errChan = make(chan error, 1) + + go l.listenLoop() + + return +} + +func (l *Listener) listenLoop() { + for { + conn, err := l.Listener.Accept() + if err != nil { + l.errChan <- err + close(l.errChan) + return + } + go l.mux(conn) + } +} + +func (l *Listener) mux(conn net.Conn) { + smuxConfig := smux.DefaultConfig() + smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled + if l.md.muxKeepAlivePeriod > 0 { + smuxConfig.KeepAliveInterval = l.md.muxKeepAlivePeriod + } + if l.md.muxKeepAliveTimeout > 0 { + smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout + } + if l.md.muxMaxFrameSize > 0 { + smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize + } + if l.md.muxMaxReceiveBuffer > 0 { + smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer + } + if l.md.muxMaxStreamBuffer > 0 { + smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer + } + session, err := smux.Server(conn, smuxConfig) + if err != nil { + l.logger.Error(err) + return + } + defer session.Close() + + for { + stream, err := session.AcceptStream() + if err != nil { + l.logger.Error("accept stream:", err) + return + } + + select { + case l.connChan <- stream: + case <-stream.GetDieCh(): + default: + stream.Close() + l.logger.Error("connection queue is full") + } + } +} + +func (l *Listener) Accept() (conn net.Conn, err error) { + var ok bool + select { + case conn = <-l.connChan: + case err, ok = <-l.errChan: + if !ok { + err = errors.New("accpet on closed listener") + } + } + return +} + +func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { + if val, ok := md[addr]; ok { + m.addr = val + } else { + err = errors.New("missing address") + return + } + + m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile]) + if err != nil { + return + } + + return +} diff --git a/server/listener/tls/mux/metadata.go b/server/listener/tls/mux/metadata.go new file mode 100644 index 0000000..7c65039 --- /dev/null +++ b/server/listener/tls/mux/metadata.go @@ -0,0 +1,38 @@ +package mux + +import ( + "crypto/tls" + "time" +) + +const ( + addr = "addr" + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + + muxKeepAliveDisabled = "muxKeepAliveDisabled" + muxKeepAlivePeriod = "muxKeepAlivePeriod" + muxKeepAliveTimeout = "muxKeepAliveTimeout" + muxMaxFrameSize = "muxMaxFrameSize" + muxMaxReceiveBuffer = "muxMaxReceiveBuffer" + muxMaxStreamBuffer = "muxMaxStreamBuffer" +) + +const ( + defaultQueueSize = 128 +) + +type metadata struct { + addr string + tlsConfig *tls.Config + + muxKeepAliveDisabled bool + muxKeepAlivePeriod time.Duration + muxKeepAliveTimeout time.Duration + muxMaxFrameSize int + muxMaxReceiveBuffer int + muxMaxStreamBuffer int + + connQueueSize int +} diff --git a/server/listener/ws/listener.go b/server/listener/ws/listener.go index e7d5ce8..fcdcfcf 100644 --- a/server/listener/ws/listener.go +++ b/server/listener/ws/listener.go @@ -1,11 +1,10 @@ -package tcp +package ws import ( "crypto/tls" "errors" "net" "net/http" - "time" "github.com/go-gost/gost/logger" "github.com/go-gost/gost/server/listener" @@ -89,12 +88,6 @@ func (l *Listener) Init(md listener.Metadata) (err error) { close(l.errChan) }() - select { - case err = <-l.errChan: - return - case <-time.After(100 * time.Millisecond): - } - return } @@ -138,36 +131,9 @@ func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) { } select { - case l.connChan <- &websocketConn{Conn: conn}: + case l.connChan <- utils.WebsocketServerConn(conn): default: conn.Close() l.logger.Warn("connection queue is full") } } - -type websocketConn struct { - *websocket.Conn - rb []byte -} - -func (c *websocketConn) Read(b []byte) (n int, err error) { - if len(c.rb) == 0 { - _, c.rb, err = c.ReadMessage() - } - n = copy(b, c.rb) - c.rb = c.rb[n:] - return -} - -func (c *websocketConn) Write(b []byte) (n int, err error) { - err = c.WriteMessage(websocket.BinaryMessage, b) - n = len(b) - return -} - -func (c *websocketConn) SetDeadline(t time.Time) error { - if err := c.SetReadDeadline(t); err != nil { - return err - } - return c.SetWriteDeadline(t) -} diff --git a/server/listener/ws/metadata.go b/server/listener/ws/metadata.go index 257364a..017195a 100644 --- a/server/listener/ws/metadata.go +++ b/server/listener/ws/metadata.go @@ -1,4 +1,4 @@ -package tcp +package ws import ( "crypto/tls" diff --git a/server/listener/ws/mux/listener.go b/server/listener/ws/mux/listener.go new file mode 100644 index 0000000..c18a1cc --- /dev/null +++ b/server/listener/ws/mux/listener.go @@ -0,0 +1,177 @@ +package mux + +import ( + "crypto/tls" + "errors" + "net" + "net/http" + + "github.com/go-gost/gost/logger" + "github.com/go-gost/gost/server/listener" + "github.com/go-gost/gost/utils" + "github.com/gorilla/websocket" + "github.com/xtaci/smux" +) + +var ( + _ listener.Listener = (*Listener)(nil) +) + +type Listener struct { + md metadata + addr net.Addr + upgrader *websocket.Upgrader + srv *http.Server + connChan chan net.Conn + errChan chan error + logger logger.Logger +} + +func NewListener(opts ...listener.Option) *Listener { + options := &listener.Options{} + for _, opt := range opts { + opt(options) + } + return &Listener{ + logger: options.Logger, + } +} + +func (l *Listener) Init(md listener.Metadata) (err error) { + l.md, err = l.parseMetadata(md) + if err != nil { + return + } + + l.upgrader = &websocket.Upgrader{ + HandshakeTimeout: l.md.handshakeTimeout, + ReadBufferSize: l.md.readBufferSize, + WriteBufferSize: l.md.writeBufferSize, + CheckOrigin: func(r *http.Request) bool { return true }, + EnableCompression: l.md.enableCompression, + } + + path := l.md.path + if path == "" { + path = defaultPath + } + mux := http.NewServeMux() + mux.Handle(path, http.HandlerFunc(l.upgrade)) + l.srv = &http.Server{ + Addr: l.md.addr, + TLSConfig: l.md.tlsConfig, + Handler: mux, + ReadHeaderTimeout: l.md.readHeaderTimeout, + } + + queueSize := l.md.connQueueSize + if queueSize <= 0 { + queueSize = defaultQueueSize + } + l.connChan = make(chan net.Conn, queueSize) + l.errChan = make(chan error, 1) + + ln, err := net.Listen("tcp", l.md.addr) + if err != nil { + return + } + if l.md.tlsConfig != nil { + ln = tls.NewListener(ln, l.md.tlsConfig) + } + + l.addr = ln.Addr() + + go func() { + err := l.srv.Serve(ln) + if err != nil { + l.errChan <- err + } + close(l.errChan) + }() + + return +} + +func (l *Listener) Accept() (conn net.Conn, err error) { + select { + case conn = <-l.connChan: + case err = <-l.errChan: + } + return +} + +func (l *Listener) Close() error { + return l.srv.Close() +} + +func (l *Listener) Addr() net.Addr { + return l.addr +} + +func (l *Listener) parseMetadata(md listener.Metadata) (m metadata, err error) { + if val, ok := md[addr]; ok { + m.addr = val + } else { + err = errors.New("missing address") + return + } + + m.tlsConfig, err = utils.LoadTLSConfig(md[certFile], md[keyFile], md[caFile]) + if err != nil { + return + } + + return +} + +func (l *Listener) upgrade(w http.ResponseWriter, r *http.Request) { + conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader) + if err != nil { + l.logger.Error(err) + return + } + + l.mux(utils.WebsocketServerConn(conn)) +} + +func (l *Listener) mux(conn net.Conn) { + smuxConfig := smux.DefaultConfig() + smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled + if l.md.muxKeepAlivePeriod > 0 { + smuxConfig.KeepAliveInterval = l.md.muxKeepAlivePeriod + } + if l.md.muxKeepAliveTimeout > 0 { + smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout + } + if l.md.muxMaxFrameSize > 0 { + smuxConfig.MaxFrameSize = l.md.muxMaxFrameSize + } + if l.md.muxMaxReceiveBuffer > 0 { + smuxConfig.MaxReceiveBuffer = l.md.muxMaxReceiveBuffer + } + if l.md.muxMaxStreamBuffer > 0 { + smuxConfig.MaxStreamBuffer = l.md.muxMaxStreamBuffer + } + session, err := smux.Server(conn, smuxConfig) + if err != nil { + l.logger.Error(err) + return + } + defer session.Close() + + for { + stream, err := session.AcceptStream() + if err != nil { + l.logger.Error("accept stream:", err) + return + } + + select { + case l.connChan <- stream: + case <-stream.GetDieCh(): + default: + stream.Close() + l.logger.Error("connection queue is full") + } + } +} diff --git a/server/listener/ws/mux/metadata.go b/server/listener/ws/mux/metadata.go new file mode 100644 index 0000000..89d89c0 --- /dev/null +++ b/server/listener/ws/mux/metadata.go @@ -0,0 +1,54 @@ +package mux + +import ( + "crypto/tls" + "net/http" + "time" +) + +const ( + addr = "addr" + path = "path" + certFile = "certFile" + keyFile = "keyFile" + caFile = "caFile" + handshakeTimeout = "handshakeTimeout" + readHeaderTimeout = "readHeaderTimeout" + readBufferSize = "readBufferSize" + writeBufferSize = "writeBufferSize" + enableCompression = "enableCompression" + responseHeader = "responseHeader" + connQueueSize = "connQueueSize" + + muxKeepAliveDisabled = "muxKeepAliveDisabled" + muxKeepAlivePeriod = "muxKeepAlivePeriod" + muxKeepAliveTimeout = "muxKeepAliveTimeout" + muxMaxFrameSize = "muxMaxFrameSize" + muxMaxReceiveBuffer = "muxMaxReceiveBuffer" + muxMaxStreamBuffer = "muxMaxStreamBuffer" +) + +const ( + defaultPath = "/ws" + defaultQueueSize = 128 +) + +type metadata struct { + addr string + path string + tlsConfig *tls.Config + handshakeTimeout time.Duration + readHeaderTimeout time.Duration + readBufferSize int + writeBufferSize int + enableCompression bool + responseHeader http.Header + + muxKeepAliveDisabled bool + muxKeepAlivePeriod time.Duration + muxKeepAliveTimeout time.Duration + muxMaxFrameSize int + muxMaxReceiveBuffer int + muxMaxStreamBuffer int + connQueueSize int +} diff --git a/utils/ws.go b/utils/ws.go index d4b585b..5723f5f 100644 --- a/utils/ws.go +++ b/utils/ws.go @@ -1 +1,41 @@ package utils + +import ( + "net" + "time" + + "github.com/gorilla/websocket" +) + +type websocketConn struct { + *websocket.Conn + rb []byte +} + +func WebsocketServerConn(conn *websocket.Conn) net.Conn { + return &websocketConn{ + Conn: conn, + } +} + +func (c *websocketConn) Read(b []byte) (n int, err error) { + if len(c.rb) == 0 { + _, c.rb, err = c.ReadMessage() + } + n = copy(b, c.rb) + c.rb = c.rb[n:] + return +} + +func (c *websocketConn) Write(b []byte) (n int, err error) { + err = c.WriteMessage(websocket.BinaryMessage, b) + n = len(b) + return +} + +func (c *websocketConn) SetDeadline(t time.Time) error { + if err := c.SetReadDeadline(t); err != nil { + return err + } + return c.SetWriteDeadline(t) +}