diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/BUILD b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/BUILD index 66c09d44844..3c8ca28129f 100644 --- a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/BUILD +++ b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/BUILD @@ -17,6 +17,7 @@ go_test( "//staging/src/k8s.io/apimachinery/pkg/util/clock:go_default_library", "//staging/src/k8s.io/apiserver/pkg/authentication/authenticator:go_default_library", "//staging/src/k8s.io/apiserver/pkg/authentication/user:go_default_library", + "//staging/src/k8s.io/apiserver/pkg/endpoints/request:go_default_library", "//vendor/github.com/pborman/uuid:go_default_library", ], ) @@ -34,6 +35,7 @@ go_library( "//staging/src/k8s.io/apimachinery/pkg/util/cache:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/util/clock:go_default_library", "//staging/src/k8s.io/apiserver/pkg/authentication/authenticator:go_default_library", + "//staging/src/k8s.io/apiserver/pkg/endpoints/request:go_default_library", ], ) diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cache_striped.go b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cache_striped.go index b791260fc24..e5b7afe4e7d 100644 --- a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cache_striped.go +++ b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cache_striped.go @@ -24,36 +24,36 @@ import ( // split cache lookups across N striped caches type stripedCache struct { stripeCount uint32 - keyFunc func(string) uint32 + hashFunc func(string) uint32 caches []cache } -type keyFunc func(string) uint32 +type hashFunc func(string) uint32 type newCacheFunc func() cache -func newStripedCache(stripeCount int, keyFunc keyFunc, newCacheFunc newCacheFunc) cache { +func newStripedCache(stripeCount int, hash hashFunc, newCacheFunc newCacheFunc) cache { caches := []cache{} for i := 0; i < stripeCount; i++ { caches = append(caches, newCacheFunc()) } return &stripedCache{ stripeCount: uint32(stripeCount), - keyFunc: keyFunc, + hashFunc: hash, caches: caches, } } func (c *stripedCache) get(key string) (*cacheRecord, bool) { - return c.caches[c.keyFunc(key)%c.stripeCount].get(key) + return c.caches[c.hashFunc(key)%c.stripeCount].get(key) } func (c *stripedCache) set(key string, value *cacheRecord, ttl time.Duration) { - c.caches[c.keyFunc(key)%c.stripeCount].set(key, value, ttl) + c.caches[c.hashFunc(key)%c.stripeCount].set(key, value, ttl) } func (c *stripedCache) remove(key string) { - c.caches[c.keyFunc(key)%c.stripeCount].remove(key) + c.caches[c.hashFunc(key)%c.stripeCount].remove(key) } -func fnvKeyFunc(key string) uint32 { +func fnvHashFunc(key string) uint32 { f := fnv.New32() f.Write([]byte(key)) return f.Sum32() diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cache_test.go b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cache_test.go index 96c794583b1..afc3e30c9ab 100644 --- a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cache_test.go +++ b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cache_test.go @@ -22,6 +22,7 @@ import ( "time" "github.com/pborman/uuid" + "k8s.io/apimachinery/pkg/util/clock" "k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/user" @@ -36,11 +37,11 @@ func BenchmarkSimpleCache(b *testing.B) { } func TestStripedCache(t *testing.T) { - testCache(newStripedCache(32, fnvKeyFunc, func() cache { return newSimpleCache(128, clock.RealClock{}) }), t) + testCache(newStripedCache(32, fnvHashFunc, func() cache { return newSimpleCache(128, clock.RealClock{}) }), t) } func BenchmarkStripedCache(b *testing.B) { - benchmarkCache(newStripedCache(32, fnvKeyFunc, func() cache { return newSimpleCache(128, clock.RealClock{}) }), b) + benchmarkCache(newStripedCache(32, fnvHashFunc, func() cache { return newSimpleCache(128, clock.RealClock{}) }), b) } func benchmarkCache(cache cache, b *testing.B) { diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go index 40243290bc9..ec5af39d8bc 100644 --- a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go +++ b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator.go @@ -18,10 +18,12 @@ package cache import ( "context" + "fmt" "time" utilclock "k8s.io/apimachinery/pkg/util/clock" "k8s.io/apiserver/pkg/authentication/authenticator" + "k8s.io/apiserver/pkg/endpoints/request" ) // cacheRecord holds the three return values of the authenticator.Token AuthenticateToken method @@ -59,15 +61,16 @@ func newWithClock(authenticator authenticator.Token, successTTL, failureTTL time authenticator: authenticator, successTTL: successTTL, failureTTL: failureTTL, - cache: newStripedCache(32, fnvKeyFunc, func() cache { return newSimpleCache(128, clock) }), + cache: newStripedCache(32, fnvHashFunc, func() cache { return newSimpleCache(128, clock) }), } } // AuthenticateToken implements authenticator.Token func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token string) (*authenticator.Response, bool, error) { - // TODO(mikedanese): The key needs to incorporate any relevant data in the - // context. - if record, ok := a.cache.get(token); ok { + auds, _ := request.AudiencesFrom(ctx) + + key := keyFunc(auds, token) + if record, ok := a.cache.get(key); ok { return record.resp, record.ok, record.err } @@ -75,10 +78,14 @@ func (a *cachedTokenAuthenticator) AuthenticateToken(ctx context.Context, token switch { case ok && a.successTTL > 0: - a.cache.set(token, &cacheRecord{resp: resp, ok: ok, err: err}, a.successTTL) + a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.successTTL) case !ok && a.failureTTL > 0: - a.cache.set(token, &cacheRecord{resp: resp, ok: ok, err: err}, a.failureTTL) + a.cache.set(key, &cacheRecord{resp: resp, ok: ok, err: err}, a.failureTTL) } return resp, ok, err } + +func keyFunc(auds []string, token string) string { + return fmt.Sprintf("%#v|%v", auds, token) +} diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go index eb993888e16..e92e957a4d3 100644 --- a/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go +++ b/staging/src/k8s.io/apiserver/pkg/authentication/token/cache/cached_token_authenticator_test.go @@ -25,6 +25,7 @@ import ( utilclock "k8s.io/apimachinery/pkg/util/clock" "k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/user" + "k8s.io/apiserver/pkg/endpoints/request" ) func TestCachedTokenAuthenticator(t *testing.T) { @@ -104,3 +105,24 @@ func TestCachedTokenAuthenticator(t *testing.T) { t.Errorf("Expected token calls, got %v", calledWithToken) } } + +func TestCachedTokenAuthenticatorWithAudiences(t *testing.T) { + resultUsers := make(map[string]user.Info) + fakeAuth := authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) { + auds, _ := request.AudiencesFrom(ctx) + return &authenticator.Response{User: resultUsers[auds[0]+token]}, true, nil + }) + fakeClock := utilclock.NewFakeClock(time.Now()) + + a := newWithClock(fakeAuth, time.Minute, 0, fakeClock) + + resultUsers["audAusertoken1"] = &user.DefaultInfo{Name: "user1"} + resultUsers["audBusertoken1"] = &user.DefaultInfo{Name: "user1-different"} + + if u, ok, _ := a.AuthenticateToken(request.WithAudiences(context.Background(), []string{"audA"}), "usertoken1"); !ok || u.User.GetName() != "user1" { + t.Errorf("Expected user1") + } + if u, ok, _ := a.AuthenticateToken(request.WithAudiences(context.Background(), []string{"audB"}), "usertoken1"); !ok || u.User.GetName() != "user1-different" { + t.Errorf("Expected user1-different") + } +}