1
0
mirror of https://github.com/rancher/steve.git synced 2025-09-05 09:21:12 +00:00

sql: propagate and use contexts (#465)

Previous SQLite-related code used context.Background() and context.TODO() because it was not developed with context awareness.

This commit propagates the main Steve context so that it can be used when interacting with SQL context-aware functions.

This PR removes all production-code use of context.Background() and context.TODO() and replaces test-code use of TODO with Background.

Contributes to rancher/rancher#47825
This commit is contained in:
Silvio Moioli
2025-02-12 09:46:10 +01:00
committed by GitHub
parent e71f8c455d
commit 3350323f91
16 changed files with 160 additions and 121 deletions

View File

@@ -64,7 +64,7 @@ func TestNewIndexer(t *testing.T) {
store.EXPECT().Prepare(fmt.Sprintf(listByIndexFmt, storeName, storeName))
store.EXPECT().Prepare(fmt.Sprintf(listKeyByIndexFmt, storeName))
store.EXPECT().Prepare(fmt.Sprintf(listIndexValuesFmt, storeName))
indexer, err := NewIndexer(indexers, store)
indexer, err := NewIndexer(context.Background(), indexers, store)
assert.Nil(t, err)
assert.Equal(t, cache.Indexers(indexers), indexer.indexers)
}})
@@ -79,7 +79,7 @@ func TestNewIndexer(t *testing.T) {
}
store.EXPECT().GetName().AnyTimes().Return("someStoreName")
store.EXPECT().WithTransaction(gomock.Any(), true, gomock.Any()).Return(fmt.Errorf("error"))
_, err := NewIndexer(indexers, store)
_, err := NewIndexer(context.Background(), indexers, store)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "NewIndexer() with Client Exec() error on first call to Exec(), should return error", test: func(t *testing.T) {
@@ -103,7 +103,7 @@ func TestNewIndexer(t *testing.T) {
t.Fail()
}
})
_, err := NewIndexer(indexers, store)
_, err := NewIndexer(context.Background(), indexers, store)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "NewIndexer() with Client Exec() error on second call to Exec(), should return error", test: func(t *testing.T) {
@@ -129,7 +129,7 @@ func TestNewIndexer(t *testing.T) {
}
})
_, err := NewIndexer(indexers, store)
_, err := NewIndexer(context.Background(), indexers, store)
assert.NotNil(t, err)
}})
tests = append(tests, testCase{description: "NewIndexer() with Client Commit() error, should return error", test: func(t *testing.T) {
@@ -153,7 +153,7 @@ func TestNewIndexer(t *testing.T) {
t.Fail()
}
})
_, err := NewIndexer(indexers, store)
_, err := NewIndexer(context.Background(), indexers, store)
assert.NotNil(t, err)
}})
t.Parallel()
@@ -177,6 +177,7 @@ func TestAfterUpsert(t *testing.T) {
deleteIndicesStmt := NewMockStmt(gomock.NewController(t))
addIndexStmt := NewMockStmt(gomock.NewController(t))
indexer := &Indexer{
ctx: context.Background(),
Store: store,
indexers: map[string]cache.IndexFunc{
"a": func(obj interface{}) ([]string, error) {
@@ -199,6 +200,7 @@ func TestAfterUpsert(t *testing.T) {
objKey := "key"
deleteIndicesStmt := NewMockStmt(gomock.NewController(t))
indexer := &Indexer{
ctx: context.Background(),
Store: store,
indexers: map[string]cache.IndexFunc{
@@ -221,6 +223,7 @@ func TestAfterUpsert(t *testing.T) {
addIndexStmt := NewMockStmt(gomock.NewController(t))
objKey := "key"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
indexers: map[string]cache.IndexFunc{
"a": func(obj interface{}) ([]string, error) {
@@ -258,6 +261,7 @@ func TestIndex(t *testing.T) {
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
indexers: map[string]cache.IndexFunc{
@@ -268,7 +272,7 @@ func TestIndex(t *testing.T) {
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().QueryForRows(context.Background(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil)
@@ -283,6 +287,7 @@ func TestIndex(t *testing.T) {
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
indexers: map[string]cache.IndexFunc{
@@ -293,7 +298,7 @@ func TestIndex(t *testing.T) {
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().QueryForRows(context.Background(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject, testObject}, nil)
@@ -308,6 +313,7 @@ func TestIndex(t *testing.T) {
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
indexers: map[string]cache.IndexFunc{
@@ -318,7 +324,7 @@ func TestIndex(t *testing.T) {
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().QueryForRows(context.Background(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{}, nil)
@@ -332,6 +338,7 @@ func TestIndex(t *testing.T) {
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
indexers: map[string]cache.IndexFunc{
@@ -351,6 +358,7 @@ func TestIndex(t *testing.T) {
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
indexers: map[string]cache.IndexFunc{
@@ -361,7 +369,7 @@ func TestIndex(t *testing.T) {
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(nil, fmt.Errorf("error"))
store.EXPECT().QueryForRows(context.Background(), indexer.listByIndexStmt, indexName, objKey).Return(nil, fmt.Errorf("error"))
_, err := indexer.Index(indexName, testObject)
assert.NotNil(t, err)
}})
@@ -372,6 +380,7 @@ func TestIndex(t *testing.T) {
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
indexers: map[string]cache.IndexFunc{
@@ -382,7 +391,7 @@ func TestIndex(t *testing.T) {
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().QueryForRows(context.Background(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, fmt.Errorf("error"))
@@ -396,6 +405,7 @@ func TestIndex(t *testing.T) {
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
indexers: map[string]cache.IndexFunc{
@@ -409,7 +419,7 @@ func TestIndex(t *testing.T) {
store.EXPECT().GetName().Return("name")
stmt := &sql.Stmt{}
store.EXPECT().Prepare(fmt.Sprintf(selectQueryFmt, "name", ", ?")).Return(stmt)
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey, objKey+"2").Return(rows, nil)
store.EXPECT().QueryForRows(context.Background(), indexer.listByIndexStmt, indexName, objKey, objKey+"2").Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil)
@@ -439,12 +449,13 @@ func TestByIndex(t *testing.T) {
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().QueryForRows(context.Background(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, nil)
@@ -459,12 +470,13 @@ func TestByIndex(t *testing.T) {
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().QueryForRows(context.Background(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject, testObject}, nil)
@@ -479,12 +491,13 @@ func TestByIndex(t *testing.T) {
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().QueryForRows(context.Background(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{}, nil)
@@ -498,11 +511,12 @@ func TestByIndex(t *testing.T) {
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(nil, fmt.Errorf("error"))
store.EXPECT().QueryForRows(context.Background(), indexer.listByIndexStmt, indexName, objKey).Return(nil, fmt.Errorf("error"))
_, err := indexer.ByIndex(indexName, objKey)
assert.NotNil(t, err)
}})
@@ -513,12 +527,13 @@ func TestByIndex(t *testing.T) {
objKey := "key"
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
}
testObject := testStoreObject{Id: "something", Val: "a"}
store.EXPECT().QueryForRows(context.TODO(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().QueryForRows(context.Background(), indexer.listByIndexStmt, indexName, objKey).Return(rows, nil)
store.EXPECT().GetType().Return(reflect.TypeOf(testObject))
store.EXPECT().GetShouldEncrypt().Return(false)
store.EXPECT().ReadObjects(rows, reflect.TypeOf(testObject), false).Return([]any{testObject}, fmt.Errorf("error"))
@@ -545,10 +560,11 @@ func TestListIndexFuncValues(t *testing.T) {
listStmt := &sql.Stmt{}
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
}
store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(rows, nil)
store.EXPECT().QueryForRows(context.Background(), indexer.listIndexValuesStmt, indexName).Return(rows, nil)
store.EXPECT().ReadStrings(rows).Return([]string{"somestrings"}, nil)
vals := indexer.ListIndexFuncValues(indexName)
assert.Equal(t, []string{"somestrings"}, vals)
@@ -558,10 +574,11 @@ func TestListIndexFuncValues(t *testing.T) {
listStmt := &sql.Stmt{}
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
}
store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(nil, fmt.Errorf("error"))
store.EXPECT().QueryForRows(context.Background(), indexer.listIndexValuesStmt, indexName).Return(nil, fmt.Errorf("error"))
assert.Panics(t, func() { indexer.ListIndexFuncValues(indexName) })
}})
tests = append(tests, testCase{description: "ListIndexFuncvalues() with ReadStrings() error returned from store, should panic", test: func(t *testing.T) {
@@ -570,10 +587,11 @@ func TestListIndexFuncValues(t *testing.T) {
listStmt := &sql.Stmt{}
indexName := "someindexname"
indexer := &Indexer{
ctx: context.Background(),
Store: store,
listByIndexStmt: listStmt,
}
store.EXPECT().QueryForRows(context.TODO(), indexer.listIndexValuesStmt, indexName).Return(rows, nil)
store.EXPECT().QueryForRows(context.Background(), indexer.listIndexValuesStmt, indexName).Return(rows, nil)
store.EXPECT().ReadStrings(rows).Return([]string{"somestrings"}, fmt.Errorf("error"))
assert.Panics(t, func() { indexer.ListIndexFuncValues(indexName) })
}})
@@ -599,6 +617,7 @@ func TestGetIndexers(t *testing.T) {
},
}
indexer := &Indexer{
ctx: context.Background(),
indexers: expectedIndexers,
}
indexers := indexer.GetIndexers()