1
0
mirror of https://github.com/rancher/steve.git synced 2025-09-04 08:55:55 +00:00

Add a special case handler to sort raw namespaces.

For any namespaces in a project, we want to group them by project,
sorted by the project's human name.  Then display the others.
Sort the names ascending otherwise.

This is a bit hacky -- if the client adds any filter or sort field,
we no longer do the bespoke joining.

It took all day to come up with this SQL to get all the namespaces.
LMK if there's a better way, and especially if you've got a way
to integrate this technique with user-specified filters and sort params.
This commit is contained in:
Eric Promislow
2025-03-21 13:12:09 -07:00
parent e1061a86cd
commit ca28cd31be
3 changed files with 245 additions and 29 deletions

View File

@@ -248,11 +248,18 @@ func (l *ListOptionIndexer) deleteLabels(key string, tx transaction.Client) erro
// - a continue token, if there are more pages after the returned one // - a continue token, if there are more pages after the returned one
// - an error instead of all of the above if anything went wrong // - an error instead of all of the above if anything went wrong
func (l *ListOptionIndexer) ListByOptions(ctx context.Context, lo ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error) { func (l *ListOptionIndexer) ListByOptions(ctx context.Context, lo ListOptions, partitions []partition.Partition, namespace string) (*unstructured.UnstructuredList, int, string, error) {
queryInfo, err := l.constructQuery(lo, partitions, namespace, db.Sanitize(l.GetName())) dbName := db.Sanitize(l.GetName())
queryInfo, err := l.constructQuery(lo, partitions, namespace, dbName)
if err != nil { if err != nil {
return nil, 0, "", err return nil, 0, "", err
} }
return l.executeQuery(ctx, queryInfo) list, total, token, err := l.executeQuery(ctx, queryInfo)
if err != nil {
if isSpecialCaseQuery(lo, namespace, dbName) && removeSpecialCaseSituation(&lo, namespace, dbName) {
return l.ListByOptions(ctx, lo, partitions, namespace)
}
}
return list, total, token, err
} }
// QueryInfo is a helper-struct that is used to represent the core query and parameters when converting // QueryInfo is a helper-struct that is used to represent the core query and parameters when converting
@@ -266,11 +273,9 @@ type QueryInfo struct {
offset int offset int
} }
func (l *ListOptionIndexer) constructQuery(lo ListOptions, partitions []partition.Partition, namespace string, dbName string) (*QueryInfo, error) { func (l *ListOptionIndexer) constructRegularQuery(lo ListOptions, namespace string, dbName string, joinTableIndexByLabelName map[string]int) (string, []string, []any, error) {
ensureSortLabelsAreSelected(&lo) ensureSortLabelsAreSelected(&lo)
queryInfo := &QueryInfo{}
queryUsesLabels := hasLabelFilter(lo.Filters) queryUsesLabels := hasLabelFilter(lo.Filters)
joinTableIndexByLabelName := make(map[string]int)
// First, what kind of filtering will we be doing? // First, what kind of filtering will we be doing?
// 1- Intro: SELECT and JOIN clauses // 1- Intro: SELECT and JOIN clauses
@@ -308,7 +313,7 @@ func (l *ListOptionIndexer) constructQuery(lo ListOptions, partitions []partitio
for _, orFilters := range lo.Filters { for _, orFilters := range lo.Filters {
orClause, orParams, err := l.buildORClauseFromFilters(orFilters, dbName, joinTableIndexByLabelName) orClause, orParams, err := l.buildORClauseFromFilters(orFilters, dbName, joinTableIndexByLabelName)
if err != nil { if err != nil {
return queryInfo, err return query, whereClauses, params, err
} }
if orClause == "" { if orClause == "" {
continue continue
@@ -322,6 +327,60 @@ func (l *ListOptionIndexer) constructQuery(lo ListOptions, partitions []partitio
whereClauses = append(whereClauses, fmt.Sprintf(`f."metadata.namespace" = ?`)) whereClauses = append(whereClauses, fmt.Sprintf(`f."metadata.namespace" = ?`))
params = append(params, namespace) params = append(params, namespace)
} }
return query, whereClauses, params, nil
}
func constructSpecialQuery(lo ListOptions, namespace string, dbName string) (string, []string, []string, error) {
if dbName != "_v1_Namespace" {
return "", nil, nil, fmt.Errorf("internal error: dbName must be %s, got %s", "_v1_Namespace", dbName)
}
// If we're grabbing all the namespaces with no filter or sort parameters,
// we want them sorted first according to the human name of the project they're associated with,
// if there is one, and then sort by namespace name.
// Namespaces that aren't in a project show up after the others.
// Note that this query will fail if the 'management.cattle.io_v3_Project*' tables haven't
// been loaded yet. In which case we redo the query with a sort on `metadata.name`
// to prevent this query from being created.
query := `SELECT object, objectnonce, dekid FROM
(
SELECT o.object as object, o.objectnonce as objectnonce, o.dekid as dekid, o.key as key, proj."spec.displayName" as humanName FROM "_v1_Namespace" o
JOIN "_v1_Namespace_fields" f ON o.key = f.key
LEFT OUTER JOIN "_v1_Namespace_labels" nslb ON o.key = nslb.key
JOIN "management.cattle.io_v3_Project_fields" proj ON nslb.value = proj."metadata.name"
WHERE nslb.label = "field.cattle.io/projectId"
UNION ALL
SELECT o.object as object, o.objectnonce as objectnonce, o.dekid as dekid, o.key as key, NULL as humanName FROM "_v1_Namespace" o
JOIN "_v1_Namespace_fields" f ON o.key = f.key
LEFT OUTER JOIN "_v1_Namespace_labels" nslb ON o.key = nslb.key
WHERE (o.key NOT IN (SELECT o1.key FROM "_v1_Namespace" o1
JOIN "_v1_Namespace_fields" f1 ON o1.key = f1.key
LEFT OUTER JOIN "_v1_Namespace_labels" lt1i1 ON o1.key = lt1i1.key
WHERE lt1i1.label = "field.cattle.io/projectId"))
)`
whereClauses := []string{}
sortClauses := []string{"humanName ASC NULLS LAST", "key ASC"}
return query, whereClauses, sortClauses, nil
}
func (l *ListOptionIndexer) constructQuery(lo ListOptions, partitions []partition.Partition, namespace string, dbName string) (*QueryInfo, error) {
joinTableIndexByLabelName := make(map[string]int)
var err error
var query string
var whereClauses []string
var orderByClauses []string
var params []any
if isSpecialCaseQuery(lo, namespace, dbName) {
query, whereClauses, orderByClauses, err = constructSpecialQuery(lo, namespace, dbName)
params = make([]any, 0)
} else {
query, whereClauses, params, err = l.constructRegularQuery(lo, namespace, dbName, joinTableIndexByLabelName)
orderByClauses = make([]string, 0)
}
if err != nil {
return nil, err
}
// WHERE clauses (from partitions and their corresponding parameters) // WHERE clauses (from partitions and their corresponding parameters)
partitionClauses := []string{} partitionClauses := []string{}
@@ -341,7 +400,7 @@ func (l *ListOptionIndexer) constructQuery(lo ListOptions, partitions []partitio
if !thisPartition.All { if !thisPartition.All {
names := thisPartition.Names names := thisPartition.Names
if len(names) == 0 { if names.Len() == 0 {
// degenerate case, there will be no results // degenerate case, there will be no results
singlePartitionClauses = append(singlePartitionClauses, "FALSE") singlePartitionClauses = append(singlePartitionClauses, "FALSE")
} else { } else {
@@ -388,11 +447,14 @@ func (l *ListOptionIndexer) constructQuery(lo ListOptions, partitions []partitio
countParams := params[:] countParams := params[:]
// 3- Sorting: ORDER BY clauses (from lo.Sort) // 3- Sorting: ORDER BY clauses (from lo.Sort)
if len(lo.Sort.Fields) != len(lo.Sort.Orders) { if len(orderByClauses) > 0 {
// From the special-case query
query += "\n ORDER BY "
query += strings.Join(orderByClauses, ", ")
} else if len(lo.Sort.Fields) != len(lo.Sort.Orders) {
return nil, fmt.Errorf("sort fields length %d != sort orders length %d", len(lo.Sort.Fields), len(lo.Sort.Orders)) return nil, fmt.Errorf("sort fields length %d != sort orders length %d", len(lo.Sort.Fields), len(lo.Sort.Orders))
} } else if len(lo.Sort.Fields) > 0 {
if len(lo.Sort.Fields) > 0 { orderByClauses = []string{}
orderByClauses := []string{}
for i, field := range lo.Sort.Fields { for i, field := range lo.Sort.Fields {
if isLabelsFieldList(field) { if isLabelsFieldList(field) {
clause, sortParam, err := buildSortLabelsClause(field[2], joinTableIndexByLabelName, lo.Sort.Orders[i] == ASC) clause, sortParam, err := buildSortLabelsClause(field[2], joinTableIndexByLabelName, lo.Sort.Orders[i] == ASC)
@@ -404,7 +466,7 @@ func (l *ListOptionIndexer) constructQuery(lo ListOptions, partitions []partitio
} else { } else {
columnName := toColumnName(field) columnName := toColumnName(field)
if err := l.validateColumn(columnName); err != nil { if err := l.validateColumn(columnName); err != nil {
return queryInfo, err return nil, err
} }
direction := "ASC" direction := "ASC"
if lo.Sort.Orders[i] == DESC { if lo.Sort.Orders[i] == DESC {
@@ -443,7 +505,7 @@ func (l *ListOptionIndexer) constructQuery(lo ListOptions, partitions []partitio
if lo.Resume != "" { if lo.Resume != "" {
offsetInt, err := strconv.Atoi(lo.Resume) offsetInt, err := strconv.Atoi(lo.Resume)
if err != nil { if err != nil {
return queryInfo, err return nil, err
} }
offset = offsetInt offset = offsetInt
} }
@@ -454,6 +516,8 @@ func (l *ListOptionIndexer) constructQuery(lo ListOptions, partitions []partitio
offsetClause = "\n OFFSET ?" offsetClause = "\n OFFSET ?"
params = append(params, offset) params = append(params, offset)
} }
queryInfo := &QueryInfo{}
if limit > 0 || offset > 0 { if limit > 0 || offset > 0 {
query += limitClause query += limitClause
query += offsetClause query += offsetClause
@@ -807,20 +871,6 @@ func (l *ListOptionIndexer) getLabelFilter(index int, filter Filter, dbName stri
return "", nil, fmt.Errorf("unrecognized operator: %s", opString) return "", nil, fmt.Errorf("unrecognized operator: %s", opString)
} }
func prepareComparisonParameters(op Op, target string) (string, float64, error) {
num, err := strconv.ParseFloat(target, 32)
if err != nil {
return "", 0, err
}
switch op {
case Lt:
return "<", num, nil
case Gt:
return ">", num, nil
}
return "", 0, fmt.Errorf("unrecognized operator when expecting '<' or '>': '%s'", op)
}
func formatMatchTarget(filter Filter) string { func formatMatchTarget(filter Filter) string {
format := strictMatchFmt format := strictMatchFmt
if filter.Partial { if filter.Partial {
@@ -838,6 +888,37 @@ func formatMatchTargetWithFormatter(match string, format string) string {
return fmt.Sprintf(format, match) return fmt.Sprintf(format, match)
} }
func isSpecialCaseQuery(lo ListOptions, namespace string, dbName string) bool {
if dbName == "_v1_Namespace" {
return (namespace == "" || namespace == "*") && len(lo.Filters) == 0 && len(lo.Sort.Fields) == 0
}
return false
}
func prepareComparisonParameters(op Op, target string) (string, float64, error) {
num, err := strconv.ParseFloat(target, 32)
if err != nil {
return "", 0, err
}
switch op {
case Lt:
return "<", num, nil
case Gt:
return ">", num, nil
}
return "", 0, fmt.Errorf("unrecognized operator when expecting '<' or '>': '%s'", op)
}
func removeSpecialCaseSituation(lo *ListOptions, namespace string, dbName string) bool {
if dbName == "_v1_Namespace" {
// Muddy it and retry
lo.Sort.Fields = [][]string{{"metadata", "name"}}
lo.Sort.Orders = []SortOrder{ASC}
return true
}
return false
}
// There are two kinds of string arrays to turn into a string, based on the last value in the array // There are two kinds of string arrays to turn into a string, based on the last value in the array
// simple: ["a", "b", "conformsToIdentifier"] => "a.b.conformsToIdentifier" // simple: ["a", "b", "conformsToIdentifier"] => "a.b.conformsToIdentifier"
// complex: ["a", "b", "foo.io/stuff"] => "a.b[foo.io/stuff]" // complex: ["a", "b", "foo.io/stuff"] => "a.b[foo.io/stuff]"

View File

@@ -1025,6 +1025,7 @@ func TestListByOptions(t *testing.T) {
} }
} }
// `'
func TestConstructQuery(t *testing.T) { func TestConstructQuery(t *testing.T) {
type testCase struct { type testCase struct {
description string description string
@@ -1679,6 +1680,138 @@ func TestConstructQuery(t *testing.T) {
} }
} }
func TestConstructSpecialQuery(t *testing.T) {
type testCase struct {
dbname string
description string
listOptions ListOptions
partitions []partition.Partition
ns string
expectedCountStmt string
expectedCountStmtArgs []any
expectedStmt string
expectedStmtArgs []any
expectedErr error
}
var tests []testCase
tests = append(tests, testCase{
dbname: "_v1_Namespace",
description: "ConstructQuery: unsorted namespaces should be sorted by project name",
listOptions: ListOptions{},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT object, objectnonce, dekid FROM
(
SELECT o.object as object, o.objectnonce as objectnonce, o.dekid as dekid, o.key as key, proj."spec.displayName" as humanName FROM "_v1_Namespace" o
JOIN "_v1_Namespace_fields" f ON o.key = f.key
LEFT OUTER JOIN "_v1_Namespace_labels" nslb ON o.key = nslb.key
JOIN "management.cattle.io_v3_Project_fields" proj ON nslb.value = proj."metadata.name"
WHERE nslb.label = "field.cattle.io/projectId"
UNION ALL
SELECT o.object as object, o.objectnonce as objectnonce, o.dekid as dekid, o.key as key, NULL as humanName FROM "_v1_Namespace" o
JOIN "_v1_Namespace_fields" f ON o.key = f.key
LEFT OUTER JOIN "_v1_Namespace_labels" nslb ON o.key = nslb.key
WHERE (o.key NOT IN (SELECT o1.key FROM "_v1_Namespace" o1
JOIN "_v1_Namespace_fields" f1 ON o1.key = f1.key
LEFT OUTER JOIN "_v1_Namespace_labels" lt1i1 ON o1.key = lt1i1.key
WHERE lt1i1.label = "field.cattle.io/projectId"))
)
WHERE
(FALSE)
ORDER BY humanName ASC NULLS LAST, key ASC`,
expectedStmtArgs: []any{},
expectedErr: nil,
})
t.Parallel()
for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
store := NewMockStore(gomock.NewController(t))
i := &Indexer{
Store: store,
}
lii := &ListOptionIndexer{
Indexer: i,
indexedFields: []string{"metadata.queryField1", "status.queryField2"},
}
if test.listOptions.Filters == nil {
test.listOptions.Filters = []OrFilter{}
}
if test.listOptions.Sort.Fields == nil {
test.listOptions.Sort.Fields = [][]string{}
test.listOptions.Sort.Orders = []SortOrder{}
}
assert.True(t, isSpecialCaseQuery(test.listOptions, test.ns, test.dbname))
queryInfo, err := lii.constructQuery(test.listOptions, test.partitions, test.ns, test.dbname)
if test.expectedErr != nil {
assert.Equal(t, test.expectedErr, err)
return
}
assert.Nil(t, err)
assert.Equal(t, test.expectedStmt, queryInfo.query)
assert.Equal(t, test.expectedStmtArgs, queryInfo.params)
assert.Equal(t, test.expectedCountStmt, queryInfo.countQuery)
assert.Equal(t, test.expectedCountStmtArgs, queryInfo.countParams)
})
}
}
func TestRemoveSpecialCaseOnNamespace(t *testing.T) {
type testCase struct {
dbname string
description string
listOptions ListOptions
partitions []partition.Partition
ns string
expectedCountStmt string
expectedCountStmtArgs []any
expectedStmt string
expectedStmtArgs []any
expectedErr error
}
test := testCase{
dbname: "_v1_Namespace",
description: "RemoveSpecialCaseOnNamespace: unsorted namespaces should be sorted by namespace name",
listOptions: ListOptions{
Filters: []OrFilter{},
Sort: Sort{Fields: [][]string{}, Orders: []SortOrder{}},
},
partitions: []partition.Partition{},
ns: "",
expectedStmt: `SELECT o.object, o.objectnonce, o.dekid FROM "_v1_Namespace" o
JOIN "_v1_Namespace_fields" f ON o.key = f.key
WHERE
(FALSE)
ORDER BY f."metadata.name" ASC`,
expectedStmtArgs: []any{},
expectedErr: nil,
}
assert.True(t, isSpecialCaseQuery(test.listOptions, test.ns, test.dbname))
assert.True(t, removeSpecialCaseSituation(&test.listOptions, test.ns, test.dbname))
assert.False(t, isSpecialCaseQuery(test.listOptions, test.ns, test.dbname))
store := NewMockStore(gomock.NewController(t))
i := &Indexer{
Store: store,
}
lii := &ListOptionIndexer{
Indexer: i,
indexedFields: []string{"metadata.name"},
}
queryInfo, err := lii.constructQuery(test.listOptions, test.partitions, test.ns, test.dbname)
if test.expectedErr != nil {
assert.Equal(t, test.expectedErr, err)
return
}
assert.Nil(t, err)
assert.Equal(t, test.expectedStmt, queryInfo.query)
assert.Equal(t, test.expectedStmtArgs, queryInfo.params)
assert.Equal(t, test.expectedCountStmt, queryInfo.countQuery)
assert.Equal(t, test.expectedCountStmtArgs, queryInfo.countParams)
}
func TestSmartJoin(t *testing.T) { func TestSmartJoin(t *testing.T) {
type testCase struct { type testCase struct {
description string description string

View File

@@ -151,7 +151,9 @@ var (
gvkKey("management.cattle.io", "v3", "NodeTemplate"): { gvkKey("management.cattle.io", "v3", "NodeTemplate"): {
{"spec", "clusterName"}}, {"spec", "clusterName"}},
gvkKey("management.cattle.io", "v3", "Project"): { gvkKey("management.cattle.io", "v3", "Project"): {
{"spec", "clusterName"}}, {"spec", "clusterName"},
{"spec", "displayName"},
},
gvkKey("networking.k8s.io", "v1", "Ingress"): { gvkKey("networking.k8s.io", "v1", "Ingress"): {
{"spec", "rules", "host"}, {"spec", "rules", "host"},
{"spec", "ingressClassName"}, {"spec", "ingressClassName"},