mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-21 10:51:29 +00:00
[KMSv2] store hash of encrypted DEK as key in cache
Signed-off-by: Anish Ramasekar <anish.ramasekar@gmail.com>
This commit is contained in:
parent
d35da348c6
commit
f72cf5c510
@ -18,8 +18,11 @@ limitations under the License.
|
||||
package kmsv2
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"crypto/sha256"
|
||||
"hash"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
utilcache "k8s.io/apimachinery/pkg/util/cache"
|
||||
"k8s.io/apiserver/pkg/storage/value"
|
||||
@ -29,18 +32,26 @@ import (
|
||||
type simpleCache struct {
|
||||
cache *utilcache.Expiring
|
||||
ttl time.Duration
|
||||
// hashPool is a per cache pool of hash.Hash (to avoid allocations from building the Hash)
|
||||
// SHA-256 is used to prevent collisions
|
||||
hashPool *sync.Pool
|
||||
}
|
||||
|
||||
func newSimpleCache(clock clock.Clock, ttl time.Duration) *simpleCache {
|
||||
return &simpleCache{
|
||||
cache: utilcache.NewExpiringWithClock(clock),
|
||||
ttl: ttl,
|
||||
hashPool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return sha256.New()
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// given a key, return the transformer, or nil if it does not exist in the cache
|
||||
func (c *simpleCache) get(key []byte) value.Transformer {
|
||||
record, ok := c.cache.Get(base64.StdEncoding.EncodeToString(key))
|
||||
record, ok := c.cache.Get(c.keyFunc(key))
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
@ -55,5 +66,25 @@ func (c *simpleCache) set(key []byte, transformer value.Transformer) {
|
||||
if transformer == nil {
|
||||
panic("transformer must not be nil")
|
||||
}
|
||||
c.cache.Set(base64.StdEncoding.EncodeToString(key), transformer, c.ttl)
|
||||
c.cache.Set(c.keyFunc(key), transformer, c.ttl)
|
||||
}
|
||||
|
||||
// keyFunc generates a string key by hashing the inputs.
|
||||
// This lowers the memory requirement of the cache.
|
||||
func (c *simpleCache) keyFunc(s []byte) string {
|
||||
h := c.hashPool.Get().(hash.Hash)
|
||||
h.Reset()
|
||||
|
||||
if _, err := h.Write(s); err != nil {
|
||||
panic(err) // Write() on hash never fails
|
||||
}
|
||||
key := toString(h.Sum(nil)) // skip base64 encoding to save an allocation
|
||||
c.hashPool.Put(h)
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
// toString performs unholy acts to avoid allocations
|
||||
func toString(b []byte) string {
|
||||
return *(*string)(unsafe.Pointer(&b))
|
||||
}
|
||||
|
@ -18,6 +18,9 @@ limitations under the License.
|
||||
package kmsv2
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -57,3 +60,75 @@ func TestSimpleCacheSetError(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyFunc(t *testing.T) {
|
||||
fakeClock := testingclock.NewFakeClock(time.Now())
|
||||
cache := newSimpleCache(fakeClock, time.Second)
|
||||
|
||||
t.Run("AllocsPerRun test", func(t *testing.T) {
|
||||
key, err := generateKey(encryptedDEKMaxSize) // simulate worst case EDEK
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f := func() {
|
||||
out := cache.keyFunc(key)
|
||||
if len(out) != sha256.Size {
|
||||
t.Errorf("Expected %d bytes, got %d", sha256.Size, len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// prime the key func
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
f()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
allocs := testing.AllocsPerRun(100, f)
|
||||
if allocs > 1 {
|
||||
t.Errorf("Expected 1 allocations, got %v", allocs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSimpleCache(t *testing.T) {
|
||||
fakeClock := testingclock.NewFakeClock(time.Now())
|
||||
cache := newSimpleCache(fakeClock, 5*time.Second)
|
||||
envelopeTransformer := &envelopeTransformer{}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 10; i++ {
|
||||
k := fmt.Sprintf("key-%d", i)
|
||||
wg.Add(1)
|
||||
go func(key string) {
|
||||
defer wg.Done()
|
||||
cache.set([]byte(key), envelopeTransformer)
|
||||
}(k)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if cache.cache.Len() != 10 {
|
||||
t.Fatalf("Expected 10 items in the cache, got %v", cache.cache.Len())
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
k := fmt.Sprintf("key-%d", i)
|
||||
if cache.get([]byte(k)) != envelopeTransformer {
|
||||
t.Fatalf("Expected to get the transformer for key %v", k)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for the cache to expire
|
||||
fakeClock.Step(6 * time.Second)
|
||||
for i := 0; i < 10; i++ {
|
||||
k := fmt.Sprintf("key-%d", i)
|
||||
if cache.get([]byte(k)) != nil {
|
||||
t.Fatalf("Expected to get nil for key %v", k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user