Merge pull request #108617 from jpbetz/function-ext-costs

CEL: Enable regex pre-compilation, Add cost estimates for function extension libraries
This commit is contained in:
Kubernetes Prow Robot 2022-03-22 14:11:59 -07:00 committed by GitHub
commit 95e30f66c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 830 additions and 19 deletions

View File

@ -117,17 +117,17 @@ func Compile(s *schema.Structural, isResourceRoot bool, perCallLimit uint64) ([]
if err != nil {
return nil, err
}
estimator := celEstimator{root: root}
estimator := newCostEstimator(root)
// compResults is the return value which saves a list of compilation results in the same order as x-kubernetes-validations rules.
compResults := make([]CompilationResult, len(celRules))
for i, rule := range celRules {
compResults[i] = compileRule(rule, env, perCallLimit, &estimator)
compResults[i] = compileRule(rule, env, perCallLimit, estimator)
}
return compResults, nil
}
func compileRule(rule apiextensions.ValidationRule, env *cel.Env, perCallLimit uint64, estimator *celEstimator) (compilationResult CompilationResult) {
func compileRule(rule apiextensions.ValidationRule, env *cel.Env, perCallLimit uint64, estimator *library.CostEstimator) (compilationResult CompilationResult) {
if len(strings.TrimSpace(rule.Rule)) == 0 {
// include a compilation result, but leave both program and error nil per documented return semantics of this
// function
@ -157,7 +157,13 @@ func compileRule(rule apiextensions.ValidationRule, env *cel.Env, perCallLimit u
}
// TODO: Ideally we could configure the per expression limit at validation time and set it to the remaining overall budget, but we would either need a way to pass in a limit at evaluation time or move program creation to validation time
prog, err := env.Program(ast, cel.EvalOptions(cel.OptOptimize, cel.OptTrackCost), cel.CostLimit(perCallLimit), cel.InterruptCheckFrequency(checkFrequency))
prog, err := env.Program(ast,
cel.EvalOptions(cel.OptOptimize, cel.OptTrackCost),
cel.CostLimit(perCallLimit),
cel.CostTracking(estimator),
cel.OptimizeRegex(library.ExtensionLibRegexOptimizations...),
cel.InterruptCheckFrequency(checkFrequency),
)
if err != nil {
compilationResult.Error = &Error{ErrorTypeInvalid, "program instantiation failed: " + err.Error()}
return
@ -180,11 +186,15 @@ func generateUniqueSelfTypeName() string {
return fmt.Sprintf("selfType%d", time.Now().Nanosecond())
}
type celEstimator struct {
func newCostEstimator(root *celmodel.DeclType) *library.CostEstimator {
return &library.CostEstimator{SizeEstimator: &sizeEstimator{root: root}}
}
type sizeEstimator struct {
root *celmodel.DeclType
}
func (c *celEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
func (c *sizeEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
if len(element.Path()) == 0 {
// Path() can return an empty list, early exit if it does since we can't
// provide size estimates when that happens
@ -218,6 +228,6 @@ func (c *celEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstima
return &checker.SizeEstimate{Min: 0, Max: uint64(currentNode.MaxElements)}
}
func (c *celEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
func (c *sizeEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
return nil
}

View File

@ -1073,7 +1073,10 @@ func genMapWithCustomItemRule(item *schema.Structural, rule string) func(maxProp
}
}
func schemaChecker(schema *schema.Structural, expectedCost uint64, calcLimit uint64, t *testing.T) func(t *testing.T) {
// schemaChecker checks the cost of the validation rule declared in the provided schema (it requires there be exactly one rule)
// and checks that the resulting equals the expectedCost if expectedCost is non-zero, and that the resulting cost is >= expectedCostExceedsLimit
// if expectedCostExceedsLimit is non-zero. Typically, only expectedCost or expectedCostExceedsLimit is non-zero, not both.
func schemaChecker(schema *schema.Structural, expectedCost uint64, expectedCostExceedsLimit uint64, t *testing.T) func(t *testing.T) {
return func(t *testing.T) {
// TODO(DangerOnTheRanger): if perCallLimit in compilation.go changes, this needs to change as well
compilationResults, err := Compile(schema, false, uint64(math.MaxInt64))
@ -1087,13 +1090,14 @@ func schemaChecker(schema *schema.Structural, expectedCost uint64, calcLimit uin
if result.Error != nil {
t.Errorf("Expected no compile-time error, got: %v", result.Error)
}
if calcLimit == 0 {
if expectedCost > 0 {
if result.MaxCost != expectedCost {
t.Errorf("Wrong cost (expected %d, got %d)", expectedCost, result.MaxCost)
}
} else {
if result.MaxCost < calcLimit {
t.Errorf("Cost did not exceed limit as expected (expected more than %d, got %d)", calcLimit, result.MaxCost)
}
if expectedCostExceedsLimit > 0 {
if result.MaxCost < expectedCostExceedsLimit {
t.Errorf("Cost did not exceed limit as expected (expected more than %d, got %d)", expectedCostExceedsLimit, result.MaxCost)
}
}
}
@ -1103,10 +1107,15 @@ func TestCostEstimation(t *testing.T) {
cases := []struct {
name string
schemaGenerator func(maxLength *int64) *schema.Structural
expectedCalcCost uint64
setMaxElements int64
expectedSetCost uint64
// calc costs expectations are checked against the generated schema without any max element limits set
expectedCalcCost uint64
expectCalcCostExceedsLimit uint64
// calc costs expectations are checked against the generated schema with max element limits set
expectedSetCost uint64
expectedSetCostExceedsLimit uint64
}{
{
name: "number array with all",
@ -1233,7 +1242,6 @@ func TestCostEstimation(t *testing.T) {
{
name: "O(n^2) loop with numbers",
schemaGenerator: genArrayWithRule("number", "self.all(x, self.all(y, true))"),
expectedCalcCost: 9895601504256,
expectCalcCostExceedsLimit: costLimit,
setMaxElements: 10,
expectedSetCost: 352,
@ -1241,7 +1249,6 @@ func TestCostEstimation(t *testing.T) {
{
name: "O(n^3) loop with numbers",
schemaGenerator: genArrayWithRule("number", "self.all(x, self.all(y, self.all(z, true)))"),
expectedCalcCost: 13499986500008999998,
expectCalcCostExceedsLimit: costLimit,
setMaxElements: 10,
expectedSetCost: 3552,
@ -1512,6 +1519,80 @@ func TestCostEstimation(t *testing.T) {
setMaxElements: 490,
expectedSetCost: 0,
},
// Ensure library functions are integrated with size estimates by testing the interesting cases.
{
name: "extended library regex find",
schemaGenerator: genStringWithRule("self.find('[0-9]+') == ''"),
expectedCalcCost: 629147,
setMaxElements: 10,
expectedSetCost: 11,
},
{
name: "extended library join",
schemaGenerator: func(max *int64) *schema.Structural {
strType := withMaxLength(primitiveType("string", ""), max)
array := withMaxItems(arrayType("atomic", nil, &strType), max)
array = withRule(array, "self.join(' ') == 'aa bb'")
return &array
},
expectedCalcCost: 329853068905,
setMaxElements: 10,
expectedSetCost: 43,
},
{
name: "extended library isSorted",
schemaGenerator: func(max *int64) *schema.Structural {
strType := withMaxLength(primitiveType("string", ""), max)
array := withMaxItems(arrayType("atomic", nil, &strType), max)
array = withRule(array, "self.isSorted() == true")
return &array
},
expectedCalcCost: 329854432052,
setMaxElements: 10,
expectedSetCost: 52,
},
{
name: "extended library replace",
schemaGenerator: func(max *int64) *schema.Structural {
strType := withMaxLength(primitiveType("string", ""), max)
objType := objectType(map[string]schema.Structural{
"str": strType,
"before": strType,
"after": strType,
})
objType = withRule(objType, "self.str.replace(self.before, self.after) == 'does not matter'")
return &objType
},
expectedCalcCost: 629154,
setMaxElements: 10,
expectedSetCost: 16,
},
{
name: "extended library split",
schemaGenerator: func(max *int64) *schema.Structural {
strType := withMaxLength(primitiveType("string", ""), max)
objType := objectType(map[string]schema.Structural{
"str": strType,
"separator": strType,
})
objType = withRule(objType, "self.str.split(self.separator) == []")
return &objType
},
expectedCalcCost: 629160,
setMaxElements: 10,
expectedSetCost: 22,
},
{
name: "extended library lowerAscii",
schemaGenerator: func(max *int64) *schema.Structural {
strType := withMaxLength(primitiveType("string", ""), max)
strType = withRule(strType, "self.lowerAscii() == 'lower!'")
return &strType
},
expectedCalcCost: 314575,
setMaxElements: 10,
expectedSetCost: 6,
},
}
for _, testCase := range cases {
t.Run(testCase.name, func(t *testing.T) {
@ -1520,7 +1601,7 @@ func TestCostEstimation(t *testing.T) {
t.Run("calc maxLength", schemaChecker(schema, testCase.expectedCalcCost, testCase.expectCalcCostExceedsLimit, t))
// static maxLength case
setSchema := testCase.schemaGenerator(&testCase.setMaxElements)
t.Run("set maxLength", schemaChecker(setSchema, testCase.expectedSetCost, 0, t))
t.Run("set maxLength", schemaChecker(setSchema, testCase.expectedSetCost, testCase.expectedSetCostExceedsLimit, t))
})
}
}

View File

@ -0,0 +1,268 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"math"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator.
type CostEstimator struct {
// SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation
// calculations to if the size is not well known (i.e. a constant).
SizeEstimator checker.CostEstimator
}
func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, result ref.Val) *uint64 {
switch function {
case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf":
var cost uint64
if len(args) > 0 {
cost += traversalCost(args[0]) // these O(n) operations all cost roughly the cost of a single traversal
}
return &cost
case "url", "lowerAscii", "upperAscii", "substring", "trim":
if len(args) >= 1 {
cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor))
return &cost
}
case "replace", "split":
if len(args) >= 1 {
// cost is the traversal plus the construction of the result
cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor))
return &cost
}
case "join":
if len(args) >= 1 {
cost := uint64(math.Ceil(float64(actualSize(result)) * 2 * common.StringTraversalCostFactor))
return &cost
}
case "find", "findAll":
if len(args) >= 2 {
strCost := uint64(math.Ceil((1.0 + float64(actualSize(args[0]))) * common.StringTraversalCostFactor))
// We don't know how many expressions are in the regex, just the string length (a huge
// improvement here would be to somehow get a count the number of expressions in the regex or
// how many states are in the regex state machine and use that to measure regex cost).
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
// in length.
regexCost := uint64(math.Ceil(float64(actualSize(args[1])) * common.RegexStringLengthCostFactor))
cost := strCost * regexCost
return &cost
}
}
return nil
}
func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
// WARNING: Any changes to this code impact API compatibility! The estimated cost is used to determine which CEL rules may be written to a
// CRD and any change (cost increases and cost decreases) are breaking.
switch function {
case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf":
if target != nil {
// Charge 1 cost for comparing each element in the list
elCost := checker.CostEstimate{Min: 1, Max: 1}
// If the list contains strings or bytes, add the cost of traversing all the strings/bytes as a way
// of estimating the additional comparison cost.
if elNode := l.listElementNode(*target); elNode != nil {
t := elNode.Type().GetPrimitive()
if t == exprpb.Type_STRING || t == exprpb.Type_BYTES {
sz := l.sizeEstimate(elNode)
elCost = elCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor))
}
return &checker.CallEstimate{CostEstimate: l.sizeEstimate(*target).MultiplyByCost(elCost)}
} else { // the target is a string, which is supported by indexOf and lastIndexOf
return &checker.CallEstimate{CostEstimate: l.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor)}
}
}
case "url":
if len(args) == 1 {
sz := l.sizeEstimate(args[0])
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)}
}
case "lowerAscii", "upperAscii", "substring", "trim":
if target != nil {
sz := l.sizeEstimate(*target)
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz}
}
case "replace":
if target != nil && len(args) >= 2 {
sz := l.sizeEstimate(*target)
toReplaceSz := l.sizeEstimate(args[0])
replaceWithSz := l.sizeEstimate(args[1])
// smallest possible result: smallest input size composed of the largest possible substrings being replaced by smallest possible replacement
minSz := uint64(math.Ceil(float64(sz.Min)/float64(toReplaceSz.Max))) * replaceWithSz.Min
// largest possible result: largest input size composed of the smallest possible substrings being replaced by largest possible replacement
maxSz := uint64(math.Ceil(float64(sz.Max)/float64(toReplaceSz.Min))) * replaceWithSz.Max
// cost is the traversal plus the construction of the result
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor), ResultSize: &checker.SizeEstimate{Min: minSz, Max: maxSz}}
}
case "split":
if target != nil {
sz := l.sizeEstimate(*target)
// Worst case size is where is that a separator of "" is used, and each char is returned as a list element.
max := sz.Max
if len(args) > 1 {
if c := args[1].Expr().GetConstExpr(); c != nil {
max = uint64(c.GetInt64Value())
}
}
// Cost is the traversal plus the construction of the result.
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor), ResultSize: &checker.SizeEstimate{Min: 0, Max: max}}
}
case "join":
if target != nil {
var sz checker.SizeEstimate
listSize := l.sizeEstimate(*target)
if elNode := l.listElementNode(*target); elNode != nil {
elemSize := l.sizeEstimate(elNode)
sz = listSize.Multiply(elemSize)
}
if len(args) > 0 {
sepSize := l.sizeEstimate(args[0])
minSeparators := uint64(0)
maxSeparators := uint64(0)
if listSize.Min > 0 {
minSeparators = listSize.Min - 1
}
if listSize.Max > 0 {
maxSeparators = listSize.Max - 1
}
sz = sz.Add(sepSize.Multiply(checker.SizeEstimate{Min: minSeparators, Max: maxSeparators}))
}
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz}
}
case "find", "findAll":
if target != nil && len(args) >= 1 {
sz := l.sizeEstimate(*target)
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
// in case where string is empty but regex is still expensive.
strCost := sz.Add(checker.SizeEstimate{Min: 1, Max: 1}).MultiplyByCostFactor(common.StringTraversalCostFactor)
// We don't know how many expressions are in the regex, just the string length (a huge
// improvement here would be to somehow get a count the number of expressions in the regex or
// how many states are in the regex state machine and use that to measure regex cost).
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
// in length.
regexCost := l.sizeEstimate(args[0]).MultiplyByCostFactor(common.RegexStringLengthCostFactor)
// worst case size of result is that every char is returned as separate find result.
return &checker.CallEstimate{CostEstimate: strCost.Multiply(regexCost), ResultSize: &checker.SizeEstimate{Min: 0, Max: sz.Max}}
}
}
return nil
}
func actualSize(value ref.Val) uint64 {
if sz, ok := value.(traits.Sizer); ok {
return uint64(sz.Size().(types.Int))
}
return 1
}
func (l *CostEstimator) sizeEstimate(t checker.AstNode) checker.SizeEstimate {
if sz := t.ComputedSize(); sz != nil {
return *sz
}
if sz := l.EstimateSize(t); sz != nil {
return *sz
}
return checker.SizeEstimate{Min: 0, Max: math.MaxUint64}
}
func (l *CostEstimator) listElementNode(list checker.AstNode) checker.AstNode {
if lt := list.Type().GetListType(); lt != nil {
nodePath := list.Path()
if nodePath != nil {
// Provide path if we have it so that a OpenAPIv3 maxLength validation can be looked up, if it exists
// for this node.
path := make([]string, len(nodePath)+1)
copy(path, nodePath)
path[len(nodePath)] = "@items"
return &itemsNode{path: path, t: lt.GetElemType(), expr: nil}
} else {
// Provide just the type if no path is available so that worst case size can be looked up based on type.
return &itemsNode{t: lt.GetElemType(), expr: nil}
}
}
return nil
}
func (l *CostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
if l.SizeEstimator != nil {
return l.SizeEstimator.EstimateSize(element)
}
return nil
}
type itemsNode struct {
path []string
t *exprpb.Type
expr *exprpb.Expr
}
func (i *itemsNode) Path() []string {
return i.path
}
func (i *itemsNode) Type() *exprpb.Type {
return i.t
}
func (i *itemsNode) Expr() *exprpb.Expr {
return i.expr
}
func (i *itemsNode) ComputedSize() *checker.SizeEstimate {
return nil
}
// traversalCost computes the cost of traversing a ref.Val as a data tree.
func traversalCost(v ref.Val) uint64 {
// TODO: This could potentially be optimized by sampling maps and lists instead of traversing.
switch vt := v.(type) {
case types.String:
return uint64(float64(len(string(vt))) * common.StringTraversalCostFactor)
case types.Bytes:
return uint64(float64(len([]byte(vt))) * common.StringTraversalCostFactor)
case traits.Lister:
cost := uint64(0)
for it := vt.Iterator(); it.HasNext() == types.True; {
i := it.Next()
cost += traversalCost(i)
}
return cost
case traits.Mapper: // maps and objects
cost := uint64(0)
for it := vt.Iterator(); it.HasNext() == types.True; {
k := it.Next()
cost += traversalCost(k) + traversalCost(vt.Get(k))
}
return cost
default:
return 1
}
}

View File

@ -0,0 +1,363 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"fmt"
"testing"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/ext"
expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
const (
intListLiteral = "[1, 2, 3, 4, 5]"
uintListLiteral = "[uint(1), uint(2), uint(3), uint(4), uint(5)]"
doubleListLiteral = "[1.0, 2.0, 3.0, 4.0, 5.0]"
boolListLiteral = "[false, true, false, true, false]"
stringListLiteral = "['012345678901', '012345678901', '012345678901', '012345678901', '012345678901']"
bytesListLiteral = "[bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901'), bytes('012345678901')]"
durationListLiteral = "[duration('1s'), duration('2s'), duration('3s'), duration('4s'), duration('5s')]"
timestampListLiteral = "[timestamp('2011-01-01T00:00:00.000+01:00'), timestamp('2011-01-02T00:00:00.000+01:00'), " +
"timestamp('2011-01-03T00:00:00.000+01:00'), timestamp('2011-01-04T00:00:00.000+01:00'), " +
"timestamp('2011-01-05T00:00:00.000+01:00')]"
stringLiteral = "'01234567890123456789012345678901234567890123456789'"
)
type comparableCost struct {
comparableLiteral string
expectedEstimatedCost checker.CostEstimate
expectedRuntimeCost uint64
param string
}
func TestListsCost(t *testing.T) {
cases := []struct {
opts []string
costs []comparableCost
}{
{
opts: []string{".sum()"},
// 10 cost for the list declaration, the rest is the due to the function call
costs: []comparableCost{
{
comparableLiteral: intListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: uintListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts
},
{
comparableLiteral: doubleListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: durationListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts
},
},
},
{
opts: []string{".isSorted()", ".max()", ".min()"},
// 10 cost for the list declaration, the rest is the due to the function call
costs: []comparableCost{
{
comparableLiteral: intListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: uintListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for numeric casts
},
{
comparableLiteral: doubleListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: boolListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: stringListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 25}, expectedRuntimeCost: 15, // +5 for string comparisons
},
{
comparableLiteral: bytesListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 25, Max: 35}, expectedRuntimeCost: 25, // +10 for casts from string to byte, +5 for byte comparisons
},
{
comparableLiteral: durationListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for numeric casts
},
{
comparableLiteral: timestampListLiteral,
expectedEstimatedCost: checker.CostEstimate{Min: 20, Max: 20}, expectedRuntimeCost: 20, // +5 for casts
},
},
},
}
for _, tc := range cases {
for _, op := range tc.opts {
for _, typ := range tc.costs {
t.Run(typ.comparableLiteral+op, func(t *testing.T) {
e := typ.comparableLiteral + op
testCost(t, e, typ.expectedEstimatedCost, typ.expectedRuntimeCost)
})
}
}
}
}
func TestIndexOfCost(t *testing.T) {
cases := []struct {
opts []string
costs []comparableCost
}{
{
opts: []string{".indexOf(%s)", ".lastIndexOf(%s)"},
// 10 cost for the list declaration, the rest is the due to the function call
costs: []comparableCost{
{
comparableLiteral: intListLiteral, param: "3",
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: uintListLiteral, param: "uint(3)",
expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +5 for numeric casts
},
{
comparableLiteral: doubleListLiteral, param: "3.0",
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: boolListLiteral, param: "true",
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 15}, expectedRuntimeCost: 15,
},
{
comparableLiteral: stringListLiteral, param: "'x'",
expectedEstimatedCost: checker.CostEstimate{Min: 15, Max: 25}, expectedRuntimeCost: 15, // +5 for string comparisons
},
{
comparableLiteral: bytesListLiteral, param: "bytes('x')",
expectedEstimatedCost: checker.CostEstimate{Min: 26, Max: 36}, expectedRuntimeCost: 26, // +11 for casts from string to byte, +5 for byte comparisons
},
{
comparableLiteral: durationListLiteral, param: "duration('3s')",
expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +6 for casts from duration to byte
},
{
comparableLiteral: timestampListLiteral, param: "timestamp('2011-01-03T00:00:00.000+01:00')",
expectedEstimatedCost: checker.CostEstimate{Min: 21, Max: 21}, expectedRuntimeCost: 21, // +6 for casts from timestamp to byte
},
// index of operations are also defined for strings
{
comparableLiteral: stringLiteral, param: "'123'",
expectedEstimatedCost: checker.CostEstimate{Min: 5, Max: 5}, expectedRuntimeCost: 5,
},
},
},
}
for _, tc := range cases {
for _, op := range tc.opts {
for _, typ := range tc.costs {
opWithParam := fmt.Sprintf(op, typ.param)
t.Run(typ.comparableLiteral+opWithParam, func(t *testing.T) {
e := typ.comparableLiteral + opWithParam
testCost(t, e, typ.expectedEstimatedCost, typ.expectedRuntimeCost)
})
}
}
}
}
func TestURLsCost(t *testing.T) {
cases := []struct {
ops []string
expectEsimatedCost checker.CostEstimate
expectRuntimeCost uint64
}{
{
ops: []string{".getScheme()", ".getHostname()", ".getHost()", ".getPort()", ".getEscapedPath()", ".getQuery()"},
expectEsimatedCost: checker.CostEstimate{Min: 4, Max: 4},
expectRuntimeCost: 4,
},
}
for _, tc := range cases {
for _, op := range tc.ops {
t.Run("url."+op, func(t *testing.T) {
testCost(t, "url('https:://kubernetes.io/')"+op, tc.expectEsimatedCost, tc.expectRuntimeCost)
})
}
}
}
func TestStringLibrary(t *testing.T) {
cases := []struct {
name string
expr string
expectEsimatedCost checker.CostEstimate
expectRuntimeCost uint64
}{
{
name: "lowerAscii",
expr: "'ABCDEFGHIJ abcdefghij'.lowerAscii()",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
name: "upperAscii",
expr: "'ABCDEFGHIJ abcdefghij'.upperAscii()",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
name: "replace",
expr: "'abc 123 def 123'.replace('123', '456')",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
name: "replace with limit",
expr: "'abc 123 def 123'.replace('123', '456', 1)",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
name: "split",
expr: "'abc 123 def 123'.split(' ')",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
name: "split with limit",
expr: "'abc 123 def 123'.split(' ', 1)",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
name: "substring",
expr: "'abc 123 def 123'.substring(5)",
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
expectRuntimeCost: 2,
},
{
name: "substring with end",
expr: "'abc 123 def 123'.substring(5, 8)",
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
expectRuntimeCost: 2,
},
{
name: "trim",
expr: "' abc 123 def 123 '.trim()",
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
expectRuntimeCost: 2,
},
{
name: "join with separator",
expr: "['aa', 'bb', 'cc', 'd', 'e', 'f', 'g', 'h', 'i', 'j'].join(' ')",
expectEsimatedCost: checker.CostEstimate{Min: 11, Max: 23},
expectRuntimeCost: 15,
},
{
name: "join",
expr: "['aa', 'bb', 'cc', 'd', 'e', 'f', 'g', 'h', 'i', 'j'].join()",
expectEsimatedCost: checker.CostEstimate{Min: 10, Max: 22},
expectRuntimeCost: 13,
},
{
name: "find",
expr: "'abc 123 def 123'.find('123')",
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
expectRuntimeCost: 2,
},
{
name: "findAll",
expr: "'abc 123 def 123'.findAll('123')",
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
expectRuntimeCost: 2,
},
{
name: "findAll with limit",
expr: "'abc 123 def 123'.findAll('123', 1)",
expectEsimatedCost: checker.CostEstimate{Min: 2, Max: 2},
expectRuntimeCost: 2,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
testCost(t, tc.expr, tc.expectEsimatedCost, tc.expectRuntimeCost)
})
}
}
func testCost(t *testing.T, expr string, expectEsimatedCost checker.CostEstimate, expectRuntimeCost uint64) {
est := &CostEstimator{SizeEstimator: &testCostEstimator{}}
env, err := cel.NewEnv(append(k8sExtensionLibs, ext.Strings())...)
if err != nil {
t.Fatalf("%v", err)
}
compiled, issues := env.Compile(expr)
if len(issues.Errors()) > 0 {
t.Fatalf("%v", issues.Errors())
}
estCost, err := env.EstimateCost(compiled, est)
if err != nil {
t.Fatalf("%v", err)
}
if estCost.Min != expectEsimatedCost.Min || estCost.Max != expectEsimatedCost.Max {
t.Errorf("Expected estimated cost of %d..%d but got %d..%d", expectEsimatedCost.Min, expectEsimatedCost.Max, estCost.Min, estCost.Max)
}
prog, err := env.Program(compiled, cel.CostTracking(est))
if err != nil {
t.Fatalf("%v", err)
}
_, details, err := prog.Eval(map[string]interface{}{})
if err != nil {
t.Fatalf("%v", err)
}
cost := details.ActualCost()
if *cost != expectRuntimeCost {
t.Errorf("Expected cost of %d but got %d", expectRuntimeCost, *cost)
}
}
type testCostEstimator struct {
}
func (t *testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
switch t := element.Type().TypeKind.(type) {
case *expr.Type_Primitive:
switch t.Primitive {
case expr.Type_STRING:
return &checker.SizeEstimate{Min: 0, Max: 12}
case expr.Type_BYTES:
return &checker.SizeEstimate{Min: 0, Max: 12}
}
}
return nil
}
func (t *testCostEstimator) EstimateCallCost(function, overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
return nil
}

View File

@ -19,6 +19,7 @@ package library
import (
"github.com/google/cel-go/cel"
"github.com/google/cel-go/ext"
"github.com/google/cel-go/interpreter"
)
// ExtensionLibs declares the set of CEL extension libraries available everywhere CEL is used in Kubernetes.
@ -29,3 +30,5 @@ var k8sExtensionLibs = []cel.EnvOption{
Regex(),
Lists(),
}
var ExtensionLibRegexOptimizations = []*interpreter.RegexOptimization{FindRegexOptimization, FindAllRegexOptimization}

View File

@ -23,6 +23,7 @@ import (
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/interpreter/functions"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
@ -156,3 +157,61 @@ func findAll(args ...ref.Val) ref.Val {
return types.NewStringList(types.DefaultTypeAdapter, result)
}
// FindRegexOptimization optimizes the 'find' function by compiling the regex pattern and
// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function
// call invocations.
var FindRegexOptimization = &interpreter.RegexOptimization{
Function: "find",
RegexIndex: 1,
Factory: func(call interpreter.InterpretableCall, regexPattern string) (interpreter.InterpretableCall, error) {
compiledRegex, err := regexp.Compile(regexPattern)
if err != nil {
return nil, err
}
return interpreter.NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(args ...ref.Val) ref.Val {
if len(args) != 2 {
return types.NoSuchOverloadErr()
}
in, ok := args[0].Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(args[0])
}
return types.String(compiledRegex.FindString(in))
}), nil
},
}
// FindAllRegexOptimization optimizes the 'findAll' function by compiling the regex pattern and
// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function
// call invocations.
var FindAllRegexOptimization = &interpreter.RegexOptimization{
Function: "findAll",
RegexIndex: 1,
Factory: func(call interpreter.InterpretableCall, regexPattern string) (interpreter.InterpretableCall, error) {
compiledRegex, err := regexp.Compile(regexPattern)
if err != nil {
return nil, err
}
return interpreter.NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(args ...ref.Val) ref.Val {
argn := len(args)
if argn < 2 || argn > 3 {
return types.NoSuchOverloadErr()
}
str, ok := args[0].Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(args[0])
}
n := int64(-1)
if argn == 3 {
n, ok = args[2].Value().(int64)
if !ok {
return types.MaybeNoSuchOverloadErr(args[2])
}
}
result := compiledRegex.FindAllString(str, int(n))
return types.NewStringList(types.DefaultTypeAdapter, result)
}), nil
},
}

View File

@ -196,6 +196,10 @@ func TestValidationExpressions(t *testing.T) {
"self.val1.upperAscii() == 'ROOK TAKES 👑'",
"self.val1.lowerAscii() == 'rook takes 👑'",
},
errors: map[string]string{
// Invalid regex with a string constant regex pattern is compile time error
"self.val1.matches(')')": "compile error: program instantiation failed: error parsing regexp: unexpected ): `)`",
},
},
{name: "escaped strings",
obj: objs("l1\nl2", "l1\nl2"),
@ -1637,6 +1641,12 @@ func TestValidationExpressions(t *testing.T) {
"self.str.findAll('xyz') == []",
"self.str.findAll('xyz', 1) == []",
},
errors: map[string]string{
// Invalid regex with a string constant regex pattern is compile time error
"self.str.find(')') == ''": "compile error: program instantiation failed: error parsing regexp: unexpected ): `)`",
"self.str.findAll(')') == []": "compile error: program instantiation failed: error parsing regexp: unexpected ): `)`",
"self.str.findAll(')', 1) == []": "compile error: program instantiation failed: error parsing regexp: unexpected ): `)`",
},
},
{name: "URL parsing",
obj: map[string]interface{}{
@ -2051,6 +2061,22 @@ func withRule(s schema.Structural, rule string) schema.Structural {
return s
}
func withMaxLength(s schema.Structural, maxLength *int64) schema.Structural {
if s.ValueValidation == nil {
s.ValueValidation = &schema.ValueValidation{}
}
s.ValueValidation.MaxLength = maxLength
return s
}
func withMaxItems(s schema.Structural, maxItems *int64) schema.Structural {
if s.ValueValidation == nil {
s.ValueValidation = &schema.ValueValidation{}
}
s.ValueValidation.MaxItems = maxItems
return s
}
func withDefault(dflt interface{}, s schema.Structural) schema.Structural {
s.Generic.Default = schema.JSON{Object: dflt}
return s

View File

@ -19,6 +19,7 @@ import (
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types"
"k8s.io/apiextensions-apiserver/pkg/apiserver/schema"
)