1
0
mirror of https://github.com/rancher/steve.git synced 2025-09-08 10:49:25 +00:00

[v0.4] Move lasso SQL cache in Steve (#473)

* Copy pkg/cache/sql from lasso to pkg/sqlcache

* Rename import from github.com/rancher/lasso/pkg/cache/sql to github.com/rancher/steve/pkg/sqlcache

* go mod tidy

* Fix lint errors

* Remove lasso SQL cache mentions

* Fix more CI lint errors

* fix goimports

Signed-off-by: Silvio Moioli <silvio@moioli.net>

* Fix imports

* Fix more linting errors

---------

Signed-off-by: Silvio Moioli <silvio@moioli.net>
Co-authored-by: Silvio Moioli <silvio@moioli.net>
This commit is contained in:
Tom Lebreux
2025-02-04 12:42:13 -05:00
committed by GitHub
parent 41674fa0cf
commit d030e42148
53 changed files with 9967 additions and 31 deletions

374
pkg/sqlcache/db/client.go Normal file
View File

@@ -0,0 +1,374 @@
/*
Package db offers client struct and functions to interact with database connection. It provides encrypting, decrypting,
and a way to reset the database.
*/
package db
import (
"bytes"
"context"
"database/sql"
"encoding/gob"
"fmt"
"io/fs"
"os"
"reflect"
"sync"
"github.com/pkg/errors"
"github.com/rancher/steve/pkg/sqlcache/db/transaction"
// needed for drivers
_ "modernc.org/sqlite"
)
const (
// InformerObjectCacheDBPath is where SQLite's object database file will be stored relative to process running steve
InformerObjectCacheDBPath = "informer_object_cache.db"
informerObjectCachePerms fs.FileMode = 0o600
)
// Client is a database client that provides encrypting, decrypting, and database resetting.
type Client struct {
conn Connection
connLock sync.RWMutex
encryptor Encryptor
decryptor Decryptor
}
// Connection represents a connection pool.
type Connection interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
Exec(query string, args ...any) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
Close() error
}
// Closable Closes an underlying connection and returns an error on failure.
type Closable interface {
Close() error
}
// Rows represents sql rows. It exposes method to navigate the rows, read their outputs, and close them.
type Rows interface {
Next() bool
Err() error
Close() error
Scan(dest ...any) error
}
// QueryError encapsulates an error while executing a query
type QueryError struct {
QueryString string
Err error
}
// Error returns a string representation of this QueryError
func (e *QueryError) Error() string {
return "while executing query: " + e.QueryString + " got error: " + e.Err.Error()
}
// Unwrap returns the underlying error
func (e *QueryError) Unwrap() error {
return e.Err
}
// TXClient represents a sql transaction. The TXClient must manage rollbacks as rollback functionality is not exposed.
type TXClient interface {
StmtExec(stmt transaction.Stmt, args ...any) error
Exec(stmt string, args ...any) error
Commit() error
Stmt(stmt *sql.Stmt) transaction.Stmt
Cancel() error
}
// Encryptor encrypts data with a key which is rotated to avoid wear-out.
type Encryptor interface {
// Encrypt encrypts the specified data, returning: the encrypted data, the nonce used to encrypt the data, and an ID identifying the key that was used (as it rotates). On failure error is returned instead.
Encrypt([]byte) ([]byte, []byte, uint32, error)
}
// Decryptor decrypts data previously encrypted by Encryptor.
type Decryptor interface {
// Decrypt accepts a chunk of encrypted data, the nonce used to encrypt it and the ID of the used key (as it rotates). It returns the decrypted data or an error.
Decrypt([]byte, []byte, uint32) ([]byte, error)
}
// NewClient returns a Client. If the given connection is nil then a default one will be created.
func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (*Client, error) {
client := &Client{
encryptor: encryptor,
decryptor: decryptor,
}
if c != nil {
client.conn = c
return client, nil
}
err := client.NewConnection()
if err != nil {
return nil, err
}
return client, nil
}
// Prepare prepares the given string into a sql statement on the client's connection.
func (c *Client) Prepare(stmt string) *sql.Stmt {
c.connLock.RLock()
defer c.connLock.RUnlock()
prepared, err := c.conn.Prepare(stmt)
if err != nil {
panic(errors.Errorf("Error preparing statement: %s\n%v", stmt, err))
}
return prepared
}
// QueryForRows queries the given stmt with the given params and returns the resulting rows. The query wil be retried
// given a sqlite busy error.
func (c *Client) QueryForRows(ctx context.Context, stmt transaction.Stmt, params ...any) (*sql.Rows, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
return stmt.QueryContext(ctx, params...)
}
// CloseStmt will call close on the given Closable. It is intended to be used with a sql statement. This function is meant
// to replace stmt.Close which can cause panics when callers unit-test since there usually is no real underlying connection.
func (c *Client) CloseStmt(closable Closable) error {
return closable.Close()
}
// ReadObjects Scans the given rows, performs any necessary decryption, converts the data to objects of the given type,
// and returns a slice of those objects.
func (c *Client) ReadObjects(rows Rows, typ reflect.Type, shouldDecrypt bool) ([]any, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
var result []any
for rows.Next() {
data, err := c.decryptScan(rows, shouldDecrypt)
if err != nil {
return nil, closeRowsOnError(rows, err)
}
singleResult, err := fromBytes(data, typ)
if err != nil {
return nil, closeRowsOnError(rows, err)
}
result = append(result, singleResult.Elem().Interface())
}
err := rows.Err()
if err != nil {
return nil, closeRowsOnError(rows, err)
}
err = rows.Close()
if err != nil {
return nil, err
}
return result, nil
}
// ReadStrings scans the given rows into strings, and then returns the strings as a slice.
func (c *Client) ReadStrings(rows Rows) ([]string, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
var result []string
for rows.Next() {
var key string
err := rows.Scan(&key)
if err != nil {
return nil, closeRowsOnError(rows, err)
}
result = append(result, key)
}
err := rows.Err()
if err != nil {
return nil, closeRowsOnError(rows, err)
}
err = rows.Close()
if err != nil {
return nil, err
}
return result, nil
}
// ReadInt scans the first of the given rows into a single int (eg. for COUNT() queries)
func (c *Client) ReadInt(rows Rows) (int, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
if !rows.Next() {
return 0, closeRowsOnError(rows, sql.ErrNoRows)
}
var result int
err := rows.Scan(&result)
if err != nil {
return 0, closeRowsOnError(rows, err)
}
err = rows.Err()
if err != nil {
return 0, closeRowsOnError(rows, err)
}
err = rows.Close()
if err != nil {
return 0, err
}
return result, nil
}
// BeginTx attempts to begin a transaction.
// If forWriting is true, this method blocks until all other concurrent forWriting
// transactions have either committed or rolled back.
// If forWriting is false, it is assumed the returned transaction will exclusively
// be used for DQL (eg. SELECT) queries.
// Not respecting the above rule might result in transactions failing with unexpected
// SQLITE_BUSY (5) errors (aka "Runtime error: database is locked").
// See discussion in https://github.com/rancher/lasso/pull/98 for details
func (c *Client) BeginTx(ctx context.Context, forWriting bool) (TXClient, error) {
c.connLock.RLock()
defer c.connLock.RUnlock()
// note: this assumes _txlock=immediate in the connection string, see NewConnection
sqlTx, err := c.conn.BeginTx(ctx, &sql.TxOptions{
ReadOnly: !forWriting,
})
if err != nil {
return nil, err
}
return transaction.NewClient(sqlTx), nil
}
func (c *Client) decryptScan(rows Rows, shouldDecrypt bool) ([]byte, error) {
var data, dataNonce sql.RawBytes
var kid uint32
err := rows.Scan(&data, &dataNonce, &kid)
if err != nil {
return nil, err
}
if c.decryptor != nil && shouldDecrypt {
decryptedData, err := c.decryptor.Decrypt(data, dataNonce, kid)
if err != nil {
return nil, err
}
return decryptedData, nil
}
return data, nil
}
// Upsert used to be called upsertEncrypted in store package before move
func (c *Client) Upsert(tx TXClient, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error {
objBytes := toBytes(obj)
var dataNonce []byte
var err error
var kid uint32
if c.encryptor != nil && shouldEncrypt {
objBytes, dataNonce, kid, err = c.encryptor.Encrypt(objBytes)
if err != nil {
return err
}
}
return tx.StmtExec(tx.Stmt(stmt), key, objBytes, dataNonce, kid)
}
// toBytes encodes an object to a byte slice
func toBytes(obj any) []byte {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(obj)
if err != nil {
panic(fmt.Errorf("error while gobbing object: %w", err))
}
bb := buf.Bytes()
return bb
}
// fromBytes decodes an object from a byte slice
func fromBytes(buf sql.RawBytes, typ reflect.Type) (reflect.Value, error) {
dec := gob.NewDecoder(bytes.NewReader(buf))
singleResult := reflect.New(typ)
err := dec.DecodeValue(singleResult)
return singleResult, err
}
// closeRowsOnError closes the sql.Rows object and wraps errors if needed
func closeRowsOnError(rows Rows, err error) error {
ce := rows.Close()
if ce != nil {
return fmt.Errorf("error in closing rows while handling %s: %w", err.Error(), ce)
}
return err
}
// NewConnection checks for currently existing connection, closes one if it exists, removes any relevant db files, and opens a new connection which subsequently
// creates new files.
func (c *Client) NewConnection() error {
c.connLock.Lock()
defer c.connLock.Unlock()
if c.conn != nil {
err := c.conn.Close()
if err != nil {
return err
}
}
err := os.RemoveAll(InformerObjectCacheDBPath)
if err != nil {
return err
}
// Set the permissions in advance, because we can't control them if
// the file is created by a sql.Open call instead.
if err := touchFile(InformerObjectCacheDBPath, informerObjectCachePerms); err != nil {
return nil
}
sqlDB, err := sql.Open("sqlite", "file:"+InformerObjectCacheDBPath+"?"+
// open SQLite file in read-write mode, creating it if it does not exist
"mode=rwc&"+
// use the WAL journal mode for consistency and efficiency
"_pragma=journal_mode=wal&"+
// do not even attempt to attain durability. Database is thrown away at pod restart
"_pragma=synchronous=off&"+
// do check foreign keys and honor ON DELETE CASCADE
"_pragma=foreign_keys=on&"+
// if two transactions want to write at the same time, allow 2 minutes for the first to complete
// before baling out
"_pragma=busy_timeout=120000&"+
// default to IMMEDIATE mode for transactions. Setting this parameter is the only current way
// to be able to switch between DEFERRED and IMMEDIATE modes in modernc.org/sqlite's implementation
// of BeginTx
"_txlock=immediate")
if err != nil {
return err
}
c.conn = sqlDB
return nil
}
// This acts like "touch" for both existing files and non-existing files.
// permissions.
//
// It's created with the correct perms, and if the file already exists, it will
// be chmodded to the correct perms.
func touchFile(filename string, perms fs.FileMode) error {
f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, perms)
if err != nil {
return err
}
if err := f.Close(); err != nil {
return err
}
return os.Chmod(filename, perms)
}

View File

@@ -0,0 +1,667 @@
package db
import (
"context"
"database/sql"
"fmt"
"io/fs"
"math"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
)
// Mocks for this test are generated with the following command.
//go:generate mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient
//go:generate mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
type testStoreObject struct {
Id string
Val string
}
func TestNewClient(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "Query rows with no params, no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
expectedClient := &Client{
conn: c,
encryptor: e,
decryptor: d,
}
client, err := NewClient(c, e, d)
assert.Nil(t, err)
assert.Equal(t, expectedClient, client)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestQueryForRows(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "Query rows with no params, no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
client := SetupClient(t, c, nil, nil)
s := NewMockStmt(gomock.NewController(t))
ctx := context.TODO()
r := &sql.Rows{}
s.EXPECT().QueryContext(ctx).Return(r, nil)
rows, err := client.QueryForRows(ctx, s)
assert.Nil(t, err)
assert.Equal(t, r, rows)
},
})
tests = append(tests, testCase{description: "Query rows with params, QueryContext() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
client := SetupClient(t, c, nil, nil)
s := NewMockStmt(gomock.NewController(t))
ctx := context.TODO()
s.EXPECT().QueryContext(ctx).Return(nil, fmt.Errorf("error"))
_, err := client.QueryForRows(ctx, s)
assert.NotNil(t, err)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestQueryObjects(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
testObject := testStoreObject{Id: "something", Val: "a"}
var keyId uint32 = math.MaxUint32
// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "Query objects, with one row, and no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
*a[0].(*sql.RawBytes) = toBytes(testObject)
*a[1].(*sql.RawBytes) = toBytes(testObject)
*a[2].(*uint32) = keyId
})
d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(toBytes(testObject), nil)
r.EXPECT().Err().Return(nil)
r.EXPECT().Next().Return(false)
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
items, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
assert.Nil(t, err)
assert.Equal(t, 1, len(items))
},
})
tests = append(tests, testCase{description: "Query objects, with one row, and a decrypt error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
*a[0].(*sql.RawBytes) = toBytes(testObject)
*a[1].(*sql.RawBytes) = toBytes(
testObject)
*a[2].(*uint32) = keyId
})
d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(nil, fmt.Errorf("error"))
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
_, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Query objects, with one row, and a Scan() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error"))
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
_, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Query objects, with one row, and a Close() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
*a[0].(*sql.RawBytes) = toBytes(testObject)
*a[1].(*sql.RawBytes) = toBytes(testObject)
*a[2].(*uint32) = keyId
})
d.EXPECT().Decrypt(toBytes(testObject), toBytes(testObject), keyId).Return(toBytes(testObject), nil)
r.EXPECT().Err().Return(nil)
r.EXPECT().Next().Return(false)
r.EXPECT().Close().Return(fmt.Errorf("error"))
client := SetupClient(t, c, e, d)
_, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Query objects, with no rows, and no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(false)
r.EXPECT().Err().Return(nil)
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
items, err := client.ReadObjects(r, reflect.TypeOf(testObject), true)
assert.Nil(t, err)
assert.Equal(t, 0, len(items))
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestQueryStrings(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
testObject := testStoreObject{Id: "something", Val: "a"}
// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "ReadStrings(), with one row, and no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
for _, v := range a {
vk := v.(*string)
*vk = string(toBytes(testObject.Id))
}
})
r.EXPECT().Err().Return(nil)
r.EXPECT().Next().Return(false)
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
items, err := client.ReadStrings(r)
assert.Nil(t, err)
assert.Equal(t, 1, len(items))
},
})
tests = append(tests, testCase{description: "Query objects, with one row, and Scan error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error"))
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
_, err := client.ReadStrings(r)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "ReadStrings(), with one row, and Err() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
for _, v := range a {
vk := v.(*string)
*vk = string(toBytes(testObject.Id))
}
})
r.EXPECT().Next().Return(false)
r.EXPECT().Err().Return(fmt.Errorf("error"))
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
_, err := client.ReadStrings(r)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "ReadStrings(), with one row, and Close() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
for _, v := range a {
vk := v.(*string)
*vk = string(toBytes(testObject.Id))
}
})
r.EXPECT().Err().Return(nil)
r.EXPECT().Next().Return(false)
r.EXPECT().Close().Return(fmt.Errorf("error"))
client := SetupClient(t, c, e, d)
_, err := client.ReadStrings(r)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "ReadStrings(), with no rows, and no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(false)
r.EXPECT().Err().Return(nil)
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
items, err := client.ReadStrings(r)
assert.Nil(t, err)
assert.Equal(t, 0, len(items))
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestReadInt(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
testResult := 42
tests = append(tests, testCase{description: "One row, no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
p := a[0].(*int)
*p = testResult
})
r.EXPECT().Err().Return(nil)
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
result, err := client.ReadInt(r)
assert.Nil(t, err)
assert.Equal(t, 42, result)
},
})
tests = append(tests, testCase{description: "One row, Scan error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Return(fmt.Errorf("error"))
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
_, err := client.ReadInt(r)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "One row, Err() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
a[0] = testResult
})
r.EXPECT().Err().Return(fmt.Errorf("error"))
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
_, err := client.ReadInt(r)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "One row, Close() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(true)
r.EXPECT().Scan(gomock.Any()).Do(func(a ...any) {
a[0] = testResult
})
r.EXPECT().Err().Return(nil)
r.EXPECT().Close().Return(fmt.Errorf("error"))
client := SetupClient(t, c, e, d)
_, err := client.ReadInt(r)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "No rows error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
r := SetupMockRows(t)
r.EXPECT().Next().Return(false)
r.EXPECT().Close().Return(nil)
client := SetupClient(t, c, e, d)
_, err := client.ReadInt(r)
assert.ErrorIs(t, err, sql.ErrNoRows)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestBegin(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
// Tests with shouldEncryptSet to false
tests = append(tests, testCase{description: "BeginTx(), with no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
sqlTx := &sql.Tx{}
c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(sqlTx, nil)
client := SetupClient(t, c, e, d)
txC, err := client.BeginTx(context.Background(), false)
assert.Nil(t, err)
assert.NotNil(t, txC)
},
})
tests = append(tests, testCase{description: "BeginTx(), with forWriting option set", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
sqlTx := &sql.Tx{}
c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: false}).Return(sqlTx, nil)
client := SetupClient(t, c, e, d)
txC, err := client.BeginTx(context.Background(), true)
assert.Nil(t, err)
assert.NotNil(t, txC)
},
})
tests = append(tests, testCase{description: "BeginTx(), with connection Begin() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
c.EXPECT().BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}).Return(nil, fmt.Errorf("error"))
client := SetupClient(t, c, e, d)
_, err := client.BeginTx(context.Background(), false)
assert.NotNil(t, err)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestUpsert(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
testObject := testStoreObject{Id: "something", Val: "a"}
var keyID uint32 = 5
// Tests with shouldEncryptSet to true
tests = append(tests, testCase{description: "Upsert() with no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
stmt := NewMockStmt(gomock.NewController(t))
testObjBytes := toBytes(testObject)
testByteValue := []byte("something")
e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil)
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(nil)
err := client.Upsert(txC, sqlStmt, "somekey", testObject, true)
assert.Nil(t, err)
},
})
tests = append(tests, testCase{description: "Upsert() with Encrypt() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
testObjBytes := toBytes(testObject)
e.EXPECT().Encrypt(testObjBytes).Return(nil, nil, uint32(0), fmt.Errorf("error"))
err := client.Upsert(txC, sqlStmt, "somekey", testObject, true)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Upsert() with StmtExec() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
stmt := NewMockStmt(gomock.NewController(t))
testObjBytes := toBytes(testObject)
testByteValue := []byte("something")
e.EXPECT().Encrypt(testObjBytes).Return(testByteValue, testByteValue, keyID, nil)
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
txC.EXPECT().StmtExec(stmt, "somekey", testByteValue, testByteValue, keyID).Return(fmt.Errorf("error"))
err := client.Upsert(txC, sqlStmt, "somekey", testObject, true)
assert.NotNil(t, err)
},
})
tests = append(tests, testCase{description: "Upsert() with no errors and shouldEncrypt false", test: func(t *testing.T) {
c := SetupMockConnection(t)
d := SetupMockDecryptor(t)
e := SetupMockEncryptor(t)
client := SetupClient(t, c, e, d)
txC := NewMockTXClient(gomock.NewController(t))
sqlStmt := &sql.Stmt{}
stmt := NewMockStmt(gomock.NewController(t))
var testByteValue []byte
testObjBytes := toBytes(testObject)
txC.EXPECT().Stmt(sqlStmt).Return(stmt)
txC.EXPECT().StmtExec(stmt, "somekey", testObjBytes, testByteValue, uint32(0)).Return(nil)
err := client.Upsert(txC, sqlStmt, "somekey", testObject, false)
assert.Nil(t, err)
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestPrepare(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "Prepare() with no errors", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
sqlStmt := &sql.Stmt{}
c.EXPECT().Prepare("something").Return(sqlStmt, nil)
stmt := client.Prepare("something")
assert.Equal(t, sqlStmt, stmt)
},
})
tests = append(tests, testCase{description: "Prepare() with Connection Prepare() error", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
c.EXPECT().Prepare("something").Return(nil, fmt.Errorf("error"))
assert.Panics(t, func() { client.Prepare("something") })
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestNewConnection(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "NewConnection replaces file", test: func(t *testing.T) {
c := SetupMockConnection(t)
e := SetupMockEncryptor(t)
d := SetupMockDecryptor(t)
client := SetupClient(t, c, e, d)
c.EXPECT().Close().Return(nil)
err := client.NewConnection()
assert.Nil(t, err)
// Create a transaction to ensure that the file is written to disk.
txC, err := client.BeginTx(context.Background(), false)
assert.NoError(t, err)
assert.NoError(t, txC.Commit())
assert.FileExists(t, InformerObjectCacheDBPath)
assertFileHasPermissions(t, InformerObjectCacheDBPath, 0600)
err = os.Remove(InformerObjectCacheDBPath)
if err != nil {
assert.Fail(t, "could not remove object cache path after test")
}
},
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestCommit(t *testing.T) {
}
func TestRollback(t *testing.T) {
}
func SetupMockConnection(t *testing.T) *MockConnection {
mockC := NewMockConnection(gomock.NewController(t))
return mockC
}
func SetupMockEncryptor(t *testing.T) *MockEncryptor {
mockE := NewMockEncryptor(gomock.NewController(t))
return mockE
}
func SetupMockDecryptor(t *testing.T) *MockDecryptor {
MockD := NewMockDecryptor(gomock.NewController(t))
return MockD
}
func SetupMockRows(t *testing.T) *MockRows {
MockR := NewMockRows(gomock.NewController(t))
return MockR
}
func SetupClient(t *testing.T, connection Connection, encryptor Encryptor, decryptor Decryptor) *Client {
c, _ := NewClient(connection, encryptor, decryptor)
return c
}
func TestTouchFile(t *testing.T) {
t.Run("File doesn't exist before", func(t *testing.T) {
filename := filepath.Join(t.TempDir(), "test1.txt")
assert.NoError(t, touchFile(filename, 0600))
assertFileHasPermissions(t, filename, 0600)
})
t.Run("File exists with different permissions", func(t *testing.T) {
filename := filepath.Join(t.TempDir(), "test2.txt")
assert.NoError(t, os.WriteFile(filename, []byte("test"), 0644))
assert.NoError(t, touchFile(filename, 0600))
assertFileHasPermissions(t, filename, 0600)
})
}
func assertFileHasPermissions(t *testing.T, fname string, wantPerms fs.FileMode) bool {
t.Helper()
info, err := os.Lstat(fname)
if err != nil {
if os.IsNotExist(err) {
return assert.Fail(t, fmt.Sprintf("unable to find file %q", fname))
}
return assert.Fail(t, fmt.Sprintf("error when running os.Lstat(%q): %s", fname, err))
}
// Stringifying the perms makes it easier to read than a Hex comparison.
assert.Equal(t, wantPerms.String(), info.Mode().Perm().String())
return true
}

View File

@@ -0,0 +1,370 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/db (interfaces: Rows,Connection,Encryptor,Decryptor,TXClient)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package db -destination ./db_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db Rows,Connection,Encryptor,Decryptor,TXClient
//
// Package db is a generated GoMock package.
package db
import (
context "context"
sql "database/sql"
reflect "reflect"
transaction "github.com/rancher/steve/pkg/sqlcache/db/transaction"
gomock "go.uber.org/mock/gomock"
)
// MockRows is a mock of Rows interface.
type MockRows struct {
ctrl *gomock.Controller
recorder *MockRowsMockRecorder
}
// MockRowsMockRecorder is the mock recorder for MockRows.
type MockRowsMockRecorder struct {
mock *MockRows
}
// NewMockRows creates a new mock instance.
func NewMockRows(ctrl *gomock.Controller) *MockRows {
mock := &MockRows{ctrl: ctrl}
mock.recorder = &MockRowsMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRows) EXPECT() *MockRowsMockRecorder {
return m.recorder
}
// Close mocks base method.
func (m *MockRows) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockRowsMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close))
}
// Err mocks base method.
func (m *MockRows) Err() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Err")
ret0, _ := ret[0].(error)
return ret0
}
// Err indicates an expected call of Err.
func (mr *MockRowsMockRecorder) Err() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err))
}
// Next mocks base method.
func (m *MockRows) Next() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Next")
ret0, _ := ret[0].(bool)
return ret0
}
// Next indicates an expected call of Next.
func (mr *MockRowsMockRecorder) Next() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next))
}
// Scan mocks base method.
func (m *MockRows) Scan(arg0 ...any) error {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Scan", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Scan indicates an expected call of Scan.
func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...)
}
// MockConnection is a mock of Connection interface.
type MockConnection struct {
ctrl *gomock.Controller
recorder *MockConnectionMockRecorder
}
// MockConnectionMockRecorder is the mock recorder for MockConnection.
type MockConnectionMockRecorder struct {
mock *MockConnection
}
// NewMockConnection creates a new mock instance.
func NewMockConnection(ctrl *gomock.Controller) *MockConnection {
mock := &MockConnection{ctrl: ctrl}
mock.recorder = &MockConnectionMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockConnection) EXPECT() *MockConnectionMockRecorder {
return m.recorder
}
// BeginTx mocks base method.
func (m *MockConnection) BeginTx(arg0 context.Context, arg1 *sql.TxOptions) (*sql.Tx, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BeginTx", arg0, arg1)
ret0, _ := ret[0].(*sql.Tx)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BeginTx indicates an expected call of BeginTx.
func (mr *MockConnectionMockRecorder) BeginTx(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTx", reflect.TypeOf((*MockConnection)(nil).BeginTx), arg0, arg1)
}
// Close mocks base method.
func (m *MockConnection) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockConnectionMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConnection)(nil).Close))
}
// Exec mocks base method.
func (m *MockConnection) Exec(arg0 string, arg1 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockConnectionMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockConnection)(nil).Exec), varargs...)
}
// Prepare mocks base method.
func (m *MockConnection) Prepare(arg0 string) (*sql.Stmt, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Prepare", arg0)
ret0, _ := ret[0].(*sql.Stmt)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Prepare indicates an expected call of Prepare.
func (mr *MockConnectionMockRecorder) Prepare(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Prepare", reflect.TypeOf((*MockConnection)(nil).Prepare), arg0)
}
// MockEncryptor is a mock of Encryptor interface.
type MockEncryptor struct {
ctrl *gomock.Controller
recorder *MockEncryptorMockRecorder
}
// MockEncryptorMockRecorder is the mock recorder for MockEncryptor.
type MockEncryptorMockRecorder struct {
mock *MockEncryptor
}
// NewMockEncryptor creates a new mock instance.
func NewMockEncryptor(ctrl *gomock.Controller) *MockEncryptor {
mock := &MockEncryptor{ctrl: ctrl}
mock.recorder = &MockEncryptorMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockEncryptor) EXPECT() *MockEncryptorMockRecorder {
return m.recorder
}
// Encrypt mocks base method.
func (m *MockEncryptor) Encrypt(arg0 []byte) ([]byte, []byte, uint32, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Encrypt", arg0)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].([]byte)
ret2, _ := ret[2].(uint32)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
}
// Encrypt indicates an expected call of Encrypt.
func (mr *MockEncryptorMockRecorder) Encrypt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Encrypt", reflect.TypeOf((*MockEncryptor)(nil).Encrypt), arg0)
}
// MockDecryptor is a mock of Decryptor interface.
type MockDecryptor struct {
ctrl *gomock.Controller
recorder *MockDecryptorMockRecorder
}
// MockDecryptorMockRecorder is the mock recorder for MockDecryptor.
type MockDecryptorMockRecorder struct {
mock *MockDecryptor
}
// NewMockDecryptor creates a new mock instance.
func NewMockDecryptor(ctrl *gomock.Controller) *MockDecryptor {
mock := &MockDecryptor{ctrl: ctrl}
mock.recorder = &MockDecryptorMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockDecryptor) EXPECT() *MockDecryptorMockRecorder {
return m.recorder
}
// Decrypt mocks base method.
func (m *MockDecryptor) Decrypt(arg0, arg1 []byte, arg2 uint32) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Decrypt", arg0, arg1, arg2)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Decrypt indicates an expected call of Decrypt.
func (mr *MockDecryptorMockRecorder) Decrypt(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Decrypt", reflect.TypeOf((*MockDecryptor)(nil).Decrypt), arg0, arg1, arg2)
}
// MockTXClient is a mock of TXClient interface.
type MockTXClient struct {
ctrl *gomock.Controller
recorder *MockTXClientMockRecorder
}
// MockTXClientMockRecorder is the mock recorder for MockTXClient.
type MockTXClientMockRecorder struct {
mock *MockTXClient
}
// NewMockTXClient creates a new mock instance.
func NewMockTXClient(ctrl *gomock.Controller) *MockTXClient {
mock := &MockTXClient{ctrl: ctrl}
mock.recorder = &MockTXClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTXClient) EXPECT() *MockTXClientMockRecorder {
return m.recorder
}
// Cancel mocks base method.
func (m *MockTXClient) Cancel() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Cancel")
ret0, _ := ret[0].(error)
return ret0
}
// Cancel indicates an expected call of Cancel.
func (mr *MockTXClientMockRecorder) Cancel() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cancel", reflect.TypeOf((*MockTXClient)(nil).Cancel))
}
// Commit mocks base method.
func (m *MockTXClient) Commit() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit")
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
func (mr *MockTXClientMockRecorder) Commit() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTXClient)(nil).Commit))
}
// Exec mocks base method.
func (m *MockTXClient) Exec(arg0 string, arg1 ...any) error {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Exec indicates an expected call of Exec.
func (mr *MockTXClientMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTXClient)(nil).Exec), varargs...)
}
// Stmt mocks base method.
func (m *MockTXClient) Stmt(arg0 *sql.Stmt) transaction.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stmt", arg0)
ret0, _ := ret[0].(transaction.Stmt)
return ret0
}
// Stmt indicates an expected call of Stmt.
func (mr *MockTXClientMockRecorder) Stmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockTXClient)(nil).Stmt), arg0)
}
// StmtExec mocks base method.
func (m *MockTXClient) StmtExec(arg0 transaction.Stmt, arg1 ...any) error {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "StmtExec", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// StmtExec indicates an expected call of StmtExec.
func (mr *MockTXClientMockRecorder) StmtExec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StmtExec", reflect.TypeOf((*MockTXClient)(nil).StmtExec), varargs...)
}

View File

@@ -0,0 +1,90 @@
/*
Package transaction provides a client for a live transaction, and interfaces for some relevant sql types. The transaction client automatically performs rollbacks on failures.
The use of this package simplifies testing for callers by making the underlying transaction mock-able.
*/
package transaction
import (
"context"
"database/sql"
"github.com/pkg/errors"
)
// Client provides a way to interact with the underlying sql transaction.
type Client struct {
sqlTx SQLTx
}
// SQLTx represents a sql transaction
type SQLTx interface {
Exec(query string, args ...any) (sql.Result, error)
Stmt(stmt *sql.Stmt) *sql.Stmt
Commit() error
Rollback() error
}
// Stmt represents a sql stmt. It is used as a return type to offer some testability over returning sql's Stmt type
// because we are able to mock its outputs and do not need an actual connection.
type Stmt interface {
Exec(args ...any) (sql.Result, error)
Query(args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
}
// NewClient returns a Client with the given transaction assigned.
func NewClient(tx SQLTx) *Client {
return &Client{sqlTx: tx}
}
// Commit commits the transaction and then unlocks the database.
func (c *Client) Commit() error {
return c.sqlTx.Commit()
}
// Exec uses the sqlTX Exec() with the given stmt and args. The transaction will be automatically rolled back if Exec()
// returns an error.
func (c *Client) Exec(stmt string, args ...any) error {
_, err := c.sqlTx.Exec(stmt, args...)
if err != nil {
return c.rollback(c.sqlTx, err)
}
return nil
}
// Stmt adds the given sql.Stmt to the client's transaction and then returns a Stmt. An interface is being returned
// here to aid in testing callers by providing a way to configure the statement's behavior.
func (c *Client) Stmt(stmt *sql.Stmt) Stmt {
s := c.sqlTx.Stmt(stmt)
return s
}
// StmtExec Execs the given statement with the given args. It assumes the stmt has been added to the transaction. The
// transaction is rolled back if Stmt.Exec() returns an error.
func (c *Client) StmtExec(stmt Stmt, args ...any) error {
_, err := stmt.Exec(args...)
if err != nil {
return c.rollback(c.sqlTx, err)
}
return nil
}
// rollback handles rollbacks and wraps errors if needed
func (c *Client) rollback(tx SQLTx, err error) error {
rerr := tx.Rollback()
if rerr != nil {
return errors.Wrapf(err, "Encountered error, then encountered another error while rolling back: %v", rerr)
}
return errors.Wrapf(err, "Encountered error, successfully rolled back")
}
// Cancel rollbacks the transaction without wrapping an error. This only needs to be called if Client has not returned
// an error yet or has not committed. Otherwise, transaction has already rolled back, or in the case of Commit() it is too
// late.
func (c *Client) Cancel() error {
rerr := c.sqlTx.Rollback()
if rerr != sql.ErrTxDone {
return rerr
}
return nil
}

View File

@@ -0,0 +1,184 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
//
// Package transaction is a generated GoMock package.
package transaction
import (
context "context"
sql "database/sql"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockStmt is a mock of Stmt interface.
type MockStmt struct {
ctrl *gomock.Controller
recorder *MockStmtMockRecorder
}
// MockStmtMockRecorder is the mock recorder for MockStmt.
type MockStmtMockRecorder struct {
mock *MockStmt
}
// NewMockStmt creates a new mock instance.
func NewMockStmt(ctrl *gomock.Controller) *MockStmt {
mock := &MockStmt{ctrl: ctrl}
mock.recorder = &MockStmtMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStmt) EXPECT() *MockStmtMockRecorder {
return m.recorder
}
// Exec mocks base method.
func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...)
}
// Query mocks base method.
func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Query", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Query indicates an expected call of Query.
func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...)
}
// QueryContext mocks base method.
func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryContext", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// QueryContext indicates an expected call of QueryContext.
func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
}
// MockSQLTx is a mock of SQLTx interface.
type MockSQLTx struct {
ctrl *gomock.Controller
recorder *MockSQLTxMockRecorder
}
// MockSQLTxMockRecorder is the mock recorder for MockSQLTx.
type MockSQLTxMockRecorder struct {
mock *MockSQLTx
}
// NewMockSQLTx creates a new mock instance.
func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx {
mock := &MockSQLTx{ctrl: ctrl}
mock.recorder = &MockSQLTxMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder {
return m.recorder
}
// Commit mocks base method.
func (m *MockSQLTx) Commit() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit")
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit))
}
// Exec mocks base method.
func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...)
}
// Rollback mocks base method.
func (m *MockSQLTx) Rollback() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Rollback")
ret0, _ := ret[0].(error)
return ret0
}
// Rollback indicates an expected call of Rollback.
func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback))
}
// Stmt mocks base method.
func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stmt", arg0)
ret0, _ := ret[0].(*sql.Stmt)
return ret0
}
// Stmt indicates an expected call of Stmt.
func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0)
}

View File

@@ -0,0 +1,182 @@
package transaction
import (
"database/sql"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
)
//go:generate mockgen --build_flags=--mod=mod -package transaction -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
func TestNewClient(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
c := NewClient(tx)
assert.Equal(t, tx, c.sqlTx)
}
func TestCommit(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "Commit() with no errors returned from sql TX should return no error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
tx.EXPECT().Commit().Return(nil)
c := &Client{
sqlTx: tx,
}
err := c.Commit()
assert.Nil(t, err)
}})
tests = append(tests, testCase{description: "Commit() with error from sql TX commit() should return error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
tx.EXPECT().Commit().Return(fmt.Errorf("error"))
c := &Client{
sqlTx: tx,
}
err := c.Commit()
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestExec(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmtStr := "some statement %s"
arg := 5
// should be passed same statement and arg that was passed to parent function
tx.EXPECT().Exec(stmtStr, arg).Return(nil, nil)
c := &Client{
sqlTx: tx,
}
err := c.Exec(stmtStr, arg)
assert.Nil(t, err)
}})
tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmtStr := "some statement %s"
arg := 5
// should be passed same statement and arg that was passed to parent function
tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error"))
tx.EXPECT().Rollback().Return(nil)
c := &Client{
sqlTx: tx,
}
err := c.Exec(stmtStr, arg)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "Exec() with error returned from sql TX Exec() and Rollback() error should return an error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmtStr := "some statement %s"
arg := 5
// should be passed same statement and arg that was passed to parent function
tx.EXPECT().Exec(stmtStr, arg).Return(nil, fmt.Errorf("error"))
tx.EXPECT().Rollback().Return(fmt.Errorf("error"))
c := &Client{
sqlTx: tx,
}
err := c.Exec(stmtStr, arg)
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestStmt(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "Exec() with no errors returned from sql TX should return no error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmt := &sql.Stmt{}
var returnedTXStmt *sql.Stmt
// should be passed same statement and arg that was passed to parent function
tx.EXPECT().Stmt(stmt).Return(returnedTXStmt)
c := &Client{
sqlTx: tx,
}
returnedStmt := c.Stmt(stmt)
// whatever tx returned should be returned here. Nil was used because none of sql.Stmt's fields are exported so its simpler to test nil as it
// won't be equal to an empty struct
assert.Equal(t, returnedTXStmt, returnedStmt)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}
func TestStmtExec(t *testing.T) {
type testCase struct {
description string
test func(t *testing.T)
}
var tests []testCase
tests = append(tests, testCase{description: "StmtExec with no errors returned from Stmt should return no error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmt := NewMockStmt(gomock.NewController(t))
arg := "something"
// should be passed same arg that was passed to parent function
stmt.EXPECT().Exec(arg).Return(nil, nil)
c := &Client{
sqlTx: tx,
}
err := c.StmtExec(stmt, arg)
assert.Nil(t, err)
}})
tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and no Tx Rollback() error should return error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmt := NewMockStmt(gomock.NewController(t))
arg := "something"
// should be passed same arg that was passed to parent function
stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error"))
tx.EXPECT().Rollback().Return(nil)
c := &Client{
sqlTx: tx,
}
err := c.StmtExec(stmt, arg)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "StmtExec with error returned from Stmt Exec and Tx Rollback() error should return error", test: func(t *testing.T) {
tx := NewMockSQLTx(gomock.NewController(t))
stmt := NewMockStmt(gomock.NewController(t))
arg := "something"
// should be passed same arg that was passed to parent function
stmt.EXPECT().Exec(arg).Return(nil, fmt.Errorf("error"))
tx.EXPECT().Rollback().Return(fmt.Errorf("error2"))
c := &Client{
sqlTx: tx,
}
err := c.StmtExec(stmt, arg)
assert.NotNil(t, err)
}})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) { test.test(t) })
}
}

View File

@@ -0,0 +1,184 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/rancher/steve/pkg/sqlcache/db/transaction (interfaces: Stmt,SQLTx)
//
// Generated by this command:
//
// mockgen --build_flags=--mod=mod -package db -destination ./transaction_mocks_test.go github.com/rancher/steve/pkg/sqlcache/db/transaction Stmt,SQLTx
//
// Package db is a generated GoMock package.
package db
import (
context "context"
sql "database/sql"
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockStmt is a mock of Stmt interface.
type MockStmt struct {
ctrl *gomock.Controller
recorder *MockStmtMockRecorder
}
// MockStmtMockRecorder is the mock recorder for MockStmt.
type MockStmtMockRecorder struct {
mock *MockStmt
}
// NewMockStmt creates a new mock instance.
func NewMockStmt(ctrl *gomock.Controller) *MockStmt {
mock := &MockStmt{ctrl: ctrl}
mock.recorder = &MockStmtMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockStmt) EXPECT() *MockStmtMockRecorder {
return m.recorder
}
// Exec mocks base method.
func (m *MockStmt) Exec(arg0 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockStmtMockRecorder) Exec(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockStmt)(nil).Exec), arg0...)
}
// Query mocks base method.
func (m *MockStmt) Query(arg0 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Query", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Query indicates an expected call of Query.
func (mr *MockStmtMockRecorder) Query(arg0 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockStmt)(nil).Query), arg0...)
}
// QueryContext mocks base method.
func (m *MockStmt) QueryContext(arg0 context.Context, arg1 ...any) (*sql.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryContext", varargs...)
ret0, _ := ret[0].(*sql.Rows)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// QueryContext indicates an expected call of QueryContext.
func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
}
// MockSQLTx is a mock of SQLTx interface.
type MockSQLTx struct {
ctrl *gomock.Controller
recorder *MockSQLTxMockRecorder
}
// MockSQLTxMockRecorder is the mock recorder for MockSQLTx.
type MockSQLTxMockRecorder struct {
mock *MockSQLTx
}
// NewMockSQLTx creates a new mock instance.
func NewMockSQLTx(ctrl *gomock.Controller) *MockSQLTx {
mock := &MockSQLTx{ctrl: ctrl}
mock.recorder = &MockSQLTxMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSQLTx) EXPECT() *MockSQLTxMockRecorder {
return m.recorder
}
// Commit mocks base method.
func (m *MockSQLTx) Commit() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit")
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
func (mr *MockSQLTxMockRecorder) Commit() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockSQLTx)(nil).Commit))
}
// Exec mocks base method.
func (m *MockSQLTx) Exec(arg0 string, arg1 ...any) (sql.Result, error) {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(sql.Result)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
func (mr *MockSQLTxMockRecorder) Exec(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSQLTx)(nil).Exec), varargs...)
}
// Rollback mocks base method.
func (m *MockSQLTx) Rollback() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Rollback")
ret0, _ := ret[0].(error)
return ret0
}
// Rollback indicates an expected call of Rollback.
func (mr *MockSQLTxMockRecorder) Rollback() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockSQLTx)(nil).Rollback))
}
// Stmt mocks base method.
func (m *MockSQLTx) Stmt(arg0 *sql.Stmt) *sql.Stmt {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stmt", arg0)
ret0, _ := ret[0].(*sql.Stmt)
return ret0
}
// Stmt indicates an expected call of Stmt.
func (mr *MockSQLTxMockRecorder) Stmt(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stmt", reflect.TypeOf((*MockSQLTx)(nil).Stmt), arg0)
}

View File

@@ -0,0 +1,8 @@
package db
import "strings"
// Sanitize returns a string that can be used in SQL as a name
func Sanitize(s string) string {
return strings.ReplaceAll(s, "\"", "")
}