From 0a2dfba067d7c75fafb9844f3cf4539153b582cf Mon Sep 17 00:00:00 2001 From: Joe Betz Date: Tue, 27 Aug 2024 14:42:58 -0400 Subject: [PATCH] Add equality cost checking --- .../pkg/cel/environment/base_test.go | 5 +- .../k8s.io/apiserver/pkg/cel/library/cost.go | 59 ++++++++---------- .../apiserver/pkg/cel/library/cost_test.go | 61 +++++++++++++++---- .../cel/library/library_compatibility_test.go | 14 ++++- 4 files changed, 91 insertions(+), 48 deletions(-) diff --git a/staging/src/k8s.io/apiserver/pkg/cel/environment/base_test.go b/staging/src/k8s.io/apiserver/pkg/cel/environment/base_test.go index d893f650c96..2b68f6b2617 100644 --- a/staging/src/k8s.io/apiserver/pkg/cel/environment/base_test.go +++ b/staging/src/k8s.io/apiserver/pkg/cel/environment/base_test.go @@ -115,6 +115,8 @@ func TestLibraryCoverage(t *testing.T) { } } +// TestKnownLibraries ensures that all libraries used in the base environment are also registered with +// KnownLibraries. Other tests rely on KnownLibraries to provide an up-to-date list of CEL libraries. func TestKnownLibraries(t *testing.T) { known := sets.New[string]() used := sets.New[string]() @@ -132,9 +134,8 @@ func TestKnownLibraries(t *testing.T) { unexpected := used.Difference(known) if len(unexpected) != 0 { - t.Errorf("Expected all libraries in the base environment to be included k8s.io/apiserver/pkg/cel/library's KnownLibraries, but found missing libraries: %v", unexpected) + t.Errorf("Expected all libraries in the base environment to be included in k8s.io/apiserver/pkg/cel/library's KnownLibraries, but found missing libraries: %v", unexpected) } - } func librariesInVersions(t *testing.T, vops ...VersionedOptions) []string { diff --git a/staging/src/k8s.io/apiserver/pkg/cel/library/cost.go b/staging/src/k8s.io/apiserver/pkg/cel/library/cost.go index 63863088c02..2ffb0755d6b 100644 --- a/staging/src/k8s.io/apiserver/pkg/cel/library/cost.go +++ b/staging/src/k8s.io/apiserver/pkg/cel/library/cost.go @@ -18,17 +18,15 @@ package library import ( "fmt" - "math" - "reflect" - "github.com/google/cel-go/checker" "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" + "math" + "strings" - "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apiserver/pkg/cel" ) @@ -50,22 +48,6 @@ var knownUnhandledFunctions = map[string]bool{ "strings.quote": true, } -// TODO: Replace this with a utility that extracts types from libraries. -var knownKubernetesRuntimeTypes = sets.New[reflect.Type]( - reflect.ValueOf(cel.URL{}).Type(), - reflect.ValueOf(cel.IP{}).Type(), - reflect.ValueOf(cel.CIDR{}).Type(), - reflect.ValueOf(&cel.Format{}).Type(), - reflect.ValueOf(cel.Quantity{}).Type(), -) -var knownKubernetesCompilerTypes = sets.New[ref.Type]( - cel.CIDRType, - cel.IPType, - cel.FormatType, - cel.QuantityType, - cel.URLType, -) - // CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator. type CostEstimator struct { // SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation @@ -258,18 +240,16 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re unitCost := uint64(1) lhs := args[0] switch lhs.(type) { - case cel.Quantity: - return &unitCost - case cel.IP: - return &unitCost - case cel.CIDR: - return &unitCost - case *cel.Format: // Formats have a small max size. - return &unitCost - case cel.URL: // TODO: Computing the actual cost is expensive, and changing this would be a breaking change + case *cel.Quantity, cel.Quantity, + *cel.IP, cel.IP, + *cel.CIDR, cel.CIDR, + *cel.Format, // Formats have a small max size. Format takes pointer receiver. + *cel.URL, cel.URL, // TODO: Computing the actual cost is expensive, and changing this would be a breaking change + *authorizerVal, authorizerVal, *pathCheckVal, pathCheckVal, *groupCheckVal, groupCheckVal, + *resourceCheckVal, resourceCheckVal, *decisionVal, decisionVal: return &unitCost default: - if panicOnUnknown && knownKubernetesRuntimeTypes.Has(reflect.ValueOf(lhs).Type()) { + if panicOnUnknown && isKubernetesType(lhs.Type()) { panic(fmt.Errorf("CallCost: unhandled equality for Kubernetes type %T", lhs)) } } @@ -528,7 +508,8 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch } if t.Kind() == types.StructKind { switch t { - case cel.QuantityType: // O(1) cost equality checks + case cel.QuantityType, AuthorizerType, PathCheckType, // O(1) cost equality checks + GroupCheckType, ResourceCheckType, DecisionType: return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} case cel.FormatType: return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: cel.MaxFormatSize}.MultiplyByCostFactor(common.StringTraversalCostFactor)} @@ -542,7 +523,7 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: size.Max}.MultiplyByCostFactor(common.StringTraversalCostFactor)} } } - if panicOnUnknown && knownKubernetesCompilerTypes.Has(t) { + if panicOnUnknown && isKubernetesType(t) { panic(fmt.Errorf("EstimateCallCost: unhandled equality for Kubernetes type %v", t)) } } @@ -651,3 +632,17 @@ func traversalCost(v ref.Val) uint64 { return 1 } } + +// isKubernetesType returns ture if a type is type defined by Kubernetes, +// as identified by opaque or struct types with a "kubernetes." prefix. +func isKubernetesType(t ref.Type) bool { + if tt, ok := t.(*types.Type); ok { + switch tt.Kind() { + case types.OpaqueKind, types.StructKind: + return strings.HasPrefix(tt.TypeName(), "kubernetes.") + default: + return false + } + } + return false +} diff --git a/staging/src/k8s.io/apiserver/pkg/cel/library/cost_test.go b/staging/src/k8s.io/apiserver/pkg/cel/library/cost_test.go index de4eaf009ab..b629a007b98 100644 --- a/staging/src/k8s.io/apiserver/pkg/cel/library/cost_test.go +++ b/staging/src/k8s.io/apiserver/pkg/cel/library/cost_test.go @@ -19,6 +19,7 @@ package library import ( "context" "fmt" + "github.com/google/cel-go/common/types/ref" "testing" "github.com/google/cel-go/cel" @@ -30,6 +31,7 @@ import ( exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" "k8s.io/apiserver/pkg/authorization/authorizer" + apiservercel "k8s.io/apiserver/pkg/cel" ) const ( @@ -1231,10 +1233,10 @@ func TestSize(t *testing.T) { est := &CostEstimator{SizeEstimator: &testCostEstimator{}} for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - var targetNode checker.AstNode = testSizeNode{size: tc.targetSize} + var targetNode checker.AstNode = testNode{size: tc.targetSize} argNodes := make([]checker.AstNode, len(tc.argSizes)) for i, arg := range tc.argSizes { - argNodes[i] = testSizeNode{size: arg} + argNodes[i] = testNode{size: arg} } result := est.EstimateCallCost(tc.function, tc.overload, &targetNode, argNodes) if result.ResultSize == nil { @@ -1247,25 +1249,62 @@ func TestSize(t *testing.T) { } } -type testSizeNode struct { +// TestTypeEquality ensures that cost is tested for all custom types used by Kubernetes libraries. +func TestTypeEquality(t *testing.T) { + examples := map[string]ref.Val{ + // Add example ref.Val's for custom types in Kubernetes here: + "kubernetes.authorization.Authorizer": authorizerVal{}, + "kubernetes.authorization.PathCheck": pathCheckVal{}, + "kubernetes.authorization.GroupCheck": groupCheckVal{}, + "kubernetes.authorization.ResourceCheck": resourceCheckVal{}, + "kubernetes.authorization.Decision": decisionVal{}, + "kubernetes.URL": apiservercel.URL{}, + "kubernetes.Quantity": apiservercel.Quantity{}, + "net.IP": apiservercel.IP{}, + "net.CIDR": apiservercel.CIDR{}, + "kubernetes.NamedFormat": &apiservercel.Format{}, + } + + originalPanicOnUnknown := panicOnUnknown + panicOnUnknown = true + t.Cleanup(func() { panicOnUnknown = originalPanicOnUnknown }) + est := &CostEstimator{SizeEstimator: &testCostEstimator{}} + + for _, lib := range KnownLibraries() { + for _, kt := range lib.Types() { + t.Run(kt.TypeName(), func(t *testing.T) { + typeNode := testNode{size: checker.SizeEstimate{Min: 10, Max: 100}, typ: kt} + est.EstimateCallCost("_==_", "", nil, []checker.AstNode{typeNode, typeNode}) + ex, ok := examples[kt.TypeName()] + if !ok { + t.Errorf("missing example for type: %s", kt.TypeName()) + } + est.CallCost("_==_", "", []ref.Val{ex, ex}, nil) + }) + } + } +} + +type testNode struct { size checker.SizeEstimate + typ *types.Type } -var _ checker.AstNode = (*testSizeNode)(nil) +var _ checker.AstNode = (*testNode)(nil) -func (t testSizeNode) Path() []string { +func (t testNode) Path() []string { return nil // not needed } -func (t testSizeNode) Type() *types.Type { +func (t testNode) Type() *types.Type { + return t.typ // not needed +} + +func (t testNode) Expr() ast.Expr { return nil // not needed } -func (t testSizeNode) Expr() ast.Expr { - return nil // not needed -} - -func (t testSizeNode) ComputedSize() *checker.SizeEstimate { +func (t testNode) ComputedSize() *checker.SizeEstimate { return &t.size } diff --git a/staging/src/k8s.io/apiserver/pkg/cel/library/library_compatibility_test.go b/staging/src/k8s.io/apiserver/pkg/cel/library/library_compatibility_test.go index cedbcb04f68..d42a1b91a15 100644 --- a/staging/src/k8s.io/apiserver/pkg/cel/library/library_compatibility_test.go +++ b/staging/src/k8s.io/apiserver/pkg/cel/library/library_compatibility_test.go @@ -66,21 +66,29 @@ func TestLibraryCompatibility(t *testing.T) { } } +// TestTypeRegistration ensures that all custom types defined and used by Kubernetes CEL libraries +// are returned by library.Types(). Other tests depend on Types() to provide an up-to-date list of +// types declared in a library. func TestTypeRegistration(t *testing.T) { for _, lib := range KnownLibraries() { registeredTypes := sets.New[*cel.Type]() usedTypes := sets.New[*cel.Type]() - // scan all registered functions + // scan all registered function declarations for the library for _, fn := range lib.declarations() { - testFn, err := decls.NewFunction("test", fn...) + fn, err := decls.NewFunction("placeholder-not-used", fn...) if err != nil { t.Fatal(err) } - for _, o := range testFn.OverloadDecls() { + for _, o := range fn.OverloadDecls() { + // ArgTypes include both the receiver type (if present) and + // all function argument types. for _, at := range o.ArgTypes() { switch at.Kind() { + // User defined types are either Opaque or Struct. case types.OpaqueKind, types.StructKind: usedTypes.Insert(at) + default: + // skip } } }