From 8eacf09649ac9042c7e998b5c24ac59d68ae7e6c Mon Sep 17 00:00:00 2001 From: Anish Ramasekar Date: Tue, 14 Mar 2023 19:38:30 +0000 Subject: [PATCH] [KMSv2] use encDEK, keyID and annotations to generate cache key It is possible for a KMSv2 plugin to return a static value as Ciphertext and store the actual encrypted DEK in the annotations. In this case, using the encDEK will not work. Instead, we are now using a combination of the encDEK, keyID and annotations to generate the cache key. Signed-off-by: Anish Ramasekar --- .../server/options/encryptionconfig/config.go | 7 +- .../options/encryptionconfig/config_test.go | 3 +- .../value/encrypt/envelope/kmsv2/cache.go | 8 +- .../value/encrypt/envelope/kmsv2/envelope.go | 90 +++++++++- .../encrypt/envelope/kmsv2/envelope_test.go | 169 +++++++++++++++--- 5 files changed, 245 insertions(+), 32 deletions(-) diff --git a/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config.go b/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config.go index 205798bb5cc..796cc6b03dc 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config.go +++ b/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config.go @@ -358,7 +358,7 @@ func (h *kmsv2PluginProbe) rotateDEKOnKeyIDChange(ctx context.Context, statusKey return nil } - transformer, resp, errGen := envelopekmsv2.GenerateTransformer(ctx, uid, h.service) + transformer, resp, cacheKey, errGen := envelopekmsv2.GenerateTransformer(ctx, uid, h.service) if resp == nil { resp = &kmsservice.EncryptResponse{} // avoid nil panics @@ -374,6 +374,7 @@ func (h *kmsv2PluginProbe) rotateDEKOnKeyIDChange(ctx context.Context, statusKey Annotations: resp.Annotations, UID: uid, ExpirationTimestamp: expirationTimestamp, + CacheKey: cacheKey, }) klog.V(6).InfoS("successfully rotated DEK", "uid", uid, @@ -410,6 +411,10 @@ func (h *kmsv2PluginProbe) getCurrentState() (envelopekmsv2.State, error) { return envelopekmsv2.State{}, fmt.Errorf("got unexpected zero expirationTimestamp") } + if len(state.CacheKey) == 0 { + return envelopekmsv2.State{}, fmt.Errorf("got unexpected empty cacheKey") + } + return state, nil } diff --git a/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config_test.go b/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config_test.go index 9768a58c111..f52a5f42a26 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config_test.go +++ b/staging/src/k8s.io/apiserver/pkg/server/options/encryptionconfig/config_test.go @@ -1781,7 +1781,7 @@ func Test_kmsv2PluginProbe_rotateDEKOnKeyIDChange(t *testing.T) { t.Errorf("log mismatch (-want +got):\n%s", diff) } - ignoredFields := sets.NewString("Transformer", "EncryptedDEK", "UID") + ignoredFields := sets.NewString("Transformer", "EncryptedDEK", "UID", "CacheKey") if diff := cmp.Diff(tt.wantState, *h.state.Load(), cmp.FilterPath(func(path cmp.Path) bool { return ignoredFields.Has(path.String()) }, cmp.Ignore()), @@ -1806,6 +1806,7 @@ func validState(keyID string, exp time.Time) envelopekmsv2.State { EncryptedDEK: []byte{1}, KeyID: keyID, ExpirationTimestamp: exp, + CacheKey: []byte{1}, } } diff --git a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache.go b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache.go index ae2e344e158..3c1fbbf8a36 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache.go @@ -98,5 +98,11 @@ func (c *simpleCache) keyFunc(s []byte) string { // toString performs unholy acts to avoid allocations func toString(b []byte) string { - return *(*string)(unsafe.Pointer(&b)) + // unsafe.SliceData relies on cap whereas we want to rely on len + if len(b) == 0 { + return "" + } + // Copied from go 1.20.1 strings.Builder.String + // https://github.com/golang/go/blob/202a1a57064127c3f19d96df57b9f9586145e21c/src/strings/builder.go#L48 + return unsafe.String(unsafe.SliceData(b), len(b)) } diff --git a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/envelope.go b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/envelope.go index f8b1d8e2d35..4aa4e6b1933 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/envelope.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/envelope.go @@ -21,9 +21,12 @@ import ( "context" "crypto/aes" "fmt" + "sort" "time" + "unsafe" "github.com/gogo/protobuf/proto" + "golang.org/x/crypto/cryptobyte" utilerrors "k8s.io/apimachinery/pkg/util/errors" "k8s.io/apimachinery/pkg/util/uuid" @@ -87,6 +90,9 @@ type State struct { UID string ExpirationTimestamp time.Time + + // CacheKey is the key used to cache the DEK in transformer.cache. + CacheKey []byte } func (s *State) ValidateEncryptCapability() error { @@ -137,8 +143,13 @@ func (t *envelopeTransformer) TransformFromStorage(ctx context.Context, data []b return nil, false, err } + encryptedObjectCacheKey, err := generateCacheKey(encryptedObject.EncryptedDEK, encryptedObject.KeyID, encryptedObject.Annotations) + if err != nil { + return nil, false, err + } + // Look up the decrypted DEK from cache first - transformer := t.cache.get(encryptedObject.EncryptedDEK) + transformer := t.cache.get(encryptedObjectCacheKey) // fallback to the envelope service if we do not have the transformer locally if transformer == nil { @@ -159,7 +170,7 @@ func (t *envelopeTransformer) TransformFromStorage(ctx context.Context, data []b return nil, false, fmt.Errorf("failed to decrypt DEK, error: %w", err) } - transformer, err = t.addTransformerForDecryption(encryptedObject.EncryptedDEK, key) + transformer, err = t.addTransformerForDecryption(encryptedObjectCacheKey, key) if err != nil { return nil, false, err } @@ -190,7 +201,7 @@ func (t *envelopeTransformer) TransformToStorage(ctx context.Context, data []byt // this has the side benefit of causing the cache to perform a GC // TODO see if we can do this inside the stateFunc control loop // TODO(aramase): Add metrics for cache fill percentage with custom cache implementation. - t.cache.set(state.EncryptedDEK, state.Transformer) + t.cache.set(state.CacheKey, state.Transformer) requestInfo := getRequestInfoFromContext(ctx) klog.V(6).InfoS("encrypting content using DEK", "uid", state.UID, "key", string(dataCtx.AuthenticatedData()), @@ -216,7 +227,7 @@ func (t *envelopeTransformer) TransformToStorage(ctx context.Context, data []byt } // addTransformerForDecryption inserts a new transformer to the Envelope cache of DEKs for future reads. -func (t *envelopeTransformer) addTransformerForDecryption(encKey []byte, key []byte) (decryptTransformer, error) { +func (t *envelopeTransformer) addTransformerForDecryption(cacheKey []byte, key []byte) (decryptTransformer, error) { block, err := aes.NewCipher(key) if err != nil { return nil, err @@ -228,7 +239,7 @@ func (t *envelopeTransformer) addTransformerForDecryption(encKey []byte, key []b return nil, err } // TODO(aramase): Add metrics for cache fill percentage with custom cache implementation. - t.cache.set(encKey, transformer) + t.cache.set(cacheKey, transformer) return transformer, nil } @@ -254,20 +265,25 @@ func (t *envelopeTransformer) doDecode(originalData []byte) (*kmstypes.Encrypted return o, nil } -func GenerateTransformer(ctx context.Context, uid string, envelopeService kmsservice.Service) (value.Transformer, *kmsservice.EncryptResponse, error) { +func GenerateTransformer(ctx context.Context, uid string, envelopeService kmsservice.Service) (value.Transformer, *kmsservice.EncryptResponse, []byte, error) { transformer, newKey, err := aestransformer.NewGCMTransformerWithUniqueKeyUnsafe() if err != nil { - return nil, nil, err + return nil, nil, nil, err } klog.V(6).InfoS("encrypting content using envelope service", "uid", uid) resp, err := envelopeService.Encrypt(ctx, uid, newKey) if err != nil { - return nil, nil, fmt.Errorf("failed to encrypt DEK, error: %w", err) + return nil, nil, nil, fmt.Errorf("failed to encrypt DEK, error: %w", err) } - return transformer, resp, nil + cacheKey, err := generateCacheKey(resp.Ciphertext, resp.KeyID, resp.Annotations) + if err != nil { + return nil, nil, nil, err + } + + return transformer, resp, cacheKey, nil } func validateEncryptedObject(o *kmstypes.EncryptedObject) error { @@ -339,3 +355,59 @@ func getRequestInfoFromContext(ctx context.Context) *genericapirequest.RequestIn } return &genericapirequest.RequestInfo{} } + +// generateCacheKey returns a key for the cache. +// The key is a concatenation of: +// 1. encryptedDEK +// 2. keyID +// 3. length of annotations +// 4. annotations (sorted by key) - each annotation is a concatenation of: +// a. annotation key +// b. annotation value +func generateCacheKey(encryptedDEK []byte, keyID string, annotations map[string][]byte) ([]byte, error) { + // TODO(aramase): use sync pool buffer to avoid allocations + b := cryptobyte.NewBuilder(nil) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(encryptedDEK) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(toBytes(keyID)) + }) + if len(annotations) == 0 { + return b.Bytes() + } + + // add the length of annotations to the cache key + b.AddUint32(uint32(len(annotations))) + + // Sort the annotations by key. + keys := make([]string, 0, len(annotations)) + for k := range annotations { + k := k + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + // The maximum size of annotations is annotationsMaxSize (32 kB) so we can safely + // assume that the length of the key and value will fit in a uint16. + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(toBytes(k)) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(annotations[k]) + }) + } + + return b.Bytes() +} + +// toBytes performs unholy acts to avoid allocations +func toBytes(s string) []byte { + // unsafe.StringData is unspecified for the empty string, so we provide a strict interpretation + if len(s) == 0 { + return nil + } + // Copied from go 1.20.1 os.File.WriteString + // https://github.com/golang/go/blob/202a1a57064127c3f19d96df57b9f9586145e21c/src/os/file.go#L246 + return unsafe.Slice(unsafe.StringData(s), len(s)) +} diff --git a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/envelope_test.go b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/envelope_test.go index c56279b0833..63ec4600de2 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/envelope_test.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/envelope_test.go @@ -55,12 +55,15 @@ const ( // testEnvelopeService is a mock Envelope service which can be used to simulate remote Envelope services // for testing of Envelope based encryption providers. type testEnvelopeService struct { - annotations map[string][]byte - disabled bool - keyVersion string + annotations map[string][]byte + disabled bool + keyVersion string + ciphertext []byte + decryptCalls int } func (t *testEnvelopeService) Decrypt(ctx context.Context, uid string, req *kmsservice.DecryptRequest) ([]byte, error) { + t.decryptCalls++ if t.disabled { return nil, fmt.Errorf("Envelope service was disabled") } @@ -88,7 +91,13 @@ func (t *testEnvelopeService) Encrypt(ctx context.Context, uid string, data []by } else { annotations["local-kek.kms.kubernetes.io"] = []byte("encrypted-local-kek") } - return &kmsservice.EncryptResponse{Ciphertext: []byte(base64.StdEncoding.EncodeToString(data)), KeyID: t.keyVersion, Annotations: annotations}, nil + + ciphertext := t.ciphertext + if ciphertext == nil { + ciphertext = []byte(base64.StdEncoding.EncodeToString(data)) + } + + return &kmsservice.EncryptResponse{Ciphertext: ciphertext, KeyID: t.keyVersion, Annotations: annotations}, nil } func (t *testEnvelopeService) Status(ctx context.Context) (*kmsservice.StatusResponse, error) { @@ -106,6 +115,10 @@ func (t *testEnvelopeService) SetAnnotations(annotations map[string][]byte) { t.annotations = annotations } +func (t *testEnvelopeService) SetCiphertext(ciphertext []byte) { + t.ciphertext = ciphertext +} + func (t *testEnvelopeService) Rotate() { i, _ := strconv.Atoi(t.keyVersion) t.keyVersion = strconv.FormatInt(int64(i+1), 10) @@ -124,17 +137,26 @@ func TestEnvelopeCaching(t *testing.T) { cacheTTL time.Duration simulateKMSPluginFailure bool expectedError string + expectedDecryptCalls int }{ { desc: "entry in cache should withstand plugin failure", cacheTTL: 5 * time.Minute, simulateKMSPluginFailure: true, + expectedDecryptCalls: 0, // should not hit KMS plugin }, { desc: "cache entry expired should not withstand plugin failure", cacheTTL: 1 * time.Millisecond, simulateKMSPluginFailure: true, expectedError: "failed to decrypt DEK, error: Envelope service was disabled", + expectedDecryptCalls: 10, // should hit KMS plugin for each read after cache entry expired and fail + }, + { + desc: "cache entry expired should work after cache refresh", + cacheTTL: 1 * time.Millisecond, + simulateKMSPluginFailure: false, + expectedDecryptCalls: 1, // should hit KMS plugin just for the 1st read after cache entry expired }, } @@ -176,30 +198,35 @@ func TestEnvelopeCaching(t *testing.T) { } envelopeService.SetDisabledStatus(tt.simulateKMSPluginFailure) - // Subsequent read for the same data should work fine due to caching. - untransformedData, _, err = transformer.TransformFromStorage(ctx, transformedData, dataCtx) - if tt.expectedError != "" { - if err == nil { - t.Fatalf("expected error: %v, got nil", tt.expectedError) - } - if err.Error() != tt.expectedError { - t.Fatalf("expected error: %v, got: %v", tt.expectedError, err) - } - } else { - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !bytes.Equal(untransformedData, originalText) { - t.Fatalf("envelopeTransformer transformed data incorrectly. Expected: %v, got %v", originalText, untransformedData) + for i := 0; i < 10; i++ { + // Subsequent reads for the same data should work fine due to caching. + untransformedData, _, err = transformer.TransformFromStorage(ctx, transformedData, dataCtx) + if tt.expectedError != "" { + if err == nil { + t.Fatalf("expected error: %v, got nil", tt.expectedError) + } + if err.Error() != tt.expectedError { + t.Fatalf("expected error: %v, got: %v", tt.expectedError, err) + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !bytes.Equal(untransformedData, originalText) { + t.Fatalf("envelopeTransformer transformed data incorrectly. Expected: %v, got %v", originalText, untransformedData) + } } } + if envelopeService.decryptCalls != tt.expectedDecryptCalls { + t.Fatalf("expected %d decrypt calls, got %d", tt.expectedDecryptCalls, envelopeService.decryptCalls) + } }) } } func testStateFunc(ctx context.Context, envelopeService kmsservice.Service, clock clock.Clock) func() (State, error) { return func() (State, error) { - transformer, resp, errGen := GenerateTransformer(ctx, string(uuid.NewUUID()), envelopeService) + transformer, resp, cacheKey, errGen := GenerateTransformer(ctx, string(uuid.NewUUID()), envelopeService) if errGen != nil { return State{}, errGen } @@ -210,6 +237,7 @@ func testStateFunc(ctx context.Context, envelopeService kmsservice.Service, cloc Annotations: resp.Annotations, UID: "panda", ExpirationTimestamp: clock.Now().Add(time.Hour), + CacheKey: cacheKey, }, nil } } @@ -861,6 +889,107 @@ func TestEnvelopeLogging(t *testing.T) { } } +func TestCacheNotCorrupted(t *testing.T) { + ctx := testContext(t) + + envelopeService := newTestEnvelopeService() + envelopeService.SetAnnotations(map[string][]byte{ + "encrypted-dek.kms.kubernetes.io": []byte("encrypted-dek-0"), + }) + + fakeClock := testingclock.NewFakeClock(time.Now()) + + state, err := testStateFunc(ctx, envelopeService, fakeClock)() + if err != nil { + t.Fatal(err) + } + + transformer := newEnvelopeTransformerWithClock(envelopeService, testProviderName, + func() (State, error) { return state, nil }, + 1*time.Second, fakeClock) + + dataCtx := value.DefaultContext(testContextText) + originalText := []byte(testText) + + transformedData1, err := transformer.TransformToStorage(ctx, originalText, dataCtx) + if err != nil { + t.Fatalf("envelopeTransformer: error while transforming data to storage: %s", err) + } + + // this is to mimic a plugin that sets a static response for ciphertext + // but uses the annotation field to send the actual encrypted DEK. + envelopeService.SetCiphertext(state.EncryptedDEK) + // for this plugin, it indicates a change in the remote key ID as the returned + // encrypted DEK is different. + envelopeService.SetAnnotations(map[string][]byte{ + "encrypted-dek.kms.kubernetes.io": []byte("encrypted-dek-1"), + }) + + state, err = testStateFunc(ctx, envelopeService, fakeClock)() + if err != nil { + t.Fatal(err) + } + + transformer = newEnvelopeTransformerWithClock(envelopeService, testProviderName, + func() (State, error) { return state, nil }, + 1*time.Second, fakeClock) + + transformedData2, err := transformer.TransformToStorage(ctx, originalText, dataCtx) + if err != nil { + t.Fatalf("envelopeTransformer: error while transforming data to storage: %s", err) + } + + if _, _, err := transformer.TransformFromStorage(ctx, transformedData1, dataCtx); err != nil { + t.Fatal(err) + } + if _, _, err := transformer.TransformFromStorage(ctx, transformedData2, dataCtx); err != nil { + t.Fatal(err) + } +} + +func TestGenerateCacheKey(t *testing.T) { + encryptedDEK1 := []byte{1, 2, 3} + keyID1 := "id1" + annotations1 := map[string][]byte{"a": {4, 5}, "b": {6, 7}} + + encryptedDEK2 := []byte{4, 5, 6} + keyID2 := "id2" + annotations2 := map[string][]byte{"x": {9, 10}, "y": {11, 12}} + + // generate all possible combinations of the above + testCases := []struct { + encryptedDEK []byte + keyID string + annotations map[string][]byte + }{ + {encryptedDEK1, keyID1, annotations1}, + {encryptedDEK1, keyID1, annotations2}, + {encryptedDEK1, keyID2, annotations1}, + {encryptedDEK1, keyID2, annotations2}, + {encryptedDEK2, keyID1, annotations1}, + {encryptedDEK2, keyID1, annotations2}, + {encryptedDEK2, keyID2, annotations1}, + {encryptedDEK2, keyID2, annotations2}, + } + + for _, tc := range testCases { + tc := tc + for _, tc2 := range testCases { + tc2 := tc2 + t.Run(fmt.Sprintf("%+v-%+v", tc, tc2), func(t *testing.T) { + key1, err1 := generateCacheKey(tc.encryptedDEK, tc.keyID, tc.annotations) + key2, err2 := generateCacheKey(tc2.encryptedDEK, tc2.keyID, tc2.annotations) + if err1 != nil || err2 != nil { + t.Errorf("generateCacheKey() want err=nil, got err1=%q, err2=%q", errString(err1), errString(err2)) + } + if bytes.Equal(key1, key2) != reflect.DeepEqual(tc, tc2) { + t.Errorf("expected %v, got %v", reflect.DeepEqual(tc, tc2), bytes.Equal(key1, key2)) + } + }) + } + } +} + func errString(err error) string { if err == nil { return ""