1
0
mirror of https://github.com/rancher/steve.git synced 2025-09-04 17:01:16 +00:00

Add object to RegisterAfterDelete and introduce RegisterAfterDeleteAll (#649)

* Add object to AfterDelete callbacks

* Add RegisterAfterDeleteAll
This commit is contained in:
Tom Lebreux
2025-06-03 15:32:43 -06:00
committed by GitHub
parent a8f3ce48d6
commit 2672969496
6 changed files with 153 additions and 123 deletions

View File

@@ -72,7 +72,8 @@ type Store interface {
GetName() string GetName() string
RegisterAfterAdd(f func(key string, obj any, tx transaction.Client) error) RegisterAfterAdd(f func(key string, obj any, tx transaction.Client) error)
RegisterAfterUpdate(f func(key string, obj any, tx transaction.Client) error) RegisterAfterUpdate(f func(key string, obj any, tx transaction.Client) error)
RegisterAfterDelete(f func(key string, tx transaction.Client) error) RegisterAfterDelete(f func(key string, obj any, tx transaction.Client) error)
RegisterAfterDeleteAll(f func(tx transaction.Client) error)
GetShouldEncrypt() bool GetShouldEncrypt() bool
GetType() reflect.Type GetType() reflect.Type
} }

View File

@@ -28,15 +28,19 @@ type ListOptionIndexer struct {
namespaced bool namespaced bool
indexedFields []string indexedFields []string
addFieldQuery string addFieldsQuery string
deleteFieldQuery string deleteFieldsByKeyQuery string
upsertLabelsQuery string deleteFieldsQuery string
deleteLabelsQuery string upsertLabelsQuery string
deleteLabelsByKeyQuery string
deleteLabelsQuery string
addFieldStmt *sql.Stmt addFieldsStmt *sql.Stmt
deleteFieldStmt *sql.Stmt deleteFieldsByKeyStmt *sql.Stmt
upsertLabelsStmt *sql.Stmt deleteFieldsStmt *sql.Stmt
deleteLabelsStmt *sql.Stmt upsertLabelsStmt *sql.Stmt
deleteLabelsByKeyStmt *sql.Stmt
deleteLabelsStmt *sql.Stmt
} }
var ( var (
@@ -56,6 +60,7 @@ const (
%s %s
)` )`
createFieldsIndexFmt = `CREATE INDEX "%s_%s_index" ON "%s_fields"("%s")` createFieldsIndexFmt = `CREATE INDEX "%s_%s_index" ON "%s_fields"("%s")`
deleteFieldsFmt = `DELETE FROM "%s_fields"`
failedToGetFromSliceFmt = "[listoption indexer] failed to get subfield [%s] from slice items" failedToGetFromSliceFmt = "[listoption indexer] failed to get subfield [%s] from slice items"
@@ -67,8 +72,9 @@ const (
)` )`
createLabelsTableIndexFmt = `CREATE INDEX IF NOT EXISTS "%s_labels_index" ON "%s_labels"(label, value)` createLabelsTableIndexFmt = `CREATE INDEX IF NOT EXISTS "%s_labels_index" ON "%s_labels"(label, value)`
upsertLabelsStmtFmt = `REPLACE INTO "%s_labels"(key, label, value) VALUES (?, ?, ?)` upsertLabelsStmtFmt = `REPLACE INTO "%s_labels"(key, label, value) VALUES (?, ?, ?)`
deleteLabelsStmtFmt = `DELETE FROM "%s_labels" WHERE KEY = ?` deleteLabelsByKeyStmtFmt = `DELETE FROM "%s_labels" WHERE KEY = ?`
deleteLabelsStmtFmt = `DELETE FROM "%s_labels"`
) )
// NewListOptionIndexer returns a SQLite-backed cache.Indexer of unstructured.Unstructured Kubernetes resources of a certain GVK // NewListOptionIndexer returns a SQLite-backed cache.Indexer of unstructured.Unstructured Kubernetes resources of a certain GVK
@@ -104,8 +110,10 @@ func NewListOptionIndexer(ctx context.Context, fields [][]string, s Store, names
l.RegisterAfterUpdate(l.addIndexFields) l.RegisterAfterUpdate(l.addIndexFields)
l.RegisterAfterAdd(l.addLabels) l.RegisterAfterAdd(l.addLabels)
l.RegisterAfterUpdate(l.addLabels) l.RegisterAfterUpdate(l.addLabels)
l.RegisterAfterDelete(l.deleteIndexFields) l.RegisterAfterDelete(l.deleteFieldsByKey)
l.RegisterAfterDelete(l.deleteLabels) l.RegisterAfterDelete(l.deleteLabelsByKey)
l.RegisterAfterDeleteAll(l.deleteFields)
l.RegisterAfterDeleteAll(l.deleteLabels)
columnDefs := make([]string, len(indexedFields)) columnDefs := make([]string, len(indexedFields))
for index, field := range indexedFields { for index, field := range indexedFields {
column := fmt.Sprintf(`"%s" TEXT`, field) column := fmt.Sprintf(`"%s" TEXT`, field)
@@ -159,21 +167,25 @@ func NewListOptionIndexer(ctx context.Context, fields [][]string, s Store, names
return nil, err return nil, err
} }
l.addFieldQuery = fmt.Sprintf( l.addFieldsQuery = fmt.Sprintf(
`INSERT INTO "%s_fields"(key, %s) VALUES (?, %s) ON CONFLICT DO UPDATE SET %s`, `INSERT INTO "%s_fields"(key, %s) VALUES (?, %s) ON CONFLICT DO UPDATE SET %s`,
dbName, dbName,
strings.Join(columns, ", "), strings.Join(columns, ", "),
strings.Join(qmarks, ", "), strings.Join(qmarks, ", "),
strings.Join(setStatements, ", "), strings.Join(setStatements, ", "),
) )
l.deleteFieldQuery = fmt.Sprintf(`DELETE FROM "%s_fields" WHERE key = ?`, dbName) l.deleteFieldsByKeyQuery = fmt.Sprintf(`DELETE FROM "%s_fields" WHERE key = ?`, dbName)
l.deleteFieldsQuery = fmt.Sprintf(deleteFieldsFmt, dbName)
l.addFieldStmt = l.Prepare(l.addFieldQuery) l.addFieldsStmt = l.Prepare(l.addFieldsQuery)
l.deleteFieldStmt = l.Prepare(l.deleteFieldQuery) l.deleteFieldsByKeyStmt = l.Prepare(l.deleteFieldsByKeyQuery)
l.deleteFieldsStmt = l.Prepare(l.deleteFieldsQuery)
l.upsertLabelsQuery = fmt.Sprintf(upsertLabelsStmtFmt, dbName) l.upsertLabelsQuery = fmt.Sprintf(upsertLabelsStmtFmt, dbName)
l.deleteLabelsByKeyQuery = fmt.Sprintf(deleteLabelsByKeyStmtFmt, dbName)
l.deleteLabelsQuery = fmt.Sprintf(deleteLabelsStmtFmt, dbName) l.deleteLabelsQuery = fmt.Sprintf(deleteLabelsStmtFmt, dbName)
l.upsertLabelsStmt = l.Prepare(l.upsertLabelsQuery) l.upsertLabelsStmt = l.Prepare(l.upsertLabelsQuery)
l.deleteLabelsByKeyStmt = l.Prepare(l.deleteLabelsByKeyQuery)
l.deleteLabelsStmt = l.Prepare(l.deleteLabelsQuery) l.deleteLabelsStmt = l.Prepare(l.deleteLabelsQuery)
return l, nil return l, nil
@@ -203,9 +215,9 @@ func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx transaction.C
} }
} }
_, err := tx.Stmt(l.addFieldStmt).Exec(args...) _, err := tx.Stmt(l.addFieldsStmt).Exec(args...)
if err != nil { if err != nil {
return &db.QueryError{QueryString: l.addFieldQuery, Err: err} return &db.QueryError{QueryString: l.addFieldsQuery, Err: err}
} }
return nil return nil
} }
@@ -226,18 +238,34 @@ func (l *ListOptionIndexer) addLabels(key string, obj any, tx transaction.Client
return nil return nil
} }
func (l *ListOptionIndexer) deleteIndexFields(key string, tx transaction.Client) error { func (l *ListOptionIndexer) deleteFieldsByKey(key string, _ any, tx transaction.Client) error {
args := []any{key} args := []any{key}
_, err := tx.Stmt(l.deleteFieldStmt).Exec(args...) _, err := tx.Stmt(l.deleteFieldsByKeyStmt).Exec(args...)
if err != nil { if err != nil {
return &db.QueryError{QueryString: l.deleteFieldQuery, Err: err} return &db.QueryError{QueryString: l.deleteFieldsByKeyQuery, Err: err}
} }
return nil return nil
} }
func (l *ListOptionIndexer) deleteLabels(key string, tx transaction.Client) error { func (l *ListOptionIndexer) deleteFields(tx transaction.Client) error {
_, err := tx.Stmt(l.deleteLabelsStmt).Exec(key) _, err := tx.Stmt(l.deleteFieldsStmt).Exec()
if err != nil {
return &db.QueryError{QueryString: l.deleteFieldsQuery, Err: err}
}
return nil
}
func (l *ListOptionIndexer) deleteLabelsByKey(key string, _ any, tx transaction.Client) error {
_, err := tx.Stmt(l.deleteLabelsByKeyStmt).Exec(key)
if err != nil {
return &db.QueryError{QueryString: l.deleteLabelsByKeyQuery, Err: err}
}
return nil
}
func (l *ListOptionIndexer) deleteLabels(tx transaction.Client) error {
_, err := tx.Stmt(l.deleteLabelsStmt).Exec()
if err != nil { if err != nil {
return &db.QueryError{QueryString: l.deleteLabelsQuery, Err: err} return &db.QueryError{QueryString: l.deleteLabelsQuery, Err: err}
} }

View File

@@ -91,6 +91,7 @@ func TestNewListOptionIndexer(t *testing.T) {
store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2)
// create field table // create field table
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil)
@@ -158,6 +159,7 @@ func TestNewListOptionIndexer(t *testing.T) {
store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2)
store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")) store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error"))
@@ -189,6 +191,7 @@ func TestNewListOptionIndexer(t *testing.T) {
store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, fmt.Errorf("error")) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, fmt.Errorf("error"))
@@ -228,6 +231,7 @@ func TestNewListOptionIndexer(t *testing.T) {
store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, nil) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, nil)
@@ -271,6 +275,7 @@ func TestNewListOptionIndexer(t *testing.T) {
store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2) store.EXPECT().RegisterAfterAdd(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2) store.EXPECT().RegisterAfterUpdate(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2)
store.EXPECT().RegisterAfterDeleteAll(gomock.Any()).Times(2)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil)
txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, nil) txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, nil)

View File

@@ -292,7 +292,7 @@ func (mr *MockStoreMockRecorder) RegisterAfterAdd(arg0 any) *gomock.Call {
} }
// RegisterAfterDelete mocks base method. // RegisterAfterDelete mocks base method.
func (m *MockStore) RegisterAfterDelete(arg0 func(string, transaction.Client) error) { func (m *MockStore) RegisterAfterDelete(arg0 func(string, any, transaction.Client) error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "RegisterAfterDelete", arg0) m.ctrl.Call(m, "RegisterAfterDelete", arg0)
} }
@@ -303,6 +303,18 @@ func (mr *MockStoreMockRecorder) RegisterAfterDelete(arg0 any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterDelete", reflect.TypeOf((*MockStore)(nil).RegisterAfterDelete), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterDelete", reflect.TypeOf((*MockStore)(nil).RegisterAfterDelete), arg0)
} }
// RegisterAfterDeleteAll mocks base method.
func (m *MockStore) RegisterAfterDeleteAll(arg0 func(transaction.Client) error) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RegisterAfterDeleteAll", arg0)
}
// RegisterAfterDeleteAll indicates an expected call of RegisterAfterDeleteAll.
func (mr *MockStoreMockRecorder) RegisterAfterDeleteAll(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterAfterDeleteAll", reflect.TypeOf((*MockStore)(nil).RegisterAfterDeleteAll), arg0)
}
// RegisterAfterUpdate mocks base method. // RegisterAfterUpdate mocks base method.
func (m *MockStore) RegisterAfterUpdate(arg0 func(string, any, transaction.Client) error) { func (m *MockStore) RegisterAfterUpdate(arg0 func(string, any, transaction.Client) error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -19,12 +19,13 @@ import (
) )
const ( const (
upsertStmtFmt = `REPLACE INTO "%s"(key, object, objectnonce, dekid) VALUES (?, ?, ?, ?)` upsertStmtFmt = `REPLACE INTO "%s"(key, object, objectnonce, dekid) VALUES (?, ?, ?, ?)`
deleteStmtFmt = `DELETE FROM "%s" WHERE key = ?` deleteStmtFmt = `DELETE FROM "%s" WHERE key = ?`
getStmtFmt = `SELECT object, objectnonce, dekid FROM "%s" WHERE key = ?` deleteAllStmtFmt = `DELETE FROM "%s"`
listStmtFmt = `SELECT object, objectnonce, dekid FROM "%s"` getStmtFmt = `SELECT object, objectnonce, dekid FROM "%s" WHERE key = ?`
listKeysStmtFmt = `SELECT key FROM "%s"` listStmtFmt = `SELECT object, objectnonce, dekid FROM "%s"`
createTableFmt = `CREATE TABLE IF NOT EXISTS "%s" ( listKeysStmtFmt = `SELECT key FROM "%s"`
createTableFmt = `CREATE TABLE IF NOT EXISTS "%s" (
key TEXT UNIQUE NOT NULL PRIMARY KEY, key TEXT UNIQUE NOT NULL PRIMARY KEY,
object BLOB, object BLOB,
objectnonce BLOB, objectnonce BLOB,
@@ -42,21 +43,24 @@ type Store struct {
keyFunc cache.KeyFunc keyFunc cache.KeyFunc
shouldEncrypt bool shouldEncrypt bool
upsertQuery string upsertQuery string
deleteQuery string deleteQuery string
getQuery string deleteAllQuery string
listQuery string getQuery string
listKeysQuery string listQuery string
listKeysQuery string
upsertStmt *sql.Stmt upsertStmt *sql.Stmt
deleteStmt *sql.Stmt deleteStmt *sql.Stmt
getStmt *sql.Stmt deleteAllStmt *sql.Stmt
listStmt *sql.Stmt getStmt *sql.Stmt
listKeysStmt *sql.Stmt listStmt *sql.Stmt
listKeysStmt *sql.Stmt
afterAdd []func(key string, obj any, tx transaction.Client) error afterAdd []func(key string, obj any, tx transaction.Client) error
afterUpdate []func(key string, obj any, tx transaction.Client) error afterUpdate []func(key string, obj any, tx transaction.Client) error
afterDelete []func(key string, tx transaction.Client) error afterDelete []func(key string, obj any, tx transaction.Client) error
afterDeleteAll []func(tx transaction.Client) error
} }
// Test that Store implements cache.Indexer // Test that Store implements cache.Indexer
@@ -65,15 +69,16 @@ var _ cache.Store = (*Store)(nil)
// NewStore creates a SQLite-backed cache.Store for objects of the given example type // NewStore creates a SQLite-backed cache.Store for objects of the given example type
func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Client, shouldEncrypt bool, name string) (*Store, error) { func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Client, shouldEncrypt bool, name string) (*Store, error) {
s := &Store{ s := &Store{
ctx: ctx, ctx: ctx,
name: name, name: name,
typ: reflect.TypeOf(example), typ: reflect.TypeOf(example),
Client: c, Client: c,
keyFunc: keyFunc, keyFunc: keyFunc,
shouldEncrypt: shouldEncrypt, shouldEncrypt: shouldEncrypt,
afterAdd: []func(key string, obj any, tx transaction.Client) error{}, afterAdd: []func(key string, obj any, tx transaction.Client) error{},
afterUpdate: []func(key string, obj any, tx transaction.Client) error{}, afterUpdate: []func(key string, obj any, tx transaction.Client) error{},
afterDelete: []func(key string, tx transaction.Client) error{}, afterDelete: []func(key string, obj any, tx transaction.Client) error{},
afterDeleteAll: []func(tx transaction.Client) error{},
} }
dbName := db.Sanitize(s.name) dbName := db.Sanitize(s.name)
@@ -94,12 +99,14 @@ func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Clie
s.upsertQuery = fmt.Sprintf(upsertStmtFmt, dbName) s.upsertQuery = fmt.Sprintf(upsertStmtFmt, dbName)
s.deleteQuery = fmt.Sprintf(deleteStmtFmt, dbName) s.deleteQuery = fmt.Sprintf(deleteStmtFmt, dbName)
s.deleteAllQuery = fmt.Sprintf(deleteAllStmtFmt, dbName)
s.getQuery = fmt.Sprintf(getStmtFmt, dbName) s.getQuery = fmt.Sprintf(getStmtFmt, dbName)
s.listQuery = fmt.Sprintf(listStmtFmt, dbName) s.listQuery = fmt.Sprintf(listStmtFmt, dbName)
s.listKeysQuery = fmt.Sprintf(listKeysStmtFmt, dbName) s.listKeysQuery = fmt.Sprintf(listKeysStmtFmt, dbName)
s.upsertStmt = s.Prepare(s.upsertQuery) s.upsertStmt = s.Prepare(s.upsertQuery)
s.deleteStmt = s.Prepare(s.deleteQuery) s.deleteStmt = s.Prepare(s.deleteQuery)
s.deleteAllStmt = s.Prepare(s.deleteAllQuery)
s.getStmt = s.Prepare(s.getQuery) s.getStmt = s.Prepare(s.getQuery)
s.listStmt = s.Prepare(s.listQuery) s.listStmt = s.Prepare(s.listQuery)
s.listKeysStmt = s.Prepare(s.listKeysQuery) s.listKeysStmt = s.Prepare(s.listKeysQuery)
@@ -110,14 +117,14 @@ func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Clie
/* Core methods */ /* Core methods */
// deleteByKey deletes the object associated with key, if it exists in this Store // deleteByKey deletes the object associated with key, if it exists in this Store
func (s *Store) deleteByKey(key string) error { func (s *Store) deleteByKey(key string, obj any) error {
return s.WithTransaction(s.ctx, true, func(tx transaction.Client) error { return s.WithTransaction(s.ctx, true, func(tx transaction.Client) error {
_, err := tx.Stmt(s.deleteStmt).Exec(key) _, err := tx.Stmt(s.deleteStmt).Exec(key)
if err != nil { if err != nil {
return &db.QueryError{QueryString: s.deleteQuery, Err: err} return &db.QueryError{QueryString: s.deleteQuery, Err: err}
} }
err = s.runAfterDelete(key, tx) err = s.runAfterDelete(key, obj, tx)
if err != nil { if err != nil {
return err return err
} }
@@ -206,7 +213,7 @@ func (s *Store) Delete(obj any) error {
if err != nil { if err != nil {
return err return err
} }
err = s.deleteByKey(key) err = s.deleteByKey(key, obj)
if err != nil { if err != nil {
log.Errorf("Error in Store.Delete for type %v: %v", s.name, err) log.Errorf("Error in Store.Delete for type %v: %v", s.name, err)
return err return err
@@ -266,32 +273,25 @@ func (s *Store) Replace(objects []any, _ string) error {
} }
objectMap[key] = object objectMap[key] = object
} }
return s.replaceByKey(objectMap) err := s.replaceByKey(objectMap)
if err != nil {
log.Errorf("Error in Store.Replace for type %v: %v", s.name, err)
return err
}
return nil
} }
// replaceByKey will delete the contents of the Store, using instead the given key to obj map // replaceByKey will delete the contents of the Store, using instead the given key to obj map
func (s *Store) replaceByKey(objects map[string]any) error { func (s *Store) replaceByKey(objects map[string]any) error {
return s.WithTransaction(s.ctx, true, func(txC transaction.Client) error { return s.WithTransaction(s.ctx, true, func(txC transaction.Client) error {
txCListKeys := txC.Stmt(s.listKeysStmt) _, err := txC.Stmt(s.deleteAllStmt).Exec()
rows, err := s.QueryForRows(s.ctx, txCListKeys)
if err != nil { if err != nil {
return err return &db.QueryError{QueryString: s.deleteAllQuery, Err: err}
}
keys, err := s.ReadStrings(rows)
if err != nil {
return err
} }
for _, key := range keys { err = s.runAfterDeleteAll(txC)
_, err = txC.Stmt(s.deleteStmt).Exec(key) if err != nil {
if err != nil { return err
return err
}
err = s.runAfterDelete(key, txC)
if err != nil {
return err
}
} }
for key, obj := range objects { for key, obj := range objects {
@@ -339,10 +339,15 @@ func (s *Store) RegisterAfterUpdate(f func(key string, obj any, txC transaction.
} }
// RegisterAfterDelete registers a func to be called after each deletion // RegisterAfterDelete registers a func to be called after each deletion
func (s *Store) RegisterAfterDelete(f func(key string, txC transaction.Client) error) { func (s *Store) RegisterAfterDelete(f func(key string, obj any, txC transaction.Client) error) {
s.afterDelete = append(s.afterDelete, f) s.afterDelete = append(s.afterDelete, f)
} }
// RegisterAfterDelete registers a func to be called after each deletion
func (s *Store) RegisterAfterDeleteAll(f func(txC transaction.Client) error) {
s.afterDeleteAll = append(s.afterDeleteAll, f)
}
// runAfterAdd executes functions registered to run after add event // runAfterAdd executes functions registered to run after add event
func (s *Store) runAfterAdd(key string, obj any, txC transaction.Client) error { func (s *Store) runAfterAdd(key string, obj any, txC transaction.Client) error {
for _, f := range s.afterAdd { for _, f := range s.afterAdd {
@@ -366,9 +371,21 @@ func (s *Store) runAfterUpdate(key string, obj any, txC transaction.Client) erro
} }
// runAfterDelete executes functions registered to run after delete event // runAfterDelete executes functions registered to run after delete event
func (s *Store) runAfterDelete(key string, txC transaction.Client) error { func (s *Store) runAfterDelete(key string, obj any, txC transaction.Client) error {
for _, f := range s.afterDelete { for _, f := range s.afterDelete {
err := f(key, txC) err := f(key, obj, txC)
if err != nil {
return err
}
}
return nil
}
// runAfterDeleteAll executes functions registered to run after delete events when
// the database is being replaced.
func (s *Store) runAfterDeleteAll(txC transaction.Client) error {
for _, f := range s.afterDeleteAll {
err := f(txC)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -550,14 +550,10 @@ func TestReplace(t *testing.T) {
tests = append(tests, testCase{description: "Replace with no DB client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { tests = append(tests, testCase{description: "Replace with no DB client errors and some items", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t) c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt) store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
stmt := NewMockStmt(gomock.NewController(t)) stmt := NewMockStmt(gomock.NewController(t))
txC.EXPECT().Stmt(store.listKeysStmt).Return(stmt) txC.EXPECT().Stmt(store.deleteAllStmt).Return(stmt)
c.EXPECT().QueryForRows(context.Background(), stmt).Return(r, nil) stmt.EXPECT().Exec()
c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil)
txC.EXPECT().Stmt(store.deleteStmt).Return(stmt)
stmt.EXPECT().Exec(testObject.Id)
c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt) c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt)
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do(
@@ -575,11 +571,10 @@ func TestReplace(t *testing.T) {
tests = append(tests, testCase{description: "Replace with no DB client errors and no items", test: func(t *testing.T, shouldEncrypt bool) { tests = append(tests, testCase{description: "Replace with no DB client errors and no items", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t) c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt) store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) stmt := NewMockStmt(gomock.NewController(t))
c.EXPECT().QueryForRows(context.Background(), store.listKeysStmt).Return(r, nil) txC.EXPECT().Stmt(store.deleteAllStmt).Return(stmt)
c.EXPECT().ReadStrings(r).Return([]string{}, nil) stmt.EXPECT().Exec()
c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt) c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt)
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do(
@@ -602,39 +597,15 @@ func TestReplace(t *testing.T) {
assert.NotNil(t, err) assert.NotNil(t, err)
}, },
}) })
tests = append(tests, testCase{description: "Replace with DB client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { tests = append(tests, testCase{description: "Replace with DB client deleteAllStmt error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t) c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt) store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) deleteAllStmt := NewMockStmt(gomock.NewController(t))
c.EXPECT().QueryForRows(context.Background(), store.listKeysStmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error"))
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( txC.EXPECT().Stmt(store.deleteAllStmt).Return(deleteAllStmt)
func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { deleteAllStmt.EXPECT().Exec().Return(nil, fmt.Errorf("error"))
err := f(txC)
if err == nil {
t.Fail()
}
})
err := store.Replace([]any{testObject}, testObject.Id)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Replace with TX client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{}
listKeysStmt := NewMockStmt(gomock.NewController(t))
deleteStmt := NewMockStmt(gomock.NewController(t))
txC.EXPECT().Stmt(store.listKeysStmt).Return(listKeysStmt)
c.EXPECT().QueryForRows(context.Background(), listKeysStmt).Return(r, nil)
c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil)
txC.EXPECT().Stmt(store.deleteStmt).Return(deleteStmt)
deleteStmt.EXPECT().Exec(testObject.Id).Return(nil, fmt.Errorf("error"))
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do(
func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) {
err := f(txC) err := f(txC)
@@ -650,15 +621,10 @@ func TestReplace(t *testing.T) {
tests = append(tests, testCase{description: "Replace with DB client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { tests = append(tests, testCase{description: "Replace with DB client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) {
c, txC := SetupMockDB(t) c, txC := SetupMockDB(t)
store := SetupStore(t, c, shouldEncrypt) store := SetupStore(t, c, shouldEncrypt)
r := &sql.Rows{} deleteAllStmt := NewMockStmt(gomock.NewController(t))
listKeysStmt := NewMockStmt(gomock.NewController(t))
deleteStmt := NewMockStmt(gomock.NewController(t))
txC.EXPECT().Stmt(store.listKeysStmt).Return(listKeysStmt) txC.EXPECT().Stmt(store.deleteAllStmt).Return(deleteAllStmt)
c.EXPECT().QueryForRows(context.Background(), listKeysStmt).Return(r, nil) deleteAllStmt.EXPECT().Exec()
c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil)
txC.EXPECT().Stmt(store.deleteStmt).Return(deleteStmt)
deleteStmt.EXPECT().Exec(testObject.Id).Return(nil, nil)
c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt).Return(fmt.Errorf("error")) c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt).Return(fmt.Errorf("error"))
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do(
@@ -719,6 +685,7 @@ func SetupMockDB(t *testing.T) (*MockClient, *MockTXClient) {
// use stmt mock here // use stmt mock here
dbC.EXPECT().Prepare(fmt.Sprintf(upsertStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) dbC.EXPECT().Prepare(fmt.Sprintf(upsertStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
dbC.EXPECT().Prepare(fmt.Sprintf(deleteStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) dbC.EXPECT().Prepare(fmt.Sprintf(deleteStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
dbC.EXPECT().Prepare(fmt.Sprintf(deleteAllStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
dbC.EXPECT().Prepare(fmt.Sprintf(getStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) dbC.EXPECT().Prepare(fmt.Sprintf(getStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
dbC.EXPECT().Prepare(fmt.Sprintf(listStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) dbC.EXPECT().Prepare(fmt.Sprintf(listStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
dbC.EXPECT().Prepare(fmt.Sprintf(listKeysStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) dbC.EXPECT().Prepare(fmt.Sprintf(listKeysStmtFmt, "testStoreObject")).Return(&sql.Stmt{})