1
0
mirror of https://github.com/rancher/steve.git synced 2025-09-05 01:12:09 +00:00

sql: use a closure to wrap transactions (#469)

This introduces the a `WithTransaction` function, which is then used for all transactional work in Steve.

Because `WithTransaction` takes care of all `Begin`s, `Commit`s and `Rollback`s, it eliminates the problem where forgotten open transactions can block all other operations (with long stalling and `SQLITE_BUSY` errors).

This also:

- merges together the disparate `DBClient` interfaces in one only `db.Client` interface with one unexported non-test implementation. I found this much easier to follow
- refactors the transaction package in order to make it as minimal as possible, and as close to the wrapped `sql.Tx` and `sql.Stmt` functions as possible, in order to reduce cognitive load when working with this part of the codebase
- simplifies tests accordingly
- adds a couple of known files to `.gitignore`
    
Credits to @tomleb for suggesting the approach: https://github.com/rancher/lasso/pull/121#pullrequestreview-2515872507
This commit is contained in:
Silvio Moioli
2025-02-05 10:05:52 +01:00
committed by GitHub
parent 6a46a1e091
commit 772dc7577e
28 changed files with 1543 additions and 2059 deletions

View File

@@ -11,6 +11,7 @@ import (
"strconv"
"strings"
"github.com/rancher/steve/pkg/sqlcache/db/transaction"
"github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/client-go/tools/cache"
@@ -108,51 +109,49 @@ func NewListOptionIndexer(fields [][]string, s Store, namespaced bool) (*ListOpt
columnDefs[index] = column
}
tx, err := l.BeginTx(context.Background(), true)
if err != nil {
return nil, err
}
dbName := db.Sanitize(i.GetName())
err = tx.Exec(fmt.Sprintf(createFieldsTableFmt, dbName, strings.Join(columnDefs, ", ")))
if err != nil {
return nil, err
}
columns := make([]string, len(indexedFields))
qmarks := make([]string, len(indexedFields))
setStatements := make([]string, len(indexedFields))
for index, field := range indexedFields {
// create index for field
err = tx.Exec(fmt.Sprintf(createFieldsIndexFmt, dbName, field, dbName, field))
err = l.WithTransaction(context.Background(), true, func(tx transaction.Client) error {
_, err = tx.Exec(fmt.Sprintf(createFieldsTableFmt, dbName, strings.Join(columnDefs, ", ")))
if err != nil {
return nil, err
return err
}
// format field into column for prepared statement
column := fmt.Sprintf(`"%s"`, field)
columns[index] = column
for index, field := range indexedFields {
// create index for field
_, err = tx.Exec(fmt.Sprintf(createFieldsIndexFmt, dbName, field, dbName, field))
if err != nil {
return err
}
// add placeholder for column's value in prepared statement
qmarks[index] = "?"
// format field into column for prepared statement
column := fmt.Sprintf(`"%s"`, field)
columns[index] = column
// add formatted set statement for prepared statement
setStatement := fmt.Sprintf(`"%s" = excluded."%s"`, field, field)
setStatements[index] = setStatement
}
createLabelsTableQuery := fmt.Sprintf(createLabelsTableFmt, dbName, dbName)
err = tx.Exec(createLabelsTableQuery)
if err != nil {
return nil, &db.QueryError{QueryString: createLabelsTableQuery, Err: err}
}
// add placeholder for column's value in prepared statement
qmarks[index] = "?"
createLabelsTableIndexQuery := fmt.Sprintf(createLabelsTableIndexFmt, dbName, dbName)
err = tx.Exec(createLabelsTableIndexQuery)
if err != nil {
return nil, &db.QueryError{QueryString: createLabelsTableIndexQuery, Err: err}
}
// add formatted set statement for prepared statement
setStatement := fmt.Sprintf(`"%s" = excluded."%s"`, field, field)
setStatements[index] = setStatement
}
createLabelsTableQuery := fmt.Sprintf(createLabelsTableFmt, dbName, dbName)
_, err = tx.Exec(createLabelsTableQuery)
if err != nil {
return &db.QueryError{QueryString: createLabelsTableQuery, Err: err}
}
err = tx.Commit()
createLabelsTableIndexQuery := fmt.Sprintf(createLabelsTableIndexFmt, dbName, dbName)
_, err = tx.Exec(createLabelsTableIndexQuery)
if err != nil {
return &db.QueryError{QueryString: createLabelsTableIndexQuery, Err: err}
}
return nil
})
if err != nil {
return nil, err
}
@@ -180,16 +179,12 @@ func NewListOptionIndexer(fields [][]string, s Store, namespaced bool) (*ListOpt
/* Core methods */
// addIndexFields saves sortable/filterable fields into tables
func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx db.TXClient) error {
func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx transaction.Client) error {
args := []any{key}
for _, field := range l.indexedFields {
value, err := getField(obj, field)
if err != nil {
logrus.Errorf("cannot index object of type [%s] with key [%s] for indexer [%s]: %v", l.GetType().String(), key, l.GetName(), err)
cErr := tx.Cancel()
if cErr != nil {
return fmt.Errorf("could not cancel transaction: %s while recovering from error: %w", cErr, err)
}
return err
}
switch typedValue := value.(type) {
@@ -201,15 +196,11 @@ func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx db.TXClient)
args = append(args, strings.Join(typedValue, "|"))
default:
err2 := fmt.Errorf("field %v has a non-supported type value: %v", field, value)
cErr := tx.Cancel()
if cErr != nil {
return fmt.Errorf("could not cancel transaction: %s while recovering from error: %w", cErr, err2)
}
return err2
}
}
err := tx.StmtExec(tx.Stmt(l.addFieldStmt), args...)
_, err := tx.Stmt(l.addFieldStmt).Exec(args...)
if err != nil {
return &db.QueryError{QueryString: l.addFieldQuery, Err: err}
}
@@ -217,14 +208,14 @@ func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx db.TXClient)
}
// labels are stored in tables that shadow the underlying object table for each GVK
func (l *ListOptionIndexer) addLabels(key string, obj any, tx db.TXClient) error {
func (l *ListOptionIndexer) addLabels(key string, obj any, tx transaction.Client) error {
k8sObj, ok := obj.(*unstructured.Unstructured)
if !ok {
return fmt.Errorf("addLabels: unexpected object type, expected unstructured.Unstructured: %v", obj)
}
incomingLabels := k8sObj.GetLabels()
for k, v := range incomingLabels {
err := tx.StmtExec(tx.Stmt(l.upsertLabelsStmt), key, k, v)
_, err := tx.Stmt(l.upsertLabelsStmt).Exec(key, k, v)
if err != nil {
return &db.QueryError{QueryString: l.upsertLabelsQuery, Err: err}
}
@@ -232,18 +223,18 @@ func (l *ListOptionIndexer) addLabels(key string, obj any, tx db.TXClient) error
return nil
}
func (l *ListOptionIndexer) deleteIndexFields(key string, tx db.TXClient) error {
func (l *ListOptionIndexer) deleteIndexFields(key string, tx transaction.Client) error {
args := []any{key}
err := tx.StmtExec(tx.Stmt(l.deleteFieldStmt), args...)
_, err := tx.Stmt(l.deleteFieldStmt).Exec(args...)
if err != nil {
return &db.QueryError{QueryString: l.deleteFieldQuery, Err: err}
}
return nil
}
func (l *ListOptionIndexer) deleteLabels(key string, tx db.TXClient) error {
err := tx.StmtExec(tx.Stmt(l.deleteLabelsStmt), key)
func (l *ListOptionIndexer) deleteLabels(key string, tx transaction.Client) error {
_, err := tx.Stmt(l.deleteLabelsStmt).Exec(key)
if err != nil {
return &db.QueryError{QueryString: l.deleteLabelsQuery, Err: err}
}
@@ -467,48 +458,37 @@ func (l *ListOptionIndexer) executeQuery(ctx context.Context, queryInfo *QueryIn
stmt := l.Prepare(queryInfo.query)
defer l.CloseStmt(stmt)
tx, err := l.BeginTx(ctx, false)
if err != nil {
return nil, 0, "", err
}
txStmt := tx.Stmt(stmt)
rows, err := txStmt.QueryContext(ctx, queryInfo.params...)
if err != nil {
if cerr := tx.Cancel(); cerr != nil {
return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err)
}
return nil, 0, "", &db.QueryError{QueryString: queryInfo.query, Err: err}
}
items, err := l.ReadObjects(rows, l.GetType(), l.GetShouldEncrypt())
if err != nil {
if cerr := tx.Cancel(); cerr != nil {
return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err)
}
return nil, 0, "", err
}
total := len(items)
if queryInfo.countQuery != "" {
countStmt := l.Prepare(queryInfo.countQuery)
defer l.CloseStmt(countStmt)
txStmt := tx.Stmt(countStmt)
rows, err := txStmt.QueryContext(ctx, queryInfo.countParams...)
var items []any
var total int
err := l.WithTransaction(ctx, false, func(tx transaction.Client) error {
txStmt := tx.Stmt(stmt)
rows, err := txStmt.QueryContext(ctx, queryInfo.params...)
if err != nil {
if cerr := tx.Cancel(); cerr != nil {
return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err)
}
return nil, 0, "", &db.QueryError{QueryString: queryInfo.countQuery, Err: err}
return &db.QueryError{QueryString: queryInfo.query, Err: err}
}
total, err = l.ReadInt(rows)
items, err = l.ReadObjects(rows, l.GetType(), l.GetShouldEncrypt())
if err != nil {
if cerr := tx.Cancel(); cerr != nil {
return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err)
}
return nil, 0, "", fmt.Errorf("error reading query results: %w", err)
return err
}
}
if err := tx.Commit(); err != nil {
total = len(items)
if queryInfo.countQuery != "" {
countStmt := l.Prepare(queryInfo.countQuery)
defer l.CloseStmt(countStmt)
txStmt := tx.Stmt(countStmt)
rows, err := txStmt.QueryContext(ctx, queryInfo.countParams...)
if err != nil {
return &db.QueryError{QueryString: queryInfo.countQuery, Err: err}
}
total, err = l.ReadInt(rows)
if err != nil {
return fmt.Errorf("error reading query results: %w", err)
}
}
return nil
})
if err != nil {
return nil, 0, "", err
}