Improve CEL cost tests to catch unhandled estimates or types

This commit is contained in:
Jordan Liggitt 2024-06-26 21:38:48 -04:00
parent 92e3445e9d
commit 1d2ad282cf
No known key found for this signature in database
2 changed files with 48 additions and 2 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
package library
import (
"fmt"
"math"
"github.com/google/cel-go/checker"
@ -25,9 +26,28 @@ import (
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"k8s.io/apiserver/pkg/cel"
)
// panicOnUnknown makes cost estimate functions panic on unrecognized functions.
// This is only set to true for unit tests.
var panicOnUnknown = false
// builtInFunctions is a list of functions used in cost tests that are not handled by CostEstimator.
var knownUnhandledFunctions = map[string]bool{
"uint": true,
"duration": true,
"bytes": true,
"timestamp": true,
"value": true,
"_==_": true,
"_&&_": true,
"_>_": true,
"!_": true,
"strings.quote": true,
}
// CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator.
type CostEstimator struct {
// SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation
@ -106,7 +126,7 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re
cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor))
return &cost
}
case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast":
case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast", "isGlobalUnicast":
// IP and CIDR accessors are nominal cost.
cost := uint64(1)
return &cost
@ -185,6 +205,13 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re
case "sign", "asInteger", "isInteger", "asApproximateFloat", "isGreaterThan", "isLessThan", "compareTo", "add", "sub":
cost := uint64(1)
return &cost
case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery":
// url accessors
cost := uint64(1)
return &cost
}
if panicOnUnknown && !knownUnhandledFunctions[function] {
panic(fmt.Errorf("CallCost: unhandled function %q or args %v", function, args))
}
return nil
}
@ -359,7 +386,7 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
// So we double the cost of parsing the string.
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor)}
}
case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast":
case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast", "isGlobalUnicast":
// IP and CIDR accessors are nominal cost.
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
case "containsIP":
@ -414,6 +441,12 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
case "sign", "asInteger", "isInteger", "asApproximateFloat", "isGreaterThan", "isLessThan", "compareTo", "add", "sub":
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery":
// url accessors
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
}
if panicOnUnknown && !knownUnhandledFunctions[function] {
panic(fmt.Errorf("EstimateCallCost: unhandled function %q, target %v, args %v", function, target, args))
}
return nil
}
@ -422,6 +455,10 @@ func actualSize(value ref.Val) uint64 {
if sz, ok := value.(traits.Sizer); ok {
return uint64(sz.Size().(types.Int))
}
if panicOnUnknown {
// debug.PrintStack()
panic(fmt.Errorf("actualSize: non-sizer type %T", value))
}
return 1
}

View File

@ -1053,6 +1053,10 @@ func TestSetsCost(t *testing.T) {
}
func testCost(t *testing.T, expr string, expectEsimatedCost checker.CostEstimate, expectRuntimeCost uint64) {
originalPanicOnUnknown := panicOnUnknown
panicOnUnknown = true
t.Cleanup(func() { panicOnUnknown = originalPanicOnUnknown })
est := &CostEstimator{SizeEstimator: &testCostEstimator{}}
env, err := cel.NewEnv(
ext.Strings(ext.StringsVersion(2)),
@ -1168,6 +1172,11 @@ func TestSize(t *testing.T) {
expectSize: checker.SizeEstimate{Min: 2, Max: 4},
},
}
originalPanicOnUnknown := panicOnUnknown
panicOnUnknown = true
t.Cleanup(func() { panicOnUnknown = originalPanicOnUnknown })
est := &CostEstimator{SizeEstimator: &testCostEstimator{}}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {