diff --git a/pkg/accesscontrol/access_store.go b/pkg/accesscontrol/access_store.go index c14ee6e..5e51f97 100644 --- a/pkg/accesscontrol/access_store.go +++ b/pkg/accesscontrol/access_store.go @@ -14,6 +14,7 @@ import ( type AccessSetLookup interface { AccessFor(user user.Info) *AccessSet + PurgeUserData(id string) } type AccessStore struct { @@ -63,6 +64,10 @@ func (l *AccessStore) AccessFor(user user.Info) *AccessSet { return result } +func (l *AccessStore) PurgeUserData(id string) { + l.cache.Remove(id) +} + func (l *AccessStore) CacheKey(user user.Info) string { d := sha256.New() diff --git a/pkg/schema/collection.go b/pkg/schema/collection.go index 916e78f..48f23ea 100644 --- a/pkg/schema/collection.go +++ b/pkg/schema/collection.go @@ -36,6 +36,7 @@ type Collection struct { byGVR map[schema.GroupVersionResource]string byGVK map[schema.GroupVersionKind]string cache *cache.LRUExpireCache + userCache *cache.LRUExpireCache lock sync.RWMutex ctx context.Context @@ -84,6 +85,7 @@ func NewCollection(ctx context.Context, baseSchema *types.APISchemas, access acc byGVR: map[schema.GroupVersionResource]string{}, byGVK: map[schema.GroupVersionKind]string{}, cache: cache.NewLRUExpireCache(1000), + userCache: cache.NewLRUExpireCache(1000), notifiers: map[int]func(){}, ctx: ctx, as: access, diff --git a/pkg/schema/factory.go b/pkg/schema/factory.go index b16af18..5f7ebd9 100644 --- a/pkg/schema/factory.go +++ b/pkg/schema/factory.go @@ -23,6 +23,7 @@ func newSchemas() (*types.APISchemas, error) { func (c *Collection) Schemas(user user.Info) (*types.APISchemas, error) { access := c.as.AccessFor(user) + c.removeOldRecords(access, user) val, ok := c.cache.Get(access.ID) if ok { schemas, _ := val.(*types.APISchemas) @@ -33,11 +34,34 @@ func (c *Collection) Schemas(user user.Info) (*types.APISchemas, error) { if err != nil { return nil, err } - - c.cache.Add(access.ID, schemas, 24*time.Hour) + c.addToCache(access, user, schemas) return schemas, nil } +func (c *Collection) removeOldRecords(access *accesscontrol.AccessSet, user user.Info) { + current, ok := c.userCache.Get(user.GetName()) + if ok { + currentId, cOk := current.(string) + if cOk && currentId != access.ID { + // we only want to keep around one record per user. If our current access record is invalid, purge the + //record of it from the cache, so we don't keep duplicates + c.purgeUserRecords(currentId) + c.userCache.Remove(user.GetName()) + } + } +} + +func (c *Collection) addToCache(access *accesscontrol.AccessSet, user user.Info, schemas *types.APISchemas) { + c.cache.Add(access.ID, schemas, 24*time.Hour) + c.userCache.Add(user.GetName(), access.ID, 24*time.Hour) +} + +// PurgeUserRecords removes a record from the backing LRU cache before expiry +func (c *Collection) purgeUserRecords(id string) { + c.cache.Remove(id) + c.as.PurgeUserData(id) +} + func (c *Collection) schemasForSubject(access *accesscontrol.AccessSet) (*types.APISchemas, error) { c.lock.RLock() defer c.lock.RUnlock() diff --git a/pkg/schema/factory_test.go b/pkg/schema/factory_test.go new file mode 100644 index 0000000..c9dc6f6 --- /dev/null +++ b/pkg/schema/factory_test.go @@ -0,0 +1,168 @@ +package schema + +import ( + "context" + "github.com/stretchr/testify/assert" + "testing" + + "github.com/rancher/apiserver/pkg/types" + "github.com/rancher/wrangler/pkg/schemas" + k8sSchema "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apiserver/pkg/authentication/user" +) + +const ( + testGroup = "test.k8s.io" + testVersion = "v1" +) + +type schemaTestConfig struct { + permissionVerbs []string + desiredResourceVerbs []string + desiredCollectionVerbs []string + errDesired bool +} + +func TestSchemas(t *testing.T) { + tests := []struct { + name string + config schemaTestConfig + }{ + { + name: "basic get schema test", + config: schemaTestConfig{ + permissionVerbs: []string{"get"}, + desiredResourceVerbs: []string{"GET"}, + desiredCollectionVerbs: []string{"GET"}, + errDesired: false, + }, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + // test caching functionality + mockLookup := newMockAccessSetLookup() + userName := "testUser" + testUser := user.DefaultInfo{ + Name: userName, + UID: userName, + Groups: []string{}, + Extra: map[string][]string{}, + } + + collection := NewCollection(context.TODO(), types.EmptyAPISchemas(), mockLookup) + collection.schemas = map[string]*types.APISchema{"testCRD": makeSchema("testCRD")} + runSchemaTest(t, test.config, mockLookup, collection, &testUser) + }) + } +} +func TestSchemaCache(t *testing.T) { + // Schemas are a frequently used resource. It's important that the cache doesn't have a leak given size/frequency of resource + tests := []struct { + name string + before schemaTestConfig + after schemaTestConfig + }{ + { + name: "permissions increase, cache size same", + before: schemaTestConfig{ + permissionVerbs: []string{"get"}, + desiredResourceVerbs: []string{"GET"}, + desiredCollectionVerbs: []string{"GET"}, + errDesired: false, + }, + after: schemaTestConfig{ + permissionVerbs: []string{"get", "create", "delete"}, + desiredResourceVerbs: []string{"GET", "DELETE"}, + desiredCollectionVerbs: []string{"GET", "POST"}, + errDesired: false, + }, + }, + { + name: "permissions decrease, cache size same", + before: schemaTestConfig{ + permissionVerbs: []string{"get", "create", "delete"}, + desiredResourceVerbs: []string{"GET", "DELETE"}, + desiredCollectionVerbs: []string{"GET", "POST"}, + errDesired: false, + }, + after: schemaTestConfig{ + permissionVerbs: []string{"get"}, + desiredResourceVerbs: []string{"GET"}, + desiredCollectionVerbs: []string{"GET"}, + errDesired: false, + }, + }, + } + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + // test caching functionality + mockLookup := newMockAccessSetLookup() + userName := "testUser" + testUser := user.DefaultInfo{ + Name: userName, + UID: userName, + Groups: []string{}, + Extra: map[string][]string{}, + } + collection := NewCollection(context.TODO(), types.EmptyAPISchemas(), mockLookup) + collection.schemas = map[string]*types.APISchema{"testCRD": makeSchema("testCRD")} + runSchemaTest(t, test.before, mockLookup, collection, &testUser) + assert.Len(t, collection.cache.Keys(), 1, "expected cache to be size 1") + mockLookup.Clear() + runSchemaTest(t, test.after, mockLookup, collection, &testUser) + assert.Len(t, collection.cache.Keys(), 1, "expected cache to be size 1") + }) + } +} + +func runSchemaTest(t *testing.T, config schemaTestConfig, lookup *mockAccessSetLookup, collection *Collection, testUser user.Info) { + for _, verb := range config.permissionVerbs { + lookup.AddAccessForUser(testUser, verb, k8sSchema.GroupResource{Group: testGroup, Resource: "testCRD"}, "*", "*") + } + + collection.schemas = map[string]*types.APISchema{"testCRD": makeSchema("testCRD")} + userSchemas, err := collection.Schemas(testUser) + if config.errDesired { + assert.Error(t, err, "expected error but none was found") + } + var testSchema *types.APISchema + for schemaName, userSchema := range userSchemas.Schemas { + if schemaName == "testCRD" { + testSchema = userSchema + } + } + assert.NotNil(t, testSchema, "expected a test schema, but was nil") + assert.Len(t, testSchema.ResourceMethods, len(config.desiredResourceVerbs), "did not get as many verbs as expected for resource methods") + assert.Len(t, testSchema.CollectionMethods, len(config.desiredCollectionVerbs), "did not get as many verbs as expected for resource methods") + for _, verb := range config.desiredResourceVerbs { + assert.Contains(t, testSchema.ResourceMethods, verb, "did not find %s in resource methods %v", verb, testSchema.ResourceMethods) + } + for _, verb := range config.desiredCollectionVerbs { + assert.Contains(t, testSchema.CollectionMethods, verb, "did not find %s in resource methods %v", verb, testSchema.CollectionMethods) + } +} + +func makeSchema(resourceType string) *types.APISchema { + return &types.APISchema{ + Schema: &schemas.Schema{ + ID: resourceType, + CollectionMethods: []string{}, + ResourceMethods: []string{}, + ResourceFields: map[string]schemas.Field{ + "name": {Type: "string"}, + "value": {Type: "string"}, + }, + Attributes: map[string]interface{}{ + "group": testGroup, + "version": testVersion, + "resource": resourceType, + "verbs": []string{"get", "list", "watch", "delete", "update", "create"}, + }, + }, + } +} diff --git a/pkg/schema/mock_test.go b/pkg/schema/mock_test.go new file mode 100644 index 0000000..c0b0b28 --- /dev/null +++ b/pkg/schema/mock_test.go @@ -0,0 +1,75 @@ +package schema + +import ( + "crypto/sha256" + "encoding/hex" + "hash" + + "github.com/rancher/steve/pkg/accesscontrol" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apiserver/pkg/authentication/user" +) + +const ( + insideSeparator = "&" + outsideSeparator = "%" +) + +type mockAccessSetLookup struct { + accessSets map[string]*accesscontrol.AccessSet + currentHash map[string]hash.Hash +} + +func newMockAccessSetLookup() *mockAccessSetLookup { + return &mockAccessSetLookup{ + accessSets: map[string]*accesscontrol.AccessSet{}, + currentHash: map[string]hash.Hash{}, + } +} + +func (m *mockAccessSetLookup) AccessFor(user user.Info) *accesscontrol.AccessSet { + if set, ok := m.accessSets[user.GetName()]; ok { + return set + } + return nil +} + +func (m *mockAccessSetLookup) PurgeUserData(id string) { + var foundKey string + for key, value := range m.accessSets { + if value.ID == id { + foundKey = key + } + } + if foundKey != "" { + delete(m.accessSets, foundKey) + } +} + +func (m *mockAccessSetLookup) AddAccessForUser(user user.Info, verb string, gr schema.GroupResource, namespace string, name string) { + currentAccessSet, ok := m.accessSets[user.GetName()] + var currentHash hash.Hash + if !ok { + currentAccessSet = &accesscontrol.AccessSet{} + currentHash = sha256.New() + } else { + currentHash = m.currentHash[currentAccessSet.ID] + } + currentAccessSet.Add(verb, gr, accesscontrol.Access{Namespace: namespace, ResourceName: name}) + calculateAccessSetID(currentHash, verb, gr, namespace, name) + currentAccessSet.ID = hex.EncodeToString(currentHash.Sum(nil)) + m.accessSets[user.GetName()] = currentAccessSet + m.currentHash[currentAccessSet.ID] = currentHash +} + +func (m *mockAccessSetLookup) Clear() { + m.accessSets = map[string]*accesscontrol.AccessSet{} + m.currentHash = map[string]hash.Hash{} +} + +func calculateAccessSetID(digest hash.Hash, verb string, gr schema.GroupResource, namespace string, name string) { + digest.Write([]byte(verb + insideSeparator)) + digest.Write([]byte(gr.String() + insideSeparator)) + digest.Write([]byte(namespace + insideSeparator)) + digest.Write([]byte(name + outsideSeparator)) +}