1
0
mirror of https://github.com/rancher/steve.git synced 2025-08-31 06:46:25 +00:00

Add object to RegisterAfterDelete and introduce RegisterAfterDeleteAll (#649)

* Add object to AfterDelete callbacks

* Add RegisterAfterDeleteAll
This commit is contained in:
Tom Lebreux
2025-06-03 15:32:43 -06:00
committed by GitHub
parent a8f3ce48d6
commit 2672969496
6 changed files with 153 additions and 123 deletions

View File

@@ -72,7 +72,8 @@ type Store interface {
GetName() string
RegisterAfterAdd(f func(key string, obj any, tx transaction.Client) error)
RegisterAfterUpdate(f func(key string, obj any, tx transaction.Client) error)
RegisterAfterDelete(f func(key string, tx transaction.Client) error)
RegisterAfterDelete(f func(key string, obj any, tx transaction.Client) error)
RegisterAfterDeleteAll(f func(tx transaction.Client) error)
GetShouldEncrypt() bool
GetType() reflect.Type
}

View File

@@ -28,15 +28,19 @@ type ListOptionIndexer struct {
namespaced bool
indexedFields []string
addFieldQuery string
deleteFieldQuery string
upsertLabelsQuery string
deleteLabelsQuery string
addFieldsQuery string
deleteFieldsByKeyQuery string
deleteFieldsQuery string
upsertLabelsQuery string
deleteLabelsByKeyQuery string
deleteLabelsQuery string
addFieldStmt *sql.Stmt
deleteFieldStmt *sql.Stmt
upsertLabelsStmt *sql.Stmt
deleteLabelsStmt *sql.Stmt
addFieldsStmt *sql.Stmt
deleteFieldsByKeyStmt *sql.Stmt
deleteFieldsStmt *sql.Stmt
upsertLabelsStmt *sql.Stmt
deleteLabelsByKeyStmt *sql.Stmt
deleteLabelsStmt *sql.Stmt
}
var (
@@ -56,6 +60,7 @@ const (
%s
)`
createFieldsIndexFmt = `CREATE INDEX "%s_%s_index" ON "%s_fields"("%s")`
deleteFieldsFmt = `DELETE FROM "%s_fields"`
failedToGetFromSliceFmt = "[listoption indexer] failed to get subfield [%s] from slice items"
@@ -67,8 +72,9 @@ const (
)`
createLabelsTableIndexFmt = `CREATE INDEX IF NOT EXISTS "%s_labels_index" ON "%s_labels"(label, value)`
upsertLabelsStmtFmt = `REPLACE INTO "%s_labels"(key, label, value) VALUES (?, ?, ?)`
deleteLabelsStmtFmt = `DELETE FROM "%s_labels" WHERE KEY = ?`
upsertLabelsStmtFmt = `REPLACE INTO "%s_labels"(key, label, value) VALUES (?, ?, ?)`
deleteLabelsByKeyStmtFmt = `DELETE FROM "%s_labels" WHERE KEY = ?`
deleteLabelsStmtFmt = `DELETE FROM "%s_labels"`
)
// NewListOptionIndexer returns a SQLite-backed cache.Indexer of unstructured.Unstructured Kubernetes resources of a certain GVK
@@ -104,8 +110,10 @@ func NewListOptionIndexer(ctx context.Context, fields [][]string, s Store, names
l.RegisterAfterUpdate(l.addIndexFields)
l.RegisterAfterAdd(l.addLabels)
l.RegisterAfterUpdate(l.addLabels)
l.RegisterAfterDelete(l.deleteIndexFields)
l.RegisterAfterDelete(l.deleteLabels)
l.RegisterAfterDelete(l.deleteFieldsByKey)
l.RegisterAfterDelete(l.deleteLabelsByKey)
l.RegisterAfterDeleteAll(l.deleteFields)
l.RegisterAfterDeleteAll(l.deleteLabels)
columnDefs := make([]string, len(indexedFields))
for index, field := range indexedFields {
column := fmt.Sprintf(`"%s" TEXT`, field)
@@ -159,21 +167,25 @@ func NewListOptionIndexer(ctx context.Context, fields [][]string, s Store, names
return nil, err
}
l.addFieldQuery = fmt.Sprintf(
l.addFieldsQuery = fmt.Sprintf(
`INSERT INTO "%s_fields"(key, %s) VALUES (?, %s) ON CONFLICT DO UPDATE SET %s`,
dbName,
strings.Join(columns, ", "),
strings.Join(qmarks, ", "),
strings.Join(setStatements, ", "),
)
l.deleteFieldQuery = fmt.Sprintf(`DELETE FROM "%s_fields" WHERE key = ?`, dbName)
l.deleteFieldsByKeyQuery = fmt.Sprintf(`DELETE FROM "%s_fields" WHERE key = ?`, dbName)
l.deleteFieldsQuery = fmt.Sprintf(deleteFieldsFmt, dbName)
l.addFieldStmt = l.Prepare(l.addFieldQuery)
l.deleteFieldStmt = l.Prepare(l.deleteFieldQuery)
l.addFieldsStmt = l.Prepare(l.addFieldsQuery)
l.deleteFieldsByKeyStmt = l.Prepare(l.deleteFieldsByKeyQuery)
l.deleteFieldsStmt = l.Prepare(l.deleteFieldsQuery)
l.upsertLabelsQuery = fmt.Sprintf(upsertLabelsStmtFmt, dbName)
l.deleteLabelsByKeyQuery = fmt.Sprintf(deleteLabelsByKeyStmtFmt, dbName)
l.deleteLabelsQuery = fmt.Sprintf(deleteLabelsStmtFmt, dbName)
l.upsertLabelsStmt = l.Prepare(l.upsertLabelsQuery)
l.deleteLabelsByKeyStmt = l.Prepare(l.deleteLabelsByKeyQuery)
l.deleteLabelsStmt = l.Prepare(l.deleteLabelsQuery)
return l, nil
@@ -203,9 +215,9 @@ func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx transaction.C
}
}
_, err := tx.Stmt(l.addFieldStmt).Exec(args...)
_, err := tx.Stmt(l.addFieldsStmt).Exec(args...)
if err != nil {
return &db.QueryError{QueryString: l.addFieldQuery, Err: err}
return &db.QueryError{QueryString: l.addFieldsQuery, Err: err}
}
return nil
}
@@ -226,18 +238,34 @@ func (l *ListOptionIndexer) addLabels(key string, obj any, tx transaction.Client
return nil
}
func (l *ListOptionIndexer) deleteIndexFields(key string, tx transaction.Client) error {
func (l *ListOptionIndexer) deleteFieldsByKey(key string, _ any, tx transaction.Client) error {
args := []any{key}
_, err := tx.Stmt(l.deleteFieldStmt).Exec(args...)
_, err := tx.Stmt(l.deleteFieldsByKeyStmt).Exec(args...)
if err != nil {
return &db.QueryError{QueryString: l.deleteFieldQuery, Err: err}
return &db.QueryError{QueryString: l.deleteFieldsByKeyQuery, Err: err}
}
return nil
}
func (l *ListOptionIndexer) deleteLabels(key string, tx transaction.Client) error {
_, err := tx.Stmt(l.deleteLabelsStmt).Exec(key)
func (l *ListOptionIndexer) deleteFields(tx transaction.Client) error {
_, err := tx.Stmt(l.deleteFieldsStmt).Exec()
if err != nil {
return &db.QueryError{QueryString: l.deleteFieldsQuery, Err: err}
}
return nil
}
func (l *ListOptionIndexer) deleteLabelsByKey(key string, _ any, tx transaction.Client) error {
_, err := tx.Stmt(l.deleteLabelsByKeyStmt).Exec(key)
if err != nil {
return &db.QueryError{QueryString: l.deleteLabelsByKeyQuery, Err: err}
}
return nil
}
func (l *ListOptionIndexer) deleteLabels(tx transaction.Client) error {
_, err := tx.Stmt(l.deleteLabelsStmt).Exec()
if err != nil {
return &db.QueryError{QueryString: l.deleteLabelsQuery, Err: err}
}

View File

@@ -91,6 +91,7 @@ func TestNewListOptionIndexer(t *testing.T) {
store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2)
// create field table
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil)
@@ -158,6 +159,7 @@ func TestNewListOptionIndexer(t *testing.T) {
store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2)
store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error"))
@@ -189,6 +191,7 @@ func TestNewListOptionIndexer(t *testing.T) {
store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, fmt.Errorf("error"))
@@ -228,6 +231,7 @@ func TestNewListOptionIndexer(t *testing.T) {
store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, nil)
@@ -271,6 +275,7 @@ func TestNewListOptionIndexer(t *testing.T) {
store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, nil)

View File

@@ -292,7 +292,7 @@ func (mr *MockStoreMockRecorder) RegisterAfterAdd(arg0 any) *gomock.Call {
}
// RegisterAfterDelete mocks base method.
func (m *MockStore) RegisterAfterDelete(arg0 func(string, transaction.Client) error) {
func (m *MockStore) RegisterAfterDelete(arg0 func(string, any, transaction.Client) error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RegisterAfterDelete", arg0)
}
@@ -303,6 +303,18 @@ func (mr *MockStoreMockRecorder) RegisterAfterDelete(arg0 any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterDelete", reflect.TypeOf((*MockStore)(nil).RegisterAfterDelete), arg0)
}
// RegisterAfterDeleteAll mocks base method.
func (m *MockStore) RegisterAfterDeleteAll(arg0 func(transaction.Client) error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RegisterAfterDeleteAll", arg0)
}
// RegisterAfterDeleteAll indicates an expected call of RegisterAfterDeleteAll.
func (mr *MockStoreMockRecorder) RegisterAfterDeleteAll(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterDeleteAll", reflect.TypeOf((*MockStore)(nil).RegisterAfterDeleteAll), arg0)
}
// RegisterAfterUpdate mocks base method.
func (m *MockStore) RegisterAfterUpdate(arg0 func(string, any, transaction.Client) error) {
m.ctrl.T.Helper()

View File

@@ -19,12 +19,13 @@ import (
)
const (
upsertStmtFmt = `REPLACE INTO "%s"(key, object, objectnonce, dekid) VALUES (?, ?, ?, ?)`
deleteStmtFmt = `DELETE FROM "%s" WHERE key = ?`
getStmtFmt = `SELECT object, objectnonce, dekid FROM "%s" WHERE key = ?`
listStmtFmt = `SELECT object, objectnonce, dekid FROM "%s"`
listKeysStmtFmt = `SELECT key FROM "%s"`
createTableFmt = `CREATE TABLE IF NOT EXISTS "%s" (
upsertStmtFmt = `REPLACE INTO "%s"(key, object, objectnonce, dekid) VALUES (?, ?, ?, ?)`
deleteStmtFmt = `DELETE FROM "%s" WHERE key = ?`
deleteAllStmtFmt = `DELETE FROM "%s"`
getStmtFmt = `SELECT object, objectnonce, dekid FROM "%s" WHERE key = ?`
listStmtFmt = `SELECT object, objectnonce, dekid FROM "%s"`
listKeysStmtFmt = `SELECT key FROM "%s"`
createTableFmt = `CREATE TABLE IF NOT EXISTS "%s" (
key TEXT UNIQUE NOT NULL PRIMARY KEY,
object BLOB,
objectnonce BLOB,
@@ -42,21 +43,24 @@ type Store struct {
keyFunc cache.KeyFunc
shouldEncrypt bool
upsertQuery string
deleteQuery string
getQuery string
listQuery string
listKeysQuery string
upsertQuery string
deleteQuery string
deleteAllQuery string
getQuery string
listQuery string
listKeysQuery string
upsertStmt *sql.Stmt
deleteStmt *sql.Stmt
getStmt *sql.Stmt
listStmt *sql.Stmt
listKeysStmt *sql.Stmt
upsertStmt *sql.Stmt
deleteStmt *sql.Stmt
deleteAllStmt *sql.Stmt
getStmt *sql.Stmt
listStmt *sql.Stmt
listKeysStmt *sql.Stmt
afterAdd []func(key string, obj any, tx transaction.Client) error
afterUpdate []func(key string, obj any, tx transaction.Client) error
afterDelete []func(key string, tx transaction.Client) error
afterAdd []func(key string, obj any, tx transaction.Client) error
afterUpdate []func(key string, obj any, tx transaction.Client) error
afterDelete []func(key string, obj any, tx transaction.Client) error
afterDeleteAll []func(tx transaction.Client) error
}
// Test that Store implements cache.Indexer
@@ -65,15 +69,16 @@ var _ cache.Store = (*Store)(nil)
// NewStore creates a SQLite-backed cache.Store for objects of the given example type
func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Client, shouldEncrypt bool, name string) (*Store, error) {
s := &Store{
ctx: ctx,
name: name,
typ: reflect.TypeOf(example),
Client: c,
keyFunc: keyFunc,
shouldEncrypt: shouldEncrypt,
afterAdd: []func(key string, obj any, tx transaction.Client) error{},
afterUpdate: []func(key string, obj any, tx transaction.Client) error{},
afterDelete: []func(key string, tx transaction.Client) error{},
ctx: ctx,
name: name,
typ: reflect.TypeOf(example),
Client: c,
keyFunc: keyFunc,
shouldEncrypt: shouldEncrypt,
afterAdd: []func(key string, obj any, tx transaction.Client) error{},
afterUpdate: []func(key string, obj any, tx transaction.Client) error{},
afterDelete: []func(key string, obj any, tx transaction.Client) error{},
afterDeleteAll: []func(tx transaction.Client) error{},
}
dbName := db.Sanitize(s.name)
@@ -94,12 +99,14 @@ func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Clie
s.upsertQuery = fmt.Sprintf(upsertStmtFmt, dbName)
s.deleteQuery = fmt.Sprintf(deleteStmtFmt, dbName)
s.deleteAllQuery = fmt.Sprintf(deleteAllStmtFmt, dbName)
s.getQuery = fmt.Sprintf(getStmtFmt, dbName)
s.listQuery = fmt.Sprintf(listStmtFmt, dbName)
s.listKeysQuery = fmt.Sprintf(listKeysStmtFmt, dbName)
s.upsertStmt = s.Prepare(s.upsertQuery)
s.deleteStmt = s.Prepare(s.deleteQuery)
s.deleteAllStmt = s.Prepare(s.deleteAllQuery)
s.getStmt = s.Prepare(s.getQuery)
s.listStmt = s.Prepare(s.listQuery)
s.listKeysStmt = s.Prepare(s.listKeysQuery)
@@ -110,14 +117,14 @@ func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Clie
/* Core methods */
// deleteByKey deletes the object associated with key, if it exists in this Store
func (s *Store) deleteByKey(key string) error {
func (s *Store) deleteByKey(key string, obj any) error {
return s.WithTransaction(s.ctx, true, func(tx transaction.Client) error {
_, err := tx.Stmt(s.deleteStmt).Exec(key)
if err != nil {
return &db.QueryError{QueryString: s.deleteQuery, Err: err}
}
err = s.runAfterDelete(key, tx)
err = s.runAfterDelete(key, obj, tx)
if err != nil {
return err
}
@@ -206,7 +213,7 @@ func (s *Store) Delete(obj any) error {
if err != nil {
return err
}
err = s.deleteByKey(key)
err = s.deleteByKey(key, obj)
if err != nil {
log.Errorf("Error in Store.Delete for type %v: %v", s.name, err)
return err
@@ -266,32 +273,25 @@ func (s *Store) Replace(objects []any, _ string) error {
}
objectMap[key] = object
}
return s.replaceByKey(objectMap)
err := s.replaceByKey(objectMap)
if err != nil {
log.Errorf("Error in Store.Replace for type %v: %v", s.name, err)
return err
}
return nil
}
// replaceByKey will delete the contents of the Store, using instead the given key to obj map
func (s *Store) replaceByKey(objects map[string]any) error {
return s.WithTransaction(s.ctx, true, func(txC transaction.Client) error {
txCListKeys := txC.Stmt(s.listKeysStmt)
rows, err := s.QueryForRows(s.ctx, txCListKeys)
_, err := txC.Stmt(s.deleteAllStmt).Exec()
if err != nil {
return err
}
keys, err := s.ReadStrings(rows)
if err != nil {
return err
return &db.QueryError{QueryString: s.deleteAllQuery, Err: err}
}
for _, key := range keys {
_, err = txC.Stmt(s.deleteStmt).Exec(key)
if err != nil {
return err
}
err = s.runAfterDelete(key, txC)
if err != nil {
return err
}
err = s.runAfterDeleteAll(txC)
if err != nil {
return err
}
for key, obj := range objects {
@@ -339,10 +339,15 @@ func (s *Store) RegisterAfterUpdate(f func(key string, obj any, txC transaction.
}
// RegisterAfterDelete registers a func to be called after each deletion
func (s *Store) RegisterAfterDelete(f func(key string, txC transaction.Client) error) {
func (s *Store) RegisterAfterDelete(f func(key string, obj any, txC transaction.Client) error) {
s.afterDelete = append(s.afterDelete, f)
}
// RegisterAfterDelete registers a func to be called after each deletion
func (s *Store) RegisterAfterDeleteAll(f func(txC transaction.Client) error) {
s.afterDeleteAll = append(s.afterDeleteAll, f)
}
// runAfterAdd executes functions registered to run after add event
func (s *Store) runAfterAdd(key string, obj any, txC transaction.Client) error {
for _, f := range s.afterAdd {
@@ -366,9 +371,21 @@ func (s *Store) runAfterUpdate(key string, obj any, txC transaction.Client) erro
}
// runAfterDelete executes functions registered to run after delete event
func (s *Store) runAfterDelete(key string, txC transaction.Client) error {
func (s *Store) runAfterDelete(key string, obj any, txC transaction.Client) error {
for _, f := range s.afterDelete {
err := f(key, txC)
err := f(key, obj, txC)
if err != nil {
return err
}
}
return nil
}
// runAfterDeleteAll executes functions registered to run after delete events when
// the database is being replaced.
func (s *Store) runAfterDeleteAll(txC transaction.Client) error {
for _, f := range s.afterDeleteAll {
err := f(txC)
if err != nil {
return err
}

View File

@@ -550,14 +550,10 @@ func TestReplace(t *testing.T) {
tests = append(tests, testCase{description: "Replace with no DB client errors and some items", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
stmt := NewMockStmt(gomock.NewController(t))
txC.EXPECT().Stmt(store.listKeysStmt).Return(stmt)
c.EXPECT().QueryForRows(context.Background(), stmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil)
txC.EXPECT().Stmt(store.deleteStmt).Return(stmt)
stmt.EXPECT().Exec(testObject.Id)
txC.EXPECT().Stmt(store.deleteAllStmt).Return(stmt)
stmt.EXPECT().Exec()
c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt)
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do(
@@ -575,11 +571,10 @@ func TestReplace(t *testing.T) {
tests = append(tests, testCase{description: "Replace with no DB client errors and no items", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt)
c.EXPECT().QueryForRows(context.Background(), store.listKeysStmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return([]string{}, nil)
stmt := NewMockStmt(gomock.NewController(t))
txC.EXPECT().Stmt(store.deleteAllStmt).Return(stmt)
stmt.EXPECT().Exec()
c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt)
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do(
@@ -602,39 +597,15 @@ func TestReplace(t *testing.T) {
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Replace with DB client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) {
tests = append(tests, testCase{description: "Replace with DB client deleteAllStmt error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt)
c.EXPECT().QueryForRows(context.Background(), store.listKeysStmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error"))
deleteAllStmt := NewMockStmt(gomock.NewController(t))
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do(
func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) {
err := f(txC)
if err == nil {
t.Fail()
}
})
txC.EXPECT().Stmt(store.deleteAllStmt).Return(deleteAllStmt)
deleteAllStmt.EXPECT().Exec().Return(nil, fmt.Errorf("error"))
err := store.Replace([]any{testObject}, testObject.Id)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Replace with TX client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
listKeysStmt := NewMockStmt(gomock.NewController(t))
deleteStmt := NewMockStmt(gomock.NewController(t))
txC.EXPECT().Stmt(store.listKeysStmt).Return(listKeysStmt)
c.EXPECT().QueryForRows(context.Background(), listKeysStmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil)
txC.EXPECT().Stmt(store.deleteStmt).Return(deleteStmt)
deleteStmt.EXPECT().Exec(testObject.Id).Return(nil, fmt.Errorf("error"))
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do(
func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) {
err := f(txC)
@@ -650,15 +621,10 @@ func TestReplace(t *testing.T) {
tests = append(tests, testCase{description: "Replace with DB client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
listKeysStmt := NewMockStmt(gomock.NewController(t))
deleteStmt := NewMockStmt(gomock.NewController(t))
deleteAllStmt := NewMockStmt(gomock.NewController(t))
txC.EXPECT().Stmt(store.listKeysStmt).Return(listKeysStmt)
c.EXPECT().QueryForRows(context.Background(), listKeysStmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil)
txC.EXPECT().Stmt(store.deleteStmt).Return(deleteStmt)
deleteStmt.EXPECT().Exec(testObject.Id).Return(nil, nil)
txC.EXPECT().Stmt(store.deleteAllStmt).Return(deleteAllStmt)
deleteAllStmt.EXPECT().Exec()
c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt).Return(fmt.Errorf("error"))
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do(
@@ -719,6 +685,7 @@ func SetupMockDB(t *testing.T) (*MockClient, *MockTXClient) {
// use stmt mock here
dbC.EXPECT().Prepare(fmt.Sprintf(upsertStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
dbC.EXPECT().Prepare(fmt.Sprintf(deleteStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
dbC.EXPECT().Prepare(fmt.Sprintf(deleteAllStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
dbC.EXPECT().Prepare(fmt.Sprintf(getStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
dbC.EXPECT().Prepare(fmt.Sprintf(listStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
dbC.EXPECT().Prepare(fmt.Sprintf(listKeysStmtFmt, "testStoreObject")).Return(&sql.Stmt{})