From abee8bddf36a311acd4a2e6e9fc7960525b8f554 Mon Sep 17 00:00:00 2001 From: Darien Raymond Date: Sat, 14 Apr 2018 13:12:50 +0200 Subject: [PATCH] only try issuing new certificate when user provide custom CA --- transport/internet/tls/config.go | 88 ++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 38 deletions(-) diff --git a/transport/internet/tls/config.go b/transport/internet/tls/config.go index 084eed864..355f096e0 100644 --- a/transport/internet/tls/config.go +++ b/transport/internet/tls/config.go @@ -58,6 +58,15 @@ func issueCertificate(rawCA *Certificate, domain string) (*tls.Certificate, erro return &cert, err } +func (c *Config) hasCustomCA() bool { + for _, certificate := range c.Certificate { + if certificate.Usage == Certificate_AUTHORITY_ISSUE { + return true + } + } + return false +} + func (c *Config) GetTLSConfig(opts ...Option) *tls.Config { config := &tls.Config{ ClientSessionCache: globalSessionCache, @@ -74,53 +83,56 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config { config.InsecureSkipVerify = c.AllowInsecure config.Certificates = c.BuildCertificates() config.BuildNameToCertificate() - config.GetCertificate = func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { - domain := hello.ServerName - certExpired := false - if certificate, found := config.NameToCertificate[domain]; found { - if !isCertificateExpired(certificate) { - return certificate, nil + if c.hasCustomCA() { + config.GetCertificate = func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + domain := hello.ServerName + certExpired := false + if certificate, found := config.NameToCertificate[domain]; found { + if !isCertificateExpired(certificate) { + return certificate, nil + } + certExpired = true } - certExpired = true - } - if certExpired { - newCerts := make([]tls.Certificate, 0, len(config.Certificates)) + if certExpired { + newCerts := make([]tls.Certificate, 0, len(config.Certificates)) - for _, certificate := range config.Certificates { - if !isCertificateExpired(&certificate) { - newCerts = append(newCerts, certificate) + for _, certificate := range config.Certificates { + if !isCertificateExpired(&certificate) { + newCerts = append(newCerts, certificate) + } + } + + config.Certificates = newCerts + } + + var issuedCertificate *tls.Certificate + + // Create a new certificate from existing CA if possible + for _, rawCert := range c.Certificate { + if rawCert.Usage == Certificate_AUTHORITY_ISSUE { + newCert, err := issueCertificate(rawCert, domain) + if err != nil { + newError("failed to issue new certificate for ", domain).Base(err).WriteToLog() + continue + } + + config.Certificates = append(config.Certificates, *newCert) + issuedCertificate = &config.Certificates[len(config.Certificates)-1] + break } } - config.Certificates = newCerts - } - - var issuedCertificate *tls.Certificate - - // Create a new certificate from existing CA if possible - for _, rawCert := range c.Certificate { - if rawCert.Usage == Certificate_AUTHORITY_ISSUE { - newCert, err := issueCertificate(rawCert, domain) - if err != nil { - newError("failed to issue new certificate for ", domain).Base(err).WriteToLog() - continue - } - - config.Certificates = append(config.Certificates, *newCert) - issuedCertificate = &config.Certificates[len(config.Certificates)-1] - break + if issuedCertificate == nil { + return nil, newError("failed to create a new certificate for ", domain) } + + config.BuildNameToCertificate() + + return issuedCertificate, nil } - - if issuedCertificate == nil { - return nil, newError("failed to create a new certificate for ", domain) - } - - config.BuildNameToCertificate() - - return issuedCertificate, nil } + if len(c.ServerName) > 0 { config.ServerName = c.ServerName }