kmsv2: implement expire cache with clock

Signed-off-by: Anish Ramasekar <anish.ramasekar@gmail.com>
This commit is contained in:
Anish Ramasekar 2022-09-14 20:01:45 +00:00
parent 674eb36f92
commit 4804baa011
No known key found for this signature in database
GPG Key ID: F1F7F3518F1ECB0C
19 changed files with 206 additions and 136 deletions

View File

@ -109,7 +109,7 @@ type KMSConfiguration struct {
// name is the name of the KMS plugin to be used. // name is the name of the KMS plugin to be used.
Name string Name string
// cachesize is the maximum number of secrets which are cached in memory. The default value is 1000. // cachesize is the maximum number of secrets which are cached in memory. The default value is 1000.
// Set to a negative value to disable caching. // Set to a negative value to disable caching. This field is only allowed for KMS v1 providers.
// +optional // +optional
CacheSize *int32 CacheSize *int32
// endpoint is the gRPC server listening address, for example "unix:///var/run/kms-provider.sock". // endpoint is the gRPC server listening address, for example "unix:///var/run/kms-provider.sock".

View File

@ -39,11 +39,12 @@ func SetDefaults_KMSConfiguration(obj *KMSConfiguration) {
obj.Timeout = defaultTimeout obj.Timeout = defaultTimeout
} }
if obj.CacheSize == nil {
obj.CacheSize = &defaultCacheSize
}
if obj.APIVersion == "" { if obj.APIVersion == "" {
obj.APIVersion = defaultAPIVersion obj.APIVersion = defaultAPIVersion
} }
// cacheSize is relevant only for kms v1
if obj.CacheSize == nil && obj.APIVersion == "v1" {
obj.CacheSize = &defaultCacheSize
}
} }

View File

@ -57,6 +57,7 @@ func TestKMSProviderCacheDefaults(t *testing.T) {
var ( var (
zero int32 = 0 zero int32 = 0
ten int32 = 10 ten int32 = 10
negative int32 = -1
) )
testCases := []struct { testCases := []struct {
@ -79,6 +80,21 @@ func TestKMSProviderCacheDefaults(t *testing.T) {
in: &KMSConfiguration{CacheSize: &ten}, in: &KMSConfiguration{CacheSize: &ten},
want: &KMSConfiguration{Timeout: defaultTimeout, CacheSize: &ten, APIVersion: defaultAPIVersion}, want: &KMSConfiguration{Timeout: defaultTimeout, CacheSize: &ten, APIVersion: defaultAPIVersion},
}, },
{
desc: "negative cache size supplied",
in: &KMSConfiguration{CacheSize: &negative},
want: &KMSConfiguration{Timeout: defaultTimeout, CacheSize: &negative, APIVersion: defaultAPIVersion},
},
{
desc: "cache size not supplied but API version is v2",
in: &KMSConfiguration{APIVersion: "v2"},
want: &KMSConfiguration{Timeout: defaultTimeout, APIVersion: "v2"},
},
{
desc: "cache size not supplied with API version v1",
in: &KMSConfiguration{APIVersion: "v1"},
want: &KMSConfiguration{Timeout: defaultTimeout, CacheSize: &defaultCacheSize, APIVersion: defaultAPIVersion},
},
} }
for _, tt := range testCases { for _, tt := range testCases {
@ -104,8 +120,13 @@ func TestKMSProviderAPIVersionDefaults(t *testing.T) {
}, },
{ {
desc: "apiVersion supplied", desc: "apiVersion supplied",
in: &KMSConfiguration{Timeout: &v1.Duration{Duration: 1 * time.Minute}, APIVersion: "v1"},
want: &KMSConfiguration{Timeout: &v1.Duration{Duration: 1 * time.Minute}, CacheSize: &defaultCacheSize, APIVersion: "v1"},
},
{
desc: "apiVersion v2 supplied, cache size not defaulted",
in: &KMSConfiguration{Timeout: &v1.Duration{Duration: 1 * time.Minute}, APIVersion: "v2"}, in: &KMSConfiguration{Timeout: &v1.Duration{Duration: 1 * time.Minute}, APIVersion: "v2"},
want: &KMSConfiguration{Timeout: &v1.Duration{Duration: 1 * time.Minute}, CacheSize: &defaultCacheSize, APIVersion: "v2"}, want: &KMSConfiguration{Timeout: &v1.Duration{Duration: 1 * time.Minute}, APIVersion: "v2"},
}, },
} }

View File

@ -109,7 +109,7 @@ type KMSConfiguration struct {
// name is the name of the KMS plugin to be used. // name is the name of the KMS plugin to be used.
Name string `json:"name"` Name string `json:"name"`
// cachesize is the maximum number of secrets which are cached in memory. The default value is 1000. // cachesize is the maximum number of secrets which are cached in memory. The default value is 1000.
// Set to a negative value to disable caching. // Set to a negative value to disable caching. This field is only allowed for KMS v1 providers.
// +optional // +optional
CacheSize *int32 `json:"cachesize,omitempty"` CacheSize *int32 `json:"cachesize,omitempty"`
// endpoint is the gRPC server listening address, for example "unix:///var/run/kms-provider.sock". // endpoint is the gRPC server listening address, for example "unix:///var/run/kms-provider.sock".

View File

@ -195,7 +195,13 @@ func validateKMSConfiguration(c *config.KMSConfiguration, fieldPath *field.Path,
func validateKMSCacheSize(c *config.KMSConfiguration, fieldPath *field.Path) field.ErrorList { func validateKMSCacheSize(c *config.KMSConfiguration, fieldPath *field.Path) field.ErrorList {
allErrs := field.ErrorList{} allErrs := field.ErrorList{}
if *c.CacheSize == 0 {
// In defaulting, we set the cache size to the default value only when API version is v1.
// So, for v2 API version, we expect the cache size field to be nil.
if c.APIVersion != "v1" && c.CacheSize != nil {
allErrs = append(allErrs, field.Invalid(fieldPath, *c.CacheSize, "cachesize is not supported in v2"))
}
if c.APIVersion == "v1" && *c.CacheSize == 0 {
allErrs = append(allErrs, field.Invalid(fieldPath, *c.CacheSize, fmt.Sprintf(nonZeroErrFmt, "cachesize"))) allErrs = append(allErrs, field.Invalid(fieldPath, *c.CacheSize, fmt.Sprintf(nonZeroErrFmt, "cachesize")))
} }

View File

@ -185,7 +185,6 @@ func TestStructure(t *testing.T) {
Name: "foo", Name: "foo",
Endpoint: "unix:///tmp/kms-provider-2.socket", Endpoint: "unix:///tmp/kms-provider-2.socket",
Timeout: &metav1.Duration{Duration: 3 * time.Second}, Timeout: &metav1.Duration{Duration: 3 * time.Second},
CacheSize: &cacheSize,
APIVersion: "v2", APIVersion: "v2",
}, },
}, },
@ -210,7 +209,6 @@ func TestStructure(t *testing.T) {
Name: "foo", Name: "foo",
Endpoint: "unix:///tmp/kms-provider-1.socket", Endpoint: "unix:///tmp/kms-provider-1.socket",
Timeout: &metav1.Duration{Duration: 3 * time.Second}, Timeout: &metav1.Duration{Duration: 3 * time.Second},
CacheSize: &cacheSize,
APIVersion: "v2", APIVersion: "v2",
}, },
}, },
@ -219,7 +217,6 @@ func TestStructure(t *testing.T) {
Name: "foo", Name: "foo",
Endpoint: "unix:///tmp/kms-provider-2.socket", Endpoint: "unix:///tmp/kms-provider-2.socket",
Timeout: &metav1.Duration{Duration: 3 * time.Second}, Timeout: &metav1.Duration{Duration: 3 * time.Second},
CacheSize: &cacheSize,
APIVersion: "v2", APIVersion: "v2",
}, },
}, },
@ -244,7 +241,6 @@ func TestStructure(t *testing.T) {
Name: "foo", Name: "foo",
Endpoint: "unix:///tmp/kms-provider-1.socket", Endpoint: "unix:///tmp/kms-provider-1.socket",
Timeout: &metav1.Duration{Duration: 3 * time.Second}, Timeout: &metav1.Duration{Duration: 3 * time.Second},
CacheSize: &cacheSize,
APIVersion: "v2", APIVersion: "v2",
}, },
}, },
@ -258,7 +254,6 @@ func TestStructure(t *testing.T) {
Name: "foo", Name: "foo",
Endpoint: "unix:///tmp/kms-provider-2.socket", Endpoint: "unix:///tmp/kms-provider-2.socket",
Timeout: &metav1.Duration{Duration: 3 * time.Second}, Timeout: &metav1.Duration{Duration: 3 * time.Second},
CacheSize: &cacheSize,
APIVersion: "v2", APIVersion: "v2",
}, },
}, },
@ -297,7 +292,6 @@ func TestStructure(t *testing.T) {
Name: "foo", Name: "foo",
Endpoint: "unix:///tmp/kms-provider-2.socket", Endpoint: "unix:///tmp/kms-provider-2.socket",
Timeout: &metav1.Duration{Duration: 3 * time.Second}, Timeout: &metav1.Duration{Duration: 3 * time.Second},
CacheSize: &cacheSize,
APIVersion: "v2", APIVersion: "v2",
}, },
}, },
@ -539,21 +533,28 @@ func TestKMSProviderCacheSize(t *testing.T) {
}{ }{
{ {
desc: "valid positive cache size", desc: "valid positive cache size",
in: &config.KMSConfiguration{CacheSize: &positiveCacheSize}, in: &config.KMSConfiguration{APIVersion: "v1", CacheSize: &positiveCacheSize},
want: field.ErrorList{}, want: field.ErrorList{},
}, },
{ {
desc: "invalid zero cache size", desc: "invalid zero cache size",
in: &config.KMSConfiguration{CacheSize: &zeroCacheSize}, in: &config.KMSConfiguration{APIVersion: "v1", CacheSize: &zeroCacheSize},
want: field.ErrorList{ want: field.ErrorList{
field.Invalid(cacheField, int32(0), fmt.Sprintf(nonZeroErrFmt, "cachesize")), field.Invalid(cacheField, int32(0), fmt.Sprintf(nonZeroErrFmt, "cachesize")),
}, },
}, },
{ {
desc: "valid negative caches size", desc: "valid negative caches size",
in: &config.KMSConfiguration{CacheSize: &negativeCacheSize}, in: &config.KMSConfiguration{APIVersion: "v1", CacheSize: &negativeCacheSize},
want: field.ErrorList{}, want: field.ErrorList{},
}, },
{
desc: "cache size set with v2 provider",
in: &config.KMSConfiguration{CacheSize: &positiveCacheSize, APIVersion: "v2"},
want: field.ErrorList{
field.Invalid(cacheField, positiveCacheSize, "cachesize is not supported in v2"),
},
},
} }
for _, tt := range testCases { for _, tt := range testCases {

View File

@ -598,7 +598,7 @@ func kmsPrefixTransformer(ctx context.Context, config *apiserverconfig.KMSConfig
// using AES-GCM by default for encrypting data with KMSv2 // using AES-GCM by default for encrypting data with KMSv2
transformer := value.PrefixTransformer{ transformer := value.PrefixTransformer{
Transformer: envelopekmsv2.NewEnvelopeTransformer(envelopeService, probe.getCurrentKeyID, int(*config.CacheSize), aestransformer.NewGCMTransformer), Transformer: envelopekmsv2.NewEnvelopeTransformer(envelopeService, probe.getCurrentKeyID, aestransformer.NewGCMTransformer),
Prefix: []byte(kmsTransformerPrefixV2 + kmsName + ":"), Prefix: []byte(kmsTransformerPrefixV2 + kmsName + ":"),
} }

View File

@ -18,7 +18,6 @@ resources:
apiVersion: v2 apiVersion: v2
name: testproviderv2 name: testproviderv2
endpoint: unix:///tmp/testprovider.sock endpoint: unix:///tmp/testprovider.sock
cachesize: 10
- identity: {} - identity: {}
- secretbox: - secretbox:
keys: keys:

View File

@ -22,7 +22,6 @@ resources:
apiVersion: v2 apiVersion: v2
name: testproviderv2 name: testproviderv2
endpoint: unix:///tmp/testprovider.sock endpoint: unix:///tmp/testprovider.sock
cachesize: 10
- aescbc: - aescbc:
keys: keys:
- name: key1 - name: key1

View File

@ -20,7 +20,6 @@ resources:
apiVersion: v2 apiVersion: v2
name: testproviderv2 name: testproviderv2
endpoint: unix:///tmp/testprovider.sock endpoint: unix:///tmp/testprovider.sock
cachesize: 10
- aescbc: - aescbc:
keys: keys:
- name: key1 - name: key1

View File

@ -12,7 +12,6 @@ resources:
apiVersion: v2 apiVersion: v2
name: testproviderv2 name: testproviderv2
endpoint: unix:///tmp/testprovider.sock endpoint: unix:///tmp/testprovider.sock
cachesize: 10
- secretbox: - secretbox:
keys: keys:
- name: key1 - name: key1

View File

@ -8,7 +8,6 @@ resources:
apiVersion: v2 apiVersion: v2
name: testproviderv2 name: testproviderv2
endpoint: unix:///tmp/testprovider.sock endpoint: unix:///tmp/testprovider.sock
cachesize: 10
- kms: - kms:
name: testprovider name: testprovider
endpoint: unix:///tmp/testprovider.sock endpoint: unix:///tmp/testprovider.sock

View File

@ -22,7 +22,6 @@ resources:
apiVersion: v2 apiVersion: v2
name: testproviderv2 name: testproviderv2
endpoint: unix:///tmp/testprovider.sock endpoint: unix:///tmp/testprovider.sock
cachesize: 10
- identity: {} - identity: {}
- aesgcm: - aesgcm:
keys: keys:

View File

@ -7,12 +7,10 @@ resources:
- kms: - kms:
apiVersion: v2 apiVersion: v2
name: kms-provider-1 name: kms-provider-1
cachesize: 1000
endpoint: unix:///@provider1.sock endpoint: unix:///@provider1.sock
- kms: - kms:
apiVersion: v2 apiVersion: v2
name: kms-provider-2 name: kms-provider-2
cachesize: 1000
endpoint: unix:///@provider2.sock endpoint: unix:///@provider2.sock
- kms: - kms:
apiVersion: v2 apiVersion: v2

View File

@ -0,0 +1,59 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package kmsv2 transforms values for storage at rest using a Envelope v2 provider
package kmsv2
import (
"encoding/base64"
"time"
utilcache "k8s.io/apimachinery/pkg/util/cache"
"k8s.io/apiserver/pkg/storage/value"
"k8s.io/utils/clock"
)
type simpleCache struct {
cache *utilcache.Expiring
ttl time.Duration
}
func newSimpleCache(clock clock.Clock, ttl time.Duration) *simpleCache {
return &simpleCache{
cache: utilcache.NewExpiringWithClock(clock),
ttl: ttl,
}
}
// given a key, return the transformer, or nil if it does not exist in the cache
func (c *simpleCache) get(key []byte) value.Transformer {
record, ok := c.cache.Get(base64.StdEncoding.EncodeToString(key))
if !ok {
return nil
}
return record.(value.Transformer)
}
// set caches the record for the key
func (c *simpleCache) set(key []byte, transformer value.Transformer) {
if len(key) == 0 {
panic("key must not be empty")
}
if transformer == nil {
panic("transformer must not be nil")
}
c.cache.Set(base64.StdEncoding.EncodeToString(key), transformer, c.ttl)
}

View File

@ -0,0 +1,59 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package kmsv2 transforms values for storage at rest using a Envelope v2 provider
package kmsv2
import (
"testing"
"time"
"k8s.io/apiserver/pkg/storage/value"
testingclock "k8s.io/utils/clock/testing"
)
func TestSimpleCacheSetError(t *testing.T) {
fakeClock := testingclock.NewFakeClock(time.Now())
cache := newSimpleCache(fakeClock, time.Second)
tests := []struct {
name string
key []byte
transformer value.Transformer
}{
{
name: "empty key",
key: []byte{},
transformer: nil,
},
{
name: "nil transformer",
key: []byte("key"),
transformer: nil,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("The code did not panic")
}
}()
cache.set(test.key, test.transformer)
})
}
}

View File

@ -22,7 +22,6 @@ import (
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"encoding/base64"
"fmt" "fmt"
"time" "time"
@ -36,7 +35,7 @@ import (
"k8s.io/apiserver/pkg/storage/value/encrypt/envelope/metrics" "k8s.io/apiserver/pkg/storage/value/encrypt/envelope/metrics"
"k8s.io/klog/v2" "k8s.io/klog/v2"
kmsservice "k8s.io/kms/service" kmsservice "k8s.io/kms/service"
"k8s.io/utils/lru" "k8s.io/utils/clock"
) )
const ( const (
@ -48,6 +47,8 @@ const (
keyIDMaxSize = 1 * 1024 // 1 kB keyIDMaxSize = 1 * 1024 // 1 kB
// encryptedDEKMaxSize is the maximum size of the encrypted DEK. // encryptedDEKMaxSize is the maximum size of the encrypted DEK.
encryptedDEKMaxSize = 1 * 1024 // 1 kB encryptedDEKMaxSize = 1 * 1024 // 1 kB
// cacheTTL is the default time-to-live for the cache entry.
cacheTTL = 1 * time.Hour
) )
type KeyIDGetterFunc func(context.Context) (keyID string, err error) type KeyIDGetterFunc func(context.Context) (keyID string, err error)
@ -57,36 +58,25 @@ type envelopeTransformer struct {
keyIDGetter KeyIDGetterFunc keyIDGetter KeyIDGetterFunc
// transformers is a thread-safe LRU cache which caches decrypted DEKs indexed by their encrypted form.
transformers *lru.Cache
// baseTransformerFunc creates a new transformer for encrypting the data with the DEK. // baseTransformerFunc creates a new transformer for encrypting the data with the DEK.
baseTransformerFunc func(cipher.Block) value.Transformer baseTransformerFunc func(cipher.Block) value.Transformer
// cache is a thread-safe expiring lru cache which caches decrypted DEKs indexed by their encrypted form.
cacheSize int cache *simpleCache
cacheEnabled bool
} }
// NewEnvelopeTransformer returns a transformer which implements a KEK-DEK based envelope encryption scheme. // NewEnvelopeTransformer returns a transformer which implements a KEK-DEK based envelope encryption scheme.
// It uses envelopeService to encrypt and decrypt DEKs. Respective DEKs (in encrypted form) are prepended to // It uses envelopeService to encrypt and decrypt DEKs. Respective DEKs (in encrypted form) are prepended to
// the data items they encrypt. A cache (of size cacheSize) is maintained to store the most recently // the data items they encrypt.
// used decrypted DEKs in memory. func NewEnvelopeTransformer(envelopeService kmsservice.Service, keyIDGetter KeyIDGetterFunc, baseTransformerFunc func(cipher.Block) value.Transformer) value.Transformer {
func NewEnvelopeTransformer(envelopeService kmsservice.Service, keyIDGetter KeyIDGetterFunc, cacheSize int, baseTransformerFunc func(cipher.Block) value.Transformer) value.Transformer { return newEnvelopeTransformerWithClock(envelopeService, keyIDGetter, baseTransformerFunc, cacheTTL, clock.RealClock{})
var cache *lru.Cache
if cacheSize > 0 {
// TODO(aramase): Switch to using expiring cache: kubernetes/kubernetes/staging/src/k8s.io/apimachinery/pkg/util/cache/expiring.go.
// It handles scans a lot better, doesn't have to be right sized, and don't have a global lock on reads.
cache = lru.New(cacheSize)
} }
func newEnvelopeTransformerWithClock(envelopeService kmsservice.Service, keyIDGetter KeyIDGetterFunc, baseTransformerFunc func(cipher.Block) value.Transformer, cacheTTL time.Duration, clock clock.Clock) value.Transformer {
return &envelopeTransformer{ return &envelopeTransformer{
envelopeService: envelopeService, envelopeService: envelopeService,
keyIDGetter: keyIDGetter, keyIDGetter: keyIDGetter,
transformers: cache, cache: newSimpleCache(clock, cacheTTL),
baseTransformerFunc: baseTransformerFunc, baseTransformerFunc: baseTransformerFunc,
cacheEnabled: cacheSize > 0,
cacheSize: cacheSize,
} }
} }
@ -101,11 +91,10 @@ func (t *envelopeTransformer) TransformFromStorage(ctx context.Context, data []b
} }
// Look up the decrypted DEK from cache or Envelope. // Look up the decrypted DEK from cache or Envelope.
transformer := t.getTransformer(encryptedObject.EncryptedDEK) transformer := t.cache.get(encryptedObject.EncryptedDEK)
if transformer == nil { if transformer == nil {
if t.cacheEnabled {
value.RecordCacheMiss() value.RecordCacheMiss()
}
uid := string(uuid.NewUUID()) uid := string(uuid.NewUUID())
klog.V(6).InfoS("Decrypting content using envelope service", "uid", uid, "key", string(dataCtx.AuthenticatedData())) klog.V(6).InfoS("Decrypting content using envelope service", "uid", uid, "key", string(dataCtx.AuthenticatedData()))
key, err := t.envelopeService.Decrypt(ctx, uid, &kmsservice.DecryptRequest{ key, err := t.envelopeService.Decrypt(ctx, uid, &kmsservice.DecryptRequest{
@ -189,28 +178,11 @@ func (t *envelopeTransformer) addTransformer(encKey []byte, key []byte) (value.T
return nil, err return nil, err
} }
transformer := t.baseTransformerFunc(block) transformer := t.baseTransformerFunc(block)
// Use base64 of encKey as the key into the cache because hashicorp/golang-lru // TODO(aramase): Add metrics for cache fill percentage with custom cache implementation.
// cannot hash []uint8. t.cache.set(encKey, transformer)
if t.cacheEnabled {
t.transformers.Add(base64.StdEncoding.EncodeToString(encKey), transformer)
metrics.RecordDekCacheFillPercent(float64(t.transformers.Len()) / float64(t.cacheSize))
}
return transformer, nil return transformer, nil
} }
// getTransformer fetches the transformer corresponding to encKey from cache, if it exists.
func (t *envelopeTransformer) getTransformer(encKey []byte) value.Transformer {
if !t.cacheEnabled {
return nil
}
_transformer, found := t.transformers.Get(base64.StdEncoding.EncodeToString(encKey))
if found {
return _transformer.(value.Transformer)
}
return nil
}
// doEncode encodes the EncryptedObject to a byte array. // doEncode encodes the EncryptedObject to a byte array.
func (t *envelopeTransformer) doEncode(request *kmstypes.EncryptedObject) ([]byte, error) { func (t *envelopeTransformer) doEncode(request *kmstypes.EncryptedObject) ([]byte, error) {
if err := validateEncryptedObject(request); err != nil { if err := validateEncryptedObject(request); err != nil {

View File

@ -26,18 +26,20 @@ import (
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time"
"k8s.io/apiserver/pkg/storage/value" "k8s.io/apiserver/pkg/storage/value"
aestransformer "k8s.io/apiserver/pkg/storage/value/encrypt/aes" aestransformer "k8s.io/apiserver/pkg/storage/value/encrypt/aes"
kmstypes "k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/v2alpha1" kmstypes "k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/v2alpha1"
kmsservice "k8s.io/kms/service" kmsservice "k8s.io/kms/service"
testingclock "k8s.io/utils/clock/testing"
) )
const ( const (
testText = "abcdefghijklmnopqrstuvwxyz" testText = "abcdefghijklmnopqrstuvwxyz"
testContextText = "0123456789" testContextText = "0123456789"
testEnvelopeCacheSize = 10
testKeyVersion = "1" testKeyVersion = "1"
testCacheTTL = 10 * time.Second
) )
// testEnvelopeService is a mock Envelope service which can be used to simulate remote Envelope services // testEnvelopeService is a mock Envelope service which can be used to simulate remote Envelope services
@ -109,36 +111,33 @@ func newTestEnvelopeService() *testEnvelopeService {
func TestEnvelopeCaching(t *testing.T) { func TestEnvelopeCaching(t *testing.T) {
testCases := []struct { testCases := []struct {
desc string desc string
cacheSize int cacheTTL time.Duration
simulateKMSPluginFailure bool simulateKMSPluginFailure bool
expectedError string expectedError string
}{ }{
{ {
desc: "positive cache size should withstand plugin failure", desc: "entry in cache should withstand plugin failure",
cacheSize: 1000, cacheTTL: 5 * time.Minute,
simulateKMSPluginFailure: true, simulateKMSPluginFailure: true,
}, },
{ {
desc: "cache disabled size should not withstand plugin failure", desc: "cache entry expired should not withstand plugin failure",
cacheSize: 0, cacheTTL: 1 * time.Millisecond,
simulateKMSPluginFailure: true, simulateKMSPluginFailure: true,
expectedError: "failed to decrypt DEK, error: Envelope service was disabled", expectedError: "failed to decrypt DEK, error: Envelope service was disabled",
}, },
{
desc: "cache disabled, no plugin failure should succeed",
cacheSize: 0,
simulateKMSPluginFailure: false,
},
} }
for _, tt := range testCases { for _, tt := range testCases {
t.Run(tt.desc, func(t *testing.T) { t.Run(tt.desc, func(t *testing.T) {
envelopeService := newTestEnvelopeService() envelopeService := newTestEnvelopeService()
envelopeTransformer := NewEnvelopeTransformer(envelopeService, fakeClock := testingclock.NewFakeClock(time.Now())
envelopeTransformer := newEnvelopeTransformerWithClock(envelopeService,
func(ctx context.Context) (string, error) { func(ctx context.Context) (string, error) {
return "", nil return "", nil
}, },
tt.cacheSize, aestransformer.NewGCMTransformer) aestransformer.NewGCMTransformer, tt.cacheTTL, fakeClock)
ctx := context.Background() ctx := context.Background()
dataCtx := value.DefaultContext([]byte(testContextText)) dataCtx := value.DefaultContext([]byte(testContextText))
originalText := []byte(testText) originalText := []byte(testText)
@ -156,6 +155,8 @@ func TestEnvelopeCaching(t *testing.T) {
} }
envelopeService.SetDisabledStatus(tt.simulateKMSPluginFailure) envelopeService.SetDisabledStatus(tt.simulateKMSPluginFailure)
fakeClock.Step(2 * time.Minute)
// Subsequent read for the same data should work fine due to caching.
untransformedData, _, err = envelopeTransformer.TransformFromStorage(ctx, transformedData, dataCtx) untransformedData, _, err = envelopeTransformer.TransformFromStorage(ctx, transformedData, dataCtx)
if tt.expectedError != "" { if tt.expectedError != "" {
if err == nil { if err == nil {
@ -176,45 +177,6 @@ func TestEnvelopeCaching(t *testing.T) {
} }
} }
// Makes Envelope transformer hit cache limit, throws error if it misbehaves.
func TestEnvelopeCacheLimit(t *testing.T) {
envelopeTransformer := NewEnvelopeTransformer(newTestEnvelopeService(),
func(ctx context.Context) (string, error) {
return "", nil
},
testEnvelopeCacheSize, aestransformer.NewGCMTransformer)
ctx := context.Background()
dataCtx := value.DefaultContext([]byte(testContextText))
transformedOutputs := map[int][]byte{}
// Overwrite lots of entries in the map
for i := 0; i < 2*testEnvelopeCacheSize; i++ {
numberText := []byte(strconv.Itoa(i))
res, err := envelopeTransformer.TransformToStorage(ctx, numberText, dataCtx)
transformedOutputs[i] = res
if err != nil {
t.Fatalf("envelopeTransformer: error while transforming data (%v) to storage: %s", numberText, err)
}
}
// Try reading all the data now, ensuring cache misses don't cause a concern.
for i := 0; i < 2*testEnvelopeCacheSize; i++ {
numberText := []byte(strconv.Itoa(i))
output, _, err := envelopeTransformer.TransformFromStorage(ctx, transformedOutputs[i], dataCtx)
if err != nil {
t.Fatalf("envelopeTransformer: error while transforming data (%v) from storage: %s", transformedOutputs[i], err)
}
if !bytes.Equal(numberText, output) {
t.Fatalf("envelopeTransformer transformed data incorrectly using cache. Expected: %v, got %v", numberText, output)
}
}
}
// Test keyIDGetter as part of envelopeTransformer, throws error if returned err or staleness is incorrect. // Test keyIDGetter as part of envelopeTransformer, throws error if returned err or staleness is incorrect.
func TestEnvelopeTransformerKeyIDGetter(t *testing.T) { func TestEnvelopeTransformerKeyIDGetter(t *testing.T) {
t.Parallel() t.Parallel()
@ -253,7 +215,7 @@ func TestEnvelopeTransformerKeyIDGetter(t *testing.T) {
func(ctx context.Context) (string, error) { func(ctx context.Context) (string, error) {
return tt.testKeyID, tt.testErr return tt.testKeyID, tt.testErr
}, },
0, aestransformer.NewGCMTransformer) aestransformer.NewGCMTransformer)
ctx := context.Background() ctx := context.Background()
dataCtx := value.DefaultContext([]byte(testContextText)) dataCtx := value.DefaultContext([]byte(testContextText))
@ -321,7 +283,7 @@ func TestTransformToStorageError(t *testing.T) {
func(ctx context.Context) (string, error) { func(ctx context.Context) (string, error) {
return "", nil return "", nil
}, },
0, aestransformer.NewGCMTransformer) aestransformer.NewGCMTransformer)
ctx := context.Background() ctx := context.Background()
dataCtx := value.DefaultContext([]byte(testContextText)) dataCtx := value.DefaultContext([]byte(testContextText))

View File

@ -125,7 +125,6 @@ resources:
- kms: - kms:
apiVersion: v2 apiVersion: v2
name: kms-provider name: kms-provider
cachesize: 1000
endpoint: unix:///@kms-provider.sock endpoint: unix:///@kms-provider.sock
` `
@ -229,7 +228,6 @@ resources:
- kms: - kms:
apiVersion: v2 apiVersion: v2
name: kms-provider name: kms-provider
cachesize: 1000
endpoint: unix:///@kms-provider.sock endpoint: unix:///@kms-provider.sock
` `
pluginMock, err := kmsv2mock.NewBase64Plugin("@kms-provider.sock") pluginMock, err := kmsv2mock.NewBase64Plugin("@kms-provider.sock")
@ -442,7 +440,6 @@ resources:
- kms: - kms:
apiVersion: v2 apiVersion: v2
name: kms-provider name: kms-provider
cachesize: 1000
endpoint: unix:///@kms-provider.sock endpoint: unix:///@kms-provider.sock
` `