add mws dialer

This commit is contained in:
ginuerzh 2021-12-17 14:05:59 +08:00
parent bfe5eae172
commit 8e31e532e4
23 changed files with 795 additions and 137 deletions

View File

@ -24,6 +24,8 @@ import (
_ "github.com/go-gost/gost/pkg/dialer/tls"
_ "github.com/go-gost/gost/pkg/dialer/tls/mux"
_ "github.com/go-gost/gost/pkg/dialer/udp"
_ "github.com/go-gost/gost/pkg/dialer/ws"
_ "github.com/go-gost/gost/pkg/dialer/ws/mux"
// Register handlers
_ "github.com/go-gost/gost/pkg/handler/auto"

View File

@ -57,7 +57,11 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
Host: address,
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Header: c.md.header,
}
if req.Header == nil {
req.Header = http.Header{}
}
req.Header.Set("Proxy-Connection", "keep-alive")
@ -68,10 +72,6 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
"Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p)))
}
for k, v := range c.md.headers {
req.Header.Set(k, v)
}
switch network {
case "tcp", "tcp4", "tcp6":
if _, ok := conn.(net.PacketConn); ok {

View File

@ -2,6 +2,7 @@ package http
import (
"fmt"
"net/http"
"net/url"
"strings"
"time"
@ -12,14 +13,14 @@ import (
type metadata struct {
connectTimeout time.Duration
User *url.Userinfo
headers map[string]string
header http.Header
}
func (c *httpConnector) parseMetadata(md md.Metadata) (err error) {
const (
connectTimeout = "timeout"
user = "user"
headers = "headers"
header = "header"
)
c.md.connectTimeout = md.GetDuration(connectTimeout)
@ -33,12 +34,12 @@ func (c *httpConnector) parseMetadata(md md.Metadata) (err error) {
}
}
if mm, _ := md.Get(headers).(map[interface{}]interface{}); len(mm) > 0 {
m := make(map[string]string)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
h := http.Header{}
for k, v := range mm {
m[fmt.Sprintf("%v", k)] = fmt.Sprintf("%v", v)
h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
c.md.headers = m
c.md.header = h
}
return

View File

@ -23,7 +23,7 @@ type obfsHTTPConn struct {
headerDrained bool
handshaked bool
handshakeMutex sync.Mutex
headers map[string]string
header http.Header
logger logger.Logger
}
@ -50,15 +50,15 @@ func (c *obfsHTTPConn) handshake() (err error) {
ProtoMajor: 1,
ProtoMinor: 1,
URL: &url.URL{Scheme: "http", Host: c.host},
Header: make(http.Header),
Header: c.header,
}
if r.Header == nil {
r.Header = http.Header{}
}
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
key, _ := c.generateChallengeKey()
r.Header.Set("Sec-WebSocket-Key", key)
for k, v := range c.headers {
r.Header.Set(k, v)
}
// cache the request header
if err = r.Write(&c.wbuf); err != nil {

View File

@ -56,9 +56,9 @@ func (d *obfsHTTPDialer) Handshake(ctx context.Context, conn net.Conn, options .
}
return &obfsHTTPConn{
Conn: conn,
host: host,
headers: d.md.headers,
logger: d.logger,
Conn: conn,
host: host,
header: d.md.header,
logger: d.logger,
}, nil
}

View File

@ -2,27 +2,28 @@ package http
import (
"fmt"
"net/http"
md "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
host string
headers map[string]string
host string
header http.Header
}
func (d *obfsHTTPDialer) parseMetadata(md md.Metadata) (err error) {
const (
headers = "headers"
host = "host"
header = "header"
host = "host"
)
if mm, _ := md.Get(headers).(map[interface{}]interface{}); len(mm) > 0 {
m := make(map[string]string)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
h := http.Header{}
for k, v := range mm {
m[fmt.Sprintf("%v", k)] = fmt.Sprintf("%v", v)
h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
d.md.headers = m
d.md.header = h
}
d.md.host = md.GetString(host)
return

109
pkg/dialer/ws/dialer.go Normal file
View File

@ -0,0 +1,109 @@
package ws
import (
"context"
"net"
"net/url"
"time"
"github.com/go-gost/gost/pkg/dialer"
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/gorilla/websocket"
)
func init() {
registry.RegisterDialer("ws", NewDialer)
registry.RegisterDialer("wss", NewTLSDialer)
}
type wsDialer struct {
tlsEnabled bool
logger logger.Logger
md metadata
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &wsDialer{
logger: options.Logger,
}
}
func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &wsDialer{
tlsEnabled: true,
logger: options.Logger,
}
}
func (d *wsDialer) Init(md md.Metadata) (err error) {
return d.parseMetadata(md)
}
func (d *wsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) {
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, "tcp", addr)
if err != nil {
d.logger.Error(err)
}
return conn, err
}
// Handshake implements dialer.Handshaker
func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) {
opts := &dialer.HandshakeOptions{}
for _, option := range options {
option(opts)
}
if d.md.handshakeTimeout > 0 {
conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout))
defer conn.SetDeadline(time.Time{})
}
host := d.md.host
if host == "" {
host = opts.Addr
}
dialer := websocket.Dialer{
HandshakeTimeout: d.md.handshakeTimeout,
ReadBufferSize: d.md.readBufferSize,
WriteBufferSize: d.md.writeBufferSize,
EnableCompression: d.md.enableCompression,
NetDial: func(net, addr string) (net.Conn, error) {
return conn, nil
},
}
url := url.URL{Scheme: "ws", Host: host, Path: d.md.path}
if d.tlsEnabled {
url.Scheme = "wss"
dialer.TLSClientConfig = d.md.tlsConfig
}
c, resp, err := dialer.Dial(url.String(), d.md.header)
if err != nil {
return nil, err
}
resp.Body.Close()
return ws_util.Conn(c), nil
}

86
pkg/dialer/ws/metadata.go Normal file
View File

@ -0,0 +1,86 @@
package ws
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
md "github.com/go-gost/gost/pkg/metadata"
)
const (
defaultPath = "/ws"
)
type metadata struct {
path string
host string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
header http.Header
}
func (d *wsDialer) parseMetadata(md md.Metadata) (err error) {
const (
path = "path"
host = "host"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
header = "header"
)
d.md.path = md.GetString(path)
if d.md.path == "" {
d.md.path = defaultPath
}
d.md.host = md.GetString(host)
sn, _, _ := net.SplitHostPort(md.GetString(serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
md.GetString(certFile),
md.GetString(keyFile),
md.GetString(caFile),
md.GetBool(secure),
sn,
)
d.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
d.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout)
d.md.readBufferSize = md.GetInt(readBufferSize)
d.md.writeBufferSize = md.GetInt(writeBufferSize)
d.md.enableCompression = md.GetBool(enableCompression)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
h := http.Header{}
for k, v := range mm {
h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
d.md.header = h
}
return
}

38
pkg/dialer/ws/mux/conn.go Normal file
View File

@ -0,0 +1,38 @@
package mux
import (
"net"
"github.com/xtaci/smux"
)
type muxSession struct {
conn net.Conn
session *smux.Session
}
func (session *muxSession) GetConn() (net.Conn, error) {
return session.session.OpenStream()
}
func (session *muxSession) Accept() (net.Conn, error) {
return session.session.AcceptStream()
}
func (session *muxSession) Close() error {
if session.session == nil {
return nil
}
return session.session.Close()
}
func (session *muxSession) IsClosed() bool {
if session.session == nil {
return true
}
return session.session.IsClosed()
}
func (session *muxSession) NumStreams() int {
return session.session.NumStreams()
}

220
pkg/dialer/ws/mux/dialer.go Normal file
View File

@ -0,0 +1,220 @@
package mux
import (
"context"
"errors"
"net"
"net/url"
"sync"
"time"
"github.com/go-gost/gost/pkg/dialer"
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/gorilla/websocket"
"github.com/xtaci/smux"
)
func init() {
registry.RegisterDialer("mws", NewDialer)
registry.RegisterDialer("mwss", NewTLSDialer)
}
type mwsDialer struct {
sessions map[string]*muxSession
sessionMutex sync.Mutex
logger logger.Logger
md metadata
tlsEnabled bool
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &mwsDialer{
sessions: make(map[string]*muxSession),
logger: options.Logger,
}
}
func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &mwsDialer{
sessions: make(map[string]*muxSession),
logger: options.Logger,
tlsEnabled: true,
}
}
func (d *mwsDialer) Init(md md.Metadata) (err error) {
if err = d.parseMetadata(md); err != nil {
return
}
return nil
}
// Multiplex implements dialer.Multiplexer interface.
func (d *mwsDialer) Multiplex() bool {
return true
}
func (d *mwsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) {
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
d.sessionMutex.Lock()
defer d.sessionMutex.Unlock()
session, ok := d.sessions[addr]
if session != nil && session.IsClosed() {
delete(d.sessions, addr) // session is dead
ok = false
}
if !ok {
conn, err = d.dial(ctx, "tcp", addr, &options)
if err != nil {
return
}
session = &muxSession{conn: conn}
d.sessions[addr] = session
}
return session.conn, err
}
// Handshake implements dialer.Handshaker
func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) {
opts := &dialer.HandshakeOptions{}
for _, option := range options {
option(opts)
}
d.sessionMutex.Lock()
defer d.sessionMutex.Unlock()
if d.md.handshakeTimeout > 0 {
conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout))
defer conn.SetDeadline(time.Time{})
}
session, ok := d.sessions[opts.Addr]
if session != nil && session.conn != conn {
conn.Close()
return nil, errors.New("mtls: unrecognized connection")
}
if !ok || session.session == nil {
host := d.md.host
if host == "" {
host = opts.Addr
}
s, err := d.initSession(ctx, host, conn)
if err != nil {
d.logger.Error(err)
conn.Close()
delete(d.sessions, opts.Addr)
return nil, err
}
session = s
d.sessions[opts.Addr] = session
}
cc, err := session.GetConn()
if err != nil {
session.Close()
delete(d.sessions, opts.Addr)
return nil, err
}
return cc, nil
}
func (d *mwsDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) {
dial := opts.DialFunc
if dial != nil {
conn, err := dial(ctx, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial with dial func")
}
return conn, err
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, network, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debugf("dial direct %s/%s", addr, network)
}
return conn, err
}
func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) (*muxSession, error) {
dialer := websocket.Dialer{
HandshakeTimeout: d.md.handshakeTimeout,
ReadBufferSize: d.md.readBufferSize,
WriteBufferSize: d.md.writeBufferSize,
EnableCompression: d.md.enableCompression,
NetDial: func(net, addr string) (net.Conn, error) {
return conn, nil
},
}
url := url.URL{Scheme: "ws", Host: host, Path: d.md.path}
if d.tlsEnabled {
url.Scheme = "wss"
dialer.TLSClientConfig = d.md.tlsConfig
}
c, resp, err := dialer.Dial(url.String(), d.md.header)
if err != nil {
return nil, err
}
resp.Body.Close()
conn = ws_util.Conn(c)
// stream multiplex
smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = d.md.muxKeepAliveDisabled
if d.md.muxKeepAliveInterval > 0 {
smuxConfig.KeepAliveInterval = d.md.muxKeepAliveInterval
}
if d.md.muxKeepAliveTimeout > 0 {
smuxConfig.KeepAliveTimeout = d.md.muxKeepAliveTimeout
}
if d.md.muxMaxFrameSize > 0 {
smuxConfig.MaxFrameSize = d.md.muxMaxFrameSize
}
if d.md.muxMaxReceiveBuffer > 0 {
smuxConfig.MaxReceiveBuffer = d.md.muxMaxReceiveBuffer
}
if d.md.muxMaxStreamBuffer > 0 {
smuxConfig.MaxStreamBuffer = d.md.muxMaxStreamBuffer
}
session, err := smux.Client(conn, smuxConfig)
if err != nil {
return nil, err
}
return &muxSession{conn: conn, session: session}, nil
}

View File

@ -0,0 +1,106 @@
package mux
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
md "github.com/go-gost/gost/pkg/metadata"
)
const (
defaultPath = "/ws"
)
type metadata struct {
path string
host string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
muxKeepAliveDisabled bool
muxKeepAliveInterval time.Duration
muxKeepAliveTimeout time.Duration
muxMaxFrameSize int
muxMaxReceiveBuffer int
muxMaxStreamBuffer int
header http.Header
}
func (d *mwsDialer) parseMetadata(md md.Metadata) (err error) {
const (
path = "path"
host = "host"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
header = "header"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAliveInterval = "muxKeepAliveInterval"
muxKeepAliveTimeout = "muxKeepAliveTimeout"
muxMaxFrameSize = "muxMaxFrameSize"
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
d.md.path = md.GetString(path)
if d.md.path == "" {
d.md.path = defaultPath
}
d.md.host = md.GetString(host)
sn, _, _ := net.SplitHostPort(md.GetString(serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
md.GetString(certFile),
md.GetString(keyFile),
md.GetString(caFile),
md.GetBool(secure),
sn,
)
d.md.muxKeepAliveDisabled = md.GetBool(muxKeepAliveDisabled)
d.md.muxKeepAliveInterval = md.GetDuration(muxKeepAliveInterval)
d.md.muxKeepAliveTimeout = md.GetDuration(muxKeepAliveTimeout)
d.md.muxMaxFrameSize = md.GetInt(muxMaxFrameSize)
d.md.muxMaxReceiveBuffer = md.GetInt(muxMaxReceiveBuffer)
d.md.muxMaxStreamBuffer = md.GetInt(muxMaxStreamBuffer)
d.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
d.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout)
d.md.readBufferSize = md.GetInt(readBufferSize)
d.md.writeBufferSize = md.GetInt(writeBufferSize)
d.md.enableCompression = md.GetBool(enableCompression)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
h := http.Header{}
for k, v := range mm {
h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
d.md.header = h
}
return
}

View File

@ -133,11 +133,10 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
resp := &http.Response{
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{},
Header: h.md.header,
}
for k, v := range h.md.headers {
resp.Header.Set(k, v)
if resp.Header == nil {
resp.Header = http.Header{}
}
/*

View File

@ -2,6 +2,7 @@ package http
import (
"fmt"
"net/http"
"strings"
"github.com/go-gost/gost/pkg/auth"
@ -14,12 +15,12 @@ type metadata struct {
probeResist *probeResist
sni bool
enableUDP bool
headers map[string]string
header http.Header
}
func (h *httpHandler) parseMetadata(md md.Metadata) error {
const (
headers = "headers"
header = "header"
users = "users"
probeResistKey = "probeResist"
knock = "knock"
@ -43,12 +44,12 @@ func (h *httpHandler) parseMetadata(md md.Metadata) error {
h.md.authenticator = authenticator
}
if mm, _ := md.Get(headers).(map[interface{}]interface{}); len(mm) > 0 {
m := make(map[string]string)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
hd := http.Header{}
for k, v := range mm {
m[fmt.Sprintf("%v", k)] = fmt.Sprintf("%v", v)
hd.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
h.md.headers = m
h.md.header = hd
}
if v := md.GetString(probeResistKey); v != "" {

View File

@ -21,10 +21,10 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, add
resp := &http.Response{
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{},
Header: h.md.header,
}
for k, v := range h.md.headers {
resp.Header.Set(k, v)
if resp.Header == nil {
resp.Header = http.Header{}
}
if !h.md.enableUDP {

View File

@ -12,7 +12,7 @@ type websocketConn struct {
rb []byte
}
func WebsocketServerConn(conn *websocket.Conn) net.Conn {
func Conn(conn *websocket.Conn) net.Conn {
return &websocketConn{
Conn: conn,
}

View File

@ -22,6 +22,7 @@ type obfsHTTPConn struct {
wbuf bytes.Buffer
handshaked bool
handshakeMutex sync.Mutex
header http.Header
logger logger.Logger
}
@ -71,9 +72,11 @@ func (c *obfsHTTPConn) handshake() (err error) {
StatusCode: http.StatusOK,
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Header: c.header,
}
if resp.Header == nil {
resp.Header = http.Header{}
}
resp.Header.Set("Server", "nginx/1.18.0")
resp.Header.Set("Date", time.Now().Format(time.RFC1123))
if r.Method != http.MethodGet || r.Header.Get("Upgrade") != "websocket" {

View File

@ -57,6 +57,7 @@ func (l *obfsListener) Accept() (net.Conn, error) {
return &obfsHTTPConn{
Conn: c,
header: l.md.header,
logger: l.logger,
}, nil
}

View File

@ -1,17 +1,27 @@
package http
import (
"fmt"
"net/http"
md "github.com/go-gost/gost/pkg/metadata"
)
const (
keepAlive = "keepAlive"
keepAlivePeriod = "keepAlivePeriod"
)
type metadata struct {
header http.Header
}
func (l *obfsListener) parseMetadata(md md.Metadata) (err error) {
const (
header = "header"
)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
h := http.Header{}
for k, v := range mm {
h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
l.md.header = h
}
return
}

View File

@ -54,6 +54,18 @@ func (l *mtlsListener) Init(md md.Metadata) (err error) {
return
}
func (l *mtlsListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *mtlsListener) listenLoop() {
for {
conn, err := l.Listener.Accept()
@ -67,6 +79,8 @@ func (l *mtlsListener) listenLoop() {
}
func (l *mtlsListener) mux(conn net.Conn) {
defer conn.Close()
smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
if l.md.muxKeepAliveInterval > 0 {
@ -94,7 +108,7 @@ func (l *mtlsListener) mux(conn net.Conn) {
for {
stream, err := session.AcceptStream()
if err != nil {
l.logger.Error("accept stream:", err)
l.logger.Error("accept stream: ", err)
return
}
@ -108,15 +122,3 @@ func (l *mtlsListener) mux(conn net.Conn) {
}
}
}
func (l *mtlsListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}

View File

@ -4,8 +4,9 @@ import (
"crypto/tls"
"net"
"net/http"
"net/http/httputil"
ws_util "github.com/go-gost/gost/pkg/common/util/ws"
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
"github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
@ -20,14 +21,14 @@ func init() {
type wsListener struct {
saddr string
md metadata
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
tlsEnabled bool
connChan chan net.Conn
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
}
func NewListener(opts ...listener.Option) listener.Listener {
@ -48,8 +49,8 @@ func NewTLSListener(opts ...listener.Option) listener.Listener {
}
return &wsListener{
saddr: options.Addr,
tlsEnabled: true,
logger: options.Logger,
tlsEnabled: true,
}
}
@ -62,35 +63,26 @@ func (l *wsListener) Init(md md.Metadata) (err error) {
HandshakeTimeout: l.md.handshakeTimeout,
ReadBufferSize: l.md.readBufferSize,
WriteBufferSize: l.md.writeBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: l.md.enableCompression,
CheckOrigin: func(r *http.Request) bool { return true },
}
path := l.md.path
if path == "" {
path = defaultPath
}
mux := http.NewServeMux()
mux.Handle(path, http.HandlerFunc(l.upgrade))
mux.Handle(l.md.path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: l.saddr,
TLSConfig: l.md.tlsConfig,
Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout,
}
queueSize := l.md.connQueueSize
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan net.Conn, queueSize)
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.saddr)
if err != nil {
return
}
if l.md.tlsConfig != nil {
if l.tlsEnabled {
ln = tls.NewListener(ln, l.md.tlsConfig)
}
@ -110,7 +102,7 @@ func (l *wsListener) Init(md md.Metadata) (err error) {
func (l *wsListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
@ -128,16 +120,25 @@ func (l *wsListener) Addr() net.Addr {
}
func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) {
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
if l.logger.IsLevelEnabled(logger.DebugLevel) {
log := l.logger.WithFields(map[string]interface{}{
"local": l.addr.String(),
"remote": r.RemoteAddr,
})
dump, _ := httputil.DumpRequest(r, false)
log.Debug(string(dump))
}
conn, err := l.upgrader.Upgrade(w, r, l.md.header)
if err != nil {
l.logger.Error(err)
return
}
select {
case l.connChan <- ws_util.WebsocketServerConn(conn):
case l.cqueue <- ws_util.Conn(conn):
default:
conn.Close()
l.logger.Warn("connection queue is full")
l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
}
}

View File

@ -2,6 +2,7 @@ package ws
import (
"crypto/tls"
"fmt"
"net/http"
"time"
@ -10,35 +11,40 @@ import (
)
const (
defaultPath = "/ws"
defaultQueueSize = 128
defaultPath = "/ws"
defaultBacklog = 128
)
type metadata struct {
path string
tlsConfig *tls.Config
path string
backlog int
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
connQueueSize int
header http.Header
}
func (l *wsListener) parseMetadata(md md.Metadata) (err error) {
const (
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
path = "path"
backlog = "backlog"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
responseHeader = "responseHeader"
connQueueSize = "connQueueSize"
header = "header"
)
l.md.tlsConfig, err = tls_util.LoadServerConfig(
@ -51,15 +57,28 @@ func (l *wsListener) parseMetadata(md md.Metadata) (err error) {
}
l.md.path = md.GetString(path)
l.md.connQueueSize = md.GetInt(connQueueSize)
if l.md.connQueueSize <= 0 {
l.md.connQueueSize = defaultQueueSize
if l.md.path == "" {
l.md.path = defaultPath
}
l.md.enableCompression = md.GetBool(enableCompression)
l.md.readBufferSize = md.GetInt(readBufferSize)
l.md.writeBufferSize = md.GetInt(writeBufferSize)
l.md.backlog = md.GetInt(backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
l.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout)
l.md.readBufferSize = md.GetInt(readBufferSize)
l.md.writeBufferSize = md.GetInt(writeBufferSize)
l.md.enableCompression = md.GetBool(enableCompression)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
h := http.Header{}
for k, v := range mm {
h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
l.md.header = h
}
return
}

View File

@ -4,8 +4,9 @@ import (
"crypto/tls"
"net"
"net/http"
"net/http/httputil"
ws_util "github.com/go-gost/gost/pkg/common/util/ws"
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
"github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
@ -16,18 +17,19 @@ import (
func init() {
registry.RegisterListener("mws", NewListener)
registry.RegisterListener("mwss", NewListener)
registry.RegisterListener("mwss", NewTLSListener)
}
type mwsListener struct {
saddr string
md metadata
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
connChan chan net.Conn
errChan chan error
logger logger.Logger
saddr string
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
tlsEnabled bool
}
func NewListener(opts ...listener.Option) listener.Listener {
@ -36,10 +38,23 @@ func NewListener(opts ...listener.Option) listener.Listener {
opt(options)
}
return &mwsListener{
saddr: options.Addr,
logger: options.Logger,
}
}
func NewTLSListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &mwsListener{
saddr: options.Addr,
logger: options.Logger,
tlsEnabled: true,
}
}
func (l *mwsListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil {
return
@ -49,8 +64,8 @@ func (l *mwsListener) Init(md md.Metadata) (err error) {
HandshakeTimeout: l.md.handshakeTimeout,
ReadBufferSize: l.md.readBufferSize,
WriteBufferSize: l.md.writeBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: l.md.enableCompression,
CheckOrigin: func(r *http.Request) bool { return true },
}
path := l.md.path
@ -61,19 +76,18 @@ func (l *mwsListener) Init(md md.Metadata) (err error) {
mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: l.saddr,
TLSConfig: l.md.tlsConfig,
Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout,
}
l.connChan = make(chan net.Conn, l.md.connQueueSize)
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.saddr)
if err != nil {
return
}
if l.md.tlsConfig != nil {
if l.tlsEnabled {
ln = tls.NewListener(ln, l.md.tlsConfig)
}
@ -93,7 +107,7 @@ func (l *mwsListener) Init(md md.Metadata) (err error) {
func (l *mwsListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.connChan:
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
@ -111,20 +125,31 @@ func (l *mwsListener) Addr() net.Addr {
}
func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) {
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader)
if l.logger.IsLevelEnabled(logger.DebugLevel) {
log := l.logger.WithFields(map[string]interface{}{
"local": l.addr.String(),
"remote": r.RemoteAddr,
})
dump, _ := httputil.DumpRequest(r, false)
log.Debug(string(dump))
}
conn, err := l.upgrader.Upgrade(w, r, l.md.header)
if err != nil {
l.logger.Error(err)
return
}
l.mux(ws_util.WebsocketServerConn(conn))
l.mux(ws_util.Conn(conn))
}
func (l *mwsListener) mux(conn net.Conn) {
defer conn.Close()
smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
if l.md.muxKeepAlivePeriod > 0 {
smuxConfig.KeepAliveInterval = l.md.muxKeepAlivePeriod
if l.md.muxKeepAliveInterval > 0 {
smuxConfig.KeepAliveInterval = l.md.muxKeepAliveInterval
}
if l.md.muxKeepAliveTimeout > 0 {
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
@ -148,17 +173,17 @@ func (l *mwsListener) mux(conn net.Conn) {
for {
stream, err := session.AcceptStream()
if err != nil {
l.logger.Error("accept stream:", err)
l.logger.Error("accept stream: ", err)
return
}
select {
case l.connChan <- stream:
case l.cqueue <- stream:
case <-stream.GetDieCh():
stream.Close()
default:
stream.Close()
l.logger.Error("connection queue is full")
l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
}
}
}

View File

@ -2,6 +2,7 @@ package mux
import (
"crypto/tls"
"fmt"
"net/http"
"time"
@ -10,45 +11,48 @@ import (
)
const (
defaultPath = "/ws"
defaultQueueSize = 128
defaultPath = "/ws"
defaultBacklog = 128
)
type metadata struct {
path string
tlsConfig *tls.Config
path string
backlog int
tlsConfig *tls.Config
header http.Header
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
responseHeader http.Header
muxKeepAliveDisabled bool
muxKeepAlivePeriod time.Duration
muxKeepAliveInterval time.Duration
muxKeepAliveTimeout time.Duration
muxMaxFrameSize int
muxMaxReceiveBuffer int
muxMaxStreamBuffer int
connQueueSize int
}
func (l *mwsListener) parseMetadata(md md.Metadata) (err error) {
const (
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
path = "path"
backlog = "backlog"
header = "header"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
responseHeader = "responseHeader"
connQueueSize = "connQueueSize"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAlivePeriod = "muxKeepAlivePeriod"
muxKeepAliveInterval = "muxKeepAliveInterval"
muxKeepAliveTimeout = "muxKeepAliveTimeout"
muxMaxFrameSize = "muxMaxFrameSize"
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
@ -64,5 +68,35 @@ func (l *mwsListener) parseMetadata(md md.Metadata) (err error) {
return
}
l.md.path = md.GetString(path)
if l.md.path == "" {
l.md.path = defaultPath
}
l.md.backlog = md.GetInt(backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
l.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout)
l.md.readBufferSize = md.GetInt(readBufferSize)
l.md.writeBufferSize = md.GetInt(writeBufferSize)
l.md.enableCompression = md.GetBool(enableCompression)
l.md.muxKeepAliveDisabled = md.GetBool(muxKeepAliveDisabled)
l.md.muxKeepAliveInterval = md.GetDuration(muxKeepAliveInterval)
l.md.muxKeepAliveTimeout = md.GetDuration(muxKeepAliveTimeout)
l.md.muxMaxFrameSize = md.GetInt(muxMaxFrameSize)
l.md.muxMaxReceiveBuffer = md.GetInt(muxMaxReceiveBuffer)
l.md.muxMaxStreamBuffer = md.GetInt(muxMaxStreamBuffer)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
h := http.Header{}
for k, v := range mm {
h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
l.md.header = h
}
return
}