Merge pull request #120221 from enj/enj/i/kms_cache_metrics_lock

kmsv2: fix race in simpleCache.set when setting cache size metric
This commit is contained in:
Kubernetes Prow Robot 2023-09-01 10:00:31 -07:00 committed by GitHub
commit a99e377a54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 111 additions and 6 deletions

View File

@ -39,8 +39,10 @@ type simpleCache struct {
ttl time.Duration ttl time.Duration
// hashPool is a per cache pool of hash.Hash (to avoid allocations from building the Hash) // hashPool is a per cache pool of hash.Hash (to avoid allocations from building the Hash)
// SHA-256 is used to prevent collisions // SHA-256 is used to prevent collisions
hashPool *sync.Pool hashPool *sync.Pool
providerName string providerName string
mu sync.Mutex // guards call to set
recordCacheSize func(providerName string, size int) // for unit tests
} }
func newSimpleCache(clock clock.Clock, ttl time.Duration, providerName string) *simpleCache { func newSimpleCache(clock clock.Clock, ttl time.Duration, providerName string) *simpleCache {
@ -54,7 +56,8 @@ func newSimpleCache(clock clock.Clock, ttl time.Duration, providerName string) *
return sha256.New() return sha256.New()
}, },
}, },
providerName: providerName, providerName: providerName,
recordCacheSize: metrics.RecordDekSourceCacheSize,
} }
} }
@ -69,6 +72,8 @@ func (c *simpleCache) get(key []byte) value.Read {
// set caches the record for the key // set caches the record for the key
func (c *simpleCache) set(key []byte, transformer value.Read) { func (c *simpleCache) set(key []byte, transformer value.Read) {
c.mu.Lock()
defer c.mu.Unlock()
if len(key) == 0 { if len(key) == 0 {
panic("key must not be empty") panic("key must not be empty")
} }
@ -77,7 +82,7 @@ func (c *simpleCache) set(key []byte, transformer value.Read) {
} }
c.cache.Set(c.keyFunc(key), transformer, c.ttl) c.cache.Set(c.keyFunc(key), transformer, c.ttl)
// Add metrics for cache size // Add metrics for cache size
metrics.RecordDekSourceCacheSize(c.providerName, c.cache.Len()) c.recordCacheSize(c.providerName, c.cache.Len())
} }
// keyFunc generates a string key by hashing the inputs. // keyFunc generates a string key by hashing the inputs.

View File

@ -22,9 +22,11 @@ import (
"crypto/sha256" "crypto/sha256"
"fmt" "fmt"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apiserver/pkg/storage/value" "k8s.io/apiserver/pkg/storage/value"
testingclock "k8s.io/utils/clock/testing" testingclock "k8s.io/utils/clock/testing"
) )
@ -153,3 +155,52 @@ func generateKey(length int) (key []byte, err error) {
} }
return key, nil return key, nil
} }
func TestMetrics(t *testing.T) {
fakeClock := testingclock.NewFakeClock(time.Now())
cache := newSimpleCache(fakeClock, 5*time.Second, "panda")
var record sync.Map
var cacheSize atomic.Uint64
cache.recordCacheSize = func(providerName string, size int) {
if providerName != "panda" {
t.Errorf(`expected "panda" as provider name, got %q`, providerName)
}
if _, loaded := record.LoadOrStore(size, nil); loaded {
t.Errorf("detected duplicated cache size metric for %d", size)
}
newSize := uint64(size)
oldSize := cacheSize.Swap(newSize)
if oldSize > newSize {
t.Errorf("cache size decreased from %d to %d", oldSize, newSize)
}
}
transformer := &envelopeTransformer{}
want := sets.NewInt()
startCh := make(chan struct{})
wg := sync.WaitGroup{}
for i := 0; i < 100; i++ {
want.Insert(i + 1)
k := fmt.Sprintf("key-%d", i)
wg.Add(1)
go func(key string) {
defer wg.Done()
<-startCh
cache.set([]byte(key), transformer)
}(k)
}
close(startCh)
wg.Wait()
got := sets.NewInt()
record.Range(func(key, value any) bool {
got.Insert(key.(int))
if value != nil {
t.Errorf("expected value to be nil but got %v", value)
}
return true
})
if !want.Equal(got) {
t.Errorf("cache size entries missing values: %v", want.SymmetricDifference(got).List())
}
}

View File

@ -27,17 +27,20 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"regexp"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"github.com/google/go-cmp/cmp"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
apiextensionsclientset "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset" apiextensionsclientset "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset"
"k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured/unstructuredscheme"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/runtime/schema"
utilrand "k8s.io/apimachinery/pkg/util/rand" utilrand "k8s.io/apimachinery/pkg/util/rand"
@ -58,6 +61,7 @@ import (
utilfeature "k8s.io/apiserver/pkg/util/feature" utilfeature "k8s.io/apiserver/pkg/util/feature"
"k8s.io/client-go/dynamic" "k8s.io/client-go/dynamic"
"k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
featuregatetesting "k8s.io/component-base/featuregate/testing" featuregatetesting "k8s.io/component-base/featuregate/testing"
"k8s.io/klog/v2" "k8s.io/klog/v2"
kmsv2api "k8s.io/kms/apis/v2" kmsv2api "k8s.io/kms/apis/v2"
@ -201,6 +205,53 @@ resources:
} }
defer test.cleanUp() defer test.cleanUp()
ctx := testContext(t)
// the global metrics registry persists across test runs - reset it here so we can make assertions
copyConfig := rest.CopyConfig(test.kubeAPIServer.ClientConfig)
copyConfig.GroupVersion = &schema.GroupVersion{}
copyConfig.NegotiatedSerializer = unstructuredscheme.NewUnstructuredNegotiatedSerializer()
rc, err := rest.RESTClientFor(copyConfig)
if err != nil {
t.Fatal(err)
}
if err := rc.Delete().AbsPath("/metrics").Do(ctx).Error(); err != nil {
t.Fatal(err)
}
// assert that the metrics we collect during the test run match expectations
wantMetricStrings := []string{
`apiserver_envelope_encryption_dek_source_cache_size{provider_name="kms-provider"} 1`,
`apiserver_envelope_encryption_key_id_hash_last_timestamp_seconds{key_id_hash="sha256:6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d49c01e52ddb7875b4b",provider_name="kms-provider",transformation_type="from_storage"} FP`,
`apiserver_envelope_encryption_key_id_hash_last_timestamp_seconds{key_id_hash="sha256:6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d49c01e52ddb7875b4b",provider_name="kms-provider",transformation_type="to_storage"} FP`,
`apiserver_envelope_encryption_key_id_hash_total{key_id_hash="sha256:6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d49c01e52ddb7875b4b",provider_name="kms-provider",transformation_type="from_storage"} 2`,
`apiserver_envelope_encryption_key_id_hash_total{key_id_hash="sha256:6b86b273ff34fce19d6b804eff5a3f5747ada4eaa22f1d49c01e52ddb7875b4b",provider_name="kms-provider",transformation_type="to_storage"} 1`,
}
defer func() {
body, err := rc.Get().AbsPath("/metrics").DoRaw(ctx)
if err != nil {
t.Fatal(err)
}
var gotMetricStrings []string
trimFP := regexp.MustCompile(`(.*)(} \d+\.\d+.*)`)
for _, line := range strings.Split(string(body), "\n") {
if strings.HasPrefix(line, "apiserver_envelope_") {
if strings.HasPrefix(line, "apiserver_envelope_encryption_dek_cache_fill_percent") {
continue // this can be ignored as it is KMS v1 only
}
if strings.Contains(line, "_seconds") {
line = trimFP.ReplaceAllString(line, `$1`) + "} FP" // ignore floating point metric values
}
gotMetricStrings = append(gotMetricStrings, line)
}
}
if diff := cmp.Diff(wantMetricStrings, gotMetricStrings); diff != "" {
t.Errorf("unexpected metrics diff (-want +got): %s", diff)
}
}()
test.secret, err = test.createSecret(testSecret, testNamespace) test.secret, err = test.createSecret(testSecret, testNamespace)
if err != nil { if err != nil {
t.Fatalf("Failed to create test secret, error: %v", err) t.Fatalf("Failed to create test secret, error: %v", err)
@ -226,8 +277,6 @@ resources:
t.Fatalf("expected secret to be prefixed with %s, but got %s", wantPrefix, rawEnvelope) t.Fatalf("expected secret to be prefixed with %s, but got %s", wantPrefix, rawEnvelope)
} }
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
ciphertext, err := envelopeData.cipherTextDEKSource() ciphertext, err := envelopeData.cipherTextDEKSource()
if err != nil { if err != nil {
t.Fatalf("failed to get ciphertext DEK/seed from KMSv2 Plugin: %v", err) t.Fatalf("failed to get ciphertext DEK/seed from KMSv2 Plugin: %v", err)