diff --git a/.gitignore b/.gitignore index b4b9df1b..b376df47 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +.DS_Store + /.dapper /.cache /certs @@ -6,3 +8,7 @@ *.swp .idea steve + +informer_object_cache.db +informer_object_cache.db-shm +informer_object_cache.db-wal diff --git a/pkg/sqlcache/db/client.go b/pkg/sqlcache/db/client.go index dffbb6dd..5675bffb 100644 --- a/pkg/sqlcache/db/client.go +++ b/pkg/sqlcache/db/client.go @@ -15,7 +15,8 @@ import ( "reflect" "sync" - "github.com/pkg/errors" + "errors" + "github.com/rancher/steve/pkg/sqlcache/db/transaction" // needed for drivers @@ -29,8 +30,59 @@ const ( informerObjectCachePerms fs.FileMode = 0o600 ) -// Client is a database client that provides encrypting, decrypting, and database resetting. -type Client struct { +// Client defines a database client that provides encrypting, decrypting, and database resetting +type Client interface { + WithTransaction(ctx context.Context, forWriting bool, f WithTransactionFunction) error + Prepare(stmt string) *sql.Stmt + QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) + ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) + ReadStrings(rows Rows) ([]string, error) + ReadInt(rows Rows) (int, error) + Upsert(tx transaction.Client, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error + CloseStmt(closable Closable) error + NewConnection() error +} + +// WithTransaction runs f within a transaction. +// +// If forWriting is true, this method blocks until all other concurrent forWriting +// transactions have either committed or rolled back. +// If forWriting is false, it is assumed the returned transaction will exclusively +// be used for DQL (e.g. SELECT) queries. +// Not respecting the above rule might result in transactions failing with unexpected +// SQLITE_BUSY (5) errors (aka "Runtime error: database is locked"). +// See discussion in https://github.com/rancher/lasso/pull/98 for details +// +// The transaction is committed if f returns nil, otherwise it is rolled back. +func (c *client) WithTransaction(ctx context.Context, forWriting bool, f WithTransactionFunction) error { + c.connLock.RLock() + // note: this assumes _txlock=immediate in the connection string, see NewConnection + tx, err := c.conn.BeginTx(ctx, &sql.TxOptions{ + ReadOnly: !forWriting, + }) + c.connLock.RUnlock() + if err != nil { + return err + } + + err = f(transaction.NewClient(tx)) + + if err != nil { + rerr := tx.Rollback() + err = errors.Join(err, rerr) + } else { + cerr := tx.Commit() + err = errors.Join(err, cerr) + } + + return err +} + +// WithTransactionFunction is a function that uses a transaction +type WithTransactionFunction func(tx transaction.Client) error + +// client is the main implementation of Client. Other implementations exist for test purposes +type client struct { conn Connection connLock sync.RWMutex encryptor Encryptor @@ -74,15 +126,6 @@ func (e *QueryError) Unwrap() error { return e.Err } -// TXClient represents a sql transaction. The TXClient must manage rollbacks as rollback functionality is not exposed. -type TXClient interface { - StmtExec(stmt transaction.Stmt, args ...any) error - Exec(stmt string, args ...any) error - Commit() error - Stmt(stmt *sql.Stmt) transaction.Stmt - Cancel() error -} - // Encryptor encrypts data with a key which is rotated to avoid wear-out. type Encryptor interface { // Encrypt encrypts the specified data, returning: the encrypted data, the nonce used to encrypt the data, and an ID identifying the key that was used (as it rotates). On failure error is returned instead. @@ -95,9 +138,9 @@ type Decryptor interface { Decrypt([]byte, []byte, uint32) ([]byte, error) } -// NewClient returns a Client. If the given connection is nil then a default one will be created. -func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (*Client, error) { - client := &Client{ +// NewClient returns a client. If the given connection is nil then a default one will be created. +func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (Client, error) { + client := &client{ encryptor: encryptor, decryptor: decryptor, } @@ -114,19 +157,19 @@ func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (*Client, } // Prepare prepares the given string into a sql statement on the client's connection. -func (c *Client) Prepare(stmt string) *sql.Stmt { +func (c *client) Prepare(stmt string) *sql.Stmt { c.connLock.RLock() defer c.connLock.RUnlock() prepared, err := c.conn.Prepare(stmt) if err != nil { - panic(errors.Errorf("Error preparing statement: %s\n%v", stmt, err)) + panic(fmt.Errorf("Error preparing statement: %s\n%w", stmt, err)) } return prepared } // QueryForRows queries the given stmt with the given params and returns the resulting rows. The query wil be retried // given a sqlite busy error. -func (c *Client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) { +func (c *client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) { c.connLock.RLock() defer c.connLock.RUnlock() @@ -135,13 +178,13 @@ func (c *Client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params // CloseStmt will call close on the given Closable. It is intended to be used with a sql statement. This function is meant // to replace stmt.Close which can cause panics when callers unit-test since there usually is no real underlying connection. -func (c *Client) CloseStmt(closable Closable) error { +func (c *client) CloseStmt(closable Closable) error { return closable.Close() } // ReadObjects Scans the given rows, performs any necessary decryption, converts the data to objects of the given type, // and returns a slice of those objects. -func (c *Client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) { +func (c *client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) { c.connLock.RLock() defer c.connLock.RUnlock() @@ -171,7 +214,7 @@ func (c *Client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([ } // ReadStrings scans the given rows into strings, and then returns the strings as a slice. -func (c *Client) ReadStrings(rows Rows) ([]string, error) { +func (c *client) ReadStrings(rows Rows) ([]string, error) { c.connLock.RLock() defer c.connLock.RUnlock() @@ -199,7 +242,7 @@ func (c *Client) ReadStrings(rows Rows) ([]string, error) { } // ReadInt scans the first of the given rows into a single int (eg. for COUNT() queries) -func (c *Client) ReadInt(rows Rows) (int, error) { +func (c *client) ReadInt(rows Rows) (int, error) { c.connLock.RLock() defer c.connLock.RUnlock() @@ -226,28 +269,7 @@ func (c *Client) ReadInt(rows Rows) (int, error) { return result, nil } -// BeginTx attempts to begin a transaction. -// If forWriting is true, this method blocks until all other concurrent forWriting -// transactions have either committed or rolled back. -// If forWriting is false, it is assumed the returned transaction will exclusively -// be used for DQL (e.g. SELECT) queries. -// Not respecting the above rule might result in transactions failing with unexpected -// SQLITE_BUSY (5) errors (aka "Runtime error: database is locked"). -// See discussion in https://github.com/rancher/lasso/pull/98 for details -func (c *Client) BeginTx(ctx context.Context, forWriting bool) (TXClient, error) { - c.connLock.RLock() - defer c.connLock.RUnlock() - // note: this assumes _txlock=immediate in the connection string, see NewConnection - sqlTx, err := c.conn.BeginTx(ctx, &sql.TxOptions{ - ReadOnly: !forWriting, - }) - if err != nil { - return nil, err - } - return transaction.NewClient(sqlTx), nil -} - -func (c *Client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) { +func (c *client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) { var data, dataNonce sql.RawBytes var kid uint32 err := rows.Scan(&data, &dataNonce, &kid) @@ -264,8 +286,9 @@ func (c *Client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) { return data, nil } -// Upsert used to be called upsertEncrypted in store package before move -func (c *Client) Upsert(tx TXClient, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error { +// Upsert executes an upsert statement encrypting arguments if necessary +// note the statement should have 4 parameters: key, objBytes, dataNonce, kid +func (c *client) Upsert(tx transaction.Client, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error { objBytes := toBytes(obj) var dataNonce []byte var err error @@ -277,7 +300,8 @@ func (c *Client) Upsert(tx TXClient, stmt *sql.Stmt, key string, obj any, should } } - return tx.StmtExec(tx.Stmt(stmt), key, objBytes, dataNonce, kid) + _, err = tx.Stmt(stmt).Exec(key, objBytes, dataNonce, kid) + return err } // toBytes encodes an object to a byte slice @@ -312,7 +336,7 @@ func closeRowsOnError(rows Rows, err error) error { // NewConnection checks for currently existing connection, closes one if it exists, removes any relevant db files, and opens a new connection which subsequently // creates new files. -func (c *Client) NewConnection() error { +func (c *client) NewConnection() error { c.connLock.Lock() defer c.connLock.Unlock() if c.conn != nil { diff --git a/pkg/sqlcache/db/client_test.go b/pkg/sqlcache/db/client_test.go index 8b7951f1..0f7f5290 100644 --- a/pkg/sqlcache/db/client_test.go +++ b/pkg/sqlcache/db/client_test.go @@ -11,13 +11,14 @@ import ( "reflect" "testing" + "github.com/rancher/steve/pkg/sqlcache/db/transaction" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" ) // Mocks for this test are generated with the following command. -//go:generate mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient -//go:generate mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx +//go:generate mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor +//go:generate mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Client,Stmt type testStoreObject struct { Id string @@ -37,7 +38,7 @@ func TestNewClient(t *testing.T) { c := SetupMockConnection(t) e := SetupMockEncryptor(t) d := SetupMockDecryptor(t) - expectedClient := &Client{ + expectedClient := &client{ conn: c, encryptor: e, decryptor: d, @@ -389,58 +390,6 @@ func TestReadInt(t *testing.T) { } } -func TestBegin(t *testing.T) { - type testCase struct { - description string - test func(t *testing.T) - } - - var tests []testCase - - // Tests with shouldEncryptSet to false - tests = append(tests, testCase{description: "BeginTx(), with no errors", test: func(t *testing.T) { - c := SetupMockConnection(t) - e := SetupMockEncryptor(t) - d := SetupMockDecryptor(t) - - sqlTx := &sql.Tx{} - c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(sqlTx, nil) - client := SetupClient(t, c, e, d) - txC, err := client.BeginTx(context.Background(), false) - assert.Nil(t, err) - assert.NotNil(t, txC) - }, - }) - tests = append(tests, testCase{description: "BeginTx(), with forWriting option set", test: func(t *testing.T) { - c := SetupMockConnection(t) - e := SetupMockEncryptor(t) - d := SetupMockDecryptor(t) - - sqlTx := &sql.Tx{} - c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: false}).Return(sqlTx, nil) - client := SetupClient(t, c, e, d) - txC, err := client.BeginTx(context.Background(), true) - assert.Nil(t, err) - assert.NotNil(t, txC) - }, - }) - tests = append(tests, testCase{description: "BeginTx(), with connection Begin() error", test: func(t *testing.T) { - c := SetupMockConnection(t) - e := SetupMockEncryptor(t) - d := SetupMockDecryptor(t) - - c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(nil, fmt.Errorf("error")) - client := SetupClient(t, c, e, d) - _, err := client.BeginTx(context.Background(), false) - assert.NotNil(t, err) - }, - }) - t.Parallel() - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { test.test(t) }) - } -} - func TestUpsert(t *testing.T) { type testCase struct { description string @@ -459,14 +408,14 @@ func TestUpsert(t *testing.T) { d := SetupMockDecryptor(t) client := SetupClient(t, c, e, d) - txC := NewMockTXClient(gomock.NewController(t)) + txC := NewMockClient(gomock.NewController(t)) sqlStmt := &sql.Stmt{} stmt := NewMockStmt(gomock.NewController(t)) testObjBytes := toBytes(testObject) testByteValue := []byte("something") e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil) txC.EXPECT().Stmt(sqlStmt).Return(stmt) - txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(nil) + stmt.EXPECT().Exec("somekey", testByteValue, testByteValue, keyID).Return(nil, nil) err := client.Upsert(txC, sqlStmt, "somekey", testObject, true) assert.Nil(t, err) }, @@ -477,7 +426,7 @@ func TestUpsert(t *testing.T) { d := SetupMockDecryptor(t) client := SetupClient(t, c, e, d) - txC := NewMockTXClient(gomock.NewController(t)) + txC := NewMockClient(gomock.NewController(t)) sqlStmt := &sql.Stmt{} testObjBytes := toBytes(testObject) e.EXPECT().Encrypt(testObjBytes).Return(nil, nil, uint32(0), fmt.Errorf("error")) @@ -491,14 +440,14 @@ func TestUpsert(t *testing.T) { d := SetupMockDecryptor(t) client := SetupClient(t, c, e, d) - txC := NewMockTXClient(gomock.NewController(t)) + txC := NewMockClient(gomock.NewController(t)) sqlStmt := &sql.Stmt{} stmt := NewMockStmt(gomock.NewController(t)) testObjBytes := toBytes(testObject) testByteValue := []byte("something") e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil) txC.EXPECT().Stmt(sqlStmt).Return(stmt) - txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(fmt.Errorf("error")) + stmt.EXPECT().Exec("somekey", testByteValue, testByteValue, keyID).Return(nil, fmt.Errorf("error")) err := client.Upsert(txC, sqlStmt, "somekey", testObject, true) assert.NotNil(t, err) }, @@ -509,13 +458,13 @@ func TestUpsert(t *testing.T) { e := SetupMockEncryptor(t) client := SetupClient(t, c, e, d) - txC := NewMockTXClient(gomock.NewController(t)) + txC := NewMockClient(gomock.NewController(t)) sqlStmt := &sql.Stmt{} stmt := NewMockStmt(gomock.NewController(t)) var testByteValue []byte testObjBytes := toBytes(testObject) txC.EXPECT().Stmt(sqlStmt).Return(stmt) - txC.EXPECT().StmtExec(stmt, "somekey", testObjBytes, testByteValue, uint32(0)).Return(nil) + stmt.EXPECT().Exec("somekey", testObjBytes, testByteValue, uint32(0)).Return(nil, nil) err := client.Upsert(txC, sqlStmt, "somekey", testObject, false) assert.Nil(t, err) }, @@ -582,9 +531,10 @@ func TestNewConnection(t *testing.T) { assert.Nil(t, err) // Create a transaction to ensure that the file is written to disk. - txC, err := client.BeginTx(context.Background(), false) + err = client.WithTransaction(context.Background(), false, func(tx transaction.Client) error { + return nil + }) assert.NoError(t, err) - assert.NoError(t, txC.Commit()) assert.FileExists(t, InformerObjectCacheDBPath) assertFileHasPermissions(t, InformerObjectCacheDBPath, 0600) @@ -630,7 +580,7 @@ func SetupMockRows(t *testing.T) *MockRows { return MockR } -func SetupClient(t *testing.T, connection Connection, encryptor Encryptor, decryptor Decryptor) *Client { +func SetupClient(t *testing.T, connection Connection, encryptor Encryptor, decryptor Decryptor) Client { c, _ := NewClient(connection, encryptor, decryptor) return c } diff --git a/pkg/sqlcache/db/db_mocks_test.go b/pkg/sqlcache/db/db_mocks_test.go index 54199ba4..a7559615 100644 --- a/pkg/sqlcache/db/db_mocks_test.go +++ b/pkg/sqlcache/db/db_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: Rows,Connection,Encryptor,Decryptor,TXClient) +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: Rows,Connection,Encryptor,Decryptor) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient +// mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor // // Package db is a generated GoMock package. @@ -14,7 +14,6 @@ import ( sql "database/sql" reflect "reflect" - transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) @@ -265,106 +264,3 @@ func (mr *MockDecryptorMockRecorder) Decrypt(arg0, arg1, arg2 any) *gomock.Call mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockDecryptor)(nil).Decrypt), arg0, arg1, arg2) } - -// MockTXClient is a mock of TXClient interface. -type MockTXClient struct { - ctrl *gomock.Controller - recorder *MockTXClientMockRecorder -} - -// MockTXClientMockRecorder is the mock recorder for MockTXClient. -type MockTXClientMockRecorder struct { - mock *MockTXClient -} - -// NewMockTXClient creates a new mock instance. -func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { - mock := &MockTXClient{ctrl: ctrl} - mock.recorder = &MockTXClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { - return m.recorder -} - -// Cancel mocks base method. -func (m *MockTXClient) Cancel() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Cancel") - ret0, _ := ret[0].(error) - return ret0 -} - -// Cancel indicates an expected call of Cancel. -func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) -} - -// Commit mocks base method. -func (m *MockTXClient) Commit() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Commit") - ret0, _ := ret[0].(error) - return ret0 -} - -// Commit indicates an expected call of Commit. -func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) -} - -// Exec mocks base method. -func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { - m.ctrl.T.Helper() - varargs := []any{arg0} - for _, a := range arg1 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Exec", varargs...) - ret0, _ := ret[0].(error) - return ret0 -} - -// Exec indicates an expected call of Exec. -func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0}, arg1...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) -} - -// Stmt mocks base method. -func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Stmt", arg0) - ret0, _ := ret[0].(transaction.Stmt) - return ret0 -} - -// Stmt indicates an expected call of Stmt. -func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) -} - -// StmtExec mocks base method. -func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { - m.ctrl.T.Helper() - varargs := []any{arg0} - for _, a := range arg1 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "StmtExec", varargs...) - ret0, _ := ret[0].(error) - return ret0 -} - -// StmtExec indicates an expected call of StmtExec. -func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0}, arg1...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) -} diff --git a/pkg/sqlcache/db/transaction/transaction.go b/pkg/sqlcache/db/transaction/transaction.go index 68b476ea..cc9e0901 100644 --- a/pkg/sqlcache/db/transaction/transaction.go +++ b/pkg/sqlcache/db/transaction/transaction.go @@ -1,92 +1,44 @@ /* -Package transaction provides a client for a live transaction, and interfaces for some relevant sql types. The transaction client automatically performs rollbacks on failures. -The use of this package simplifies testing for callers by making the underlying transaction mock-able. +Package transaction provides mockable interfaces of sql package struct types. */ package transaction import ( "context" "database/sql" - - "github.com/pkg/errors" - "github.com/sirupsen/logrus" ) -// Client provides a way to interact with the underlying sql transaction. -type Client struct { - sqlTx SQLTx -} - -// SQLTx represents a sql transaction -type SQLTx interface { +// Client is an interface over a subset of sql.Tx methods +// rationale 1: explicitly forbid direct access to Commit and Rollback functionality +// as that is exclusively dealt with by WithTransaction in ../db +// rationale 2: allow mocking +type Client interface { Exec(query string, args ...any) (sql.Result, error) - Stmt(stmt *sql.Stmt) *sql.Stmt - Commit() error - Rollback() error + Stmt(stmt *sql.Stmt) Stmt } -// Stmt represents a sql stmt. It is used as a return type to offer some testability over returning sql's Stmt type -// because we are able to mock its outputs and do not need an actual connection. +// client is the main implementation of Client, delegates to sql.Tx +// other implementations exist for testing purposes +type client struct { + tx *sql.Tx +} + +func NewClient(tx *sql.Tx) Client { + return &client{tx: tx} +} + +func (c client) Exec(query string, args ...any) (sql.Result, error) { + return c.tx.Exec(query, args...) +} + +func (c client) Stmt(stmt *sql.Stmt) Stmt { + return c.tx.Stmt(stmt) +} + +// Stmt is an interface over a subset of sql.Stmt methods +// rationale: allow mocking type Stmt interface { Exec(args ...any) (sql.Result, error) Query(args ...any) (*sql.Rows, error) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) } - -// NewClient returns a Client with the given transaction assigned. -func NewClient(tx SQLTx) *Client { - return &Client{sqlTx: tx} -} - -// Commit commits the transaction and then unlocks the database. -func (c *Client) Commit() error { - return c.sqlTx.Commit() -} - -// Exec uses the sqlTX Exec() with the given stmt and args. The transaction will be automatically rolled back if Exec() -// returns an error. -func (c *Client) Exec(stmt string, args ...any) error { - _, err := c.sqlTx.Exec(stmt, args...) - if err != nil { - return c.rollback(c.sqlTx, err) - } - return nil -} - -// Stmt adds the given sql.Stmt to the client's transaction and then returns a Stmt. An interface is being returned -// here to aid in testing callers by providing a way to configure the statement's behavior. -func (c *Client) Stmt(stmt *sql.Stmt) Stmt { - s := c.sqlTx.Stmt(stmt) - return s -} - -// StmtExec Execs the given statement with the given args. It assumes the stmt has been added to the transaction. The -// transaction is rolled back if Stmt.Exec() returns an error. -func (c *Client) StmtExec(stmt Stmt, args ...any) error { - _, err := stmt.Exec(args...) - if err != nil { - logrus.Debugf("StmtExec failed: query %s, args: %s, err: %s", stmt, args, err) - return c.rollback(c.sqlTx, err) - } - return nil -} - -// rollback handles rollbacks and wraps errors if needed -func (c *Client) rollback(tx SQLTx, err error) error { - rerr := tx.Rollback() - if rerr != nil { - return errors.Wrapf(err, "Encountered error, then encountered another error while rolling back: %v", rerr) - } - return errors.Wrapf(err, "Encountered error, successfully rolled back") -} - -// Cancel rollbacks the transaction without wrapping an error. This only needs to be called if Client has not returned -// an error yet or has not committed. Otherwise, transaction has already rolled back, or in the case of Commit() it is too -// late. -func (c *Client) Cancel() error { - rerr := c.sqlTx.Rollback() - if rerr != sql.ErrTxDone { - return rerr - } - return nil -} diff --git a/pkg/sqlcache/db/transaction/transaction_test.go b/pkg/sqlcache/db/transaction/transaction_test.go deleted file mode 100644 index 0ede5d2e..00000000 --- a/pkg/sqlcache/db/transaction/transaction_test.go +++ /dev/null @@ -1,182 +0,0 @@ -package transaction - -import ( - "database/sql" - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" -) - -//go:generate mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx - -func TestNewClient(t *testing.T) { - tx := NewMockSQLTx(gomock.NewController(t)) - c := NewClient(tx) - assert.Equal(t, tx, c.sqlTx) -} - -func TestCommit(t *testing.T) { - type testCase struct { - description string - test func(t *testing.T) - } - - var tests []testCase - - tests = append(tests, testCase{description: "Commit() with no errors returned from sql TX should return no error", test: func(t *testing.T) { - tx := NewMockSQLTx(gomock.NewController(t)) - tx.EXPECT().Commit().Return(nil) - c := &Client{ - sqlTx: tx, - } - err := c.Commit() - assert.Nil(t, err) - }}) - tests = append(tests, testCase{description: "Commit() with error from sql TX commit() should return error", test: func(t *testing.T) { - tx := NewMockSQLTx(gomock.NewController(t)) - tx.EXPECT().Commit().Return(fmt.Errorf("error")) - c := &Client{ - sqlTx: tx, - } - err := c.Commit() - assert.NotNil(t, err) - }}) - t.Parallel() - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { test.test(t) }) - } -} - -func TestExec(t *testing.T) { - type testCase struct { - description string - test func(t *testing.T) - } - - var tests []testCase - - tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) { - tx := NewMockSQLTx(gomock.NewController(t)) - stmtStr := "some statement %s" - arg := 5 - // should be passed same statement and arg that was passed to parent function - tx.EXPECT().Exec(stmtStr, arg).Return(nil, nil) - c := &Client{ - sqlTx: tx, - } - err := c.Exec(stmtStr, arg) - assert.Nil(t, err) - }}) - tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) { - tx := NewMockSQLTx(gomock.NewController(t)) - stmtStr := "some statement %s" - arg := 5 - // should be passed same statement and arg that was passed to parent function - tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error")) - tx.EXPECT().Rollback().Return(nil) - c := &Client{ - sqlTx: tx, - } - err := c.Exec(stmtStr, arg) - assert.NotNil(t, err) - }}) - tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) { - tx := NewMockSQLTx(gomock.NewController(t)) - stmtStr := "some statement %s" - arg := 5 - // should be passed same statement and arg that was passed to parent function - tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error")) - tx.EXPECT().Rollback().Return(fmt.Errorf("error")) - c := &Client{ - sqlTx: tx, - } - err := c.Exec(stmtStr, arg) - assert.NotNil(t, err) - }}) - t.Parallel() - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { test.test(t) }) - } -} - -func TestStmt(t *testing.T) { - type testCase struct { - description string - test func(t *testing.T) - } - - var tests []testCase - - tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) { - tx := NewMockSQLTx(gomock.NewController(t)) - stmt := &sql.Stmt{} - var returnedTXStmt *sql.Stmt - // should be passed same statement and arg that was passed to parent function - tx.EXPECT().Stmt(stmt).Return(returnedTXStmt) - c := &Client{ - sqlTx: tx, - } - returnedStmt := c.Stmt(stmt) - // whatever tx returned should be returned here. Nil was used because none of sql.Stmt's fields are exported so its simpler to test nil as it - // won't be equal to an empty struct - assert.Equal(t, returnedTXStmt, returnedStmt) - }}) - t.Parallel() - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { test.test(t) }) - } -} - -func TestStmtExec(t *testing.T) { - type testCase struct { - description string - test func(t *testing.T) - } - - var tests []testCase - - tests = append(tests, testCase{description: "StmtExec with no errors returned from Stmt should return no error", test: func(t *testing.T) { - tx := NewMockSQLTx(gomock.NewController(t)) - stmt := NewMockStmt(gomock.NewController(t)) - arg := "something" - // should be passed same arg that was passed to parent function - stmt.EXPECT().Exec(arg).Return(nil, nil) - c := &Client{ - sqlTx: tx, - } - err := c.StmtExec(stmt, arg) - assert.Nil(t, err) - }}) - tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and no Tx Rollback() error should return error", test: func(t *testing.T) { - tx := NewMockSQLTx(gomock.NewController(t)) - stmt := NewMockStmt(gomock.NewController(t)) - arg := "something" - // should be passed same arg that was passed to parent function - stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error")) - tx.EXPECT().Rollback().Return(nil) - c := &Client{ - sqlTx: tx, - } - err := c.StmtExec(stmt, arg) - assert.NotNil(t, err) - }}) - tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and Tx Rollback() error should return error", test: func(t *testing.T) { - tx := NewMockSQLTx(gomock.NewController(t)) - stmt := NewMockStmt(gomock.NewController(t)) - arg := "something" - // should be passed same arg that was passed to parent function - stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error")) - tx.EXPECT().Rollback().Return(fmt.Errorf("error2")) - c := &Client{ - sqlTx: tx, - } - err := c.StmtExec(stmt, arg) - assert.NotNil(t, err) - }}) - t.Parallel() - for _, test := range tests { - t.Run(test.description, func(t *testing.T) { test.test(t) }) - } -} diff --git a/pkg/sqlcache/db/transaction_mocks_test.go b/pkg/sqlcache/db/transaction_mocks_test.go index 1cac5caf..4a64cc16 100644 --- a/pkg/sqlcache/db/transaction_mocks_test.go +++ b/pkg/sqlcache/db/transaction_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx) +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Client,Stmt) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx +// mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Client,Stmt // // Package db is a generated GoMock package. @@ -14,9 +14,67 @@ import ( sql "database/sql" reflect "reflect" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockClient) Exec(arg0 string, arg1 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockClient)(nil).Stmt), arg0) +} + // MockStmt is a mock of Stmt interface. type MockStmt struct { ctrl *gomock.Controller @@ -97,88 +155,3 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call varargs := append([]any{arg0}, arg1...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) } - -// MockSQLTx is a mock of SQLTx interface. -type MockSQLTx struct { - ctrl *gomock.Controller - recorder *MockSQLTxMockRecorder -} - -// MockSQLTxMockRecorder is the mock recorder for MockSQLTx. -type MockSQLTxMockRecorder struct { - mock *MockSQLTx -} - -// NewMockSQLTx creates a new mock instance. -func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx { - mock := &MockSQLTx{ctrl: ctrl} - mock.recorder = &MockSQLTxMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder { - return m.recorder -} - -// Commit mocks base method. -func (m *MockSQLTx) Commit() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Commit") - ret0, _ := ret[0].(error) - return ret0 -} - -// Commit indicates an expected call of Commit. -func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit)) -} - -// Exec mocks base method. -func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) { - m.ctrl.T.Helper() - varargs := []any{arg0} - for _, a := range arg1 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Exec", varargs...) - ret0, _ := ret[0].(sql.Result) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Exec indicates an expected call of Exec. -func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0}, arg1...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...) -} - -// Rollback mocks base method. -func (m *MockSQLTx) Rollback() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Rollback") - ret0, _ := ret[0].(error) - return ret0 -} - -// Rollback indicates an expected call of Rollback. -func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback)) -} - -// Stmt mocks base method. -func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Stmt", arg0) - ret0, _ := ret[0].(*sql.Stmt) - return ret0 -} - -// Stmt indicates an expected call of Stmt. -func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0) -} diff --git a/pkg/sqlcache/informer/db_mocks_test.go b/pkg/sqlcache/informer/db_mocks_test.go index 7d2c81ce..63fa214f 100644 --- a/pkg/sqlcache/informer/db_mocks_test.go +++ b/pkg/sqlcache/informer/db_mocks_test.go @@ -1,125 +1,24 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient,Rows) +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: Rows,Client) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows +// mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Client // // Package informer is a generated GoMock package. package informer import ( + context "context" sql "database/sql" reflect "reflect" + db "github.com/rancher/steve/pkg/sqlcache/db" transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) -// MockTXClient is a mock of TXClient interface. -type MockTXClient struct { - ctrl *gomock.Controller - recorder *MockTXClientMockRecorder -} - -// MockTXClientMockRecorder is the mock recorder for MockTXClient. -type MockTXClientMockRecorder struct { - mock *MockTXClient -} - -// NewMockTXClient creates a new mock instance. -func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { - mock := &MockTXClient{ctrl: ctrl} - mock.recorder = &MockTXClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { - return m.recorder -} - -// Cancel mocks base method. -func (m *MockTXClient) Cancel() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Cancel") - ret0, _ := ret[0].(error) - return ret0 -} - -// Cancel indicates an expected call of Cancel. -func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) -} - -// Commit mocks base method. -func (m *MockTXClient) Commit() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Commit") - ret0, _ := ret[0].(error) - return ret0 -} - -// Commit indicates an expected call of Commit. -func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) -} - -// Exec mocks base method. -func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { - m.ctrl.T.Helper() - varargs := []any{arg0} - for _, a := range arg1 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Exec", varargs...) - ret0, _ := ret[0].(error) - return ret0 -} - -// Exec indicates an expected call of Exec. -func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0}, arg1...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) -} - -// Stmt mocks base method. -func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Stmt", arg0) - ret0, _ := ret[0].(transaction.Stmt) - return ret0 -} - -// Stmt indicates an expected call of Stmt. -func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) -} - -// StmtExec mocks base method. -func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { - m.ctrl.T.Helper() - varargs := []any{arg0} - for _, a := range arg1 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "StmtExec", varargs...) - ret0, _ := ret[0].(error) - return ret0 -} - -// StmtExec indicates an expected call of StmtExec. -func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0}, arg1...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) -} - // MockRows is a mock of Rows interface. type MockRows struct { ctrl *gomock.Controller @@ -202,3 +101,161 @@ func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...) } + +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// CloseStmt mocks base method. +func (m *MockClient) CloseStmt(arg0 db.Closable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStmt", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockClient)(nil).CloseStmt), arg0) +} + +// NewConnection mocks base method. +func (m *MockClient) NewConnection() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewConnection") + ret0, _ := ret[0].(error) + return ret0 +} + +// NewConnection indicates an expected call of NewConnection. +func (mr *MockClientMockRecorder) NewConnection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockClient)(nil).NewConnection)) +} + +// Prepare mocks base method. +func (m *MockClient) Prepare(arg0 string) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockClientMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockClient)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockClient)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockClient) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockClientMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockClient)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockClient)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockClient) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockClient)(nil).ReadStrings), arg0) +} + +// Upsert mocks base method. +func (m *MockClient) Upsert(arg0 transaction.Client, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// Upsert indicates an expected call of Upsert. +func (mr *MockClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) +} + +// WithTransaction mocks base method. +func (m *MockClient) WithTransaction(arg0 context.Context, arg1 bool, arg2 db.WithTransactionFunction) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithTransaction", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// WithTransaction indicates an expected call of WithTransaction. +func (mr *MockClientMockRecorder) WithTransaction(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTransaction", reflect.TypeOf((*MockClient)(nil).WithTransaction), arg0, arg1, arg2) +} diff --git a/pkg/sqlcache/informer/factory/db_mocks_test.go b/pkg/sqlcache/informer/factory/db_mocks_test.go index 9ac55bb3..76fca697 100644 --- a/pkg/sqlcache/informer/factory/db_mocks_test.go +++ b/pkg/sqlcache/informer/factory/db_mocks_test.go @@ -1,121 +1,178 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient) +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: Client) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient +// mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Client // // Package factory is a generated GoMock package. package factory import ( + context "context" sql "database/sql" reflect "reflect" + db "github.com/rancher/steve/pkg/sqlcache/db" transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) -// MockTXClient is a mock of TXClient interface. -type MockTXClient struct { +// MockClient is a mock of Client interface. +type MockClient struct { ctrl *gomock.Controller - recorder *MockTXClientMockRecorder + recorder *MockClientMockRecorder } -// MockTXClientMockRecorder is the mock recorder for MockTXClient. -type MockTXClientMockRecorder struct { - mock *MockTXClient +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient } -// NewMockTXClient creates a new mock instance. -func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { - mock := &MockTXClient{ctrl: ctrl} - mock.recorder = &MockTXClientMockRecorder{mock} +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { +func (m *MockClient) EXPECT() *MockClientMockRecorder { return m.recorder } -// Cancel mocks base method. -func (m *MockTXClient) Cancel() error { +// CloseStmt mocks base method. +func (m *MockClient) CloseStmt(arg0 db.Closable) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Cancel") + ret := m.ctrl.Call(m, "CloseStmt", arg0) ret0, _ := ret[0].(error) return ret0 } -// Cancel indicates an expected call of Cancel. -func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockClient)(nil).CloseStmt), arg0) } -// Commit mocks base method. -func (m *MockTXClient) Commit() error { +// NewConnection mocks base method. +func (m *MockClient) NewConnection() error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Commit") + ret := m.ctrl.Call(m, "NewConnection") ret0, _ := ret[0].(error) return ret0 } -// Commit indicates an expected call of Commit. -func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { +// NewConnection indicates an expected call of NewConnection. +func (mr *MockClientMockRecorder) NewConnection() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockClient)(nil).NewConnection)) } -// Exec mocks base method. -func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { +// Prepare mocks base method. +func (m *MockClient) Prepare(arg0 string) *sql.Stmt { m.ctrl.T.Helper() - varargs := []any{arg0} - for _, a := range arg1 { + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockClientMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockClient)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { varargs = append(varargs, a) } - ret := m.ctrl.Call(m, "Exec", varargs...) + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockClient)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockClient) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockClientMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockClient)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockClient)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockClient) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockClient)(nil).ReadStrings), arg0) +} + +// Upsert mocks base method. +func (m *MockClient) Upsert(arg0 transaction.Client, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(error) return ret0 } -// Exec indicates an expected call of Exec. -func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { +// Upsert indicates an expected call of Upsert. +func (mr *MockClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0}, arg1...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) } -// Stmt mocks base method. -func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { +// WithTransaction mocks base method. +func (m *MockClient) WithTransaction(arg0 context.Context, arg1 bool, arg2 db.WithTransactionFunction) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Stmt", arg0) - ret0, _ := ret[0].(transaction.Stmt) - return ret0 -} - -// Stmt indicates an expected call of Stmt. -func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) -} - -// StmtExec mocks base method. -func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { - m.ctrl.T.Helper() - varargs := []any{arg0} - for _, a := range arg1 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "StmtExec", varargs...) + ret := m.ctrl.Call(m, "WithTransaction", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } -// StmtExec indicates an expected call of StmtExec. -func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { +// WithTransaction indicates an expected call of WithTransaction. +func (mr *MockClientMockRecorder) WithTransaction(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0}, arg1...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTransaction", reflect.TypeOf((*MockClient)(nil).WithTransaction), arg0, arg1, arg2) } diff --git a/pkg/sqlcache/informer/factory/factory_mocks_test.go b/pkg/sqlcache/informer/factory/factory_mocks_test.go deleted file mode 100644 index a7adab6a..00000000 --- a/pkg/sqlcache/informer/factory/factory_mocks_test.go +++ /dev/null @@ -1,179 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/steve/pkg/sqlcache/informer/factory (interfaces: DBClient) -// -// Generated by this command: -// -// mockgen --build_flags=--mod=mod -package factory -destination ./factory_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer/factory DBClient -// - -// Package factory is a generated GoMock package. -package factory - -import ( - context "context" - sql "database/sql" - reflect "reflect" - - db "github.com/rancher/steve/pkg/sqlcache/db" - transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" - gomock "go.uber.org/mock/gomock" -) - -// MockDBClient is a mock of DBClient interface. -type MockDBClient struct { - ctrl *gomock.Controller - recorder *MockDBClientMockRecorder -} - -// MockDBClientMockRecorder is the mock recorder for MockDBClient. -type MockDBClientMockRecorder struct { - mock *MockDBClient -} - -// NewMockDBClient creates a new mock instance. -func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient { - mock := &MockDBClient{ctrl: ctrl} - mock.recorder = &MockDBClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder { - return m.recorder -} - -// BeginTx mocks base method. -func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) - ret0, _ := ret[0].(db.TXClient) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// BeginTx indicates an expected call of BeginTx. -func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1) -} - -// CloseStmt mocks base method. -func (m *MockDBClient) CloseStmt(arg0 db.Closable) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CloseStmt", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// CloseStmt indicates an expected call of CloseStmt. -func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0) -} - -// NewConnection mocks base method. -func (m *MockDBClient) NewConnection() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewConnection") - ret0, _ := ret[0].(error) - return ret0 -} - -// NewConnection indicates an expected call of NewConnection. -func (mr *MockDBClientMockRecorder) NewConnection() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockDBClient)(nil).NewConnection)) -} - -// Prepare mocks base method. -func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Prepare", arg0) - ret0, _ := ret[0].(*sql.Stmt) - return ret0 -} - -// Prepare indicates an expected call of Prepare. -func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0) -} - -// QueryForRows mocks base method. -func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { - m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "QueryForRows", varargs...) - ret0, _ := ret[0].(*sql.Rows) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// QueryForRows indicates an expected call of QueryForRows. -func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...) -} - -// ReadInt mocks base method. -func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadInt", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ReadInt indicates an expected call of ReadInt. -func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0) -} - -// ReadObjects mocks base method. -func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) - ret0, _ := ret[0].([]any) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ReadObjects indicates an expected call of ReadObjects. -func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2) -} - -// ReadStrings mocks base method. -func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadStrings", arg0) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ReadStrings indicates an expected call of ReadStrings. -func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0) -} - -// Upsert mocks base method. -func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 -} - -// Upsert indicates an expected call of Upsert. -func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) -} diff --git a/pkg/sqlcache/informer/factory/informer_factory.go b/pkg/sqlcache/informer/factory/informer_factory.go index e2fd019f..71470200 100644 --- a/pkg/sqlcache/informer/factory/informer_factory.go +++ b/pkg/sqlcache/informer/factory/informer_factory.go @@ -13,7 +13,6 @@ import ( "github.com/rancher/steve/pkg/sqlcache/db" "github.com/rancher/steve/pkg/sqlcache/encryption" "github.com/rancher/steve/pkg/sqlcache/informer" - sqlStore "github.com/rancher/steve/pkg/sqlcache/store" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/wait" @@ -28,7 +27,7 @@ const EncryptAllEnvVar = "CATTLE_ENCRYPT_CACHE_ALL" // CacheFactory builds Informer instances and keeps a cache of instances it created type CacheFactory struct { wg wait.Group - dbClient DBClient + dbClient db.Client stopCh chan struct{} mutex sync.RWMutex encryptAll bool @@ -44,22 +43,12 @@ type guardedInformer struct { mutex *sync.Mutex } -type newInformer func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespace bool) (*informer.Informer, error) - -type DBClient interface { - informer.DBClient - sqlStore.DBClient - connector -} +type newInformer func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db db.Client, shouldEncrypt bool, namespace bool) (*informer.Informer, error) type Cache struct { informer.ByOptionsLister } -type connector interface { - NewConnection() error -} - var defaultEncryptedResourceTypes = map[schema.GroupVersionKind]struct{}{ { Version: "v1", diff --git a/pkg/sqlcache/informer/factory/informer_factory_test.go b/pkg/sqlcache/informer/factory/informer_factory_test.go index 1daf0261..3abacae2 100644 --- a/pkg/sqlcache/informer/factory/informer_factory_test.go +++ b/pkg/sqlcache/informer/factory/informer_factory_test.go @@ -5,9 +5,9 @@ import ( "testing" "time" + "github.com/rancher/steve/pkg/sqlcache/db" "github.com/rancher/steve/pkg/sqlcache/informer" - sqlStore "github.com/rancher/steve/pkg/sqlcache/store" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "k8s.io/apimachinery/pkg/runtime/schema" @@ -15,8 +15,8 @@ import ( "k8s.io/client-go/tools/cache" ) -//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./factory_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer/factory DBClient -//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Client +//go:generate mockgen --build_flags=--mod=mod -package factory -destination ./transaction_mocks_tests.go -mock_names Client=MockTXClient github.com/rancher/steve/pkg/sqlcache/db/transaction Client //go:generate mockgen --build_flags=--mod=mod -package factory -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface //go:generate mockgen --build_flags=--mod=mod -package factory -destination ./k8s_cache_mocks_test.go k8s.io/client-go/tools/cache SharedIndexInformer @@ -58,7 +58,7 @@ func TestCacheFor(t *testing.T) { var tests []testCase tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, and stopCh not closed, should return no error and should call Informer.Run(). A subsequent call to CacheFor() should return same informer", test: func(t *testing.T) { - dbClient := NewMockDBClient(gomock.NewController(t)) + dbClient := NewMockClient(gomock.NewController(t)) dynamicClient := NewMockResourceInterface(gomock.NewController(t)) fields := [][]string{{"something"}} expectedGVK := schema.GroupVersionKind{} @@ -73,7 +73,7 @@ func TestCacheFor(t *testing.T) { expectedC := Cache{ ByOptionsLister: i, } - testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*informer.Informer, error) { + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db db.Client, shouldEncrypt bool, namespaced bool) (*informer.Informer, error) { assert.Equal(t, client, dynamicClient) assert.Equal(t, fields, fields) assert.Equal(t, expectedGVK, gvk) @@ -105,7 +105,7 @@ func TestCacheFor(t *testing.T) { assert.Equal(t, c, c2) }}) tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning false, and stopCh not closed, should call Run() and return an error", test: func(t *testing.T) { - dbClient := NewMockDBClient(gomock.NewController(t)) + dbClient := NewMockClient(gomock.NewController(t)) dynamicClient := NewMockResourceInterface(gomock.NewController(t)) fields := [][]string{{"something"}} expectedGVK := schema.GroupVersionKind{} @@ -118,7 +118,7 @@ func TestCacheFor(t *testing.T) { // need to set this so Run function is not nil SharedIndexInformer: sii, } - testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) { + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db db.Client, shouldEncrypt, namespaced bool) (*informer.Informer, error) { assert.Equal(t, client, dynamicClient) assert.Equal(t, fields, fields) assert.Equal(t, expectedGVK, gvk) @@ -143,7 +143,7 @@ func TestCacheFor(t *testing.T) { time.Sleep(2 * time.Second) }}) tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, and stopCh closed, should not call Run() more than once and not return an error", test: func(t *testing.T) { - dbClient := NewMockDBClient(gomock.NewController(t)) + dbClient := NewMockClient(gomock.NewController(t)) dynamicClient := NewMockResourceInterface(gomock.NewController(t)) fields := [][]string{{"something"}} expectedGVK := schema.GroupVersionKind{} @@ -160,7 +160,7 @@ func TestCacheFor(t *testing.T) { expectedC := Cache{ ByOptionsLister: i, } - testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) { + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db db.Client, shouldEncrypt, namespaced bool) (*informer.Informer, error) { assert.Equal(t, client, dynamicClient) assert.Equal(t, fields, fields) assert.Equal(t, expectedGVK, gvk) @@ -184,7 +184,7 @@ func TestCacheFor(t *testing.T) { time.Sleep(1 * time.Second) }}) tests = append(tests, testCase{description: "CacheFor() with no errors returned and encryptAll set to true, should return no error and pass shouldEncrypt as true to newInformer func", test: func(t *testing.T) { - dbClient := NewMockDBClient(gomock.NewController(t)) + dbClient := NewMockClient(gomock.NewController(t)) dynamicClient := NewMockResourceInterface(gomock.NewController(t)) fields := [][]string{{"something"}} expectedGVK := schema.GroupVersionKind{} @@ -199,7 +199,7 @@ func TestCacheFor(t *testing.T) { expectedC := Cache{ ByOptionsLister: i, } - testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) { + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db db.Client, shouldEncrypt, namespaced bool) (*informer.Informer, error) { assert.Equal(t, client, dynamicClient) assert.Equal(t, fields, fields) assert.Equal(t, expectedGVK, gvk) @@ -228,7 +228,7 @@ func TestCacheFor(t *testing.T) { }}) tests = append(tests, testCase{description: "CacheFor() should encrypt v1 Secrets", test: func(t *testing.T) { - dbClient := NewMockDBClient(gomock.NewController(t)) + dbClient := NewMockClient(gomock.NewController(t)) dynamicClient := NewMockResourceInterface(gomock.NewController(t)) fields := [][]string{{"something"}} expectedGVK := schema.GroupVersionKind{ @@ -247,7 +247,7 @@ func TestCacheFor(t *testing.T) { expectedC := Cache{ ByOptionsLister: i, } - testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) { + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db db.Client, shouldEncrypt, namespaced bool) (*informer.Informer, error) { assert.Equal(t, client, dynamicClient) assert.Equal(t, fields, fields) assert.Equal(t, expectedGVK, gvk) @@ -275,7 +275,7 @@ func TestCacheFor(t *testing.T) { time.Sleep(1 * time.Second) }}) tests = append(tests, testCase{description: "CacheFor() should encrypt management.cattle.io tokens", test: func(t *testing.T) { - dbClient := NewMockDBClient(gomock.NewController(t)) + dbClient := NewMockClient(gomock.NewController(t)) dynamicClient := NewMockResourceInterface(gomock.NewController(t)) fields := [][]string{{"something"}} expectedGVK := schema.GroupVersionKind{ @@ -294,7 +294,7 @@ func TestCacheFor(t *testing.T) { expectedC := Cache{ ByOptionsLister: i, } - testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt, namespaced bool) (*informer.Informer, error) { + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db db.Client, shouldEncrypt, namespaced bool) (*informer.Informer, error) { assert.Equal(t, client, dynamicClient) assert.Equal(t, fields, fields) assert.Equal(t, expectedGVK, gvk) @@ -323,7 +323,7 @@ func TestCacheFor(t *testing.T) { }}) tests = append(tests, testCase{description: "CacheFor() with no errors returned, HasSync returning true, stopCh not closed, and transform func should return no error", test: func(t *testing.T) { - dbClient := NewMockDBClient(gomock.NewController(t)) + dbClient := NewMockClient(gomock.NewController(t)) dynamicClient := NewMockResourceInterface(gomock.NewController(t)) fields := [][]string{{"something"}} expectedGVK := schema.GroupVersionKind{} @@ -341,7 +341,7 @@ func TestCacheFor(t *testing.T) { expectedC := Cache{ ByOptionsLister: i, } - testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*informer.Informer, error) { + testNewInformer := func(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db db.Client, shouldEncrypt bool, namespaced bool) (*informer.Informer, error) { // we can't test func == func, so instead we check if the output was as expected input := "someinput" ouput, err := transform(input) diff --git a/pkg/sqlcache/informer/factory/transaction_mocks_tests.go b/pkg/sqlcache/informer/factory/transaction_mocks_tests.go new file mode 100644 index 00000000..9a2b677b --- /dev/null +++ b/pkg/sqlcache/informer/factory/transaction_mocks_tests.go @@ -0,0 +1,75 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Client) +// +// Generated by this command: +// +// mockgen --build_flags=--mod=mod -package factory -destination ./transaction_mocks_tests.go -mock_names Client=MockTXClient github.com/rancher/steve/pkg/sqlcache/db/transaction Client +// + +// Package factory is a generated GoMock package. +package factory + +import ( + sql "database/sql" + reflect "reflect" + + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" + gomock "go.uber.org/mock/gomock" +) + +// MockTXClient is a mock of Client interface. +type MockTXClient struct { + ctrl *gomock.Controller + recorder *MockTXClientMockRecorder +} + +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient +} + +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) +} diff --git a/pkg/sqlcache/informer/indexer.go b/pkg/sqlcache/informer/indexer.go index 7ed4451b..210e1e2f 100644 --- a/pkg/sqlcache/informer/indexer.go +++ b/pkg/sqlcache/informer/indexer.go @@ -62,45 +62,34 @@ type Indexer struct { var _ cache.Indexer = (*Indexer)(nil) type Store interface { - DBClient + db.Client cache.Store GetByKey(key string) (item any, exists bool, err error) GetName() string - RegisterAfterUpsert(f func(key string, obj any, tx db.TXClient) error) - RegisterAfterDelete(f func(key string, tx db.TXClient) error) + RegisterAfterUpsert(f func(key string, obj any, tx transaction.Client) error) + RegisterAfterDelete(f func(key string, tx transaction.Client) error) GetShouldEncrypt() bool GetType() reflect.Type } -type DBClient interface { - BeginTx(ctx context.Context, forWriting bool) (db.TXClient, error) - QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) - ReadObjects(rows db.Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) - ReadStrings(rows db.Rows) ([]string, error) - ReadInt(rows db.Rows) (int, error) - Prepare(stmt string) *sql.Stmt - CloseStmt(stmt db.Closable) error -} - // NewIndexer returns a cache.Indexer backed by SQLite for objects of the given example type func NewIndexer(indexers cache.Indexers, s Store) (*Indexer, error) { - tx, err := s.BeginTx(context.Background(), true) - if err != nil { - return nil, err - } dbName := db.Sanitize(s.GetName()) - createTableQuery := fmt.Sprintf(createTableFmt, dbName) - err = tx.Exec(createTableQuery) - if err != nil { - return nil, &db.QueryError{QueryString: createTableQuery, Err: err} - } - createIndexQuery := fmt.Sprintf(createIndexFmt, dbName) - err = tx.Exec(createIndexQuery) - if err != nil { - return nil, &db.QueryError{QueryString: createIndexQuery, Err: err} - } - err = tx.Commit() + + err := s.WithTransaction(context.Background(), true, func(tx transaction.Client) error { + createTableQuery := fmt.Sprintf(createTableFmt, dbName) + _, err := tx.Exec(createTableQuery) + if err != nil { + return &db.QueryError{QueryString: createTableQuery, Err: err} + } + createIndexQuery := fmt.Sprintf(createIndexFmt, dbName) + _, err = tx.Exec(createIndexQuery) + if err != nil { + return &db.QueryError{QueryString: createIndexQuery, Err: err} + } + return nil + }) if err != nil { return nil, err } @@ -129,9 +118,9 @@ func NewIndexer(indexers cache.Indexers, s Store) (*Indexer, error) { /* Core methods */ // AfterUpsert updates indices of an object -func (i *Indexer) AfterUpsert(key string, obj any, tx db.TXClient) error { +func (i *Indexer) AfterUpsert(key string, obj any, tx transaction.Client) error { // delete all - err := tx.StmtExec(tx.Stmt(i.deleteIndicesStmt), key) + _, err := tx.Stmt(i.deleteIndicesStmt).Exec(key) if err != nil { return &db.QueryError{QueryString: i.deleteIndicesQuery, Err: err} } @@ -146,7 +135,7 @@ func (i *Indexer) AfterUpsert(key string, obj any, tx db.TXClient) error { } for _, value := range values { - err = tx.StmtExec(tx.Stmt(i.addIndexStmt), indexName, value, key) + _, err = tx.Stmt(i.addIndexStmt).Exec(indexName, value, key) if err != nil { return &db.QueryError{QueryString: i.addIndexQuery, Err: err} } diff --git a/pkg/sqlcache/informer/indexer_test.go b/pkg/sqlcache/informer/indexer_test.go index 4118118c..4c10b5de 100644 --- a/pkg/sqlcache/informer/indexer_test.go +++ b/pkg/sqlcache/informer/indexer_test.go @@ -13,13 +13,15 @@ import ( "reflect" "testing" + "github.com/rancher/steve/pkg/sqlcache/db" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" "k8s.io/client-go/tools/cache" ) //go:generate mockgen --build_flags=--mod=mod -package informer -destination ./sql_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer Store -//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Client +//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./transaction_mocks_test.go -mock_names Client=MockTXClient github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,Client type testStoreObject struct { Id string @@ -34,7 +36,7 @@ func TestNewIndexer(t *testing.T) { var tests []testCase - tests = append(tests, testCase{description: "NewIndexer() with no errors returned from Store or TXClient, should return no error", test: func(t *testing.T) { + tests = append(tests, testCase{description: "NewIndexer() with no errors returned from Store or Client, should return no error", test: func(t *testing.T) { store := NewMockStore(gomock.NewController(t)) client := NewMockTXClient(gomock.NewController(t)) @@ -45,11 +47,17 @@ func TestNewIndexer(t *testing.T) { }, } storeName := "someStoreName" - store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) + store.EXPECT().GetName().AnyTimes().Return(storeName) - client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil) - client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(nil) - client.EXPECT().Commit().Return(nil) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil, nil) + client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(nil, nil) + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(client) + if err != nil { + t.Fail() + } + }) store.EXPECT().RegisterAfterUpsert(gomock.Any()) store.EXPECT().Prepare(fmt.Sprintf(deleteIndicesFmt, storeName)) store.EXPECT().Prepare(fmt.Sprintf(addIndexFmt, storeName)) @@ -60,7 +68,7 @@ func TestNewIndexer(t *testing.T) { assert.Nil(t, err) assert.Equal(t, cache.Indexers(indexers), indexer.indexers) }}) - tests = append(tests, testCase{description: "NewIndexer() with Store Begin() error, should return error", test: func(t *testing.T) { + tests = append(tests, testCase{description: "NewIndexer() with WithTransaction() error, should return error", test: func(t *testing.T) { store := NewMockStore(gomock.NewController(t)) objKey := "objKey" @@ -69,11 +77,12 @@ func TestNewIndexer(t *testing.T) { return []string{objKey}, nil }, } - store.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error")) + store.EXPECT().GetName().AnyTimes().Return("someStoreName") + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")) _, err := NewIndexer(indexers, store) assert.NotNil(t, err) }}) - tests = append(tests, testCase{description: "NewIndexer() with TXClient Exec() error on first call to Exec(), should return error", test: func(t *testing.T) { + tests = append(tests, testCase{description: "NewIndexer() with Client Exec() error on first call to Exec(), should return error", test: func(t *testing.T) { store := NewMockStore(gomock.NewController(t)) client := NewMockTXClient(gomock.NewController(t)) @@ -84,13 +93,20 @@ func TestNewIndexer(t *testing.T) { }, } storeName := "someStoreName" - store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) store.EXPECT().GetName().AnyTimes().Return(storeName) - client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(fmt.Errorf("error")) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil, fmt.Errorf("error")) + + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(client) + if err == nil { + t.Fail() + } + }) _, err := NewIndexer(indexers, store) assert.NotNil(t, err) }}) - tests = append(tests, testCase{description: "NewIndexer() with TXClient Exec() error on second call to Exec(), should return error", test: func(t *testing.T) { + tests = append(tests, testCase{description: "NewIndexer() with Client Exec() error on second call to Exec(), should return error", test: func(t *testing.T) { store := NewMockStore(gomock.NewController(t)) client := NewMockTXClient(gomock.NewController(t)) @@ -101,14 +117,22 @@ func TestNewIndexer(t *testing.T) { }, } storeName := "someStoreName" - store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) store.EXPECT().GetName().AnyTimes().Return(storeName) - client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil) - client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(fmt.Errorf("error")) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil, nil) + client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(nil, fmt.Errorf("error")) + + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(client) + if err == nil { + t.Fail() + } + }) + _, err := NewIndexer(indexers, store) assert.NotNil(t, err) }}) - tests = append(tests, testCase{description: "NewIndexer() with TXClient Commit() error, should return error", test: func(t *testing.T) { + tests = append(tests, testCase{description: "NewIndexer() with Client Commit() error, should return error", test: func(t *testing.T) { store := NewMockStore(gomock.NewController(t)) client := NewMockTXClient(gomock.NewController(t)) @@ -119,11 +143,16 @@ func TestNewIndexer(t *testing.T) { }, } storeName := "someStoreName" - store.EXPECT().BeginTx(gomock.Any(), true).Return(client, nil) store.EXPECT().GetName().AnyTimes().Return(storeName) - client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil) - client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(nil) - client.EXPECT().Commit().Return(fmt.Errorf("error")) + client.EXPECT().Exec(fmt.Sprintf(createTableFmt, storeName, storeName)).Return(nil, nil) + client.EXPECT().Exec(fmt.Sprintf(createIndexFmt, storeName, storeName)).Return(nil, nil) + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(client) + if err != nil { + t.Fail() + } + }) _, err := NewIndexer(indexers, store) assert.NotNil(t, err) }}) @@ -141,16 +170,14 @@ func TestAfterUpsert(t *testing.T) { var tests []testCase - tests = append(tests, testCase{description: "AfterUpsert() with no errors returned from TXClient should return no error", test: func(t *testing.T) { + tests = append(tests, testCase{description: "AfterUpsert() with no errors returned from Client should return no error", test: func(t *testing.T) { store := NewMockStore(gomock.NewController(t)) client := NewMockTXClient(gomock.NewController(t)) - deleteStmt := &sql.Stmt{} - addStmt := &sql.Stmt{} objKey := "key" + deleteIndicesStmt := NewMockStmt(gomock.NewController(t)) + addIndexStmt := NewMockStmt(gomock.NewController(t)) indexer := &Indexer{ - Store: store, - deleteIndicesStmt: deleteStmt, - addIndexStmt: addStmt, + Store: store, indexers: map[string]cache.IndexFunc{ "a": func(obj interface{}) ([]string, error) { return []string{objKey}, nil @@ -158,24 +185,22 @@ func TestAfterUpsert(t *testing.T) { }, } key := "somekey" - client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt) - client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(nil) - client.EXPECT().Stmt(indexer.addIndexStmt).Return(indexer.addIndexStmt) - client.EXPECT().StmtExec(indexer.addIndexStmt, "a", objKey, key).Return(nil) + client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(deleteIndicesStmt) + deleteIndicesStmt.EXPECT().Exec(key).Return(nil, nil) + client.EXPECT().Stmt(indexer.addIndexStmt).Return(addIndexStmt) + addIndexStmt.EXPECT().Exec("a", objKey, key).Return(nil, nil) testObject := testStoreObject{Id: "something", Val: "a"} err := indexer.AfterUpsert(key, testObject, client) assert.Nil(t, err) }}) - tests = append(tests, testCase{description: "AfterUpsert() with error returned from TXClient StmtExec() should return an error", test: func(t *testing.T) { + tests = append(tests, testCase{description: "AfterUpsert() with error returned from Client StmtExec() should return an error", test: func(t *testing.T) { store := NewMockStore(gomock.NewController(t)) client := NewMockTXClient(gomock.NewController(t)) - deleteStmt := &sql.Stmt{} - addStmt := &sql.Stmt{} objKey := "key" + deleteIndicesStmt := NewMockStmt(gomock.NewController(t)) indexer := &Indexer{ - Store: store, - deleteIndicesStmt: deleteStmt, - addIndexStmt: addStmt, + Store: store, + indexers: map[string]cache.IndexFunc{ "a": func(obj interface{}) ([]string, error) { return []string{objKey}, nil @@ -183,22 +208,20 @@ func TestAfterUpsert(t *testing.T) { }, } key := "somekey" - client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt) - client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(fmt.Errorf("error")) + client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(deleteIndicesStmt) + deleteIndicesStmt.EXPECT().Exec(key).Return(nil, fmt.Errorf("error")) testObject := testStoreObject{Id: "something", Val: "a"} err := indexer.AfterUpsert(key, testObject, client) assert.NotNil(t, err) }}) - tests = append(tests, testCase{description: "AfterUpsert() with error returned from TXClient second StmtExec() call should return an error", test: func(t *testing.T) { + tests = append(tests, testCase{description: "AfterUpsert() with error returned from Client second StmtExec() call should return an error", test: func(t *testing.T) { store := NewMockStore(gomock.NewController(t)) client := NewMockTXClient(gomock.NewController(t)) - deleteStmt := &sql.Stmt{} - addStmt := &sql.Stmt{} + deleteIndicesStmt := NewMockStmt(gomock.NewController(t)) + addIndexStmt := NewMockStmt(gomock.NewController(t)) objKey := "key" indexer := &Indexer{ - Store: store, - deleteIndicesStmt: deleteStmt, - addIndexStmt: addStmt, + Store: store, indexers: map[string]cache.IndexFunc{ "a": func(obj interface{}) ([]string, error) { return []string{objKey}, nil @@ -206,10 +229,10 @@ func TestAfterUpsert(t *testing.T) { }, } key := "somekey" - client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(indexer.deleteIndicesStmt) - client.EXPECT().StmtExec(indexer.deleteIndicesStmt, key).Return(nil) - client.EXPECT().Stmt(indexer.addIndexStmt).Return(indexer.addIndexStmt) - client.EXPECT().StmtExec(indexer.addIndexStmt, "a", objKey, key).Return(fmt.Errorf("error")) + client.EXPECT().Stmt(indexer.deleteIndicesStmt).Return(deleteIndicesStmt) + deleteIndicesStmt.EXPECT().Exec(key).Return(nil, nil) + client.EXPECT().Stmt(indexer.addIndexStmt).Return(addIndexStmt) + addIndexStmt.EXPECT().Exec("a", objKey, key).Return(nil, fmt.Errorf("error")) testObject := testStoreObject{Id: "something", Val: "a"} err := indexer.AfterUpsert(key, testObject, client) assert.NotNil(t, err) diff --git a/pkg/sqlcache/informer/informer.go b/pkg/sqlcache/informer/informer.go index 8cf0355f..a54d9980 100644 --- a/pkg/sqlcache/informer/informer.go +++ b/pkg/sqlcache/informer/informer.go @@ -8,6 +8,7 @@ import ( "context" "time" + "github.com/rancher/steve/pkg/sqlcache/db" "github.com/rancher/steve/pkg/sqlcache/partition" sqlStore "github.com/rancher/steve/pkg/sqlcache/store" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -34,7 +35,7 @@ var newInformer = cache.NewSharedIndexInformer // NewInformer returns a new SQLite-backed Informer for the type specified by schema in unstructured.Unstructured form // using the specified client -func NewInformer(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db sqlStore.DBClient, shouldEncrypt bool, namespaced bool) (*Informer, error) { +func NewInformer(client dynamic.ResourceInterface, fields [][]string, transform cache.TransformFunc, gvk schema.GroupVersionKind, db db.Client, shouldEncrypt bool, namespaced bool) (*Informer, error) { listWatcher := &cache.ListWatch{ ListFunc: func(options metav1.ListOptions) (runtime.Object, error) { a, err := client.List(context.Background(), options) diff --git a/pkg/sqlcache/informer/informer_test.go b/pkg/sqlcache/informer/informer_test.go index 5337ee8e..60509f4d 100644 --- a/pkg/sqlcache/informer/informer_test.go +++ b/pkg/sqlcache/informer/informer_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/rancher/steve/pkg/sqlcache/db" "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -20,7 +21,6 @@ import ( //go:generate mockgen --build_flags=--mod=mod -package informer -destination ./informer_mocks_test.go github.com/rancher/steve/pkg/sqlcache/informer ByOptionsLister //go:generate mockgen --build_flags=--mod=mod -package informer -destination ./dynamic_mocks_test.go k8s.io/client-go/dynamic ResourceInterface -//go:generate mockgen --build_flags=--mod=mod -package informer -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient func TestNewInformer(t *testing.T) { type testCase struct { @@ -31,7 +31,7 @@ func TestNewInformer(t *testing.T) { var tests []testCase tests = append(tests, testCase{description: "NewInformer() with no errors returned, should return no error", test: func(t *testing.T) { - dbClient := NewMockDBClient(gomock.NewController(t)) + dbClient := NewMockClient(gomock.NewController(t)) txClient := NewMockTXClient(gomock.NewController(t)) dynamicClient := NewMockResourceInterface(gomock.NewController(t)) @@ -40,29 +40,44 @@ func TestNewInformer(t *testing.T) { // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore // is tested in depth in its own package. - dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + dbClient.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer // is tested in depth in its own indexer_test.go - dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + dbClient.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) // NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer // is tested in depth in its own indexer_test.go - dbClient.EXPECT().BeginTx(context.Background(), true).Return(txClient, nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + dbClient.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) informer, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) assert.Nil(t, err) @@ -70,7 +85,7 @@ func TestNewInformer(t *testing.T) { assert.NotNil(t, informer.SharedIndexInformer) }}) tests = append(tests, testCase{description: "NewInformer() with errors returned from NewStore(), should return an error", test: func(t *testing.T) { - dbClient := NewMockDBClient(gomock.NewController(t)) + dbClient := NewMockClient(gomock.NewController(t)) txClient := NewMockTXClient(gomock.NewController(t)) dynamicClient := NewMockResourceInterface(gomock.NewController(t)) @@ -79,15 +94,20 @@ func TestNewInformer(t *testing.T) { // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore // is tested in depth in its own package. - dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + dbClient.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) _, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) assert.NotNil(t, err) }}) tests = append(tests, testCase{description: "NewInformer() with errors returned from NewIndexer(), should return an error", test: func(t *testing.T) { - dbClient := NewMockDBClient(gomock.NewController(t)) + dbClient := NewMockClient(gomock.NewController(t)) txClient := NewMockTXClient(gomock.NewController(t)) dynamicClient := NewMockResourceInterface(gomock.NewController(t)) @@ -96,23 +116,33 @@ func TestNewInformer(t *testing.T) { // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore // is tested in depth in its own package. - dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + dbClient.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer // is tested in depth in its own indexer_test.go - dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + dbClient.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) _, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) assert.NotNil(t, err) }}) tests = append(tests, testCase{description: "NewInformer() with errors returned from NewListOptionIndexer(), should return an error", test: func(t *testing.T) { - dbClient := NewMockDBClient(gomock.NewController(t)) + dbClient := NewMockClient(gomock.NewController(t)) txClient := NewMockTXClient(gomock.NewController(t)) dynamicClient := NewMockResourceInterface(gomock.NewController(t)) @@ -121,35 +151,50 @@ func TestNewInformer(t *testing.T) { // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore // is tested in depth in its own package. - dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + dbClient.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer // is tested in depth in its own indexer_test.go - dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + dbClient.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) // NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer // is tested in depth in its own indexer_test.go - dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + dbClient.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) _, err := NewInformer(dynamicClient, fields, nil, gvk, dbClient, false, true) assert.NotNil(t, err) }}) tests = append(tests, testCase{description: "NewInformer() with transform func", test: func(t *testing.T) { - dbClient := NewMockDBClient(gomock.NewController(t)) + dbClient := NewMockClient(gomock.NewController(t)) txClient := NewMockTXClient(gomock.NewController(t)) dynamicClient := NewMockResourceInterface(gomock.NewController(t)) mockInformer := mockInformer{} @@ -166,29 +211,44 @@ func TestNewInformer(t *testing.T) { // NewStore() from store package logic. This package is only concerned with whether it returns err or not as NewStore // is tested in depth in its own package. - dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + dbClient.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) dbClient.EXPECT().Prepare(gomock.Any()).Return(&sql.Stmt{}).AnyTimes() // NewIndexer() logic (within NewListOptionIndexer(). This test is only concerned with whether it returns err or not as NewIndexer // is tested in depth in its own indexer_test.go - dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + dbClient.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) // NewListOptionIndexer() logic. This test is only concerned with whether it returns err or not as NewIndexer // is tested in depth in its own indexer_test.go - dbClient.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + dbClient.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) transformFunc := func(input interface{}) (interface{}, error) { return "someoutput", nil @@ -210,7 +270,7 @@ func TestNewInformer(t *testing.T) { newInformer = cache.NewSharedIndexInformer }}) tests = append(tests, testCase{description: "NewInformer() unable to set transform func", test: func(t *testing.T) { - dbClient := NewMockDBClient(gomock.NewController(t)) + dbClient := NewMockClient(gomock.NewController(t)) dynamicClient := NewMockResourceInterface(gomock.NewController(t)) mockInformer := mockInformer{ setTranformErr: fmt.Errorf("some error"), diff --git a/pkg/sqlcache/informer/listoption_indexer.go b/pkg/sqlcache/informer/listoption_indexer.go index 93ce5f08..0131efd4 100644 --- a/pkg/sqlcache/informer/listoption_indexer.go +++ b/pkg/sqlcache/informer/listoption_indexer.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" + "github.com/rancher/steve/pkg/sqlcache/db/transaction" "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/client-go/tools/cache" @@ -108,51 +109,49 @@ func NewListOptionIndexer(fields [][]string, s Store, namespaced bool) (*ListOpt columnDefs[index] = column } - tx, err := l.BeginTx(context.Background(), true) - if err != nil { - return nil, err - } dbName := db.Sanitize(i.GetName()) - err = tx.Exec(fmt.Sprintf(createFieldsTableFmt, dbName, strings.Join(columnDefs, ", "))) - if err != nil { - return nil, err - } - columns := make([]string, len(indexedFields)) qmarks := make([]string, len(indexedFields)) setStatements := make([]string, len(indexedFields)) - for index, field := range indexedFields { - // create index for field - err = tx.Exec(fmt.Sprintf(createFieldsIndexFmt, dbName, field, dbName, field)) + err = l.WithTransaction(context.Background(), true, func(tx transaction.Client) error { + _, err = tx.Exec(fmt.Sprintf(createFieldsTableFmt, dbName, strings.Join(columnDefs, ", "))) if err != nil { - return nil, err + return err } - // format field into column for prepared statement - column := fmt.Sprintf(`"%s"`, field) - columns[index] = column + for index, field := range indexedFields { + // create index for field + _, err = tx.Exec(fmt.Sprintf(createFieldsIndexFmt, dbName, field, dbName, field)) + if err != nil { + return err + } - // add placeholder for column's value in prepared statement - qmarks[index] = "?" + // format field into column for prepared statement + column := fmt.Sprintf(`"%s"`, field) + columns[index] = column - // add formatted set statement for prepared statement - setStatement := fmt.Sprintf(`"%s" = excluded."%s"`, field, field) - setStatements[index] = setStatement - } - createLabelsTableQuery := fmt.Sprintf(createLabelsTableFmt, dbName, dbName) - err = tx.Exec(createLabelsTableQuery) - if err != nil { - return nil, &db.QueryError{QueryString: createLabelsTableQuery, Err: err} - } + // add placeholder for column's value in prepared statement + qmarks[index] = "?" - createLabelsTableIndexQuery := fmt.Sprintf(createLabelsTableIndexFmt, dbName, dbName) - err = tx.Exec(createLabelsTableIndexQuery) - if err != nil { - return nil, &db.QueryError{QueryString: createLabelsTableIndexQuery, Err: err} - } + // add formatted set statement for prepared statement + setStatement := fmt.Sprintf(`"%s" = excluded."%s"`, field, field) + setStatements[index] = setStatement + } + createLabelsTableQuery := fmt.Sprintf(createLabelsTableFmt, dbName, dbName) + _, err = tx.Exec(createLabelsTableQuery) + if err != nil { + return &db.QueryError{QueryString: createLabelsTableQuery, Err: err} + } - err = tx.Commit() + createLabelsTableIndexQuery := fmt.Sprintf(createLabelsTableIndexFmt, dbName, dbName) + _, err = tx.Exec(createLabelsTableIndexQuery) + if err != nil { + return &db.QueryError{QueryString: createLabelsTableIndexQuery, Err: err} + } + + return nil + }) if err != nil { return nil, err } @@ -180,16 +179,12 @@ func NewListOptionIndexer(fields [][]string, s Store, namespaced bool) (*ListOpt /* Core methods */ // addIndexFields saves sortable/filterable fields into tables -func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx db.TXClient) error { +func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx transaction.Client) error { args := []any{key} for _, field := range l.indexedFields { value, err := getField(obj, field) if err != nil { logrus.Errorf("cannot index object of type [%s] with key [%s] for indexer [%s]: %v", l.GetType().String(), key, l.GetName(), err) - cErr := tx.Cancel() - if cErr != nil { - return fmt.Errorf("could not cancel transaction: %s while recovering from error: %w", cErr, err) - } return err } switch typedValue := value.(type) { @@ -201,15 +196,11 @@ func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx db.TXClient) args = append(args, strings.Join(typedValue, "|")) default: err2 := fmt.Errorf("field %v has a non-supported type value: %v", field, value) - cErr := tx.Cancel() - if cErr != nil { - return fmt.Errorf("could not cancel transaction: %s while recovering from error: %w", cErr, err2) - } return err2 } } - err := tx.StmtExec(tx.Stmt(l.addFieldStmt), args...) + _, err := tx.Stmt(l.addFieldStmt).Exec(args...) if err != nil { return &db.QueryError{QueryString: l.addFieldQuery, Err: err} } @@ -217,14 +208,14 @@ func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx db.TXClient) } // labels are stored in tables that shadow the underlying object table for each GVK -func (l *ListOptionIndexer) addLabels(key string, obj any, tx db.TXClient) error { +func (l *ListOptionIndexer) addLabels(key string, obj any, tx transaction.Client) error { k8sObj, ok := obj.(*unstructured.Unstructured) if !ok { return fmt.Errorf("addLabels: unexpected object type, expected unstructured.Unstructured: %v", obj) } incomingLabels := k8sObj.GetLabels() for k, v := range incomingLabels { - err := tx.StmtExec(tx.Stmt(l.upsertLabelsStmt), key, k, v) + _, err := tx.Stmt(l.upsertLabelsStmt).Exec(key, k, v) if err != nil { return &db.QueryError{QueryString: l.upsertLabelsQuery, Err: err} } @@ -232,18 +223,18 @@ func (l *ListOptionIndexer) addLabels(key string, obj any, tx db.TXClient) error return nil } -func (l *ListOptionIndexer) deleteIndexFields(key string, tx db.TXClient) error { +func (l *ListOptionIndexer) deleteIndexFields(key string, tx transaction.Client) error { args := []any{key} - err := tx.StmtExec(tx.Stmt(l.deleteFieldStmt), args...) + _, err := tx.Stmt(l.deleteFieldStmt).Exec(args...) if err != nil { return &db.QueryError{QueryString: l.deleteFieldQuery, Err: err} } return nil } -func (l *ListOptionIndexer) deleteLabels(key string, tx db.TXClient) error { - err := tx.StmtExec(tx.Stmt(l.deleteLabelsStmt), key) +func (l *ListOptionIndexer) deleteLabels(key string, tx transaction.Client) error { + _, err := tx.Stmt(l.deleteLabelsStmt).Exec(key) if err != nil { return &db.QueryError{QueryString: l.deleteLabelsQuery, Err: err} } @@ -467,48 +458,37 @@ func (l *ListOptionIndexer) executeQuery(ctx context.Context, queryInfo *QueryIn stmt := l.Prepare(queryInfo.query) defer l.CloseStmt(stmt) - tx, err := l.BeginTx(ctx, false) - if err != nil { - return nil, 0, "", err - } - - txStmt := tx.Stmt(stmt) - rows, err := txStmt.QueryContext(ctx, queryInfo.params...) - if err != nil { - if cerr := tx.Cancel(); cerr != nil { - return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) - } - return nil, 0, "", &db.QueryError{QueryString: queryInfo.query, Err: err} - } - items, err := l.ReadObjects(rows, l.GetType(), l.GetShouldEncrypt()) - if err != nil { - if cerr := tx.Cancel(); cerr != nil { - return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) - } - return nil, 0, "", err - } - - total := len(items) - if queryInfo.countQuery != "" { - countStmt := l.Prepare(queryInfo.countQuery) - defer l.CloseStmt(countStmt) - txStmt := tx.Stmt(countStmt) - rows, err := txStmt.QueryContext(ctx, queryInfo.countParams...) + var items []any + var total int + err := l.WithTransaction(ctx, false, func(tx transaction.Client) error { + txStmt := tx.Stmt(stmt) + rows, err := txStmt.QueryContext(ctx, queryInfo.params...) if err != nil { - if cerr := tx.Cancel(); cerr != nil { - return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) - } - return nil, 0, "", &db.QueryError{QueryString: queryInfo.countQuery, Err: err} + return &db.QueryError{QueryString: queryInfo.query, Err: err} } - total, err = l.ReadInt(rows) + items, err = l.ReadObjects(rows, l.GetType(), l.GetShouldEncrypt()) if err != nil { - if cerr := tx.Cancel(); cerr != nil { - return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err) - } - return nil, 0, "", fmt.Errorf("error reading query results: %w", err) + return err } - } - if err := tx.Commit(); err != nil { + + total = len(items) + if queryInfo.countQuery != "" { + countStmt := l.Prepare(queryInfo.countQuery) + defer l.CloseStmt(countStmt) + txStmt := tx.Stmt(countStmt) + rows, err := txStmt.QueryContext(ctx, queryInfo.countParams...) + if err != nil { + return &db.QueryError{QueryString: queryInfo.countQuery, Err: err} + } + total, err = l.ReadInt(rows) + if err != nil { + return fmt.Errorf("error reading query results: %w", err) + } + } + + return nil + }) + if err != nil { return nil, 0, "", err } diff --git a/pkg/sqlcache/informer/listoption_indexer_test.go b/pkg/sqlcache/informer/listoption_indexer_test.go index 5352cebd..fa36db55 100644 --- a/pkg/sqlcache/informer/listoption_indexer_test.go +++ b/pkg/sqlcache/informer/listoption_indexer_test.go @@ -15,6 +15,7 @@ import ( "strings" "testing" + "github.com/rancher/steve/pkg/sqlcache/db" "github.com/rancher/steve/pkg/sqlcache/partition" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -37,11 +38,16 @@ func TestNewListOptionIndexer(t *testing.T) { id := "somename" stmt := &sql.Stmt{} // logic for NewIndexer(), only interested in if this results in error or not - store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) store.EXPECT().GetName().Return(id).AnyTimes() - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) store.EXPECT().RegisterAfterUpsert(gomock.Any()) store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() // end NewIndexer() logic @@ -49,17 +55,22 @@ func TestNewListOptionIndexer(t *testing.T) { store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) - store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) // create field table - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) // create field table indexes - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableFmt, id, id)).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableIndexFmt, id, id)).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableFmt, id, id)).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableIndexFmt, id, id)).Return(nil, nil) + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) loi, err := NewListOptionIndexer(fields, store, true) assert.Nil(t, err) @@ -71,11 +82,16 @@ func TestNewListOptionIndexer(t *testing.T) { fields := [][]string{{"something"}} id := "somename" // logic for NewIndexer(), only interested in if this results in error or not - store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) store.EXPECT().GetName().Return(id).AnyTimes() - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) _, err := NewListOptionIndexer(fields, store, false) assert.NotNil(t, err) @@ -87,11 +103,16 @@ func TestNewListOptionIndexer(t *testing.T) { id := "somename" stmt := &sql.Stmt{} // logic for NewIndexer(), only interested in if this results in error or not - store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) store.EXPECT().GetName().Return(id).AnyTimes() - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) store.EXPECT().RegisterAfterUpsert(gomock.Any()) store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() // end NewIndexer() logic @@ -99,7 +120,7 @@ func TestNewListOptionIndexer(t *testing.T) { store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) - store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, fmt.Errorf("error")) + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")) _, err := NewListOptionIndexer(fields, store, false) assert.NotNil(t, err) @@ -111,11 +132,16 @@ func TestNewListOptionIndexer(t *testing.T) { id := "somename" stmt := &sql.Stmt{} // logic for NewIndexer(), only interested in if this results in error or not - store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) store.EXPECT().GetName().Return(id).AnyTimes() - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) store.EXPECT().RegisterAfterUpsert(gomock.Any()) store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() // end NewIndexer() logic @@ -123,9 +149,15 @@ func TestNewListOptionIndexer(t *testing.T) { store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) - store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(fmt.Errorf("error")) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, fmt.Errorf("error")) + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err == nil { + t.Fail() + } + }) _, err := NewListOptionIndexer(fields, store, true) assert.NotNil(t, err) @@ -137,11 +169,16 @@ func TestNewListOptionIndexer(t *testing.T) { id := "somename" stmt := &sql.Stmt{} // logic for NewIndexer(), only interested in if this results in error or not - store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) store.EXPECT().GetName().Return(id).AnyTimes() - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) store.EXPECT().RegisterAfterUpsert(gomock.Any()) store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() // end NewIndexer() logic @@ -149,13 +186,19 @@ func TestNewListOptionIndexer(t *testing.T) { store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) - store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableFmt, id, id)).Return(fmt.Errorf("error")) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableFmt, id, id)).Return(nil, fmt.Errorf("error")) + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err == nil { + t.Fail() + } + }) _, err := NewListOptionIndexer(fields, store, true) assert.NotNil(t, err) @@ -167,11 +210,16 @@ func TestNewListOptionIndexer(t *testing.T) { id := "somename" stmt := &sql.Stmt{} // logic for NewIndexer(), only interested in if this results in error or not - store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) store.EXPECT().GetName().Return(id).AnyTimes() - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Exec(gomock.Any()).Return(nil) - txClient.EXPECT().Commit().Return(nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + txClient.EXPECT().Exec(gomock.Any()).Return(nil, nil) + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) store.EXPECT().RegisterAfterUpsert(gomock.Any()) store.EXPECT().Prepare(gomock.Any()).Return(stmt).AnyTimes() // end NewIndexer() logic @@ -179,15 +227,20 @@ func TestNewListOptionIndexer(t *testing.T) { store.EXPECT().RegisterAfterUpsert(gomock.Any()).Times(2) store.EXPECT().RegisterAfterDelete(gomock.Any()).Times(2) - store.EXPECT().BeginTx(gomock.Any(), true).Return(txClient, nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableFmt, id, id)).Return(nil) - txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableIndexFmt, id, id)).Return(nil) - txClient.EXPECT().Commit().Return(fmt.Errorf("error")) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsTableFmt, id, `"metadata.name" TEXT, "metadata.creationTimestamp" TEXT, "metadata.namespace" TEXT, "something" TEXT`)).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.name", id, "metadata.name")).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.namespace", id, "metadata.namespace")).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, "metadata.creationTimestamp", id, "metadata.creationTimestamp")).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createFieldsIndexFmt, id, fields[0][0], id, fields[0][0])).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableFmt, id, id)).Return(nil, nil) + txClient.EXPECT().Exec(fmt.Sprintf(createLabelsTableIndexFmt, id, id)).Return(nil, nil) + store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if err != nil { + t.Fail() + } + }) _, err := NewListOptionIndexer(fields, store, true) assert.NotNil(t, err) @@ -876,7 +929,6 @@ func TestListByOptions(t *testing.T) { stmt := &sql.Stmt{} rows := &sql.Rows{} objType := reflect.TypeOf(testObject) - store.EXPECT().BeginTx(gomock.Any(), false).Return(txClient, nil) txClient.EXPECT().Stmt(gomock.Any()).Return(stmts).AnyTimes() store.EXPECT().Prepare(test.expectedStmt).Do(func(a ...any) { fmt.Println(a) @@ -895,13 +947,21 @@ func TestListByOptions(t *testing.T) { store.EXPECT().ReadObjects(rows, objType, false).Return(test.returnList, nil) store.EXPECT().CloseStmt(stmt).Return(nil) + store.EXPECT().WithTransaction(gomock.Any(), false, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txClient) + if test.expectedErr == nil { + assert.Nil(t, err) + } else { + assert.Equal(t, test.expectedErr, err) + } + }) + if test.expectedCountStmt != "" { store.EXPECT().Prepare(test.expectedCountStmt).Return(stmt) - //store.EXPECT().QueryForRows(context.TODO(), stmt, test.expectedCountStmtArgs...).Return(rows, nil) store.EXPECT().ReadInt(rows).Return(len(test.expectedList.Items), nil) store.EXPECT().CloseStmt(stmt).Return(nil) } - txClient.EXPECT().Commit() list, total, contToken, err := lii.executeQuery(context.TODO(), queryInfo) if test.expectedErr == nil { assert.Nil(t, err) diff --git a/pkg/sqlcache/informer/sql_mocks_test.go b/pkg/sqlcache/informer/sql_mocks_test.go index c269b01b..b8f7c578 100644 --- a/pkg/sqlcache/informer/sql_mocks_test.go +++ b/pkg/sqlcache/informer/sql_mocks_test.go @@ -56,21 +56,6 @@ func (mr *MockStoreMockRecorder) Add(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockStore)(nil).Add), arg0) } -// BeginTx mocks base method. -func (m *MockStore) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) - ret0, _ := ret[0].(db.TXClient) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// BeginTx indicates an expected call of BeginTx. -func (mr *MockStoreMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockStore)(nil).BeginTx), arg0, arg1) -} - // CloseStmt mocks base method. func (m *MockStore) CloseStmt(arg0 db.Closable) error { m.ctrl.T.Helper() @@ -201,6 +186,20 @@ func (mr *MockStoreMockRecorder) ListKeys() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListKeys", reflect.TypeOf((*MockStore)(nil).ListKeys)) } +// NewConnection mocks base method. +func (m *MockStore) NewConnection() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewConnection") + ret0, _ := ret[0].(error) + return ret0 +} + +// NewConnection indicates an expected call of NewConnection. +func (mr *MockStoreMockRecorder) NewConnection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockStore)(nil).NewConnection)) +} + // Prepare mocks base method. func (m *MockStore) Prepare(arg0 string) *sql.Stmt { m.ctrl.T.Helper() @@ -281,7 +280,7 @@ func (mr *MockStoreMockRecorder) ReadStrings(arg0 any) *gomock.Call { } // RegisterAfterDelete mocks base method. -func (m *MockStore) RegisterAfterDelete(arg0 func(string, db.TXClient) error) { +func (m *MockStore) RegisterAfterDelete(arg0 func(string, transaction.Client) error) { m.ctrl.T.Helper() m.ctrl.Call(m, "RegisterAfterDelete", arg0) } @@ -293,7 +292,7 @@ func (mr *MockStoreMockRecorder) RegisterAfterDelete(arg0 any) *gomock.Call { } // RegisterAfterUpsert mocks base method. -func (m *MockStore) RegisterAfterUpsert(arg0 func(string, any, db.TXClient) error) { +func (m *MockStore) RegisterAfterUpsert(arg0 func(string, any, transaction.Client) error) { m.ctrl.T.Helper() m.ctrl.Call(m, "RegisterAfterUpsert", arg0) } @@ -345,3 +344,31 @@ func (mr *MockStoreMockRecorder) Update(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockStore)(nil).Update), arg0) } + +// Upsert mocks base method. +func (m *MockStore) Upsert(arg0 transaction.Client, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// Upsert indicates an expected call of Upsert. +func (mr *MockStoreMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockStore)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) +} + +// WithTransaction mocks base method. +func (m *MockStore) WithTransaction(arg0 context.Context, arg1 bool, arg2 db.WithTransactionFunction) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithTransaction", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// WithTransaction indicates an expected call of WithTransaction. +func (mr *MockStoreMockRecorder) WithTransaction(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTransaction", reflect.TypeOf((*MockStore)(nil).WithTransaction), arg0, arg1, arg2) +} diff --git a/pkg/sqlcache/informer/store_mocks_test.go b/pkg/sqlcache/informer/store_mocks_test.go deleted file mode 100644 index c1c7d426..00000000 --- a/pkg/sqlcache/informer/store_mocks_test.go +++ /dev/null @@ -1,165 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/steve/pkg/sqlcache/store (interfaces: DBClient) -// -// Generated by this command: -// -// mockgen --build_flags=--mod=mod -package informer -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient -// - -// Package informer is a generated GoMock package. -package informer - -import ( - context "context" - sql "database/sql" - reflect "reflect" - - db "github.com/rancher/steve/pkg/sqlcache/db" - transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" - gomock "go.uber.org/mock/gomock" -) - -// MockDBClient is a mock of DBClient interface. -type MockDBClient struct { - ctrl *gomock.Controller - recorder *MockDBClientMockRecorder -} - -// MockDBClientMockRecorder is the mock recorder for MockDBClient. -type MockDBClientMockRecorder struct { - mock *MockDBClient -} - -// NewMockDBClient creates a new mock instance. -func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient { - mock := &MockDBClient{ctrl: ctrl} - mock.recorder = &MockDBClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder { - return m.recorder -} - -// BeginTx mocks base method. -func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) - ret0, _ := ret[0].(db.TXClient) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// BeginTx indicates an expected call of BeginTx. -func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1) -} - -// CloseStmt mocks base method. -func (m *MockDBClient) CloseStmt(arg0 db.Closable) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CloseStmt", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// CloseStmt indicates an expected call of CloseStmt. -func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0) -} - -// Prepare mocks base method. -func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Prepare", arg0) - ret0, _ := ret[0].(*sql.Stmt) - return ret0 -} - -// Prepare indicates an expected call of Prepare. -func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0) -} - -// QueryForRows mocks base method. -func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { - m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "QueryForRows", varargs...) - ret0, _ := ret[0].(*sql.Rows) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// QueryForRows indicates an expected call of QueryForRows. -func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...) -} - -// ReadInt mocks base method. -func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadInt", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ReadInt indicates an expected call of ReadInt. -func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0) -} - -// ReadObjects mocks base method. -func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) - ret0, _ := ret[0].([]any) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ReadObjects indicates an expected call of ReadObjects. -func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2) -} - -// ReadStrings mocks base method. -func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadStrings", arg0) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ReadStrings indicates an expected call of ReadStrings. -func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0) -} - -// Upsert mocks base method. -func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 -} - -// Upsert indicates an expected call of Upsert. -func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) -} diff --git a/pkg/sqlcache/db/transaction/transaction_mocks_test.go b/pkg/sqlcache/informer/transaction_mocks_test.go similarity index 64% rename from pkg/sqlcache/db/transaction/transaction_mocks_test.go rename to pkg/sqlcache/informer/transaction_mocks_test.go index 0d7fdaa7..64e885ee 100644 --- a/pkg/sqlcache/db/transaction/transaction_mocks_test.go +++ b/pkg/sqlcache/informer/transaction_mocks_test.go @@ -1,19 +1,20 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx) +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,Client) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx +// mockgen --build_flags=--mod=mod -package informer -destination ./transaction_mocks_test.go -mock_names Client=MockTXClient github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,Client // -// Package transaction is a generated GoMock package. -package transaction +// Package informer is a generated GoMock package. +package informer import ( context "context" sql "database/sql" reflect "reflect" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) @@ -98,45 +99,31 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) } -// MockSQLTx is a mock of SQLTx interface. -type MockSQLTx struct { +// MockTXClient is a mock of Client interface. +type MockTXClient struct { ctrl *gomock.Controller - recorder *MockSQLTxMockRecorder + recorder *MockTXClientMockRecorder } -// MockSQLTxMockRecorder is the mock recorder for MockSQLTx. -type MockSQLTxMockRecorder struct { - mock *MockSQLTx +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient } -// NewMockSQLTx creates a new mock instance. -func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx { - mock := &MockSQLTx{ctrl: ctrl} - mock.recorder = &MockSQLTxMockRecorder{mock} +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder { +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { return m.recorder } -// Commit mocks base method. -func (m *MockSQLTx) Commit() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Commit") - ret0, _ := ret[0].(error) - return ret0 -} - -// Commit indicates an expected call of Commit. -func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit)) -} - // Exec mocks base method. -func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) { +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) (sql.Result, error) { m.ctrl.T.Helper() varargs := []any{arg0} for _, a := range arg1 { @@ -149,36 +136,22 @@ func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) { } // Exec indicates an expected call of Exec. -func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{arg0}, arg1...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...) -} - -// Rollback mocks base method. -func (m *MockSQLTx) Rollback() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Rollback") - ret0, _ := ret[0].(error) - return ret0 -} - -// Rollback indicates an expected call of Rollback. -func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) } // Stmt mocks base method. -func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt { +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Stmt", arg0) - ret0, _ := ret[0].(*sql.Stmt) + ret0, _ := ret[0].(transaction.Stmt) return ret0 } // Stmt indicates an expected call of Stmt. -func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call { +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) } diff --git a/pkg/sqlcache/informer/tx_mocks_test.go b/pkg/sqlcache/informer/tx_mocks_test.go deleted file mode 100644 index 9383411d..00000000 --- a/pkg/sqlcache/informer/tx_mocks_test.go +++ /dev/null @@ -1,99 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt) -// -// Generated by this command: -// -// mockgen --build_flags=--mod=mod -package informer -destination ./pkg/cache/sql/informer/tx_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt -// - -// Package informer is a generated GoMock package. -package informer - -import ( - context "context" - sql "database/sql" - reflect "reflect" - - gomock "go.uber.org/mock/gomock" -) - -// MockStmt is a mock of Stmt interface. -type MockStmt struct { - ctrl *gomock.Controller - recorder *MockStmtMockRecorder -} - -// MockStmtMockRecorder is the mock recorder for MockStmt. -type MockStmtMockRecorder struct { - mock *MockStmt -} - -// NewMockStmt creates a new mock instance. -func NewMockStmt(ctrl *gomock.Controller) *MockStmt { - mock := &MockStmt{ctrl: ctrl} - mock.recorder = &MockStmtMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockStmt) EXPECT() *MockStmtMockRecorder { - return m.recorder -} - -// Exec mocks base method. -func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) { - m.ctrl.T.Helper() - varargs := []any{} - for _, a := range arg0 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Exec", varargs...) - ret0, _ := ret[0].(sql.Result) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Exec indicates an expected call of Exec. -func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...) -} - -// Query mocks base method. -func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) { - m.ctrl.T.Helper() - varargs := []any{} - for _, a := range arg0 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Query", varargs...) - ret0, _ := ret[0].(*sql.Rows) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Query indicates an expected call of Query. -func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...) -} - -// QueryContext mocks base method. -func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) { - m.ctrl.T.Helper() - varargs := []any{arg0} - for _, a := range arg1 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "QueryContext", varargs...) - ret0, _ := ret[0].(*sql.Rows) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// QueryContext indicates an expected call of QueryContext. -func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0}, arg1...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) -} diff --git a/pkg/sqlcache/store/db_mocks_test.go b/pkg/sqlcache/store/db_mocks_test.go index 75f70b6e..fb938c0d 100644 --- a/pkg/sqlcache/store/db_mocks_test.go +++ b/pkg/sqlcache/store/db_mocks_test.go @@ -1,125 +1,24 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: TXClient,Rows) +// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: Rows,Client) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows +// mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Client // // Package store is a generated GoMock package. package store import ( + context "context" sql "database/sql" reflect "reflect" + db "github.com/rancher/steve/pkg/sqlcache/db" transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) -// MockTXClient is a mock of TXClient interface. -type MockTXClient struct { - ctrl *gomock.Controller - recorder *MockTXClientMockRecorder -} - -// MockTXClientMockRecorder is the mock recorder for MockTXClient. -type MockTXClientMockRecorder struct { - mock *MockTXClient -} - -// NewMockTXClient creates a new mock instance. -func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { - mock := &MockTXClient{ctrl: ctrl} - mock.recorder = &MockTXClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { - return m.recorder -} - -// Cancel mocks base method. -func (m *MockTXClient) Cancel() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Cancel") - ret0, _ := ret[0].(error) - return ret0 -} - -// Cancel indicates an expected call of Cancel. -func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel)) -} - -// Commit mocks base method. -func (m *MockTXClient) Commit() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Commit") - ret0, _ := ret[0].(error) - return ret0 -} - -// Commit indicates an expected call of Commit. -func (mr *MockTXClientMockRecorder) Commit() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit)) -} - -// Exec mocks base method. -func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error { - m.ctrl.T.Helper() - varargs := []any{arg0} - for _, a := range arg1 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Exec", varargs...) - ret0, _ := ret[0].(error) - return ret0 -} - -// Exec indicates an expected call of Exec. -func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0}, arg1...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) -} - -// Stmt mocks base method. -func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Stmt", arg0) - ret0, _ := ret[0].(transaction.Stmt) - return ret0 -} - -// Stmt indicates an expected call of Stmt. -func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) -} - -// StmtExec mocks base method. -func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error { - m.ctrl.T.Helper() - varargs := []any{arg0} - for _, a := range arg1 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "StmtExec", varargs...) - ret0, _ := ret[0].(error) - return ret0 -} - -// StmtExec indicates an expected call of StmtExec. -func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0}, arg1...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...) -} - // MockRows is a mock of Rows interface. type MockRows struct { ctrl *gomock.Controller @@ -202,3 +101,161 @@ func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...) } + +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// CloseStmt mocks base method. +func (m *MockClient) CloseStmt(arg0 db.Closable) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CloseStmt", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// CloseStmt indicates an expected call of CloseStmt. +func (mr *MockClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockClient)(nil).CloseStmt), arg0) +} + +// NewConnection mocks base method. +func (m *MockClient) NewConnection() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewConnection") + ret0, _ := ret[0].(error) + return ret0 +} + +// NewConnection indicates an expected call of NewConnection. +func (mr *MockClientMockRecorder) NewConnection() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockClient)(nil).NewConnection)) +} + +// Prepare mocks base method. +func (m *MockClient) Prepare(arg0 string) *sql.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Prepare", arg0) + ret0, _ := ret[0].(*sql.Stmt) + return ret0 +} + +// Prepare indicates an expected call of Prepare. +func (mr *MockClientMockRecorder) Prepare(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockClient)(nil).Prepare), arg0) +} + +// QueryForRows mocks base method. +func (m *MockClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryForRows", varargs...) + ret0, _ := ret[0].(*sql.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// QueryForRows indicates an expected call of QueryForRows. +func (mr *MockClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockClient)(nil).QueryForRows), varargs...) +} + +// ReadInt mocks base method. +func (m *MockClient) ReadInt(arg0 db.Rows) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadInt", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadInt indicates an expected call of ReadInt. +func (mr *MockClientMockRecorder) ReadInt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockClient)(nil).ReadInt), arg0) +} + +// ReadObjects mocks base method. +func (m *MockClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) + ret0, _ := ret[0].([]any) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadObjects indicates an expected call of ReadObjects. +func (mr *MockClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockClient)(nil).ReadObjects), arg0, arg1, arg2) +} + +// ReadStrings mocks base method. +func (m *MockClient) ReadStrings(arg0 db.Rows) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadStrings", arg0) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadStrings indicates an expected call of ReadStrings. +func (mr *MockClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockClient)(nil).ReadStrings), arg0) +} + +// Upsert mocks base method. +func (m *MockClient) Upsert(arg0 transaction.Client, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// Upsert indicates an expected call of Upsert. +func (mr *MockClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) +} + +// WithTransaction mocks base method. +func (m *MockClient) WithTransaction(arg0 context.Context, arg1 bool, arg2 db.WithTransactionFunction) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithTransaction", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// WithTransaction indicates an expected call of WithTransaction. +func (mr *MockClientMockRecorder) WithTransaction(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithTransaction", reflect.TypeOf((*MockClient)(nil).WithTransaction), arg0, arg1, arg2) +} diff --git a/pkg/sqlcache/store/store.go b/pkg/sqlcache/store/store.go index a228ee86..d76e0542 100644 --- a/pkg/sqlcache/store/store.go +++ b/pkg/sqlcache/store/store.go @@ -34,7 +34,7 @@ const ( // Store is a SQLite-backed cache.Store type Store struct { - DBClient + db.Client name string typ reflect.Type @@ -53,49 +53,37 @@ type Store struct { listStmt *sql.Stmt listKeysStmt *sql.Stmt - afterUpsert []func(key string, obj any, tx db.TXClient) error - afterDelete []func(key string, tx db.TXClient) error + afterUpsert []func(key string, obj any, tx transaction.Client) error + afterDelete []func(key string, tx transaction.Client) error } // Test that Store implements cache.Indexer var _ cache.Store = (*Store)(nil) -type DBClient interface { - BeginTx(ctx context.Context, forWriting bool) (db.TXClient, error) - Prepare(stmt string) *sql.Stmt - QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) - ReadObjects(rows db.Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) - ReadStrings(rows db.Rows) ([]string, error) - ReadInt(rows db.Rows) (int, error) - Upsert(tx db.TXClient, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error - CloseStmt(closable db.Closable) error -} - // NewStore creates a SQLite-backed cache.Store for objects of the given example type -func NewStore(example any, keyFunc cache.KeyFunc, c DBClient, shouldEncrypt bool, name string) (*Store, error) { +func NewStore(example any, keyFunc cache.KeyFunc, c db.Client, shouldEncrypt bool, name string) (*Store, error) { s := &Store{ name: name, typ: reflect.TypeOf(example), - DBClient: c, + Client: c, keyFunc: keyFunc, shouldEncrypt: shouldEncrypt, - afterUpsert: []func(key string, obj any, tx db.TXClient) error{}, - afterDelete: []func(key string, tx db.TXClient) error{}, + afterUpsert: []func(key string, obj any, tx transaction.Client) error{}, + afterDelete: []func(key string, tx transaction.Client) error{}, } + dbName := db.Sanitize(s.name) + // once multiple informerfactories are needed, this can accept the case where table already exists error is received - txC, err := s.BeginTx(context.Background(), true) - if err != nil { - return nil, err - } - dbName := db.Sanitize(s.name) - createTableQuery := fmt.Sprintf(createTableFmt, dbName) - err = txC.Exec(createTableQuery) - if err != nil { - return nil, &db.QueryError{QueryString: createTableQuery, Err: err} - } + err := s.WithTransaction(context.Background(), true, func(tx transaction.Client) error { + createTableQuery := fmt.Sprintf(createTableFmt, dbName) + _, err := tx.Exec(createTableQuery) + if err != nil { + return &db.QueryError{QueryString: createTableQuery, Err: err} + } - err = txC.Commit() + return nil + }) if err != nil { return nil, err } @@ -118,42 +106,36 @@ func NewStore(example any, keyFunc cache.KeyFunc, c DBClient, shouldEncrypt bool /* Core methods */ // upsert saves an obj with its key, or updates key with obj if it exists in this Store func (s *Store) upsert(key string, obj any) error { - tx, err := s.BeginTx(context.Background(), true) - if err != nil { - return err - } + return s.WithTransaction(context.Background(), true, func(tx transaction.Client) error { + err := s.Upsert(tx, s.upsertStmt, key, obj, s.shouldEncrypt) + if err != nil { + return &db.QueryError{QueryString: s.upsertQuery, Err: err} + } - err = s.Upsert(tx, s.upsertStmt, key, obj, s.shouldEncrypt) - if err != nil { - return &db.QueryError{QueryString: s.upsertQuery, Err: err} - } + err = s.runAfterUpsert(key, obj, tx) + if err != nil { + return err + } - err = s.runAfterUpsert(key, obj, tx) - if err != nil { - return err - } - - return tx.Commit() + return nil + }) } // deleteByKey deletes the object associated with key, if it exists in this Store func (s *Store) deleteByKey(key string) error { - tx, err := s.BeginTx(context.Background(), true) - if err != nil { - return err - } + return s.WithTransaction(context.Background(), true, func(tx transaction.Client) error { + _, err := tx.Stmt(s.deleteStmt).Exec(key) + if err != nil { + return &db.QueryError{QueryString: s.deleteQuery, Err: err} + } - err = tx.StmtExec(tx.Stmt(s.deleteStmt), key) - if err != nil { - return &db.QueryError{QueryString: s.deleteQuery, Err: err} - } + err = s.runAfterDelete(key, tx) + if err != nil { + return err + } - err = s.runAfterDelete(key, tx) - if err != nil { - return err - } - - return tx.Commit() + return nil + }) } // GetByKey returns the object associated with the given object's key @@ -267,45 +249,42 @@ func (s *Store) Replace(objects []any, _ string) error { // replaceByKey will delete the contents of the Store, using instead the given key to obj map func (s *Store) replaceByKey(objects map[string]any) error { - txC, err := s.BeginTx(context.Background(), true) - if err != nil { - return err - } + return s.WithTransaction(context.Background(), true, func(txC transaction.Client) error { + txCListKeys := txC.Stmt(s.listKeysStmt) - txCListKeys := txC.Stmt(s.listKeysStmt) - - rows, err := s.QueryForRows(context.TODO(), txCListKeys) - if err != nil { - return err - } - keys, err := s.ReadStrings(rows) - if err != nil { - return err - } - - for _, key := range keys { - err = txC.StmtExec(txC.Stmt(s.deleteStmt), key) + rows, err := s.QueryForRows(context.TODO(), txCListKeys) if err != nil { return err } - err = s.runAfterDelete(key, txC) + keys, err := s.ReadStrings(rows) if err != nil { return err } - } - for key, obj := range objects { - err = s.Upsert(txC, s.upsertStmt, key, obj, s.shouldEncrypt) - if err != nil { - return err + for _, key := range keys { + _, err = txC.Stmt(s.deleteStmt).Exec(key) + if err != nil { + return err + } + err = s.runAfterDelete(key, txC) + if err != nil { + return err + } } - err = s.runAfterUpsert(key, obj, txC) - if err != nil { - return err - } - } - return txC.Commit() + for key, obj := range objects { + err = s.Upsert(txC, s.upsertStmt, key, obj, s.shouldEncrypt) + if err != nil { + return err + } + err = s.runAfterUpsert(key, obj, txC) + if err != nil { + return err + } + } + + return nil + }) } // Resync is a no-op and is deprecated @@ -316,7 +295,7 @@ func (s *Store) Resync() error { /* Utilities */ // RegisterAfterUpsert registers a func to be called after each upsert -func (s *Store) RegisterAfterUpsert(f func(key string, obj any, txC db.TXClient) error) { +func (s *Store) RegisterAfterUpsert(f func(key string, obj any, txC transaction.Client) error) { s.afterUpsert = append(s.afterUpsert, f) } @@ -334,7 +313,7 @@ func (s *Store) GetType() reflect.Type { // keep // runAfterUpsert executes functions registered to run after upsert -func (s *Store) runAfterUpsert(key string, obj any, txC db.TXClient) error { +func (s *Store) runAfterUpsert(key string, obj any, txC transaction.Client) error { for _, f := range s.afterUpsert { err := f(key, obj, txC) if err != nil { @@ -345,13 +324,13 @@ func (s *Store) runAfterUpsert(key string, obj any, txC db.TXClient) error { } // RegisterAfterDelete registers a func to be called after each deletion -func (s *Store) RegisterAfterDelete(f func(key string, txC db.TXClient) error) { +func (s *Store) RegisterAfterDelete(f func(key string, txC transaction.Client) error) { s.afterDelete = append(s.afterDelete, f) } // keep // runAfterDelete executes functions registered to run after upsert -func (s *Store) runAfterDelete(key string, txC db.TXClient) error { +func (s *Store) runAfterDelete(key string, txC transaction.Client) error { for _, f := range s.afterDelete { err := f(key, txC) if err != nil { diff --git a/pkg/sqlcache/store/store_mocks_test.go b/pkg/sqlcache/store/store_mocks_test.go deleted file mode 100644 index d30df82b..00000000 --- a/pkg/sqlcache/store/store_mocks_test.go +++ /dev/null @@ -1,165 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/steve/pkg/sqlcache/store (interfaces: DBClient) -// -// Generated by this command: -// -// mockgen --build_flags=--mod=mod -package store -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient -// - -// Package store is a generated GoMock package. -package store - -import ( - context "context" - sql "database/sql" - reflect "reflect" - - db "github.com/rancher/steve/pkg/sqlcache/db" - transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" - gomock "go.uber.org/mock/gomock" -) - -// MockDBClient is a mock of DBClient interface. -type MockDBClient struct { - ctrl *gomock.Controller - recorder *MockDBClientMockRecorder -} - -// MockDBClientMockRecorder is the mock recorder for MockDBClient. -type MockDBClientMockRecorder struct { - mock *MockDBClient -} - -// NewMockDBClient creates a new mock instance. -func NewMockDBClient(ctrl *gomock.Controller) *MockDBClient { - mock := &MockDBClient{ctrl: ctrl} - mock.recorder = &MockDBClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockDBClient) EXPECT() *MockDBClientMockRecorder { - return m.recorder -} - -// BeginTx mocks base method. -func (m *MockDBClient) BeginTx(arg0 context.Context, arg1 bool) (db.TXClient, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BeginTx", arg0, arg1) - ret0, _ := ret[0].(db.TXClient) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// BeginTx indicates an expected call of BeginTx. -func (mr *MockDBClientMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockDBClient)(nil).BeginTx), arg0, arg1) -} - -// CloseStmt mocks base method. -func (m *MockDBClient) CloseStmt(arg0 db.Closable) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CloseStmt", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// CloseStmt indicates an expected call of CloseStmt. -func (mr *MockDBClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseStmt", reflect.TypeOf((*MockDBClient)(nil).CloseStmt), arg0) -} - -// Prepare mocks base method. -func (m *MockDBClient) Prepare(arg0 string) *sql.Stmt { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Prepare", arg0) - ret0, _ := ret[0].(*sql.Stmt) - return ret0 -} - -// Prepare indicates an expected call of Prepare. -func (mr *MockDBClientMockRecorder) Prepare(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockDBClient)(nil).Prepare), arg0) -} - -// QueryForRows mocks base method. -func (m *MockDBClient) QueryForRows(arg0 context.Context, arg1 transaction.Stmt, arg2 ...any) (*sql.Rows, error) { - m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "QueryForRows", varargs...) - ret0, _ := ret[0].(*sql.Rows) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// QueryForRows indicates an expected call of QueryForRows. -func (mr *MockDBClientMockRecorder) QueryForRows(arg0, arg1 any, arg2 ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryForRows", reflect.TypeOf((*MockDBClient)(nil).QueryForRows), varargs...) -} - -// ReadInt mocks base method. -func (m *MockDBClient) ReadInt(arg0 db.Rows) (int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadInt", arg0) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ReadInt indicates an expected call of ReadInt. -func (mr *MockDBClientMockRecorder) ReadInt(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadInt", reflect.TypeOf((*MockDBClient)(nil).ReadInt), arg0) -} - -// ReadObjects mocks base method. -func (m *MockDBClient) ReadObjects(arg0 db.Rows, arg1 reflect.Type, arg2 bool) ([]any, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadObjects", arg0, arg1, arg2) - ret0, _ := ret[0].([]any) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ReadObjects indicates an expected call of ReadObjects. -func (mr *MockDBClientMockRecorder) ReadObjects(arg0, arg1, arg2 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadObjects", reflect.TypeOf((*MockDBClient)(nil).ReadObjects), arg0, arg1, arg2) -} - -// ReadStrings mocks base method. -func (m *MockDBClient) ReadStrings(arg0 db.Rows) ([]string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ReadStrings", arg0) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ReadStrings indicates an expected call of ReadStrings. -func (mr *MockDBClientMockRecorder) ReadStrings(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadStrings", reflect.TypeOf((*MockDBClient)(nil).ReadStrings), arg0) -} - -// Upsert mocks base method. -func (m *MockDBClient) Upsert(arg0 db.TXClient, arg1 *sql.Stmt, arg2 string, arg3 any, arg4 bool) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Upsert", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 -} - -// Upsert indicates an expected call of Upsert. -func (mr *MockDBClientMockRecorder) Upsert(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDBClient)(nil).Upsert), arg0, arg1, arg2, arg3, arg4) -} diff --git a/pkg/sqlcache/store/store_test.go b/pkg/sqlcache/store/store_test.go index 1d4e2613..e1bff33b 100644 --- a/pkg/sqlcache/store/store_test.go +++ b/pkg/sqlcache/store/store_test.go @@ -7,9 +7,8 @@ Adapted from client-go, Copyright 2014 The Kubernetes Authors. package store // Mocks for this test are generated with the following command. -//go:generate mockgen --build_flags=--mod=mod -package store -destination ./store_mocks_test.go github.com/rancher/steve/pkg/sqlcache/store DBClient -//go:generate mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db TXClient,Rows -//go:generate mockgen --build_flags=--mod=mod -package store -destination ./tx_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt +//go:generate mockgen --build_flags=--mod=mod -package store -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Client +//go:generate mockgen --build_flags=--mod=mod -package store -destination ./transaction_mocks_test.go -mock_names Client=MockTXClient github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,Client import ( "context" @@ -19,6 +18,7 @@ import ( "testing" "github.com/rancher/steve/pkg/sqlcache/db" + "github.com/rancher/steve/pkg/sqlcache/db/transaction" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -44,27 +44,39 @@ func TestAdd(t *testing.T) { var tests []testCase // Tests with shouldEncryptSet to false - tests = append(tests, testCase{description: "Add with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Add with no DB client errors", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) - txC.EXPECT().Commit().Return(nil) + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err != nil { + t.Fail() + } + }) + err := store.Add(testObject) assert.Nil(t, err) - // dbclient beginerr }, }) - tests = append(tests, testCase{description: "Add with no DB Client errors and an afterUpsert function", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Add with no DB client errors and an afterUpsert function", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) - txC.EXPECT().Commit().Return(nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err != nil { + t.Fail() + } + }) var count int - store.afterUpsert = append(store.afterUpsert, func(key string, object any, tx db.TXClient) error { + store.afterUpsert = append(store.afterUpsert, func(key string, object any, tx transaction.Client) error { count++ return nil }) @@ -74,53 +86,63 @@ func TestAdd(t *testing.T) { }, }) - tests = append(tests, testCase{description: "Add with no DB Client errors and an afterUpsert function that returns error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Add with no DB client errors and an afterUpsert function that returns error", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) - store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error { + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err == nil { + t.Fail() + } + }) + + store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC transaction.Client) error { return fmt.Errorf("error") }) err := store.Add(testObject) assert.NotNil(t, err) - // dbclient beginerr }, }) - tests = append(tests, testCase{description: "Add with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Add with DB client WithTransaction error", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) - c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("failed")) + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("failed")) store := SetupStore(t, c, shouldEncrypt) err := store.Add(testObject) assert.NotNil(t, err) }}) - tests = append(tests, testCase{description: "Add with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Add with DB client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("failed")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err == nil { + t.Fail() + } + }) err := store.Add(testObject) assert.NotNil(t, err) }}) - tests = append(tests, testCase{description: "Add with DB Client Upsert() error with following Rollback() error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Add with DB client Commit() error", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) - c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) - err := store.Add(testObject) - assert.NotNil(t, err) - }}) - tests = append(tests, testCase{description: "Add with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) { - c, txC := SetupMockDB(t) - store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) - txC.EXPECT().Commit().Return(fmt.Errorf("failed")) + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("failed")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err != nil { + t.Fail() + } + }) err := store.Add(testObject) assert.NotNil(t, err) @@ -145,27 +167,38 @@ func TestUpdate(t *testing.T) { var tests []testCase // Tests with shouldEncryptSet to false - tests = append(tests, testCase{description: "Update with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Update with no DB client errors", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) - txC.EXPECT().Commit().Return(nil) + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err != nil { + t.Fail() + } + }) err := store.Update(testObject) assert.Nil(t, err) - // dbclient beginerr }, }) - tests = append(tests, testCase{description: "Update with no DB Client errors and an afterUpsert function", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Update with no DB client errors and an afterUpsert function", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) - txC.EXPECT().Commit().Return(nil) + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err != nil { + t.Fail() + } + }) var count int - store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error { + store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC transaction.Client) error { count++ return nil }) @@ -175,13 +208,21 @@ func TestUpdate(t *testing.T) { }, }) - tests = append(tests, testCase{description: "Update with no DB Client errors and an afterUpsert function that returns error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Update with no DB client errors and an afterUpsert function that returns error", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) - store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC db.TXClient) error { + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err == nil { + t.Fail() + } + }) + + store.afterUpsert = append(store.afterUpsert, func(key string, object any, txC transaction.Client) error { return fmt.Errorf("error") }) err := store.Update(testObject) @@ -189,40 +230,28 @@ func TestUpdate(t *testing.T) { }, }) - tests = append(tests, testCase{description: "Update with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Update with DB client WithTransaction returning error", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) - c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("failed")) + + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")) store := SetupStore(t, c, shouldEncrypt) err := store.Update(testObject) assert.NotNil(t, err) }}) - tests = append(tests, testCase{description: "Update with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Update with DB client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) - err := store.Update(testObject) - assert.NotNil(t, err) - }}) - - tests = append(tests, testCase{description: "Update with DB Client Upsert() error with following Rollback() error", test: func(t *testing.T, shouldEncrypt bool) { - c, txC := SetupMockDB(t) - store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) - c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(fmt.Errorf("failed")) - err := store.Update(testObject) - assert.NotNil(t, err) - }}) - - tests = append(tests, testCase{description: "Update with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) { - c, txC := SetupMockDB(t) - store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) - c.EXPECT().Upsert(txC, store.upsertStmt, "something", testObject, store.shouldEncrypt).Return(nil) - txC.EXPECT().Commit().Return(fmt.Errorf("failed")) - + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err == nil { + t.Fail() + } + }) err := store.Update(testObject) assert.NotNil(t, err) }}) @@ -246,47 +275,67 @@ func TestDelete(t *testing.T) { var tests []testCase // Tests with shouldEncryptSet to false - tests = append(tests, testCase{description: "Delete with no DB Client errors", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Delete with no DB client errors", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) - // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt - txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) - txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil) - txC.EXPECT().Commit().Return(nil) + stmt := NewMockStmt(gomock.NewController(t)) + txC.EXPECT().Stmt(store.deleteStmt).Return(stmt) + stmt.EXPECT().Exec(testObject.Id).Return(nil, nil) + + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err != nil { + t.Fail() + } + }) + err := store.Delete(testObject) assert.Nil(t, err) }, }) - tests = append(tests, testCase{description: "Delete with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Delete with DB client WithTransaction returning error", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error")) - // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")) err := store.Delete(testObject) assert.NotNil(t, err) }, }) - tests = append(tests, testCase{description: "Delete with TX Client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Delete with TX client Exec() error", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) - txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) - txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(fmt.Errorf("error")) - // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt + stmt := NewMockStmt(gomock.NewController(t)) + txC.EXPECT().Stmt(store.deleteStmt).Return(stmt) + stmt.EXPECT().Exec(testObject.Id).Return(nil, fmt.Errorf("error")) + + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err == nil { + t.Fail() + } + }) + err := store.Delete(testObject) assert.NotNil(t, err) }, }) - tests = append(tests, testCase{description: "Delete with DB Client Commit() error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Delete with DB client Commit() error", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) - // deleteStmt here will be an empty string since Prepare mock returns an empty *sql.Stmt - txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) - // tx.EXPECT(). - txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil) - txC.EXPECT().Commit().Return(fmt.Errorf("error")) + stmt := NewMockStmt(gomock.NewController(t)) + txC.EXPECT().Stmt(store.deleteStmt).Return(stmt) + stmt.EXPECT().Exec(testObject.Id).Return(nil, nil) + + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err != nil { + t.Fail() + } + }) + err := store.Delete(testObject) assert.NotNil(t, err) }, @@ -309,7 +358,7 @@ func TestList(t *testing.T) { var tests []testCase - tests = append(tests, testCase{description: "List with no DB Client errors and no items", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "List with no DB client errors and no items", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} @@ -319,7 +368,7 @@ func TestList(t *testing.T) { assert.Len(t, items, 0) }, }) - tests = append(tests, testCase{description: "List with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "List with no DB client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) fakeItemsToReturn := []any{"something1", 2, false} @@ -330,7 +379,7 @@ func TestList(t *testing.T) { assert.Equal(t, fakeItemsToReturn, items) }, }) - tests = append(tests, testCase{description: "List with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "List with DB client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} @@ -359,7 +408,7 @@ func TestListKeys(t *testing.T) { var tests []testCase - tests = append(tests, testCase{description: "ListKeys with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "ListKeys with no DB client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} @@ -370,7 +419,7 @@ func TestListKeys(t *testing.T) { }, }) - tests = append(tests, testCase{description: "ListKeys with DB Client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "ListKeys with DB client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} @@ -396,7 +445,7 @@ func TestGet(t *testing.T) { var tests []testCase testObject := testStoreObject{Id: "something", Val: "a"} - tests = append(tests, testCase{description: "Get with no DB Client errors and object exists", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Get with no DB client errors and object exists", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} @@ -408,7 +457,7 @@ func TestGet(t *testing.T) { assert.True(t, exists) }, }) - tests = append(tests, testCase{description: "Get with no DB Client errors and object does not exist", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Get with no DB client errors and object does not exist", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} @@ -420,7 +469,7 @@ func TestGet(t *testing.T) { assert.False(t, exists) }, }) - tests = append(tests, testCase{description: "Get with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Get with DB client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} @@ -446,7 +495,7 @@ func TestGetByKey(t *testing.T) { var tests []testCase testObject := testStoreObject{Id: "something", Val: "a"} - tests = append(tests, testCase{description: "GetByKey with no DB Client errors and item exists", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "GetByKey with no DB client errors and item exists", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} @@ -458,7 +507,7 @@ func TestGetByKey(t *testing.T) { assert.True(t, exists) }, }) - tests = append(tests, testCase{description: "GetByKey with no DB Client errors and item does not exist", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "GetByKey with no DB client errors and item does not exist", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} @@ -470,7 +519,7 @@ func TestGetByKey(t *testing.T) { assert.False(t, exists) }, }) - tests = append(tests, testCase{description: "GetByKey with DB Client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "GetByKey with DB client ReadObjects() error", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} @@ -498,93 +547,128 @@ func TestReplace(t *testing.T) { var tests []testCase testObject := testStoreObject{Id: "something", Val: "a"} - tests = append(tests, testCase{description: "Replace with no DB Client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Replace with no DB client errors and some items", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) - txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) - c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + stmt := NewMockStmt(gomock.NewController(t)) + + txC.EXPECT().Stmt(store.listKeysStmt).Return(stmt) + c.EXPECT().QueryForRows(context.TODO(), stmt).Return(r, nil) c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) - txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) - txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id) + txC.EXPECT().Stmt(store.deleteStmt).Return(stmt) + stmt.EXPECT().Exec(testObject.Id) c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt) - txC.EXPECT().Commit() + + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err != nil { + t.Fail() + } + }) + err := store.Replace([]any{testObject}, testObject.Id) assert.Nil(t, err) }, }) - tests = append(tests, testCase{description: "Replace with no DB Client errors and no items", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Replace with no DB client errors and no items", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) c.EXPECT().ReadStrings(r).Return([]string{}, nil) c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt) - txC.EXPECT().Commit() + + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err != nil { + t.Fail() + } + }) + err := store.Replace([]any{testObject}, testObject.Id) assert.Nil(t, err) }, }) - tests = append(tests, testCase{description: "Replace with DB Client BeginTx(gomock.Any(), true) error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Replace with DB client WithTransaction returning error", test: func(t *testing.T, shouldEncrypt bool) { c, _ := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) - c.EXPECT().BeginTx(gomock.Any(), true).Return(nil, fmt.Errorf("error")) + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")) err := store.Replace([]any{testObject}, testObject.Id) assert.NotNil(t, err) }, }) - tests = append(tests, testCase{description: "Replace with no DB Client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Replace with DB client ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error")) + + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err == nil { + t.Fail() + } + }) + err := store.Replace([]any{testObject}, testObject.Id) assert.NotNil(t, err) }, }) - tests = append(tests, testCase{description: "Replace with ReadStrings() error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Replace with TX client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) - txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) - c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) - c.EXPECT().ReadStrings(r).Return(nil, fmt.Errorf("error")) - err := store.Replace([]any{testObject}, testObject.Id) - assert.NotNil(t, err) - }, - }) - tests = append(tests, testCase{description: "Replace with TX Client StmtExec() error", test: func(t *testing.T, shouldEncrypt bool) { - c, txC := SetupMockDB(t) - store := SetupStore(t, c, shouldEncrypt) - r := &sql.Rows{} - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) - txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) - c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + listKeysStmt := NewMockStmt(gomock.NewController(t)) + deleteStmt := NewMockStmt(gomock.NewController(t)) + + txC.EXPECT().Stmt(store.listKeysStmt).Return(listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), listKeysStmt).Return(r, nil) c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) - txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) - txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(fmt.Errorf("error")) + txC.EXPECT().Stmt(store.deleteStmt).Return(deleteStmt) + deleteStmt.EXPECT().Exec(testObject.Id).Return(nil, fmt.Errorf("error")) + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err == nil { + t.Fail() + } + }) + err := store.Replace([]any{testObject}, testObject.Id) assert.NotNil(t, err) }, }) - tests = append(tests, testCase{description: "Replace with DB Client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { + tests = append(tests, testCase{description: "Replace with DB client Upsert() error", test: func(t *testing.T, shouldEncrypt bool) { c, txC := SetupMockDB(t) store := SetupStore(t, c, shouldEncrypt) r := &sql.Rows{} - c.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) - txC.EXPECT().Stmt(store.listKeysStmt).Return(store.listKeysStmt) - c.EXPECT().QueryForRows(context.TODO(), store.listKeysStmt).Return(r, nil) + listKeysStmt := NewMockStmt(gomock.NewController(t)) + deleteStmt := NewMockStmt(gomock.NewController(t)) + + txC.EXPECT().Stmt(store.listKeysStmt).Return(listKeysStmt) + c.EXPECT().QueryForRows(context.TODO(), listKeysStmt).Return(r, nil) c.EXPECT().ReadStrings(r).Return([]string{testObject.Id}, nil) - txC.EXPECT().Stmt(store.deleteStmt).Return(store.deleteStmt) - txC.EXPECT().StmtExec(store.deleteStmt, testObject.Id).Return(nil) + txC.EXPECT().Stmt(store.deleteStmt).Return(deleteStmt) + deleteStmt.EXPECT().Exec(testObject.Id).Return(nil, nil) c.EXPECT().Upsert(txC, store.upsertStmt, testObject.Id, testObject, store.shouldEncrypt).Return(fmt.Errorf("error")) + + c.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error")).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err == nil { + t.Fail() + } + }) + err := store.Replace([]any{testObject}, testObject.Id) assert.NotNil(t, err) }, @@ -620,13 +704,17 @@ func TestResync(t *testing.T) { } } -func SetupMockDB(t *testing.T) (*MockDBClient, *MockTXClient) { - dbC := NewMockDBClient(gomock.NewController(t)) // add functionality once store expectation are known +func SetupMockDB(t *testing.T) (*MockClient, *MockTXClient) { + dbC := NewMockClient(gomock.NewController(t)) // add functionality once store expectation are known txC := NewMockTXClient(gomock.NewController(t)) - // stmt := NewMockStmt(gomock.NewController()) - txC.EXPECT().Exec(fmt.Sprintf(createTableFmt, "testStoreObject")).Return(nil) - txC.EXPECT().Commit().Return(nil) - dbC.EXPECT().BeginTx(gomock.Any(), true).Return(txC, nil) + txC.EXPECT().Exec(fmt.Sprintf(createTableFmt, "testStoreObject")).Return(nil, nil) + dbC.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(nil).Do( + func(ctx context.Context, shouldEncrypt bool, f db.WithTransactionFunction) { + err := f(txC) + if err != nil { + t.Fail() + } + }) // use stmt mock here dbC.EXPECT().Prepare(fmt.Sprintf(upsertStmtFmt, "testStoreObject")).Return(&sql.Stmt{}) @@ -637,7 +725,7 @@ func SetupMockDB(t *testing.T) (*MockDBClient, *MockTXClient) { return dbC, txC } -func SetupStore(t *testing.T, client *MockDBClient, shouldEncrypt bool) *Store { +func SetupStore(t *testing.T, client *MockClient, shouldEncrypt bool) *Store { store, err := NewStore(testStoreObject{}, testStoreKeyFunc, client, shouldEncrypt, "testStoreObject") if err != nil { t.Error(err) diff --git a/pkg/sqlcache/store/tx_mocks_test.go b/pkg/sqlcache/store/transaction_mocks_test.go similarity index 58% rename from pkg/sqlcache/store/tx_mocks_test.go rename to pkg/sqlcache/store/transaction_mocks_test.go index 0c05ab7f..85a3e177 100644 --- a/pkg/sqlcache/store/tx_mocks_test.go +++ b/pkg/sqlcache/store/transaction_mocks_test.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt) +// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,Client) // // Generated by this command: // -// mockgen --build_flags=--mod=mod -package store -destination ./tx_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt +// mockgen --build_flags=--mod=mod -package store -destination ./transaction_mocks_test.go -mock_names Client=MockTXClient github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,Client // // Package store is a generated GoMock package. @@ -14,6 +14,7 @@ import ( sql "database/sql" reflect "reflect" + transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction" gomock "go.uber.org/mock/gomock" ) @@ -97,3 +98,60 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call varargs := append([]any{arg0}, arg1...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) } + +// MockTXClient is a mock of Client interface. +type MockTXClient struct { + ctrl *gomock.Controller + recorder *MockTXClientMockRecorder +} + +// MockTXClientMockRecorder is the mock recorder for MockTXClient. +type MockTXClientMockRecorder struct { + mock *MockTXClient +} + +// NewMockTXClient creates a new mock instance. +func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient { + mock := &MockTXClient{ctrl: ctrl} + mock.recorder = &MockTXClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder { + return m.recorder +} + +// Exec mocks base method. +func (m *MockTXClient) Exec(arg0 string, arg1 ...any) (sql.Result, error) { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(sql.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...) +} + +// Stmt mocks base method. +func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stmt", arg0) + ret0, _ := ret[0].(transaction.Stmt) + return ret0 +} + +// Stmt indicates an expected call of Stmt. +func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0) +}