DRA CEL: add missing size estimator

Not implementing a size estimator had the effect that strings retrieved from
the attributes were treated as "unknown size", leading to wildly overestimating
the cost and validation errors even for even simple expressions like this:

    device.attributes["qat.intel.com"].services.matches("[^a]?sym")

Maximum number of elements in maps and the maximum length of the driver name
string were also ignored resp. missing. Pre-defined types like
apiservercel.StringType must be avoided because they are defined as having
a zero maximum size.
This commit is contained in:
Patrick Ohly 2025-01-16 11:50:13 +01:00
parent 6473e7b6ca
commit f89e4c08cf
6 changed files with 178 additions and 32 deletions

View File

@ -141,6 +141,10 @@ type ResourceSliceSpec struct {
Devices []Device
}
// DriverNameMaxLength is the maximum valid length of a driver name in the
// ResourceSliceSpec and other places. It's the same as for CSI driver names.
const DriverNameMaxLength = 63
// ResourcePool describes the pool that ResourceSlices belong to.
type ResourcePool struct {
// Name is used to identify the pool. For node-local devices, this

View File

@ -529,7 +529,8 @@ func TestValidateClaim(t *testing.T) {
claim.Spec.Devices.Requests[0].Selectors = []resource.DeviceSelector{
{
CEL: &resource.CELDeviceSelector{
Expression: `device.attributes["dra.example.com"].map(s, s.lowerAscii()).map(s, s.size()).sum() == 0`,
// From https://github.com/kubernetes/kubernetes/blob/50fc400f178d2078d0ca46aee955ee26375fc437/test/integration/apiserver/cel/validatingadmissionpolicy_test.go#L2150.
Expression: `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(x, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(y, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z2, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z3, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z4, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z5, int('1'.find('[0-9]*')) < 100)))))))`,
},
},
}

View File

@ -145,6 +145,10 @@ type ResourceSliceSpec struct {
Devices []Device `json:"devices" protobuf:"bytes,6,name=devices"`
}
// DriverNameMaxLength is the maximum valid length of a driver name in the
// ResourceSliceSpec and other places. It's the same as for CSI driver names.
const DriverNameMaxLength = 63
// ResourcePool describes the pool that ResourceSlices belong to.
type ResourcePool struct {
// Name is used to identify the pool. For node-local devices, this

View File

@ -144,6 +144,10 @@ type ResourceSliceSpec struct {
Devices []Device `json:"devices" protobuf:"bytes,6,name=devices"`
}
// DriverNameMaxLength is the maximum valid length of a driver name in the
// ResourceSliceSpec and other places. It's the same as for CSI driver names.
const DriverNameMaxLength = 63
// ResourcePool describes the pool that ResourceSlices belong to.
type ResourcePool struct {
// Name is used to identify the pool. For node-local devices, this

View File

@ -26,6 +26,7 @@ import (
"github.com/blang/semver/v4"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
@ -50,6 +51,23 @@ const (
var (
lazyCompilerInit sync.Once
lazyCompiler *compiler
// A variant of AnyType = https://github.com/kubernetes/kubernetes/blob/ec2e0de35a298363872897e5904501b029817af3/staging/src/k8s.io/apiserver/pkg/cel/types.go#L550:
// unknown actual type (could be bool, int, string, etc.) but with a known maximum size.
attributeType = withMaxElements(apiservercel.AnyType, resourceapi.DeviceAttributeMaxValueLength)
// Other strings also have a known maximum size.
domainType = withMaxElements(apiservercel.StringType, resourceapi.DeviceMaxDomainLength)
idType = withMaxElements(apiservercel.StringType, resourceapi.DeviceMaxIDLength)
driverType = withMaxElements(apiservercel.StringType, resourceapi.DriverNameMaxLength)
// Each map is bound by the maximum number of different attributes.
innerAttributesMapType = apiservercel.NewMapType(idType, attributeType, resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice)
outerAttributesMapType = apiservercel.NewMapType(domainType, innerAttributesMapType, resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice)
// Same for capacity.
innerCapacityMapType = apiservercel.NewMapType(idType, apiservercel.QuantityDeclType, resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice)
outerCapacityMapType = apiservercel.NewMapType(domainType, innerCapacityMapType, resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice)
)
func GetCompiler() *compiler {
@ -85,11 +103,12 @@ type Device struct {
}
type compiler struct {
envset *environment.EnvSet
}
func newCompiler() *compiler {
return &compiler{envset: mustBuildEnv()}
// deviceType is a definition for the type of the `device` variable.
// This is needed for the cost estimator. Both are currently version-independent.
// If that ever changes, some additional logic might be needed to make
// cost estimates version-dependent.
deviceType *apiservercel.DeclType
envset *environment.EnvSet
}
// Options contains several additional parameters
@ -124,7 +143,7 @@ func (c compiler) CompileCELExpression(expression string, options Options) Compi
// We don't have a SizeEstimator. The potential size of the input (= a
// device) is already declared in the definition of the environment.
estimator := &library.CostEstimator{}
estimator := c.newCostEstimator()
ast, issues := env.Compile(expression)
if issues != nil {
@ -169,6 +188,10 @@ func (c compiler) CompileCELExpression(expression string, options Options) Compi
return compilationResult
}
func (c *compiler) newCostEstimator() *library.CostEstimator {
return &library.CostEstimator{SizeEstimator: &sizeEstimator{compiler: c}}
}
// getAttributeValue returns the native representation of the one value that
// should be stored in the attribute, otherwise an error. An error is
// also returned when there is no supported value.
@ -241,7 +264,7 @@ func (c CompilationResult) DeviceMatches(ctx context.Context, input Device) (boo
return resultBool, details, nil
}
func mustBuildEnv() *environment.EnvSet {
func newCompiler() *compiler {
envset := environment.MustBaseEnvSet(environment.DefaultCompatibilityVersion(), true /* strictCost */)
field := func(name string, declType *apiservercel.DeclType, required bool) *apiservercel.DeclField {
return apiservercel.NewDeclField(name, declType, required, nil, nil)
@ -253,10 +276,11 @@ func mustBuildEnv() *environment.EnvSet {
}
return result
}
deviceType := apiservercel.NewObjectType("kubernetes.DRADevice", fields(
field(driverVar, apiservercel.StringType, true),
field(attributesVar, apiservercel.NewMapType(apiservercel.StringType, apiservercel.NewMapType(apiservercel.StringType, apiservercel.AnyType, resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice), resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice), true),
field(capacityVar, apiservercel.NewMapType(apiservercel.StringType, apiservercel.NewMapType(apiservercel.StringType, apiservercel.QuantityDeclType, resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice), resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice), true),
field(driverVar, driverType, true),
field(attributesVar, outerAttributesMapType, true),
field(capacityVar, outerCapacityMapType, true),
))
versioned := []environment.VersionedOptions{
@ -284,7 +308,13 @@ func mustBuildEnv() *environment.EnvSet {
if err != nil {
panic(fmt.Errorf("internal error building CEL environment: %w", err))
}
return envset
return &compiler{envset: envset, deviceType: deviceType}
}
func withMaxElements(in *apiservercel.DeclType, maxElements uint64) *apiservercel.DeclType {
out := *in
out.MaxElements = int64(maxElements)
return &out
}
// parseQualifiedName splits into domain and identified, using the default domain
@ -322,3 +352,67 @@ func (m mapper) Find(key ref.Val) (ref.Val, bool) {
return m.defaultValue, true
}
// sizeEstimator tells the cost estimator the maximum size of maps or strings accessible through the `device` variable.
// Without this, the maximum string size of e.g. `device.attributes["dra.example.com"].services` would be unknown.
//
// sizeEstimator is derived from the sizeEstimator in k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel.
type sizeEstimator struct {
compiler *compiler
}
func (s *sizeEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
path := element.Path()
if len(path) == 0 {
// Path() can return an empty list, early exit if it does since we can't
// provide size estimates when that happens
return nil
}
// The estimator provides information about the environment's variable(s).
var currentNode *apiservercel.DeclType
switch path[0] {
case deviceVar:
currentNode = s.compiler.deviceType
default:
// Unknown root, shouldn't happen.
return nil
}
// Cut off initial variable from path, it was checked above.
for _, name := range path[1:] {
switch name {
case "@items", "@values":
if currentNode.ElemType == nil {
return nil
}
currentNode = currentNode.ElemType
case "@keys":
if currentNode.KeyType == nil {
return nil
}
currentNode = currentNode.KeyType
default:
field, ok := currentNode.Fields[name]
if !ok {
// If this is an attribute map, then we know that all elements
// have the same maximum size as set in attributeType, regardless
// of their name.
if currentNode.ElemType == attributeType {
currentNode = attributeType
continue
}
return nil
}
if field.Type == nil {
return nil
}
currentNode = field.Type
}
}
return &checker.SizeEstimate{Min: 0, Max: uint64(currentNode.MaxElements)}
}
func (s *sizeEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
return nil
}

View File

@ -203,6 +203,48 @@ device.attributes["dra.example.com"]["version"].isGreaterThan(semver("0.0.1"))
expectMatch: true,
expectCost: 12,
},
"check_attribute_domains": {
expression: `device.attributes.exists_one(x, x == "dra.example.com")`,
attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{"services": {StringValue: ptr.To("some_example_value")}},
driver: "dra.example.com",
expectMatch: true,
expectCost: 164,
},
"check_attribute_ids": {
expression: `device.attributes["dra.example.com"].exists_one(x, x == "services")`,
attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{"services": {StringValue: ptr.To("some_example_value")}},
driver: "dra.example.com",
expectMatch: true,
expectCost: 133,
},
"split_attribute": {
expression: `device.attributes["dra.example.com"].services.split("example").size() >= 2`,
attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{"services": {StringValue: ptr.To("some_example_value")}},
driver: "dra.example.com",
expectMatch: true,
expectCost: 19,
},
"regexp_attribute": {
expression: `device.attributes["dra.example.com"].services.matches("[^a]?sym")`,
attributes: map[resourceapi.QualifiedName]resourceapi.DeviceAttribute{"services": {StringValue: ptr.To("asymetric")}},
driver: "dra.example.com",
expectMatch: true,
expectCost: 18,
},
"check_capacity_domains": {
expression: `device.capacity.exists_one(x, x == "dra.example.com")`,
capacity: map[resourceapi.QualifiedName]resourceapi.DeviceCapacity{"memory": {Value: resource.MustParse("1Mi")}},
driver: "dra.example.com",
expectMatch: true,
expectCost: 164,
},
"check_capacity_ids": {
expression: `device.capacity["dra.example.com"].exists_one(x, x == "memory")`,
capacity: map[resourceapi.QualifiedName]resourceapi.DeviceCapacity{"memory": {Value: resource.MustParse("1Mi")}},
driver: "dra.example.com",
expectMatch: true,
expectCost: 133,
},
"expensive": {
// The worst-case is based on the maximum number of
// attributes and the maximum attribute name length.
@ -214,21 +256,18 @@ device.attributes["dra.example.com"]["version"].isGreaterThan(semver("0.0.1"))
attribute := resourceapi.DeviceAttribute{
StringValue: ptr.To("abc"),
}
// If the cost estimate was accurate, using exactly as many attributes
// as allowed at most should exceed the limit. In practice, the estimate
// is an upper bound and significantly more attributes are needed before
// the runtime cost becomes too large.
for i := 0; i < 1000*resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice; i++ {
for i := 0; i < resourceapi.ResourceSliceMaxAttributesAndCapacitiesPerDevice; i++ {
suffix := fmt.Sprintf("-%d", i)
name := prefix + strings.Repeat("x", resourceapi.DeviceMaxIDLength-len(suffix)) + suffix
attributes[resourceapi.QualifiedName(name)] = attribute
}
return attributes
}(),
expression: `device.attributes["dra.example.com"].map(s, s.lowerAscii()).map(s, s.size()).sum() == 0`,
// From https://github.com/kubernetes/kubernetes/blob/50fc400f178d2078d0ca46aee955ee26375fc437/test/integration/apiserver/cel/validatingadmissionpolicy_test.go#L2150.
expression: `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(x, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(y, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z2, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z3, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z4, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].all(z5, int('1'.find('[0-9]*')) < 100)))))))`,
driver: "dra.example.com",
expectMatchError: "actual cost limit exceeded",
expectCost: 18446744073709551615, // Exceeds limit!
expectCost: 85555551, // Exceed limit!
},
}
@ -238,50 +277,50 @@ func TestCEL(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
result := GetCompiler().CompileCELExpression(scenario.expression, Options{})
if scenario.expectCompileError != "" && result.Error == nil {
t.Fatalf("expected compile error %q, got none", scenario.expectCompileError)
t.Fatalf("FAILURE: expected compile error %q, got none", scenario.expectCompileError)
}
if result.Error != nil {
if scenario.expectCompileError == "" {
t.Fatalf("unexpected compile error: %v", result.Error)
t.Fatalf("FAILURE: unexpected compile error: %v", result.Error)
}
if !strings.Contains(result.Error.Error(), scenario.expectCompileError) {
t.Fatalf("expected compile error to contain %q, but got instead: %v", scenario.expectCompileError, result.Error)
t.Fatalf("FAILURE: expected compile error to contain %q, but got instead: %v", scenario.expectCompileError, result.Error)
}
return
}
if scenario.expectCompileError != "" {
t.Fatalf("expected compile error %q, got none", scenario.expectCompileError)
t.Fatalf("FAILURE: expected compile error %q, got none", scenario.expectCompileError)
}
if expect, actual := scenario.expectCost, result.MaxCost; expect != actual {
t.Errorf("expected CEL cost %d, got %d instead", expect, actual)
t.Errorf("ERROR: expected CEL cost %d, got %d instead (%.0f%% of limit %d)", expect, actual, float64(actual)*100.0/float64(resourceapi.CELSelectorExpressionMaxCost), resourceapi.CELSelectorExpressionMaxCost)
}
match, details, err := result.DeviceMatches(ctx, Device{Attributes: scenario.attributes, Capacity: scenario.capacity, Driver: scenario.driver})
// details.ActualCost can be called for nil details, no need to check.
actualCost := ptr.Deref(details.ActualCost(), 0)
if scenario.expectCost > 0 {
t.Logf("actual cost %d, %d%% of worst-case estimate", actualCost, actualCost*100/scenario.expectCost)
t.Logf("actual cost %d, %d%% of worst-case estimate %d", actualCost, actualCost*100/scenario.expectCost, scenario.expectCost)
} else {
t.Logf("actual cost %d, expected zero costs", actualCost)
if actualCost > 0 {
t.Errorf("expected zero costs for (presumably) constant expression %q, got instead %d", scenario.expression, actualCost)
}
}
if actualCost > result.MaxCost {
t.Errorf("ERROR: cost estimate %d underestimated the evaluation cost of %d", result.MaxCost, actualCost)
}
if err != nil {
if scenario.expectMatchError == "" {
t.Fatalf("unexpected evaluation error: %v", err)
t.Fatalf("FAILURE: unexpected evaluation error: %v", err)
}
if !strings.Contains(err.Error(), scenario.expectMatchError) {
t.Fatalf("expected evaluation error to contain %q, but got instead: %v", scenario.expectMatchError, err)
t.Fatalf("FAILURE: expected evaluation error to contain %q, but got instead: %v", scenario.expectMatchError, err)
}
return
}
if scenario.expectMatchError != "" {
t.Fatalf("expected match error %q, got none", scenario.expectMatchError)
t.Fatalf("FAILURE: expected match error %q, got none", scenario.expectMatchError)
}
if match != scenario.expectMatch {
t.Fatalf("expected result %v, got %v", scenario.expectMatch, match)
t.Fatalf("FAILURE: expected result %v, got %v", scenario.expectMatch, match)
}
})
}