mirror of
https://github.com/rancher/steve.git
synced 2025-09-05 01:12:09 +00:00
sql: use a closure to wrap transactions (#469)
This introduces the a `WithTransaction` function, which is then used for all transactional work in Steve. Because `WithTransaction` takes care of all `Begin`s, `Commit`s and `Rollback`s, it eliminates the problem where forgotten open transactions can block all other operations (with long stalling and `SQLITE_BUSY` errors). This also: - merges together the disparate `DBClient` interfaces in one only `db.Client` interface with one unexported non-test implementation. I found this much easier to follow - refactors the transaction package in order to make it as minimal as possible, and as close to the wrapped `sql.Tx` and `sql.Stmt` functions as possible, in order to reduce cognitive load when working with this part of the codebase - simplifies tests accordingly - adds a couple of known files to `.gitignore` Credits to @tomleb for suggesting the approach: https://github.com/rancher/lasso/pull/121#pullrequestreview-2515872507
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/rancher/steve/pkg/sqlcache/db/transaction"
|
||||
"github.com/sirupsen/logrus"
|
||||
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
|
||||
"k8s.io/client-go/tools/cache"
|
||||
@@ -108,51 +109,49 @@ func NewListOptionIndexer(fields [][]string, s Store, namespaced bool) (*ListOpt
|
||||
columnDefs[index] = column
|
||||
}
|
||||
|
||||
tx, err := l.BeginTx(context.Background(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbName := db.Sanitize(i.GetName())
|
||||
err = tx.Exec(fmt.Sprintf(createFieldsTableFmt, dbName, strings.Join(columnDefs, ", ")))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := make([]string, len(indexedFields))
|
||||
qmarks := make([]string, len(indexedFields))
|
||||
setStatements := make([]string, len(indexedFields))
|
||||
|
||||
for index, field := range indexedFields {
|
||||
// create index for field
|
||||
err = tx.Exec(fmt.Sprintf(createFieldsIndexFmt, dbName, field, dbName, field))
|
||||
err = l.WithTransaction(context.Background(), true, func(tx transaction.Client) error {
|
||||
_, err = tx.Exec(fmt.Sprintf(createFieldsTableFmt, dbName, strings.Join(columnDefs, ", ")))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
// format field into column for prepared statement
|
||||
column := fmt.Sprintf(`"%s"`, field)
|
||||
columns[index] = column
|
||||
for index, field := range indexedFields {
|
||||
// create index for field
|
||||
_, err = tx.Exec(fmt.Sprintf(createFieldsIndexFmt, dbName, field, dbName, field))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// add placeholder for column's value in prepared statement
|
||||
qmarks[index] = "?"
|
||||
// format field into column for prepared statement
|
||||
column := fmt.Sprintf(`"%s"`, field)
|
||||
columns[index] = column
|
||||
|
||||
// add formatted set statement for prepared statement
|
||||
setStatement := fmt.Sprintf(`"%s" = excluded."%s"`, field, field)
|
||||
setStatements[index] = setStatement
|
||||
}
|
||||
createLabelsTableQuery := fmt.Sprintf(createLabelsTableFmt, dbName, dbName)
|
||||
err = tx.Exec(createLabelsTableQuery)
|
||||
if err != nil {
|
||||
return nil, &db.QueryError{QueryString: createLabelsTableQuery, Err: err}
|
||||
}
|
||||
// add placeholder for column's value in prepared statement
|
||||
qmarks[index] = "?"
|
||||
|
||||
createLabelsTableIndexQuery := fmt.Sprintf(createLabelsTableIndexFmt, dbName, dbName)
|
||||
err = tx.Exec(createLabelsTableIndexQuery)
|
||||
if err != nil {
|
||||
return nil, &db.QueryError{QueryString: createLabelsTableIndexQuery, Err: err}
|
||||
}
|
||||
// add formatted set statement for prepared statement
|
||||
setStatement := fmt.Sprintf(`"%s" = excluded."%s"`, field, field)
|
||||
setStatements[index] = setStatement
|
||||
}
|
||||
createLabelsTableQuery := fmt.Sprintf(createLabelsTableFmt, dbName, dbName)
|
||||
_, err = tx.Exec(createLabelsTableQuery)
|
||||
if err != nil {
|
||||
return &db.QueryError{QueryString: createLabelsTableQuery, Err: err}
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
createLabelsTableIndexQuery := fmt.Sprintf(createLabelsTableIndexFmt, dbName, dbName)
|
||||
_, err = tx.Exec(createLabelsTableIndexQuery)
|
||||
if err != nil {
|
||||
return &db.QueryError{QueryString: createLabelsTableIndexQuery, Err: err}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -180,16 +179,12 @@ func NewListOptionIndexer(fields [][]string, s Store, namespaced bool) (*ListOpt
|
||||
/* Core methods */
|
||||
|
||||
// addIndexFields saves sortable/filterable fields into tables
|
||||
func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx db.TXClient) error {
|
||||
func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx transaction.Client) error {
|
||||
args := []any{key}
|
||||
for _, field := range l.indexedFields {
|
||||
value, err := getField(obj, field)
|
||||
if err != nil {
|
||||
logrus.Errorf("cannot index object of type [%s] with key [%s] for indexer [%s]: %v", l.GetType().String(), key, l.GetName(), err)
|
||||
cErr := tx.Cancel()
|
||||
if cErr != nil {
|
||||
return fmt.Errorf("could not cancel transaction: %s while recovering from error: %w", cErr, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
switch typedValue := value.(type) {
|
||||
@@ -201,15 +196,11 @@ func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx db.TXClient)
|
||||
args = append(args, strings.Join(typedValue, "|"))
|
||||
default:
|
||||
err2 := fmt.Errorf("field %v has a non-supported type value: %v", field, value)
|
||||
cErr := tx.Cancel()
|
||||
if cErr != nil {
|
||||
return fmt.Errorf("could not cancel transaction: %s while recovering from error: %w", cErr, err2)
|
||||
}
|
||||
return err2
|
||||
}
|
||||
}
|
||||
|
||||
err := tx.StmtExec(tx.Stmt(l.addFieldStmt), args...)
|
||||
_, err := tx.Stmt(l.addFieldStmt).Exec(args...)
|
||||
if err != nil {
|
||||
return &db.QueryError{QueryString: l.addFieldQuery, Err: err}
|
||||
}
|
||||
@@ -217,14 +208,14 @@ func (l *ListOptionIndexer) addIndexFields(key string, obj any, tx db.TXClient)
|
||||
}
|
||||
|
||||
// labels are stored in tables that shadow the underlying object table for each GVK
|
||||
func (l *ListOptionIndexer) addLabels(key string, obj any, tx db.TXClient) error {
|
||||
func (l *ListOptionIndexer) addLabels(key string, obj any, tx transaction.Client) error {
|
||||
k8sObj, ok := obj.(*unstructured.Unstructured)
|
||||
if !ok {
|
||||
return fmt.Errorf("addLabels: unexpected object type, expected unstructured.Unstructured: %v", obj)
|
||||
}
|
||||
incomingLabels := k8sObj.GetLabels()
|
||||
for k, v := range incomingLabels {
|
||||
err := tx.StmtExec(tx.Stmt(l.upsertLabelsStmt), key, k, v)
|
||||
_, err := tx.Stmt(l.upsertLabelsStmt).Exec(key, k, v)
|
||||
if err != nil {
|
||||
return &db.QueryError{QueryString: l.upsertLabelsQuery, Err: err}
|
||||
}
|
||||
@@ -232,18 +223,18 @@ func (l *ListOptionIndexer) addLabels(key string, obj any, tx db.TXClient) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *ListOptionIndexer) deleteIndexFields(key string, tx db.TXClient) error {
|
||||
func (l *ListOptionIndexer) deleteIndexFields(key string, tx transaction.Client) error {
|
||||
args := []any{key}
|
||||
|
||||
err := tx.StmtExec(tx.Stmt(l.deleteFieldStmt), args...)
|
||||
_, err := tx.Stmt(l.deleteFieldStmt).Exec(args...)
|
||||
if err != nil {
|
||||
return &db.QueryError{QueryString: l.deleteFieldQuery, Err: err}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *ListOptionIndexer) deleteLabels(key string, tx db.TXClient) error {
|
||||
err := tx.StmtExec(tx.Stmt(l.deleteLabelsStmt), key)
|
||||
func (l *ListOptionIndexer) deleteLabels(key string, tx transaction.Client) error {
|
||||
_, err := tx.Stmt(l.deleteLabelsStmt).Exec(key)
|
||||
if err != nil {
|
||||
return &db.QueryError{QueryString: l.deleteLabelsQuery, Err: err}
|
||||
}
|
||||
@@ -467,48 +458,37 @@ func (l *ListOptionIndexer) executeQuery(ctx context.Context, queryInfo *QueryIn
|
||||
stmt := l.Prepare(queryInfo.query)
|
||||
defer l.CloseStmt(stmt)
|
||||
|
||||
tx, err := l.BeginTx(ctx, false)
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
|
||||
txStmt := tx.Stmt(stmt)
|
||||
rows, err := txStmt.QueryContext(ctx, queryInfo.params...)
|
||||
if err != nil {
|
||||
if cerr := tx.Cancel(); cerr != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err)
|
||||
}
|
||||
return nil, 0, "", &db.QueryError{QueryString: queryInfo.query, Err: err}
|
||||
}
|
||||
items, err := l.ReadObjects(rows, l.GetType(), l.GetShouldEncrypt())
|
||||
if err != nil {
|
||||
if cerr := tx.Cancel(); cerr != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err)
|
||||
}
|
||||
return nil, 0, "", err
|
||||
}
|
||||
|
||||
total := len(items)
|
||||
if queryInfo.countQuery != "" {
|
||||
countStmt := l.Prepare(queryInfo.countQuery)
|
||||
defer l.CloseStmt(countStmt)
|
||||
txStmt := tx.Stmt(countStmt)
|
||||
rows, err := txStmt.QueryContext(ctx, queryInfo.countParams...)
|
||||
var items []any
|
||||
var total int
|
||||
err := l.WithTransaction(ctx, false, func(tx transaction.Client) error {
|
||||
txStmt := tx.Stmt(stmt)
|
||||
rows, err := txStmt.QueryContext(ctx, queryInfo.params...)
|
||||
if err != nil {
|
||||
if cerr := tx.Cancel(); cerr != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err)
|
||||
}
|
||||
return nil, 0, "", &db.QueryError{QueryString: queryInfo.countQuery, Err: err}
|
||||
return &db.QueryError{QueryString: queryInfo.query, Err: err}
|
||||
}
|
||||
total, err = l.ReadInt(rows)
|
||||
items, err = l.ReadObjects(rows, l.GetType(), l.GetShouldEncrypt())
|
||||
if err != nil {
|
||||
if cerr := tx.Cancel(); cerr != nil {
|
||||
return nil, 0, "", fmt.Errorf("failed to cancel transaction (%v) after error: %w", cerr, err)
|
||||
}
|
||||
return nil, 0, "", fmt.Errorf("error reading query results: %w", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
|
||||
total = len(items)
|
||||
if queryInfo.countQuery != "" {
|
||||
countStmt := l.Prepare(queryInfo.countQuery)
|
||||
defer l.CloseStmt(countStmt)
|
||||
txStmt := tx.Stmt(countStmt)
|
||||
rows, err := txStmt.QueryContext(ctx, queryInfo.countParams...)
|
||||
if err != nil {
|
||||
return &db.QueryError{QueryString: queryInfo.countQuery, Err: err}
|
||||
}
|
||||
total, err = l.ReadInt(rows)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading query results: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, "", err
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user