add tls config option

This commit is contained in:
ginuerzh
2022-01-05 00:02:55 +08:00
parent c428b37a36
commit 3b48c4acfb
43 changed files with 395 additions and 496 deletions

View File

@ -23,28 +23,31 @@ type wsDialer struct {
tlsEnabled bool
logger logger.Logger
md metadata
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &wsDialer{
logger: options.Logger,
logger: options.Logger,
options: options,
}
}
func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &wsDialer{
tlsEnabled: true,
logger: options.Logger,
options: options,
}
}
@ -96,7 +99,7 @@ func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dial
url := url.URL{Scheme: "ws", Host: host, Path: d.md.path}
if d.tlsEnabled {
url.Scheme = "wss"
dialer.TLSClientConfig = d.md.tlsConfig
dialer.TLSClientConfig = d.options.TLSConfig
}
c, resp, err := dialer.Dial(url.String(), d.md.header)

View File

@ -1,12 +1,9 @@
package ws
import (
"crypto/tls"
"net"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -15,9 +12,8 @@ const (
)
type metadata struct {
path string
host string
tlsConfig *tls.Config
host string
path string
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
@ -30,14 +26,8 @@ type metadata struct {
func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) {
const (
path = "path"
host = "host"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
path = "path"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
@ -48,25 +38,13 @@ func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) {
header = "header"
)
d.md.host = mdata.GetString(md, host)
d.md.path = mdata.GetString(md, path)
if d.md.path == "" {
d.md.path = defaultPath
}
d.md.host = mdata.GetString(md, host)
sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
)
d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
d.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout)
d.md.readBufferSize = mdata.GetInt(md, readBufferSize)

View File

@ -25,33 +25,36 @@ func init() {
type mwsDialer struct {
sessions map[string]*muxSession
sessionMutex sync.Mutex
tlsEnabled bool
logger logger.Logger
md metadata
tlsEnabled bool
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &mwsDialer{
sessions: make(map[string]*muxSession),
logger: options.Logger,
options: options,
}
}
func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &mwsDialer{
tlsEnabled: true,
sessions: make(map[string]*muxSession),
logger: options.Logger,
tlsEnabled: true,
options: options,
}
}
func (d *mwsDialer) Init(md md.Metadata) (err error) {
@ -182,7 +185,7 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
url := url.URL{Scheme: "ws", Host: host, Path: d.md.path}
if d.tlsEnabled {
url.Scheme = "wss"
dialer.TLSClientConfig = d.md.tlsConfig
dialer.TLSClientConfig = d.options.TLSConfig
}
c, resp, err := dialer.Dial(url.String(), d.md.header)

View File

@ -1,12 +1,9 @@
package mux
import (
"crypto/tls"
"net"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -15,9 +12,8 @@ const (
)
type metadata struct {
path string
host string
tlsConfig *tls.Config
host string
path string
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
@ -37,14 +33,8 @@ type metadata struct {
func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
const (
path = "path"
host = "host"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
path = "path"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
@ -62,25 +52,13 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
d.md.host = mdata.GetString(md, host)
d.md.path = mdata.GetString(md, path)
if d.md.path == "" {
d.md.path = defaultPath
}
d.md.host = mdata.GetString(md, host)
sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
)
d.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled)
d.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval)
d.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout)