1
0
mirror of https://github.com/rancher/steve.git synced 2025-09-16 23:39:16 +00:00
Files
steve/pkg/sqlcache/informer/query_generator.go
2025-04-24 16:48:13 -07:00

1103 lines
41 KiB
Go

package informer
import (
"encoding/json"
"errors"
"fmt"
"regexp"
"sort"
"strconv"
"strings"
"github.com/rancher/steve/pkg/sqlcache/partition"
"github.com/rancher/steve/pkg/sqlcache/sqltypes"
)
var (
badTableNameChars = regexp.MustCompile(`[^-a-zA-Z0-9._]+`)
nonIdentifierChars = regexp.MustCompile(`[^a-zA-Z0-9_]+`)
)
func (l *ListOptionIndexer) constructQuery(lo *sqltypes.ListOptions, partitions []partition.Partition, namespace string, dbName string) (*QueryInfo, error) {
indirectSortDirective, err := checkForIndirectSortDirective(lo)
if err != nil {
return nil, err
}
joinTableIndexByLabelName := make(map[string]int)
if indirectSortDirective != nil && isLabelsFieldList(indirectSortDirective.Fields) {
return l.constructIndirectSortQuery(lo, partitions, namespace, dbName, joinTableIndexByLabelName)
}
ensureSortLabelsAreSelected(lo)
return l.finishConstructQuery(lo, partitions, namespace, dbName, joinTableIndexByLabelName)
}
/** constructIndirectSortQuery - process indirect-sorting
* Here we create two queries:
* one that has an existence test for the sorter,
* and one with a non-existence test, so each of these is AND-ed with the other WHERE tests (filters).
* Then do a `UNION ALL` on the two different queries.
* The unobvious part: have the original options-list do only the one indirect sort.
* Have the copy process any other sort options.
*
* Two limitations:
* 1. Only at most one indirect sort per query
* 2. The indirect sort will go before the other ones (todo: fix this)
*/
func (l *ListOptionIndexer) constructIndirectSortQuery(lo *sqltypes.ListOptions, partitions []partition.Partition, namespace string, dbName string, joinTableIndexByLabelName map[string]int) (*QueryInfo, error) {
var loNoLabel sqltypes.ListOptions
// We want to make sure that the want-sort-label options test for the label's existence,
// but we want the non-label to have a not-exists test on it. So first ensure it exists,
// then ensure a non-exists test exists on the non-label filter.
// The other thing is we put all the non-indirect sort directives on the copy of the list options
var indirectSortDirective sqltypes.Sort
newSortList1 := make([]sqltypes.Sort, 1)
newSortList2 := make([]sqltypes.Sort, 0, len(lo.SortList.SortDirectives)-1)
indirectSortPosition := -1
for i, sd := range lo.SortList.SortDirectives {
if sd.IsIndirect {
indirectSortDirective = lo.SortList.SortDirectives[i]
newSortList1[0] = indirectSortDirective
indirectSortPosition = i
} else {
newSortList2 = append(newSortList2, sd)
}
}
if indirectSortPosition == -1 {
return nil, fmt.Errorf("expected an indirect sort directive, didn't find one")
}
if len(indirectSortDirective.IndirectFields) != 4 {
return nil, fmt.Errorf("expected indirect sort directive to have 4 indirect fields, got %d", len(indirectSortDirective.IndirectFields))
}
bytes, err := json.Marshal(*lo)
if err != nil {
return nil, fmt.Errorf("can't json-encode list options: %w", err)
}
err = json.Unmarshal(bytes, &loNoLabel)
if err != nil {
return nil, fmt.Errorf("can't json-decode list options: %w", err)
}
applyIndirectLabelTests(lo, &loNoLabel, &indirectSortDirective)
lo.SortList.SortDirectives = newSortList1
loNoLabel.SortList.SortDirectives = newSortList2
joinParts1, whereClauses1, params1, needsDistinctModifier1, _, _, _, err1 := l.getQueryParts(lo, partitions, namespace, dbName, joinTableIndexByLabelName)
if err1 != nil {
return nil, err1
}
// Now add clauses for the indirectSortDirective
joinParts2, whereClauses2, params2, needsDistinctModifier2, orderByClauses2, orderByParams2, _, err2 := l.getQueryParts(&loNoLabel, partitions, namespace, dbName, joinTableIndexByLabelName)
if err2 != nil {
return nil, err2
}
addFalseTest := false
if whereClauses1[len(whereClauses1)-1] == "FALSE" {
whereClauses1 = whereClauses1[:len(whereClauses1)-1]
addFalseTest = true
}
if whereClauses2[len(whereClauses2)-1] == "FALSE" {
whereClauses2 = whereClauses2[:len(whereClauses2)-1]
addFalseTest = true
}
distinctModifier := ""
if needsDistinctModifier1 || needsDistinctModifier2 {
distinctModifier = " DISTINCT"
}
externalTableName := getExternalTableName(&indirectSortDirective)
extIndex, ok := joinTableIndexByLabelName[externalTableName]
if !ok {
return nil, fmt.Errorf("internal error: unable to find an entry for external table %s", externalTableName)
}
sortParts, importWithParts, importAsNullParts := processOrderByFields(&indirectSortDirective, extIndex, orderByClauses2)
selectLine := fmt.Sprintf("SELECT%s o.object AS __ix_object, o.objectnonce AS __ix_objectnonce, o.dekid AS __ix_dekid", distinctModifier)
indent1 := " "
indent2 := indent1 + indent1
indent3 := indent2 + indent1
where1 := joinWhereClauses(whereClauses1, indent2, indent3, "AND")
where2 := joinWhereClauses(whereClauses2, indent2, indent3, "AND")
parts := []string{
"SELECT __ix_object, __ix_objectnonce, __ix_dekid FROM (",
fmt.Sprintf(`%s%s, %s FROM %s`, indent1, selectLine, strings.Join(importWithParts, ", "), strings.Join(joinParts1, "\n"+indent2)),
where1,
"UNION ALL",
fmt.Sprintf(`%s%s, %s FROM %s`, indent1, selectLine, strings.Join(importAsNullParts, ", "), strings.Join(joinParts2, "\n"+indent2)),
where2,
")",
}
if addFalseTest {
parts = append(parts, "WHERE FALSE")
}
params := make([]any, 0, len(params1)+len(params2)+len(orderByParams2))
params = append(params, params1...)
params = append(params, params2...)
fullQuery := strings.Join(parts, "\n")
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM (%s)", fullQuery)
countParams := params[:]
params = append(params, orderByParams2...)
fixedSortParts := putIndirectSortInPosition(sortParts, indirectSortPosition)
fullQuery += fmt.Sprintf("\n%sORDER BY %s", indent1, strings.Join(fixedSortParts, ", "))
queryInfo := &QueryInfo{
query: fullQuery,
params: params,
countQuery: countQuery,
countParams: countParams,
}
return queryInfo, nil
}
func putIndirectSortInPosition(sortParts []string, indirectSortPosition int) []string {
fixedSortParts := make([]string, 0, len(sortParts))
indirectSortPart := sortParts[0]
sortParts = sortParts[1:]
fixedSortParts = append(fixedSortParts, sortParts[0:indirectSortPosition]...)
fixedSortParts = append(fixedSortParts, indirectSortPart)
fixedSortParts = append(fixedSortParts, sortParts[indirectSortPosition:len(sortParts)]...)
return fixedSortParts
}
func (l *ListOptionIndexer) finishConstructQuery(lo *sqltypes.ListOptions, partitions []partition.Partition, namespace string, dbName string, joinTableIndexByLabelName map[string]int) (*QueryInfo, error) {
joinParts, whereClauses, params, needsDistinctModifier, orderByClauses, orderByParams, sortSelectField, err := l.getQueryParts(lo, partitions, namespace, dbName, joinTableIndexByLabelName)
if err != nil {
return nil, err
}
distinctModifier := ""
if needsDistinctModifier {
distinctModifier = " DISTINCT"
}
queryInfo := &QueryInfo{}
if len(sortSelectField) > 0 {
if sortSelectField[0] != ' ' {
sortSelectField = " " + sortSelectField
}
}
query := fmt.Sprintf(`SELECT%s o.object, o.objectnonce, o.dekid%s FROM `, distinctModifier, sortSelectField)
query += strings.Join(joinParts, "\n ")
if len(whereClauses) > 0 {
indent := " "
separator := fmt.Sprintf(") AND\n%s(", indent)
query += fmt.Sprintf("\n WHERE\n%s(%s)", indent, strings.Join(whereClauses, separator))
}
// before proceeding, save a copy of the query and params without LIMIT/OFFSET/ORDER info
// for COUNTing all results later
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM (%s)", query)
countParams := params[:]
if len(orderByClauses) > 0 {
query += "\n ORDER BY "
query += strings.Join(orderByClauses, ", ")
params = append(params, orderByParams...)
}
// 4- sqltypes.Pagination: LIMIT clause (from lo.Pagination and/or lo.ChunkSize/lo.Resume)
limitClause := ""
// take the smallest limit between lo.Pagination and lo.ChunkSize
limit := lo.Pagination.PageSize
if limit == 0 || (lo.ChunkSize > 0 && lo.ChunkSize < limit) {
limit = lo.ChunkSize
}
if limit > 0 {
limitClause = "\n LIMIT ?"
params = append(params, limit)
}
// OFFSET clause (from lo.Pagination and/or lo.Resume)
offsetClause := ""
offset := 0
if lo.Resume != "" {
offsetInt, err := strconv.Atoi(lo.Resume)
if err != nil {
return queryInfo, err
}
offset = offsetInt
}
if lo.Pagination.Page >= 1 {
offset += lo.Pagination.PageSize * (lo.Pagination.Page - 1)
}
if offset > 0 {
offsetClause = "\n OFFSET ?"
params = append(params, offset)
}
if limit > 0 || offset > 0 {
query += limitClause
query += offsetClause
queryInfo.countQuery = countQuery
queryInfo.countParams = countParams
queryInfo.limit = limit
queryInfo.offset = offset
}
// Otherwise leave these as default values and the executor won't do pagination work
queryInfo.query = query
queryInfo.params = params
return queryInfo, nil
}
// Other ListOptionIndexer methods for generating SQL in alphabetical order:
// buildORClause creates an SQLite compatible query that ORs conditions built from passed filters
func (l *ListOptionIndexer) buildORClauseFromFilters(orFilters sqltypes.OrFilter, dbName string, joinTableIndexByLabelName map[string]int, joinedTables map[string]bool) (string, []string, []any, bool, error) {
params := make([]any, 0)
whereClauses := make([]string, 0, len(orFilters.Filters))
joinClauses := make([]string, 0)
needDistinct := false
for _, filter := range orFilters.Filters {
if isLabelFilter(&filter) {
fullName := fmt.Sprintf("%s:%s", dbName, filter.Field[2])
labelIndex, ok := joinTableIndexByLabelName[fullName]
if !ok {
labelIndex = len(joinTableIndexByLabelName) + 1
joinTableIndexByLabelName[fullName] = labelIndex
}
_, ok = joinedTables[fullName]
if !ok {
joinedTables[fullName] = true
joinClauses = append(joinClauses, fmt.Sprintf(`LEFT OUTER JOIN "%s_labels" lt%d ON o.key = lt%d.key`, dbName, labelIndex, labelIndex))
}
needDistinct = true
labelFunc := l.getLabelFilter
if isIndirectFilter(&filter) {
labelFunc = l.getIndirectLabelFilter
}
newWhereClause, newJoins, newParams, err := labelFunc(filter, dbName, joinTableIndexByLabelName, joinedTables)
if err != nil {
return "", nil, nil, needDistinct, err
}
joinClauses = append(joinClauses, newJoins...)
if newWhereClause != "" {
whereClauses = append(whereClauses, newWhereClause)
}
params = append(params, newParams...)
} else if isIndirectFilter(&filter) {
newWhereClause, newJoins, newParams, err := l.getIndirectNonLabelFilter(filter, dbName, joinTableIndexByLabelName, joinedTables)
if err != nil {
return "", nil, nil, needDistinct, err
}
joinClauses = append(joinClauses, newJoins...)
if newWhereClause != "" {
whereClauses = append(whereClauses, newWhereClause)
}
params = append(params, newParams...)
} else {
newWhereClause, newParams, err := l.getFieldFilter(filter)
if err != nil {
return "", nil, nil, needDistinct, err
}
if newWhereClause != "" {
whereClauses = append(whereClauses, newWhereClause)
}
params = append(params, newParams...)
}
}
finalWhereClause := ""
switch len(whereClauses) {
case 0:
finalWhereClause = "" // no change
case 1:
finalWhereClause = whereClauses[0]
default:
finalWhereClause = fmt.Sprintf("(%s)", strings.Join(whereClauses, ") OR ("))
}
return finalWhereClause, joinClauses, params, needDistinct, nil
}
// ensureSortLabelsAreSelected - if the user tries to sort on a particular label without mentioning it in a query,
// and it's not an indirect sort directive, we need to ensure the label is added.
func ensureSortLabelsAreSelected(lo *sqltypes.ListOptions) {
if len(lo.SortList.SortDirectives) == 0 {
return
}
unboundSortLabels := make(map[string]bool)
for _, sortDirective := range lo.SortList.SortDirectives {
fields := sortDirective.Fields
if isLabelsFieldList(fields) {
unboundSortLabels[fields[2]] = true
}
}
if len(unboundSortLabels) == 0 {
return
}
// If we have sort directives but no filters, add an exists-filter for each label.
if lo.Filters == nil || len(lo.Filters) == 0 {
lo.Filters = make([]sqltypes.OrFilter, 1)
lo.Filters[0].Filters = make([]sqltypes.Filter, len(unboundSortLabels))
i := 0
for labelName := range unboundSortLabels {
lo.Filters[0].Filters[i] = sqltypes.Filter{
Field: []string{"metadata", "labels", labelName},
Op: sqltypes.Exists,
}
i++
}
return
}
// Find any labels that are already mentioned in an existing filter
// The gotcha is we have to bind the labels for each set of orFilters, so copy them each time
for i, orFilters := range lo.Filters {
copyUnboundSortLabels := make(map[string]bool, len(unboundSortLabels))
for k, v := range unboundSortLabels {
copyUnboundSortLabels[k] = v
}
for _, filter := range orFilters.Filters {
if isLabelFilter(&filter) {
copyUnboundSortLabels[filter.Field[2]] = false
}
}
// Now for any labels that are still true, add another where clause
for labelName, needsBinding := range copyUnboundSortLabels {
if needsBinding {
// `orFilters` is a copy of lo.Filters[i], so reference the original.
lo.Filters[i].Filters = append(lo.Filters[i].Filters, sqltypes.Filter{
Field: []string{"metadata", "labels", labelName},
Op: sqltypes.Exists,
})
}
}
}
}
// Possible ops from the k8s parser:
// KEY = and == (same) VALUE
// KEY != VALUE
// KEY exists [] # ,KEY, => this filter
// KEY ! [] # ,!KEY, => assert KEY doesn't exist
// KEY in VALUES
// KEY notin VALUES
func (l *ListOptionIndexer) getFieldFilter(filter sqltypes.Filter) (string, []any, error) {
opString := ""
escapeString := ""
columnName := toColumnName(filter.Field)
if err := l.validateColumn(columnName); err != nil {
return "", nil, err
}
switch filter.Op {
case sqltypes.Eq:
if filter.Partial {
opString = "LIKE"
escapeString = escapeBackslashDirective
} else {
opString = "="
}
clause := fmt.Sprintf(`f."%s" %s ?%s`, columnName, opString, escapeString)
return clause, []any{formatMatchTarget(filter)}, nil
case sqltypes.NotEq:
if filter.Partial {
opString = "NOT LIKE"
escapeString = escapeBackslashDirective
} else {
opString = "!="
}
clause := fmt.Sprintf(`f."%s" %s ?%s`, columnName, opString, escapeString)
return clause, []any{formatMatchTarget(filter)}, nil
case sqltypes.Lt, sqltypes.Gt:
sym, target, err := prepareComparisonParameters(filter.Op, filter.Matches[0])
if err != nil {
return "", nil, err
}
clause := fmt.Sprintf(`f."%s" %s ?`, columnName, sym)
return clause, []any{target}, nil
case sqltypes.Exists, sqltypes.NotExists:
return "", nil, errors.New("NULL and NOT NULL tests aren't supported for non-label queries")
case sqltypes.In:
fallthrough
case sqltypes.NotIn:
target := "()"
if len(filter.Matches) > 0 {
target = fmt.Sprintf("(?%s)", strings.Repeat(", ?", len(filter.Matches)-1))
}
opString = "IN"
if filter.Op == sqltypes.NotIn {
opString = "NOT IN"
}
clause := fmt.Sprintf(`f."%s" %s %s`, columnName, opString, target)
matches := make([]any, len(filter.Matches))
for i, match := range filter.Matches {
matches[i] = match
}
return clause, matches, nil
}
return "", nil, fmt.Errorf("unrecognized operator: %s", opString)
}
func (l *ListOptionIndexer) getIndirectLabelFilter(filter sqltypes.Filter, dbName string, joinTableIndexByLabelName map[string]int, joinedTables map[string]bool) (string, []string, []any, error) {
if len(filter.IndirectFields) != 4 {
s := "<empty>"
if len(filter.IndirectFields) > 0 {
s = strings.Join(filter.IndirectFields, " ")
}
return "", nil, nil, fmt.Errorf("expected exactly 4 indirect field parts, got %s (%d)", s, len(filter.IndirectFields))
}
labelName := filter.Field[2]
fullName := fmt.Sprintf("%s:%s", dbName, labelName)
labelIndex, ok := joinTableIndexByLabelName[fullName]
if !ok {
return "", nil, nil, fmt.Errorf("internal error: can't find an entry for table %s", fullName)
}
joinClauses := make([]string, 0)
extDBName := fmt.Sprintf("%s_%s", filter.IndirectFields[0], filter.IndirectFields[1])
extDBName = badTableNameChars.ReplaceAllString(extDBName, "_")
extDBName = strings.ReplaceAll(extDBName, "/", "_")
extIndex, ok := joinTableIndexByLabelName[extDBName]
if !ok {
extIndex = len(joinTableIndexByLabelName) + 1
joinTableIndexByLabelName[extDBName] = extIndex
}
selectorFieldName := filter.IndirectFields[2]
if badTableNameChars.MatchString(selectorFieldName) {
return "", nil, nil, fmt.Errorf("invalid database column name '%s'", selectorFieldName)
}
targetFieldName := filter.IndirectFields[3]
if badTableNameChars.MatchString(targetFieldName) {
return "", nil, nil, fmt.Errorf("invalid database column name '%s'", targetFieldName)
}
extDBNameFields := fmt.Sprintf("%s_fields", extDBName)
_, ok = joinedTables[extDBNameFields]
if !ok {
joinedTables[extDBNameFields] = true
joinClauses = append(joinClauses, fmt.Sprintf(`JOIN "%s" ext%d ON lt%d.value = ext%d."%s"`, extDBNameFields, extIndex, labelIndex, extIndex, selectorFieldName))
}
labelWhereSubClause := fmt.Sprintf("lt%d.label = ?", labelIndex)
targetFieldReference := fmt.Sprintf(`ext%d."%s"`, extIndex, targetFieldName)
var clause string
var op string
params := []any{labelName}
opString := ""
escapeString := ""
matchFmtToUse := strictMatchFmt
switch filter.Op {
case sqltypes.Eq:
if filter.Partial {
opString = "LIKE"
escapeString = escapeBackslashDirective
matchFmtToUse = matchFmt
} else {
opString = "="
}
clause = fmt.Sprintf(`%s AND %s %s ?%s`, labelWhereSubClause, targetFieldReference, opString, escapeString)
params = append(params, formatMatchTargetWithFormatter(filter.Matches[0], matchFmtToUse))
return clause, joinClauses, params, nil
case sqltypes.NotEq:
if filter.Partial {
opString = "NOT LIKE"
escapeString = escapeBackslashDirective
matchFmtToUse = matchFmt
} else {
opString = "!="
}
clause = fmt.Sprintf(`%s AND %s %s ?%s`, labelWhereSubClause, targetFieldReference, opString, escapeString)
params = append(params, formatMatchTargetWithFormatter(filter.Matches[0], matchFmtToUse))
return clause, joinClauses, params, nil
case sqltypes.Lt, sqltypes.Gt:
sym, target, err := prepareComparisonParameters(filter.Op, filter.Matches[0])
if err != nil {
return "", nil, nil, err
}
clause := fmt.Sprintf(`%s AND %s %s ?`, labelWhereSubClause, targetFieldReference, sym)
params = append(params, target)
return clause, joinClauses, params, nil
case sqltypes.Exists:
clause := fmt.Sprintf(`%s AND %s != NULL`, labelWhereSubClause, targetFieldReference)
return clause, joinClauses, params, nil
case sqltypes.NotExists:
clause := fmt.Sprintf(`%s AND %s == NULL`, labelWhereSubClause, targetFieldReference)
return clause, joinClauses, params, nil
case sqltypes.In, sqltypes.NotIn:
target := "(?"
if len(filter.Matches) > 0 {
target += strings.Repeat(", ?", len(filter.Matches)-1)
}
target += ")"
op = "IN"
if filter.Op == sqltypes.NotIn {
op = "NOT IN"
}
clause := fmt.Sprintf(`%s AND %s %s %s`, labelWhereSubClause, targetFieldReference, op, target)
for _, match := range filter.Matches {
params = append(params, match)
}
return clause, joinClauses, params, nil
// See getLabelFilter for rest of operators
}
return "", nil, nil, fmt.Errorf("unrecognized operator: %s", opString)
}
func (l *ListOptionIndexer) getIndirectNonLabelFilter(filter sqltypes.Filter, dbName string, joinTableIndexByLabelName map[string]int, joinedTables map[string]bool) (string, []string, []any, error) {
if len(filter.IndirectFields) != 4 {
s := "<empty>"
if len(filter.IndirectFields) > 0 {
s = strings.Join(filter.IndirectFields, " ")
}
return "", nil, nil, fmt.Errorf("expected exactly 4 indirect field parts, got %s (%d)", s, len(filter.IndirectFields))
}
columnName := toColumnName(filter.Field)
if err := l.validateColumn(columnName); err != nil {
return "", nil, nil, err
}
extDBName := fmt.Sprintf("%s_%s", filter.IndirectFields[0], filter.IndirectFields[1])
extDBName = badTableNameChars.ReplaceAllString(extDBName, "_")
extIndex, ok := joinTableIndexByLabelName[extDBName]
if !ok {
extIndex = len(joinTableIndexByLabelName) + 1
joinTableIndexByLabelName[extDBName] = extIndex
}
selectorFieldName := filter.IndirectFields[2]
if badTableNameChars.MatchString(selectorFieldName) {
return "", nil, nil, fmt.Errorf("invalid database column name '%s'", selectorFieldName)
}
externalFieldName := filter.IndirectFields[3]
if badTableNameChars.MatchString(externalFieldName) {
return "", nil, nil, fmt.Errorf("invalid database column name '%s'", externalFieldName)
}
extDBNameFields := fmt.Sprintf("%s_fields", extDBName)
_, ok = joinedTables[extDBNameFields]
joinClauses := make([]string, 0)
if !ok {
joinedTables[extDBNameFields] = true
joinClauses = append(joinClauses, fmt.Sprintf(`JOIN "%s_fields" ext%d ON f."%s" = ext%d."%s"`, extDBName, extIndex, columnName, extIndex, selectorFieldName))
}
params := make([]any, 0)
opString := ""
escapeString := ""
matchFmtToUse := strictMatchFmt
switch filter.Op {
case sqltypes.Eq:
if filter.Partial {
opString = "LIKE"
escapeString = escapeBackslashDirective
matchFmtToUse = matchFmt
} else {
opString = "="
}
clause := fmt.Sprintf(`ext%d."%s" %s ?%s`, extIndex, externalFieldName, opString, escapeString)
return clause, joinClauses, []any{formatMatchTargetWithFormatter(filter.Matches[0], matchFmtToUse)}, nil
case sqltypes.NotEq:
if filter.Partial {
opString = "NOT LIKE"
escapeString = escapeBackslashDirective
matchFmtToUse = matchFmt
} else {
opString = "!="
}
clause := fmt.Sprintf(`ext%d."%s" %s ?%s`, extIndex, externalFieldName, opString, escapeString)
return clause, joinClauses, []any{formatMatchTargetWithFormatter(filter.Matches[0], matchFmtToUse)}, nil
case sqltypes.Lt, sqltypes.Gt:
sym, target, err := prepareComparisonParameters(filter.Op, filter.Matches[0])
if err != nil {
return "", nil, nil, err
}
clause := fmt.Sprintf(`ext%d."%s" %s ?`, extIndex, externalFieldName, sym)
return clause, joinClauses, []any{target}, nil
case sqltypes.Exists:
clause := fmt.Sprintf(`ext%d."%s" != NULL`, extIndex, externalFieldName)
return clause, joinClauses, []any{}, nil
case sqltypes.NotExists:
clause := fmt.Sprintf(`ext%d."%s" == NULL`, extIndex, externalFieldName)
return clause, joinClauses, []any{}, nil
case sqltypes.In, sqltypes.NotIn:
target := "(?"
if len(filter.Matches) > 0 {
target += strings.Repeat(", ?", len(filter.Matches)-1)
}
target += ")"
opString = "IN"
if filter.Op == sqltypes.NotIn {
opString = "NOT IN"
}
clause := fmt.Sprintf(`ext%d."%s" %s %s`, extIndex, externalFieldName, opString, target)
for _, match := range filter.Matches {
params = append(params, match)
}
return clause, joinClauses, params, nil
}
return "", nil, nil, fmt.Errorf("unrecognized operator: %s", opString)
}
func (l *ListOptionIndexer) getLabelFilter(filter sqltypes.Filter, dbName string, joinTableIndexByLabelName map[string]int, joinedTables map[string]bool) (string, []string, []any, error) {
opString := ""
escapeString := ""
matchFmtToUse := strictMatchFmt
labelName := filter.Field[2]
fullName := fmt.Sprintf("%s:%s", dbName, labelName)
labelIndex, ok := joinTableIndexByLabelName[fullName]
if !ok {
return "", nil, nil, fmt.Errorf("internal error: can't find an entry for table %s", fullName)
}
joinClauses := make([]string, 0)
switch filter.Op {
case sqltypes.Eq:
if filter.Partial {
opString = "LIKE"
escapeString = escapeBackslashDirective
matchFmtToUse = matchFmt
} else {
opString = "="
}
clause := fmt.Sprintf(`lt%d.label = ? AND lt%d.value %s ?%s`, labelIndex, labelIndex, opString, escapeString)
return clause, joinClauses, []any{labelName, formatMatchTargetWithFormatter(filter.Matches[0], matchFmtToUse)}, nil
case sqltypes.NotEq:
if filter.Partial {
opString = "NOT LIKE"
escapeString = escapeBackslashDirective
matchFmtToUse = matchFmt
} else {
opString = "!="
}
subFilter := sqltypes.Filter{
Field: filter.Field,
Op: sqltypes.NotExists,
}
existenceClause, _, subParams, err := l.getLabelFilter(subFilter, dbName, joinTableIndexByLabelName, joinedTables)
if err != nil {
return "", nil, nil, err
}
clause := fmt.Sprintf(`(%s) OR (lt%d.label = ? AND lt%d.value %s ?%s)`, existenceClause, labelIndex, labelIndex, opString, escapeString)
params := append(subParams, labelName, formatMatchTargetWithFormatter(filter.Matches[0], matchFmtToUse))
return clause, joinClauses, params, nil
case sqltypes.Lt, sqltypes.Gt:
sym, target, err := prepareComparisonParameters(filter.Op, filter.Matches[0])
if err != nil {
return "", nil, nil, err
}
clause := fmt.Sprintf(`lt%d.label = ? AND lt%d.value %s ?`, labelIndex, labelIndex, sym)
return clause, joinClauses, []any{labelName, target}, nil
case sqltypes.Exists:
clause := fmt.Sprintf(`lt%d.label = ?`, labelIndex)
return clause, joinClauses, []any{labelName}, nil
case sqltypes.NotExists:
clause := fmt.Sprintf(`o.key NOT IN (SELECT o1.key FROM "%s" o1
JOIN "%s_fields" f1 ON o1.key = f1.key
LEFT OUTER JOIN "%s_labels" lt%di1 ON o1.key = lt%di1.key
WHERE lt%di1.label = ?)`, dbName, dbName, dbName, labelIndex, labelIndex, labelIndex)
return clause, joinClauses, []any{labelName}, nil
case sqltypes.In:
target := "(?"
if len(filter.Matches) > 0 {
target += strings.Repeat(", ?", len(filter.Matches)-1)
}
target += ")"
clause := fmt.Sprintf(`lt%d.label = ? AND lt%d.value IN %s`, labelIndex, labelIndex, target)
matches := make([]any, len(filter.Matches)+1)
matches[0] = labelName
for i, match := range filter.Matches {
matches[i+1] = match
}
return clause, joinClauses, matches, nil
case sqltypes.NotIn:
target := "(?"
if len(filter.Matches) > 0 {
target += strings.Repeat(", ?", len(filter.Matches)-1)
}
target += ")"
subFilter := sqltypes.Filter{
Field: filter.Field,
Op: sqltypes.NotExists,
}
existenceClause, _, subParams, err := l.getLabelFilter(subFilter, dbName, joinTableIndexByLabelName, joinedTables)
if err != nil {
return "", nil, nil, err
}
clause := fmt.Sprintf(`(%s) OR (lt%d.label = ? AND lt%d.value NOT IN %s)`, existenceClause, labelIndex, labelIndex, target)
matches := append(subParams, labelName)
for _, match := range filter.Matches {
matches = append(matches, match)
}
return clause, joinClauses, matches, nil
}
return "", nil, nil, fmt.Errorf("unrecognized operator: %s", opString)
}
func (l *ListOptionIndexer) getQueryParts(lo *sqltypes.ListOptions, partitions []partition.Partition, namespace string, dbName string, joinTableIndexByLabelName map[string]int) ([]string, []string, []any, bool, []string, []any, string, error) {
joinParts := []string{fmt.Sprintf(`"%s" o`, dbName), fmt.Sprintf(`JOIN "%s_fields" f ON o.key = f.key`, dbName)}
whereClauses := make([]string, 0)
params := make([]any, 0)
needDistinctFinal := false
joinedTables := make(map[string]bool)
joinedTables[dbName] = true
joinedTables[fmt.Sprintf("%s_fields", dbName)] = true
// 1- Figure out what we'll be joining and testing
for _, orFilters := range lo.Filters {
newWhereClause, newJoinParts, newParams, needDistinct, err := l.buildORClauseFromFilters(orFilters, dbName, joinTableIndexByLabelName, joinedTables)
if err != nil {
return joinParts, whereClauses, params, needDistinctFinal, nil, nil, "", err
}
joinParts = append(joinParts, newJoinParts...)
if len(newWhereClause) > 0 {
whereClauses = append(whereClauses, newWhereClause)
}
params = append(params, newParams...)
if needDistinct {
needDistinctFinal = true
}
}
// WHERE clauses (from namespace)
if namespace != "" && namespace != "*" {
whereClauses = append(whereClauses, fmt.Sprintf(`f."metadata.namespace" = ?`))
params = append(params, namespace)
}
// WHERE clauses (from partitions and their corresponding parameters)
partitionClauses := make([]string, 0)
for _, thisPartition := range partitions {
if thisPartition.Passthrough {
// nothing to do, no extra filtering to apply by definition
} else {
singlePartitionClauses := make([]string, 0)
// filter by namespace
if thisPartition.Namespace != "" && thisPartition.Namespace != "*" {
singlePartitionClauses = append(singlePartitionClauses, fmt.Sprintf(`f."metadata.namespace" = ?`))
params = append(params, thisPartition.Namespace)
}
// optionally filter by names
if !thisPartition.All {
names := thisPartition.Names
if names.Len() == 0 {
// degenerate case, there will be no results
singlePartitionClauses = append(singlePartitionClauses, "FALSE")
} else {
singlePartitionClauses = append(singlePartitionClauses, fmt.Sprintf(`f."metadata.name" IN (?%s)`, strings.Repeat(", ?", thisPartition.Names.Len()-1)))
// sort for reproducibility
sortedNames := thisPartition.Names.UnsortedList()
sort.Strings(sortedNames)
for _, name := range sortedNames {
params = append(params, name)
}
}
}
if len(singlePartitionClauses) > 0 {
partitionClauses = append(partitionClauses, strings.Join(singlePartitionClauses, " AND "))
}
}
}
if len(partitions) == 0 {
// degenerate case, there will be no results
whereClauses = append(whereClauses, "FALSE")
}
if len(partitionClauses) == 1 {
whereClauses = append(whereClauses, partitionClauses[0])
}
if len(partitionClauses) > 1 {
whereClauses = append(whereClauses, "(\n ("+strings.Join(partitionClauses, ") OR\n (")+")\n)")
}
sortSelectField, sortJoinClauses, sortWhereClauses, orderByClauses, orderByParams, err := l.getSortDirectives(lo, dbName, joinTableIndexByLabelName)
joinParts = append(joinParts, sortJoinClauses...)
whereClauses = append(whereClauses, sortWhereClauses...)
return joinParts, whereClauses, params, needDistinctFinal, orderByClauses, orderByParams, sortSelectField, err
}
func (l *ListOptionIndexer) getSortDirectives(lo *sqltypes.ListOptions, dbName string, joinTableIndexByLabelName map[string]int) (string, []string, []string, []string, []any, error) {
sortSelectField := ""
sortJoinClauses := make([]string, 0)
sortWhereClauses := make([]string, 0)
orderByClauses := make([]string, 0)
orderByParams := make([]any, 0)
if len(lo.SortList.SortDirectives) == 0 {
// make sure at least one default order is always picked
orderByClauses = append(orderByClauses, `f."metadata.name" ASC`)
if l.namespaced {
orderByClauses = append(orderByClauses, `f."metadata.namespace" ASC`)
}
return sortSelectField, sortJoinClauses, sortWhereClauses, orderByClauses, orderByParams, nil
}
for _, sortDirective := range lo.SortList.SortDirectives {
fields := sortDirective.Fields
if isLabelsFieldList(fields) {
labelName := sortDirective.Fields[2]
fullName := fmt.Sprintf("%s:%s", dbName, labelName)
labelIndex, ok := joinTableIndexByLabelName[fullName]
if !ok {
if sortDirective.IsIndirect {
return sortSelectField, sortJoinClauses, sortWhereClauses, orderByClauses, orderByParams, fmt.Errorf(`internal error: no join-table index given for labelName "%s"`, labelName)
}
labelIndex = len(joinTableIndexByLabelName) + 1
joinTableIndexByLabelName[fullName] = labelIndex
}
if sortDirective.IsIndirect {
//TODO: check the external table name.
externalTableName := getExternalTableName(&sortDirective)
extIndex, ok := joinTableIndexByLabelName[externalTableName]
if !ok {
extIndex = len(joinTableIndexByLabelName) + 1
joinTableIndexByLabelName[externalTableName] = extIndex
}
selectorFieldName := sortDirective.IndirectFields[2]
if badTableNameChars.MatchString(selectorFieldName) {
return sortSelectField, sortJoinClauses, sortWhereClauses, orderByClauses, orderByParams, fmt.Errorf("invalid database column name '%s'", selectorFieldName)
}
externalFieldName := sortDirective.IndirectFields[3]
if badTableNameChars.MatchString(externalFieldName) {
return sortSelectField, sortJoinClauses, sortWhereClauses, orderByClauses, orderByParams, fmt.Errorf("invalid database column name '%s'", externalFieldName)
}
sortJoinClauses = append(sortJoinClauses, fmt.Sprintf(`JOIN "%s_fields" ext%d ON lt%d.value = ext%d."%s"`, externalTableName, extIndex, labelIndex, extIndex, selectorFieldName))
//TODO: Verify the field name
sortSelectField = fmt.Sprintf(`ext%d."%s" as ext%d_target`, extIndex, externalFieldName, extIndex)
}
clause, sortParam, err := buildSortLabelsClause(fields[2], dbName, joinTableIndexByLabelName, sortDirective.Order == sqltypes.ASC)
if err != nil {
return sortSelectField, sortJoinClauses, sortWhereClauses, orderByClauses, orderByParams, err
}
orderByClauses = append(orderByClauses, clause)
orderByParams = append(orderByParams, sortParam)
} else if sortDirective.IsIndirect {
if len(sortDirective.IndirectFields) != 4 {
return sortSelectField, sortJoinClauses, sortWhereClauses, orderByClauses, orderByParams, fmt.Errorf("expected indirect sort directive to have 4 indirect fields, got %d", len(sortDirective.IndirectFields))
}
externalTableName := getExternalTableName(&sortDirective)
extIndex, ok := joinTableIndexByLabelName[externalTableName]
if !ok {
extIndex = len(joinTableIndexByLabelName) + 1
joinTableIndexByLabelName[externalTableName] = extIndex
}
selectorFieldName := sortDirective.IndirectFields[2]
if badTableNameChars.MatchString(selectorFieldName) {
return sortSelectField, sortJoinClauses, sortWhereClauses, orderByClauses, orderByParams, fmt.Errorf("invalid database column name '%s'", selectorFieldName)
}
externalFieldName := sortDirective.IndirectFields[3]
if badTableNameChars.MatchString(externalFieldName) {
return sortSelectField, sortJoinClauses, sortWhereClauses, orderByClauses, orderByParams, fmt.Errorf("invalid database column name '%s'", externalFieldName)
}
columnName := toColumnName(fields)
if err := l.validateColumn(columnName); err != nil {
return sortSelectField, sortJoinClauses, sortWhereClauses, orderByClauses, orderByParams, err
}
sortJoinClauses = append(sortJoinClauses, fmt.Sprintf(`JOIN "%s_fields" ext%d ON f."%s" = ext%d."%s"`, externalTableName, extIndex, columnName, extIndex, selectorFieldName))
direction := "ASC"
nullsPlace := "LAST"
if sortDirective.Order == sqltypes.DESC {
direction = "DESC"
nullsPlace = "FIRST"
}
orderByClauses = append(orderByClauses, fmt.Sprintf(`ext%d."%s" %s NULLS %s`, extIndex, externalFieldName, direction, nullsPlace))
} else {
columnName := toColumnName(fields)
if err := l.validateColumn(columnName); err != nil {
return sortSelectField, sortJoinClauses, sortWhereClauses, orderByClauses, orderByParams, err
}
direction := "ASC"
if sortDirective.Order == sqltypes.DESC {
direction = "DESC"
}
orderByClauses = append(orderByClauses, fmt.Sprintf(`f."%s" %s`, columnName, direction))
}
}
return sortSelectField, sortJoinClauses, sortWhereClauses, orderByClauses, orderByParams, nil
}
func (l *ListOptionIndexer) validateColumn(column string) error {
for _, v := range l.indexedFields {
if v == column {
return nil
}
}
return fmt.Errorf("column is invalid [%s]: %w", column, ErrInvalidColumn)
}
// Helper functions for generating SQL in alphabetical order:
func applyIndirectLabelTests(loWithLabel *sqltypes.ListOptions, loWithoutLabel *sqltypes.ListOptions, indirectSortDirective *sqltypes.Sort) {
labelFilter := sqltypes.Filter{
Field: indirectSortDirective.Fields[:],
//Matches: make([]string, 0),
Op: sqltypes.Exists,
}
loWithLabel.Filters = append(loWithLabel.Filters, sqltypes.OrFilter{Filters: []sqltypes.Filter{labelFilter}})
// And add an AND-test that the label does not exists for the second test
nonLabelFilter := labelFilter
nonLabelFilter.Op = sqltypes.NotExists
loWithoutLabel.Filters = append(loWithoutLabel.Filters, sqltypes.OrFilter{Filters: []sqltypes.Filter{nonLabelFilter}})
}
func buildSortLabelsClause(labelName string, dbName string, joinTableIndexByLabelName map[string]int, isAsc bool) (string, string, error) {
fullName := fmt.Sprintf("%s:%s", dbName, labelName)
labelIndex, ok := joinTableIndexByLabelName[fullName]
if !ok {
return "", "", fmt.Errorf(`internal error: no join-table index given for labelName "%s"`, labelName)
}
stmt := fmt.Sprintf(`CASE lt%d.label WHEN ? THEN lt%d.value ELSE NULL END`, labelIndex, labelIndex)
dir := "ASC"
nullsPosition := "LAST"
if !isAsc {
dir = "DESC"
nullsPosition = "FIRST"
}
return fmt.Sprintf("(%s) %s NULLS %s", stmt, dir, nullsPosition), labelName, nil
}
func checkForIndirectSortDirective(lo *sqltypes.ListOptions) (*sqltypes.Sort, error) {
indirectSortDirectives := make([]string, 0)
var id *sqltypes.Sort
for _, sd := range lo.SortList.SortDirectives {
if sd.IsIndirect {
id = &sd
indirectSortDirectives = append(indirectSortDirectives, fmt.Sprintf("[%s]", strings.Join(sd.IndirectFields, "][")))
}
}
if len(indirectSortDirectives) > 1 {
return nil, fmt.Errorf("can have at most one indirect sort directive, have %d: %s", len(indirectSortDirectives), indirectSortDirectives)
}
return id, nil
}
func formatMatchTarget(filter sqltypes.Filter) string {
format := strictMatchFmt
if filter.Partial {
format = matchFmt
}
return formatMatchTargetWithFormatter(filter.Matches[0], format)
}
func formatMatchTargetWithFormatter(match string, format string) string {
// To allow matches on the backslash itself, the character needs to be replaced first.
// Otherwise, it will undo the following replacements.
match = strings.ReplaceAll(match, `\`, `\\`)
match = strings.ReplaceAll(match, `_`, `\_`)
match = strings.ReplaceAll(match, `%`, `\%`)
return fmt.Sprintf(format, match)
}
func getExternalTableName(sd *sqltypes.Sort) string {
s := strings.Join(sd.IndirectFields[0:2], "_")
return strings.ReplaceAll(s, "/", "_")
}
func isIndirectFilter(filter *sqltypes.Filter) bool {
return filter.IsIndirect
}
func isLabelFilter(f *sqltypes.Filter) bool {
return len(f.Field) >= 2 && f.Field[0] == "metadata" && f.Field[1] == "labels"
}
func isLabelsFieldList(fields []string) bool {
return len(fields) == 3 && fields[0] == "metadata" && fields[1] == "labels"
}
func joinWhereClauses(whereClauses []string, leadingIndent string, continuingIndent string, op string) string {
switch len(whereClauses) {
case 0:
return ""
case 1:
return fmt.Sprintf("%sWHERE %s\n", leadingIndent, whereClauses[0])
}
separator := fmt.Sprintf(") %s\n%s(", op, continuingIndent)
return fmt.Sprintf("%sWHERE (%s)\n", leadingIndent, strings.Join(whereClauses, separator))
}
func prepareComparisonParameters(op sqltypes.Op, target string) (string, float64, error) {
num, err := strconv.ParseFloat(target, 32)
if err != nil {
return "", 0, err
}
switch op {
case sqltypes.Lt:
return "<", num, nil
case sqltypes.Gt:
return ">", num, nil
}
return "", 0, fmt.Errorf("unrecognized operator when expecting '<' or '>': '%s'", op)
}
func processOrderByFields(sd *sqltypes.Sort, extIndex int, orderByClauses []string) ([]string, []string, []string) {
sortFieldMap := make(map[string]string)
externalFieldName := sd.IndirectFields[3]
newName := fmt.Sprintf("__ix_ext%d_%s", extIndex, nonIdentifierChars.ReplaceAllString(externalFieldName, "_"))
sortFieldMap[externalFieldName] = newName
sortParts := make([]string, 1+len(orderByClauses))
direction := "ASC"
nullPosition := "LAST"
if sd.Order == sqltypes.DESC {
direction = "DESC"
nullPosition = "FIRST"
}
sortParts[0] = fmt.Sprintf("%s %s NULLS %s", newName, direction, nullPosition)
importWithParts := make([]string, 1+len(orderByClauses))
importWithParts[0] = fmt.Sprintf(`ext%d."%s" AS %s`, extIndex, externalFieldName, newName)
importAsNullParts := make([]string, 1+len(orderByClauses))
importAsNullParts[0] = fmt.Sprintf("NULL AS %s", newName)
for i, clause := range orderByClauses {
orderParts := strings.SplitN(clause, " ", 2)
fieldName := orderParts[0]
_, ok := sortFieldMap[fieldName]
if ok {
continue
}
fieldParts := strings.SplitN(fieldName, ".", 2)
prefix := fieldParts[0]
baseName := fieldParts[1]
if baseName[0] == '"' {
baseName = baseName[1 : len(baseName)-1]
}
newBaseName := nonIdentifierChars.ReplaceAllString(baseName, "_")
newName := fmt.Sprintf("__ix_%s_%s", prefix, newBaseName)
sortFieldMap[fieldName] = newName
importWithParts[i+1] = fmt.Sprintf("%s AS %s", fieldName, newName)
importAsNullParts[i+1] = importWithParts[i+1]
sortParts[i+1] = fmt.Sprintf("%s %s", newName, orderParts[1])
}
return sortParts, importWithParts, importAsNullParts
}
// There are two kinds of string arrays to turn into a string, based on the last value in the array
// simple: ["a", "b", "conformsToIdentifier"] => "a.b.conformsToIdentifier"
// complex: ["a", "b", "foo.io/stuff"] => "a.b[foo.io/stuff]"
func smartJoin(s []string) string {
if len(s) == 0 {
return ""
}
if len(s) == 1 {
return s[0]
}
lastBit := s[len(s)-1]
simpleName := regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
if simpleName.MatchString(lastBit) {
return strings.Join(s, ".")
}
return fmt.Sprintf("%s[%s]", strings.Join(s[0:len(s)-1], "."), lastBit)
}