mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-18 08:09:58 +00:00
Merge pull request #126359 from jpbetz/quantity-estimated-cost
Fix estimated cost for Kubernetes defined CEL types for equals
This commit is contained in:
commit
f88281768c
@ -2010,14 +2010,14 @@ func TestCelEstimatedCostStability(t *testing.T) {
|
|||||||
`isQuantity(self.val2)`: 314575,
|
`isQuantity(self.val2)`: 314575,
|
||||||
`isQuantity("200M")`: 1,
|
`isQuantity("200M")`: 1,
|
||||||
`isQuantity("20Mi")`: 1,
|
`isQuantity("20Mi")`: 1,
|
||||||
`quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`: uint64(3689348814741910532),
|
`quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`: uint64(6),
|
||||||
`quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`: uint64(5534023222112865798),
|
`quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`: uint64(9),
|
||||||
`quantity(self.val1).isLessThan(quantity(self.val2))`: 629151,
|
`quantity(self.val1).isLessThan(quantity(self.val2))`: 629151,
|
||||||
`quantity("50M").isLessThan(quantity("100M"))`: 3,
|
`quantity("50M").isLessThan(quantity("100M"))`: 3,
|
||||||
`quantity("50Mi").isGreaterThan(quantity("50M"))`: 3,
|
`quantity("50Mi").isGreaterThan(quantity("50M"))`: 3,
|
||||||
`quantity("200M").compareTo(quantity("0.2G")) == 0`: 4,
|
`quantity("200M").compareTo(quantity("0.2G")) == 0`: 4,
|
||||||
`quantity("50k").add(quantity("20")) == quantity("50.02k")`: uint64(1844674407370955268),
|
`quantity("50k").add(quantity("20")) == quantity("50.02k")`: uint64(5),
|
||||||
`quantity("50k").sub(20) == quantity("49980")`: uint64(1844674407370955267),
|
`quantity("50k").sub(20) == quantity("49980")`: uint64(4),
|
||||||
`quantity("50").isInteger()`: 2,
|
`quantity("50").isInteger()`: 2,
|
||||||
`quantity(self.val1).isInteger()`: 314576,
|
`quantity(self.val1).isInteger()`: 314576,
|
||||||
},
|
},
|
||||||
|
@ -19,6 +19,7 @@ package library
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
"github.com/google/cel-go/checker"
|
"github.com/google/cel-go/checker"
|
||||||
"github.com/google/cel-go/common"
|
"github.com/google/cel-go/common"
|
||||||
@ -27,6 +28,7 @@ import (
|
|||||||
"github.com/google/cel-go/common/types/ref"
|
"github.com/google/cel-go/common/types/ref"
|
||||||
"github.com/google/cel-go/common/types/traits"
|
"github.com/google/cel-go/common/types/traits"
|
||||||
|
|
||||||
|
"k8s.io/apimachinery/pkg/util/sets"
|
||||||
"k8s.io/apiserver/pkg/cel"
|
"k8s.io/apiserver/pkg/cel"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -48,6 +50,22 @@ var knownUnhandledFunctions = map[string]bool{
|
|||||||
"strings.quote": true,
|
"strings.quote": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Replace this with a utility that extracts types from libraries.
|
||||||
|
var knownKubernetesRuntimeTypes = sets.New[reflect.Type](
|
||||||
|
reflect.ValueOf(cel.URL{}).Type(),
|
||||||
|
reflect.ValueOf(cel.IP{}).Type(),
|
||||||
|
reflect.ValueOf(cel.CIDR{}).Type(),
|
||||||
|
reflect.ValueOf(&cel.Format{}).Type(),
|
||||||
|
reflect.ValueOf(cel.Quantity{}).Type(),
|
||||||
|
)
|
||||||
|
var knownKubernetesCompilerTypes = sets.New[ref.Type](
|
||||||
|
cel.CIDRType,
|
||||||
|
cel.IPType,
|
||||||
|
cel.FormatType,
|
||||||
|
cel.QuantityType,
|
||||||
|
cel.URLType,
|
||||||
|
)
|
||||||
|
|
||||||
// CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator.
|
// CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator.
|
||||||
type CostEstimator struct {
|
type CostEstimator struct {
|
||||||
// SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation
|
// SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation
|
||||||
@ -235,6 +253,27 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re
|
|||||||
// url accessors
|
// url accessors
|
||||||
cost := uint64(1)
|
cost := uint64(1)
|
||||||
return &cost
|
return &cost
|
||||||
|
case "_==_":
|
||||||
|
if len(args) == 2 {
|
||||||
|
unitCost := uint64(1)
|
||||||
|
lhs := args[0]
|
||||||
|
switch lhs.(type) {
|
||||||
|
case cel.Quantity:
|
||||||
|
return &unitCost
|
||||||
|
case cel.IP:
|
||||||
|
return &unitCost
|
||||||
|
case cel.CIDR:
|
||||||
|
return &unitCost
|
||||||
|
case *cel.Format: // Formats have a small max size.
|
||||||
|
return &unitCost
|
||||||
|
case cel.URL: // TODO: Computing the actual cost is expensive, and changing this would be a breaking change
|
||||||
|
return &unitCost
|
||||||
|
default:
|
||||||
|
if panicOnUnknown && knownKubernetesRuntimeTypes.Has(reflect.ValueOf(lhs).Type()) {
|
||||||
|
panic(fmt.Errorf("CallCost: unhandled equality for Kubernetes type %T", lhs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if panicOnUnknown && !knownUnhandledFunctions[function] {
|
if panicOnUnknown && !knownUnhandledFunctions[function] {
|
||||||
panic(fmt.Errorf("CallCost: unhandled function %q or args %v", function, args))
|
panic(fmt.Errorf("CallCost: unhandled function %q or args %v", function, args))
|
||||||
@ -278,7 +317,7 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
|
|||||||
case "url":
|
case "url":
|
||||||
if len(args) == 1 {
|
if len(args) == 1 {
|
||||||
sz := l.sizeEstimate(args[0])
|
sz := l.sizeEstimate(args[0])
|
||||||
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)}
|
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz}
|
||||||
}
|
}
|
||||||
case "lowerAscii", "upperAscii", "substring", "trim":
|
case "lowerAscii", "upperAscii", "substring", "trim":
|
||||||
if target != nil {
|
if target != nil {
|
||||||
@ -475,6 +514,39 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
|
|||||||
case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery":
|
case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery":
|
||||||
// url accessors
|
// url accessors
|
||||||
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
|
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
|
||||||
|
case "_==_":
|
||||||
|
if len(args) == 2 {
|
||||||
|
lhs := args[0]
|
||||||
|
rhs := args[1]
|
||||||
|
if lhs.Type().Equal(rhs.Type()) == types.True {
|
||||||
|
t := lhs.Type()
|
||||||
|
if t.Kind() == types.OpaqueKind {
|
||||||
|
switch t.TypeName() {
|
||||||
|
case cel.IPType.TypeName(), cel.CIDRType.TypeName():
|
||||||
|
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if t.Kind() == types.StructKind {
|
||||||
|
switch t {
|
||||||
|
case cel.QuantityType: // O(1) cost equality checks
|
||||||
|
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
|
||||||
|
case cel.FormatType:
|
||||||
|
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: cel.MaxFormatSize}.MultiplyByCostFactor(common.StringTraversalCostFactor)}
|
||||||
|
case cel.URLType:
|
||||||
|
size := checker.SizeEstimate{Min: 1, Max: 1}
|
||||||
|
rhSize := rhs.ComputedSize()
|
||||||
|
lhSize := rhs.ComputedSize()
|
||||||
|
if rhSize != nil && lhSize != nil {
|
||||||
|
size = rhSize.Union(*lhSize)
|
||||||
|
}
|
||||||
|
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: size.Max}.MultiplyByCostFactor(common.StringTraversalCostFactor)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if panicOnUnknown && knownKubernetesCompilerTypes.Has(t) {
|
||||||
|
panic(fmt.Errorf("EstimateCallCost: unhandled equality for Kubernetes type %v", t))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if panicOnUnknown && !knownUnhandledFunctions[function] {
|
if panicOnUnknown && !knownUnhandledFunctions[function] {
|
||||||
panic(fmt.Errorf("EstimateCallCost: unhandled function %q, target %v, args %v", function, target, args))
|
panic(fmt.Errorf("EstimateCallCost: unhandled function %q, target %v, args %v", function, target, args))
|
||||||
|
@ -206,6 +206,16 @@ func TestURLsCost(t *testing.T) {
|
|||||||
expectEsimatedCost: checker.CostEstimate{Min: 4, Max: 4},
|
expectEsimatedCost: checker.CostEstimate{Min: 4, Max: 4},
|
||||||
expectRuntimeCost: 4,
|
expectRuntimeCost: 4,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ops: []string{" == url('https:://kubernetes.io/')"},
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 7, Max: 9},
|
||||||
|
expectRuntimeCost: 7,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ops: []string{" == url('http://x.b')"},
|
||||||
|
expectEsimatedCost: checker.CostEstimate{Min: 5, Max: 5},
|
||||||
|
expectRuntimeCost: 5,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
@ -245,6 +255,14 @@ func TestIPCost(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectRuntimeCost: func(c uint64) uint64 { return c + 1 },
|
expectRuntimeCost: func(c uint64) uint64 { return c + 1 },
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ops: []string{" == ip('192.168.0.1')"},
|
||||||
|
// For most other operations, the cost is expected to be the base + 1.
|
||||||
|
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
|
||||||
|
return c.Add(ipv4BaseEstimatedCost).Add(checker.CostEstimate{Min: 1, Max: 1})
|
||||||
|
},
|
||||||
|
expectRuntimeCost: func(c uint64) uint64 { return c + ipv4BaseRuntimeCost + 1 },
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
@ -320,6 +338,14 @@ func TestCIDRCost(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectRuntimeCost: func(c uint64) uint64 { return c + 1 },
|
expectRuntimeCost: func(c uint64) uint64 { return c + 1 },
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ops: []string{" == cidr('2001:db8::/32')"},
|
||||||
|
// For most other operations, the cost is expected to be the base + 1.
|
||||||
|
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
|
||||||
|
return c.Add(ipv6BaseEstimatedCost).Add(checker.CostEstimate{Min: 1, Max: 1})
|
||||||
|
},
|
||||||
|
expectRuntimeCost: func(c uint64) uint64 { return c + ipv6BaseRuntimeCost + 1 },
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gocritic
|
//nolint:gocritic
|
||||||
@ -708,19 +734,19 @@ func TestQuantityCost(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "equality_reflexivity",
|
name: "equality_reflexivity",
|
||||||
expr: `quantity("200M") == quantity("200M")`,
|
expr: `quantity("200M") == quantity("200M")`,
|
||||||
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 1844674407370955266},
|
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 3},
|
||||||
expectRuntimeCost: 3,
|
expectRuntimeCost: 3,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "equality_symmetry",
|
name: "equality_symmetry",
|
||||||
expr: `quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`,
|
expr: `quantity("200M") == quantity("0.2G") && quantity("0.2G") == quantity("200M")`,
|
||||||
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 3689348814741910532},
|
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 6},
|
||||||
expectRuntimeCost: 6,
|
expectRuntimeCost: 6,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "equality_transitivity",
|
name: "equality_transitivity",
|
||||||
expr: `quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`,
|
expr: `quantity("2M") == quantity("0.002G") && quantity("2000k") == quantity("2M") && quantity("0.002G") == quantity("2000k")`,
|
||||||
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 5534023222112865798},
|
expectEstimatedCost: checker.CostEstimate{Min: 3, Max: 9},
|
||||||
expectRuntimeCost: 9,
|
expectRuntimeCost: 9,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -744,19 +770,19 @@ func TestQuantityCost(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "add_quantity",
|
name: "add_quantity",
|
||||||
expr: `quantity("50k").add(quantity("20")) == quantity("50.02k")`,
|
expr: `quantity("50k").add(quantity("20")) == quantity("50.02k")`,
|
||||||
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 1844674407370955268},
|
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 5},
|
||||||
expectRuntimeCost: 5,
|
expectRuntimeCost: 5,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "sub_quantity",
|
name: "sub_quantity",
|
||||||
expr: `quantity("50k").sub(quantity("20")) == quantity("49.98k")`,
|
expr: `quantity("50k").sub(quantity("20")) == quantity("49.98k")`,
|
||||||
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 1844674407370955268},
|
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 5},
|
||||||
expectRuntimeCost: 5,
|
expectRuntimeCost: 5,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "sub_int",
|
name: "sub_int",
|
||||||
expr: `quantity("50k").sub(20) == quantity("49980")`,
|
expr: `quantity("50k").sub(20) == quantity("49980")`,
|
||||||
expectEstimatedCost: checker.CostEstimate{Min: 4, Max: 1844674407370955267},
|
expectEstimatedCost: checker.CostEstimate{Min: 4, Max: 4},
|
||||||
expectRuntimeCost: 4,
|
expectRuntimeCost: 4,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -825,6 +851,18 @@ func TestNameFormatCost(t *testing.T) {
|
|||||||
expectEstimatedCost: checker.CostEstimate{Min: 34, Max: 34},
|
expectEstimatedCost: checker.CostEstimate{Min: 34, Max: 34},
|
||||||
expectRuntimeCost: 10,
|
expectRuntimeCost: 10,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "format.dns1123label.validate",
|
||||||
|
expr: `format.named("dns1123Label").value().validate("my-name")`,
|
||||||
|
expectEstimatedCost: checker.CostEstimate{Min: 34, Max: 34},
|
||||||
|
expectRuntimeCost: 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "format.dns1123label.validate",
|
||||||
|
expr: `format.named("dns1123Label").value() == format.named("dns1123Label").value()`,
|
||||||
|
expectEstimatedCost: checker.CostEstimate{Min: 5, Max: 11},
|
||||||
|
expectRuntimeCost: 5,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
|
@ -22,6 +22,8 @@ import (
|
|||||||
|
|
||||||
"github.com/google/cel-go/common/types"
|
"github.com/google/cel-go/common/types"
|
||||||
"github.com/google/cel-go/common/types/ref"
|
"github.com/google/cel-go/common/types/ref"
|
||||||
|
|
||||||
|
"k8s.io/apiserver/pkg/cel"
|
||||||
"k8s.io/apiserver/pkg/cel/library"
|
"k8s.io/apiserver/pkg/cel/library"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -228,3 +230,11 @@ func TestFormat(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSizeLimit(t *testing.T) {
|
||||||
|
for name := range library.ConstantFormats {
|
||||||
|
if len(name) > cel.MaxFormatSize {
|
||||||
|
t.Fatalf("All formats must be <= %d chars in length", cel.MaxFormatSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -48,5 +48,7 @@ const (
|
|||||||
// MinNumberSize is the length of literal 0
|
// MinNumberSize is the length of literal 0
|
||||||
MinNumberSize = 1
|
MinNumberSize = 1
|
||||||
|
|
||||||
|
// MaxFormatSize is the maximum size we allow for format strings
|
||||||
|
MaxFormatSize = 64
|
||||||
MaxNameFormatRegexSize = 128
|
MaxNameFormatRegexSize = 128
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user