From 1d2ad282cff163e51e5c24569a0ac762ed814e74 Mon Sep 17 00:00:00 2001 From: Jordan Liggitt Date: Wed, 26 Jun 2024 21:38:48 -0400 Subject: [PATCH] Improve CEL cost tests to catch unhandled estimates or types --- .../k8s.io/apiserver/pkg/cel/library/cost.go | 41 ++++++++++++++++++- .../apiserver/pkg/cel/library/cost_test.go | 9 ++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/staging/src/k8s.io/apiserver/pkg/cel/library/cost.go b/staging/src/k8s.io/apiserver/pkg/cel/library/cost.go index 2be62c0b124..79980d4bcbb 100644 --- a/staging/src/k8s.io/apiserver/pkg/cel/library/cost.go +++ b/staging/src/k8s.io/apiserver/pkg/cel/library/cost.go @@ -17,6 +17,7 @@ limitations under the License. package library import ( + "fmt" "math" "github.com/google/cel-go/checker" @@ -25,9 +26,28 @@ import ( "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" + "k8s.io/apiserver/pkg/cel" ) +// panicOnUnknown makes cost estimate functions panic on unrecognized functions. +// This is only set to true for unit tests. +var panicOnUnknown = false + +// builtInFunctions is a list of functions used in cost tests that are not handled by CostEstimator. +var knownUnhandledFunctions = map[string]bool{ + "uint": true, + "duration": true, + "bytes": true, + "timestamp": true, + "value": true, + "_==_": true, + "_&&_": true, + "_>_": true, + "!_": true, + "strings.quote": true, +} + // CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator. type CostEstimator struct { // SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation @@ -106,7 +126,7 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor)) return &cost } - case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast": + case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast", "isGlobalUnicast": // IP and CIDR accessors are nominal cost. cost := uint64(1) return &cost @@ -185,6 +205,13 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re case "sign", "asInteger", "isInteger", "asApproximateFloat", "isGreaterThan", "isLessThan", "compareTo", "add", "sub": cost := uint64(1) return &cost + case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery": + // url accessors + cost := uint64(1) + return &cost + } + if panicOnUnknown && !knownUnhandledFunctions[function] { + panic(fmt.Errorf("CallCost: unhandled function %q or args %v", function, args)) } return nil } @@ -359,7 +386,7 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch // So we double the cost of parsing the string. return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor)} } - case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast": + case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast", "isGlobalUnicast": // IP and CIDR accessors are nominal cost. return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} case "containsIP": @@ -414,6 +441,12 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} case "sign", "asInteger", "isInteger", "asApproximateFloat", "isGreaterThan", "isLessThan", "compareTo", "add", "sub": return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} + case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery": + // url accessors + return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} + } + if panicOnUnknown && !knownUnhandledFunctions[function] { + panic(fmt.Errorf("EstimateCallCost: unhandled function %q, target %v, args %v", function, target, args)) } return nil } @@ -422,6 +455,10 @@ func actualSize(value ref.Val) uint64 { if sz, ok := value.(traits.Sizer); ok { return uint64(sz.Size().(types.Int)) } + if panicOnUnknown { + // debug.PrintStack() + panic(fmt.Errorf("actualSize: non-sizer type %T", value)) + } return 1 } diff --git a/staging/src/k8s.io/apiserver/pkg/cel/library/cost_test.go b/staging/src/k8s.io/apiserver/pkg/cel/library/cost_test.go index e42cebbf4dc..b46591cd2d4 100644 --- a/staging/src/k8s.io/apiserver/pkg/cel/library/cost_test.go +++ b/staging/src/k8s.io/apiserver/pkg/cel/library/cost_test.go @@ -1053,6 +1053,10 @@ func TestSetsCost(t *testing.T) { } func testCost(t *testing.T, expr string, expectEsimatedCost checker.CostEstimate, expectRuntimeCost uint64) { + originalPanicOnUnknown := panicOnUnknown + panicOnUnknown = true + t.Cleanup(func() { panicOnUnknown = originalPanicOnUnknown }) + est := &CostEstimator{SizeEstimator: &testCostEstimator{}} env, err := cel.NewEnv( ext.Strings(ext.StringsVersion(2)), @@ -1168,6 +1172,11 @@ func TestSize(t *testing.T) { expectSize: checker.SizeEstimate{Min: 2, Max: 4}, }, } + + originalPanicOnUnknown := panicOnUnknown + panicOnUnknown = true + t.Cleanup(func() { panicOnUnknown = originalPanicOnUnknown }) + est := &CostEstimator{SizeEstimator: &testCostEstimator{}} for _, tc := range cases { t.Run(tc.name, func(t *testing.T) {