mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-20 18:31:15 +00:00
Merge pull request #108617 from jpbetz/function-ext-costs
CEL: Enable regex pre-compilation, Add cost estimates for function extension libraries
This commit is contained in:
commit
95e30f66c3
@ -117,17 +117,17 @@ func Compile(s *schema.Structural, isResourceRoot bool, perCallLimit uint64) ([]
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
estimator := celEstimator{root: root}
|
estimator := newCostEstimator(root)
|
||||||
// compResults is the return value which saves a list of compilation results in the same order as x-kubernetes-validations rules.
|
// compResults is the return value which saves a list of compilation results in the same order as x-kubernetes-validations rules.
|
||||||
compResults := make([]CompilationResult, len(celRules))
|
compResults := make([]CompilationResult, len(celRules))
|
||||||
for i, rule := range celRules {
|
for i, rule := range celRules {
|
||||||
compResults[i] = compileRule(rule, env, perCallLimit, &estimator)
|
compResults[i] = compileRule(rule, env, perCallLimit, estimator)
|
||||||
}
|
}
|
||||||
|
|
||||||
return compResults, nil
|
return compResults, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func compileRule(rule apiextensions.ValidationRule, env *cel.Env, perCallLimit uint64, estimator *celEstimator) (compilationResult CompilationResult) {
|
func compileRule(rule apiextensions.ValidationRule, env *cel.Env, perCallLimit uint64, estimator *library.CostEstimator) (compilationResult CompilationResult) {
|
||||||
if len(strings.TrimSpace(rule.Rule)) == 0 {
|
if len(strings.TrimSpace(rule.Rule)) == 0 {
|
||||||
// include a compilation result, but leave both program and error nil per documented return semantics of this
|
// include a compilation result, but leave both program and error nil per documented return semantics of this
|
||||||
// function
|
// function
|
||||||
@ -157,7 +157,13 @@ func compileRule(rule apiextensions.ValidationRule, env *cel.Env, perCallLimit u
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Ideally we could configure the per expression limit at validation time and set it to the remaining overall budget, but we would either need a way to pass in a limit at evaluation time or move program creation to validation time
|
// TODO: Ideally we could configure the per expression limit at validation time and set it to the remaining overall budget, but we would either need a way to pass in a limit at evaluation time or move program creation to validation time
|
||||||
prog, err := env.Program(ast, cel.EvalOptions(cel.OptOptimize, cel.OptTrackCost), cel.CostLimit(perCallLimit), cel.InterruptCheckFrequency(checkFrequency))
|
prog, err := env.Program(ast,
|
||||||
|
cel.EvalOptions(cel.OptOptimize, cel.OptTrackCost),
|
||||||
|
cel.CostLimit(perCallLimit),
|
||||||
|
cel.CostTracking(estimator),
|
||||||
|
cel.OptimizeRegex(library.ExtensionLibRegexOptimizations...),
|
||||||
|
cel.InterruptCheckFrequency(checkFrequency),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
compilationResult.Error = &Error{ErrorTypeInvalid, "program instantiation failed: " + err.Error()}
|
compilationResult.Error = &Error{ErrorTypeInvalid, "program instantiation failed: " + err.Error()}
|
||||||
return
|
return
|
||||||
@ -180,11 +186,15 @@ func generateUniqueSelfTypeName() string {
|
|||||||
return fmt.Sprintf("selfType%d", time.Now().Nanosecond())
|
return fmt.Sprintf("selfType%d", time.Now().Nanosecond())
|
||||||
}
|
}
|
||||||
|
|
||||||
type celEstimator struct {
|
func newCostEstimator(root *celmodel.DeclType) *library.CostEstimator {
|
||||||
|
return &library.CostEstimator{SizeEstimator: &sizeEstimator{root: root}}
|
||||||
|
}
|
||||||
|
|
||||||
|
type sizeEstimator struct {
|
||||||
root *celmodel.DeclType
|
root *celmodel.DeclType
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *celEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
|
func (c *sizeEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
|
||||||
if len(element.Path()) == 0 {
|
if len(element.Path()) == 0 {
|
||||||
// Path() can return an empty list, early exit if it does since we can't
|
// Path() can return an empty list, early exit if it does since we can't
|
||||||
// provide size estimates when that happens
|
// provide size estimates when that happens
|
||||||
@ -218,6 +228,6 @@ func (c *celEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstima
|
|||||||
return &checker.SizeEstimate{Min: 0, Max: uint64(currentNode.MaxElements)}
|
return &checker.SizeEstimate{Min: 0, Max: uint64(currentNode.MaxElements)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *celEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
|
func (c *sizeEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1073,7 +1073,10 @@ func genMapWithCustomItemRule(item *schema.Structural, rule string) func(maxProp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func schemaChecker(schema *schema.Structural, expectedCost uint64, calcLimit uint64, t *testing.T) func(t *testing.T) {
|
// schemaChecker checks the cost of the validation rule declared in the provided schema (it requires there be exactly one rule)
|
||||||
|
// and checks that the resulting equals the expectedCost if expectedCost is non-zero, and that the resulting cost is >= expectedCostExceedsLimit
|
||||||
|
// if expectedCostExceedsLimit is non-zero. Typically, only expectedCost or expectedCostExceedsLimit is non-zero, not both.
|
||||||
|
func schemaChecker(schema *schema.Structural, expectedCost uint64, expectedCostExceedsLimit uint64, t *testing.T) func(t *testing.T) {
|
||||||
return func(t *testing.T) {
|
return func(t *testing.T) {
|
||||||
// TODO(DangerOnTheRanger): if perCallLimit in compilation.go changes, this needs to change as well
|
// TODO(DangerOnTheRanger): if perCallLimit in compilation.go changes, this needs to change as well
|
||||||
compilationResults, err := Compile(schema, false, uint64(math.MaxInt64))
|
compilationResults, err := Compile(schema, false, uint64(math.MaxInt64))
|
||||||
@ -1087,13 +1090,14 @@ func schemaChecker(schema *schema.Structural, expectedCost uint64, calcLimit uin
|
|||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
t.Errorf("Expected no compile-time error, got: %v", result.Error)
|
t.Errorf("Expected no compile-time error, got: %v", result.Error)
|
||||||
}
|
}
|
||||||
if calcLimit == 0 {
|
if expectedCost > 0 {
|
||||||
if result.MaxCost != expectedCost {
|
if result.MaxCost != expectedCost {
|
||||||
t.Errorf("Wrong cost (expected %d, got %d)", expectedCost, result.MaxCost)
|
t.Errorf("Wrong cost (expected %d, got %d)", expectedCost, result.MaxCost)
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
if result.MaxCost < calcLimit {
|
if expectedCostExceedsLimit > 0 {
|
||||||
t.Errorf("Cost did not exceed limit as expected (expected more than %d, got %d)", calcLimit, result.MaxCost)
|
if result.MaxCost < expectedCostExceedsLimit {
|
||||||
|
t.Errorf("Cost did not exceed limit as expected (expected more than %d, got %d)", expectedCostExceedsLimit, result.MaxCost)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1101,12 +1105,17 @@ func schemaChecker(schema *schema.Structural, expectedCost uint64, calcLimit uin
|
|||||||
|
|
||||||
func TestCostEstimation(t *testing.T) {
|
func TestCostEstimation(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
schemaGenerator func(maxLength *int64) *schema.Structural
|
schemaGenerator func(maxLength *int64) *schema.Structural
|
||||||
|
setMaxElements int64
|
||||||
|
|
||||||
|
// calc costs expectations are checked against the generated schema without any max element limits set
|
||||||
expectedCalcCost uint64
|
expectedCalcCost uint64
|
||||||
setMaxElements int64
|
|
||||||
expectedSetCost uint64
|
|
||||||
expectCalcCostExceedsLimit uint64
|
expectCalcCostExceedsLimit uint64
|
||||||
|
|
||||||
|
// calc costs expectations are checked against the generated schema with max element limits set
|
||||||
|
expectedSetCost uint64
|
||||||
|
expectedSetCostExceedsLimit uint64
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "number array with all",
|
name: "number array with all",
|
||||||
@ -1233,7 +1242,6 @@ func TestCostEstimation(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "O(n^2) loop with numbers",
|
name: "O(n^2) loop with numbers",
|
||||||
schemaGenerator: genArrayWithRule("number", "self.all(x, self.all(y, true))"),
|
schemaGenerator: genArrayWithRule("number", "self.all(x, self.all(y, true))"),
|
||||||
expectedCalcCost: 9895601504256,
|
|
||||||
expectCalcCostExceedsLimit: costLimit,
|
expectCalcCostExceedsLimit: costLimit,
|
||||||
setMaxElements: 10,
|
setMaxElements: 10,
|
||||||
expectedSetCost: 352,
|
expectedSetCost: 352,
|
||||||
@ -1241,7 +1249,6 @@ func TestCostEstimation(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "O(n^3) loop with numbers",
|
name: "O(n^3) loop with numbers",
|
||||||
schemaGenerator: genArrayWithRule("number", "self.all(x, self.all(y, self.all(z, true)))"),
|
schemaGenerator: genArrayWithRule("number", "self.all(x, self.all(y, self.all(z, true)))"),
|
||||||
expectedCalcCost: 13499986500008999998,
|
|
||||||
expectCalcCostExceedsLimit: costLimit,
|
expectCalcCostExceedsLimit: costLimit,
|
||||||
setMaxElements: 10,
|
setMaxElements: 10,
|
||||||
expectedSetCost: 3552,
|
expectedSetCost: 3552,
|
||||||
@ -1512,6 +1519,80 @@ func TestCostEstimation(t *testing.T) {
|
|||||||
setMaxElements: 490,
|
setMaxElements: 490,
|
||||||
expectedSetCost: 0,
|
expectedSetCost: 0,
|
||||||
},
|
},
|
||||||
|
// Ensure library functions are integrated with size estimates by testing the interesting cases.
|
||||||
|
{
|
||||||
|
name: "extended library regex find",
|
||||||
|
schemaGenerator: genStringWithRule("self.find('[0-9]+') == ''"),
|
||||||
|
expectedCalcCost: 629147,
|
||||||
|
setMaxElements: 10,
|
||||||
|
expectedSetCost: 11,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extended library join",
|
||||||
|
schemaGenerator: func(max *int64) *schema.Structural {
|
||||||
|
strType := withMaxLength(primitiveType("string", ""), max)
|
||||||
|
array := withMaxItems(arrayType("atomic", nil, &strType), max)
|
||||||
|
array = withRule(array, "self.join(' ') == 'aa bb'")
|
||||||
|
return &array
|
||||||
|
},
|
||||||
|
expectedCalcCost: 329853068905,
|
||||||
|
setMaxElements: 10,
|
||||||
|
expectedSetCost: 43,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extended library isSorted",
|
||||||
|
schemaGenerator: func(max *int64) *schema.Structural {
|
||||||
|
strType := withMaxLength(primitiveType("string", ""), max)
|
||||||
|
array := withMaxItems(arrayType("atomic", nil, &strType), max)
|
||||||
|
array = withRule(array, "self.isSorted() == true")
|
||||||
|
return &array
|
||||||
|
},
|
||||||
|
expectedCalcCost: 329854432052,
|
||||||
|
setMaxElements: 10,
|
||||||
|
expectedSetCost: 52,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extended library replace",
|
||||||
|
schemaGenerator: func(max *int64) *schema.Structural {
|
||||||
|
strType := withMaxLength(primitiveType("string", ""), max)
|
||||||
|
objType := objectType(map[string]schema.Structural{
|
||||||
|
"str": strType,
|
||||||
|
"before": strType,
|
||||||
|
"after": strType,
|
||||||
|
})
|
||||||
|
objType = withRule(objType, "self.str.replace(self.before, self.after) == 'does not matter'")
|
||||||
|
return &objType
|
||||||
|
},
|
||||||
|
expectedCalcCost: 629154,
|
||||||
|
setMaxElements: 10,
|
||||||
|
expectedSetCost: 16,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extended library split",
|
||||||
|
schemaGenerator: func(max *int64) *schema.Structural {
|
||||||
|
strType := withMaxLength(primitiveType("string", ""), max)
|
||||||
|
objType := objectType(map[string]schema.Structural{
|
||||||
|
"str": strType,
|
||||||
|
"separator": strType,
|
||||||
|
})
|
||||||
|
objType = withRule(objType, "self.str.split(self.separator) == []")
|
||||||
|
return &objType
|
||||||
|
},
|
||||||
|
expectedCalcCost: 629160,
|
||||||
|
setMaxElements: 10,
|
||||||
|
expectedSetCost: 22,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extended library lowerAscii",
|
||||||
|
schemaGenerator: func(max *int64) *schema.Structural {
|
||||||
|
strType := withMaxLength(primitiveType("string", ""), max)
|
||||||
|
strType = withRule(strType, "self.lowerAscii() == 'lower!'")
|
||||||
|
return &strType
|
||||||
|
},
|
||||||
|
expectedCalcCost: 314575,
|
||||||
|
setMaxElements: 10,
|
||||||
|
expectedSetCost: 6,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, testCase := range cases {
|
for _, testCase := range cases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
@ -1520,7 +1601,7 @@ func TestCostEstimation(t *testing.T) {
|
|||||||
t.Run("calc maxLength", schemaChecker(schema, testCase.expectedCalcCost, testCase.expectCalcCostExceedsLimit, t))
|
t.Run("calc maxLength", schemaChecker(schema, testCase.expectedCalcCost, testCase.expectCalcCostExceedsLimit, t))
|
||||||
// static maxLength case
|
// static maxLength case
|
||||||
setSchema := testCase.schemaGenerator(&testCase.setMaxElements)
|
setSchema := testCase.schemaGenerator(&testCase.setMaxElements)
|
||||||
t.Run("set maxLength", schemaChecker(setSchema, testCase.expectedSetCost, 0, t))
|
t.Run("set maxLength", schemaChecker(setSchema, testCase.expectedSetCost, testCase.expectedSetCostExceedsLimit, t))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,268 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2022 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package library
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/google/cel-go/checker"
|
||||||
|
"github.com/google/cel-go/common"
|
||||||
|
"github.com/google/cel-go/common/types"
|
||||||
|
"github.com/google/cel-go/common/types/ref"
|
||||||
|
"github.com/google/cel-go/common/types/traits"
|
||||||
|
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator.
|
||||||
|
type CostEstimator struct {
|
||||||
|
// SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation
|
||||||
|
// calculations to if the size is not well known (i.e. a constant).
|
||||||
|
SizeEstimator checker.CostEstimator
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, result ref.Val) *uint64 {
|
||||||
|
switch function {
|
||||||
|
case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf":
|
||||||
|
var cost uint64
|
||||||
|
if len(args) > 0 {
|
||||||
|
cost += traversalCost(args[0]) // these O(n) operations all cost roughly the cost of a single traversal
|
||||||
|
}
|
||||||
|
return &cost
|
||||||
|
case "url", "lowerAscii", "upperAscii", "substring", "trim":
|
||||||
|
if len(args) >= 1 {
|
||||||
|
cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor))
|
||||||
|
return &cost
|
||||||
|
}
|
||||||
|
case "replace", "split":
|
||||||
|
if len(args) >= 1 {
|
||||||
|
// cost is the traversal plus the construction of the result
|
||||||
|
cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor))
|
||||||
|
return &cost
|
||||||
|
}
|
||||||
|
case "join":
|
||||||
|
if len(args) >= 1 {
|
||||||
|
cost := uint64(math.Ceil(float64(actualSize(result)) * 2 * common.StringTraversalCostFactor))
|
||||||
|
return &cost
|
||||||
|
}
|
||||||
|
case "find", "findAll":
|
||||||
|
if len(args) >= 2 {
|
||||||
|
strCost := uint64(math.Ceil((1.0 + float64(actualSize(args[0]))) * common.StringTraversalCostFactor))
|
||||||
|
// We don't know how many expressions are in the regex, just the string length (a huge
|
||||||
|
// improvement here would be to somehow get a count the number of expressions in the regex or
|
||||||
|
// how many states are in the regex state machine and use that to measure regex cost).
|
||||||
|
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
|
||||||
|
// in length.
|
||||||
|
regexCost := uint64(math.Ceil(float64(actualSize(args[1])) * common.RegexStringLengthCostFactor))
|
||||||
|
cost := strCost * regexCost
|
||||||
|
return &cost
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
|
||||||
|
// WARNING: Any changes to this code impact API compatibility! The estimated cost is used to determine which CEL rules may be written to a
|
||||||
|
// CRD and any change (cost increases and cost decreases) are breaking.
|
||||||
|
switch function {
|
||||||
|
case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf":
|
||||||
|
if target != nil {
|
||||||
|
// Charge 1 cost for comparing each element in the list
|
||||||
|
elCost := checker.CostEstimate{Min: 1, Max: 1}
|
||||||
|
// If the list contains strings or bytes, add the cost of traversing all the strings/bytes as a way
|
||||||
|
// of estimating the additional comparison cost.
|
||||||
|
if elNode := l.listElementNode(*target); elNode != nil {
|
||||||
|
t := elNode.Type().GetPrimitive()
|
||||||
|
if t == exprpb.Type_STRING || t == exprpb.Type_BYTES {
|
||||||
|
sz := l.sizeEstimate(elNode)
|
||||||
|
elCost = elCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor))
|
||||||
|
}
|
||||||
|
return &checker.CallEstimate{CostEstimate: l.sizeEstimate(*target).MultiplyByCost(elCost)}
|
||||||
|
} else { // the target is a string, which is supported by indexOf and lastIndexOf
|
||||||
|
return &checker.CallEstimate{CostEstimate: l.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor)}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
case "url":
|
||||||
|
if len(args) == 1 {
|
||||||
|
sz := l.sizeEstimate(args[0])
|
||||||
|
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)}
|
||||||
|
}
|
||||||
|
case "lowerAscii", "upperAscii", "substring", "trim":
|
||||||
|
if target != nil {
|
||||||
|
sz := l.sizeEstimate(*target)
|
||||||
|
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz}
|
||||||
|
}
|
||||||
|
case "replace":
|
||||||
|
if target != nil && len(args) >= 2 {
|
||||||
|
sz := l.sizeEstimate(*target)
|
||||||
|
toReplaceSz := l.sizeEstimate(args[0])
|
||||||
|
replaceWithSz := l.sizeEstimate(args[1])
|
||||||
|
// smallest possible result: smallest input size composed of the largest possible substrings being replaced by smallest possible replacement
|
||||||
|
minSz := uint64(math.Ceil(float64(sz.Min)/float64(toReplaceSz.Max))) * replaceWithSz.Min
|
||||||
|
// largest possible result: largest input size composed of the smallest possible substrings being replaced by largest possible replacement
|
||||||
|
maxSz := uint64(math.Ceil(float64(sz.Max)/float64(toReplaceSz.Min))) * replaceWithSz.Max
|
||||||
|
|
||||||
|
// cost is the traversal plus the construction of the result
|
||||||
|
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor), ResultSize: &checker.SizeEstimate{Min: minSz, Max: maxSz}}
|
||||||
|
}
|
||||||
|
case "split":
|
||||||
|
if target != nil {
|
||||||
|
sz := l.sizeEstimate(*target)
|
||||||
|
|
||||||
|
// Worst case size is where is that a separator of "" is used, and each char is returned as a list element.
|
||||||
|
max := sz.Max
|
||||||
|
if len(args) > 1 {
|
||||||
|
if c := args[1].Expr().GetConstExpr(); c != nil {
|
||||||
|
max = uint64(c.GetInt64Value())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Cost is the traversal plus the construction of the result.
|
||||||
|
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor), ResultSize: &checker.SizeEstimate{Min: 0, Max: max}}
|
||||||
|
}
|
||||||
|
case "join":
|
||||||
|
if target != nil {
|
||||||
|
var sz checker.SizeEstimate
|
||||||
|
listSize := l.sizeEstimate(*target)
|
||||||
|
if elNode := l.listElementNode(*target); elNode != nil {
|
||||||
|
elemSize := l.sizeEstimate(elNode)
|
||||||
|
sz = listSize.Multiply(elemSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) > 0 {
|
||||||
|
sepSize := l.sizeEstimate(args[0])
|
||||||
|
minSeparators := uint64(0)
|
||||||
|
maxSeparators := uint64(0)
|
||||||
|
if listSize.Min > 0 {
|
||||||
|
minSeparators = listSize.Min - 1
|
||||||
|
}
|
||||||
|
if listSize.Max > 0 {
|
||||||
|
maxSeparators = listSize.Max - 1
|
||||||
|
}
|
||||||
|
sz = sz.Add(sepSize.Multiply(checker.SizeEstimate{Min: minSeparators, Max: maxSeparators}))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz}
|
||||||
|
}
|
||||||
|
case "find", "findAll":
|
||||||
|
if target != nil && len(args) >= 1 {
|
||||||
|
sz := l.sizeEstimate(*target)
|
||||||
|
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
|
||||||
|
// in case where string is empty but regex is still expensive.
|
||||||
|
strCost := sz.Add(checker.SizeEstimate{Min: 1, Max: 1}).MultiplyByCostFactor(common.StringTraversalCostFactor)
|
||||||
|
// We don't know how many expressions are in the regex, just the string length (a huge
|
||||||
|
// improvement here would be to somehow get a count the number of expressions in the regex or
|
||||||
|
// how many states are in the regex state machine and use that to measure regex cost).
|
||||||
|
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
|
||||||
|
// in length.
|
||||||
|
regexCost := l.sizeEstimate(args[0]).MultiplyByCostFactor(common.RegexStringLengthCostFactor)
|
||||||
|
// worst case size of result is that every char is returned as separate find result.
|
||||||
|
return &checker.CallEstimate{CostEstimate: strCost.Multiply(regexCost), ResultSize: &checker.SizeEstimate{Min: 0, Max: sz.Max}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func actualSize(value ref.Val) uint64 {
|
||||||
|
if sz, ok := value.(traits.Sizer); ok {
|
||||||
|
return uint64(sz.Size().(types.Int))
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *CostEstimator) sizeEstimate(t checker.AstNode) checker.SizeEstimate {
|
||||||
|
if sz := t.ComputedSize(); sz != nil {
|
||||||
|
return *sz
|
||||||
|
}
|
||||||
|
if sz := l.EstimateSize(t); sz != nil {
|
||||||
|
return *sz
|
||||||
|
}
|
||||||
|
return checker.SizeEstimate{Min: 0, Max: math.MaxUint64}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *CostEstimator) listElementNode(list checker.AstNode) checker.AstNode {
|
||||||
|
if lt := list.Type().GetListType(); lt != nil {
|
||||||
|
nodePath := list.Path()
|
||||||
|
if nodePath != nil {
|
||||||
|
// Provide path if we have it so that a OpenAPIv3 maxLength validation can be looked up, if it exists
|
||||||
|
// for this node.
|
||||||
|
path := make([]string, len(nodePath)+1)
|
||||||
|
copy(path, nodePath)
|
||||||
|
path[len(nodePath)] = "@items"
|
||||||
|
return &itemsNode{path: path, t: lt.GetElemType(), expr: nil}
|
||||||
|
} else {
|
||||||
|
// Provide just the type if no path is available so that worst case size can be looked up based on type.
|
||||||
|
return &itemsNode{t: lt.GetElemType(), expr: nil}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *CostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
|
||||||
|
if l.SizeEstimator != nil {
|
||||||
|
return l.SizeEstimator.EstimateSize(element)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type itemsNode struct {
|
||||||
|
path []string
|
||||||
|
t *exprpb.Type
|
||||||
|
expr *exprpb.Expr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *itemsNode) Path() []string {
|
||||||
|
return i.path
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *itemsNode) Type() *exprpb.Type {
|
||||||
|
return i.t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *itemsNode) Expr() *exprpb.Expr {
|
||||||
|
return i.expr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *itemsNode) ComputedSize() *checker.SizeEstimate {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// traversalCost computes the cost of traversing a ref.Val as a data tree.
|
||||||
|
func traversalCost(v ref.Val) uint64 {
|
||||||
|
// TODO: This could potentially be optimized by sampling maps and lists instead of traversing.
|
||||||
|
switch vt := v.(type) {
|
||||||
|
case types.String:
|
||||||
|
return uint64(float64(len(string(vt))) * common.StringTraversalCostFactor)
|
||||||
|
case types.Bytes:
|
||||||
|
return uint64(float64(len([]byte(vt))) * common.StringTraversalCostFactor)
|
||||||
|
case traits.Lister:
|
||||||
|
cost := uint64(0)
|
||||||
|
for it := vt.Iterator(); it.HasNext() == types.True; {
|
||||||
|
i := it.Next()
|
||||||
|
cost += traversalCost(i)
|
||||||
|
}
|
||||||
|
return cost
|
||||||
|
case traits.Mapper: // maps and objects
|
||||||
|
cost := uint64(0)
|
||||||
|
for it := vt.Iterator(); it.HasNext() == types.True; {
|
||||||
|
k := it.Next()
|
||||||
|
cost += traversalCost(k) + traversalCost(vt.Get(k))
|
||||||
|
}
|
||||||
|
return cost
|
||||||
|
default:
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,363 @@
|
|||||||
|
/*
|
||||||
|
Copyright 2022 The Kubernetes Authors.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package library
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/cel-go/cel"
|
||||||
|
"github.com/google/cel-go/checker"
|
||||||
|
"github.com/google/cel-go/ext"
|
||||||
|
expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
intListLiteral = "[1, 2, 3, 4, 5]"
|
||||||
|
uintListLiteral = "[uint(1), uint(2), uint(3), uint(4), uint(5)]"
|
||||||
|
doubleListLiteral = "[1.0, 2.0, 3.0, 4.0, 5.0]"
|
||||||
|
boolListLiteral = "[false, true, false, true, false]"
|
||||||
|
stringListLiteral = "['012345678901', '012345678901', '012345678901', '012345678901', '012345678901']"
|
||||||
|
bytesListLiteral = "[bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901')]"
|
||||||
|
durationListLiteral = "[duration('1s'), duration('2s'), duration('3s'), duration('4s'), duration('5s')]"
|
||||||
|
timestampListLiteral = "[timestamp('2011-01-01T00:00:00.000+01:00'), timestamp('2011-01-02T00:00:00.000+01:00'), " +
|
||||||
|
"timestamp('2011-01-03T00:00:00.000+01:00'), timestamp('2011-01-04T00:00:00.000+01:00'), " +
|
||||||
|
"timestamp('2011-01-05T00:00:00.000+01:00')]"
|
||||||
|
stringLiteral = "'01234567890123456789012345678901234567890123456789'"
|
||||||
|
)
|
||||||
|
|
||||||
|
type comparableCost struct {
|
||||||
|
comparableLiteral string
|
||||||
|
expectedEstimatedCost checker.CostEstimate
|
||||||
|
expectedRuntimeCost uint64
|
||||||
|
|
||||||
|
param string
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListsCost(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
opts []string
|
||||||
|
costs []comparableCost
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
opts: []string{".sum()"},
|
||||||
|
// 10 cost for the list declaration, the rest is the due to the function call
|
||||||
|
costs: []comparableCost{
|
||||||
|
{
|
||||||
|
comparableLiteral: intListLiteral,
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: uintListLiteral,
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: doubleListLiteral,
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: durationListLiteral,
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
opts: []string{".isSorted()", ".max()", ".min()"},
|
||||||
|
// 10 cost for the list declaration, the rest is the due to the function call
|
||||||
|
costs: []comparableCost{
|
||||||
|
{
|
||||||
|
comparableLiteral: intListLiteral,
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: uintListLiteral,
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for numeric casts
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: doubleListLiteral,
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: boolListLiteral,
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: stringListLiteral,
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 25}, expectedRuntimeCost: 15, // +5 for string comparisons
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: bytesListLiteral,
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 25, Max: 35}, expectedRuntimeCost: 25, // +10 for casts from string to byte, +5 for byte comparisons
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: durationListLiteral,
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for numeric casts
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: timestampListLiteral,
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
for _, op := range tc.opts {
|
||||||
|
for _, typ := range tc.costs {
|
||||||
|
t.Run(typ.comparableLiteral+op, func(t *testing.T) {
|
||||||
|
e := typ.comparableLiteral + op
|
||||||
|
testCost(t, e, typ.expectedEstimatedCost, typ.expectedRuntimeCost)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIndexOfCost(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
opts []string
|
||||||
|
costs []comparableCost
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
opts: []string{".indexOf(%s)", ".lastIndexOf(%s)"},
|
||||||
|
// 10 cost for the list declaration, the rest is the due to the function call
|
||||||
|
costs: []comparableCost{
|
||||||
|
{
|
||||||
|
comparableLiteral: intListLiteral, param: "3",
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: uintListLiteral, param: "uint(3)",
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +5 for numeric casts
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: doubleListLiteral, param: "3.0",
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: boolListLiteral, param: "true",
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: stringListLiteral, param: "'x'",
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 25}, expectedRuntimeCost: 15, // +5 for string comparisons
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: bytesListLiteral, param: "bytes('x')",
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 26, Max: 36}, expectedRuntimeCost: 26, // +11 for casts from string to byte, +5 for byte comparisons
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: durationListLiteral, param: "duration('3s')",
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +6 for casts from duration to byte
|
||||||
|
},
|
||||||
|
{
|
||||||
|
comparableLiteral: timestampListLiteral, param: "timestamp('2011-01-03T00:00:00.000+01:00')",
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +6 for casts from timestamp to byte
|
||||||
|
},
|
||||||
|
|
||||||
|
// index of operations are also defined for strings
|
||||||
|
{
|
||||||
|
comparableLiteral: stringLiteral, param: "'123'",
|
||||||
|
expectedEstimatedCost: checker.CostEstimate{Min: 5, Max: 5}, expectedRuntimeCost: 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
for _, op := range tc.opts {
|
||||||
|
for _, typ := range tc.costs {
|
||||||
|
opWithParam := fmt.Sprintf(op, typ.param)
|
||||||
|
t.Run(typ.comparableLiteral+opWithParam, func(t *testing.T) {
|
||||||
|
e := typ.comparableLiteral + opWithParam
|
||||||
|
testCost(t, e, typ.expectedEstimatedCost, typ.expectedRuntimeCost)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestURLsCost(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
ops []string
|
||||||
|
expectEsimatedCost checker.CostEstimate
|
||||||
|
expectRuntimeCost uint64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
ops: []string{".getScheme()", ".getHostname()", ".getHost()", ".getPort()", ".getEscapedPath()", ".getQuery()"},
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 4, Max: 4},
|
||||||
|
expectRuntimeCost: 4,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
for _, op := range tc.ops {
|
||||||
|
t.Run("url."+op, func(t *testing.T) {
|
||||||
|
testCost(t, "url('https:://kubernetes.io/')"+op, tc.expectEsimatedCost, tc.expectRuntimeCost)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStringLibrary(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
expr string
|
||||||
|
expectEsimatedCost checker.CostEstimate
|
||||||
|
expectRuntimeCost uint64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "lowerAscii",
|
||||||
|
expr: "'ABCDEFGHIJ abcdefghij'.lowerAscii()",
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
|
||||||
|
expectRuntimeCost: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "upperAscii",
|
||||||
|
expr: "'ABCDEFGHIJ abcdefghij'.upperAscii()",
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
|
||||||
|
expectRuntimeCost: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "replace",
|
||||||
|
expr: "'abc 123 def 123'.replace('123', '456')",
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
|
||||||
|
expectRuntimeCost: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "replace with limit",
|
||||||
|
expr: "'abc 123 def 123'.replace('123', '456', 1)",
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
|
||||||
|
expectRuntimeCost: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "split",
|
||||||
|
expr: "'abc 123 def 123'.split(' ')",
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
|
||||||
|
expectRuntimeCost: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "split with limit",
|
||||||
|
expr: "'abc 123 def 123'.split(' ', 1)",
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
|
||||||
|
expectRuntimeCost: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "substring",
|
||||||
|
expr: "'abc 123 def 123'.substring(5)",
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
|
||||||
|
expectRuntimeCost: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "substring with end",
|
||||||
|
expr: "'abc 123 def 123'.substring(5, 8)",
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
|
||||||
|
expectRuntimeCost: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trim",
|
||||||
|
expr: "' abc 123 def 123 '.trim()",
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
|
||||||
|
expectRuntimeCost: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "join with separator",
|
||||||
|
expr: "['aa', 'bb', 'cc', 'd', 'e', 'f', 'g', 'h', 'i', 'j'].join(' ')",
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 11, Max: 23},
|
||||||
|
expectRuntimeCost: 15,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "join",
|
||||||
|
expr: "['aa', 'bb', 'cc', 'd', 'e', 'f', 'g', 'h', 'i', 'j'].join()",
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 10, Max: 22},
|
||||||
|
expectRuntimeCost: 13,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "find",
|
||||||
|
expr: "'abc 123 def 123'.find('123')",
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
|
||||||
|
expectRuntimeCost: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "findAll",
|
||||||
|
expr: "'abc 123 def 123'.findAll('123')",
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
|
||||||
|
expectRuntimeCost: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "findAll with limit",
|
||||||
|
expr: "'abc 123 def 123'.findAll('123', 1)",
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
|
||||||
|
expectRuntimeCost: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
testCost(t, tc.expr, tc.expectEsimatedCost, tc.expectRuntimeCost)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCost(t *testing.T, expr string, expectEsimatedCost checker.CostEstimate, expectRuntimeCost uint64) {
|
||||||
|
est := &CostEstimator{SizeEstimator: &testCostEstimator{}}
|
||||||
|
env, err := cel.NewEnv(append(k8sExtensionLibs, ext.Strings())...)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%v", err)
|
||||||
|
}
|
||||||
|
compiled, issues := env.Compile(expr)
|
||||||
|
if len(issues.Errors()) > 0 {
|
||||||
|
t.Fatalf("%v", issues.Errors())
|
||||||
|
}
|
||||||
|
estCost, err := env.EstimateCost(compiled, est)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%v", err)
|
||||||
|
}
|
||||||
|
if estCost.Min != expectEsimatedCost.Min || estCost.Max != expectEsimatedCost.Max {
|
||||||
|
t.Errorf("Expected estimated cost of %d..%d but got %d..%d", expectEsimatedCost.Min, expectEsimatedCost.Max, estCost.Min, estCost.Max)
|
||||||
|
}
|
||||||
|
prog, err := env.Program(compiled, cel.CostTracking(est))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%v", err)
|
||||||
|
}
|
||||||
|
_, details, err := prog.Eval(map[string]interface{}{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%v", err)
|
||||||
|
}
|
||||||
|
cost := details.ActualCost()
|
||||||
|
if *cost != expectRuntimeCost {
|
||||||
|
t.Errorf("Expected cost of %d but got %d", expectRuntimeCost, *cost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testCostEstimator struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
|
||||||
|
switch t := element.Type().TypeKind.(type) {
|
||||||
|
case *expr.Type_Primitive:
|
||||||
|
switch t.Primitive {
|
||||||
|
case expr.Type_STRING:
|
||||||
|
return &checker.SizeEstimate{Min: 0, Max: 12}
|
||||||
|
case expr.Type_BYTES:
|
||||||
|
return &checker.SizeEstimate{Min: 0, Max: 12}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *testCostEstimator) EstimateCallCost(function, overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
|
||||||
|
return nil
|
||||||
|
}
|
@ -19,6 +19,7 @@ package library
|
|||||||
import (
|
import (
|
||||||
"github.com/google/cel-go/cel"
|
"github.com/google/cel-go/cel"
|
||||||
"github.com/google/cel-go/ext"
|
"github.com/google/cel-go/ext"
|
||||||
|
"github.com/google/cel-go/interpreter"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ExtensionLibs declares the set of CEL extension libraries available everywhere CEL is used in Kubernetes.
|
// ExtensionLibs declares the set of CEL extension libraries available everywhere CEL is used in Kubernetes.
|
||||||
@ -29,3 +30,5 @@ var k8sExtensionLibs = []cel.EnvOption{
|
|||||||
Regex(),
|
Regex(),
|
||||||
Lists(),
|
Lists(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ExtensionLibRegexOptimizations = []*interpreter.RegexOptimization{FindRegexOptimization, FindAllRegexOptimization}
|
||||||
|
@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/google/cel-go/checker/decls"
|
"github.com/google/cel-go/checker/decls"
|
||||||
"github.com/google/cel-go/common/types"
|
"github.com/google/cel-go/common/types"
|
||||||
"github.com/google/cel-go/common/types/ref"
|
"github.com/google/cel-go/common/types/ref"
|
||||||
|
"github.com/google/cel-go/interpreter"
|
||||||
"github.com/google/cel-go/interpreter/functions"
|
"github.com/google/cel-go/interpreter/functions"
|
||||||
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||||
)
|
)
|
||||||
@ -156,3 +157,61 @@ func findAll(args ...ref.Val) ref.Val {
|
|||||||
|
|
||||||
return types.NewStringList(types.DefaultTypeAdapter, result)
|
return types.NewStringList(types.DefaultTypeAdapter, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FindRegexOptimization optimizes the 'find' function by compiling the regex pattern and
|
||||||
|
// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function
|
||||||
|
// call invocations.
|
||||||
|
var FindRegexOptimization = &interpreter.RegexOptimization{
|
||||||
|
Function: "find",
|
||||||
|
RegexIndex: 1,
|
||||||
|
Factory: func(call interpreter.InterpretableCall, regexPattern string) (interpreter.InterpretableCall, error) {
|
||||||
|
compiledRegex, err := regexp.Compile(regexPattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return interpreter.NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(args ...ref.Val) ref.Val {
|
||||||
|
if len(args) != 2 {
|
||||||
|
return types.NoSuchOverloadErr()
|
||||||
|
}
|
||||||
|
in, ok := args[0].Value().(string)
|
||||||
|
if !ok {
|
||||||
|
return types.MaybeNoSuchOverloadErr(args[0])
|
||||||
|
}
|
||||||
|
return types.String(compiledRegex.FindString(in))
|
||||||
|
}), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindAllRegexOptimization optimizes the 'findAll' function by compiling the regex pattern and
|
||||||
|
// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function
|
||||||
|
// call invocations.
|
||||||
|
var FindAllRegexOptimization = &interpreter.RegexOptimization{
|
||||||
|
Function: "findAll",
|
||||||
|
RegexIndex: 1,
|
||||||
|
Factory: func(call interpreter.InterpretableCall, regexPattern string) (interpreter.InterpretableCall, error) {
|
||||||
|
compiledRegex, err := regexp.Compile(regexPattern)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return interpreter.NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(args ...ref.Val) ref.Val {
|
||||||
|
argn := len(args)
|
||||||
|
if argn < 2 || argn > 3 {
|
||||||
|
return types.NoSuchOverloadErr()
|
||||||
|
}
|
||||||
|
str, ok := args[0].Value().(string)
|
||||||
|
if !ok {
|
||||||
|
return types.MaybeNoSuchOverloadErr(args[0])
|
||||||
|
}
|
||||||
|
n := int64(-1)
|
||||||
|
if argn == 3 {
|
||||||
|
n, ok = args[2].Value().(int64)
|
||||||
|
if !ok {
|
||||||
|
return types.MaybeNoSuchOverloadErr(args[2])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := compiledRegex.FindAllString(str, int(n))
|
||||||
|
return types.NewStringList(types.DefaultTypeAdapter, result)
|
||||||
|
}), nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
@ -196,6 +196,10 @@ func TestValidationExpressions(t *testing.T) {
|
|||||||
"self.val1.upperAscii() == 'ROOK TAKES 👑'",
|
"self.val1.upperAscii() == 'ROOK TAKES 👑'",
|
||||||
"self.val1.lowerAscii() == 'rook takes 👑'",
|
"self.val1.lowerAscii() == 'rook takes 👑'",
|
||||||
},
|
},
|
||||||
|
errors: map[string]string{
|
||||||
|
// Invalid regex with a string constant regex pattern is compile time error
|
||||||
|
"self.val1.matches(')')": "compile error: program instantiation failed: error parsing regexp: unexpected ): `)`",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{name: "escaped strings",
|
{name: "escaped strings",
|
||||||
obj: objs("l1\nl2", "l1\nl2"),
|
obj: objs("l1\nl2", "l1\nl2"),
|
||||||
@ -1637,6 +1641,12 @@ func TestValidationExpressions(t *testing.T) {
|
|||||||
"self.str.findAll('xyz') == []",
|
"self.str.findAll('xyz') == []",
|
||||||
"self.str.findAll('xyz', 1) == []",
|
"self.str.findAll('xyz', 1) == []",
|
||||||
},
|
},
|
||||||
|
errors: map[string]string{
|
||||||
|
// Invalid regex with a string constant regex pattern is compile time error
|
||||||
|
"self.str.find(')') == ''": "compile error: program instantiation failed: error parsing regexp: unexpected ): `)`",
|
||||||
|
"self.str.findAll(')') == []": "compile error: program instantiation failed: error parsing regexp: unexpected ): `)`",
|
||||||
|
"self.str.findAll(')', 1) == []": "compile error: program instantiation failed: error parsing regexp: unexpected ): `)`",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{name: "URL parsing",
|
{name: "URL parsing",
|
||||||
obj: map[string]interface{}{
|
obj: map[string]interface{}{
|
||||||
@ -2051,6 +2061,22 @@ func withRule(s schema.Structural, rule string) schema.Structural {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func withMaxLength(s schema.Structural, maxLength *int64) schema.Structural {
|
||||||
|
if s.ValueValidation == nil {
|
||||||
|
s.ValueValidation = &schema.ValueValidation{}
|
||||||
|
}
|
||||||
|
s.ValueValidation.MaxLength = maxLength
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func withMaxItems(s schema.Structural, maxItems *int64) schema.Structural {
|
||||||
|
if s.ValueValidation == nil {
|
||||||
|
s.ValueValidation = &schema.ValueValidation{}
|
||||||
|
}
|
||||||
|
s.ValueValidation.MaxItems = maxItems
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
func withDefault(dflt interface{}, s schema.Structural) schema.Structural {
|
func withDefault(dflt interface{}, s schema.Structural) schema.Structural {
|
||||||
s.Generic.Default = schema.JSON{Object: dflt}
|
s.Generic.Default = schema.JSON{Object: dflt}
|
||||||
return s
|
return s
|
||||||
|
@ -19,6 +19,7 @@ import (
|
|||||||
|
|
||||||
"github.com/google/cel-go/checker/decls"
|
"github.com/google/cel-go/checker/decls"
|
||||||
"github.com/google/cel-go/common/types"
|
"github.com/google/cel-go/common/types"
|
||||||
|
|
||||||
"k8s.io/apiextensions-apiserver/pkg/apiserver/schema"
|
"k8s.io/apiextensions-apiserver/pkg/apiserver/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user