client-go certificate: context-aware APIs and logging

For NewManager, the Config struct gets changed (not extended!) so that the
caller can provide a logger instead of just a logging function. Breaking the
API was chosen because it avoids having to maintain two different log calls in
various places (one for printf-style logging, one for structured logging).

RequestCertificateWithContext is an extension. It enables getting rid of
context.TODO calls.

NewFileStoreWithLogger also is an extension.

Kubernetes-commit: f9051901cee8d8ff4aed3db27ff495a706f1a487
This commit is contained in:
Patrick Ohly 2024-12-09 12:45:54 +01:00 committed by Kubernetes Publisher
parent b46275ad75
commit d0f5d55191
5 changed files with 298 additions and 106 deletions

View File

@ -203,13 +203,16 @@ type Config struct {
// certificate renewal failures. // certificate renewal failures.
CertificateRenewFailure Counter CertificateRenewFailure Counter
// Name is an optional string that will be used when writing log output // Name is an optional string that will be used when writing log output
// or returning errors from manager methods. If not set, SignerName will // via logger.WithName or returning errors from manager methods.
//
// If not set, SignerName will
// be used, if SignerName is not set, if Usages includes client auth the // be used, if SignerName is not set, if Usages includes client auth the
// name will be "client auth", otherwise the value will be "server". // name will be "client auth", otherwise the value will be "server".
Name string Name string
// Logf is an optional function that log output will be sent to from the // Ctx is an optional context. Cancelling it is equivalent to
// certificate manager. If not set it will use klog.V(2) // calling Stop. A logger is extracted from it if non-nil, otherwise
Logf func(format string, args ...interface{}) // klog.Background() is used.
Ctx *context.Context
} }
// Store is responsible for getting and updating the current certificate. // Store is responsible for getting and updating the current certificate.
@ -278,30 +281,22 @@ type manager struct {
cert *tls.Certificate cert *tls.Certificate
serverHealth bool serverHealth bool
// Context and cancel function for background goroutines.
ctx context.Context
cancel func(err error)
// the clientFn must only be accessed under the clientAccessLock // the clientFn must only be accessed under the clientAccessLock
clientAccessLock sync.Mutex clientAccessLock sync.Mutex
clientsetFn ClientsetFunc clientsetFn ClientsetFunc
stopCh chan struct{}
stopped bool
// Set to time.Now but can be stubbed out for testing // Set to time.Now but can be stubbed out for testing
now func() time.Time now func() time.Time
name string
logf func(format string, args ...interface{})
} }
// NewManager returns a new certificate manager. A certificate manager is // NewManager returns a new certificate manager. A certificate manager is
// responsible for being the authoritative source of certificates in the // responsible for being the authoritative source of certificates in the
// Kubelet and handling updates due to rotation. // Kubelet and handling updates due to rotation.
func NewManager(config *Config) (Manager, error) { func NewManager(config *Config) (Manager, error) {
cert, forceRotation, err := getCurrentCertificateOrBootstrap(
config.CertificateStore,
config.BootstrapCertificatePEM,
config.BootstrapKeyPEM)
if err != nil {
return nil, err
}
getTemplate := config.GetTemplate getTemplate := config.GetTemplate
if getTemplate == nil { if getTemplate == nil {
@ -321,7 +316,6 @@ func NewManager(config *Config) (Manager, error) {
getUsages = func(interface{}) []certificates.KeyUsage { return config.Usages } getUsages = func(interface{}) []certificates.KeyUsage { return config.Usages }
} }
m := manager{ m := manager{
stopCh: make(chan struct{}),
clientsetFn: config.ClientsetFn, clientsetFn: config.ClientsetFn,
getTemplate: getTemplate, getTemplate: getTemplate,
dynamicTemplate: config.GetTemplate != nil, dynamicTemplate: config.GetTemplate != nil,
@ -329,13 +323,12 @@ func NewManager(config *Config) (Manager, error) {
requestedCertificateLifetime: config.RequestedCertificateLifetime, requestedCertificateLifetime: config.RequestedCertificateLifetime,
getUsages: getUsages, getUsages: getUsages,
certStore: config.CertificateStore, certStore: config.CertificateStore,
cert: cert,
forceRotation: forceRotation,
certificateRotation: config.CertificateRotation, certificateRotation: config.CertificateRotation,
certificateRenewFailure: config.CertificateRenewFailure, certificateRenewFailure: config.CertificateRenewFailure,
now: time.Now, now: time.Now,
} }
// Determine the name that is to be included in log output from this manager instance.
name := config.Name name := config.Name
if len(name) == 0 { if len(name) == 0 {
name = m.signerName name = m.signerName
@ -350,11 +343,35 @@ func NewManager(config *Config) (Manager, error) {
} }
} }
m.name = name // The name gets included through contextual logging.
m.logf = config.Logf logger := klog.Background()
if m.logf == nil { if config.Ctx != nil {
m.logf = func(format string, args ...interface{}) { klog.V(2).Infof(format, args...) } logger = klog.FromContext(*config.Ctx)
} }
logger = klog.LoggerWithName(logger, name)
cert, forceRotation, err := getCurrentCertificateOrBootstrap(
logger,
config.CertificateStore,
config.BootstrapCertificatePEM,
config.BootstrapKeyPEM)
if err != nil {
return nil, err
}
m.cert, m.forceRotation = cert, forceRotation
// cancel will be called by Stop, ctx.Done is our stop channel.
m.ctx, m.cancel = context.WithCancelCause(context.Background())
if config.Ctx != nil && (*config.Ctx).Done() != nil {
ctx := *config.Ctx
// If we have been passed a context and it has a Done channel, then
// we need to map its cancellation to our Done method.
go func() {
<-ctx.Done()
m.Stop()
}()
}
m.ctx = klog.NewContext(m.ctx, logger)
return &m, nil return &m, nil
} }
@ -367,7 +384,7 @@ func (m *manager) Current() *tls.Certificate {
m.certAccessLock.RLock() m.certAccessLock.RLock()
defer m.certAccessLock.RUnlock() defer m.certAccessLock.RUnlock()
if m.cert != nil && m.cert.Leaf != nil && m.now().After(m.cert.Leaf.NotAfter) { if m.cert != nil && m.cert.Leaf != nil && m.now().After(m.cert.Leaf.NotAfter) {
m.logf("%s: Current certificate is expired", m.name) klog.FromContext(m.ctx).V(2).Info("Current certificate is expired")
return nil return nil
} }
return m.cert return m.cert
@ -383,31 +400,35 @@ func (m *manager) ServerHealthy() bool {
// Stop terminates the manager. // Stop terminates the manager.
func (m *manager) Stop() { func (m *manager) Stop() {
m.clientAccessLock.Lock() m.cancel(errors.New("asked to stop"))
defer m.clientAccessLock.Unlock()
if m.stopped {
return
}
close(m.stopCh)
m.stopped = true
} }
// Start will start the background work of rotating the certificates. // Start will start the background work of rotating the certificates.
func (m *manager) Start() { func (m *manager) Start() {
go m.run()
}
// run, in contrast to Start, blocks while the manager is running.
// It waits for all goroutines to stop.
func (m *manager) run() {
logger := klog.FromContext(m.ctx)
// Certificate rotation depends on access to the API server certificate // Certificate rotation depends on access to the API server certificate
// signing API, so don't start the certificate manager if we don't have a // signing API, so don't start the certificate manager if we don't have a
// client. // client.
if m.clientsetFn == nil { if m.clientsetFn == nil {
m.logf("%s: Certificate rotation is not enabled, no connection to the apiserver", m.name) logger.V(2).Info("Certificate rotation is not enabled, no connection to the apiserver")
return return
} }
m.logf("%s: Certificate rotation is enabled", m.name) logger.V(2).Info("Certificate rotation is enabled")
var wg sync.WaitGroup
defer wg.Wait()
templateChanged := make(chan struct{}) templateChanged := make(chan struct{})
go wait.Until(func() { rotate := func(ctx context.Context) {
deadline := m.nextRotationDeadline() deadline := m.nextRotationDeadline(logger)
if sleepInterval := deadline.Sub(m.now()); sleepInterval > 0 { if sleepInterval := deadline.Sub(m.now()); sleepInterval > 0 {
m.logf("%s: Waiting %v for next certificate rotation", m.name, sleepInterval) logger.V(2).Info("Waiting for next certificate rotation", "sleep", sleepInterval)
timer := time.NewTimer(sleepInterval) timer := time.NewTimer(sleepInterval)
defer timer.Stop() defer timer.Stop()
@ -421,7 +442,7 @@ func (m *manager) Start() {
// if the template now matches what we last requested, restart the rotation deadline loop // if the template now matches what we last requested, restart the rotation deadline loop
return return
} }
m.logf("%s: Certificate template changed, rotating", m.name) logger.V(2).Info("Certificate template changed, rotating")
} }
} }
@ -436,18 +457,24 @@ func (m *manager) Start() {
Jitter: 0.1, Jitter: 0.1,
Steps: 5, Steps: 5,
} }
if err := wait.ExponentialBackoff(backoff, m.rotateCerts); err != nil { if err := wait.ExponentialBackoffWithContext(ctx, backoff, m.rotateCerts); err != nil {
utilruntime.HandleError(fmt.Errorf("%s: Reached backoff limit, still unable to rotate certs: %v", m.name, err)) utilruntime.HandleErrorWithContext(ctx, err, "Reached backoff limit, still unable to rotate certs")
wait.PollInfinite(32*time.Second, m.rotateCerts) wait.PollInfiniteWithContext(ctx, 32*time.Second, m.rotateCerts)
} }
}, time.Second, m.stopCh) }
wg.Add(1)
go func() {
defer wg.Done()
wait.UntilWithContext(m.ctx, rotate, time.Second)
}()
if m.dynamicTemplate { if m.dynamicTemplate {
go wait.Until(func() { template := func(ctx context.Context) {
// check if the current template matches what we last requested // check if the current template matches what we last requested
lastRequestCancel, lastRequestTemplate := m.getLastRequest() lastRequestCancel, lastRequestTemplate := m.getLastRequest()
if !m.certSatisfiesTemplate() && !reflect.DeepEqual(lastRequestTemplate, m.getTemplate()) { if !m.certSatisfiesTemplate(logger) && !reflect.DeepEqual(lastRequestTemplate, m.getTemplate()) {
// if the template is different, queue up an interrupt of the rotation deadline loop. // if the template is different, queue up an interrupt of the rotation deadline loop.
// if we've requested a CSR that matches the new template by the time the interrupt is handled, the interrupt is disregarded. // if we've requested a CSR that matches the new template by the time the interrupt is handled, the interrupt is disregarded.
if lastRequestCancel != nil { if lastRequestCancel != nil {
@ -456,14 +483,20 @@ func (m *manager) Start() {
} }
select { select {
case templateChanged <- struct{}{}: case templateChanged <- struct{}{}:
case <-m.stopCh: case <-ctx.Done():
} }
} }
}, time.Second, m.stopCh) }
wg.Add(1)
go func() {
defer wg.Done()
wait.UntilWithContext(m.ctx, template, time.Second)
}()
} }
} }
func getCurrentCertificateOrBootstrap( func getCurrentCertificateOrBootstrap(
logger klog.Logger,
store Store, store Store,
bootstrapCertificatePEM []byte, bootstrapCertificatePEM []byte,
bootstrapKeyPEM []byte) (cert *tls.Certificate, shouldRotate bool, errResult error) { bootstrapKeyPEM []byte) (cert *tls.Certificate, shouldRotate bool, errResult error) {
@ -494,7 +527,7 @@ func getCurrentCertificateOrBootstrap(
certs, err := x509.ParseCertificates(bootstrapCert.Certificate[0]) certs, err := x509.ParseCertificates(bootstrapCert.Certificate[0])
if err != nil { if err != nil {
return nil, false, fmt.Errorf("unable to parse certificate data: %v", err) return nil, false, fmt.Errorf("unable to parse certificate data: %w", err)
} }
if len(certs) < 1 { if len(certs) < 1 {
return nil, false, fmt.Errorf("no cert data found") return nil, false, fmt.Errorf("no cert data found")
@ -502,7 +535,7 @@ func getCurrentCertificateOrBootstrap(
bootstrapCert.Leaf = certs[0] bootstrapCert.Leaf = certs[0]
if _, err := store.Update(bootstrapCertificatePEM, bootstrapKeyPEM); err != nil { if _, err := store.Update(bootstrapCertificatePEM, bootstrapKeyPEM); err != nil {
utilruntime.HandleError(fmt.Errorf("unable to set the cert/key pair to the bootstrap certificate: %v", err)) utilruntime.HandleErrorWithLogger(logger, err, "Unable to set the cert/key pair to the bootstrap certificate")
} }
return &bootstrapCert, true, nil return &bootstrapCert, true, nil
@ -519,7 +552,7 @@ func (m *manager) getClientset() (clientset.Interface, error) {
// Returns true if it changed the cert, false otherwise. Error is only returned in // Returns true if it changed the cert, false otherwise. Error is only returned in
// exceptional cases. // exceptional cases.
func (m *manager) RotateCerts() (bool, error) { func (m *manager) RotateCerts() (bool, error) {
return m.rotateCerts() return m.rotateCerts(m.ctx)
} }
// rotateCerts attempts to request a client cert from the server, wait a reasonable // rotateCerts attempts to request a client cert from the server, wait a reasonable
@ -528,12 +561,13 @@ func (m *manager) RotateCerts() (bool, error) {
// This method also keeps track of "server health" by interpreting the responses it gets // This method also keeps track of "server health" by interpreting the responses it gets
// from the server on the various calls it makes. // from the server on the various calls it makes.
// TODO: return errors, have callers handle and log them correctly // TODO: return errors, have callers handle and log them correctly
func (m *manager) rotateCerts() (bool, error) { func (m *manager) rotateCerts(ctx context.Context) (bool, error) {
m.logf("%s: Rotating certificates", m.name) logger := klog.FromContext(ctx)
logger.V(2).Info("Rotating certificates")
template, csrPEM, keyPEM, privateKey, err := m.generateCSR() template, csrPEM, keyPEM, privateKey, err := m.generateCSR()
if err != nil { if err != nil {
utilruntime.HandleError(fmt.Errorf("%s: Unable to generate a certificate signing request: %v", m.name, err)) utilruntime.HandleErrorWithContext(ctx, err, "Unable to generate a certificate signing request")
if m.certificateRenewFailure != nil { if m.certificateRenewFailure != nil {
m.certificateRenewFailure.Inc() m.certificateRenewFailure.Inc()
} }
@ -543,7 +577,7 @@ func (m *manager) rotateCerts() (bool, error) {
// request the client each time // request the client each time
clientSet, err := m.getClientset() clientSet, err := m.getClientset()
if err != nil { if err != nil {
utilruntime.HandleError(fmt.Errorf("%s: Unable to load a client to request certificates: %v", m.name, err)) utilruntime.HandleErrorWithContext(ctx, err, "Unable to load a client to request certificates")
if m.certificateRenewFailure != nil { if m.certificateRenewFailure != nil {
m.certificateRenewFailure.Inc() m.certificateRenewFailure.Inc()
} }
@ -557,16 +591,16 @@ func (m *manager) rotateCerts() (bool, error) {
usages := getUsages(privateKey) usages := getUsages(privateKey)
// Call the Certificate Signing Request API to get a certificate for the // Call the Certificate Signing Request API to get a certificate for the
// new private key // new private key
reqName, reqUID, err := csr.RequestCertificate(clientSet, csrPEM, "", m.signerName, m.requestedCertificateLifetime, usages, privateKey) reqName, reqUID, err := csr.RequestCertificateWithContext(ctx, clientSet, csrPEM, "", m.signerName, m.requestedCertificateLifetime, usages, privateKey)
if err != nil { if err != nil {
utilruntime.HandleError(fmt.Errorf("%s: Failed while requesting a signed certificate from the control plane: %v", m.name, err)) utilruntime.HandleErrorWithContext(ctx, err, "Failed while requesting a signed certificate from the control plane")
if m.certificateRenewFailure != nil { if m.certificateRenewFailure != nil {
m.certificateRenewFailure.Inc() m.certificateRenewFailure.Inc()
} }
return false, m.updateServerError(err) return false, m.updateServerError(err)
} }
ctx, cancel := context.WithTimeout(context.Background(), certificateWaitTimeout) ctx, cancel := context.WithTimeout(ctx, certificateWaitTimeout)
defer cancel() defer cancel()
// Once we've successfully submitted a CSR for this template, record that we did so // Once we've successfully submitted a CSR for this template, record that we did so
@ -576,7 +610,7 @@ func (m *manager) rotateCerts() (bool, error) {
// is a remainder after the old design using raw watch wrapped with backoff. // is a remainder after the old design using raw watch wrapped with backoff.
crtPEM, err := csr.WaitForCertificate(ctx, clientSet, reqName, reqUID) crtPEM, err := csr.WaitForCertificate(ctx, clientSet, reqName, reqUID)
if err != nil { if err != nil {
utilruntime.HandleError(fmt.Errorf("%s: certificate request was not signed: %v", m.name, err)) utilruntime.HandleErrorWithContext(ctx, err, "Certificate request was not signed")
if m.certificateRenewFailure != nil { if m.certificateRenewFailure != nil {
m.certificateRenewFailure.Inc() m.certificateRenewFailure.Inc()
} }
@ -585,7 +619,7 @@ func (m *manager) rotateCerts() (bool, error) {
cert, err := m.certStore.Update(crtPEM, keyPEM) cert, err := m.certStore.Update(crtPEM, keyPEM)
if err != nil { if err != nil {
utilruntime.HandleError(fmt.Errorf("%s: Unable to store the new cert/key pair: %v", m.name, err)) utilruntime.HandleErrorWithContext(ctx, err, "Unable to store the new cert/key pair")
if m.certificateRenewFailure != nil { if m.certificateRenewFailure != nil {
m.certificateRenewFailure.Inc() m.certificateRenewFailure.Inc()
} }
@ -606,14 +640,14 @@ func (m *manager) rotateCerts() (bool, error) {
// the template will not trigger a renewal. // the template will not trigger a renewal.
// //
// Requires certAccessLock to be locked. // Requires certAccessLock to be locked.
func (m *manager) certSatisfiesTemplateLocked() bool { func (m *manager) certSatisfiesTemplateLocked(logger klog.Logger) bool {
if m.cert == nil { if m.cert == nil {
return false return false
} }
if template := m.getTemplate(); template != nil { if template := m.getTemplate(); template != nil {
if template.Subject.CommonName != m.cert.Leaf.Subject.CommonName { if template.Subject.CommonName != m.cert.Leaf.Subject.CommonName {
m.logf("%s: Current certificate CN (%s) does not match requested CN (%s)", m.name, m.cert.Leaf.Subject.CommonName, template.Subject.CommonName) logger.V(2).Info("Current certificate CN does not match requested CN", "currentName", m.cert.Leaf.Subject.CommonName, "requestedName", template.Subject.CommonName)
return false return false
} }
@ -621,7 +655,7 @@ func (m *manager) certSatisfiesTemplateLocked() bool {
desiredDNSNames := sets.NewString(template.DNSNames...) desiredDNSNames := sets.NewString(template.DNSNames...)
missingDNSNames := desiredDNSNames.Difference(currentDNSNames) missingDNSNames := desiredDNSNames.Difference(currentDNSNames)
if len(missingDNSNames) > 0 { if len(missingDNSNames) > 0 {
m.logf("%s: Current certificate is missing requested DNS names %v", m.name, missingDNSNames.List()) logger.V(2).Info("Current certificate is missing requested DNS names", "dnsNames", missingDNSNames.List())
return false return false
} }
@ -635,7 +669,7 @@ func (m *manager) certSatisfiesTemplateLocked() bool {
} }
missingIPs := desiredIPs.Difference(currentIPs) missingIPs := desiredIPs.Difference(currentIPs)
if len(missingIPs) > 0 { if len(missingIPs) > 0 {
m.logf("%s: Current certificate is missing requested IP addresses %v", m.name, missingIPs.List()) logger.V(2).Info("Current certificate is missing requested IP addresses", "IPs", missingIPs.List())
return false return false
} }
@ -643,7 +677,7 @@ func (m *manager) certSatisfiesTemplateLocked() bool {
desiredOrgs := sets.NewString(template.Subject.Organization...) desiredOrgs := sets.NewString(template.Subject.Organization...)
missingOrgs := desiredOrgs.Difference(currentOrgs) missingOrgs := desiredOrgs.Difference(currentOrgs)
if len(missingOrgs) > 0 { if len(missingOrgs) > 0 {
m.logf("%s: Current certificate is missing requested orgs %v", m.name, missingOrgs.List()) logger.V(2).Info("Current certificate is missing requested orgs", "orgs", missingOrgs.List())
return false return false
} }
} }
@ -651,16 +685,16 @@ func (m *manager) certSatisfiesTemplateLocked() bool {
return true return true
} }
func (m *manager) certSatisfiesTemplate() bool { func (m *manager) certSatisfiesTemplate(logger klog.Logger) bool {
m.certAccessLock.RLock() m.certAccessLock.RLock()
defer m.certAccessLock.RUnlock() defer m.certAccessLock.RUnlock()
return m.certSatisfiesTemplateLocked() return m.certSatisfiesTemplateLocked(logger)
} }
// nextRotationDeadline returns a value for the threshold at which the // nextRotationDeadline returns a value for the threshold at which the
// current certificate should be rotated, 80%+/-10% of the expiration of the // current certificate should be rotated, 80%+/-10% of the expiration of the
// certificate. // certificate.
func (m *manager) nextRotationDeadline() time.Time { func (m *manager) nextRotationDeadline(logger klog.Logger) time.Time {
// forceRotation is not protected by locks // forceRotation is not protected by locks
if m.forceRotation { if m.forceRotation {
m.forceRotation = false m.forceRotation = false
@ -670,7 +704,7 @@ func (m *manager) nextRotationDeadline() time.Time {
m.certAccessLock.RLock() m.certAccessLock.RLock()
defer m.certAccessLock.RUnlock() defer m.certAccessLock.RUnlock()
if !m.certSatisfiesTemplateLocked() { if !m.certSatisfiesTemplateLocked(logger) {
return m.now() return m.now()
} }
@ -678,7 +712,7 @@ func (m *manager) nextRotationDeadline() time.Time {
totalDuration := float64(notAfter.Sub(m.cert.Leaf.NotBefore)) totalDuration := float64(notAfter.Sub(m.cert.Leaf.NotBefore))
deadline := m.cert.Leaf.NotBefore.Add(jitteryDuration(totalDuration)) deadline := m.cert.Leaf.NotBefore.Add(jitteryDuration(totalDuration))
m.logf("%s: Certificate expiration is %v, rotation deadline is %v", m.name, notAfter, deadline) logger.V(2).Info("Certificate rotation deadline determined", "expiration", notAfter, "deadline", deadline)
return deadline return deadline
} }
@ -732,22 +766,22 @@ func (m *manager) generateCSR() (template *x509.CertificateRequest, csrPEM []byt
// Generate a new private key. // Generate a new private key.
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), cryptorand.Reader) privateKey, err := ecdsa.GenerateKey(elliptic.P256(), cryptorand.Reader)
if err != nil { if err != nil {
return nil, nil, nil, nil, fmt.Errorf("%s: unable to generate a new private key: %v", m.name, err) return nil, nil, nil, nil, fmt.Errorf("unable to generate a new private key: %w", err)
} }
der, err := x509.MarshalECPrivateKey(privateKey) der, err := x509.MarshalECPrivateKey(privateKey)
if err != nil { if err != nil {
return nil, nil, nil, nil, fmt.Errorf("%s: unable to marshal the new key to DER: %v", m.name, err) return nil, nil, nil, nil, fmt.Errorf("unable to marshal the new key to DER: %w", err)
} }
keyPEM = pem.EncodeToMemory(&pem.Block{Type: keyutil.ECPrivateKeyBlockType, Bytes: der}) keyPEM = pem.EncodeToMemory(&pem.Block{Type: keyutil.ECPrivateKeyBlockType, Bytes: der})
template = m.getTemplate() template = m.getTemplate()
if template == nil { if template == nil {
return nil, nil, nil, nil, fmt.Errorf("%s: unable to create a csr, no template available", m.name) return nil, nil, nil, nil, errors.New("unable to create a csr, no template available")
} }
csrPEM, err = cert.MakeCSRFromTemplate(privateKey, template) csrPEM, err = cert.MakeCSRFromTemplate(privateKey, template)
if err != nil { if err != nil {
return nil, nil, nil, nil, fmt.Errorf("%s: unable to create a csr from the private key: %v", m.name, err) return nil, nil, nil, nil, fmt.Errorf("unable to create a csr from the private key: %w", err)
} }
return template, csrPEM, keyPEM, privateKey, nil return template, csrPEM, keyPEM, privateKey, nil
} }

View File

@ -18,15 +18,20 @@ package certificate
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"errors"
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
certificatesv1 "k8s.io/api/certificates/v1" certificatesv1 "k8s.io/api/certificates/v1"
certificatesv1beta1 "k8s.io/api/certificates/v1beta1" certificatesv1beta1 "k8s.io/api/certificates/v1beta1"
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
@ -38,6 +43,8 @@ import (
"k8s.io/client-go/kubernetes/fake" "k8s.io/client-go/kubernetes/fake"
certificatesclient "k8s.io/client-go/kubernetes/typed/certificates/v1beta1" certificatesclient "k8s.io/client-go/kubernetes/typed/certificates/v1beta1"
clienttesting "k8s.io/client-go/testing" clienttesting "k8s.io/client-go/testing"
"k8s.io/klog/v2"
"k8s.io/klog/v2/ktesting"
netutils "k8s.io/utils/net" netutils "k8s.io/utils/net"
) )
@ -268,6 +275,7 @@ func TestSetRotationDeadline(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
logger, ctx := ktesting.NewTestContext(t)
m := manager{ m := manager{
cert: &tls.Certificate{ cert: &tls.Certificate{
Leaf: &x509.Certificate{ Leaf: &x509.Certificate{
@ -277,12 +285,12 @@ func TestSetRotationDeadline(t *testing.T) {
}, },
getTemplate: func() *x509.CertificateRequest { return &x509.CertificateRequest{} }, getTemplate: func() *x509.CertificateRequest { return &x509.CertificateRequest{} },
now: func() time.Time { return now }, now: func() time.Time { return now },
logf: t.Logf, ctx: ctx,
} }
jitteryDuration = func(float64) time.Duration { return time.Duration(float64(tc.notAfter.Sub(tc.notBefore)) * 0.7) } jitteryDuration = func(float64) time.Duration { return time.Duration(float64(tc.notAfter.Sub(tc.notBefore)) * 0.7) }
lowerBound := tc.notBefore.Add(time.Duration(float64(tc.notAfter.Sub(tc.notBefore)) * 0.7)) lowerBound := tc.notBefore.Add(time.Duration(float64(tc.notAfter.Sub(tc.notBefore)) * 0.7))
deadline := m.nextRotationDeadline() deadline := m.nextRotationDeadline(logger)
if !deadline.Equal(lowerBound) { if !deadline.Equal(lowerBound) {
t.Errorf("For notBefore %v, notAfter %v, the rotationDeadline %v should be %v.", t.Errorf("For notBefore %v, notAfter %v, the rotationDeadline %v should be %v.",
@ -438,6 +446,7 @@ func TestCertSatisfiesTemplate(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
logger, ctx := ktesting.NewTestContext(t)
var tlsCert *tls.Certificate var tlsCert *tls.Certificate
if tc.cert != nil { if tc.cert != nil {
@ -450,10 +459,10 @@ func TestCertSatisfiesTemplate(t *testing.T) {
cert: tlsCert, cert: tlsCert,
getTemplate: func() *x509.CertificateRequest { return tc.template }, getTemplate: func() *x509.CertificateRequest { return tc.template },
now: time.Now, now: time.Now,
logf: t.Logf, ctx: ctx,
} }
result := m.certSatisfiesTemplate() result := m.certSatisfiesTemplate(logger)
if result != tc.shouldSatisfy { if result != tc.shouldSatisfy {
t.Errorf("cert: %+v, template: %+v, certSatisfiesTemplate returned %v, want %v", m.cert, tc.template, result, tc.shouldSatisfy) t.Errorf("cert: %+v, template: %+v, certSatisfiesTemplate returned %v, want %v", m.cert, tc.template, result, tc.shouldSatisfy)
} }
@ -462,6 +471,7 @@ func TestCertSatisfiesTemplate(t *testing.T) {
} }
func TestRotateCertCreateCSRError(t *testing.T) { func TestRotateCertCreateCSRError(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
now := time.Now() now := time.Now()
m := manager{ m := manager{
cert: &tls.Certificate{ cert: &tls.Certificate{
@ -474,11 +484,11 @@ func TestRotateCertCreateCSRError(t *testing.T) {
clientsetFn: func(_ *tls.Certificate) (clientset.Interface, error) { clientsetFn: func(_ *tls.Certificate) (clientset.Interface, error) {
return newClientset(fakeClient{failureType: createError}), nil return newClientset(fakeClient{failureType: createError}), nil
}, },
now: func() time.Time { return now }, now: func() time.Time { return now },
logf: t.Logf, ctx: ctx,
} }
if success, err := m.rotateCerts(); success { if success, err := m.rotateCerts(ctx); success {
t.Errorf("Got success from 'rotateCerts', wanted failure") t.Errorf("Got success from 'rotateCerts', wanted failure")
} else if err != nil { } else if err != nil {
t.Errorf("Got error %v from 'rotateCerts', wanted no error.", err) t.Errorf("Got error %v from 'rotateCerts', wanted no error.", err)
@ -486,6 +496,7 @@ func TestRotateCertCreateCSRError(t *testing.T) {
} }
func TestRotateCertWaitingForResultError(t *testing.T) { func TestRotateCertWaitingForResultError(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
now := time.Now() now := time.Now()
m := manager{ m := manager{
cert: &tls.Certificate{ cert: &tls.Certificate{
@ -498,13 +509,13 @@ func TestRotateCertWaitingForResultError(t *testing.T) {
clientsetFn: func(_ *tls.Certificate) (clientset.Interface, error) { clientsetFn: func(_ *tls.Certificate) (clientset.Interface, error) {
return newClientset(fakeClient{failureType: watchError}), nil return newClientset(fakeClient{failureType: watchError}), nil
}, },
now: func() time.Time { return now }, now: func() time.Time { return now },
logf: t.Logf, ctx: ctx,
} }
defer func(t time.Duration) { certificateWaitTimeout = t }(certificateWaitTimeout) defer func(t time.Duration) { certificateWaitTimeout = t }(certificateWaitTimeout)
certificateWaitTimeout = 1 * time.Millisecond certificateWaitTimeout = 1 * time.Millisecond
if success, err := m.rotateCerts(); success { if success, err := m.rotateCerts(ctx); success {
t.Errorf("Got success from 'rotateCerts', wanted failure.") t.Errorf("Got success from 'rotateCerts', wanted failure.")
} else if err != nil { } else if err != nil {
t.Errorf("Got error %v from 'rotateCerts', wanted no error.", err) t.Errorf("Got error %v from 'rotateCerts', wanted no error.", err)
@ -610,11 +621,13 @@ func TestGetCurrentCertificateOrBootstrap(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
logger, _ := ktesting.NewTestContext(t)
store := &fakeStore{ store := &fakeStore{
cert: tc.storeCert, cert: tc.storeCert,
} }
certResult, shouldRotate, err := getCurrentCertificateOrBootstrap( certResult, shouldRotate, err := getCurrentCertificateOrBootstrap(
logger,
store, store,
tc.bootstrapCertData, tc.bootstrapCertData,
tc.bootstrapKeyData) tc.bootstrapKeyData)
@ -717,6 +730,7 @@ func TestInitializeCertificateSigningRequestClient(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
certificateStore := &fakeStore{ certificateStore := &fakeStore{
cert: tc.storeCert.certificate, cert: tc.storeCert.certificate,
} }
@ -744,6 +758,7 @@ func TestInitializeCertificateSigningRequestClient(t *testing.T) {
certificatePEM: tc.apiCert.certificatePEM, certificatePEM: tc.apiCert.certificatePEM,
}), nil }), nil
}, },
Ctx: &ctx,
}) })
if err != nil { if err != nil {
t.Errorf("Got %v, wanted no error.", err) t.Errorf("Got %v, wanted no error.", err)
@ -764,7 +779,7 @@ func TestInitializeCertificateSigningRequestClient(t *testing.T) {
t.Errorf("Expected a '*manager' from 'NewManager'") t.Errorf("Expected a '*manager' from 'NewManager'")
} else { } else {
if m.forceRotation { if m.forceRotation {
if success, err := m.rotateCerts(); !success { if success, err := m.rotateCerts(ctx); !success {
t.Errorf("Got failure from 'rotateCerts', wanted success.") t.Errorf("Got failure from 'rotateCerts', wanted success.")
} else if err != nil { } else if err != nil {
t.Errorf("Got error %v, expected none.", err) t.Errorf("Got error %v, expected none.", err)
@ -832,6 +847,7 @@ func TestInitializeOtherRESTClients(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
certificateStore := &fakeStore{ certificateStore := &fakeStore{
cert: tc.storeCert.certificate, cert: tc.storeCert.certificate,
} }
@ -870,7 +886,7 @@ func TestInitializeOtherRESTClients(t *testing.T) {
t.Errorf("Expected a '*manager' from 'NewManager'") t.Errorf("Expected a '*manager' from 'NewManager'")
} else { } else {
if m.forceRotation { if m.forceRotation {
success, err := certificateManager.(*manager).rotateCerts() success, err := certificateManager.(*manager).rotateCerts(ctx)
if err != nil { if err != nil {
t.Errorf("Got error %v, expected none.", err) t.Errorf("Got error %v, expected none.", err)
return return
@ -977,6 +993,7 @@ func TestServerHealth(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
certificateStore := &fakeStore{ certificateStore := &fakeStore{
cert: tc.storeCert.certificate, cert: tc.storeCert.certificate,
} }
@ -1016,7 +1033,7 @@ func TestServerHealth(t *testing.T) {
if _, ok := certificateManager.(*manager); !ok { if _, ok := certificateManager.(*manager); !ok {
t.Errorf("Expected a '*manager' from 'NewManager'") t.Errorf("Expected a '*manager' from 'NewManager'")
} else { } else {
success, err := certificateManager.(*manager).rotateCerts() success, err := certificateManager.(*manager).rotateCerts(ctx)
if err != nil { if err != nil {
t.Errorf("Got error %v, expected none.", err) t.Errorf("Got error %v, expected none.", err)
return return
@ -1039,6 +1056,7 @@ func TestServerHealth(t *testing.T) {
} }
func TestRotationLogsDuration(t *testing.T) { func TestRotationLogsDuration(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
h := metricMock{} h := metricMock{}
now := time.Now() now := time.Now()
certIss := now.Add(-2 * time.Hour) certIss := now.Add(-2 * time.Hour)
@ -1058,9 +1076,9 @@ func TestRotationLogsDuration(t *testing.T) {
}, },
certificateRotation: &h, certificateRotation: &h,
now: func() time.Time { return now }, now: func() time.Time { return now },
logf: t.Logf, ctx: ctx,
} }
ok, err := m.rotateCerts() ok, err := m.rotateCerts(ctx)
if err != nil || !ok { if err != nil || !ok {
t.Errorf("failed to rotate certs: %v", err) t.Errorf("failed to rotate certs: %v", err)
} }
@ -1073,6 +1091,101 @@ func TestRotationLogsDuration(t *testing.T) {
} }
func TestStop(t *testing.T) {
// No certificate yet, will be added while manager runs.
store := &fakeStore{}
m, err := NewManager(&Config{
GetTemplate: func() *x509.CertificateRequest {
return &x509.CertificateRequest{
Subject: pkix.Name{
Organization: []string{"system:nodes"},
CommonName: "system:node:fake-node-name",
},
}
},
Usages: []certificatesv1.KeyUsage{},
CertificateStore: store,
ClientsetFn: func(_ *tls.Certificate) (clientset.Interface, error) {
return newClientset(fakeClient{
certificatePEM: apiServerCertData.certificatePEM,
}), nil
},
})
require.NoError(t, err, "initialize the certificate manager")
require.Nil(t, m.Current(), "no certificate yet")
// Run the manager and stop it when the test cleans up.
var wg sync.WaitGroup
defer func() {
t.Log("Waiting for manager to stop...")
wg.Wait()
}()
wg.Add(1)
go func() {
defer wg.Done()
m.(*manager).run()
}()
defer m.Stop()
require.Eventually(t, func() bool {
return m.Current() != nil
}, 10*time.Second, time.Microsecond, "current certificate")
}
func TestContext(t *testing.T) {
logger := ktesting.NewLogger(t, ktesting.NewConfig(
ktesting.BufferLogs(true),
ktesting.Verbosity(2),
))
ctx := klog.NewContext(context.Background(), logger)
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(errors.New("test is done"))
// No certificate yet, will be added while manager runs.
store := &fakeStore{}
m, err := NewManager(&Config{
Ctx: &ctx,
GetTemplate: func() *x509.CertificateRequest {
return &x509.CertificateRequest{
Subject: pkix.Name{
Organization: []string{"system:nodes"},
CommonName: "system:node:fake-node-name",
},
}
},
Usages: []certificatesv1.KeyUsage{},
CertificateStore: store,
ClientsetFn: func(_ *tls.Certificate) (clientset.Interface, error) {
return newClientset(fakeClient{
certificatePEM: apiServerCertData.certificatePEM,
}), nil
},
})
require.NoError(t, err, "initialize the certificate manager")
require.Nil(t, m.Current(), "no certificate yet")
// Run the manager and stop it when the test cleans up.
var wg sync.WaitGroup
defer func() {
t.Log("Waiting for manager to stop...")
wg.Wait()
// Must be a no-op.
m.Stop()
}()
wg.Add(1)
go func() {
defer wg.Done()
m.(*manager).run()
}()
defer cancel(errors.New("testing context cancellation"))
require.Eventually(t, func() bool {
return m.Current() != nil
}, 10*time.Second, time.Microsecond, "current certificate")
require.Contains(t, logger.GetSink().(ktesting.Underlier).GetBuffer().String(), "certificate: Rotating certificates", "contextual log output from manager.rotateCerts")
}
type fakeClientFailureType int type fakeClientFailureType int
const ( const (
@ -1243,10 +1356,14 @@ func (w *fakeWatch) ResultChan() <-chan watch.Event {
} }
type fakeStore struct { type fakeStore struct {
cert *tls.Certificate mutex sync.Mutex
cert *tls.Certificate
} }
func (s *fakeStore) Current() (*tls.Certificate, error) { func (s *fakeStore) Current() (*tls.Certificate, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.cert == nil { if s.cert == nil {
noKeyErr := NoCertKeyError("") noKeyErr := NoCertKeyError("")
return nil, &noKeyErr return nil, &noKeyErr
@ -1258,6 +1375,9 @@ func (s *fakeStore) Current() (*tls.Certificate, error) {
// pair the 'current' pair, that will be returned by future calls to // pair the 'current' pair, that will be returned by future calls to
// Current(). // Current().
func (s *fakeStore) Update(certPEM, keyPEM []byte) (*tls.Certificate, error) { func (s *fakeStore) Update(certPEM, keyPEM []byte) (*tls.Certificate, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
// In order to make the mocking work, whenever a cert/key pair is passed in // In order to make the mocking work, whenever a cert/key pair is passed in
// to be updated in the mock store, assume that the certificate manager // to be updated in the mock store, assume that the certificate manager
// generated the key, and then asked the mock CertificateSigningRequest API // generated the key, and then asked the mock CertificateSigningRequest API

View File

@ -38,6 +38,7 @@ const (
) )
type fileStore struct { type fileStore struct {
logger klog.Logger
pairNamePrefix string pairNamePrefix string
certDirectory string certDirectory string
keyDirectory string keyDirectory string
@ -67,14 +68,30 @@ type FileStore interface {
// updates will be written to the ${certDirectory} directory and // updates will be written to the ${certDirectory} directory and
// ${certDirectory}/${pairNamePrefix}-current.pem will be created as a soft // ${certDirectory}/${pairNamePrefix}-current.pem will be created as a soft
// link to the currently selected cert/key pair. // link to the currently selected cert/key pair.
//
// Contextual logging: NewFileStoreWithLogger should be used instead of NewFileStore in code which supports contextual logging.
func NewFileStore( func NewFileStore(
pairNamePrefix string, pairNamePrefix string,
certDirectory string, certDirectory string,
keyDirectory string, keyDirectory string,
certFile string, certFile string,
keyFile string) (FileStore, error) { keyFile string) (FileStore, error) {
return NewFileStoreWithLogger(klog.Background(), pairNamePrefix, certDirectory, keyDirectory, certFile, keyFile)
}
// NewFileStoreWithLogger is a variant of NewFileStore where the caller is in
// control of logging. All log messages get emitted with logger.Info, so
// pass e.g. logger.V(3) to make logging less verbose.
func NewFileStoreWithLogger(
logger klog.Logger,
pairNamePrefix string,
certDirectory string,
keyDirectory string,
certFile string,
keyFile string) (FileStore, error) {
s := fileStore{ s := fileStore{
logger: logger,
pairNamePrefix: pairNamePrefix, pairNamePrefix: pairNamePrefix,
certDirectory: certDirectory, certDirectory: certDirectory,
keyDirectory: keyDirectory, keyDirectory: keyDirectory,
@ -127,7 +144,7 @@ func (s *fileStore) Current() (*tls.Certificate, error) {
if pairFileExists, err := fileExists(pairFile); err != nil { if pairFileExists, err := fileExists(pairFile); err != nil {
return nil, err return nil, err
} else if pairFileExists { } else if pairFileExists {
klog.Infof("Loading cert/key pair from %q.", pairFile) s.logger.Info("Loading cert/key pair from a file", "filePath", pairFile)
return loadFile(pairFile) return loadFile(pairFile)
} }
@ -140,7 +157,7 @@ func (s *fileStore) Current() (*tls.Certificate, error) {
return nil, err return nil, err
} }
if certFileExists && keyFileExists { if certFileExists && keyFileExists {
klog.Infof("Loading cert/key pair from (%q, %q).", s.certFile, s.keyFile) s.logger.Info("Loading cert/key pair", "certFile", s.certFile, "keyFile", s.keyFile)
return loadX509KeyPair(s.certFile, s.keyFile) return loadX509KeyPair(s.certFile, s.keyFile)
} }
@ -155,7 +172,7 @@ func (s *fileStore) Current() (*tls.Certificate, error) {
return nil, err return nil, err
} }
if certFileExists && keyFileExists { if certFileExists && keyFileExists {
klog.Infof("Loading cert/key pair from (%q, %q).", c, k) s.logger.Info("Loading cert/key pair", "certFile", c, "keyFile", k)
return loadX509KeyPair(c, k) return loadX509KeyPair(c, k)
} }

View File

@ -235,6 +235,7 @@ func TestUpdateNoRotation(t *testing.T) {
t.Fatalf("Unable to create the file %q: %v", certFile, err) t.Fatalf("Unable to create the file %q: %v", certFile, err)
} }
//nolint:logcheck // Intentionally uses the old API.
s, err := NewFileStore(prefix, dir, dir, certFile, keyFile) s, err := NewFileStore(prefix, dir, dir, certFile, keyFile)
if err != nil { if err != nil {
t.Fatalf("Got %v while creating a new store.", err) t.Fatalf("Got %v while creating a new store.", err)
@ -269,6 +270,7 @@ func TestUpdateRotation(t *testing.T) {
t.Fatalf("Unable to create the file %q: %v", certFile, err) t.Fatalf("Unable to create the file %q: %v", certFile, err)
} }
//nolint:logcheck // Intentionally uses the old API.
s, err := NewFileStore(prefix, dir, dir, certFile, keyFile) s, err := NewFileStore(prefix, dir, dir, certFile, keyFile)
if err != nil { if err != nil {
t.Fatalf("Got %v while creating a new store.", err) t.Fatalf("Got %v while creating a new store.", err)
@ -303,6 +305,7 @@ func TestUpdateTwoCerts(t *testing.T) {
t.Fatalf("Unable to create the file %q: %v", certFile, err) t.Fatalf("Unable to create the file %q: %v", certFile, err)
} }
//nolint:logcheck // Intentionally uses the old API.
s, err := NewFileStore(prefix, dir, dir, certFile, keyFile) s, err := NewFileStore(prefix, dir, dir, certFile, keyFile)
if err != nil { if err != nil {
t.Fatalf("Got %v while creating a new store.", err) t.Fatalf("Got %v while creating a new store.", err)
@ -340,6 +343,7 @@ func TestUpdateWithBadCertKeyData(t *testing.T) {
t.Fatalf("Unable to create the file %q: %v", certFile, err) t.Fatalf("Unable to create the file %q: %v", certFile, err)
} }
//nolint:logcheck // Intentionally uses the old API.
s, err := NewFileStore(prefix, dir, dir, certFile, keyFile) s, err := NewFileStore(prefix, dir, dir, certFile, keyFile)
if err != nil { if err != nil {
t.Fatalf("Got %v while creating a new store.", err) t.Fatalf("Got %v while creating a new store.", err)
@ -376,6 +380,7 @@ func TestCurrentPairFile(t *testing.T) {
t.Fatalf("unable to create a symlink from %q to %q: %v", currentFile, pairFile, err) t.Fatalf("unable to create a symlink from %q to %q: %v", currentFile, pairFile, err)
} }
//nolint:logcheck // Intentionally uses the old API.
store, err := NewFileStore("kubelet-server", dir, dir, "", "") store, err := NewFileStore("kubelet-server", dir, dir, "", "")
if err != nil { if err != nil {
t.Fatalf("Failed to initialize certificate store: %v", err) t.Fatalf("Failed to initialize certificate store: %v", err)
@ -413,6 +418,7 @@ func TestCurrentCertKeyFiles(t *testing.T) {
t.Fatalf("Unable to create the file %q: %v", keyFile, err) t.Fatalf("Unable to create the file %q: %v", keyFile, err)
} }
//nolint:logcheck // Intentionally uses the old API.
store, err := NewFileStore(prefix, dir, dir, certFile, keyFile) store, err := NewFileStore(prefix, dir, dir, certFile, keyFile)
if err != nil { if err != nil {
t.Fatalf("Failed to initialize certificate store: %v", err) t.Fatalf("Failed to initialize certificate store: %v", err)
@ -450,6 +456,7 @@ func TestCurrentTwoCerts(t *testing.T) {
t.Fatalf("Unable to create the file %q: %v", keyFile, err) t.Fatalf("Unable to create the file %q: %v", keyFile, err)
} }
//nolint:logcheck // Intentionally uses the old API.
store, err := NewFileStore(prefix, dir, dir, certFile, keyFile) store, err := NewFileStore(prefix, dir, dir, certFile, keyFile)
if err != nil { if err != nil {
t.Fatalf("Failed to initialize certificate store: %v", err) t.Fatalf("Failed to initialize certificate store: %v", err)
@ -481,6 +488,7 @@ func TestCurrentNoFiles(t *testing.T) {
} }
}() }()
//nolint:logcheck // Intentionally uses the old API.
store, err := NewFileStore("kubelet-server", dir, dir, "", "") store, err := NewFileStore("kubelet-server", dir, dir, "", "")
if err != nil { if err != nil {
t.Fatalf("Failed to initialize certificate store: %v", err) t.Fatalf("Failed to initialize certificate store: %v", err)

View File

@ -47,7 +47,18 @@ import (
// PEM encoded CSR and send it to API server. An optional requestedDuration may be passed // PEM encoded CSR and send it to API server. An optional requestedDuration may be passed
// to set the spec.expirationSeconds field on the CSR to control the lifetime of the issued // to set the spec.expirationSeconds field on the CSR to control the lifetime of the issued
// certificate. This is not guaranteed as the signer may choose to ignore the request. // certificate. This is not guaranteed as the signer may choose to ignore the request.
//
// Deprecated: use RequestCertificateWithContext instead.
func RequestCertificate(client clientset.Interface, csrData []byte, name, signerName string, requestedDuration *time.Duration, usages []certificatesv1.KeyUsage, privateKey interface{}) (reqName string, reqUID types.UID, err error) { func RequestCertificate(client clientset.Interface, csrData []byte, name, signerName string, requestedDuration *time.Duration, usages []certificatesv1.KeyUsage, privateKey interface{}) (reqName string, reqUID types.UID, err error) {
return RequestCertificateWithContext(context.Background(), client, csrData, name, signerName, requestedDuration, usages, privateKey)
}
// RequestCertificateWithContext will either use an existing (if this process has run
// before but not to completion) or create a certificate signing request using the
// PEM encoded CSR and send it to API server. An optional requestedDuration may be passed
// to set the spec.expirationSeconds field on the CSR to control the lifetime of the issued
// certificate. This is not guaranteed as the signer may choose to ignore the request.
func RequestCertificateWithContext(ctx context.Context, client clientset.Interface, csrData []byte, name, signerName string, requestedDuration *time.Duration, usages []certificatesv1.KeyUsage, privateKey interface{}) (reqName string, reqUID types.UID, err error) {
csr := &certificatesv1.CertificateSigningRequest{ csr := &certificatesv1.CertificateSigningRequest{
// Username, UID, Groups will be injected by API server. // Username, UID, Groups will be injected by API server.
TypeMeta: metav1.TypeMeta{Kind: "CertificateSigningRequest"}, TypeMeta: metav1.TypeMeta{Kind: "CertificateSigningRequest"},
@ -67,21 +78,22 @@ func RequestCertificate(client clientset.Interface, csrData []byte, name, signer
csr.Spec.ExpirationSeconds = DurationToExpirationSeconds(*requestedDuration) csr.Spec.ExpirationSeconds = DurationToExpirationSeconds(*requestedDuration)
} }
reqName, reqUID, err = create(client, csr) reqName, reqUID, err = create(ctx, client, csr)
switch { switch {
case err == nil: case err == nil:
return reqName, reqUID, err return reqName, reqUID, err
case apierrors.IsAlreadyExists(err) && len(name) > 0: case apierrors.IsAlreadyExists(err) && len(name) > 0:
klog.Infof("csr for this node already exists, reusing") logger := klog.FromContext(ctx)
req, err := get(client, name) logger.Info("csr for this node already exists, reusing")
req, err := get(ctx, client, name)
if err != nil { if err != nil {
return "", "", formatError("cannot retrieve certificate signing request: %v", err) return "", "", formatError("cannot retrieve certificate signing request: %v", err)
} }
if err := ensureCompatible(req, csr, privateKey); err != nil { if err := ensureCompatible(req, csr, privateKey); err != nil {
return "", "", fmt.Errorf("retrieved csr is not compatible: %v", err) return "", "", fmt.Errorf("retrieved csr is not compatible: %v", err)
} }
klog.Infof("csr for this node is still valid") logger.Info("csr for this node is still valid")
return req.Name, req.UID, nil return req.Name, req.UID, nil
default: default:
@ -97,13 +109,13 @@ func ExpirationSecondsToDuration(expirationSeconds int32) time.Duration {
return time.Duration(expirationSeconds) * time.Second return time.Duration(expirationSeconds) * time.Second
} }
func get(client clientset.Interface, name string) (*certificatesv1.CertificateSigningRequest, error) { func get(ctx context.Context, client clientset.Interface, name string) (*certificatesv1.CertificateSigningRequest, error) {
v1req, v1err := client.CertificatesV1().CertificateSigningRequests().Get(context.TODO(), name, metav1.GetOptions{}) v1req, v1err := client.CertificatesV1().CertificateSigningRequests().Get(ctx, name, metav1.GetOptions{})
if v1err == nil || !apierrors.IsNotFound(v1err) { if v1err == nil || !apierrors.IsNotFound(v1err) {
return v1req, v1err return v1req, v1err
} }
v1beta1req, v1beta1err := client.CertificatesV1beta1().CertificateSigningRequests().Get(context.TODO(), name, metav1.GetOptions{}) v1beta1req, v1beta1err := client.CertificatesV1beta1().CertificateSigningRequests().Get(ctx, name, metav1.GetOptions{})
if v1beta1err != nil { if v1beta1err != nil {
return nil, v1beta1err return nil, v1beta1err
} }
@ -123,10 +135,10 @@ func get(client clientset.Interface, name string) (*certificatesv1.CertificateSi
return v1req, nil return v1req, nil
} }
func create(client clientset.Interface, csr *certificatesv1.CertificateSigningRequest) (reqName string, reqUID types.UID, err error) { func create(ctx context.Context, client clientset.Interface, csr *certificatesv1.CertificateSigningRequest) (reqName string, reqUID types.UID, err error) {
// only attempt a create via v1 if we specified signerName and usages and are not using the legacy unknown signerName // only attempt a create via v1 if we specified signerName and usages and are not using the legacy unknown signerName
if len(csr.Spec.Usages) > 0 && len(csr.Spec.SignerName) > 0 && csr.Spec.SignerName != "kubernetes.io/legacy-unknown" { if len(csr.Spec.Usages) > 0 && len(csr.Spec.SignerName) > 0 && csr.Spec.SignerName != "kubernetes.io/legacy-unknown" {
v1req, v1err := client.CertificatesV1().CertificateSigningRequests().Create(context.TODO(), csr, metav1.CreateOptions{}) v1req, v1err := client.CertificatesV1().CertificateSigningRequests().Create(ctx, csr, metav1.CreateOptions{})
switch { switch {
case v1err != nil && apierrors.IsNotFound(v1err): case v1err != nil && apierrors.IsNotFound(v1err):
// v1 CSR API was not found, continue to try v1beta1 // v1 CSR API was not found, continue to try v1beta1
@ -154,7 +166,7 @@ func create(client clientset.Interface, csr *certificatesv1.CertificateSigningRe
} }
// create v1beta1 // create v1beta1
v1beta1req, v1beta1err := client.CertificatesV1beta1().CertificateSigningRequests().Create(context.TODO(), v1beta1csr, metav1.CreateOptions{}) v1beta1req, v1beta1err := client.CertificatesV1beta1().CertificateSigningRequests().Create(ctx, v1beta1csr, metav1.CreateOptions{})
if v1beta1err != nil { if v1beta1err != nil {
return "", "", v1beta1err return "", "", v1beta1err
} }
@ -164,6 +176,7 @@ func create(client clientset.Interface, csr *certificatesv1.CertificateSigningRe
// WaitForCertificate waits for a certificate to be issued until timeout, or returns an error. // WaitForCertificate waits for a certificate to be issued until timeout, or returns an error.
func WaitForCertificate(ctx context.Context, client clientset.Interface, reqName string, reqUID types.UID) (certData []byte, err error) { func WaitForCertificate(ctx context.Context, client clientset.Interface, reqName string, reqUID types.UID) (certData []byte, err error) {
fieldSelector := fields.OneTermEqualSelector("metadata.name", reqName).String() fieldSelector := fields.OneTermEqualSelector("metadata.name", reqName).String()
logger := klog.FromContext(ctx)
var lw *cache.ListWatch var lw *cache.ListWatch
var obj runtime.Object var obj runtime.Object
@ -184,7 +197,7 @@ func WaitForCertificate(ctx context.Context, client clientset.Interface, reqName
} }
break break
} else { } else {
klog.V(2).Infof("error fetching v1 certificate signing request: %v", err) logger.V(2).Info("Error fetching v1 certificate signing request", "err", err)
} }
// return if we've timed out // return if we've timed out
@ -208,7 +221,7 @@ func WaitForCertificate(ctx context.Context, client clientset.Interface, reqName
} }
break break
} else { } else {
klog.V(2).Infof("error fetching v1beta1 certificate signing request: %v", err) logger.V(2).Info("Error fetching v1beta1 certificate signing request", "err", err)
} }
// return if we've timed out // return if we've timed out
@ -254,11 +267,11 @@ func WaitForCertificate(ctx context.Context, client clientset.Interface, reqName
} }
if approved { if approved {
if len(csr.Status.Certificate) > 0 { if len(csr.Status.Certificate) > 0 {
klog.V(2).Infof("certificate signing request %s is issued", csr.Name) logger.V(2).Info("Certificate signing request is issued", "csr", klog.KObj(csr))
issuedCertificate = csr.Status.Certificate issuedCertificate = csr.Status.Certificate
return true, nil return true, nil
} }
klog.V(2).Infof("certificate signing request %s is approved, waiting to be issued", csr.Name) logger.V(2).Info("Certificate signing request is approved, waiting to be issued", "csr", klog.KObj(csr))
} }
case *certificatesv1beta1.CertificateSigningRequest: case *certificatesv1beta1.CertificateSigningRequest:
@ -279,11 +292,11 @@ func WaitForCertificate(ctx context.Context, client clientset.Interface, reqName
} }
if approved { if approved {
if len(csr.Status.Certificate) > 0 { if len(csr.Status.Certificate) > 0 {
klog.V(2).Infof("certificate signing request %s is issued", csr.Name) logger.V(2).Info("Certificate signing request is issued", "csr", klog.KObj(csr))
issuedCertificate = csr.Status.Certificate issuedCertificate = csr.Status.Certificate
return true, nil return true, nil
} }
klog.V(2).Infof("certificate signing request %s is approved, waiting to be issued", csr.Name) logger.V(2).Info("Certificate signing request is approved, waiting to be issued", "csr", klog.KObj(csr))
} }
default: default: