diff --git a/util/certificate/certificate_manager.go b/util/certificate/certificate_manager.go index b4dcb0b8..cda9dfe4 100644 --- a/util/certificate/certificate_manager.go +++ b/util/certificate/certificate_manager.go @@ -203,13 +203,16 @@ type Config struct { // certificate renewal failures. CertificateRenewFailure Counter // Name is an optional string that will be used when writing log output - // or returning errors from manager methods. If not set, SignerName will + // via logger.WithName or returning errors from manager methods. + // + // If not set, SignerName will // be used, if SignerName is not set, if Usages includes client auth the // name will be "client auth", otherwise the value will be "server". Name string - // Logf is an optional function that log output will be sent to from the - // certificate manager. If not set it will use klog.V(2) - Logf func(format string, args ...interface{}) + // Ctx is an optional context. Cancelling it is equivalent to + // calling Stop. A logger is extracted from it if non-nil, otherwise + // klog.Background() is used. + Ctx *context.Context } // Store is responsible for getting and updating the current certificate. @@ -278,30 +281,22 @@ type manager struct { cert *tls.Certificate serverHealth bool + // Context and cancel function for background goroutines. + ctx context.Context + cancel func(err error) + // the clientFn must only be accessed under the clientAccessLock clientAccessLock sync.Mutex clientsetFn ClientsetFunc - stopCh chan struct{} - stopped bool // Set to time.Now but can be stubbed out for testing now func() time.Time - - name string - logf func(format string, args ...interface{}) } // NewManager returns a new certificate manager. A certificate manager is // responsible for being the authoritative source of certificates in the // Kubelet and handling updates due to rotation. func NewManager(config *Config) (Manager, error) { - cert, forceRotation, err := getCurrentCertificateOrBootstrap( - config.CertificateStore, - config.BootstrapCertificatePEM, - config.BootstrapKeyPEM) - if err != nil { - return nil, err - } getTemplate := config.GetTemplate if getTemplate == nil { @@ -321,7 +316,6 @@ func NewManager(config *Config) (Manager, error) { getUsages = func(interface{}) []certificates.KeyUsage { return config.Usages } } m := manager{ - stopCh: make(chan struct{}), clientsetFn: config.ClientsetFn, getTemplate: getTemplate, dynamicTemplate: config.GetTemplate != nil, @@ -329,13 +323,12 @@ func NewManager(config *Config) (Manager, error) { requestedCertificateLifetime: config.RequestedCertificateLifetime, getUsages: getUsages, certStore: config.CertificateStore, - cert: cert, - forceRotation: forceRotation, certificateRotation: config.CertificateRotation, certificateRenewFailure: config.CertificateRenewFailure, now: time.Now, } + // Determine the name that is to be included in log output from this manager instance. name := config.Name if len(name) == 0 { name = m.signerName @@ -350,11 +343,35 @@ func NewManager(config *Config) (Manager, error) { } } - m.name = name - m.logf = config.Logf - if m.logf == nil { - m.logf = func(format string, args ...interface{}) { klog.V(2).Infof(format, args...) } + // The name gets included through contextual logging. + logger := klog.Background() + if config.Ctx != nil { + logger = klog.FromContext(*config.Ctx) } + logger = klog.LoggerWithName(logger, name) + + cert, forceRotation, err := getCurrentCertificateOrBootstrap( + logger, + config.CertificateStore, + config.BootstrapCertificatePEM, + config.BootstrapKeyPEM) + if err != nil { + return nil, err + } + m.cert, m.forceRotation = cert, forceRotation + + // cancel will be called by Stop, ctx.Done is our stop channel. + m.ctx, m.cancel = context.WithCancelCause(context.Background()) + if config.Ctx != nil && (*config.Ctx).Done() != nil { + ctx := *config.Ctx + // If we have been passed a context and it has a Done channel, then + // we need to map its cancellation to our Done method. + go func() { + <-ctx.Done() + m.Stop() + }() + } + m.ctx = klog.NewContext(m.ctx, logger) return &m, nil } @@ -367,7 +384,7 @@ func (m *manager) Current() *tls.Certificate { m.certAccessLock.RLock() defer m.certAccessLock.RUnlock() if m.cert != nil && m.cert.Leaf != nil && m.now().After(m.cert.Leaf.NotAfter) { - m.logf("%s: Current certificate is expired", m.name) + klog.FromContext(m.ctx).V(2).Info("Current certificate is expired") return nil } return m.cert @@ -383,31 +400,35 @@ func (m *manager) ServerHealthy() bool { // Stop terminates the manager. func (m *manager) Stop() { - m.clientAccessLock.Lock() - defer m.clientAccessLock.Unlock() - if m.stopped { - return - } - close(m.stopCh) - m.stopped = true + m.cancel(errors.New("asked to stop")) } // Start will start the background work of rotating the certificates. func (m *manager) Start() { + go m.run() +} + +// run, in contrast to Start, blocks while the manager is running. +// It waits for all goroutines to stop. +func (m *manager) run() { + logger := klog.FromContext(m.ctx) // Certificate rotation depends on access to the API server certificate // signing API, so don't start the certificate manager if we don't have a // client. if m.clientsetFn == nil { - m.logf("%s: Certificate rotation is not enabled, no connection to the apiserver", m.name) + logger.V(2).Info("Certificate rotation is not enabled, no connection to the apiserver") return } - m.logf("%s: Certificate rotation is enabled", m.name) + logger.V(2).Info("Certificate rotation is enabled") + + var wg sync.WaitGroup + defer wg.Wait() templateChanged := make(chan struct{}) - go wait.Until(func() { - deadline := m.nextRotationDeadline() + rotate := func(ctx context.Context) { + deadline := m.nextRotationDeadline(logger) if sleepInterval := deadline.Sub(m.now()); sleepInterval > 0 { - m.logf("%s: Waiting %v for next certificate rotation", m.name, sleepInterval) + logger.V(2).Info("Waiting for next certificate rotation", "sleep", sleepInterval) timer := time.NewTimer(sleepInterval) defer timer.Stop() @@ -421,7 +442,7 @@ func (m *manager) Start() { // if the template now matches what we last requested, restart the rotation deadline loop return } - m.logf("%s: Certificate template changed, rotating", m.name) + logger.V(2).Info("Certificate template changed, rotating") } } @@ -436,18 +457,24 @@ func (m *manager) Start() { Jitter: 0.1, Steps: 5, } - if err := wait.ExponentialBackoff(backoff, m.rotateCerts); err != nil { - utilruntime.HandleError(fmt.Errorf("%s: Reached backoff limit, still unable to rotate certs: %v", m.name, err)) - wait.PollInfinite(32*time.Second, m.rotateCerts) + if err := wait.ExponentialBackoffWithContext(ctx, backoff, m.rotateCerts); err != nil { + utilruntime.HandleErrorWithContext(ctx, err, "Reached backoff limit, still unable to rotate certs") + wait.PollInfiniteWithContext(ctx, 32*time.Second, m.rotateCerts) } - }, time.Second, m.stopCh) + } + + wg.Add(1) + go func() { + defer wg.Done() + wait.UntilWithContext(m.ctx, rotate, time.Second) + }() if m.dynamicTemplate { - go wait.Until(func() { + template := func(ctx context.Context) { // check if the current template matches what we last requested lastRequestCancel, lastRequestTemplate := m.getLastRequest() - if !m.certSatisfiesTemplate() && !reflect.DeepEqual(lastRequestTemplate, m.getTemplate()) { + if !m.certSatisfiesTemplate(logger) && !reflect.DeepEqual(lastRequestTemplate, m.getTemplate()) { // if the template is different, queue up an interrupt of the rotation deadline loop. // if we've requested a CSR that matches the new template by the time the interrupt is handled, the interrupt is disregarded. if lastRequestCancel != nil { @@ -456,14 +483,20 @@ func (m *manager) Start() { } select { case templateChanged <- struct{}{}: - case <-m.stopCh: + case <-ctx.Done(): } } - }, time.Second, m.stopCh) + } + wg.Add(1) + go func() { + defer wg.Done() + wait.UntilWithContext(m.ctx, template, time.Second) + }() } } func getCurrentCertificateOrBootstrap( + logger klog.Logger, store Store, bootstrapCertificatePEM []byte, bootstrapKeyPEM []byte) (cert *tls.Certificate, shouldRotate bool, errResult error) { @@ -494,7 +527,7 @@ func getCurrentCertificateOrBootstrap( certs, err := x509.ParseCertificates(bootstrapCert.Certificate[0]) if err != nil { - return nil, false, fmt.Errorf("unable to parse certificate data: %v", err) + return nil, false, fmt.Errorf("unable to parse certificate data: %w", err) } if len(certs) < 1 { return nil, false, fmt.Errorf("no cert data found") @@ -502,7 +535,7 @@ func getCurrentCertificateOrBootstrap( bootstrapCert.Leaf = certs[0] if _, err := store.Update(bootstrapCertificatePEM, bootstrapKeyPEM); err != nil { - utilruntime.HandleError(fmt.Errorf("unable to set the cert/key pair to the bootstrap certificate: %v", err)) + utilruntime.HandleErrorWithLogger(logger, err, "Unable to set the cert/key pair to the bootstrap certificate") } return &bootstrapCert, true, nil @@ -519,7 +552,7 @@ func (m *manager) getClientset() (clientset.Interface, error) { // Returns true if it changed the cert, false otherwise. Error is only returned in // exceptional cases. func (m *manager) RotateCerts() (bool, error) { - return m.rotateCerts() + return m.rotateCerts(m.ctx) } // rotateCerts attempts to request a client cert from the server, wait a reasonable @@ -528,12 +561,13 @@ func (m *manager) RotateCerts() (bool, error) { // This method also keeps track of "server health" by interpreting the responses it gets // from the server on the various calls it makes. // TODO: return errors, have callers handle and log them correctly -func (m *manager) rotateCerts() (bool, error) { - m.logf("%s: Rotating certificates", m.name) +func (m *manager) rotateCerts(ctx context.Context) (bool, error) { + logger := klog.FromContext(ctx) + logger.V(2).Info("Rotating certificates") template, csrPEM, keyPEM, privateKey, err := m.generateCSR() if err != nil { - utilruntime.HandleError(fmt.Errorf("%s: Unable to generate a certificate signing request: %v", m.name, err)) + utilruntime.HandleErrorWithContext(ctx, err, "Unable to generate a certificate signing request") if m.certificateRenewFailure != nil { m.certificateRenewFailure.Inc() } @@ -543,7 +577,7 @@ func (m *manager) rotateCerts() (bool, error) { // request the client each time clientSet, err := m.getClientset() if err != nil { - utilruntime.HandleError(fmt.Errorf("%s: Unable to load a client to request certificates: %v", m.name, err)) + utilruntime.HandleErrorWithContext(ctx, err, "Unable to load a client to request certificates") if m.certificateRenewFailure != nil { m.certificateRenewFailure.Inc() } @@ -557,16 +591,16 @@ func (m *manager) rotateCerts() (bool, error) { usages := getUsages(privateKey) // Call the Certificate Signing Request API to get a certificate for the // new private key - reqName, reqUID, err := csr.RequestCertificate(clientSet, csrPEM, "", m.signerName, m.requestedCertificateLifetime, usages, privateKey) + reqName, reqUID, err := csr.RequestCertificateWithContext(ctx, clientSet, csrPEM, "", m.signerName, m.requestedCertificateLifetime, usages, privateKey) if err != nil { - utilruntime.HandleError(fmt.Errorf("%s: Failed while requesting a signed certificate from the control plane: %v", m.name, err)) + utilruntime.HandleErrorWithContext(ctx, err, "Failed while requesting a signed certificate from the control plane") if m.certificateRenewFailure != nil { m.certificateRenewFailure.Inc() } return false, m.updateServerError(err) } - ctx, cancel := context.WithTimeout(context.Background(), certificateWaitTimeout) + ctx, cancel := context.WithTimeout(ctx, certificateWaitTimeout) defer cancel() // Once we've successfully submitted a CSR for this template, record that we did so @@ -576,7 +610,7 @@ func (m *manager) rotateCerts() (bool, error) { // is a remainder after the old design using raw watch wrapped with backoff. crtPEM, err := csr.WaitForCertificate(ctx, clientSet, reqName, reqUID) if err != nil { - utilruntime.HandleError(fmt.Errorf("%s: certificate request was not signed: %v", m.name, err)) + utilruntime.HandleErrorWithContext(ctx, err, "Certificate request was not signed") if m.certificateRenewFailure != nil { m.certificateRenewFailure.Inc() } @@ -585,7 +619,7 @@ func (m *manager) rotateCerts() (bool, error) { cert, err := m.certStore.Update(crtPEM, keyPEM) if err != nil { - utilruntime.HandleError(fmt.Errorf("%s: Unable to store the new cert/key pair: %v", m.name, err)) + utilruntime.HandleErrorWithContext(ctx, err, "Unable to store the new cert/key pair") if m.certificateRenewFailure != nil { m.certificateRenewFailure.Inc() } @@ -606,14 +640,14 @@ func (m *manager) rotateCerts() (bool, error) { // the template will not trigger a renewal. // // Requires certAccessLock to be locked. -func (m *manager) certSatisfiesTemplateLocked() bool { +func (m *manager) certSatisfiesTemplateLocked(logger klog.Logger) bool { if m.cert == nil { return false } if template := m.getTemplate(); template != nil { if template.Subject.CommonName != m.cert.Leaf.Subject.CommonName { - m.logf("%s: Current certificate CN (%s) does not match requested CN (%s)", m.name, m.cert.Leaf.Subject.CommonName, template.Subject.CommonName) + logger.V(2).Info("Current certificate CN does not match requested CN", "currentName", m.cert.Leaf.Subject.CommonName, "requestedName", template.Subject.CommonName) return false } @@ -621,7 +655,7 @@ func (m *manager) certSatisfiesTemplateLocked() bool { desiredDNSNames := sets.NewString(template.DNSNames...) missingDNSNames := desiredDNSNames.Difference(currentDNSNames) if len(missingDNSNames) > 0 { - m.logf("%s: Current certificate is missing requested DNS names %v", m.name, missingDNSNames.List()) + logger.V(2).Info("Current certificate is missing requested DNS names", "dnsNames", missingDNSNames.List()) return false } @@ -635,7 +669,7 @@ func (m *manager) certSatisfiesTemplateLocked() bool { } missingIPs := desiredIPs.Difference(currentIPs) if len(missingIPs) > 0 { - m.logf("%s: Current certificate is missing requested IP addresses %v", m.name, missingIPs.List()) + logger.V(2).Info("Current certificate is missing requested IP addresses", "IPs", missingIPs.List()) return false } @@ -643,7 +677,7 @@ func (m *manager) certSatisfiesTemplateLocked() bool { desiredOrgs := sets.NewString(template.Subject.Organization...) missingOrgs := desiredOrgs.Difference(currentOrgs) if len(missingOrgs) > 0 { - m.logf("%s: Current certificate is missing requested orgs %v", m.name, missingOrgs.List()) + logger.V(2).Info("Current certificate is missing requested orgs", "orgs", missingOrgs.List()) return false } } @@ -651,16 +685,16 @@ func (m *manager) certSatisfiesTemplateLocked() bool { return true } -func (m *manager) certSatisfiesTemplate() bool { +func (m *manager) certSatisfiesTemplate(logger klog.Logger) bool { m.certAccessLock.RLock() defer m.certAccessLock.RUnlock() - return m.certSatisfiesTemplateLocked() + return m.certSatisfiesTemplateLocked(logger) } // nextRotationDeadline returns a value for the threshold at which the // current certificate should be rotated, 80%+/-10% of the expiration of the // certificate. -func (m *manager) nextRotationDeadline() time.Time { +func (m *manager) nextRotationDeadline(logger klog.Logger) time.Time { // forceRotation is not protected by locks if m.forceRotation { m.forceRotation = false @@ -670,7 +704,7 @@ func (m *manager) nextRotationDeadline() time.Time { m.certAccessLock.RLock() defer m.certAccessLock.RUnlock() - if !m.certSatisfiesTemplateLocked() { + if !m.certSatisfiesTemplateLocked(logger) { return m.now() } @@ -678,7 +712,7 @@ func (m *manager) nextRotationDeadline() time.Time { totalDuration := float64(notAfter.Sub(m.cert.Leaf.NotBefore)) deadline := m.cert.Leaf.NotBefore.Add(jitteryDuration(totalDuration)) - m.logf("%s: Certificate expiration is %v, rotation deadline is %v", m.name, notAfter, deadline) + logger.V(2).Info("Certificate rotation deadline determined", "expiration", notAfter, "deadline", deadline) return deadline } @@ -732,22 +766,22 @@ func (m *manager) generateCSR() (template *x509.CertificateRequest, csrPEM []byt // Generate a new private key. privateKey, err := ecdsa.GenerateKey(elliptic.P256(), cryptorand.Reader) if err != nil { - return nil, nil, nil, nil, fmt.Errorf("%s: unable to generate a new private key: %v", m.name, err) + return nil, nil, nil, nil, fmt.Errorf("unable to generate a new private key: %w", err) } der, err := x509.MarshalECPrivateKey(privateKey) if err != nil { - return nil, nil, nil, nil, fmt.Errorf("%s: unable to marshal the new key to DER: %v", m.name, err) + return nil, nil, nil, nil, fmt.Errorf("unable to marshal the new key to DER: %w", err) } keyPEM = pem.EncodeToMemory(&pem.Block{Type: keyutil.ECPrivateKeyBlockType, Bytes: der}) template = m.getTemplate() if template == nil { - return nil, nil, nil, nil, fmt.Errorf("%s: unable to create a csr, no template available", m.name) + return nil, nil, nil, nil, errors.New("unable to create a csr, no template available") } csrPEM, err = cert.MakeCSRFromTemplate(privateKey, template) if err != nil { - return nil, nil, nil, nil, fmt.Errorf("%s: unable to create a csr from the private key: %v", m.name, err) + return nil, nil, nil, nil, fmt.Errorf("unable to create a csr from the private key: %w", err) } return template, csrPEM, keyPEM, privateKey, nil } diff --git a/util/certificate/certificate_manager_test.go b/util/certificate/certificate_manager_test.go index df1fd1d5..867ea0b9 100644 --- a/util/certificate/certificate_manager_test.go +++ b/util/certificate/certificate_manager_test.go @@ -18,15 +18,20 @@ package certificate import ( "bytes" + "context" "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "errors" "fmt" "net" "strings" + "sync" "testing" "time" + "github.com/stretchr/testify/require" + certificatesv1 "k8s.io/api/certificates/v1" certificatesv1beta1 "k8s.io/api/certificates/v1beta1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -38,6 +43,8 @@ import ( "k8s.io/client-go/kubernetes/fake" certificatesclient "k8s.io/client-go/kubernetes/typed/certificates/v1beta1" clienttesting "k8s.io/client-go/testing" + "k8s.io/klog/v2" + "k8s.io/klog/v2/ktesting" netutils "k8s.io/utils/net" ) @@ -268,6 +275,7 @@ func TestSetRotationDeadline(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + logger, ctx := ktesting.NewTestContext(t) m := manager{ cert: &tls.Certificate{ Leaf: &x509.Certificate{ @@ -277,12 +285,12 @@ func TestSetRotationDeadline(t *testing.T) { }, getTemplate: func() *x509.CertificateRequest { return &x509.CertificateRequest{} }, now: func() time.Time { return now }, - logf: t.Logf, + ctx: ctx, } jitteryDuration = func(float64) time.Duration { return time.Duration(float64(tc.notAfter.Sub(tc.notBefore)) * 0.7) } lowerBound := tc.notBefore.Add(time.Duration(float64(tc.notAfter.Sub(tc.notBefore)) * 0.7)) - deadline := m.nextRotationDeadline() + deadline := m.nextRotationDeadline(logger) if !deadline.Equal(lowerBound) { t.Errorf("For notBefore %v, notAfter %v, the rotationDeadline %v should be %v.", @@ -438,6 +446,7 @@ func TestCertSatisfiesTemplate(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + logger, ctx := ktesting.NewTestContext(t) var tlsCert *tls.Certificate if tc.cert != nil { @@ -450,10 +459,10 @@ func TestCertSatisfiesTemplate(t *testing.T) { cert: tlsCert, getTemplate: func() *x509.CertificateRequest { return tc.template }, now: time.Now, - logf: t.Logf, + ctx: ctx, } - result := m.certSatisfiesTemplate() + result := m.certSatisfiesTemplate(logger) if result != tc.shouldSatisfy { t.Errorf("cert: %+v, template: %+v, certSatisfiesTemplate returned %v, want %v", m.cert, tc.template, result, tc.shouldSatisfy) } @@ -462,6 +471,7 @@ func TestCertSatisfiesTemplate(t *testing.T) { } func TestRotateCertCreateCSRError(t *testing.T) { + _, ctx := ktesting.NewTestContext(t) now := time.Now() m := manager{ cert: &tls.Certificate{ @@ -474,11 +484,11 @@ func TestRotateCertCreateCSRError(t *testing.T) { clientsetFn: func(_ *tls.Certificate) (clientset.Interface, error) { return newClientset(fakeClient{failureType: createError}), nil }, - now: func() time.Time { return now }, - logf: t.Logf, + now: func() time.Time { return now }, + ctx: ctx, } - if success, err := m.rotateCerts(); success { + if success, err := m.rotateCerts(ctx); success { t.Errorf("Got success from 'rotateCerts', wanted failure") } else if err != nil { t.Errorf("Got error %v from 'rotateCerts', wanted no error.", err) @@ -486,6 +496,7 @@ func TestRotateCertCreateCSRError(t *testing.T) { } func TestRotateCertWaitingForResultError(t *testing.T) { + _, ctx := ktesting.NewTestContext(t) now := time.Now() m := manager{ cert: &tls.Certificate{ @@ -498,13 +509,13 @@ func TestRotateCertWaitingForResultError(t *testing.T) { clientsetFn: func(_ *tls.Certificate) (clientset.Interface, error) { return newClientset(fakeClient{failureType: watchError}), nil }, - now: func() time.Time { return now }, - logf: t.Logf, + now: func() time.Time { return now }, + ctx: ctx, } defer func(t time.Duration) { certificateWaitTimeout = t }(certificateWaitTimeout) certificateWaitTimeout = 1 * time.Millisecond - if success, err := m.rotateCerts(); success { + if success, err := m.rotateCerts(ctx); success { t.Errorf("Got success from 'rotateCerts', wanted failure.") } else if err != nil { t.Errorf("Got error %v from 'rotateCerts', wanted no error.", err) @@ -610,11 +621,13 @@ func TestGetCurrentCertificateOrBootstrap(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { + logger, _ := ktesting.NewTestContext(t) store := &fakeStore{ cert: tc.storeCert, } certResult, shouldRotate, err := getCurrentCertificateOrBootstrap( + logger, store, tc.bootstrapCertData, tc.bootstrapKeyData) @@ -717,6 +730,7 @@ func TestInitializeCertificateSigningRequestClient(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { + _, ctx := ktesting.NewTestContext(t) certificateStore := &fakeStore{ cert: tc.storeCert.certificate, } @@ -744,6 +758,7 @@ func TestInitializeCertificateSigningRequestClient(t *testing.T) { certificatePEM: tc.apiCert.certificatePEM, }), nil }, + Ctx: &ctx, }) if err != nil { t.Errorf("Got %v, wanted no error.", err) @@ -764,7 +779,7 @@ func TestInitializeCertificateSigningRequestClient(t *testing.T) { t.Errorf("Expected a '*manager' from 'NewManager'") } else { if m.forceRotation { - if success, err := m.rotateCerts(); !success { + if success, err := m.rotateCerts(ctx); !success { t.Errorf("Got failure from 'rotateCerts', wanted success.") } else if err != nil { t.Errorf("Got error %v, expected none.", err) @@ -832,6 +847,7 @@ func TestInitializeOtherRESTClients(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { + _, ctx := ktesting.NewTestContext(t) certificateStore := &fakeStore{ cert: tc.storeCert.certificate, } @@ -870,7 +886,7 @@ func TestInitializeOtherRESTClients(t *testing.T) { t.Errorf("Expected a '*manager' from 'NewManager'") } else { if m.forceRotation { - success, err := certificateManager.(*manager).rotateCerts() + success, err := certificateManager.(*manager).rotateCerts(ctx) if err != nil { t.Errorf("Got error %v, expected none.", err) return @@ -977,6 +993,7 @@ func TestServerHealth(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { + _, ctx := ktesting.NewTestContext(t) certificateStore := &fakeStore{ cert: tc.storeCert.certificate, } @@ -1016,7 +1033,7 @@ func TestServerHealth(t *testing.T) { if _, ok := certificateManager.(*manager); !ok { t.Errorf("Expected a '*manager' from 'NewManager'") } else { - success, err := certificateManager.(*manager).rotateCerts() + success, err := certificateManager.(*manager).rotateCerts(ctx) if err != nil { t.Errorf("Got error %v, expected none.", err) return @@ -1039,6 +1056,7 @@ func TestServerHealth(t *testing.T) { } func TestRotationLogsDuration(t *testing.T) { + _, ctx := ktesting.NewTestContext(t) h := metricMock{} now := time.Now() certIss := now.Add(-2 * time.Hour) @@ -1058,9 +1076,9 @@ func TestRotationLogsDuration(t *testing.T) { }, certificateRotation: &h, now: func() time.Time { return now }, - logf: t.Logf, + ctx: ctx, } - ok, err := m.rotateCerts() + ok, err := m.rotateCerts(ctx) if err != nil || !ok { t.Errorf("failed to rotate certs: %v", err) } @@ -1073,6 +1091,101 @@ func TestRotationLogsDuration(t *testing.T) { } +func TestStop(t *testing.T) { + // No certificate yet, will be added while manager runs. + store := &fakeStore{} + m, err := NewManager(&Config{ + GetTemplate: func() *x509.CertificateRequest { + return &x509.CertificateRequest{ + Subject: pkix.Name{ + Organization: []string{"system:nodes"}, + CommonName: "system:node:fake-node-name", + }, + } + }, + Usages: []certificatesv1.KeyUsage{}, + CertificateStore: store, + ClientsetFn: func(_ *tls.Certificate) (clientset.Interface, error) { + return newClientset(fakeClient{ + certificatePEM: apiServerCertData.certificatePEM, + }), nil + }, + }) + require.NoError(t, err, "initialize the certificate manager") + require.Nil(t, m.Current(), "no certificate yet") + + // Run the manager and stop it when the test cleans up. + var wg sync.WaitGroup + defer func() { + t.Log("Waiting for manager to stop...") + wg.Wait() + }() + wg.Add(1) + go func() { + defer wg.Done() + m.(*manager).run() + }() + defer m.Stop() + + require.Eventually(t, func() bool { + return m.Current() != nil + }, 10*time.Second, time.Microsecond, "current certificate") +} + +func TestContext(t *testing.T) { + logger := ktesting.NewLogger(t, ktesting.NewConfig( + ktesting.BufferLogs(true), + ktesting.Verbosity(2), + )) + ctx := klog.NewContext(context.Background(), logger) + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(errors.New("test is done")) + + // No certificate yet, will be added while manager runs. + store := &fakeStore{} + m, err := NewManager(&Config{ + Ctx: &ctx, + GetTemplate: func() *x509.CertificateRequest { + return &x509.CertificateRequest{ + Subject: pkix.Name{ + Organization: []string{"system:nodes"}, + CommonName: "system:node:fake-node-name", + }, + } + }, + Usages: []certificatesv1.KeyUsage{}, + CertificateStore: store, + ClientsetFn: func(_ *tls.Certificate) (clientset.Interface, error) { + return newClientset(fakeClient{ + certificatePEM: apiServerCertData.certificatePEM, + }), nil + }, + }) + require.NoError(t, err, "initialize the certificate manager") + require.Nil(t, m.Current(), "no certificate yet") + + // Run the manager and stop it when the test cleans up. + var wg sync.WaitGroup + defer func() { + t.Log("Waiting for manager to stop...") + wg.Wait() + + // Must be a no-op. + m.Stop() + }() + wg.Add(1) + go func() { + defer wg.Done() + m.(*manager).run() + }() + defer cancel(errors.New("testing context cancellation")) + + require.Eventually(t, func() bool { + return m.Current() != nil + }, 10*time.Second, time.Microsecond, "current certificate") + require.Contains(t, logger.GetSink().(ktesting.Underlier).GetBuffer().String(), "certificate: Rotating certificates", "contextual log output from manager.rotateCerts") +} + type fakeClientFailureType int const ( @@ -1243,10 +1356,14 @@ func (w *fakeWatch) ResultChan() <-chan watch.Event { } type fakeStore struct { - cert *tls.Certificate + mutex sync.Mutex + cert *tls.Certificate } func (s *fakeStore) Current() (*tls.Certificate, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + if s.cert == nil { noKeyErr := NoCertKeyError("") return nil, &noKeyErr @@ -1258,6 +1375,9 @@ func (s *fakeStore) Current() (*tls.Certificate, error) { // pair the 'current' pair, that will be returned by future calls to // Current(). func (s *fakeStore) Update(certPEM, keyPEM []byte) (*tls.Certificate, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + // In order to make the mocking work, whenever a cert/key pair is passed in // to be updated in the mock store, assume that the certificate manager // generated the key, and then asked the mock CertificateSigningRequest API diff --git a/util/certificate/certificate_store.go b/util/certificate/certificate_store.go index e7ed58ee..4f9d5db0 100644 --- a/util/certificate/certificate_store.go +++ b/util/certificate/certificate_store.go @@ -38,6 +38,7 @@ const ( ) type fileStore struct { + logger klog.Logger pairNamePrefix string certDirectory string keyDirectory string @@ -67,14 +68,30 @@ type FileStore interface { // updates will be written to the ${certDirectory} directory and // ${certDirectory}/${pairNamePrefix}-current.pem will be created as a soft // link to the currently selected cert/key pair. +// +// Contextual logging: NewFileStoreWithLogger should be used instead of NewFileStore in code which supports contextual logging. func NewFileStore( pairNamePrefix string, certDirectory string, keyDirectory string, certFile string, keyFile string) (FileStore, error) { + return NewFileStoreWithLogger(klog.Background(), pairNamePrefix, certDirectory, keyDirectory, certFile, keyFile) +} + +// NewFileStoreWithLogger is a variant of NewFileStore where the caller is in +// control of logging. All log messages get emitted with logger.Info, so +// pass e.g. logger.V(3) to make logging less verbose. +func NewFileStoreWithLogger( + logger klog.Logger, + pairNamePrefix string, + certDirectory string, + keyDirectory string, + certFile string, + keyFile string) (FileStore, error) { s := fileStore{ + logger: logger, pairNamePrefix: pairNamePrefix, certDirectory: certDirectory, keyDirectory: keyDirectory, @@ -127,7 +144,7 @@ func (s *fileStore) Current() (*tls.Certificate, error) { if pairFileExists, err := fileExists(pairFile); err != nil { return nil, err } else if pairFileExists { - klog.Infof("Loading cert/key pair from %q.", pairFile) + s.logger.Info("Loading cert/key pair from a file", "filePath", pairFile) return loadFile(pairFile) } @@ -140,7 +157,7 @@ func (s *fileStore) Current() (*tls.Certificate, error) { return nil, err } if certFileExists && keyFileExists { - klog.Infof("Loading cert/key pair from (%q, %q).", s.certFile, s.keyFile) + s.logger.Info("Loading cert/key pair", "certFile", s.certFile, "keyFile", s.keyFile) return loadX509KeyPair(s.certFile, s.keyFile) } @@ -155,7 +172,7 @@ func (s *fileStore) Current() (*tls.Certificate, error) { return nil, err } if certFileExists && keyFileExists { - klog.Infof("Loading cert/key pair from (%q, %q).", c, k) + s.logger.Info("Loading cert/key pair", "certFile", c, "keyFile", k) return loadX509KeyPair(c, k) } diff --git a/util/certificate/certificate_store_test.go b/util/certificate/certificate_store_test.go index 3d6abaa4..11591520 100644 --- a/util/certificate/certificate_store_test.go +++ b/util/certificate/certificate_store_test.go @@ -235,6 +235,7 @@ func TestUpdateNoRotation(t *testing.T) { t.Fatalf("Unable to create the file %q: %v", certFile, err) } + //nolint:logcheck // Intentionally uses the old API. s, err := NewFileStore(prefix, dir, dir, certFile, keyFile) if err != nil { t.Fatalf("Got %v while creating a new store.", err) @@ -269,6 +270,7 @@ func TestUpdateRotation(t *testing.T) { t.Fatalf("Unable to create the file %q: %v", certFile, err) } + //nolint:logcheck // Intentionally uses the old API. s, err := NewFileStore(prefix, dir, dir, certFile, keyFile) if err != nil { t.Fatalf("Got %v while creating a new store.", err) @@ -303,6 +305,7 @@ func TestUpdateTwoCerts(t *testing.T) { t.Fatalf("Unable to create the file %q: %v", certFile, err) } + //nolint:logcheck // Intentionally uses the old API. s, err := NewFileStore(prefix, dir, dir, certFile, keyFile) if err != nil { t.Fatalf("Got %v while creating a new store.", err) @@ -340,6 +343,7 @@ func TestUpdateWithBadCertKeyData(t *testing.T) { t.Fatalf("Unable to create the file %q: %v", certFile, err) } + //nolint:logcheck // Intentionally uses the old API. s, err := NewFileStore(prefix, dir, dir, certFile, keyFile) if err != nil { t.Fatalf("Got %v while creating a new store.", err) @@ -376,6 +380,7 @@ func TestCurrentPairFile(t *testing.T) { t.Fatalf("unable to create a symlink from %q to %q: %v", currentFile, pairFile, err) } + //nolint:logcheck // Intentionally uses the old API. store, err := NewFileStore("kubelet-server", dir, dir, "", "") if err != nil { t.Fatalf("Failed to initialize certificate store: %v", err) @@ -413,6 +418,7 @@ func TestCurrentCertKeyFiles(t *testing.T) { t.Fatalf("Unable to create the file %q: %v", keyFile, err) } + //nolint:logcheck // Intentionally uses the old API. store, err := NewFileStore(prefix, dir, dir, certFile, keyFile) if err != nil { t.Fatalf("Failed to initialize certificate store: %v", err) @@ -450,6 +456,7 @@ func TestCurrentTwoCerts(t *testing.T) { t.Fatalf("Unable to create the file %q: %v", keyFile, err) } + //nolint:logcheck // Intentionally uses the old API. store, err := NewFileStore(prefix, dir, dir, certFile, keyFile) if err != nil { t.Fatalf("Failed to initialize certificate store: %v", err) @@ -481,6 +488,7 @@ func TestCurrentNoFiles(t *testing.T) { } }() + //nolint:logcheck // Intentionally uses the old API. store, err := NewFileStore("kubelet-server", dir, dir, "", "") if err != nil { t.Fatalf("Failed to initialize certificate store: %v", err) diff --git a/util/certificate/csr/csr.go b/util/certificate/csr/csr.go index 0390d1c0..a2921ecd 100644 --- a/util/certificate/csr/csr.go +++ b/util/certificate/csr/csr.go @@ -47,7 +47,18 @@ import ( // PEM encoded CSR and send it to API server. An optional requestedDuration may be passed // to set the spec.expirationSeconds field on the CSR to control the lifetime of the issued // certificate. This is not guaranteed as the signer may choose to ignore the request. +// +// Deprecated: use RequestCertificateWithContext instead. func RequestCertificate(client clientset.Interface, csrData []byte, name, signerName string, requestedDuration *time.Duration, usages []certificatesv1.KeyUsage, privateKey interface{}) (reqName string, reqUID types.UID, err error) { + return RequestCertificateWithContext(context.Background(), client, csrData, name, signerName, requestedDuration, usages, privateKey) +} + +// RequestCertificateWithContext will either use an existing (if this process has run +// before but not to completion) or create a certificate signing request using the +// PEM encoded CSR and send it to API server. An optional requestedDuration may be passed +// to set the spec.expirationSeconds field on the CSR to control the lifetime of the issued +// certificate. This is not guaranteed as the signer may choose to ignore the request. +func RequestCertificateWithContext(ctx context.Context, client clientset.Interface, csrData []byte, name, signerName string, requestedDuration *time.Duration, usages []certificatesv1.KeyUsage, privateKey interface{}) (reqName string, reqUID types.UID, err error) { csr := &certificatesv1.CertificateSigningRequest{ // Username, UID, Groups will be injected by API server. TypeMeta: metav1.TypeMeta{Kind: "CertificateSigningRequest"}, @@ -67,21 +78,22 @@ func RequestCertificate(client clientset.Interface, csrData []byte, name, signer csr.Spec.ExpirationSeconds = DurationToExpirationSeconds(*requestedDuration) } - reqName, reqUID, err = create(client, csr) + reqName, reqUID, err = create(ctx, client, csr) switch { case err == nil: return reqName, reqUID, err case apierrors.IsAlreadyExists(err) && len(name) > 0: - klog.Infof("csr for this node already exists, reusing") - req, err := get(client, name) + logger := klog.FromContext(ctx) + logger.Info("csr for this node already exists, reusing") + req, err := get(ctx, client, name) if err != nil { return "", "", formatError("cannot retrieve certificate signing request: %v", err) } if err := ensureCompatible(req, csr, privateKey); err != nil { return "", "", fmt.Errorf("retrieved csr is not compatible: %v", err) } - klog.Infof("csr for this node is still valid") + logger.Info("csr for this node is still valid") return req.Name, req.UID, nil default: @@ -97,13 +109,13 @@ func ExpirationSecondsToDuration(expirationSeconds int32) time.Duration { return time.Duration(expirationSeconds) * time.Second } -func get(client clientset.Interface, name string) (*certificatesv1.CertificateSigningRequest, error) { - v1req, v1err := client.CertificatesV1().CertificateSigningRequests().Get(context.TODO(), name, metav1.GetOptions{}) +func get(ctx context.Context, client clientset.Interface, name string) (*certificatesv1.CertificateSigningRequest, error) { + v1req, v1err := client.CertificatesV1().CertificateSigningRequests().Get(ctx, name, metav1.GetOptions{}) if v1err == nil || !apierrors.IsNotFound(v1err) { return v1req, v1err } - v1beta1req, v1beta1err := client.CertificatesV1beta1().CertificateSigningRequests().Get(context.TODO(), name, metav1.GetOptions{}) + v1beta1req, v1beta1err := client.CertificatesV1beta1().CertificateSigningRequests().Get(ctx, name, metav1.GetOptions{}) if v1beta1err != nil { return nil, v1beta1err } @@ -123,10 +135,10 @@ func get(client clientset.Interface, name string) (*certificatesv1.CertificateSi return v1req, nil } -func create(client clientset.Interface, csr *certificatesv1.CertificateSigningRequest) (reqName string, reqUID types.UID, err error) { +func create(ctx context.Context, client clientset.Interface, csr *certificatesv1.CertificateSigningRequest) (reqName string, reqUID types.UID, err error) { // only attempt a create via v1 if we specified signerName and usages and are not using the legacy unknown signerName if len(csr.Spec.Usages) > 0 && len(csr.Spec.SignerName) > 0 && csr.Spec.SignerName != "kubernetes.io/legacy-unknown" { - v1req, v1err := client.CertificatesV1().CertificateSigningRequests().Create(context.TODO(), csr, metav1.CreateOptions{}) + v1req, v1err := client.CertificatesV1().CertificateSigningRequests().Create(ctx, csr, metav1.CreateOptions{}) switch { case v1err != nil && apierrors.IsNotFound(v1err): // v1 CSR API was not found, continue to try v1beta1 @@ -154,7 +166,7 @@ func create(client clientset.Interface, csr *certificatesv1.CertificateSigningRe } // create v1beta1 - v1beta1req, v1beta1err := client.CertificatesV1beta1().CertificateSigningRequests().Create(context.TODO(), v1beta1csr, metav1.CreateOptions{}) + v1beta1req, v1beta1err := client.CertificatesV1beta1().CertificateSigningRequests().Create(ctx, v1beta1csr, metav1.CreateOptions{}) if v1beta1err != nil { return "", "", v1beta1err } @@ -164,6 +176,7 @@ func create(client clientset.Interface, csr *certificatesv1.CertificateSigningRe // WaitForCertificate waits for a certificate to be issued until timeout, or returns an error. func WaitForCertificate(ctx context.Context, client clientset.Interface, reqName string, reqUID types.UID) (certData []byte, err error) { fieldSelector := fields.OneTermEqualSelector("metadata.name", reqName).String() + logger := klog.FromContext(ctx) var lw *cache.ListWatch var obj runtime.Object @@ -184,7 +197,7 @@ func WaitForCertificate(ctx context.Context, client clientset.Interface, reqName } break } else { - klog.V(2).Infof("error fetching v1 certificate signing request: %v", err) + logger.V(2).Info("Error fetching v1 certificate signing request", "err", err) } // return if we've timed out @@ -208,7 +221,7 @@ func WaitForCertificate(ctx context.Context, client clientset.Interface, reqName } break } else { - klog.V(2).Infof("error fetching v1beta1 certificate signing request: %v", err) + logger.V(2).Info("Error fetching v1beta1 certificate signing request", "err", err) } // return if we've timed out @@ -254,11 +267,11 @@ func WaitForCertificate(ctx context.Context, client clientset.Interface, reqName } if approved { if len(csr.Status.Certificate) > 0 { - klog.V(2).Infof("certificate signing request %s is issued", csr.Name) + logger.V(2).Info("Certificate signing request is issued", "csr", klog.KObj(csr)) issuedCertificate = csr.Status.Certificate return true, nil } - klog.V(2).Infof("certificate signing request %s is approved, waiting to be issued", csr.Name) + logger.V(2).Info("Certificate signing request is approved, waiting to be issued", "csr", klog.KObj(csr)) } case *certificatesv1beta1.CertificateSigningRequest: @@ -279,11 +292,11 @@ func WaitForCertificate(ctx context.Context, client clientset.Interface, reqName } if approved { if len(csr.Status.Certificate) > 0 { - klog.V(2).Infof("certificate signing request %s is issued", csr.Name) + logger.V(2).Info("Certificate signing request is issued", "csr", klog.KObj(csr)) issuedCertificate = csr.Status.Certificate return true, nil } - klog.V(2).Infof("certificate signing request %s is approved, waiting to be issued", csr.Name) + logger.V(2).Info("Certificate signing request is approved, waiting to be issued", "csr", klog.KObj(csr)) } default: