Switch to use cel.TypeToExprType(celType) to generate the exprType.

This commit is contained in:
Cici Huang 2022-07-13 20:19:59 +00:00
parent 6883198898
commit 4c726155a2
5 changed files with 88 additions and 68 deletions

View File

@ -103,15 +103,35 @@ var paramA = cel.TypeParamType("A")
// CEL typeParams can be used to constraint to a specific trait (e.g. traits.ComparableType) if the 1st operand is the type to constrain. // CEL typeParams can be used to constraint to a specific trait (e.g. traits.ComparableType) if the 1st operand is the type to constrain.
// But the functions we need to constrain are <list<paramType>>, not just <paramType>. // But the functions we need to constrain are <list<paramType>>, not just <paramType>.
var summableTypes = map[string]*cel.Type{"int": cel.IntType, "uint": cel.UintType, "double": cel.DoubleType, "duration": cel.DurationType} // Make sure the order of overload set is deterministic
type namedCELType struct {
typeName string
celType *cel.Type
}
var summableTypes = []namedCELType{
{typeName: "int", celType: cel.IntType},
{typeName: "uint", celType: cel.UintType},
{typeName: "double", celType: cel.DoubleType},
{typeName: "duration", celType: cel.DurationType},
}
var zeroValuesOfSummableTypes = map[string]ref.Val{ var zeroValuesOfSummableTypes = map[string]ref.Val{
"int": types.Int(0), "int": types.Int(0),
"uint": types.Uint(0), "uint": types.Uint(0),
"double": types.Double(0.0), "double": types.Double(0.0),
"duration": types.Duration{Duration: 0}, "duration": types.Duration{Duration: 0},
} }
var comparableTypes = map[string]*cel.Type{"bool": cel.BoolType, "int": cel.IntType, "uint": cel.UintType, "double": cel.DoubleType, var comparableTypes = []namedCELType{
"duration": cel.DurationType, "timestamp": cel.TimestampType, "string": cel.StringType, "bytes": cel.BytesType} {typeName: "int", celType: cel.IntType},
{typeName: "uint", celType: cel.UintType},
{typeName: "double", celType: cel.DoubleType},
{typeName: "bool", celType: cel.BoolType},
{typeName: "duration", celType: cel.DurationType},
{typeName: "timestamp", celType: cel.TimestampType},
{typeName: "string", celType: cel.StringType},
{typeName: "bytes", celType: cel.BytesType},
}
// WARNING: All library additions or modifications must follow // WARNING: All library additions or modifications must follow
// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/2876-crd-validation-expression-language#function-library-updates // https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/2876-crd-validation-expression-language#function-library-updates
@ -285,11 +305,11 @@ func lastIndexOf(list ref.Val, item ref.Val) ref.Val {
// templatedOverloads returns overloads for each of the provided types. The template function is called with each type // templatedOverloads returns overloads for each of the provided types. The template function is called with each type
// name (map key) and type to construct the overloads. // name (map key) and type to construct the overloads.
func templatedOverloads(types map[string]*cel.Type, template func(name string, t *cel.Type) cel.FunctionOpt) []cel.FunctionOpt { func templatedOverloads(types []namedCELType, template func(name string, t *cel.Type) cel.FunctionOpt) []cel.FunctionOpt {
overloads := make([]cel.FunctionOpt, len(types)) overloads := make([]cel.FunctionOpt, len(types))
i := 0 i := 0
for name, t := range types { for _, t := range types {
overloads[i] = template(name, t) overloads[i] = template(t.typeName, t.celType)
i++ i++
} }
return overloads return overloads

View File

@ -18,7 +18,6 @@ import (
"github.com/google/cel-go/cel" "github.com/google/cel-go/cel"
"time" "time"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types"
"k8s.io/apiextensions-apiserver/pkg/apiserver/schema" "k8s.io/apiextensions-apiserver/pkg/apiserver/schema"
@ -73,7 +72,7 @@ func SchemaDeclType(s *schema.Structural, isResourceRoot bool) *DeclType {
// To validate requirements on both the int and string representation: // To validate requirements on both the int and string representation:
// `type(intOrStringField) == int ? intOrStringField < 5 : double(intOrStringField.replace('%', '')) < 0.5 // `type(intOrStringField) == int ? intOrStringField < 5 : double(intOrStringField.replace('%', '')) < 0.5
// //
dyn := newSimpleType("dyn", decls.Dyn, cel.DynType, nil) dyn := newSimpleType("dyn", cel.DynType, nil)
// handle x-kubernetes-int-or-string by returning the max length of the largest possible string // handle x-kubernetes-int-or-string by returning the max length of the largest possible string
dyn.MaxElements = maxRequestSizeBytes - 2 dyn.MaxElements = maxRequestSizeBytes - 2
return dyn return dyn
@ -150,7 +149,7 @@ func SchemaDeclType(s *schema.Structural, isResourceRoot bool) *DeclType {
if s.ValueValidation != nil { if s.ValueValidation != nil {
switch s.ValueValidation.Format { switch s.ValueValidation.Format {
case "byte": case "byte":
byteWithMaxLength := newSimpleType("bytes", decls.Bytes, cel.BytesType, types.Bytes([]byte{})) byteWithMaxLength := newSimpleType("bytes", cel.BytesType, types.Bytes([]byte{}))
if s.ValueValidation.MaxLength != nil { if s.ValueValidation.MaxLength != nil {
byteWithMaxLength.MaxElements = zeroIfNegative(*s.ValueValidation.MaxLength) byteWithMaxLength.MaxElements = zeroIfNegative(*s.ValueValidation.MaxLength)
} else { } else {
@ -158,16 +157,16 @@ func SchemaDeclType(s *schema.Structural, isResourceRoot bool) *DeclType {
} }
return byteWithMaxLength return byteWithMaxLength
case "duration": case "duration":
durationWithMaxLength := newSimpleType("duration", decls.Duration, cel.DurationType, types.Duration{Duration: time.Duration(0)}) durationWithMaxLength := newSimpleType("duration", cel.DurationType, types.Duration{Duration: time.Duration(0)})
durationWithMaxLength.MaxElements = estimateMaxStringLengthPerRequest(s) durationWithMaxLength.MaxElements = estimateMaxStringLengthPerRequest(s)
return durationWithMaxLength return durationWithMaxLength
case "date", "date-time": case "date", "date-time":
timestampWithMaxLength := newSimpleType("timestamp", decls.Timestamp, cel.TimestampType, types.Timestamp{Time: time.Time{}}) timestampWithMaxLength := newSimpleType("timestamp", cel.TimestampType, types.Timestamp{Time: time.Time{}})
timestampWithMaxLength.MaxElements = estimateMaxStringLengthPerRequest(s) timestampWithMaxLength.MaxElements = estimateMaxStringLengthPerRequest(s)
return timestampWithMaxLength return timestampWithMaxLength
} }
} }
strWithMaxLength := newSimpleType("string", decls.String, cel.StringType, types.String("")) strWithMaxLength := newSimpleType("string", cel.StringType, types.String(""))
if s.ValueValidation != nil && s.ValueValidation.MaxLength != nil { if s.ValueValidation != nil && s.ValueValidation.MaxLength != nil {
// multiply the user-provided max length by 4 in the case of an otherwise-untyped string // multiply the user-provided max length by 4 in the case of an otherwise-untyped string
// we do this because the OpenAPIv3 spec indicates that maxLength is specified in runes/code points, // we do this because the OpenAPIv3 spec indicates that maxLength is specified in runes/code points,

View File

@ -107,8 +107,16 @@ func TestSchemaDeclTypes(t *testing.T) {
t.Errorf("missing type in rule types: %s", exp) t.Errorf("missing type in rule types: %s", exp)
continue continue
} }
if !proto.Equal(expType.ExprType(), actType.ExprType()) { expT, err := expType.ExprType()
t.Errorf("incompatible CEL types. got=%v, wanted=%v", actType.ExprType(), expType.ExprType()) if err != nil {
t.Errorf("fail to get cel type: %s", err)
}
actT, err := actType.ExprType()
if err != nil {
t.Errorf("fail to get cel type: %s", err)
}
if !proto.Equal(expT, actT) {
t.Errorf("incompatible CEL types. got=%v, wanted=%v", expT, actT)
} }
} }
} }

View File

@ -20,7 +20,6 @@ import (
"time" "time"
"github.com/google/cel-go/cel" "github.com/google/cel-go/cel"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/common/types/traits"
@ -41,7 +40,6 @@ func NewListType(elem *DeclType, maxItems int64) *DeclType {
name: "list", name: "list",
ElemType: elem, ElemType: elem,
MaxElements: maxItems, MaxElements: maxItems,
exprType: decls.NewListType(elem.ExprType()),
celType: cel.ListType(elem.CelType()), celType: cel.ListType(elem.CelType()),
defaultValue: NewListValue(), defaultValue: NewListValue(),
} }
@ -54,7 +52,6 @@ func NewMapType(key, elem *DeclType, maxProperties int64) *DeclType {
KeyType: key, KeyType: key,
ElemType: elem, ElemType: elem,
MaxElements: maxProperties, MaxElements: maxProperties,
exprType: decls.NewMapType(key.ExprType(), elem.ExprType()),
celType: cel.MapType(key.CelType(), elem.CelType()), celType: cel.MapType(key.CelType(), elem.CelType()),
defaultValue: NewMapValue(), defaultValue: NewMapValue(),
} }
@ -65,7 +62,6 @@ func NewObjectType(name string, fields map[string]*DeclField) *DeclType {
t := &DeclType{ t := &DeclType{
name: name, name: name,
Fields: fields, Fields: fields,
exprType: decls.NewObjectType(name),
celType: cel.ObjectType(name), celType: cel.ObjectType(name),
traitMask: traits.FieldTesterType | traits.IndexerType, traitMask: traits.FieldTesterType | traits.IndexerType,
} }
@ -73,34 +69,9 @@ func NewObjectType(name string, fields map[string]*DeclField) *DeclType {
return t return t
} }
// NewObjectTypeRef returns a reference to an object type by name func newSimpleType(name string, celType *cel.Type, zeroVal ref.Val) *DeclType {
func NewObjectTypeRef(name string) *DeclType {
t := &DeclType{
name: name,
exprType: decls.NewObjectType(name),
celType: cel.ObjectType(name),
traitMask: traits.FieldTesterType | traits.IndexerType,
}
return t
}
// NewTypeParam creates a type parameter type with a simple name.
//
// Type parameters are resolved at compilation time to concrete types, or CEL 'dyn' type if no
// type assignment can be inferred.
func NewTypeParam(name string) *DeclType {
return &DeclType{ return &DeclType{
name: name, name: name,
TypeParam: true,
exprType: decls.NewTypeParamType(name),
celType: cel.TypeParamType(name),
}
}
func newSimpleType(name string, exprType *exprpb.Type, celType *cel.Type, zeroVal ref.Val) *DeclType {
return &DeclType{
name: name,
exprType: exprType,
celType: celType, celType: celType,
defaultValue: zeroVal, defaultValue: zeroVal,
} }
@ -118,7 +89,6 @@ type DeclType struct {
Metadata map[string]string Metadata map[string]string
MaxElements int64 MaxElements int64
exprType *exprpb.Type
celType *cel.Type celType *cel.Type
traitMask int traitMask int
defaultValue ref.Val defaultValue ref.Val
@ -164,7 +134,6 @@ func (t *DeclType) MaybeAssignTypeName(name string) *DeclType {
ElemType: t.ElemType, ElemType: t.ElemType,
TypeParam: t.TypeParam, TypeParam: t.TypeParam,
Metadata: t.Metadata, Metadata: t.Metadata,
exprType: decls.NewObjectType(name),
celType: cel.ObjectType(name), celType: cel.ObjectType(name),
traitMask: t.traitMask, traitMask: t.traitMask,
defaultValue: t.defaultValue, defaultValue: t.defaultValue,
@ -190,8 +159,8 @@ func (t *DeclType) MaybeAssignTypeName(name string) *DeclType {
} }
// ExprType returns the CEL expression type of this declaration. // ExprType returns the CEL expression type of this declaration.
func (t *DeclType) ExprType() *exprpb.Type { func (t *DeclType) ExprType() (*exprpb.Type, error) {
return t.exprType return cel.TypeToExprType(t.celType)
} }
// CelType returns the CEL type of this declaration. // CelType returns the CEL type of this declaration.
@ -382,7 +351,11 @@ func (rt *RuleTypes) EnvOptions(tp ref.TypeProvider) ([]cel.EnvOption, error) {
} }
for name, declType := range rt.ruleSchemaDeclTypes.types { for name, declType := range rt.ruleSchemaDeclTypes.types {
tpType, found := tp.FindType(name) tpType, found := tp.FindType(name)
if found && !proto.Equal(tpType, declType.ExprType()) { expT, err := declType.ExprType()
if err != nil {
return nil, fmt.Errorf("fail to get cel type: %s", err)
}
if found && !proto.Equal(tpType, expT) {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"type %s definition differs between CEL environment and rule", name) "type %s definition differs between CEL environment and rule", name)
} }
@ -407,7 +380,11 @@ func (rt *RuleTypes) FindType(typeName string) (*exprpb.Type, bool) {
} }
declType, found := rt.findDeclType(typeName) declType, found := rt.findDeclType(typeName)
if found { if found {
return declType.ExprType(), found expT, err := declType.ExprType()
if err != nil {
return expT, false
}
return expT, found
} }
return rt.TypeProvider.FindType(typeName) return rt.TypeProvider.FindType(typeName)
} }
@ -435,15 +412,23 @@ func (rt *RuleTypes) FindFieldType(typeName, fieldName string) (*ref.FieldType,
f, found := st.Fields[fieldName] f, found := st.Fields[fieldName]
if found { if found {
ft := f.Type ft := f.Type
expT, err := ft.ExprType()
if err != nil {
return nil, false
}
return &ref.FieldType{ return &ref.FieldType{
Type: ft.ExprType(), Type: expT,
}, true }, true
} }
// This could be a dynamic map. // This could be a dynamic map.
if st.IsMap() { if st.IsMap() {
et := st.ElemType et := st.ElemType
expT, err := et.ExprType()
if err != nil {
return nil, false
}
return &ref.FieldType{ return &ref.FieldType{
Type: et.ExprType(), Type: expT,
}, true }, true
} }
return nil, false return nil, false
@ -528,42 +513,42 @@ type schemaTypeProvider struct {
var ( var (
// AnyType is equivalent to the CEL 'protobuf.Any' type in that the value may have any of the // AnyType is equivalent to the CEL 'protobuf.Any' type in that the value may have any of the
// types supported. // types supported.
AnyType = newSimpleType("any", decls.Any, cel.AnyType, nil) AnyType = newSimpleType("any", cel.AnyType, nil)
// BoolType is equivalent to the CEL 'bool' type. // BoolType is equivalent to the CEL 'bool' type.
BoolType = newSimpleType("bool", decls.Bool, cel.BoolType, types.False) BoolType = newSimpleType("bool", cel.BoolType, types.False)
// BytesType is equivalent to the CEL 'bytes' type. // BytesType is equivalent to the CEL 'bytes' type.
BytesType = newSimpleType("bytes", decls.Bytes, cel.BytesType, types.Bytes([]byte{})) BytesType = newSimpleType("bytes", cel.BytesType, types.Bytes([]byte{}))
// DoubleType is equivalent to the CEL 'double' type which is a 64-bit floating point value. // DoubleType is equivalent to the CEL 'double' type which is a 64-bit floating point value.
DoubleType = newSimpleType("double", decls.Double, cel.DoubleType, types.Double(0)) DoubleType = newSimpleType("double", cel.DoubleType, types.Double(0))
// DurationType is equivalent to the CEL 'duration' type. // DurationType is equivalent to the CEL 'duration' type.
DurationType = newSimpleType("duration", decls.Duration, cel.DurationType, types.Duration{Duration: time.Duration(0)}) DurationType = newSimpleType("duration", cel.DurationType, types.Duration{Duration: time.Duration(0)})
// DateType is equivalent to the CEL 'date' type. // DateType is equivalent to the CEL 'date' type.
DateType = newSimpleType("date", decls.Timestamp, cel.TimestampType, types.Timestamp{Time: time.Time{}}) DateType = newSimpleType("date", cel.TimestampType, types.Timestamp{Time: time.Time{}})
// DynType is the equivalent of the CEL 'dyn' concept which indicates that the type will be // DynType is the equivalent of the CEL 'dyn' concept which indicates that the type will be
// determined at runtime rather than compile time. // determined at runtime rather than compile time.
DynType = newSimpleType("dyn", decls.Dyn, cel.DynType, nil) DynType = newSimpleType("dyn", cel.DynType, nil)
// IntType is equivalent to the CEL 'int' type which is a 64-bit signed int. // IntType is equivalent to the CEL 'int' type which is a 64-bit signed int.
IntType = newSimpleType("int", decls.Int, cel.IntType, types.IntZero) IntType = newSimpleType("int", cel.IntType, types.IntZero)
// NullType is equivalent to the CEL 'null_type'. // NullType is equivalent to the CEL 'null_type'.
NullType = newSimpleType("null_type", decls.Null, cel.NullType, types.NullValue) NullType = newSimpleType("null_type", cel.NullType, types.NullValue)
// StringType is equivalent to the CEL 'string' type which is expected to be a UTF-8 string. // StringType is equivalent to the CEL 'string' type which is expected to be a UTF-8 string.
// StringType values may either be string literals or expression strings. // StringType values may either be string literals or expression strings.
StringType = newSimpleType("string", decls.String, cel.StringType, types.String("")) StringType = newSimpleType("string", cel.StringType, types.String(""))
// TimestampType corresponds to the well-known protobuf.Timestamp type supported within CEL. // TimestampType corresponds to the well-known protobuf.Timestamp type supported within CEL.
TimestampType = newSimpleType("timestamp", decls.Timestamp, cel.TimestampType, types.Timestamp{Time: time.Time{}}) TimestampType = newSimpleType("timestamp", cel.TimestampType, types.Timestamp{Time: time.Time{}})
// UintType is equivalent to the CEL 'uint' type. // UintType is equivalent to the CEL 'uint' type.
UintType = newSimpleType("uint", decls.Uint, cel.UintType, types.Uint(0)) UintType = newSimpleType("uint", cel.UintType, types.Uint(0))
// ListType is equivalent to the CEL 'list' type. // ListType is equivalent to the CEL 'list' type.
ListType = NewListType(AnyType, noMaxLength) ListType = NewListType(AnyType, noMaxLength)

View File

@ -37,8 +37,12 @@ func TestTypes_ListType(t *testing.T) {
if list.ElemType.TypeName() != "string" { if list.ElemType.TypeName() != "string" {
t.Errorf("got %s, wanted elem type of string", list.ElemType.TypeName()) t.Errorf("got %s, wanted elem type of string", list.ElemType.TypeName())
} }
if list.ExprType().GetListType() == nil { expT, err := list.ExprType()
t.Errorf("got %v, wanted CEL list type", list.ExprType()) if err != nil {
t.Errorf("fail to get cel type: %s", err)
}
if expT.GetListType() == nil {
t.Errorf("got %v, wanted CEL list type", expT)
} }
} }
@ -59,8 +63,12 @@ func TestTypes_MapType(t *testing.T) {
if mp.ElemType.TypeName() != "int" { if mp.ElemType.TypeName() != "int" {
t.Errorf("got %s, wanted elem type of int", mp.ElemType.TypeName()) t.Errorf("got %s, wanted elem type of int", mp.ElemType.TypeName())
} }
if mp.ExprType().GetMapType() == nil { expT, err := mp.ExprType()
t.Errorf("got %v, wanted CEL map type", mp.ExprType()) if err != nil {
t.Errorf("fail to get cel type: %s", err)
}
if expT.GetMapType() == nil {
t.Errorf("got %v, wanted CEL map type", expT)
} }
} }