refit cost estimator with ast.Expr

This commit is contained in:
Jiahui Feng 2024-04-15 13:50:04 -07:00
parent 94997c6fef
commit ac5391fa21
2 changed files with 13 additions and 6 deletions

View File

@ -21,10 +21,10 @@ import (
"github.com/google/cel-go/checker" "github.com/google/cel-go/checker"
"github.com/google/cel-go/common" "github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/common/types/traits"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
) )
// CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator. // CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator.
@ -255,8 +255,10 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
// Worst case size is where is that a separator of "" is used, and each char is returned as a list element. // Worst case size is where is that a separator of "" is used, and each char is returned as a list element.
max := sz.Max max := sz.Max
if len(args) > 1 { if len(args) > 1 {
if c := args[1].Expr().GetConstExpr(); c != nil { if v := args[1].Expr().AsLiteral(); v != nil {
max = uint64(c.GetInt64Value()) if i, ok := v.Value().(int64); ok {
max = uint64(i)
}
} }
} }
// Cost is the traversal plus the construction of the result. // Cost is the traversal plus the construction of the result.
@ -425,7 +427,7 @@ func (l *CostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstim
type itemsNode struct { type itemsNode struct {
path []string path []string
t *types.Type t *types.Type
expr *exprpb.Expr expr ast.Expr
} }
func (i *itemsNode) Path() []string { func (i *itemsNode) Path() []string {
@ -436,7 +438,7 @@ func (i *itemsNode) Type() *types.Type {
return i.t return i.t
} }
func (i *itemsNode) Expr() *exprpb.Expr { func (i *itemsNode) Expr() ast.Expr {
return i.expr return i.expr
} }
@ -444,6 +446,8 @@ func (i *itemsNode) ComputedSize() *checker.SizeEstimate {
return nil return nil
} }
var _ checker.AstNode = (*itemsNode)(nil)
// traversalCost computes the cost of traversing a ref.Val as a data tree. // traversalCost computes the cost of traversing a ref.Val as a data tree.
func traversalCost(v ref.Val) uint64 { func traversalCost(v ref.Val) uint64 {
// TODO: This could potentially be optimized by sampling maps and lists instead of traversing. // TODO: This could potentially be optimized by sampling maps and lists instead of traversing.

View File

@ -23,6 +23,7 @@ import (
"github.com/google/cel-go/cel" "github.com/google/cel-go/cel"
"github.com/google/cel-go/checker" "github.com/google/cel-go/checker"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types"
"github.com/google/cel-go/ext" "github.com/google/cel-go/ext"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
@ -1147,6 +1148,8 @@ type testSizeNode struct {
size checker.SizeEstimate size checker.SizeEstimate
} }
var _ checker.AstNode = (*testSizeNode)(nil)
func (t testSizeNode) Path() []string { func (t testSizeNode) Path() []string {
return nil // not needed return nil // not needed
} }
@ -1155,7 +1158,7 @@ func (t testSizeNode) Type() *types.Type {
return nil // not needed return nil // not needed
} }
func (t testSizeNode) Expr() *exprpb.Expr { func (t testSizeNode) Expr() ast.Expr {
return nil // not needed return nil // not needed
} }