diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go index 2ffda6f1bb4..d94866d5baa 100644 --- a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go +++ b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go @@ -18,8 +18,15 @@ package cache import ( "context" - "fmt" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/binary" + "hash" + "io" + "sync" "time" + "unsafe" utilclock "k8s.io/apimachinery/pkg/util/clock" "k8s.io/apiserver/pkg/authentication/authenticator" @@ -40,6 +47,11 @@ type cachedTokenAuthenticator struct { failureTTL time.Duration cache cache + + // hashPool is a per authenticator pool of hash.Hash (to avoid allocations from building the Hash) + // HMAC with SHA-256 and a random key is used to prevent precomputation and length extension attacks + // It also mitigates hash map DOS attacks via collisions (the inputs are supplied by untrusted users) + hashPool *sync.Pool } type cache interface { @@ -57,6 +69,11 @@ func New(authenticator authenticator.Token, cacheErrs bool, successTTL, failureT } func newWithClock(authenticator authenticator.Token, cacheErrs bool, successTTL, failureTTL time.Duration, clock utilclock.Clock) authenticator.Token { + randomCacheKey := make([]byte, 32) + if _, err := rand.Read(randomCacheKey); err != nil { + panic(err) // rand should never fail + } + return &cachedTokenAuthenticator{ authenticator: authenticator, cacheErrs: cacheErrs, @@ -70,6 +87,12 @@ func newWithClock(authenticator authenticator.Token, cacheErrs bool, successTTL, // namespaces; a 32k entry cache is therefore a 2x safety // margin. cache: newStripedCache(32, fnvHashFunc, func() cache { return newSimpleCache(1024, clock) }), + + hashPool: &sync.Pool{ + New: func() interface{} { + return hmac.New(sha256.New, randomCacheKey) + }, + }, } } @@ -77,7 +100,7 @@ func newWithClock(authenticator authenticator.Token, cacheErrs bool, successTTL, func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token string) (*authenticator.Response, bool, error) { auds, _ := authenticator.AudiencesFrom(ctx) - key := keyFunc(auds, token) + key := keyFunc(a.hashPool, auds, token) if record, ok := a.cache.get(key); ok { return record.resp, record.ok, record.err } @@ -97,6 +120,55 @@ func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token return resp, ok, err } -func keyFunc(auds []string, token string) string { - return fmt.Sprintf("%#v|%v", auds, token) +// keyFunc generates a string key by hashing the inputs. +// This lowers the memory requirement of the cache and keeps tokens out of memory. +func keyFunc(hashPool *sync.Pool, auds []string, token string) string { + h := hashPool.Get().(hash.Hash) + + h.Reset() + + // try to force stack allocation + var a [4]byte + b := a[:] + + writeLengthPrefixedString(h, b, token) + // encode the length of audiences to avoid ambiguities + writeLength(h, b, len(auds)) + for _, aud := range auds { + writeLengthPrefixedString(h, b, aud) + } + + key := toString(h.Sum(nil)) // skip base64 encoding to save an allocation + + hashPool.Put(h) + + return key +} + +// writeLengthPrefixedString writes s with a length prefix to prevent ambiguities, i.e. "xy" + "z" == "x" + "yz" +// the length of b is assumed to be 4 (b is mutated by this function to store the length of s) +func writeLengthPrefixedString(w io.Writer, b []byte, s string) { + writeLength(w, b, len(s)) + if _, err := w.Write(toBytes(s)); err != nil { + panic(err) // Write() on hash never fails + } +} + +// writeLength encodes length into b and then writes it via the given writer +// the length of b is assumed to be 4 +func writeLength(w io.Writer, b []byte, length int) { + binary.BigEndian.PutUint32(b, uint32(length)) + if _, err := w.Write(b); err != nil { + panic(err) // Write() on hash never fails + } +} + +// toBytes performs unholy acts to avoid allocations +func toBytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer(&s)) +} + +// toString performs unholy acts to avoid allocations +func toString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) } diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go index 09901b40033..291fb4e3212 100644 --- a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go +++ b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go @@ -18,7 +18,11 @@ package cache import ( "context" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" "reflect" + "sync" "testing" "time" @@ -125,3 +129,39 @@ func TestCachedTokenAuthenticatorWithAudiences(t *testing.T) { t.Errorf("Expected user1-different") } } + +var bKey string + +// use a realistic token for benchmarking +const jwtToken = `eyJhbGciOiJSUzI1NiIsImtpZCI6IiJ9.eyJpc3MiOiJrdWJlcm5ldGVzL3NlcnZpY2VhY2NvdW50Iiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9uYW1lc3BhY2UiOiJvcGVuc2hpZnQtc2RuIiwia3ViZXJuZXRlcy5pby9zZXJ2aWNlYWNjb3VudC9zZWNyZXQubmFtZSI6InNkbi10b2tlbi1nNndtYyIsImt1YmVybmV0ZXMuaW8vc2VydmljZWFjY291bnQvc2VydmljZS1hY2NvdW50Lm5hbWUiOiJzZG4iLCJrdWJlcm5ldGVzLmlvL3NlcnZpY2VhY2NvdW50L3NlcnZpY2UtYWNjb3VudC51aWQiOiIzYzM5YzNhYS1kM2Q5LTExZTktYTVkMC0wMmI3YjllODg1OWUiLCJzdWIiOiJzeXN0ZW06c2VydmljZWFjY291bnQ6b3BlbnNoaWZ0LXNkbjpzZG4ifQ.PIs0rsUTekj5AX8yJeLDyW4vQB17YS4IOgO026yjEvsCY7Wv_2TD0lwyZWqyQh639q3jPh2_3LTQq2Cp0cReBP1PYOIGgprNm3C-3OFZRnkls-GH09kvPYE8J_-a1YwjxucOwytzJvEM5QTC9iXfEJNSTBfLge-HMYT1y0AGKs8DWTSC4rtd_2PedK3OYiAyDg_xHA8qNpG9pRNM8vfjV9VsmqJtlbnTVlTngqC0t5vyMaWrmLNRxN0rTbN2W9L3diXRnYqI8BUfgPQb7uhYcPuXGeypaFrN4d3yNN4NbgVxnkgdd2IXQ8elSJuQn6ynrvLgG0JPMmThOHnwvsZDeA` + +func BenchmarkKeyFunc(b *testing.B) { + randomCacheKey := make([]byte, 32) + if _, err := rand.Read(randomCacheKey); err != nil { + b.Fatal(err) // rand should never fail + } + hashPool := &sync.Pool{ + New: func() interface{} { + return hmac.New(sha256.New, randomCacheKey) + }, + } + + // use realistic audiences for benchmarking + auds := []string{"7daf30b7-a85c-429b-8b21-e666aecbb235", "c22aa267-bdde-4acb-8505-998be7818400", "44f9b4f3-7125-4333-b04c-1446a16c6113"} + + b.Run("has audiences", func(b *testing.B) { + var key string + for n := 0; n < b.N; n++ { + key = keyFunc(hashPool, auds, jwtToken) + } + bKey = key + }) + + b.Run("nil audiences", func(b *testing.B) { + var key string + for n := 0; n < b.N; n++ { + key = keyFunc(hashPool, nil, jwtToken) + } + bKey = key + }) +}