add tls config option

This commit is contained in:
ginuerzh
2022-01-05 00:02:55 +08:00
parent c428b37a36
commit 3b48c4acfb
43 changed files with 395 additions and 496 deletions

View File

@ -19,21 +19,23 @@ func init() {
}
type http2Dialer struct {
md metadata
clients map[string]*http.Client
clientMutex sync.Mutex
logger logger.Logger
md metadata
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &http2Dialer{
clients: make(map[string]*http.Client),
logger: options.Logger,
options: options,
}
}
@ -69,7 +71,7 @@ func (d *http2Dialer) Dial(ctx context.Context, address string, opts ...dialer.D
if !ok {
client = &http.Client{
Transport: &http.Transport{
TLSClientConfig: d.md.tlsConfig,
TLSClientConfig: d.options.TLSConfig,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return d.dial(ctx, network, addr, options)
},

View File

@ -27,33 +27,36 @@ func init() {
type h2Dialer struct {
clients map[string]*http.Client
clientMutex sync.Mutex
h2c bool
logger logger.Logger
md metadata
h2c bool
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &h2Dialer{
h2c: true,
clients: make(map[string]*http.Client),
logger: options.Logger,
h2c: true,
options: options,
}
}
func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &h2Dialer{
clients: make(map[string]*http.Client),
logger: options.Logger,
options: options,
}
}
@ -95,7 +98,7 @@ func (d *h2Dialer) Dial(ctx context.Context, address string, opts ...dialer.Dial
}
} else {
client.Transport = &http.Transport{
TLSClientConfig: d.md.tlsConfig,
TLSClientConfig: d.options.TLSConfig,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return d.dial(ctx, network, addr, options)
},

View File

@ -1,42 +1,21 @@
package h2
import (
"crypto/tls"
"net"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
path string
host string
tlsConfig *tls.Config
host string
path string
}
func (d *h2Dialer) parseMetadata(md mdata.Metadata) (err error) {
const (
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
path = "path"
)
d.md.host = mdata.GetString(md, serverName)
sn, _, _ := net.SplitHostPort(d.md.host)
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
host = "host"
path = "path"
)
d.md.host = mdata.GetString(md, host)
d.md.path = mdata.GetString(md, path)
return

View File

@ -1,37 +1,12 @@
package http2
import (
"crypto/tls"
"net"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
tlsConfig *tls.Config
}
func (d *http2Dialer) parseMetadata(md mdata.Metadata) (err error) {
const (
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
)
sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
)
return
}

View File

@ -2,6 +2,7 @@ package dialer
import (
"context"
"crypto/tls"
"net"
"net/url"
@ -9,8 +10,9 @@ import (
)
type Options struct {
User *url.Userinfo
Logger logger.Logger
User *url.Userinfo
TLSConfig *tls.Config
Logger logger.Logger
}
type Option func(opts *Options)
@ -21,6 +23,12 @@ func UserOption(user *url.Userinfo) Option {
}
}
func TLSConfigOption(tlsConfig *tls.Config) Option {
return func(opts *Options) {
opts.TLSConfig = tlsConfig
}
}
func LoggerOption(logger logger.Logger) Option {
return func(opts *Options) {
opts.Logger = logger

View File

@ -24,17 +24,19 @@ type quicDialer struct {
sessionMutex sync.Mutex
logger logger.Logger
md metadata
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &quicDialer{
sessions: make(map[string]*quicSession),
logger: options.Logger,
options: options,
}
}
@ -141,7 +143,7 @@ func (d *quicDialer) initSession(ctx context.Context, addr string, conn net.Conn
},
}
tlsCfg := d.md.tlsConfig
tlsCfg := d.options.TLSConfig
tlsCfg.NextProtos = []string{"http/3", "quic/v1"}
session, err := quic.DialContext(ctx, pc, udpAddr, addr, tlsCfg, quicConfig)

View File

@ -1,11 +1,8 @@
package quic
import (
"crypto/tls"
"net"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -15,7 +12,6 @@ type metadata struct {
handshakeTimeout time.Duration
cipherKey []byte
tlsConfig *tls.Config
}
func (d *quicDialer) parseMetadata(md mdata.Metadata) (err error) {
@ -24,12 +20,6 @@ func (d *quicDialer) parseMetadata(md mdata.Metadata) (err error) {
handshakeTimeout = "handshakeTimeout"
maxIdleTimeout = "maxIdleTimeout"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
cipherKey = "cipherKey"
)
@ -39,18 +29,6 @@ func (d *quicDialer) parseMetadata(md mdata.Metadata) (err error) {
d.md.cipherKey = []byte(key)
}
sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
)
d.md.keepAlive = mdata.GetBool(md, keepAlive)
d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
d.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout)

View File

@ -17,18 +17,20 @@ func init() {
}
type tlsDialer struct {
md metadata
logger logger.Logger
md metadata
logger logger.Logger
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &tlsDialer{
logger: options.Logger,
logger: options.Logger,
options: options,
}
}
@ -57,7 +59,7 @@ func (d *tlsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
defer conn.SetDeadline(time.Time{})
}
tlsConn := tls.Client(conn, d.md.tlsConfig)
tlsConn := tls.Client(conn, d.options.TLSConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
conn.Close()
return nil, err

View File

@ -1,42 +1,20 @@
package tls
import (
"crypto/tls"
"net"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
tlsConfig *tls.Config
handshakeTimeout time.Duration
}
func (d *tlsDialer) parseMetadata(md mdata.Metadata) (err error) {
const (
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
handshakeTimeout = "handshakeTimeout"
)
sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
)
d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
return

View File

@ -24,17 +24,19 @@ type mtlsDialer struct {
sessionMutex sync.Mutex
logger logger.Logger
md metadata
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &mtlsDialer{
sessions: make(map[string]*muxSession),
logger: options.Logger,
options: options,
}
}
@ -149,7 +151,7 @@ func (d *mtlsDialer) dial(ctx context.Context, network, addr string, opts *diale
}
func (d *mtlsDialer) initSession(ctx context.Context, conn net.Conn) (*muxSession, error) {
tlsConn := tls.Client(conn, d.md.tlsConfig)
tlsConn := tls.Client(conn, d.options.TLSConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
return nil, err
}

View File

@ -1,16 +1,12 @@
package mux
import (
"crypto/tls"
"net"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
tlsConfig *tls.Config
handshakeTimeout time.Duration
muxKeepAliveDisabled bool
@ -23,12 +19,6 @@ type metadata struct {
func (d *mtlsDialer) parseMetadata(md mdata.Metadata) (err error) {
const (
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
handshakeTimeout = "handshakeTimeout"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
@ -39,17 +29,6 @@ func (d *mtlsDialer) parseMetadata(md mdata.Metadata) (err error) {
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
)
d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
d.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled)

View File

@ -23,28 +23,31 @@ type wsDialer struct {
tlsEnabled bool
logger logger.Logger
md metadata
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &wsDialer{
logger: options.Logger,
logger: options.Logger,
options: options,
}
}
func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &wsDialer{
tlsEnabled: true,
logger: options.Logger,
options: options,
}
}
@ -96,7 +99,7 @@ func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dial
url := url.URL{Scheme: "ws", Host: host, Path: d.md.path}
if d.tlsEnabled {
url.Scheme = "wss"
dialer.TLSClientConfig = d.md.tlsConfig
dialer.TLSClientConfig = d.options.TLSConfig
}
c, resp, err := dialer.Dial(url.String(), d.md.header)

View File

@ -1,12 +1,9 @@
package ws
import (
"crypto/tls"
"net"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -15,9 +12,8 @@ const (
)
type metadata struct {
path string
host string
tlsConfig *tls.Config
host string
path string
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
@ -30,14 +26,8 @@ type metadata struct {
func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) {
const (
path = "path"
host = "host"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
path = "path"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
@ -48,25 +38,13 @@ func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) {
header = "header"
)
d.md.host = mdata.GetString(md, host)
d.md.path = mdata.GetString(md, path)
if d.md.path == "" {
d.md.path = defaultPath
}
d.md.host = mdata.GetString(md, host)
sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
)
d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
d.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout)
d.md.readBufferSize = mdata.GetInt(md, readBufferSize)

View File

@ -25,33 +25,36 @@ func init() {
type mwsDialer struct {
sessions map[string]*muxSession
sessionMutex sync.Mutex
tlsEnabled bool
logger logger.Logger
md metadata
tlsEnabled bool
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &mwsDialer{
sessions: make(map[string]*muxSession),
logger: options.Logger,
options: options,
}
}
func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &mwsDialer{
tlsEnabled: true,
sessions: make(map[string]*muxSession),
logger: options.Logger,
tlsEnabled: true,
options: options,
}
}
func (d *mwsDialer) Init(md md.Metadata) (err error) {
@ -182,7 +185,7 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
url := url.URL{Scheme: "ws", Host: host, Path: d.md.path}
if d.tlsEnabled {
url.Scheme = "wss"
dialer.TLSClientConfig = d.md.tlsConfig
dialer.TLSClientConfig = d.options.TLSConfig
}
c, resp, err := dialer.Dial(url.String(), d.md.header)

View File

@ -1,12 +1,9 @@
package mux
import (
"crypto/tls"
"net"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -15,9 +12,8 @@ const (
)
type metadata struct {
path string
host string
tlsConfig *tls.Config
host string
path string
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
@ -37,14 +33,8 @@ type metadata struct {
func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
const (
path = "path"
host = "host"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
path = "path"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
@ -62,25 +52,13 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
d.md.host = mdata.GetString(md, host)
d.md.path = mdata.GetString(md, path)
if d.md.path == "" {
d.md.path = defaultPath
}
d.md.host = mdata.GetString(md, host)
sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
)
d.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled)
d.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval)
d.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout)