diff --git a/pkg/sqlcache/db/client.go b/pkg/sqlcache/db/client.go index 5675bffb..043bbd94 100644 --- a/pkg/sqlcache/db/client.go +++ b/pkg/sqlcache/db/client.go @@ -65,16 +65,33 @@ func (c *client) WithTransaction(ctx context.Context, forWriting bool, f WithTra return err } - err = f(transaction.NewClient(tx)) - - if err != nil { - rerr := tx.Rollback() - err = errors.Join(err, rerr) - } else { - cerr := tx.Commit() - err = errors.Join(err, cerr) + if err = f(transaction.NewClient(tx)); err != nil { + rerr := c.rollback(ctx, tx) + return errors.Join(err, rerr) } + err = c.commit(ctx, tx) + if err != nil { + // When the context.Context given to BeginTx is canceled, then the + // Tx is rolled back already, so rolling back again could have failed. + return err + } + return nil +} + +func (c *client) commit(ctx context.Context, tx *sql.Tx) error { + err := tx.Commit() + if errors.Is(err, sql.ErrTxDone) && ctx.Err() == context.Canceled { + return fmt.Errorf("commit failed due to canceled context") + } + return err +} + +func (c *client) rollback(ctx context.Context, tx *sql.Tx) error { + err := tx.Rollback() + if errors.Is(err, sql.ErrTxDone) && ctx.Err() == context.Canceled { + return fmt.Errorf("rollback failed due to canceled context") + } return err } diff --git a/pkg/sqlcache/db/transaction/transaction.go b/pkg/sqlcache/db/transaction/transaction.go index cc9e0901..2b92d958 100644 --- a/pkg/sqlcache/db/transaction/transaction.go +++ b/pkg/sqlcache/db/transaction/transaction.go @@ -41,4 +41,5 @@ type Stmt interface { Exec(args ...any) (sql.Result, error) Query(args ...any) (*sql.Rows, error) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, args ...any) *sql.Row } diff --git a/pkg/sqlcache/db/transaction_mocks_test.go b/pkg/sqlcache/db/transaction_mocks_test.go index 4a64cc16..c431ca47 100644 --- a/pkg/sqlcache/db/transaction_mocks_test.go +++ b/pkg/sqlcache/db/transaction_mocks_test.go @@ -155,3 +155,22 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call varargs := append([]any{arg0}, arg1...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) } + +// QueryRowContext mocks base method. +func (m *MockStmt) QueryRowContext(arg0 context.Context, arg1 ...any) *sql.Row { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryRowContext", varargs...) + ret0, _ := ret[0].(*sql.Row) + return ret0 +} + +// QueryRowContext indicates an expected call of QueryRowContext. +func (mr *MockStmtMockRecorder) QueryRowContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockStmt)(nil).QueryRowContext), varargs...) +} diff --git a/pkg/sqlcache/informer/listoption_indexer.go b/pkg/sqlcache/informer/listoption_indexer.go index 4d587a99..e8bfd0e0 100644 --- a/pkg/sqlcache/informer/listoption_indexer.go +++ b/pkg/sqlcache/informer/listoption_indexer.go @@ -68,6 +68,7 @@ var ( subfieldRegex = regexp.MustCompile(`([a-zA-Z]+)|(\[[-a-zA-Z./]+])|(\[[0-9]+])`) ErrInvalidColumn = errors.New("supplied column is invalid") + ErrTooOld = errors.New("resourceversion too old") ) const ( @@ -144,11 +145,11 @@ func NewListOptionIndexer(ctx context.Context, fields [][]string, s Store, names watchers: make(map[*watchKey]*watcher), } l.RegisterAfterAdd(l.addIndexFields) + l.RegisterAfterAdd(l.addLabels) l.RegisterAfterAdd(l.notifyEventAdded) l.RegisterAfterUpdate(l.addIndexFields) - l.RegisterAfterUpdate(l.notifyEventModified) - l.RegisterAfterAdd(l.addLabels) l.RegisterAfterUpdate(l.addLabels) + l.RegisterAfterUpdate(l.notifyEventModified) l.RegisterAfterDelete(l.deleteFieldsByKey) l.RegisterAfterDelete(l.deleteLabelsByKey) l.RegisterAfterDelete(l.notifyEventDeleted) @@ -264,22 +265,27 @@ func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, events // Even though we're not writing in this transaction, we prevent other writes to SQL // because we don't want to add more events while we're backfilling events, so we don't miss events err := l.WithTransaction(ctx, true, func(tx transaction.Client) error { - rowIDRows, err := tx.Stmt(l.findEventsRowByRVStmt).QueryContext(ctx, targetRV) - if err != nil { - return &db.QueryError{QueryString: l.listEventsAfterQuery, Err: err} - } - if !rowIDRows.Next() && targetRV != latestRV { - return fmt.Errorf("resourceversion too old") + rowIDRow := tx.Stmt(l.findEventsRowByRVStmt).QueryRowContext(ctx, targetRV) + if err := rowIDRow.Err(); err != nil { + return &db.QueryError{QueryString: l.findEventsRowByRVQuery, Err: err} } var rowID int - rowIDRows.Scan(&rowID) + err := rowIDRow.Scan(&rowID) + if errors.Is(err, sql.ErrNoRows) { + if targetRV != latestRV { + return ErrTooOld + } + } else if err != nil { + return fmt.Errorf("failed scan rowid: %w", err) + } // Backfilling previous events from resourceVersion rows, err := tx.Stmt(l.listEventsAfterStmt).QueryContext(ctx, rowID) if err != nil { return &db.QueryError{QueryString: l.listEventsAfterQuery, Err: err} } + defer rows.Close() for rows.Next() { var typ, rv string @@ -311,6 +317,10 @@ func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, events }) } + if err := rows.Err(); err != nil { + return err + } + for _, event := range events { eventsCh <- event } @@ -318,9 +328,13 @@ func (l *ListOptionIndexer) Watch(ctx context.Context, opts WatchOptions, events key = l.addWatcher(eventsCh, opts.Filter) return nil }) + if err != nil { + return err + } + <-ctx.Done() l.removeWatcher(key) - return err + return nil } func toBytes(obj any) []byte { diff --git a/pkg/sqlcache/informer/listoption_indexer_test.go b/pkg/sqlcache/informer/listoption_indexer_test.go index d645a635..97d05ea4 100644 --- a/pkg/sqlcache/informer/listoption_indexer_test.go +++ b/pkg/sqlcache/informer/listoption_indexer_test.go @@ -2199,6 +2199,7 @@ func TestWatchResourceVersion(t *testing.T) { tests := []struct { rv string expectedEvents []watch.Event + expectedErr error }{ { rv: "", @@ -2237,6 +2238,10 @@ func TestWatchResourceVersion(t *testing.T) { rv: rv5, expectedEvents: nil, }, + { + rv: "unknown", + expectedErr: ErrTooOld, + }, } for _, test := range tests { @@ -2245,11 +2250,14 @@ func TestWatchResourceVersion(t *testing.T) { watcherCh, errCh := startWatcher(ctx, loi, test.rv) gotEvents := receiveEvents(watcherCh) - assert.Equal(t, test.expectedEvents, gotEvents) - cancel() err := waitStopWatcher(errCh) - assert.NoError(t, err) + if test.expectedErr != nil { + assert.ErrorIs(t, err, ErrTooOld) + } else { + assert.NoError(t, err) + assert.Equal(t, test.expectedEvents, gotEvents) + } }) } } diff --git a/pkg/sqlcache/informer/transaction_mocks_test.go b/pkg/sqlcache/informer/transaction_mocks_test.go index 64e885ee..4a86c446 100644 --- a/pkg/sqlcache/informer/transaction_mocks_test.go +++ b/pkg/sqlcache/informer/transaction_mocks_test.go @@ -99,6 +99,25 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) } +// QueryRowContext mocks base method. +func (m *MockStmt) QueryRowContext(arg0 context.Context, arg1 ...any) *sql.Row { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryRowContext", varargs...) + ret0, _ := ret[0].(*sql.Row) + return ret0 +} + +// QueryRowContext indicates an expected call of QueryRowContext. +func (mr *MockStmtMockRecorder) QueryRowContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockStmt)(nil).QueryRowContext), varargs...) +} + // MockTXClient is a mock of Client interface. type MockTXClient struct { ctrl *gomock.Controller diff --git a/pkg/sqlcache/store/transaction_mocks_test.go b/pkg/sqlcache/store/transaction_mocks_test.go index 85a3e177..bd2fff44 100644 --- a/pkg/sqlcache/store/transaction_mocks_test.go +++ b/pkg/sqlcache/store/transaction_mocks_test.go @@ -99,6 +99,25 @@ func (mr *MockStmtMockRecorder) QueryContext(arg0 any, arg1 ...any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryContext", reflect.TypeOf((*MockStmt)(nil).QueryContext), varargs...) } +// QueryRowContext mocks base method. +func (m *MockStmt) QueryRowContext(arg0 context.Context, arg1 ...any) *sql.Row { + m.ctrl.T.Helper() + varargs := []any{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryRowContext", varargs...) + ret0, _ := ret[0].(*sql.Row) + return ret0 +} + +// QueryRowContext indicates an expected call of QueryRowContext. +func (mr *MockStmtMockRecorder) QueryRowContext(arg0 any, arg1 ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRowContext", reflect.TypeOf((*MockStmt)(nil).QueryRowContext), varargs...) +} + // MockTXClient is a mock of Client interface. type MockTXClient struct { ctrl *gomock.Controller