mirror of
https://github.com/rancher/steve.git
synced 2025-09-05 09:21:12 +00:00
Handle transaction failure due to canceled context.Context (#662)
* Re-order SQL event hooks so events are last * Add QueryRowContext for single line queries * Add test case for unknown resource version * Properly check rows and close it * More accurate error message when context.Context is canceled * Re-order test check
This commit is contained in:
@@ -65,16 +65,33 @@ func (c *client) WithTransaction(ctx context.Context, forWriting bool, f WithTra
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = f(transaction.NewClient(tx))
|
if err = f(transaction.NewClient(tx)); err != nil {
|
||||||
|
rerr := c.rollback(ctx, tx)
|
||||||
if err != nil {
|
return errors.Join(err, rerr)
|
||||||
rerr := tx.Rollback()
|
|
||||||
err = errors.Join(err, rerr)
|
|
||||||
} else {
|
|
||||||
cerr := tx.Commit()
|
|
||||||
err = errors.Join(err, cerr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = c.commit(ctx, tx)
|
||||||
|
if err != nil {
|
||||||
|
// When the context.Context given to BeginTx is canceled, then the
|
||||||
|
// Tx is rolled back already, so rolling back again could have failed.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) commit(ctx context.Context, tx *sql.Tx) error {
|
||||||
|
err := tx.Commit()
|
||||||
|
if errors.Is(err, sql.ErrTxDone) && ctx.Err() == context.Canceled {
|
||||||
|
return fmt.Errorf("commit failed due to canceled context")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *client) rollback(ctx context.Context, tx *sql.Tx) error {
|
||||||
|
err := tx.Rollback()
|
||||||
|
if errors.Is(err, sql.ErrTxDone) && ctx.Err() == context.Canceled {
|
||||||
|
return fmt.Errorf("rollback failed due to canceled context")
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -41,4 +41,5 @@ type Stmt interface {
|
|||||||
Exec(args ...any) (sql.Result, error)
|
Exec(args ...any) (sql.Result, error)
|
||||||
Query(args ...any) (*sql.Rows, error)
|
Query(args ...any) (*sql.Rows, error)
|
||||||
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
|
QueryContext(ctx context.Context, args ...any) (*sql.Rows, error)
|
||||||
|
QueryRowContext(ctx context.Context, args ...any) *sql.Row
|
||||||
}
|
}
|
||||||
|
@@ -155,3 +155,22 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call
|
|||||||
varargs := append([]any{arg0}, arg1...)
|
varargs := append([]any{arg0}, arg1...)
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryRowContext mocks base method.
|
||||||
|
func (m *MockStmt) QueryRowContext(arg0 context.Context, arg1 ...any) *sql.Row {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []any{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
ret := m.ctrl.Call(m, "QueryRowContext", varargs...)
|
||||||
|
ret0, _ := ret[0].(*sql.Row)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRowContext indicates an expected call of QueryRowContext.
|
||||||
|
func (mr *MockStmtMockRecorder) QueryRowContext(arg0 any, arg1 ...any) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
varargs := append([]any{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockStmt)(nil).QueryRowContext), varargs...)
|
||||||
|
}
|
||||||
|
@@ -68,6 +68,7 @@ var (
|
|||||||
subfieldRegex = regexp.MustCompile(`([a-zA-Z]+)|(\[[-a-zA-Z./]+])|(\[[0-9]+])`)
|
subfieldRegex = regexp.MustCompile(`([a-zA-Z]+)|(\[[-a-zA-Z./]+])|(\[[0-9]+])`)
|
||||||
|
|
||||||
ErrInvalidColumn = errors.New("supplied column is invalid")
|
ErrInvalidColumn = errors.New("supplied column is invalid")
|
||||||
|
ErrTooOld = errors.New("resourceversion too old")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -144,11 +145,11 @@ func NewListOptionIndexer(ctx context.Context, fields [][]string, s Store, names
|
|||||||
watchers: make(map[*watchKey]*watcher),
|
watchers: make(map[*watchKey]*watcher),
|
||||||
}
|
}
|
||||||
l.RegisterAfterAdd(l.addIndexFields)
|
l.RegisterAfterAdd(l.addIndexFields)
|
||||||
|
l.RegisterAfterAdd(l.addLabels)
|
||||||
l.RegisterAfterAdd(l.notifyEventAdded)
|
l.RegisterAfterAdd(l.notifyEventAdded)
|
||||||
l.RegisterAfterUpdate(l.addIndexFields)
|
l.RegisterAfterUpdate(l.addIndexFields)
|
||||||
l.RegisterAfterUpdate(l.notifyEventModified)
|
|
||||||
l.RegisterAfterAdd(l.addLabels)
|
|
||||||
l.RegisterAfterUpdate(l.addLabels)
|
l.RegisterAfterUpdate(l.addLabels)
|
||||||
|
l.RegisterAfterUpdate(l.notifyEventModified)
|
||||||
l.RegisterAfterDelete(l.deleteFieldsByKey)
|
l.RegisterAfterDelete(l.deleteFieldsByKey)
|
||||||
l.RegisterAfterDelete(l.deleteLabelsByKey)
|
l.RegisterAfterDelete(l.deleteLabelsByKey)
|
||||||
l.RegisterAfterDelete(l.notifyEventDeleted)
|
l.RegisterAfterDelete(l.notifyEventDeleted)
|
||||||
@@ -264,22 +265,27 @@ func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, events
|
|||||||
// Even though we're not writing in this transaction, we prevent other writes to SQL
|
// Even though we're not writing in this transaction, we prevent other writes to SQL
|
||||||
// because we don't want to add more events while we're backfilling events, so we don't miss events
|
// because we don't want to add more events while we're backfilling events, so we don't miss events
|
||||||
err := l.WithTransaction(ctx, true, func(tx transaction.Client) error {
|
err := l.WithTransaction(ctx, true, func(tx transaction.Client) error {
|
||||||
rowIDRows, err := tx.Stmt(l.findEventsRowByRVStmt).QueryContext(ctx, targetRV)
|
rowIDRow := tx.Stmt(l.findEventsRowByRVStmt).QueryRowContext(ctx, targetRV)
|
||||||
if err != nil {
|
if err := rowIDRow.Err(); err != nil {
|
||||||
return &db.QueryError{QueryString: l.listEventsAfterQuery, Err: err}
|
return &db.QueryError{QueryString: l.findEventsRowByRVQuery, Err: err}
|
||||||
}
|
|
||||||
if !rowIDRows.Next() && targetRV != latestRV {
|
|
||||||
return fmt.Errorf("resourceversion too old")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var rowID int
|
var rowID int
|
||||||
rowIDRows.Scan(&rowID)
|
err := rowIDRow.Scan(&rowID)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
if targetRV != latestRV {
|
||||||
|
return ErrTooOld
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
return fmt.Errorf("failed scan rowid: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Backfilling previous events from resourceVersion
|
// Backfilling previous events from resourceVersion
|
||||||
rows, err := tx.Stmt(l.listEventsAfterStmt).QueryContext(ctx, rowID)
|
rows, err := tx.Stmt(l.listEventsAfterStmt).QueryContext(ctx, rowID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &db.QueryError{QueryString: l.listEventsAfterQuery, Err: err}
|
return &db.QueryError{QueryString: l.listEventsAfterQuery, Err: err}
|
||||||
}
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var typ, rv string
|
var typ, rv string
|
||||||
@@ -311,6 +317,10 @@ func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, events
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
eventsCh <- event
|
eventsCh <- event
|
||||||
}
|
}
|
||||||
@@ -318,9 +328,13 @@ func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, events
|
|||||||
key = l.addWatcher(eventsCh, opts.Filter)
|
key = l.addWatcher(eventsCh, opts.Filter)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
l.removeWatcher(key)
|
l.removeWatcher(key)
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toBytes(obj any) []byte {
|
func toBytes(obj any) []byte {
|
||||||
|
@@ -2199,6 +2199,7 @@ func TestWatchResourceVersion(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
rv string
|
rv string
|
||||||
expectedEvents []watch.Event
|
expectedEvents []watch.Event
|
||||||
|
expectedErr error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
rv: "",
|
rv: "",
|
||||||
@@ -2237,6 +2238,10 @@ func TestWatchResourceVersion(t *testing.T) {
|
|||||||
rv: rv5,
|
rv: rv5,
|
||||||
expectedEvents: nil,
|
expectedEvents: nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
rv: "unknown",
|
||||||
|
expectedErr: ErrTooOld,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@@ -2245,11 +2250,14 @@ func TestWatchResourceVersion(t *testing.T) {
|
|||||||
watcherCh, errCh := startWatcher(ctx, loi, test.rv)
|
watcherCh, errCh := startWatcher(ctx, loi, test.rv)
|
||||||
gotEvents := receiveEvents(watcherCh)
|
gotEvents := receiveEvents(watcherCh)
|
||||||
|
|
||||||
assert.Equal(t, test.expectedEvents, gotEvents)
|
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
err := waitStopWatcher(errCh)
|
err := waitStopWatcher(errCh)
|
||||||
|
if test.expectedErr != nil {
|
||||||
|
assert.ErrorIs(t, err, ErrTooOld)
|
||||||
|
} else {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, test.expectedEvents, gotEvents)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -99,6 +99,25 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryRowContext mocks base method.
|
||||||
|
func (m *MockStmt) QueryRowContext(arg0 context.Context, arg1 ...any) *sql.Row {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []any{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
ret := m.ctrl.Call(m, "QueryRowContext", varargs...)
|
||||||
|
ret0, _ := ret[0].(*sql.Row)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRowContext indicates an expected call of QueryRowContext.
|
||||||
|
func (mr *MockStmtMockRecorder) QueryRowContext(arg0 any, arg1 ...any) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
varargs := append([]any{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockStmt)(nil).QueryRowContext), varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
// MockTXClient is a mock of Client interface.
|
// MockTXClient is a mock of Client interface.
|
||||||
type MockTXClient struct {
|
type MockTXClient struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
|
@@ -99,6 +99,25 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryRowContext mocks base method.
|
||||||
|
func (m *MockStmt) QueryRowContext(arg0 context.Context, arg1 ...any) *sql.Row {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []any{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
ret := m.ctrl.Call(m, "QueryRowContext", varargs...)
|
||||||
|
ret0, _ := ret[0].(*sql.Row)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRowContext indicates an expected call of QueryRowContext.
|
||||||
|
func (mr *MockStmtMockRecorder) QueryRowContext(arg0 any, arg1 ...any) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
varargs := append([]any{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockStmt)(nil).QueryRowContext), varargs...)
|
||||||
|
}
|
||||||
|
|
||||||
// MockTXClient is a mock of Client interface.
|
// MockTXClient is a mock of Client interface.
|
||||||
type MockTXClient struct {
|
type MockTXClient struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
|
Reference in New Issue
Block a user