1
0
mirror of https://github.com/rancher/steve.git synced 2025-09-01 07:27:46 +00:00

sql: use a closure to wrap transactions (#469)

This introduces the a `WithTransaction` function, which is then used for all transactional work in Steve.

Because `WithTransaction` takes care of all `Begin`s, `Commit`s and `Rollback`s, it eliminates the problem where forgotten open transactions can block all other operations (with long stalling and `SQLITE_BUSY` errors).

This also:

- merges together the disparate `DBClient` interfaces in one only `db.Client` interface with one unexported non-test implementation. I found this much easier to follow
- refactors the transaction package in order to make it as minimal as possible, and as close to the wrapped `sql.Tx` and `sql.Stmt` functions as possible, in order to reduce cognitive load when working with this part of the codebase
- simplifies tests accordingly
- adds a couple of known files to `.gitignore`
    
Credits to @tomleb for suggesting the approach: https://github.com/rancher/lasso/pull/121#pullrequestreview-2515872507
This commit is contained in:
Silvio Moioli
2025-02-05 10:05:52 +01:00
committed by GitHub
parent 6a46a1e091
commit 772dc7577e
28 changed files with 1543 additions and 2059 deletions

View File

@@ -15,7 +15,8 @@ import (
"reflect"
"sync"
"github.com/pkg/errors"
"errors"
"github.com/rancher/steve/pkg/sqlcache/db/transaction"
// needed for drivers
@@ -29,8 +30,59 @@ const (
informerObjectCachePerms fs.FileMode = 0o600
)
// Client is a database client that provides encrypting, decrypting, and database resetting.
type Client struct {
// Client defines a database client that provides encrypting, decrypting, and database resetting
type Client interface {
WithTransaction(ctx context.Context, forWriting bool, f WithTransactionFunction) error
Prepare(stmt string) *sql.Stmt
QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error)
ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error)
ReadStrings(rows Rows) ([]string, error)
ReadInt(rows Rows) (int, error)
Upsert(tx transaction.Client, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error
CloseStmt(closable Closable) error
NewConnection() error
}
// WithTransaction runs f within a transaction.
//
// If forWriting is true, this method blocks until all other concurrent forWriting
// transactions have either committed or rolled back.
// If forWriting is false, it is assumed the returned transaction will exclusively
// be used for DQL (e.g. SELECT) queries.
// Not respecting the above rule might result in transactions failing with unexpected
// SQLITE_BUSY (5) errors (aka "Runtime error: database is locked").
// See discussion in https://github.com/rancher/lasso/pull/98 for details
//
// The transaction is committed if f returns nil, otherwise it is rolled back.
func (c *client) WithTransaction(ctx context.Context, forWriting bool, f WithTransactionFunction) error {
c.connLock.RLock()
// note: this assumes _txlock=immediate in the connection string, see NewConnection
tx, err := c.conn.BeginTx(ctx, &sql.TxOptions{
ReadOnly: !forWriting,
})
c.connLock.RUnlock()
if err != nil {
return err
}
err = f(transaction.NewClient(tx))
if err != nil {
rerr := tx.Rollback()
err = errors.Join(err, rerr)
} else {
cerr := tx.Commit()
err = errors.Join(err, cerr)
}
return err
}
// WithTransactionFunction is a function that uses a transaction
type WithTransactionFunction func(tx transaction.Client) error
// client is the main implementation of Client. Other implementations exist for test purposes
type client struct {
conn Connection
connLock sync.RWMutex
encryptor Encryptor
@@ -74,15 +126,6 @@ func (e *QueryError) Unwrap() error {
return e.Err
}
// TXClient represents a sql transaction. The TXClient must manage rollbacks as rollback functionality is not exposed.
type TXClient interface {
StmtExec(stmt transaction.Stmt, args ...any) error
Exec(stmt string, args ...any) error
Commit() error
Stmt(stmt *sql.Stmt) transaction.Stmt
Cancel() error
}
// Encryptor encrypts data with a key which is rotated to avoid wear-out.
type Encryptor interface {
// Encrypt encrypts the specified data, returning: the encrypted data, the nonce used to encrypt the data, and an ID identifying the key that was used (as it rotates). On failure error is returned instead.
@@ -95,9 +138,9 @@ type Decryptor interface {
Decrypt([]byte, []byte, uint32) ([]byte, error)
}
// NewClient returns a Client. If the given connection is nil then a default one will be created.
func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (*Client, error) {
client := &Client{
// NewClient returns a client. If the given connection is nil then a default one will be created.
func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (Client, error) {
client := &client{
encryptor: encryptor,
decryptor: decryptor,
}
@@ -114,19 +157,19 @@ func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (*Client,
}
// Prepare prepares the given string into a sql statement on the client's connection.
func (c *Client) Prepare(stmt string) *sql.Stmt {
func (c *client) Prepare(stmt string) *sql.Stmt {
c.connLock.RLock()
defer c.connLock.RUnlock()
prepared, err := c.conn.Prepare(stmt)
if err != nil {
panic(errors.Errorf("Error preparing statement: %s\n%v", stmt, err))
panic(fmt.Errorf("Error preparing statement: %s\n%w", stmt, err))
}
return prepared
}
// QueryForRows queries the given stmt with the given params and returns the resulting rows. The query wil be retried
// given a sqlite busy error.
func (c *Client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) {
func (c *client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
@@ -135,13 +178,13 @@ func (c *Client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params
// CloseStmt will call close on the given Closable. It is intended to be used with a sql statement. This function is meant
// to replace stmt.Close which can cause panics when callers unit-test since there usually is no real underlying connection.
func (c *Client) CloseStmt(closable Closable) error {
func (c *client) CloseStmt(closable Closable) error {
return closable.Close()
}
// ReadObjects Scans the given rows, performs any necessary decryption, converts the data to objects of the given type,
// and returns a slice of those objects.
func (c *Client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) {
func (c *client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
@@ -171,7 +214,7 @@ func (c *Client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([
}
// ReadStrings scans the given rows into strings, and then returns the strings as a slice.
func (c *Client) ReadStrings(rows Rows) ([]string, error) {
func (c *client) ReadStrings(rows Rows) ([]string, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
@@ -199,7 +242,7 @@ func (c *Client) ReadStrings(rows Rows) ([]string, error) {
}
// ReadInt scans the first of the given rows into a single int (eg. for COUNT() queries)
func (c *Client) ReadInt(rows Rows) (int, error) {
func (c *client) ReadInt(rows Rows) (int, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
@@ -226,28 +269,7 @@ func (c *Client) ReadInt(rows Rows) (int, error) {
return result, nil
}
// BeginTx attempts to begin a transaction.
// If forWriting is true, this method blocks until all other concurrent forWriting
// transactions have either committed or rolled back.
// If forWriting is false, it is assumed the returned transaction will exclusively
// be used for DQL (e.g. SELECT) queries.
// Not respecting the above rule might result in transactions failing with unexpected
// SQLITE_BUSY (5) errors (aka "Runtime error: database is locked").
// See discussion in https://github.com/rancher/lasso/pull/98 for details
func (c *Client) BeginTx(ctx context.Context, forWriting bool) (TXClient, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
// note: this assumes _txlock=immediate in the connection string, see NewConnection
sqlTx, err := c.conn.BeginTx(ctx, &sql.TxOptions{
ReadOnly: !forWriting,
})
if err != nil {
return nil, err
}
return transaction.NewClient(sqlTx), nil
}
func (c *Client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) {
func (c *client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) {
var data, dataNonce sql.RawBytes
var kid uint32
err := rows.Scan(&data, &dataNonce, &kid)
@@ -264,8 +286,9 @@ func (c *Client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) {
return data, nil
}
// Upsert used to be called upsertEncrypted in store package before move
func (c *Client) Upsert(tx TXClient, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error {
// Upsert executes an upsert statement encrypting arguments if necessary
// note the statement should have 4 parameters: key, objBytes, dataNonce, kid
func (c *client) Upsert(tx transaction.Client, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error {
objBytes := toBytes(obj)
var dataNonce []byte
var err error
@@ -277,7 +300,8 @@ func (c *Client) Upsert(tx TXClient, stmt *sql.Stmt, key string, obj any, should
}
}
return tx.StmtExec(tx.Stmt(stmt), key, objBytes, dataNonce, kid)
_, err = tx.Stmt(stmt).Exec(key, objBytes, dataNonce, kid)
return err
}
// toBytes encodes an object to a byte slice
@@ -312,7 +336,7 @@ func closeRowsOnError(rows Rows, err error) error {
// NewConnection checks for currently existing connection, closes one if it exists, removes any relevant db files, and opens a new connection which subsequently
// creates new files.
func (c *Client) NewConnection() error {
func (c *client) NewConnection() error {
c.connLock.Lock()
defer c.connLock.Unlock()
if c.conn != nil {

View File

@@ -11,13 +11,14 @@ import (
"reflect"
"testing"
"github.com/rancher/steve/pkg/sqlcache/db/transaction"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
)
// Mocks for this test are generated with the following command.
//go:generate mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient
//go:generate mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
//go:generate mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor
//go:generate mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Client,Stmt
type testStoreObject struct {
Id string
@@ -37,7 +38,7 @@ func TestNewClient(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
expectedClient := &Client{
expectedClient := &client{
conn: c,
encryptor: e,
decryptor: d,
@@ -389,58 +390,6 @@ func TestReadInt(t *testing.T) {
}
}
func TestBegin(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "BeginTx(), with no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
sqlTx := &sql.Tx{}
c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(sqlTx, nil)
client := SetupClient(t, c, e, d)
txC, err := client.BeginTx(context.Background(), false)
assert.Nil(t, err)
assert.NotNil(t, txC)
},
})
tests = append(tests, testCase{description: "BeginTx(), with forWriting option set", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
sqlTx := &sql.Tx{}
c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: false}).Return(sqlTx, nil)
client := SetupClient(t, c, e, d)
txC, err := client.BeginTx(context.Background(), true)
assert.Nil(t, err)
assert.NotNil(t, txC)
},
})
tests = append(tests, testCase{description: "BeginTx(), with connection Begin() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(nil, fmt.Errorf("error"))
client := SetupClient(t, c, e, d)
_, err := client.BeginTx(context.Background(), false)
assert.NotNil(t, err)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestUpsert(t *testing.T) {
type testCase struct {
description string
@@ -459,14 +408,14 @@ func TestUpsert(t *testing.T) {
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
txC := NewMockClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
stmt := NewMockStmt(gomock.NewController(t))
testObjBytes := toBytes(testObject)
testByteValue := []byte("something")
e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil)
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(nil)
stmt.EXPECT().Exec("somekey", testByteValue, testByteValue, keyID).Return(nil, nil)
err := client.Upsert(txC, sqlStmt, "somekey", testObject, true)
assert.Nil(t, err)
},
@@ -477,7 +426,7 @@ func TestUpsert(t *testing.T) {
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
txC := NewMockClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
testObjBytes := toBytes(testObject)
e.EXPECT().Encrypt(testObjBytes).Return(nil, nil, uint32(0), fmt.Errorf("error"))
@@ -491,14 +440,14 @@ func TestUpsert(t *testing.T) {
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
txC := NewMockClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
stmt := NewMockStmt(gomock.NewController(t))
testObjBytes := toBytes(testObject)
testByteValue := []byte("something")
e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil)
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(fmt.Errorf("error"))
stmt.EXPECT().Exec("somekey", testByteValue, testByteValue, keyID).Return(nil, fmt.Errorf("error"))
err := client.Upsert(txC, sqlStmt, "somekey", testObject, true)
assert.NotNil(t, err)
},
@@ -509,13 +458,13 @@ func TestUpsert(t *testing.T) {
e := SetupMockEncryptor(t)
client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
txC := NewMockClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
stmt := NewMockStmt(gomock.NewController(t))
var testByteValue []byte
testObjBytes := toBytes(testObject)
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
txC.EXPECT().StmtExec(stmt, "somekey", testObjBytes, testByteValue, uint32(0)).Return(nil)
stmt.EXPECT().Exec("somekey", testObjBytes, testByteValue, uint32(0)).Return(nil, nil)
err := client.Upsert(txC, sqlStmt, "somekey", testObject, false)
assert.Nil(t, err)
},
@@ -582,9 +531,10 @@ func TestNewConnection(t *testing.T) {
assert.Nil(t, err)
// Create a transaction to ensure that the file is written to disk.
txC, err := client.BeginTx(context.Background(), false)
err = client.WithTransaction(context.Background(), false, func(tx transaction.Client) error {
return nil
})
assert.NoError(t, err)
assert.NoError(t, txC.Commit())
assert.FileExists(t, InformerObjectCacheDBPath)
assertFileHasPermissions(t, InformerObjectCacheDBPath, 0600)
@@ -630,7 +580,7 @@ func SetupMockRows(t *testing.T) *MockRows {
return MockR
}
func SetupClient(t *testing.T, connection Connection, encryptor Encryptor, decryptor Decryptor) *Client {
func SetupClient(t *testing.T, connection Connection, encryptor Encryptor, decryptor Decryptor) Client {
c, _ := NewClient(connection, encryptor, decryptor)
return c
}

View File

@@ -1,9 +1,9 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: Rows,Connection,Encryptor,Decryptor,TXClient)
// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: Rows,Connection,Encryptor,Decryptor)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient
// mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor
//
// Package db is a generated GoMock package.
@@ -14,7 +14,6 @@ import (
sql "database/sql"
reflect "reflect"
transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction"
gomock "go.uber.org/mock/gomock"
)
@@ -265,106 +264,3 @@ func (mr *MockDecryptorMockRecorder) Decrypt(arg0, arg1, arg2 any) *gomock.Call
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockDecryptor)(nil).Decrypt), arg0, arg1, arg2)
}
// MockTXClient is a mock of TXClient interface.
type MockTXClient struct {
ctrl *gomock.Controller
recorder *MockTXClientMockRecorder
}
// MockTXClientMockRecorder is the mock recorder for MockTXClient.
type MockTXClientMockRecorder struct {
mock *MockTXClient
}
// NewMockTXClient creates a new mock instance.
func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient {
mock := &MockTXClient{ctrl: ctrl}
mock.recorder = &MockTXClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder {
return m.recorder
}
// Cancel mocks base method.
func (m *MockTXClient) Cancel() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Cancel")
ret0, _ := ret[0].(error)
return ret0
}
// Cancel indicates an expected call of Cancel.
func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel))
}
// Commit mocks base method.
func (m *MockTXClient) Commit() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit")
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
func (mr *MockTXClientMockRecorder) Commit() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit))
}
// Exec mocks base method.
func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Exec indicates an expected call of Exec.
func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...)
}
// Stmt mocks base method.
func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stmt", arg0)
ret0, _ := ret[0].(transaction.Stmt)
return ret0
}
// Stmt indicates an expected call of Stmt.
func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0)
}
// StmtExec mocks base method.
func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "StmtExec", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// StmtExec indicates an expected call of StmtExec.
func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...)
}

View File

@@ -1,92 +1,44 @@
/*
Package transaction provides a client for a live transaction, and interfaces for some relevant sql types. The transaction client automatically performs rollbacks on failures.
The use of this package simplifies testing for callers by making the underlying transaction mock-able.
Package transaction provides mockable interfaces of sql package struct types.
*/
package transaction
import (
"context"
"database/sql"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
// Client provides a way to interact with the underlying sql transaction.
type Client struct {
sqlTx SQLTx
}
// SQLTx represents a sql transaction
type SQLTx interface {
// Client is an interface over a subset of sql.Tx methods
// rationale 1: explicitly forbid direct access to Commit and Rollback functionality
// as that is exclusively dealt with by WithTransaction in ../db
// rationale 2: allow mocking
type Client interface {
Exec(query string, args ...any) (sql.Result, error)
Stmt(stmt *sql.Stmt) *sql.Stmt
Commit() error
Rollback() error
Stmt(stmt *sql.Stmt) Stmt
}
// Stmt represents a sql stmt. It is used as a return type to offer some testability over returning sql's Stmt type
// because we are able to mock its outputs and do not need an actual connection.
// client is the main implementation of Client, delegates to sql.Tx
// other implementations exist for testing purposes
type client struct {
tx *sql.Tx
}
func NewClient(tx *sql.Tx) Client {
return &client{tx: tx}
}
func (c client) Exec(query string, args ...any) (sql.Result, error) {
return c.tx.Exec(query, args...)
}
func (c client) Stmt(stmt *sql.Stmt) Stmt {
return c.tx.Stmt(stmt)
}
// Stmt is an interface over a subset of sql.Stmt methods
// rationale: allow mocking
type Stmt interface {
Exec(args ...any) (sql.Result, error)
Query(args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
}
// NewClient returns a Client with the given transaction assigned.
func NewClient(tx SQLTx) *Client {
return &Client{sqlTx: tx}
}
// Commit commits the transaction and then unlocks the database.
func (c *Client) Commit() error {
return c.sqlTx.Commit()
}
// Exec uses the sqlTX Exec() with the given stmt and args. The transaction will be automatically rolled back if Exec()
// returns an error.
func (c *Client) Exec(stmt string, args ...any) error {
_, err := c.sqlTx.Exec(stmt, args...)
if err != nil {
return c.rollback(c.sqlTx, err)
}
return nil
}
// Stmt adds the given sql.Stmt to the client's transaction and then returns a Stmt. An interface is being returned
// here to aid in testing callers by providing a way to configure the statement's behavior.
func (c *Client) Stmt(stmt *sql.Stmt) Stmt {
s := c.sqlTx.Stmt(stmt)
return s
}
// StmtExec Execs the given statement with the given args. It assumes the stmt has been added to the transaction. The
// transaction is rolled back if Stmt.Exec() returns an error.
func (c *Client) StmtExec(stmt Stmt, args ...any) error {
_, err := stmt.Exec(args...)
if err != nil {
logrus.Debugf("StmtExec failed: query %s, args: %s, err: %s", stmt, args, err)
return c.rollback(c.sqlTx, err)
}
return nil
}
// rollback handles rollbacks and wraps errors if needed
func (c *Client) rollback(tx SQLTx, err error) error {
rerr := tx.Rollback()
if rerr != nil {
return errors.Wrapf(err, "Encountered error, then encountered another error while rolling back: %v", rerr)
}
return errors.Wrapf(err, "Encountered error, successfully rolled back")
}
// Cancel rollbacks the transaction without wrapping an error. This only needs to be called if Client has not returned
// an error yet or has not committed. Otherwise, transaction has already rolled back, or in the case of Commit() it is too
// late.
func (c *Client) Cancel() error {
rerr := c.sqlTx.Rollback()
if rerr != sql.ErrTxDone {
return rerr
}
return nil
}

View File

@@ -1,184 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
//
// Package transaction is a generated GoMock package.
package transaction
import (
context "context"
sql "database/sql"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockStmt is a mock of Stmt interface.
type MockStmt struct {
ctrl *gomock.Controller
recorder *MockStmtMockRecorder
}
// MockStmtMockRecorder is the mock recorder for MockStmt.
type MockStmtMockRecorder struct {
mock *MockStmt
}
// NewMockStmt creates a new mock instance.
func NewMockStmt(ctrl *gomock.Controller) *MockStmt {
mock := &MockStmt{ctrl: ctrl}
mock.recorder = &MockStmtMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStmt) EXPECT() *MockStmtMockRecorder {
return m.recorder
}
// Exec mocks base method.
func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...)
}
// Query mocks base method.
func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Query", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Query indicates an expected call of Query.
func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...)
}
// QueryContext mocks base method.
func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryContext", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// QueryContext indicates an expected call of QueryContext.
func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
}
// MockSQLTx is a mock of SQLTx interface.
type MockSQLTx struct {
ctrl *gomock.Controller
recorder *MockSQLTxMockRecorder
}
// MockSQLTxMockRecorder is the mock recorder for MockSQLTx.
type MockSQLTxMockRecorder struct {
mock *MockSQLTx
}
// NewMockSQLTx creates a new mock instance.
func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx {
mock := &MockSQLTx{ctrl: ctrl}
mock.recorder = &MockSQLTxMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder {
return m.recorder
}
// Commit mocks base method.
func (m *MockSQLTx) Commit() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit")
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit))
}
// Exec mocks base method.
func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...)
}
// Rollback mocks base method.
func (m *MockSQLTx) Rollback() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Rollback")
ret0, _ := ret[0].(error)
return ret0
}
// Rollback indicates an expected call of Rollback.
func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback))
}
// Stmt mocks base method.
func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stmt", arg0)
ret0, _ := ret[0].(*sql.Stmt)
return ret0
}
// Stmt indicates an expected call of Stmt.
func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0)
}

View File

@@ -1,182 +0,0 @@
package transaction
import (
"database/sql"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
)
//go:generate mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
func TestNewClient(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
c := NewClient(tx)
assert.Equal(t, tx, c.sqlTx)
}
func TestCommit(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "Commit() with no errors returned from sql TX should return no error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
tx.EXPECT().Commit().Return(nil)
c := &Client{
sqlTx: tx,
}
err := c.Commit()
assert.Nil(t, err)
}})
tests = append(tests, testCase{description: "Commit() with error from sql TX commit() should return error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
tx.EXPECT().Commit().Return(fmt.Errorf("error"))
c := &Client{
sqlTx: tx,
}
err := c.Commit()
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestExec(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmtStr := "some statement %s"
arg := 5
// should be passed same statement and arg that was passed to parent function
tx.EXPECT().Exec(stmtStr, arg).Return(nil, nil)
c := &Client{
sqlTx: tx,
}
err := c.Exec(stmtStr, arg)
assert.Nil(t, err)
}})
tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmtStr := "some statement %s"
arg := 5
// should be passed same statement and arg that was passed to parent function
tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error"))
tx.EXPECT().Rollback().Return(nil)
c := &Client{
sqlTx: tx,
}
err := c.Exec(stmtStr, arg)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmtStr := "some statement %s"
arg := 5
// should be passed same statement and arg that was passed to parent function
tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error"))
tx.EXPECT().Rollback().Return(fmt.Errorf("error"))
c := &Client{
sqlTx: tx,
}
err := c.Exec(stmtStr, arg)
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestStmt(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmt := &sql.Stmt{}
var returnedTXStmt *sql.Stmt
// should be passed same statement and arg that was passed to parent function
tx.EXPECT().Stmt(stmt).Return(returnedTXStmt)
c := &Client{
sqlTx: tx,
}
returnedStmt := c.Stmt(stmt)
// whatever tx returned should be returned here. Nil was used because none of sql.Stmt's fields are exported so its simpler to test nil as it
// won't be equal to an empty struct
assert.Equal(t, returnedTXStmt, returnedStmt)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestStmtExec(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "StmtExec with no errors returned from Stmt should return no error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmt := NewMockStmt(gomock.NewController(t))
arg := "something"
// should be passed same arg that was passed to parent function
stmt.EXPECT().Exec(arg).Return(nil, nil)
c := &Client{
sqlTx: tx,
}
err := c.StmtExec(stmt, arg)
assert.Nil(t, err)
}})
tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and no Tx Rollback() error should return error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmt := NewMockStmt(gomock.NewController(t))
arg := "something"
// should be passed same arg that was passed to parent function
stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error"))
tx.EXPECT().Rollback().Return(nil)
c := &Client{
sqlTx: tx,
}
err := c.StmtExec(stmt, arg)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and Tx Rollback() error should return error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmt := NewMockStmt(gomock.NewController(t))
arg := "something"
// should be passed same arg that was passed to parent function
stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error"))
tx.EXPECT().Rollback().Return(fmt.Errorf("error2"))
c := &Client{
sqlTx: tx,
}
err := c.StmtExec(stmt, arg)
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}

View File

@@ -1,9 +1,9 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx)
// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Client,Stmt)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
// mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Client,Stmt
//
// Package db is a generated GoMock package.
@@ -14,9 +14,67 @@ import (
sql "database/sql"
reflect "reflect"
transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction"
gomock "go.uber.org/mock/gomock"
)
// MockClient is a mock of Client interface.
type MockClient struct {
ctrl *gomock.Controller
recorder *MockClientMockRecorder
}
// MockClientMockRecorder is the mock recorder for MockClient.
type MockClientMockRecorder struct {
mock *MockClient
}
// NewMockClient creates a new mock instance.
func NewMockClient(ctrl *gomock.Controller) *MockClient {
mock := &MockClient{ctrl: ctrl}
mock.recorder = &MockClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockClient) EXPECT() *MockClientMockRecorder {
return m.recorder
}
// Exec mocks base method.
func (m *MockClient) Exec(arg0 string, arg1 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockClient)(nil).Exec), varargs...)
}
// Stmt mocks base method.
func (m *MockClient) Stmt(arg0 *sql.Stmt) transaction.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stmt", arg0)
ret0, _ := ret[0].(transaction.Stmt)
return ret0
}
// Stmt indicates an expected call of Stmt.
func (mr *MockClientMockRecorder) Stmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockClient)(nil).Stmt), arg0)
}
// MockStmt is a mock of Stmt interface.
type MockStmt struct {
ctrl *gomock.Controller
@@ -97,88 +155,3 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
}
// MockSQLTx is a mock of SQLTx interface.
type MockSQLTx struct {
ctrl *gomock.Controller
recorder *MockSQLTxMockRecorder
}
// MockSQLTxMockRecorder is the mock recorder for MockSQLTx.
type MockSQLTxMockRecorder struct {
mock *MockSQLTx
}
// NewMockSQLTx creates a new mock instance.
func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx {
mock := &MockSQLTx{ctrl: ctrl}
mock.recorder = &MockSQLTxMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder {
return m.recorder
}
// Commit mocks base method.
func (m *MockSQLTx) Commit() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit")
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit))
}
// Exec mocks base method.
func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...)
}
// Rollback mocks base method.
func (m *MockSQLTx) Rollback() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Rollback")
ret0, _ := ret[0].(error)
return ret0
}
// Rollback indicates an expected call of Rollback.
func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback))
}
// Stmt mocks base method.
func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stmt", arg0)
ret0, _ := ret[0].(*sql.Stmt)
return ret0
}
// Stmt indicates an expected call of Stmt.
func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0)
}