1
0
mirror of https://github.com/rancher/steve.git synced 2025-09-12 13:31:57 +00:00

Handle transaction failure due to canceled context.Context (#662)

* Re-order SQL event hooks so events are last

* Add QueryRowContext for single line queries

* Add test case for unknown resource version

* Properly check rows and close it

* More accurate error message when context.Context is canceled

* Re-order test check
This commit is contained in:
Tom Lebreux
2025-06-09 13:39:09 -06:00
committed by GitHub
parent b695567794
commit b4db257cdb
7 changed files with 118 additions and 21 deletions

View File

@@ -65,16 +65,33 @@ func (c *client) WithTransaction(ctx context.Context, forWriting bool, f WithTra
return err
}
err = f(transaction.NewClient(tx))
if err != nil {
rerr := tx.Rollback()
err = errors.Join(err, rerr)
} else {
cerr := tx.Commit()
err = errors.Join(err, cerr)
if err = f(transaction.NewClient(tx)); err != nil {
rerr := c.rollback(ctx, tx)
return errors.Join(err, rerr)
}
err = c.commit(ctx, tx)
if err != nil {
// When the context.Context given to BeginTx is canceled, then the
// Tx is rolled back already, so rolling back again could have failed.
return err
}
return nil
}
func (c *client) commit(ctx context.Context, tx *sql.Tx) error {
err := tx.Commit()
if errors.Is(err, sql.ErrTxDone) && ctx.Err() == context.Canceled {
return fmt.Errorf("commit failed due to canceled context")
}
return err
}
func (c *client) rollback(ctx context.Context, tx *sql.Tx) error {
err := tx.Rollback()
if errors.Is(err, sql.ErrTxDone) && ctx.Err() == context.Canceled {
return fmt.Errorf("rollback failed due to canceled context")
}
return err
}

View File

@@ -41,4 +41,5 @@ type Stmt interface {
Exec(args ...any) (sql.Result, error)
Query(args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, args ...any) *sql.Row
}

View File

@@ -155,3 +155,22 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
}
// QueryRowContext mocks base method.
func (m *MockStmt) QueryRowContext(arg0 context.Context, arg1 ...any) *sql.Row {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryRowContext", varargs...)
ret0, _ := ret[0].(*sql.Row)
return ret0
}
// QueryRowContext indicates an expected call of QueryRowContext.
func (mr *MockStmtMockRecorder) QueryRowContext(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockStmt)(nil).QueryRowContext), varargs...)
}