Add stdlib of CEL functions to Kubernetes that extends the standard library provided by CEL

This commit is contained in:
Joe Betz 2022-03-01 16:42:14 -05:00
parent 16c9d59d2d
commit fd5ae0451d
10 changed files with 1263 additions and 6 deletions

View File

@ -23,12 +23,11 @@ import (
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/ext"
expr "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"google.golang.org/protobuf/proto"
"k8s.io/apiextensions-apiserver/pkg/apiserver/schema"
"k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library"
celmodel "k8s.io/apiextensions-apiserver/third_party/forked/celopenapi/model"
)
@ -82,8 +81,8 @@ func Compile(s *schema.Structural, isResourceRoot bool) ([]CompilationResult, er
root = rootDecl.MaybeAssignTypeName(scopedTypeName)
}
propDecls = append(propDecls, decls.NewVar(ScopedVarName, root.ExprType()))
opts = append(opts, cel.Declarations(propDecls...))
opts = append(opts, ext.Strings())
opts = append(opts, cel.Declarations(propDecls...), cel.HomogeneousAggregateLiterals())
opts = append(opts, library.ExtensionLibs...)
env, err = env.Extend(opts...)
if err != nil {
return nil, err
@ -103,7 +102,7 @@ func Compile(s *schema.Structural, isResourceRoot bool) ([]CompilationResult, er
} else if !proto.Equal(ast.ResultType(), decls.Bool) {
compilationResult.Error = &Error{ErrorTypeInvalid, "cel expression must evaluate to a bool"}
} else {
prog, err := env.Program(ast)
prog, err := env.Program(ast, cel.EvalOptions(cel.OptOptimize))
if err != nil {
compilationResult.Error = &Error{ErrorTypeInvalid, "program instantiation failed: " + err.Error()}
} else {

View File

@ -0,0 +1,31 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"github.com/google/cel-go/cel"
"github.com/google/cel-go/ext"
)
// ExtensionLibs declares the set of CEL extension libraries available everywhere CEL is used in Kubernetes.
var ExtensionLibs = append(k8sExtensionLibs, ext.Strings())
var k8sExtensionLibs = []cel.EnvOption{
URLs(),
Regex(),
Lists(),
}

View File

@ -0,0 +1,59 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"testing"
"github.com/google/cel-go/cel"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
func TestLibraryCompatibility(t *testing.T) {
functionNames := map[string]struct{}{}
decls := map[cel.Library][]*exprpb.Decl{
urlsLib: urlLibraryDecls,
listsLib: listsLibraryDecls,
regexLib: regexLibraryDecls,
}
if len(k8sExtensionLibs) != len(decls) {
t.Errorf("Expected the same number of libraries in the ExtensionLibs as are tested for compatibility")
}
for _, l := range decls {
for _, d := range l {
functionNames[d.GetName()] = struct{}{}
}
}
// WARN: All library changes must follow
// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/2876-crd-validation-expression-language#function-library-updates
// and must track the functions here along with which Kubernetes version introduced them.
knownFunctions := []string{
// Kubernetes 1.24:
"isSorted", "sum", "max", "min", "indexOf", "lastIndexOf", "find", "findAll", "url", "getScheme", "getHost", "getHostname",
"getPort", "getEscapedPath", "getQuery", "isURL",
// Kubernetes <1.??>:
}
for _, fn := range knownFunctions {
delete(functionNames, fn)
}
if len(functionNames) != 0 {
t.Errorf("Expected all functions in the libraries to be assigned to a kubernetes release, but found the unassigned function names: %v", functionNames)
}
}

View File

@ -0,0 +1,386 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"fmt"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/interpreter/functions"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Lists provides a CEL function library extension of list utility functions.
//
// isSorted
//
// Returns true if the provided list of comparable elements is sorted, else returns false.
//
// <list<T>>.isSorted() <bool>, T must be a comparable type
//
// Examples:
//
// [1, 2, 3].isSorted() // return true
// ['a', 'b', 'b', 'c'].isSorted() // return true
// [2.0, 1.0].isSorted() // return false
// [1].isSorted() // return true
// [].isSorted() // return true
//
//
// sum
//
// Returns the sum of the elements of the provided list. Supports CEL number (int, uint, double) and duration types.
//
// <list<T>>.sum() <T>, T must be a numeric type or a duration
//
// Examples:
//
// [1, 3].sum() // returns 4
// [1.0, 3.0].sum() // returns 4.0
// ['1m', '1s'].sum() // returns '1m1s'
// emptyIntList.sum() // returns 0
// emptyDoubleList.sum() // returns 0.0
// [].sum() // returns 0
//
//
// min / max
//
// Returns the minimum/maximum valued element of the provided list. Supports all comparable types.
// If the list is empty, an error is returned.
//
// <list<T>>.min() <T>, T must be a comparable type
// <list<T>>.max() <T>, T must be a comparable type
//
// Examples:
//
// [1, 3].min() // returns 1
// [1, 3].max() // returns 3
// [].min() // error
// [1].min() // returns 1
// ([0] + emptyList).min() // returns 0
//
//
// indexOf / lastIndexOf
//
// Returns either the first or last positional index of the provided element in the list.
// If the element is not found, -1 is returned. Supports all equatable types.
//
// <list<T>>.indexOf(<T>) <int>, T must be an equatable type
// <list<T>>.lastIndexOf(<T>) <int>, T must be an equatable type
//
// Examples:
//
// [1, 2, 2, 3].indexOf(2) // returns 1
// ['a', 'b', 'b', 'c'].lastIndexOf('b') // returns 2
// [1.0].indexOf(1.1) // returns -1
// [].indexOf('string') // returns -1
//
func Lists() cel.EnvOption {
return cel.Lib(listsLib)
}
var listsLib = &lists{}
type lists struct{}
var paramA = decls.NewTypeParamType("A")
// CEL typeParams can be used to constraint to a specific trait (e.g. traits.ComparableType) if the 1st operand is the type to constrain.
// But the functions we need to constrain are <list<paramType>>, not just <paramType>.
var summableTypes = map[string]*exprpb.Type{"int": decls.Int, "uint": decls.Uint, "double": decls.Double, "duration": decls.Duration}
var comparableTypes = map[string]*exprpb.Type{"bool": decls.Bool, "int": decls.Int, "uint": decls.Uint, "double": decls.Double,
"duration": decls.Duration, "timestamp": decls.Timestamp, "string": decls.String, "bytes": decls.Bytes}
// WARNING: All library additions or modifications must follow
// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/2876-crd-validation-expression-language#function-library-updates
var listsLibraryDecls = []*exprpb.Decl{
decls.NewFunction("isSorted",
templatedOverloads(comparableTypes, func(name string, paramType *exprpb.Type) *exprpb.Decl_FunctionDecl_Overload {
return decls.NewInstanceOverload(fmt.Sprintf("list_%s_is_sorted_bool", name),
[]*exprpb.Type{decls.NewListType(paramType)},
decls.Bool)
})...,
),
decls.NewFunction("sum",
templatedOverloads(summableTypes, func(name string, paramType *exprpb.Type) *exprpb.Decl_FunctionDecl_Overload {
return decls.NewInstanceOverload(fmt.Sprintf("list_%s_sum_%s", name, name),
[]*exprpb.Type{decls.NewListType(paramType)},
paramType)
})...,
),
decls.NewFunction("max",
templatedOverloads(comparableTypes, func(name string, paramType *exprpb.Type) *exprpb.Decl_FunctionDecl_Overload {
return decls.NewInstanceOverload(fmt.Sprintf("list_%s_max_%s", name, name),
[]*exprpb.Type{decls.NewListType(paramType)},
paramType)
})...,
),
decls.NewFunction("min",
templatedOverloads(comparableTypes, func(name string, paramType *exprpb.Type) *exprpb.Decl_FunctionDecl_Overload {
return decls.NewInstanceOverload(fmt.Sprintf("list_%s_min_%s", name, name),
[]*exprpb.Type{decls.NewListType(paramType)},
paramType)
})...,
),
decls.NewFunction("indexOf",
decls.NewInstanceOverload("list_a_index_of_int",
[]*exprpb.Type{decls.NewListType(paramA), paramA},
decls.Int),
),
decls.NewFunction("lastIndexOf",
decls.NewInstanceOverload("list_a_last_index_of_int",
[]*exprpb.Type{decls.NewListType(paramA), paramA},
decls.Int),
),
}
func (*lists) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Declarations(listsLibraryDecls...),
}
}
func (*lists) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{
cel.Functions(
&functions.Overload{
Operator: "isSorted",
Unary: isSorted,
},
// if 'sum' is called directly, it is via dynamic dispatch, and we infer the type from the 1st element of the
// list if it has one, otherwise we return int64(0)
&functions.Overload{
Operator: "sum",
Unary: dynSum(),
},
// use overload names for sum so an initial accumulator value can be assigned to each
&functions.Overload{
Operator: "list_int_sum_int",
Unary: sum(func() ref.Val {
return types.Int(0)
}),
},
&functions.Overload{
Operator: "list_uint_sum_uint",
Unary: sum(func() ref.Val {
return types.Uint(0)
}),
},
&functions.Overload{
Operator: "list_double_sum_double",
Unary: sum(func() ref.Val {
return types.Double(0.0)
}),
},
&functions.Overload{
Operator: "list_duration_sum_duration",
Unary: sum(func() ref.Val {
return types.Duration{Duration: 0}
}),
},
&functions.Overload{
Operator: "max",
Unary: max(),
},
&functions.Overload{
Operator: "min",
Unary: min(),
},
// use overload names for indexOf and lastIndexOf to de-conflict with function of same name in strings extension library
&functions.Overload{
Operator: "list_a_index_of_int",
Binary: indexOf,
},
&functions.Overload{
Operator: "list_a_last_index_of_int",
Binary: lastIndexOf,
},
),
}
}
func isSorted(val ref.Val) ref.Val {
var prev traits.Comparer
iterable, ok := val.(traits.Iterable)
if !ok {
return types.MaybeNoSuchOverloadErr(val)
}
for it := iterable.Iterator(); it.HasNext() == types.True; {
next := it.Next()
nextCmp, ok := next.(traits.Comparer)
if !ok {
return types.MaybeNoSuchOverloadErr(next)
}
if prev != nil {
cmp := prev.Compare(next)
if cmp == types.IntOne {
return types.False
}
}
prev = nextCmp
}
return types.True
}
func dynSum() functions.UnaryOp {
return func(val ref.Val) ref.Val {
iterable, ok := val.(traits.Iterable)
if !ok {
return types.MaybeNoSuchOverloadErr(val)
}
it := iterable.Iterator()
var initval ref.Val
if it.HasNext() == types.True {
first := it.Next()
switch first.Type() {
case types.IntType:
initval = types.Int(0)
case types.UintType:
initval = types.Uint(0)
case types.DoubleType:
initval = types.Double(0.0)
case types.DurationType:
initval = types.Duration{Duration: 0}
default:
return types.MaybeNoSuchOverloadErr(first)
}
} else {
initval = types.Int(0)
}
initFn := func() ref.Val {
return initval
}
return sum(initFn)(val)
}
}
func sum(init func() ref.Val) functions.UnaryOp {
return func(val ref.Val) ref.Val {
i := init()
acc, ok := i.(traits.Adder)
if !ok {
// Should never happen since all passed in init values are valid
return types.MaybeNoSuchOverloadErr(i)
}
iterable, ok := val.(traits.Iterable)
if !ok {
return types.MaybeNoSuchOverloadErr(val)
}
for it := iterable.Iterator(); it.HasNext() == types.True; {
next := it.Next()
nextAdder, ok := next.(traits.Adder)
if !ok {
// Should never happen for type checked CEL programs
return types.MaybeNoSuchOverloadErr(next)
}
if acc != nil {
s := acc.Add(next)
sum, ok := s.(traits.Adder)
if !ok {
// Should never happen for type checked CEL programs
return types.MaybeNoSuchOverloadErr(s)
}
acc = sum
} else {
acc = nextAdder
}
}
return acc.(ref.Val)
}
}
func min() functions.UnaryOp {
return cmp("min", types.IntOne)
}
func max() functions.UnaryOp {
return cmp("max", types.IntNegOne)
}
func cmp(opName string, opPreferCmpResult ref.Val) functions.UnaryOp {
return func(val ref.Val) ref.Val {
var result traits.Comparer
iterable, ok := val.(traits.Iterable)
if !ok {
return types.MaybeNoSuchOverloadErr(val)
}
for it := iterable.Iterator(); it.HasNext() == types.True; {
next := it.Next()
nextCmp, ok := next.(traits.Comparer)
if !ok {
// Should never happen for type checked CEL programs
return types.MaybeNoSuchOverloadErr(next)
}
if result == nil {
result = nextCmp
} else {
cmp := result.Compare(next)
if cmp == opPreferCmpResult {
result = nextCmp
}
}
}
if result == nil {
return types.NewErr("%s called on empty list", opName)
}
return result.(ref.Val)
}
}
func indexOf(list ref.Val, item ref.Val) ref.Val {
lister, ok := list.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(list)
}
sz := lister.Size().(types.Int)
for i := types.Int(0); i < sz; i++ {
if lister.Get(types.Int(i)).Equal(item) == types.True {
return types.Int(i)
}
}
return types.Int(-1)
}
func lastIndexOf(list ref.Val, item ref.Val) ref.Val {
lister, ok := list.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(list)
}
sz := lister.Size().(types.Int)
for i := sz - 1; i >= 0; i-- {
if lister.Get(types.Int(i)).Equal(item) == types.True {
return types.Int(i)
}
}
return types.Int(-1)
}
// templatedOverloads returns overloads for each of the provided types. The template function is called with each type
// name (map key) and type to construct the overloads.
func templatedOverloads(types map[string]*exprpb.Type, template func(name string, t *exprpb.Type) *exprpb.Decl_FunctionDecl_Overload) []*exprpb.Decl_FunctionDecl_Overload {
overloads := make([]*exprpb.Decl_FunctionDecl_Overload, len(types))
i := 0
for name, t := range types {
overloads[i] = template(name, t)
i++
}
return overloads
}

View File

@ -0,0 +1,158 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"regexp"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter/functions"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// Regex provides a CEL function library extension of regex utility functions.
//
// find / findAll
//
// Returns substrings that match the provided regular expression. find returns the first match. findAll may optionally
// be provided a limit. If the limit is set and >= 0, no more than the limit number of matches are returned.
//
// <string>.find(<string>) <string>
// <string>.findAll(<string>) <list <string>>
// <string>.findAll(<string>, <int>) <list <string>>
//
// Examples:
//
// "abc 123".find('[0-9]*') // returns '123'
// "abc 123".find('xyz') // returns ''
// "123 abc 456".findAll('[0-9]*') // returns ['123', '456']
// "123 abc 456".findAll('[0-9]*', 1) // returns ['123']
// "123 abc 456".findAll('xyz') // returns []
//
func Regex() cel.EnvOption {
return cel.Lib(regexLib)
}
var regexLib = &regex{}
type regex struct{}
var regexLibraryDecls = []*exprpb.Decl{
decls.NewFunction("find",
decls.NewInstanceOverload("string_find_string",
[]*exprpb.Type{decls.String, decls.String},
decls.String),
),
decls.NewFunction("findAll",
decls.NewInstanceOverload("string_find_all_string",
[]*exprpb.Type{decls.String, decls.String},
decls.NewListType(decls.String)),
decls.NewInstanceOverload("string_find_all_string_int",
[]*exprpb.Type{decls.String, decls.String, decls.Int},
decls.NewListType(decls.String)),
),
}
func (*regex) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Declarations(regexLibraryDecls...),
}
}
func (*regex) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{
cel.Functions(
&functions.Overload{
Operator: "find",
Binary: find,
},
&functions.Overload{
Operator: "string_find_string",
Binary: find,
},
&functions.Overload{
Operator: "findAll",
Binary: func(str, regex ref.Val) ref.Val {
return findAll(str, regex, types.Int(-1))
},
Function: findAll,
},
&functions.Overload{
Operator: "string_find_all_string",
Binary: func(str, regex ref.Val) ref.Val {
return findAll(str, regex, types.Int(-1))
},
},
&functions.Overload{
Operator: "string_find_all_string_int",
Function: findAll,
},
),
}
}
func find(strVal ref.Val, regexVal ref.Val) ref.Val {
str, ok := strVal.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(strVal)
}
regex, ok := regexVal.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(regexVal)
}
re, err := regexp.Compile(regex)
if err != nil {
return types.NewErr("Illegal regex: %v", err.Error())
}
result := re.FindString(str)
return types.String(result)
}
func findAll(args ...ref.Val) ref.Val {
argn := len(args)
if argn < 2 || argn > 3 {
return types.NoSuchOverloadErr()
}
str, ok := args[0].Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(args[0])
}
regex, ok := args[1].Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(args[1])
}
n := int64(-1)
if argn == 3 {
n, ok = args[2].Value().(int64)
if !ok {
return types.MaybeNoSuchOverloadErr(args[2])
}
}
re, err := regexp.Compile(regex)
if err != nil {
return types.NewErr("Illegal regex: %v", err.Error())
}
result := re.FindAllString(str, int(n))
return types.NewStringList(types.DefaultTypeAdapter, result)
}

View File

@ -0,0 +1,315 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"net/url"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter/functions"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
"k8s.io/apiextensions-apiserver/third_party/forked/celopenapi/model"
)
// URLs provides a CEL function library extension of URL parsing functions.
//
// url
//
// Converts a string to a URL or results in an error if the string is not a valid URL. The URL must be an absolute URI
// or an absolute path.
//
// url(<string>) <URL>
//
// Examples:
//
// url('https://user:pass@example.com:80/path?query=val#fragment') // returns a URL
// url('/absolute-path') // returns a URL
// url('https://a:b:c/') // error
// url('../relative-path') // error
//
// isURL
//
// Returns true if a string is a valid URL. The URL must be an absolute URI or an absolute path.
//
// isURL( <string>) <bool>
//
// Examples:
//
// isURL('https://user:pass@example.com:80/path?query=val#fragment') // returns true
// isURL('/absolute-path') // returns true
// isURL('https://a:b:c/') // returns false
// isURL('../relative-path') // returns false
//
//
// getScheme / getHost / getHostname / getPort / getEscapedPath / getQuery
//
// Return the parsed components of a URL.
// - getScheme: If absent in the URL, returns an empty string.
// - getHostname: IPv6 addresses are returned with braces, e.g. "[::1]". If absent in the URL, returns an empty string.
// - getHost: IPv6 addresses are returned without braces, e.g. "::1". If absent in the URL, returns an empty string.
// - getEscapedPath: The string returned by getEscapedPath is URL escaped, e.g. "with space" becomes "with%20space".
// If absent in the URL, returns an empty string.
// - getPort: If absent in the URL, returns an empty string.
// - getQuery: Returns the query parameters in "matrix" form where a repeated query key is interpreted to
// mean that there are multiple values for that key. The keys and values are returned unescaped.
// If absent in the URL, returns an empty map.
//
// <URL>.getScheme() <string>
// <URL>.getHost() <string>
// <URL>.getHostname() <string>
// <URL>.getPort() <string>
// <URL>.getEscapedPath() <string>
// <URL>.getQuery() <map <string>, <list <string>>
//
// Examples:
//
// url('/path').getScheme() // returns ''
// url('https://example.com/').getScheme() // returns 'https'
// url('https://example.com:80/').getHost() // returns 'example.com:80'
// url('https://example.com/').getHost() // returns 'example.com'
// url('https://[::1]:80/').getHost() // returns '[::1]:80'
// url('https://[::1]/').getHost() // returns '[::1]'
// url('/path').getHost() // returns ''
// url('https://example.com:80/').getHostname() // returns 'example.com'
// url('https://127.0.0.1:80/').getHostname() // returns '127.0.0.1'
// url('https://[::1]:80/').getHostname() // returns '::1'
// url('/path').getHostname() // returns ''
// url('https://example.com:80/').getPort() // returns '80'
// url('https://example.com/').getPort() // returns ''
// url('/path').getPort() // returns ''
// url('https://example.com/path').getEscapedPath() // returns '/path'
// url('https://example.com/path with spaces/').getEscapedPath() // returns '/path%20with%20spaces/'
// url('https://example.com').getEscapedPath() // returns ''
// url('https://example.com/path?k1=a&k2=b&k2=c').getQuery() // returns { 'k1': ['a'], 'k2': ['b', 'c']}
// url('https://example.com/path?key with spaces=value with spaces').getQuery() // returns { 'key with spaces': ['value with spaces']}
// url('https://example.com/path?').getQuery() // returns {}
// url('https://example.com/path').getQuery() // returns {}
//
func URLs() cel.EnvOption {
return cel.Lib(urlsLib)
}
var urlsLib = &urls{}
type urls struct{}
var urlLibraryDecls = []*exprpb.Decl{
decls.NewFunction("url",
decls.NewOverload("string_to_url",
[]*exprpb.Type{decls.String},
model.URLObject),
),
decls.NewFunction("getScheme",
decls.NewInstanceOverload("url_get_scheme",
[]*exprpb.Type{model.URLObject},
decls.String),
),
decls.NewFunction("getHost",
decls.NewInstanceOverload("url_get_host",
[]*exprpb.Type{model.URLObject},
decls.String),
),
decls.NewFunction("getHostname",
decls.NewInstanceOverload("url_get_hostname",
[]*exprpb.Type{model.URLObject},
decls.String),
),
decls.NewFunction("getPort",
decls.NewInstanceOverload("url_get_port",
[]*exprpb.Type{model.URLObject},
decls.String),
),
decls.NewFunction("getEscapedPath",
decls.NewInstanceOverload("url_get_escaped_path",
[]*exprpb.Type{model.URLObject},
decls.String),
),
decls.NewFunction("getQuery",
decls.NewInstanceOverload("url_get_query",
[]*exprpb.Type{model.URLObject},
decls.NewMapType(decls.String, decls.NewListType(decls.String))),
),
decls.NewFunction("isURL",
decls.NewOverload("is_url_string",
[]*exprpb.Type{decls.String},
decls.Bool),
),
}
func (*urls) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Declarations(urlLibraryDecls...),
}
}
func (*urls) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{
cel.Functions(
&functions.Overload{
Operator: "url",
Unary: stringToUrl,
},
&functions.Overload{
Operator: "string_to_url",
Unary: stringToUrl,
},
&functions.Overload{
Operator: "getScheme",
Unary: getScheme,
},
&functions.Overload{
Operator: "url_get_scheme",
Unary: getScheme,
},
&functions.Overload{
Operator: "getHost",
Unary: getHost,
},
&functions.Overload{
Operator: "url_get_host",
Unary: getHost,
},
&functions.Overload{
Operator: "getHostname",
Unary: getHostname,
},
&functions.Overload{
Operator: "url_get_hostname",
Unary: getHostname,
},
&functions.Overload{
Operator: "getPort",
Unary: getPort,
},
&functions.Overload{
Operator: "url_get_port",
Unary: getPort,
},
&functions.Overload{
Operator: "getEscapedPath",
Unary: getEscapedPath,
},
&functions.Overload{
Operator: "url_get_escaped_path",
Unary: getEscapedPath,
},
&functions.Overload{
Operator: "getQuery",
Unary: getQuery,
},
&functions.Overload{
Operator: "url_get_query",
Unary: getQuery,
},
&functions.Overload{
Operator: "isURL",
Unary: isURL,
},
&functions.Overload{
Operator: "is_url_string",
Unary: isURL,
},
),
}
}
func stringToUrl(arg ref.Val) ref.Val {
s, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
// Use ParseRequestURI to check the URL before conversion.
// ParseRequestURI requires absolute URLs and is used by the OpenAPIv3 'uri' format.
_, err := url.ParseRequestURI(s)
if err != nil {
return types.NewErr("URL parse error during conversion from string: %v", err)
}
// We must parse again with Parse since ParseRequestURI incorrectly parses URLs that contain a fragment
// part and will incorrectly append the fragment to either the path or the query, depending on which it was adjacent to.
u, err := url.Parse(s)
if err != nil {
// Errors are not expected here since Parse is a more lenient parser than ParseRequestURI.
return types.NewErr("URL parse error during conversion from string: %v", err)
}
return model.URL{URL: u}
}
func getScheme(arg ref.Val) ref.Val {
u, ok := arg.Value().(*url.URL)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.String(u.Scheme)
}
func getHost(arg ref.Val) ref.Val {
u, ok := arg.Value().(*url.URL)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.String(u.Host)
}
func getHostname(arg ref.Val) ref.Val {
u, ok := arg.Value().(*url.URL)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.String(u.Hostname())
}
func getPort(arg ref.Val) ref.Val {
u, ok := arg.Value().(*url.URL)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.String(u.Port())
}
func getEscapedPath(arg ref.Val) ref.Val {
u, ok := arg.Value().(*url.URL)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.String(u.EscapedPath())
}
func getQuery(arg ref.Val) ref.Val {
u, ok := arg.Value().(*url.URL)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
result := map[ref.Val]ref.Val{}
for k, v := range u.Query() {
result[types.String(k)] = types.NewStringList(types.DefaultTypeAdapter, v)
}
return types.NewRefValMap(types.DefaultTypeAdapter, result)
}
func isURL(arg ref.Val) ref.Val {
s, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
_, err := url.ParseRequestURI(s)
return types.Bool(err == nil)
}

View File

@ -1353,6 +1353,238 @@ func TestValidationExpressions(t *testing.T) {
// TODO: also find a way to test the errors returned for: array with no items, object with no properties or additionalProperties, invalid listType and invalid type.
},
},
{name: "stdlib list functions",
obj: map[string]interface{}{
"ints": []interface{}{int64(1), int64(2), int64(2), int64(3)},
"unsortedInts": []interface{}{int64(2), int64(1)},
"emptyInts": []interface{}{},
"doubles": []interface{}{float64(1), float64(2), float64(2), float64(3)},
"unsortedDoubles": []interface{}{float64(2), float64(1)},
"emptyDoubles": []interface{}{},
"intBackedDoubles": []interface{}{int64(1), int64(2), int64(2), int64(3)},
"unsortedIntBackedDDoubles": []interface{}{int64(2), int64(1)},
"emptyIntBackedDDoubles": []interface{}{},
"durations": []interface{}{"1s", "1m", "1m", "1h"},
"unsortedDurations": []interface{}{"1m", "1s"},
"emptyDurations": []interface{}{},
"strings": []interface{}{"a", "b", "b", "c"},
"unsortedStrings": []interface{}{"b", "a"},
"emptyStrings": []interface{}{},
"dates": []interface{}{"2000-01-01", "2000-02-01", "2000-02-01", "2010-01-01"},
"unsortedDates": []interface{}{"2000-02-01", "2000-01-01"},
"emptyDates": []interface{}{},
"objs": []interface{}{
map[string]interface{}{"f1": "a", "f2": "a"},
map[string]interface{}{"f1": "a", "f2": "b"},
map[string]interface{}{"f1": "a", "f2": "b"},
map[string]interface{}{"f1": "a", "f2": "c"},
},
},
schema: objectTypePtr(map[string]schema.Structural{
"ints": listType(&integerType),
"unsortedInts": listType(&integerType),
"emptyInts": listType(&integerType),
"doubles": listType(&doubleType),
"unsortedDoubles": listType(&doubleType),
"emptyDoubles": listType(&doubleType),
"intBackedDoubles": listType(&doubleType),
"unsortedIntBackedDDoubles": listType(&doubleType),
"emptyIntBackedDDoubles": listType(&doubleType),
"durations": listType(&durationFormat),
"unsortedDurations": listType(&durationFormat),
"emptyDurations": listType(&durationFormat),
"strings": listType(&stringType),
"unsortedStrings": listType(&stringType),
"emptyStrings": listType(&stringType),
"dates": listType(&dateFormat),
"unsortedDates": listType(&dateFormat),
"emptyDates": listType(&dateFormat),
"objs": listType(objectTypePtr(map[string]schema.Structural{
"f1": stringType,
"f2": stringType,
})),
}),
valid: []string{
"self.ints.sum() == 8",
"self.ints.min() == 1",
"self.ints.max() == 3",
"self.emptyInts.sum() == 0",
"self.ints.isSorted()",
"self.emptyInts.isSorted()",
"self.unsortedInts.isSorted() == false",
"self.ints.indexOf(2) == 1",
"self.ints.lastIndexOf(2) == 2",
"self.ints.indexOf(10) == -1",
"self.ints.lastIndexOf(10) == -1",
"self.doubles.sum() == 8.0",
"self.doubles.min() == 1.0",
"self.doubles.max() == 3.0",
"self.emptyDoubles.sum() == 0.0",
"self.doubles.isSorted()",
"self.emptyDoubles.isSorted()",
"self.unsortedDoubles.isSorted() == false",
"self.doubles.indexOf(2.0) == 1",
"self.doubles.lastIndexOf(2.0) == 2",
"self.doubles.indexOf(10.0) == -1",
"self.doubles.lastIndexOf(10.0) == -1",
"self.intBackedDoubles.sum() == 8.0",
"self.intBackedDoubles.min() == 1.0",
"self.intBackedDoubles.max() == 3.0",
"self.emptyIntBackedDDoubles.sum() == 0.0",
"self.intBackedDoubles.isSorted()",
"self.emptyDoubles.isSorted()",
"self.unsortedIntBackedDDoubles.isSorted() == false",
"self.intBackedDoubles.indexOf(2.0) == 1",
"self.intBackedDoubles.lastIndexOf(2.0) == 2",
"self.intBackedDoubles.indexOf(10.0) == -1",
"self.intBackedDoubles.lastIndexOf(10.0) == -1",
"self.durations.sum() == duration('1h2m1s')",
"self.durations.min() == duration('1s')",
"self.durations.max() == duration('1h')",
"self.emptyDurations.sum() == duration('0')",
"self.durations.isSorted()",
"self.emptyDurations.isSorted()",
"self.unsortedDurations.isSorted() == false",
"self.durations.indexOf(duration('1m')) == 1",
"self.durations.lastIndexOf(duration('1m')) == 2",
"self.durations.indexOf(duration('2m')) == -1",
"self.durations.lastIndexOf(duration('2m')) == -1",
"self.strings.min() == 'a'",
"self.strings.max() == 'c'",
"self.strings.isSorted()",
"self.emptyStrings.isSorted()",
"self.unsortedStrings.isSorted() == false",
"self.strings.indexOf('b') == 1",
"self.strings.lastIndexOf('b') == 2",
"self.strings.indexOf('x') == -1",
"self.strings.lastIndexOf('x') == -1",
"self.dates.min() == timestamp('2000-01-01T00:00:00.000Z')",
"self.dates.max() == timestamp('2010-01-01T00:00:00.000Z')",
"self.dates.isSorted()",
"self.emptyDates.isSorted()",
"self.unsortedDates.isSorted() == false",
"self.dates.indexOf(timestamp('2000-02-01T00:00:00.000Z')) == 1",
"self.dates.lastIndexOf(timestamp('2000-02-01T00:00:00.000Z')) == 2",
"self.dates.indexOf(timestamp('2005-02-01T00:00:00.000Z')) == -1",
"self.dates.lastIndexOf(timestamp('2005-02-01T00:00:00.000Z')) == -1",
// array, map and object types use structural equality (aka "deep equals")
"[[1], [2]].indexOf([1]) == 0",
"[{'a': 1}, {'b': 2}].lastIndexOf({'b': 2}) == 1",
"self.objs.indexOf(self.objs[1]) == 1",
"self.objs.lastIndexOf(self.objs[1]) == 2",
// avoiding empty list error with min and max by appending an acceptable default minimum value
"([0] + self.emptyInts).min() == 0",
// handle CEL's dynamic dispatch appropriately (special cases to handle an empty list)
"dyn([]).sum() == 0",
"dyn([1, 2]).sum() == 3",
"dyn([1.0, 2.0]).sum() == 3.0",
// TODO: enable once type system fix it made to CEL
//"[].sum() == 0", // An empty list returns an 0 int
},
errors: map[string]string{
// return an error for min/max on empty list
"self.emptyInts.min() == 1": "min called on empty list",
"self.emptyInts.max() == 3": "max called on empty list",
"self.emptyDoubles.min() == 1.0": "min called on empty list",
"self.emptyDoubles.max() == 3.0": "max called on empty list",
"self.emptyStrings.min() == 'a'": "min called on empty list",
"self.emptyStrings.max() == 'c'": "max called on empty list",
// only allow sum on numeric types and duration
"['a', 'b'].sum() == 'c'": "found no matching overload for 'sum' applied to 'list(string).()", // compiler type checking error
// only allow min/max/indexOf/lastIndexOf on comparable types
"[[1], [2]].min() == [1]": "found no matching overload for 'min' applied to 'list(list(int)).()", // compiler type checking error
"[{'a': 1}, {'b': 2}].max() == {'b': 2}": "found no matching overload for 'max' applied to 'list(map(string, int)).()", // compiler type checking error
},
},
{name: "stdlib regex functions",
obj: map[string]interface{}{
"str": "this is a 123 string 456",
},
schema: objectTypePtr(map[string]schema.Structural{
"str": stringType,
}),
valid: []string{
"self.str.find('[0-9]+') == '123'",
"self.str.find('[0-9]+') != '456'",
"self.str.find('xyz') == ''",
"self.str.findAll('[0-9]+') == ['123', '456']",
"self.str.findAll('[0-9]+', 0) == []",
"self.str.findAll('[0-9]+', 1) == ['123']",
"self.str.findAll('[0-9]+', 2) == ['123', '456']",
"self.str.findAll('[0-9]+', 3) == ['123', '456']",
"self.str.findAll('[0-9]+', -1) == ['123', '456']",
"self.str.findAll('xyz') == []",
"self.str.findAll('xyz', 1) == []",
},
},
{name: "URL parsing",
obj: map[string]interface{}{
"url": "https://user:pass@kubernetes.io:80/docs/home?k1=a&k2=b&k2=c#anchor",
},
schema: objectTypePtr(map[string]schema.Structural{
"url": stringType,
}),
valid: []string{
"url('/path').getScheme() == ''",
"url('https://example.com/').getScheme() == 'https'",
"url('https://example.com:80/').getHost() == 'example.com:80'",
"url('https://example.com/').getHost() == 'example.com'",
"url('https://[::1]:80/').getHost() == '[::1]:80'",
"url('https://[::1]/').getHost() == '[::1]'",
"url('/path').getHost() == ''",
"url('https://example.com:80/').getHostname() == 'example.com'",
"url('https://127.0.0.1/').getHostname() == '127.0.0.1'",
"url('https://[::1]/').getHostname() == '::1'",
"url('/path').getHostname() == ''",
"url('https://example.com:80/').getPort() == '80'",
"url('https://example.com/').getPort() == ''",
"url('/path').getPort() == ''",
"url('https://example.com/path').getEscapedPath() == '/path'",
"url('https://example.com/with space/').getEscapedPath() == '/with%20space/'",
"url('https://example.com').getEscapedPath() == ''",
"url('https://example.com/path?k1=a&k2=b&k2=c').getQuery() == { 'k1': ['a'], 'k2': ['b', 'c']}",
"url('https://example.com/path?key with spaces=value with spaces').getQuery() == { 'key with spaces': ['value with spaces']}",
"url('https://example.com/path?').getQuery() == {}",
"url('https://example.com/path').getQuery() == {}",
// test with string input
"url(self.url).getScheme() == 'https'",
"url(self.url).getHost() == 'kubernetes.io:80'",
"url(self.url).getHostname() == 'kubernetes.io'",
"url(self.url).getPort() == '80'",
"url(self.url).getEscapedPath() == '/docs/home'",
"url(self.url).getQuery() == {'k1': ['a'], 'k2': ['b', 'c']}",
"isURL('https://user:pass@example.com:80/path?query=val#fragment')",
"isURL('/path') == true",
"isURL('https://a:b:c/') == false",
"isURL('../relative-path') == false",
},
},
}
for _, tt := range tests {

View File

@ -33,7 +33,7 @@ type Resolver interface {
func NewRegistry(stdExprEnv *cel.Env) *Registry {
return &Registry{
exprEnvs: map[string]*cel.Env{"": stdExprEnv},
schemas: map[string]*schema.Structural{},
schemas: map[string]*schema.Structural{},
types: map[string]*DeclType{
BoolType.TypeName(): BoolType,
BytesType.TypeName(): BytesType,

View File

@ -0,0 +1,76 @@
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package model
import (
"fmt"
"net/url"
"reflect"
"github.com/google/cel-go/checker/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// URL provides a CEL representation of a URL.
type URL struct {
*url.URL
}
var (
URLObject = decls.NewObjectType("kubernetes.URL")
typeValue = types.NewTypeValue("kubernetes.URL")
)
// ConvertToNative implements ref.Val.ConvertToNative.
func (d URL) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
if reflect.TypeOf(d.URL).AssignableTo(typeDesc) {
return d.URL, nil
}
if reflect.TypeOf("").AssignableTo(typeDesc) {
return d.URL.String(), nil
}
return nil, fmt.Errorf("type conversion error from 'URL' to '%v'", typeDesc)
}
// ConvertToType implements ref.Val.ConvertToType.
func (d URL) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case typeValue:
return d
case types.TypeType:
return typeValue
}
return types.NewErr("type conversion error from '%s' to '%s'", typeValue, typeVal)
}
// Equal implements ref.Val.Equal.
func (d URL) Equal(other ref.Val) ref.Val {
otherDur, ok := other.(URL)
if !ok {
return types.MaybeNoSuchOverloadErr(other)
}
return types.Bool(d.URL.String() == otherDur.URL.String())
}
// Type implements ref.Val.Type.
func (d URL) Type() ref.Type {
return typeValue
}
// Value implements ref.Val.Value.
func (d URL) Value() interface{} {
return d.URL
}

1
vendor/modules.txt vendored
View File

@ -1312,6 +1312,7 @@ k8s.io/apiextensions-apiserver/pkg/apiserver
k8s.io/apiextensions-apiserver/pkg/apiserver/conversion
k8s.io/apiextensions-apiserver/pkg/apiserver/schema
k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel
k8s.io/apiextensions-apiserver/pkg/apiserver/schema/cel/library
k8s.io/apiextensions-apiserver/pkg/apiserver/schema/defaulting
k8s.io/apiextensions-apiserver/pkg/apiserver/schema/listtype
k8s.io/apiextensions-apiserver/pkg/apiserver/schema/objectmeta