customize random-generated certificate information

This commit is contained in:
ginuerzh 2022-07-27 16:58:49 +08:00
parent f7be171df5
commit d7b56871a9
4 changed files with 45 additions and 18 deletions

View File

@ -71,6 +71,11 @@ type TLSConfig struct {
CAFile string `yaml:"caFile,omitempty" json:"caFile,omitempty"` CAFile string `yaml:"caFile,omitempty" json:"caFile,omitempty"`
Secure bool `yaml:",omitempty" json:"secure,omitempty"` Secure bool `yaml:",omitempty" json:"secure,omitempty"`
ServerName string `yaml:"serverName,omitempty" json:"serverName,omitempty"` ServerName string `yaml:"serverName,omitempty" json:"serverName,omitempty"`
// for auto-generated default certificate.
Validity time.Duration `yaml:",omitempty" json:"validity,omitempty"`
CommonName string `yaml:"commonName,omitempty" json:"commonName,omitempty"`
Organization string `yaml:",omitempty" json:"organization,omitempty"`
} }
type AutherConfig struct { type AutherConfig struct {

View File

@ -31,7 +31,7 @@ func BuildDefaultTLSConfig(cfg *config.TLSConfig) {
tlsConfig, err := loadConfig(cfg.CertFile, cfg.KeyFile) tlsConfig, err := loadConfig(cfg.CertFile, cfg.KeyFile)
if err != nil { if err != nil {
// generate random self-signed certificate. // generate random self-signed certificate.
cert, err := genCertificate() cert, err := genCertificate(cfg.Validity, cfg.Organization, cfg.CommonName)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -58,15 +58,15 @@ func loadConfig(certFile, keyFile string) (*tls.Config, error) {
return cfg, nil return cfg, nil
} }
func genCertificate() (cert tls.Certificate, err error) { func genCertificate(validity time.Duration, org string, cn string) (cert tls.Certificate, err error) {
rawCert, rawKey, err := generateKeyPair() rawCert, rawKey, err := generateKeyPair(validity, org, cn)
if err != nil { if err != nil {
return return
} }
return tls.X509KeyPair(rawCert, rawKey) return tls.X509KeyPair(rawCert, rawKey)
} }
func generateKeyPair() (rawCert, rawKey []byte, err error) { func generateKeyPair(validity time.Duration, org string, cn string) (rawCert, rawKey []byte, err error) {
// Create private key and self-signed certificate // Create private key and self-signed certificate
// Adapted from https://golang.org/src/crypto/tls/generate_cert.go // Adapted from https://golang.org/src/crypto/tls/generate_cert.go
@ -74,7 +74,18 @@ func generateKeyPair() (rawCert, rawKey []byte, err error) {
if err != nil { if err != nil {
return return
} }
validFor := time.Hour * 24 * 365 * 10 // ten years
if validity <= 0 {
validity = time.Hour * 24 * 365 // one year
}
if org == "" {
org = "GOST"
}
if cn == "" {
cn = "gost.run"
}
validFor := validity
notBefore := time.Now() notBefore := time.Now()
notAfter := notBefore.Add(validFor) notAfter := notBefore.Add(validFor)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
@ -86,7 +97,8 @@ func generateKeyPair() (rawCert, rawKey []byte, err error) {
template := x509.Certificate{ template := x509.Certificate{
SerialNumber: serialNumber, SerialNumber: serialNumber,
Subject: pkix.Name{ Subject: pkix.Name{
Organization: []string{"gost"}, Organization: []string{org},
CommonName: cn,
}, },
NotBefore: notBefore, NotBefore: notBefore,
NotAfter: notAfter, NotAfter: notAfter,
@ -96,7 +108,7 @@ func generateKeyPair() (rawCert, rawKey []byte, err error) {
BasicConstraintsValid: true, BasicConstraintsValid: true,
} }
template.DNSNames = append(template.DNSNames, "gost.run") template.DNSNames = append(template.DNSNames, cn)
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil { if err != nil {
return return

View File

@ -16,10 +16,11 @@ type metadata struct {
func (h *httpHandler) parseMetadata(md mdata.Metadata) error { func (h *httpHandler) parseMetadata(md mdata.Metadata) error {
const ( const (
header = "header" header = "header"
probeResistKey = "probeResistance" probeResistKey = "probeResistance"
knock = "knock" probeResistKeyX = "probe_resist"
enableUDP = "udp" knock = "knock"
enableUDP = "udp"
) )
if m := mdx.GetStringMapString(md, header); len(m) > 0 { if m := mdx.GetStringMapString(md, header); len(m) > 0 {
@ -30,8 +31,12 @@ func (h *httpHandler) parseMetadata(md mdata.Metadata) error {
h.md.header = hd h.md.header = hd
} }
if v := mdx.GetString(md, probeResistKey); v != "" { pr := mdx.GetString(md, probeResistKey)
if ss := strings.SplitN(v, ":", 2); len(ss) == 2 { if pr == "" {
pr = mdx.GetString(md, probeResistKeyX)
}
if pr != "" {
if ss := strings.SplitN(pr, ":", 2); len(ss) == 2 {
h.md.probeResistance = &probeResistance{ h.md.probeResistance = &probeResistance{
Type: ss[0], Type: ss[0],
Value: ss[1], Value: ss[1],

View File

@ -15,9 +15,10 @@ type metadata struct {
func (h *http2Handler) parseMetadata(md mdata.Metadata) error { func (h *http2Handler) parseMetadata(md mdata.Metadata) error {
const ( const (
header = "header" header = "header"
probeResistKey = "probeResistance" probeResistKey = "probeResistance"
knock = "knock" probeResistKeyX = "probe_resist"
knock = "knock"
) )
if m := mdx.GetStringMapString(md, header); len(m) > 0 { if m := mdx.GetStringMapString(md, header); len(m) > 0 {
@ -28,8 +29,12 @@ func (h *http2Handler) parseMetadata(md mdata.Metadata) error {
h.md.header = hd h.md.header = hd
} }
if v := mdx.GetString(md, probeResistKey); v != "" { pr := mdx.GetString(md, probeResistKey)
if ss := strings.SplitN(v, ":", 2); len(ss) == 2 { if pr == "" {
pr = mdx.GetString(md, probeResistKeyX)
}
if pr != "" {
if ss := strings.SplitN(pr, ":", 2); len(ss) == 2 {
h.md.probeResistance = &probeResistance{ h.md.probeResistance = &probeResistance{
Type: ss[0], Type: ss[0],
Value: ss[1], Value: ss[1],