mirror of
https://github.com/rancher/steve.git
synced 2025-09-05 01:12:09 +00:00
Handle transaction failure due to canceled context.Context (#662)
* Re-order SQL event hooks so events are last * Add QueryRowContext for single line queries * Add test case for unknown resource version * Properly check rows and close it * More accurate error message when context.Context is canceled * Re-order test check
This commit is contained in:
@@ -65,16 +65,33 @@ func (c *client) WithTransaction(ctx context.Context, forWriting bool, f WithTra
|
||||
return err
|
||||
}
|
||||
|
||||
err = f(transaction.NewClient(tx))
|
||||
|
||||
if err != nil {
|
||||
rerr := tx.Rollback()
|
||||
err = errors.Join(err, rerr)
|
||||
} else {
|
||||
cerr := tx.Commit()
|
||||
err = errors.Join(err, cerr)
|
||||
if err = f(transaction.NewClient(tx)); err != nil {
|
||||
rerr := c.rollback(ctx, tx)
|
||||
return errors.Join(err, rerr)
|
||||
}
|
||||
|
||||
err = c.commit(ctx, tx)
|
||||
if err != nil {
|
||||
// When the context.Context given to BeginTx is canceled, then the
|
||||
// Tx is rolled back already, so rolling back again could have failed.
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) commit(ctx context.Context, tx *sql.Tx) error {
|
||||
err := tx.Commit()
|
||||
if errors.Is(err, sql.ErrTxDone) && ctx.Err() == context.Canceled {
|
||||
return fmt.Errorf("commit failed due to canceled context")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) rollback(ctx context.Context, tx *sql.Tx) error {
|
||||
err := tx.Rollback()
|
||||
if errors.Is(err, sql.ErrTxDone) && ctx.Err() == context.Canceled {
|
||||
return fmt.Errorf("rollback failed due to canceled context")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
@@ -41,4 +41,5 @@ type Stmt interface {
|
||||
Exec(args ...any) (sql.Result, error)
|
||||
Query(args ...any) (*sql.Rows, error)
|
||||
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, args ...any) *sql.Row
|
||||
}
|
||||
|
@@ -155,3 +155,22 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call
|
||||
varargs := append([]any{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
|
||||
}
|
||||
|
||||
// QueryRowContext mocks base method.
|
||||
func (m *MockStmt) QueryRowContext(arg0 context.Context, arg1 ...any) *sql.Row {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "QueryRowContext", varargs...)
|
||||
ret0, _ := ret[0].(*sql.Row)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// QueryRowContext indicates an expected call of QueryRowContext.
|
||||
func (mr *MockStmtMockRecorder) QueryRowContext(arg0 any, arg1 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockStmt)(nil).QueryRowContext), varargs...)
|
||||
}
|
||||
|
@@ -68,6 +68,7 @@ var (
|
||||
subfieldRegex = regexp.MustCompile(`([a-zA-Z]+)|(\[[-a-zA-Z./]+])|(\[[0-9]+])`)
|
||||
|
||||
ErrInvalidColumn = errors.New("supplied column is invalid")
|
||||
ErrTooOld = errors.New("resourceversion too old")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -144,11 +145,11 @@ func NewListOptionIndexer(ctx context.Context, fields [][]string, s Store, names
|
||||
watchers: make(map[*watchKey]*watcher),
|
||||
}
|
||||
l.RegisterAfterAdd(l.addIndexFields)
|
||||
l.RegisterAfterAdd(l.addLabels)
|
||||
l.RegisterAfterAdd(l.notifyEventAdded)
|
||||
l.RegisterAfterUpdate(l.addIndexFields)
|
||||
l.RegisterAfterUpdate(l.notifyEventModified)
|
||||
l.RegisterAfterAdd(l.addLabels)
|
||||
l.RegisterAfterUpdate(l.addLabels)
|
||||
l.RegisterAfterUpdate(l.notifyEventModified)
|
||||
l.RegisterAfterDelete(l.deleteFieldsByKey)
|
||||
l.RegisterAfterDelete(l.deleteLabelsByKey)
|
||||
l.RegisterAfterDelete(l.notifyEventDeleted)
|
||||
@@ -264,22 +265,27 @@ func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, events
|
||||
// Even though we're not writing in this transaction, we prevent other writes to SQL
|
||||
// because we don't want to add more events while we're backfilling events, so we don't miss events
|
||||
err := l.WithTransaction(ctx, true, func(tx transaction.Client) error {
|
||||
rowIDRows, err := tx.Stmt(l.findEventsRowByRVStmt).QueryContext(ctx, targetRV)
|
||||
if err != nil {
|
||||
return &db.QueryError{QueryString: l.listEventsAfterQuery, Err: err}
|
||||
}
|
||||
if !rowIDRows.Next() && targetRV != latestRV {
|
||||
return fmt.Errorf("resourceversion too old")
|
||||
rowIDRow := tx.Stmt(l.findEventsRowByRVStmt).QueryRowContext(ctx, targetRV)
|
||||
if err := rowIDRow.Err(); err != nil {
|
||||
return &db.QueryError{QueryString: l.findEventsRowByRVQuery, Err: err}
|
||||
}
|
||||
|
||||
var rowID int
|
||||
rowIDRows.Scan(&rowID)
|
||||
err := rowIDRow.Scan(&rowID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if targetRV != latestRV {
|
||||
return ErrTooOld
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed scan rowid: %w", err)
|
||||
}
|
||||
|
||||
// Backfilling previous events from resourceVersion
|
||||
rows, err := tx.Stmt(l.listEventsAfterStmt).QueryContext(ctx, rowID)
|
||||
if err != nil {
|
||||
return &db.QueryError{QueryString: l.listEventsAfterQuery, Err: err}
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var typ, rv string
|
||||
@@ -311,6 +317,10 @@ func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, events
|
||||
})
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
eventsCh <- event
|
||||
}
|
||||
@@ -318,9 +328,13 @@ func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, events
|
||||
key = l.addWatcher(eventsCh, opts.Filter)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
<-ctx.Done()
|
||||
l.removeWatcher(key)
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func toBytes(obj any) []byte {
|
||||
|
@@ -2199,6 +2199,7 @@ func TestWatchResourceVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
rv string
|
||||
expectedEvents []watch.Event
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
rv: "",
|
||||
@@ -2237,6 +2238,10 @@ func TestWatchResourceVersion(t *testing.T) {
|
||||
rv: rv5,
|
||||
expectedEvents: nil,
|
||||
},
|
||||
{
|
||||
rv: "unknown",
|
||||
expectedErr: ErrTooOld,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
@@ -2245,11 +2250,14 @@ func TestWatchResourceVersion(t *testing.T) {
|
||||
watcherCh, errCh := startWatcher(ctx, loi, test.rv)
|
||||
gotEvents := receiveEvents(watcherCh)
|
||||
|
||||
assert.Equal(t, test.expectedEvents, gotEvents)
|
||||
|
||||
cancel()
|
||||
err := waitStopWatcher(errCh)
|
||||
assert.NoError(t, err)
|
||||
if test.expectedErr != nil {
|
||||
assert.ErrorIs(t, err, ErrTooOld)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.expectedEvents, gotEvents)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -99,6 +99,25 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
|
||||
}
|
||||
|
||||
// QueryRowContext mocks base method.
|
||||
func (m *MockStmt) QueryRowContext(arg0 context.Context, arg1 ...any) *sql.Row {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "QueryRowContext", varargs...)
|
||||
ret0, _ := ret[0].(*sql.Row)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// QueryRowContext indicates an expected call of QueryRowContext.
|
||||
func (mr *MockStmtMockRecorder) QueryRowContext(arg0 any, arg1 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockStmt)(nil).QueryRowContext), varargs...)
|
||||
}
|
||||
|
||||
// MockTXClient is a mock of Client interface.
|
||||
type MockTXClient struct {
|
||||
ctrl *gomock.Controller
|
||||
|
@@ -99,6 +99,25 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
|
||||
}
|
||||
|
||||
// QueryRowContext mocks base method.
|
||||
func (m *MockStmt) QueryRowContext(arg0 context.Context, arg1 ...any) *sql.Row {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "QueryRowContext", varargs...)
|
||||
ret0, _ := ret[0].(*sql.Row)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// QueryRowContext indicates an expected call of QueryRowContext.
|
||||
func (mr *MockStmtMockRecorder) QueryRowContext(arg0 any, arg1 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockStmt)(nil).QueryRowContext), varargs...)
|
||||
}
|
||||
|
||||
// MockTXClient is a mock of Client interface.
|
||||
type MockTXClient struct {
|
||||
ctrl *gomock.Controller
|
||||
|
Reference in New Issue
Block a user