mirror of
https://github.com/rancher/steve.git
synced 2025-08-22 16:16:49 +00:00
91 lines
2.9 KiB
Go
91 lines
2.9 KiB
Go
|
/*
|
||
|
Package transaction provides a client for a live transaction, and interfaces for some relevant sql types. The transaction client automatically performs rollbacks on failures.
|
||
|
The use of this package simplifies testing for callers by making the underlying transaction mock-able.
|
||
|
*/
|
||
|
package transaction
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"database/sql"
|
||
|
|
||
|
"github.com/pkg/errors"
|
||
|
)
|
||
|
|
||
|
// Client provides a way to interact with the underlying sql transaction.
|
||
|
type Client struct {
|
||
|
sqlTx SQLTx
|
||
|
}
|
||
|
|
||
|
// SQLTx represents a sql transaction
|
||
|
type SQLTx interface {
|
||
|
Exec(query string, args ...any) (sql.Result, error)
|
||
|
Stmt(stmt *sql.Stmt) *sql.Stmt
|
||
|
Commit() error
|
||
|
Rollback() error
|
||
|
}
|
||
|
|
||
|
// Stmt represents a sql stmt. It is used as a return type to offer some testability over returning sql's Stmt type
|
||
|
// because we are able to mock its outputs and do not need an actual connection.
|
||
|
type Stmt interface {
|
||
|
Exec(args ...any) (sql.Result, error)
|
||
|
Query(args ...any) (*sql.Rows, error)
|
||
|
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
|
||
|
}
|
||
|
|
||
|
// NewClient returns a Client with the given transaction assigned.
|
||
|
func NewClient(tx SQLTx) *Client {
|
||
|
return &Client{sqlTx: tx}
|
||
|
}
|
||
|
|
||
|
// Commit commits the transaction and then unlocks the database.
|
||
|
func (c *Client) Commit() error {
|
||
|
return c.sqlTx.Commit()
|
||
|
}
|
||
|
|
||
|
// Exec uses the sqlTX Exec() with the given stmt and args. The transaction will be automatically rolled back if Exec()
|
||
|
// returns an error.
|
||
|
func (c *Client) Exec(stmt string, args ...any) error {
|
||
|
_, err := c.sqlTx.Exec(stmt, args...)
|
||
|
if err != nil {
|
||
|
return c.rollback(c.sqlTx, err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Stmt adds the given sql.Stmt to the client's transaction and then returns a Stmt. An interface is being returned
|
||
|
// here to aid in testing callers by providing a way to configure the statement's behavior.
|
||
|
func (c *Client) Stmt(stmt *sql.Stmt) Stmt {
|
||
|
s := c.sqlTx.Stmt(stmt)
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
// StmtExec Execs the given statement with the given args. It assumes the stmt has been added to the transaction. The
|
||
|
// transaction is rolled back if Stmt.Exec() returns an error.
|
||
|
func (c *Client) StmtExec(stmt Stmt, args ...any) error {
|
||
|
_, err := stmt.Exec(args...)
|
||
|
if err != nil {
|
||
|
return c.rollback(c.sqlTx, err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// rollback handles rollbacks and wraps errors if needed
|
||
|
func (c *Client) rollback(tx SQLTx, err error) error {
|
||
|
rerr := tx.Rollback()
|
||
|
if rerr != nil {
|
||
|
return errors.Wrapf(err, "Encountered error, then encountered another error while rolling back: %v", rerr)
|
||
|
}
|
||
|
return errors.Wrapf(err, "Encountered error, successfully rolled back")
|
||
|
}
|
||
|
|
||
|
// Cancel rollbacks the transaction without wrapping an error. This only needs to be called if Client has not returned
|
||
|
// an error yet or has not committed. Otherwise, transaction has already rolled back, or in the case of Commit() it is too
|
||
|
// late.
|
||
|
func (c *Client) Cancel() error {
|
||
|
rerr := c.sqlTx.Rollback()
|
||
|
if rerr != sql.ErrTxDone {
|
||
|
return rerr
|
||
|
}
|
||
|
return nil
|
||
|
}
|