Merge pull request #129690 from pohly/automated-cherry-pick-of-#129661-origin-release-1.32

Automated cherry pick of #129661: DRA CEL: add missing size estimator
This commit is contained in:
Kubernetes Prow Robot 2025-03-03 05:21:16 -08:00 committed by GitHub
commit 8207fb465f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 203 additions and 41 deletions

View File

@ -141,6 +141,10 @@ type ResourceSliceSpec struct {
Devices []Device
}
// DriverNameMaxLength is the maximum valid length of a driver name in the
// ResourceSliceSpec and other places. It's the same as for CSI driver names.
const DriverNameMaxLength = 63
// ResourcePool describes the pool that ResourceSlices belong to.
type ResourcePool struct {
// Name is used to identify the pool. For node-local devices, this

View File

@ -529,7 +529,8 @@ func TestValidateClaim(t *testing.T) {
claim.Spec.Devices.Requests[0].Selectors = []resource.DeviceSelector{
{
CEL: &resource.CELDeviceSelector{
Expression: `device.attributes["dra.example.com"].map(s, s.lowerAscii()).map(s, s.size()).sum() == 0`,
// From https://github.com/kubernetes/kubernetes/blob/50fc400f178d2078d0ca46aee955ee26375fc437/test/integration/apiserver/cel/validatingadmissionpolicy_test.go#L2150.
Expression: `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(x, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(y, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z2, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z3, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z4, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z5, int('1'.find('[0-9]*')) < 100)))))))`,
},
},
}

View File

@ -145,6 +145,10 @@ type ResourceSliceSpec struct {
Devices []Device `json:"devices" protobuf:"bytes,6,name=devices"`
}
// DriverNameMaxLength is the maximum valid length of a driver name in the
// ResourceSliceSpec and other places. It's the same as for CSI driver names.
const DriverNameMaxLength = 63
// ResourcePool describes the pool that ResourceSlices belong to.
type ResourcePool struct {
// Name is used to identify the pool. For node-local devices, this

View File

@ -144,6 +144,10 @@ type ResourceSliceSpec struct {
Devices []Device `json:"devices" protobuf:"bytes,6,name=devices"`
}
// DriverNameMaxLength is the maximum valid length of a driver name in the
// ResourceSliceSpec and other places. It's the same as for CSI driver names.
const DriverNameMaxLength = 63
// ResourcePool describes the pool that ResourceSlices belong to.
type ResourcePool struct {
// Name is used to identify the pool. For node-local devices, this

View File

@ -43,6 +43,8 @@ func NewCache(maxCacheEntries int) *Cache {
// GetOrCompile checks whether the cache already has a compilation result
// and returns that if available. Otherwise it compiles, stores successful
// results and returns the new result.
//
// Cost estimation is disabled.
func (c *Cache) GetOrCompile(expression string) CompilationResult {
// Compiling a CEL expression is expensive enough that it is cheaper
// to lock a mutex than doing it several times in parallel.
@ -55,7 +57,7 @@ func (c *Cache) GetOrCompile(expression string) CompilationResult {
return *cached
}
expr := GetCompiler().CompileCELExpression(expression, Options{})
expr := GetCompiler().CompileCELExpression(expression, Options{DisableCostEstimation: true})
if expr.Error == nil {
c.add(expression, &expr)
}

View File

@ -18,6 +18,7 @@ package cel
import (
"fmt"
"math"
"sync"
"testing"
@ -73,6 +74,11 @@ func TestCacheSemantic(t *testing.T) {
if resultFalse == resultFalseAgain {
t.Fatal("result of compiling `false` should have been evicted from the cache")
}
// Cost estimation must be off (not needed by scheduler).
if resultFalseAgain.MaxCost != math.MaxUint64 {
t.Error("cost estimation should have been disabled, was enabled")
}
}
func TestCacheConcurrency(t *testing.T) {

View File

@ -20,12 +20,14 @@ import (
"context"
"errors"
"fmt"
"math"
"reflect"
"strings"
"sync"
"github.com/blang/semver/v4"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
@ -50,6 +52,23 @@ const (
var (
lazyCompilerInit sync.Once
lazyCompiler *compiler
// A variant of AnyType = https://github.com/kubernetes/kubernetes/blob/ec2e0de35a298363872897e5904501b029817af3/staging/src/k8s.io/apiserver/pkg/cel/types.go#L550:
// unknown actual type (could be bool, int, string, etc.) but with a known maximum size.
attributeType = withMaxElements(apiservercel.AnyType, resourceapi.DeviceAttributeMaxValueLength)
// Other strings also have a known maximum size.
domainType = withMaxElements(apiservercel.StringType, resourceapi.DeviceMaxDomainLength)
idType = withMaxElements(apiservercel.StringType, resourceapi.DeviceMaxIDLength)
driverType = withMaxElements(apiservercel.StringType, resourceapi.DriverNameMaxLength)
// Each map is bound by the maximum number of different attributes.
innerAttributesMapType = apiservercel.NewMapType(idType, attributeType, resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice)
outerAttributesMapType = apiservercel.NewMapType(domainType, innerAttributesMapType, resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice)
// Same for capacity.
innerCapacityMapType = apiservercel.NewMapType(idType, apiservercel.QuantityDeclType, resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice)
outerCapacityMapType = apiservercel.NewMapType(domainType, innerCapacityMapType, resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice)
)
func GetCompiler() *compiler {
@ -85,13 +104,14 @@ type Device struct {
}
type compiler struct {
// deviceType is a definition for the type of the `device` variable.
// This is needed for the cost estimator. Both are currently version-independent.
// If that ever changes, some additional logic might be needed to make
// cost estimates version-dependent.
deviceType *apiservercel.DeclType
envset *environment.EnvSet
}
func newCompiler() *compiler {
return &compiler{envset: mustBuildEnv()}
}
// Options contains several additional parameters
// for [CompileCELExpression]. All of them have reasonable
// defaults.
@ -101,6 +121,10 @@ type Options struct {
// CostLimit allows overriding the default runtime cost limit [resourceapi.CELSelectorExpressionMaxCost].
CostLimit *uint64
// DisableCostEstimation can be set to skip estimating the worst-case CEL cost.
// If disabled or after an error, [CompilationResult.MaxCost] will be set to [math.Uint64].
DisableCostEstimation bool
}
// CompileCELExpression returns a compiled CEL expression. It evaluates to bool.
@ -114,6 +138,7 @@ func (c compiler) CompileCELExpression(expression string, options Options) Compi
Detail: errorString,
},
Expression: expression,
MaxCost: math.MaxUint64,
}
}
@ -122,10 +147,6 @@ func (c compiler) CompileCELExpression(expression string, options Options) Compi
return resultError(fmt.Sprintf("unexpected error loading CEL environment: %v", err), apiservercel.ErrorTypeInternal)
}
// We don't have a SizeEstimator. The potential size of the input (= a
// device) is already declared in the definition of the environment.
estimator := &library.CostEstimator{}
ast, issues := env.Compile(expression)
if issues != nil {
return resultError("compilation failed: "+issues.String(), apiservercel.ErrorTypeInvalid)
@ -157,18 +178,28 @@ func (c compiler) CompileCELExpression(expression string, options Options) Compi
OutputType: ast.OutputType(),
Environment: env,
emptyMapVal: env.CELTypeAdapter().NativeToValue(map[string]any{}),
MaxCost: math.MaxUint64,
}
if !options.DisableCostEstimation {
// We don't have a SizeEstimator. The potential size of the input (= a
// device) is already declared in the definition of the environment.
estimator := c.newCostEstimator()
costEst, err := env.EstimateCost(ast, estimator)
if err != nil {
compilationResult.Error = &apiservercel.Error{Type: apiservercel.ErrorTypeInternal, Detail: "cost estimation failed: " + err.Error()}
return compilationResult
}
compilationResult.MaxCost = costEst.Max
}
return compilationResult
}
func (c *compiler) newCostEstimator() *library.CostEstimator {
return &library.CostEstimator{SizeEstimator: &sizeEstimator{compiler: c}}
}
// getAttributeValue returns the native representation of the one value that
// should be stored in the attribute, otherwise an error. An error is
// also returned when there is no supported value.
@ -241,7 +272,7 @@ func (c CompilationResult) DeviceMatches(ctx context.Context, input Device) (boo
return resultBool, details, nil
}
func mustBuildEnv() *environment.EnvSet {
func newCompiler() *compiler {
envset := environment.MustBaseEnvSet(environment.DefaultCompatibilityVersion(), true /* strictCost */)
field := func(name string, declType *apiservercel.DeclType, required bool) *apiservercel.DeclField {
return apiservercel.NewDeclField(name, declType, required, nil, nil)
@ -253,10 +284,11 @@ func mustBuildEnv() *environment.EnvSet {
}
return result
}
deviceType := apiservercel.NewObjectType("kubernetes.DRADevice", fields(
field(driverVar, apiservercel.StringType, true),
field(attributesVar, apiservercel.NewMapType(apiservercel.StringType, apiservercel.NewMapType(apiservercel.StringType, apiservercel.AnyType, resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice), resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice), true),
field(capacityVar, apiservercel.NewMapType(apiservercel.StringType, apiservercel.NewMapType(apiservercel.StringType, apiservercel.QuantityDeclType, resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice), resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice), true),
field(driverVar, driverType, true),
field(attributesVar, outerAttributesMapType, true),
field(capacityVar, outerCapacityMapType, true),
))
versioned := []environment.VersionedOptions{
@ -284,7 +316,13 @@ func mustBuildEnv() *environment.EnvSet {
if err != nil {
panic(fmt.Errorf("internal error building CEL environment: %w", err))
}
return envset
return &compiler{envset: envset, deviceType: deviceType}
}
func withMaxElements(in *apiservercel.DeclType, maxElements uint64) *apiservercel.DeclType {
out := *in
out.MaxElements = int64(maxElements)
return &out
}
// parseQualifiedName splits into domain and identified, using the default domain
@ -322,3 +360,67 @@ func (m mapper) Find(key ref.Val) (ref.Val, bool) {
return m.defaultValue, true
}
// sizeEstimator tells the cost estimator the maximum size of maps or strings accessible through the `device` variable.
// Without this, the maximum string size of e.g. `device.attributes["dra.example.com"].services` would be unknown.
//
// sizeEstimator is derived from the sizeEstimator in k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel.
type sizeEstimator struct {
compiler *compiler
}
func (s *sizeEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
path := element.Path()
if len(path) == 0 {
// Path() can return an empty list, early exit if it does since we can't
// provide size estimates when that happens
return nil
}
// The estimator provides information about the environment's variable(s).
var currentNode *apiservercel.DeclType
switch path[0] {
case deviceVar:
currentNode = s.compiler.deviceType
default:
// Unknown root, shouldn't happen.
return nil
}
// Cut off initial variable from path, it was checked above.
for _, name := range path[1:] {
switch name {
case "@items", "@values":
if currentNode.ElemType == nil {
return nil
}
currentNode = currentNode.ElemType
case "@keys":
if currentNode.KeyType == nil {
return nil
}
currentNode = currentNode.KeyType
default:
field, ok := currentNode.Fields[name]
if !ok {
// If this is an attribute map, then we know that all elements
// have the same maximum size as set in attributeType, regardless
// of their name.
if currentNode.ElemType == attributeType {
currentNode = attributeType
continue
}
return nil
}
if field.Type == nil {
return nil
}
currentNode = field.Type
}
}
return &checker.SizeEstimate{Min: 0, Max: uint64(currentNode.MaxElements)}
}
func (s *sizeEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
return nil
}

View File

@ -203,6 +203,48 @@ device.attributes["dra.example.com"]["version"].isGreaterThan(semver("0.0.1"))
expectMatch: true,
expectCost: 12,
},
"check_attribute_domains": {
expression: `device.attributes.exists_one(x, x == "dra.example.com")`,
attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{"services": {StringValue: ptr.To("some_example_value")}},
driver: "dra.example.com",
expectMatch: true,
expectCost: 164,
},
"check_attribute_ids": {
expression: `device.attributes["dra.example.com"].exists_one(x, x == "services")`,
attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{"services": {StringValue: ptr.To("some_example_value")}},
driver: "dra.example.com",
expectMatch: true,
expectCost: 133,
},
"split_attribute": {
expression: `device.attributes["dra.example.com"].services.split("example").size() >= 2`,
attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{"services": {StringValue: ptr.To("some_example_value")}},
driver: "dra.example.com",
expectMatch: true,
expectCost: 19,
},
"regexp_attribute": {
expression: `device.attributes["dra.example.com"].services.matches("[^a]?sym")`,
attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{"services": {StringValue: ptr.To("asymetric")}},
driver: "dra.example.com",
expectMatch: true,
expectCost: 18,
},
"check_capacity_domains": {
expression: `device.capacity.exists_one(x, x == "dra.example.com")`,
capacity: map[resourceapi.QualifiedName]resourceapi.DeviceCapacity{"memory": {Value: resource.MustParse("1Mi")}},
driver: "dra.example.com",
expectMatch: true,
expectCost: 164,
},
"check_capacity_ids": {
expression: `device.capacity["dra.example.com"].exists_one(x, x == "memory")`,
capacity: map[resourceapi.QualifiedName]resourceapi.DeviceCapacity{"memory": {Value: resource.MustParse("1Mi")}},
driver: "dra.example.com",
expectMatch: true,
expectCost: 133,
},
"expensive": {
// The worst-case is based on the maximum number of
// attributes and the maximum attribute name length.
@ -214,21 +256,18 @@ device.attributes["dra.example.com"]["version"].isGreaterThan(semver("0.0.1"))
attribute := resourceapi.DeviceAttribute{
StringValue: ptr.To("abc"),
}
// If the cost estimate was accurate, using exactly as many attributes
// as allowed at most should exceed the limit. In practice, the estimate
// is an upper bound and significantly more attributes are needed before
// the runtime cost becomes too large.
for i := 0; i < 1000*resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice; i++ {
for i := 0; i < resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice; i++ {
suffix := fmt.Sprintf("-%d", i)
name := prefix + strings.Repeat("x", resourceapi.DeviceMaxIDLength-len(suffix)) + suffix
attributes[resourceapi.QualifiedName(name)] = attribute
}
return attributes
}(),
expression: `device.attributes["dra.example.com"].map(s, s.lowerAscii()).map(s, s.size()).sum() == 0`,
// From https://github.com/kubernetes/kubernetes/blob/50fc400f178d2078d0ca46aee955ee26375fc437/test/integration/apiserver/cel/validatingadmissionpolicy_test.go#L2150.
expression: `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(x, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(y, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z2, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z3, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z4, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z5, int('1'.find('[0-9]*')) < 100)))))))`,
driver: "dra.example.com",
expectMatchError: "actual cost limit exceeded",
expectCost: 18446744073709551615, // Exceeds limit!
expectCost: 85555551, // Exceed limit!
},
}
@ -238,50 +277,50 @@ func TestCEL(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
result := GetCompiler().CompileCELExpression(scenario.expression, Options{})
if scenario.expectCompileError != "" && result.Error == nil {
t.Fatalf("expected compile error %q, got none", scenario.expectCompileError)
t.Fatalf("FAILURE: expected compile error %q, got none", scenario.expectCompileError)
}
if result.Error != nil {
if scenario.expectCompileError == "" {
t.Fatalf("unexpected compile error: %v", result.Error)
t.Fatalf("FAILURE: unexpected compile error: %v", result.Error)
}
if !strings.Contains(result.Error.Error(), scenario.expectCompileError) {
t.Fatalf("expected compile error to contain %q, but got instead: %v", scenario.expectCompileError, result.Error)
t.Fatalf("FAILURE: expected compile error to contain %q, but got instead: %v", scenario.expectCompileError, result.Error)
}
return
}
if scenario.expectCompileError != "" {
t.Fatalf("expected compile error %q, got none", scenario.expectCompileError)
t.Fatalf("FAILURE: expected compile error %q, got none", scenario.expectCompileError)
}
if expect, actual := scenario.expectCost, result.MaxCost; expect != actual {
t.Errorf("expected CEL cost %d, got %d instead", expect, actual)
t.Errorf("ERROR: expected CEL cost %d, got %d instead (%.0f%% of limit %d)", expect, actual, float64(actual)*100.0/float64(resourceapi.CELSelectorExpressionMaxCost), resourceapi.CELSelectorExpressionMaxCost)
}
match, details, err := result.DeviceMatches(ctx, Device{Attributes: scenario.attributes, Capacity: scenario.capacity, Driver: scenario.driver})
// details.ActualCost can be called for nil details, no need to check.
actualCost := ptr.Deref(details.ActualCost(), 0)
if scenario.expectCost > 0 {
t.Logf("actual cost %d, %d%% of worst-case estimate", actualCost, actualCost*100/scenario.expectCost)
t.Logf("actual cost %d, %d%% of worst-case estimate %d", actualCost, actualCost*100/scenario.expectCost, scenario.expectCost)
} else {
t.Logf("actual cost %d, expected zero costs", actualCost)
if actualCost > 0 {
t.Errorf("expected zero costs for (presumably) constant expression %q, got instead %d", scenario.expression, actualCost)
}
if actualCost > result.MaxCost {
t.Errorf("ERROR: cost estimate %d underestimated the evaluation cost of %d", result.MaxCost, actualCost)
}
if err != nil {
if scenario.expectMatchError == "" {
t.Fatalf("unexpected evaluation error: %v", err)
t.Fatalf("FAILURE: unexpected evaluation error: %v", err)
}
if !strings.Contains(err.Error(), scenario.expectMatchError) {
t.Fatalf("expected evaluation error to contain %q, but got instead: %v", scenario.expectMatchError, err)
t.Fatalf("FAILURE: expected evaluation error to contain %q, but got instead: %v", scenario.expectMatchError, err)
}
return
}
if scenario.expectMatchError != "" {
t.Fatalf("expected match error %q, got none", scenario.expectMatchError)
t.Fatalf("FAILURE: expected match error %q, got none", scenario.expectMatchError)
}
if match != scenario.expectMatch {
t.Fatalf("expected result %v, got %v", scenario.expectMatch, match)
t.Fatalf("FAILURE: expected result %v, got %v", scenario.expectMatch, match)
}
})
}