From b4fde8da81498ff547f8767314adc29550f7b2e3 Mon Sep 17 00:00:00 2001 From: Krzysztof Ostrowski Date: Mon, 13 Feb 2023 16:38:56 +0100 Subject: [PATCH 1/2] add cryptographic wearout for AES GCM transformer Signed-off-by: Krzysztof Ostrowski --- staging/src/k8s.io/kms/encryption/service.go | 81 +++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/staging/src/k8s.io/kms/encryption/service.go b/staging/src/k8s.io/kms/encryption/service.go index 89bc3d7dc0a..edb5d135339 100644 --- a/staging/src/k8s.io/kms/encryption/service.go +++ b/staging/src/k8s.io/kms/encryption/service.go @@ -24,6 +24,7 @@ import ( "fmt" "strings" "sync" + "sync/atomic" "time" "k8s.io/apimachinery/pkg/util/wait" @@ -56,6 +57,17 @@ const ( // keyLength is the length of the local KEK in bytes. // This is the same length used for the DEKs generated in kube-apiserver. keyLength = 32 + // keyMaxUsage is the maximum number of times an AES GCM key can be used + // with a random nonce: 2^32. The local KEK is a transformer that hold an + // AES GCM key. It is based on recommendations from + // https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf. + // It is reduced by one to be comparable with a atomic.Uint32. + keyMaxUsage = 1<<32 - 1 + // keySuggestedUsage is a threshold that triggers the rotation of a new local KEK. It means that half + // the number of times a local KEK can be used has been reached. + keySuggestedUsage = 1 << 31 + // keyMaxAge is the maximum age of a local KEK. + keyMaxAge = 7 * 24 * time.Hour ) var _ service.Service = &LocalKEKService{} @@ -73,9 +85,15 @@ type LocalKEKService struct { // transformers is a thread-safe LRU cache which caches decrypted DEKs indexed by their encrypted form. transformers *lru.Cache - remoteKMSResponse *service.EncryptResponse + // localMutex is used to read / write a new localTransformer, localUsage and localExpiry. + localMutex sync.RWMutex + // localUsage is incremented by the getTransformerForEncryption method and initialized / read by the Run method. + localUsage atomic.Uint32 + // localExpiry should be only read and be written by the Run method. + localExpiry time.Time localTransformer value.Transformer localTransformerErr error + remoteKMSResponse *service.EncryptResponse } // NewLocalKEKService is being initialized with a remote KMS service. @@ -91,6 +109,57 @@ func NewLocalKEKService(remoteService service.Service) *LocalKEKService { } } +// Run is locking and expected to be run with a goroutine. The method creates a +// new local KEK when the following thresholds are met: +// - the local KEK is used more often than keySuggestedUsage times or +// - the local KEK is older than a localExpiry. +func (m *LocalKEKService) Run(ctx context.Context) { + wait.UntilWithContext(ctx, func(ctx context.Context) { + if time.Now().After(m.localExpiry) || m.localUsage.Load() > keySuggestedUsage { + uid := fmt.Sprintf("%s:%d", referenceKEKAnnotationKey, time.Now().Unix()) + + key, err := generateKey(keyLength) + if err != nil { + klog.ErrorS(err, "failed to generate local KEK", "uid", uid) + return + } + + block, err := aes.NewCipher(key) + if err != nil { + klog.ErrorS(err, "failed to create cipher block", "uid", uid) + return + } + + transformer := aestransformer.NewGCMTransformer(block) + resp, err := m.remoteKMS.Encrypt(ctx, uid, key) + if err != nil { + klog.ErrorS(err, "failed to encrypt local KEK with remote KMS", "uid", uid) + return + } + if err = validateRemoteKMSResponse(resp); err != nil { + klog.ErrorS(err, "response annotations failed validation", "uid", uid) + return + } + + m.localMutex.Lock() + m.localExpiry = time.Now().Add(keyMaxAge) + m.localUsage = atomic.Uint32{} + m.localTransformer = transformer + m.remoteKMSResponse = copyResponseAndAddLocalKEKAnnotation(resp) + m.localMutex.Unlock() + + m.transformers.Add(base64.StdEncoding.EncodeToString(resp.Ciphertext), transformer) + } + + return + }, time.Minute) +} + +// getTransformerForEncryption returns the local KEK as localTransformer, the corresponding +// rmeoteKMSResponse and an potential error. +// On first use, the localTransformer is initialized and the remoteKMSResponse is set. +// On every use the localUsage is incremented by one. +// It is assumed that only one encryption will happen with the returned transformer. func (m *LocalKEKService) getTransformerForEncryption(uid string) (value.Transformer, *service.EncryptResponse, error) { // Check if we have a local KEK // - If exists, use the local KEK for encryption and return @@ -119,12 +188,22 @@ func (m *LocalKEKService) getTransformerForEncryption(uid string) (value.Transfo if err = validateRemoteKMSResponse(resp); err != nil { return false, fmt.Errorf("response annotations failed validation: %w", err) } + m.localExpiry = time.Now().Add(keyMaxAge) + m.localUsage = atomic.Uint32{} m.remoteKMSResponse = copyResponseAndAddLocalKEKAnnotation(resp) m.localTransformer = transformer m.transformers.Add(base64.StdEncoding.EncodeToString(resp.Ciphertext), transformer) return true, nil }) }) + + if counter := m.localUsage.Add(1); counter == keyMaxUsage { + return nil, nil, fmt.Errorf("local KEK has reached maximum usage of %d", keyMaxUsage) + } + + m.localMutex.RLock() + defer m.localMutex.RUnlock() + return m.localTransformer, m.remoteKMSResponse, m.localTransformerErr } From 3bdd5ceae1bfcbb8b6dc141f8837340877dc2a4e Mon Sep 17 00:00:00 2001 From: Anish Ramasekar Date: Mon, 13 Feb 2023 23:31:01 +0000 Subject: [PATCH 2/2] implement local kek generation and rotate based on status Signed-off-by: Anish Ramasekar --- staging/src/k8s.io/kms/encryption/service.go | 393 ++++++++----- .../src/k8s.io/kms/encryption/service_test.go | 516 ++++++++++++++++-- staging/src/k8s.io/kms/go.mod | 1 + staging/src/k8s.io/kms/go.sum | 2 + 4 files changed, 717 insertions(+), 195 deletions(-) diff --git a/staging/src/k8s.io/kms/encryption/service.go b/staging/src/k8s.io/kms/encryption/service.go index edb5d135339..864cbafa49f 100644 --- a/staging/src/k8s.io/kms/encryption/service.go +++ b/staging/src/k8s.io/kms/encryption/service.go @@ -17,6 +17,7 @@ limitations under the License. package encryption import ( + "bytes" "context" "crypto/aes" "crypto/rand" @@ -27,14 +28,26 @@ import ( "sync/atomic" "time" + "k8s.io/apimachinery/pkg/util/uuid" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/klog/v2" aestransformer "k8s.io/kms/pkg/encrypt/aes" "k8s.io/kms/pkg/value" "k8s.io/kms/service" + "k8s.io/utils/clock" "k8s.io/utils/lru" ) +// localKEK is a struct that holds the local KEK and the remote KMS response. +type localKEK struct { + encKEK []byte + usage atomic.Uint64 + expiry time.Time + transformer value.Transformer + remoteKMSResponse *service.EncryptResponse + generatedAt time.Time +} + var ( // emptyContext is an empty slice of bytes. This is passed as value.Context to the // GCM transformer. The grpc interface does not provide any additional authenticated data @@ -42,10 +55,6 @@ var ( emptyContext = value.DefaultContext([]byte{}) // errInvalidKMSAnnotationKeySuffix is returned when the annotation key suffix is not allowed. errInvalidKMSAnnotationKeySuffix = fmt.Errorf("annotation keys are not allowed to use %s", referenceSuffix) - - // these are var instead of const so that we can set them during tests - localKEKGenerationPollInterval = 1 * time.Second - localKEKGenerationPollTimeout = 5 * time.Minute ) const ( @@ -54,6 +63,10 @@ const ( referenceKEKAnnotationKey = "encrypted-kek" + referenceSuffix numAnnotations = 1 cacheSize = 1_000 + + // localKEKGenerationPollInterval is the interval at which the local KEK is checked for rotation. + localKEKGenerationPollInterval = 1 * time.Minute + // keyLength is the length of the local KEK in bytes. // This is the same length used for the DEKs generated in kube-apiserver. keyLength = 32 @@ -62,191 +75,154 @@ const ( // AES GCM key. It is based on recommendations from // https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf. // It is reduced by one to be comparable with a atomic.Uint32. - keyMaxUsage = 1<<32 - 1 + // We picked a arbitrary number that is less than the max usage of the local KEK. + keyMaxUsage = 1<<22 - 1 // keySuggestedUsage is a threshold that triggers the rotation of a new local KEK. It means that half // the number of times a local KEK can be used has been reached. - keySuggestedUsage = 1 << 31 - // keyMaxAge is the maximum age of a local KEK. + keySuggestedUsage = 1 << 21 + // keyMaxAge is the maximum age of a local KEK. It is not a cryptographic necessity. keyMaxAge = 7 * 24 * time.Hour ) var _ service.Service = &LocalKEKService{} -// LocalKEKService adds an additional KEK layer to reduce calls to the remote -// KMS. -// The local KEK is generated once and stored in the LocalKEKService. This KEK -// is used for all encryption operations. For the decrypt operation, if the encrypted -// local KEK is not found in the cache, the remote KMS is used to decrypt the local KEK. +// LocalKEKService adds an additional KEK layer to reduce calls to the remote KMS. +// The local KEK is generated at startup in a controller loop and stored in the +// LocalKEKService. This KEK is used for all encryption operations. For the decrypt +// operation, if the encrypted local KEK is not found in the cache, the remote KMS +// is used to decrypt the local KEK. type LocalKEKService struct { + mu sync.Mutex // remoteKMS is the remote kms that is used to encrypt and decrypt the local KEKs. - remoteKMS service.Service - remoteOnce sync.Once - + remoteKMS service.Service + // localKEKTracker is a atomic pointer to avoid locks. This is used to store the local KEK. + localKEKTracker atomic.Pointer[localKEK] // transformers is a thread-safe LRU cache which caches decrypted DEKs indexed by their encrypted form. + // The cache is only used for the decrypt operation. transformers *lru.Cache + // isReady is an atomic boolean that indicates if the localKEK service is ready for encryption. + isReady atomic.Bool - // localMutex is used to read / write a new localTransformer, localUsage and localExpiry. - localMutex sync.RWMutex - // localUsage is incremented by the getTransformerForEncryption method and initialized / read by the Run method. - localUsage atomic.Uint32 - // localExpiry should be only read and be written by the Run method. - localExpiry time.Time - localTransformer value.Transformer - localTransformerErr error - remoteKMSResponse *service.EncryptResponse + keyMaxUsage uint64 + keySuggestedUsage uint64 + keyMaxAge time.Duration + + pollInterval time.Duration + + clock clock.Clock } // NewLocalKEKService is being initialized with a remote KMS service. -// In the current implementation, the localKEK Service needs to be -// restarted by the caller after security thresholds are met. -// TODO(aramase): handle rotation of local KEKs -// - when the keyID in Status() no longer matches the keyID used during encryption -// - when the local KEK has been used for a certain number of times -func NewLocalKEKService(remoteService service.Service) *LocalKEKService { - return &LocalKEKService{ - remoteKMS: remoteService, - transformers: lru.New(cacheSize), - } +// The local KEK is generated in a controller loop. The local KEK is used for all +// encryption operations. +func NewLocalKEKService(ctx context.Context, remoteService service.Service) *LocalKEKService { + return newLocalKEKService(ctx, remoteService, keyMaxUsage, keySuggestedUsage, keyMaxAge, localKEKGenerationPollInterval, clock.RealClock{}) } -// Run is locking and expected to be run with a goroutine. The method creates a -// new local KEK when the following thresholds are met: +func newLocalKEKService(ctx context.Context, remoteService service.Service, maxUsage, suggestedUsage uint64, maxAge, pollInterval time.Duration, clock clock.Clock) *LocalKEKService { + localKEKService := &LocalKEKService{ + remoteKMS: remoteService, + transformers: lru.New(cacheSize), + keyMaxUsage: maxUsage, + keySuggestedUsage: suggestedUsage, + keyMaxAge: maxAge, + pollInterval: pollInterval, + clock: clock, + } + go localKEKService.run(ctx) + return localKEKService +} + +// Run method creates a new local KEK when the following thresholds are met: // - the local KEK is used more often than keySuggestedUsage times or // - the local KEK is older than a localExpiry. -func (m *LocalKEKService) Run(ctx context.Context) { - wait.UntilWithContext(ctx, func(ctx context.Context) { - if time.Now().After(m.localExpiry) || m.localUsage.Load() > keySuggestedUsage { - uid := fmt.Sprintf("%s:%d", referenceKEKAnnotationKey, time.Now().Unix()) +// +// this method starts the controller and blocks until the context is cancelled. +func (m *LocalKEKService) run(ctx context.Context) { + // same as wait.UntilWithContext but with a custom clock + wait.BackoffUntil(func() { + lk := m.getLocalKEK() + // this is the first time the local KEK is generated + localKEKNotGenerated := lk.transformer == nil + // the local KEK is used more often than keySuggestedUsage times + localKEKUsageThresholdReached := lk.usage.Load() > m.keySuggestedUsage + // the local KEK is older than the expiry + localKEKExpired := m.clock.Now().After(lk.expiry) - key, err := generateKey(keyLength) - if err != nil { - klog.ErrorS(err, "failed to generate local KEK", "uid", uid) + if localKEKNotGenerated || localKEKUsageThresholdReached || localKEKExpired { + uid := string(uuid.NewUUID()) + err := m.generateLocalKEK(ctx, uid, "") + if err == nil { + m.isReady.Store(true) return } - - block, err := aes.NewCipher(key) - if err != nil { - klog.ErrorS(err, "failed to create cipher block", "uid", uid) - return + klog.V(2).ErrorS(err, "failed to generate local KEK", "uid", uid) + // if the local KEK is expired and we cannot generate a new one, we set + // isReady to false because we can no longer encrypt new data. + if localKEKExpired { + m.isReady.Store(false) } - - transformer := aestransformer.NewGCMTransformer(block) - resp, err := m.remoteKMS.Encrypt(ctx, uid, key) - if err != nil { - klog.ErrorS(err, "failed to encrypt local KEK with remote KMS", "uid", uid) - return - } - if err = validateRemoteKMSResponse(resp); err != nil { - klog.ErrorS(err, "response annotations failed validation", "uid", uid) - return - } - - m.localMutex.Lock() - m.localExpiry = time.Now().Add(keyMaxAge) - m.localUsage = atomic.Uint32{} - m.localTransformer = transformer - m.remoteKMSResponse = copyResponseAndAddLocalKEKAnnotation(resp) - m.localMutex.Unlock() - - m.transformers.Add(base64.StdEncoding.EncodeToString(resp.Ciphertext), transformer) } - - return - }, time.Minute) + }, wait.NewJitteredBackoffManager(m.pollInterval, 0, m.clock), true, ctx.Done()) } // getTransformerForEncryption returns the local KEK as localTransformer, the corresponding -// rmeoteKMSResponse and an potential error. -// On first use, the localTransformer is initialized and the remoteKMSResponse is set. +// remoteKMSResponse and a potential error. // On every use the localUsage is incremented by one. // It is assumed that only one encryption will happen with the returned transformer. func (m *LocalKEKService) getTransformerForEncryption(uid string) (value.Transformer, *service.EncryptResponse, error) { - // Check if we have a local KEK - // - If exists, use the local KEK for encryption and return - // - Not exists, generate local KEK, encrypt with remote KEK, - // store it in cache encrypt the data and return. This can be - // expensive but only 1 in N calls will incur this additional latency, - // N being number of times local KEK is reused) - m.remoteOnce.Do(func() { - m.localTransformerErr = wait.PollImmediateWithContext(context.Background(), localKEKGenerationPollInterval, localKEKGenerationPollTimeout, - func(ctx context.Context) (done bool, err error) { - key, err := generateKey(keyLength) - if err != nil { - return false, fmt.Errorf("failed to generate local KEK: %w", err) - } - block, err := aes.NewCipher(key) - if err != nil { - return false, fmt.Errorf("failed to create cipher block: %w", err) - } - transformer := aestransformer.NewGCMTransformer(block) + lk := m.getLocalKEK() + // localKEK is not initialized yet + if lk.transformer == nil { + return nil, nil, fmt.Errorf("local KEK is not initialized") + } - resp, err := m.remoteKMS.Encrypt(ctx, uid, key) - if err != nil { - klog.ErrorS(err, "failed to encrypt local KEK with remote KMS", "uid", uid) - return false, nil - } - if err = validateRemoteKMSResponse(resp); err != nil { - return false, fmt.Errorf("response annotations failed validation: %w", err) - } - m.localExpiry = time.Now().Add(keyMaxAge) - m.localUsage = atomic.Uint32{} - m.remoteKMSResponse = copyResponseAndAddLocalKEKAnnotation(resp) - m.localTransformer = transformer - m.transformers.Add(base64.StdEncoding.EncodeToString(resp.Ciphertext), transformer) - return true, nil - }) - }) + if m.clock.Now().After(lk.expiry) { + return nil, nil, fmt.Errorf("local KEK has expired at %v", lk.expiry) + } - if counter := m.localUsage.Add(1); counter == keyMaxUsage { + if counter := lk.usage.Add(1); counter >= m.keyMaxUsage { return nil, nil, fmt.Errorf("local KEK has reached maximum usage of %d", keyMaxUsage) } - m.localMutex.RLock() - defer m.localMutex.RUnlock() - - return m.localTransformer, m.remoteKMSResponse, m.localTransformerErr -} - -func copyResponseAndAddLocalKEKAnnotation(resp *service.EncryptResponse) *service.EncryptResponse { - annotations := make(map[string][]byte, len(resp.Annotations)+numAnnotations) - for s, bytes := range resp.Annotations { - s := s - bytes := bytes - annotations[s] = bytes - } - annotations[referenceKEKAnnotationKey] = resp.Ciphertext - - return &service.EncryptResponse{ - // Ciphertext is not set on purpose - it is different per Encrypt call - KeyID: resp.KeyID, - Annotations: annotations, - } + return lk.transformer, lk.remoteKMSResponse, nil } // Encrypt encrypts the plaintext with the localKEK. func (m *LocalKEKService) Encrypt(ctx context.Context, uid string, pt []byte) (*service.EncryptResponse, error) { transformer, resp, err := m.getTransformerForEncryption(uid) if err != nil { - klog.V(2).InfoS("encrypt plaintext", "uid", uid, "err", err) + klog.V(2).ErrorS(err, "failed to get transformer for encryption", "uid", uid) return nil, err } ct, err := transformer.TransformToStorage(ctx, pt, emptyContext) if err != nil { - klog.V(2).InfoS("encrypt plaintext", "uid", uid, "err", err) + klog.V(2).ErrorS(err, "failed to encrypt data", "uid", uid) return nil, err } return &service.EncryptResponse{ Ciphertext: ct, - KeyID: resp.KeyID, // TODO what about rotation ?? + KeyID: resp.KeyID, Annotations: resp.Annotations, }, nil } +// getTransformerForDecryption returns the transformer for the given encryptedKEK. +// - If the encryptedKEK is the current localKEK, the localKEK is returned. +// - If the encryptedKEK is not the current localKEK, the cache is checked. +// - If the encryptedKEK is not found in the cache, the remote KMS is used to decrypt the encryptedKEK. func (m *LocalKEKService) getTransformerForDecryption(ctx context.Context, uid string, req *service.DecryptRequest) (value.Transformer, error) { encKEK := req.Annotations[referenceKEKAnnotationKey] + // check if the key required for decryption is the current local KEK + // that's being used for encryption + lk := m.getLocalKEK() + if lk.transformer != nil && bytes.Equal(lk.encKEK, encKEK) { + return lk.transformer, nil + } + // check if the key required for decryption is already in the cache if _transformer, found := m.transformers.Get(base64.StdEncoding.EncodeToString(encKEK)); found { return _transformer.(value.Transformer), nil } @@ -269,7 +245,7 @@ func (m *LocalKEKService) getTransformerForDecryption(ctx context.Context, uid s // Overwrite the plain key with 0s. copy(key, make([]byte, len(key))) - m.transformers.Add(encKEK, transformer) + m.transformers.Add(base64.StdEncoding.EncodeToString(encKEK), transformer) return transformer, nil } @@ -283,13 +259,13 @@ func (m *LocalKEKService) Decrypt(ctx context.Context, uid string, req *service. transformer, err := m.getTransformerForDecryption(ctx, uid, req) if err != nil { - klog.V(2).InfoS("decrypt ciphertext", "uid", uid, "err", err) + klog.V(2).ErrorS(err, "failed to get transformer for decryption", "uid", uid) return nil, fmt.Errorf("failed to get transformer for decryption: %w", err) } pt, _, err := transformer.TransformFromStorage(ctx, req.Ciphertext, emptyContext) if err != nil { - klog.V(2).InfoS("decrypt ciphertext with pulled key", "uid", uid, "err", err) + klog.V(2).ErrorS(err, "failed to decrypt data", "uid", uid) return nil, err } @@ -298,12 +274,142 @@ func (m *LocalKEKService) Decrypt(ctx context.Context, uid string, req *service. // Status returns the status of the remote KMS. func (m *LocalKEKService) Status(ctx context.Context) (*service.StatusResponse, error) { - // TODO(aramase): the response from the remote KMS is funneled through without any validation/action. - // This needs to handle the case when remote KEK has changed. The local KEK needs to be rotated and - // re-encrypted with the new remote KEK. - return m.remoteKMS.Status(ctx) + resp, err := m.remoteKMS.Status(ctx) + if err != nil { + return nil, err + } + if err := validateRemoteKMSStatusResponse(resp); err != nil { + return nil, err + } + + r := copyStatusResponse(resp) + // if the remote KMS KeyID has changed, we need to rotate the local KEK + lk := m.getLocalKEK() + if lk.transformer != nil && r.KeyID != lk.remoteKMSResponse.KeyID { + if err := m.rotateLocalKEK(ctx, r.KeyID); err != nil { + klog.ErrorS(err, "failed to rotate local KEK", "expectedKeyID", r.KeyID, "currentKeyID", lk.remoteKMSResponse.KeyID) + // if rotation fails, we will overwrite the keyID to the one we are currently using + // for encryption as localKEKService is the authoritative source for the keyID. + r.KeyID = lk.remoteKMSResponse.KeyID + // TODO(aramase): we are currently not returning the error if rotation fails. We should + // allow the failed state for an arbitrary time period and return the error if the state + // is not eventually fixed. + } + } + + var aggregateHealthz []string + if r.Healthz != "ok" { + aggregateHealthz = append(aggregateHealthz, r.Healthz) + } + + if !m.isReady.Load() { + // if the localKEKService is not ready, we will set the healthz status to not ready + klog.V(2).InfoS("localKEKService is not ready", "keyID", r.KeyID) + aggregateHealthz = append(aggregateHealthz, "localKEKService is not ready") + } + if len(aggregateHealthz) > 0 { + r.Healthz = strings.Join(aggregateHealthz, "; ") + } + + return r, nil } +// rotateLocalKEK rotates the local KEK by generating a new local KEK and encrypting it with the +// remote KMS. +func (m *LocalKEKService) rotateLocalKEK(ctx context.Context, expectedKeyID string) error { + uid := string(uuid.NewUUID()) + if err := m.generateLocalKEK(ctx, uid, expectedKeyID); err != nil { + klog.V(2).ErrorS(err, "failed to generate local KEK as part of rotation", "uid", uid) + return fmt.Errorf("[uid=%s] failed to generate local KEK as part of rotation: %w", uid, err) + } + return nil +} + +// generateLocalKEK generates a new local KEK and encrypts it with the remote KMS. +// if expectedKeyID is not empty, it will check if the keyID returned from the remote KMS matches +// the expected keyID. If the keyID does not match, it will continue using the existing local KEK +// and return an error. +func (m *LocalKEKService) generateLocalKEK(ctx context.Context, uid, expectedKeyID string) error { + m.mu.Lock() + defer m.mu.Unlock() + + lk := m.getLocalKEK() + // if the localKEK was generated in the last pollInterval duration, we will not generate a new + // localKEK. This is to avoid regenerating a new localKEK for queued requests. + if lk.transformer != nil && m.clock.Since(lk.generatedAt) < m.pollInterval { + return nil + } + + key, err := generateKey(keyLength) + if err != nil { + return fmt.Errorf("failed to generate local KEK: %w", err) + } + block, err := aes.NewCipher(key) + if err != nil { + return fmt.Errorf("failed to create cipher block: %w", err) + } + + resp, err := m.remoteKMS.Encrypt(ctx, uid, key) + if err != nil { + return fmt.Errorf("failed to encrypt local KEK: %w", err) + } + if err = validateRemoteKMSEncryptResponse(resp); err != nil { + return fmt.Errorf("invalid response from remote KMS: %w", err) + } + if expectedKeyID != "" && resp.KeyID != expectedKeyID { + return fmt.Errorf("keyID returned from remote KMS does not match expected keyID") + } + + now := m.clock.Now() + m.localKEKTracker.Store(&localKEK{ + encKEK: resp.Ciphertext, + expiry: now.Add(m.keyMaxAge), + usage: atomic.Uint64{}, + transformer: aestransformer.NewGCMTransformer(block), + remoteKMSResponse: copyResponseAndAddLocalKEKAnnotation(resp), + generatedAt: now, + }) + + return nil +} + +func (m *LocalKEKService) getLocalKEK() *localKEK { + lk := m.localKEKTracker.Load() + if lk == nil { + return &localKEK{} + } + return lk +} + +// copyResponseAndAddLocalKEKAnnotation returns a copy of the remoteKMSResponse with the +// referenceKEKAnnotationKey added to the annotations. +func copyResponseAndAddLocalKEKAnnotation(resp *service.EncryptResponse) *service.EncryptResponse { + annotations := make(map[string][]byte, len(resp.Annotations)+numAnnotations) + for s, bytes := range resp.Annotations { + s := s + bytes := bytes + annotations[s] = bytes + } + annotations[referenceKEKAnnotationKey] = resp.Ciphertext + + return &service.EncryptResponse{ + // Ciphertext is not set on purpose - it is different per Encrypt call + KeyID: resp.KeyID, + Annotations: annotations, + } +} + +// copyStatusResponse returns a copy of the remote KMS status response. +func copyStatusResponse(resp *service.StatusResponse) *service.StatusResponse { + return &service.StatusResponse{ + Healthz: resp.Healthz, + Version: resp.Version, + KeyID: resp.KeyID, + } +} + +// annotationsWithoutReferenceKeys returns a copy of the annotations without the reference implementation +// annotations. func annotationsWithoutReferenceKeys(annotations map[string][]byte) map[string][]byte { if len(annotations) <= numAnnotations { return nil @@ -320,7 +426,8 @@ func annotationsWithoutReferenceKeys(annotations map[string][]byte) map[string][ return m } -func validateRemoteKMSResponse(resp *service.EncryptResponse) error { +// validateRemoteKMSEncryptResponse validates the EncryptResponse from the remote KMS. +func validateRemoteKMSEncryptResponse(resp *service.EncryptResponse) error { // validate annotations don't contain the reference implementation annotations for k := range resp.Annotations { if strings.HasSuffix(k, referenceSuffix) { @@ -330,6 +437,14 @@ func validateRemoteKMSResponse(resp *service.EncryptResponse) error { return nil } +// validateRemoteKMSStatusResponse validates the StatusResponse from the remote KMS. +func validateRemoteKMSStatusResponse(resp *service.StatusResponse) error { + if len(resp.KeyID) == 0 { + return fmt.Errorf("keyID is empty") + } + return nil +} + // generateKey generates a random key using system randomness. func generateKey(length int) (key []byte, err error) { key = make([]byte, length) diff --git a/staging/src/k8s.io/kms/encryption/service_test.go b/staging/src/k8s.io/kms/encryption/service_test.go index 4a46b27fcba..68343413e2c 100644 --- a/staging/src/k8s.io/kms/encryption/service_test.go +++ b/staging/src/k8s.io/kms/encryption/service_test.go @@ -17,18 +17,24 @@ limitations under the License. package encryption import ( + "bytes" "context" "encoding/base64" "errors" "reflect" + "strings" "sync" "testing" "time" + "k8s.io/apimachinery/pkg/util/rand" + "k8s.io/apimachinery/pkg/util/wait" "k8s.io/kms/service" + testingclock "k8s.io/utils/clock/testing" ) func TestCopyResponseAndAddLocalKEKAnnotation(t *testing.T) { + t.Parallel() testCases := []struct { name string input *service.EncryptResponse @@ -89,6 +95,7 @@ func TestCopyResponseAndAddLocalKEKAnnotation(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { + t.Parallel() got := copyResponseAndAddLocalKEKAnnotation(tc.input) if !reflect.DeepEqual(got, tc.want) { t.Errorf("copyResponseAndAddLocalKEKAnnotation(%v) = %v, want %v", tc.input, got, tc.want) @@ -98,6 +105,7 @@ func TestCopyResponseAndAddLocalKEKAnnotation(t *testing.T) { } func TestAnnotationsWithoutReferenceKeys(t *testing.T) { + t.Parallel() testCases := []struct { name string input map[string][]byte @@ -135,6 +143,7 @@ func TestAnnotationsWithoutReferenceKeys(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { + t.Parallel() got := annotationsWithoutReferenceKeys(tc.input) if !reflect.DeepEqual(got, tc.want) { t.Errorf("annotationsWithoutReferenceKeys(%v) = %v, want %v", tc.input, got, tc.want) @@ -143,7 +152,8 @@ func TestAnnotationsWithoutReferenceKeys(t *testing.T) { } } -func TestValidateRemoteKMSResponse(t *testing.T) { +func TestValidateRemoteKMSEncryptResponse(t *testing.T) { + t.Parallel() testCases := []struct { name string input *service.EncryptResponse @@ -178,7 +188,8 @@ func TestValidateRemoteKMSResponse(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - got := validateRemoteKMSResponse(tc.input) + t.Parallel() + got := validateRemoteKMSEncryptResponse(tc.input) if got != tc.want { t.Errorf("validateRemoteKMSResponse(%v) = %v, want %v", tc.input, got, tc.want) } @@ -186,19 +197,66 @@ func TestValidateRemoteKMSResponse(t *testing.T) { } } +func TestValidateRemoteKMSStatusResponse(t *testing.T) { + t.Parallel() + testCases := []struct { + name string + input *service.StatusResponse + wantErr string + }{ + { + name: "keyID is empty", + input: &service.StatusResponse{ + KeyID: "", + }, + wantErr: "keyID is empty", + }, + { + name: "no error", + input: &service.StatusResponse{ + KeyID: "keyID", + }, + wantErr: "", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := validateRemoteKMSStatusResponse(tc.input) + if tc.wantErr != "" { + if got == nil { + t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr) + } + if !strings.Contains(got.Error(), tc.wantErr) { + t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr) + } + } else { + if got != nil { + t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr) + } + } + }) + } +} + var _ service.Service = &testRemoteService{} type testRemoteService struct { mu sync.Mutex - keyID string - disabled bool + keyID string + disabled bool + encryptCallCount int + decryptCallCount int } func (s *testRemoteService) Encrypt(ctx context.Context, uid string, plaintext []byte) (*service.EncryptResponse, error) { s.mu.Lock() defer s.mu.Unlock() + s.encryptCallCount++ if s.disabled { return nil, errors.New("failed to encrypt") } @@ -215,6 +273,7 @@ func (s *testRemoteService) Decrypt(ctx context.Context, uid string, req *servic s.mu.Lock() defer s.mu.Unlock() + s.decryptCallCount++ if s.disabled { return nil, errors.New("failed to decrypt") } @@ -231,82 +290,88 @@ func (s *testRemoteService) Status(ctx context.Context) (*service.StatusResponse s.mu.Lock() defer s.mu.Unlock() - if s.disabled { - return nil, errors.New("failed to get status") - } - return &service.StatusResponse{ + resp := &service.StatusResponse{ Version: "v2alpha1", Healthz: "ok", KeyID: s.keyID, - }, nil + } + if s.disabled { + resp.Healthz = "remote KMS is disabled" + } + return resp, nil } func (s *testRemoteService) SetDisabledStatus(disabled bool) { s.mu.Lock() defer s.mu.Unlock() - s.disabled = true + s.disabled = disabled +} + +func (s *testRemoteService) SetKeyID(keyID string) { + s.mu.Lock() + defer s.mu.Unlock() + s.keyID = keyID +} + +func (s *testRemoteService) EncryptCallCount() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.encryptCallCount +} + +func (s *testRemoteService) DecryptCallCount() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.decryptCallCount } func TestEncrypt(t *testing.T) { - remoteKMS := &testRemoteService{keyID: "test-key-id"} - localKEKService := NewLocalKEKService(remoteKMS) - - validateResponse := func(got *service.EncryptResponse, t *testing.T) { - if len(got.Annotations) != 2 { - t.Fatalf("Encrypt() annotations = %v, want 2 annotations", got.Annotations) - } - if _, ok := got.Annotations[referenceKEKAnnotationKey]; !ok { - t.Fatalf("Encrypt() annotations = %v, want %v", got.Annotations, referenceKEKAnnotationKey) - } - if got.KeyID != remoteKMS.keyID { - t.Fatalf("Encrypt() keyID = %v, want %v", got.KeyID, remoteKMS.keyID) - } - if localKEKService.localTransformer == nil { - t.Fatalf("Encrypt() localTransformer = %v, want non-nil", localKEKService.localTransformer) - } - } - ctx := testContext(t) + remoteKMS := &testRemoteService{keyID: "test-key-id"} + localKEKService := NewLocalKEKService(ctx, remoteKMS) + + waitUntilReady(t, localKEKService) + // local KEK is generated and encryption is successful got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) if err != nil { t.Fatalf("Encrypt() error = %v", err) } - validateResponse(got, t) + validateEncryptResponse(t, got, remoteKMS.keyID, localKEKService) // local KEK is used for encryption even when remote KMS is failing remoteKMS.SetDisabledStatus(true) if got, err = localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")); err != nil { t.Fatalf("Encrypt() error = %v", err) } - validateResponse(got, t) + validateEncryptResponse(t, got, remoteKMS.keyID, localKEKService) } func TestEncryptError(t *testing.T) { - remoteKMS := &testRemoteService{keyID: "test-key-id"} - localKEKService := NewLocalKEKService(remoteKMS) - ctx := testContext(t) + remoteKMS := &testRemoteService{keyID: "test-key-id"} + localKEKService := NewLocalKEKService(ctx, remoteKMS) - localKEKGenerationPollTimeout = 5 * time.Second // first time local KEK generation fails because of remote KMS remoteKMS.SetDisabledStatus(true) _, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) if err == nil { t.Fatalf("Encrypt() error = %v, want non-nil", err) } - if localKEKService.localTransformer != nil { - t.Fatalf("Encrypt() localTransformer = %v, want nil", localKEKService.localTransformer) + lk := localKEKService.getLocalKEK() + if lk.transformer != nil { + t.Fatalf("Encrypt() localKEKTracker = %v, want non-nil localKEK", lk) } remoteKMS.SetDisabledStatus(false) } func TestDecrypt(t *testing.T) { - remoteKMS := &testRemoteService{keyID: "test-key-id"} - localKEKService := NewLocalKEKService(remoteKMS) - ctx := testContext(t) + remoteKMS := &testRemoteService{keyID: "test-key-id"} + localKEKService := NewLocalKEKService(ctx, remoteKMS) + + waitUntilReady(t, localKEKService) // local KEK is generated and encryption/decryption is successful got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) @@ -337,10 +402,11 @@ func TestDecrypt(t *testing.T) { } func TestDecryptError(t *testing.T) { - remoteKMS := &testRemoteService{keyID: "test-key-id"} - localKEKService := NewLocalKEKService(remoteKMS) - ctx := testContext(t) + remoteKMS := &testRemoteService{keyID: "test-key-id"} + localKEKService := NewLocalKEKService(ctx, remoteKMS) + + waitUntilReady(t, localKEKService) got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) if err != nil { @@ -353,6 +419,10 @@ func TestDecryptError(t *testing.T) { } // local KEK for decryption not in cache and remote KMS is failing remoteKMS.SetDisabledStatus(true) + lk := localKEKService.localKEKTracker.Load() + lk.transformer = nil + localKEKService.localKEKTracker.Store(lk) + // clear the cache localKEKService.transformers.Clear() if _, err = localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err == nil { @@ -361,30 +431,318 @@ func TestDecryptError(t *testing.T) { } func TestStatus(t *testing.T) { - remoteKMS := &testRemoteService{keyID: "test-key-id"} - localKEKService := NewLocalKEKService(remoteKMS) - ctx := testContext(t) + fakeClock := testingclock.NewFakeClock(time.Now()) + remoteKMS := &testRemoteService{keyID: "test-key-id"} + localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Second, 100*time.Millisecond, fakeClock) + + waitUntilReady(t, localKEKService) got, err := localKEKService.Status(ctx) if err != nil { t.Fatalf("Status() error = %v", err) } - if got.Version != "v2alpha1" { - t.Fatalf("Status() version = %v, want %v", got.Version, "v2alpha1") - } - if got.Healthz != "ok" { - t.Fatalf("Status() healthz = %v, want %v", got.Healthz, "ok") - } - if got.KeyID != "test-key-id" { - t.Fatalf("Status() keyID = %v, want %v", got.KeyID, "test-key-id") - } + validateStatusResponse(t, got, "v2alpha1", "ok", "test-key-id") + fakeClock.Step(1 * time.Second) // remote KMS is failing remoteKMS.SetDisabledStatus(true) - if _, err = localKEKService.Status(ctx); err == nil { - t.Fatalf("Status() error = %v, want non-nil", err) + // remote KMS keyID changed but local KEK not rotated because of remote KMS failure + // the keyID in status should be the old keyID + // the error should still be nil + remoteKMS.SetKeyID("test-key-id-2") + + if got, err = localKEKService.Status(ctx); err != nil { + t.Fatalf("Status() error = %v, want nil", err) } + validateStatusResponse(t, got, "v2alpha1", "remote KMS is disabled", "test-key-id") + + fakeClock.Step(1 * time.Second) + // wait for local KEK to expire and local KEK service ready to be false + wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) { + return !localKEKService.isReady.Load(), nil + }) + + // status response should include the localKEK unhealthy status + if got, err = localKEKService.Status(ctx); err != nil { + t.Fatalf("Status() error = %v, want nil", err) + } + validateStatusResponse(t, got, "v2alpha1", "remote KMS is disabled; localKEKService is not ready", "test-key-id") + + // remote KMS is functional again, local KEK is rotated + remoteKMS.SetDisabledStatus(false) + fakeClock.Step(1 * time.Second) + waitUntilReady(t, localKEKService) + if got, err = localKEKService.Status(ctx); err != nil { + t.Fatalf("Status() error = %v, want nil", err) + } + validateStatusResponse(t, got, "v2alpha1", "ok", "test-key-id-2") +} + +func TestRotationKeyUsage(t *testing.T) { + ctx := testContext(t) + + var record sync.Map + + fakeClock := testingclock.NewFakeClock(time.Now()) + remoteKMS := &testRemoteService{keyID: "test-key-id"} + localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Minute, 100*time.Millisecond, fakeClock) + waitUntilReady(t, localKEKService) + lk := localKEKService.localKEKTracker.Load() + encLocalKEK := lk.encKEK + + // check only single call for Encrypt to remote KMS + if remoteKMS.EncryptCallCount() != 1 { + t.Fatalf("Encrypt() remoteKMS.EncryptCallCount() = %v, want %v", remoteKMS.EncryptCallCount(), 1) + } + + var wg sync.WaitGroup + for i := 0; i < 6; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte(rand.String(32))) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) { + t.Fatalf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK) + } + record.Store(resp, nil) + }() + } + wg.Wait() + + fakeClock.Step(30 * time.Second) + rotated := false + // wait for the local KEK to be rotated + wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) { + // local KEK must have been rotated after 5 usages + lk = localKEKService.localKEKTracker.Load() + rotated = !bytes.Equal(lk.encKEK, encLocalKEK) + return rotated, nil + }) + if !rotated { + t.Fatalf("local KEK must have been rotated") + } + if remoteKMS.EncryptCallCount() != 2 { + t.Fatalf("Encrypt() remoteKMS.EncryptCallCount() = %v, want %v", remoteKMS.EncryptCallCount(), 2) + } + + // new local KEK must be used for encryption + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte(rand.String(32))) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) { + t.Fatalf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK) + } + record.Store(resp, nil) + }() + } + wg.Wait() + + // check we can decrypt data encrypted with the old and new local KEKs + record.Range(func(key, _ any) bool { + k := key.(*service.EncryptResponse) + decryptRequest := &service.DecryptRequest{ + Ciphertext: k.Ciphertext, + Annotations: k.Annotations, + KeyID: k.KeyID, + } + if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil { + t.Fatalf("Decrypt() error = %v", err) + } + return true + }) + + // Out of the 11 calls to Decrypt: + // - 5 should be using the current local KEK + // - 1 out of the 6 should generate a decrypt call to the remote KMS as the local KEK not in cache + // - 5 out of the 6 should use the cached local KEK after 1st decrypt call to the remote KMS + assertCallCount(t, remoteKMS, localKEKService) +} + +func TestRotationKeyExpiry(t *testing.T) { + ctx := testContext(t) + + var record sync.Map + + fakeClock := testingclock.NewFakeClock(time.Now()) + remoteKMS := &testRemoteService{keyID: "test-key-id"} + localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 5*time.Second, 100*time.Millisecond, fakeClock) + waitUntilReady(t, localKEKService) + lk := localKEKService.localKEKTracker.Load() + encLocalKEK := lk.encKEK + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) { + t.Fatalf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK) + } + record.Store(resp, nil) + }() + } + wg.Wait() + + // check local KEK has only been used 3 times and still under the suggested usage + if lk.usage.Load() != 3 { + t.Fatalf("local KEK usage = %v, want %v", lk.usage.Load(), 3) + } + + // advance the clock to trigger key expiry + fakeClock.Step(6 * time.Second) + + rotated := false + // wait for the local KEK to be rotated due to key expiry + wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) { + // local KEK must have been rotated after the key max age + t.Log("waiting for local KEK to be rotated") + lk = localKEKService.localKEKTracker.Load() + rotated = !bytes.Equal(lk.encKEK, encLocalKEK) + return rotated, nil + }) + if !rotated { + t.Fatalf("local KEK must have been rotated") + } + + // new local KEK must be used for encryption + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) { + t.Fatalf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK) + } + record.Store(resp, nil) + }() + } + wg.Wait() + + // check we can decrypt data encrypted with the old and new local KEKs + record.Range(func(key, _ any) bool { + k := key.(*service.EncryptResponse) + decryptRequest := &service.DecryptRequest{ + Ciphertext: k.Ciphertext, + Annotations: k.Annotations, + KeyID: k.KeyID, + } + if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil { + t.Fatalf("Decrypt() error = %v", err) + } + return true + }) + + // Out of the 8 calls to Decrypt: + // - 5 should be using the current local KEK + // - 1 out of the 3 should generate a decrypt call to the remote KMS as the local KEK not in cache + // - 2 out of the 3 should use the cached local KEK after 1st decrypt call to the remote KMS + assertCallCount(t, remoteKMS, localKEKService) +} + +func TestRotationRemoteKeyIDChanged(t *testing.T) { + ctx := testContext(t) + + var record sync.Map + + fakeClock := testingclock.NewFakeClock(time.Now()) + remoteKMS := &testRemoteService{keyID: "test-key-id"} + localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Minute, 100*time.Millisecond, fakeClock) + waitUntilReady(t, localKEKService) + lk := localKEKService.localKEKTracker.Load() + encLocalKEK := lk.encKEK + + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) { + t.Fatalf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK) + } + record.Store(resp, nil) + }() + } + wg.Wait() + + // check local KEK has only been used 3 times and still under the suggested usage + if lk.usage.Load() != 3 { + t.Fatalf("local KEK usage = %v, want %v", lk.usage.Load(), 3) + } + + fakeClock.Step(30 * time.Second) + // change the remote key ID + remoteKMS.SetKeyID("test-key-id-2") + if _, err := localKEKService.Status(ctx); err != nil { + t.Fatalf("Status() error = %v", err) + } + + rotated := false + // wait for the local KEK to be rotated due to remote key ID change + wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) { + lk = localKEKService.localKEKTracker.Load() + rotated = !bytes.Equal(lk.encKEK, encLocalKEK) + return rotated, nil + }) + if !rotated { + t.Fatalf("local KEK must have been rotated") + } + + // new local KEK must be used for encryption + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) { + t.Fatalf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK) + } + record.Store(resp, nil) + }() + } + wg.Wait() + + // check we can decrypt data encrypted with the old and new local KEKs + record.Range(func(key, _ any) bool { + k := key.(*service.EncryptResponse) + decryptRequest := &service.DecryptRequest{ + Ciphertext: k.Ciphertext, + Annotations: k.Annotations, + KeyID: k.KeyID, + } + if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil { + t.Fatalf("Decrypt() error = %v", err) + } + return true + }) + + // Out of the 8 calls to Decrypt: + // - 5 should be using the current local KEK + // - 1 out of the 3 should generate a decrypt call to the remote KMS as the local KEK not in cache + // - 2 out of the 3 should use the cached local KEK after 1st decrypt call to the remote KMS + assertCallCount(t, remoteKMS, localKEKService) } func testContext(t *testing.T) context.Context { @@ -392,3 +750,49 @@ func testContext(t *testing.T) context.Context { t.Cleanup(cancel) return ctx } + +func waitUntilReady(t *testing.T, s *LocalKEKService) { + t.Helper() + wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) { + return s.isReady.Load(), nil + }) +} + +func validateEncryptResponse(t *testing.T, got *service.EncryptResponse, wantKeyID string, localKEKService *LocalKEKService) { + t.Helper() + if len(got.Annotations) != 2 { + t.Fatalf("Encrypt() annotations = %v, want 2 annotations", got.Annotations) + } + if _, ok := got.Annotations[referenceKEKAnnotationKey]; !ok { + t.Fatalf("Encrypt() annotations = %v, want %v", got.Annotations, referenceKEKAnnotationKey) + } + if got.KeyID != wantKeyID { + t.Fatalf("Encrypt() keyID = %v, want %v", got.KeyID, wantKeyID) + } + if localKEKService.localKEKTracker.Load() == nil { + t.Fatalf("Encrypt() localKEKTracker = %v, want non-nil localKEK", localKEKService.localKEKTracker.Load()) + } +} + +func validateStatusResponse(t *testing.T, got *service.StatusResponse, wantVersion, wantHealthz, wantKeyID string) { + t.Helper() + if got.Version != wantVersion { + t.Fatalf("Status() version = %v, want %v", got.Version, wantVersion) + } + if !strings.EqualFold(got.Healthz, wantHealthz) { + t.Fatalf("Status() healthz = %v, want %v", got.Healthz, wantHealthz) + } + if got.KeyID != wantKeyID { + t.Fatalf("Status() keyID = %v, want %v", got.KeyID, wantKeyID) + } +} + +func assertCallCount(t *testing.T, remoteKMS *testRemoteService, localKEKService *LocalKEKService) { + t.Helper() + if remoteKMS.DecryptCallCount() != 1 { + t.Fatalf("Decrypt() remoteKMS.DecryptCallCount() = %v, want %v", remoteKMS.DecryptCallCount(), 1) + } + if localKEKService.transformers.Len() != 1 { + t.Fatalf("Decrypt() localKEKService.transformers.Len() = %v, want %v", localKEKService.transformers.Len(), 1) + } +} diff --git a/staging/src/k8s.io/kms/go.mod b/staging/src/k8s.io/kms/go.mod index 9dc67bd96a7..c5fea63d185 100644 --- a/staging/src/k8s.io/kms/go.mod +++ b/staging/src/k8s.io/kms/go.mod @@ -15,6 +15,7 @@ require ( require ( github.com/go-logr/logr v1.2.3 // indirect github.com/golang/protobuf v1.5.2 // indirect + github.com/google/uuid v1.3.0 // indirect golang.org/x/net v0.7.0 // indirect golang.org/x/sys v0.5.0 // indirect golang.org/x/text v0.7.0 // indirect diff --git a/staging/src/k8s.io/kms/go.sum b/staging/src/k8s.io/kms/go.sum index 551b7d52343..d145bcd927d 100644 --- a/staging/src/k8s.io/kms/go.sum +++ b/staging/src/k8s.io/kms/go.sum @@ -49,6 +49,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=