mirror of
https://github.com/rancher/steve.git
synced 2025-09-12 13:31:57 +00:00
Handle transaction failure due to canceled context.Context (#662)
* Re-order SQL event hooks so events are last * Add QueryRowContext for single line queries * Add test case for unknown resource version * Properly check rows and close it * More accurate error message when context.Context is canceled * Re-order test check
This commit is contained in:
@@ -65,16 +65,33 @@ func (c *client) WithTransaction(ctx context.Context, forWriting bool, f WithTra
|
||||
return err
|
||||
}
|
||||
|
||||
err = f(transaction.NewClient(tx))
|
||||
|
||||
if err != nil {
|
||||
rerr := tx.Rollback()
|
||||
err = errors.Join(err, rerr)
|
||||
} else {
|
||||
cerr := tx.Commit()
|
||||
err = errors.Join(err, cerr)
|
||||
if err = f(transaction.NewClient(tx)); err != nil {
|
||||
rerr := c.rollback(ctx, tx)
|
||||
return errors.Join(err, rerr)
|
||||
}
|
||||
|
||||
err = c.commit(ctx, tx)
|
||||
if err != nil {
|
||||
// When the context.Context given to BeginTx is canceled, then the
|
||||
// Tx is rolled back already, so rolling back again could have failed.
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) commit(ctx context.Context, tx *sql.Tx) error {
|
||||
err := tx.Commit()
|
||||
if errors.Is(err, sql.ErrTxDone) && ctx.Err() == context.Canceled {
|
||||
return fmt.Errorf("commit failed due to canceled context")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) rollback(ctx context.Context, tx *sql.Tx) error {
|
||||
err := tx.Rollback()
|
||||
if errors.Is(err, sql.ErrTxDone) && ctx.Err() == context.Canceled {
|
||||
return fmt.Errorf("rollback failed due to canceled context")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
@@ -41,4 +41,5 @@ type Stmt interface {
|
||||
Exec(args ...any) (sql.Result, error)
|
||||
Query(args ...any) (*sql.Rows, error)
|
||||
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, args ...any) *sql.Row
|
||||
}
|
||||
|
@@ -155,3 +155,22 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call
|
||||
varargs := append([]any{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
|
||||
}
|
||||
|
||||
// QueryRowContext mocks base method.
|
||||
func (m *MockStmt) QueryRowContext(arg0 context.Context, arg1 ...any) *sql.Row {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "QueryRowContext", varargs...)
|
||||
ret0, _ := ret[0].(*sql.Row)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// QueryRowContext indicates an expected call of QueryRowContext.
|
||||
func (mr *MockStmtMockRecorder) QueryRowContext(arg0 any, arg1 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockStmt)(nil).QueryRowContext), varargs...)
|
||||
}
|
||||
|
Reference in New Issue
Block a user