Merge pull request #108576 from jpbetz/cel-0_10_0

Bump to CEL v0.10.0 and update tests and usage
This commit is contained in:
Kubernetes Prow Robot 2022-03-08 07:06:33 -08:00 committed by GitHub
commit c964ef8d8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
68 changed files with 3449 additions and 838 deletions

View File

@ -203,4 +203,35 @@
See the License for the specific language governing permissions and
limitations under the License.
= vendor/github.com/google/cel-go/LICENSE 3b83ef96387f14655fc854ddc3c6bd57
===========================================================================
The common/types/pb/equal.go modification of proto.Equal logic
===========================================================================
Copyright (c) 2018 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
= vendor/github.com/google/cel-go/LICENSE 9e40c7725e55fa8f61a69abf908e2c6f

2
go.mod
View File

@ -269,7 +269,7 @@ replace (
github.com/golangplus/testing => github.com/golangplus/testing v0.0.0-20180327235837-af21d9c3145e
github.com/google/btree => github.com/google/btree v1.0.1
github.com/google/cadvisor => github.com/google/cadvisor v0.43.0
github.com/google/cel-go => github.com/google/cel-go v0.9.0
github.com/google/cel-go => github.com/google/cel-go v0.10.0
github.com/google/cel-spec => github.com/google/cel-spec v0.6.0
github.com/google/go-cmp => github.com/google/go-cmp v0.5.5
github.com/google/gofuzz => github.com/google/gofuzz v1.1.0

4
go.sum
View File

@ -221,8 +221,8 @@ github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/google/cadvisor v0.43.0 h1:z0ULgYPKZ7L/c7Zjq+ZD6ltklWwYdCSvBMgSjNC/hGo=
github.com/google/cadvisor v0.43.0/go.mod h1:+RdMSbc3FVr5NYCD2dOEJy/LI0jYJ/0xJXkzWXEyiFQ=
github.com/google/cel-go v0.9.0 h1:u1hg7lcZ/XWw2d3aV1jFS30ijQQ6q0/h1C2ZBeBD1gY=
github.com/google/cel-go v0.9.0/go.mod h1:U7ayypeSkw23szu4GaQTPJGx66c20mx8JklMSxrmI1w=
github.com/google/cel-go v0.10.0 h1:SBdarVzHoCXsTjqX+Lsgg9asSO7bViwgizzDi9kBigg=
github.com/google/cel-go v0.10.0/go.mod h1:U7ayypeSkw23szu4GaQTPJGx66c20mx8JklMSxrmI1w=
github.com/google/cel-spec v0.6.0/go.mod h1:Nwjgxy5CbjlPrtCWjeDjUyKMl8w41YBYGjsyDdqk0xA=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=

View File

@ -7,7 +7,7 @@ go 1.16
require (
github.com/emicklei/go-restful v2.9.5+incompatible
github.com/gogo/protobuf v1.3.2
github.com/google/cel-go v0.9.0
github.com/google/cel-go v0.10.0
github.com/google/go-cmp v0.5.5
github.com/google/gofuzz v1.1.0
github.com/google/uuid v1.1.2

View File

@ -229,8 +229,8 @@ github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Z
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/google/cel-go v0.9.0 h1:u1hg7lcZ/XWw2d3aV1jFS30ijQQ6q0/h1C2ZBeBD1gY=
github.com/google/cel-go v0.9.0/go.mod h1:U7ayypeSkw23szu4GaQTPJGx66c20mx8JklMSxrmI1w=
github.com/google/cel-go v0.10.0 h1:SBdarVzHoCXsTjqX+Lsgg9asSO7bViwgizzDi9kBigg=
github.com/google/cel-go v0.10.0/go.mod h1:U7ayypeSkw23szu4GaQTPJGx66c20mx8JklMSxrmI1w=
github.com/google/cel-spec v0.6.0/go.mod h1:Nwjgxy5CbjlPrtCWjeDjUyKMl8w41YBYGjsyDdqk0xA=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=

View File

@ -67,7 +67,9 @@ func Compile(s *schema.Structural, isResourceRoot bool) ([]CompilationResult, er
var propDecls []*expr.Decl
var root *celmodel.DeclType
var ok bool
env, err := cel.NewEnv()
env, err := cel.NewEnv(
cel.HomogeneousAggregateLiterals(),
)
if err != nil {
return nil, err
}

View File

@ -77,6 +77,92 @@ func TestValidationExpressions(t *testing.T) {
"self.val7 == 1.0",
},
},
{name: "numeric comparisons",
obj: objs(
int64(5), // val1, integer type, integer value
int64(10), // val2, integer type, integer value
int64(15), // val3, integer type, integer value
float64(10.0), // val4, number type, parsed from decimal literal
float64(10.0), // val5, float type, parsed from decimal literal
float64(10.0), // val6, double type, parsed from decimal literal
int64(10), // val7, number type, parsed from integer literal
int64(10), // val8, float type, parsed from integer literal
int64(10), // val9, double type, parsed from integer literal
),
schema: schemas(integerType, integerType, integerType, numberType, floatType, doubleType, numberType, floatType, doubleType),
valid: []string{
// xref: https://github.com/google/cel-spec/wiki/proposal-210
// compare integers with all float types
"double(self.val1) < self.val4",
"double(self.val1) <= self.val4",
"double(self.val2) <= self.val4",
"double(self.val2) == self.val4",
"double(self.val2) >= self.val4",
"double(self.val3) > self.val4",
"double(self.val3) >= self.val4",
"self.val1 < int(self.val4)",
"self.val2 == int(self.val4)",
"self.val3 > int(self.val4)",
"double(self.val1) < self.val5",
"double(self.val2) == self.val5",
"double(self.val3) > self.val5",
"self.val1 < int(self.val5)",
"self.val2 == int(self.val5)",
"self.val3 > int(self.val5)",
"double(self.val1) < self.val6",
"double(self.val2) == self.val6",
"double(self.val3) > self.val6",
"self.val1 < int(self.val6)",
"self.val2 == int(self.val6)",
"self.val3 > int(self.val6)",
// Also compare with float types backed by integer values,
// which is how integer literals are parsed from JSON for custom resources.
"double(self.val1) < self.val7",
"double(self.val2) == self.val7",
"double(self.val3) > self.val7",
"self.val1 < int(self.val7)",
"self.val2 == int(self.val7)",
"self.val3 > int(self.val7)",
"double(self.val1) < self.val8",
"double(self.val2) == self.val8",
"double(self.val3) > self.val8",
"self.val1 < int(self.val8)",
"self.val2 == int(self.val8)",
"self.val3 > int(self.val8)",
"double(self.val1) < self.val9",
"double(self.val2) == self.val9",
"double(self.val3) > self.val9",
"self.val1 < int(self.val9)",
"self.val2 == int(self.val9)",
"self.val3 > int(self.val9)",
// compare literal integers and floats
"double(5) < 10.0",
"double(10) == 10.0",
"double(15) > 10.0",
"5 < int(10.0)",
"10 == int(10.0)",
"15 > int(10.0)",
// compare integers with literal floats
"double(self.val1) < 10.0",
"double(self.val2) == 10.0",
"double(self.val3) > 10.0",
},
},
{name: "unicode strings",
obj: objs("Rook takes 👑", "Rook takes 👑"),
schema: schemas(stringType, stringType),
@ -698,17 +784,21 @@ func TestValidationExpressions(t *testing.T) {
"something": intOrStringType(),
}),
valid: []string{
// typical int-or-string usage would be to check both types
"type(self.something) == int ? self.something == 1 : self.something == '25%'",
// to require the value be a particular type, guard it with a runtime type check
// In Kubernetes 1.24 and later, the CEL type returns false for an int-or-string comparison against the
// other type, making it safe to write validation rules like:
"self.something == '25%'",
"self.something != 1",
"self.something == 1 || self.something == '25%'",
"self.something == '25%' || self.something == 1",
// In Kubernetes 1.23 and earlier, all int-or-string access must be guarded by a type check to prevent
// a runtime error attempting an equality check between string and int types.
"type(self.something) == string && self.something == '25%'",
},
errors: map[string]string{
// because the type is dynamic type checking fails a runtime even for unrelated types
"self.something == ['anything']": "no such overload",
// type checking fails a runtime if the value is an int and the expression assumes it is a string
// without a type check guard
"self.something == 1": "no such overload",
"type(self.something) == int ? self.something == 1 : self.something == '25%'",
// Because the type is dynamic it receives no type checking, and evaluates to false when compared to
// other types at runtime.
"self.something != ['anything']",
},
},
{name: "int in intOrString",
@ -719,17 +809,21 @@ func TestValidationExpressions(t *testing.T) {
"something": intOrStringType(),
}),
valid: []string{
// typical int-or-string usage would be to check both types
"type(self.something) == int ? self.something == 1 : self.something == '25%'",
// to require the value be a particular type, guard it with a runtime type check
// In Kubernetes 1.24 and later, the CEL type returns false for an int-or-string comparison against the
// other type, making it safe to write validation rules like:
"self.something == 1",
"self.something != 'some string'",
"self.something == 1 || self.something == '25%'",
"self.something == '25%' || self.something == 1",
// In Kubernetes 1.23 and earlier, all int-or-string access must be guarded by a type check to prevent
// a runtime error attempting an equality check between string and int types.
"type(self.something) == int && self.something == 1",
},
errors: map[string]string{
// because the type is dynamic type checking fails a runtime even for unrelated types
"self.something == ['anything']": "no such overload",
// type checking fails a runtime if the value is an int and the expression assumes it is a string
// without a type check guard
"self.something == 'anything'": "no such overload",
"type(self.something) == int ? self.something == 1 : self.something == '25%'",
// Because the type is dynamic it receives no type checking, and evaluates to false when compared to
// other types at runtime.
"self.something != ['anything']",
},
},
{name: "null in intOrString",

View File

@ -200,3 +200,34 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
===========================================================================
The common/types/pb/equal.go modification of proto.Equal logic
===========================================================================
Copyright (c) 2018 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -24,6 +24,7 @@ go_library(
"//common/types:go_default_library",
"//common/types/pb:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter:go_default_library",
"//interpreter/functions:go_default_library",
"//parser:go_default_library",
@ -34,6 +35,9 @@ go_library(
"@org_golang_google_protobuf//reflect/protoregistry:go_default_library",
"@org_golang_google_protobuf//types/descriptorpb:go_default_library",
"@org_golang_google_protobuf//types/dynamicpb:go_default_library",
"@org_golang_google_protobuf//types/known/anypb:go_default_library",
"@org_golang_google_protobuf//types/known/durationpb:go_default_library",
"@org_golang_google_protobuf//types/known/timestamppb:go_default_library",
],
)
@ -41,6 +45,7 @@ go_test(
name = "go_default_test",
srcs = [
"cel_test.go",
"io_test.go",
],
data = [
"//cel/testdata:gen_test_fds",
@ -56,9 +61,11 @@ go_test(
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter/functions:go_default_library",
"//test:go_default_library",
"//test/proto2pb:go_default_library",
"//test/proto3pb:go_default_library",
"@io_bazel_rules_go//proto/wkt:descriptor_go_proto",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//types/known/structpb:go_default_library",
],
)

View File

@ -16,6 +16,7 @@ package cel
import (
"errors"
"fmt"
"sync"
"github.com/google/cel-go/checker"
@ -89,13 +90,17 @@ type Env struct {
adapter ref.TypeAdapter
provider ref.TypeProvider
features map[int]bool
// program options tied to the environment.
progOpts []ProgramOption
// Internal parser representation
prsr *parser.Parser
// Internal checker representation
chk *checker.Env
chkErr error
once sync.Once
chk *checker.Env
chkErr error
chkOnce sync.Once
// Program options tied to the environment
progOpts []ProgramOption
}
// NewEnv creates a program environment configured with the standard library of CEL functions and
@ -147,18 +152,22 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
pe, _ := AstToParsedExpr(ast)
// Construct the internal checker env, erroring if there is an issue adding the declarations.
e.once.Do(func() {
ce := checker.NewEnv(e.Container, e.provider)
ce.EnableDynamicAggregateLiterals(true)
if e.HasFeature(FeatureDisableDynamicAggregateLiterals) {
ce.EnableDynamicAggregateLiterals(false)
}
err := ce.Add(e.declarations...)
e.chkOnce.Do(func() {
ce, err := checker.NewEnv(e.Container, e.provider,
checker.HomogeneousAggregateLiterals(
e.HasFeature(featureDisableDynamicAggregateLiterals)),
checker.CrossTypeNumericComparisons(
e.HasFeature(featureCrossTypeNumericComparisons)))
if err != nil {
e.chkErr = err
} else {
e.chk = ce
return
}
err = ce.Add(e.declarations...)
if err != nil {
e.chkErr = err
return
}
e.chk = ce
})
// The once call will ensure that this value is set or nil for all invocations.
if e.chkErr != nil {
@ -207,11 +216,10 @@ func (e *Env) CompileSource(src common.Source) (*Ast, *Issues) {
return nil, iss
}
checked, iss2 := e.Check(ast)
iss = iss.Append(iss2)
if iss.Err() != nil {
return nil, iss
if iss2.Err() != nil {
return nil, iss2
}
return checked, iss
return checked, iss2
}
// Extend the current environment with additional options to produce a new Env.
@ -280,8 +288,8 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
// HasFeature checks whether the environment enables the given feature
// flag, as enumerated in options.go.
func (e *Env) HasFeature(flag int) bool {
_, has := e.features[flag]
return has
enabled, has := e.features[flag]
return has && enabled
}
// Parse parses the input expression value `txt` to a Ast and/or a set of Issues.
@ -301,7 +309,7 @@ func (e *Env) Parse(txt string) (*Ast, *Issues) {
// It is possible to have both non-nil Ast and Issues values returned from this call; however,
// the mere presence of an Ast does not imply that it is valid for use.
func (e *Env) ParseSource(src common.Source) (*Ast, *Issues) {
res, errs := parser.ParseWithMacros(src, e.macros)
res, errs := e.prsr.Parse(src)
if len(errs.GetErrors()) > 0 {
return nil, &Issues{errs: errs}
}
@ -325,11 +333,6 @@ func (e *Env) Program(ast *Ast, opts ...ProgramOption) (Program, error) {
return newProgram(e, ast, optSet)
}
// SetFeature sets the given feature flag, as enumerated in options.go.
func (e *Env) SetFeature(flag int) {
e.features[flag] = true
}
// TypeAdapter returns the `ref.TypeAdapter` configured for the environment.
func (e *Env) TypeAdapter() ref.TypeAdapter {
return e.adapter
@ -402,6 +405,16 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
return checked, nil
}
// EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and
// extension functions provided by estimator.
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator) (checker.CostEstimate, error) {
checked, err := AstToCheckedExpr(ast)
if err != nil {
return checker.CostEstimate{}, fmt.Errorf("EsimateCost could not inspect Ast: %v", err)
}
return checker.Cost(checked, estimator), nil
}
// configure applies a series of EnvOptions to the current environment.
func (e *Env) configure(opts []EnvOption) (*Env, error) {
// Customized the environment using the provided EnvOption values. If an error is
@ -413,6 +426,14 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
return nil, err
}
}
prsrOpts := []parser.Option{parser.Macros(e.macros...)}
if e.HasFeature(featureEnableMacroCallTracking) {
prsrOpts = append(prsrOpts, parser.PopulateMacroCalls(true))
}
e.prsr, err = parser.NewParser(prsrOpts...)
if err != nil {
return nil, err
}
return e, nil
}
@ -454,6 +475,9 @@ func (i *Issues) Append(other *Issues) *Issues {
if i == nil {
return other
}
if other == nil {
return i
}
return NewIssues(i.errs.Append(other.errs.GetErrors()))
}

View File

@ -15,12 +15,20 @@
package cel
import (
"errors"
"fmt"
"reflect"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/parser"
"google.golang.org/protobuf/proto"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
anypb "google.golang.org/protobuf/types/known/anypb"
)
// CheckedExprToAst converts a checked expression proto message to an Ast.
@ -120,3 +128,153 @@ func AstToString(a *Ast) (string, error) {
info := a.SourceInfo()
return parser.Unparse(expr, info)
}
// RefValueToValue converts between ref.Val and api.expr.Value.
// The result Value is the serialized proto form. The ref.Val must not be error or unknown.
func RefValueToValue(res ref.Val) (*exprpb.Value, error) {
switch res.Type() {
case types.BoolType:
return &exprpb.Value{
Kind: &exprpb.Value_BoolValue{BoolValue: res.Value().(bool)}}, nil
case types.BytesType:
return &exprpb.Value{
Kind: &exprpb.Value_BytesValue{BytesValue: res.Value().([]byte)}}, nil
case types.DoubleType:
return &exprpb.Value{
Kind: &exprpb.Value_DoubleValue{DoubleValue: res.Value().(float64)}}, nil
case types.IntType:
return &exprpb.Value{
Kind: &exprpb.Value_Int64Value{Int64Value: res.Value().(int64)}}, nil
case types.ListType:
l := res.(traits.Lister)
sz := l.Size().(types.Int)
elts := make([]*exprpb.Value, 0, int64(sz))
for i := types.Int(0); i < sz; i++ {
v, err := RefValueToValue(l.Get(i))
if err != nil {
return nil, err
}
elts = append(elts, v)
}
return &exprpb.Value{
Kind: &exprpb.Value_ListValue{
ListValue: &exprpb.ListValue{Values: elts}}}, nil
case types.MapType:
mapper := res.(traits.Mapper)
sz := mapper.Size().(types.Int)
entries := make([]*exprpb.MapValue_Entry, 0, int64(sz))
for it := mapper.Iterator(); it.HasNext().(types.Bool); {
k := it.Next()
v := mapper.Get(k)
kv, err := RefValueToValue(k)
if err != nil {
return nil, err
}
vv, err := RefValueToValue(v)
if err != nil {
return nil, err
}
entries = append(entries, &exprpb.MapValue_Entry{Key: kv, Value: vv})
}
return &exprpb.Value{
Kind: &exprpb.Value_MapValue{
MapValue: &exprpb.MapValue{Entries: entries}}}, nil
case types.NullType:
return &exprpb.Value{
Kind: &exprpb.Value_NullValue{}}, nil
case types.StringType:
return &exprpb.Value{
Kind: &exprpb.Value_StringValue{StringValue: res.Value().(string)}}, nil
case types.TypeType:
typeName := res.(ref.Type).TypeName()
return &exprpb.Value{Kind: &exprpb.Value_TypeValue{TypeValue: typeName}}, nil
case types.UintType:
return &exprpb.Value{
Kind: &exprpb.Value_Uint64Value{Uint64Value: res.Value().(uint64)}}, nil
default:
any, err := res.ConvertToNative(anyPbType)
if err != nil {
return nil, err
}
return &exprpb.Value{
Kind: &exprpb.Value_ObjectValue{ObjectValue: any.(*anypb.Any)}}, nil
}
}
var (
typeNameToTypeValue = map[string]*types.TypeValue{
"bool": types.BoolType,
"bytes": types.BytesType,
"double": types.DoubleType,
"null_type": types.NullType,
"int": types.IntType,
"list": types.ListType,
"map": types.MapType,
"string": types.StringType,
"type": types.TypeType,
"uint": types.UintType,
}
anyPbType = reflect.TypeOf(&anypb.Any{})
)
// ValueToRefValue converts between exprpb.Value and ref.Val.
func ValueToRefValue(adapter ref.TypeAdapter, v *exprpb.Value) (ref.Val, error) {
switch v.Kind.(type) {
case *exprpb.Value_NullValue:
return types.NullValue, nil
case *exprpb.Value_BoolValue:
return types.Bool(v.GetBoolValue()), nil
case *exprpb.Value_Int64Value:
return types.Int(v.GetInt64Value()), nil
case *exprpb.Value_Uint64Value:
return types.Uint(v.GetUint64Value()), nil
case *exprpb.Value_DoubleValue:
return types.Double(v.GetDoubleValue()), nil
case *exprpb.Value_StringValue:
return types.String(v.GetStringValue()), nil
case *exprpb.Value_BytesValue:
return types.Bytes(v.GetBytesValue()), nil
case *exprpb.Value_ObjectValue:
any := v.GetObjectValue()
msg, err := anypb.UnmarshalNew(any, proto.UnmarshalOptions{DiscardUnknown: true})
if err != nil {
return nil, err
}
return adapter.NativeToValue(msg.(proto.Message)), nil
case *exprpb.Value_MapValue:
m := v.GetMapValue()
entries := make(map[ref.Val]ref.Val)
for _, entry := range m.Entries {
key, err := ValueToRefValue(adapter, entry.Key)
if err != nil {
return nil, err
}
pb, err := ValueToRefValue(adapter, entry.Value)
if err != nil {
return nil, err
}
entries[key] = pb
}
return adapter.NativeToValue(entries), nil
case *exprpb.Value_ListValue:
l := v.GetListValue()
elts := make([]ref.Val, len(l.Values))
for i, e := range l.Values {
rv, err := ValueToRefValue(adapter, e)
if err != nil {
return nil, err
}
elts[i] = rv
}
return adapter.NativeToValue(elts), nil
case *exprpb.Value_TypeValue:
typeName := v.GetTypeValue()
tv, ok := typeNameToTypeValue[typeName]
if ok {
return tv, nil
}
return types.NewObjectTypeValue(typeName), nil
}
return nil, errors.New("unknown value")
}

View File

@ -17,6 +17,12 @@ package cel
import (
"fmt"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/dynamicpb"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/types/pb"
@ -24,11 +30,6 @@ import (
"github.com/google/cel-go/interpreter"
"github.com/google/cel-go/interpreter/functions"
"github.com/google/cel-go/parser"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/dynamicpb"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
descpb "google.golang.org/protobuf/types/descriptorpb"
@ -45,7 +46,13 @@ const (
// provided as variables to the expression, as well as via conversion
// of well-known dynamic types, or with unchecked expressions.
// Affects checking. Provides a subset of standard behavior.
FeatureDisableDynamicAggregateLiterals
featureDisableDynamicAggregateLiterals
// Enable the tracking of function call expressions replaced by macros.
featureEnableMacroCallTracking
// Enable the use of cross-type numeric comparisons at the type-checker.
featureCrossTypeNumericComparisons
)
// EnvOption is a functional interface for configuring the environment.
@ -96,16 +103,6 @@ func Declarations(decls ...*exprpb.Decl) EnvOption {
}
}
// Features sets the given feature flags. See list of Feature constants above.
func Features(flags ...int) EnvOption {
return func(e *Env) (*Env, error) {
for _, flag := range flags {
e.SetFeature(flag)
}
return e, nil
}
}
// HomogeneousAggregateLiterals option ensures that list and map literal entry types must agree
// during type-checking.
//
@ -113,7 +110,7 @@ func Features(flags ...int) EnvOption {
// expression, as well as via conversion of well-known dynamic types, or with unchecked
// expressions.
func HomogeneousAggregateLiterals() EnvOption {
return Features(FeatureDisableDynamicAggregateLiterals)
return features(featureDisableDynamicAggregateLiterals, true)
}
// Macros option extends the macro set configured in the environment.
@ -334,8 +331,7 @@ func Functions(funcs ...*functions.Overload) ProgramOption {
// The vars value may either be an `interpreter.Activation` instance or a `map[string]interface{}`.
func Globals(vars interface{}) ProgramOption {
return func(p *prog) (*prog, error) {
defaultVars, err :=
interpreter.NewActivation(vars)
defaultVars, err := interpreter.NewActivation(vars)
if err != nil {
return nil, err
}
@ -344,6 +340,16 @@ func Globals(vars interface{}) ProgramOption {
}
}
// OptimizeRegex provides a way to replace the InterpretableCall for regex functions. This can be used
// to compile regex string constants at program creation time and report any errors and then use the
// compiled regex for all regex function invocations.
func OptimizeRegex(regexOptimizations ...*interpreter.RegexOptimization) ProgramOption {
return func(p *prog) (*prog, error) {
p.regexOptimizations = append(p.regexOptimizations, regexOptimizations...)
return p, nil
}
}
// EvalOption indicates an evaluation option that may affect the evaluation behavior or information
// in the output result.
type EvalOption int
@ -356,7 +362,9 @@ const (
OptExhaustiveEval EvalOption = 1<<iota | OptTrackState
// OptOptimize precomputes functions and operators with constants as arguments at program
// creation time. This flag is useful when the expression will be evaluated repeatedly against
// creation time. It also pre-compiles regex pattern constants passed to 'matches', reports any compilation errors
// at program creation and uses the compiled regex pattern for all 'matches' function invocations.
// This flag is useful when the expression will be evaluated repeatedly against
// a series of different inputs.
OptOptimize EvalOption = 1 << iota
@ -365,8 +373,12 @@ const (
// member graph.
//
// By itself, OptPartialEval does not change evaluation behavior unless the input to the
// Program Eval is an PartialVars.
// Program Eval() call is created via PartialVars().
OptPartialEval EvalOption = 1 << iota
// OptTrackCost enables the runtime cost calculation while validation and return cost within evalDetails
// cost calculation is available via func ActualCost()
OptTrackCost EvalOption = 1 << iota
)
// EvalOptions sets one or more evaluation options which may affect the evaluation or Result.
@ -379,6 +391,36 @@ func EvalOptions(opts ...EvalOption) ProgramOption {
}
}
// InterruptCheckFrequency configures the number of iterations within a comprehension to evaluate
// before checking whether the function evaluation has been interrupted.
func InterruptCheckFrequency(checkFrequency uint) ProgramOption {
return func(p *prog) (*prog, error) {
p.interruptCheckFrequency = checkFrequency
return p, nil
}
}
// CostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls.
func CostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption {
return func(p *prog) (*prog, error) {
p.callCostEstimator = costEstimator
p.evalOpts |= OptTrackCost
return p, nil
}
}
// CostLimit enables cost tracking and sets configures program evaluation to exit early with a
// "runtime cost limit exceeded" error if the runtime cost exceeds the costLimit.
// The CostLimit is a metric that corresponds to the number and estimated expense of operations
// performed while evaluating an expression. It is indicative of CPU usage, not memory usage.
func CostLimit(costLimit uint64) ProgramOption {
return func(p *prog) (*prog, error) {
p.costLimit = &costLimit
p.evalOpts |= OptTrackCost
return p, nil
}
}
func fieldToCELType(field protoreflect.FieldDescriptor) (*exprpb.Type, error) {
if field.Kind() == protoreflect.MessageKind {
msgName := (string)(field.Message().FullName())
@ -411,19 +453,19 @@ func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) {
return nil, err
}
return decls.NewVar(name, decls.NewMapType(keyType, valueType)), nil
} else if field.IsList() {
}
if field.IsList() {
elemType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return decls.NewVar(name, decls.NewListType(elemType)), nil
} else {
celType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return decls.NewVar(name, celType), nil
}
celType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return decls.NewVar(name, celType), nil
}
// DeclareContextProto returns an option to extend CEL environment with declarations from the given context proto.
@ -449,3 +491,22 @@ func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption {
return Types(dynamicpb.NewMessage(descriptor))(e)
}
}
// EnableMacroCallTracking ensures that call expressions which are replaced by macros
// are tracked in the `SourceInfo` of parsed and checked expressions.
func EnableMacroCallTracking() EnvOption {
return features(featureEnableMacroCallTracking, true)
}
// CrossTypeNumericComparisons makes it possible to compare across numeric types, e.g. double < int
func CrossTypeNumericComparisons(enabled bool) EnvOption {
return features(featureCrossTypeNumericComparisons, enabled)
}
// features sets the given feature flags. See list of Feature constants above.
func features(flag int, enabled bool) EnvOption {
return func(e *Env) (*Env, error) {
e.features[flag] = enabled
return e, nil
}
}

View File

@ -15,14 +15,16 @@
package cel
import (
"context"
"fmt"
"math"
"sync"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Program is an evaluable view of an Ast.
@ -31,7 +33,7 @@ type Program interface {
//
// The vars value may either be an `interpreter.Activation` or a `map[string]interface{}`.
//
// If the `OptTrackState` or `OptExhaustiveEval` flags are used, the `details` response will
// If the `OptTrackState`, `OptTrackCost` or `OptExhaustiveEval` flags are used, the `details` response will
// be non-nil. Given this caveat on `details`, the return state from evaluation will be:
//
// * `val`, `details`, `nil` - Successful evaluation of a non-error result.
@ -41,7 +43,16 @@ type Program interface {
// An unsuccessful evaluation is typically the result of a series of incompatible `EnvOption`
// or `ProgramOption` values used in the creation of the evaluation environment or executable
// program.
Eval(vars interface{}) (ref.Val, *EvalDetails, error)
Eval(interface{}) (ref.Val, *EvalDetails, error)
// ContextEval evaluates the program with a set of input variables and a context object in order
// to support cancellation and timeouts. This method must be used in conjunction with the
// InterruptCheckFrequency() option for cancellation interrupts to be impact evaluation.
//
// The vars value may eitehr be an `interpreter.Activation` or `map[string]interface{}`.
//
// The output contract for `ContextEval` is otherwise identical to the `Eval` method.
ContextEval(context.Context, interface{}) (ref.Val, *EvalDetails, error)
}
// NoVars returns an empty Activation.
@ -81,7 +92,8 @@ func AttributePattern(varName string) *interpreter.AttributePattern {
// EvalDetails holds additional information observed during the Eval() call.
type EvalDetails struct {
state interpreter.EvalState
state interpreter.EvalState
costTracker *interpreter.CostTracker
}
// State of the evaluation, non-nil if the OptTrackState or OptExhaustiveEval is specified
@ -90,24 +102,43 @@ func (ed *EvalDetails) State() interpreter.EvalState {
return ed.state
}
func (ed *EvalDetails) ActualCost() *uint64 {
if ed.costTracker == nil {
return nil
}
cost := ed.costTracker.ActualCost()
return &cost
}
// prog is the internal implementation of the Program interface.
type prog struct {
*Env
evalOpts EvalOption
decorators []interpreter.InterpretableDecorator
defaultVars interpreter.Activation
dispatcher interpreter.Dispatcher
interpreter interpreter.Interpreter
interpretable interpreter.Interpretable
attrFactory interpreter.AttributeFactory
evalOpts EvalOption
defaultVars interpreter.Activation
dispatcher interpreter.Dispatcher
interpreter interpreter.Interpreter
interruptCheckFrequency uint
// Intermediate state used to configure the InterpretableDecorator set provided
// to the initInterpretable call.
decorators []interpreter.InterpretableDecorator
regexOptimizations []*interpreter.RegexOptimization
// Interpretable configured from an Ast and aggregate decorator set based on program options.
interpretable interpreter.Interpretable
callCostEstimator interpreter.ActualCostEstimator
costLimit *uint64
}
// progFactory is a helper alias for marking a program creation factory function.
type progFactory func(interpreter.EvalState) (Program, error)
// progGen holds a reference to a progFactory instance and implements the Program interface.
type progGen struct {
factory progFactory
func (p *prog) clone() *prog {
return &prog{
Env: p.Env,
evalOpts: p.evalOpts,
defaultVars: p.defaultVars,
dispatcher: p.dispatcher,
interpreter: p.interpreter,
interruptCheckFrequency: p.interruptCheckFrequency,
}
}
// newProgram creates a program instance with an environment, an ast, and an optional list of
@ -129,9 +160,6 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
// Configure the program via the ProgramOption values.
var err error
for _, opt := range opts {
if opt == nil {
return nil, fmt.Errorf("program options should be non-nil")
}
p, err = opt(p)
if err != nil {
return nil, err
@ -139,97 +167,86 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
}
// Set the attribute factory after the options have been set.
var attrFactory interpreter.AttributeFactory
if p.evalOpts&OptPartialEval == OptPartialEval {
p.attrFactory = interpreter.NewPartialAttributeFactory(e.Container, e.adapter, e.provider)
attrFactory = interpreter.NewPartialAttributeFactory(e.Container, e.adapter, e.provider)
} else {
p.attrFactory = interpreter.NewAttributeFactory(e.Container, e.adapter, e.provider)
attrFactory = interpreter.NewAttributeFactory(e.Container, e.adapter, e.provider)
}
interp := interpreter.NewInterpreter(disp, e.Container, e.provider, e.adapter, p.attrFactory)
interp := interpreter.NewInterpreter(disp, e.Container, e.provider, e.adapter, attrFactory)
p.interpreter = interp
// Translate the EvalOption flags into InterpretableDecorator instances.
decorators := make([]interpreter.InterpretableDecorator, len(p.decorators))
copy(decorators, p.decorators)
// Enable interrupt checking if there's a non-zero check frequency
if p.interruptCheckFrequency > 0 {
decorators = append(decorators, interpreter.InterruptableEval())
}
// Enable constant folding first.
if p.evalOpts&OptOptimize == OptOptimize {
decorators = append(decorators, interpreter.Optimize())
p.regexOptimizations = append(p.regexOptimizations, interpreter.MatchesRegexOptimization)
}
// Enable exhaustive eval over state tracking since it offers a superset of features.
if p.evalOpts&OptExhaustiveEval == OptExhaustiveEval {
// State tracking requires that each Eval() call operate on an isolated EvalState
// object; hence, the presence of the factory.
factory := func(state interpreter.EvalState) (Program, error) {
decs := append(decorators, interpreter.ExhaustiveEval(state))
clone := &prog{
evalOpts: p.evalOpts,
defaultVars: p.defaultVars,
Env: e,
dispatcher: disp,
interpreter: interp}
return initInterpretable(clone, ast, decs)
// Enable regex compilation of constants immediately after folding constants.
if len(p.regexOptimizations) > 0 {
decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...))
}
// Enable exhaustive eval, state tracking and cost tracking last since they require a factory.
if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 {
factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) {
costTracker.Estimator = p.callCostEstimator
costTracker.Limit = p.costLimit
decs := decorators
var observers []interpreter.EvalObserver
if p.evalOpts&(OptExhaustiveEval|OptTrackState) != 0 {
// EvalStateObserver is required for OptExhaustiveEval.
observers = append(observers, interpreter.EvalStateObserver(state))
}
if p.evalOpts&OptTrackCost == OptTrackCost {
observers = append(observers, interpreter.CostObserver(costTracker))
}
// Enable exhaustive eval over a basic observer since it offers a superset of features.
if p.evalOpts&OptExhaustiveEval == OptExhaustiveEval {
decs = append(decs, interpreter.ExhaustiveEval(), interpreter.Observe(observers...))
} else if len(observers) > 0 {
decs = append(decs, interpreter.Observe(observers...))
}
return p.clone().initInterpretable(ast, decs)
}
return initProgGen(factory)
return newProgGen(factory)
}
// Enable state tracking last since it too requires the factory approach but is less
// featured than the ExhaustiveEval decorator.
if p.evalOpts&OptTrackState == OptTrackState {
factory := func(state interpreter.EvalState) (Program, error) {
decs := append(decorators, interpreter.TrackState(state))
clone := &prog{
evalOpts: p.evalOpts,
defaultVars: p.defaultVars,
Env: e,
dispatcher: disp,
interpreter: interp}
return initInterpretable(clone, ast, decs)
}
return initProgGen(factory)
}
return initInterpretable(p, ast, decorators)
return p.initInterpretable(ast, decorators)
}
// initProgGen tests the factory object by calling it once and returns a factory-based Program if
// the test is successful.
func initProgGen(factory progFactory) (Program, error) {
// Test the factory to make sure that configuration errors are spotted at config
_, err := factory(interpreter.NewEvalState())
if err != nil {
return nil, err
}
return &progGen{factory: factory}, nil
}
// initIterpretable creates a checked or unchecked interpretable depending on whether the Ast
// has been run through the type-checker.
func initInterpretable(
p *prog,
ast *Ast,
decorators []interpreter.InterpretableDecorator) (Program, error) {
var err error
// Unchecked programs do not contain type and reference information and may be
// slower to execute than their checked counterparts.
func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) {
// Unchecked programs do not contain type and reference information and may be slower to execute.
if !ast.IsChecked() {
p.interpretable, err =
p.interpreter.NewUncheckedInterpretable(ast.Expr(), decorators...)
interpretable, err :=
p.interpreter.NewUncheckedInterpretable(ast.Expr(), decs...)
if err != nil {
return nil, err
}
p.interpretable = interpretable
return p, nil
}
// When the AST has been checked it contains metadata that can be used to speed up program
// execution.
var checked *exprpb.CheckedExpr
checked, err = AstToCheckedExpr(ast)
if err != nil {
return nil, err
}
p.interpretable, err = p.interpreter.NewInterpretable(checked, decorators...)
if err != nil {
return nil, err
}
// When the AST has been checked it contains metadata that can be used to speed up program execution.
var checked *exprpb.CheckedExpr
checked, err := AstToCheckedExpr(ast)
if err != nil {
return nil, err
}
interpretable, err := p.interpreter.NewInterpretable(checked, decs...)
if err != nil {
return nil, err
}
p.interpretable = interpretable
return p, nil
}
@ -240,13 +257,24 @@ func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error)
// function.
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("internal error: %v", r)
switch t := r.(type) {
case interpreter.EvalCancelledError:
err = t
default:
err = fmt.Errorf("internal error: %v", r)
}
}
}()
// Build a hierarchical activation if there are default vars set.
vars, err := interpreter.NewActivation(input)
if err != nil {
return
var vars interpreter.Activation
switch v := input.(type) {
case interpreter.Activation:
vars = v
case map[string]interface{}:
vars = activationPool.Setup(v)
defer activationPool.Put(vars)
default:
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]interface{}, got: (%T)%v", input, input)
}
if p.defaultVars != nil {
vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars)
@ -261,23 +289,63 @@ func (p *prog) Eval(input interface{}) (v ref.Val, det *EvalDetails, err error)
return
}
// ContextEval implements the Program interface.
func (p *prog) ContextEval(ctx context.Context, input interface{}) (ref.Val, *EvalDetails, error) {
// Configure the input, making sure to wrap Activation inputs in the special ctxActivation which
// exposes the #interrupted variable and manages rate-limited checks of the ctx.Done() state.
var vars interpreter.Activation
switch v := input.(type) {
case interpreter.Activation:
vars = ctxActivationPool.Setup(v, ctx.Done(), p.interruptCheckFrequency)
defer ctxActivationPool.Put(vars)
case map[string]interface{}:
rawVars := activationPool.Setup(v)
defer activationPool.Put(rawVars)
vars = ctxActivationPool.Setup(rawVars, ctx.Done(), p.interruptCheckFrequency)
defer ctxActivationPool.Put(vars)
default:
return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]interface{}, got: (%T)%v", input, input)
}
return p.Eval(vars)
}
// Cost implements the Coster interface method.
func (p *prog) Cost() (min, max int64) {
return estimateCost(p.interpretable)
}
// progFactory is a helper alias for marking a program creation factory function.
type progFactory func(interpreter.EvalState, *interpreter.CostTracker) (Program, error)
// progGen holds a reference to a progFactory instance and implements the Program interface.
type progGen struct {
factory progFactory
}
// newProgGen tests the factory object by calling it once and returns a factory-based Program if
// the test is successful.
func newProgGen(factory progFactory) (Program, error) {
// Test the factory to make sure that configuration errors are spotted at config
_, err := factory(interpreter.NewEvalState(), &interpreter.CostTracker{})
if err != nil {
return nil, err
}
return &progGen{factory: factory}, nil
}
// Eval implements the Program interface method.
func (gen *progGen) Eval(input interface{}) (ref.Val, *EvalDetails, error) {
// The factory based Eval() differs from the standard evaluation model in that it generates a
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
det := &EvalDetails{state: state}
costTracker := &interpreter.CostTracker{}
det := &EvalDetails{state: state, costTracker: costTracker}
// Generate a new instance of the interpretable using the factory configured during the call to
// newProgram(). It is incredibly unlikely that the factory call will generate an error given
// the factory test performed within the Program() call.
p, err := gen.factory(state)
p, err := gen.factory(state, costTracker)
if err != nil {
return nil, det, err
}
@ -290,20 +358,40 @@ func (gen *progGen) Eval(input interface{}) (ref.Val, *EvalDetails, error) {
return v, det, nil
}
// ContextEval implements the Program interface method.
func (gen *progGen) ContextEval(ctx context.Context, input interface{}) (ref.Val, *EvalDetails, error) {
// The factory based Eval() differs from the standard evaluation model in that it generates a
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
det := &EvalDetails{state: state}
// Generate a new instance of the interpretable using the factory configured during the call to
// newProgram(). It is incredibly unlikely that the factory call will generate an error given
// the factory test performed within the Program() call.
p, err := gen.factory(state, &interpreter.CostTracker{})
if err != nil {
return nil, det, err
}
// Evaluate the input, returning the result and the 'state' within EvalDetails.
v, _, err := p.ContextEval(ctx, input)
if err != nil {
return v, det, err
}
return v, det, nil
}
// Cost implements the Coster interface method.
func (gen *progGen) Cost() (min, max int64) {
// Use an empty state value since no evaluation is performed.
p, err := gen.factory(emptyEvalState)
p, err := gen.factory(emptyEvalState, nil)
if err != nil {
return 0, math.MaxInt64
}
return estimateCost(p)
}
var (
emptyEvalState = interpreter.NewEvalState()
)
// EstimateCost returns the heuristic cost interval for the program.
func EstimateCost(p Program) (min, max int64) {
return estimateCost(p)
@ -316,3 +404,140 @@ func estimateCost(i interface{}) (min, max int64) {
}
return c.Cost()
}
type ctxEvalActivation struct {
parent interpreter.Activation
interrupt <-chan struct{}
interruptCheckCount uint
interruptCheckFrequency uint
}
// ResolveName implements the Activation interface method, but adds a special #interrupted variable
// which is capable of testing whether a 'done' signal is provided from a context.Context channel.
func (a *ctxEvalActivation) ResolveName(name string) (interface{}, bool) {
if name == "#interrupted" {
a.interruptCheckCount++
if a.interruptCheckCount%a.interruptCheckFrequency == 0 {
select {
case <-a.interrupt:
return true, true
default:
return nil, false
}
}
return nil, false
}
return a.parent.ResolveName(name)
}
func (a *ctxEvalActivation) Parent() interpreter.Activation {
return a.parent
}
func newCtxEvalActivationPool() *ctxEvalActivationPool {
return &ctxEvalActivationPool{
Pool: sync.Pool{
New: func() interface{} {
return &ctxEvalActivation{}
},
},
}
}
type ctxEvalActivationPool struct {
sync.Pool
}
// Setup initializes a pooled Activation with the ability check for context.Context cancellation
func (p *ctxEvalActivationPool) Setup(vars interpreter.Activation, done <-chan struct{}, interruptCheckRate uint) *ctxEvalActivation {
a := p.Pool.Get().(*ctxEvalActivation)
a.parent = vars
a.interrupt = done
a.interruptCheckCount = 0
a.interruptCheckFrequency = interruptCheckRate
return a
}
type evalActivation struct {
vars map[string]interface{}
lazyVars map[string]interface{}
}
// ResolveName looks up the value of the input variable name, if found.
//
// Lazy bindings may be supplied within the map-based input in either of the following forms:
// - func() interface{}
// - func() ref.Val
//
// The lazy binding will only be invoked once per evaluation.
//
// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using
// the ref.TypeAdapter configured in the environment.
func (a *evalActivation) ResolveName(name string) (interface{}, bool) {
v, found := a.vars[name]
if !found {
return nil, false
}
switch obj := v.(type) {
case func() ref.Val:
if resolved, found := a.lazyVars[name]; found {
return resolved, true
}
lazy := obj()
a.lazyVars[name] = lazy
return lazy, true
case func() interface{}:
if resolved, found := a.lazyVars[name]; found {
return resolved, true
}
lazy := obj()
a.lazyVars[name] = lazy
return lazy, true
default:
return obj, true
}
}
// Parent implements the interpreter.Activation interface
func (a *evalActivation) Parent() interpreter.Activation {
return nil
}
func newEvalActivationPool() *evalActivationPool {
return &evalActivationPool{
Pool: sync.Pool{
New: func() interface{} {
return &evalActivation{lazyVars: make(map[string]interface{})}
},
},
}
}
type evalActivationPool struct {
sync.Pool
}
// Setup initializes a pooled Activation object with the map input.
func (p *evalActivationPool) Setup(vars map[string]interface{}) *evalActivation {
a := p.Pool.Get().(*evalActivation)
a.vars = vars
return a
}
func (p *evalActivationPool) Put(value interface{}) {
a := value.(*evalActivation)
for k := range a.lazyVars {
delete(a.lazyVars, k)
}
p.Pool.Put(a)
}
var (
emptyEvalState = interpreter.NewEvalState()
// activationPool is an internally managed pool of Activation values that wrap map[string]interface{} inputs
activationPool = newEvalActivationPool()
// ctxActivationPool is an internally managed pool of Activation values that expose a special #interrupted variable
ctxActivationPool = newCtxEvalActivationPool()
)

View File

@ -8,9 +8,11 @@ go_library(
name = "go_default_library",
srcs = [
"checker.go",
"cost.go",
"env.go",
"errors.go",
"mapping.go",
"options.go",
"printer.go",
"standard.go",
"types.go",
@ -40,6 +42,7 @@ go_test(
size = "small",
srcs = [
"checker_test.go",
"cost_test.go",
"env_test.go",
],
embed = [

602
vendor/github.com/google/cel-go/checker/cost.go generated vendored Normal file
View File

@ -0,0 +1,602 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package checker
import (
"math"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/parser"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// WARNING: Any changes to cost calculations in this file require a corresponding change in interpreter/runtimecost.go
// CostEstimator estimates the sizes of variable length input data and the costs of functions.
type CostEstimator interface {
// EstimateSize returns a SizeEstimate for the given AstNode, or nil if
// the estimator has no estimate to provide. The size is equivalent to the result of the CEL `size()` function:
// length of strings and bytes, number of map entries or number of list items.
// EstimateSize is only called for AstNodes where
// CEL does not know the size; EstimateSize is not called for values defined inline in CEL where the size
// is already obvious to CEL.
EstimateSize(element AstNode) *SizeEstimate
// EstimateCallCost returns the estimated cost of an invocation, or nil if
// the estimator has no estimate to provide.
EstimateCallCost(overloadId string, target *AstNode, args []AstNode) *CallEstimate
}
// CallEstimate includes a CostEstimate for the call, and an optional estimate of the result object size.
// The ResultSize should only be provided if the call results in a map, list, string or bytes.
type CallEstimate struct {
CostEstimate
ResultSize *SizeEstimate
}
// AstNode represents an AST node for the purpose of cost estimations.
type AstNode interface {
// Path returns a field path through the provided type declarations to the type of the AstNode, or nil if the AstNode does not
// represent type directly reachable from the provided type declarations.
// The first path element is a variable. All subsequent path elements are one of: field name, '@items', '@keys', '@values'.
Path() []string
// Type returns the deduced type of the AstNode.
Type() *exprpb.Type
// Expr returns the expression of the AstNode.
Expr() *exprpb.Expr
// ComputedSize returns a size estimate of the AstNode derived from information available in the CEL expression.
// For constants and inline list and map declarations, the exact size is returned. For concatenated list, strings
// and bytes, the size is derived from the size estimates of the operands. nil is returned if there is no
// computed size available.
ComputedSize() *SizeEstimate
}
type astNode struct {
path []string
t *exprpb.Type
expr *exprpb.Expr
derivedSize *SizeEstimate
}
func (e astNode) Path() []string {
return e.path
}
func (e astNode) Type() *exprpb.Type {
return e.t
}
func (e astNode) Expr() *exprpb.Expr {
return e.expr
}
func (e astNode) ComputedSize() *SizeEstimate {
if e.derivedSize != nil {
return e.derivedSize
}
var v uint64
switch ek := e.expr.ExprKind.(type) {
case *exprpb.Expr_ConstExpr:
switch ck := ek.ConstExpr.ConstantKind.(type) {
case *exprpb.Constant_StringValue:
v = uint64(len(ck.StringValue))
case *exprpb.Constant_BytesValue:
v = uint64(len(ck.BytesValue))
case *exprpb.Constant_BoolValue, *exprpb.Constant_DoubleValue, *exprpb.Constant_DurationValue,
*exprpb.Constant_Int64Value, *exprpb.Constant_TimestampValue, *exprpb.Constant_Uint64Value,
*exprpb.Constant_NullValue:
v = uint64(1)
default:
return nil
}
case *exprpb.Expr_ListExpr:
v = uint64(len(ek.ListExpr.Elements))
case *exprpb.Expr_StructExpr:
if ek.StructExpr.MessageName == "" {
v = uint64(len(ek.StructExpr.Entries))
}
default:
return nil
}
return &SizeEstimate{Min: v, Max: v}
}
// SizeEstimate represents an estimated size of a variable length string, bytes, map or list.
type SizeEstimate struct {
Min, Max uint64
}
// Add adds to another SizeEstimate and returns the sum.
// If add would result in an uint64 overflow, the result is Maxuint64.
func (se SizeEstimate) Add(sizeEstimate SizeEstimate) SizeEstimate {
return SizeEstimate{
addUint64NoOverflow(se.Min, sizeEstimate.Min),
addUint64NoOverflow(se.Max, sizeEstimate.Max),
}
}
// Multiply multiplies by another SizeEstimate and returns the product.
// If multiply would result in an uint64 overflow, the result is Maxuint64.
func (se SizeEstimate) Multiply(sizeEstimate SizeEstimate) SizeEstimate {
return SizeEstimate{
multiplyUint64NoOverflow(se.Min, sizeEstimate.Min),
multiplyUint64NoOverflow(se.Max, sizeEstimate.Max),
}
}
// MultiplyByCostFactor multiplies a SizeEstimate by a cost factor and returns the CostEstimate with the
// nearest integer of the result, rounded up.
func (se SizeEstimate) MultiplyByCostFactor(costPerUnit float64) CostEstimate {
return CostEstimate{
multiplyByCostFactor(se.Min, costPerUnit),
multiplyByCostFactor(se.Max, costPerUnit),
}
}
// MultiplyByCost multiplies by the cost and returns the product.
// If multiply would result in an uint64 overflow, the result is Maxuint64.
func (se SizeEstimate) MultiplyByCost(cost CostEstimate) CostEstimate {
return CostEstimate{
multiplyUint64NoOverflow(se.Min, cost.Min),
multiplyUint64NoOverflow(se.Max, cost.Max),
}
}
// Union returns a SizeEstimate that encompasses both input the SizeEstimate.
func (se SizeEstimate) Union(size SizeEstimate) SizeEstimate {
result := se
if size.Min < result.Min {
result.Min = size.Min
}
if size.Max > result.Max {
result.Max = size.Max
}
return result
}
// CostEstimate represents an estimated cost range and provides add and multiply operations
// that do not overflow.
type CostEstimate struct {
Min, Max uint64
}
// Add adds the costs and returns the sum.
// If add would result in an uint64 overflow for the min or max, the value is set to Maxuint64.
func (ce CostEstimate) Add(cost CostEstimate) CostEstimate {
return CostEstimate{
addUint64NoOverflow(ce.Min, cost.Min),
addUint64NoOverflow(ce.Max, cost.Max),
}
}
// Multiply multiplies by the cost and returns the product.
// If multiply would result in an uint64 overflow, the result is Maxuint64.
func (ce CostEstimate) Multiply(cost CostEstimate) CostEstimate {
return CostEstimate{
multiplyUint64NoOverflow(ce.Min, cost.Min),
multiplyUint64NoOverflow(ce.Max, cost.Max),
}
}
// MultiplyByCostFactor multiplies a CostEstimate by a cost factor and returns the CostEstimate with the
// nearest integer of the result, rounded up.
func (ce CostEstimate) MultiplyByCostFactor(costPerUnit float64) CostEstimate {
return CostEstimate{
multiplyByCostFactor(ce.Min, costPerUnit),
multiplyByCostFactor(ce.Max, costPerUnit),
}
}
// Union returns a CostEstimate that encompasses both input the CostEstimates.
func (ce CostEstimate) Union(size CostEstimate) CostEstimate {
result := ce
if size.Min < result.Min {
result.Min = size.Min
}
if size.Max > result.Max {
result.Max = size.Max
}
return result
}
// addUint64NoOverflow adds non-negative ints. If the result is exceeds math.MaxUint64, math.MaxUint64
// is returned.
func addUint64NoOverflow(x, y uint64) uint64 {
if y > 0 && x > math.MaxUint64-y {
return math.MaxUint64
}
return x + y
}
// multiplyUint64NoOverflow multiplies non-negative ints. If the result is exceeds math.MaxUint64, math.MaxUint64
// is returned.
func multiplyUint64NoOverflow(x, y uint64) uint64 {
if x > 0 && y > 0 && x > math.MaxUint64/y {
return math.MaxUint64
}
return x * y
}
// multiplyByFactor multiplies an integer by a cost factor float and returns the nearest integer value, rounded up.
func multiplyByCostFactor(x uint64, y float64) uint64 {
xFloat := float64(x)
if xFloat > 0 && y > 0 && xFloat > math.MaxUint64/y {
return math.MaxUint64
}
return uint64(math.Ceil(xFloat * y))
}
var (
selectAndIdentCost = CostEstimate{Min: common.SelectAndIdentCost, Max: common.SelectAndIdentCost}
constCost = CostEstimate{Min: common.ConstCost, Max: common.ConstCost}
createListBaseCost = CostEstimate{Min: common.ListCreateBaseCost, Max: common.ListCreateBaseCost}
createMapBaseCost = CostEstimate{Min: common.MapCreateBaseCost, Max: common.MapCreateBaseCost}
createMessageBaseCost = CostEstimate{Min: common.StructCreateBaseCost, Max: common.StructCreateBaseCost}
)
type coster struct {
// exprPath maps from Expr Id to field path.
exprPath map[int64][]string
// iterRanges tracks the iterRange of each iterVar.
iterRanges iterRangeScopes
// computedSizes tracks the computed sizes of call results.
computedSizes map[int64]SizeEstimate
checkedExpr *exprpb.CheckedExpr
estimator CostEstimator
}
// Use a stack of iterVar -> iterRange Expr Ids to handle shadowed variable names.
type iterRangeScopes map[string][]int64
func (vs iterRangeScopes) push(varName string, expr *exprpb.Expr) {
vs[varName] = append(vs[varName], expr.GetId())
}
func (vs iterRangeScopes) pop(varName string) {
varStack := vs[varName]
vs[varName] = varStack[:len(varStack)-1]
}
func (vs iterRangeScopes) peek(varName string) (int64, bool) {
varStack := vs[varName]
if len(varStack) > 0 {
return varStack[len(varStack)-1], true
}
return 0, false
}
// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *exprpb.CheckedExpr, estimator CostEstimator) CostEstimate {
c := coster{
checkedExpr: checker,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
}
return c.cost(checker.GetExpr())
}
func (c *coster) cost(e *exprpb.Expr) CostEstimate {
if e == nil {
return CostEstimate{}
}
var cost CostEstimate
switch e.ExprKind.(type) {
case *exprpb.Expr_ConstExpr:
cost = constCost
case *exprpb.Expr_IdentExpr:
cost = c.costIdent(e)
case *exprpb.Expr_SelectExpr:
cost = c.costSelect(e)
case *exprpb.Expr_CallExpr:
cost = c.costCall(e)
case *exprpb.Expr_ListExpr:
cost = c.costCreateList(e)
case *exprpb.Expr_StructExpr:
cost = c.costCreateStruct(e)
case *exprpb.Expr_ComprehensionExpr:
cost = c.costComprehension(e)
default:
return CostEstimate{}
}
return cost
}
func (c *coster) costIdent(e *exprpb.Expr) CostEstimate {
identExpr := e.GetIdentExpr()
// build and track the field path
if iterRange, ok := c.iterRanges.peek(identExpr.GetName()); ok {
switch c.checkedExpr.TypeMap[iterRange].TypeKind.(type) {
case *exprpb.Type_ListType_:
c.addPath(e, append(c.exprPath[iterRange], "@items"))
case *exprpb.Type_MapType_:
c.addPath(e, append(c.exprPath[iterRange], "@keys"))
}
} else {
c.addPath(e, []string{identExpr.GetName()})
}
return selectAndIdentCost
}
func (c *coster) costSelect(e *exprpb.Expr) CostEstimate {
sel := e.GetSelectExpr()
var sum CostEstimate
if sel.GetTestOnly() {
return sum
}
sum = sum.Add(c.cost(sel.GetOperand()))
targetType := c.getType(sel.GetOperand())
switch kindOf(targetType) {
case kindMap, kindObject, kindTypeParam:
sum = sum.Add(selectAndIdentCost)
}
// build and track the field path
c.addPath(e, append(c.getPath(sel.GetOperand()), sel.Field))
return sum
}
func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
call := e.GetCallExpr()
target := call.GetTarget()
args := call.GetArgs()
var sum CostEstimate
argTypes := make([]AstNode, len(args))
argCosts := make([]CostEstimate, len(args))
for i, arg := range args {
argCosts[i] = c.cost(arg)
argTypes[i] = c.newAstNode(arg)
}
ref := c.checkedExpr.ReferenceMap[e.GetId()]
if ref == nil || len(ref.GetOverloadId()) == 0 {
return CostEstimate{}
}
var targetType AstNode
if target != nil {
if call.Target != nil {
sum = sum.Add(c.cost(call.GetTarget()))
targetType = c.newAstNode(call.GetTarget())
}
}
// Pick a cost estimate range that covers all the overload cost estimation ranges
fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0}
var resultSize *SizeEstimate
for _, overload := range ref.GetOverloadId() {
overloadCost := c.functionCost(overload, &targetType, argTypes, argCosts)
fnCost = fnCost.Union(overloadCost.CostEstimate)
if overloadCost.ResultSize != nil {
if resultSize == nil {
resultSize = overloadCost.ResultSize
} else {
size := resultSize.Union(*overloadCost.ResultSize)
resultSize = &size
}
}
// build and track the field path for index operations
switch overload {
case overloads.IndexList:
if len(args) > 0 {
c.addPath(e, append(c.getPath(args[0]), "@items"))
}
case overloads.IndexMap:
if len(args) > 0 {
c.addPath(e, append(c.getPath(args[0]), "@values"))
}
}
}
if resultSize != nil {
c.computedSizes[e.GetId()] = *resultSize
}
return sum.Add(fnCost)
}
func (c *coster) costCreateList(e *exprpb.Expr) CostEstimate {
create := e.GetListExpr()
var sum CostEstimate
for _, e := range create.GetElements() {
sum = sum.Add(c.cost(e))
}
return sum.Add(createListBaseCost)
}
func (c *coster) costCreateStruct(e *exprpb.Expr) CostEstimate {
str := e.GetStructExpr()
if str.MessageName != "" {
return c.costCreateMessage(e)
} else {
return c.costCreateMap(e)
}
}
func (c *coster) costCreateMap(e *exprpb.Expr) CostEstimate {
mapVal := e.GetStructExpr()
var sum CostEstimate
for _, ent := range mapVal.GetEntries() {
key := ent.GetMapKey()
sum = sum.Add(c.cost(key))
sum = sum.Add(c.cost(ent.GetValue()))
}
return sum.Add(createMapBaseCost)
}
func (c *coster) costCreateMessage(e *exprpb.Expr) CostEstimate {
msgVal := e.GetStructExpr()
var sum CostEstimate
for _, ent := range msgVal.GetEntries() {
sum = sum.Add(c.cost(ent.GetValue()))
}
return sum.Add(createMessageBaseCost)
}
func (c *coster) costComprehension(e *exprpb.Expr) CostEstimate {
comp := e.GetComprehensionExpr()
var sum CostEstimate
sum = sum.Add(c.cost(comp.GetIterRange()))
sum = sum.Add(c.cost(comp.GetAccuInit()))
// Track the iterRange of each IterVar for field path construction
c.iterRanges.push(comp.GetIterVar(), comp.GetIterRange())
loopCost := c.cost(comp.GetLoopCondition())
stepCost := c.cost(comp.GetLoopStep())
c.iterRanges.pop(comp.GetIterVar())
sum = sum.Add(c.cost(comp.Result))
rangeCnt := c.sizeEstimate(c.newAstNode(comp.GetIterRange()))
rangeCost := rangeCnt.MultiplyByCost(stepCost.Add(loopCost))
sum = sum.Add(rangeCost)
return sum
}
func (c *coster) sizeEstimate(t AstNode) SizeEstimate {
if l := t.ComputedSize(); l != nil {
return *l
}
if l := c.estimator.EstimateSize(t); l != nil {
return *l
}
return SizeEstimate{Min: 0, Max: math.MaxUint64}
}
func (c *coster) functionCost(overloadId string, target *AstNode, args []AstNode, argCosts []CostEstimate) CallEstimate {
argCostSum := func() CostEstimate {
var sum CostEstimate
for _, a := range argCosts {
sum = sum.Add(a)
}
return sum
}
if est := c.estimator.EstimateCallCost(overloadId, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum())}
}
switch overloadId {
// O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString:
if len(args) == 1 {
return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())}
}
case overloads.InList:
// If a list is composed entirely of constant values this is O(1), but we don't account for that here.
// We just assume all list containment checks are O(n).
if len(args) == 2 {
return CallEstimate{CostEstimate: c.sizeEstimate(args[1]).MultiplyByCostFactor(1).Add(argCostSum())}
}
// O(nm) functions
case overloads.MatchesString:
// https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL
if target != nil && len(args) == 1 {
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
// in case where string is empty but regex is still expensive.
strCost := c.sizeEstimate(*target).Add(SizeEstimate{Min: 1, Max: 1}).MultiplyByCostFactor(common.StringTraversalCostFactor)
// We don't know how many expressions are in the regex, just the string length (a huge
// improvement here would be to somehow get a count the number of expressions in the regex or
// how many states are in the regex state machine and use that to measure regex cost).
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
// in length.
regexCost := c.sizeEstimate(args[0]).MultiplyByCostFactor(common.RegexStringLengthCostFactor)
return CallEstimate{CostEstimate: strCost.Multiply(regexCost).Add(argCostSum())}
}
case overloads.ContainsString:
if target != nil && len(args) == 1 {
strCost := c.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor)
substrCost := c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor)
return CallEstimate{CostEstimate: strCost.Multiply(substrCost).Add(argCostSum())}
}
case overloads.LogicalOr, overloads.LogicalAnd:
lhs := argCosts[0]
rhs := argCosts[1]
// min cost is min of LHS for short circuited && or ||
argCost := CostEstimate{Min: lhs.Min, Max: lhs.Add(rhs).Max}
return CallEstimate{CostEstimate: argCost}
case overloads.Conditional:
size := c.sizeEstimate(args[1]).Union(c.sizeEstimate(args[2]))
conditionalCost := argCosts[0]
ifTrueCost := argCosts[1]
ifFalseCost := argCosts[2]
argCost := conditionalCost.Add(ifTrueCost.Union(ifFalseCost))
return CallEstimate{CostEstimate: argCost, ResultSize: &size}
case overloads.AddString, overloads.AddBytes, overloads.AddList:
if len(args) == 2 {
lhsSize := c.sizeEstimate(args[0])
rhsSize := c.sizeEstimate(args[1])
resultSize := lhsSize.Add(rhsSize)
switch overloadId {
case overloads.AddList:
// list concatenation is O(1), but we handle it here to track size
return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum()), ResultSize: &resultSize}
default:
return CallEstimate{CostEstimate: resultSize.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &resultSize}
}
}
case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString,
overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes,
overloads.Equals, overloads.NotEquals:
lhsCost := c.sizeEstimate(args[0])
rhsCost := c.sizeEstimate(args[1])
min := uint64(0)
smallestMax := lhsCost.Max
if rhsCost.Max < smallestMax {
smallestMax = rhsCost.Max
}
if smallestMax > 0 {
min = 1
}
// equality of 2 scalar values results in a cost of 1
return CallEstimate{CostEstimate: CostEstimate{Min: min, Max: smallestMax}.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())}
}
// O(1) functions
// See CostTracker.costCall for more details about O(1) cost calculations
// Benchmarks suggest that most of the other operations take +/- 50% of a base cost unit
// which on an Intel xeon 2.20GHz CPU is 50ns.
return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum())}
}
func (c *coster) getType(e *exprpb.Expr) *exprpb.Type {
return c.checkedExpr.TypeMap[e.GetId()]
}
func (c *coster) getPath(e *exprpb.Expr) []string {
return c.exprPath[e.GetId()]
}
func (c *coster) addPath(e *exprpb.Expr, path []string) {
c.exprPath[e.GetId()] = path
}
func (c *coster) newAstNode(e *exprpb.Expr) *astNode {
path := c.getPath(e)
if len(path) > 0 && path[0] == parser.AccumulatorName {
// only provide paths to root vars; omit accumulator vars
path = nil
}
var derivedSize *SizeEstimate
if size, ok := c.computedSizes[e.GetId()]; ok {
derivedSize = &size
}
return &astNode{path: path, t: c.getType(e), expr: e, derivedSize: derivedSize}
}

View File

@ -20,6 +20,7 @@ import (
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/pb"
"github.com/google/cel-go/common/types/ref"
@ -35,49 +36,76 @@ const (
homogenousElementType aggregateLiteralElementType = 1 << iota
)
var (
crossTypeNumericComparisonOverloads = map[string]struct{}{
// double <-> int | uint
overloads.LessDoubleInt64: {},
overloads.LessDoubleUint64: {},
overloads.LessEqualsDoubleInt64: {},
overloads.LessEqualsDoubleUint64: {},
overloads.GreaterDoubleInt64: {},
overloads.GreaterDoubleUint64: {},
overloads.GreaterEqualsDoubleInt64: {},
overloads.GreaterEqualsDoubleUint64: {},
// int <-> double | uint
overloads.LessInt64Double: {},
overloads.LessInt64Uint64: {},
overloads.LessEqualsInt64Double: {},
overloads.LessEqualsInt64Uint64: {},
overloads.GreaterInt64Double: {},
overloads.GreaterInt64Uint64: {},
overloads.GreaterEqualsInt64Double: {},
overloads.GreaterEqualsInt64Uint64: {},
// uint <-> double | int
overloads.LessUint64Double: {},
overloads.LessUint64Int64: {},
overloads.LessEqualsUint64Double: {},
overloads.LessEqualsUint64Int64: {},
overloads.GreaterUint64Double: {},
overloads.GreaterUint64Int64: {},
overloads.GreaterEqualsUint64Double: {},
overloads.GreaterEqualsUint64Int64: {},
}
)
// Env is the environment for type checking.
//
// The Env is comprised of a container, type provider, declarations, and other related objects
// which can be used to assist with type-checking.
type Env struct {
container *containers.Container
provider ref.TypeProvider
declarations *decls.Scopes
aggLitElemType aggregateLiteralElementType
container *containers.Container
provider ref.TypeProvider
declarations *decls.Scopes
aggLitElemType aggregateLiteralElementType
filteredOverloadIDs map[string]struct{}
}
// NewEnv returns a new *Env with the given parameters.
func NewEnv(container *containers.Container, provider ref.TypeProvider) *Env {
func NewEnv(container *containers.Container, provider ref.TypeProvider, opts ...Option) (*Env, error) {
declarations := decls.NewScopes()
declarations.Push()
envOptions := &options{}
for _, opt := range opts {
if err := opt(envOptions); err != nil {
return nil, err
}
}
aggLitElemType := dynElementType
if envOptions.homogeneousAggregateLiterals {
aggLitElemType = homogenousElementType
}
filteredOverloadIDs := crossTypeNumericComparisonOverloads
if envOptions.crossTypeNumericComparisons {
filteredOverloadIDs = make(map[string]struct{})
}
return &Env{
container: container,
provider: provider,
declarations: declarations,
}
}
// NewStandardEnv returns a new *Env with the given params plus standard declarations.
func NewStandardEnv(container *containers.Container, provider ref.TypeProvider) *Env {
e := NewEnv(container, provider)
if err := e.Add(StandardDeclarations()...); err != nil {
// The standard declaration set should never have duplicate declarations.
panic(err)
}
// TODO: isolate standard declarations from the custom set which may be provided layer.
return e
}
// EnableDynamicAggregateLiterals detmerines whether list and map literals may support mixed
// element types at check-time. This does not preclude the presence of a dynamic list or map
// somewhere in the CEL evaluation process.
func (e *Env) EnableDynamicAggregateLiterals(enabled bool) *Env {
e.aggLitElemType = dynElementType
if !enabled {
e.aggLitElemType = homogenousElementType
}
return e
container: container,
provider: provider,
declarations: declarations,
aggLitElemType: aggLitElemType,
filteredOverloadIDs: filteredOverloadIDs,
}, nil
}
// Add adds new Decl protos to the Env.
@ -189,6 +217,9 @@ func (e *Env) addFunction(decl *exprpb.Decl) []errorMsg {
errorMsgs := make([]errorMsg, 0)
for _, overload := range decl.GetFunction().GetOverloads() {
if _, found := e.filteredOverloadIDs[overload.GetOverloadId()]; found {
continue
}
errorMsgs = append(errorMsgs, e.addOverload(current, overload)...)
}
return errorMsgs

41
vendor/github.com/google/cel-go/checker/options.go generated vendored Normal file
View File

@ -0,0 +1,41 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package checker
type options struct {
crossTypeNumericComparisons bool
homogeneousAggregateLiterals bool
}
// Option is a functional option for configuring the type-checker
type Option func(*options) error
// CrossTypeNumericComparisons toggles type-checker support for numeric comparisons across type
// See https://github.com/google/cel-spec/wiki/proposal-210 for more details.
func CrossTypeNumericComparisons(enabled bool) Option {
return func(opts *options) error {
opts.crossTypeNumericComparisons = enabled
return nil
}
}
// HomogeneousAggregateLiterals toggles support for constructing lists and maps whose elements all
// have the same type.
func HomogeneousAggregateLiterals(enabled bool) Option {
return func(opts *options) error {
opts.homogeneousAggregateLiterals = enabled
return nil
}
}

View File

@ -22,8 +22,11 @@ import (
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// StandardDeclarations returns the Decls for all functions and constants in the evaluator.
func StandardDeclarations() []*exprpb.Decl {
var (
standardDeclarations []*exprpb.Decl
)
func init() {
// Some shortcuts we use when building declarations.
paramA := decls.NewTypeParamType("A")
typeParamAList := []string{"A"}
@ -45,9 +48,9 @@ func StandardDeclarations() []*exprpb.Decl {
decls.NewVar("null_type", decls.NewTypeType(decls.Null)),
decls.NewVar("type", decls.NewTypeType(decls.NewTypeType(nil))))
// Booleans
// TODO: allow the conditional to return a heterogenous type.
return append(idents, []*exprpb.Decl{
standardDeclarations = append(standardDeclarations, idents...)
standardDeclarations = append(standardDeclarations, []*exprpb.Decl{
// Booleans
decls.NewFunction(operators.Conditional,
decls.NewParameterizedOverload(overloads.Conditional,
[]*exprpb.Type{decls.Bool, paramA, paramA}, paramA,
@ -69,80 +72,6 @@ func StandardDeclarations() []*exprpb.Decl {
decls.NewOverload(overloads.NotStrictlyFalse,
[]*exprpb.Type{decls.Bool}, decls.Bool)),
// Relations.
decls.NewFunction(operators.Less,
decls.NewOverload(overloads.LessBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.LessInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.LessBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.LessTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.LessDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.LessEquals,
decls.NewOverload(overloads.LessEqualsBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.LessEqualsInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessEqualsUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessEqualsString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.LessEqualsBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.LessEqualsTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.Greater,
decls.NewOverload(overloads.GreaterBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.GreaterInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.GreaterBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.GreaterTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.GreaterDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.GreaterEquals,
decls.NewOverload(overloads.GreaterEqualsBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.Equals,
decls.NewParameterizedOverload(overloads.Equals,
[]*exprpb.Type{paramA, paramA}, decls.Bool,
@ -227,8 +156,6 @@ func StandardDeclarations() []*exprpb.Decl {
decls.NewParameterizedOverload(overloads.IndexMap,
[]*exprpb.Type{mapOfAB, paramA}, paramB,
typeParamABList)),
//decls.NewOverload(overloads.IndexMessage,
// []*expr.Type{decls.Dyn, decls.String}, decls.Dyn)),
// Collections.
@ -267,8 +194,6 @@ func StandardDeclarations() []*exprpb.Decl {
decls.NewParameterizedOverload(overloads.InMap,
[]*exprpb.Type{paramA, mapOfAB}, decls.Bool,
typeParamABList)),
//decls.NewOverload(overloads.InMessage,
// []*expr.Type{Dyn, decls.String},decls.Bool)),
// Conversions to type.
@ -436,5 +361,132 @@ func StandardDeclarations() []*exprpb.Decl {
decls.NewInstanceOverload(overloads.TimestampToMillisecondsWithTz,
[]*exprpb.Type{decls.Timestamp, decls.String}, decls.Int),
decls.NewInstanceOverload(overloads.DurationToMilliseconds,
[]*exprpb.Type{decls.Duration}, decls.Int))}...)
[]*exprpb.Type{decls.Duration}, decls.Int)),
// Relations.
decls.NewFunction(operators.Less,
decls.NewOverload(overloads.LessBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.LessInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.LessBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.LessTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.LessDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.LessEquals,
decls.NewOverload(overloads.LessEqualsBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.LessEqualsInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessEqualsInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessEqualsInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessEqualsUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessEqualsUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessEqualsUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.LessEqualsString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.LessEqualsBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.LessEqualsTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.LessEqualsDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.Greater,
decls.NewOverload(overloads.GreaterBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.GreaterInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.GreaterBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.GreaterTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.GreaterDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
decls.NewFunction(operators.GreaterEquals,
decls.NewOverload(overloads.GreaterEqualsBool,
[]*exprpb.Type{decls.Bool, decls.Bool}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsInt64,
[]*exprpb.Type{decls.Int, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsInt64Double,
[]*exprpb.Type{decls.Int, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsInt64Uint64,
[]*exprpb.Type{decls.Int, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsUint64,
[]*exprpb.Type{decls.Uint, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsUint64Double,
[]*exprpb.Type{decls.Uint, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsUint64Int64,
[]*exprpb.Type{decls.Uint, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDouble,
[]*exprpb.Type{decls.Double, decls.Double}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDoubleInt64,
[]*exprpb.Type{decls.Double, decls.Int}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDoubleUint64,
[]*exprpb.Type{decls.Double, decls.Uint}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsString,
[]*exprpb.Type{decls.String, decls.String}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsBytes,
[]*exprpb.Type{decls.Bytes, decls.Bytes}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsTimestamp,
[]*exprpb.Type{decls.Timestamp, decls.Timestamp}, decls.Bool),
decls.NewOverload(overloads.GreaterEqualsDuration,
[]*exprpb.Type{decls.Duration, decls.Duration}, decls.Bool)),
}...)
}
// StandardDeclarations returns the Decls for all functions and constants in the evaluator.
func StandardDeclarations() []*exprpb.Decl {
return standardDeclarations
}

View File

@ -8,6 +8,7 @@ package(
go_library(
name = "go_default_library",
srcs = [
"cost.go",
"error.go",
"errors.go",
"location.go",

40
vendor/github.com/google/cel-go/common/cost.go generated vendored Normal file
View File

@ -0,0 +1,40 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
const (
// SelectAndIdentCost is the cost of an operation that accesses an identifier or performs a select.
SelectAndIdentCost = 1
// ConstCost is the cost of an operation that accesses a constant.
ConstCost = 0
// ListCreateBaseCost is the base cost of any operation that creates a new list.
ListCreateBaseCost = 10
// MapCreateBaseCost is the base cost of any operation that creates a new map.
MapCreateBaseCost = 30
// StructCreateBaseCost is the base cost of any operation that creates a new struct.
StructCreateBaseCost = 40
// StringTraversalCostFactor is multiplied to a length of a string when computing the cost of traversing the entire
// string once.
StringTraversalCostFactor = 0.1
// RegexStringLengthCostFactor is multiplied ot the length of a regex string pattern when computing the cost of
// applying the regex to a string of unit cost.
RegexStringLengthCostFactor = 0.25
)

View File

@ -70,46 +70,35 @@ var (
"!=": NotEquals,
"-": Subtract,
}
reverseOperators = map[string]string{
Add: "+",
Divide: "/",
Equals: "==",
Greater: ">",
GreaterEquals: ">=",
In: "in",
Less: "<",
LessEquals: "<=",
LogicalAnd: "&&",
LogicalNot: "!",
LogicalOr: "||",
Modulo: "%",
Multiply: "*",
Negate: "-",
NotEquals: "!=",
OldIn: "in",
Subtract: "-",
}
// precedence of the operator, where the higher value means higher.
precedence = map[string]int{
Conditional: 8,
LogicalOr: 7,
LogicalAnd: 6,
Equals: 5,
Greater: 5,
GreaterEquals: 5,
In: 5,
Less: 5,
LessEquals: 5,
NotEquals: 5,
OldIn: 5,
Add: 4,
Subtract: 4,
Divide: 3,
Modulo: 3,
Multiply: 3,
LogicalNot: 2,
Negate: 2,
Index: 1,
// operatorMap of the operator symbol which refers to a struct containing the display name,
// if applicable, the operator precedence, and the arity.
//
// If the symbol does not have a display name listed in the map, it is only because it requires
// special casing to render properly as text.
operatorMap = map[string]struct {
displayName string
precedence int
arity int
}{
Conditional: {displayName: "", precedence: 8, arity: 3},
LogicalOr: {displayName: "||", precedence: 7, arity: 2},
LogicalAnd: {displayName: "&&", precedence: 6, arity: 2},
Equals: {displayName: "==", precedence: 5, arity: 2},
Greater: {displayName: ">", precedence: 5, arity: 2},
GreaterEquals: {displayName: ">=", precedence: 5, arity: 2},
In: {displayName: "in", precedence: 5, arity: 2},
Less: {displayName: "<", precedence: 5, arity: 2},
LessEquals: {displayName: "<=", precedence: 5, arity: 2},
NotEquals: {displayName: "!=", precedence: 5, arity: 2},
OldIn: {displayName: "in", precedence: 5, arity: 2},
Add: {displayName: "+", precedence: 4, arity: 2},
Subtract: {displayName: "-", precedence: 4, arity: 2},
Divide: {displayName: "/", precedence: 3, arity: 2},
Modulo: {displayName: "%", precedence: 3, arity: 2},
Multiply: {displayName: "*", precedence: 3, arity: 2},
LogicalNot: {displayName: "!", precedence: 2, arity: 1},
Negate: {displayName: "-", precedence: 2, arity: 1},
Index: {displayName: "", precedence: 1, arity: 2},
}
)
@ -120,26 +109,35 @@ func Find(text string) (string, bool) {
}
// FindReverse returns the unmangled, text representation of the operator.
func FindReverse(op string) (string, bool) {
txt, found := reverseOperators[op]
return txt, found
func FindReverse(symbol string) (string, bool) {
op, found := operatorMap[symbol]
if !found {
return "", false
}
return op.displayName, true
}
// FindReverseBinaryOperator returns the unmangled, text representation of a binary operator.
func FindReverseBinaryOperator(op string) (string, bool) {
if op == LogicalNot || op == Negate {
//
// If the symbol does refer to an operator, but the operator does not have a display name the
// result is false.
func FindReverseBinaryOperator(symbol string) (string, bool) {
op, found := operatorMap[symbol]
if !found || op.arity != 2 {
return "", false
}
txt, found := reverseOperators[op]
return txt, found
if op.displayName == "" {
return "", false
}
return op.displayName, true
}
// Precedence returns the operator precedence, where the higher the number indicates
// higher precedence operations.
func Precedence(op string) int {
p, found := precedence[op]
if found {
return p
func Precedence(symbol string) int {
op, found := operatorMap[symbol]
if !found {
return 0
}
return 0
return op.precedence
}

View File

@ -18,45 +18,69 @@ package overloads
// Boolean logic overloads
const (
Conditional = "conditional"
LogicalAnd = "logical_and"
LogicalOr = "logical_or"
LogicalNot = "logical_not"
NotStrictlyFalse = "not_strictly_false"
Equals = "equals"
NotEquals = "not_equals"
LessBool = "less_bool"
LessInt64 = "less_int64"
LessUint64 = "less_uint64"
LessDouble = "less_double"
LessString = "less_string"
LessBytes = "less_bytes"
LessTimestamp = "less_timestamp"
LessDuration = "less_duration"
LessEqualsBool = "less_equals_bool"
LessEqualsInt64 = "less_equals_int64"
LessEqualsUint64 = "less_equals_uint64"
LessEqualsDouble = "less_equals_double"
LessEqualsString = "less_equals_string"
LessEqualsBytes = "less_equals_bytes"
LessEqualsTimestamp = "less_equals_timestamp"
LessEqualsDuration = "less_equals_duration"
GreaterBool = "greater_bool"
GreaterInt64 = "greater_int64"
GreaterUint64 = "greater_uint64"
GreaterDouble = "greater_double"
GreaterString = "greater_string"
GreaterBytes = "greater_bytes"
GreaterTimestamp = "greater_timestamp"
GreaterDuration = "greater_duration"
GreaterEqualsBool = "greater_equals_bool"
GreaterEqualsInt64 = "greater_equals_int64"
GreaterEqualsUint64 = "greater_equals_uint64"
GreaterEqualsDouble = "greater_equals_double"
GreaterEqualsString = "greater_equals_string"
GreaterEqualsBytes = "greater_equals_bytes"
GreaterEqualsTimestamp = "greater_equals_timestamp"
GreaterEqualsDuration = "greater_equals_duration"
Conditional = "conditional"
LogicalAnd = "logical_and"
LogicalOr = "logical_or"
LogicalNot = "logical_not"
NotStrictlyFalse = "not_strictly_false"
Equals = "equals"
NotEquals = "not_equals"
LessBool = "less_bool"
LessInt64 = "less_int64"
LessInt64Double = "less_int64_double"
LessInt64Uint64 = "less_int64_uint64"
LessUint64 = "less_uint64"
LessUint64Double = "less_uint64_double"
LessUint64Int64 = "less_uint64_int64"
LessDouble = "less_double"
LessDoubleInt64 = "less_double_int64"
LessDoubleUint64 = "less_double_uint64"
LessString = "less_string"
LessBytes = "less_bytes"
LessTimestamp = "less_timestamp"
LessDuration = "less_duration"
LessEqualsBool = "less_equals_bool"
LessEqualsInt64 = "less_equals_int64"
LessEqualsInt64Double = "less_equals_int64_double"
LessEqualsInt64Uint64 = "less_equals_int64_uint64"
LessEqualsUint64 = "less_equals_uint64"
LessEqualsUint64Double = "less_equals_uint64_double"
LessEqualsUint64Int64 = "less_equals_uint64_int64"
LessEqualsDouble = "less_equals_double"
LessEqualsDoubleInt64 = "less_equals_double_int64"
LessEqualsDoubleUint64 = "less_equals_double_uint64"
LessEqualsString = "less_equals_string"
LessEqualsBytes = "less_equals_bytes"
LessEqualsTimestamp = "less_equals_timestamp"
LessEqualsDuration = "less_equals_duration"
GreaterBool = "greater_bool"
GreaterInt64 = "greater_int64"
GreaterInt64Double = "greater_int64_double"
GreaterInt64Uint64 = "greater_int64_uint64"
GreaterUint64 = "greater_uint64"
GreaterUint64Double = "greater_uint64_double"
GreaterUint64Int64 = "greater_uint64_int64"
GreaterDouble = "greater_double"
GreaterDoubleInt64 = "greater_double_int64"
GreaterDoubleUint64 = "greater_double_uint64"
GreaterString = "greater_string"
GreaterBytes = "greater_bytes"
GreaterTimestamp = "greater_timestamp"
GreaterDuration = "greater_duration"
GreaterEqualsBool = "greater_equals_bool"
GreaterEqualsInt64 = "greater_equals_int64"
GreaterEqualsInt64Double = "greater_equals_int64_double"
GreaterEqualsInt64Uint64 = "greater_equals_int64_uint64"
GreaterEqualsUint64 = "greater_equals_uint64"
GreaterEqualsUint64Double = "greater_equals_uint64_double"
GreaterEqualsUint64Int64 = "greater_equals_uint64_int64"
GreaterEqualsDouble = "greater_equals_double"
GreaterEqualsDoubleInt64 = "greater_equals_double_int64"
GreaterEqualsDoubleUint64 = "greater_equals_double_uint64"
GreaterEqualsString = "greater_equals_string"
GreaterEqualsBytes = "greater_equals_bytes"
GreaterEqualsTimestamp = "greater_equals_timestamp"
GreaterEqualsDuration = "greater_equals_duration"
)
// Math overloads

View File

@ -11,6 +11,7 @@ go_library(
"any_value.go",
"bool.go",
"bytes.go",
"compare.go",
"double.go",
"duration.go",
"err.go",
@ -38,6 +39,9 @@ go_library(
"//common/types/traits:go_default_library",
"@com_github_stoewer_go_strcase//:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_genproto//googleapis/rpc/status:go_default_library",
"@org_golang_google_grpc//codes:go_default_library",
"@org_golang_google_grpc//status:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",

View File

@ -111,10 +111,7 @@ func (b Bool) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements the ref.Val interface method.
func (b Bool) Equal(other ref.Val) ref.Val {
otherBool, ok := other.(Bool)
if !ok {
return ValOrErr(other, "no such overload")
}
return Bool(b == otherBool)
return Bool(ok && b == otherBool)
}
// Negate implements the traits.Negater interface method.

View File

@ -113,10 +113,7 @@ func (b Bytes) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements the ref.Val interface method.
func (b Bytes) Equal(other ref.Val) ref.Val {
otherBytes, ok := other.(Bytes)
if !ok {
return ValOrErr(other, "no such overload")
}
return Bool(bytes.Equal(b, otherBytes))
return Bool(ok && bytes.Equal(b, otherBytes))
}
// Size implements the traits.Sizer interface method.

View File

@ -0,0 +1,95 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"math"
)
func compareDoubleInt(d Double, i Int) Int {
if d < math.MinInt64 {
return IntNegOne
}
if d > math.MaxInt64 {
return IntOne
}
return compareDouble(d, Double(i))
}
func compareIntDouble(i Int, d Double) Int {
return -compareDoubleInt(d, i)
}
func compareDoubleUint(d Double, u Uint) Int {
if d < 0 {
return IntNegOne
}
if d > math.MaxUint64 {
return IntOne
}
return compareDouble(d, Double(u))
}
func compareUintDouble(u Uint, d Double) Int {
return -compareDoubleUint(d, u)
}
func compareIntUint(i Int, u Uint) Int {
if i < 0 || u > math.MaxInt64 {
return IntNegOne
}
cmp := i - Int(u)
if cmp < 0 {
return IntNegOne
}
if cmp > 0 {
return IntOne
}
return IntZero
}
func compareUintInt(u Uint, i Int) Int {
return -compareIntUint(i, u)
}
func compareDouble(a, b Double) Int {
if a < b {
return IntNegOne
}
if a > b {
return IntOne
}
return IntZero
}
func compareInt(a, b Int) Int {
if a < b {
return IntNegOne
}
if a > b {
return IntOne
}
return IntZero
}
func compareUint(a, b Uint) Int {
if a < b {
return IntNegOne
}
if a > b {
return IntOne
}
return IntZero
}

View File

@ -16,6 +16,7 @@ package types
import (
"fmt"
"math"
"reflect"
"github.com/google/cel-go/common/types/ref"
@ -58,17 +59,22 @@ func (d Double) Add(other ref.Val) ref.Val {
// Compare implements traits.Comparer.Compare.
func (d Double) Compare(other ref.Val) ref.Val {
otherDouble, ok := other.(Double)
if !ok {
if math.IsNaN(float64(d)) {
return NewErr("NaN values cannot be ordered")
}
switch ov := other.(type) {
case Double:
if math.IsNaN(float64(ov)) {
return NewErr("NaN values cannot be ordered")
}
return compareDouble(d, ov)
case Int:
return compareDoubleInt(d, ov)
case Uint:
return compareDoubleUint(d, ov)
default:
return MaybeNoSuchOverloadErr(other)
}
if d < otherDouble {
return IntNegOne
}
if d > otherDouble {
return IntOne
}
return IntZero
}
// ConvertToNative implements ref.Val.ConvertToNative.
@ -158,12 +164,22 @@ func (d Double) Divide(other ref.Val) ref.Val {
// Equal implements ref.Val.Equal.
func (d Double) Equal(other ref.Val) ref.Val {
otherDouble, ok := other.(Double)
if !ok {
return MaybeNoSuchOverloadErr(other)
if math.IsNaN(float64(d)) {
return False
}
switch ov := other.(type) {
case Double:
if math.IsNaN(float64(ov)) {
return False
}
return Bool(d == ov)
case Int:
return Bool(compareDoubleInt(d, ov) == 0)
case Uint:
return Bool(compareDoubleUint(d, ov) == 0)
default:
return False
}
// TODO: Handle NaNs properly.
return Bool(d == otherDouble)
}
// Multiply implements traits.Multiplier.Multiply.

View File

@ -135,10 +135,7 @@ func (d Duration) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements ref.Val.Equal.
func (d Duration) Equal(other ref.Val) ref.Val {
otherDur, ok := other.(Duration)
if !ok {
return MaybeNoSuchOverloadErr(other)
}
return Bool(d.Duration == otherDur.Duration)
return Bool(ok && d.Duration == otherDur.Duration)
}
// Negate implements traits.Negater.Negate.

View File

@ -16,6 +16,7 @@ package types
import (
"fmt"
"math"
"reflect"
"strconv"
"time"
@ -72,17 +73,19 @@ func (i Int) Add(other ref.Val) ref.Val {
// Compare implements traits.Comparer.Compare.
func (i Int) Compare(other ref.Val) ref.Val {
otherInt, ok := other.(Int)
if !ok {
switch ov := other.(type) {
case Double:
if math.IsNaN(float64(ov)) {
return NewErr("NaN values cannot be ordered")
}
return compareIntDouble(i, ov)
case Int:
return compareInt(i, ov)
case Uint:
return compareIntUint(i, ov)
default:
return MaybeNoSuchOverloadErr(other)
}
if i < otherInt {
return IntNegOne
}
if i > otherInt {
return IntOne
}
return IntZero
}
// ConvertToNative implements ref.Val.ConvertToNative.
@ -208,11 +211,19 @@ func (i Int) Divide(other ref.Val) ref.Val {
// Equal implements ref.Val.Equal.
func (i Int) Equal(other ref.Val) ref.Val {
otherInt, ok := other.(Int)
if !ok {
return MaybeNoSuchOverloadErr(other)
switch ov := other.(type) {
case Double:
if math.IsNaN(float64(ov)) {
return False
}
return Bool(compareIntDouble(i, ov) == 0)
case Int:
return Bool(i == ov)
case Uint:
return Bool(compareIntUint(i, ov) == 0)
default:
return False
}
return Bool(i == otherInt)
}
// Modulo implements traits.Modder.Modulo.

View File

@ -95,6 +95,18 @@ func NewJSONList(adapter ref.TypeAdapter, l *structpb.ListValue) traits.Lister {
}
}
// NewMutableList creates a new mutable list whose internal state can be modified.
//
// The mutable list only handles `Add` calls correctly as it is intended only for use within
// comprehension loops which generate an immutable result upon completion.
func NewMutableList(adapter ref.TypeAdapter) traits.Lister {
return &mutableList{
TypeAdapter: adapter,
baseList: nil,
mutableValues: []ref.Val{},
}
}
// baseList points to a list containing elements of any type.
// The `value` is an array of native values, and refValue is its reflection object.
// The `ref.TypeAdapter` enables native type to CEL type conversions.
@ -131,28 +143,14 @@ func (l *baseList) Add(other ref.Val) ref.Val {
// Contains implements the traits.Container interface method.
func (l *baseList) Contains(elem ref.Val) ref.Val {
if IsUnknownOrError(elem) {
return elem
}
var err ref.Val
for i := 0; i < l.size; i++ {
val := l.NativeToValue(l.get(i))
cmp := elem.Equal(val)
b, ok := cmp.(Bool)
// When there is an error on the contain check, this is not necessarily terminal.
// The contains call could find the element and return True, just as though the user
// had written a per-element comparison in an exists() macro or logical ||, e.g.
// list.exists(e, e == elem)
if !ok && err == nil {
err = ValOrErr(cmp, "no such overload")
}
if b == True {
if ok && b == True {
return True
}
}
if err != nil {
return err
}
return False
}
@ -222,25 +220,18 @@ func (l *baseList) ConvertToType(typeVal ref.Type) ref.Val {
func (l *baseList) Equal(other ref.Val) ref.Val {
otherList, ok := other.(traits.Lister)
if !ok {
return MaybeNoSuchOverloadErr(other)
return False
}
if l.Size() != otherList.Size() {
return False
}
var maybeErr ref.Val
for i := IntZero; i < l.Size().(Int); i++ {
thisElem := l.Get(i)
otherElem := otherList.Get(i)
elemEq := thisElem.Equal(otherElem)
elemEq := Equal(thisElem, otherElem)
if elemEq == False {
return False
}
if maybeErr == nil && IsUnknownOrError(elemEq) {
maybeErr = elemEq
}
}
if maybeErr != nil {
return maybeErr
}
return True
}
@ -279,6 +270,32 @@ func (l *baseList) Value() interface{} {
return l.value
}
// mutableList aggregates values into its internal storage. For use with internal CEL variables only.
type mutableList struct {
ref.TypeAdapter
*baseList
mutableValues []ref.Val
}
// Add copies elements from the other list into the internal storage of the mutable list.
func (l *mutableList) Add(other ref.Val) ref.Val {
otherList, ok := other.(traits.Lister)
if !ok {
return MaybeNoSuchOverloadErr(otherList)
}
for i := IntZero; i < otherList.Size().(Int); i++ {
l.mutableValues = append(l.mutableValues, otherList.Get(i))
}
return l
}
// ToImmutableList returns an immutable list based on the internal storage of the mutable list.
func (l *mutableList) ToImmutableList() traits.Lister {
// The reference to internal state is guaranteed to be safe as this call is only performed
// when mutations have been completed.
return NewRefValList(l.TypeAdapter, l.mutableValues)
}
// concatList combines two list implementations together into a view.
// The `ref.TypeAdapter` enables native type to CEL type conversions.
type concatList struct {
@ -349,7 +366,7 @@ func (l *concatList) ConvertToType(typeVal ref.Type) ref.Val {
func (l *concatList) Equal(other ref.Val) ref.Val {
otherList, ok := other.(traits.Lister)
if !ok {
return MaybeNoSuchOverloadErr(other)
return False
}
if l.Size() != otherList.Size() {
return False
@ -358,7 +375,7 @@ func (l *concatList) Equal(other ref.Val) ref.Val {
for i := IntZero; i < l.Size().(Int); i++ {
thisElem := l.Get(i)
otherElem := otherList.Get(i)
elemEq := thisElem.Equal(otherElem)
elemEq := Equal(thisElem, otherElem)
if elemEq == False {
return False
}

View File

@ -108,8 +108,6 @@ type mapAccessor interface {
// Find returns a value, if one exists, for the inpput key.
//
// If the key is not found the function returns (nil, false).
// If the input key is not valid for the map, or is Err or Unknown the function returns
// (Unknown|Err, false).
Find(ref.Val) (ref.Val, bool)
// Iterator returns an Iterator over the map key set.
@ -135,11 +133,7 @@ type baseMap struct {
// Contains implements the traits.Container interface method.
func (m *baseMap) Contains(index ref.Val) ref.Val {
val, found := m.Find(index)
// When the index is not found and val is non-nil, this is an error or unknown value.
if !found && val != nil && IsUnknownOrError(val) {
return val
}
_, found := m.Find(index)
return Bool(found)
}
@ -251,36 +245,23 @@ func (m *baseMap) ConvertToType(typeVal ref.Type) ref.Val {
func (m *baseMap) Equal(other ref.Val) ref.Val {
otherMap, ok := other.(traits.Mapper)
if !ok {
return MaybeNoSuchOverloadErr(other)
return False
}
if m.Size() != otherMap.Size() {
return False
}
it := m.Iterator()
var maybeErr ref.Val
for it.HasNext() == True {
key := it.Next()
thisVal, _ := m.Find(key)
otherVal, found := otherMap.Find(key)
if !found {
if otherVal == nil {
return False
}
if maybeErr == nil {
maybeErr = MaybeNoSuchOverloadErr(otherVal)
}
continue
return False
}
valEq := thisVal.Equal(otherVal)
valEq := Equal(thisVal, otherVal)
if valEq == False {
return False
}
if maybeErr == nil && IsUnknownOrError(valEq) {
maybeErr = valEq
}
}
if maybeErr != nil {
return maybeErr
}
return True
}
@ -325,12 +306,10 @@ type jsonStructAccessor struct {
// found.
//
// If the key is not found the function returns (nil, false).
// If the input key is not a String, or is an Err or Unknown, the function returns
// (Unknown|Err, false).
func (a *jsonStructAccessor) Find(key ref.Val) (ref.Val, bool) {
strKey, ok := key.(String)
if !ok {
return ValOrErr(key, "unsupported key type: %v", key.Type()), false
return nil, false
}
keyVal, found := a.st[string(strKey)]
if !found {
@ -373,39 +352,58 @@ type reflectMapAccessor struct {
// returning (value, true) if present.
//
// If the key is not found the function returns (nil, false).
// If the input key is not a String, or is an Err or Unknown, the function returns
// (Unknown|Err, false).
func (a *reflectMapAccessor) Find(key ref.Val) (ref.Val, bool) {
if IsUnknownOrError(key) {
return MaybeNoSuchOverloadErr(key), false
}
if a.refValue.Len() == 0 {
func (m *reflectMapAccessor) Find(key ref.Val) (ref.Val, bool) {
if m.refValue.Len() == 0 {
return nil, false
}
k, err := key.ConvertToNative(a.keyType)
if err != nil {
return NewErr("unsupported key type: %v", key.Type()), false
if keyVal, found := m.findInternal(key); found {
return keyVal, true
}
refKey := reflect.ValueOf(k)
val := a.refValue.MapIndex(refKey)
if val.IsValid() {
return a.NativeToValue(val.Interface()), true
}
mapIt := a.refValue.MapRange()
for mapIt.Next() {
if refKey.Kind() == mapIt.Key().Kind() {
return nil, false
switch k := key.(type) {
// Double is not a valid proto map key type, so check for the key as an int or uint.
case Double:
if ik, ok := doubleToInt64Lossless(float64(k)); ok {
if keyVal, found := m.findInternal(Int(ik)); found {
return keyVal, true
}
}
if uk, ok := doubleToUint64Lossless(float64(k)); ok {
return m.findInternal(Uint(uk))
}
// map keys of type double are not supported.
case Int:
if uk, ok := int64ToUint64Lossless(int64(k)); ok {
return m.findInternal(Uint(uk))
}
case Uint:
if ik, ok := uint64ToInt64Lossless(uint64(k)); ok {
return m.findInternal(Int(ik))
}
}
return NewErr("unsupported key type: %v", key.Type()), false
return nil, false
}
// findInternal attempts to convert the incoming key to the map's internal native type
// and then returns the value, if found.
func (m *reflectMapAccessor) findInternal(key ref.Val) (ref.Val, bool) {
k, err := key.ConvertToNative(m.keyType)
if err != nil {
return nil, false
}
refKey := reflect.ValueOf(k)
val := m.refValue.MapIndex(refKey)
if val.IsValid() {
return m.NativeToValue(val.Interface()), true
}
return nil, false
}
// Iterator creates a Golang reflection based traits.Iterator.
func (a *reflectMapAccessor) Iterator() traits.Iterator {
func (m *reflectMapAccessor) Iterator() traits.Iterator {
return &mapIterator{
TypeAdapter: a.TypeAdapter,
mapKeys: a.refValue.MapRange(),
len: a.refValue.Len(),
TypeAdapter: m.TypeAdapter,
mapKeys: m.refValue.MapRange(),
len: m.refValue.Len(),
}
}
@ -420,24 +418,37 @@ type refValMapAccessor struct {
// Find uses native map accesses to find the key, returning (value, true) if present.
//
// If the key is not found the function returns (nil, false).
// If the input key is an Err or Unknown, the function returns (Unknown|Err, false).
func (a *refValMapAccessor) Find(key ref.Val) (ref.Val, bool) {
if IsUnknownOrError(key) {
return key, false
}
if len(a.mapVal) == 0 {
return nil, false
}
keyVal, found := a.mapVal[key]
if found {
if keyVal, found := a.mapVal[key]; found {
return keyVal, true
}
for k := range a.mapVal {
if k.Type().TypeName() == key.Type().TypeName() {
return nil, false
switch k := key.(type) {
case Double:
if ik, ok := doubleToInt64Lossless(float64(k)); ok {
if keyVal, found := a.mapVal[Int(ik)]; found {
return keyVal, true
}
}
if uk, ok := doubleToUint64Lossless(float64(k)); ok {
keyVal, found := a.mapVal[Uint(uk)]
return keyVal, found
}
// map keys of type double are not supported.
case Int:
if uk, ok := int64ToUint64Lossless(int64(k)); ok {
keyVal, found := a.mapVal[Uint(uk)]
return keyVal, found
}
case Uint:
if ik, ok := uint64ToInt64Lossless(uint64(k)); ok {
keyVal, found := a.mapVal[Int(ik)]
return keyVal, found
}
}
return NewErr("unsupported key type: %v", key.Type()), found
return nil, false
}
// Iterator produces a new traits.Iterator which iterates over the map keys via Golang reflection.
@ -460,12 +471,10 @@ type stringMapAccessor struct {
// Find uses native map accesses to find the key, returning (value, true) if present.
//
// If the key is not found the function returns (nil, false).
// If the input key is not a String, or is an Err or Unknown, the function returns
// (Unknown|Err, false).
func (a *stringMapAccessor) Find(key ref.Val) (ref.Val, bool) {
strKey, ok := key.(String)
if !ok {
return ValOrErr(key, "unsupported key type: %v", key.Type()), false
return nil, false
}
keyVal, found := a.mapVal[string(strKey)]
if !found {
@ -504,12 +513,10 @@ type stringIfaceMapAccessor struct {
// Find uses native map accesses to find the key, returning (value, true) if present.
//
// If the key is not found the function returns (nil, false).
// If the input key is not a String, or is an Err or Unknown, the function returns
// (Unknown|Err, false).
func (a *stringIfaceMapAccessor) Find(key ref.Val) (ref.Val, bool) {
strKey, ok := key.(String)
if !ok {
return ValOrErr(key, "unsupported key type: %v", key.Type()), false
return nil, false
}
keyVal, found := a.mapVal[string(strKey)]
if !found {
@ -542,11 +549,7 @@ type protoMap struct {
// Contains returns whether the map contains the given key.
func (m *protoMap) Contains(key ref.Val) ref.Val {
val, found := m.Find(key)
// When the index is not found and val is non-nil, this is an error or unknown value.
if !found && val != nil && IsUnknownOrError(val) {
return val
}
_, found := m.Find(key)
return Bool(found)
}
@ -642,7 +645,7 @@ func (m *protoMap) ConvertToType(typeVal ref.Type) ref.Val {
func (m *protoMap) Equal(other ref.Val) ref.Val {
otherMap, ok := other.(traits.Mapper)
if !ok {
return MaybeNoSuchOverloadErr(other)
return False
}
if m.value.Map.Len() != int(otherMap.Size().(Int)) {
return False
@ -653,14 +656,10 @@ func (m *protoMap) Equal(other ref.Val) ref.Val {
valVal := m.NativeToValue(val)
otherVal, found := otherMap.Find(keyVal)
if !found {
if otherVal == nil {
retVal = False
return false
}
retVal = MaybeNoSuchOverloadErr(otherVal)
retVal = False
return false
}
valEq := valVal.Equal(otherVal)
valEq := Equal(valVal, otherVal)
if valEq != True {
retVal = valEq
return false
@ -673,17 +672,41 @@ func (m *protoMap) Equal(other ref.Val) ref.Val {
// Find returns whether the protoreflect.Map contains the input key.
//
// If the key is not found the function returns (nil, false).
// If the input key is not a supported proto map key type, or is an Err or Unknown,
// the function returns
// (Unknown|Err, false).
func (m *protoMap) Find(key ref.Val) (ref.Val, bool) {
if IsUnknownOrError(key) {
return key, false
if keyVal, found := m.findInternal(key); found {
return keyVal, true
}
switch k := key.(type) {
// Double is not a valid proto map key type, so check for the key as an int or uint.
case Double:
if ik, ok := doubleToInt64Lossless(float64(k)); ok {
if keyVal, found := m.findInternal(Int(ik)); found {
return keyVal, true
}
}
if uk, ok := doubleToUint64Lossless(float64(k)); ok {
return m.findInternal(Uint(uk))
}
// map keys of type double are not supported.
case Int:
if uk, ok := int64ToUint64Lossless(int64(k)); ok {
return m.findInternal(Uint(uk))
}
case Uint:
if ik, ok := uint64ToInt64Lossless(uint64(k)); ok {
return m.findInternal(Int(ik))
}
}
return nil, false
}
// findInternal attempts to convert the incoming key to the map's internal native type
// and then returns the value, if found.
func (m *protoMap) findInternal(key ref.Val) (ref.Val, bool) {
// Convert the input key to the expected protobuf key type.
ntvKey, err := key.ConvertToNative(m.value.KeyType.ReflectType())
if err != nil {
return NewErr("unsupported key type: %v", key.Type()), false
return nil, false
}
// Use protoreflection to get the key value.
val := m.value.Get(protoreflect.ValueOf(ntvKey).MapKey())
@ -694,7 +717,7 @@ func (m *protoMap) Find(key ref.Val) (ref.Val, bool) {
switch v := val.Interface().(type) {
case protoreflect.List, protoreflect.Map:
// Maps do not support list or map values
return NewErr("unsupported map element type: (%T)%v", v, v), false
return nil, false
default:
return m.NativeToValue(v), true
}

View File

@ -83,10 +83,7 @@ func (n Null) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements ref.Val.Equal.
func (n Null) Equal(other ref.Val) ref.Val {
if NullType != other.Type() {
return ValOrErr(other, "no such overload")
}
return True
return Bool(NullType == other.Type())
}
// Type implements ref.Val.Type.

View File

@ -109,10 +109,8 @@ func (o *protoObj) ConvertToType(typeVal ref.Type) ref.Val {
}
func (o *protoObj) Equal(other ref.Val) ref.Val {
if o.typeDesc.Name() != other.Type().TypeName() {
return MaybeNoSuchOverloadErr(other)
}
return Bool(proto.Equal(o.value, other.Value().(proto.Message)))
otherPB, ok := other.Value().(proto.Message)
return Bool(ok && pb.Equal(o.value, otherPB))
}
// IsSet tests whether a field which is defined is set to a non-default value.

View File

@ -355,3 +355,35 @@ func uint64ToInt64Checked(v uint64) (int64, error) {
}
return int64(v), nil
}
func doubleToUint64Lossless(v float64) (uint64, bool) {
u, err := doubleToUint64Checked(v)
if err != nil {
return 0, false
}
if float64(u) != v {
return 0, false
}
return u, true
}
func doubleToInt64Lossless(v float64) (int64, bool) {
i, err := doubleToInt64Checked(v)
if err != nil {
return 0, false
}
if float64(i) != v {
return 0, false
}
return i, true
}
func int64ToUint64Lossless(v int64) (uint64, bool) {
u, err := int64ToUint64Checked(v)
return u, err == nil
}
func uint64ToInt64Lossless(v uint64) (int64, bool) {
i, err := uint64ToInt64Checked(v)
return i, err == nil
}

View File

@ -10,6 +10,7 @@ go_library(
srcs = [
"checked.go",
"enum.go",
"equal.go",
"file.go",
"pb.go",
"type.go",
@ -17,6 +18,7 @@ go_library(
importpath = "github.com/google/cel-go/common/types/pb",
deps = [
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
"@org_golang_google_protobuf//encoding/protowire:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
"@org_golang_google_protobuf//reflect/protoregistry:go_default_library",
@ -34,6 +36,7 @@ go_test(
name = "go_default_test",
size = "small",
srcs = [
"equal_test.go",
"file_test.go",
"pb_test.go",
"type_test.go",

View File

@ -0,0 +1,205 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pb
import (
"bytes"
"reflect"
"google.golang.org/protobuf/encoding/protowire"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/anypb"
)
// Equal returns whether two proto.Message instances are equal using the following criteria:
//
// - Messages must share the same instance of the type descriptor
// - Known set fields are compared using semantics equality
// - Bytes are compared using bytes.Equal
// - Scalar values are compared with operator ==
// - List and map types are equal if they have the same length and all elements are equal
// - Messages are equal if they share the same descriptor and all set fields are equal
// - Unknown fields are compared using byte equality
// - NaN values are not equal to each other
// - google.protobuf.Any values are unpacked before comparison
// - If the type descriptor for a protobuf.Any cannot be found, byte equality is used rather than
// semantic equality.
//
// This method of proto equality mirrors the behavior of the C++ protobuf MessageDifferencer
// whereas the golang proto.Equal implementation mirrors the Java protobuf equals() methods
// behaviors which needed to treat NaN values as equal due to Java semantics.
func Equal(x, y proto.Message) bool {
if x == nil || y == nil {
return x == nil && y == nil
}
xRef := x.ProtoReflect()
yRef := y.ProtoReflect()
return equalMessage(xRef, yRef)
}
func equalMessage(mx, my protoreflect.Message) bool {
// Note, the original proto.Equal upon which this implementation is based does not specifically handle the
// case when both messages are invalid. It is assumed that the descriptors will be equal and that byte-wise
// comparison will be used, though the semantics of validity are neither clear, nor promised within the
// proto.Equal implementation.
if mx.IsValid() != my.IsValid() || mx.Descriptor() != my.Descriptor() {
return false
}
// This is an innovation on the default proto.Equal where protobuf.Any values are unpacked before comparison
// as otherwise the Any values are compared by bytes rather than structurally.
if isAny(mx) && isAny(my) {
ax := mx.Interface().(*anypb.Any)
ay := my.Interface().(*anypb.Any)
// If the values are not the same type url, return false.
if ax.GetTypeUrl() != ay.GetTypeUrl() {
return false
}
// If the values are byte equal, then return true.
if bytes.Equal(ax.GetValue(), ay.GetValue()) {
return true
}
// Otherwise fall through to the semantic comparison of the any values.
x, err := ax.UnmarshalNew()
if err != nil {
return false
}
y, err := ay.UnmarshalNew()
if err != nil {
return false
}
// Recursively compare the unwrapped messages to ensure nested Any values are unwrapped accordingly.
return equalMessage(x.ProtoReflect(), y.ProtoReflect())
}
// Walk the set fields to determine field-wise equality
nx := 0
equal := true
mx.Range(func(fd protoreflect.FieldDescriptor, vx protoreflect.Value) bool {
nx++
equal = my.Has(fd) && equalField(fd, vx, my.Get(fd))
return equal
})
if !equal {
return false
}
// Establish the count of set fields on message y
ny := 0
my.Range(func(protoreflect.FieldDescriptor, protoreflect.Value) bool {
ny++
return true
})
// If the number of set fields is not equal return false.
if nx != ny {
return false
}
return equalUnknown(mx.GetUnknown(), my.GetUnknown())
}
func equalField(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool {
switch {
case fd.IsMap():
return equalMap(fd, x.Map(), y.Map())
case fd.IsList():
return equalList(fd, x.List(), y.List())
default:
return equalValue(fd, x, y)
}
}
func equalMap(fd protoreflect.FieldDescriptor, x, y protoreflect.Map) bool {
if x.Len() != y.Len() {
return false
}
equal := true
x.Range(func(k protoreflect.MapKey, vx protoreflect.Value) bool {
vy := y.Get(k)
equal = y.Has(k) && equalValue(fd.MapValue(), vx, vy)
return equal
})
return equal
}
func equalList(fd protoreflect.FieldDescriptor, x, y protoreflect.List) bool {
if x.Len() != y.Len() {
return false
}
for i := x.Len() - 1; i >= 0; i-- {
if !equalValue(fd, x.Get(i), y.Get(i)) {
return false
}
}
return true
}
func equalValue(fd protoreflect.FieldDescriptor, x, y protoreflect.Value) bool {
switch fd.Kind() {
case protoreflect.BoolKind:
return x.Bool() == y.Bool()
case protoreflect.EnumKind:
return x.Enum() == y.Enum()
case protoreflect.Int32Kind, protoreflect.Sint32Kind,
protoreflect.Int64Kind, protoreflect.Sint64Kind,
protoreflect.Sfixed32Kind, protoreflect.Sfixed64Kind:
return x.Int() == y.Int()
case protoreflect.Uint32Kind, protoreflect.Uint64Kind,
protoreflect.Fixed32Kind, protoreflect.Fixed64Kind:
return x.Uint() == y.Uint()
case protoreflect.FloatKind, protoreflect.DoubleKind:
return x.Float() == y.Float()
case protoreflect.StringKind:
return x.String() == y.String()
case protoreflect.BytesKind:
return bytes.Equal(x.Bytes(), y.Bytes())
case protoreflect.MessageKind, protoreflect.GroupKind:
return equalMessage(x.Message(), y.Message())
default:
return x.Interface() == y.Interface()
}
}
func equalUnknown(x, y protoreflect.RawFields) bool {
lenX := len(x)
lenY := len(y)
if lenX != lenY {
return false
}
if lenX == 0 {
return true
}
if bytes.Equal([]byte(x), []byte(y)) {
return true
}
mx := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
my := make(map[protoreflect.FieldNumber]protoreflect.RawFields)
for len(x) > 0 {
fnum, _, n := protowire.ConsumeField(x)
mx[fnum] = append(mx[fnum], x[:n]...)
x = x[n:]
}
for len(y) > 0 {
fnum, _, n := protowire.ConsumeField(y)
my[fnum] = append(my[fnum], y[:n]...)
y = y[n:]
}
return reflect.DeepEqual(mx, my)
}
func isAny(m protoreflect.Message) bool {
return string(m.Descriptor().FullName()) == "google.protobuf.Any"
}

View File

@ -151,10 +151,7 @@ func (s String) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements ref.Val.Equal.
func (s String) Equal(other ref.Val) ref.Val {
otherString, ok := other.(String)
if !ok {
return MaybeNoSuchOverloadErr(other)
}
return Bool(s == otherString)
return Bool(ok && s == otherString)
}
// Match implements traits.Matcher.Match.

View File

@ -134,10 +134,8 @@ func (t Timestamp) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements ref.Val.Equal.
func (t Timestamp) Equal(other ref.Val) ref.Val {
if TimestampType != other.Type() {
return MaybeNoSuchOverloadErr(other)
}
return Bool(t.Time.Equal(other.(Timestamp).Time))
otherTime, ok := other.(Timestamp)
return Bool(ok && t.Time.Equal(otherTime.Time))
}
// Receive implements traits.Reciever.Receive.

View File

@ -25,3 +25,8 @@ type Lister interface {
Iterable
Sizer
}
// MutableLister interface which emits an immutable result after an intermediate computation.
type MutableLister interface {
ToImmutableList() Lister
}

View File

@ -71,10 +71,8 @@ func (t *TypeValue) ConvertToType(typeVal ref.Type) ref.Val {
// Equal implements ref.Val.Equal.
func (t *TypeValue) Equal(other ref.Val) ref.Val {
if TypeType != other.Type() {
return ValOrErr(other, "no such overload")
}
return Bool(t.TypeName() == other.(ref.Type).TypeName())
otherType, ok := other.(ref.Type)
return Bool(ok && t.TypeName() == otherType.TypeName())
}
// HasTrait indicates whether the type supports the given trait.

View File

@ -16,6 +16,7 @@ package types
import (
"fmt"
"math"
"reflect"
"strconv"
@ -65,17 +66,19 @@ func (i Uint) Add(other ref.Val) ref.Val {
// Compare implements traits.Comparer.Compare.
func (i Uint) Compare(other ref.Val) ref.Val {
otherUint, ok := other.(Uint)
if !ok {
switch ov := other.(type) {
case Double:
if math.IsNaN(float64(ov)) {
return NewErr("NaN values cannot be ordered")
}
return compareUintDouble(i, ov)
case Int:
return compareUintInt(i, ov)
case Uint:
return compareUint(i, ov)
default:
return MaybeNoSuchOverloadErr(other)
}
if i < otherUint {
return IntNegOne
}
if i > otherUint {
return IntOne
}
return IntZero
}
// ConvertToNative implements ref.Val.ConvertToNative.
@ -176,11 +179,19 @@ func (i Uint) Divide(other ref.Val) ref.Val {
// Equal implements ref.Val.Equal.
func (i Uint) Equal(other ref.Val) ref.Val {
otherUint, ok := other.(Uint)
if !ok {
return MaybeNoSuchOverloadErr(other)
switch ov := other.(type) {
case Double:
if math.IsNaN(float64(ov)) {
return False
}
return Bool(compareUintDouble(i, ov) == 0)
case Int:
return Bool(compareUintInt(i, ov) == 0)
case Uint:
return Bool(i == ov)
default:
return False
}
return Bool(i == otherUint)
}
// Modulo implements traits.Modder.Modulo.

View File

@ -36,3 +36,13 @@ func IsPrimitiveType(val ref.Val) bool {
}
return false
}
// Equal returns whether the two ref.Value are heterogeneously equivalent.
func Equal(lhs ref.Val, rhs ref.Val) ref.Val {
lNull := lhs == NullValue
rNull := rhs == NullValue
if lNull || rNull {
return Bool(lNull == rNull)
}
return lhs.Equal(rhs)
}

View File

@ -18,6 +18,7 @@ go_library(
"//checker/decls:go_default_library",
"//common/types:go_default_library",
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter/functions:go_default_library",
"@org_golang_google_genproto//googleapis/api/expr/v1alpha1:go_default_library",
],

View File

@ -17,6 +17,7 @@ package ext
import (
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter/functions"
)
@ -246,3 +247,57 @@ func callInStrStrStrIntOutStr(fn func(string, string, string, int64) (string, er
return types.String(out)
}
}
func callInListStrOutStr(fn func([]string) (string, error)) functions.UnaryOp {
return func(args1 ref.Val) ref.Val {
vVal, ok := args1.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(args1)
}
strings := make([]string, vVal.Size().Value().(int64))
i := 0
for it := vVal.Iterator(); it.HasNext() == types.True; {
next := it.Next()
v, ok := next.(types.String)
if !ok {
return types.MaybeNoSuchOverloadErr(next)
}
strings[i] = string(v)
i++
}
out, err := fn(strings)
if err != nil {
return types.NewErr(err.Error())
}
return types.DefaultTypeAdapter.NativeToValue(out)
}
}
func callInListStrStrOutStr(fn func([]string, string) (string, error)) functions.BinaryOp {
return func(args1, args2 ref.Val) ref.Val {
vVal, ok := args1.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(args1)
}
arg1Val, ok := args2.(types.String)
if !ok {
return types.MaybeNoSuchOverloadErr(args2)
}
strings := make([]string, vVal.Size().Value().(int64))
i := 0
for it := vVal.Iterator(); it.HasNext() == types.True; {
next := it.Next()
v, ok := next.(types.String)
if !ok {
return types.MaybeNoSuchOverloadErr(next)
}
strings[i] = string(v)
i++
}
out, err := fn(strings, string(arg1Val))
if err != nil {
return types.NewErr(err.Error())
}
return types.DefaultTypeAdapter.NativeToValue(out)
}
}

View File

@ -67,6 +67,22 @@ import (
// 'hello mellow'.indexOf('ello', 2) // returns 7
// 'hello mellow'.indexOf('ello', 20) // error
//
// Join
//
// Returns a new string where the elements of string list are concatenated.
//
// The function also accepts an optional separator which is placed between elements in the resulting string.
//
// <list<string>>.join() -> <string>
// <list<string>>.join(<string>) -> <string>
//
// Examples:
//
// ['hello', 'mellow'].join() // returns 'hellomellow'
// ['hello', 'mellow'].join(' ') // returns 'hello mellow'
// [].join() // returns ''
// [].join('/') // returns ''
//
// LastIndexOf
//
// Returns the integer index at the start of the last occurrence of the search string. If the
@ -243,6 +259,14 @@ func (stringLib) CompileOptions() []cel.EnvOption {
decls.NewInstanceOverload("string_upper_ascii",
[]*exprpb.Type{decls.String},
decls.String)),
decls.NewFunction("join",
decls.NewInstanceOverload("list_join",
[]*exprpb.Type{decls.NewListType(decls.String)},
decls.String),
decls.NewInstanceOverload("list_join_string",
[]*exprpb.Type{decls.NewListType(decls.String), decls.String},
decls.String),
),
),
}
}
@ -356,6 +380,19 @@ func (stringLib) ProgramOptions() []cel.ProgramOption {
Operator: "string_upper_ascii",
Unary: callInStrOutStr(upperASCII),
},
&functions.Overload{
Operator: "join",
Unary: callInListStrOutStr(join),
Binary: callInListStrStrOutStr(joinSeparator),
},
&functions.Overload{
Operator: "list_join",
Unary: callInListStrOutStr(join),
},
&functions.Overload{
Operator: "list_join_string",
Binary: callInListStrStrOutStr(joinSeparator),
},
),
}
}
@ -501,3 +538,11 @@ func upperASCII(str string) (string, error) {
}
return string(runes), nil
}
func joinSeparator(strs []string, separator string) (string, error) {
return strings.Join(strs, separator), nil
}
func join(strs []string) (string, error) {
return strings.Join(strs, ""), nil
}

View File

@ -17,8 +17,10 @@ go_library(
"evalstate.go",
"interpretable.go",
"interpreter.go",
"optimizations.go",
"planner.go",
"prune.go",
"runtimecost.go",
],
importpath = "github.com/google/cel-go/interpreter",
deps = [

View File

@ -157,14 +157,23 @@ func (q *stringQualifier) QualifierValueEquals(value interface{}) bool {
// QualifierValueEquals implementation for int qualifiers.
func (q *intQualifier) QualifierValueEquals(value interface{}) bool {
ival, ok := value.(int64)
return ok && q.value == ival
return numericValueEquals(value, q.celValue)
}
// QualifierValueEquals implementation for uint qualifiers.
func (q *uintQualifier) QualifierValueEquals(value interface{}) bool {
uval, ok := value.(uint64)
return ok && q.value == uval
return numericValueEquals(value, q.celValue)
}
// QualifierValueEquals implementation for double qualifiers.
func (q *doubleQualifier) QualifierValueEquals(value interface{}) bool {
return numericValueEquals(value, q.celValue)
}
// numericValueEquals uses CEL equality to determine whether two number values are
func numericValueEquals(value interface{}, celValue ref.Val) bool {
val := types.DefaultTypeAdapter.NativeToValue(value)
return celValue.Equal(val) == types.True
}
// NewPartialAttributeFactory returns an AttributeFactory implementation capable of performing
@ -348,7 +357,8 @@ func (m *attributeMatcher) Resolve(vars Activation) (interface{}, error) {
// the standard Resolve logic applies.
func (m *attributeMatcher) TryResolve(vars Activation) (interface{}, bool, error) {
id := m.NamespacedAttribute.ID()
partial, isPartial := vars.(PartialActivation)
// Bug in how partial activation is resolved, should search parents as well.
partial, isPartial := toPartialActivation(vars)
if isPartial {
unk, err := m.fac.matchesUnknownPatterns(
partial,
@ -381,3 +391,14 @@ func (m *attributeMatcher) Qualify(vars Activation, obj interface{}) (interface{
}
return qual.Qualify(vars, obj)
}
func toPartialActivation(vars Activation) (PartialActivation, bool) {
pv, ok := vars.(PartialActivation)
if ok {
return pv, true
}
if vars.Parent() != nil {
return toPartialActivation(vars.Parent())
}
return nil, false
}

View File

@ -15,7 +15,6 @@
package interpreter
import (
"errors"
"fmt"
"math"
@ -487,9 +486,7 @@ func (a *maybeAttribute) AddQualifier(qual Qualifier) (Attribute, error) {
}
}
// Next, ensure the most specific variable / type reference is searched first.
a.attrs = append([]NamespacedAttribute{
a.fac.AbsoluteAttribute(qual.ID(), augmentedNames...),
}, a.attrs...)
a.attrs = append([]NamespacedAttribute{a.fac.AbsoluteAttribute(qual.ID(), augmentedNames...)}, a.attrs...)
return a, nil
}
@ -628,6 +625,10 @@ func newQualifier(adapter ref.TypeAdapter, id int64, v interface{}) (Qualifier,
qual = &uintQualifier{id: id, value: val, celValue: types.Uint(val), adapter: adapter}
case bool:
qual = &boolQualifier{id: id, value: val, celValue: types.Bool(val), adapter: adapter}
case float32:
qual = &doubleQualifier{id: id, value: float64(val), celValue: types.Double(val), adapter: adapter}
case float64:
qual = &doubleQualifier{id: id, value: val, celValue: types.Double(val), adapter: adapter}
case types.String:
qual = &stringQualifier{id: id, value: string(val), celValue: val, adapter: adapter}
case types.Int:
@ -714,9 +715,6 @@ func (q *stringQualifier) Qualify(vars Activation, obj interface{}) (interface{}
if err != nil {
return nil, err
}
if types.IsUnknown(elem) {
return elem, nil
}
return elem, nil
}
if isMap && !isKey {
@ -829,9 +827,6 @@ func (q *intQualifier) Qualify(vars Activation, obj interface{}) (interface{}, e
if err != nil {
return nil, err
}
if types.IsUnknown(elem) {
return elem, nil
}
return elem, nil
}
if isMap && !isKey {
@ -891,9 +886,6 @@ func (q *uintQualifier) Qualify(vars Activation, obj interface{}) (interface{},
if err != nil {
return nil, err
}
if types.IsUnknown(elem) {
return elem, nil
}
return elem, nil
}
if isMap && !isKey {
@ -942,9 +934,6 @@ func (q *boolQualifier) Qualify(vars Activation, obj interface{}) (interface{},
if err != nil {
return nil, err
}
if types.IsUnknown(elem) {
return elem, nil
}
return elem, nil
}
if !isKey {
@ -996,6 +985,37 @@ func (q *fieldQualifier) Cost() (min, max int64) {
return 0, 0
}
// doubleQualifier qualifies a CEL object, map, or list using a double value.
//
// This qualifier is used for working with dynamic data like JSON or protobuf.Any where the value
// type may not be known ahead of time and may not conform to the standard types supported as valid
// protobuf map key types.
type doubleQualifier struct {
id int64
value float64
celValue ref.Val
adapter ref.TypeAdapter
}
// ID is an implementation of the Qualifier interface method.
func (q *doubleQualifier) ID() int64 {
return q.id
}
// Qualify implements the Qualifier interface method.
func (q *doubleQualifier) Qualify(vars Activation, obj interface{}) (interface{}, error) {
switch o := obj.(type) {
case types.Unknown:
return o, nil
default:
elem, err := refResolve(q.adapter, q.celValue, obj)
if err != nil {
return nil, err
}
return elem, nil
}
}
// refResolve attempts to convert the value to a CEL value and then uses reflection methods
// to try and resolve the qualifier.
func refResolve(adapter ref.TypeAdapter, idx ref.Val, obj interface{}) (ref.Val, error) {
@ -1006,9 +1026,6 @@ func refResolve(adapter ref.TypeAdapter, idx ref.Val, obj interface{}) (ref.Val,
if !found {
return nil, fmt.Errorf("no such key: %v", idx)
}
if types.IsError(elem) {
return nil, elem.(*types.Err)
}
return elem, nil
}
indexer, isIndexer := celVal.(traits.Indexer)
@ -1028,5 +1045,5 @@ func refResolve(adapter ref.TypeAdapter, idx ref.Val, obj interface{}) (ref.Val,
if types.IsError(celVal) {
return nil, celVal.(*types.Err)
}
return nil, errors.New("no such overload")
return nil, fmt.Errorf("no such key: %v", idx)
}

View File

@ -16,7 +16,11 @@ package interpreter
import "math"
// TODO: remove Coster.
// Coster calculates the heuristic cost incurred during evaluation.
// Deprecated: Please migrate cel.EstimateCost, it supports length estimates for input data and cost estimates for
// extension functions.
type Coster interface {
Cost() (min, max int64)
}

View File

@ -25,11 +25,8 @@ import (
// Interpretable expression nodes at construction time.
type InterpretableDecorator func(Interpretable) (Interpretable, error)
// evalObserver is a functional interface that accepts an expression id and an observed value.
type evalObserver func(int64, ref.Val)
// decObserveEval records evaluation state into an EvalState object.
func decObserveEval(observer evalObserver) InterpretableDecorator {
func decObserveEval(observer EvalObserver) InterpretableDecorator {
return func(i Interpretable) (Interpretable, error) {
switch inst := i.(type) {
case *evalWatch, *evalWatchAttr, *evalWatchConst:
@ -54,6 +51,19 @@ func decObserveEval(observer evalObserver) InterpretableDecorator {
}
}
// decInterruptFolds creates an intepretable decorator which marks comprehensions as interruptable
// where the interrupt state is communicated via a hidden variable on the Activation.
func decInterruptFolds() InterpretableDecorator {
return func(i Interpretable) (Interpretable, error) {
fold, ok := i.(*evalFold)
if !ok {
return i, nil
}
fold.interruptable = true
return fold, nil
}
}
// decDisableShortcircuits ensures that all branches of an expression will be evaluated, no short-circuiting.
func decDisableShortcircuits() InterpretableDecorator {
return func(i Interpretable) (Interpretable, error) {
@ -71,16 +81,8 @@ func decDisableShortcircuits() InterpretableDecorator {
rhs: expr.rhs,
}, nil
case *evalFold:
return &evalExhaustiveFold{
id: expr.id,
accu: expr.accu,
accuVar: expr.accuVar,
iterRange: expr.iterRange,
iterVar: expr.iterVar,
cond: expr.cond,
step: expr.step,
result: expr.result,
}, nil
expr.exhaustive = true
return expr, nil
case InterpretableAttribute:
cond, isCond := expr.Attr().(*conditionalAttribute)
if isCond {
@ -118,6 +120,48 @@ func decOptimize() InterpretableDecorator {
}
}
// decRegexOptimizer compiles regex pattern string constants.
func decRegexOptimizer(regexOptimizations ...*RegexOptimization) InterpretableDecorator {
functionMatchMap := make(map[string]*RegexOptimization)
overloadMatchMap := make(map[string]*RegexOptimization)
for _, m := range regexOptimizations {
functionMatchMap[m.Function] = m
if m.OverloadID != "" {
overloadMatchMap[m.OverloadID] = m
}
}
return func(i Interpretable) (Interpretable, error) {
call, ok := i.(InterpretableCall)
if !ok {
return i, nil
}
var matcher *RegexOptimization
var found bool
if call.OverloadID() != "" {
matcher, found = overloadMatchMap[call.OverloadID()]
}
if !found {
matcher, found = functionMatchMap[call.Function()]
}
if !found || matcher.RegexIndex >= len(call.Args()) {
return i, nil
}
args := call.Args()
regexArg := args[matcher.RegexIndex]
regexStr, isConst := regexArg.(InterpretableConst)
if !isConst {
return i, nil
}
pattern, ok := regexStr.Value().(types.String)
if !ok {
return i, nil
}
return matcher.Factory(call, string(pattern))
}
}
func maybeOptimizeConstUnary(i Interpretable, call InterpretableCall) (Interpretable, error) {
args := call.Args()
if len(args) != 1 {
@ -177,7 +221,6 @@ func maybeOptimizeSetMembership(i Interpretable, inlist InterpretableCall) (Inte
return NewConstValue(inlist.ID(), types.False), nil
}
it := list.Iterator()
var typ ref.Type
valueSet := make(map[ref.Val]ref.Val)
for it.HasNext() == types.True {
elem := it.Next()
@ -185,17 +228,44 @@ func maybeOptimizeSetMembership(i Interpretable, inlist InterpretableCall) (Inte
// Note, non-primitive type are not yet supported.
return i, nil
}
if typ == nil {
typ = elem.Type()
} else if typ.TypeName() != elem.Type().TypeName() {
return i, nil
}
valueSet[elem] = types.True
switch ev := elem.(type) {
case types.Double:
iv := ev.ConvertToType(types.IntType)
// Ensure that only lossless conversions are added to the set
if !types.IsError(iv) && iv.Equal(ev) == types.True {
valueSet[iv] = types.True
}
// Ensure that only lossless conversions are added to the set
uv := ev.ConvertToType(types.UintType)
if !types.IsError(uv) && uv.Equal(ev) == types.True {
valueSet[uv] = types.True
}
case types.Int:
dv := ev.ConvertToType(types.DoubleType)
if !types.IsError(dv) {
valueSet[dv] = types.True
}
uv := ev.ConvertToType(types.UintType)
if !types.IsError(uv) {
valueSet[uv] = types.True
}
case types.Uint:
dv := ev.ConvertToType(types.DoubleType)
if !types.IsError(dv) {
valueSet[dv] = types.True
}
iv := ev.ConvertToType(types.IntType)
if !types.IsError(iv) {
valueSet[iv] = types.True
}
default:
break
}
}
return &evalSetMembership{
inst: inlist,
arg: lhs,
argTypeName: typ.TypeName(),
valueSet: valueSet,
inst: inlist,
arg: lhs,
valueSet: valueSet,
}, nil
}

View File

@ -100,8 +100,6 @@ func StandardOverloads() []*Overload {
return cmp
}},
// TODO: Verify overflow, NaN, underflow cases for numeric values.
// Add operator
{Operator: operators.Add,
OperandTrait: traits.AdderType,

View File

@ -88,6 +88,18 @@ type InterpretableCall interface {
Args() []Interpretable
}
// InterpretableConstructor interface for inspecting Interpretable instructions that initialize a list, map
// or struct.
type InterpretableConstructor interface {
Interpretable
// InitVals returns all the list elements, map key and values or struct field values.
InitVals() []Interpretable
// Type returns the type constructed.
Type() ref.Type
}
// Core Interpretable implementations used during the program planning phase.
type evalTestOnly struct {
@ -298,7 +310,13 @@ func (eq *evalEq) ID() int64 {
func (eq *evalEq) Eval(ctx Activation) ref.Val {
lVal := eq.lhs.Eval(ctx)
rVal := eq.rhs.Eval(ctx)
return lVal.Equal(rVal)
if types.IsUnknownOrError(lVal) {
return lVal
}
if types.IsUnknownOrError(rVal) {
return rVal
}
return types.Equal(lVal, rVal)
}
// Cost implements the Coster interface method.
@ -336,12 +354,13 @@ func (ne *evalNe) ID() int64 {
func (ne *evalNe) Eval(ctx Activation) ref.Val {
lVal := ne.lhs.Eval(ctx)
rVal := ne.rhs.Eval(ctx)
eqVal := lVal.Equal(rVal)
eqBool, ok := eqVal.(types.Bool)
if !ok {
return types.ValOrErr(eqVal, "no such overload: _!=_")
if types.IsUnknownOrError(lVal) {
return lVal
}
return !eqBool
if types.IsUnknownOrError(rVal) {
return rVal
}
return types.Bool(types.Equal(lVal, rVal) != types.True)
}
// Cost implements the Coster interface method.
@ -526,6 +545,17 @@ type evalVarArgs struct {
impl functions.FunctionOp
}
// NewCall creates a new call Interpretable.
func NewCall(id int64, function, overload string, args []Interpretable, impl functions.FunctionOp) InterpretableCall {
return &evalVarArgs{
id: id,
function: function,
overload: overload,
args: args,
impl: impl,
}
}
// ID implements the Interpretable interface method.
func (fn *evalVarArgs) ID() int64 {
return fn.id
@ -603,6 +633,14 @@ func (l *evalList) Eval(ctx Activation) ref.Val {
return l.adapter.NativeToValue(elemVals)
}
func (l *evalList) InitVals() []Interpretable {
return l.elems
}
func (l *evalList) Type() ref.Type {
return types.ListType
}
// Cost implements the Coster interface method.
func (l *evalList) Cost() (min, max int64) {
return sumOfCost(l.elems)
@ -638,6 +676,14 @@ func (m *evalMap) Eval(ctx Activation) ref.Val {
return m.adapter.NativeToValue(entries)
}
func (m *evalMap) InitVals() []Interpretable {
return append(m.keys, m.vals...)
}
func (m *evalMap) Type() ref.Type {
return types.MapType
}
// Cost implements the Coster interface method.
func (m *evalMap) Cost() (min, max int64) {
kMin, kMax := sumOfCost(m.keys)
@ -672,6 +718,14 @@ func (o *evalObj) Eval(ctx Activation) ref.Val {
return o.provider.NewValue(o.typeName, fieldVals)
}
func (o *evalObj) InitVals() []Interpretable {
return o.vals
}
func (o *evalObj) Type() ref.Type {
return types.NewObjectTypeValue(o.typeName)
}
// Cost implements the Coster interface method.
func (o *evalObj) Cost() (min, max int64) {
return sumOfCost(o.vals)
@ -688,14 +742,17 @@ func sumOfCost(interps []Interpretable) (min, max int64) {
}
type evalFold struct {
id int64
accuVar string
iterVar string
iterRange Interpretable
accu Interpretable
cond Interpretable
step Interpretable
result Interpretable
id int64
accuVar string
iterVar string
iterRange Interpretable
accu Interpretable
cond Interpretable
step Interpretable
result Interpretable
adapter ref.TypeAdapter
exhaustive bool
interruptable bool
}
// ID implements the Interpretable interface method.
@ -714,9 +771,19 @@ func (fold *evalFold) Eval(ctx Activation) ref.Val {
accuCtx.parent = ctx
accuCtx.name = fold.accuVar
accuCtx.val = fold.accu.Eval(ctx)
// If the accumulator starts as an empty list, then the comprehension will build a list
// so create a mutable list to optimize the cost of the inner loop.
l, ok := accuCtx.val.(traits.Lister)
buildingList := false
if !fold.exhaustive && ok && l.Size() == types.IntZero {
buildingList = true
accuCtx.val = types.NewMutableList(fold.adapter)
}
iterCtx := varActivationPool.Get().(*varActivation)
iterCtx.parent = accuCtx
iterCtx.name = fold.iterVar
interrupted := false
it := foldRange.(traits.Iterable).Iterator()
for it.HasNext() == types.True {
// Modify the iter var in the fold activation.
@ -725,17 +792,31 @@ func (fold *evalFold) Eval(ctx Activation) ref.Val {
// Evaluate the condition, terminate the loop if false.
cond := fold.cond.Eval(iterCtx)
condBool, ok := cond.(types.Bool)
if !types.IsUnknown(cond) && ok && condBool != types.True {
if !fold.exhaustive && ok && condBool != types.True {
break
}
// Evalute the evaluation step into accu var.
// Evaluate the evaluation step into accu var.
accuCtx.val = fold.step.Eval(iterCtx)
if fold.interruptable {
if stop, found := ctx.ResolveName("#interrupted"); found && stop == true {
interrupted = true
break
}
}
}
varActivationPool.Put(iterCtx)
if interrupted {
varActivationPool.Put(accuCtx)
return types.NewErr("operation interrupted")
}
// Compute the result.
res := fold.result.Eval(accuCtx)
varActivationPool.Put(iterCtx)
varActivationPool.Put(accuCtx)
// Convert a mutable list to an immutable one, if the comprehension has generated a list as a result.
if !types.IsUnknownOrError(res) && buildingList {
res = res.(traits.MutableLister).ToImmutableList()
}
return res
}
@ -760,6 +841,10 @@ func (fold *evalFold) Cost() (min, max int64) {
cMin, cMax := estimateCost(fold.cond)
sMin, sMax := estimateCost(fold.step)
rMin, rMax := estimateCost(fold.result)
if fold.exhaustive {
cMin = cMin * rangeCnt
sMin = sMin * rangeCnt
}
// The cond and step costs are multiplied by size(iterRange). The minimum possible cost incurs
// when the evaluation result can be determined by the first iteration.
@ -773,10 +858,9 @@ func (fold *evalFold) Cost() (min, max int64) {
// evalSetMembership is an Interpretable implementation which tests whether an input value
// exists within the set of map keys used to model a set.
type evalSetMembership struct {
inst Interpretable
arg Interpretable
argTypeName string
valueSet map[ref.Val]ref.Val
inst Interpretable
arg Interpretable
valueSet map[ref.Val]ref.Val
}
// ID implements the Interpretable interface method.
@ -787,9 +871,6 @@ func (e *evalSetMembership) ID() int64 {
// Eval implements the Interpretable interface method.
func (e *evalSetMembership) Eval(ctx Activation) ref.Val {
val := e.arg.Eval(ctx)
if val.Type().TypeName() != e.argTypeName {
return types.ValOrErr(val, "no such overload")
}
if ret, found := e.valueSet[val]; found {
return ret
}
@ -805,13 +886,13 @@ func (e *evalSetMembership) Cost() (min, max int64) {
// expression so that it may observe the computed value and send it to an observer.
type evalWatch struct {
Interpretable
observer evalObserver
observer EvalObserver
}
// Eval implements the Interpretable interface method.
func (e *evalWatch) Eval(ctx Activation) ref.Val {
val := e.Interpretable.Eval(ctx)
e.observer(e.ID(), val)
e.observer(e.ID(), e.Interpretable, val)
return val
}
@ -826,7 +907,7 @@ func (e *evalWatch) Cost() (min, max int64) {
// must implement the instAttr interface by proxy.
type evalWatchAttr struct {
InterpretableAttribute
observer evalObserver
observer EvalObserver
}
// AddQualifier creates a wrapper over the incoming qualifier which observes the qualification
@ -850,11 +931,23 @@ func (e *evalWatchAttr) AddQualifier(q Qualifier) (Attribute, error) {
return e, err
}
// Cost implements the Coster interface method.
func (e *evalWatchAttr) Cost() (min, max int64) {
return estimateCost(e.InterpretableAttribute)
}
// Eval implements the Interpretable interface method.
func (e *evalWatchAttr) Eval(vars Activation) ref.Val {
val := e.InterpretableAttribute.Eval(vars)
e.observer(e.ID(), e.InterpretableAttribute, val)
return val
}
// evalWatchConstQual observes the qualification of an object using a constant boolean, int,
// string, or uint.
type evalWatchConstQual struct {
ConstantQualifier
observer evalObserver
observer EvalObserver
adapter ref.TypeAdapter
}
@ -872,7 +965,7 @@ func (e *evalWatchConstQual) Qualify(vars Activation, obj interface{}) (interfac
} else {
val = e.adapter.NativeToValue(out)
}
e.observer(e.ID(), val)
e.observer(e.ID(), e.ConstantQualifier, val)
return out, err
}
@ -885,7 +978,7 @@ func (e *evalWatchConstQual) QualifierValueEquals(value interface{}) bool {
// evalWatchQual observes the qualification of an object by a value computed at runtime.
type evalWatchQual struct {
Qualifier
observer evalObserver
observer EvalObserver
adapter ref.TypeAdapter
}
@ -903,32 +996,20 @@ func (e *evalWatchQual) Qualify(vars Activation, obj interface{}) (interface{},
} else {
val = e.adapter.NativeToValue(out)
}
e.observer(e.ID(), val)
e.observer(e.ID(), e.Qualifier, val)
return out, err
}
// Cost implements the Coster interface method.
func (e *evalWatchAttr) Cost() (min, max int64) {
return estimateCost(e.InterpretableAttribute)
}
// Eval implements the Interpretable interface method.
func (e *evalWatchAttr) Eval(vars Activation) ref.Val {
val := e.InterpretableAttribute.Eval(vars)
e.observer(e.ID(), val)
return val
}
// evalWatchConst describes a watcher of an instConst Interpretable.
type evalWatchConst struct {
InterpretableConst
observer evalObserver
observer EvalObserver
}
// Eval implements the Interpretable interface method.
func (e *evalWatchConst) Eval(vars Activation) ref.Val {
val := e.Value()
e.observer(e.ID(), val)
e.observer(e.ID(), e.InterpretableConst, val)
return val
}
@ -1074,83 +1155,6 @@ func (cond *evalExhaustiveConditional) Cost() (min, max int64) {
return cond.attr.Cost()
}
// evalExhaustiveFold is like evalFold, but does not short-circuit argument evaluation.
type evalExhaustiveFold struct {
id int64
accuVar string
iterVar string
iterRange Interpretable
accu Interpretable
cond Interpretable
step Interpretable
result Interpretable
}
// ID implements the Interpretable interface method.
func (fold *evalExhaustiveFold) ID() int64 {
return fold.id
}
// Eval implements the Interpretable interface method.
func (fold *evalExhaustiveFold) Eval(ctx Activation) ref.Val {
foldRange := fold.iterRange.Eval(ctx)
if !foldRange.Type().HasTrait(traits.IterableType) {
return types.ValOrErr(foldRange, "got '%T', expected iterable type", foldRange)
}
// Configure the fold activation with the accumulator initial value.
accuCtx := varActivationPool.Get().(*varActivation)
accuCtx.parent = ctx
accuCtx.name = fold.accuVar
accuCtx.val = fold.accu.Eval(ctx)
iterCtx := varActivationPool.Get().(*varActivation)
iterCtx.parent = accuCtx
iterCtx.name = fold.iterVar
it := foldRange.(traits.Iterable).Iterator()
for it.HasNext() == types.True {
// Modify the iter var in the fold activation.
iterCtx.val = it.Next()
// Evaluate the condition, but don't terminate the loop as this is exhaustive eval!
fold.cond.Eval(iterCtx)
// Evalute the evaluation step into accu var.
accuCtx.val = fold.step.Eval(iterCtx)
}
// Compute the result.
res := fold.result.Eval(accuCtx)
varActivationPool.Put(iterCtx)
varActivationPool.Put(accuCtx)
return res
}
// Cost implements the Coster interface method.
func (fold *evalExhaustiveFold) Cost() (min, max int64) {
// Compute the cost for evaluating iterRange.
iMin, iMax := estimateCost(fold.iterRange)
// Compute the size of iterRange. If the size depends on the input, return the maximum possible
// cost range.
foldRange := fold.iterRange.Eval(EmptyActivation())
if !foldRange.Type().HasTrait(traits.IterableType) {
return 0, math.MaxInt64
}
var rangeCnt int64
it := foldRange.(traits.Iterable).Iterator()
for it.HasNext() == types.True {
it.Next()
rangeCnt++
}
aMin, aMax := estimateCost(fold.accu)
cMin, cMax := estimateCost(fold.cond)
sMin, sMax := estimateCost(fold.step)
rMin, rMax := estimateCost(fold.result)
// The cond and step costs are multiplied by size(iterRange).
return iMin + aMin + cMin*rangeCnt + sMin*rangeCnt + rMin,
iMax + aMax + cMax*rangeCnt + sMax*rangeCnt + rMax
}
// evalAttr evaluates an Attribute value.
type evalAttr struct {
adapter ref.TypeAdapter

View File

@ -38,41 +38,118 @@ type Interpreter interface {
decorators ...InterpretableDecorator) (Interpretable, error)
}
// EvalObserver is a functional interface that accepts an expression id and an observed value.
// The id identifies the expression that was evaluated, the programStep is the Interpretable or Qualifier that
// was evaluated and value is the result of the evaluation.
type EvalObserver func(id int64, programStep interface{}, value ref.Val)
// Observe constructs a decorator that calls all the provided observers in order after evaluating each Interpretable
// or Qualifier during program evaluation.
func Observe(observers ...EvalObserver) InterpretableDecorator {
if len(observers) == 1 {
return decObserveEval(observers[0])
}
observeFn := func(id int64, programStep interface{}, val ref.Val) {
for _, observer := range observers {
observer(id, programStep, val)
}
}
return decObserveEval(observeFn)
}
// EvalCancelledError represents a cancelled program evaluation operation.
type EvalCancelledError struct {
Message string
// Type identifies the cause of the cancellation.
Cause CancellationCause
}
func (e EvalCancelledError) Error() string {
return e.Message
}
// CancellationCause enumerates the ways a program evaluation operation can be cancelled.
type CancellationCause int
const (
// ContextCancelled indicates that the operation was cancelled in response to a Golang context cancellation.
ContextCancelled CancellationCause = iota
// CostLimitExceeded indicates that the operation was cancelled in response to the actual cost limit being
// exceeded.
CostLimitExceeded
)
// TODO: Replace all usages of TrackState with EvalStateObserver
// TrackState decorates each expression node with an observer which records the value
// associated with the given expression id. EvalState must be provided to the decorator.
// This decorator is not thread-safe, and the EvalState must be reset between Eval()
// calls.
// DEPRECATED: Please use EvalStateObserver instead. It composes gracefully with additional observers.
func TrackState(state EvalState) InterpretableDecorator {
observer := func(id int64, val ref.Val) {
return Observe(EvalStateObserver(state))
}
// EvalStateObserver provides an observer which records the value
// associated with the given expression id. EvalState must be provided to the observer.
// This decorator is not thread-safe, and the EvalState must be reset between Eval()
// calls.
func EvalStateObserver(state EvalState) EvalObserver {
return func(id int64, programStep interface{}, val ref.Val) {
state.SetValue(id, val)
}
return decObserveEval(observer)
}
// TODO: Replace all usages of ExhaustiveEval with ExhaustiveEvalWrapper
// ExhaustiveEval replaces operations that short-circuit with versions that evaluate
// expressions and couples this behavior with the TrackState() decorator to provide
// insight into the evaluation state of the entire expression. EvalState must be
// provided to the decorator. This decorator is not thread-safe, and the EvalState
// must be reset between Eval() calls.
func ExhaustiveEval(state EvalState) InterpretableDecorator {
func ExhaustiveEval() InterpretableDecorator {
ex := decDisableShortcircuits()
obs := TrackState(state)
return func(i Interpretable) (Interpretable, error) {
var err error
i, err = ex(i)
if err != nil {
return nil, err
}
return obs(i)
return ex(i)
}
}
func InterruptableEval() InterpretableDecorator {
return decInterruptFolds()
}
// Optimize will pre-compute operations such as list and map construction and optimize
// call arguments to set membership tests. The set of optimizations will increase over time.
func Optimize() InterpretableDecorator {
return decOptimize()
}
// RegexOptimization provides a way to replace an InterpretableCall for a regex function when the
// RegexIndex argument is a string constant. Typically, the Factory would compile the regex pattern at
// RegexIndex and report any errors (at program creation time) and then use the compiled regex for
// all regex function invocations.
type RegexOptimization struct {
// Function is the name of the function to optimize.
Function string
// OverloadID is the ID of the overload to optimize.
OverloadID string
// RegexIndex is the index position of the regex pattern argument. Only calls to the function where this argument is
// a string constant will be delegated to this optimizer.
RegexIndex int
// Factory constructs a replacement InterpretableCall node that optimizes the regex function call. Factory is
// provided with the unoptimized regex call and the string constant at the RegexIndex argument.
// The Factory may compile the regex for use across all invocations of the call, return any errors and
// return an interpreter.NewCall with the desired regex optimized function impl.
Factory func(call InterpretableCall, regexPattern string) (InterpretableCall, error)
}
// CompileRegexConstants compiles regex pattern string constants at program creation time and reports any regex pattern
// compile errors.
func CompileRegexConstants(regexOptimizations ...*RegexOptimization) InterpretableDecorator {
return decRegexOptimizer(regexOptimizations...)
}
type exprInterpreter struct {
dispatcher Dispatcher
container *containers.Container

View File

@ -0,0 +1,46 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package interpreter
import (
"regexp"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// MatchesRegexOptimization optimizes the 'matches' standard library function by compiling the regex pattern and
// reporting any compilation errors at program creation time, and using the compiled regex pattern for all function
// call invocations.
var MatchesRegexOptimization = &RegexOptimization{
Function: "matches",
RegexIndex: 1,
Factory: func(call InterpretableCall, regexPattern string) (InterpretableCall, error) {
compiledRegex, err := regexp.Compile(regexPattern)
if err != nil {
return nil, err
}
return NewCall(call.ID(), call.Function(), call.OverloadID(), call.Args(), func(values ...ref.Val) ref.Val {
if len(values) != 2 {
return types.NoSuchOverloadErr()
}
in, ok := values[0].Value().(string)
if !ok {
return types.NoSuchOverloadErr()
}
return types.Bool(compiledRegex.MatchString(in))
}), nil
},
}

View File

@ -617,6 +617,7 @@ func (p *planner) planComprehension(expr *exprpb.Expr) (Interpretable, error) {
cond: cond,
step: step,
result: result,
adapter: p.adapter,
}, nil
}

View File

@ -0,0 +1,192 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package interpreter provides functions to evaluate parsed expressions with
// the option to augment the evaluation with inputs and functions supplied at
// evaluation time.
package interpreter
import (
"math"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
// WARNING: Any changes to cost calculations in this file require a corresponding change in checker/cost.go
// ActualCostEstimator provides function call cost estimations at runtime
// CallCost returns an estimated cost for the function overload invocation with the given args, or nil if it has no
// estimate to provide. CEL attempts to provide reasonable estimates for its standard function library, so CallCost
// should typically not need to provide an estimate for CELs standard function.
type ActualCostEstimator interface {
CallCost(overloadId string, args []ref.Val) *uint64
}
// CostObserver provides an observer that tracks runtime cost.
func CostObserver(tracker *CostTracker) EvalObserver {
observer := func(id int64, programStep interface{}, val ref.Val) {
switch t := programStep.(type) {
case ConstantQualifier:
// TODO: Push identifiers on to the stack before observing constant qualifiers that apply to them
// and enable the below pop. Once enabled this can case can be collapsed into the Qualifier case.
//tracker.stack.pop(1)
tracker.cost += 1
case InterpretableConst:
// zero cost
case InterpretableAttribute:
// Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions.
_, isConditional := t.Attr().(*conditionalAttribute)
if !isConditional {
tracker.cost += common.SelectAndIdentCost
}
case *evalExhaustiveConditional, *evalOr, *evalAnd, *evalExhaustiveOr, *evalExhaustiveAnd:
// Ternary has no direct cost. All cost is from the conditional and the true/false branch expressions.
case Qualifier:
tracker.stack.pop(1)
tracker.cost += 1
case InterpretableCall:
if argVals, ok := tracker.stack.pop(len(t.Args())); ok {
tracker.cost += tracker.costCall(t, argVals)
}
case InterpretableConstructor:
switch t.Type() {
case types.ListType:
tracker.cost += common.ListCreateBaseCost
case types.MapType:
tracker.cost += common.MapCreateBaseCost
default:
tracker.cost += common.StructCreateBaseCost
}
}
tracker.stack.push(val)
if tracker.Limit != nil && tracker.cost > *tracker.Limit {
panic(EvalCancelledError{Cause: CostLimitExceeded, Message: "operation cancelled: actual cost limit exceeded"})
}
}
return observer
}
// CostTracker represents the information needed for tacking runtime cost
type CostTracker struct {
Estimator ActualCostEstimator
Limit *uint64
cost uint64
stack refValStack
}
// ActualCost returns the runtime cost
func (c CostTracker) ActualCost() uint64 {
return c.cost
}
func (c CostTracker) costCall(call InterpretableCall, argValues []ref.Val) uint64 {
var cost uint64
if c.Estimator != nil {
callCost := c.Estimator.CallCost(call.OverloadID(), argValues)
if callCost != nil {
cost += *callCost
return cost
}
}
// if user didn't specify, the default way of calculating runtime cost would be used.
// if user has their own implementation of ActualCostEstimator, make sure to cover the mapping between overloadId and cost calculation
switch call.OverloadID() {
// O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString:
cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor))
case overloads.InList:
// If a list is composed entirely of constant values this is O(1), but we don't account for that here.
// We just assume all list containment checks are O(n).
cost += c.actualSize(argValues[1])
// O(min(m, n)) functions
case overloads.LessString, overloads.GreaterString, overloads.LessEqualsString, overloads.GreaterEqualsString,
overloads.LessBytes, overloads.GreaterBytes, overloads.LessEqualsBytes, overloads.GreaterEqualsBytes,
overloads.Equals, overloads.NotEquals:
// When we check the equality of 2 scalar values (e.g. 2 integers, 2 floating-point numbers, 2 booleans etc.),
// the CostTracker.actualSize() function by definition returns 1 for each operand, resulting in an overall cost
// of 1.
lhsSize := c.actualSize(argValues[0])
rhsSize := c.actualSize(argValues[1])
minSize := lhsSize
if rhsSize < minSize {
minSize = rhsSize
}
cost += uint64(math.Ceil(float64(minSize) * common.StringTraversalCostFactor))
// O(m+n) functions
case overloads.AddString, overloads.AddBytes:
// In the worst case scenario, we would need to reallocate a new backing store and copy both operands over.
cost += uint64(math.Ceil(float64(c.actualSize(argValues[0])+c.actualSize(argValues[1])) * common.StringTraversalCostFactor))
// O(nm) functions
case overloads.MatchesString:
// https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL
// Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0
// in case where string is empty but regex is still expensive.
strCost := uint64(math.Ceil((1.0 + float64(c.actualSize(argValues[0]))) * common.StringTraversalCostFactor))
// We don't know how many expressions are in the regex, just the string length (a huge
// improvement here would be to somehow get a count the number of expressions in the regex or
// how many states are in the regex state machine and use that to measure regex cost).
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
// in length.
regexCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.RegexStringLengthCostFactor))
cost += strCost * regexCost
case overloads.ContainsString:
strCost := uint64(math.Ceil(float64(c.actualSize(argValues[0])) * common.StringTraversalCostFactor))
substrCost := uint64(math.Ceil(float64(c.actualSize(argValues[1])) * common.StringTraversalCostFactor))
cost += strCost * substrCost
default:
// The following operations are assumed to have O(1) complexity.
// - AddList due to the implementation. Index lookup can be O(c) the
// number of concatenated lists, but we don't track that is cost calculations.
// - Conversions, since none perform a traversal of a type of unbound length.
// - Computing the size of strings, byte sequences, lists and maps.
// - Logical operations and all operators on fixed width scalars (comparisons, equality)
// - Any functions that don't have a declared cost either here or in provided ActualCostEstimator.
cost += 1
}
return cost
}
// actualSize returns the size of value
func (c CostTracker) actualSize(value ref.Val) uint64 {
if sz, ok := value.(traits.Sizer); ok {
return uint64(sz.Size().(types.Int))
}
return 1
}
// refValStack keeps track of values of the stack for cost calculation purposes
type refValStack []ref.Val
func (s *refValStack) push(value ref.Val) {
*s = append(*s, value)
}
func (s *refValStack) pop(count int) ([]ref.Val, bool) {
if len(*s) < count {
return nil, false
}
idx := len(*s) - count
el := (*s)[idx:]
*s = (*s)[:idx]
return el, true
}

View File

@ -83,9 +83,9 @@ literal
| sign=MINUS? tok=NUM_FLOAT # Double
| tok=STRING # String
| tok=BYTES # Bytes
| tok='true' # BoolTrue
| tok='false' # BoolFalse
| tok='null' # Null
| tok=CEL_TRUE # BoolTrue
| tok=CEL_FALSE # BoolFalse
| tok=NUL # Null
;
// Lexer Rules
@ -117,9 +117,9 @@ PLUS : '+';
STAR : '*';
SLASH : '/';
PERCENT : '%';
TRUE : 'true';
FALSE : 'false';
NULL : 'null';
CEL_TRUE : 'true';
CEL_FALSE : 'false';
NUL : 'null';
fragment BACKSLASH : '\\';
fragment LETTER : 'A'..'Z' | 'a'..'z' ;

View File

@ -23,9 +23,9 @@ PLUS=22
STAR=23
SLASH=24
PERCENT=25
TRUE=26
FALSE=27
NULL=28
CEL_TRUE=26
CEL_FALSE=27
NUL=28
WHITESPACE=29
COMMENT=30
NUM_FLOAT=31

View File

@ -23,9 +23,9 @@ PLUS=22
STAR=23
SLASH=24
PERCENT=25
TRUE=26
FALSE=27
NULL=28
CEL_TRUE=26
CEL_FALSE=27
NUL=28
WHITESPACE=29
COMMENT=30
NUM_FLOAT=31

View File

@ -1,4 +1,4 @@
// Generated from /Users/tswadell/go/src/github.com/google/cel-go/bin/../parser/gen/CEL.g4 by ANTLR 4.7.
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/bin/../parser/gen/CEL.g4 by ANTLR 4.9.1. DO NOT EDIT.
package gen // CEL
import "github.com/antlr/antlr4/runtime/Go/antlr"

View File

@ -1,4 +1,4 @@
// Generated from /Users/tswadell/go/src/github.com/google/cel-go/bin/../parser/gen/CEL.g4 by ANTLR 4.7.
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/bin/../parser/gen/CEL.g4 by ANTLR 4.9.1. DO NOT EDIT.
package gen // CEL
import "github.com/antlr/antlr4/runtime/Go/antlr"

View File

@ -1,4 +1,4 @@
// Generated from /Users/tswadell/go/src/github.com/google/cel-go/bin/../parser/gen/CEL.g4 by ANTLR 4.7.
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/bin/../parser/gen/CEL.g4 by ANTLR 4.9.1. DO NOT EDIT.
package gen
@ -233,18 +233,19 @@ var lexerSymbolicNames = []string{
"", "EQUALS", "NOT_EQUALS", "IN", "LESS", "LESS_EQUALS", "GREATER_EQUALS",
"GREATER", "LOGICAL_AND", "LOGICAL_OR", "LBRACKET", "RPRACKET", "LBRACE",
"RBRACE", "LPAREN", "RPAREN", "DOT", "COMMA", "MINUS", "EXCLAM", "QUESTIONMARK",
"COLON", "PLUS", "STAR", "SLASH", "PERCENT", "TRUE", "FALSE", "NULL", "WHITESPACE",
"COMMENT", "NUM_FLOAT", "NUM_INT", "NUM_UINT", "STRING", "BYTES", "IDENTIFIER",
"COLON", "PLUS", "STAR", "SLASH", "PERCENT", "CEL_TRUE", "CEL_FALSE", "NUL",
"WHITESPACE", "COMMENT", "NUM_FLOAT", "NUM_INT", "NUM_UINT", "STRING",
"BYTES", "IDENTIFIER",
}
var lexerRuleNames = []string{
"EQUALS", "NOT_EQUALS", "IN", "LESS", "LESS_EQUALS", "GREATER_EQUALS",
"GREATER", "LOGICAL_AND", "LOGICAL_OR", "LBRACKET", "RPRACKET", "LBRACE",
"RBRACE", "LPAREN", "RPAREN", "DOT", "COMMA", "MINUS", "EXCLAM", "QUESTIONMARK",
"COLON", "PLUS", "STAR", "SLASH", "PERCENT", "TRUE", "FALSE", "NULL", "BACKSLASH",
"LETTER", "DIGIT", "EXPONENT", "HEXDIGIT", "RAW", "ESC_SEQ", "ESC_CHAR_SEQ",
"ESC_OCT_SEQ", "ESC_BYTE_SEQ", "ESC_UNI_SEQ", "WHITESPACE", "COMMENT",
"NUM_FLOAT", "NUM_INT", "NUM_UINT", "STRING", "BYTES", "IDENTIFIER",
"COLON", "PLUS", "STAR", "SLASH", "PERCENT", "CEL_TRUE", "CEL_FALSE", "NUL",
"BACKSLASH", "LETTER", "DIGIT", "EXPONENT", "HEXDIGIT", "RAW", "ESC_SEQ",
"ESC_CHAR_SEQ", "ESC_OCT_SEQ", "ESC_BYTE_SEQ", "ESC_UNI_SEQ", "WHITESPACE",
"COMMENT", "NUM_FLOAT", "NUM_INT", "NUM_UINT", "STRING", "BYTES", "IDENTIFIER",
}
type CELLexer struct {
@ -254,8 +255,13 @@ type CELLexer struct {
// TODO: EOF string
}
// NewCELLexer produces a new lexer instance for the optional input antlr.CharStream.
//
// The *CELLexer instance produced may be reused by calling the SetInputStream method.
// The initial lexer configuration is expensive to construct, and the object is not thread-safe;
// however, if used within a Golang sync.Pool, the construction cost amortizes well and the
// objects can be used in a thread-safe manner.
func NewCELLexer(input antlr.CharStream) *CELLexer {
l := new(CELLexer)
lexerDeserializer := antlr.NewATNDeserializer(nil)
lexerAtn := lexerDeserializer.DeserializeFromUInt16(serializedLexerAtn)
@ -263,7 +269,6 @@ func NewCELLexer(input antlr.CharStream) *CELLexer {
for index, ds := range lexerAtn.DecisionToState {
lexerDecisionToDFA[index] = antlr.NewDFA(ds, index)
}
l.BaseLexer = antlr.NewBaseLexer(input)
l.Interpreter = antlr.NewLexerATNSimulator(l, lexerAtn, lexerDecisionToDFA, antlr.NewPredictionContextCache())
@ -305,9 +310,9 @@ const (
CELLexerSTAR = 23
CELLexerSLASH = 24
CELLexerPERCENT = 25
CELLexerTRUE = 26
CELLexerFALSE = 27
CELLexerNULL = 28
CELLexerCEL_TRUE = 26
CELLexerCEL_FALSE = 27
CELLexerNUL = 28
CELLexerWHITESPACE = 29
CELLexerCOMMENT = 30
CELLexerNUM_FLOAT = 31

View File

@ -1,4 +1,4 @@
// Generated from /Users/tswadell/go/src/github.com/google/cel-go/bin/../parser/gen/CEL.g4 by ANTLR 4.7.
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/bin/../parser/gen/CEL.g4 by ANTLR 4.9.1. DO NOT EDIT.
package gen // CEL
import "github.com/antlr/antlr4/runtime/Go/antlr"

View File

@ -1,4 +1,4 @@
// Generated from /Users/tswadell/go/src/github.com/google/cel-go/bin/../parser/gen/CEL.g4 by ANTLR 4.7.
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/bin/../parser/gen/CEL.g4 by ANTLR 4.9.1. DO NOT EDIT.
package gen // CEL
import (
@ -112,7 +112,6 @@ var parserATN = []uint16{
2, 2, 31, 37, 44, 52, 63, 75, 77, 84, 90, 93, 103, 106, 116, 119, 122,
124, 128, 133, 136, 144, 147, 152, 155, 159, 166, 178, 191, 195, 200, 208,
}
var literalNames = []string{
"", "'=='", "'!='", "'in'", "'<'", "'<='", "'>='", "'>'", "'&&'", "'||'",
"'['", "']'", "'{'", "'}'", "'('", "')'", "'.'", "','", "'-'", "'!'", "'?'",
@ -122,8 +121,9 @@ var symbolicNames = []string{
"", "EQUALS", "NOT_EQUALS", "IN", "LESS", "LESS_EQUALS", "GREATER_EQUALS",
"GREATER", "LOGICAL_AND", "LOGICAL_OR", "LBRACKET", "RPRACKET", "LBRACE",
"RBRACE", "LPAREN", "RPAREN", "DOT", "COMMA", "MINUS", "EXCLAM", "QUESTIONMARK",
"COLON", "PLUS", "STAR", "SLASH", "PERCENT", "TRUE", "FALSE", "NULL", "WHITESPACE",
"COMMENT", "NUM_FLOAT", "NUM_INT", "NUM_UINT", "STRING", "BYTES", "IDENTIFIER",
"COLON", "PLUS", "STAR", "SLASH", "PERCENT", "CEL_TRUE", "CEL_FALSE", "NUL",
"WHITESPACE", "COMMENT", "NUM_FLOAT", "NUM_INT", "NUM_UINT", "STRING",
"BYTES", "IDENTIFIER",
}
var ruleNames = []string{
@ -136,6 +136,12 @@ type CELParser struct {
*antlr.BaseParser
}
// NewCELParser produces a new parser instance for the optional input antlr.TokenStream.
//
// The *CELParser instance produced may be reused by calling the SetInputStream method.
// The initial parser configuration is expensive to construct, and the object is not thread-safe;
// however, if used within a Golang sync.Pool, the construction cost amortizes well and the
// objects can be used in a thread-safe manner.
func NewCELParser(input antlr.TokenStream) *CELParser {
this := new(CELParser)
deserializer := antlr.NewATNDeserializer(nil)
@ -144,7 +150,6 @@ func NewCELParser(input antlr.TokenStream) *CELParser {
for index, ds := range deserializedATN.DecisionToState {
decisionToDFA[index] = antlr.NewDFA(ds, index)
}
this.BaseParser = antlr.NewBaseParser(input)
this.Interpreter = antlr.NewParserATNSimulator(this, deserializedATN, decisionToDFA, antlr.NewPredictionContextCache())
@ -184,9 +189,9 @@ const (
CELParserSTAR = 23
CELParserSLASH = 24
CELParserPERCENT = 25
CELParserTRUE = 26
CELParserFALSE = 27
CELParserNULL = 28
CELParserCEL_TRUE = 26
CELParserCEL_FALSE = 27
CELParserNUL = 28
CELParserWHITESPACE = 29
CELParserCOMMENT = 30
CELParserNUM_FLOAT = 31
@ -448,6 +453,14 @@ func (s *ExprContext) ConditionalOr(i int) IConditionalOrContext {
return t.(IConditionalOrContext)
}
func (s *ExprContext) COLON() antlr.TerminalNode {
return s.GetToken(CELParserCOLON, 0)
}
func (s *ExprContext) QUESTIONMARK() antlr.TerminalNode {
return s.GetToken(CELParserQUESTIONMARK, 0)
}
func (s *ExprContext) Expr() IExprContext {
var t = s.GetTypedRuleContext(reflect.TypeOf((*IExprContext)(nil)).Elem(), 0)
@ -669,6 +682,14 @@ func (s *ConditionalOrContext) ConditionalAnd(i int) IConditionalAndContext {
return t.(IConditionalAndContext)
}
func (s *ConditionalOrContext) AllLOGICAL_OR() []antlr.TerminalNode {
return s.GetTokens(CELParserLOGICAL_OR)
}
func (s *ConditionalOrContext) LOGICAL_OR(i int) antlr.TerminalNode {
return s.GetToken(CELParserLOGICAL_OR, i)
}
func (s *ConditionalOrContext) GetRuleContext() antlr.RuleContext {
return s
}
@ -874,6 +895,14 @@ func (s *ConditionalAndContext) Relation(i int) IRelationContext {
return t.(IRelationContext)
}
func (s *ConditionalAndContext) AllLOGICAL_AND() []antlr.TerminalNode {
return s.GetTokens(CELParserLOGICAL_AND)
}
func (s *ConditionalAndContext) LOGICAL_AND(i int) antlr.TerminalNode {
return s.GetToken(CELParserLOGICAL_AND, i)
}
func (s *ConditionalAndContext) GetRuleContext() antlr.RuleContext {
return s
}
@ -1045,6 +1074,34 @@ func (s *RelationContext) Relation(i int) IRelationContext {
return t.(IRelationContext)
}
func (s *RelationContext) LESS() antlr.TerminalNode {
return s.GetToken(CELParserLESS, 0)
}
func (s *RelationContext) LESS_EQUALS() antlr.TerminalNode {
return s.GetToken(CELParserLESS_EQUALS, 0)
}
func (s *RelationContext) GREATER_EQUALS() antlr.TerminalNode {
return s.GetToken(CELParserGREATER_EQUALS, 0)
}
func (s *RelationContext) GREATER() antlr.TerminalNode {
return s.GetToken(CELParserGREATER, 0)
}
func (s *RelationContext) EQUALS() antlr.TerminalNode {
return s.GetToken(CELParserEQUALS, 0)
}
func (s *RelationContext) NOT_EQUALS() antlr.TerminalNode {
return s.GetToken(CELParserNOT_EQUALS, 0)
}
func (s *RelationContext) IN() antlr.TerminalNode {
return s.GetToken(CELParserIN, 0)
}
func (s *RelationContext) GetRuleContext() antlr.RuleContext {
return s
}
@ -1131,21 +1188,23 @@ func (p *CELParser) relation(_p int) (localctx IRelationContext) {
if !(p.Precpred(p.GetParserRuleContext(), 1)) {
panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 1)", ""))
}
p.SetState(57)
{
p.SetState(57)
var _lt = p.GetTokenStream().LT(1)
var _lt = p.GetTokenStream().LT(1)
localctx.(*RelationContext).op = _lt
localctx.(*RelationContext).op = _lt
_la = p.GetTokenStream().LA(1)
_la = p.GetTokenStream().LA(1)
if !(((_la)&-(0x1f+1)) == 0 && ((1<<uint(_la))&((1<<CELParserEQUALS)|(1<<CELParserNOT_EQUALS)|(1<<CELParserIN)|(1<<CELParserLESS)|(1<<CELParserLESS_EQUALS)|(1<<CELParserGREATER_EQUALS)|(1<<CELParserGREATER))) != 0) {
var _ri = p.GetErrorHandler().RecoverInline(p)
if !(((_la)&-(0x1f+1)) == 0 && ((1<<uint(_la))&((1<<CELParserEQUALS)|(1<<CELParserNOT_EQUALS)|(1<<CELParserIN)|(1<<CELParserLESS)|(1<<CELParserLESS_EQUALS)|(1<<CELParserGREATER_EQUALS)|(1<<CELParserGREATER))) != 0) {
var _ri = p.GetErrorHandler().RecoverInline(p)
localctx.(*RelationContext).op = _ri
} else {
p.GetErrorHandler().ReportMatch(p)
p.Consume()
localctx.(*RelationContext).op = _ri
} else {
p.GetErrorHandler().ReportMatch(p)
p.Consume()
}
}
{
p.SetState(58)
@ -1243,6 +1302,26 @@ func (s *CalcContext) Calc(i int) ICalcContext {
return t.(ICalcContext)
}
func (s *CalcContext) STAR() antlr.TerminalNode {
return s.GetToken(CELParserSTAR, 0)
}
func (s *CalcContext) SLASH() antlr.TerminalNode {
return s.GetToken(CELParserSLASH, 0)
}
func (s *CalcContext) PERCENT() antlr.TerminalNode {
return s.GetToken(CELParserPERCENT, 0)
}
func (s *CalcContext) PLUS() antlr.TerminalNode {
return s.GetToken(CELParserPLUS, 0)
}
func (s *CalcContext) MINUS() antlr.TerminalNode {
return s.GetToken(CELParserMINUS, 0)
}
func (s *CalcContext) GetRuleContext() antlr.RuleContext {
return s
}
@ -1333,21 +1412,23 @@ func (p *CELParser) calc(_p int) (localctx ICalcContext) {
if !(p.Precpred(p.GetParserRuleContext(), 2)) {
panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 2)", ""))
}
p.SetState(68)
{
p.SetState(68)
var _lt = p.GetTokenStream().LT(1)
var _lt = p.GetTokenStream().LT(1)
localctx.(*CalcContext).op = _lt
localctx.(*CalcContext).op = _lt
_la = p.GetTokenStream().LA(1)
_la = p.GetTokenStream().LA(1)
if !(((_la)&-(0x1f+1)) == 0 && ((1<<uint(_la))&((1<<CELParserSTAR)|(1<<CELParserSLASH)|(1<<CELParserPERCENT))) != 0) {
var _ri = p.GetErrorHandler().RecoverInline(p)
if !(((_la)&-(0x1f+1)) == 0 && ((1<<uint(_la))&((1<<CELParserSTAR)|(1<<CELParserSLASH)|(1<<CELParserPERCENT))) != 0) {
var _ri = p.GetErrorHandler().RecoverInline(p)
localctx.(*CalcContext).op = _ri
} else {
p.GetErrorHandler().ReportMatch(p)
p.Consume()
localctx.(*CalcContext).op = _ri
} else {
p.GetErrorHandler().ReportMatch(p)
p.Consume()
}
}
{
p.SetState(69)
@ -1362,21 +1443,23 @@ func (p *CELParser) calc(_p int) (localctx ICalcContext) {
if !(p.Precpred(p.GetParserRuleContext(), 1)) {
panic(antlr.NewFailedPredicateException(p, "p.Precpred(p.GetParserRuleContext(), 1)", ""))
}
p.SetState(71)
{
p.SetState(71)
var _lt = p.GetTokenStream().LT(1)
var _lt = p.GetTokenStream().LT(1)
localctx.(*CalcContext).op = _lt
localctx.(*CalcContext).op = _lt
_la = p.GetTokenStream().LA(1)
_la = p.GetTokenStream().LA(1)
if !(_la == CELParserMINUS || _la == CELParserPLUS) {
var _ri = p.GetErrorHandler().RecoverInline(p)
if !(_la == CELParserMINUS || _la == CELParserPLUS) {
var _ri = p.GetErrorHandler().RecoverInline(p)
localctx.(*CalcContext).op = _ri
} else {
p.GetErrorHandler().ReportMatch(p)
p.Consume()
localctx.(*CalcContext).op = _ri
} else {
p.GetErrorHandler().ReportMatch(p)
p.Consume()
}
}
{
p.SetState(72)
@ -1482,6 +1565,14 @@ func (s *LogicalNotContext) Member() IMemberContext {
return t.(IMemberContext)
}
func (s *LogicalNotContext) AllEXCLAM() []antlr.TerminalNode {
return s.GetTokens(CELParserEXCLAM)
}
func (s *LogicalNotContext) EXCLAM(i int) antlr.TerminalNode {
return s.GetToken(CELParserEXCLAM, i)
}
func (s *LogicalNotContext) EnterRule(listener antlr.ParseTreeListener) {
if listenerT, ok := listener.(CELListener); ok {
listenerT.EnterLogicalNot(s)
@ -1592,6 +1683,14 @@ func (s *NegateContext) Member() IMemberContext {
return t.(IMemberContext)
}
func (s *NegateContext) AllMINUS() []antlr.TerminalNode {
return s.GetTokens(CELParserMINUS)
}
func (s *NegateContext) MINUS(i int) antlr.TerminalNode {
return s.GetToken(CELParserMINUS, i)
}
func (s *NegateContext) EnterRule(listener antlr.ParseTreeListener) {
if listenerT, ok := listener.(CELListener); ok {
listenerT.EnterNegate(s)
@ -1808,10 +1907,22 @@ func (s *SelectOrCallContext) Member() IMemberContext {
return t.(IMemberContext)
}
func (s *SelectOrCallContext) DOT() antlr.TerminalNode {
return s.GetToken(CELParserDOT, 0)
}
func (s *SelectOrCallContext) IDENTIFIER() antlr.TerminalNode {
return s.GetToken(CELParserIDENTIFIER, 0)
}
func (s *SelectOrCallContext) RPAREN() antlr.TerminalNode {
return s.GetToken(CELParserRPAREN, 0)
}
func (s *SelectOrCallContext) LPAREN() antlr.TerminalNode {
return s.GetToken(CELParserLPAREN, 0)
}
func (s *SelectOrCallContext) ExprList() IExprListContext {
var t = s.GetTypedRuleContext(reflect.TypeOf((*IExprListContext)(nil)).Elem(), 0)
@ -1932,6 +2043,14 @@ func (s *IndexContext) Member() IMemberContext {
return t.(IMemberContext)
}
func (s *IndexContext) RPRACKET() antlr.TerminalNode {
return s.GetToken(CELParserRPRACKET, 0)
}
func (s *IndexContext) LBRACKET() antlr.TerminalNode {
return s.GetToken(CELParserLBRACKET, 0)
}
func (s *IndexContext) Expr() IExprContext {
var t = s.GetTypedRuleContext(reflect.TypeOf((*IExprContext)(nil)).Elem(), 0)
@ -2002,6 +2121,18 @@ func (s *CreateMessageContext) Member() IMemberContext {
return t.(IMemberContext)
}
func (s *CreateMessageContext) RBRACE() antlr.TerminalNode {
return s.GetToken(CELParserRBRACE, 0)
}
func (s *CreateMessageContext) LBRACE() antlr.TerminalNode {
return s.GetToken(CELParserLBRACE, 0)
}
func (s *CreateMessageContext) COMMA() antlr.TerminalNode {
return s.GetToken(CELParserCOMMA, 0)
}
func (s *CreateMessageContext) FieldInitializerList() IFieldInitializerListContext {
var t = s.GetTypedRuleContext(reflect.TypeOf((*IFieldInitializerListContext)(nil)).Elem(), 0)
@ -2127,7 +2258,7 @@ func (p *CELParser) member(_p int) (localctx IMemberContext) {
p.GetErrorHandler().Sync(p)
_la = p.GetTokenStream().LA(1)
if ((_la-10)&-(0x1f+1)) == 0 && ((1<<uint((_la-10)))&((1<<(CELParserLBRACKET-10))|(1<<(CELParserLBRACE-10))|(1<<(CELParserLPAREN-10))|(1<<(CELParserDOT-10))|(1<<(CELParserMINUS-10))|(1<<(CELParserEXCLAM-10))|(1<<(CELParserTRUE-10))|(1<<(CELParserFALSE-10))|(1<<(CELParserNULL-10))|(1<<(CELParserNUM_FLOAT-10))|(1<<(CELParserNUM_INT-10))|(1<<(CELParserNUM_UINT-10))|(1<<(CELParserSTRING-10))|(1<<(CELParserBYTES-10))|(1<<(CELParserIDENTIFIER-10)))) != 0 {
if ((_la-10)&-(0x1f+1)) == 0 && ((1<<uint((_la-10)))&((1<<(CELParserLBRACKET-10))|(1<<(CELParserLBRACE-10))|(1<<(CELParserLPAREN-10))|(1<<(CELParserDOT-10))|(1<<(CELParserMINUS-10))|(1<<(CELParserEXCLAM-10))|(1<<(CELParserCEL_TRUE-10))|(1<<(CELParserCEL_FALSE-10))|(1<<(CELParserNUL-10))|(1<<(CELParserNUM_FLOAT-10))|(1<<(CELParserNUM_INT-10))|(1<<(CELParserNUM_UINT-10))|(1<<(CELParserSTRING-10))|(1<<(CELParserBYTES-10))|(1<<(CELParserIDENTIFIER-10)))) != 0 {
{
p.SetState(100)
@ -2305,6 +2436,18 @@ func (s *CreateListContext) GetRuleContext() antlr.RuleContext {
return s
}
func (s *CreateListContext) RPRACKET() antlr.TerminalNode {
return s.GetToken(CELParserRPRACKET, 0)
}
func (s *CreateListContext) LBRACKET() antlr.TerminalNode {
return s.GetToken(CELParserLBRACKET, 0)
}
func (s *CreateListContext) COMMA() antlr.TerminalNode {
return s.GetToken(CELParserCOMMA, 0)
}
func (s *CreateListContext) ExprList() IExprListContext {
var t = s.GetTypedRuleContext(reflect.TypeOf((*IExprListContext)(nil)).Elem(), 0)
@ -2365,6 +2508,18 @@ func (s *CreateStructContext) GetRuleContext() antlr.RuleContext {
return s
}
func (s *CreateStructContext) RBRACE() antlr.TerminalNode {
return s.GetToken(CELParserRBRACE, 0)
}
func (s *CreateStructContext) LBRACE() antlr.TerminalNode {
return s.GetToken(CELParserLBRACE, 0)
}
func (s *CreateStructContext) COMMA() antlr.TerminalNode {
return s.GetToken(CELParserCOMMA, 0)
}
func (s *CreateStructContext) MapInitializerList() IMapInitializerListContext {
var t = s.GetTypedRuleContext(reflect.TypeOf((*IMapInitializerListContext)(nil)).Elem(), 0)
@ -2470,6 +2625,14 @@ func (s *NestedContext) GetRuleContext() antlr.RuleContext {
return s
}
func (s *NestedContext) LPAREN() antlr.TerminalNode {
return s.GetToken(CELParserLPAREN, 0)
}
func (s *NestedContext) RPAREN() antlr.TerminalNode {
return s.GetToken(CELParserRPAREN, 0)
}
func (s *NestedContext) Expr() IExprContext {
var t = s.GetTypedRuleContext(reflect.TypeOf((*IExprContext)(nil)).Elem(), 0)
@ -2544,6 +2707,18 @@ func (s *IdentOrGlobalCallContext) IDENTIFIER() antlr.TerminalNode {
return s.GetToken(CELParserIDENTIFIER, 0)
}
func (s *IdentOrGlobalCallContext) RPAREN() antlr.TerminalNode {
return s.GetToken(CELParserRPAREN, 0)
}
func (s *IdentOrGlobalCallContext) DOT() antlr.TerminalNode {
return s.GetToken(CELParserDOT, 0)
}
func (s *IdentOrGlobalCallContext) LPAREN() antlr.TerminalNode {
return s.GetToken(CELParserLPAREN, 0)
}
func (s *IdentOrGlobalCallContext) ExprList() IExprListContext {
var t = s.GetTypedRuleContext(reflect.TypeOf((*IExprListContext)(nil)).Elem(), 0)
@ -2640,7 +2815,7 @@ func (p *CELParser) Primary() (localctx IPrimaryContext) {
p.GetErrorHandler().Sync(p)
_la = p.GetTokenStream().LA(1)
if ((_la-10)&-(0x1f+1)) == 0 && ((1<<uint((_la-10)))&((1<<(CELParserLBRACKET-10))|(1<<(CELParserLBRACE-10))|(1<<(CELParserLPAREN-10))|(1<<(CELParserDOT-10))|(1<<(CELParserMINUS-10))|(1<<(CELParserEXCLAM-10))|(1<<(CELParserTRUE-10))|(1<<(CELParserFALSE-10))|(1<<(CELParserNULL-10))|(1<<(CELParserNUM_FLOAT-10))|(1<<(CELParserNUM_INT-10))|(1<<(CELParserNUM_UINT-10))|(1<<(CELParserSTRING-10))|(1<<(CELParserBYTES-10))|(1<<(CELParserIDENTIFIER-10)))) != 0 {
if ((_la-10)&-(0x1f+1)) == 0 && ((1<<uint((_la-10)))&((1<<(CELParserLBRACKET-10))|(1<<(CELParserLBRACE-10))|(1<<(CELParserLPAREN-10))|(1<<(CELParserDOT-10))|(1<<(CELParserMINUS-10))|(1<<(CELParserEXCLAM-10))|(1<<(CELParserCEL_TRUE-10))|(1<<(CELParserCEL_FALSE-10))|(1<<(CELParserNUL-10))|(1<<(CELParserNUM_FLOAT-10))|(1<<(CELParserNUM_INT-10))|(1<<(CELParserNUM_UINT-10))|(1<<(CELParserSTRING-10))|(1<<(CELParserBYTES-10))|(1<<(CELParserIDENTIFIER-10)))) != 0 {
{
p.SetState(130)
@ -2690,7 +2865,7 @@ func (p *CELParser) Primary() (localctx IPrimaryContext) {
p.GetErrorHandler().Sync(p)
_la = p.GetTokenStream().LA(1)
if ((_la-10)&-(0x1f+1)) == 0 && ((1<<uint((_la-10)))&((1<<(CELParserLBRACKET-10))|(1<<(CELParserLBRACE-10))|(1<<(CELParserLPAREN-10))|(1<<(CELParserDOT-10))|(1<<(CELParserMINUS-10))|(1<<(CELParserEXCLAM-10))|(1<<(CELParserTRUE-10))|(1<<(CELParserFALSE-10))|(1<<(CELParserNULL-10))|(1<<(CELParserNUM_FLOAT-10))|(1<<(CELParserNUM_INT-10))|(1<<(CELParserNUM_UINT-10))|(1<<(CELParserSTRING-10))|(1<<(CELParserBYTES-10))|(1<<(CELParserIDENTIFIER-10)))) != 0 {
if ((_la-10)&-(0x1f+1)) == 0 && ((1<<uint((_la-10)))&((1<<(CELParserLBRACKET-10))|(1<<(CELParserLBRACE-10))|(1<<(CELParserLPAREN-10))|(1<<(CELParserDOT-10))|(1<<(CELParserMINUS-10))|(1<<(CELParserEXCLAM-10))|(1<<(CELParserCEL_TRUE-10))|(1<<(CELParserCEL_FALSE-10))|(1<<(CELParserNUL-10))|(1<<(CELParserNUM_FLOAT-10))|(1<<(CELParserNUM_INT-10))|(1<<(CELParserNUM_UINT-10))|(1<<(CELParserSTRING-10))|(1<<(CELParserBYTES-10))|(1<<(CELParserIDENTIFIER-10)))) != 0 {
{
p.SetState(141)
@ -2730,7 +2905,7 @@ func (p *CELParser) Primary() (localctx IPrimaryContext) {
p.GetErrorHandler().Sync(p)
_la = p.GetTokenStream().LA(1)
if ((_la-10)&-(0x1f+1)) == 0 && ((1<<uint((_la-10)))&((1<<(CELParserLBRACKET-10))|(1<<(CELParserLBRACE-10))|(1<<(CELParserLPAREN-10))|(1<<(CELParserDOT-10))|(1<<(CELParserMINUS-10))|(1<<(CELParserEXCLAM-10))|(1<<(CELParserTRUE-10))|(1<<(CELParserFALSE-10))|(1<<(CELParserNULL-10))|(1<<(CELParserNUM_FLOAT-10))|(1<<(CELParserNUM_INT-10))|(1<<(CELParserNUM_UINT-10))|(1<<(CELParserSTRING-10))|(1<<(CELParserBYTES-10))|(1<<(CELParserIDENTIFIER-10)))) != 0 {
if ((_la-10)&-(0x1f+1)) == 0 && ((1<<uint((_la-10)))&((1<<(CELParserLBRACKET-10))|(1<<(CELParserLBRACE-10))|(1<<(CELParserLPAREN-10))|(1<<(CELParserDOT-10))|(1<<(CELParserMINUS-10))|(1<<(CELParserEXCLAM-10))|(1<<(CELParserCEL_TRUE-10))|(1<<(CELParserCEL_FALSE-10))|(1<<(CELParserNUL-10))|(1<<(CELParserNUM_FLOAT-10))|(1<<(CELParserNUM_INT-10))|(1<<(CELParserNUM_UINT-10))|(1<<(CELParserSTRING-10))|(1<<(CELParserBYTES-10))|(1<<(CELParserIDENTIFIER-10)))) != 0 {
{
p.SetState(149)
@ -2756,7 +2931,7 @@ func (p *CELParser) Primary() (localctx IPrimaryContext) {
p.Match(CELParserRBRACE)
}
case CELParserMINUS, CELParserTRUE, CELParserFALSE, CELParserNULL, CELParserNUM_FLOAT, CELParserNUM_INT, CELParserNUM_UINT, CELParserSTRING, CELParserBYTES:
case CELParserMINUS, CELParserCEL_TRUE, CELParserCEL_FALSE, CELParserNUL, CELParserNUM_FLOAT, CELParserNUM_INT, CELParserNUM_UINT, CELParserSTRING, CELParserBYTES:
localctx = NewConstantLiteralContext(p, localctx)
p.EnterOuterAlt(localctx, 5)
{
@ -2854,6 +3029,14 @@ func (s *ExprListContext) Expr(i int) IExprContext {
return t.(IExprContext)
}
func (s *ExprListContext) AllCOMMA() []antlr.TerminalNode {
return s.GetTokens(CELParserCOMMA)
}
func (s *ExprListContext) COMMA(i int) antlr.TerminalNode {
return s.GetToken(CELParserCOMMA, i)
}
func (s *ExprListContext) GetRuleContext() antlr.RuleContext {
return s
}
@ -3055,6 +3238,14 @@ func (s *FieldInitializerListContext) IDENTIFIER(i int) antlr.TerminalNode {
return s.GetToken(CELParserIDENTIFIER, i)
}
func (s *FieldInitializerListContext) AllCOLON() []antlr.TerminalNode {
return s.GetTokens(CELParserCOLON)
}
func (s *FieldInitializerListContext) COLON(i int) antlr.TerminalNode {
return s.GetToken(CELParserCOLON, i)
}
func (s *FieldInitializerListContext) AllExpr() []IExprContext {
var ts = s.GetTypedRuleContexts(reflect.TypeOf((*IExprContext)(nil)).Elem())
var tst = make([]IExprContext, len(ts))
@ -3078,6 +3269,14 @@ func (s *FieldInitializerListContext) Expr(i int) IExprContext {
return t.(IExprContext)
}
func (s *FieldInitializerListContext) AllCOMMA() []antlr.TerminalNode {
return s.GetTokens(CELParserCOMMA)
}
func (s *FieldInitializerListContext) COMMA(i int) antlr.TerminalNode {
return s.GetToken(CELParserCOMMA, i)
}
func (s *FieldInitializerListContext) GetRuleContext() antlr.RuleContext {
return s
}
@ -3315,6 +3514,22 @@ func (s *MapInitializerListContext) Expr(i int) IExprContext {
return t.(IExprContext)
}
func (s *MapInitializerListContext) AllCOLON() []antlr.TerminalNode {
return s.GetTokens(CELParserCOLON)
}
func (s *MapInitializerListContext) COLON(i int) antlr.TerminalNode {
return s.GetToken(CELParserCOLON, i)
}
func (s *MapInitializerListContext) AllCOMMA() []antlr.TerminalNode {
return s.GetTokens(CELParserCOMMA)
}
func (s *MapInitializerListContext) COMMA(i int) antlr.TerminalNode {
return s.GetToken(CELParserCOMMA, i)
}
func (s *MapInitializerListContext) GetRuleContext() antlr.RuleContext {
return s
}
@ -3607,6 +3822,10 @@ func (s *NullContext) GetRuleContext() antlr.RuleContext {
return s
}
func (s *NullContext) NUL() antlr.TerminalNode {
return s.GetToken(CELParserNUL, 0)
}
func (s *NullContext) EnterRule(listener antlr.ParseTreeListener) {
if listenerT, ok := listener.(CELListener); ok {
listenerT.EnterNull(s)
@ -3652,6 +3871,10 @@ func (s *BoolFalseContext) GetRuleContext() antlr.RuleContext {
return s
}
func (s *BoolFalseContext) CEL_FALSE() antlr.TerminalNode {
return s.GetToken(CELParserCEL_FALSE, 0)
}
func (s *BoolFalseContext) EnterRule(listener antlr.ParseTreeListener) {
if listenerT, ok := listener.(CELListener); ok {
listenerT.EnterBoolFalse(s)
@ -3804,6 +4027,10 @@ func (s *BoolTrueContext) GetRuleContext() antlr.RuleContext {
return s
}
func (s *BoolTrueContext) CEL_TRUE() antlr.TerminalNode {
return s.GetToken(CELParserCEL_TRUE, 0)
}
func (s *BoolTrueContext) EnterRule(listener antlr.ParseTreeListener) {
if listenerT, ok := listener.(CELListener); ok {
listenerT.EnterBoolTrue(s)
@ -3997,7 +4224,7 @@ func (p *CELParser) Literal() (localctx ILiteralContext) {
{
p.SetState(203)
var _m = p.Match(CELParserTRUE)
var _m = p.Match(CELParserCEL_TRUE)
localctx.(*BoolTrueContext).tok = _m
}
@ -4008,7 +4235,7 @@ func (p *CELParser) Literal() (localctx ILiteralContext) {
{
p.SetState(204)
var _m = p.Match(CELParserFALSE)
var _m = p.Match(CELParserCEL_FALSE)
localctx.(*BoolFalseContext).tok = _m
}
@ -4019,7 +4246,7 @@ func (p *CELParser) Literal() (localctx ILiteralContext) {
{
p.SetState(205)
var _m = p.Match(CELParserNULL)
var _m = p.Match(CELParserNUL)
localctx.(*NullContext).tok = _m
}

View File

@ -1,4 +1,4 @@
// Generated from /Users/tswadell/go/src/github.com/google/cel-go/bin/../parser/gen/CEL.g4 by ANTLR 4.7.
// Code generated from /Users/tswadell/go/src/github.com/google/cel-go/bin/../parser/gen/CEL.g4 by ANTLR 4.9.1. DO NOT EDIT.
package gen // CEL
import "github.com/antlr/antlr4/runtime/Go/antlr"

View File

@ -241,6 +241,20 @@ func (p *parserHelper) buildMacroCallArg(expr *exprpb.Expr) *exprpb.Expr {
},
},
}
case *exprpb.Expr_ListExpr:
listExpr := expr.GetListExpr()
macroListArgs := make([]*exprpb.Expr, len(listExpr.GetElements()))
for i, elem := range listExpr.GetElements() {
macroListArgs[i] = p.buildMacroCallArg(elem)
}
return &exprpb.Expr{
Id: expr.GetId(),
ExprKind: &exprpb.Expr_ListExpr{
ListExpr: &exprpb.Expr_CreateList{
Elements: macroListArgs,
},
},
}
}
return expr
@ -253,6 +267,8 @@ func (p *parserHelper) addMacroCall(exprID int64, function string, target *exprp
if target != nil {
if _, found := p.macroCalls[target.GetId()]; found {
macroTarget = &exprpb.Expr{Id: target.GetId()}
} else {
macroTarget = p.buildMacroCallArg(target)
}
}

View File

@ -142,14 +142,7 @@ var reservedIds = map[string]struct{}{
//
// Deprecated: Use NewParser().Parse() instead.
func Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors) {
return ParseWithMacros(source, AllMacros)
}
// ParseWithMacros converts a source input and macros set to a parsed expression.
//
// Deprecated: Use NewParser().Parse() instead.
func ParseWithMacros(source common.Source, macros []Macro) (*exprpb.ParsedExpr, *common.Errors) {
return mustNewParser(Macros(macros...)).Parse(source)
return mustNewParser(Macros(AllMacros...)).Parse(source)
}
type recursionError struct {
@ -304,6 +297,7 @@ var (
)
func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr {
// TODO: get rid of these pools once https://github.com/antlr/antlr4/pull/3571 is in a release
lexer := lexerPool.Get().(*gen.CELLexer)
prsr := parserPool.Get().(*gen.CELParser)

4
vendor/modules.txt vendored
View File

@ -371,7 +371,7 @@ github.com/google/cadvisor/utils/sysfs
github.com/google/cadvisor/utils/sysinfo
github.com/google/cadvisor/version
github.com/google/cadvisor/watcher
# github.com/google/cel-go v0.9.0 => github.com/google/cel-go v0.9.0
# github.com/google/cel-go v0.10.0 => github.com/google/cel-go v0.10.0
github.com/google/cel-go/cel
github.com/google/cel-go/checker
github.com/google/cel-go/checker/decls
@ -2541,7 +2541,7 @@ sigs.k8s.io/yaml
# github.com/golangplus/testing => github.com/golangplus/testing v0.0.0-20180327235837-af21d9c3145e
# github.com/google/btree => github.com/google/btree v1.0.1
# github.com/google/cadvisor => github.com/google/cadvisor v0.43.0
# github.com/google/cel-go => github.com/google/cel-go v0.9.0
# github.com/google/cel-go => github.com/google/cel-go v0.10.0
# github.com/google/cel-spec => github.com/google/cel-spec v0.6.0
# github.com/google/go-cmp => github.com/google/go-cmp v0.5.5
# github.com/google/gofuzz => github.com/google/gofuzz v1.1.0