From f72cf5c510cf2cf7b8ee375f5c2ec835e3ed225a Mon Sep 17 00:00:00 2001 From: Anish Ramasekar Date: Fri, 27 Jan 2023 00:19:32 +0000 Subject: [PATCH] [KMSv2] store hash of encrypted DEK as key in cache Signed-off-by: Anish Ramasekar --- .../value/encrypt/envelope/kmsv2/cache.go | 37 ++++++++- .../encrypt/envelope/kmsv2/cache_test.go | 75 +++++++++++++++++++ 2 files changed, 109 insertions(+), 3 deletions(-) diff --git a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache.go b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache.go index dd82e29270a..d2485c4e4b9 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache.go @@ -18,8 +18,11 @@ limitations under the License. package kmsv2 import ( - "encoding/base64" + "crypto/sha256" + "hash" + "sync" "time" + "unsafe" utilcache "k8s.io/apimachinery/pkg/util/cache" "k8s.io/apiserver/pkg/storage/value" @@ -29,18 +32,26 @@ import ( type simpleCache struct { cache *utilcache.Expiring ttl time.Duration + // hashPool is a per cache pool of hash.Hash (to avoid allocations from building the Hash) + // SHA-256 is used to prevent collisions + hashPool *sync.Pool } func newSimpleCache(clock clock.Clock, ttl time.Duration) *simpleCache { return &simpleCache{ cache: utilcache.NewExpiringWithClock(clock), ttl: ttl, + hashPool: &sync.Pool{ + New: func() interface{} { + return sha256.New() + }, + }, } } // given a key, return the transformer, or nil if it does not exist in the cache func (c *simpleCache) get(key []byte) value.Transformer { - record, ok := c.cache.Get(base64.StdEncoding.EncodeToString(key)) + record, ok := c.cache.Get(c.keyFunc(key)) if !ok { return nil } @@ -55,5 +66,25 @@ func (c *simpleCache) set(key []byte, transformer value.Transformer) { if transformer == nil { panic("transformer must not be nil") } - c.cache.Set(base64.StdEncoding.EncodeToString(key), transformer, c.ttl) + c.cache.Set(c.keyFunc(key), transformer, c.ttl) +} + +// keyFunc generates a string key by hashing the inputs. +// This lowers the memory requirement of the cache. +func (c *simpleCache) keyFunc(s []byte) string { + h := c.hashPool.Get().(hash.Hash) + h.Reset() + + if _, err := h.Write(s); err != nil { + panic(err) // Write() on hash never fails + } + key := toString(h.Sum(nil)) // skip base64 encoding to save an allocation + c.hashPool.Put(h) + + return key +} + +// toString performs unholy acts to avoid allocations +func toString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) } diff --git a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache_test.go b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache_test.go index 89c8c32e83d..f629fc68cab 100644 --- a/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache_test.go +++ b/staging/src/k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2/cache_test.go @@ -18,6 +18,9 @@ limitations under the License. package kmsv2 import ( + "crypto/sha256" + "fmt" + "sync" "testing" "time" @@ -57,3 +60,75 @@ func TestSimpleCacheSetError(t *testing.T) { }) } } + +func TestKeyFunc(t *testing.T) { + fakeClock := testingclock.NewFakeClock(time.Now()) + cache := newSimpleCache(fakeClock, time.Second) + + t.Run("AllocsPerRun test", func(t *testing.T) { + key, err := generateKey(encryptedDEKMaxSize) // simulate worst case EDEK + if err != nil { + t.Fatal(err) + } + + f := func() { + out := cache.keyFunc(key) + if len(out) != sha256.Size { + t.Errorf("Expected %d bytes, got %d", sha256.Size, len(out)) + } + } + + // prime the key func + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + f() + wg.Done() + }() + } + wg.Wait() + + allocs := testing.AllocsPerRun(100, f) + if allocs > 1 { + t.Errorf("Expected 1 allocations, got %v", allocs) + } + }) +} + +func TestSimpleCache(t *testing.T) { + fakeClock := testingclock.NewFakeClock(time.Now()) + cache := newSimpleCache(fakeClock, 5*time.Second) + envelopeTransformer := &envelopeTransformer{} + + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + k := fmt.Sprintf("key-%d", i) + wg.Add(1) + go func(key string) { + defer wg.Done() + cache.set([]byte(key), envelopeTransformer) + }(k) + } + wg.Wait() + + if cache.cache.Len() != 10 { + t.Fatalf("Expected 10 items in the cache, got %v", cache.cache.Len()) + } + + for i := 0; i < 10; i++ { + k := fmt.Sprintf("key-%d", i) + if cache.get([]byte(k)) != envelopeTransformer { + t.Fatalf("Expected to get the transformer for key %v", k) + } + } + + // Wait for the cache to expire + fakeClock.Step(6 * time.Second) + for i := 0; i < 10; i++ { + k := fmt.Sprintf("key-%d", i) + if cache.get([]byte(k)) != nil { + t.Fatalf("Expected to get nil for key %v", k) + } + } +}