1
0
mirror of https://github.com/rancher/steve.git synced 2025-09-05 09:21:12 +00:00

Handle transaction failure due to canceled context.Context (#662)

* Re-order SQL event hooks so events are last

* Add QueryRowContext for single line queries

* Add test case for unknown resource version

* Properly check rows and close it

* More accurate error message when context.Context is canceled

* Re-order test check
This commit is contained in:
Tom Lebreux
2025-06-09 13:39:09 -06:00
committed by GitHub
parent b695567794
commit b4db257cdb
7 changed files with 118 additions and 21 deletions

View File

@@ -65,16 +65,33 @@ func (c *client) WithTransaction(ctx context.Context, forWriting bool, f WithTra
return err
}
err = f(transaction.NewClient(tx))
if err != nil {
rerr := tx.Rollback()
err = errors.Join(err, rerr)
} else {
cerr := tx.Commit()
err = errors.Join(err, cerr)
if err = f(transaction.NewClient(tx)); err != nil {
rerr := c.rollback(ctx, tx)
return errors.Join(err, rerr)
}
err = c.commit(ctx, tx)
if err != nil {
// When the context.Context given to BeginTx is canceled, then the
// Tx is rolled back already, so rolling back again could have failed.
return err
}
return nil
}
func (c *client) commit(ctx context.Context, tx *sql.Tx) error {
err := tx.Commit()
if errors.Is(err, sql.ErrTxDone) && ctx.Err() == context.Canceled {
return fmt.Errorf("commit failed due to canceled context")
}
return err
}
func (c *client) rollback(ctx context.Context, tx *sql.Tx) error {
err := tx.Rollback()
if errors.Is(err, sql.ErrTxDone) && ctx.Err() == context.Canceled {
return fmt.Errorf("rollback failed due to canceled context")
}
return err
}

View File

@@ -41,4 +41,5 @@ type Stmt interface {
Exec(args ...any) (sql.Result, error)
Query(args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, args ...any) *sql.Row
}

View File

@@ -155,3 +155,22 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
}
// QueryRowContext mocks base method.
func (m *MockStmt) QueryRowContext(arg0 context.Context, arg1 ...any) *sql.Row {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryRowContext", varargs...)
ret0, _ := ret[0].(*sql.Row)
return ret0
}
// QueryRowContext indicates an expected call of QueryRowContext.
func (mr *MockStmtMockRecorder) QueryRowContext(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockStmt)(nil).QueryRowContext), varargs...)
}

View File

@@ -68,6 +68,7 @@ var (
subfieldRegex = regexp.MustCompile(`([a-zA-Z]+)|(\[[-a-zA-Z./]+])|(\[[0-9]+])`)
ErrInvalidColumn = errors.New("supplied column is invalid")
ErrTooOld = errors.New("resourceversion too old")
)
const (
@@ -144,11 +145,11 @@ func NewListOptionIndexer(ctx context.Context, fields [][]string, s Store, names
watchers: make(map[*watchKey]*watcher),
}
l.RegisterAfterAdd(l.addIndexFields)
l.RegisterAfterAdd(l.addLabels)
l.RegisterAfterAdd(l.notifyEventAdded)
l.RegisterAfterUpdate(l.addIndexFields)
l.RegisterAfterUpdate(l.notifyEventModified)
l.RegisterAfterAdd(l.addLabels)
l.RegisterAfterUpdate(l.addLabels)
l.RegisterAfterUpdate(l.notifyEventModified)
l.RegisterAfterDelete(l.deleteFieldsByKey)
l.RegisterAfterDelete(l.deleteLabelsByKey)
l.RegisterAfterDelete(l.notifyEventDeleted)
@@ -264,22 +265,27 @@ func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, events
// Even though we're not writing in this transaction, we prevent other writes to SQL
// because we don't want to add more events while we're backfilling events, so we don't miss events
err := l.WithTransaction(ctx, true, func(tx transaction.Client) error {
rowIDRows, err := tx.Stmt(l.findEventsRowByRVStmt).QueryContext(ctx, targetRV)
if err != nil {
return &db.QueryError{QueryString: l.listEventsAfterQuery, Err: err}
}
if !rowIDRows.Next() && targetRV != latestRV {
return fmt.Errorf("resourceversion too old")
rowIDRow := tx.Stmt(l.findEventsRowByRVStmt).QueryRowContext(ctx, targetRV)
if err := rowIDRow.Err(); err != nil {
return &db.QueryError{QueryString: l.findEventsRowByRVQuery, Err: err}
}
var rowID int
rowIDRows.Scan(&rowID)
err := rowIDRow.Scan(&rowID)
if errors.Is(err, sql.ErrNoRows) {
if targetRV != latestRV {
return ErrTooOld
}
} else if err != nil {
return fmt.Errorf("failed scan rowid: %w", err)
}
// Backfilling previous events from resourceVersion
rows, err := tx.Stmt(l.listEventsAfterStmt).QueryContext(ctx, rowID)
if err != nil {
return &db.QueryError{QueryString: l.listEventsAfterQuery, Err: err}
}
defer rows.Close()
for rows.Next() {
var typ, rv string
@@ -311,6 +317,10 @@ func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, events
})
}
if err := rows.Err(); err != nil {
return err
}
for _, event := range events {
eventsCh <- event
}
@@ -318,9 +328,13 @@ func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, events
key = l.addWatcher(eventsCh, opts.Filter)
return nil
})
if err != nil {
return err
}
<-ctx.Done()
l.removeWatcher(key)
return err
return nil
}
func toBytes(obj any) []byte {

View File

@@ -2199,6 +2199,7 @@ func TestWatchResourceVersion(t *testing.T) {
tests := []struct {
rv string
expectedEvents []watch.Event
expectedErr error
}{
{
rv: "",
@@ -2237,6 +2238,10 @@ func TestWatchResourceVersion(t *testing.T) {
rv: rv5,
expectedEvents: nil,
},
{
rv: "unknown",
expectedErr: ErrTooOld,
},
}
for _, test := range tests {
@@ -2245,11 +2250,14 @@ func TestWatchResourceVersion(t *testing.T) {
watcherCh, errCh := startWatcher(ctx, loi, test.rv)
gotEvents := receiveEvents(watcherCh)
assert.Equal(t, test.expectedEvents, gotEvents)
cancel()
err := waitStopWatcher(errCh)
if test.expectedErr != nil {
assert.ErrorIs(t, err, ErrTooOld)
} else {
assert.NoError(t, err)
assert.Equal(t, test.expectedEvents, gotEvents)
}
})
}
}

View File

@@ -99,6 +99,25 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
}
// QueryRowContext mocks base method.
func (m *MockStmt) QueryRowContext(arg0 context.Context, arg1 ...any) *sql.Row {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryRowContext", varargs...)
ret0, _ := ret[0].(*sql.Row)
return ret0
}
// QueryRowContext indicates an expected call of QueryRowContext.
func (mr *MockStmtMockRecorder) QueryRowContext(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockStmt)(nil).QueryRowContext), varargs...)
}
// MockTXClient is a mock of Client interface.
type MockTXClient struct {
ctrl *gomock.Controller

View File

@@ -99,6 +99,25 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
}
// QueryRowContext mocks base method.
func (m *MockStmt) QueryRowContext(arg0 context.Context, arg1 ...any) *sql.Row {
m.ctrl.T.Helper()
varargs := []any{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryRowContext", varargs...)
ret0, _ := ret[0].(*sql.Row)
return ret0
}
// QueryRowContext indicates an expected call of QueryRowContext.
func (mr *MockStmtMockRecorder) QueryRowContext(arg0 any, arg1 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockStmt)(nil).QueryRowContext), varargs...)
}
// MockTXClient is a mock of Client interface.
type MockTXClient struct {
ctrl *gomock.Controller