add mws dialer

This commit is contained in:
ginuerzh 2021-12-17 14:05:59 +08:00
parent bfe5eae172
commit 8e31e532e4
23 changed files with 795 additions and 137 deletions

View File

@ -24,6 +24,8 @@ import (
_ "github.com/go-gost/gost/pkg/dialer/tls" _ "github.com/go-gost/gost/pkg/dialer/tls"
_ "github.com/go-gost/gost/pkg/dialer/tls/mux" _ "github.com/go-gost/gost/pkg/dialer/tls/mux"
_ "github.com/go-gost/gost/pkg/dialer/udp" _ "github.com/go-gost/gost/pkg/dialer/udp"
_ "github.com/go-gost/gost/pkg/dialer/ws"
_ "github.com/go-gost/gost/pkg/dialer/ws/mux"
// Register handlers // Register handlers
_ "github.com/go-gost/gost/pkg/handler/auto" _ "github.com/go-gost/gost/pkg/handler/auto"

View File

@ -57,7 +57,11 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
Host: address, Host: address,
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
Header: make(http.Header), Header: c.md.header,
}
if req.Header == nil {
req.Header = http.Header{}
} }
req.Header.Set("Proxy-Connection", "keep-alive") req.Header.Set("Proxy-Connection", "keep-alive")
@ -68,10 +72,6 @@ func (c *httpConnector) Connect(ctx context.Context, conn net.Conn, network, add
"Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p))) "Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p)))
} }
for k, v := range c.md.headers {
req.Header.Set(k, v)
}
switch network { switch network {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
if _, ok := conn.(net.PacketConn); ok { if _, ok := conn.(net.PacketConn); ok {

View File

@ -2,6 +2,7 @@ package http
import ( import (
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"strings" "strings"
"time" "time"
@ -12,14 +13,14 @@ import (
type metadata struct { type metadata struct {
connectTimeout time.Duration connectTimeout time.Duration
User *url.Userinfo User *url.Userinfo
headers map[string]string header http.Header
} }
func (c *httpConnector) parseMetadata(md md.Metadata) (err error) { func (c *httpConnector) parseMetadata(md md.Metadata) (err error) {
const ( const (
connectTimeout = "timeout" connectTimeout = "timeout"
user = "user" user = "user"
headers = "headers" header = "header"
) )
c.md.connectTimeout = md.GetDuration(connectTimeout) c.md.connectTimeout = md.GetDuration(connectTimeout)
@ -33,12 +34,12 @@ func (c *httpConnector) parseMetadata(md md.Metadata) (err error) {
} }
} }
if mm, _ := md.Get(headers).(map[interface{}]interface{}); len(mm) > 0 { if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
m := make(map[string]string) h := http.Header{}
for k, v := range mm { for k, v := range mm {
m[fmt.Sprintf("%v", k)] = fmt.Sprintf("%v", v) h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
} }
c.md.headers = m c.md.header = h
} }
return return

View File

@ -23,7 +23,7 @@ type obfsHTTPConn struct {
headerDrained bool headerDrained bool
handshaked bool handshaked bool
handshakeMutex sync.Mutex handshakeMutex sync.Mutex
headers map[string]string header http.Header
logger logger.Logger logger logger.Logger
} }
@ -50,15 +50,15 @@ func (c *obfsHTTPConn) handshake() (err error) {
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
URL: &url.URL{Scheme: "http", Host: c.host}, URL: &url.URL{Scheme: "http", Host: c.host},
Header: make(http.Header), Header: c.header,
}
if r.Header == nil {
r.Header = http.Header{}
} }
r.Header.Set("Connection", "Upgrade") r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket") r.Header.Set("Upgrade", "websocket")
key, _ := c.generateChallengeKey() key, _ := c.generateChallengeKey()
r.Header.Set("Sec-WebSocket-Key", key) r.Header.Set("Sec-WebSocket-Key", key)
for k, v := range c.headers {
r.Header.Set(k, v)
}
// cache the request header // cache the request header
if err = r.Write(&c.wbuf); err != nil { if err = r.Write(&c.wbuf); err != nil {

View File

@ -58,7 +58,7 @@ func (d *obfsHTTPDialer) Handshake(ctx context.Context, conn net.Conn, options .
return &obfsHTTPConn{ return &obfsHTTPConn{
Conn: conn, Conn: conn,
host: host, host: host,
headers: d.md.headers, header: d.md.header,
logger: d.logger, logger: d.logger,
}, nil }, nil
} }

View File

@ -2,27 +2,28 @@ package http
import ( import (
"fmt" "fmt"
"net/http"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
) )
type metadata struct { type metadata struct {
host string host string
headers map[string]string header http.Header
} }
func (d *obfsHTTPDialer) parseMetadata(md md.Metadata) (err error) { func (d *obfsHTTPDialer) parseMetadata(md md.Metadata) (err error) {
const ( const (
headers = "headers" header = "header"
host = "host" host = "host"
) )
if mm, _ := md.Get(headers).(map[interface{}]interface{}); len(mm) > 0 { if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
m := make(map[string]string) h := http.Header{}
for k, v := range mm { for k, v := range mm {
m[fmt.Sprintf("%v", k)] = fmt.Sprintf("%v", v) h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
} }
d.md.headers = m d.md.header = h
} }
d.md.host = md.GetString(host) d.md.host = md.GetString(host)
return return

109
pkg/dialer/ws/dialer.go Normal file
View File

@ -0,0 +1,109 @@
package ws
import (
"context"
"net"
"net/url"
"time"
"github.com/go-gost/gost/pkg/dialer"
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/gorilla/websocket"
)
func init() {
registry.RegisterDialer("ws", NewDialer)
registry.RegisterDialer("wss", NewTLSDialer)
}
type wsDialer struct {
tlsEnabled bool
logger logger.Logger
md metadata
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &wsDialer{
logger: options.Logger,
}
}
func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &wsDialer{
tlsEnabled: true,
logger: options.Logger,
}
}
func (d *wsDialer) Init(md md.Metadata) (err error) {
return d.parseMetadata(md)
}
func (d *wsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (net.Conn, error) {
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, "tcp", addr)
if err != nil {
d.logger.Error(err)
}
return conn, err
}
// Handshake implements dialer.Handshaker
func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) {
opts := &dialer.HandshakeOptions{}
for _, option := range options {
option(opts)
}
if d.md.handshakeTimeout > 0 {
conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout))
defer conn.SetDeadline(time.Time{})
}
host := d.md.host
if host == "" {
host = opts.Addr
}
dialer := websocket.Dialer{
HandshakeTimeout: d.md.handshakeTimeout,
ReadBufferSize: d.md.readBufferSize,
WriteBufferSize: d.md.writeBufferSize,
EnableCompression: d.md.enableCompression,
NetDial: func(net, addr string) (net.Conn, error) {
return conn, nil
},
}
url := url.URL{Scheme: "ws", Host: host, Path: d.md.path}
if d.tlsEnabled {
url.Scheme = "wss"
dialer.TLSClientConfig = d.md.tlsConfig
}
c, resp, err := dialer.Dial(url.String(), d.md.header)
if err != nil {
return nil, err
}
resp.Body.Close()
return ws_util.Conn(c), nil
}

86
pkg/dialer/ws/metadata.go Normal file
View File

@ -0,0 +1,86 @@
package ws
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
md "github.com/go-gost/gost/pkg/metadata"
)
const (
defaultPath = "/ws"
)
type metadata struct {
path string
host string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
header http.Header
}
func (d *wsDialer) parseMetadata(md md.Metadata) (err error) {
const (
path = "path"
host = "host"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
header = "header"
)
d.md.path = md.GetString(path)
if d.md.path == "" {
d.md.path = defaultPath
}
d.md.host = md.GetString(host)
sn, _, _ := net.SplitHostPort(md.GetString(serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
md.GetString(certFile),
md.GetString(keyFile),
md.GetString(caFile),
md.GetBool(secure),
sn,
)
d.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
d.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout)
d.md.readBufferSize = md.GetInt(readBufferSize)
d.md.writeBufferSize = md.GetInt(writeBufferSize)
d.md.enableCompression = md.GetBool(enableCompression)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
h := http.Header{}
for k, v := range mm {
h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
d.md.header = h
}
return
}

38
pkg/dialer/ws/mux/conn.go Normal file
View File

@ -0,0 +1,38 @@
package mux
import (
"net"
"github.com/xtaci/smux"
)
type muxSession struct {
conn net.Conn
session *smux.Session
}
func (session *muxSession) GetConn() (net.Conn, error) {
return session.session.OpenStream()
}
func (session *muxSession) Accept() (net.Conn, error) {
return session.session.AcceptStream()
}
func (session *muxSession) Close() error {
if session.session == nil {
return nil
}
return session.session.Close()
}
func (session *muxSession) IsClosed() bool {
if session.session == nil {
return true
}
return session.session.IsClosed()
}
func (session *muxSession) NumStreams() int {
return session.session.NumStreams()
}

220
pkg/dialer/ws/mux/dialer.go Normal file
View File

@ -0,0 +1,220 @@
package mux
import (
"context"
"errors"
"net"
"net/url"
"sync"
"time"
"github.com/go-gost/gost/pkg/dialer"
ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
"github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata"
"github.com/go-gost/gost/pkg/registry"
"github.com/gorilla/websocket"
"github.com/xtaci/smux"
)
func init() {
registry.RegisterDialer("mws", NewDialer)
registry.RegisterDialer("mwss", NewTLSDialer)
}
type mwsDialer struct {
sessions map[string]*muxSession
sessionMutex sync.Mutex
logger logger.Logger
md metadata
tlsEnabled bool
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &mwsDialer{
sessions: make(map[string]*muxSession),
logger: options.Logger,
}
}
func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
for _, opt := range opts {
opt(options)
}
return &mwsDialer{
sessions: make(map[string]*muxSession),
logger: options.Logger,
tlsEnabled: true,
}
}
func (d *mwsDialer) Init(md md.Metadata) (err error) {
if err = d.parseMetadata(md); err != nil {
return
}
return nil
}
// Multiplex implements dialer.Multiplexer interface.
func (d *mwsDialer) Multiplex() bool {
return true
}
func (d *mwsDialer) Dial(ctx context.Context, addr string, opts ...dialer.DialOption) (conn net.Conn, err error) {
var options dialer.DialOptions
for _, opt := range opts {
opt(&options)
}
d.sessionMutex.Lock()
defer d.sessionMutex.Unlock()
session, ok := d.sessions[addr]
if session != nil && session.IsClosed() {
delete(d.sessions, addr) // session is dead
ok = false
}
if !ok {
conn, err = d.dial(ctx, "tcp", addr, &options)
if err != nil {
return
}
session = &muxSession{conn: conn}
d.sessions[addr] = session
}
return session.conn, err
}
// Handshake implements dialer.Handshaker
func (d *mwsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dialer.HandshakeOption) (net.Conn, error) {
opts := &dialer.HandshakeOptions{}
for _, option := range options {
option(opts)
}
d.sessionMutex.Lock()
defer d.sessionMutex.Unlock()
if d.md.handshakeTimeout > 0 {
conn.SetDeadline(time.Now().Add(d.md.handshakeTimeout))
defer conn.SetDeadline(time.Time{})
}
session, ok := d.sessions[opts.Addr]
if session != nil && session.conn != conn {
conn.Close()
return nil, errors.New("mtls: unrecognized connection")
}
if !ok || session.session == nil {
host := d.md.host
if host == "" {
host = opts.Addr
}
s, err := d.initSession(ctx, host, conn)
if err != nil {
d.logger.Error(err)
conn.Close()
delete(d.sessions, opts.Addr)
return nil, err
}
session = s
d.sessions[opts.Addr] = session
}
cc, err := session.GetConn()
if err != nil {
session.Close()
delete(d.sessions, opts.Addr)
return nil, err
}
return cc, nil
}
func (d *mwsDialer) dial(ctx context.Context, network, addr string, opts *dialer.DialOptions) (net.Conn, error) {
dial := opts.DialFunc
if dial != nil {
conn, err := dial(ctx, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debug("dial with dial func")
}
return conn, err
}
var netd net.Dialer
conn, err := netd.DialContext(ctx, network, addr)
if err != nil {
d.logger.Error(err)
} else {
d.logger.WithFields(map[string]interface{}{
"src": conn.LocalAddr().String(),
"dst": addr,
}).Debugf("dial direct %s/%s", addr, network)
}
return conn, err
}
func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn) (*muxSession, error) {
dialer := websocket.Dialer{
HandshakeTimeout: d.md.handshakeTimeout,
ReadBufferSize: d.md.readBufferSize,
WriteBufferSize: d.md.writeBufferSize,
EnableCompression: d.md.enableCompression,
NetDial: func(net, addr string) (net.Conn, error) {
return conn, nil
},
}
url := url.URL{Scheme: "ws", Host: host, Path: d.md.path}
if d.tlsEnabled {
url.Scheme = "wss"
dialer.TLSClientConfig = d.md.tlsConfig
}
c, resp, err := dialer.Dial(url.String(), d.md.header)
if err != nil {
return nil, err
}
resp.Body.Close()
conn = ws_util.Conn(c)
// stream multiplex
smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = d.md.muxKeepAliveDisabled
if d.md.muxKeepAliveInterval > 0 {
smuxConfig.KeepAliveInterval = d.md.muxKeepAliveInterval
}
if d.md.muxKeepAliveTimeout > 0 {
smuxConfig.KeepAliveTimeout = d.md.muxKeepAliveTimeout
}
if d.md.muxMaxFrameSize > 0 {
smuxConfig.MaxFrameSize = d.md.muxMaxFrameSize
}
if d.md.muxMaxReceiveBuffer > 0 {
smuxConfig.MaxReceiveBuffer = d.md.muxMaxReceiveBuffer
}
if d.md.muxMaxStreamBuffer > 0 {
smuxConfig.MaxStreamBuffer = d.md.muxMaxStreamBuffer
}
session, err := smux.Client(conn, smuxConfig)
if err != nil {
return nil, err
}
return &muxSession{conn: conn, session: session}, nil
}

View File

@ -0,0 +1,106 @@
package mux
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
md "github.com/go-gost/gost/pkg/metadata"
)
const (
defaultPath = "/ws"
)
type metadata struct {
path string
host string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
writeBufferSize int
enableCompression bool
muxKeepAliveDisabled bool
muxKeepAliveInterval time.Duration
muxKeepAliveTimeout time.Duration
muxMaxFrameSize int
muxMaxReceiveBuffer int
muxMaxStreamBuffer int
header http.Header
}
func (d *mwsDialer) parseMetadata(md md.Metadata) (err error) {
const (
path = "path"
host = "host"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression"
header = "header"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAliveInterval = "muxKeepAliveInterval"
muxKeepAliveTimeout = "muxKeepAliveTimeout"
muxMaxFrameSize = "muxMaxFrameSize"
muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
d.md.path = md.GetString(path)
if d.md.path == "" {
d.md.path = defaultPath
}
d.md.host = md.GetString(host)
sn, _, _ := net.SplitHostPort(md.GetString(serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
md.GetString(certFile),
md.GetString(keyFile),
md.GetString(caFile),
md.GetBool(secure),
sn,
)
d.md.muxKeepAliveDisabled = md.GetBool(muxKeepAliveDisabled)
d.md.muxKeepAliveInterval = md.GetDuration(muxKeepAliveInterval)
d.md.muxKeepAliveTimeout = md.GetDuration(muxKeepAliveTimeout)
d.md.muxMaxFrameSize = md.GetInt(muxMaxFrameSize)
d.md.muxMaxReceiveBuffer = md.GetInt(muxMaxReceiveBuffer)
d.md.muxMaxStreamBuffer = md.GetInt(muxMaxStreamBuffer)
d.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
d.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout)
d.md.readBufferSize = md.GetInt(readBufferSize)
d.md.writeBufferSize = md.GetInt(writeBufferSize)
d.md.enableCompression = md.GetBool(enableCompression)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
h := http.Header{}
for k, v := range mm {
h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
d.md.header = h
}
return
}

View File

@ -133,11 +133,10 @@ func (h *httpHandler) handleRequest(ctx context.Context, conn net.Conn, req *htt
resp := &http.Response{ resp := &http.Response{
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
Header: http.Header{}, Header: h.md.header,
} }
if resp.Header == nil {
for k, v := range h.md.headers { resp.Header = http.Header{}
resp.Header.Set(k, v)
} }
/* /*

View File

@ -2,6 +2,7 @@ package http
import ( import (
"fmt" "fmt"
"net/http"
"strings" "strings"
"github.com/go-gost/gost/pkg/auth" "github.com/go-gost/gost/pkg/auth"
@ -14,12 +15,12 @@ type metadata struct {
probeResist *probeResist probeResist *probeResist
sni bool sni bool
enableUDP bool enableUDP bool
headers map[string]string header http.Header
} }
func (h *httpHandler) parseMetadata(md md.Metadata) error { func (h *httpHandler) parseMetadata(md md.Metadata) error {
const ( const (
headers = "headers" header = "header"
users = "users" users = "users"
probeResistKey = "probeResist" probeResistKey = "probeResist"
knock = "knock" knock = "knock"
@ -43,12 +44,12 @@ func (h *httpHandler) parseMetadata(md md.Metadata) error {
h.md.authenticator = authenticator h.md.authenticator = authenticator
} }
if mm, _ := md.Get(headers).(map[interface{}]interface{}); len(mm) > 0 { if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
m := make(map[string]string) hd := http.Header{}
for k, v := range mm { for k, v := range mm {
m[fmt.Sprintf("%v", k)] = fmt.Sprintf("%v", v) hd.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
} }
h.md.headers = m h.md.header = hd
} }
if v := md.GetString(probeResistKey); v != "" { if v := md.GetString(probeResistKey); v != "" {

View File

@ -21,10 +21,10 @@ func (h *httpHandler) handleUDP(ctx context.Context, conn net.Conn, network, add
resp := &http.Response{ resp := &http.Response{
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
Header: http.Header{}, Header: h.md.header,
} }
for k, v := range h.md.headers { if resp.Header == nil {
resp.Header.Set(k, v) resp.Header = http.Header{}
} }
if !h.md.enableUDP { if !h.md.enableUDP {

View File

@ -12,7 +12,7 @@ type websocketConn struct {
rb []byte rb []byte
} }
func WebsocketServerConn(conn *websocket.Conn) net.Conn { func Conn(conn *websocket.Conn) net.Conn {
return &websocketConn{ return &websocketConn{
Conn: conn, Conn: conn,
} }

View File

@ -22,6 +22,7 @@ type obfsHTTPConn struct {
wbuf bytes.Buffer wbuf bytes.Buffer
handshaked bool handshaked bool
handshakeMutex sync.Mutex handshakeMutex sync.Mutex
header http.Header
logger logger.Logger logger logger.Logger
} }
@ -71,9 +72,11 @@ func (c *obfsHTTPConn) handshake() (err error) {
StatusCode: http.StatusOK, StatusCode: http.StatusOK,
ProtoMajor: 1, ProtoMajor: 1,
ProtoMinor: 1, ProtoMinor: 1,
Header: make(http.Header), Header: c.header,
}
if resp.Header == nil {
resp.Header = http.Header{}
} }
resp.Header.Set("Server", "nginx/1.18.0")
resp.Header.Set("Date", time.Now().Format(time.RFC1123)) resp.Header.Set("Date", time.Now().Format(time.RFC1123))
if r.Method != http.MethodGet || r.Header.Get("Upgrade") != "websocket" { if r.Method != http.MethodGet || r.Header.Get("Upgrade") != "websocket" {

View File

@ -57,6 +57,7 @@ func (l *obfsListener) Accept() (net.Conn, error) {
return &obfsHTTPConn{ return &obfsHTTPConn{
Conn: c, Conn: c,
header: l.md.header,
logger: l.logger, logger: l.logger,
}, nil }, nil
} }

View File

@ -1,17 +1,27 @@
package http package http
import ( import (
"fmt"
"net/http"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
) )
const (
keepAlive = "keepAlive"
keepAlivePeriod = "keepAlivePeriod"
)
type metadata struct { type metadata struct {
header http.Header
} }
func (l *obfsListener) parseMetadata(md md.Metadata) (err error) { func (l *obfsListener) parseMetadata(md md.Metadata) (err error) {
const (
header = "header"
)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
h := http.Header{}
for k, v := range mm {
h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
l.md.header = h
}
return return
} }

View File

@ -54,6 +54,18 @@ func (l *mtlsListener) Init(md md.Metadata) (err error) {
return return
} }
func (l *mtlsListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}
func (l *mtlsListener) listenLoop() { func (l *mtlsListener) listenLoop() {
for { for {
conn, err := l.Listener.Accept() conn, err := l.Listener.Accept()
@ -67,6 +79,8 @@ func (l *mtlsListener) listenLoop() {
} }
func (l *mtlsListener) mux(conn net.Conn) { func (l *mtlsListener) mux(conn net.Conn) {
defer conn.Close()
smuxConfig := smux.DefaultConfig() smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
if l.md.muxKeepAliveInterval > 0 { if l.md.muxKeepAliveInterval > 0 {
@ -94,7 +108,7 @@ func (l *mtlsListener) mux(conn net.Conn) {
for { for {
stream, err := session.AcceptStream() stream, err := session.AcceptStream()
if err != nil { if err != nil {
l.logger.Error("accept stream:", err) l.logger.Error("accept stream: ", err)
return return
} }
@ -108,15 +122,3 @@ func (l *mtlsListener) mux(conn net.Conn) {
} }
} }
} }
func (l *mtlsListener) Accept() (conn net.Conn, err error) {
var ok bool
select {
case conn = <-l.cqueue:
case err, ok = <-l.errChan:
if !ok {
err = listener.ErrClosed
}
}
return
}

View File

@ -4,8 +4,9 @@ import (
"crypto/tls" "crypto/tls"
"net" "net"
"net/http" "net/http"
"net/http/httputil"
ws_util "github.com/go-gost/gost/pkg/common/util/ws" ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
"github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
@ -20,14 +21,14 @@ func init() {
type wsListener struct { type wsListener struct {
saddr string saddr string
md metadata
addr net.Addr addr net.Addr
upgrader *websocket.Upgrader upgrader *websocket.Upgrader
srv *http.Server srv *http.Server
tlsEnabled bool tlsEnabled bool
connChan chan net.Conn cqueue chan net.Conn
errChan chan error errChan chan error
logger logger.Logger logger logger.Logger
md metadata
} }
func NewListener(opts ...listener.Option) listener.Listener { func NewListener(opts ...listener.Option) listener.Listener {
@ -48,8 +49,8 @@ func NewTLSListener(opts ...listener.Option) listener.Listener {
} }
return &wsListener{ return &wsListener{
saddr: options.Addr, saddr: options.Addr,
tlsEnabled: true,
logger: options.Logger, logger: options.Logger,
tlsEnabled: true,
} }
} }
@ -62,35 +63,26 @@ func (l *wsListener) Init(md md.Metadata) (err error) {
HandshakeTimeout: l.md.handshakeTimeout, HandshakeTimeout: l.md.handshakeTimeout,
ReadBufferSize: l.md.readBufferSize, ReadBufferSize: l.md.readBufferSize,
WriteBufferSize: l.md.writeBufferSize, WriteBufferSize: l.md.writeBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: l.md.enableCompression, EnableCompression: l.md.enableCompression,
CheckOrigin: func(r *http.Request) bool { return true },
} }
path := l.md.path
if path == "" {
path = defaultPath
}
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle(path, http.HandlerFunc(l.upgrade)) mux.Handle(l.md.path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{ l.srv = &http.Server{
Addr: l.saddr, Addr: l.saddr,
TLSConfig: l.md.tlsConfig,
Handler: mux, Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout, ReadHeaderTimeout: l.md.readHeaderTimeout,
} }
queueSize := l.md.connQueueSize l.cqueue = make(chan net.Conn, l.md.backlog)
if queueSize <= 0 {
queueSize = defaultQueueSize
}
l.connChan = make(chan net.Conn, queueSize)
l.errChan = make(chan error, 1) l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.saddr) ln, err := net.Listen("tcp", l.saddr)
if err != nil { if err != nil {
return return
} }
if l.md.tlsConfig != nil { if l.tlsEnabled {
ln = tls.NewListener(ln, l.md.tlsConfig) ln = tls.NewListener(ln, l.md.tlsConfig)
} }
@ -110,7 +102,7 @@ func (l *wsListener) Init(md md.Metadata) (err error) {
func (l *wsListener) Accept() (conn net.Conn, err error) { func (l *wsListener) Accept() (conn net.Conn, err error) {
var ok bool var ok bool
select { select {
case conn = <-l.connChan: case conn = <-l.cqueue:
case err, ok = <-l.errChan: case err, ok = <-l.errChan:
if !ok { if !ok {
err = listener.ErrClosed err = listener.ErrClosed
@ -128,16 +120,25 @@ func (l *wsListener) Addr() net.Addr {
} }
func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) { func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) {
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader) if l.logger.IsLevelEnabled(logger.DebugLevel) {
log := l.logger.WithFields(map[string]interface{}{
"local": l.addr.String(),
"remote": r.RemoteAddr,
})
dump, _ := httputil.DumpRequest(r, false)
log.Debug(string(dump))
}
conn, err := l.upgrader.Upgrade(w, r, l.md.header)
if err != nil { if err != nil {
l.logger.Error(err) l.logger.Error(err)
return return
} }
select { select {
case l.connChan <- ws_util.WebsocketServerConn(conn): case l.cqueue <- ws_util.Conn(conn):
default: default:
conn.Close() conn.Close()
l.logger.Warn("connection queue is full") l.logger.Warnf("connection queue is full, client %s discarded", conn.RemoteAddr())
} }
} }

View File

@ -2,6 +2,7 @@ package ws
import ( import (
"crypto/tls" "crypto/tls"
"fmt"
"net/http" "net/http"
"time" "time"
@ -11,34 +12,39 @@ import (
const ( const (
defaultPath = "/ws" defaultPath = "/ws"
defaultQueueSize = 128 defaultBacklog = 128
) )
type metadata struct { type metadata struct {
path string path string
backlog int
tlsConfig *tls.Config tlsConfig *tls.Config
handshakeTimeout time.Duration handshakeTimeout time.Duration
readHeaderTimeout time.Duration readHeaderTimeout time.Duration
readBufferSize int readBufferSize int
writeBufferSize int writeBufferSize int
enableCompression bool enableCompression bool
responseHeader http.Header
connQueueSize int header http.Header
} }
func (l *wsListener) parseMetadata(md md.Metadata) (err error) { func (l *wsListener) parseMetadata(md md.Metadata) (err error) {
const ( const (
path = "path"
certFile = "certFile" certFile = "certFile"
keyFile = "keyFile" keyFile = "keyFile"
caFile = "caFile" caFile = "caFile"
path = "path"
backlog = "backlog"
handshakeTimeout = "handshakeTimeout" handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout" readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize" readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize" writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression" enableCompression = "enableCompression"
responseHeader = "responseHeader"
connQueueSize = "connQueueSize" header = "header"
) )
l.md.tlsConfig, err = tls_util.LoadServerConfig( l.md.tlsConfig, err = tls_util.LoadServerConfig(
@ -51,15 +57,28 @@ func (l *wsListener) parseMetadata(md md.Metadata) (err error) {
} }
l.md.path = md.GetString(path) l.md.path = md.GetString(path)
l.md.connQueueSize = md.GetInt(connQueueSize) if l.md.path == "" {
if l.md.connQueueSize <= 0 { l.md.path = defaultPath
l.md.connQueueSize = defaultQueueSize
} }
l.md.enableCompression = md.GetBool(enableCompression)
l.md.readBufferSize = md.GetInt(readBufferSize) l.md.backlog = md.GetInt(backlog)
l.md.writeBufferSize = md.GetInt(writeBufferSize) if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.handshakeTimeout = md.GetDuration(handshakeTimeout) l.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
l.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout) l.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout)
l.md.readBufferSize = md.GetInt(readBufferSize)
l.md.writeBufferSize = md.GetInt(writeBufferSize)
l.md.enableCompression = md.GetBool(enableCompression)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
h := http.Header{}
for k, v := range mm {
h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
l.md.header = h
}
return return
} }

View File

@ -4,8 +4,9 @@ import (
"crypto/tls" "crypto/tls"
"net" "net"
"net/http" "net/http"
"net/http/httputil"
ws_util "github.com/go-gost/gost/pkg/common/util/ws" ws_util "github.com/go-gost/gost/pkg/internal/util/ws"
"github.com/go-gost/gost/pkg/listener" "github.com/go-gost/gost/pkg/listener"
"github.com/go-gost/gost/pkg/logger" "github.com/go-gost/gost/pkg/logger"
md "github.com/go-gost/gost/pkg/metadata" md "github.com/go-gost/gost/pkg/metadata"
@ -16,18 +17,19 @@ import (
func init() { func init() {
registry.RegisterListener("mws", NewListener) registry.RegisterListener("mws", NewListener)
registry.RegisterListener("mwss", NewListener) registry.RegisterListener("mwss", NewTLSListener)
} }
type mwsListener struct { type mwsListener struct {
saddr string saddr string
md metadata
addr net.Addr addr net.Addr
upgrader *websocket.Upgrader upgrader *websocket.Upgrader
srv *http.Server srv *http.Server
connChan chan net.Conn cqueue chan net.Conn
errChan chan error errChan chan error
logger logger.Logger logger logger.Logger
md metadata
tlsEnabled bool
} }
func NewListener(opts ...listener.Option) listener.Listener { func NewListener(opts ...listener.Option) listener.Listener {
@ -36,10 +38,23 @@ func NewListener(opts ...listener.Option) listener.Listener {
opt(options) opt(options)
} }
return &mwsListener{ return &mwsListener{
saddr: options.Addr,
logger: options.Logger, logger: options.Logger,
} }
} }
func NewTLSListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
for _, opt := range opts {
opt(options)
}
return &mwsListener{
saddr: options.Addr,
logger: options.Logger,
tlsEnabled: true,
}
}
func (l *mwsListener) Init(md md.Metadata) (err error) { func (l *mwsListener) Init(md md.Metadata) (err error) {
if err = l.parseMetadata(md); err != nil { if err = l.parseMetadata(md); err != nil {
return return
@ -49,8 +64,8 @@ func (l *mwsListener) Init(md md.Metadata) (err error) {
HandshakeTimeout: l.md.handshakeTimeout, HandshakeTimeout: l.md.handshakeTimeout,
ReadBufferSize: l.md.readBufferSize, ReadBufferSize: l.md.readBufferSize,
WriteBufferSize: l.md.writeBufferSize, WriteBufferSize: l.md.writeBufferSize,
CheckOrigin: func(r *http.Request) bool { return true },
EnableCompression: l.md.enableCompression, EnableCompression: l.md.enableCompression,
CheckOrigin: func(r *http.Request) bool { return true },
} }
path := l.md.path path := l.md.path
@ -61,19 +76,18 @@ func (l *mwsListener) Init(md md.Metadata) (err error) {
mux.Handle(path, http.HandlerFunc(l.upgrade)) mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{ l.srv = &http.Server{
Addr: l.saddr, Addr: l.saddr,
TLSConfig: l.md.tlsConfig,
Handler: mux, Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout, ReadHeaderTimeout: l.md.readHeaderTimeout,
} }
l.connChan = make(chan net.Conn, l.md.connQueueSize) l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1) l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.saddr) ln, err := net.Listen("tcp", l.saddr)
if err != nil { if err != nil {
return return
} }
if l.md.tlsConfig != nil { if l.tlsEnabled {
ln = tls.NewListener(ln, l.md.tlsConfig) ln = tls.NewListener(ln, l.md.tlsConfig)
} }
@ -93,7 +107,7 @@ func (l *mwsListener) Init(md md.Metadata) (err error) {
func (l *mwsListener) Accept() (conn net.Conn, err error) { func (l *mwsListener) Accept() (conn net.Conn, err error) {
var ok bool var ok bool
select { select {
case conn = <-l.connChan: case conn = <-l.cqueue:
case err, ok = <-l.errChan: case err, ok = <-l.errChan:
if !ok { if !ok {
err = listener.ErrClosed err = listener.ErrClosed
@ -111,20 +125,31 @@ func (l *mwsListener) Addr() net.Addr {
} }
func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) { func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) {
conn, err := l.upgrader.Upgrade(w, r, l.md.responseHeader) if l.logger.IsLevelEnabled(logger.DebugLevel) {
log := l.logger.WithFields(map[string]interface{}{
"local": l.addr.String(),
"remote": r.RemoteAddr,
})
dump, _ := httputil.DumpRequest(r, false)
log.Debug(string(dump))
}
conn, err := l.upgrader.Upgrade(w, r, l.md.header)
if err != nil { if err != nil {
l.logger.Error(err) l.logger.Error(err)
return return
} }
l.mux(ws_util.WebsocketServerConn(conn)) l.mux(ws_util.Conn(conn))
} }
func (l *mwsListener) mux(conn net.Conn) { func (l *mwsListener) mux(conn net.Conn) {
defer conn.Close()
smuxConfig := smux.DefaultConfig() smuxConfig := smux.DefaultConfig()
smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled smuxConfig.KeepAliveDisabled = l.md.muxKeepAliveDisabled
if l.md.muxKeepAlivePeriod > 0 { if l.md.muxKeepAliveInterval > 0 {
smuxConfig.KeepAliveInterval = l.md.muxKeepAlivePeriod smuxConfig.KeepAliveInterval = l.md.muxKeepAliveInterval
} }
if l.md.muxKeepAliveTimeout > 0 { if l.md.muxKeepAliveTimeout > 0 {
smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout smuxConfig.KeepAliveTimeout = l.md.muxKeepAliveTimeout
@ -148,17 +173,17 @@ func (l *mwsListener) mux(conn net.Conn) {
for { for {
stream, err := session.AcceptStream() stream, err := session.AcceptStream()
if err != nil { if err != nil {
l.logger.Error("accept stream:", err) l.logger.Error("accept stream: ", err)
return return
} }
select { select {
case l.connChan <- stream: case l.cqueue <- stream:
case <-stream.GetDieCh(): case <-stream.GetDieCh():
stream.Close() stream.Close()
default: default:
stream.Close() stream.Close()
l.logger.Error("connection queue is full") l.logger.Warnf("connection queue is full, client %s discarded", stream.RemoteAddr())
} }
} }
} }

View File

@ -2,6 +2,7 @@ package mux
import ( import (
"crypto/tls" "crypto/tls"
"fmt"
"net/http" "net/http"
"time" "time"
@ -11,44 +12,47 @@ import (
const ( const (
defaultPath = "/ws" defaultPath = "/ws"
defaultQueueSize = 128 defaultBacklog = 128
) )
type metadata struct { type metadata struct {
path string path string
backlog int
tlsConfig *tls.Config tlsConfig *tls.Config
header http.Header
handshakeTimeout time.Duration handshakeTimeout time.Duration
readHeaderTimeout time.Duration readHeaderTimeout time.Duration
readBufferSize int readBufferSize int
writeBufferSize int writeBufferSize int
enableCompression bool enableCompression bool
responseHeader http.Header
muxKeepAliveDisabled bool muxKeepAliveDisabled bool
muxKeepAlivePeriod time.Duration muxKeepAliveInterval time.Duration
muxKeepAliveTimeout time.Duration muxKeepAliveTimeout time.Duration
muxMaxFrameSize int muxMaxFrameSize int
muxMaxReceiveBuffer int muxMaxReceiveBuffer int
muxMaxStreamBuffer int muxMaxStreamBuffer int
connQueueSize int
} }
func (l *mwsListener) parseMetadata(md md.Metadata) (err error) { func (l *mwsListener) parseMetadata(md md.Metadata) (err error) {
const ( const (
path = "path" path = "path"
backlog = "backlog"
header = "header"
certFile = "certFile" certFile = "certFile"
keyFile = "keyFile" keyFile = "keyFile"
caFile = "caFile" caFile = "caFile"
handshakeTimeout = "handshakeTimeout" handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout" readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize" readBufferSize = "readBufferSize"
writeBufferSize = "writeBufferSize" writeBufferSize = "writeBufferSize"
enableCompression = "enableCompression" enableCompression = "enableCompression"
responseHeader = "responseHeader"
connQueueSize = "connQueueSize"
muxKeepAliveDisabled = "muxKeepAliveDisabled" muxKeepAliveDisabled = "muxKeepAliveDisabled"
muxKeepAlivePeriod = "muxKeepAlivePeriod" muxKeepAliveInterval = "muxKeepAliveInterval"
muxKeepAliveTimeout = "muxKeepAliveTimeout" muxKeepAliveTimeout = "muxKeepAliveTimeout"
muxMaxFrameSize = "muxMaxFrameSize" muxMaxFrameSize = "muxMaxFrameSize"
muxMaxReceiveBuffer = "muxMaxReceiveBuffer" muxMaxReceiveBuffer = "muxMaxReceiveBuffer"
@ -64,5 +68,35 @@ func (l *mwsListener) parseMetadata(md md.Metadata) (err error) {
return return
} }
l.md.path = md.GetString(path)
if l.md.path == "" {
l.md.path = defaultPath
}
l.md.backlog = md.GetInt(backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog
}
l.md.handshakeTimeout = md.GetDuration(handshakeTimeout)
l.md.readHeaderTimeout = md.GetDuration(readHeaderTimeout)
l.md.readBufferSize = md.GetInt(readBufferSize)
l.md.writeBufferSize = md.GetInt(writeBufferSize)
l.md.enableCompression = md.GetBool(enableCompression)
l.md.muxKeepAliveDisabled = md.GetBool(muxKeepAliveDisabled)
l.md.muxKeepAliveInterval = md.GetDuration(muxKeepAliveInterval)
l.md.muxKeepAliveTimeout = md.GetDuration(muxKeepAliveTimeout)
l.md.muxMaxFrameSize = md.GetInt(muxMaxFrameSize)
l.md.muxMaxReceiveBuffer = md.GetInt(muxMaxReceiveBuffer)
l.md.muxMaxStreamBuffer = md.GetInt(muxMaxStreamBuffer)
if mm, _ := md.Get(header).(map[interface{}]interface{}); len(mm) > 0 {
h := http.Header{}
for k, v := range mm {
h.Add(fmt.Sprintf("%v", k), fmt.Sprintf("%v", v))
}
l.md.header = h
}
return return
} }