add tls config option
This commit is contained in:
119
cmd/gost/cmd.go
119
cmd/gost/cmd.go
@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
@ -115,21 +116,46 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var auths []*config.AuthConfig
|
||||
if url.User != nil {
|
||||
auth := &config.AuthConfig{
|
||||
Username: url.User.Username(),
|
||||
}
|
||||
auth.Password, _ = url.User.Password()
|
||||
auths = append(auths, auth)
|
||||
}
|
||||
|
||||
md := make(map[string]interface{})
|
||||
for k, v := range url.Query() {
|
||||
if len(v) > 0 {
|
||||
md[k] = v[0]
|
||||
}
|
||||
}
|
||||
|
||||
var auths []config.AuthConfig
|
||||
if url.User != nil {
|
||||
auth := config.AuthConfig{
|
||||
Username: url.User.Username(),
|
||||
if sauth := md["auth"]; sauth != nil {
|
||||
if sa, _ := sauth.(string); sa != "" {
|
||||
au, err := parseAuthFromCmd(sa)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auths = append(auths, au)
|
||||
}
|
||||
auth.Password, _ = url.User.Password()
|
||||
auths = append(auths, auth)
|
||||
}
|
||||
delete(md, "auth")
|
||||
|
||||
var tlsConfig *config.TLSConfig
|
||||
if certs := md["cert"]; certs != nil {
|
||||
cert, _ := certs.(string)
|
||||
key, _ := md["key"].(string)
|
||||
ca, _ := md["ca"].(string)
|
||||
tlsConfig = &config.TLSConfig{
|
||||
Cert: cert,
|
||||
Key: key,
|
||||
CA: ca,
|
||||
}
|
||||
}
|
||||
delete(md, "cert")
|
||||
delete(md, "key")
|
||||
delete(md, "ca")
|
||||
|
||||
svc.Handler = &config.HandlerConfig{
|
||||
Type: handler,
|
||||
@ -138,6 +164,7 @@ func buildServiceConfig(url *url.URL) (*config.ServiceConfig, error) {
|
||||
}
|
||||
svc.Listener = &config.ListenerConfig{
|
||||
Type: listener,
|
||||
TLS: tlsConfig,
|
||||
Metadata: md,
|
||||
}
|
||||
|
||||
@ -170,14 +197,6 @@ func buildNodeConfig(url *url.URL) (*config.NodeConfig, error) {
|
||||
}
|
||||
}
|
||||
|
||||
md := make(map[string]interface{})
|
||||
for k, v := range url.Query() {
|
||||
if len(v) > 0 {
|
||||
md[k] = v[0]
|
||||
}
|
||||
}
|
||||
md["serverName"] = url.Host
|
||||
|
||||
var auth *config.AuthConfig
|
||||
if url.User != nil {
|
||||
auth = &config.AuthConfig{
|
||||
@ -186,6 +205,46 @@ func buildNodeConfig(url *url.URL) (*config.NodeConfig, error) {
|
||||
auth.Password, _ = url.User.Password()
|
||||
}
|
||||
|
||||
md := make(map[string]interface{})
|
||||
for k, v := range url.Query() {
|
||||
if len(v) > 0 {
|
||||
md[k] = v[0]
|
||||
}
|
||||
}
|
||||
md["serverName"] = url.Host
|
||||
|
||||
if sauth := md["auth"]; sauth != nil && auth == nil {
|
||||
if sa, _ := sauth.(string); sa != "" {
|
||||
au, err := parseAuthFromCmd(sa)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth = au
|
||||
}
|
||||
}
|
||||
delete(md, "auth")
|
||||
|
||||
var tlsConfig *config.TLSConfig
|
||||
if certs := md["cert"]; certs != nil {
|
||||
cert, _ := certs.(string)
|
||||
key, _ := md["key"].(string)
|
||||
ca, _ := md["ca"].(string)
|
||||
secure, _ := md["secure"].(bool)
|
||||
serverName, _ := md["serverName"].(string)
|
||||
tlsConfig = &config.TLSConfig{
|
||||
Cert: cert,
|
||||
Key: key,
|
||||
CA: ca,
|
||||
Secure: secure,
|
||||
ServerName: serverName,
|
||||
}
|
||||
}
|
||||
delete(md, "cert")
|
||||
delete(md, "key")
|
||||
delete(md, "ca")
|
||||
delete(md, "secure")
|
||||
delete(md, "serverName")
|
||||
|
||||
node.Connector = &config.ConnectorConfig{
|
||||
Type: connector,
|
||||
Auth: auth,
|
||||
@ -193,6 +252,7 @@ func buildNodeConfig(url *url.URL) (*config.NodeConfig, error) {
|
||||
}
|
||||
node.Dialer = &config.DialerConfig{
|
||||
Type: dialer,
|
||||
TLS: tlsConfig,
|
||||
Metadata: md,
|
||||
}
|
||||
|
||||
@ -209,5 +269,32 @@ func normCmd(s string) (*url.URL, error) {
|
||||
s = "auto://" + s
|
||||
}
|
||||
|
||||
return url.Parse(s)
|
||||
url, err := url.Parse(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if url.Scheme == "https" {
|
||||
url.Scheme = "http+tls"
|
||||
}
|
||||
|
||||
return url, nil
|
||||
}
|
||||
|
||||
func parseAuthFromCmd(sa string) (*config.AuthConfig, error) {
|
||||
v, err := base64.StdEncoding.DecodeString(sa)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cs := string(v)
|
||||
n := strings.IndexByte(cs, ':')
|
||||
if n < 0 {
|
||||
return &config.AuthConfig{
|
||||
Username: cs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &config.AuthConfig{
|
||||
Username: cs[:n],
|
||||
Password: cs[n+1:],
|
||||
}, nil
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
@ -68,9 +69,23 @@ func buildService(cfg *config.Config) (services []*service.Service) {
|
||||
listenerLogger := serviceLogger.WithFields(map[string]interface{}{
|
||||
"kind": "listener",
|
||||
})
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
var err error
|
||||
|
||||
tlsCfg := svc.Listener.TLS
|
||||
if tlsCfg == nil {
|
||||
tlsCfg = &config.TLSConfig{}
|
||||
}
|
||||
tlsConfig, err = loadServerTLSConfig(tlsCfg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ln := registry.GetListener(svc.Listener.Type)(
|
||||
listener.AddrOption(svc.Addr),
|
||||
listener.AuthsOption(authsFromConfig(svc.Listener.Auths...)...),
|
||||
listener.TLSConfigOption(tlsConfig),
|
||||
listener.LoggerOption(listenerLogger),
|
||||
)
|
||||
|
||||
@ -89,6 +104,16 @@ func buildService(cfg *config.Config) (services []*service.Service) {
|
||||
"kind": "handler",
|
||||
})
|
||||
|
||||
tlsConfig = nil
|
||||
tlsCfg = svc.Handler.TLS
|
||||
if tlsCfg == nil {
|
||||
tlsCfg = &config.TLSConfig{}
|
||||
}
|
||||
tlsConfig, err = loadServerTLSConfig(tlsCfg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
h := registry.GetHandler(svc.Handler.Type)(
|
||||
handler.RetriesOption(svc.Handler.Retries),
|
||||
handler.ChainOption(chains[svc.Handler.Chain]),
|
||||
@ -96,6 +121,7 @@ func buildService(cfg *config.Config) (services []*service.Service) {
|
||||
handler.HostsOption(hosts[svc.Handler.Hosts]),
|
||||
handler.BypassOption(bypasses[svc.Handler.Bypass]),
|
||||
handler.AuthsOption(authsFromConfig(svc.Handler.Auths...)...),
|
||||
handler.TLSConfigOption(tlsConfig),
|
||||
handler.LoggerOption(handlerLogger),
|
||||
)
|
||||
|
||||
@ -148,16 +174,29 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain {
|
||||
"kind": "connector",
|
||||
})
|
||||
|
||||
var connectorUser *url.Userinfo
|
||||
var user *url.Userinfo
|
||||
if auth := v.Connector.Auth; auth != nil && auth.Username != "" {
|
||||
if auth.Password == "" {
|
||||
connectorUser = url.User(auth.Username)
|
||||
user = url.User(auth.Username)
|
||||
} else {
|
||||
connectorUser = url.UserPassword(auth.Username, auth.Password)
|
||||
user = url.UserPassword(auth.Username, auth.Password)
|
||||
}
|
||||
}
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
var err error
|
||||
tlsCfg := v.Connector.TLS
|
||||
if tlsCfg == nil {
|
||||
tlsCfg = &config.TLSConfig{}
|
||||
}
|
||||
tlsConfig, err = loadClientTLSConfig(tlsCfg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
cr := registry.GetConnector(v.Connector.Type)(
|
||||
connector.UserOption(connectorUser),
|
||||
connector.UserOption(user),
|
||||
connector.TLSConfigOption(tlsConfig),
|
||||
connector.LoggerOption(connectorLogger),
|
||||
)
|
||||
|
||||
@ -172,16 +211,28 @@ func chainFromConfig(cfg *config.ChainConfig) *chain.Chain {
|
||||
"kind": "dialer",
|
||||
})
|
||||
|
||||
var dialerUser *url.Userinfo
|
||||
user = nil
|
||||
if auth := v.Dialer.Auth; auth != nil && auth.Username != "" {
|
||||
if auth.Password == "" {
|
||||
dialerUser = url.User(auth.Username)
|
||||
user = url.User(auth.Username)
|
||||
} else {
|
||||
dialerUser = url.UserPassword(auth.Username, auth.Password)
|
||||
user = url.UserPassword(auth.Username, auth.Password)
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig = nil
|
||||
tlsCfg = v.Dialer.TLS
|
||||
if tlsCfg == nil {
|
||||
tlsCfg = &config.TLSConfig{}
|
||||
}
|
||||
tlsConfig, err = loadClientTLSConfig(tlsCfg)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
d := registry.GetDialer(v.Dialer.Type)(
|
||||
dialer.UserOption(dialerUser),
|
||||
dialer.UserOption(user),
|
||||
dialer.TLSConfigOption(tlsConfig),
|
||||
dialer.LoggerOption(dialerLogger),
|
||||
)
|
||||
|
||||
@ -328,11 +379,11 @@ func hostsFromConfig(cfg *config.HostsConfig) hostspkg.HostMapper {
|
||||
return hosts
|
||||
}
|
||||
|
||||
func authsFromConfig(cfgs ...config.AuthConfig) []*url.Userinfo {
|
||||
func authsFromConfig(cfgs ...*config.AuthConfig) []*url.Userinfo {
|
||||
var auths []*url.Userinfo
|
||||
|
||||
for _, cfg := range cfgs {
|
||||
if cfg.Username == "" {
|
||||
if cfg == nil || cfg.Username == "" {
|
||||
continue
|
||||
}
|
||||
auths = append(auths, url.UserPassword(cfg.Username, cfg.Password))
|
||||
|
@ -14,6 +14,14 @@ import (
|
||||
"github.com/go-gost/gost/pkg/config"
|
||||
)
|
||||
|
||||
func loadServerTLSConfig(cfg *config.TLSConfig) (*tls.Config, error) {
|
||||
return tls_util.LoadServerConfig(cfg.Cert, cfg.Key, cfg.CA)
|
||||
}
|
||||
|
||||
func loadClientTLSConfig(cfg *config.TLSConfig) (*tls.Config, error) {
|
||||
return tls_util.LoadClientConfig(cfg.Cert, cfg.Key, cfg.CA, cfg.Secure, cfg.ServerName)
|
||||
}
|
||||
|
||||
func buildDefaultTLSConfig(cfg *config.TLSConfig) {
|
||||
if cfg == nil {
|
||||
cfg = &config.TLSConfig{
|
||||
|
Reference in New Issue
Block a user