diff --git a/pkg/accesscontrol/access_store.go b/pkg/accesscontrol/access_store.go index 2380f12b..6614b9fd 100644 --- a/pkg/accesscontrol/access_store.go +++ b/pkg/accesscontrol/access_store.go @@ -31,11 +31,18 @@ type roleRevisions interface { roleRevision(string, string) string } +// accessStoreCache is a subset of the methods implemented by LRUExpireCache +type accessStoreCache interface { + Add(key interface{}, value interface{}, ttl time.Duration) + Get(key interface{}) (interface{}, bool) + Remove(key interface{}) +} + type AccessStore struct { usersPolicyRules policyRules groupsPolicyRules policyRules roles roleRevisions - cache *cache.LRUExpireCache + cache accessStoreCache } type roleKey struct { @@ -56,26 +63,29 @@ func NewAccessStore(ctx context.Context, cacheResults bool, rbac v1.Interface) * } func (l *AccessStore) AccessFor(user user.Info) *AccessSet { - var cacheKey string - if l.cache != nil { - cacheKey = l.CacheKey(user) - val, ok := l.cache.Get(cacheKey) - if ok { - as, _ := val.(*AccessSet) - return as - } + if l.cache == nil { + return l.newAccessSet(user) } + cacheKey := l.CacheKey(user) + + if val, ok := l.cache.Get(cacheKey); ok { + as, _ := val.(*AccessSet) + return as + } + + result := l.newAccessSet(user) + result.ID = cacheKey + l.cache.Add(cacheKey, result, 24*time.Hour) + + return result +} + +func (l *AccessStore) newAccessSet(user user.Info) *AccessSet { result := l.usersPolicyRules.get(user.GetName()) for _, group := range user.GetGroups() { result.Merge(l.groupsPolicyRules.get(group)) } - - if l.cache != nil { - result.ID = cacheKey - l.cache.Add(cacheKey, result, 24*time.Hour) - } - return result } diff --git a/pkg/accesscontrol/access_store_test.go b/pkg/accesscontrol/access_store_test.go index bdaf5d51..543723cd 100644 --- a/pkg/accesscontrol/access_store_test.go +++ b/pkg/accesscontrol/access_store_test.go @@ -3,7 +3,9 @@ package accesscontrol import ( "fmt" "slices" + "sync" "testing" + "time" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" @@ -274,6 +276,81 @@ func TestAccessStore_AccessFor(t *testing.T) { } } +type spyCache struct { + accessStoreCache + + mu sync.Mutex + setCalls map[any]int +} + +func (c *spyCache) Add(k interface{}, v interface{}, ttl time.Duration) { + defer c.observeAdd(k) + + time.Sleep(1 * time.Millisecond) // allow other routines to wake up, simulating heavy load + c.accessStoreCache.Add(k, v, ttl) +} + +func (c *spyCache) observeAdd(k interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.setCalls == nil { + c.setCalls = make(map[any]int) + } + c.setCalls[k]++ +} + +func TestAccessStore_AccessFor_concurrent(t *testing.T) { + t.Skipf("TODO - Add a fix for this test") + testUser := &user.DefaultInfo{Name: "test-user"} + asCache := &spyCache{accessStoreCache: cache.NewLRUExpireCache(100)} + store := &AccessStore{ + roles: roleRevisionsMock(func(ns, name string) string { + return fmt.Sprintf("%s%srev", ns, name) + }), + usersPolicyRules: &policyRulesMock{ + getRBFunc: func(s string) []*rbacv1.RoleBinding { + return []*rbacv1.RoleBinding{ + makeRB("testns", "testrb", testUser.Name, "testrole"), + } + }, + getFunc: func(_ string) *AccessSet { + return &AccessSet{ + set: map[key]resourceAccessSet{ + {"get", corev1.Resource("ConfigMap")}: map[Access]bool{ + {Namespace: All, ResourceName: All}: true, + }, + }, + } + }, + }, + cache: asCache, + } + + const n = 5 // observation showed cases with up to 5 (or more) concurrent queries for the same user + + wait := make(chan struct{}) + var wg sync.WaitGroup + var id string + for range n { + wg.Add(1) + go func() { + <-wait + id = store.AccessFor(testUser).ID + wg.Done() + }() + } + close(wait) + wg.Wait() + + if got, want := len(asCache.setCalls), 1; got != want { + t.Errorf("Unexpected number of cache entries: got %d, want %d", got, want) + } + if got, want := asCache.setCalls[id], 1; got != want { + t.Errorf("Unexpected number of calls to cache.Set(): got %d, want %d", got, want) + } +} + func makeRB(ns, name, user, role string) *rbacv1.RoleBinding { return &rbacv1.RoleBinding{ ObjectMeta: metav1.ObjectMeta{Namespace: ns, Name: name},