Fix estimated cost for Kubernetes defined CEL types

This commit is contained in:
Joe Betz 2024-07-25 14:14:20 -04:00
parent b95f9c32d6
commit 0a4e863373
6 changed files with 98 additions and 9 deletions

View File

@ -33,7 +33,7 @@ type CIDR struct {
}
var (
CIDRType = cel.OpaqueType("net.CIDR")
CIDRType = cel.ObjectType("net.CIDR")
)
// ConvertToNative implements ref.Val.ConvertToNative.

View File

@ -33,7 +33,7 @@ type IP struct {
}
var (
IPType = cel.OpaqueType("net.IP")
IPType = cel.ObjectType("net.IP")
)
// ConvertToNative implements ref.Val.ConvertToNative.

View File

@ -235,6 +235,23 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re
// url accessors
cost := uint64(1)
return &cost
case "_==_":
if len(args) == 2 {
unitCost := uint64(1)
lhs := args[0]
switch lhs.(type) {
case cel.Quantity:
return &unitCost
case cel.IP:
return &unitCost
case cel.CIDR:
return &unitCost
case *cel.Format: // Formats have a small max size.
return &unitCost
case cel.URL: // TODO: Computing the actual cost is expensive, and changing this would be a breaking change
return &unitCost
}
}
}
if panicOnUnknown && !knownUnhandledFunctions[function] {
panic(fmt.Errorf("CallCost: unhandled function %q or args %v", function, args))
@ -278,7 +295,7 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
case "url":
if len(args) == 1 {
sz := l.sizeEstimate(args[0])
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)}
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz}
}
case "lowerAscii", "upperAscii", "substring", "trim":
if target != nil {
@ -475,6 +492,28 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery":
// url accessors
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
case "_==_":
if len(args) == 2 {
lhs := args[0]
rhs := args[1]
if lhs.Type().Equal(rhs.Type()) == types.True {
t := lhs.Type()
switch t {
case cel.IPType, cel.CIDRType, cel.QuantityType: // O(1) cost equality checks
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
case cel.FormatType:
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: cel.MaxFormatSize}.MultiplyByCostFactor(common.StringTraversalCostFactor)}
case cel.URLType:
size := checker.SizeEstimate{Min: 1, Max: 1}
rhSize := rhs.ComputedSize()
lhSize := rhs.ComputedSize()
if rhSize != nil && lhSize != nil {
size = rhSize.Union(*lhSize)
}
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: size.Max}.MultiplyByCostFactor(common.StringTraversalCostFactor)}
}
}
}
}
if panicOnUnknown && !knownUnhandledFunctions[function] {
panic(fmt.Errorf("EstimateCallCost: unhandled function %q, target %v, args %v", function, target, args))

View File

@ -206,6 +206,16 @@ func TestURLsCost(t *testing.T) {
expectEsimatedCost: checker.CostEstimate{Min: 4, Max: 4},
expectRuntimeCost: 4,
},
{
ops: []string{" == url('https:://kubernetes.io/')"},
expectEsimatedCost: checker.CostEstimate{Min: 7, Max: 9},
expectRuntimeCost: 7,
},
{
ops: []string{" == url('http://x.b')"},
expectEsimatedCost: checker.CostEstimate{Min: 5, Max: 5},
expectRuntimeCost: 5,
},
}
for _, tc := range cases {
@ -245,6 +255,14 @@ func TestIPCost(t *testing.T) {
},
expectRuntimeCost: func(c uint64) uint64 { return c + 1 },
},
{
ops: []string{" == ip('192.168.0.1')"},
// For most other operations, the cost is expected to be the base + 1.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return c.Add(ipv4BaseEstimatedCost).Add(checker.CostEstimate{Min: 1, Max: 1})
},
expectRuntimeCost: func(c uint64) uint64 { return c + ipv4BaseRuntimeCost + 1 },
},
}
for _, tc := range testCases {
@ -320,6 +338,14 @@ func TestCIDRCost(t *testing.T) {
},
expectRuntimeCost: func(c uint64) uint64 { return c + 1 },
},
{
ops: []string{" == cidr('2001:db8::/32')"},
// For most other operations, the cost is expected to be the base + 1.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return c.Add(ipv6BaseEstimatedCost).Add(checker.CostEstimate{Min: 1, Max: 1})
},
expectRuntimeCost: func(c uint64) uint64 { return c + ipv6BaseRuntimeCost + 1 },
},
}
//nolint:gocritic
@ -708,19 +734,19 @@ func TestQuantityCost(t *testing.T) {
{
name: "equality_reflexivity",
expr: `quantity("200M") == quantity("200M")`,
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 1844674407370955266},
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
name: "equality_symmetry",
expr: `quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`,
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 3689348814741910532},
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 6},
expectRuntimeCost: 6,
},
{
name: "equality_transitivity",
expr: `quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`,
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 5534023222112865798},
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 9},
expectRuntimeCost: 9,
},
{
@ -744,19 +770,19 @@ func TestQuantityCost(t *testing.T) {
{
name: "add_quantity",
expr: `quantity("50k").add(quantity("20")) == quantity("50.02k")`,
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 1844674407370955268},
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 5},
expectRuntimeCost: 5,
},
{
name: "sub_quantity",
expr: `quantity("50k").sub(quantity("20")) == quantity("49.98k")`,
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 1844674407370955268},
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 5},
expectRuntimeCost: 5,
},
{
name: "sub_int",
expr: `quantity("50k").sub(20) == quantity("49980")`,
expectEstimatedCost: checker.CostEstimate{Min: 4, Max: 1844674407370955267},
expectEstimatedCost: checker.CostEstimate{Min: 4, Max: 4},
expectRuntimeCost: 4,
},
{
@ -825,6 +851,18 @@ func TestNameFormatCost(t *testing.T) {
expectEstimatedCost: checker.CostEstimate{Min: 34, Max: 34},
expectRuntimeCost: 10,
},
{
name: "format.dns1123label.validate",
expr: `format.named("dns1123Label").value().validate("my-name")`,
expectEstimatedCost: checker.CostEstimate{Min: 34, Max: 34},
expectRuntimeCost: 10,
},
{
name: "format.dns1123label.validate",
expr: `format.named("dns1123Label").value() == format.named("dns1123Label").value()`,
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 11},
expectRuntimeCost: 5,
},
}
for _, tc := range cases {

View File

@ -22,6 +22,8 @@ import (
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"k8s.io/apiserver/pkg/cel"
"k8s.io/apiserver/pkg/cel/library"
)
@ -228,3 +230,11 @@ func TestFormat(t *testing.T) {
})
}
}
func TestSizeLimit(t *testing.T) {
for name := range library.ConstantFormats {
if len(name) > cel.MaxFormatSize {
t.Fatalf("All formats must be <= %d chars in length", cel.MaxFormatSize)
}
}
}

View File

@ -48,5 +48,7 @@ const (
// MinNumberSize is the length of literal 0
MinNumberSize = 1
// MaxFormatSize is the maximum size we allow for format strings
MaxFormatSize = 64
MaxNameFormatRegexSize = 128
)