From 804d8b205272fb61a6391b589683dfd76e48c834 Mon Sep 17 00:00:00 2001 From: Anish Ramasekar Date: Wed, 15 Mar 2023 01:25:48 +0000 Subject: [PATCH] [KMSv2] remove key hierarchy in reference implementation Signed-off-by: Anish Ramasekar --- staging/src/k8s.io/kms/go.mod | 3 +- staging/src/k8s.io/kms/go.sum | 2 - .../k8s.io/kms/internal/plugins/mock/go.mod | 2 - .../k8s.io/kms/internal/plugins/mock/go.sum | 2 - .../kms/internal/plugins/mock/plugin.go | 17 +- .../src/k8s.io/kms/pkg/hierarchy/hierarchy.go | 456 ---------- .../kms/pkg/hierarchy/hierarchy_test.go | 833 ------------------ 7 files changed, 12 insertions(+), 1303 deletions(-) delete mode 100644 staging/src/k8s.io/kms/pkg/hierarchy/hierarchy.go delete mode 100644 staging/src/k8s.io/kms/pkg/hierarchy/hierarchy_test.go diff --git a/staging/src/k8s.io/kms/go.mod b/staging/src/k8s.io/kms/go.mod index 59702d36e75..6a653cdba95 100644 --- a/staging/src/k8s.io/kms/go.mod +++ b/staging/src/k8s.io/kms/go.mod @@ -10,19 +10,18 @@ require ( k8s.io/apimachinery v0.0.0 k8s.io/client-go v0.0.0 k8s.io/klog/v2 v2.90.1 - k8s.io/utils v0.0.0-20230209194617-a36077c30491 ) require ( github.com/go-logr/logr v1.2.3 // indirect github.com/golang/protobuf v1.5.3 // indirect - github.com/google/uuid v1.3.0 // indirect golang.org/x/net v0.8.0 // indirect golang.org/x/sys v0.6.0 // indirect golang.org/x/text v0.8.0 // indirect golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect google.golang.org/protobuf v1.28.1 // indirect + k8s.io/utils v0.0.0-20230209194617-a36077c30491 // indirect ) replace ( diff --git a/staging/src/k8s.io/kms/go.sum b/staging/src/k8s.io/kms/go.sum index 781738e59b2..5af750c81c4 100644 --- a/staging/src/k8s.io/kms/go.sum +++ b/staging/src/k8s.io/kms/go.sum @@ -50,8 +50,6 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= diff --git a/staging/src/k8s.io/kms/internal/plugins/mock/go.mod b/staging/src/k8s.io/kms/internal/plugins/mock/go.mod index 26fe2cc7025..8435d27b4bd 100644 --- a/staging/src/k8s.io/kms/internal/plugins/mock/go.mod +++ b/staging/src/k8s.io/kms/internal/plugins/mock/go.mod @@ -11,7 +11,6 @@ require ( github.com/go-logr/logr v1.2.3 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.3 // indirect - github.com/google/uuid v1.3.0 // indirect golang.org/x/net v0.8.0 // indirect golang.org/x/sys v0.6.0 // indirect golang.org/x/text v0.8.0 // indirect @@ -19,7 +18,6 @@ require ( google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect google.golang.org/grpc v1.51.0 // indirect google.golang.org/protobuf v1.28.1 // indirect - k8s.io/apimachinery v0.0.0 // indirect k8s.io/client-go v0.0.0 // indirect k8s.io/utils v0.0.0-20230209194617-a36077c30491 // indirect ) diff --git a/staging/src/k8s.io/kms/internal/plugins/mock/go.sum b/staging/src/k8s.io/kms/internal/plugins/mock/go.sum index 781738e59b2..5af750c81c4 100644 --- a/staging/src/k8s.io/kms/internal/plugins/mock/go.sum +++ b/staging/src/k8s.io/kms/internal/plugins/mock/go.sum @@ -50,8 +50,6 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= diff --git a/staging/src/k8s.io/kms/internal/plugins/mock/plugin.go b/staging/src/k8s.io/kms/internal/plugins/mock/plugin.go index 4919e58a06a..ad9d2903d1a 100644 --- a/staging/src/k8s.io/kms/internal/plugins/mock/plugin.go +++ b/staging/src/k8s.io/kms/internal/plugins/mock/plugin.go @@ -26,7 +26,6 @@ import ( "k8s.io/klog/v2" "k8s.io/kms/internal" - "k8s.io/kms/pkg/hierarchy" "k8s.io/kms/pkg/service" "k8s.io/kms/pkg/util" ) @@ -55,14 +54,20 @@ func main() { grpcService := service.NewGRPCService( addr, *timeout, - hierarchy.NewLocalKEKService(ctx, remoteKMSService), + remoteKMSService, ) klog.InfoS("starting server", "listenAddr", *listenAddr) - if err := grpcService.ListenAndServe(); err != nil { - klog.ErrorS(err, "failed to serve") - os.Exit(1) - } + go func() { + if err := grpcService.ListenAndServe(); err != nil { + klog.ErrorS(err, "failed to serve") + os.Exit(1) + } + }() + + <-ctx.Done() + klog.InfoS("shutting down server") + grpcService.Shutdown() } // withShutdownSignal returns a copy of the parent context that will close if diff --git a/staging/src/k8s.io/kms/pkg/hierarchy/hierarchy.go b/staging/src/k8s.io/kms/pkg/hierarchy/hierarchy.go deleted file mode 100644 index 62a55516eff..00000000000 --- a/staging/src/k8s.io/kms/pkg/hierarchy/hierarchy.go +++ /dev/null @@ -1,456 +0,0 @@ -/* -Copyright 2023 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package hierarchy - -import ( - "bytes" - "context" - "crypto/aes" - "crypto/rand" - "encoding/base64" - "fmt" - "strings" - "sync" - "sync/atomic" - "time" - - "k8s.io/apimachinery/pkg/util/uuid" - "k8s.io/apimachinery/pkg/util/wait" - "k8s.io/klog/v2" - aestransformer "k8s.io/kms/pkg/encrypt/aes" - "k8s.io/kms/pkg/service" - "k8s.io/kms/pkg/value" - "k8s.io/utils/clock" - "k8s.io/utils/lru" -) - -// localKEK is a struct that holds the local KEK and the remote KMS response. -type localKEK struct { - encKEK []byte - usage atomic.Uint64 - expiry time.Time - transformer value.Transformer - remoteKMSResponse *service.EncryptResponse - generatedAt time.Time -} - -var ( - // emptyContext is an empty slice of bytes. This is passed as value.Context to the - // GCM transformer. The grpc interface does not provide any additional authenticated data - // to use with AEAD. - emptyContext = value.DefaultContext([]byte{}) - // errInvalidKMSAnnotationKeySuffix is returned when the annotation key suffix is not allowed. - errInvalidKMSAnnotationKeySuffix = fmt.Errorf("annotation keys are not allowed to use %s", referenceSuffix) -) - -const ( - referenceSuffix = ".reference.encryption.k8s.io" - // referenceKEKAnnotationKey is the key used to store the localKEK in the annotations. - referenceKEKAnnotationKey = "encrypted-kek" + referenceSuffix - numAnnotations = 1 - cacheSize = 1_000 - - // localKEKGenerationPollInterval is the interval at which the local KEK is checked for rotation. - localKEKGenerationPollInterval = 1 * time.Minute - - // keyLength is the length of the local KEK in bytes. - // This is the same length used for the DEKs generated in kube-apiserver. - keyLength = 32 - // keyMaxUsage is the maximum number of times an AES GCM key can be used - // with a random nonce: 2^32. The local KEK is a transformer that hold an - // AES GCM key. It is based on recommendations from - // https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf. - // It is reduced by one to be comparable with a atomic.Uint32. - // We picked a arbitrary number that is less than the max usage of the local KEK. - keyMaxUsage = 1<<22 - 1 - // keySuggestedUsage is a threshold that triggers the rotation of a new local KEK. It means that half - // the number of times a local KEK can be used has been reached. - keySuggestedUsage = 1 << 21 - // keyMaxAge is the maximum age of a local KEK. It is not a cryptographic necessity. - keyMaxAge = 7 * 24 * time.Hour -) - -var _ service.Service = &LocalKEKService{} - -// LocalKEKService adds an additional KEK layer to reduce calls to the remote KMS. -// The local KEK is generated at startup in a controller loop and stored in the -// LocalKEKService. This KEK is used for all encryption operations. For the decrypt -// operation, if the encrypted local KEK is not found in the cache, the remote KMS -// is used to decrypt the local KEK. -type LocalKEKService struct { - mu sync.Mutex - // remoteKMS is the remote kms that is used to encrypt and decrypt the local KEKs. - remoteKMS service.Service - // localKEKTracker is a atomic pointer to avoid locks. This is used to store the local KEK. - localKEKTracker atomic.Pointer[localKEK] - // transformers is a thread-safe LRU cache which caches decrypted DEKs indexed by their encrypted form. - // The cache is only used for the decrypt operation. - transformers *lru.Cache - // isReady is an atomic boolean that indicates if the localKEK service is ready for encryption. - isReady atomic.Bool - - keyMaxUsage uint64 - keySuggestedUsage uint64 - keyMaxAge time.Duration - - pollInterval time.Duration - - clock clock.Clock -} - -// NewLocalKEKService is being initialized with a remote KMS service. -// The local KEK is generated in a controller loop. The local KEK is used for all -// encryption operations. -func NewLocalKEKService(ctx context.Context, remoteService service.Service) *LocalKEKService { - return newLocalKEKService(ctx, remoteService, keyMaxUsage, keySuggestedUsage, keyMaxAge, localKEKGenerationPollInterval, clock.RealClock{}) -} - -func newLocalKEKService(ctx context.Context, remoteService service.Service, maxUsage, suggestedUsage uint64, maxAge, pollInterval time.Duration, clock clock.Clock) *LocalKEKService { - localKEKService := &LocalKEKService{ - remoteKMS: remoteService, - transformers: lru.New(cacheSize), - keyMaxUsage: maxUsage, - keySuggestedUsage: suggestedUsage, - keyMaxAge: maxAge, - pollInterval: pollInterval, - clock: clock, - } - go localKEKService.run(ctx) - return localKEKService -} - -// Run method creates a new local KEK when the following thresholds are met: -// - the local KEK is used more often than keySuggestedUsage times or -// - the local KEK is older than a localExpiry. -// -// this method starts the controller and blocks until the context is cancelled. -func (m *LocalKEKService) run(ctx context.Context) { - // same as wait.UntilWithContext but with a custom clock - wait.BackoffUntil(func() { - lk := m.getLocalKEK() - // this is the first time the local KEK is generated - localKEKNotGenerated := lk.transformer == nil - // the local KEK is used more often than keySuggestedUsage times - localKEKUsageThresholdReached := lk.usage.Load() > m.keySuggestedUsage - // the local KEK is older than the expiry - localKEKExpired := m.clock.Now().After(lk.expiry) - - if localKEKNotGenerated || localKEKUsageThresholdReached || localKEKExpired { - uid := string(uuid.NewUUID()) - err := m.generateLocalKEK(ctx, uid, "") - if err == nil { - m.isReady.Store(true) - return - } - klog.V(2).ErrorS(err, "failed to generate local KEK", "uid", uid) - // if the local KEK is expired and we cannot generate a new one, we set - // isReady to false because we can no longer encrypt new data. - if localKEKExpired { - m.isReady.Store(false) - } - } - }, wait.NewJitteredBackoffManager(m.pollInterval, 0, m.clock), true, ctx.Done()) -} - -// getTransformerForEncryption returns the local KEK as localTransformer, the corresponding -// remoteKMSResponse and a potential error. -// On every use the localUsage is incremented by one. -// It is assumed that only one encryption will happen with the returned transformer. -func (m *LocalKEKService) getTransformerForEncryption(uid string) (value.Transformer, *service.EncryptResponse, error) { - lk := m.getLocalKEK() - // localKEK is not initialized yet - if lk.transformer == nil { - return nil, nil, fmt.Errorf("local KEK is not initialized") - } - - if m.clock.Now().After(lk.expiry) { - return nil, nil, fmt.Errorf("local KEK has expired at %v", lk.expiry) - } - - if counter := lk.usage.Add(1); counter >= m.keyMaxUsage { - return nil, nil, fmt.Errorf("local KEK has reached maximum usage of %d", keyMaxUsage) - } - - return lk.transformer, lk.remoteKMSResponse, nil -} - -// Encrypt encrypts the plaintext with the localKEK. -func (m *LocalKEKService) Encrypt(ctx context.Context, uid string, pt []byte) (*service.EncryptResponse, error) { - transformer, resp, err := m.getTransformerForEncryption(uid) - if err != nil { - klog.V(2).ErrorS(err, "failed to get transformer for encryption", "uid", uid) - return nil, err - } - - ct, err := transformer.TransformToStorage(ctx, pt, emptyContext) - if err != nil { - klog.V(2).ErrorS(err, "failed to encrypt data", "uid", uid) - return nil, err - } - - return &service.EncryptResponse{ - Ciphertext: ct, - KeyID: resp.KeyID, - Annotations: resp.Annotations, - }, nil -} - -// getTransformerForDecryption returns the transformer for the given encryptedKEK. -// - If the encryptedKEK is the current localKEK, the localKEK is returned. -// - If the encryptedKEK is not the current localKEK, the cache is checked. -// - If the encryptedKEK is not found in the cache, the remote KMS is used to decrypt the encryptedKEK. -func (m *LocalKEKService) getTransformerForDecryption(ctx context.Context, uid string, req *service.DecryptRequest) (value.Transformer, error) { - encKEK := req.Annotations[referenceKEKAnnotationKey] - - // check if the key required for decryption is the current local KEK - // that's being used for encryption - lk := m.getLocalKEK() - if lk.transformer != nil && bytes.Equal(lk.encKEK, encKEK) { - return lk.transformer, nil - } - // check if the key required for decryption is already in the cache - if _transformer, found := m.transformers.Get(base64.StdEncoding.EncodeToString(encKEK)); found { - return _transformer.(value.Transformer), nil - } - - key, err := m.remoteKMS.Decrypt(ctx, uid, &service.DecryptRequest{ - Ciphertext: encKEK, - KeyID: req.KeyID, - Annotations: annotationsWithoutReferenceKeys(req.Annotations), - }) - if err != nil { - return nil, err - } - - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - transformer := aestransformer.NewGCMTransformer(block) - - // Overwrite the plain key with 0s. - copy(key, make([]byte, len(key))) - - m.transformers.Add(base64.StdEncoding.EncodeToString(encKEK), transformer) - - return transformer, nil -} - -// Decrypt attempts to decrypt the ciphertext with the localKEK, a KEK from the -// store, or the remote KMS. -func (m *LocalKEKService) Decrypt(ctx context.Context, uid string, req *service.DecryptRequest) ([]byte, error) { - if _, ok := req.Annotations[referenceKEKAnnotationKey]; !ok { - return nil, fmt.Errorf("unable to find local KEK for request with uid %q", uid) - } - - transformer, err := m.getTransformerForDecryption(ctx, uid, req) - if err != nil { - klog.V(2).ErrorS(err, "failed to get transformer for decryption", "uid", uid) - return nil, fmt.Errorf("failed to get transformer for decryption: %w", err) - } - - pt, _, err := transformer.TransformFromStorage(ctx, req.Ciphertext, emptyContext) - if err != nil { - klog.V(2).ErrorS(err, "failed to decrypt data", "uid", uid) - return nil, err - } - - return pt, nil -} - -// Status returns the status of the remote KMS. -func (m *LocalKEKService) Status(ctx context.Context) (*service.StatusResponse, error) { - resp, err := m.remoteKMS.Status(ctx) - if err != nil { - return nil, err - } - if err := validateRemoteKMSStatusResponse(resp); err != nil { - return nil, err - } - - r := copyStatusResponse(resp) - // if the remote KMS KeyID has changed, we need to rotate the local KEK - lk := m.getLocalKEK() - if lk.transformer != nil && r.KeyID != lk.remoteKMSResponse.KeyID { - if err := m.rotateLocalKEK(ctx, r.KeyID); err != nil { - klog.ErrorS(err, "failed to rotate local KEK", "expectedKeyID", r.KeyID, "currentKeyID", lk.remoteKMSResponse.KeyID) - // if rotation fails, we will overwrite the keyID to the one we are currently using - // for encryption as localKEKService is the authoritative source for the keyID. - r.KeyID = lk.remoteKMSResponse.KeyID - // TODO(aramase): we are currently not returning the error if rotation fails. We should - // allow the failed state for an arbitrary time period and return the error if the state - // is not eventually fixed. - } - } - - var aggregateHealthz []string - if r.Healthz != "ok" { - aggregateHealthz = append(aggregateHealthz, r.Healthz) - } - - if !m.isReady.Load() { - // if the localKEKService is not ready, we will set the healthz status to not ready - klog.V(2).InfoS("localKEKService is not ready", "keyID", r.KeyID) - aggregateHealthz = append(aggregateHealthz, "localKEKService is not ready") - } - if len(aggregateHealthz) > 0 { - r.Healthz = strings.Join(aggregateHealthz, "; ") - } - - return r, nil -} - -// rotateLocalKEK rotates the local KEK by generating a new local KEK and encrypting it with the -// remote KMS. -func (m *LocalKEKService) rotateLocalKEK(ctx context.Context, expectedKeyID string) error { - uid := string(uuid.NewUUID()) - if err := m.generateLocalKEK(ctx, uid, expectedKeyID); err != nil { - klog.V(2).ErrorS(err, "failed to generate local KEK as part of rotation", "uid", uid) - return fmt.Errorf("[uid=%s] failed to generate local KEK as part of rotation: %w", uid, err) - } - return nil -} - -// generateLocalKEK generates a new local KEK and encrypts it with the remote KMS. -// if expectedKeyID is not empty, it will check if the keyID returned from the remote KMS matches -// the expected keyID. If the keyID does not match, it will continue using the existing local KEK -// and return an error. -func (m *LocalKEKService) generateLocalKEK(ctx context.Context, uid, expectedKeyID string) error { - m.mu.Lock() - defer m.mu.Unlock() - - lk := m.getLocalKEK() - // if the localKEK was generated in the last pollInterval duration, we will not generate a new - // localKEK. This is to avoid regenerating a new localKEK for queued requests. - if lk.transformer != nil && m.clock.Since(lk.generatedAt) < m.pollInterval { - return nil - } - - key, err := generateKey(keyLength) - if err != nil { - return fmt.Errorf("failed to generate local KEK: %w", err) - } - block, err := aes.NewCipher(key) - if err != nil { - return fmt.Errorf("failed to create cipher block: %w", err) - } - - resp, err := m.remoteKMS.Encrypt(ctx, uid, key) - if err != nil { - return fmt.Errorf("failed to encrypt local KEK: %w", err) - } - if err = validateRemoteKMSEncryptResponse(resp); err != nil { - return fmt.Errorf("invalid response from remote KMS: %w", err) - } - if expectedKeyID != "" && resp.KeyID != expectedKeyID { - return fmt.Errorf("keyID returned from remote KMS does not match expected keyID") - } - - now := m.clock.Now() - m.localKEKTracker.Store(&localKEK{ - encKEK: resp.Ciphertext, - expiry: now.Add(m.keyMaxAge), - usage: atomic.Uint64{}, - transformer: aestransformer.NewGCMTransformer(block), - remoteKMSResponse: copyResponseAndAddLocalKEKAnnotation(resp), - generatedAt: now, - }) - - return nil -} - -func (m *LocalKEKService) getLocalKEK() *localKEK { - lk := m.localKEKTracker.Load() - if lk == nil { - return &localKEK{} - } - return lk -} - -// copyResponseAndAddLocalKEKAnnotation returns a copy of the remoteKMSResponse with the -// referenceKEKAnnotationKey added to the annotations. -func copyResponseAndAddLocalKEKAnnotation(resp *service.EncryptResponse) *service.EncryptResponse { - annotations := make(map[string][]byte, len(resp.Annotations)+numAnnotations) - for s, bytes := range resp.Annotations { - s := s - bytes := bytes - annotations[s] = bytes - } - annotations[referenceKEKAnnotationKey] = resp.Ciphertext - - return &service.EncryptResponse{ - // Ciphertext is not set on purpose - it is different per Encrypt call - KeyID: resp.KeyID, - Annotations: annotations, - } -} - -// copyStatusResponse returns a copy of the remote KMS status response. -func copyStatusResponse(resp *service.StatusResponse) *service.StatusResponse { - return &service.StatusResponse{ - Healthz: resp.Healthz, - Version: resp.Version, - KeyID: resp.KeyID, - } -} - -// annotationsWithoutReferenceKeys returns a copy of the annotations without the reference implementation -// annotations. -func annotationsWithoutReferenceKeys(annotations map[string][]byte) map[string][]byte { - if len(annotations) <= numAnnotations { - return nil - } - - m := make(map[string][]byte, len(annotations)-numAnnotations) - for k, v := range annotations { - k, v := k, v - if strings.HasSuffix(k, referenceSuffix) { - continue - } - m[k] = v - } - return m -} - -// validateRemoteKMSEncryptResponse validates the EncryptResponse from the remote KMS. -func validateRemoteKMSEncryptResponse(resp *service.EncryptResponse) error { - // validate annotations don't contain the reference implementation annotations - for k := range resp.Annotations { - if strings.HasSuffix(k, referenceSuffix) { - return errInvalidKMSAnnotationKeySuffix - } - } - return nil -} - -// validateRemoteKMSStatusResponse validates the StatusResponse from the remote KMS. -func validateRemoteKMSStatusResponse(resp *service.StatusResponse) error { - if len(resp.KeyID) == 0 { - return fmt.Errorf("keyID is empty") - } - return nil -} - -// generateKey generates a random key using system randomness. -func generateKey(length int) (key []byte, err error) { - key = make([]byte, length) - if _, err = rand.Read(key); err != nil { - return nil, err - } - - return key, nil -} diff --git a/staging/src/k8s.io/kms/pkg/hierarchy/hierarchy_test.go b/staging/src/k8s.io/kms/pkg/hierarchy/hierarchy_test.go deleted file mode 100644 index 7ffdcb29d2f..00000000000 --- a/staging/src/k8s.io/kms/pkg/hierarchy/hierarchy_test.go +++ /dev/null @@ -1,833 +0,0 @@ -/* -Copyright 2023 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package hierarchy - -import ( - "bytes" - "context" - "encoding/base64" - "errors" - "reflect" - "strings" - "sync" - "testing" - "time" - - "k8s.io/apimachinery/pkg/util/rand" - "k8s.io/apimachinery/pkg/util/wait" - "k8s.io/kms/pkg/service" - testingclock "k8s.io/utils/clock/testing" -) - -const ( - testAnnotationKey = "version.encryption.remote.io" - testAnnotationKeyVersion = "key-version.encryption.remote.io" -) - -func TestCopyResponseAndAddLocalKEKAnnotation(t *testing.T) { - t.Parallel() - testCases := []struct { - name string - input *service.EncryptResponse - want *service.EncryptResponse - }{ - { - name: "annotations is nil", - input: &service.EncryptResponse{ - Ciphertext: []byte("encryptedLocalKEK"), - KeyID: "keyID", - Annotations: nil, - }, - want: &service.EncryptResponse{ - KeyID: "keyID", - Annotations: map[string][]byte{ - referenceKEKAnnotationKey: []byte("encryptedLocalKEK"), - }, - }, - }, - { - name: "remote KMS sent 1 annotation", - input: &service.EncryptResponse{ - Ciphertext: []byte("encryptedLocalKEK"), - KeyID: "keyID", - Annotations: map[string][]byte{ - testAnnotationKey: []byte("1"), - }, - }, - want: &service.EncryptResponse{ - KeyID: "keyID", - Annotations: map[string][]byte{ - testAnnotationKey: []byte("1"), - referenceKEKAnnotationKey: []byte("encryptedLocalKEK"), - }, - }, - }, - { - name: "remote KMS sent 2 annotations", - input: &service.EncryptResponse{ - Ciphertext: []byte("encryptedLocalKEK"), - KeyID: "keyID", - Annotations: map[string][]byte{ - testAnnotationKey: []byte("1"), - testAnnotationKeyVersion: []byte("2"), - }, - }, - want: &service.EncryptResponse{ - KeyID: "keyID", - Annotations: map[string][]byte{ - testAnnotationKey: []byte("1"), - testAnnotationKeyVersion: []byte("2"), - referenceKEKAnnotationKey: []byte("encryptedLocalKEK"), - }, - }, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got := copyResponseAndAddLocalKEKAnnotation(tc.input) - if !reflect.DeepEqual(got, tc.want) { - t.Errorf("copyResponseAndAddLocalKEKAnnotation(%v) = %v, want %v", tc.input, got, tc.want) - } - }) - } -} - -func TestAnnotationsWithoutReferenceKeys(t *testing.T) { - t.Parallel() - testCases := []struct { - name string - input map[string][]byte - want map[string][]byte - }{ - { - name: "annotations is nil", - input: nil, - want: nil, - }, - { - name: "annotations is empty", - input: map[string][]byte{}, - want: nil, - }, - { - name: "annotations only contains reference keys", - input: map[string][]byte{ - referenceKEKAnnotationKey: []byte("encryptedLocalKEK"), - }, - want: nil, - }, - { - name: "annotations contains 1 reference key and 1 other key", - input: map[string][]byte{ - referenceKEKAnnotationKey: []byte("encryptedLocalKEK"), - testAnnotationKey: []byte("1"), - }, - want: map[string][]byte{ - testAnnotationKey: []byte("1"), - }, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got := annotationsWithoutReferenceKeys(tc.input) - if !reflect.DeepEqual(got, tc.want) { - t.Errorf("annotationsWithoutReferenceKeys(%v) = %v, want %v", tc.input, got, tc.want) - } - }) - } -} - -func TestValidateRemoteKMSEncryptResponse(t *testing.T) { - t.Parallel() - testCases := []struct { - name string - input *service.EncryptResponse - want error - }{ - { - name: "annotations is nil", - input: &service.EncryptResponse{}, - want: nil, - }, - { - name: "annotation key contains reference suffix", - input: &service.EncryptResponse{ - Annotations: map[string][]byte{ - "version.reference.encryption.k8s.io": []byte("1"), - }, - }, - want: errInvalidKMSAnnotationKeySuffix, - }, - { - name: "no annotation key contains reference suffix", - input: &service.EncryptResponse{ - Annotations: map[string][]byte{ - testAnnotationKey: []byte("1"), - testAnnotationKeyVersion: []byte("2"), - }, - }, - want: nil, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got := validateRemoteKMSEncryptResponse(tc.input) - if got != tc.want { - t.Errorf("validateRemoteKMSResponse(%v) = %v, want %v", tc.input, got, tc.want) - } - }) - } -} - -func TestValidateRemoteKMSStatusResponse(t *testing.T) { - t.Parallel() - testCases := []struct { - name string - input *service.StatusResponse - wantErr string - }{ - { - name: "keyID is empty", - input: &service.StatusResponse{ - KeyID: "", - }, - wantErr: "keyID is empty", - }, - { - name: "no error", - input: &service.StatusResponse{ - KeyID: "keyID", - }, - wantErr: "", - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got := validateRemoteKMSStatusResponse(tc.input) - if tc.wantErr != "" { - if got == nil { - t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr) - } - if !strings.Contains(got.Error(), tc.wantErr) { - t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr) - } - } else { - if got != nil { - t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr) - } - } - }) - } -} - -var _ service.Service = &testRemoteService{} - -type testRemoteService struct { - mu sync.Mutex - - keyID string - disabled bool - encryptCallCount int - decryptCallCount int -} - -func (s *testRemoteService) Encrypt(ctx context.Context, uid string, plaintext []byte) (*service.EncryptResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() - - s.encryptCallCount++ - if s.disabled { - return nil, errors.New("failed to encrypt") - } - return &service.EncryptResponse{ - KeyID: s.keyID, - Ciphertext: []byte(base64.StdEncoding.EncodeToString(plaintext)), - Annotations: map[string][]byte{ - testAnnotationKey: []byte("1"), - }, - }, nil -} - -func (s *testRemoteService) Decrypt(ctx context.Context, uid string, req *service.DecryptRequest) ([]byte, error) { - s.mu.Lock() - defer s.mu.Unlock() - - s.decryptCallCount++ - if s.disabled { - return nil, errors.New("failed to decrypt") - } - if len(req.Annotations) != 1 { - return nil, errors.New("invalid annotations") - } - if v, ok := req.Annotations[testAnnotationKey]; !ok || string(v) != "1" { - return nil, errors.New("invalid version in annotations") - } - return base64.StdEncoding.DecodeString(string(req.Ciphertext)) -} - -func (s *testRemoteService) Status(ctx context.Context) (*service.StatusResponse, error) { - s.mu.Lock() - defer s.mu.Unlock() - - resp := &service.StatusResponse{ - Version: "v2alpha1", - Healthz: "ok", - KeyID: s.keyID, - } - if s.disabled { - resp.Healthz = "remote KMS is disabled" - } - return resp, nil -} - -func (s *testRemoteService) SetDisabledStatus(disabled bool) { - s.mu.Lock() - defer s.mu.Unlock() - s.disabled = disabled -} - -func (s *testRemoteService) SetKeyID(keyID string) { - s.mu.Lock() - defer s.mu.Unlock() - s.keyID = keyID -} - -func (s *testRemoteService) EncryptCallCount() int { - s.mu.Lock() - defer s.mu.Unlock() - return s.encryptCallCount -} - -func (s *testRemoteService) DecryptCallCount() int { - s.mu.Lock() - defer s.mu.Unlock() - return s.decryptCallCount -} - -func TestEncrypt(t *testing.T) { - ctx := testContext(t) - remoteKMS := &testRemoteService{keyID: "test-key-id"} - localKEKService := NewLocalKEKService(ctx, remoteKMS) - - waitUntilReady(t, localKEKService) - - // local KEK is generated and encryption is successful - got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) - if err != nil { - t.Fatalf("Encrypt() error = %v", err) - } - validateEncryptResponse(t, got, remoteKMS.keyID, localKEKService) - - // local KEK is used for encryption even when remote KMS is failing - remoteKMS.SetDisabledStatus(true) - if got, err = localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")); err != nil { - t.Fatalf("Encrypt() error = %v", err) - } - validateEncryptResponse(t, got, remoteKMS.keyID, localKEKService) -} - -func TestEncryptError(t *testing.T) { - ctx := testContext(t) - remoteKMS := &testRemoteService{keyID: "test-key-id"} - localKEKService := NewLocalKEKService(ctx, remoteKMS) - - // first time local KEK generation fails because of remote KMS - remoteKMS.SetDisabledStatus(true) - _, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) - if err == nil { - t.Fatalf("Encrypt() error = %v, want non-nil", err) - } - lk := localKEKService.getLocalKEK() - if lk.transformer != nil { - t.Fatalf("Encrypt() localKEKTracker = %v, want non-nil localKEK", lk) - } - - remoteKMS.SetDisabledStatus(false) -} - -func TestDecrypt(t *testing.T) { - ctx := testContext(t) - remoteKMS := &testRemoteService{keyID: "test-key-id"} - localKEKService := NewLocalKEKService(ctx, remoteKMS) - - waitUntilReady(t, localKEKService) - - // local KEK is generated and encryption/decryption is successful - got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) - if err != nil { - t.Fatalf("Encrypt() error = %v", err) - } - if string(got.Ciphertext) == "test-plaintext" { - t.Fatalf("Encrypt() ciphertext = %v, want it to be encrypted", got.Ciphertext) - } - decryptRequest := &service.DecryptRequest{ - Ciphertext: got.Ciphertext, - Annotations: got.Annotations, - KeyID: got.KeyID, - } - plaintext, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest) - if err != nil { - t.Fatalf("Decrypt() error = %v", err) - } - if string(plaintext) != "test-plaintext" { - t.Fatalf("Decrypt() plaintext = %v, want %v", string(plaintext), "test-plaintext") - } - - // local KEK is used for decryption even when remote KMS is failing - remoteKMS.SetDisabledStatus(true) - if _, err = localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil { - t.Fatalf("Decrypt() error = %v", err) - } -} - -func TestDecryptError(t *testing.T) { - ctx := testContext(t) - remoteKMS := &testRemoteService{keyID: "test-key-id"} - localKEKService := NewLocalKEKService(ctx, remoteKMS) - - waitUntilReady(t, localKEKService) - - got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) - if err != nil { - t.Fatalf("Encrypt() error = %v", err) - } - decryptRequest := &service.DecryptRequest{ - Ciphertext: got.Ciphertext, - Annotations: got.Annotations, - KeyID: got.KeyID, - } - // local KEK for decryption not in cache and remote KMS is failing - remoteKMS.SetDisabledStatus(true) - lk := localKEKService.localKEKTracker.Load() - lk.transformer = nil - localKEKService.localKEKTracker.Store(lk) - - // clear the cache - localKEKService.transformers.Clear() - if _, err = localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err == nil { - t.Fatalf("Decrypt() error = %v, want non-nil", err) - } -} - -func TestStatus(t *testing.T) { - ctx := testContext(t) - fakeClock := testingclock.NewFakeClock(time.Now()) - remoteKMS := &testRemoteService{keyID: "test-key-id"} - localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Second, 100*time.Millisecond, fakeClock) - - waitUntilReady(t, localKEKService) - - got, err := localKEKService.Status(ctx) - if err != nil { - t.Fatalf("Status() error = %v", err) - } - validateStatusResponse(t, got, "v2alpha1", "ok", "test-key-id") - - fakeClock.Step(1 * time.Second) - // remote KMS is failing - remoteKMS.SetDisabledStatus(true) - // remote KMS keyID changed but local KEK not rotated because of remote KMS failure - // the keyID in status should be the old keyID - // the error should still be nil - remoteKMS.SetKeyID("test-key-id-2") - - if got, err = localKEKService.Status(ctx); err != nil { - t.Fatalf("Status() error = %v, want nil", err) - } - validateStatusResponse(t, got, "v2alpha1", "remote KMS is disabled", "test-key-id") - - fakeClock.Step(1 * time.Second) - // wait for local KEK to expire and local KEK service ready to be false - wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) { - return !localKEKService.isReady.Load(), nil - }) - - // status response should include the localKEK unhealthy status - if got, err = localKEKService.Status(ctx); err != nil { - t.Fatalf("Status() error = %v, want nil", err) - } - validateStatusResponse(t, got, "v2alpha1", "remote KMS is disabled; localKEKService is not ready", "test-key-id") - - // remote KMS is functional again, local KEK is rotated - remoteKMS.SetDisabledStatus(false) - fakeClock.Step(1 * time.Second) - waitUntilReady(t, localKEKService) - if got, err = localKEKService.Status(ctx); err != nil { - t.Fatalf("Status() error = %v, want nil", err) - } - validateStatusResponse(t, got, "v2alpha1", "ok", "test-key-id-2") -} - -func TestRotationKeyUsage(t *testing.T) { - ctx := testContext(t) - - var record sync.Map - - fakeClock := testingclock.NewFakeClock(time.Now()) - remoteKMS := &testRemoteService{keyID: "test-key-id"} - localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Minute, 100*time.Millisecond, fakeClock) - waitUntilReady(t, localKEKService) - lk := localKEKService.localKEKTracker.Load() - encLocalKEK := lk.encKEK - - // check only single call for Encrypt to remote KMS - if remoteKMS.EncryptCallCount() != 1 { - t.Fatalf("Encrypt() remoteKMS.EncryptCallCount() = %v, want %v", remoteKMS.EncryptCallCount(), 1) - } - - var wg sync.WaitGroup - for i := 0; i < 6; i++ { - wg.Add(1) - go func() { - defer wg.Done() - resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte(rand.String(32))) - if err != nil { - t.Errorf("Encrypt() error = %v", err) - return - } - if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) { - t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK) - return - } - record.Store(resp, nil) - }() - } - wg.Wait() - if t.Failed() { - return - } - - fakeClock.Step(30 * time.Second) - rotated := false - // wait for the local KEK to be rotated - wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) { - // local KEK must have been rotated after 5 usages - lk = localKEKService.localKEKTracker.Load() - rotated = !bytes.Equal(lk.encKEK, encLocalKEK) - return rotated, nil - }) - if !rotated { - t.Fatalf("local KEK must have been rotated") - } - if remoteKMS.EncryptCallCount() != 2 { - t.Fatalf("Encrypt() remoteKMS.EncryptCallCount() = %v, want %v", remoteKMS.EncryptCallCount(), 2) - } - - // new local KEK must be used for encryption - for i := 0; i < 5; i++ { - wg.Add(1) - go func() { - defer wg.Done() - resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte(rand.String(32))) - if err != nil { - t.Errorf("Encrypt() error = %v", err) - return - } - if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) { - t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK) - return - } - record.Store(resp, nil) - }() - } - wg.Wait() - if t.Failed() { - return - } - - // check we can decrypt data encrypted with the old and new local KEKs - record.Range(func(key, _ any) bool { - k := key.(*service.EncryptResponse) - decryptRequest := &service.DecryptRequest{ - Ciphertext: k.Ciphertext, - Annotations: k.Annotations, - KeyID: k.KeyID, - } - if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil { - t.Fatalf("Decrypt() error = %v", err) - } - return true - }) - - // Out of the 11 calls to Decrypt: - // - 5 should be using the current local KEK - // - 1 out of the 6 should generate a decrypt call to the remote KMS as the local KEK not in cache - // - 5 out of the 6 should use the cached local KEK after 1st decrypt call to the remote KMS - assertCallCount(t, remoteKMS, localKEKService) -} - -func TestRotationKeyExpiry(t *testing.T) { - ctx := testContext(t) - - var record sync.Map - - fakeClock := testingclock.NewFakeClock(time.Now()) - remoteKMS := &testRemoteService{keyID: "test-key-id"} - localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 5*time.Second, 100*time.Millisecond, fakeClock) - waitUntilReady(t, localKEKService) - lk := localKEKService.localKEKTracker.Load() - encLocalKEK := lk.encKEK - - var wg sync.WaitGroup - for i := 0; i < 3; i++ { - wg.Add(1) - go func() { - defer wg.Done() - resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) - if err != nil { - t.Errorf("Encrypt() error = %v", err) - return - } - if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) { - t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK) - return - } - record.Store(resp, nil) - }() - } - wg.Wait() - if t.Failed() { - return - } - - // check local KEK has only been used 3 times and still under the suggested usage - if lk.usage.Load() != 3 { - t.Fatalf("local KEK usage = %v, want %v", lk.usage.Load(), 3) - } - - // advance the clock to trigger key expiry - fakeClock.Step(6 * time.Second) - - rotated := false - // wait for the local KEK to be rotated due to key expiry - wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) { - // local KEK must have been rotated after the key max age - t.Log("waiting for local KEK to be rotated") - lk = localKEKService.localKEKTracker.Load() - rotated = !bytes.Equal(lk.encKEK, encLocalKEK) - return rotated, nil - }) - if !rotated { - t.Fatalf("local KEK must have been rotated") - } - - // new local KEK must be used for encryption - for i := 0; i < 5; i++ { - wg.Add(1) - go func() { - defer wg.Done() - resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) - if err != nil { - t.Errorf("Encrypt() error = %v", err) - return - } - if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) { - t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK) - return - } - record.Store(resp, nil) - }() - } - wg.Wait() - if t.Failed() { - return - } - - // check we can decrypt data encrypted with the old and new local KEKs - record.Range(func(key, _ any) bool { - k := key.(*service.EncryptResponse) - decryptRequest := &service.DecryptRequest{ - Ciphertext: k.Ciphertext, - Annotations: k.Annotations, - KeyID: k.KeyID, - } - if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil { - t.Fatalf("Decrypt() error = %v", err) - } - return true - }) - - // Out of the 8 calls to Decrypt: - // - 5 should be using the current local KEK - // - 1 out of the 3 should generate a decrypt call to the remote KMS as the local KEK not in cache - // - 2 out of the 3 should use the cached local KEK after 1st decrypt call to the remote KMS - assertCallCount(t, remoteKMS, localKEKService) -} - -func TestRotationRemoteKeyIDChanged(t *testing.T) { - ctx := testContext(t) - - var record sync.Map - - fakeClock := testingclock.NewFakeClock(time.Now()) - remoteKMS := &testRemoteService{keyID: "test-key-id"} - localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Minute, 100*time.Millisecond, fakeClock) - waitUntilReady(t, localKEKService) - lk := localKEKService.localKEKTracker.Load() - encLocalKEK := lk.encKEK - - var wg sync.WaitGroup - for i := 0; i < 3; i++ { - wg.Add(1) - go func() { - defer wg.Done() - resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) - if err != nil { - t.Errorf("Encrypt() error = %v", err) - return - } - if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) { - t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK) - return - } - record.Store(resp, nil) - }() - } - wg.Wait() - if t.Failed() { - return - } - - // check local KEK has only been used 3 times and still under the suggested usage - if lk.usage.Load() != 3 { - t.Fatalf("local KEK usage = %v, want %v", lk.usage.Load(), 3) - } - - fakeClock.Step(30 * time.Second) - // change the remote key ID - remoteKMS.SetKeyID("test-key-id-2") - if _, err := localKEKService.Status(ctx); err != nil { - t.Fatalf("Status() error = %v", err) - } - - rotated := false - // wait for the local KEK to be rotated due to remote key ID change - wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) { - lk = localKEKService.localKEKTracker.Load() - rotated = !bytes.Equal(lk.encKEK, encLocalKEK) - return rotated, nil - }) - if !rotated { - t.Fatalf("local KEK must have been rotated") - } - - // new local KEK must be used for encryption - for i := 0; i < 5; i++ { - wg.Add(1) - go func() { - defer wg.Done() - resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")) - if err != nil { - t.Errorf("Encrypt() error = %v", err) - return - } - if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) { - t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK) - return - } - record.Store(resp, nil) - }() - } - wg.Wait() - if t.Failed() { - return - } - - // check we can decrypt data encrypted with the old and new local KEKs - record.Range(func(key, _ any) bool { - k := key.(*service.EncryptResponse) - decryptRequest := &service.DecryptRequest{ - Ciphertext: k.Ciphertext, - Annotations: k.Annotations, - KeyID: k.KeyID, - } - if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil { - t.Fatalf("Decrypt() error = %v", err) - } - return true - }) - - // Out of the 8 calls to Decrypt: - // - 5 should be using the current local KEK - // - 1 out of the 3 should generate a decrypt call to the remote KMS as the local KEK not in cache - // - 2 out of the 3 should use the cached local KEK after 1st decrypt call to the remote KMS - assertCallCount(t, remoteKMS, localKEKService) -} - -func testContext(t *testing.T) context.Context { - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - return ctx -} - -func waitUntilReady(t *testing.T, s *LocalKEKService) { - t.Helper() - wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) { - return s.isReady.Load(), nil - }) -} - -func validateEncryptResponse(t *testing.T, got *service.EncryptResponse, wantKeyID string, localKEKService *LocalKEKService) { - t.Helper() - if len(got.Annotations) != 2 { - t.Fatalf("Encrypt() annotations = %v, want 2 annotations", got.Annotations) - } - if _, ok := got.Annotations[referenceKEKAnnotationKey]; !ok { - t.Fatalf("Encrypt() annotations = %v, want %v", got.Annotations, referenceKEKAnnotationKey) - } - if got.KeyID != wantKeyID { - t.Fatalf("Encrypt() keyID = %v, want %v", got.KeyID, wantKeyID) - } - if localKEKService.localKEKTracker.Load() == nil { - t.Fatalf("Encrypt() localKEKTracker = %v, want non-nil localKEK", localKEKService.localKEKTracker.Load()) - } -} - -func validateStatusResponse(t *testing.T, got *service.StatusResponse, wantVersion, wantHealthz, wantKeyID string) { - t.Helper() - if got.Version != wantVersion { - t.Fatalf("Status() version = %v, want %v", got.Version, wantVersion) - } - if !strings.EqualFold(got.Healthz, wantHealthz) { - t.Fatalf("Status() healthz = %v, want %v", got.Healthz, wantHealthz) - } - if got.KeyID != wantKeyID { - t.Fatalf("Status() keyID = %v, want %v", got.KeyID, wantKeyID) - } -} - -func assertCallCount(t *testing.T, remoteKMS *testRemoteService, localKEKService *LocalKEKService) { - t.Helper() - if remoteKMS.DecryptCallCount() != 1 { - t.Fatalf("Decrypt() remoteKMS.DecryptCallCount() = %v, want %v", remoteKMS.DecryptCallCount(), 1) - } - if localKEKService.transformers.Len() != 1 { - t.Fatalf("Decrypt() localKEKService.transformers.Len() = %v, want %v", localKEKService.transformers.Len(), 1) - } -}