mirror of
https://github.com/rancher/steve.git
synced 2025-09-08 10:49:25 +00:00
[v0.4] Move lasso SQL cache in Steve (#473)
* Copy pkg/cache/sql from lasso to pkg/sqlcache * Rename import from github.com/rancher/lasso/pkg/cache/sql to github.com/rancher/steve/pkg/sqlcache * go mod tidy * Fix lint errors * Remove lasso SQL cache mentions * Fix more CI lint errors * fix goimports Signed-off-by: Silvio Moioli <silvio@moioli.net> * Fix imports * Fix more linting errors --------- Signed-off-by: Silvio Moioli <silvio@moioli.net> Co-authored-by: Silvio Moioli <silvio@moioli.net>
This commit is contained in:
374
pkg/sqlcache/db/client.go
Normal file
374
pkg/sqlcache/db/client.go
Normal file
@@ -0,0 +1,374 @@
|
||||
/*
|
||||
Package db offers client struct and functions to interact with database connection. It provides encrypting, decrypting,
|
||||
and a way to reset the database.
|
||||
*/
|
||||
package db
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rancher/steve/pkg/sqlcache/db/transaction"
|
||||
|
||||
// needed for drivers
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
const (
|
||||
// InformerObjectCacheDBPath is where SQLite's object database file will be stored relative to process running steve
|
||||
InformerObjectCacheDBPath = "informer_object_cache.db"
|
||||
|
||||
informerObjectCachePerms fs.FileMode = 0o600
|
||||
)
|
||||
|
||||
// Client is a database client that provides encrypting, decrypting, and database resetting.
|
||||
type Client struct {
|
||||
conn Connection
|
||||
connLock sync.RWMutex
|
||||
encryptor Encryptor
|
||||
decryptor Decryptor
|
||||
}
|
||||
|
||||
// Connection represents a connection pool.
|
||||
type Connection interface {
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
Prepare(query string) (*sql.Stmt, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Closable Closes an underlying connection and returns an error on failure.
|
||||
type Closable interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Rows represents sql rows. It exposes method to navigate the rows, read their outputs, and close them.
|
||||
type Rows interface {
|
||||
Next() bool
|
||||
Err() error
|
||||
Close() error
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
// QueryError encapsulates an error while executing a query
|
||||
type QueryError struct {
|
||||
QueryString string
|
||||
Err error
|
||||
}
|
||||
|
||||
// Error returns a string representation of this QueryError
|
||||
func (e *QueryError) Error() string {
|
||||
return "while executing query: " + e.QueryString + " got error: " + e.Err.Error()
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error
|
||||
func (e *QueryError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// TXClient represents a sql transaction. The TXClient must manage rollbacks as rollback functionality is not exposed.
|
||||
type TXClient interface {
|
||||
StmtExec(stmt transaction.Stmt, args ...any) error
|
||||
Exec(stmt string, args ...any) error
|
||||
Commit() error
|
||||
Stmt(stmt *sql.Stmt) transaction.Stmt
|
||||
Cancel() error
|
||||
}
|
||||
|
||||
// Encryptor encrypts data with a key which is rotated to avoid wear-out.
|
||||
type Encryptor interface {
|
||||
// Encrypt encrypts the specified data, returning: the encrypted data, the nonce used to encrypt the data, and an ID identifying the key that was used (as it rotates). On failure error is returned instead.
|
||||
Encrypt([]byte) ([]byte, []byte, uint32, error)
|
||||
}
|
||||
|
||||
// Decryptor decrypts data previously encrypted by Encryptor.
|
||||
type Decryptor interface {
|
||||
// Decrypt accepts a chunk of encrypted data, the nonce used to encrypt it and the ID of the used key (as it rotates). It returns the decrypted data or an error.
|
||||
Decrypt([]byte, []byte, uint32) ([]byte, error)
|
||||
}
|
||||
|
||||
// NewClient returns a Client. If the given connection is nil then a default one will be created.
|
||||
func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (*Client, error) {
|
||||
client := &Client{
|
||||
encryptor: encryptor,
|
||||
decryptor: decryptor,
|
||||
}
|
||||
if c != nil {
|
||||
client.conn = c
|
||||
return client, nil
|
||||
}
|
||||
err := client.NewConnection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// Prepare prepares the given string into a sql statement on the client's connection.
|
||||
func (c *Client) Prepare(stmt string) *sql.Stmt {
|
||||
c.connLock.RLock()
|
||||
defer c.connLock.RUnlock()
|
||||
prepared, err := c.conn.Prepare(stmt)
|
||||
if err != nil {
|
||||
panic(errors.Errorf("Error preparing statement: %s\n%v", stmt, err))
|
||||
}
|
||||
return prepared
|
||||
}
|
||||
|
||||
// QueryForRows queries the given stmt with the given params and returns the resulting rows. The query wil be retried
|
||||
// given a sqlite busy error.
|
||||
func (c *Client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) {
|
||||
c.connLock.RLock()
|
||||
defer c.connLock.RUnlock()
|
||||
|
||||
return stmt.QueryContext(ctx, params...)
|
||||
}
|
||||
|
||||
// CloseStmt will call close on the given Closable. It is intended to be used with a sql statement. This function is meant
|
||||
// to replace stmt.Close which can cause panics when callers unit-test since there usually is no real underlying connection.
|
||||
func (c *Client) CloseStmt(closable Closable) error {
|
||||
return closable.Close()
|
||||
}
|
||||
|
||||
// ReadObjects Scans the given rows, performs any necessary decryption, converts the data to objects of the given type,
|
||||
// and returns a slice of those objects.
|
||||
func (c *Client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) {
|
||||
c.connLock.RLock()
|
||||
defer c.connLock.RUnlock()
|
||||
|
||||
var result []any
|
||||
for rows.Next() {
|
||||
data, err := c.decryptScan(rows, shouldDecrypt)
|
||||
if err != nil {
|
||||
return nil, closeRowsOnError(rows, err)
|
||||
}
|
||||
singleResult, err := fromBytes(data, typ)
|
||||
if err != nil {
|
||||
return nil, closeRowsOnError(rows, err)
|
||||
}
|
||||
result = append(result, singleResult.Elem().Interface())
|
||||
}
|
||||
err := rows.Err()
|
||||
if err != nil {
|
||||
return nil, closeRowsOnError(rows, err)
|
||||
}
|
||||
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ReadStrings scans the given rows into strings, and then returns the strings as a slice.
|
||||
func (c *Client) ReadStrings(rows Rows) ([]string, error) {
|
||||
c.connLock.RLock()
|
||||
defer c.connLock.RUnlock()
|
||||
|
||||
var result []string
|
||||
for rows.Next() {
|
||||
var key string
|
||||
err := rows.Scan(&key)
|
||||
if err != nil {
|
||||
return nil, closeRowsOnError(rows, err)
|
||||
}
|
||||
|
||||
result = append(result, key)
|
||||
}
|
||||
err := rows.Err()
|
||||
if err != nil {
|
||||
return nil, closeRowsOnError(rows, err)
|
||||
}
|
||||
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ReadInt scans the first of the given rows into a single int (eg. for COUNT() queries)
|
||||
func (c *Client) ReadInt(rows Rows) (int, error) {
|
||||
c.connLock.RLock()
|
||||
defer c.connLock.RUnlock()
|
||||
|
||||
if !rows.Next() {
|
||||
return 0, closeRowsOnError(rows, sql.ErrNoRows)
|
||||
}
|
||||
|
||||
var result int
|
||||
err := rows.Scan(&result)
|
||||
if err != nil {
|
||||
return 0, closeRowsOnError(rows, err)
|
||||
}
|
||||
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return 0, closeRowsOnError(rows, err)
|
||||
}
|
||||
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// BeginTx attempts to begin a transaction.
|
||||
// If forWriting is true, this method blocks until all other concurrent forWriting
|
||||
// transactions have either committed or rolled back.
|
||||
// If forWriting is false, it is assumed the returned transaction will exclusively
|
||||
// be used for DQL (eg. SELECT) queries.
|
||||
// Not respecting the above rule might result in transactions failing with unexpected
|
||||
// SQLITE_BUSY (5) errors (aka "Runtime error: database is locked").
|
||||
// See discussion in https://github.com/rancher/lasso/pull/98 for details
|
||||
func (c *Client) BeginTx(ctx context.Context, forWriting bool) (TXClient, error) {
|
||||
c.connLock.RLock()
|
||||
defer c.connLock.RUnlock()
|
||||
// note: this assumes _txlock=immediate in the connection string, see NewConnection
|
||||
sqlTx, err := c.conn.BeginTx(ctx, &sql.TxOptions{
|
||||
ReadOnly: !forWriting,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return transaction.NewClient(sqlTx), nil
|
||||
}
|
||||
|
||||
func (c *Client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) {
|
||||
var data, dataNonce sql.RawBytes
|
||||
var kid uint32
|
||||
err := rows.Scan(&data, &dataNonce, &kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.decryptor != nil && shouldDecrypt {
|
||||
decryptedData, err := c.decryptor.Decrypt(data, dataNonce, kid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decryptedData, nil
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Upsert used to be called upsertEncrypted in store package before move
|
||||
func (c *Client) Upsert(tx TXClient, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error {
|
||||
objBytes := toBytes(obj)
|
||||
var dataNonce []byte
|
||||
var err error
|
||||
var kid uint32
|
||||
if c.encryptor != nil && shouldEncrypt {
|
||||
objBytes, dataNonce, kid, err = c.encryptor.Encrypt(objBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.StmtExec(tx.Stmt(stmt), key, objBytes, dataNonce, kid)
|
||||
}
|
||||
|
||||
// toBytes encodes an object to a byte slice
|
||||
func toBytes(obj any) []byte {
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
err := enc.Encode(obj)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("error while gobbing object: %w", err))
|
||||
}
|
||||
bb := buf.Bytes()
|
||||
return bb
|
||||
}
|
||||
|
||||
// fromBytes decodes an object from a byte slice
|
||||
func fromBytes(buf sql.RawBytes, typ reflect.Type) (reflect.Value, error) {
|
||||
dec := gob.NewDecoder(bytes.NewReader(buf))
|
||||
singleResult := reflect.New(typ)
|
||||
err := dec.DecodeValue(singleResult)
|
||||
return singleResult, err
|
||||
}
|
||||
|
||||
// closeRowsOnError closes the sql.Rows object and wraps errors if needed
|
||||
func closeRowsOnError(rows Rows, err error) error {
|
||||
ce := rows.Close()
|
||||
if ce != nil {
|
||||
return fmt.Errorf("error in closing rows while handling %s: %w", err.Error(), ce)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// NewConnection checks for currently existing connection, closes one if it exists, removes any relevant db files, and opens a new connection which subsequently
|
||||
// creates new files.
|
||||
func (c *Client) NewConnection() error {
|
||||
c.connLock.Lock()
|
||||
defer c.connLock.Unlock()
|
||||
if c.conn != nil {
|
||||
err := c.conn.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err := os.RemoveAll(InformerObjectCacheDBPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set the permissions in advance, because we can't control them if
|
||||
// the file is created by a sql.Open call instead.
|
||||
if err := touchFile(InformerObjectCacheDBPath, informerObjectCachePerms); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sqlDB, err := sql.Open("sqlite", "file:"+InformerObjectCacheDBPath+"?"+
|
||||
// open SQLite file in read-write mode, creating it if it does not exist
|
||||
"mode=rwc&"+
|
||||
// use the WAL journal mode for consistency and efficiency
|
||||
"_pragma=journal_mode=wal&"+
|
||||
// do not even attempt to attain durability. Database is thrown away at pod restart
|
||||
"_pragma=synchronous=off&"+
|
||||
// do check foreign keys and honor ON DELETE CASCADE
|
||||
"_pragma=foreign_keys=on&"+
|
||||
// if two transactions want to write at the same time, allow 2 minutes for the first to complete
|
||||
// before baling out
|
||||
"_pragma=busy_timeout=120000&"+
|
||||
// default to IMMEDIATE mode for transactions. Setting this parameter is the only current way
|
||||
// to be able to switch between DEFERRED and IMMEDIATE modes in modernc.org/sqlite's implementation
|
||||
// of BeginTx
|
||||
"_txlock=immediate")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.conn = sqlDB
|
||||
return nil
|
||||
}
|
||||
|
||||
// This acts like "touch" for both existing files and non-existing files.
|
||||
// permissions.
|
||||
//
|
||||
// It's created with the correct perms, and if the file already exists, it will
|
||||
// be chmodded to the correct perms.
|
||||
func touchFile(filename string, perms fs.FileMode) error {
|
||||
f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, perms)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.Chmod(filename, perms)
|
||||
}
|
667
pkg/sqlcache/db/client_test.go
Normal file
667
pkg/sqlcache/db/client_test.go
Normal file
@@ -0,0 +1,667 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// Mocks for this test are generated with the following command.
|
||||
//go:generate mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient
|
||||
//go:generate mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
|
||||
|
||||
type testStoreObject struct {
|
||||
Id string
|
||||
Val string
|
||||
}
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
test func(t *testing.T)
|
||||
}
|
||||
|
||||
var tests []testCase
|
||||
|
||||
// Tests with shouldEncryptSet to false
|
||||
tests = append(tests, testCase{description: "Query rows with no params, no errors", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
expectedClient := &Client{
|
||||
conn: c,
|
||||
encryptor: e,
|
||||
decryptor: d,
|
||||
}
|
||||
client, err := NewClient(c, e, d)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, expectedClient, client)
|
||||
},
|
||||
})
|
||||
t.Parallel()
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) { test.test(t) })
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryForRows(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
test func(t *testing.T)
|
||||
}
|
||||
|
||||
var tests []testCase
|
||||
|
||||
// Tests with shouldEncryptSet to false
|
||||
tests = append(tests, testCase{description: "Query rows with no params, no errors", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
client := SetupClient(t, c, nil, nil)
|
||||
s := NewMockStmt(gomock.NewController(t))
|
||||
ctx := context.TODO()
|
||||
r := &sql.Rows{}
|
||||
s.EXPECT().QueryContext(ctx).Return(r, nil)
|
||||
rows, err := client.QueryForRows(ctx, s)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, r, rows)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "Query rows with params, QueryContext() error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
client := SetupClient(t, c, nil, nil)
|
||||
s := NewMockStmt(gomock.NewController(t))
|
||||
ctx := context.TODO()
|
||||
s.EXPECT().QueryContext(ctx).Return(nil, fmt.Errorf("error"))
|
||||
_, err := client.QueryForRows(ctx, s)
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
t.Parallel()
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) { test.test(t) })
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryObjects(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
test func(t *testing.T)
|
||||
}
|
||||
|
||||
var tests []testCase
|
||||
|
||||
testObject := testStoreObject{Id: "something", Val: "a"}
|
||||
var keyId uint32 = math.MaxUint32
|
||||
|
||||
// Tests with shouldEncryptSet to false
|
||||
tests = append(tests, testCase{description: "Query objects, with one row, and no errors", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(true)
|
||||
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
|
||||
*a[0].(*sql.RawBytes) = toBytes(testObject)
|
||||
*a[1].(*sql.RawBytes) = toBytes(testObject)
|
||||
*a[2].(*uint32) = keyId
|
||||
})
|
||||
d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(toBytes(testObject), nil)
|
||||
r.EXPECT().Err().Return(nil)
|
||||
r.EXPECT().Next().Return(false)
|
||||
r.EXPECT().Close().Return(nil)
|
||||
client := SetupClient(t, c, e, d)
|
||||
items, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(items))
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "Query objects, with one row, and a decrypt error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(true)
|
||||
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
|
||||
*a[0].(*sql.RawBytes) = toBytes(testObject)
|
||||
*a[1].(*sql.RawBytes) = toBytes(
|
||||
testObject)
|
||||
*a[2].(*uint32) = keyId
|
||||
})
|
||||
d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(nil, fmt.Errorf("error"))
|
||||
r.EXPECT().Close().Return(nil)
|
||||
client := SetupClient(t, c, e, d)
|
||||
_, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "Query objects, with one row, and a Scan() error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(true)
|
||||
r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error"))
|
||||
r.EXPECT().Close().Return(nil)
|
||||
client := SetupClient(t, c, e, d)
|
||||
_, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "Query objects, with one row, and a Close() error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(true)
|
||||
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
|
||||
*a[0].(*sql.RawBytes) = toBytes(testObject)
|
||||
*a[1].(*sql.RawBytes) = toBytes(testObject)
|
||||
*a[2].(*uint32) = keyId
|
||||
})
|
||||
d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(toBytes(testObject), nil)
|
||||
r.EXPECT().Err().Return(nil)
|
||||
r.EXPECT().Next().Return(false)
|
||||
r.EXPECT().Close().Return(fmt.Errorf("error"))
|
||||
client := SetupClient(t, c, e, d)
|
||||
_, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "Query objects, with no rows, and no errors", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(false)
|
||||
r.EXPECT().Err().Return(nil)
|
||||
r.EXPECT().Close().Return(nil)
|
||||
client := SetupClient(t, c, e, d)
|
||||
items, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, len(items))
|
||||
},
|
||||
})
|
||||
t.Parallel()
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) { test.test(t) })
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryStrings(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
test func(t *testing.T)
|
||||
}
|
||||
|
||||
var tests []testCase
|
||||
|
||||
testObject := testStoreObject{Id: "something", Val: "a"}
|
||||
// Tests with shouldEncryptSet to false
|
||||
tests = append(tests, testCase{description: "ReadStrings(), with one row, and no errors", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(true)
|
||||
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
|
||||
for _, v := range a {
|
||||
vk := v.(*string)
|
||||
*vk = string(toBytes(testObject.Id))
|
||||
}
|
||||
})
|
||||
r.EXPECT().Err().Return(nil)
|
||||
r.EXPECT().Next().Return(false)
|
||||
r.EXPECT().Close().Return(nil)
|
||||
client := SetupClient(t, c, e, d)
|
||||
items, err := client.ReadStrings(r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 1, len(items))
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "Query objects, with one row, and Scan error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(true)
|
||||
r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error"))
|
||||
r.EXPECT().Close().Return(nil)
|
||||
client := SetupClient(t, c, e, d)
|
||||
_, err := client.ReadStrings(r)
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "ReadStrings(), with one row, and Err() error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(true)
|
||||
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
|
||||
for _, v := range a {
|
||||
vk := v.(*string)
|
||||
*vk = string(toBytes(testObject.Id))
|
||||
}
|
||||
})
|
||||
r.EXPECT().Next().Return(false)
|
||||
r.EXPECT().Err().Return(fmt.Errorf("error"))
|
||||
r.EXPECT().Close().Return(nil)
|
||||
client := SetupClient(t, c, e, d)
|
||||
_, err := client.ReadStrings(r)
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "ReadStrings(), with one row, and Close() error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(true)
|
||||
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
|
||||
for _, v := range a {
|
||||
vk := v.(*string)
|
||||
*vk = string(toBytes(testObject.Id))
|
||||
}
|
||||
})
|
||||
r.EXPECT().Err().Return(nil)
|
||||
r.EXPECT().Next().Return(false)
|
||||
r.EXPECT().Close().Return(fmt.Errorf("error"))
|
||||
client := SetupClient(t, c, e, d)
|
||||
_, err := client.ReadStrings(r)
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "ReadStrings(), with no rows, and no errors", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(false)
|
||||
r.EXPECT().Err().Return(nil)
|
||||
r.EXPECT().Close().Return(nil)
|
||||
client := SetupClient(t, c, e, d)
|
||||
items, err := client.ReadStrings(r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 0, len(items))
|
||||
},
|
||||
})
|
||||
t.Parallel()
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) { test.test(t) })
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadInt(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
test func(t *testing.T)
|
||||
}
|
||||
|
||||
var tests []testCase
|
||||
|
||||
testResult := 42
|
||||
tests = append(tests, testCase{description: "One row, no errors", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(true)
|
||||
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
|
||||
p := a[0].(*int)
|
||||
*p = testResult
|
||||
})
|
||||
r.EXPECT().Err().Return(nil)
|
||||
r.EXPECT().Close().Return(nil)
|
||||
client := SetupClient(t, c, e, d)
|
||||
result, err := client.ReadInt(r)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, 42, result)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "One row, Scan error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(true)
|
||||
r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error"))
|
||||
r.EXPECT().Close().Return(nil)
|
||||
client := SetupClient(t, c, e, d)
|
||||
_, err := client.ReadInt(r)
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "One row, Err() error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(true)
|
||||
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
|
||||
a[0] = testResult
|
||||
})
|
||||
r.EXPECT().Err().Return(fmt.Errorf("error"))
|
||||
r.EXPECT().Close().Return(nil)
|
||||
client := SetupClient(t, c, e, d)
|
||||
_, err := client.ReadInt(r)
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "One row, Close() error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(true)
|
||||
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
|
||||
a[0] = testResult
|
||||
})
|
||||
r.EXPECT().Err().Return(nil)
|
||||
r.EXPECT().Close().Return(fmt.Errorf("error"))
|
||||
client := SetupClient(t, c, e, d)
|
||||
_, err := client.ReadInt(r)
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "No rows error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
r := SetupMockRows(t)
|
||||
r.EXPECT().Next().Return(false)
|
||||
r.EXPECT().Close().Return(nil)
|
||||
client := SetupClient(t, c, e, d)
|
||||
_, err := client.ReadInt(r)
|
||||
assert.ErrorIs(t, err, sql.ErrNoRows)
|
||||
},
|
||||
})
|
||||
t.Parallel()
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) { test.test(t) })
|
||||
}
|
||||
}
|
||||
|
||||
func TestBegin(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
test func(t *testing.T)
|
||||
}
|
||||
|
||||
var tests []testCase
|
||||
|
||||
// Tests with shouldEncryptSet to false
|
||||
tests = append(tests, testCase{description: "BeginTx(), with no errors", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
|
||||
sqlTx := &sql.Tx{}
|
||||
c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(sqlTx, nil)
|
||||
client := SetupClient(t, c, e, d)
|
||||
txC, err := client.BeginTx(context.Background(), false)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, txC)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "BeginTx(), with forWriting option set", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
|
||||
sqlTx := &sql.Tx{}
|
||||
c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: false}).Return(sqlTx, nil)
|
||||
client := SetupClient(t, c, e, d)
|
||||
txC, err := client.BeginTx(context.Background(), true)
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, txC)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "BeginTx(), with connection Begin() error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
|
||||
c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(nil, fmt.Errorf("error"))
|
||||
client := SetupClient(t, c, e, d)
|
||||
_, err := client.BeginTx(context.Background(), false)
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
t.Parallel()
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) { test.test(t) })
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpsert(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
test func(t *testing.T)
|
||||
}
|
||||
|
||||
var tests []testCase
|
||||
|
||||
testObject := testStoreObject{Id: "something", Val: "a"}
|
||||
var keyID uint32 = 5
|
||||
|
||||
// Tests with shouldEncryptSet to true
|
||||
tests = append(tests, testCase{description: "Upsert() with no errors", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
|
||||
client := SetupClient(t, c, e, d)
|
||||
txC := NewMockTXClient(gomock.NewController(t))
|
||||
sqlStmt := &sql.Stmt{}
|
||||
stmt := NewMockStmt(gomock.NewController(t))
|
||||
testObjBytes := toBytes(testObject)
|
||||
testByteValue := []byte("something")
|
||||
e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil)
|
||||
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
|
||||
txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(nil)
|
||||
err := client.Upsert(txC, sqlStmt, "somekey", testObject, true)
|
||||
assert.Nil(t, err)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "Upsert() with Encrypt() error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
|
||||
client := SetupClient(t, c, e, d)
|
||||
txC := NewMockTXClient(gomock.NewController(t))
|
||||
sqlStmt := &sql.Stmt{}
|
||||
testObjBytes := toBytes(testObject)
|
||||
e.EXPECT().Encrypt(testObjBytes).Return(nil, nil, uint32(0), fmt.Errorf("error"))
|
||||
err := client.Upsert(txC, sqlStmt, "somekey", testObject, true)
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "Upsert() with StmtExec() error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
|
||||
client := SetupClient(t, c, e, d)
|
||||
txC := NewMockTXClient(gomock.NewController(t))
|
||||
sqlStmt := &sql.Stmt{}
|
||||
stmt := NewMockStmt(gomock.NewController(t))
|
||||
testObjBytes := toBytes(testObject)
|
||||
testByteValue := []byte("something")
|
||||
e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil)
|
||||
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
|
||||
txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(fmt.Errorf("error"))
|
||||
err := client.Upsert(txC, sqlStmt, "somekey", testObject, true)
|
||||
assert.NotNil(t, err)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "Upsert() with no errors and shouldEncrypt false", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
|
||||
client := SetupClient(t, c, e, d)
|
||||
txC := NewMockTXClient(gomock.NewController(t))
|
||||
sqlStmt := &sql.Stmt{}
|
||||
stmt := NewMockStmt(gomock.NewController(t))
|
||||
var testByteValue []byte
|
||||
testObjBytes := toBytes(testObject)
|
||||
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
|
||||
txC.EXPECT().StmtExec(stmt, "somekey", testObjBytes, testByteValue, uint32(0)).Return(nil)
|
||||
err := client.Upsert(txC, sqlStmt, "somekey", testObject, false)
|
||||
assert.Nil(t, err)
|
||||
},
|
||||
})
|
||||
t.Parallel()
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) { test.test(t) })
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepare(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
test func(t *testing.T)
|
||||
}
|
||||
|
||||
var tests []testCase
|
||||
tests = append(tests, testCase{description: "Prepare() with no errors", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
|
||||
client := SetupClient(t, c, e, d)
|
||||
sqlStmt := &sql.Stmt{}
|
||||
c.EXPECT().Prepare("something").Return(sqlStmt, nil)
|
||||
|
||||
stmt := client.Prepare("something")
|
||||
assert.Equal(t, sqlStmt, stmt)
|
||||
},
|
||||
})
|
||||
tests = append(tests, testCase{description: "Prepare() with Connection Prepare() error", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
|
||||
client := SetupClient(t, c, e, d)
|
||||
c.EXPECT().Prepare("something").Return(nil, fmt.Errorf("error"))
|
||||
|
||||
assert.Panics(t, func() { client.Prepare("something") })
|
||||
},
|
||||
})
|
||||
t.Parallel()
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) { test.test(t) })
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnection(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
test func(t *testing.T)
|
||||
}
|
||||
|
||||
var tests []testCase
|
||||
tests = append(tests, testCase{description: "NewConnection replaces file", test: func(t *testing.T) {
|
||||
c := SetupMockConnection(t)
|
||||
e := SetupMockEncryptor(t)
|
||||
d := SetupMockDecryptor(t)
|
||||
|
||||
client := SetupClient(t, c, e, d)
|
||||
c.EXPECT().Close().Return(nil)
|
||||
|
||||
err := client.NewConnection()
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Create a transaction to ensure that the file is written to disk.
|
||||
txC, err := client.BeginTx(context.Background(), false)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, txC.Commit())
|
||||
|
||||
assert.FileExists(t, InformerObjectCacheDBPath)
|
||||
assertFileHasPermissions(t, InformerObjectCacheDBPath, 0600)
|
||||
|
||||
err = os.Remove(InformerObjectCacheDBPath)
|
||||
if err != nil {
|
||||
assert.Fail(t, "could not remove object cache path after test")
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
t.Parallel()
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) { test.test(t) })
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommit(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestRollback(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func SetupMockConnection(t *testing.T) *MockConnection {
|
||||
mockC := NewMockConnection(gomock.NewController(t))
|
||||
return mockC
|
||||
}
|
||||
|
||||
func SetupMockEncryptor(t *testing.T) *MockEncryptor {
|
||||
mockE := NewMockEncryptor(gomock.NewController(t))
|
||||
return mockE
|
||||
}
|
||||
|
||||
func SetupMockDecryptor(t *testing.T) *MockDecryptor {
|
||||
MockD := NewMockDecryptor(gomock.NewController(t))
|
||||
return MockD
|
||||
}
|
||||
|
||||
func SetupMockRows(t *testing.T) *MockRows {
|
||||
MockR := NewMockRows(gomock.NewController(t))
|
||||
return MockR
|
||||
}
|
||||
|
||||
func SetupClient(t *testing.T, connection Connection, encryptor Encryptor, decryptor Decryptor) *Client {
|
||||
c, _ := NewClient(connection, encryptor, decryptor)
|
||||
return c
|
||||
}
|
||||
|
||||
func TestTouchFile(t *testing.T) {
|
||||
t.Run("File doesn't exist before", func(t *testing.T) {
|
||||
filename := filepath.Join(t.TempDir(), "test1.txt")
|
||||
assert.NoError(t, touchFile(filename, 0600))
|
||||
assertFileHasPermissions(t, filename, 0600)
|
||||
})
|
||||
|
||||
t.Run("File exists with different permissions", func(t *testing.T) {
|
||||
filename := filepath.Join(t.TempDir(), "test2.txt")
|
||||
assert.NoError(t, os.WriteFile(filename, []byte("test"), 0644))
|
||||
assert.NoError(t, touchFile(filename, 0600))
|
||||
assertFileHasPermissions(t, filename, 0600)
|
||||
})
|
||||
}
|
||||
|
||||
func assertFileHasPermissions(t *testing.T, fname string, wantPerms fs.FileMode) bool {
|
||||
t.Helper()
|
||||
info, err := os.Lstat(fname)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return assert.Fail(t, fmt.Sprintf("unable to find file %q", fname))
|
||||
}
|
||||
return assert.Fail(t, fmt.Sprintf("error when running os.Lstat(%q): %s", fname, err))
|
||||
}
|
||||
|
||||
// Stringifying the perms makes it easier to read than a Hex comparison.
|
||||
assert.Equal(t, wantPerms.String(), info.Mode().Perm().String())
|
||||
|
||||
return true
|
||||
}
|
370
pkg/sqlcache/db/db_mocks_test.go
Normal file
370
pkg/sqlcache/db/db_mocks_test.go
Normal file
@@ -0,0 +1,370 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: Rows,Connection,Encryptor,Decryptor,TXClient)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient
|
||||
//
|
||||
|
||||
// Package db is a generated GoMock package.
|
||||
package db
|
||||
|
||||
import (
|
||||
context "context"
|
||||
sql "database/sql"
|
||||
reflect "reflect"
|
||||
|
||||
transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockRows is a mock of Rows interface.
|
||||
type MockRows struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRowsMockRecorder
|
||||
}
|
||||
|
||||
// MockRowsMockRecorder is the mock recorder for MockRows.
|
||||
type MockRowsMockRecorder struct {
|
||||
mock *MockRows
|
||||
}
|
||||
|
||||
// NewMockRows creates a new mock instance.
|
||||
func NewMockRows(ctrl *gomock.Controller) *MockRows {
|
||||
mock := &MockRows{ctrl: ctrl}
|
||||
mock.recorder = &MockRowsMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRows) EXPECT() *MockRowsMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockRows) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockRowsMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close))
|
||||
}
|
||||
|
||||
// Err mocks base method.
|
||||
func (m *MockRows) Err() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Err")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Err indicates an expected call of Err.
|
||||
func (mr *MockRowsMockRecorder) Err() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err))
|
||||
}
|
||||
|
||||
// Next mocks base method.
|
||||
func (m *MockRows) Next() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Next")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Next indicates an expected call of Next.
|
||||
func (mr *MockRowsMockRecorder) Next() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next))
|
||||
}
|
||||
|
||||
// Scan mocks base method.
|
||||
func (m *MockRows) Scan(arg0 ...any) error {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{}
|
||||
for _, a := range arg0 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Scan", varargs...)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Scan indicates an expected call of Scan.
|
||||
func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...)
|
||||
}
|
||||
|
||||
// MockConnection is a mock of Connection interface.
|
||||
type MockConnection struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockConnectionMockRecorder
|
||||
}
|
||||
|
||||
// MockConnectionMockRecorder is the mock recorder for MockConnection.
|
||||
type MockConnectionMockRecorder struct {
|
||||
mock *MockConnection
|
||||
}
|
||||
|
||||
// NewMockConnection creates a new mock instance.
|
||||
func NewMockConnection(ctrl *gomock.Controller) *MockConnection {
|
||||
mock := &MockConnection{ctrl: ctrl}
|
||||
mock.recorder = &MockConnectionMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockConnection) EXPECT() *MockConnectionMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// BeginTx mocks base method.
|
||||
func (m *MockConnection) BeginTx(arg0 context.Context, arg1 *sql.TxOptions) (*sql.Tx, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BeginTx", arg0, arg1)
|
||||
ret0, _ := ret[0].(*sql.Tx)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// BeginTx indicates an expected call of BeginTx.
|
||||
func (mr *MockConnectionMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockConnection)(nil).BeginTx), arg0, arg1)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockConnection) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockConnectionMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnection)(nil).Close))
|
||||
}
|
||||
|
||||
// Exec mocks base method.
|
||||
func (m *MockConnection) Exec(arg0 string, arg1 ...any) (sql.Result, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Exec", varargs...)
|
||||
ret0, _ := ret[0].(sql.Result)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Exec indicates an expected call of Exec.
|
||||
func (mr *MockConnectionMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockConnection)(nil).Exec), varargs...)
|
||||
}
|
||||
|
||||
// Prepare mocks base method.
|
||||
func (m *MockConnection) Prepare(arg0 string) (*sql.Stmt, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Prepare", arg0)
|
||||
ret0, _ := ret[0].(*sql.Stmt)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Prepare indicates an expected call of Prepare.
|
||||
func (mr *MockConnectionMockRecorder) Prepare(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockConnection)(nil).Prepare), arg0)
|
||||
}
|
||||
|
||||
// MockEncryptor is a mock of Encryptor interface.
|
||||
type MockEncryptor struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockEncryptorMockRecorder
|
||||
}
|
||||
|
||||
// MockEncryptorMockRecorder is the mock recorder for MockEncryptor.
|
||||
type MockEncryptorMockRecorder struct {
|
||||
mock *MockEncryptor
|
||||
}
|
||||
|
||||
// NewMockEncryptor creates a new mock instance.
|
||||
func NewMockEncryptor(ctrl *gomock.Controller) *MockEncryptor {
|
||||
mock := &MockEncryptor{ctrl: ctrl}
|
||||
mock.recorder = &MockEncryptorMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockEncryptor) EXPECT() *MockEncryptorMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Encrypt mocks base method.
|
||||
func (m *MockEncryptor) Encrypt(arg0 []byte) ([]byte, []byte, uint32, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Encrypt", arg0)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].([]byte)
|
||||
ret2, _ := ret[2].(uint32)
|
||||
ret3, _ := ret[3].(error)
|
||||
return ret0, ret1, ret2, ret3
|
||||
}
|
||||
|
||||
// Encrypt indicates an expected call of Encrypt.
|
||||
func (mr *MockEncryptorMockRecorder) Encrypt(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockEncryptor)(nil).Encrypt), arg0)
|
||||
}
|
||||
|
||||
// MockDecryptor is a mock of Decryptor interface.
|
||||
type MockDecryptor struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockDecryptorMockRecorder
|
||||
}
|
||||
|
||||
// MockDecryptorMockRecorder is the mock recorder for MockDecryptor.
|
||||
type MockDecryptorMockRecorder struct {
|
||||
mock *MockDecryptor
|
||||
}
|
||||
|
||||
// NewMockDecryptor creates a new mock instance.
|
||||
func NewMockDecryptor(ctrl *gomock.Controller) *MockDecryptor {
|
||||
mock := &MockDecryptor{ctrl: ctrl}
|
||||
mock.recorder = &MockDecryptorMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockDecryptor) EXPECT() *MockDecryptorMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Decrypt mocks base method.
|
||||
func (m *MockDecryptor) Decrypt(arg0, arg1 []byte, arg2 uint32) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Decrypt", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Decrypt indicates an expected call of Decrypt.
|
||||
func (mr *MockDecryptorMockRecorder) Decrypt(arg0, arg1, arg2 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockDecryptor)(nil).Decrypt), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// MockTXClient is a mock of TXClient interface.
|
||||
type MockTXClient struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockTXClientMockRecorder
|
||||
}
|
||||
|
||||
// MockTXClientMockRecorder is the mock recorder for MockTXClient.
|
||||
type MockTXClientMockRecorder struct {
|
||||
mock *MockTXClient
|
||||
}
|
||||
|
||||
// NewMockTXClient creates a new mock instance.
|
||||
func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient {
|
||||
mock := &MockTXClient{ctrl: ctrl}
|
||||
mock.recorder = &MockTXClientMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Cancel mocks base method.
|
||||
func (m *MockTXClient) Cancel() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Cancel")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Cancel indicates an expected call of Cancel.
|
||||
func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel))
|
||||
}
|
||||
|
||||
// Commit mocks base method.
|
||||
func (m *MockTXClient) Commit() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Commit")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Commit indicates an expected call of Commit.
|
||||
func (mr *MockTXClientMockRecorder) Commit() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit))
|
||||
}
|
||||
|
||||
// Exec mocks base method.
|
||||
func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Exec", varargs...)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Exec indicates an expected call of Exec.
|
||||
func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...)
|
||||
}
|
||||
|
||||
// Stmt mocks base method.
|
||||
func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Stmt", arg0)
|
||||
ret0, _ := ret[0].(transaction.Stmt)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Stmt indicates an expected call of Stmt.
|
||||
func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0)
|
||||
}
|
||||
|
||||
// StmtExec mocks base method.
|
||||
func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "StmtExec", varargs...)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StmtExec indicates an expected call of StmtExec.
|
||||
func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...)
|
||||
}
|
90
pkg/sqlcache/db/transaction/transaction.go
Normal file
90
pkg/sqlcache/db/transaction/transaction.go
Normal file
@@ -0,0 +1,90 @@
|
||||
/*
|
||||
Package transaction provides a client for a live transaction, and interfaces for some relevant sql types. The transaction client automatically performs rollbacks on failures.
|
||||
The use of this package simplifies testing for callers by making the underlying transaction mock-able.
|
||||
*/
|
||||
package transaction
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Client provides a way to interact with the underlying sql transaction.
|
||||
type Client struct {
|
||||
sqlTx SQLTx
|
||||
}
|
||||
|
||||
// SQLTx represents a sql transaction
|
||||
type SQLTx interface {
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
Stmt(stmt *sql.Stmt) *sql.Stmt
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// Stmt represents a sql stmt. It is used as a return type to offer some testability over returning sql's Stmt type
|
||||
// because we are able to mock its outputs and do not need an actual connection.
|
||||
type Stmt interface {
|
||||
Exec(args ...any) (sql.Result, error)
|
||||
Query(args ...any) (*sql.Rows, error)
|
||||
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
// NewClient returns a Client with the given transaction assigned.
|
||||
func NewClient(tx SQLTx) *Client {
|
||||
return &Client{sqlTx: tx}
|
||||
}
|
||||
|
||||
// Commit commits the transaction and then unlocks the database.
|
||||
func (c *Client) Commit() error {
|
||||
return c.sqlTx.Commit()
|
||||
}
|
||||
|
||||
// Exec uses the sqlTX Exec() with the given stmt and args. The transaction will be automatically rolled back if Exec()
|
||||
// returns an error.
|
||||
func (c *Client) Exec(stmt string, args ...any) error {
|
||||
_, err := c.sqlTx.Exec(stmt, args...)
|
||||
if err != nil {
|
||||
return c.rollback(c.sqlTx, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stmt adds the given sql.Stmt to the client's transaction and then returns a Stmt. An interface is being returned
|
||||
// here to aid in testing callers by providing a way to configure the statement's behavior.
|
||||
func (c *Client) Stmt(stmt *sql.Stmt) Stmt {
|
||||
s := c.sqlTx.Stmt(stmt)
|
||||
return s
|
||||
}
|
||||
|
||||
// StmtExec Execs the given statement with the given args. It assumes the stmt has been added to the transaction. The
|
||||
// transaction is rolled back if Stmt.Exec() returns an error.
|
||||
func (c *Client) StmtExec(stmt Stmt, args ...any) error {
|
||||
_, err := stmt.Exec(args...)
|
||||
if err != nil {
|
||||
return c.rollback(c.sqlTx, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// rollback handles rollbacks and wraps errors if needed
|
||||
func (c *Client) rollback(tx SQLTx, err error) error {
|
||||
rerr := tx.Rollback()
|
||||
if rerr != nil {
|
||||
return errors.Wrapf(err, "Encountered error, then encountered another error while rolling back: %v", rerr)
|
||||
}
|
||||
return errors.Wrapf(err, "Encountered error, successfully rolled back")
|
||||
}
|
||||
|
||||
// Cancel rollbacks the transaction without wrapping an error. This only needs to be called if Client has not returned
|
||||
// an error yet or has not committed. Otherwise, transaction has already rolled back, or in the case of Commit() it is too
|
||||
// late.
|
||||
func (c *Client) Cancel() error {
|
||||
rerr := c.sqlTx.Rollback()
|
||||
if rerr != sql.ErrTxDone {
|
||||
return rerr
|
||||
}
|
||||
return nil
|
||||
}
|
184
pkg/sqlcache/db/transaction/transaction_mocks_test.go
Normal file
184
pkg/sqlcache/db/transaction/transaction_mocks_test.go
Normal file
@@ -0,0 +1,184 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
|
||||
//
|
||||
|
||||
// Package transaction is a generated GoMock package.
|
||||
package transaction
|
||||
|
||||
import (
|
||||
context "context"
|
||||
sql "database/sql"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockStmt is a mock of Stmt interface.
|
||||
type MockStmt struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStmtMockRecorder
|
||||
}
|
||||
|
||||
// MockStmtMockRecorder is the mock recorder for MockStmt.
|
||||
type MockStmtMockRecorder struct {
|
||||
mock *MockStmt
|
||||
}
|
||||
|
||||
// NewMockStmt creates a new mock instance.
|
||||
func NewMockStmt(ctrl *gomock.Controller) *MockStmt {
|
||||
mock := &MockStmt{ctrl: ctrl}
|
||||
mock.recorder = &MockStmtMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockStmt) EXPECT() *MockStmtMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Exec mocks base method.
|
||||
func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{}
|
||||
for _, a := range arg0 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Exec", varargs...)
|
||||
ret0, _ := ret[0].(sql.Result)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Exec indicates an expected call of Exec.
|
||||
func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...)
|
||||
}
|
||||
|
||||
// Query mocks base method.
|
||||
func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{}
|
||||
for _, a := range arg0 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Query", varargs...)
|
||||
ret0, _ := ret[0].(*sql.Rows)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Query indicates an expected call of Query.
|
||||
func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...)
|
||||
}
|
||||
|
||||
// QueryContext mocks base method.
|
||||
func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "QueryContext", varargs...)
|
||||
ret0, _ := ret[0].(*sql.Rows)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// QueryContext indicates an expected call of QueryContext.
|
||||
func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
|
||||
}
|
||||
|
||||
// MockSQLTx is a mock of SQLTx interface.
|
||||
type MockSQLTx struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSQLTxMockRecorder
|
||||
}
|
||||
|
||||
// MockSQLTxMockRecorder is the mock recorder for MockSQLTx.
|
||||
type MockSQLTxMockRecorder struct {
|
||||
mock *MockSQLTx
|
||||
}
|
||||
|
||||
// NewMockSQLTx creates a new mock instance.
|
||||
func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx {
|
||||
mock := &MockSQLTx{ctrl: ctrl}
|
||||
mock.recorder = &MockSQLTxMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Commit mocks base method.
|
||||
func (m *MockSQLTx) Commit() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Commit")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Commit indicates an expected call of Commit.
|
||||
func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit))
|
||||
}
|
||||
|
||||
// Exec mocks base method.
|
||||
func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Exec", varargs...)
|
||||
ret0, _ := ret[0].(sql.Result)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Exec indicates an expected call of Exec.
|
||||
func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...)
|
||||
}
|
||||
|
||||
// Rollback mocks base method.
|
||||
func (m *MockSQLTx) Rollback() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Rollback")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Rollback indicates an expected call of Rollback.
|
||||
func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback))
|
||||
}
|
||||
|
||||
// Stmt mocks base method.
|
||||
func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Stmt", arg0)
|
||||
ret0, _ := ret[0].(*sql.Stmt)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Stmt indicates an expected call of Stmt.
|
||||
func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0)
|
||||
}
|
182
pkg/sqlcache/db/transaction/transaction_test.go
Normal file
182
pkg/sqlcache/db/transaction/transaction_test.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package transaction
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
//go:generate mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
tx := NewMockSQLTx(gomock.NewController(t))
|
||||
c := NewClient(tx)
|
||||
assert.Equal(t, tx, c.sqlTx)
|
||||
}
|
||||
|
||||
func TestCommit(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
test func(t *testing.T)
|
||||
}
|
||||
|
||||
var tests []testCase
|
||||
|
||||
tests = append(tests, testCase{description: "Commit() with no errors returned from sql TX should return no error", test: func(t *testing.T) {
|
||||
tx := NewMockSQLTx(gomock.NewController(t))
|
||||
tx.EXPECT().Commit().Return(nil)
|
||||
c := &Client{
|
||||
sqlTx: tx,
|
||||
}
|
||||
err := c.Commit()
|
||||
assert.Nil(t, err)
|
||||
}})
|
||||
tests = append(tests, testCase{description: "Commit() with error from sql TX commit() should return error", test: func(t *testing.T) {
|
||||
tx := NewMockSQLTx(gomock.NewController(t))
|
||||
tx.EXPECT().Commit().Return(fmt.Errorf("error"))
|
||||
c := &Client{
|
||||
sqlTx: tx,
|
||||
}
|
||||
err := c.Commit()
|
||||
assert.NotNil(t, err)
|
||||
}})
|
||||
t.Parallel()
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) { test.test(t) })
|
||||
}
|
||||
}
|
||||
|
||||
func TestExec(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
test func(t *testing.T)
|
||||
}
|
||||
|
||||
var tests []testCase
|
||||
|
||||
tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) {
|
||||
tx := NewMockSQLTx(gomock.NewController(t))
|
||||
stmtStr := "some statement %s"
|
||||
arg := 5
|
||||
// should be passed same statement and arg that was passed to parent function
|
||||
tx.EXPECT().Exec(stmtStr, arg).Return(nil, nil)
|
||||
c := &Client{
|
||||
sqlTx: tx,
|
||||
}
|
||||
err := c.Exec(stmtStr, arg)
|
||||
assert.Nil(t, err)
|
||||
}})
|
||||
tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) {
|
||||
tx := NewMockSQLTx(gomock.NewController(t))
|
||||
stmtStr := "some statement %s"
|
||||
arg := 5
|
||||
// should be passed same statement and arg that was passed to parent function
|
||||
tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error"))
|
||||
tx.EXPECT().Rollback().Return(nil)
|
||||
c := &Client{
|
||||
sqlTx: tx,
|
||||
}
|
||||
err := c.Exec(stmtStr, arg)
|
||||
assert.NotNil(t, err)
|
||||
}})
|
||||
tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) {
|
||||
tx := NewMockSQLTx(gomock.NewController(t))
|
||||
stmtStr := "some statement %s"
|
||||
arg := 5
|
||||
// should be passed same statement and arg that was passed to parent function
|
||||
tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error"))
|
||||
tx.EXPECT().Rollback().Return(fmt.Errorf("error"))
|
||||
c := &Client{
|
||||
sqlTx: tx,
|
||||
}
|
||||
err := c.Exec(stmtStr, arg)
|
||||
assert.NotNil(t, err)
|
||||
}})
|
||||
t.Parallel()
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) { test.test(t) })
|
||||
}
|
||||
}
|
||||
|
||||
func TestStmt(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
test func(t *testing.T)
|
||||
}
|
||||
|
||||
var tests []testCase
|
||||
|
||||
tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) {
|
||||
tx := NewMockSQLTx(gomock.NewController(t))
|
||||
stmt := &sql.Stmt{}
|
||||
var returnedTXStmt *sql.Stmt
|
||||
// should be passed same statement and arg that was passed to parent function
|
||||
tx.EXPECT().Stmt(stmt).Return(returnedTXStmt)
|
||||
c := &Client{
|
||||
sqlTx: tx,
|
||||
}
|
||||
returnedStmt := c.Stmt(stmt)
|
||||
// whatever tx returned should be returned here. Nil was used because none of sql.Stmt's fields are exported so its simpler to test nil as it
|
||||
// won't be equal to an empty struct
|
||||
assert.Equal(t, returnedTXStmt, returnedStmt)
|
||||
}})
|
||||
t.Parallel()
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) { test.test(t) })
|
||||
}
|
||||
}
|
||||
|
||||
func TestStmtExec(t *testing.T) {
|
||||
type testCase struct {
|
||||
description string
|
||||
test func(t *testing.T)
|
||||
}
|
||||
|
||||
var tests []testCase
|
||||
|
||||
tests = append(tests, testCase{description: "StmtExec with no errors returned from Stmt should return no error", test: func(t *testing.T) {
|
||||
tx := NewMockSQLTx(gomock.NewController(t))
|
||||
stmt := NewMockStmt(gomock.NewController(t))
|
||||
arg := "something"
|
||||
// should be passed same arg that was passed to parent function
|
||||
stmt.EXPECT().Exec(arg).Return(nil, nil)
|
||||
c := &Client{
|
||||
sqlTx: tx,
|
||||
}
|
||||
err := c.StmtExec(stmt, arg)
|
||||
assert.Nil(t, err)
|
||||
}})
|
||||
tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and no Tx Rollback() error should return error", test: func(t *testing.T) {
|
||||
tx := NewMockSQLTx(gomock.NewController(t))
|
||||
stmt := NewMockStmt(gomock.NewController(t))
|
||||
arg := "something"
|
||||
// should be passed same arg that was passed to parent function
|
||||
stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error"))
|
||||
tx.EXPECT().Rollback().Return(nil)
|
||||
c := &Client{
|
||||
sqlTx: tx,
|
||||
}
|
||||
err := c.StmtExec(stmt, arg)
|
||||
assert.NotNil(t, err)
|
||||
}})
|
||||
tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and Tx Rollback() error should return error", test: func(t *testing.T) {
|
||||
tx := NewMockSQLTx(gomock.NewController(t))
|
||||
stmt := NewMockStmt(gomock.NewController(t))
|
||||
arg := "something"
|
||||
// should be passed same arg that was passed to parent function
|
||||
stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error"))
|
||||
tx.EXPECT().Rollback().Return(fmt.Errorf("error2"))
|
||||
c := &Client{
|
||||
sqlTx: tx,
|
||||
}
|
||||
err := c.StmtExec(stmt, arg)
|
||||
assert.NotNil(t, err)
|
||||
}})
|
||||
t.Parallel()
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) { test.test(t) })
|
||||
}
|
||||
}
|
184
pkg/sqlcache/db/transaction_mocks_test.go
Normal file
184
pkg/sqlcache/db/transaction_mocks_test.go
Normal file
@@ -0,0 +1,184 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
|
||||
//
|
||||
|
||||
// Package db is a generated GoMock package.
|
||||
package db
|
||||
|
||||
import (
|
||||
context "context"
|
||||
sql "database/sql"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockStmt is a mock of Stmt interface.
|
||||
type MockStmt struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockStmtMockRecorder
|
||||
}
|
||||
|
||||
// MockStmtMockRecorder is the mock recorder for MockStmt.
|
||||
type MockStmtMockRecorder struct {
|
||||
mock *MockStmt
|
||||
}
|
||||
|
||||
// NewMockStmt creates a new mock instance.
|
||||
func NewMockStmt(ctrl *gomock.Controller) *MockStmt {
|
||||
mock := &MockStmt{ctrl: ctrl}
|
||||
mock.recorder = &MockStmtMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockStmt) EXPECT() *MockStmtMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Exec mocks base method.
|
||||
func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{}
|
||||
for _, a := range arg0 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Exec", varargs...)
|
||||
ret0, _ := ret[0].(sql.Result)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Exec indicates an expected call of Exec.
|
||||
func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...)
|
||||
}
|
||||
|
||||
// Query mocks base method.
|
||||
func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{}
|
||||
for _, a := range arg0 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Query", varargs...)
|
||||
ret0, _ := ret[0].(*sql.Rows)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Query indicates an expected call of Query.
|
||||
func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...)
|
||||
}
|
||||
|
||||
// QueryContext mocks base method.
|
||||
func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "QueryContext", varargs...)
|
||||
ret0, _ := ret[0].(*sql.Rows)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// QueryContext indicates an expected call of QueryContext.
|
||||
func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
|
||||
}
|
||||
|
||||
// MockSQLTx is a mock of SQLTx interface.
|
||||
type MockSQLTx struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSQLTxMockRecorder
|
||||
}
|
||||
|
||||
// MockSQLTxMockRecorder is the mock recorder for MockSQLTx.
|
||||
type MockSQLTxMockRecorder struct {
|
||||
mock *MockSQLTx
|
||||
}
|
||||
|
||||
// NewMockSQLTx creates a new mock instance.
|
||||
func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx {
|
||||
mock := &MockSQLTx{ctrl: ctrl}
|
||||
mock.recorder = &MockSQLTxMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Commit mocks base method.
|
||||
func (m *MockSQLTx) Commit() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Commit")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Commit indicates an expected call of Commit.
|
||||
func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit))
|
||||
}
|
||||
|
||||
// Exec mocks base method.
|
||||
func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Exec", varargs...)
|
||||
ret0, _ := ret[0].(sql.Result)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Exec indicates an expected call of Exec.
|
||||
func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...)
|
||||
}
|
||||
|
||||
// Rollback mocks base method.
|
||||
func (m *MockSQLTx) Rollback() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Rollback")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Rollback indicates an expected call of Rollback.
|
||||
func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback))
|
||||
}
|
||||
|
||||
// Stmt mocks base method.
|
||||
func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Stmt", arg0)
|
||||
ret0, _ := ret[0].(*sql.Stmt)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Stmt indicates an expected call of Stmt.
|
||||
func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0)
|
||||
}
|
8
pkg/sqlcache/db/utility.go
Normal file
8
pkg/sqlcache/db/utility.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package db
|
||||
|
||||
import "strings"
|
||||
|
||||
// Sanitize returns a string that can be used in SQL as a name
|
||||
func Sanitize(s string) string {
|
||||
return strings.ReplaceAll(s, "\"", "")
|
||||
}
|
Reference in New Issue
Block a user