From b10697c7880848d7ec110fd6b3e67015bbe74fa8 Mon Sep 17 00:00:00 2001 From: Monis Khan Date: Sun, 27 Aug 2023 15:14:04 -0400 Subject: [PATCH] kmsv2: fix race in simpleCache.set when setting cache size metric Signed-off-by: Monis Khan --- .../value/encrypt/envelope/kmsv2/cache.go | 13 +++-- .../encrypt/envelope/kmsv2/cache_test.go | 51 ++++++++++++++++++ .../kmsv2_transformation_test.go | 53 ++++++++++++++++++- 3 files changed, 111 insertions(+), 6 deletions(-) diff --git a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache.go b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache.go index be7a2a7f1a6..bc7f04b9c6b 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache.go @@ -39,8 +39,10 @@ type simpleCache struct { ttl time.Duration // hashPool is a per cache pool of hash.Hash (to avoid allocations from building the Hash) // SHA-256 is used to prevent collisions - hashPool *sync.Pool - providerName string + hashPool *sync.Pool + providerName string + mu sync.Mutex // guards call to set + recordCacheSize func(providerName string, size int) // for unit tests } func newSimpleCache(clock clock.Clock, ttl time.Duration, providerName string) *simpleCache { @@ -54,7 +56,8 @@ func newSimpleCache(clock clock.Clock, ttl time.Duration, providerName string) * return sha256.New() }, }, - providerName: providerName, + providerName: providerName, + recordCacheSize: metrics.RecordDekSourceCacheSize, } } @@ -69,6 +72,8 @@ func (c *simpleCache) get(key []byte) value.Read { // set caches the record for the key func (c *simpleCache) set(key []byte, transformer value.Read) { + c.mu.Lock() + defer c.mu.Unlock() if len(key) == 0 { panic("key must not be empty") } @@ -77,7 +82,7 @@ func (c *simpleCache) set(key []byte, transformer value.Read) { } c.cache.Set(c.keyFunc(key), transformer, c.ttl) // Add metrics for cache size - metrics.RecordDekSourceCacheSize(c.providerName, c.cache.Len()) + c.recordCacheSize(c.providerName, c.cache.Len()) } // keyFunc generates a string key by hashing the inputs. diff --git a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache_test.go b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache_test.go index 1f686170eab..b79294d31a1 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache_test.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache_test.go @@ -22,9 +22,11 @@ import ( "crypto/sha256" "fmt" "sync" + "sync/atomic" "testing" "time" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apiserver/pkg/storage/value" testingclock "k8s.io/utils/clock/testing" ) @@ -153,3 +155,52 @@ func generateKey(length int) (key []byte, err error) { } return key, nil } + +func TestMetrics(t *testing.T) { + fakeClock := testingclock.NewFakeClock(time.Now()) + cache := newSimpleCache(fakeClock, 5*time.Second, "panda") + var record sync.Map + var cacheSize atomic.Uint64 + cache.recordCacheSize = func(providerName string, size int) { + if providerName != "panda" { + t.Errorf(`expected "panda" as provider name, got %q`, providerName) + } + if _, loaded := record.LoadOrStore(size, nil); loaded { + t.Errorf("detected duplicated cache size metric for %d", size) + } + newSize := uint64(size) + oldSize := cacheSize.Swap(newSize) + if oldSize > newSize { + t.Errorf("cache size decreased from %d to %d", oldSize, newSize) + } + } + transformer := &envelopeTransformer{} + + want := sets.NewInt() + startCh := make(chan struct{}) + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + want.Insert(i + 1) + k := fmt.Sprintf("key-%d", i) + wg.Add(1) + go func(key string) { + defer wg.Done() + <-startCh + cache.set([]byte(key), transformer) + }(k) + } + close(startCh) + wg.Wait() + + got := sets.NewInt() + record.Range(func(key, value any) bool { + got.Insert(key.(int)) + if value != nil { + t.Errorf("expected value to be nil but got %v", value) + } + return true + }) + if !want.Equal(got) { + t.Errorf("cache size entries missing values: %v", want.SymmetricDifference(got).List()) + } +} diff --git a/test/integration/controlplane/transformation/kmsv2_transformation_test.go b/test/integration/controlplane/transformation/kmsv2_transformation_test.go index ebdb295e5be..a39bdbdd7de 100644 --- a/test/integration/controlplane/transformation/kmsv2_transformation_test.go +++ b/test/integration/controlplane/transformation/kmsv2_transformation_test.go @@ -27,17 +27,20 @@ import ( "encoding/binary" "fmt" "io" + "regexp" "strings" "testing" "time" "github.com/gogo/protobuf/proto" + "github.com/google/go-cmp/cmp" clientv3 "go.etcd.io/etcd/client/v3" corev1 "k8s.io/api/core/v1" apiextensionsclientset "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset" "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured/unstructuredscheme" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" utilrand "k8s.io/apimachinery/pkg/util/rand" @@ -58,6 +61,7 @@ import ( utilfeature "k8s.io/apiserver/pkg/util/feature" "k8s.io/client-go/dynamic" "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" featuregatetesting "k8s.io/component-base/featuregate/testing" "k8s.io/klog/v2" kmsv2api "k8s.io/kms/apis/v2" @@ -201,6 +205,53 @@ resources: } defer test.cleanUp() + ctx := testContext(t) + + // the global metrics registry persists across test runs - reset it here so we can make assertions + copyConfig := rest.CopyConfig(test.kubeAPIServer.ClientConfig) + copyConfig.GroupVersion = &schema.GroupVersion{} + copyConfig.NegotiatedSerializer = unstructuredscheme.NewUnstructuredNegotiatedSerializer() + rc, err := rest.RESTClientFor(copyConfig) + if err != nil { + t.Fatal(err) + } + if err := rc.Delete().AbsPath("/metrics").Do(ctx).Error(); err != nil { + t.Fatal(err) + } + + // assert that the metrics we collect during the test run match expectations + wantMetricStrings := []string{ + `apiserver_envelope_encryption_dek_source_cache_size{provider_name="kms-provider"} 1`, + `apiserver_envelope_encryption_key_id_hash_last_timestamp_seconds{key_id_hash="sha256:6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d49c01e52ddb7875b4b",provider_name="kms-provider",transformation_type="from_storage"} FP`, + `apiserver_envelope_encryption_key_id_hash_last_timestamp_seconds{key_id_hash="sha256:6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d49c01e52ddb7875b4b",provider_name="kms-provider",transformation_type="to_storage"} FP`, + `apiserver_envelope_encryption_key_id_hash_total{key_id_hash="sha256:6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d49c01e52ddb7875b4b",provider_name="kms-provider",transformation_type="from_storage"} 2`, + `apiserver_envelope_encryption_key_id_hash_total{key_id_hash="sha256:6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d49c01e52ddb7875b4b",provider_name="kms-provider",transformation_type="to_storage"} 1`, + } + defer func() { + body, err := rc.Get().AbsPath("/metrics").DoRaw(ctx) + if err != nil { + t.Fatal(err) + } + var gotMetricStrings []string + trimFP := regexp.MustCompile(`(.*)(} \d+\.\d+.*)`) + for _, line := range strings.Split(string(body), "\n") { + if strings.HasPrefix(line, "apiserver_envelope_") { + if strings.HasPrefix(line, "apiserver_envelope_encryption_dek_cache_fill_percent") { + continue // this can be ignored as it is KMS v1 only + } + + if strings.Contains(line, "_seconds") { + line = trimFP.ReplaceAllString(line, `$1`) + "} FP" // ignore floating point metric values + } + + gotMetricStrings = append(gotMetricStrings, line) + } + } + if diff := cmp.Diff(wantMetricStrings, gotMetricStrings); diff != "" { + t.Errorf("unexpected metrics diff (-want +got): %s", diff) + } + }() + test.secret, err = test.createSecret(testSecret, testNamespace) if err != nil { t.Fatalf("Failed to create test secret, error: %v", err) @@ -226,8 +277,6 @@ resources: t.Fatalf("expected secret to be prefixed with %s, but got %s", wantPrefix, rawEnvelope) } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() ciphertext, err := envelopeData.cipherTextDEKSource() if err != nil { t.Fatalf("failed to get ciphertext DEK/seed from KMSv2 Plugin: %v", err)