mirror of
https://github.com/rancher/steve.git
synced 2025-09-03 08:25:13 +00:00
Add object to RegisterAfterDelete and introduce RegisterAfterDeleteAll (#649)
* Add object to AfterDelete callbacks * Add RegisterAfterDeleteAll
This commit is contained in:
@@ -19,12 +19,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
upsertStmtFmt = `REPLACE INTO "%s"(key, object, objectnonce, dekid) VALUES (?, ?, ?, ?)`
|
||||
deleteStmtFmt = `DELETE FROM "%s" WHERE key = ?`
|
||||
getStmtFmt = `SELECT object, objectnonce, dekid FROM "%s" WHERE key = ?`
|
||||
listStmtFmt = `SELECT object, objectnonce, dekid FROM "%s"`
|
||||
listKeysStmtFmt = `SELECT key FROM "%s"`
|
||||
createTableFmt = `CREATE TABLE IF NOT EXISTS "%s" (
|
||||
upsertStmtFmt = `REPLACE INTO "%s"(key, object, objectnonce, dekid) VALUES (?, ?, ?, ?)`
|
||||
deleteStmtFmt = `DELETE FROM "%s" WHERE key = ?`
|
||||
deleteAllStmtFmt = `DELETE FROM "%s"`
|
||||
getStmtFmt = `SELECT object, objectnonce, dekid FROM "%s" WHERE key = ?`
|
||||
listStmtFmt = `SELECT object, objectnonce, dekid FROM "%s"`
|
||||
listKeysStmtFmt = `SELECT key FROM "%s"`
|
||||
createTableFmt = `CREATE TABLE IF NOT EXISTS "%s" (
|
||||
key TEXT UNIQUE NOT NULL PRIMARY KEY,
|
||||
object BLOB,
|
||||
objectnonce BLOB,
|
||||
@@ -42,21 +43,24 @@ type Store struct {
|
||||
keyFunc cache.KeyFunc
|
||||
shouldEncrypt bool
|
||||
|
||||
upsertQuery string
|
||||
deleteQuery string
|
||||
getQuery string
|
||||
listQuery string
|
||||
listKeysQuery string
|
||||
upsertQuery string
|
||||
deleteQuery string
|
||||
deleteAllQuery string
|
||||
getQuery string
|
||||
listQuery string
|
||||
listKeysQuery string
|
||||
|
||||
upsertStmt *sql.Stmt
|
||||
deleteStmt *sql.Stmt
|
||||
getStmt *sql.Stmt
|
||||
listStmt *sql.Stmt
|
||||
listKeysStmt *sql.Stmt
|
||||
upsertStmt *sql.Stmt
|
||||
deleteStmt *sql.Stmt
|
||||
deleteAllStmt *sql.Stmt
|
||||
getStmt *sql.Stmt
|
||||
listStmt *sql.Stmt
|
||||
listKeysStmt *sql.Stmt
|
||||
|
||||
afterAdd []func(key string, obj any, tx transaction.Client) error
|
||||
afterUpdate []func(key string, obj any, tx transaction.Client) error
|
||||
afterDelete []func(key string, tx transaction.Client) error
|
||||
afterAdd []func(key string, obj any, tx transaction.Client) error
|
||||
afterUpdate []func(key string, obj any, tx transaction.Client) error
|
||||
afterDelete []func(key string, obj any, tx transaction.Client) error
|
||||
afterDeleteAll []func(tx transaction.Client) error
|
||||
}
|
||||
|
||||
// Test that Store implements cache.Indexer
|
||||
@@ -65,15 +69,16 @@ var _ cache.Store = (*Store)(nil)
|
||||
// NewStore creates a SQLite-backed cache.Store for objects of the given example type
|
||||
func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Client, shouldEncrypt bool, name string) (*Store, error) {
|
||||
s := &Store{
|
||||
ctx: ctx,
|
||||
name: name,
|
||||
typ: reflect.TypeOf(example),
|
||||
Client: c,
|
||||
keyFunc: keyFunc,
|
||||
shouldEncrypt: shouldEncrypt,
|
||||
afterAdd: []func(key string, obj any, tx transaction.Client) error{},
|
||||
afterUpdate: []func(key string, obj any, tx transaction.Client) error{},
|
||||
afterDelete: []func(key string, tx transaction.Client) error{},
|
||||
ctx: ctx,
|
||||
name: name,
|
||||
typ: reflect.TypeOf(example),
|
||||
Client: c,
|
||||
keyFunc: keyFunc,
|
||||
shouldEncrypt: shouldEncrypt,
|
||||
afterAdd: []func(key string, obj any, tx transaction.Client) error{},
|
||||
afterUpdate: []func(key string, obj any, tx transaction.Client) error{},
|
||||
afterDelete: []func(key string, obj any, tx transaction.Client) error{},
|
||||
afterDeleteAll: []func(tx transaction.Client) error{},
|
||||
}
|
||||
|
||||
dbName := db.Sanitize(s.name)
|
||||
@@ -94,12 +99,14 @@ func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Clie
|
||||
|
||||
s.upsertQuery = fmt.Sprintf(upsertStmtFmt, dbName)
|
||||
s.deleteQuery = fmt.Sprintf(deleteStmtFmt, dbName)
|
||||
s.deleteAllQuery = fmt.Sprintf(deleteAllStmtFmt, dbName)
|
||||
s.getQuery = fmt.Sprintf(getStmtFmt, dbName)
|
||||
s.listQuery = fmt.Sprintf(listStmtFmt, dbName)
|
||||
s.listKeysQuery = fmt.Sprintf(listKeysStmtFmt, dbName)
|
||||
|
||||
s.upsertStmt = s.Prepare(s.upsertQuery)
|
||||
s.deleteStmt = s.Prepare(s.deleteQuery)
|
||||
s.deleteAllStmt = s.Prepare(s.deleteAllQuery)
|
||||
s.getStmt = s.Prepare(s.getQuery)
|
||||
s.listStmt = s.Prepare(s.listQuery)
|
||||
s.listKeysStmt = s.Prepare(s.listKeysQuery)
|
||||
@@ -110,14 +117,14 @@ func NewStore(ctx context.Context, example any, keyFunc cache.KeyFunc, c db.Clie
|
||||
/* Core methods */
|
||||
|
||||
// deleteByKey deletes the object associated with key, if it exists in this Store
|
||||
func (s *Store) deleteByKey(key string) error {
|
||||
func (s *Store) deleteByKey(key string, obj any) error {
|
||||
return s.WithTransaction(s.ctx, true, func(tx transaction.Client) error {
|
||||
_, err := tx.Stmt(s.deleteStmt).Exec(key)
|
||||
if err != nil {
|
||||
return &db.QueryError{QueryString: s.deleteQuery, Err: err}
|
||||
}
|
||||
|
||||
err = s.runAfterDelete(key, tx)
|
||||
err = s.runAfterDelete(key, obj, tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -206,7 +213,7 @@ func (s *Store) Delete(obj any) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.deleteByKey(key)
|
||||
err = s.deleteByKey(key, obj)
|
||||
if err != nil {
|
||||
log.Errorf("Error in Store.Delete for type %v: %v", s.name, err)
|
||||
return err
|
||||
@@ -266,32 +273,25 @@ func (s *Store) Replace(objects []any, _ string) error {
|
||||
}
|
||||
objectMap[key] = object
|
||||
}
|
||||
return s.replaceByKey(objectMap)
|
||||
err := s.replaceByKey(objectMap)
|
||||
if err != nil {
|
||||
log.Errorf("Error in Store.Replace for type %v: %v", s.name, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// replaceByKey will delete the contents of the Store, using instead the given key to obj map
|
||||
func (s *Store) replaceByKey(objects map[string]any) error {
|
||||
return s.WithTransaction(s.ctx, true, func(txC transaction.Client) error {
|
||||
txCListKeys := txC.Stmt(s.listKeysStmt)
|
||||
|
||||
rows, err := s.QueryForRows(s.ctx, txCListKeys)
|
||||
_, err := txC.Stmt(s.deleteAllStmt).Exec()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keys, err := s.ReadStrings(rows)
|
||||
if err != nil {
|
||||
return err
|
||||
return &db.QueryError{QueryString: s.deleteAllQuery, Err: err}
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
_, err = txC.Stmt(s.deleteStmt).Exec(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.runAfterDelete(key, txC)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = s.runAfterDeleteAll(txC)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for key, obj := range objects {
|
||||
@@ -339,10 +339,15 @@ func (s *Store) RegisterAfterUpdate(f func(key string, obj any, txC transaction.
|
||||
}
|
||||
|
||||
// RegisterAfterDelete registers a func to be called after each deletion
|
||||
func (s *Store) RegisterAfterDelete(f func(key string, txC transaction.Client) error) {
|
||||
func (s *Store) RegisterAfterDelete(f func(key string, obj any, txC transaction.Client) error) {
|
||||
s.afterDelete = append(s.afterDelete, f)
|
||||
}
|
||||
|
||||
// RegisterAfterDelete registers a func to be called after each deletion
|
||||
func (s *Store) RegisterAfterDeleteAll(f func(txC transaction.Client) error) {
|
||||
s.afterDeleteAll = append(s.afterDeleteAll, f)
|
||||
}
|
||||
|
||||
// runAfterAdd executes functions registered to run after add event
|
||||
func (s *Store) runAfterAdd(key string, obj any, txC transaction.Client) error {
|
||||
for _, f := range s.afterAdd {
|
||||
@@ -366,9 +371,21 @@ func (s *Store) runAfterUpdate(key string, obj any, txC transaction.Client) erro
|
||||
}
|
||||
|
||||
// runAfterDelete executes functions registered to run after delete event
|
||||
func (s *Store) runAfterDelete(key string, txC transaction.Client) error {
|
||||
func (s *Store) runAfterDelete(key string, obj any, txC transaction.Client) error {
|
||||
for _, f := range s.afterDelete {
|
||||
err := f(key, txC)
|
||||
err := f(key, obj, txC)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runAfterDeleteAll executes functions registered to run after delete events when
|
||||
// the database is being replaced.
|
||||
func (s *Store) runAfterDeleteAll(txC transaction.Client) error {
|
||||
for _, f := range s.afterDeleteAll {
|
||||
err := f(txC)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -550,14 +550,10 @@ func TestReplace(t *testing.T) {
|
||||
tests = append(tests, testCase{description: "Replace with no DB client errors and some items", test: func(t *testing.T, shouldEncrypt bool) {
|
||||
c, txC := SetupMockDB(t)
|
||||
store := SetupStore(t, c, shouldEncrypt)
|
||||
r := &sql.Rows{}
|
||||
stmt := NewMockStmt(gomock.NewController(t))
|
||||
|
||||
txC.EXPECT().Stmt(store.listKeysStmt).Return(stmt)
|
||||
c.EXPECT().QueryForRows(context.Background(), stmt).Return(r, nil)
|
||||
c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil)
|
||||
txC.EXPECT().Stmt(store.deleteStmt).Return(stmt)
|
||||
stmt.EXPECT().Exec(testObject.Id)
|
||||
txC.EXPECT().Stmt(store.deleteAllStmt).Return(stmt)
|
||||
stmt.EXPECT().Exec()
|
||||
c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt)
|
||||
|
||||
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do(
|
||||
@@ -575,11 +571,10 @@ func TestReplace(t *testing.T) {
|
||||
tests = append(tests, testCase{description: "Replace with no DB client errors and no items", test: func(t *testing.T, shouldEncrypt bool) {
|
||||
c, txC := SetupMockDB(t)
|
||||
store := SetupStore(t, c, shouldEncrypt)
|
||||
r := &sql.Rows{}
|
||||
|
||||
txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt)
|
||||
c.EXPECT().QueryForRows(context.Background(), store.listKeysStmt).Return(r, nil)
|
||||
c.EXPECT().ReadStrings(r).Return([]string{}, nil)
|
||||
stmt := NewMockStmt(gomock.NewController(t))
|
||||
txC.EXPECT().Stmt(store.deleteAllStmt).Return(stmt)
|
||||
stmt.EXPECT().Exec()
|
||||
c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt)
|
||||
|
||||
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do(
|
||||
@@ -602,39 +597,15 @@ func TestReplace(t *testing.T) {
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "Replace with DB client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) {
|
||||
tests = append(tests, testCase{description: "Replace with DB client deleteAllStmt error", test: func(t *testing.T, shouldEncrypt bool) {
|
||||
c, txC := SetupMockDB(t)
|
||||
store := SetupStore(t, c, shouldEncrypt)
|
||||
r := &sql.Rows{}
|
||||
|
||||
txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt)
|
||||
c.EXPECT().QueryForRows(context.Background(), store.listKeysStmt).Return(r, nil)
|
||||
c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error"))
|
||||
deleteAllStmt := NewMockStmt(gomock.NewController(t))
|
||||
|
||||
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do(
|
||||
func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) {
|
||||
err := f(txC)
|
||||
if err == nil {
|
||||
t.Fail()
|
||||
}
|
||||
})
|
||||
txC.EXPECT().Stmt(store.deleteAllStmt).Return(deleteAllStmt)
|
||||
deleteAllStmt.EXPECT().Exec().Return(nil, fmt.Errorf("error"))
|
||||
|
||||
err := store.Replace([]any{testObject}, testObject.Id)
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "Replace with TX client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) {
|
||||
c, txC := SetupMockDB(t)
|
||||
store := SetupStore(t, c, shouldEncrypt)
|
||||
r := &sql.Rows{}
|
||||
listKeysStmt := NewMockStmt(gomock.NewController(t))
|
||||
deleteStmt := NewMockStmt(gomock.NewController(t))
|
||||
|
||||
txC.EXPECT().Stmt(store.listKeysStmt).Return(listKeysStmt)
|
||||
c.EXPECT().QueryForRows(context.Background(), listKeysStmt).Return(r, nil)
|
||||
c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil)
|
||||
txC.EXPECT().Stmt(store.deleteStmt).Return(deleteStmt)
|
||||
deleteStmt.EXPECT().Exec(testObject.Id).Return(nil, fmt.Errorf("error"))
|
||||
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do(
|
||||
func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) {
|
||||
err := f(txC)
|
||||
@@ -650,15 +621,10 @@ func TestReplace(t *testing.T) {
|
||||
tests = append(tests, testCase{description: "Replace with DB client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) {
|
||||
c, txC := SetupMockDB(t)
|
||||
store := SetupStore(t, c, shouldEncrypt)
|
||||
r := &sql.Rows{}
|
||||
listKeysStmt := NewMockStmt(gomock.NewController(t))
|
||||
deleteStmt := NewMockStmt(gomock.NewController(t))
|
||||
deleteAllStmt := NewMockStmt(gomock.NewController(t))
|
||||
|
||||
txC.EXPECT().Stmt(store.listKeysStmt).Return(listKeysStmt)
|
||||
c.EXPECT().QueryForRows(context.Background(), listKeysStmt).Return(r, nil)
|
||||
c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil)
|
||||
txC.EXPECT().Stmt(store.deleteStmt).Return(deleteStmt)
|
||||
deleteStmt.EXPECT().Exec(testObject.Id).Return(nil, nil)
|
||||
txC.EXPECT().Stmt(store.deleteAllStmt).Return(deleteAllStmt)
|
||||
deleteAllStmt.EXPECT().Exec()
|
||||
c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt).Return(fmt.Errorf("error"))
|
||||
|
||||
c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do(
|
||||
@@ -719,6 +685,7 @@ func SetupMockDB(t *testing.T) (*MockClient, *MockTXClient) {
|
||||
// use stmt mock here
|
||||
dbC.EXPECT().Prepare(fmt.Sprintf(upsertStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
|
||||
dbC.EXPECT().Prepare(fmt.Sprintf(deleteStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
|
||||
dbC.EXPECT().Prepare(fmt.Sprintf(deleteAllStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
|
||||
dbC.EXPECT().Prepare(fmt.Sprintf(getStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
|
||||
dbC.EXPECT().Prepare(fmt.Sprintf(listStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
|
||||
dbC.EXPECT().Prepare(fmt.Sprintf(listKeysStmtFmt, "testStoreObject")).Return(&sql.Stmt{})
|
||||
|
Reference in New Issue
Block a user