diff --git a/cert/cert.go b/cert/cert.go index 084a5f7..c82e2d5 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -45,6 +45,10 @@ const ( duration365d = time.Hour * 24 * 365 ) +var ( + ErrStaticCert = errors.New("cannot renew static certificate") +) + // Config contains the basic fields required for creating a certificate type Config struct { CommonName string @@ -119,7 +123,13 @@ func NewSignedCert(cfg Config, key crypto.Signer, caCert *x509.Certificate, caKe if err != nil { return nil, err } - return x509.ParseCertificate(certDERBytes) + + parsedCert, err := x509.ParseCertificate(certDERBytes) + if err == nil { + logrus.Infof("certificate %s signed by %s: notBefore=%s notAfter=%s", + parsedCert.Subject, caCert.Subject, parsedCert.NotBefore, parsedCert.NotAfter) + } + return parsedCert, err } // MakeEllipticPrivateKeyPEM creates an ECDSA private key @@ -271,11 +281,11 @@ func ipsToStrings(ips []net.IP) []string { } // IsCertExpired checks if the certificate about to expire -func IsCertExpired(cert *x509.Certificate) bool { +func IsCertExpired(cert *x509.Certificate, days int) bool { expirationDate := cert.NotAfter - diffDays := expirationDate.Sub(time.Now()).Hours() / 24.0 - if diffDays <= 90 { - logrus.Infof("certificate will expire in %f days", diffDays) + diffDays := time.Until(expirationDate).Hours() / 24.0 + if diffDays <= float64(days) { + logrus.Infof("certificate %s will expire in %f days at %s", cert.Subject, diffDays, cert.NotAfter) return true } return false diff --git a/cert/io.go b/cert/io.go index 5319566..984307f 100644 --- a/cert/io.go +++ b/cert/io.go @@ -34,15 +34,15 @@ func CanReadCertAndKey(certPath, keyPath string) (bool, error) { certReadable := canReadFile(certPath) keyReadable := canReadFile(keyPath) - if certReadable == false && keyReadable == false { + if !certReadable && !keyReadable { return false, nil } - if certReadable == false { + if !certReadable { return false, fmt.Errorf("error reading %s, certificate and key must be supplied as a pair", certPath) } - if keyReadable == false { + if !keyReadable { return false, fmt.Errorf("error reading %s, certificate and key must be supplied as a pair", keyPath) } diff --git a/factory/cert_utils.go b/factory/cert_utils.go index 6b28abe..95405e6 100644 --- a/factory/cert_utils.go +++ b/factory/cert_utils.go @@ -11,6 +11,8 @@ import ( "math/big" "net" "time" + + "github.com/sirupsen/logrus" ) const ( @@ -92,7 +94,12 @@ func NewSignedCert(signer crypto.Signer, caCert *x509.Certificate, caKey crypto. return nil, err } - return x509.ParseCertificate(cert) + parsedCert, err := x509.ParseCertificate(cert) + if err == nil { + logrus.Infof("certificate %s signed by %s: notBefore=%s notAfter=%s", + parsedCert.Subject, caCert.Subject, parsedCert.NotBefore, parsedCert.NotAfter) + } + return parsedCert, err } func ParseCertPEM(pemCerts []byte) (*x509.Certificate, error) { diff --git a/factory/gen.go b/factory/gen.go index ddd5b6d..ff15a93 100644 --- a/factory/gen.go +++ b/factory/gen.go @@ -5,10 +5,10 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" - "crypto/sha256" + "crypto/sha1" "crypto/x509" - "encoding/hex" "encoding/pem" + "fmt" "net" "regexp" "sort" @@ -20,9 +20,9 @@ import ( ) const ( - cnPrefix = "listener.cattle.io/cn-" - Static = "listener.cattle.io/static" - hashKey = "listener.cattle.io/hash" + cnPrefix = "listener.cattle.io/cn-" + Static = "listener.cattle.io/static" + fingerprint = "listener.cattle.io/fingerprint" ) var ( @@ -49,16 +49,14 @@ func cns(secret *v1.Secret) (cns []string) { return } -func collectCNs(secret *v1.Secret) (domains []string, ips []net.IP, hash string, err error) { +func collectCNs(secret *v1.Secret) (domains []string, ips []net.IP, err error) { var ( - cns = cns(secret) - digest = sha256.New() + cns = cns(secret) ) sort.Strings(cns) for _, v := range cns { - digest.Write([]byte(v)) ip := net.ParseIP(v) if ip == nil { domains = append(domains, v) @@ -67,40 +65,61 @@ func collectCNs(secret *v1.Secret) (domains []string, ips []net.IP, hash string, } } - hash = hex.EncodeToString(digest.Sum(nil)) return } +// Merge combines the SAN lists from the target and additional Secrets, and returns a potentially modified Secret, +// along with a bool indicating if the returned Secret has been updated or not. If the two SAN lists alread matched +// and no merging was necessary, but the Secrets' certificate fingerprints differed, the second secret is returned +// and the updated bool is set to true despite neither certificate having actually been modified. This is required +// to support handling certificate renewal within the kubernetes storage provider. func (t *TLS) Merge(target, additional *v1.Secret) (*v1.Secret, bool, error) { - return t.AddCN(target, cns(additional)...) + secret, updated, err := t.AddCN(target, cns(additional)...) + if !updated { + if target.Annotations[fingerprint] != additional.Annotations[fingerprint] { + secret = additional + updated = true + } + } + return secret, updated, err } -func (t *TLS) Refresh(secret *v1.Secret) (*v1.Secret, error) { +// Renew returns a copy of the given certificate that has been re-signed +// to extend the NotAfter date. It is an error to attempt to renew +// a static (user-provided) certificate. +func (t *TLS) Renew(secret *v1.Secret) (*v1.Secret, error) { + if IsStatic(secret) { + return secret, cert.ErrStaticCert + } cns := cns(secret) secret = secret.DeepCopy() secret.Annotations = map[string]string{} - secret, _, err := t.AddCN(secret, cns...) + secret, _, err := t.generateCert(secret, cns...) return secret, err } +// Filter ensures that the CNs are all valid accorting to both internal logic, and any filter callbacks. +// The returned list will contain only approved CN entries. func (t *TLS) Filter(cn ...string) []string { - if t.FilterCN == nil { + if len(cn) == 0 || t.FilterCN == nil { return cn } return t.FilterCN(cn...) } +// AddCN attempts to add a list of CN strings to a given Secret, returning the potentially-modified +// Secret along with a bool indicating whether or not it has been updated. The Secret will not be changed +// if it has an attribute indicating that it is static (aka user-provided), or if no new CNs were added. func (t *TLS) AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) { - var ( - err error - ) - cn = t.Filter(cn...) - if !NeedsUpdate(0, secret, cn...) { + if IsStatic(secret) || !NeedsUpdate(0, secret, cn...) { return secret, false, nil } + return t.generateCert(secret, cn...) +} +func (t *TLS) generateCert(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) { secret = secret.DeepCopy() if secret == nil { secret = &v1.Secret{} @@ -113,7 +132,7 @@ func (t *TLS) AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) { return nil, false, err } - domains, ips, hash, err := collectCNs(secret) + domains, ips, err := collectCNs(secret) if err != nil { return nil, false, err } @@ -133,7 +152,7 @@ func (t *TLS) AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) { } secret.Data[v1.TLSCertKey] = certBytes secret.Data[v1.TLSPrivateKeyKey] = keyBytes - secret.Annotations[hashKey] = hash + secret.Annotations[fingerprint] = fmt.Sprintf("SHA1=%X", sha1.Sum(newCert.Raw)) return secret, true, nil } @@ -157,15 +176,21 @@ func populateCN(secret *v1.Secret, cn ...string) *v1.Secret { return secret } +// IsStatic returns true if the Secret has an attribute indicating that it contains +// a static (aka user-provided) certificate, which should not be modified. +func IsStatic(secret *v1.Secret) bool { + return secret.Annotations[Static] == "true" +} + +// NeedsUpdate returns true if any of the CNs are not currently present on the +// secret's Certificate, as recorded in the cnPrefix annotations. It will return +// false if all requested CNs are already present, or if maxSANs is non-zero and has +// been exceeded. func NeedsUpdate(maxSANs int, secret *v1.Secret, cn ...string) bool { if secret == nil { return true } - if secret.Annotations[Static] == "true" { - return false - } - for _, cn := range cn { if secret.Annotations[cnPrefix+cn] == "" { if maxSANs > 0 && len(cns(secret)) >= maxSANs { @@ -192,6 +217,7 @@ func getPrivateKey(secret *v1.Secret) (crypto.Signer, error) { return NewPrivateKey() } +// Marshal returns the given cert and key as byte slices. func Marshal(x509Cert *x509.Certificate, privateKey crypto.Signer) ([]byte, []byte, error) { certBlock := pem.Block{ Type: CertificateBlockType, @@ -206,6 +232,7 @@ func Marshal(x509Cert *x509.Certificate, privateKey crypto.Signer) ([]byte, []by return pem.EncodeToMemory(&certBlock), keyBytes, nil } +// NewPrivateKey returnes a new ECDSA key func NewPrivateKey() (crypto.Signer, error) { return ecdsa.GenerateKey(elliptic.P256(), rand.Reader) } diff --git a/listener.go b/listener.go index 54b3fed..3a30e20 100644 --- a/listener.go +++ b/listener.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/rancher/dynamiclistener/cert" "github.com/rancher/dynamiclistener/factory" "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" @@ -22,7 +23,7 @@ type TLSStorage interface { } type TLSFactory interface { - Refresh(secret *v1.Secret) (*v1.Secret, error) + Renew(secret *v1.Secret) (*v1.Secret, error) AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) Merge(target *v1.Secret, additional *v1.Secret) (*v1.Secret, bool, error) Filter(cn ...string) []string @@ -152,13 +153,18 @@ type listener struct { func (l *listener) WrapExpiration(days int) net.Listener { ctx, cancel := context.WithCancel(context.Background()) go func() { - time.Sleep(5 * time.Minute) + time.Sleep(30 * time.Second) for { wait := 6 * time.Hour if err := l.checkExpiration(days); err != nil { - logrus.Errorf("failed to check and refresh dynamic cert: %v", err) - wait = 5 + time.Minute + logrus.Errorf("failed to check and renew dynamic cert: %v", err) + // Don't go into short retry loop if we're using a static (user-provided) cert. + // We will still check and print an error every six hours until the user updates the secret with + // a cert that is not about to expire. Hopefully this will prompt them to take action. + if err != cert.ErrStaticCert { + wait = 5 * time.Minute + } } select { case <-ctx.Done(): @@ -191,22 +197,26 @@ func (l *listener) checkExpiration(days int) error { return err } - cert, err := tls.X509KeyPair(secret.Data[v1.TLSCertKey], secret.Data[v1.TLSPrivateKeyKey]) + keyPair, err := tls.X509KeyPair(secret.Data[v1.TLSCertKey], secret.Data[v1.TLSPrivateKeyKey]) if err != nil { return err } - certParsed, err := x509.ParseCertificate(cert.Certificate[0]) + certParsed, err := x509.ParseCertificate(keyPair.Certificate[0]) if err != nil { return err } - if time.Now().UTC().Add(time.Hour * 24 * time.Duration(days)).After(certParsed.NotAfter) { - secret, err := l.factory.Refresh(secret) + if cert.IsCertExpired(certParsed, days) { + secret, err := l.factory.Renew(secret) if err != nil { return err } - return l.storage.Update(secret) + if err := l.storage.Update(secret); err != nil { + return err + } + // clear version to force cert reload + l.version = "" } return nil @@ -304,7 +314,7 @@ func (l *listener) updateCert(cn ...string) error { return err } - if !factory.NeedsUpdate(l.maxSANs, secret, cn...) { + if !factory.IsStatic(secret) && !factory.NeedsUpdate(l.maxSANs, secret, cn...) { return nil } @@ -324,13 +334,6 @@ func (l *listener) updateCert(cn ...string) error { } // clear version to force cert reload l.version = "" - if l.conns != nil { - l.connLock.Lock() - for _, conn := range l.conns { - _ = conn.close() - } - l.connLock.Unlock() - } } return nil @@ -366,6 +369,15 @@ func (l *listener) loadCert() (*tls.Certificate, error) { return nil, err } + // cert has changed, close closeWrapper wrapped connections + if l.conns != nil { + l.connLock.Lock() + for _, conn := range l.conns { + _ = conn.close() + } + l.connLock.Unlock() + } + l.cert = &cert l.version = secret.ResourceVersion return l.cert, nil diff --git a/storage/kubernetes/controller.go b/storage/kubernetes/controller.go index 683ab9c..70d5d7c 100644 --- a/storage/kubernetes/controller.go +++ b/storage/kubernetes/controller.go @@ -156,10 +156,9 @@ func (s *storage) saveInK8s(secret *v1.Secret) (*v1.Secret, error) { if targetSecret.UID == "" { logrus.Infof("Creating new TLS secret for %v (count: %d): %v", targetSecret.Name, len(targetSecret.Annotations)-1, targetSecret.Annotations) return s.secrets.Create(targetSecret) - } else { - logrus.Infof("Updating TLS secret for %v (count: %d): %v", targetSecret.Name, len(targetSecret.Annotations)-1, targetSecret.Annotations) - return s.secrets.Update(targetSecret) } + logrus.Infof("Updating TLS secret for %v (count: %d): %v", targetSecret.Name, len(targetSecret.Annotations)-1, targetSecret.Annotations) + return s.secrets.Update(targetSecret) } func (s *storage) Update(secret *v1.Secret) (err error) { diff --git a/storage/memory/memory.go b/storage/memory/memory.go index c417e30..54f6251 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -32,13 +32,15 @@ func (m *memory) Get() (*v1.Secret, error) { } func (m *memory) Update(secret *v1.Secret) error { - if m.storage != nil { - if err := m.storage.Update(secret); err != nil { - return err + if m.secret == nil || m.secret.ResourceVersion != secret.ResourceVersion { + if m.storage != nil { + if err := m.storage.Update(secret); err != nil { + return err + } } - } - logrus.Infof("Active TLS secret %s (ver=%s) (count %d): %v", secret.Name, secret.ResourceVersion, len(secret.Annotations)-1, secret.Annotations) - m.secret = secret + logrus.Infof("Active TLS secret %s (ver=%s) (count %d): %v", secret.Name, secret.ResourceVersion, len(secret.Annotations)-1, secret.Annotations) + m.secret = secret + } return nil }