[KMSv2] store hash of encrypted DEK as key in cache

Signed-off-by: Anish Ramasekar <anish.ramasekar@gmail.com>
This commit is contained in:
Anish Ramasekar 2023-01-27 00:19:32 +00:00
parent d35da348c6
commit f72cf5c510
No known key found for this signature in database
GPG Key ID: F1F7F3518F1ECB0C
2 changed files with 109 additions and 3 deletions

View File

@ -18,8 +18,11 @@ limitations under the License.
package kmsv2
import (
"encoding/base64"
"crypto/sha256"
"hash"
"sync"
"time"
"unsafe"
utilcache "k8s.io/apimachinery/pkg/util/cache"
"k8s.io/apiserver/pkg/storage/value"
@ -29,18 +32,26 @@ import (
type simpleCache struct {
cache *utilcache.Expiring
ttl time.Duration
// hashPool is a per cache pool of hash.Hash (to avoid allocations from building the Hash)
// SHA-256 is used to prevent collisions
hashPool *sync.Pool
}
func newSimpleCache(clock clock.Clock, ttl time.Duration) *simpleCache {
return &simpleCache{
cache: utilcache.NewExpiringWithClock(clock),
ttl: ttl,
hashPool: &sync.Pool{
New: func() interface{} {
return sha256.New()
},
},
}
}
// given a key, return the transformer, or nil if it does not exist in the cache
func (c *simpleCache) get(key []byte) value.Transformer {
record, ok := c.cache.Get(base64.StdEncoding.EncodeToString(key))
record, ok := c.cache.Get(c.keyFunc(key))
if !ok {
return nil
}
@ -55,5 +66,25 @@ func (c *simpleCache) set(key []byte, transformer value.Transformer) {
if transformer == nil {
panic("transformer must not be nil")
}
c.cache.Set(base64.StdEncoding.EncodeToString(key), transformer, c.ttl)
c.cache.Set(c.keyFunc(key), transformer, c.ttl)
}
// keyFunc generates a string key by hashing the inputs.
// This lowers the memory requirement of the cache.
func (c *simpleCache) keyFunc(s []byte) string {
h := c.hashPool.Get().(hash.Hash)
h.Reset()
if _, err := h.Write(s); err != nil {
panic(err) // Write() on hash never fails
}
key := toString(h.Sum(nil)) // skip base64 encoding to save an allocation
c.hashPool.Put(h)
return key
}
// toString performs unholy acts to avoid allocations
func toString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}

View File

@ -18,6 +18,9 @@ limitations under the License.
package kmsv2
import (
"crypto/sha256"
"fmt"
"sync"
"testing"
"time"
@ -57,3 +60,75 @@ func TestSimpleCacheSetError(t *testing.T) {
})
}
}
func TestKeyFunc(t *testing.T) {
fakeClock := testingclock.NewFakeClock(time.Now())
cache := newSimpleCache(fakeClock, time.Second)
t.Run("AllocsPerRun test", func(t *testing.T) {
key, err := generateKey(encryptedDEKMaxSize) // simulate worst case EDEK
if err != nil {
t.Fatal(err)
}
f := func() {
out := cache.keyFunc(key)
if len(out) != sha256.Size {
t.Errorf("Expected %d bytes, got %d", sha256.Size, len(out))
}
}
// prime the key func
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
f()
wg.Done()
}()
}
wg.Wait()
allocs := testing.AllocsPerRun(100, f)
if allocs > 1 {
t.Errorf("Expected 1 allocations, got %v", allocs)
}
})
}
func TestSimpleCache(t *testing.T) {
fakeClock := testingclock.NewFakeClock(time.Now())
cache := newSimpleCache(fakeClock, 5*time.Second)
envelopeTransformer := &envelopeTransformer{}
wg := sync.WaitGroup{}
for i := 0; i < 10; i++ {
k := fmt.Sprintf("key-%d", i)
wg.Add(1)
go func(key string) {
defer wg.Done()
cache.set([]byte(key), envelopeTransformer)
}(k)
}
wg.Wait()
if cache.cache.Len() != 10 {
t.Fatalf("Expected 10 items in the cache, got %v", cache.cache.Len())
}
for i := 0; i < 10; i++ {
k := fmt.Sprintf("key-%d", i)
if cache.get([]byte(k)) != envelopeTransformer {
t.Fatalf("Expected to get the transformer for key %v", k)
}
}
// Wait for the cache to expire
fakeClock.Step(6 * time.Second)
for i := 0; i < 10; i++ {
k := fmt.Sprintf("key-%d", i)
if cache.get([]byte(k)) != nil {
t.Fatalf("Expected to get nil for key %v", k)
}
}
}