add tls config option

This commit is contained in:
ginuerzh 2022-01-05 00:02:55 +08:00
parent c428b37a36
commit 3b48c4acfb
43 changed files with 395 additions and 496 deletions

View File

@ -1,6 +1,7 @@
package main
import (
"encoding/base64"
"errors"
"fmt"
"net/url"
@ -115,21 +116,46 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) {
}
}
var auths []*config.AuthConfig
if url.User != nil {
auth := &config.AuthConfig{
Username: url.User.Username(),
}
auth.Password, _ = url.User.Password()
auths = append(auths, auth)
}
md := make(map[string]interface{})
for k, v := range url.Query() {
if len(v) > 0 {
md[k] = v[0]
}
}
var auths []config.AuthConfig
if url.User != nil {
auth := config.AuthConfig{
Username: url.User.Username(),
if sauth := md["auth"]; sauth != nil {
if sa, _ := sauth.(string); sa != "" {
au, err := parseAuthFromCmd(sa)
if err != nil {
return nil, err
}
auths = append(auths, au)
}
auth.Password, _ = url.User.Password()
auths = append(auths, auth)
}
delete(md, "auth")
var tlsConfig *config.TLSConfig
if certs := md["cert"]; certs != nil {
cert, _ := certs.(string)
key, _ := md["key"].(string)
ca, _ := md["ca"].(string)
tlsConfig = &config.TLSConfig{
Cert: cert,
Key: key,
CA: ca,
}
}
delete(md, "cert")
delete(md, "key")
delete(md, "ca")
svc.Handler = &config.HandlerConfig{
Type: handler,
@ -138,6 +164,7 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) {
}
svc.Listener = &config.ListenerConfig{
Type: listener,
TLS: tlsConfig,
Metadata: md,
}
@ -170,14 +197,6 @@ func buildNodeConfig(url *url.URL) (*config.NodeConfig, error) {
}
}
md := make(map[string]interface{})
for k, v := range url.Query() {
if len(v) > 0 {
md[k] = v[0]
}
}
md["serverName"] = url.Host
var auth *config.AuthConfig
if url.User != nil {
auth = &config.AuthConfig{
@ -186,6 +205,46 @@ func buildNodeConfig(url *url.URL) (*config.NodeConfig, error) {
auth.Password, _ = url.User.Password()
}
md := make(map[string]interface{})
for k, v := range url.Query() {
if len(v) > 0 {
md[k] = v[0]
}
}
md["serverName"] = url.Host
if sauth := md["auth"]; sauth != nil && auth == nil {
if sa, _ := sauth.(string); sa != "" {
au, err := parseAuthFromCmd(sa)
if err != nil {
return nil, err
}
auth = au
}
}
delete(md, "auth")
var tlsConfig *config.TLSConfig
if certs := md["cert"]; certs != nil {
cert, _ := certs.(string)
key, _ := md["key"].(string)
ca, _ := md["ca"].(string)
secure, _ := md["secure"].(bool)
serverName, _ := md["serverName"].(string)
tlsConfig = &config.TLSConfig{
Cert: cert,
Key: key,
CA: ca,
Secure: secure,
ServerName: serverName,
}
}
delete(md, "cert")
delete(md, "key")
delete(md, "ca")
delete(md, "secure")
delete(md, "serverName")
node.Connector = &config.ConnectorConfig{
Type: connector,
Auth: auth,
@ -193,6 +252,7 @@ func buildNodeConfig(url *url.URL) (*config.NodeConfig, error) {
}
node.Dialer = &config.DialerConfig{
Type: dialer,
TLS: tlsConfig,
Metadata: md,
}
@ -209,5 +269,32 @@ func normCmd(s string) (*url.URL, error) {
s = "auto://" + s
}
return url.Parse(s)
url, err := url.Parse(s)
if err != nil {
return nil, err
}
if url.Scheme == "https" {
url.Scheme = "http+tls"
}
return url, nil
}
func parseAuthFromCmd(sa string) (*config.AuthConfig, error) {
v, err := base64.StdEncoding.DecodeString(sa)
if err != nil {
return nil, err
}
cs := string(v)
n := strings.IndexByte(cs, ':')
if n < 0 {
return &config.AuthConfig{
Username: cs,
}, nil
}
return &config.AuthConfig{
Username: cs[:n],
Password: cs[n+1:],
}, nil
}

View File

@ -1,6 +1,7 @@
package main
import (
"crypto/tls"
"io"
"net"
"net/url"
@ -68,9 +69,23 @@ func buildService(cfg *config.Config) (services []*service.Service) {
listenerLogger := serviceLogger.WithFields(map[string]interface{}{
"kind": "listener",
})
var tlsConfig *tls.Config
var err error
tlsCfg := svc.Listener.TLS
if tlsCfg == nil {
tlsCfg = &config.TLSConfig{}
}
tlsConfig, err = loadServerTLSConfig(tlsCfg)
if err != nil {
log.Fatal(err)
}
ln := registry.GetListener(svc.Listener.Type)(
listener.AddrOption(svc.Addr),
listener.AuthsOption(authsFromConfig(svc.Listener.Auths...)...),
listener.TLSConfigOption(tlsConfig),
listener.LoggerOption(listenerLogger),
)
@ -89,6 +104,16 @@ func buildService(cfg *config.Config) (services []*service.Service) {
"kind": "handler",
})
tlsConfig = nil
tlsCfg = svc.Handler.TLS
if tlsCfg == nil {
tlsCfg = &config.TLSConfig{}
}
tlsConfig, err = loadServerTLSConfig(tlsCfg)
if err != nil {
log.Fatal(err)
}
h := registry.GetHandler(svc.Handler.Type)(
handler.RetriesOption(svc.Handler.Retries),
handler.ChainOption(chains[svc.Handler.Chain]),
@ -96,6 +121,7 @@ func buildService(cfg *config.Config) (services []*service.Service) {
handler.HostsOption(hosts[svc.Handler.Hosts]),
handler.BypassOption(bypasses[svc.Handler.Bypass]),
handler.AuthsOption(authsFromConfig(svc.Handler.Auths...)...),
handler.TLSConfigOption(tlsConfig),
handler.LoggerOption(handlerLogger),
)
@ -148,16 +174,29 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain {
"kind": "connector",
})
var connectorUser *url.Userinfo
var user *url.Userinfo
if auth := v.Connector.Auth; auth != nil && auth.Username != "" {
if auth.Password == "" {
connectorUser = url.User(auth.Username)
user = url.User(auth.Username)
} else {
connectorUser = url.UserPassword(auth.Username, auth.Password)
user = url.UserPassword(auth.Username, auth.Password)
}
}
var tlsConfig *tls.Config
var err error
tlsCfg := v.Connector.TLS
if tlsCfg == nil {
tlsCfg = &config.TLSConfig{}
}
tlsConfig, err = loadClientTLSConfig(tlsCfg)
if err != nil {
log.Fatal(err)
}
cr := registry.GetConnector(v.Connector.Type)(
connector.UserOption(connectorUser),
connector.UserOption(user),
connector.TLSConfigOption(tlsConfig),
connector.LoggerOption(connectorLogger),
)
@ -172,16 +211,28 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain {
"kind": "dialer",
})
var dialerUser *url.Userinfo
user = nil
if auth := v.Dialer.Auth; auth != nil && auth.Username != "" {
if auth.Password == "" {
dialerUser = url.User(auth.Username)
user = url.User(auth.Username)
} else {
dialerUser = url.UserPassword(auth.Username, auth.Password)
user = url.UserPassword(auth.Username, auth.Password)
}
}
tlsConfig = nil
tlsCfg = v.Dialer.TLS
if tlsCfg == nil {
tlsCfg = &config.TLSConfig{}
}
tlsConfig, err = loadClientTLSConfig(tlsCfg)
if err != nil {
log.Fatal(err)
}
d := registry.GetDialer(v.Dialer.Type)(
dialer.UserOption(dialerUser),
dialer.UserOption(user),
dialer.TLSConfigOption(tlsConfig),
dialer.LoggerOption(dialerLogger),
)
@ -328,11 +379,11 @@ func hostsFromConfig(cfg *config.HostsConfig) hostspkg.HostMapper {
return hosts
}
func authsFromConfig(cfgs ...config.AuthConfig) []*url.Userinfo {
func authsFromConfig(cfgs ...*config.AuthConfig) []*url.Userinfo {
var auths []*url.Userinfo
for _, cfg := range cfgs {
if cfg.Username == "" {
if cfg == nil || cfg.Username == "" {
continue
}
auths = append(auths, url.UserPassword(cfg.Username, cfg.Password))

View File

@ -14,6 +14,14 @@ import (
"github.com/go-gost/gost/pkg/config"
)
func loadServerTLSConfig(cfg *config.TLSConfig) (*tls.Config, error) {
return tls_util.LoadServerConfig(cfg.Cert, cfg.Key, cfg.CA)
}
func loadClientTLSConfig(cfg *config.TLSConfig) (*tls.Config, error) {
return tls_util.LoadClientConfig(cfg.Cert, cfg.Key, cfg.CA, cfg.Secure, cfg.ServerName)
}
func buildDefaultTLSConfig(cfg *config.TLSConfig) {
if cfg == nil {
cfg = &config.TLSConfig{

View File

@ -283,9 +283,9 @@ bypasses:
# http://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml
- 224.0.0.0/4 # RFC5771: Multicast/Reserved
# tls:
# cert: "cert.pem"
# key: "key.pem"
tls:
cert: "cert.pem"
key: "key.pem"
# ca: "root.ca"
resolvers:

View File

@ -31,9 +31,11 @@ type ProfilingConfig struct {
}
type TLSConfig struct {
Cert string
Key string
CA string
Cert string
Key string
CA string `yaml:",omitempty"`
Secure bool `yaml:",omitempty"`
ServerName string `yaml:",omitempty"`
}
type AuthConfig struct {
@ -82,7 +84,8 @@ type HostsConfig struct {
type ListenerConfig struct {
Type string
Chain string `yaml:",omitempty"`
Auths []AuthConfig `yaml:",omitempty"`
Auths []*AuthConfig `yaml:",omitempty"`
TLS *TLSConfig `yaml:",omitempty"`
Metadata map[string]interface{} `yaml:",omitempty"`
}
@ -93,7 +96,8 @@ type HandlerConfig struct {
Bypass string `yaml:",omitempty"`
Resolver string `yaml:",omitempty"`
Hosts string `yaml:",omitempty"`
Auths []AuthConfig `yaml:",omitempty"`
Auths []*AuthConfig `yaml:",omitempty"`
TLS *TLSConfig `yaml:",omitempty"`
Metadata map[string]interface{} `yaml:",omitempty"`
}
@ -105,12 +109,14 @@ type ForwarderConfig struct {
type DialerConfig struct {
Type string
Auth *AuthConfig `yaml:",omitempty"`
TLS *TLSConfig `yaml:",omitempty"`
Metadata map[string]interface{} `yaml:",omitempty"`
}
type ConnectorConfig struct {
Type string
Auth *AuthConfig `yaml:",omitempty"`
TLS *TLSConfig `yaml:",omitempty"`
Metadata map[string]interface{} `yaml:",omitempty"`
}

View File

@ -1,6 +1,7 @@
package connector
import (
"crypto/tls"
"net/url"
"time"
@ -8,8 +9,9 @@ import (
)
type Options struct {
User *url.Userinfo
Logger logger.Logger
User *url.Userinfo
TLSConfig *tls.Config
Logger logger.Logger
}
type Option func(opts *Options)
@ -20,6 +22,12 @@ func UserOption(user *url.Userinfo) Option {
}
}
func TLSConfigOption(tlsConfig *tls.Config) Option {
return func(opts *Options) {
opts.TLSConfig = tlsConfig
}
}
func LoggerOption(logger logger.Logger) Option {
return func(opts *Options) {
opts.Logger = logger

View File

@ -53,7 +53,7 @@ func (c *socks5Connector) Init(md md.Metadata) (err error) {
},
logger: c.logger,
User: c.options.User,
TLSConfig: c.md.tlsConfig,
TLSConfig: c.options.TLSConfig,
}
if !c.md.noTLS {
selector.methods = append(selector.methods, socks.MethodTLS)

View File

@ -1,7 +1,6 @@
package v5
import (
"crypto/tls"
"time"
mdata "github.com/go-gost/gost/pkg/metadata"
@ -9,7 +8,6 @@ import (
type metadata struct {
connectTimeout time.Duration
tlsConfig *tls.Config
noTLS bool
}

View File

@ -19,21 +19,23 @@ func init() {
}
type http2Dialer struct {
md metadata
clients map[string]*http.Client
clientMutex sync.Mutex
logger logger.Logger
md metadata
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &http2Dialer{
clients: make(map[string]*http.Client),
logger: options.Logger,
options: options,
}
}
@ -69,7 +71,7 @@ func (d *http2Dialer) Dial(ctx context.Context, address string, opts ...dialer.D
if !ok {
client = &http.Client{
Transport: &http.Transport{
TLSClientConfig: d.md.tlsConfig,
TLSClientConfig: d.options.TLSConfig,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return d.dial(ctx, network, addr, options)
},

View File

@ -27,33 +27,36 @@ func init() {
type h2Dialer struct {
clients map[string]*http.Client
clientMutex sync.Mutex
h2c bool
logger logger.Logger
md metadata
h2c bool
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &h2Dialer{
h2c: true,
clients: make(map[string]*http.Client),
logger: options.Logger,
h2c: true,
options: options,
}
}
func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &h2Dialer{
clients: make(map[string]*http.Client),
logger: options.Logger,
options: options,
}
}
@ -95,7 +98,7 @@ func (d *h2Dialer) Dial(ctx context.Context, address string, opts ...dialer.Dial
}
} else {
client.Transport = &http.Transport{
TLSClientConfig: d.md.tlsConfig,
TLSClientConfig: d.options.TLSConfig,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return d.dial(ctx, network, addr, options)
},

View File

@ -1,42 +1,21 @@
package h2
import (
"crypto/tls"
"net"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
path string
host string
tlsConfig *tls.Config
host string
path string
}
func (d *h2Dialer) parseMetadata(md mdata.Metadata) (err error) {
const (
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
path = "path"
)
d.md.host = mdata.GetString(md, serverName)
sn, _, _ := net.SplitHostPort(d.md.host)
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
host = "host"
path = "path"
)
d.md.host = mdata.GetString(md, host)
d.md.path = mdata.GetString(md, path)
return

View File

@ -1,37 +1,12 @@
package http2
import (
"crypto/tls"
"net"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
tlsConfig *tls.Config
}
func (d *http2Dialer) parseMetadata(md mdata.Metadata) (err error) {
const (
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
)
sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
)
return
}

View File

@ -2,6 +2,7 @@ package dialer
import (
"context"
"crypto/tls"
"net"
"net/url"
@ -9,8 +10,9 @@ import (
)
type Options struct {
User *url.Userinfo
Logger logger.Logger
User *url.Userinfo
TLSConfig *tls.Config
Logger logger.Logger
}
type Option func(opts *Options)
@ -21,6 +23,12 @@ func UserOption(user *url.Userinfo) Option {
}
}
func TLSConfigOption(tlsConfig *tls.Config) Option {
return func(opts *Options) {
opts.TLSConfig = tlsConfig
}
}
func LoggerOption(logger logger.Logger) Option {
return func(opts *Options) {
opts.Logger = logger

View File

@ -24,17 +24,19 @@ type quicDialer struct {
sessionMutex sync.Mutex
logger logger.Logger
md metadata
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &quicDialer{
sessions: make(map[string]*quicSession),
logger: options.Logger,
options: options,
}
}
@ -141,7 +143,7 @@ func (d *quicDialer) initSession(ctx context.Context, addr string, conn net.Conn
},
}
tlsCfg := d.md.tlsConfig
tlsCfg := d.options.TLSConfig
tlsCfg.NextProtos = []string{"http/3", "quic/v1"}
session, err := quic.DialContext(ctx, pc, udpAddr, addr, tlsCfg, quicConfig)

View File

@ -1,11 +1,8 @@
package quic
import (
"crypto/tls"
"net"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -15,7 +12,6 @@ type metadata struct {
handshakeTimeout time.Duration
cipherKey []byte
tlsConfig *tls.Config
}
func (d *quicDialer) parseMetadata(md mdata.Metadata) (err error) {
@ -24,12 +20,6 @@ func (d *quicDialer) parseMetadata(md mdata.Metadata) (err error) {
handshakeTimeout = "handshakeTimeout"
maxIdleTimeout = "maxIdleTimeout"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
cipherKey = "cipherKey"
)
@ -39,18 +29,6 @@ func (d *quicDialer) parseMetadata(md mdata.Metadata) (err error) {
d.md.cipherKey = []byte(key)
}
sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
)
d.md.keepAlive = mdata.GetBool(md, keepAlive)
d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
d.md.maxIdleTimeout = mdata.GetDuration(md, maxIdleTimeout)

View File

@ -17,18 +17,20 @@ func init() {
}
type tlsDialer struct {
md metadata
logger logger.Logger
md metadata
logger logger.Logger
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &tlsDialer{
logger: options.Logger,
logger: options.Logger,
options: options,
}
}
@ -57,7 +59,7 @@ func (d *tlsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dia
defer conn.SetDeadline(time.Time{})
}
tlsConn := tls.Client(conn, d.md.tlsConfig)
tlsConn := tls.Client(conn, d.options.TLSConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
conn.Close()
return nil, err

View File

@ -1,42 +1,20 @@
package tls
import (
"crypto/tls"
"net"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
tlsConfig *tls.Config
handshakeTimeout time.Duration
}
func (d *tlsDialer) parseMetadata(md mdata.Metadata) (err error) {
const (
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
handshakeTimeout = "handshakeTimeout"
)
sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
)
d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
return

View File

@ -24,17 +24,19 @@ type mtlsDialer struct {
sessionMutex sync.Mutex
logger logger.Logger
md metadata
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &mtlsDialer{
sessions: make(map[string]*muxSession),
logger: options.Logger,
options: options,
}
}
@ -149,7 +151,7 @@ func (d *mtlsDialer) dial(ctx context.Context, network, addr string, opts *diale
}
func (d *mtlsDialer) initSession(ctx context.Context, conn net.Conn) (*muxSession, error) {
tlsConn := tls.Client(conn, d.md.tlsConfig)
tlsConn := tls.Client(conn, d.options.TLSConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
return nil, err
}

View File

@ -1,16 +1,12 @@
package mux
import (
"crypto/tls"
"net"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
tlsConfig *tls.Config
handshakeTimeout time.Duration
muxKeepAliveDisabled bool
@ -23,12 +19,6 @@ type metadata struct {
func (d *mtlsDialer) parseMetadata(md mdata.Metadata) (err error) {
const (
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
handshakeTimeout = "handshakeTimeout"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
@ -39,17 +29,6 @@ func (d *mtlsDialer) parseMetadata(md mdata.Metadata) (err error) {
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
)
d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
d.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled)

View File

@ -23,28 +23,31 @@ type wsDialer struct {
tlsEnabled bool
logger logger.Logger
md metadata
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &wsDialer{
logger: options.Logger,
logger: options.Logger,
options: options,
}
}
func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &wsDialer{
tlsEnabled: true,
logger: options.Logger,
options: options,
}
}
@ -96,7 +99,7 @@ func (d *wsDialer) Handshake(ctx context.Context, conn net.Conn, options ...dial
url := url.URL{Scheme: "ws", Host: host, Path: d.md.path}
if d.tlsEnabled {
url.Scheme = "wss"
dialer.TLSClientConfig = d.md.tlsConfig
dialer.TLSClientConfig = d.options.TLSConfig
}
c, resp, err := dialer.Dial(url.String(), d.md.header)

View File

@ -1,12 +1,9 @@
package ws
import (
"crypto/tls"
"net"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -15,9 +12,8 @@ const (
)
type metadata struct {
path string
host string
tlsConfig *tls.Config
host string
path string
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
@ -30,14 +26,8 @@ type metadata struct {
func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) {
const (
path = "path"
host = "host"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
path = "path"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
@ -48,25 +38,13 @@ func (d *wsDialer) parseMetadata(md mdata.Metadata) (err error) {
header = "header"
)
d.md.host = mdata.GetString(md, host)
d.md.path = mdata.GetString(md, path)
if d.md.path == "" {
d.md.path = defaultPath
}
d.md.host = mdata.GetString(md, host)
sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
)
d.md.handshakeTimeout = mdata.GetDuration(md, handshakeTimeout)
d.md.readHeaderTimeout = mdata.GetDuration(md, readHeaderTimeout)
d.md.readBufferSize = mdata.GetInt(md, readBufferSize)

View File

@ -25,33 +25,36 @@ func init() {
type mwsDialer struct {
sessions map[string]*muxSession
sessionMutex sync.Mutex
tlsEnabled bool
logger logger.Logger
md metadata
tlsEnabled bool
options dialer.Options
}
func NewDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &mwsDialer{
sessions: make(map[string]*muxSession),
logger: options.Logger,
options: options,
}
}
func NewTLSDialer(opts ...dialer.Option) dialer.Dialer {
options := &dialer.Options{}
options := dialer.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &mwsDialer{
tlsEnabled: true,
sessions: make(map[string]*muxSession),
logger: options.Logger,
tlsEnabled: true,
options: options,
}
}
func (d *mwsDialer) Init(md md.Metadata) (err error) {
@ -182,7 +185,7 @@ func (d *mwsDialer) initSession(ctx context.Context, host string, conn net.Conn)
url := url.URL{Scheme: "ws", Host: host, Path: d.md.path}
if d.tlsEnabled {
url.Scheme = "wss"
dialer.TLSClientConfig = d.md.tlsConfig
dialer.TLSClientConfig = d.options.TLSConfig
}
c, resp, err := dialer.Dial(url.String(), d.md.header)

View File

@ -1,12 +1,9 @@
package mux
import (
"crypto/tls"
"net"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -15,9 +12,8 @@ const (
)
type metadata struct {
path string
host string
tlsConfig *tls.Config
host string
path string
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
@ -37,14 +33,8 @@ type metadata struct {
func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
const (
path = "path"
host = "host"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
secure = "secure"
serverName = "serverName"
path = "path"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
@ -62,25 +52,13 @@ func (d *mwsDialer) parseMetadata(md mdata.Metadata) (err error) {
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
d.md.host = mdata.GetString(md, host)
d.md.path = mdata.GetString(md, path)
if d.md.path == "" {
d.md.path = defaultPath
}
d.md.host = mdata.GetString(md, host)
sn, _, _ := net.SplitHostPort(mdata.GetString(md, serverName))
if sn == "" {
sn = "localhost"
}
d.md.tlsConfig, err = tls_util.LoadClientConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
mdata.GetBool(md, secure),
sn,
)
d.md.muxKeepAliveDisabled = mdata.GetBool(md, muxKeepAliveDisabled)
d.md.muxKeepAliveInterval = mdata.GetDuration(md, muxKeepAliveInterval)
d.md.muxKeepAliveTimeout = mdata.GetDuration(md, muxKeepAliveTimeout)

View File

@ -1,6 +1,7 @@
package handler
import (
"crypto/tls"
"net/url"
"github.com/go-gost/gost/pkg/bypass"
@ -11,13 +12,14 @@ import (
)
type Options struct {
Retries int
Chain *chain.Chain
Resolver resolver.Resolver
Hosts hosts.HostMapper
Bypass bypass.Bypass
Auths []*url.Userinfo
Logger logger.Logger
Retries int
Chain *chain.Chain
Resolver resolver.Resolver
Hosts hosts.HostMapper
Bypass bypass.Bypass
Auths []*url.Userinfo
TLSConfig *tls.Config
Logger logger.Logger
}
type Option func(opts *Options)
@ -58,6 +60,12 @@ func AuthsOption(auths ...*url.Userinfo) Option {
}
}
func TLSConfigOption(tlsConfig *tls.Config) Option {
return func(opts *Options) {
opts.TLSConfig = tlsConfig
}
}
func LoggerOption(logger logger.Logger) Option {
return func(opts *Options) {
opts.Logger = logger

View File

@ -55,7 +55,7 @@ func (h *socks5Handler) Init(md md.Metadata) (err error) {
h.selector = &serverSelector{
Authenticator: auth_util.AuthFromUsers(h.options.Auths...),
TLSConfig: h.md.tlsConfig,
TLSConfig: h.options.TLSConfig,
logger: h.logger,
noTLS: h.md.noTLS,
}

View File

@ -1,16 +1,13 @@
package v5
import (
"crypto/tls"
"math"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
tlsConfig *tls.Config
timeout time.Duration
readTimeout time.Duration
noTLS bool
@ -22,9 +19,6 @@ type metadata struct {
func (h *socks5Handler) parseMetadata(md mdata.Metadata) (err error) {
const (
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
readTimeout = "readTimeout"
timeout = "timeout"
noTLS = "notls"
@ -34,15 +28,6 @@ func (h *socks5Handler) parseMetadata(md mdata.Metadata) (err error) {
compatibilityMode = "comp"
)
h.md.tlsConfig, err = tls_util.LoadServerConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
)
if err != nil {
return
}
h.md.readTimeout = mdata.GetDuration(md, readTimeout)
h.md.timeout = mdata.GetDuration(md, timeout)
h.md.noTLS = mdata.GetBool(md, noTLS)

View File

@ -21,23 +21,23 @@ func init() {
}
type dnsListener struct {
saddr string
addr net.Addr
server Server
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
options := listener.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &dnsListener{
saddr: options.Addr,
logger: options.Logger,
logger: options.Logger,
options: options,
}
}
@ -46,7 +46,7 @@ func (l *dnsListener) Init(md md.Metadata) (err error) {
return
}
l.addr, err = net.ResolveTCPAddr("tcp", l.saddr)
l.addr, err = net.ResolveTCPAddr("tcp", l.options.Addr)
if err != nil {
return err
}
@ -55,7 +55,7 @@ func (l *dnsListener) Init(md md.Metadata) (err error) {
case "tcp":
l.server = &dns.Server{
Net: "tcp",
Addr: l.saddr,
Addr: l.options.Addr,
Handler: l,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
@ -63,16 +63,16 @@ func (l *dnsListener) Init(md md.Metadata) (err error) {
case "tls":
l.server = &dns.Server{
Net: "tcp-tls",
Addr: l.saddr,
Addr: l.options.Addr,
Handler: l,
TLSConfig: l.md.tlsConfig,
TLSConfig: l.options.TLSConfig,
ReadTimeout: l.md.readTimeout,
WriteTimeout: l.md.writeTimeout,
}
case "https":
l.server = &dohServer{
addr: l.saddr,
tlsConfig: l.md.tlsConfig,
addr: l.options.Addr,
tlsConfig: l.options.TLSConfig,
server: &http.Server{
Handler: l,
ReadTimeout: l.md.readTimeout,
@ -80,10 +80,10 @@ func (l *dnsListener) Init(md md.Metadata) (err error) {
},
}
default:
l.addr, err = net.ResolveUDPAddr("udp", l.saddr)
l.addr, err = net.ResolveUDPAddr("udp", l.options.Addr)
l.server = &dns.Server{
Net: "udp",
Addr: l.saddr,
Addr: l.options.Addr,
Handler: l,
UDPSize: l.md.readBufferSize,
ReadTimeout: l.md.readTimeout,

View File

@ -1,10 +1,8 @@
package dns
import (
"crypto/tls"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -17,7 +15,6 @@ type metadata struct {
readBufferSize int
readTimeout time.Duration
writeTimeout time.Duration
tlsConfig *tls.Config
backlog int
}
@ -26,24 +23,12 @@ func (l *dnsListener) parseMetadata(md mdata.Metadata) (err error) {
mode = "mode"
readBufferSize = "readBufferSize"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
backlog = "backlog"
)
l.md.mode = mdata.GetString(md, mode)
l.md.readBufferSize = mdata.GetInt(md, readBufferSize)
l.md.tlsConfig, err = tls_util.LoadServerConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
)
if err != nil {
return
}
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog

View File

@ -22,35 +22,35 @@ func init() {
type h2Listener struct {
server *http.Server
saddr string
addr net.Addr
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
h2c bool
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
options := listener.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &h2Listener{
saddr: options.Addr,
logger: options.Logger,
h2c: true,
h2c: true,
logger: options.Logger,
options: options,
}
}
func NewTLSListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
options := listener.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &h2Listener{
saddr: options.Addr,
logger: options.Logger,
logger: options.Logger,
options: options,
}
}
@ -60,10 +60,10 @@ func (l *h2Listener) Init(md md.Metadata) (err error) {
}
l.server = &http.Server{
Addr: l.saddr,
Addr: l.options.Addr,
}
ln, err := net.Listen("tcp", l.saddr)
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return err
}
@ -74,12 +74,12 @@ func (l *h2Listener) Init(md md.Metadata) (err error) {
http.HandlerFunc(l.handleFunc), &http2.Server{})
} else {
l.server.Handler = http.HandlerFunc(l.handleFunc)
l.server.TLSConfig = l.md.tlsConfig
l.server.TLSConfig = l.options.TLSConfig
if err := http2.ConfigureServer(l.server, nil); err != nil {
ln.Close()
return err
}
ln = tls.NewListener(ln, l.md.tlsConfig)
ln = tls.NewListener(ln, l.options.TLSConfig)
}
l.cqueue = make(chan net.Conn, l.md.backlog)

View File

@ -1,9 +1,6 @@
package h2
import (
"crypto/tls"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -12,29 +9,16 @@ const (
)
type metadata struct {
path string
tlsConfig *tls.Config
backlog int
path string
backlog int
}
func (l *h2Listener) parseMetadata(md mdata.Metadata) (err error) {
const (
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
backlog = "backlog"
path = "path"
backlog = "backlog"
)
l.md.tlsConfig, err = tls_util.LoadServerConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
)
if err != nil {
return
}
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog

View File

@ -20,22 +20,22 @@ func init() {
type http2Listener struct {
server *http.Server
saddr string
addr net.Addr
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
options := listener.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &http2Listener{
saddr: options.Addr,
logger: options.Logger,
logger: options.Logger,
options: options,
}
}
@ -45,15 +45,15 @@ func (l *http2Listener) Init(md md.Metadata) (err error) {
}
l.server = &http.Server{
Addr: l.saddr,
Addr: l.options.Addr,
Handler: http.HandlerFunc(l.handleFunc),
TLSConfig: l.md.tlsConfig,
TLSConfig: l.options.TLSConfig,
}
if err := http2.ConfigureServer(l.server, nil); err != nil {
return err
}
ln, err := net.Listen("tcp", l.saddr)
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return err
}
@ -63,7 +63,7 @@ func (l *http2Listener) Init(md md.Metadata) (err error) {
&util.TCPKeepAliveListener{
TCPListener: ln.(*net.TCPListener),
},
l.md.tlsConfig,
l.options.TLSConfig,
)
l.cqueue = make(chan net.Conn, l.md.backlog)

View File

@ -1,11 +1,9 @@
package http2
import (
"crypto/tls"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -15,7 +13,6 @@ const (
type metadata struct {
path string
tlsConfig *tls.Config
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
readBufferSize int
@ -28,9 +25,6 @@ type metadata struct {
func (l *http2Listener) parseMetadata(md mdata.Metadata) (err error) {
const (
path = "path"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
@ -38,15 +32,6 @@ func (l *http2Listener) parseMetadata(md mdata.Metadata) (err error) {
backlog = "backlog"
)
l.md.tlsConfig, err = tls_util.LoadServerConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
)
if err != nil {
return
}
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog

View File

@ -1,15 +1,17 @@
package listener
import (
"crypto/tls"
"net/url"
"github.com/go-gost/gost/pkg/logger"
)
type Options struct {
Addr string
Auths []*url.Userinfo
Logger logger.Logger
Addr string
Auths []*url.Userinfo
TLSConfig *tls.Config
Logger logger.Logger
}
type Option func(opts *Options)
@ -26,6 +28,12 @@ func AuthsOption(auths ...*url.Userinfo) Option {
}
}
func TLSConfigOption(tlsConfig *tls.Config) Option {
return func(opts *Options) {
opts.TLSConfig = tlsConfig
}
}
func LoggerOption(logger logger.Logger) Option {
return func(opts *Options) {
opts.Logger = logger

View File

@ -17,22 +17,22 @@ func init() {
}
type quicListener struct {
addr string
ln quic.Listener
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
options := listener.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &quicListener{
addr: options.Addr,
logger: options.Logger,
logger: options.Logger,
options: options,
}
}
@ -41,7 +41,7 @@ func (l *quicListener) Init(md md.Metadata) (err error) {
return
}
laddr, err := net.ResolveUDPAddr("udp", l.addr)
laddr, err := net.ResolveUDPAddr("udp", l.options.Addr)
if err != nil {
return
}
@ -67,7 +67,7 @@ func (l *quicListener) Init(md md.Metadata) (err error) {
},
}
tlsCfg := l.md.tlsConfig
tlsCfg := l.options.TLSConfig
tlsCfg.NextProtos = []string{"http/3", "quic/v1"}
ln, err := quic.Listen(conn, tlsCfg, config)

View File

@ -1,10 +1,8 @@
package quic
import (
"crypto/tls"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -17,7 +15,6 @@ type metadata struct {
handshakeTimeout time.Duration
maxIdleTimeout time.Duration
tlsConfig *tls.Config
cipherKey []byte
backlog int
}
@ -28,23 +25,10 @@ func (l *quicListener) parseMetadata(md mdata.Metadata) (err error) {
handshakeTimeout = "handshakeTimeout"
maxIdleTimeout = "maxIdleTimeout"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
backlog = "backlog"
cipherKey = "cipherKey"
)
l.md.tlsConfig, err = tls_util.LoadServerConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
)
if err != nil {
return
}
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog

View File

@ -15,20 +15,20 @@ func init() {
}
type tlsListener struct {
addr string
net.Listener
logger logger.Logger
md metadata
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
options := listener.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &tlsListener{
addr: options.Addr,
logger: options.Logger,
logger: options.Logger,
options: options,
}
}
@ -37,12 +37,12 @@ func (l *tlsListener) Init(md md.Metadata) (err error) {
return
}
ln, err := net.Listen("tcp", l.addr)
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return
}
l.Listener = tls.NewListener(ln, l.md.tlsConfig)
l.Listener = tls.NewListener(ln, l.options.TLSConfig)
return
}

View File

@ -1,31 +1,12 @@
package tls
import (
"crypto/tls"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
type metadata struct {
tlsConfig *tls.Config
}
func (l *tlsListener) parseMetadata(md mdata.Metadata) (err error) {
const (
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
)
l.md.tlsConfig, err = tls_util.LoadServerConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
)
if err != nil {
return
}
return
}

View File

@ -16,22 +16,22 @@ func init() {
}
type mtlsListener struct {
addr string
net.Listener
cqueue chan net.Conn
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
options := listener.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &mtlsListener{
addr: options.Addr,
logger: options.Logger,
logger: options.Logger,
options: options,
}
}
@ -40,11 +40,11 @@ func (l *mtlsListener) Init(md md.Metadata) (err error) {
return
}
ln, err := net.Listen("tcp", l.addr)
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return
}
l.Listener = tls.NewListener(ln, l.md.tlsConfig)
l.Listener = tls.NewListener(ln, l.options.TLSConfig)
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)

View File

@ -1,10 +1,8 @@
package mux
import (
"crypto/tls"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -13,8 +11,6 @@ const (
)
type metadata struct {
tlsConfig *tls.Config
muxKeepAliveDisabled bool
muxKeepAliveInterval time.Duration
muxKeepAliveTimeout time.Duration
@ -27,10 +23,6 @@ type metadata struct {
func (l *mtlsListener) parseMetadata(md mdata.Metadata) (err error) {
const (
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
backlog = "backlog"
muxKeepAliveDisabled = "muxKeepAliveDisabled"
@ -41,15 +33,6 @@ func (l *mtlsListener) parseMetadata(md mdata.Metadata) (err error) {
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
l.md.tlsConfig, err = tls_util.LoadServerConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
)
if err != nil {
return
}
l.md.backlog = mdata.GetInt(md, backlog)
if l.md.backlog <= 0 {
l.md.backlog = defaultBacklog

View File

@ -20,7 +20,6 @@ func init() {
}
type wsListener struct {
saddr string
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
@ -29,28 +28,29 @@ type wsListener struct {
errChan chan error
logger logger.Logger
md metadata
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
options := listener.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &wsListener{
saddr: options.Addr,
logger: options.Logger,
logger: options.Logger,
options: options,
}
}
func NewTLSListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
options := listener.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &wsListener{
saddr: options.Addr,
logger: options.Logger,
tlsEnabled: true,
logger: options.Logger,
options: options,
}
}
@ -70,7 +70,7 @@ func (l *wsListener) Init(md md.Metadata) (err error) {
mux := http.NewServeMux()
mux.Handle(l.md.path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: l.saddr,
Addr: l.options.Addr,
Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout,
}
@ -78,12 +78,12 @@ func (l *wsListener) Init(md md.Metadata) (err error) {
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.saddr)
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return
}
if l.tlsEnabled {
ln = tls.NewListener(ln, l.md.tlsConfig)
ln = tls.NewListener(ln, l.options.TLSConfig)
}
l.addr = ln.Addr()

View File

@ -1,11 +1,9 @@
package ws
import (
"crypto/tls"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -15,9 +13,8 @@ const (
)
type metadata struct {
path string
backlog int
tlsConfig *tls.Config
path string
backlog int
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
@ -30,10 +27,6 @@ type metadata struct {
func (l *wsListener) parseMetadata(md mdata.Metadata) (err error) {
const (
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
path = "path"
backlog = "backlog"
@ -46,15 +39,6 @@ func (l *wsListener) parseMetadata(md mdata.Metadata) (err error) {
header = "header"
)
l.md.tlsConfig, err = tls_util.LoadServerConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
)
if err != nil {
return
}
l.md.path = mdata.GetString(md, path)
if l.md.path == "" {
l.md.path = defaultPath

View File

@ -21,37 +21,37 @@ func init() {
}
type mwsListener struct {
saddr string
addr net.Addr
upgrader *websocket.Upgrader
srv *http.Server
cqueue chan net.Conn
errChan chan error
tlsEnabled bool
logger logger.Logger
md metadata
tlsEnabled bool
options listener.Options
}
func NewListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
options := listener.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &mwsListener{
saddr: options.Addr,
logger: options.Logger,
logger: options.Logger,
options: options,
}
}
func NewTLSListener(opts ...listener.Option) listener.Listener {
options := &listener.Options{}
options := listener.Options{}
for _, opt := range opts {
opt(options)
opt(&options)
}
return &mwsListener{
saddr: options.Addr,
logger: options.Logger,
tlsEnabled: true,
logger: options.Logger,
options: options,
}
}
@ -75,7 +75,7 @@ func (l *mwsListener) Init(md md.Metadata) (err error) {
mux := http.NewServeMux()
mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{
Addr: l.saddr,
Addr: l.options.Addr,
Handler: mux,
ReadHeaderTimeout: l.md.readHeaderTimeout,
}
@ -83,12 +83,12 @@ func (l *mwsListener) Init(md md.Metadata) (err error) {
l.cqueue = make(chan net.Conn, l.md.backlog)
l.errChan = make(chan error, 1)
ln, err := net.Listen("tcp", l.saddr)
ln, err := net.Listen("tcp", l.options.Addr)
if err != nil {
return
}
if l.tlsEnabled {
ln = tls.NewListener(ln, l.md.tlsConfig)
ln = tls.NewListener(ln, l.options.TLSConfig)
}
l.addr = ln.Addr()

View File

@ -1,11 +1,9 @@
package mux
import (
"crypto/tls"
"net/http"
"time"
tls_util "github.com/go-gost/gost/pkg/common/util/tls"
mdata "github.com/go-gost/gost/pkg/metadata"
)
@ -15,10 +13,9 @@ const (
)
type metadata struct {
path string
backlog int
tlsConfig *tls.Config
header http.Header
path string
backlog int
header http.Header
handshakeTimeout time.Duration
readHeaderTimeout time.Duration
@ -40,10 +37,6 @@ func (l *mwsListener) parseMetadata(md mdata.Metadata) (err error) {
backlog = "backlog"
header = "header"
certFile = "certFile"
keyFile = "keyFile"
caFile = "caFile"
handshakeTimeout = "handshakeTimeout"
readHeaderTimeout = "readHeaderTimeout"
readBufferSize = "readBufferSize"
@ -58,15 +51,6 @@ func (l *mwsListener) parseMetadata(md mdata.Metadata) (err error) {
muxMaxStreamBuffer = "muxMaxStreamBuffer"
)
l.md.tlsConfig, err = tls_util.LoadServerConfig(
mdata.GetString(md, certFile),
mdata.GetString(md, keyFile),
mdata.GetString(md, caFile),
)
if err != nil {
return
}
l.md.path = mdata.GetString(md, path)
if l.md.path == "" {
l.md.path = defaultPath