diff --git a/pkg/sqlcache/db/client.go b/pkg/sqlcache/db/client.go index 043bbd94..8a5554a7 100644 --- a/pkg/sqlcache/db/client.go +++ b/pkg/sqlcache/db/client.go @@ -8,11 +8,14 @@ import ( "bytes" "context" "database/sql" + "database/sql/driver" "encoding/gob" "fmt" "io/fs" "os" "reflect" + "strconv" + "strings" "sync" "errors" @@ -21,11 +24,15 @@ import ( // needed for drivers _ "modernc.org/sqlite" + sqlite "modernc.org/sqlite" ) const ( // InformerObjectCacheDBPath is where SQLite's object database file will be stored relative to process running steve - InformerObjectCacheDBPath = "informer_object_cache.db" + // It's given in two parts because the root is used as the suffix for the tempfile, and then we'll add a ".db" after it. + // In non-test mode, we can append the ".db" extension right here. + InformerObjectCacheDBPathRoot = "informer_object_cache" + InformerObjectCacheDBPath = InformerObjectCacheDBPathRoot + ".db" informerObjectCachePerms fs.FileMode = 0o600 ) @@ -40,7 +47,7 @@ type Client interface { ReadInt(rows Rows) (int, error) Upsert(tx transaction.Client, stmt *sql.Stmt, key string, obj any, shouldEncrypt bool) error CloseStmt(closable Closable) error - NewConnection() error + NewConnection(isTemp bool) (string, error) } // WithTransaction runs f within a transaction. @@ -155,22 +162,22 @@ type Decryptor interface { Decrypt([]byte, []byte, uint32) ([]byte, error) } -// NewClient returns a client. If the given connection is nil then a default one will be created. -func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor) (Client, error) { +// NewClient returns a client and the path to the database. If the given connection is nil then a default one will be created. +func NewClient(c Connection, encryptor Encryptor, decryptor Decryptor, useTempDir bool) (Client, string, error) { client := &client{ encryptor: encryptor, decryptor: decryptor, } if c != nil { client.conn = c - return client, nil + return client, "", nil } - err := client.NewConnection() + dbPath, err := client.NewConnection(useTempDir) if err != nil { - return nil, err + return nil, "", err } - return client, nil + return client, dbPath, nil } // Prepare prepares the given string into a sql statement on the client's connection. @@ -353,27 +360,43 @@ func closeRowsOnError(rows Rows, err error) error { // NewConnection checks for currently existing connection, closes one if it exists, removes any relevant db files, and opens a new connection which subsequently // creates new files. -func (c *client) NewConnection() error { +func (c *client) NewConnection(useTempDir bool) (string, error) { c.connLock.Lock() defer c.connLock.Unlock() if c.conn != nil { err := c.conn.Close() if err != nil { - return err + return "", err } } - err := os.RemoveAll(InformerObjectCacheDBPath) - if err != nil { - return err + if !useTempDir { + err := os.RemoveAll(InformerObjectCacheDBPath) + if err != nil { + return "", err + } } // Set the permissions in advance, because we can't control them if // the file is created by a sql.Open call instead. - if err := touchFile(InformerObjectCacheDBPath, informerObjectCachePerms); err != nil { - return nil + var dbPath string + if useTempDir { + dir := os.TempDir() + f, err := os.CreateTemp(dir, InformerObjectCacheDBPathRoot) + if err != nil { + return "", err + } + path := f.Name() + dbPath = path + ".db" + f.Close() + os.Remove(path) + } else { + dbPath = InformerObjectCacheDBPath + } + if err := touchFile(dbPath, informerObjectCachePerms); err != nil { + return dbPath, nil } - sqlDB, err := sql.Open("sqlite", "file:"+InformerObjectCacheDBPath+"?"+ + sqlDB, err := sql.Open("sqlite", "file:"+dbPath+"?"+ // open SQLite file in read-write mode, creating it if it does not exist "mode=rwc&"+ // use the WAL journal mode for consistency and efficiency @@ -390,11 +413,45 @@ func (c *client) NewConnection() error { // of BeginTx "_txlock=immediate") if err != nil { - return err + return dbPath, err } - + sqlite.RegisterDeterministicScalarFunction( + "extractBarredValue", + 2, + func(ctx *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) { + var arg1 string + var arg2 int + switch argTyped := args[0].(type) { + case string: + arg1 = argTyped + case []byte: + arg1 = string(argTyped) + default: + return nil, fmt.Errorf("unsupported type for arg1: expected a string, got :%T", args[0]) + } + var err error + switch argTyped := args[1].(type) { + case int: + arg2 = argTyped + case string: + arg2, err = strconv.Atoi(argTyped) + case []byte: + arg2, err = strconv.Atoi(string(argTyped)) + default: + return nil, fmt.Errorf("unsupported type for arg2: expected an int, got: %T", args[0]) + } + if err != nil { + return nil, fmt.Errorf("problem with arg2: %w", err) + } + parts := strings.Split(arg1, "|") + if arg2 >= len(parts) || arg2 < 0 { + return "", nil + } + return parts[arg2], nil + }, + ) c.conn = sqlDB - return nil + return dbPath, nil } // This acts like "touch" for both existing files and non-existing files. diff --git a/pkg/sqlcache/db/client_test.go b/pkg/sqlcache/db/client_test.go index 4eef36e6..65d71cbc 100644 --- a/pkg/sqlcache/db/client_test.go +++ b/pkg/sqlcache/db/client_test.go @@ -43,7 +43,7 @@ func TestNewClient(t *testing.T) { encryptor: e, decryptor: d, } - client, err := NewClient(c, e, d) + client, _, err := NewClient(c, e, d, false) assert.Nil(t, err) assert.Equal(t, expectedClient, client) }, @@ -527,7 +527,7 @@ func TestNewConnection(t *testing.T) { client := SetupClient(t, c, e, d) c.EXPECT().Close().Return(nil) - err := client.NewConnection() + dbPath, err := client.NewConnection(true) assert.Nil(t, err) // Create a transaction to ensure that the file is written to disk. @@ -536,10 +536,10 @@ func TestNewConnection(t *testing.T) { }) assert.NoError(t, err) - assert.FileExists(t, InformerObjectCacheDBPath) - assertFileHasPermissions(t, InformerObjectCacheDBPath, 0600) + assert.FileExists(t, dbPath) + assertFileHasPermissions(t, dbPath, 0600) - err = os.Remove(InformerObjectCacheDBPath) + err = os.Remove(dbPath) if err != nil { assert.Fail(t, "could not remove object cache path after test") } @@ -581,7 +581,8 @@ func SetupMockRows(t *testing.T) *MockRows { } func SetupClient(t *testing.T, connection Connection, encryptor Encryptor, decryptor Decryptor) Client { - c, _ := NewClient(connection, encryptor, decryptor) + // No need to specify temp dir for this client because the connection is mocked + c, _, _ := NewClient(connection, encryptor, decryptor, false) return c } diff --git a/pkg/sqlcache/informer/db_mocks_test.go b/pkg/sqlcache/informer/db_mocks_test.go index 63fa214f..cf089c82 100644 --- a/pkg/sqlcache/informer/db_mocks_test.go +++ b/pkg/sqlcache/informer/db_mocks_test.go @@ -140,17 +140,18 @@ func (mr *MockClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { } // NewConnection mocks base method. -func (m *MockClient) NewConnection() error { +func (m *MockClient) NewConnection(arg0 bool) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewConnection") - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "NewConnection", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 } // NewConnection indicates an expected call of NewConnection. -func (mr *MockClientMockRecorder) NewConnection() *gomock.Call { +func (mr *MockClientMockRecorder) NewConnection(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockClient)(nil).NewConnection)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockClient)(nil).NewConnection), arg0) } // Prepare mocks base method. diff --git a/pkg/sqlcache/informer/factory/db_mocks_test.go b/pkg/sqlcache/informer/factory/db_mocks_test.go index 76fca697..e6905f15 100644 --- a/pkg/sqlcache/informer/factory/db_mocks_test.go +++ b/pkg/sqlcache/informer/factory/db_mocks_test.go @@ -57,17 +57,18 @@ func (mr *MockClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { } // NewConnection mocks base method. -func (m *MockClient) NewConnection() error { +func (m *MockClient) NewConnection(arg0 bool) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewConnection") - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "NewConnection", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 } // NewConnection indicates an expected call of NewConnection. -func (mr *MockClientMockRecorder) NewConnection() *gomock.Call { +func (mr *MockClientMockRecorder) NewConnection(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockClient)(nil).NewConnection)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockClient)(nil).NewConnection), arg0) } // Prepare mocks base method. diff --git a/pkg/sqlcache/informer/factory/informer_factory.go b/pkg/sqlcache/informer/factory/informer_factory.go index 18b2a962..f855d6ea 100644 --- a/pkg/sqlcache/informer/factory/informer_factory.go +++ b/pkg/sqlcache/informer/factory/informer_factory.go @@ -88,7 +88,7 @@ func NewCacheFactory(opts CacheFactoryOptions) (*CacheFactory, error) { if err != nil { return nil, err } - dbClient, err := db.NewClient(nil, m, m) + dbClient, _, err := db.NewClient(nil, m, m, false) if err != nil { return nil, err } @@ -204,7 +204,7 @@ func (f *CacheFactory) Reset() error { f.informers = make(map[schema.GroupVersionKind]*guardedInformer) // finally, reset the DB connection - err := f.dbClient.NewConnection() + _, err := f.dbClient.NewConnection(false) if err != nil { return err } diff --git a/pkg/sqlcache/informer/listoption_indexer.go b/pkg/sqlcache/informer/listoption_indexer.go index 2e3f159a..81e00db4 100644 --- a/pkg/sqlcache/informer/listoption_indexer.go +++ b/pkg/sqlcache/informer/listoption_indexer.go @@ -68,9 +68,10 @@ type ListOptionIndexer struct { } var ( - defaultIndexedFields = []string{"metadata.name", "metadata.creationTimestamp"} - defaultIndexNamespaced = "metadata.namespace" - subfieldRegex = regexp.MustCompile(`([a-zA-Z]+)|(\[[-a-zA-Z./]+])|(\[[0-9]+])`) + defaultIndexedFields = []string{"metadata.name", "metadata.creationTimestamp"} + defaultIndexNamespaced = "metadata.namespace" + subfieldRegex = regexp.MustCompile(`([a-zA-Z]+)|(\[[-a-zA-Z./]+])|(\[[0-9]+])`) + containsNonNumericRegex = regexp.MustCompile(`\D`) ErrInvalidColumn = errors.New("supplied column is invalid") ErrTooOld = errors.New("resourceversion too old") @@ -725,15 +726,15 @@ func (l *ListOptionIndexer) constructQuery(lo *sqltypes.ListOptions, partitions orderByClauses = append(orderByClauses, clause) params = append(params, sortParam) } else { - columnName := toColumnName(fields) - if err := l.validateColumn(columnName); err != nil { + fieldEntry, err := l.getValidFieldEntry("f", fields) + if err != nil { return queryInfo, err } direction := "ASC" if sortDirective.Order == sqltypes.DESC { direction = "DESC" } - orderByClauses = append(orderByClauses, fmt.Sprintf(`f."%s" %s`, columnName, direction)) + orderByClauses = append(orderByClauses, fmt.Sprintf("%s %s", fieldEntry, direction)) } } query += "\n ORDER BY " @@ -854,6 +855,49 @@ func (l *ListOptionIndexer) validateColumn(column string) error { return fmt.Errorf("column is invalid [%s]: %w", column, ErrInvalidColumn) } +// Suppose the query access something like 'spec.containers[3].image' but only +// spec.containers.image is specified in the index. If `spec.containers` is +// an array, then spec.containers.image is a pseudo-array of |-separated strings, +// and we can use our custom registered extractBarredValue function to extract the +// desired substring. +// +// The index can appear anywhere in the list of fields after the first entry, +// but we always end up with a |-separated list of substrings. Most of the time +// the index will be the second-last entry, but we lose nothing allowing for any +// position. +// Indices are 0-based. + +func (l *ListOptionIndexer) getValidFieldEntry(prefix string, fields []string) (string, error) { + columnName := toColumnName(fields) + err := l.validateColumn(columnName) + if err == nil { + return fmt.Sprintf(`%s."%s"`, prefix, columnName), nil + } + if len(fields) <= 2 { + return "", err + } + idx := -1 + for i := len(fields) - 1; i > 0; i-- { + if !containsNonNumericRegex.MatchString(fields[i]) { + idx = i + break + } + } + if idx == -1 { + // We don't have an index onto a valid field + return "", err + } + indexField := fields[idx] + // fields[len(fields):] gives empty array + otherFields := append(fields[0:idx], fields[idx+1:]...) + leadingColumnName := toColumnName(otherFields) + if l.validateColumn(leadingColumnName) != nil { + // We have an index, but not onto a valid field + return "", err + } + return fmt.Sprintf(`extractBarredValue(%s."%s", "%s")`, prefix, leadingColumnName, indexField), nil +} + // buildORClause creates an SQLite compatible query that ORs conditions built from passed filters func (l *ListOptionIndexer) buildORClauseFromFilters(orFilters sqltypes.OrFilter, dbName string, joinTableIndexByLabelName map[string]int) (string, []any, error) { var params []any @@ -973,8 +1017,8 @@ func ensureSortLabelsAreSelected(lo *sqltypes.ListOptions) { func (l *ListOptionIndexer) getFieldFilter(filter sqltypes.Filter) (string, []any, error) { opString := "" escapeString := "" - columnName := toColumnName(filter.Field) - if err := l.validateColumn(columnName); err != nil { + fieldEntry, err := l.getValidFieldEntry("f", filter.Field) + if err != nil { return "", nil, err } switch filter.Op { @@ -985,7 +1029,7 @@ func (l *ListOptionIndexer) getFieldFilter(filter sqltypes.Filter) (string, []an } else { opString = "=" } - clause := fmt.Sprintf(`f."%s" %s ?%s`, columnName, opString, escapeString) + clause := fmt.Sprintf("%s %s ?%s", fieldEntry, opString, escapeString) return clause, []any{formatMatchTarget(filter)}, nil case sqltypes.NotEq: if filter.Partial { @@ -994,7 +1038,7 @@ func (l *ListOptionIndexer) getFieldFilter(filter sqltypes.Filter) (string, []an } else { opString = "!=" } - clause := fmt.Sprintf(`f."%s" %s ?%s`, columnName, opString, escapeString) + clause := fmt.Sprintf("%s %s ?%s", fieldEntry, opString, escapeString) return clause, []any{formatMatchTarget(filter)}, nil case sqltypes.Lt, sqltypes.Gt: @@ -1002,7 +1046,7 @@ func (l *ListOptionIndexer) getFieldFilter(filter sqltypes.Filter) (string, []an if err != nil { return "", nil, err } - clause := fmt.Sprintf(`f."%s" %s ?`, columnName, sym) + clause := fmt.Sprintf("%s %s ?", fieldEntry, sym) return clause, []any{target}, nil case sqltypes.Exists, sqltypes.NotExists: @@ -1019,7 +1063,7 @@ func (l *ListOptionIndexer) getFieldFilter(filter sqltypes.Filter) (string, []an if filter.Op == sqltypes.NotIn { opString = "NOT IN" } - clause := fmt.Sprintf(`f."%s" %s %s`, columnName, opString, target) + clause := fmt.Sprintf("%s %s %s", fieldEntry, opString, target) matches := make([]any, len(filter.Matches)) for i, match := range filter.Matches { matches[i] = match diff --git a/pkg/sqlcache/informer/listoption_indexer_test.go b/pkg/sqlcache/informer/listoption_indexer_test.go index 844ec8cd..02792de3 100644 --- a/pkg/sqlcache/informer/listoption_indexer_test.go +++ b/pkg/sqlcache/informer/listoption_indexer_test.go @@ -11,6 +11,7 @@ import ( "database/sql" "errors" "fmt" + "os" "testing" "time" @@ -30,7 +31,7 @@ import ( "k8s.io/client-go/tools/cache" ) -func makeListOptionIndexer(ctx context.Context, opts ListOptionIndexerOptions) (*ListOptionIndexer, error) { +func makeListOptionIndexer(ctx context.Context, opts ListOptionIndexerOptions) (*ListOptionIndexer, string, error) { gvk := schema.GroupVersionKind{ Group: "", Version: "v1", @@ -41,25 +42,31 @@ func makeListOptionIndexer(ctx context.Context, opts ListOptionIndexerOptions) ( name := informerNameFromGVK(gvk) m, err := encryption.NewManager() if err != nil { - return nil, err + return nil, "", err } - db, err := db.NewClient(nil, m, m) + db, dbPath, err := db.NewClient(nil, m, m, true) if err != nil { - return nil, err + return nil, "", err } s, err := store.NewStore(ctx, example, cache.DeletionHandlingMetaNamespaceKeyFunc, db, false, name) if err != nil { - return nil, err + return nil, "", err } listOptionIndexer, err := NewListOptionIndexer(ctx, s, opts) if err != nil { - return nil, err + return nil, "", err } - return listOptionIndexer, nil + return listOptionIndexer, dbPath, nil +} + +func cleanTempFiles(basePath string) { + os.Remove(basePath) + os.Remove(basePath + "-shm") + os.Remove(basePath + "-wal") } func TestNewListOptionIndexer(t *testing.T) { @@ -920,7 +927,8 @@ func TestNewListOptionIndexerEasy(t *testing.T) { Fields: fields, IsNamespaced: true, } - loi, err := makeListOptionIndexer(ctx, opts) + loi, dbPath, err := makeListOptionIndexer(ctx, opts) + defer cleanTempFiles(dbPath) assert.NoError(t, err) for _, item := range itemList.Items { @@ -941,6 +949,216 @@ func TestNewListOptionIndexerEasy(t *testing.T) { } } +func TestUserDefinedExtractFunction(t *testing.T) { + makeObj := func(name string, barSeparatedHosts string) map[string]any { + h1 := map[string]any{ + "metadata": map[string]any{ + "name": name, + }, + "spec": map[string]any{ + "rules": map[string]any{ + "host": barSeparatedHosts, + }, + }, + } + return h1 + } + ctx := context.Background() + + type testCase struct { + description string + listOptions sqltypes.ListOptions + partitions []partition.Partition + ns string + + items []*unstructured.Unstructured + + extraIndexedFields [][]string + expectedList *unstructured.UnstructuredList + expectedTotal int + expectedContToken string + expectedErr error + } + + obj01 := makeObj("obj01", "dogs|horses|humans") + obj02 := makeObj("obj02", "dogs|cats|fish") + obj03 := makeObj("obj03", "camels|clowns|zebras") + obj04 := makeObj("obj04", "aardvarks|harps|zyphyrs") + allObjects := []map[string]any{obj01, obj02, obj03, obj04} + makeList := func(t *testing.T, objs ...map[string]any) *unstructured.UnstructuredList { + t.Helper() + + if len(objs) == 0 { + return &unstructured.UnstructuredList{Object: map[string]any{"items": []any{}}, Items: []unstructured.Unstructured{}} + } + + var items []any + for _, obj := range objs { + items = append(items, obj) + } + + list := &unstructured.Unstructured{ + Object: map[string]any{ + "items": items, + }, + } + + itemList, err := list.ToList() + require.NoError(t, err) + + return itemList + } + itemList := makeList(t, allObjects...) + + var tests []testCase + tests = append(tests, testCase{ + description: "find dogs in the first substring", + listOptions: sqltypes.ListOptions{Filters: []sqltypes.OrFilter{ + { + []sqltypes.Filter{ + { + Field: []string{"spec", "rules", "0", "host"}, + Matches: []string{"dogs"}, + Op: sqltypes.Eq, + }, + }, + }, + }, + }, + partitions: []partition.Partition{{All: true}}, + ns: "", + expectedList: makeList(t, obj01, obj02), + expectedTotal: 2, + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "extractBarredValue on item 0 should work", + listOptions: sqltypes.ListOptions{ + SortList: sqltypes.SortList{ + SortDirectives: []sqltypes.Sort{ + { + Fields: []string{"spec", "rules", "0", "host"}, + Order: sqltypes.ASC, + }, + }, + }, + }, + partitions: []partition.Partition{{All: true}}, + ns: "", + expectedList: makeList(t, obj04, obj03, obj01, obj02), + expectedTotal: len(allObjects), + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "extractBarredValue on item 1 should work", + listOptions: sqltypes.ListOptions{ + SortList: sqltypes.SortList{ + SortDirectives: []sqltypes.Sort{ + { + Fields: []string{"spec", "rules", "1", "host"}, + Order: sqltypes.ASC, + }, + }, + }, + }, + partitions: []partition.Partition{{All: true}}, + ns: "", + expectedList: makeList(t, obj02, obj03, obj04, obj01), + expectedTotal: len(allObjects), + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "extractBarredValue on item 2 should work", + listOptions: sqltypes.ListOptions{ + SortList: sqltypes.SortList{ + SortDirectives: []sqltypes.Sort{ + { + Fields: []string{"spec", "rules", "2", "host"}, + Order: sqltypes.ASC, + }, + }, + }, + }, + partitions: []partition.Partition{{All: true}}, + ns: "", + expectedList: makeList(t, obj02, obj01, obj03, obj04), + expectedTotal: len(allObjects), + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "extractBarredValue on item 3 should fall back to default sorting", + listOptions: sqltypes.ListOptions{ + SortList: sqltypes.SortList{ + SortDirectives: []sqltypes.Sort{ + { + Fields: []string{"spec", "rules", "3", "host"}, + Order: sqltypes.ASC, + }, + }, + }, + }, + partitions: []partition.Partition{{All: true}}, + ns: "", + expectedList: makeList(t, allObjects...), + expectedTotal: len(allObjects), + expectedContToken: "", + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "extractBarredValue on item -2 should result in a compile error", + listOptions: sqltypes.ListOptions{ + SortList: sqltypes.SortList{ + SortDirectives: []sqltypes.Sort{ + { + Fields: []string{"spec", "rules", "-2", "host"}, + Order: sqltypes.ASC, + }, + }, + }, + }, + partitions: []partition.Partition{{All: true}}, + ns: "", + expectedErr: errors.New("column is invalid [spec.rules.-2.host]: supplied column is invalid"), + }) + t.Parallel() + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + fields := [][]string{ + {"spec", "rules", "host"}, + } + fields = append(fields, test.extraIndexedFields...) + + opts := ListOptionIndexerOptions{ + Fields: fields, + IsNamespaced: true, + } + loi, dbPath, err := makeListOptionIndexer(ctx, opts) + defer cleanTempFiles(dbPath) + assert.NoError(t, err) + + for _, item := range itemList.Items { + err = loi.Add(&item) + assert.NoError(t, err) + } + + list, total, contToken, err := loi.ListByOptions(ctx, &test.listOptions, test.partitions, test.ns) + if test.expectedErr != nil { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, test.expectedList, list) + assert.Equal(t, test.expectedTotal, total) + assert.Equal(t, test.expectedContToken, contToken) + }) + } +} + func TestConstructQuery(t *testing.T) { type testCase struct { description string @@ -1365,6 +1583,87 @@ func TestConstructQuery(t *testing.T) { expectedStmtArgs: []any{"numericThing", float64(35)}, expectedErr: nil, }) + tests = append(tests, testCase{ + description: "TestConstructQuery: uses the extractBarredValue custom function for penultimate indexer", + listOptions: sqltypes.ListOptions{Filters: []sqltypes.OrFilter{ + { + []sqltypes.Filter{ + { + Field: []string{"spec", "containers", "3", "image"}, + Matches: []string{"nginx-happy"}, + Op: sqltypes.Eq, + }, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (extractBarredValue(f."spec.containers.image", "3") = ?) AND + (FALSE) + ORDER BY f."metadata.name" ASC `, + expectedStmtArgs: []any{"nginx-happy"}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: uses the extractBarredValue custom function for penultimate indexer when sorting", + listOptions: sqltypes.ListOptions{ + SortList: sqltypes.SortList{ + SortDirectives: []sqltypes.Sort{ + { + Fields: []string{"spec", "containers", "16", "image"}, + Order: sqltypes.ASC, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (FALSE) + ORDER BY extractBarredValue(f."spec.containers.image", "16") ASC`, + expectedStmtArgs: []any{}, + expectedErr: nil, + }) + tests = append(tests, testCase{ + description: "TestConstructQuery: uses the extractBarredValue custom function for penultimate indexer when both filtering and sorting", + listOptions: sqltypes.ListOptions{ + Filters: []sqltypes.OrFilter{ + { + []sqltypes.Filter{ + { + Field: []string{"spec", "containers", "3", "image"}, + Matches: []string{"nginx-happy"}, + Op: sqltypes.Eq, + }, + }, + }, + }, + SortList: sqltypes.SortList{ + SortDirectives: []sqltypes.Sort{ + { + Fields: []string{"spec", "containers", "16", "image"}, + Order: sqltypes.ASC, + }, + }, + }, + }, + partitions: []partition.Partition{}, + ns: "", + expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "something" o + JOIN "something_fields" f ON o.key = f.key + WHERE + (extractBarredValue(f."spec.containers.image", "3") = ?) AND + (FALSE) + ORDER BY extractBarredValue(f."spec.containers.image", "16") ASC`, + expectedStmtArgs: []any{"nginx-happy"}, + expectedErr: nil, + }) tests = append(tests, testCase{ description: "multiple filters with a positive label test and a negative non-label test still outer-join", listOptions: sqltypes.ListOptions{Filters: []sqltypes.OrFilter{ @@ -1565,7 +1864,7 @@ func TestConstructQuery(t *testing.T) { } lii := &ListOptionIndexer{ Indexer: i, - indexedFields: []string{"metadata.queryField1", "status.queryField2"}, + indexedFields: []string{"metadata.queryField1", "status.queryField2", "spec.containers.image"}, } queryInfo, err := lii.constructQuery(&test.listOptions, test.partitions, test.ns, "something") if test.expectedErr != nil { @@ -1862,7 +2161,8 @@ func TestWatchMany(t *testing.T) { }, IsNamespaced: true, } - loi, err := makeListOptionIndexer(ctx, opts) + loi, dbPath, err := makeListOptionIndexer(ctx, opts) + defer cleanTempFiles(dbPath) assert.NoError(t, err) startWatcher := func(ctx context.Context) (chan watch.Event, chan error) { @@ -2118,7 +2418,8 @@ func TestWatchFilter(t *testing.T) { Fields: [][]string{{"metadata", "somefield"}}, IsNamespaced: true, } - loi, err := makeListOptionIndexer(ctx, opts) + loi, dbPath, err := makeListOptionIndexer(ctx, opts) + defer cleanTempFiles(dbPath) assert.NoError(t, err) wCh, errCh := startWatcher(ctx, loi, WatchFilter{ @@ -2209,7 +2510,8 @@ func TestWatchResourceVersion(t *testing.T) { opts := ListOptionIndexerOptions{ IsNamespaced: true, } - loi, err := makeListOptionIndexer(parentCtx, opts) + loi, dbPath, err := makeListOptionIndexer(parentCtx, opts) + defer cleanTempFiles(dbPath) assert.NoError(t, err) getRV := func(t *testing.T) string { @@ -2361,7 +2663,8 @@ func TestWatchGarbageCollection(t *testing.T) { opts := ListOptionIndexerOptions{ MaximumEventsCount: 2, } - loi, err := makeListOptionIndexer(parentCtx, opts) + loi, dbPath, err := makeListOptionIndexer(parentCtx, opts) + defer cleanTempFiles(dbPath) assert.NoError(t, err) getRV := func(t *testing.T) string { @@ -2465,7 +2768,8 @@ func TestNonNumberResourceVersion(t *testing.T) { Fields: [][]string{{"metadata", "somefield"}}, IsNamespaced: true, } - loi, err := makeListOptionIndexer(ctx, opts) + loi, dbPath, err := makeListOptionIndexer(ctx, opts) + defer cleanTempFiles(dbPath) assert.NoError(t, err) foo := &unstructured.Unstructured{ diff --git a/pkg/sqlcache/informer/sql_mocks_test.go b/pkg/sqlcache/informer/sql_mocks_test.go index 85336a12..450281e9 100644 --- a/pkg/sqlcache/informer/sql_mocks_test.go +++ b/pkg/sqlcache/informer/sql_mocks_test.go @@ -187,17 +187,18 @@ func (mr *MockStoreMockRecorder) ListKeys() *gomock.Call { } // NewConnection mocks base method. -func (m *MockStore) NewConnection() error { +func (m *MockStore) NewConnection(arg0 bool) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewConnection") - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "NewConnection", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 } // NewConnection indicates an expected call of NewConnection. -func (mr *MockStoreMockRecorder) NewConnection() *gomock.Call { +func (mr *MockStoreMockRecorder) NewConnection(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockStore)(nil).NewConnection)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockStore)(nil).NewConnection), arg0) } // Prepare mocks base method. diff --git a/pkg/sqlcache/store/db_mocks_test.go b/pkg/sqlcache/store/db_mocks_test.go index fb938c0d..326100bb 100644 --- a/pkg/sqlcache/store/db_mocks_test.go +++ b/pkg/sqlcache/store/db_mocks_test.go @@ -140,17 +140,18 @@ func (mr *MockClientMockRecorder) CloseStmt(arg0 any) *gomock.Call { } // NewConnection mocks base method. -func (m *MockClient) NewConnection() error { +func (m *MockClient) NewConnection(arg0 bool) (string, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewConnection") - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "NewConnection", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 } // NewConnection indicates an expected call of NewConnection. -func (mr *MockClientMockRecorder) NewConnection() *gomock.Call { +func (mr *MockClientMockRecorder) NewConnection(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockClient)(nil).NewConnection)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewConnection", reflect.TypeOf((*MockClient)(nil).NewConnection), arg0) } // Prepare mocks base method.