diff --git a/pkg/sqlcache/informer/indexer.go b/pkg/sqlcache/informer/indexer.go index 2a154a32..88f8ce03 100644 --- a/pkg/sqlcache/informer/indexer.go +++ b/pkg/sqlcache/informer/indexer.go @@ -72,7 +72,8 @@ type Store interface { GetName() string RegisterAfterAdd(f func(key string, obj any, tx transaction.Client) error) RegisterAfterUpdate(f func(key string, obj any, tx transaction.Client) error) - RegisterAfterDelete(f func(key string, tx transaction.Client) error) + RegisterAfterDelete(f func(key string, obj any, tx transaction.Client) error) + RegisterAfterDeleteAll(f func(tx transaction.Client) error) GetShouldEncrypt() bool GetType() reflect.Type } diff --git a/pkg/sqlcache/informer/listoption_indexer.go b/pkg/sqlcache/informer/listoption_indexer.go index 2c6c2e40..8898cd51 100644 --- a/pkg/sqlcache/informer/listoption_indexer.go +++ b/pkg/sqlcache/informer/listoption_indexer.go @@ -28,15 +28,19 @@ type ListOptionIndexer struct { namespaced bool indexedFields []string - addFieldQuery string - deleteFieldQuery string - upsertLabelsQuery string - deleteLabelsQuery string + addFieldsQuery string + deleteFieldsByKeyQuery string + deleteFieldsQuery string + upsertLabelsQuery string + deleteLabelsByKeyQuery string + deleteLabelsQuery string - addFieldStmt *sql.Stmt - deleteFieldStmt *sql.Stmt - upsertLabelsStmt *sql.Stmt - deleteLabelsStmt *sql.Stmt + addFieldsStmt *sql.Stmt + deleteFieldsByKeyStmt *sql.Stmt + deleteFieldsStmt *sql.Stmt + upsertLabelsStmt *sql.Stmt + deleteLabelsByKeyStmt *sql.Stmt + deleteLabelsStmt *sql.Stmt } var ( @@ -56,6 +60,7 @@ const ( %s )` createFieldsIndexFmt = `CREATE INDEX "%s_%s_index" ON "%s_fields"("%s")` + deleteFieldsFmt = `DELETE FROM "%s_fields"` failedToGetFromSliceFmt = "[listoption indexer] failed to get subfield [%s] from slice items" @@ -67,8 +72,9 @@ const ( )` createLabelsTableIndexFmt = `CREATE INDEX IF NOT EXISTS "%s_labels_index" ON "%s_labels"(label, value)` - upsertLabelsStmtFmt = `REPLACE INTO "%s_labels"(key, label, value) VALUES (?, ?, ?)` - deleteLabelsStmtFmt = `DELETE FROM "%s_labels" WHERE KEY = ?` + upsertLabelsStmtFmt = `REPLACE INTO "%s_labels"(key, label, value) VALUES (?, ?, ?)` + deleteLabelsByKeyStmtFmt = `DELETE FROM "%s_labels" WHERE KEY = ?` + deleteLabelsStmtFmt = `DELETE FROM "%s_labels"` ) // NewListOptionIndexer returns a SQLite-backed cache.Indexer of unstructured.Unstructured Kubernetes resources of a certain GVK @@ -104,8 +110,10 @@ func NewListOptionIndexer(ctx context.Context, fields [][]string, s Store, names l.RegisterAfterUpdate(l.addIndexFields) l.RegisterAfterAdd(l.addLabels) l.RegisterAfterUpdate(l.addLabels) - l.RegisterAfterDelete(l.deleteIndexFields) - l.RegisterAfterDelete(l.deleteLabels) + l.RegisterAfterDelete(l.deleteFieldsByKey) + l.RegisterAfterDelete(l.deleteLabelsByKey) + l.RegisterAfterDeleteAll(l.deleteFields) + l.RegisterAfterDeleteAll(l.deleteLabels) columnDefs := make([]string, len(indexedFields)) for index, field := range indexedFields { column := fmt.Sprintf(`"%s" TEXT`, field) @@ -159,21 +167,25 @@ func NewListOptionIndexer(ctx context.Context, fields [][]string, s Store, names return nil, err } - l.addFieldQuery = fmt.Sprintf( + l.addFieldsQuery = fmt.Sprintf( `INSERT INTO "%s_fields"(key, %s) VALUES (?, %s) ON CONFLICT DO UPDATE SET %s`, dbName, strings.Join(columns, ", "), strings.Join(qmarks, ", "), strings.Join(setStatements, ", "), ) - l.deleteFieldQuery = fmt.Sprintf(`DELETE FROM "%s_fields" WHERE key = ?`, dbName) + l.deleteFieldsByKeyQuery = fmt.Sprintf(`DELETE FROM "%s_fields" WHERE key = ?`, dbName) + l.deleteFieldsQuery = fmt.Sprintf(deleteFieldsFmt, dbName) - l.addFieldStmt = l.Prepare(l.addFieldQuery) - l.deleteFieldStmt = l.Prepare(l.deleteFieldQuery) + l.addFieldsStmt = l.Prepare(l.addFieldsQuery) + l.deleteFieldsByKeyStmt = l.Prepare(l.deleteFieldsByKeyQuery) + l.deleteFieldsStmt = l.Prepare(l.deleteFieldsQuery) l.upsertLabelsQuery = fmt.Sprintf(upsertLabelsStmtFmt, dbName) + l.deleteLabelsByKeyQuery = fmt.Sprintf(deleteLabelsByKeyStmtFmt, dbName) l.deleteLabelsQuery = fmt.Sprintf(deleteLabelsStmtFmt, dbName) l.upsertLabelsStmt = l.Prepare(l.upsertLabelsQuery) + l.deleteLabelsByKeyStmt = l.Prepare(l.deleteLabelsByKeyQuery) l.deleteLabelsStmt = l.Prepare(l.deleteLabelsQuery) return l, nil @@ -203,9 +215,9 @@ func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx transaction.C } } - _, err := tx.Stmt(l.addFieldStmt).Exec(args...) + _, err := tx.Stmt(l.addFieldsStmt).Exec(args...) if err != nil { - return &db.QueryError{QueryString: l.addFieldQuery, Err: err} + return &db.QueryError{QueryString: l.addFieldsQuery, Err: err} } return nil } @@ -226,18 +238,34 @@ func (l *ListOptionIndexer) addLabels(key string, obj any, tx transaction.Client return nil } -func (l *ListOptionIndexer) deleteIndexFields(key string, tx transaction.Client) error { +func (l *ListOptionIndexer) deleteFieldsByKey(key string, _ any, tx transaction.Client) error { args := []any{key} - _, err := tx.Stmt(l.deleteFieldStmt).Exec(args...) + _, err := tx.Stmt(l.deleteFieldsByKeyStmt).Exec(args...) if err != nil { - return &db.QueryError{QueryString: l.deleteFieldQuery, Err: err} + return &db.QueryError{QueryString: l.deleteFieldsByKeyQuery, Err: err} } return nil } -func (l *ListOptionIndexer) deleteLabels(key string, tx transaction.Client) error { - _, err := tx.Stmt(l.deleteLabelsStmt).Exec(key) +func (l *ListOptionIndexer) deleteFields(tx transaction.Client) error { + _, err := tx.Stmt(l.deleteFieldsStmt).Exec() + if err != nil { + return &db.QueryError{QueryString: l.deleteFieldsQuery, Err: err} + } + return nil +} + +func (l *ListOptionIndexer) deleteLabelsByKey(key string, _ any, tx transaction.Client) error { + _, err := tx.Stmt(l.deleteLabelsByKeyStmt).Exec(key) + if err != nil { + return &db.QueryError{QueryString: l.deleteLabelsByKeyQuery, Err: err} + } + return nil +} + +func (l *ListOptionIndexer) deleteLabels(tx transaction.Client) error { + _, err := tx.Stmt(l.deleteLabelsStmt).Exec() if err != nil { return &db.QueryError{QueryString: l.deleteLabelsQuery, Err: err} } diff --git a/pkg/sqlcache/informer/listoption_indexer_test.go b/pkg/sqlcache/informer/listoption_indexer_test.go index 81b27e76..7df40c21 100644 --- a/pkg/sqlcache/informer/listoption_indexer_test.go +++ b/pkg/sqlcache/informer/listoption_indexer_test.go @@ -91,6 +91,7 @@ func TestNewListOptionIndexer(t *testing.T) { store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2) // create field table txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) @@ -158,6 +159,7 @@ func TestNewListOptionIndexer(t *testing.T) { store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2) store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")) @@ -189,6 +191,7 @@ func TestNewListOptionIndexer(t *testing.T) { store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, fmt.Errorf("error")) @@ -228,6 +231,7 @@ func TestNewListOptionIndexer(t *testing.T) { store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, nil) @@ -271,6 +275,7 @@ func TestNewListOptionIndexer(t *testing.T) { store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) + store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, nil) diff --git a/pkg/sqlcache/informer/sql_mocks_test.go b/pkg/sqlcache/informer/sql_mocks_test.go index 91031baa..85336a12 100644 --- a/pkg/sqlcache/informer/sql_mocks_test.go +++ b/pkg/sqlcache/informer/sql_mocks_test.go @@ -292,7 +292,7 @@ func (mr *MockStoreMockRecorder) RegisterAfterAdd(arg0 any) *gomock.Call { } // RegisterAfterDelete mocks base method. -func (m *MockStore) RegisterAfterDelete(arg0 func(string, transaction.Client) error) { +func (m *MockStore) RegisterAfterDelete(arg0 func(string, any, transaction.Client) error) { m.ctrl.T.Helper() m.ctrl.Call(m, "RegisterAfterDelete", arg0) } @@ -303,6 +303,18 @@ func (mr *MockStoreMockRecorder) RegisterAfterDelete(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterDelete", reflect.TypeOf((*MockStore)(nil).RegisterAfterDelete), arg0) } +// RegisterAfterDeleteAll mocks base method. +func (m *MockStore) RegisterAfterDeleteAll(arg0 func(transaction.Client) error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RegisterAfterDeleteAll", arg0) +} + +// RegisterAfterDeleteAll indicates an expected call of RegisterAfterDeleteAll. +func (mr *MockStoreMockRecorder) RegisterAfterDeleteAll(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterDeleteAll", reflect.TypeOf((*MockStore)(nil).RegisterAfterDeleteAll), arg0) +} + // RegisterAfterUpdate mocks base method. func (m *MockStore) RegisterAfterUpdate(arg0 func(string, any, transaction.Client) error) { m.ctrl.T.Helper() diff --git a/pkg/sqlcache/store/store.go b/pkg/sqlcache/store/store.go index 4143e989..d73b7354 100644 --- a/pkg/sqlcache/store/store.go +++ b/pkg/sqlcache/store/store.go @@ -19,12 +19,13 @@ import ( ) const ( - upsertStmtFmt = `REPLACE INTO "%s"(key, object, objectnonce, dekid) VALUES (?, ?, ?, ?)` - deleteStmtFmt = `DELETE FROM "%s" WHERE key = ?` - getStmtFmt = `SELECT object, objectnonce, dekid FROM "%s" WHERE key = ?` - listStmtFmt = `SELECT object, objectnonce, dekid FROM "%s"` - listKeysStmtFmt = `SELECT key FROM "%s"` - createTableFmt = `CREATE TABLE IF NOT EXISTS "%s" ( + upsertStmtFmt = `REPLACE INTO "%s"(key, object, objectnonce, dekid) VALUES (?, ?, ?, ?)` + deleteStmtFmt = `DELETE FROM "%s" WHERE key = ?` + deleteAllStmtFmt = `DELETE FROM "%s"` + getStmtFmt = `SELECT object, objectnonce, dekid FROM "%s" WHERE key = ?` + listStmtFmt = `SELECT object, objectnonce, dekid FROM "%s"` + listKeysStmtFmt = `SELECT key FROM "%s"` + createTableFmt = `CREATE TABLE IF NOT EXISTS "%s" ( key TEXT UNIQUE NOT NULL PRIMARY KEY, object BLOB, objectnonce BLOB, @@ -42,21 +43,24 @@ type Store struct { keyFunc cache.KeyFunc shouldEncrypt bool - upsertQuery string - deleteQuery string - getQuery string - listQuery string - listKeysQuery string + upsertQuery string + deleteQuery string + deleteAllQuery string + getQuery string + listQuery string + listKeysQuery string - upsertStmt *sql.Stmt - deleteStmt *sql.Stmt - getStmt *sql.Stmt - listStmt *sql.Stmt - listKeysStmt *sql.Stmt + upsertStmt *sql.Stmt + deleteStmt *sql.Stmt + deleteAllStmt *sql.Stmt + getStmt *sql.Stmt + listStmt *sql.Stmt + listKeysStmt *sql.Stmt - afterAdd []func(key string, obj any, tx transaction.Client) error - afterUpdate []func(key string, obj any, tx transaction.Client) error - afterDelete []func(key string, tx transaction.Client) error + afterAdd []func(key string, obj any, tx transaction.Client) error + afterUpdate []func(key string, obj any, tx transaction.Client) error + afterDelete []func(key string, obj any, tx transaction.Client) error + afterDeleteAll []func(tx transaction.Client) error } // Test that Store implements cache.Indexer @@ -65,15 +69,16 @@ var _ cache.Store = (*Store)(nil) // NewStore creates a SQLite-backed cache.Store for objects of the given example type func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Client, shouldEncrypt bool, name string) (*Store, error) { s := &Store{ - ctx: ctx, - name: name, - typ: reflect.TypeOf(example), - Client: c, - keyFunc: keyFunc, - shouldEncrypt: shouldEncrypt, - afterAdd: []func(key string, obj any, tx transaction.Client) error{}, - afterUpdate: []func(key string, obj any, tx transaction.Client) error{}, - afterDelete: []func(key string, tx transaction.Client) error{}, + ctx: ctx, + name: name, + typ: reflect.TypeOf(example), + Client: c, + keyFunc: keyFunc, + shouldEncrypt: shouldEncrypt, + afterAdd: []func(key string, obj any, tx transaction.Client) error{}, + afterUpdate: []func(key string, obj any, tx transaction.Client) error{}, + afterDelete: []func(key string, obj any, tx transaction.Client) error{}, + afterDeleteAll: []func(tx transaction.Client) error{}, } dbName := db.Sanitize(s.name) @@ -94,12 +99,14 @@ func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Clie s.upsertQuery = fmt.Sprintf(upsertStmtFmt, dbName) s.deleteQuery = fmt.Sprintf(deleteStmtFmt, dbName) + s.deleteAllQuery = fmt.Sprintf(deleteAllStmtFmt, dbName) s.getQuery = fmt.Sprintf(getStmtFmt, dbName) s.listQuery = fmt.Sprintf(listStmtFmt, dbName) s.listKeysQuery = fmt.Sprintf(listKeysStmtFmt, dbName) s.upsertStmt = s.Prepare(s.upsertQuery) s.deleteStmt = s.Prepare(s.deleteQuery) + s.deleteAllStmt = s.Prepare(s.deleteAllQuery) s.getStmt = s.Prepare(s.getQuery) s.listStmt = s.Prepare(s.listQuery) s.listKeysStmt = s.Prepare(s.listKeysQuery) @@ -110,14 +117,14 @@ func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Clie /* Core methods */ // deleteByKey deletes the object associated with key, if it exists in this Store -func (s *Store) deleteByKey(key string) error { +func (s *Store) deleteByKey(key string, obj any) error { return s.WithTransaction(s.ctx, true, func(tx transaction.Client) error { _, err := tx.Stmt(s.deleteStmt).Exec(key) if err != nil { return &db.QueryError{QueryString: s.deleteQuery, Err: err} } - err = s.runAfterDelete(key, tx) + err = s.runAfterDelete(key, obj, tx) if err != nil { return err } @@ -206,7 +213,7 @@ func (s *Store) Delete(obj any) error { if err != nil { return err } - err = s.deleteByKey(key) + err = s.deleteByKey(key, obj) if err != nil { log.Errorf("Error in Store.Delete for type %v: %v", s.name, err) return err @@ -266,32 +273,25 @@ func (s *Store) Replace(objects []any, _ string) error { } objectMap[key] = object } - return s.replaceByKey(objectMap) + err := s.replaceByKey(objectMap) + if err != nil { + log.Errorf("Error in Store.Replace for type %v: %v", s.name, err) + return err + } + return nil } // replaceByKey will delete the contents of the Store, using instead the given key to obj map func (s *Store) replaceByKey(objects map[string]any) error { return s.WithTransaction(s.ctx, true, func(txC transaction.Client) error { - txCListKeys := txC.Stmt(s.listKeysStmt) - - rows, err := s.QueryForRows(s.ctx, txCListKeys) + _, err := txC.Stmt(s.deleteAllStmt).Exec() if err != nil { - return err - } - keys, err := s.ReadStrings(rows) - if err != nil { - return err + return &db.QueryError{QueryString: s.deleteAllQuery, Err: err} } - for _, key := range keys { - _, err = txC.Stmt(s.deleteStmt).Exec(key) - if err != nil { - return err - } - err = s.runAfterDelete(key, txC) - if err != nil { - return err - } + err = s.runAfterDeleteAll(txC) + if err != nil { + return err } for key, obj := range objects { @@ -339,10 +339,15 @@ func (s *Store) RegisterAfterUpdate(f func(key string, obj any, txC transaction. } // RegisterAfterDelete registers a func to be called after each deletion -func (s *Store) RegisterAfterDelete(f func(key string, txC transaction.Client) error) { +func (s *Store) RegisterAfterDelete(f func(key string, obj any, txC transaction.Client) error) { s.afterDelete = append(s.afterDelete, f) } +// RegisterAfterDelete registers a func to be called after each deletion +func (s *Store) RegisterAfterDeleteAll(f func(txC transaction.Client) error) { + s.afterDeleteAll = append(s.afterDeleteAll, f) +} + // runAfterAdd executes functions registered to run after add event func (s *Store) runAfterAdd(key string, obj any, txC transaction.Client) error { for _, f := range s.afterAdd { @@ -366,9 +371,21 @@ func (s *Store) runAfterUpdate(key string, obj any, txC transaction.Client) erro } // runAfterDelete executes functions registered to run after delete event -func (s *Store) runAfterDelete(key string, txC transaction.Client) error { +func (s *Store) runAfterDelete(key string, obj any, txC transaction.Client) error { for _, f := range s.afterDelete { - err := f(key, txC) + err := f(key, obj, txC) + if err != nil { + return err + } + } + return nil +} + +// runAfterDeleteAll executes functions registered to run after delete events when +// the database is being replaced. +func (s *Store) runAfterDeleteAll(txC transaction.Client) error { + for _, f := range s.afterDeleteAll { + err := f(txC) if err != nil { return err } diff --git a/pkg/sqlcache/store/store_test.go b/pkg/sqlcache/store/store_test.go index fc3a8ae6..0e16dc70 100644 --- a/pkg/sqlcache/store/store_test.go +++ b/pkg/sqlcache/store/store_test.go @@ -550,14 +550,10 @@ func TestReplace(t *testing.T) { tests = append(tests, testCase{description: "Replace with no DB client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - r := &sql.Rows{} stmt := NewMockStmt(gomock.NewController(t)) - txC.EXPECT().Stmt(store.listKeysStmt).Return(stmt) - c.EXPECT().QueryForRows(context.Background(), stmt).Return(r, nil) - c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) - txC.EXPECT().Stmt(store.deleteStmt).Return(stmt) - stmt.EXPECT().Exec(testObject.Id) + txC.EXPECT().Stmt(store.deleteAllStmt).Return(stmt) + stmt.EXPECT().Exec() c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt) c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( @@ -575,11 +571,10 @@ func TestReplace(t *testing.T) { tests = append(tests, testCase{description: "Replace with no DB client errors and no items", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - r := &sql.Rows{} - txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) - c.EXPECT().QueryForRows(context.Background(), store.listKeysStmt).Return(r, nil) - c.EXPECT().ReadStrings(r).Return([]string{}, nil) + stmt := NewMockStmt(gomock.NewController(t)) + txC.EXPECT().Stmt(store.deleteAllStmt).Return(stmt) + stmt.EXPECT().Exec() c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt) c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( @@ -602,39 +597,15 @@ func TestReplace(t *testing.T) { assert.NotNil(t, err) }, }) - tests = append(tests, testCase{description: "Replace with DB client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Replace with DB client deleteAllStmt error", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - r := &sql.Rows{} - txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) - c.EXPECT().QueryForRows(context.Background(), store.listKeysStmt).Return(r, nil) - c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error")) + deleteAllStmt := NewMockStmt(gomock.NewController(t)) - c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( - func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { - err := f(txC) - if err == nil { - t.Fail() - } - }) + txC.EXPECT().Stmt(store.deleteAllStmt).Return(deleteAllStmt) + deleteAllStmt.EXPECT().Exec().Return(nil, fmt.Errorf("error")) - err := store.Replace([]any{testObject}, testObject.Id) - assert.NotNil(t, err) - }, - }) - tests = append(tests, testCase{description: "Replace with TX client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) { - c, txC := SetupMockDB(t) - store := SetupStore(t, c, shouldEncrypt) - r := &sql.Rows{} - listKeysStmt := NewMockStmt(gomock.NewController(t)) - deleteStmt := NewMockStmt(gomock.NewController(t)) - - txC.EXPECT().Stmt(store.listKeysStmt).Return(listKeysStmt) - c.EXPECT().QueryForRows(context.Background(), listKeysStmt).Return(r, nil) - c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) - txC.EXPECT().Stmt(store.deleteStmt).Return(deleteStmt) - deleteStmt.EXPECT().Exec(testObject.Id).Return(nil, fmt.Errorf("error")) c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { err := f(txC) @@ -650,15 +621,10 @@ func TestReplace(t *testing.T) { tests = append(tests, testCase{description: "Replace with DB client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - r := &sql.Rows{} - listKeysStmt := NewMockStmt(gomock.NewController(t)) - deleteStmt := NewMockStmt(gomock.NewController(t)) + deleteAllStmt := NewMockStmt(gomock.NewController(t)) - txC.EXPECT().Stmt(store.listKeysStmt).Return(listKeysStmt) - c.EXPECT().QueryForRows(context.Background(), listKeysStmt).Return(r, nil) - c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) - txC.EXPECT().Stmt(store.deleteStmt).Return(deleteStmt) - deleteStmt.EXPECT().Exec(testObject.Id).Return(nil, nil) + txC.EXPECT().Stmt(store.deleteAllStmt).Return(deleteAllStmt) + deleteAllStmt.EXPECT().Exec() c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt).Return(fmt.Errorf("error")) c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( @@ -719,6 +685,7 @@ func SetupMockDB(t *testing.T) (*MockClient, *MockTXClient) { // use stmt mock here dbC.EXPECT().Prepare(fmt.Sprintf(upsertStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) dbC.EXPECT().Prepare(fmt.Sprintf(deleteStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) + dbC.EXPECT().Prepare(fmt.Sprintf(deleteAllStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) dbC.EXPECT().Prepare(fmt.Sprintf(getStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) dbC.EXPECT().Prepare(fmt.Sprintf(listStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) dbC.EXPECT().Prepare(fmt.Sprintf(listKeysStmtFmt, "testStoreObject")).Return(&sql.Stmt{})