diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation.go index 0167c817a78..fae806d3318 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation.go +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation.go @@ -117,17 +117,17 @@ func Compile(s *schema.Structural, isResourceRoot bool, perCallLimit uint64) ([] if err != nil { return nil, err } - estimator := celEstimator{root: root} + estimator := newCostEstimator(root) // compResults is the return value which saves a list of compilation results in the same order as x-kubernetes-validations rules. compResults := make([]CompilationResult, len(celRules)) for i, rule := range celRules { - compResults[i] = compileRule(rule, env, perCallLimit, &estimator) + compResults[i] = compileRule(rule, env, perCallLimit, estimator) } return compResults, nil } -func compileRule(rule apiextensions.ValidationRule, env *cel.Env, perCallLimit uint64, estimator *celEstimator) (compilationResult CompilationResult) { +func compileRule(rule apiextensions.ValidationRule, env *cel.Env, perCallLimit uint64, estimator *library.CostEstimator) (compilationResult CompilationResult) { if len(strings.TrimSpace(rule.Rule)) == 0 { // include a compilation result, but leave both program and error nil per documented return semantics of this // function @@ -157,7 +157,13 @@ func compileRule(rule apiextensions.ValidationRule, env *cel.Env, perCallLimit u } // TODO: Ideally we could configure the per expression limit at validation time and set it to the remaining overall budget, but we would either need a way to pass in a limit at evaluation time or move program creation to validation time - prog, err := env.Program(ast, cel.EvalOptions(cel.OptOptimize, cel.OptTrackCost), cel.CostLimit(perCallLimit), cel.InterruptCheckFrequency(checkFrequency)) + prog, err := env.Program(ast, + cel.EvalOptions(cel.OptOptimize, cel.OptTrackCost), + cel.CostLimit(perCallLimit), + cel.CostTracking(estimator), + cel.OptimizeRegex(library.ExtensionLibRegexOptimizations...), + cel.InterruptCheckFrequency(checkFrequency), + ) if err != nil { compilationResult.Error = &Error{ErrorTypeInvalid, "program instantiation failed: " + err.Error()} return @@ -180,11 +186,15 @@ func generateUniqueSelfTypeName() string { return fmt.Sprintf("selfType%d", time.Now().Nanosecond()) } -type celEstimator struct { +func newCostEstimator(root *celmodel.DeclType) *library.CostEstimator { + return &library.CostEstimator{SizeEstimator: &sizeEstimator{root: root}} +} + +type sizeEstimator struct { root *celmodel.DeclType } -func (c *celEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { +func (c *sizeEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { if len(element.Path()) == 0 { // Path() can return an empty list, early exit if it does since we can't // provide size estimates when that happens @@ -218,6 +228,6 @@ func (c *celEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstima return &checker.SizeEstimate{Min: 0, Max: uint64(currentNode.MaxElements)} } -func (c *celEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { +func (c *sizeEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { return nil } diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation_test.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation_test.go index 112b1d4ea76..c0b914dae99 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation_test.go +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/compilation_test.go @@ -1073,7 +1073,10 @@ func genMapWithCustomItemRule(item *schema.Structural, rule string) func(maxProp } } -func schemaChecker(schema *schema.Structural, expectedCost uint64, calcLimit uint64, t *testing.T) func(t *testing.T) { +// schemaChecker checks the cost of the validation rule declared in the provided schema (it requires there be exactly one rule) +// and checks that the resulting equals the expectedCost if expectedCost is non-zero, and that the resulting cost is >= expectedCostExceedsLimit +// if expectedCostExceedsLimit is non-zero. Typically, only expectedCost or expectedCostExceedsLimit is non-zero, not both. +func schemaChecker(schema *schema.Structural, expectedCost uint64, expectedCostExceedsLimit uint64, t *testing.T) func(t *testing.T) { return func(t *testing.T) { // TODO(DangerOnTheRanger): if perCallLimit in compilation.go changes, this needs to change as well compilationResults, err := Compile(schema, false, uint64(math.MaxInt64)) @@ -1087,13 +1090,14 @@ func schemaChecker(schema *schema.Structural, expectedCost uint64, calcLimit uin if result.Error != nil { t.Errorf("Expected no compile-time error, got: %v", result.Error) } - if calcLimit == 0 { + if expectedCost > 0 { if result.MaxCost != expectedCost { t.Errorf("Wrong cost (expected %d, got %d)", expectedCost, result.MaxCost) } - } else { - if result.MaxCost < calcLimit { - t.Errorf("Cost did not exceed limit as expected (expected more than %d, got %d)", calcLimit, result.MaxCost) + } + if expectedCostExceedsLimit > 0 { + if result.MaxCost < expectedCostExceedsLimit { + t.Errorf("Cost did not exceed limit as expected (expected more than %d, got %d)", expectedCostExceedsLimit, result.MaxCost) } } } @@ -1101,12 +1105,17 @@ func schemaChecker(schema *schema.Structural, expectedCost uint64, calcLimit uin func TestCostEstimation(t *testing.T) { cases := []struct { - name string - schemaGenerator func(maxLength *int64) *schema.Structural + name string + schemaGenerator func(maxLength *int64) *schema.Structural + setMaxElements int64 + + // calc costs expectations are checked against the generated schema without any max element limits set expectedCalcCost uint64 - setMaxElements int64 - expectedSetCost uint64 expectCalcCostExceedsLimit uint64 + + // calc costs expectations are checked against the generated schema with max element limits set + expectedSetCost uint64 + expectedSetCostExceedsLimit uint64 }{ { name: "number array with all", @@ -1233,7 +1242,6 @@ func TestCostEstimation(t *testing.T) { { name: "O(n^2) loop with numbers", schemaGenerator: genArrayWithRule("number", "self.all(x, self.all(y, true))"), - expectedCalcCost: 9895601504256, expectCalcCostExceedsLimit: costLimit, setMaxElements: 10, expectedSetCost: 352, @@ -1241,7 +1249,6 @@ func TestCostEstimation(t *testing.T) { { name: "O(n^3) loop with numbers", schemaGenerator: genArrayWithRule("number", "self.all(x, self.all(y, self.all(z, true)))"), - expectedCalcCost: 13499986500008999998, expectCalcCostExceedsLimit: costLimit, setMaxElements: 10, expectedSetCost: 3552, @@ -1512,6 +1519,80 @@ func TestCostEstimation(t *testing.T) { setMaxElements: 490, expectedSetCost: 0, }, + // Ensure library functions are integrated with size estimates by testing the interesting cases. + { + name: "extended library regex find", + schemaGenerator: genStringWithRule("self.find('[0-9]+') == ''"), + expectedCalcCost: 629147, + setMaxElements: 10, + expectedSetCost: 11, + }, + { + name: "extended library join", + schemaGenerator: func(max *int64) *schema.Structural { + strType := withMaxLength(primitiveType("string", ""), max) + array := withMaxItems(arrayType("atomic", nil, &strType), max) + array = withRule(array, "self.join(' ') == 'aa bb'") + return &array + }, + expectedCalcCost: 329853068905, + setMaxElements: 10, + expectedSetCost: 43, + }, + { + name: "extended library isSorted", + schemaGenerator: func(max *int64) *schema.Structural { + strType := withMaxLength(primitiveType("string", ""), max) + array := withMaxItems(arrayType("atomic", nil, &strType), max) + array = withRule(array, "self.isSorted() == true") + return &array + }, + expectedCalcCost: 329854432052, + setMaxElements: 10, + expectedSetCost: 52, + }, + { + name: "extended library replace", + schemaGenerator: func(max *int64) *schema.Structural { + strType := withMaxLength(primitiveType("string", ""), max) + objType := objectType(map[string]schema.Structural{ + "str": strType, + "before": strType, + "after": strType, + }) + objType = withRule(objType, "self.str.replace(self.before, self.after) == 'does not matter'") + return &objType + }, + expectedCalcCost: 629154, + setMaxElements: 10, + expectedSetCost: 16, + }, + { + name: "extended library split", + schemaGenerator: func(max *int64) *schema.Structural { + strType := withMaxLength(primitiveType("string", ""), max) + objType := objectType(map[string]schema.Structural{ + "str": strType, + "separator": strType, + }) + objType = withRule(objType, "self.str.split(self.separator) == []") + return &objType + }, + expectedCalcCost: 629160, + setMaxElements: 10, + expectedSetCost: 22, + }, + { + name: "extended library lowerAscii", + schemaGenerator: func(max *int64) *schema.Structural { + strType := withMaxLength(primitiveType("string", ""), max) + strType = withRule(strType, "self.lowerAscii() == 'lower!'") + return &strType + }, + expectedCalcCost: 314575, + setMaxElements: 10, + expectedSetCost: 6, + }, } for _, testCase := range cases { t.Run(testCase.name, func(t *testing.T) { @@ -1520,7 +1601,7 @@ func TestCostEstimation(t *testing.T) { t.Run("calc maxLength", schemaChecker(schema, testCase.expectedCalcCost, testCase.expectCalcCostExceedsLimit, t)) // static maxLength case setSchema := testCase.schemaGenerator(&testCase.setMaxElements) - t.Run("set maxLength", schemaChecker(setSchema, testCase.expectedSetCost, 0, t)) + t.Run("set maxLength", schemaChecker(setSchema, testCase.expectedSetCost, testCase.expectedSetCostExceedsLimit, t)) }) } } diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library/cost.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library/cost.go new file mode 100644 index 00000000000..39098e3f605 --- /dev/null +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library/cost.go @@ -0,0 +1,268 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package library + +import ( + "math" + + "github.com/google/cel-go/checker" + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +// CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator. +type CostEstimator struct { + // SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation + // calculations to if the size is not well known (i.e. a constant). + SizeEstimator checker.CostEstimator +} + +func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, result ref.Val) *uint64 { + switch function { + case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf": + var cost uint64 + if len(args) > 0 { + cost += traversalCost(args[0]) // these O(n) operations all cost roughly the cost of a single traversal + } + return &cost + case "url", "lowerAscii", "upperAscii", "substring", "trim": + if len(args) >= 1 { + cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor)) + return &cost + } + case "replace", "split": + if len(args) >= 1 { + // cost is the traversal plus the construction of the result + cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor)) + return &cost + } + case "join": + if len(args) >= 1 { + cost := uint64(math.Ceil(float64(actualSize(result)) * 2 * common.StringTraversalCostFactor)) + return &cost + } + case "find", "findAll": + if len(args) >= 2 { + strCost := uint64(math.Ceil((1.0 + float64(actualSize(args[0]))) * common.StringTraversalCostFactor)) + // We don't know how many expressions are in the regex, just the string length (a huge + // improvement here would be to somehow get a count the number of expressions in the regex or + // how many states are in the regex state machine and use that to measure regex cost). + // For now, we're making a guess that each expression in a regex is typically at least 4 chars + // in length. + regexCost := uint64(math.Ceil(float64(actualSize(args[1])) * common.RegexStringLengthCostFactor)) + cost := strCost * regexCost + return &cost + } + } + return nil +} + +func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + // WARNING: Any changes to this code impact API compatibility! The estimated cost is used to determine which CEL rules may be written to a + // CRD and any change (cost increases and cost decreases) are breaking. + switch function { + case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf": + if target != nil { + // Charge 1 cost for comparing each element in the list + elCost := checker.CostEstimate{Min: 1, Max: 1} + // If the list contains strings or bytes, add the cost of traversing all the strings/bytes as a way + // of estimating the additional comparison cost. + if elNode := l.listElementNode(*target); elNode != nil { + t := elNode.Type().GetPrimitive() + if t == exprpb.Type_STRING || t == exprpb.Type_BYTES { + sz := l.sizeEstimate(elNode) + elCost = elCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor)) + } + return &checker.CallEstimate{CostEstimate: l.sizeEstimate(*target).MultiplyByCost(elCost)} + } else { // the target is a string, which is supported by indexOf and lastIndexOf + return &checker.CallEstimate{CostEstimate: l.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor)} + } + + } + case "url": + if len(args) == 1 { + sz := l.sizeEstimate(args[0]) + return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)} + } + case "lowerAscii", "upperAscii", "substring", "trim": + if target != nil { + sz := l.sizeEstimate(*target) + return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz} + } + case "replace": + if target != nil && len(args) >= 2 { + sz := l.sizeEstimate(*target) + toReplaceSz := l.sizeEstimate(args[0]) + replaceWithSz := l.sizeEstimate(args[1]) + // smallest possible result: smallest input size composed of the largest possible substrings being replaced by smallest possible replacement + minSz := uint64(math.Ceil(float64(sz.Min)/float64(toReplaceSz.Max))) * replaceWithSz.Min + // largest possible result: largest input size composed of the smallest possible substrings being replaced by largest possible replacement + maxSz := uint64(math.Ceil(float64(sz.Max)/float64(toReplaceSz.Min))) * replaceWithSz.Max + + // cost is the traversal plus the construction of the result + return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor), ResultSize: &checker.SizeEstimate{Min: minSz, Max: maxSz}} + } + case "split": + if target != nil { + sz := l.sizeEstimate(*target) + + // Worst case size is where is that a separator of "" is used, and each char is returned as a list element. + max := sz.Max + if len(args) > 1 { + if c := args[1].Expr().GetConstExpr(); c != nil { + max = uint64(c.GetInt64Value()) + } + } + // Cost is the traversal plus the construction of the result. + return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor), ResultSize: &checker.SizeEstimate{Min: 0, Max: max}} + } + case "join": + if target != nil { + var sz checker.SizeEstimate + listSize := l.sizeEstimate(*target) + if elNode := l.listElementNode(*target); elNode != nil { + elemSize := l.sizeEstimate(elNode) + sz = listSize.Multiply(elemSize) + } + + if len(args) > 0 { + sepSize := l.sizeEstimate(args[0]) + minSeparators := uint64(0) + maxSeparators := uint64(0) + if listSize.Min > 0 { + minSeparators = listSize.Min - 1 + } + if listSize.Max > 0 { + maxSeparators = listSize.Max - 1 + } + sz = sz.Add(sepSize.Multiply(checker.SizeEstimate{Min: minSeparators, Max: maxSeparators})) + } + + return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz} + } + case "find", "findAll": + if target != nil && len(args) >= 1 { + sz := l.sizeEstimate(*target) + // Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0 + // in case where string is empty but regex is still expensive. + strCost := sz.Add(checker.SizeEstimate{Min: 1, Max: 1}).MultiplyByCostFactor(common.StringTraversalCostFactor) + // We don't know how many expressions are in the regex, just the string length (a huge + // improvement here would be to somehow get a count the number of expressions in the regex or + // how many states are in the regex state machine and use that to measure regex cost). + // For now, we're making a guess that each expression in a regex is typically at least 4 chars + // in length. + regexCost := l.sizeEstimate(args[0]).MultiplyByCostFactor(common.RegexStringLengthCostFactor) + // worst case size of result is that every char is returned as separate find result. + return &checker.CallEstimate{CostEstimate: strCost.Multiply(regexCost), ResultSize: &checker.SizeEstimate{Min: 0, Max: sz.Max}} + } + } + return nil +} + +func actualSize(value ref.Val) uint64 { + if sz, ok := value.(traits.Sizer); ok { + return uint64(sz.Size().(types.Int)) + } + return 1 +} + +func (l *CostEstimator) sizeEstimate(t checker.AstNode) checker.SizeEstimate { + if sz := t.ComputedSize(); sz != nil { + return *sz + } + if sz := l.EstimateSize(t); sz != nil { + return *sz + } + return checker.SizeEstimate{Min: 0, Max: math.MaxUint64} +} + +func (l *CostEstimator) listElementNode(list checker.AstNode) checker.AstNode { + if lt := list.Type().GetListType(); lt != nil { + nodePath := list.Path() + if nodePath != nil { + // Provide path if we have it so that a OpenAPIv3 maxLength validation can be looked up, if it exists + // for this node. + path := make([]string, len(nodePath)+1) + copy(path, nodePath) + path[len(nodePath)] = "@items" + return &itemsNode{path: path, t: lt.GetElemType(), expr: nil} + } else { + // Provide just the type if no path is available so that worst case size can be looked up based on type. + return &itemsNode{t: lt.GetElemType(), expr: nil} + } + } + return nil +} + +func (l *CostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { + if l.SizeEstimator != nil { + return l.SizeEstimator.EstimateSize(element) + } + return nil +} + +type itemsNode struct { + path []string + t *exprpb.Type + expr *exprpb.Expr +} + +func (i *itemsNode) Path() []string { + return i.path +} + +func (i *itemsNode) Type() *exprpb.Type { + return i.t +} + +func (i *itemsNode) Expr() *exprpb.Expr { + return i.expr +} + +func (i *itemsNode) ComputedSize() *checker.SizeEstimate { + return nil +} + +// traversalCost computes the cost of traversing a ref.Val as a data tree. +func traversalCost(v ref.Val) uint64 { + // TODO: This could potentially be optimized by sampling maps and lists instead of traversing. + switch vt := v.(type) { + case types.String: + return uint64(float64(len(string(vt))) * common.StringTraversalCostFactor) + case types.Bytes: + return uint64(float64(len([]byte(vt))) * common.StringTraversalCostFactor) + case traits.Lister: + cost := uint64(0) + for it := vt.Iterator(); it.HasNext() == types.True; { + i := it.Next() + cost += traversalCost(i) + } + return cost + case traits.Mapper: // maps and objects + cost := uint64(0) + for it := vt.Iterator(); it.HasNext() == types.True; { + k := it.Next() + cost += traversalCost(k) + traversalCost(vt.Get(k)) + } + return cost + default: + return 1 + } +} diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library/cost_test.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library/cost_test.go new file mode 100644 index 00000000000..0b1e0020c63 --- /dev/null +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library/cost_test.go @@ -0,0 +1,363 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package library + +import ( + "fmt" + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/checker" + "github.com/google/cel-go/ext" + expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +const ( + intListLiteral = "[1, 2, 3, 4, 5]" + uintListLiteral = "[uint(1), uint(2), uint(3), uint(4), uint(5)]" + doubleListLiteral = "[1.0, 2.0, 3.0, 4.0, 5.0]" + boolListLiteral = "[false, true, false, true, false]" + stringListLiteral = "['012345678901', '012345678901', '012345678901', '012345678901', '012345678901']" + bytesListLiteral = "[bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901')]" + durationListLiteral = "[duration('1s'), duration('2s'), duration('3s'), duration('4s'), duration('5s')]" + timestampListLiteral = "[timestamp('2011-01-01T00:00:00.000+01:00'), timestamp('2011-01-02T00:00:00.000+01:00'), " + + "timestamp('2011-01-03T00:00:00.000+01:00'), timestamp('2011-01-04T00:00:00.000+01:00'), " + + "timestamp('2011-01-05T00:00:00.000+01:00')]" + stringLiteral = "'01234567890123456789012345678901234567890123456789'" +) + +type comparableCost struct { + comparableLiteral string + expectedEstimatedCost checker.CostEstimate + expectedRuntimeCost uint64 + + param string +} + +func TestListsCost(t *testing.T) { + cases := []struct { + opts []string + costs []comparableCost + }{ + { + opts: []string{".sum()"}, + // 10 cost for the list declaration, the rest is the due to the function call + costs: []comparableCost{ + { + comparableLiteral: intListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: uintListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts + }, + { + comparableLiteral: doubleListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: durationListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts + }, + }, + }, + { + opts: []string{".isSorted()", ".max()", ".min()"}, + // 10 cost for the list declaration, the rest is the due to the function call + costs: []comparableCost{ + { + comparableLiteral: intListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: uintListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for numeric casts + }, + { + comparableLiteral: doubleListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: boolListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: stringListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 25}, expectedRuntimeCost: 15, // +5 for string comparisons + }, + { + comparableLiteral: bytesListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 25, Max: 35}, expectedRuntimeCost: 25, // +10 for casts from string to byte, +5 for byte comparisons + }, + { + comparableLiteral: durationListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for numeric casts + }, + { + comparableLiteral: timestampListLiteral, + expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts + }, + }, + }, + } + for _, tc := range cases { + for _, op := range tc.opts { + for _, typ := range tc.costs { + t.Run(typ.comparableLiteral+op, func(t *testing.T) { + e := typ.comparableLiteral + op + testCost(t, e, typ.expectedEstimatedCost, typ.expectedRuntimeCost) + }) + } + } + } +} + +func TestIndexOfCost(t *testing.T) { + cases := []struct { + opts []string + costs []comparableCost + }{ + { + opts: []string{".indexOf(%s)", ".lastIndexOf(%s)"}, + // 10 cost for the list declaration, the rest is the due to the function call + costs: []comparableCost{ + { + comparableLiteral: intListLiteral, param: "3", + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: uintListLiteral, param: "uint(3)", + expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +5 for numeric casts + }, + { + comparableLiteral: doubleListLiteral, param: "3.0", + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: boolListLiteral, param: "true", + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15, + }, + { + comparableLiteral: stringListLiteral, param: "'x'", + expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 25}, expectedRuntimeCost: 15, // +5 for string comparisons + }, + { + comparableLiteral: bytesListLiteral, param: "bytes('x')", + expectedEstimatedCost: checker.CostEstimate{Min: 26, Max: 36}, expectedRuntimeCost: 26, // +11 for casts from string to byte, +5 for byte comparisons + }, + { + comparableLiteral: durationListLiteral, param: "duration('3s')", + expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +6 for casts from duration to byte + }, + { + comparableLiteral: timestampListLiteral, param: "timestamp('2011-01-03T00:00:00.000+01:00')", + expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +6 for casts from timestamp to byte + }, + + // index of operations are also defined for strings + { + comparableLiteral: stringLiteral, param: "'123'", + expectedEstimatedCost: checker.CostEstimate{Min: 5, Max: 5}, expectedRuntimeCost: 5, + }, + }, + }, + } + for _, tc := range cases { + for _, op := range tc.opts { + for _, typ := range tc.costs { + opWithParam := fmt.Sprintf(op, typ.param) + t.Run(typ.comparableLiteral+opWithParam, func(t *testing.T) { + e := typ.comparableLiteral + opWithParam + testCost(t, e, typ.expectedEstimatedCost, typ.expectedRuntimeCost) + }) + } + } + } +} + +func TestURLsCost(t *testing.T) { + cases := []struct { + ops []string + expectEsimatedCost checker.CostEstimate + expectRuntimeCost uint64 + }{ + { + ops: []string{".getScheme()", ".getHostname()", ".getHost()", ".getPort()", ".getEscapedPath()", ".getQuery()"}, + expectEsimatedCost: checker.CostEstimate{Min: 4, Max: 4}, + expectRuntimeCost: 4, + }, + } + + for _, tc := range cases { + for _, op := range tc.ops { + t.Run("url."+op, func(t *testing.T) { + testCost(t, "url('https:://kubernetes.io/')"+op, tc.expectEsimatedCost, tc.expectRuntimeCost) + }) + } + } +} + +func TestStringLibrary(t *testing.T) { + cases := []struct { + name string + expr string + expectEsimatedCost checker.CostEstimate + expectRuntimeCost uint64 + }{ + { + name: "lowerAscii", + expr: "'ABCDEFGHIJ abcdefghij'.lowerAscii()", + expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectRuntimeCost: 3, + }, + { + name: "upperAscii", + expr: "'ABCDEFGHIJ abcdefghij'.upperAscii()", + expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectRuntimeCost: 3, + }, + { + name: "replace", + expr: "'abc 123 def 123'.replace('123', '456')", + expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectRuntimeCost: 3, + }, + { + name: "replace with limit", + expr: "'abc 123 def 123'.replace('123', '456', 1)", + expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectRuntimeCost: 3, + }, + { + name: "split", + expr: "'abc 123 def 123'.split(' ')", + expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectRuntimeCost: 3, + }, + { + name: "split with limit", + expr: "'abc 123 def 123'.split(' ', 1)", + expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectRuntimeCost: 3, + }, + { + name: "substring", + expr: "'abc 123 def 123'.substring(5)", + expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectRuntimeCost: 2, + }, + { + name: "substring with end", + expr: "'abc 123 def 123'.substring(5, 8)", + expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectRuntimeCost: 2, + }, + { + name: "trim", + expr: "' abc 123 def 123 '.trim()", + expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectRuntimeCost: 2, + }, + { + name: "join with separator", + expr: "['aa', 'bb', 'cc', 'd', 'e', 'f', 'g', 'h', 'i', 'j'].join(' ')", + expectEsimatedCost: checker.CostEstimate{Min: 11, Max: 23}, + expectRuntimeCost: 15, + }, + { + name: "join", + expr: "['aa', 'bb', 'cc', 'd', 'e', 'f', 'g', 'h', 'i', 'j'].join()", + expectEsimatedCost: checker.CostEstimate{Min: 10, Max: 22}, + expectRuntimeCost: 13, + }, + { + name: "find", + expr: "'abc 123 def 123'.find('123')", + expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectRuntimeCost: 2, + }, + { + name: "findAll", + expr: "'abc 123 def 123'.findAll('123')", + expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectRuntimeCost: 2, + }, + { + name: "findAll with limit", + expr: "'abc 123 def 123'.findAll('123', 1)", + expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2}, + expectRuntimeCost: 2, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + testCost(t, tc.expr, tc.expectEsimatedCost, tc.expectRuntimeCost) + }) + } +} + +func testCost(t *testing.T, expr string, expectEsimatedCost checker.CostEstimate, expectRuntimeCost uint64) { + est := &CostEstimator{SizeEstimator: &testCostEstimator{}} + env, err := cel.NewEnv(append(k8sExtensionLibs, ext.Strings())...) + if err != nil { + t.Fatalf("%v", err) + } + compiled, issues := env.Compile(expr) + if len(issues.Errors()) > 0 { + t.Fatalf("%v", issues.Errors()) + } + estCost, err := env.EstimateCost(compiled, est) + if err != nil { + t.Fatalf("%v", err) + } + if estCost.Min != expectEsimatedCost.Min || estCost.Max != expectEsimatedCost.Max { + t.Errorf("Expected estimated cost of %d..%d but got %d..%d", expectEsimatedCost.Min, expectEsimatedCost.Max, estCost.Min, estCost.Max) + } + prog, err := env.Program(compiled, cel.CostTracking(est)) + if err != nil { + t.Fatalf("%v", err) + } + _, details, err := prog.Eval(map[string]interface{}{}) + if err != nil { + t.Fatalf("%v", err) + } + cost := details.ActualCost() + if *cost != expectRuntimeCost { + t.Errorf("Expected cost of %d but got %d", expectRuntimeCost, *cost) + } +} + +type testCostEstimator struct { +} + +func (t *testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { + switch t := element.Type().TypeKind.(type) { + case *expr.Type_Primitive: + switch t.Primitive { + case expr.Type_STRING: + return &checker.SizeEstimate{Min: 0, Max: 12} + case expr.Type_BYTES: + return &checker.SizeEstimate{Min: 0, Max: 12} + } + } + return nil +} + +func (t *testCostEstimator) EstimateCallCost(function, overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + return nil +} diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library/libraries.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library/libraries.go index 31c2c42ab7c..18f6d7a7c2e 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library/libraries.go +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library/libraries.go @@ -19,6 +19,7 @@ package library import ( "github.com/google/cel-go/cel" "github.com/google/cel-go/ext" + "github.com/google/cel-go/interpreter" ) // ExtensionLibs declares the set of CEL extension libraries available everywhere CEL is used in Kubernetes. @@ -29,3 +30,5 @@ var k8sExtensionLibs = []cel.EnvOption{ Regex(), Lists(), } + +var ExtensionLibRegexOptimizations = []*interpreter.RegexOptimization{FindRegexOptimization, FindAllRegexOptimization} diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library/regex.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library/regex.go index 5da192929f1..e44f82090ab 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library/regex.go +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library/regex.go @@ -23,6 +23,7 @@ import ( "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/interpreter" "github.com/google/cel-go/interpreter/functions" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) @@ -156,3 +157,61 @@ func findAll(args ...ref.Val) ref.Val { return types.NewStringList(types.DefaultTypeAdapter, result) } + +// FindRegexOptimization optimizes the 'find' function by compiling the regex pattern and +// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function +// call invocations. +var FindRegexOptimization = &interpreter.RegexOptimization{ + Function: "find", + RegexIndex: 1, + Factory: func(call interpreter.InterpretableCall, regexPattern string) (interpreter.InterpretableCall, error) { + compiledRegex, err := regexp.Compile(regexPattern) + if err != nil { + return nil, err + } + return interpreter.NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(args ...ref.Val) ref.Val { + if len(args) != 2 { + return types.NoSuchOverloadErr() + } + in, ok := args[0].Value().(string) + if !ok { + return types.MaybeNoSuchOverloadErr(args[0]) + } + return types.String(compiledRegex.FindString(in)) + }), nil + }, +} + +// FindAllRegexOptimization optimizes the 'findAll' function by compiling the regex pattern and +// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function +// call invocations. +var FindAllRegexOptimization = &interpreter.RegexOptimization{ + Function: "findAll", + RegexIndex: 1, + Factory: func(call interpreter.InterpretableCall, regexPattern string) (interpreter.InterpretableCall, error) { + compiledRegex, err := regexp.Compile(regexPattern) + if err != nil { + return nil, err + } + return interpreter.NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(args ...ref.Val) ref.Val { + argn := len(args) + if argn < 2 || argn > 3 { + return types.NoSuchOverloadErr() + } + str, ok := args[0].Value().(string) + if !ok { + return types.MaybeNoSuchOverloadErr(args[0]) + } + n := int64(-1) + if argn == 3 { + n, ok = args[2].Value().(int64) + if !ok { + return types.MaybeNoSuchOverloadErr(args[2]) + } + } + + result := compiledRegex.FindAllString(str, int(n)) + return types.NewStringList(types.DefaultTypeAdapter, result) + }), nil + }, +} diff --git a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/validation_test.go b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/validation_test.go index efc54a3e1cb..c3ad6f5b13b 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/validation_test.go +++ b/staging/src/k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/validation_test.go @@ -196,6 +196,10 @@ func TestValidationExpressions(t *testing.T) { "self.val1.upperAscii() == 'ROOK TAKES 👑'", "self.val1.lowerAscii() == 'rook takes 👑'", }, + errors: map[string]string{ + // Invalid regex with a string constant regex pattern is compile time error + "self.val1.matches(')')": "compile error: program instantiation failed: error parsing regexp: unexpected ): `)`", + }, }, {name: "escaped strings", obj: objs("l1\nl2", "l1\nl2"), @@ -1637,6 +1641,12 @@ func TestValidationExpressions(t *testing.T) { "self.str.findAll('xyz') == []", "self.str.findAll('xyz', 1) == []", }, + errors: map[string]string{ + // Invalid regex with a string constant regex pattern is compile time error + "self.str.find(')') == ''": "compile error: program instantiation failed: error parsing regexp: unexpected ): `)`", + "self.str.findAll(')') == []": "compile error: program instantiation failed: error parsing regexp: unexpected ): `)`", + "self.str.findAll(')', 1) == []": "compile error: program instantiation failed: error parsing regexp: unexpected ): `)`", + }, }, {name: "URL parsing", obj: map[string]interface{}{ @@ -2051,6 +2061,22 @@ func withRule(s schema.Structural, rule string) schema.Structural { return s } +func withMaxLength(s schema.Structural, maxLength *int64) schema.Structural { + if s.ValueValidation == nil { + s.ValueValidation = &schema.ValueValidation{} + } + s.ValueValidation.MaxLength = maxLength + return s +} + +func withMaxItems(s schema.Structural, maxItems *int64) schema.Structural { + if s.ValueValidation == nil { + s.ValueValidation = &schema.ValueValidation{} + } + s.ValueValidation.MaxItems = maxItems + return s +} + func withDefault(dflt interface{}, s schema.Structural) schema.Structural { s.Generic.Default = schema.JSON{Object: dflt} return s diff --git a/staging/src/k8s.io/apiextensions-apiserver/third_party/forked/celopenapi/model/schemas.go b/staging/src/k8s.io/apiextensions-apiserver/third_party/forked/celopenapi/model/schemas.go index b0ada9b1163..51895a64665 100644 --- a/staging/src/k8s.io/apiextensions-apiserver/third_party/forked/celopenapi/model/schemas.go +++ b/staging/src/k8s.io/apiextensions-apiserver/third_party/forked/celopenapi/model/schemas.go @@ -19,6 +19,7 @@ import ( "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/types" + "k8s.io/apiextensions-apiserver/pkg/apiserver/schema" )