Merge pull request #116630 from aramase/aramase/c/rm_key_hierarchy

[KMSv2] remove key hierarchy in reference implementation
This commit is contained in:
Kubernetes Prow Robot 2023-03-14 22:02:14 -07:00 committed by GitHub
commit 9bb778d48e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 12 additions and 1303 deletions

View File

@ -10,19 +10,18 @@ require (
k8s.io/apimachinery v0.0.0
k8s.io/client-go v0.0.0
k8s.io/klog/v2 v2.90.1
k8s.io/utils v0.0.0-20230209194617-a36077c30491
)
require (
github.com/go-logr/logr v1.2.3 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/uuid v1.3.0 // indirect
golang.org/x/net v0.8.0 // indirect
golang.org/x/sys v0.6.0 // indirect
golang.org/x/text v0.8.0 // indirect
golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect
google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect
google.golang.org/protobuf v1.28.1 // indirect
k8s.io/utils v0.0.0-20230209194617-a36077c30491 // indirect
)
replace (

View File

@ -50,8 +50,6 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=

View File

@ -11,7 +11,6 @@ require (
github.com/go-logr/logr v1.2.3 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/uuid v1.3.0 // indirect
golang.org/x/net v0.8.0 // indirect
golang.org/x/sys v0.6.0 // indirect
golang.org/x/text v0.8.0 // indirect
@ -19,7 +18,6 @@ require (
google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect
google.golang.org/grpc v1.51.0 // indirect
google.golang.org/protobuf v1.28.1 // indirect
k8s.io/apimachinery v0.0.0 // indirect
k8s.io/client-go v0.0.0 // indirect
k8s.io/utils v0.0.0-20230209194617-a36077c30491 // indirect
)

View File

@ -50,8 +50,6 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=

View File

@ -26,7 +26,6 @@ import (
"k8s.io/klog/v2"
"k8s.io/kms/internal"
"k8s.io/kms/pkg/hierarchy"
"k8s.io/kms/pkg/service"
"k8s.io/kms/pkg/util"
)
@ -55,14 +54,20 @@ func main() {
grpcService := service.NewGRPCService(
addr,
*timeout,
hierarchy.NewLocalKEKService(ctx, remoteKMSService),
remoteKMSService,
)
klog.InfoS("starting server", "listenAddr", *listenAddr)
if err := grpcService.ListenAndServe(); err != nil {
klog.ErrorS(err, "failed to serve")
os.Exit(1)
}
go func() {
if err := grpcService.ListenAndServe(); err != nil {
klog.ErrorS(err, "failed to serve")
os.Exit(1)
}
}()
<-ctx.Done()
klog.InfoS("shutting down server")
grpcService.Shutdown()
}
// withShutdownSignal returns a copy of the parent context that will close if

View File

@ -1,456 +0,0 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package hierarchy
import (
"bytes"
"context"
"crypto/aes"
"crypto/rand"
"encoding/base64"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"k8s.io/apimachinery/pkg/util/uuid"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/klog/v2"
aestransformer "k8s.io/kms/pkg/encrypt/aes"
"k8s.io/kms/pkg/service"
"k8s.io/kms/pkg/value"
"k8s.io/utils/clock"
"k8s.io/utils/lru"
)
// localKEK is a struct that holds the local KEK and the remote KMS response.
type localKEK struct {
encKEK []byte
usage atomic.Uint64
expiry time.Time
transformer value.Transformer
remoteKMSResponse *service.EncryptResponse
generatedAt time.Time
}
var (
// emptyContext is an empty slice of bytes. This is passed as value.Context to the
// GCM transformer. The grpc interface does not provide any additional authenticated data
// to use with AEAD.
emptyContext = value.DefaultContext([]byte{})
// errInvalidKMSAnnotationKeySuffix is returned when the annotation key suffix is not allowed.
errInvalidKMSAnnotationKeySuffix = fmt.Errorf("annotation keys are not allowed to use %s", referenceSuffix)
)
const (
referenceSuffix = ".reference.encryption.k8s.io"
// referenceKEKAnnotationKey is the key used to store the localKEK in the annotations.
referenceKEKAnnotationKey = "encrypted-kek" + referenceSuffix
numAnnotations = 1
cacheSize = 1_000
// localKEKGenerationPollInterval is the interval at which the local KEK is checked for rotation.
localKEKGenerationPollInterval = 1 * time.Minute
// keyLength is the length of the local KEK in bytes.
// This is the same length used for the DEKs generated in kube-apiserver.
keyLength = 32
// keyMaxUsage is the maximum number of times an AES GCM key can be used
// with a random nonce: 2^32. The local KEK is a transformer that hold an
// AES GCM key. It is based on recommendations from
// https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf.
// It is reduced by one to be comparable with a atomic.Uint32.
// We picked a arbitrary number that is less than the max usage of the local KEK.
keyMaxUsage = 1<<22 - 1
// keySuggestedUsage is a threshold that triggers the rotation of a new local KEK. It means that half
// the number of times a local KEK can be used has been reached.
keySuggestedUsage = 1 << 21
// keyMaxAge is the maximum age of a local KEK. It is not a cryptographic necessity.
keyMaxAge = 7 * 24 * time.Hour
)
var _ service.Service = &LocalKEKService{}
// LocalKEKService adds an additional KEK layer to reduce calls to the remote KMS.
// The local KEK is generated at startup in a controller loop and stored in the
// LocalKEKService. This KEK is used for all encryption operations. For the decrypt
// operation, if the encrypted local KEK is not found in the cache, the remote KMS
// is used to decrypt the local KEK.
type LocalKEKService struct {
mu sync.Mutex
// remoteKMS is the remote kms that is used to encrypt and decrypt the local KEKs.
remoteKMS service.Service
// localKEKTracker is a atomic pointer to avoid locks. This is used to store the local KEK.
localKEKTracker atomic.Pointer[localKEK]
// transformers is a thread-safe LRU cache which caches decrypted DEKs indexed by their encrypted form.
// The cache is only used for the decrypt operation.
transformers *lru.Cache
// isReady is an atomic boolean that indicates if the localKEK service is ready for encryption.
isReady atomic.Bool
keyMaxUsage uint64
keySuggestedUsage uint64
keyMaxAge time.Duration
pollInterval time.Duration
clock clock.Clock
}
// NewLocalKEKService is being initialized with a remote KMS service.
// The local KEK is generated in a controller loop. The local KEK is used for all
// encryption operations.
func NewLocalKEKService(ctx context.Context, remoteService service.Service) *LocalKEKService {
return newLocalKEKService(ctx, remoteService, keyMaxUsage, keySuggestedUsage, keyMaxAge, localKEKGenerationPollInterval, clock.RealClock{})
}
func newLocalKEKService(ctx context.Context, remoteService service.Service, maxUsage, suggestedUsage uint64, maxAge, pollInterval time.Duration, clock clock.Clock) *LocalKEKService {
localKEKService := &LocalKEKService{
remoteKMS: remoteService,
transformers: lru.New(cacheSize),
keyMaxUsage: maxUsage,
keySuggestedUsage: suggestedUsage,
keyMaxAge: maxAge,
pollInterval: pollInterval,
clock: clock,
}
go localKEKService.run(ctx)
return localKEKService
}
// Run method creates a new local KEK when the following thresholds are met:
// - the local KEK is used more often than keySuggestedUsage times or
// - the local KEK is older than a localExpiry.
//
// this method starts the controller and blocks until the context is cancelled.
func (m *LocalKEKService) run(ctx context.Context) {
// same as wait.UntilWithContext but with a custom clock
wait.BackoffUntil(func() {
lk := m.getLocalKEK()
// this is the first time the local KEK is generated
localKEKNotGenerated := lk.transformer == nil
// the local KEK is used more often than keySuggestedUsage times
localKEKUsageThresholdReached := lk.usage.Load() > m.keySuggestedUsage
// the local KEK is older than the expiry
localKEKExpired := m.clock.Now().After(lk.expiry)
if localKEKNotGenerated || localKEKUsageThresholdReached || localKEKExpired {
uid := string(uuid.NewUUID())
err := m.generateLocalKEK(ctx, uid, "")
if err == nil {
m.isReady.Store(true)
return
}
klog.V(2).ErrorS(err, "failed to generate local KEK", "uid", uid)
// if the local KEK is expired and we cannot generate a new one, we set
// isReady to false because we can no longer encrypt new data.
if localKEKExpired {
m.isReady.Store(false)
}
}
}, wait.NewJitteredBackoffManager(m.pollInterval, 0, m.clock), true, ctx.Done())
}
// getTransformerForEncryption returns the local KEK as localTransformer, the corresponding
// remoteKMSResponse and a potential error.
// On every use the localUsage is incremented by one.
// It is assumed that only one encryption will happen with the returned transformer.
func (m *LocalKEKService) getTransformerForEncryption(uid string) (value.Transformer, *service.EncryptResponse, error) {
lk := m.getLocalKEK()
// localKEK is not initialized yet
if lk.transformer == nil {
return nil, nil, fmt.Errorf("local KEK is not initialized")
}
if m.clock.Now().After(lk.expiry) {
return nil, nil, fmt.Errorf("local KEK has expired at %v", lk.expiry)
}
if counter := lk.usage.Add(1); counter >= m.keyMaxUsage {
return nil, nil, fmt.Errorf("local KEK has reached maximum usage of %d", keyMaxUsage)
}
return lk.transformer, lk.remoteKMSResponse, nil
}
// Encrypt encrypts the plaintext with the localKEK.
func (m *LocalKEKService) Encrypt(ctx context.Context, uid string, pt []byte) (*service.EncryptResponse, error) {
transformer, resp, err := m.getTransformerForEncryption(uid)
if err != nil {
klog.V(2).ErrorS(err, "failed to get transformer for encryption", "uid", uid)
return nil, err
}
ct, err := transformer.TransformToStorage(ctx, pt, emptyContext)
if err != nil {
klog.V(2).ErrorS(err, "failed to encrypt data", "uid", uid)
return nil, err
}
return &service.EncryptResponse{
Ciphertext: ct,
KeyID: resp.KeyID,
Annotations: resp.Annotations,
}, nil
}
// getTransformerForDecryption returns the transformer for the given encryptedKEK.
// - If the encryptedKEK is the current localKEK, the localKEK is returned.
// - If the encryptedKEK is not the current localKEK, the cache is checked.
// - If the encryptedKEK is not found in the cache, the remote KMS is used to decrypt the encryptedKEK.
func (m *LocalKEKService) getTransformerForDecryption(ctx context.Context, uid string, req *service.DecryptRequest) (value.Transformer, error) {
encKEK := req.Annotations[referenceKEKAnnotationKey]
// check if the key required for decryption is the current local KEK
// that's being used for encryption
lk := m.getLocalKEK()
if lk.transformer != nil && bytes.Equal(lk.encKEK, encKEK) {
return lk.transformer, nil
}
// check if the key required for decryption is already in the cache
if _transformer, found := m.transformers.Get(base64.StdEncoding.EncodeToString(encKEK)); found {
return _transformer.(value.Transformer), nil
}
key, err := m.remoteKMS.Decrypt(ctx, uid, &service.DecryptRequest{
Ciphertext: encKEK,
KeyID: req.KeyID,
Annotations: annotationsWithoutReferenceKeys(req.Annotations),
})
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
transformer := aestransformer.NewGCMTransformer(block)
// Overwrite the plain key with 0s.
copy(key, make([]byte, len(key)))
m.transformers.Add(base64.StdEncoding.EncodeToString(encKEK), transformer)
return transformer, nil
}
// Decrypt attempts to decrypt the ciphertext with the localKEK, a KEK from the
// store, or the remote KMS.
func (m *LocalKEKService) Decrypt(ctx context.Context, uid string, req *service.DecryptRequest) ([]byte, error) {
if _, ok := req.Annotations[referenceKEKAnnotationKey]; !ok {
return nil, fmt.Errorf("unable to find local KEK for request with uid %q", uid)
}
transformer, err := m.getTransformerForDecryption(ctx, uid, req)
if err != nil {
klog.V(2).ErrorS(err, "failed to get transformer for decryption", "uid", uid)
return nil, fmt.Errorf("failed to get transformer for decryption: %w", err)
}
pt, _, err := transformer.TransformFromStorage(ctx, req.Ciphertext, emptyContext)
if err != nil {
klog.V(2).ErrorS(err, "failed to decrypt data", "uid", uid)
return nil, err
}
return pt, nil
}
// Status returns the status of the remote KMS.
func (m *LocalKEKService) Status(ctx context.Context) (*service.StatusResponse, error) {
resp, err := m.remoteKMS.Status(ctx)
if err != nil {
return nil, err
}
if err := validateRemoteKMSStatusResponse(resp); err != nil {
return nil, err
}
r := copyStatusResponse(resp)
// if the remote KMS KeyID has changed, we need to rotate the local KEK
lk := m.getLocalKEK()
if lk.transformer != nil && r.KeyID != lk.remoteKMSResponse.KeyID {
if err := m.rotateLocalKEK(ctx, r.KeyID); err != nil {
klog.ErrorS(err, "failed to rotate local KEK", "expectedKeyID", r.KeyID, "currentKeyID", lk.remoteKMSResponse.KeyID)
// if rotation fails, we will overwrite the keyID to the one we are currently using
// for encryption as localKEKService is the authoritative source for the keyID.
r.KeyID = lk.remoteKMSResponse.KeyID
// TODO(aramase): we are currently not returning the error if rotation fails. We should
// allow the failed state for an arbitrary time period and return the error if the state
// is not eventually fixed.
}
}
var aggregateHealthz []string
if r.Healthz != "ok" {
aggregateHealthz = append(aggregateHealthz, r.Healthz)
}
if !m.isReady.Load() {
// if the localKEKService is not ready, we will set the healthz status to not ready
klog.V(2).InfoS("localKEKService is not ready", "keyID", r.KeyID)
aggregateHealthz = append(aggregateHealthz, "localKEKService is not ready")
}
if len(aggregateHealthz) > 0 {
r.Healthz = strings.Join(aggregateHealthz, "; ")
}
return r, nil
}
// rotateLocalKEK rotates the local KEK by generating a new local KEK and encrypting it with the
// remote KMS.
func (m *LocalKEKService) rotateLocalKEK(ctx context.Context, expectedKeyID string) error {
uid := string(uuid.NewUUID())
if err := m.generateLocalKEK(ctx, uid, expectedKeyID); err != nil {
klog.V(2).ErrorS(err, "failed to generate local KEK as part of rotation", "uid", uid)
return fmt.Errorf("[uid=%s] failed to generate local KEK as part of rotation: %w", uid, err)
}
return nil
}
// generateLocalKEK generates a new local KEK and encrypts it with the remote KMS.
// if expectedKeyID is not empty, it will check if the keyID returned from the remote KMS matches
// the expected keyID. If the keyID does not match, it will continue using the existing local KEK
// and return an error.
func (m *LocalKEKService) generateLocalKEK(ctx context.Context, uid, expectedKeyID string) error {
m.mu.Lock()
defer m.mu.Unlock()
lk := m.getLocalKEK()
// if the localKEK was generated in the last pollInterval duration, we will not generate a new
// localKEK. This is to avoid regenerating a new localKEK for queued requests.
if lk.transformer != nil && m.clock.Since(lk.generatedAt) < m.pollInterval {
return nil
}
key, err := generateKey(keyLength)
if err != nil {
return fmt.Errorf("failed to generate local KEK: %w", err)
}
block, err := aes.NewCipher(key)
if err != nil {
return fmt.Errorf("failed to create cipher block: %w", err)
}
resp, err := m.remoteKMS.Encrypt(ctx, uid, key)
if err != nil {
return fmt.Errorf("failed to encrypt local KEK: %w", err)
}
if err = validateRemoteKMSEncryptResponse(resp); err != nil {
return fmt.Errorf("invalid response from remote KMS: %w", err)
}
if expectedKeyID != "" && resp.KeyID != expectedKeyID {
return fmt.Errorf("keyID returned from remote KMS does not match expected keyID")
}
now := m.clock.Now()
m.localKEKTracker.Store(&localKEK{
encKEK: resp.Ciphertext,
expiry: now.Add(m.keyMaxAge),
usage: atomic.Uint64{},
transformer: aestransformer.NewGCMTransformer(block),
remoteKMSResponse: copyResponseAndAddLocalKEKAnnotation(resp),
generatedAt: now,
})
return nil
}
func (m *LocalKEKService) getLocalKEK() *localKEK {
lk := m.localKEKTracker.Load()
if lk == nil {
return &localKEK{}
}
return lk
}
// copyResponseAndAddLocalKEKAnnotation returns a copy of the remoteKMSResponse with the
// referenceKEKAnnotationKey added to the annotations.
func copyResponseAndAddLocalKEKAnnotation(resp *service.EncryptResponse) *service.EncryptResponse {
annotations := make(map[string][]byte, len(resp.Annotations)+numAnnotations)
for s, bytes := range resp.Annotations {
s := s
bytes := bytes
annotations[s] = bytes
}
annotations[referenceKEKAnnotationKey] = resp.Ciphertext
return &service.EncryptResponse{
// Ciphertext is not set on purpose - it is different per Encrypt call
KeyID: resp.KeyID,
Annotations: annotations,
}
}
// copyStatusResponse returns a copy of the remote KMS status response.
func copyStatusResponse(resp *service.StatusResponse) *service.StatusResponse {
return &service.StatusResponse{
Healthz: resp.Healthz,
Version: resp.Version,
KeyID: resp.KeyID,
}
}
// annotationsWithoutReferenceKeys returns a copy of the annotations without the reference implementation
// annotations.
func annotationsWithoutReferenceKeys(annotations map[string][]byte) map[string][]byte {
if len(annotations) <= numAnnotations {
return nil
}
m := make(map[string][]byte, len(annotations)-numAnnotations)
for k, v := range annotations {
k, v := k, v
if strings.HasSuffix(k, referenceSuffix) {
continue
}
m[k] = v
}
return m
}
// validateRemoteKMSEncryptResponse validates the EncryptResponse from the remote KMS.
func validateRemoteKMSEncryptResponse(resp *service.EncryptResponse) error {
// validate annotations don't contain the reference implementation annotations
for k := range resp.Annotations {
if strings.HasSuffix(k, referenceSuffix) {
return errInvalidKMSAnnotationKeySuffix
}
}
return nil
}
// validateRemoteKMSStatusResponse validates the StatusResponse from the remote KMS.
func validateRemoteKMSStatusResponse(resp *service.StatusResponse) error {
if len(resp.KeyID) == 0 {
return fmt.Errorf("keyID is empty")
}
return nil
}
// generateKey generates a random key using system randomness.
func generateKey(length int) (key []byte, err error) {
key = make([]byte, length)
if _, err = rand.Read(key); err != nil {
return nil, err
}
return key, nil
}

View File

@ -1,833 +0,0 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package hierarchy
import (
"bytes"
"context"
"encoding/base64"
"errors"
"reflect"
"strings"
"sync"
"testing"
"time"
"k8s.io/apimachinery/pkg/util/rand"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/kms/pkg/service"
testingclock "k8s.io/utils/clock/testing"
)
const (
testAnnotationKey = "version.encryption.remote.io"
testAnnotationKeyVersion = "key-version.encryption.remote.io"
)
func TestCopyResponseAndAddLocalKEKAnnotation(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input *service.EncryptResponse
want *service.EncryptResponse
}{
{
name: "annotations is nil",
input: &service.EncryptResponse{
Ciphertext: []byte("encryptedLocalKEK"),
KeyID: "keyID",
Annotations: nil,
},
want: &service.EncryptResponse{
KeyID: "keyID",
Annotations: map[string][]byte{
referenceKEKAnnotationKey: []byte("encryptedLocalKEK"),
},
},
},
{
name: "remote KMS sent 1 annotation",
input: &service.EncryptResponse{
Ciphertext: []byte("encryptedLocalKEK"),
KeyID: "keyID",
Annotations: map[string][]byte{
testAnnotationKey: []byte("1"),
},
},
want: &service.EncryptResponse{
KeyID: "keyID",
Annotations: map[string][]byte{
testAnnotationKey: []byte("1"),
referenceKEKAnnotationKey: []byte("encryptedLocalKEK"),
},
},
},
{
name: "remote KMS sent 2 annotations",
input: &service.EncryptResponse{
Ciphertext: []byte("encryptedLocalKEK"),
KeyID: "keyID",
Annotations: map[string][]byte{
testAnnotationKey: []byte("1"),
testAnnotationKeyVersion: []byte("2"),
},
},
want: &service.EncryptResponse{
KeyID: "keyID",
Annotations: map[string][]byte{
testAnnotationKey: []byte("1"),
testAnnotationKeyVersion: []byte("2"),
referenceKEKAnnotationKey: []byte("encryptedLocalKEK"),
},
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := copyResponseAndAddLocalKEKAnnotation(tc.input)
if !reflect.DeepEqual(got, tc.want) {
t.Errorf("copyResponseAndAddLocalKEKAnnotation(%v) = %v, want %v", tc.input, got, tc.want)
}
})
}
}
func TestAnnotationsWithoutReferenceKeys(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input map[string][]byte
want map[string][]byte
}{
{
name: "annotations is nil",
input: nil,
want: nil,
},
{
name: "annotations is empty",
input: map[string][]byte{},
want: nil,
},
{
name: "annotations only contains reference keys",
input: map[string][]byte{
referenceKEKAnnotationKey: []byte("encryptedLocalKEK"),
},
want: nil,
},
{
name: "annotations contains 1 reference key and 1 other key",
input: map[string][]byte{
referenceKEKAnnotationKey: []byte("encryptedLocalKEK"),
testAnnotationKey: []byte("1"),
},
want: map[string][]byte{
testAnnotationKey: []byte("1"),
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := annotationsWithoutReferenceKeys(tc.input)
if !reflect.DeepEqual(got, tc.want) {
t.Errorf("annotationsWithoutReferenceKeys(%v) = %v, want %v", tc.input, got, tc.want)
}
})
}
}
func TestValidateRemoteKMSEncryptResponse(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input *service.EncryptResponse
want error
}{
{
name: "annotations is nil",
input: &service.EncryptResponse{},
want: nil,
},
{
name: "annotation key contains reference suffix",
input: &service.EncryptResponse{
Annotations: map[string][]byte{
"version.reference.encryption.k8s.io": []byte("1"),
},
},
want: errInvalidKMSAnnotationKeySuffix,
},
{
name: "no annotation key contains reference suffix",
input: &service.EncryptResponse{
Annotations: map[string][]byte{
testAnnotationKey: []byte("1"),
testAnnotationKeyVersion: []byte("2"),
},
},
want: nil,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := validateRemoteKMSEncryptResponse(tc.input)
if got != tc.want {
t.Errorf("validateRemoteKMSResponse(%v) = %v, want %v", tc.input, got, tc.want)
}
})
}
}
func TestValidateRemoteKMSStatusResponse(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
input *service.StatusResponse
wantErr string
}{
{
name: "keyID is empty",
input: &service.StatusResponse{
KeyID: "",
},
wantErr: "keyID is empty",
},
{
name: "no error",
input: &service.StatusResponse{
KeyID: "keyID",
},
wantErr: "",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := validateRemoteKMSStatusResponse(tc.input)
if tc.wantErr != "" {
if got == nil {
t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr)
}
if !strings.Contains(got.Error(), tc.wantErr) {
t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr)
}
} else {
if got != nil {
t.Errorf("validateRemoteKMSStatusResponse(%v) = %v, want %v", tc.input, got, tc.wantErr)
}
}
})
}
}
var _ service.Service = &testRemoteService{}
type testRemoteService struct {
mu sync.Mutex
keyID string
disabled bool
encryptCallCount int
decryptCallCount int
}
func (s *testRemoteService) Encrypt(ctx context.Context, uid string, plaintext []byte) (*service.EncryptResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.encryptCallCount++
if s.disabled {
return nil, errors.New("failed to encrypt")
}
return &service.EncryptResponse{
KeyID: s.keyID,
Ciphertext: []byte(base64.StdEncoding.EncodeToString(plaintext)),
Annotations: map[string][]byte{
testAnnotationKey: []byte("1"),
},
}, nil
}
func (s *testRemoteService) Decrypt(ctx context.Context, uid string, req *service.DecryptRequest) ([]byte, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.decryptCallCount++
if s.disabled {
return nil, errors.New("failed to decrypt")
}
if len(req.Annotations) != 1 {
return nil, errors.New("invalid annotations")
}
if v, ok := req.Annotations[testAnnotationKey]; !ok || string(v) != "1" {
return nil, errors.New("invalid version in annotations")
}
return base64.StdEncoding.DecodeString(string(req.Ciphertext))
}
func (s *testRemoteService) Status(ctx context.Context) (*service.StatusResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
resp := &service.StatusResponse{
Version: "v2alpha1",
Healthz: "ok",
KeyID: s.keyID,
}
if s.disabled {
resp.Healthz = "remote KMS is disabled"
}
return resp, nil
}
func (s *testRemoteService) SetDisabledStatus(disabled bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.disabled = disabled
}
func (s *testRemoteService) SetKeyID(keyID string) {
s.mu.Lock()
defer s.mu.Unlock()
s.keyID = keyID
}
func (s *testRemoteService) EncryptCallCount() int {
s.mu.Lock()
defer s.mu.Unlock()
return s.encryptCallCount
}
func (s *testRemoteService) DecryptCallCount() int {
s.mu.Lock()
defer s.mu.Unlock()
return s.decryptCallCount
}
func TestEncrypt(t *testing.T) {
ctx := testContext(t)
remoteKMS := &testRemoteService{keyID: "test-key-id"}
localKEKService := NewLocalKEKService(ctx, remoteKMS)
waitUntilReady(t, localKEKService)
// local KEK is generated and encryption is successful
got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
if err != nil {
t.Fatalf("Encrypt() error = %v", err)
}
validateEncryptResponse(t, got, remoteKMS.keyID, localKEKService)
// local KEK is used for encryption even when remote KMS is failing
remoteKMS.SetDisabledStatus(true)
if got, err = localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext")); err != nil {
t.Fatalf("Encrypt() error = %v", err)
}
validateEncryptResponse(t, got, remoteKMS.keyID, localKEKService)
}
func TestEncryptError(t *testing.T) {
ctx := testContext(t)
remoteKMS := &testRemoteService{keyID: "test-key-id"}
localKEKService := NewLocalKEKService(ctx, remoteKMS)
// first time local KEK generation fails because of remote KMS
remoteKMS.SetDisabledStatus(true)
_, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
if err == nil {
t.Fatalf("Encrypt() error = %v, want non-nil", err)
}
lk := localKEKService.getLocalKEK()
if lk.transformer != nil {
t.Fatalf("Encrypt() localKEKTracker = %v, want non-nil localKEK", lk)
}
remoteKMS.SetDisabledStatus(false)
}
func TestDecrypt(t *testing.T) {
ctx := testContext(t)
remoteKMS := &testRemoteService{keyID: "test-key-id"}
localKEKService := NewLocalKEKService(ctx, remoteKMS)
waitUntilReady(t, localKEKService)
// local KEK is generated and encryption/decryption is successful
got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
if err != nil {
t.Fatalf("Encrypt() error = %v", err)
}
if string(got.Ciphertext) == "test-plaintext" {
t.Fatalf("Encrypt() ciphertext = %v, want it to be encrypted", got.Ciphertext)
}
decryptRequest := &service.DecryptRequest{
Ciphertext: got.Ciphertext,
Annotations: got.Annotations,
KeyID: got.KeyID,
}
plaintext, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest)
if err != nil {
t.Fatalf("Decrypt() error = %v", err)
}
if string(plaintext) != "test-plaintext" {
t.Fatalf("Decrypt() plaintext = %v, want %v", string(plaintext), "test-plaintext")
}
// local KEK is used for decryption even when remote KMS is failing
remoteKMS.SetDisabledStatus(true)
if _, err = localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil {
t.Fatalf("Decrypt() error = %v", err)
}
}
func TestDecryptError(t *testing.T) {
ctx := testContext(t)
remoteKMS := &testRemoteService{keyID: "test-key-id"}
localKEKService := NewLocalKEKService(ctx, remoteKMS)
waitUntilReady(t, localKEKService)
got, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
if err != nil {
t.Fatalf("Encrypt() error = %v", err)
}
decryptRequest := &service.DecryptRequest{
Ciphertext: got.Ciphertext,
Annotations: got.Annotations,
KeyID: got.KeyID,
}
// local KEK for decryption not in cache and remote KMS is failing
remoteKMS.SetDisabledStatus(true)
lk := localKEKService.localKEKTracker.Load()
lk.transformer = nil
localKEKService.localKEKTracker.Store(lk)
// clear the cache
localKEKService.transformers.Clear()
if _, err = localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err == nil {
t.Fatalf("Decrypt() error = %v, want non-nil", err)
}
}
func TestStatus(t *testing.T) {
ctx := testContext(t)
fakeClock := testingclock.NewFakeClock(time.Now())
remoteKMS := &testRemoteService{keyID: "test-key-id"}
localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Second, 100*time.Millisecond, fakeClock)
waitUntilReady(t, localKEKService)
got, err := localKEKService.Status(ctx)
if err != nil {
t.Fatalf("Status() error = %v", err)
}
validateStatusResponse(t, got, "v2alpha1", "ok", "test-key-id")
fakeClock.Step(1 * time.Second)
// remote KMS is failing
remoteKMS.SetDisabledStatus(true)
// remote KMS keyID changed but local KEK not rotated because of remote KMS failure
// the keyID in status should be the old keyID
// the error should still be nil
remoteKMS.SetKeyID("test-key-id-2")
if got, err = localKEKService.Status(ctx); err != nil {
t.Fatalf("Status() error = %v, want nil", err)
}
validateStatusResponse(t, got, "v2alpha1", "remote KMS is disabled", "test-key-id")
fakeClock.Step(1 * time.Second)
// wait for local KEK to expire and local KEK service ready to be false
wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
return !localKEKService.isReady.Load(), nil
})
// status response should include the localKEK unhealthy status
if got, err = localKEKService.Status(ctx); err != nil {
t.Fatalf("Status() error = %v, want nil", err)
}
validateStatusResponse(t, got, "v2alpha1", "remote KMS is disabled; localKEKService is not ready", "test-key-id")
// remote KMS is functional again, local KEK is rotated
remoteKMS.SetDisabledStatus(false)
fakeClock.Step(1 * time.Second)
waitUntilReady(t, localKEKService)
if got, err = localKEKService.Status(ctx); err != nil {
t.Fatalf("Status() error = %v, want nil", err)
}
validateStatusResponse(t, got, "v2alpha1", "ok", "test-key-id-2")
}
func TestRotationKeyUsage(t *testing.T) {
ctx := testContext(t)
var record sync.Map
fakeClock := testingclock.NewFakeClock(time.Now())
remoteKMS := &testRemoteService{keyID: "test-key-id"}
localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Minute, 100*time.Millisecond, fakeClock)
waitUntilReady(t, localKEKService)
lk := localKEKService.localKEKTracker.Load()
encLocalKEK := lk.encKEK
// check only single call for Encrypt to remote KMS
if remoteKMS.EncryptCallCount() != 1 {
t.Fatalf("Encrypt() remoteKMS.EncryptCallCount() = %v, want %v", remoteKMS.EncryptCallCount(), 1)
}
var wg sync.WaitGroup
for i := 0; i < 6; i++ {
wg.Add(1)
go func() {
defer wg.Done()
resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte(rand.String(32)))
if err != nil {
t.Errorf("Encrypt() error = %v", err)
return
}
if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) {
t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK)
return
}
record.Store(resp, nil)
}()
}
wg.Wait()
if t.Failed() {
return
}
fakeClock.Step(30 * time.Second)
rotated := false
// wait for the local KEK to be rotated
wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
// local KEK must have been rotated after 5 usages
lk = localKEKService.localKEKTracker.Load()
rotated = !bytes.Equal(lk.encKEK, encLocalKEK)
return rotated, nil
})
if !rotated {
t.Fatalf("local KEK must have been rotated")
}
if remoteKMS.EncryptCallCount() != 2 {
t.Fatalf("Encrypt() remoteKMS.EncryptCallCount() = %v, want %v", remoteKMS.EncryptCallCount(), 2)
}
// new local KEK must be used for encryption
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte(rand.String(32)))
if err != nil {
t.Errorf("Encrypt() error = %v", err)
return
}
if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) {
t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK)
return
}
record.Store(resp, nil)
}()
}
wg.Wait()
if t.Failed() {
return
}
// check we can decrypt data encrypted with the old and new local KEKs
record.Range(func(key, _ any) bool {
k := key.(*service.EncryptResponse)
decryptRequest := &service.DecryptRequest{
Ciphertext: k.Ciphertext,
Annotations: k.Annotations,
KeyID: k.KeyID,
}
if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil {
t.Fatalf("Decrypt() error = %v", err)
}
return true
})
// Out of the 11 calls to Decrypt:
// - 5 should be using the current local KEK
// - 1 out of the 6 should generate a decrypt call to the remote KMS as the local KEK not in cache
// - 5 out of the 6 should use the cached local KEK after 1st decrypt call to the remote KMS
assertCallCount(t, remoteKMS, localKEKService)
}
func TestRotationKeyExpiry(t *testing.T) {
ctx := testContext(t)
var record sync.Map
fakeClock := testingclock.NewFakeClock(time.Now())
remoteKMS := &testRemoteService{keyID: "test-key-id"}
localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 5*time.Second, 100*time.Millisecond, fakeClock)
waitUntilReady(t, localKEKService)
lk := localKEKService.localKEKTracker.Load()
encLocalKEK := lk.encKEK
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
defer wg.Done()
resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
if err != nil {
t.Errorf("Encrypt() error = %v", err)
return
}
if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) {
t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK)
return
}
record.Store(resp, nil)
}()
}
wg.Wait()
if t.Failed() {
return
}
// check local KEK has only been used 3 times and still under the suggested usage
if lk.usage.Load() != 3 {
t.Fatalf("local KEK usage = %v, want %v", lk.usage.Load(), 3)
}
// advance the clock to trigger key expiry
fakeClock.Step(6 * time.Second)
rotated := false
// wait for the local KEK to be rotated due to key expiry
wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
// local KEK must have been rotated after the key max age
t.Log("waiting for local KEK to be rotated")
lk = localKEKService.localKEKTracker.Load()
rotated = !bytes.Equal(lk.encKEK, encLocalKEK)
return rotated, nil
})
if !rotated {
t.Fatalf("local KEK must have been rotated")
}
// new local KEK must be used for encryption
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
if err != nil {
t.Errorf("Encrypt() error = %v", err)
return
}
if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) {
t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK)
return
}
record.Store(resp, nil)
}()
}
wg.Wait()
if t.Failed() {
return
}
// check we can decrypt data encrypted with the old and new local KEKs
record.Range(func(key, _ any) bool {
k := key.(*service.EncryptResponse)
decryptRequest := &service.DecryptRequest{
Ciphertext: k.Ciphertext,
Annotations: k.Annotations,
KeyID: k.KeyID,
}
if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil {
t.Fatalf("Decrypt() error = %v", err)
}
return true
})
// Out of the 8 calls to Decrypt:
// - 5 should be using the current local KEK
// - 1 out of the 3 should generate a decrypt call to the remote KMS as the local KEK not in cache
// - 2 out of the 3 should use the cached local KEK after 1st decrypt call to the remote KMS
assertCallCount(t, remoteKMS, localKEKService)
}
func TestRotationRemoteKeyIDChanged(t *testing.T) {
ctx := testContext(t)
var record sync.Map
fakeClock := testingclock.NewFakeClock(time.Now())
remoteKMS := &testRemoteService{keyID: "test-key-id"}
localKEKService := newLocalKEKService(ctx, remoteKMS, 10, 5, 1*time.Minute, 100*time.Millisecond, fakeClock)
waitUntilReady(t, localKEKService)
lk := localKEKService.localKEKTracker.Load()
encLocalKEK := lk.encKEK
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
defer wg.Done()
resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
if err != nil {
t.Errorf("Encrypt() error = %v", err)
return
}
if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, encLocalKEK) {
t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, encLocalKEK)
return
}
record.Store(resp, nil)
}()
}
wg.Wait()
if t.Failed() {
return
}
// check local KEK has only been used 3 times and still under the suggested usage
if lk.usage.Load() != 3 {
t.Fatalf("local KEK usage = %v, want %v", lk.usage.Load(), 3)
}
fakeClock.Step(30 * time.Second)
// change the remote key ID
remoteKMS.SetKeyID("test-key-id-2")
if _, err := localKEKService.Status(ctx); err != nil {
t.Fatalf("Status() error = %v", err)
}
rotated := false
// wait for the local KEK to be rotated due to remote key ID change
wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
lk = localKEKService.localKEKTracker.Load()
rotated = !bytes.Equal(lk.encKEK, encLocalKEK)
return rotated, nil
})
if !rotated {
t.Fatalf("local KEK must have been rotated")
}
// new local KEK must be used for encryption
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
resp, err := localKEKService.Encrypt(ctx, "test-uid", []byte("test-plaintext"))
if err != nil {
t.Errorf("Encrypt() error = %v", err)
return
}
if v, ok := resp.Annotations[referenceKEKAnnotationKey]; !ok || !bytes.Equal(v, lk.encKEK) {
t.Errorf("Encrypt() annotations = %v, want %v", resp.Annotations, lk.encKEK)
return
}
record.Store(resp, nil)
}()
}
wg.Wait()
if t.Failed() {
return
}
// check we can decrypt data encrypted with the old and new local KEKs
record.Range(func(key, _ any) bool {
k := key.(*service.EncryptResponse)
decryptRequest := &service.DecryptRequest{
Ciphertext: k.Ciphertext,
Annotations: k.Annotations,
KeyID: k.KeyID,
}
if _, err := localKEKService.Decrypt(ctx, "test-uid", decryptRequest); err != nil {
t.Fatalf("Decrypt() error = %v", err)
}
return true
})
// Out of the 8 calls to Decrypt:
// - 5 should be using the current local KEK
// - 1 out of the 3 should generate a decrypt call to the remote KMS as the local KEK not in cache
// - 2 out of the 3 should use the cached local KEK after 1st decrypt call to the remote KMS
assertCallCount(t, remoteKMS, localKEKService)
}
func testContext(t *testing.T) context.Context {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
return ctx
}
func waitUntilReady(t *testing.T, s *LocalKEKService) {
t.Helper()
wait.PollImmediate(100*time.Millisecond, wait.ForeverTestTimeout, func() (bool, error) {
return s.isReady.Load(), nil
})
}
func validateEncryptResponse(t *testing.T, got *service.EncryptResponse, wantKeyID string, localKEKService *LocalKEKService) {
t.Helper()
if len(got.Annotations) != 2 {
t.Fatalf("Encrypt() annotations = %v, want 2 annotations", got.Annotations)
}
if _, ok := got.Annotations[referenceKEKAnnotationKey]; !ok {
t.Fatalf("Encrypt() annotations = %v, want %v", got.Annotations, referenceKEKAnnotationKey)
}
if got.KeyID != wantKeyID {
t.Fatalf("Encrypt() keyID = %v, want %v", got.KeyID, wantKeyID)
}
if localKEKService.localKEKTracker.Load() == nil {
t.Fatalf("Encrypt() localKEKTracker = %v, want non-nil localKEK", localKEKService.localKEKTracker.Load())
}
}
func validateStatusResponse(t *testing.T, got *service.StatusResponse, wantVersion, wantHealthz, wantKeyID string) {
t.Helper()
if got.Version != wantVersion {
t.Fatalf("Status() version = %v, want %v", got.Version, wantVersion)
}
if !strings.EqualFold(got.Healthz, wantHealthz) {
t.Fatalf("Status() healthz = %v, want %v", got.Healthz, wantHealthz)
}
if got.KeyID != wantKeyID {
t.Fatalf("Status() keyID = %v, want %v", got.KeyID, wantKeyID)
}
}
func assertCallCount(t *testing.T, remoteKMS *testRemoteService, localKEKService *LocalKEKService) {
t.Helper()
if remoteKMS.DecryptCallCount() != 1 {
t.Fatalf("Decrypt() remoteKMS.DecryptCallCount() = %v, want %v", remoteKMS.DecryptCallCount(), 1)
}
if localKEKService.transformers.Len() != 1 {
t.Fatalf("Decrypt() localKEKService.transformers.Len() = %v, want %v", localKEKService.transformers.Len(), 1)
}
}