From 55a1b940a09e8cdba30250d3a707c4021bb603a3 Mon Sep 17 00:00:00 2001 From: Tom Lebreux Date: Fri, 30 May 2025 08:25:12 -0400 Subject: [PATCH] Split RegisterAfterUpsert into RegisterAfterAdd and RegisterAfterUpdate (#644) * Split RegisterAfterUpsert into two We're going to need to be able to differentiate between Add and Update for storing events in the _events table. * Update mocks --- pkg/sqlcache/informer/indexer.go | 6 +- pkg/sqlcache/informer/indexer_test.go | 3 +- pkg/sqlcache/informer/listoption_indexer.go | 6 +- .../informer/listoption_indexer_test.go | 30 +++-- pkg/sqlcache/informer/sql_mocks_test.go | 24 +++- pkg/sqlcache/store/store.go | 104 ++++++++++++------ pkg/sqlcache/store/store_test.go | 16 +-- 7 files changed, 125 insertions(+), 64 deletions(-) diff --git a/pkg/sqlcache/informer/indexer.go b/pkg/sqlcache/informer/indexer.go index 78b81882..2a154a32 100644 --- a/pkg/sqlcache/informer/indexer.go +++ b/pkg/sqlcache/informer/indexer.go @@ -70,7 +70,8 @@ type Store interface { GetByKey(key string) (item any, exists bool, err error) GetName() string - RegisterAfterUpsert(f func(key string, obj any, tx transaction.Client) error) + RegisterAfterAdd(f func(key string, obj any, tx transaction.Client) error) + RegisterAfterUpdate(f func(key string, obj any, tx transaction.Client) error) RegisterAfterDelete(f func(key string, tx transaction.Client) error) GetShouldEncrypt() bool GetType() reflect.Type @@ -102,7 +103,8 @@ func NewIndexer(ctx context.Context, indexers cache.Indexers, s Store) (*Indexer Store: s, indexers: indexers, } - i.RegisterAfterUpsert(i.AfterUpsert) + i.RegisterAfterAdd(i.AfterUpsert) + i.RegisterAfterUpdate(i.AfterUpsert) i.deleteIndicesQuery = fmt.Sprintf(deleteIndicesFmt, db.Sanitize(s.GetName())) i.addIndexQuery = fmt.Sprintf(addIndexFmt, db.Sanitize(s.GetName())) diff --git a/pkg/sqlcache/informer/indexer_test.go b/pkg/sqlcache/informer/indexer_test.go index c4d05b45..1d4f675b 100644 --- a/pkg/sqlcache/informer/indexer_test.go +++ b/pkg/sqlcache/informer/indexer_test.go @@ -58,7 +58,8 @@ func TestNewIndexer(t *testing.T) { t.Fail() } }) - store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().RegisterAfterAdd(gomock.Any()) + store.EXPECT().RegisterAfterUpdate(gomock.Any()) store.EXPECT().Prepare(fmt.Sprintf(deleteIndicesFmt, storeName)) store.EXPECT().Prepare(fmt.Sprintf(addIndexFmt, storeName)) store.EXPECT().Prepare(fmt.Sprintf(listByIndexFmt, storeName, storeName)) diff --git a/pkg/sqlcache/informer/listoption_indexer.go b/pkg/sqlcache/informer/listoption_indexer.go index 305d7ce9..7e542423 100644 --- a/pkg/sqlcache/informer/listoption_indexer.go +++ b/pkg/sqlcache/informer/listoption_indexer.go @@ -100,8 +100,10 @@ func NewListOptionIndexer(ctx context.Context, fields [][]string, s Store, names namespaced: namespaced, indexedFields: indexedFields, } - l.RegisterAfterUpsert(l.addIndexFields) - l.RegisterAfterUpsert(l.addLabels) + l.RegisterAfterAdd(l.addIndexFields) + l.RegisterAfterUpdate(l.addIndexFields) + l.RegisterAfterAdd(l.addLabels) + l.RegisterAfterUpdate(l.addLabels) l.RegisterAfterDelete(l.deleteIndexFields) l.RegisterAfterDelete(l.deleteLabels) columnDefs := make([]string, len(indexedFields)) diff --git a/pkg/sqlcache/informer/listoption_indexer_test.go b/pkg/sqlcache/informer/listoption_indexer_test.go index a7f30f44..14bc8d2a 100644 --- a/pkg/sqlcache/informer/listoption_indexer_test.go +++ b/pkg/sqlcache/informer/listoption_indexer_test.go @@ -50,11 +50,13 @@ func TestNewListOptionIndexer(t *testing.T) { t.Fail() } }) - store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().RegisterAfterAdd(gomock.Any()) + store.EXPECT().RegisterAfterUpdate(gomock.Any()) store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() // end NewIndexer() logic - store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) // create field table @@ -115,11 +117,13 @@ func TestNewListOptionIndexer(t *testing.T) { t.Fail() } }) - store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().RegisterAfterAdd(gomock.Any()) + store.EXPECT().RegisterAfterUpdate(gomock.Any()) store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() // end NewIndexer() logic - store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")) @@ -144,11 +148,13 @@ func TestNewListOptionIndexer(t *testing.T) { t.Fail() } }) - store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().RegisterAfterAdd(gomock.Any()) + store.EXPECT().RegisterAfterUpdate(gomock.Any()) store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() // end NewIndexer() logic - store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) @@ -181,11 +187,13 @@ func TestNewListOptionIndexer(t *testing.T) { t.Fail() } }) - store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().RegisterAfterAdd(gomock.Any()) + store.EXPECT().RegisterAfterUpdate(gomock.Any()) store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() // end NewIndexer() logic - store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) @@ -222,11 +230,13 @@ func TestNewListOptionIndexer(t *testing.T) { t.Fail() } }) - store.EXPECT().RegisterAfterUpsert(gomock.Any()) + store.EXPECT().RegisterAfterAdd(gomock.Any()) + store.EXPECT().RegisterAfterUpdate(gomock.Any()) store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() // end NewIndexer() logic - store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) diff --git a/pkg/sqlcache/informer/sql_mocks_test.go b/pkg/sqlcache/informer/sql_mocks_test.go index b8f7c578..91031baa 100644 --- a/pkg/sqlcache/informer/sql_mocks_test.go +++ b/pkg/sqlcache/informer/sql_mocks_test.go @@ -279,6 +279,18 @@ func (mr *MockStoreMockRecorder) ReadStrings(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockStore)(nil).ReadStrings), arg0) } +// RegisterAfterAdd mocks base method. +func (m *MockStore) RegisterAfterAdd(arg0 func(string, any, transaction.Client) error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RegisterAfterAdd", arg0) +} + +// RegisterAfterAdd indicates an expected call of RegisterAfterAdd. +func (mr *MockStoreMockRecorder) RegisterAfterAdd(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterAdd", reflect.TypeOf((*MockStore)(nil).RegisterAfterAdd), arg0) +} + // RegisterAfterDelete mocks base method. func (m *MockStore) RegisterAfterDelete(arg0 func(string, transaction.Client) error) { m.ctrl.T.Helper() @@ -291,16 +303,16 @@ func (mr *MockStoreMockRecorder) RegisterAfterDelete(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterDelete", reflect.TypeOf((*MockStore)(nil).RegisterAfterDelete), arg0) } -// RegisterAfterUpsert mocks base method. -func (m *MockStore) RegisterAfterUpsert(arg0 func(string, any, transaction.Client) error) { +// RegisterAfterUpdate mocks base method. +func (m *MockStore) RegisterAfterUpdate(arg0 func(string, any, transaction.Client) error) { m.ctrl.T.Helper() - m.ctrl.Call(m, "RegisterAfterUpsert", arg0) + m.ctrl.Call(m, "RegisterAfterUpdate", arg0) } -// RegisterAfterUpsert indicates an expected call of RegisterAfterUpsert. -func (mr *MockStoreMockRecorder) RegisterAfterUpsert(arg0 any) *gomock.Call { +// RegisterAfterUpdate indicates an expected call of RegisterAfterUpdate. +func (mr *MockStoreMockRecorder) RegisterAfterUpdate(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterUpsert", reflect.TypeOf((*MockStore)(nil).RegisterAfterUpsert), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterUpdate", reflect.TypeOf((*MockStore)(nil).RegisterAfterUpdate), arg0) } // Replace mocks base method. diff --git a/pkg/sqlcache/store/store.go b/pkg/sqlcache/store/store.go index 4ec00fec..4143e989 100644 --- a/pkg/sqlcache/store/store.go +++ b/pkg/sqlcache/store/store.go @@ -54,7 +54,8 @@ type Store struct { listStmt *sql.Stmt listKeysStmt *sql.Stmt - afterUpsert []func(key string, obj any, tx transaction.Client) error + afterAdd []func(key string, obj any, tx transaction.Client) error + afterUpdate []func(key string, obj any, tx transaction.Client) error afterDelete []func(key string, tx transaction.Client) error } @@ -70,7 +71,8 @@ func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Clie Client: c, keyFunc: keyFunc, shouldEncrypt: shouldEncrypt, - afterUpsert: []func(key string, obj any, tx transaction.Client) error{}, + afterAdd: []func(key string, obj any, tx transaction.Client) error{}, + afterUpdate: []func(key string, obj any, tx transaction.Client) error{}, afterDelete: []func(key string, tx transaction.Client) error{}, } @@ -106,22 +108,6 @@ func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Clie } /* Core methods */ -// upsert saves an obj with its key, or updates key with obj if it exists in this Store -func (s *Store) upsert(key string, obj any) error { - return s.WithTransaction(s.ctx, true, func(tx transaction.Client) error { - err := s.Upsert(tx, s.upsertStmt, key, obj, s.shouldEncrypt) - if err != nil { - return &db.QueryError{QueryString: s.upsertQuery, Err: err} - } - - err = s.runAfterUpsert(key, obj, tx) - if err != nil { - return err - } - - return nil - }) -} // deleteByKey deletes the object associated with key, if it exists in this Store func (s *Store) deleteByKey(key string) error { @@ -167,7 +153,19 @@ func (s *Store) Add(obj any) error { return err } - err = s.upsert(key, obj) + err = s.WithTransaction(s.ctx, true, func(tx transaction.Client) error { + err := s.Upsert(tx, s.upsertStmt, key, obj, s.shouldEncrypt) + if err != nil { + return &db.QueryError{QueryString: s.upsertQuery, Err: err} + } + + err = s.runAfterAdd(key, obj, tx) + if err != nil { + return err + } + + return nil + }) if err != nil { log.Errorf("Error in Store.Add for type %v: %v", s.name, err) return err @@ -177,7 +175,29 @@ func (s *Store) Add(obj any) error { // Update saves an obj, or updates it if it exists in this Store func (s *Store) Update(obj any) error { - return s.Add(obj) + key, err := s.keyFunc(obj) + if err != nil { + return err + } + + err = s.WithTransaction(s.ctx, true, func(tx transaction.Client) error { + err := s.Upsert(tx, s.upsertStmt, key, obj, s.shouldEncrypt) + if err != nil { + return &db.QueryError{QueryString: s.upsertQuery, Err: err} + } + + err = s.runAfterUpdate(key, obj, tx) + if err != nil { + return err + } + + return nil + }) + if err != nil { + log.Errorf("Error in Store.Update for type %v: %v", s.name, err) + return err + } + return nil } // Delete deletes the given object, if it exists in this Store @@ -279,7 +299,7 @@ func (s *Store) replaceByKey(objects map[string]any) error { if err != nil { return err } - err = s.runAfterUpsert(key, obj, txC) + err = s.runAfterAdd(key, obj, txC) if err != nil { return err } @@ -296,11 +316,6 @@ func (s *Store) Resync() error { /* Utilities */ -// RegisterAfterUpsert registers a func to be called after each upsert -func (s *Store) RegisterAfterUpsert(f func(key string, obj any, txC transaction.Client) error) { - s.afterUpsert = append(s.afterUpsert, f) -} - func (s *Store) GetName() string { return s.name } @@ -313,10 +328,24 @@ func (s *Store) GetType() reflect.Type { return s.typ } -// keep -// runAfterUpsert executes functions registered to run after upsert -func (s *Store) runAfterUpsert(key string, obj any, txC transaction.Client) error { - for _, f := range s.afterUpsert { +// RegisterAfterAdd registers a func to be called after each add event +func (s *Store) RegisterAfterAdd(f func(key string, obj any, txC transaction.Client) error) { + s.afterAdd = append(s.afterAdd, f) +} + +// RegisterAfterUpdate registers a func to be called after each update event +func (s *Store) RegisterAfterUpdate(f func(key string, obj any, txC transaction.Client) error) { + s.afterUpdate = append(s.afterUpdate, f) +} + +// RegisterAfterDelete registers a func to be called after each deletion +func (s *Store) RegisterAfterDelete(f func(key string, txC transaction.Client) error) { + s.afterDelete = append(s.afterDelete, f) +} + +// runAfterAdd executes functions registered to run after add event +func (s *Store) runAfterAdd(key string, obj any, txC transaction.Client) error { + for _, f := range s.afterAdd { err := f(key, obj, txC) if err != nil { return err @@ -325,13 +354,18 @@ func (s *Store) runAfterUpsert(key string, obj any, txC transaction.Client) erro return nil } -// RegisterAfterDelete registers a func to be called after each deletion -func (s *Store) RegisterAfterDelete(f func(key string, txC transaction.Client) error) { - s.afterDelete = append(s.afterDelete, f) +// runAfterUpdate executes functions registered to run after update event +func (s *Store) runAfterUpdate(key string, obj any, txC transaction.Client) error { + for _, f := range s.afterUpdate { + err := f(key, obj, txC) + if err != nil { + return err + } + } + return nil } -// keep -// runAfterDelete executes functions registered to run after upsert +// runAfterDelete executes functions registered to run after delete event func (s *Store) runAfterDelete(key string, txC transaction.Client) error { for _, f := range s.afterDelete { err := f(key, txC) diff --git a/pkg/sqlcache/store/store_test.go b/pkg/sqlcache/store/store_test.go index e7db7045..fc3a8ae6 100644 --- a/pkg/sqlcache/store/store_test.go +++ b/pkg/sqlcache/store/store_test.go @@ -62,7 +62,7 @@ func TestAdd(t *testing.T) { }, }) - tests = append(tests, testCase{description: "Add with no DB client errors and an afterUpsert function", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Add with no DB client errors and an afterAdd function", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) @@ -76,7 +76,7 @@ func TestAdd(t *testing.T) { }) var count int - store.afterUpsert = append(store.afterUpsert, func(key string, object any, tx transaction.Client) error { + store.afterAdd = append(store.afterAdd, func(key string, object any, tx transaction.Client) error { count++ return nil }) @@ -86,7 +86,7 @@ func TestAdd(t *testing.T) { }, }) - tests = append(tests, testCase{description: "Add with no DB client errors and an afterUpsert function that returns error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Add with no DB client errors and an afterAdd function that returns error", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) @@ -99,7 +99,7 @@ func TestAdd(t *testing.T) { } }) - store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC transaction.Client) error { + store.afterAdd = append(store.afterAdd, func(key string, object any, txC transaction.Client) error { return fmt.Errorf("error") }) err := store.Add(testObject) @@ -184,7 +184,7 @@ func TestUpdate(t *testing.T) { }, }) - tests = append(tests, testCase{description: "Update with no DB client errors and an afterUpsert function", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Update with no DB client errors and an afterUpdate function", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) @@ -198,7 +198,7 @@ func TestUpdate(t *testing.T) { }) var count int - store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC transaction.Client) error { + store.afterUpdate = append(store.afterUpdate, func(key string, object any, txC transaction.Client) error { count++ return nil }) @@ -208,7 +208,7 @@ func TestUpdate(t *testing.T) { }, }) - tests = append(tests, testCase{description: "Update with no DB client errors and an afterUpsert function that returns error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Update with no DB client errors and an afterUpdate function that returns error", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) @@ -222,7 +222,7 @@ func TestUpdate(t *testing.T) { } }) - store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC transaction.Client) error { + store.afterUpdate = append(store.afterUpdate, func(key string, object any, txC transaction.Client) error { return fmt.Errorf("error") }) err := store.Update(testObject)