Merge pull request #121912 from JoelSpeed/cel-ip-addr

CEL library extensions for IP Address and CIDR network parsing
This commit is contained in:
Kubernetes Prow Robot 2023-12-16 11:06:04 +01:00 committed by GitHub
commit 76cd7521aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1785 additions and 1 deletions

View File

@ -0,0 +1,87 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
import (
"fmt"
"math"
"net/netip"
"reflect"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// CIDR provides a CEL representation of an network address.
type CIDR struct {
netip.Prefix
}
var (
CIDRType = cel.OpaqueType("net.CIDR")
)
// ConvertToNative implements ref.Val.ConvertToNative.
func (d CIDR) ConvertToNative(typeDesc reflect.Type) (any, error) {
if reflect.TypeOf(d.Prefix).AssignableTo(typeDesc) {
return d.Prefix, nil
}
if reflect.TypeOf("").AssignableTo(typeDesc) {
return d.Prefix.String(), nil
}
return nil, fmt.Errorf("type conversion error from 'CIDR' to '%v'", typeDesc)
}
// ConvertToType implements ref.Val.ConvertToType.
func (d CIDR) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case CIDRType:
return d
case types.TypeType:
return CIDRType
case types.StringType:
return types.String(d.Prefix.String())
}
return types.NewErr("type conversion error from '%s' to '%s'", CIDRType, typeVal)
}
// Equal implements ref.Val.Equal.
func (d CIDR) Equal(other ref.Val) ref.Val {
otherD, ok := other.(CIDR)
if !ok {
return types.ValOrErr(other, "no such overload")
}
return types.Bool(d.Prefix == otherD.Prefix)
}
// Type implements ref.Val.Type.
func (d CIDR) Type() ref.Type {
return CIDRType
}
// Value implements ref.Val.Value.
func (d CIDR) Value() any {
return d.Prefix
}
// Size returns the size of the CIDR prefix address in bytes.
// Used in the size estimation of the runtime cost.
func (d CIDR) Size() ref.Val {
return types.Int(int(math.Ceil(float64(d.Prefix.Bits()) / 8)))
}

View File

@ -123,6 +123,13 @@ var baseOpts = []VersionedOptions{
ext.Sets(),
},
},
{
IntroducedVersion: version.MajorMinor(1, 30),
EnvOptions: []cel.EnvOption{
library.IP(),
library.CIDR(),
},
},
}
// MustBaseEnvSet returns the common CEL base environments for Kubernetes for Version, or panics

View File

@ -0,0 +1,86 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cel
import (
"fmt"
"math"
"net/netip"
"reflect"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// IP provides a CEL representation of an IP address.
type IP struct {
netip.Addr
}
var (
IPType = cel.OpaqueType("net.IP")
)
// ConvertToNative implements ref.Val.ConvertToNative.
func (d IP) ConvertToNative(typeDesc reflect.Type) (any, error) {
if reflect.TypeOf(d.Addr).AssignableTo(typeDesc) {
return d.Addr, nil
}
if reflect.TypeOf("").AssignableTo(typeDesc) {
return d.Addr.String(), nil
}
return nil, fmt.Errorf("type conversion error from 'IP' to '%v'", typeDesc)
}
// ConvertToType implements ref.Val.ConvertToType.
func (d IP) ConvertToType(typeVal ref.Type) ref.Val {
switch typeVal {
case IPType:
return d
case types.TypeType:
return IPType
case types.StringType:
return types.String(d.Addr.String())
}
return types.NewErr("type conversion error from '%s' to '%s'", IPType, typeVal)
}
// Equal implements ref.Val.Equal.
func (d IP) Equal(other ref.Val) ref.Val {
otherD, ok := other.(IP)
if !ok {
return types.ValOrErr(other, "no such overload")
}
return types.Bool(d.Addr == otherD.Addr)
}
// Type implements ref.Val.Type.
func (d IP) Type() ref.Type {
return IPType
}
// Value implements ref.Val.Value.
func (d IP) Value() any {
return d.Addr
}
// Size returns the size of the IP address in bytes.
// Used in the size estimation of the runtime cost.
func (d IP) Size() ref.Val {
return types.Int(int(math.Ceil(float64(d.Addr.BitLen()) / 8)))
}

View File

@ -0,0 +1,287 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"fmt"
"net/netip"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
apiservercel "k8s.io/apiserver/pkg/cel"
)
// CIDR provides a CEL function library extension of CIDR notation parsing functions.
//
// cidr
//
// Converts a string in CIDR notation to a network address representation or results in an error if the string is not a valid CIDR notation.
// The CIDR must be an IPv4 or IPv6 subnet address with a mask.
// Leading zeros in IPv4 address octets are not allowed.
// IPv4-mapped IPv6 addresses (e.g. ::ffff:1.2.3.4/24) are not allowed.
//
// cidr(<string>) <CIDR>
//
// Examples:
//
// cidr('192.168.0.0/16') // returns an IPv4 address with a CIDR mask
// cidr('::1/128') // returns an IPv6 address with a CIDR mask
// cidr('192.168.0.0/33') // error
// cidr('::1/129') // error
// cidr('192.168.0.1/16') // error, because there are non-0 bits after the prefix
//
// isCIDR
//
// Returns true if a string is a valid CIDR notation respresentation of a subnet with mask.
// The CIDR must be an IPv4 or IPv6 subnet address with a mask.
// Leading zeros in IPv4 address octets are not allowed.
// IPv4-mapped IPv6 addresses (e.g. ::ffff:1.2.3.4/24) are not allowed.
//
// isCIDR(<string>) <bool>
//
// Examples:
//
// isCIDR('192.168.0.0/16') // returns true
// isCIDR('::1/128') // returns true
// isCIDR('192.168.0.0/33') // returns false
// isCIDR('::1/129') // returns false
//
// containsIP / containerCIDR / ip / masked / prefixLength
//
// - containsIP: Returns true if a the CIDR contains the given IP address.
// The IP address must be an IPv4 or IPv6 address.
// May take either a string or IP address as an argument.
//
// - containsCIDR: Returns true if a the CIDR contains the given CIDR.
// The CIDR must be an IPv4 or IPv6 subnet address with a mask.
// May take either a string or CIDR as an argument.
//
// - ip: Returns the IP address representation of the CIDR.
//
// - masked: Returns the CIDR representation of the network address with a masked prefix.
// This can be used to return the canonical form of the CIDR network.
//
// - prefixLength: Returns the prefix length of the CIDR in bits.
// This is the number of bits in the mask.
//
// Examples:
//
// cidr('192.168.0.0/24').containsIP(ip('192.168.0.1')) // returns true
// cidr('192.168.0.0/24').containsIP(ip('192.168.1.1')) // returns false
// cidr('192.168.0.0/24').containsIP('192.168.0.1') // returns true
// cidr('192.168.0.0/24').containsIP('192.168.1.1') // returns false
// cidr('192.168.0.0/16').containsCIDR(cidr('192.168.10.0/24')) // returns true
// cidr('192.168.1.0/24').containsCIDR(cidr('192.168.2.0/24')) // returns false
// cidr('192.168.0.0/16').containsCIDR('192.168.10.0/24') // returns true
// cidr('192.168.1.0/24').containsCIDR('192.168.2.0/24') // returns false
// cidr('192.168.0.1/24').ip() // returns ipAddr('192.168.0.1')
// cidr('192.168.0.1/24').ip().family() // returns '4'
// cidr('::1/128').ip() // returns ipAddr('::1')
// cidr('::1/128').ip().family() // returns '6'
// cidr('192.168.0.0/24').masked() // returns cidr('192.168.0.0/24')
// cidr('192.168.0.1/24').masked() // returns cidr('192.168.0.0/24')
// cidr('192.168.0.0/24') == cidr('192.168.0.0/24').masked() // returns true, CIDR was already in canonical format
// cidr('192.168.0.1/24') == cidr('192.168.0.1/24').masked() // returns false, CIDR was not in canonical format
// cidr('192.168.0.0/16').prefixLength() // returns 16
// cidr('::1/128').prefixLength() // returns 128
func CIDR() cel.EnvOption {
return cel.Lib(cidrsLib)
}
var cidrsLib = &cidrs{}
type cidrs struct{}
func (*cidrs) LibraryName() string {
return "net.cidr"
}
var cidrLibraryDecls = map[string][]cel.FunctionOpt{
"cidr": {
cel.Overload("string_to_cidr", []*cel.Type{cel.StringType}, apiservercel.CIDRType,
cel.UnaryBinding(stringToCIDR)),
},
"containsIP": {
cel.MemberOverload("cidr_contains_ip_string", []*cel.Type{apiservercel.CIDRType, cel.StringType}, cel.BoolType,
cel.BinaryBinding(cidrContainsIPString)),
cel.MemberOverload("cidr_contains_ip_ip", []*cel.Type{apiservercel.CIDRType, apiservercel.IPType}, cel.BoolType,
cel.BinaryBinding(cidrContainsIP)),
},
"containsCIDR": {
cel.MemberOverload("cidr_contains_cidr_string", []*cel.Type{apiservercel.CIDRType, cel.StringType}, cel.BoolType,
cel.BinaryBinding(cidrContainsCIDRString)),
cel.MemberOverload("cidr_contains_cidr", []*cel.Type{apiservercel.CIDRType, apiservercel.CIDRType}, cel.BoolType,
cel.BinaryBinding(cidrContainsCIDR)),
},
"ip": {
cel.MemberOverload("cidr_ip", []*cel.Type{apiservercel.CIDRType}, apiservercel.IPType,
cel.UnaryBinding(cidrToIP)),
},
"prefixLength": {
cel.MemberOverload("cidr_prefix_length", []*cel.Type{apiservercel.CIDRType}, cel.IntType,
cel.UnaryBinding(prefixLength)),
},
"masked": {
cel.MemberOverload("cidr_masked", []*cel.Type{apiservercel.CIDRType}, apiservercel.CIDRType,
cel.UnaryBinding(masked)),
},
"isCIDR": {
cel.Overload("is_cidr", []*cel.Type{cel.StringType}, cel.BoolType,
cel.UnaryBinding(isCIDR)),
},
"string": {
cel.Overload("cidr_to_string", []*cel.Type{apiservercel.CIDRType}, cel.StringType,
cel.UnaryBinding(cidrToString)),
},
}
func (*cidrs) CompileOptions() []cel.EnvOption {
options := []cel.EnvOption{cel.Types(apiservercel.CIDRType),
cel.Variable(apiservercel.CIDRType.TypeName(), types.NewTypeTypeWithParam(apiservercel.CIDRType)),
}
for name, overloads := range cidrLibraryDecls {
options = append(options, cel.Function(name, overloads...))
}
return options
}
func (*cidrs) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func stringToCIDR(arg ref.Val) ref.Val {
s, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
net, err := parseCIDR(s)
if err != nil {
return types.NewErr("network address parse error during conversion from string: %v", err)
}
return apiservercel.CIDR{
Prefix: net,
}
}
func cidrToString(arg ref.Val) ref.Val {
cidr, ok := arg.(apiservercel.CIDR)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.String(cidr.Prefix.String())
}
func cidrContainsIPString(arg ref.Val, other ref.Val) ref.Val {
return cidrContainsIP(arg, stringToIP(other))
}
func cidrContainsCIDRString(arg ref.Val, other ref.Val) ref.Val {
return cidrContainsCIDR(arg, stringToCIDR(other))
}
func cidrContainsIP(arg ref.Val, other ref.Val) ref.Val {
cidr, ok := arg.(apiservercel.CIDR)
if !ok {
return types.MaybeNoSuchOverloadErr(other)
}
ip, ok := other.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.Bool(cidr.Contains(ip.Addr))
}
func cidrContainsCIDR(arg ref.Val, other ref.Val) ref.Val {
cidr, ok := arg.(apiservercel.CIDR)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
containsCIDR, ok := other.(apiservercel.CIDR)
if !ok {
return types.MaybeNoSuchOverloadErr(other)
}
equalMasked := cidr.Prefix.Masked() == netip.PrefixFrom(containsCIDR.Prefix.Addr(), cidr.Prefix.Bits())
return types.Bool(equalMasked && cidr.Prefix.Bits() <= containsCIDR.Prefix.Bits())
}
func prefixLength(arg ref.Val) ref.Val {
cidr, ok := arg.(apiservercel.CIDR)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.Int(cidr.Prefix.Bits())
}
func isCIDR(arg ref.Val) ref.Val {
s, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
_, err := parseCIDR(s)
return types.Bool(err == nil)
}
func cidrToIP(arg ref.Val) ref.Val {
cidr, ok := arg.(apiservercel.CIDR)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return apiservercel.IP{
Addr: cidr.Prefix.Addr(),
}
}
func masked(arg ref.Val) ref.Val {
cidr, ok := arg.(apiservercel.CIDR)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
maskedCIDR := cidr.Prefix.Masked()
return apiservercel.CIDR{
Prefix: maskedCIDR,
}
}
// parseCIDR parses a string into an CIDR.
// We use this function to parse CIDR notation in the CEL library
// so that we can share the common logic of rejecting strings
// that IPv4-mapped IPv6 addresses or contain non-zero bits after the mask.
func parseCIDR(raw string) (netip.Prefix, error) {
net, err := netip.ParsePrefix(raw)
if err != nil {
return netip.Prefix{}, fmt.Errorf("network address parse error during conversion from string: %v", err)
}
if net.Addr().Is4In6() {
return netip.Prefix{}, fmt.Errorf("IPv4-mapped IPv6 address %q is not allowed", raw)
}
return net, nil
}

View File

@ -0,0 +1,276 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library_test
import (
"net/netip"
"regexp"
"testing"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/sets"
apiservercel "k8s.io/apiserver/pkg/cel"
"k8s.io/apiserver/pkg/cel/library"
)
func testCIDR(t *testing.T, expr string, expectResult ref.Val, expectRuntimeErr string, expectCompileErrs []string) {
env, err := cel.NewEnv(
library.IP(),
library.CIDR(),
)
if err != nil {
t.Fatalf("%v", err)
}
compiled, issues := env.Compile(expr)
if len(expectCompileErrs) > 0 {
missingCompileErrs := []string{}
matchedCompileErrs := sets.New[int]()
for _, expectedCompileErr := range expectCompileErrs {
compiledPattern, err := regexp.Compile(expectedCompileErr)
if err != nil {
t.Fatalf("failed to compile expected err regex: %v", err)
}
didMatch := false
for i, compileError := range issues.Errors() {
if compiledPattern.Match([]byte(compileError.Message)) {
didMatch = true
matchedCompileErrs.Insert(i)
}
}
if !didMatch {
missingCompileErrs = append(missingCompileErrs, expectedCompileErr)
} else if len(matchedCompileErrs) != len(issues.Errors()) {
unmatchedErrs := []cel.Error{}
for i, issue := range issues.Errors() {
if !matchedCompileErrs.Has(i) {
unmatchedErrs = append(unmatchedErrs, *issue)
}
}
require.Empty(t, unmatchedErrs, "unexpected compilation errors")
}
}
require.Empty(t, missingCompileErrs, "expected compilation errors")
return
} else if len(issues.Errors()) > 0 {
t.Fatalf("%v", issues.Errors())
}
prog, err := env.Program(compiled)
if err != nil {
t.Fatalf("%v", err)
}
res, _, err := prog.Eval(map[string]interface{}{})
if len(expectRuntimeErr) > 0 {
if err == nil {
t.Fatalf("no runtime error thrown. Expected: %v", expectRuntimeErr)
} else if expectRuntimeErr != err.Error() {
t.Fatalf("unexpected err: %v", err)
}
} else if err != nil {
t.Fatalf("%v", err)
} else if expectResult != nil {
converted := res.Equal(expectResult).Value().(bool)
require.True(t, converted, "expectation not equal to output")
} else {
t.Fatal("expected result must not be nil")
}
}
func TestCIDR(t *testing.T) {
ipv4CIDR, _ := netip.ParsePrefix("192.168.0.0/24")
ipv4Addr, _ := netip.ParseAddr("192.168.0.0")
ipv6CIDR, _ := netip.ParsePrefix("2001:db8::/32")
ipv6Addr, _ := netip.ParseAddr("2001:db8::")
trueVal := types.Bool(true)
falseVal := types.Bool(false)
cases := []struct {
name string
expr string
expectResult ref.Val
expectRuntimeErr string
expectCompileErrs []string
}{
{
name: "parse ipv4",
expr: `cidr("192.168.0.0/24")`,
expectResult: apiservercel.CIDR{Prefix: ipv4CIDR},
},
{
name: "parse invalid ipv4",
expr: `cidr("192.168.0.0/")`,
expectRuntimeErr: "network address parse error during conversion from string: network address parse error during conversion from string: netip.ParsePrefix(\"192.168.0.0/\"): bad bits after slash: \"\"",
},
{
name: "contains IP ipv4 (IP)",
expr: `cidr("192.168.0.0/24").containsIP(ip("192.168.0.1"))`,
expectResult: trueVal,
},
{
name: "does not contain IP ipv4 (IP)",
expr: `cidr("192.168.0.0/24").containsIP(ip("192.168.1.1"))`,
expectResult: falseVal,
},
{
name: "contains IP ipv4 (string)",
expr: `cidr("192.168.0.0/24").containsIP("192.168.0.1")`,
expectResult: trueVal,
},
{
name: "does not contain IP ipv4 (string)",
expr: `cidr("192.168.0.0/24").containsIP("192.168.1.1")`,
expectResult: falseVal,
},
{
name: "contains CIDR ipv4 (CIDR)",
expr: `cidr("192.168.0.0/24").containsCIDR(cidr("192.168.0.0/25"))`,
expectResult: trueVal,
},
{
name: "does not contain IP ipv4 (CIDR)",
expr: `cidr("192.168.0.0/24").containsCIDR(cidr("192.168.0.0/23"))`,
expectResult: falseVal,
},
{
name: "contains CIDR ipv4 (string)",
expr: `cidr("192.168.0.0/24").containsCIDR("192.168.0.0/25")`,
expectResult: trueVal,
},
{
name: "does not contain CIDR ipv4 (string)",
expr: `cidr("192.168.0.0/24").containsCIDR("192.168.0.0/23")`,
expectResult: falseVal,
},
{
name: "returns IP ipv4",
expr: `cidr("192.168.0.0/24").ip()`,
expectResult: apiservercel.IP{Addr: ipv4Addr},
},
{
name: "masks masked ipv4",
expr: `cidr("192.168.0.0/24").masked()`,
expectResult: apiservercel.CIDR{Prefix: netip.PrefixFrom(ipv4Addr, 24)},
},
{
name: "masks unmasked ipv4",
expr: `cidr("192.168.0.1/24").masked()`,
expectResult: apiservercel.CIDR{Prefix: netip.PrefixFrom(ipv4Addr, 24)},
},
{
name: "returns prefix length ipv4",
expr: `cidr("192.168.0.0/24").prefixLength()`,
expectResult: types.Int(24),
},
{
name: "parse ipv6",
expr: `cidr("2001:db8::/32")`,
expectResult: apiservercel.CIDR{Prefix: ipv6CIDR},
},
{
name: "parse invalid ipv6",
expr: `cidr("2001:db8::/")`,
expectRuntimeErr: "network address parse error during conversion from string: network address parse error during conversion from string: netip.ParsePrefix(\"2001:db8::/\"): bad bits after slash: \"\"",
},
{
name: "contains IP ipv6 (IP)",
expr: `cidr("2001:db8::/32").containsIP(ip("2001:db8::1"))`,
expectResult: trueVal,
},
{
name: "does not contain IP ipv6 (IP)",
expr: `cidr("2001:db8::/32").containsIP(ip("2001:dc8::1"))`,
expectResult: falseVal,
},
{
name: "contains IP ipv6 (string)",
expr: `cidr("2001:db8::/32").containsIP("2001:db8::1")`,
expectResult: trueVal,
},
{
name: "does not contain IP ipv6 (string)",
expr: `cidr("2001:db8::/32").containsIP("2001:dc8::1")`,
expectResult: falseVal,
},
{
name: "contains CIDR ipv6 (CIDR)",
expr: `cidr("2001:db8::/32").containsCIDR(cidr("2001:db8::/33"))`,
expectResult: trueVal,
},
{
name: "does not contain IP ipv6 (CIDR)",
expr: `cidr("2001:db8::/32").containsCIDR(cidr("2001:db8::/31"))`,
expectResult: falseVal,
},
{
name: "contains CIDR ipv6 (string)",
expr: `cidr("2001:db8::/32").containsCIDR("2001:db8::/33")`,
expectResult: trueVal,
},
{
name: "does not contain CIDR ipv6 (string)",
expr: `cidr("2001:db8::/32").containsCIDR("2001:db8::/31")`,
expectResult: falseVal,
},
{
name: "returns IP ipv6",
expr: `cidr("2001:db8::/32").ip()`,
expectResult: apiservercel.IP{Addr: ipv6Addr},
},
{
name: "masks masked ipv6",
expr: `cidr("2001:db8::/32").masked()`,
expectResult: apiservercel.CIDR{Prefix: netip.PrefixFrom(ipv6Addr, 32)},
},
{
name: "masks unmasked ipv6",
expr: `cidr("2001:db8:1::/32").masked()`,
expectResult: apiservercel.CIDR{Prefix: netip.PrefixFrom(ipv6Addr, 32)},
},
{
name: "returns prefix length ipv6",
expr: `cidr("2001:db8::/32").prefixLength()`,
expectResult: types.Int(32),
},
{
name: "converting a CIDR to a string",
expr: `string(cidr("192.168.0.0/24"))`,
expectResult: types.String("192.168.0.0/24"),
},
{
name: "type of CIDR is net.CIDR",
expr: `type(cidr("192.168.0.0/24")) == net.CIDR`,
expectResult: trueVal,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
testCIDR(t, tc.expr, tc.expectResult, tc.expectRuntimeErr, tc.expectCompileErrs)
})
}
}

View File

@ -77,6 +77,74 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re
// in length.
regexCost := uint64(math.Ceil(float64(actualSize(args[1])) * common.RegexStringLengthCostFactor))
cost := strCost * regexCost
return &cost
}
case "cidr", "isIP", "isCIDR":
// IP and CIDR parsing is a string traversal.
if len(args) >= 1 {
cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor))
return &cost
}
case "ip":
// IP and CIDR parsing is a string traversal.
if len(args) >= 1 {
if overloadId == "cidr_ip" {
// The IP member of the CIDR object is just accessing a field.
// Nominal cost.
cost := uint64(1)
return &cost
}
cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor))
return &cost
}
case "ip.isCanonical":
if len(args) >= 1 {
// We have to parse the string and then compare the parsed string to the original string.
// So we double the cost of parsing the string.
cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor))
return &cost
}
case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast":
// IP and CIDR accessors are nominal cost.
cost := uint64(1)
return &cost
case "containsIP":
if len(args) >= 2 {
cidrSize := actualSize(args[0])
otherSize := actualSize(args[1])
// This is the base cost of comparing two byte lists.
// We will compare only up to the length of the CIDR prefix in bytes, so use the cidrSize twice.
cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * common.StringTraversalCostFactor))
if overloadId == "cidr_contains_ip_string" {
// If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again.
cost += uint64(math.Ceil(float64(otherSize) * common.StringTraversalCostFactor))
}
return &cost
}
case "containsCIDR":
if len(args) >= 2 {
cidrSize := actualSize(args[0])
otherSize := actualSize(args[1])
// This is the base cost of comparing two byte lists.
// We will compare only up to the length of the CIDR prefix in bytes, so use the cidrSize twice.
cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * common.StringTraversalCostFactor))
// As we are comparing if a CIDR is within another CIDR, we first mask the base CIDR and
// also compare the CIDR bits.
// This has an additional cost of the length of the IP being traversed again, plus 1.
cost += uint64(math.Ceil(float64(cidrSize)*common.StringTraversalCostFactor)) + 1
if overloadId == "cidr_contains_cidr_string" {
// If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again.
cost += uint64(math.Ceil(float64(otherSize) * common.StringTraversalCostFactor))
}
return &cost
}
}
@ -225,6 +293,73 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
// worst case size of result is that every char is returned as separate find result.
return &checker.CallEstimate{CostEstimate: strCost.Multiply(regexCost), ResultSize: &checker.SizeEstimate{Min: 0, Max: sz.Max}}
}
case "cidr", "isIP", "isCIDR":
if target != nil {
sz := l.sizeEstimate(args[0])
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)}
}
case "ip":
if target != nil && len(args) >= 1 {
if overloadId == "cidr_ip" {
// The IP member of the CIDR object is just accessing a field.
// Nominal cost.
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
}
sz := l.sizeEstimate(args[0])
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)}
} else if target != nil {
// The IP member of a CIDR is a just accessing a field, nominal cost.
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
}
case "ip.isCanonical":
if target != nil && len(args) >= 1 {
sz := l.sizeEstimate(args[0])
// We have to parse the string and then compare the parsed string to the original string.
// So we double the cost of parsing the string.
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor)}
}
case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast":
// IP and CIDR accessors are nominal cost.
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
case "containsIP":
if target != nil && len(args) >= 1 {
// The base cost of the function is the cost of comparing two byte lists.
// The byte lists will be either ipv4 or ipv6 so will have a length of 4, or 16 bytes.
sz := checker.SizeEstimate{Min: 4, Max: 16}
// We have to compare the two strings to determine if the CIDR/IP is in the other CIDR.
ipCompCost := sz.Add(sz).MultiplyByCostFactor(common.StringTraversalCostFactor)
if overloadId == "cidr_contains_ip_string" {
// If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again.
ipCompCost = ipCompCost.Add(checker.CostEstimate(l.sizeEstimate(args[0])).MultiplyByCostFactor(common.StringTraversalCostFactor))
}
return &checker.CallEstimate{CostEstimate: ipCompCost}
}
case "containsCIDR":
if target != nil && len(args) >= 1 {
// The base cost of the function is the cost of comparing two byte lists.
// The byte lists will be either ipv4 or ipv6 so will have a length of 4, or 16 bytes.
sz := checker.SizeEstimate{Min: 4, Max: 16}
// We have to compare the two strings to determine if the CIDR/IP is in the other CIDR.
ipCompCost := sz.Add(sz).MultiplyByCostFactor(common.StringTraversalCostFactor)
// As we are comparing if a CIDR is within another CIDR, we first mask the base CIDR and
// also compare the CIDR bits.
// This has an additional cost of the length of the IP being traversed again, plus 1.
ipCompCost = ipCompCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor))
ipCompCost = ipCompCost.Add(checker.CostEstimate{Min: 1, Max: 1})
if overloadId == "cidr_contains_cidr_string" {
// If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again.
ipCompCost = ipCompCost.Add(checker.CostEstimate(l.sizeEstimate(args[0])).MultiplyByCostFactor(common.StringTraversalCostFactor))
}
return &checker.CallEstimate{CostEstimate: ipCompCost}
}
}
return nil
}

View File

@ -215,6 +215,263 @@ func TestURLsCost(t *testing.T) {
}
}
func TestIPCost(t *testing.T) {
ipv4 := "ip('192.168.0.1')"
ipv4BaseEstimatedCost := checker.CostEstimate{Min: 2, Max: 2}
ipv4BaseRuntimeCost := uint64(2)
ipv6 := "ip('2001:db8:3333:4444:5555:6666:7777:8888')"
ipv6BaseEstimatedCost := checker.CostEstimate{Min: 4, Max: 4}
ipv6BaseRuntimeCost := uint64(4)
testCases := []struct {
ops []string
expectEsimatedCost func(checker.CostEstimate) checker.CostEstimate
expectRuntimeCost func(uint64) uint64
}{
{
// For just parsing the IP, the cost is expected to be the base.
ops: []string{""},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { return c },
expectRuntimeCost: func(c uint64) uint64 { return c },
},
{
ops: []string{".family()", ".isUnspecified()", ".isLoopback()", ".isLinkLocalMulticast()", ".isLinkLocalUnicast()", ".isGlobalUnicast()"},
// For most other operations, the cost is expected to be the base + 1.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 1, Max: c.Max + 1}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 1 },
},
}
for _, tc := range testCases {
for _, op := range tc.ops {
t.Run(ipv4+op, func(t *testing.T) {
testCost(t, ipv4+op, tc.expectEsimatedCost(ipv4BaseEstimatedCost), tc.expectRuntimeCost(ipv4BaseRuntimeCost))
})
t.Run(ipv6+op, func(t *testing.T) {
testCost(t, ipv6+op, tc.expectEsimatedCost(ipv6BaseEstimatedCost), tc.expectRuntimeCost(ipv6BaseRuntimeCost))
})
}
}
}
func TestIPIsCanonicalCost(t *testing.T) {
testCases := []struct {
op string
expectEsimatedCost checker.CostEstimate
expectRuntimeCost uint64
}{
{
op: "ip.isCanonical('192.168.0.1')",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
op: "ip.isCanonical('2001:db8:3333:4444:5555:6666:7777:8888')",
expectEsimatedCost: checker.CostEstimate{Min: 8, Max: 8},
expectRuntimeCost: 8,
},
{
op: "ip.isCanonical('2001:db8::abcd')",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
}
for _, tc := range testCases {
t.Run(tc.op, func(t *testing.T) {
testCost(t, tc.op, tc.expectEsimatedCost, tc.expectRuntimeCost)
})
}
}
func TestCIDRCost(t *testing.T) {
ipv4 := "cidr('192.168.0.0/16')"
ipv4BaseEstimatedCost := checker.CostEstimate{Min: 2, Max: 2}
ipv4BaseRuntimeCost := uint64(2)
ipv6 := "cidr('2001:db8::/32')"
ipv6BaseEstimatedCost := checker.CostEstimate{Min: 2, Max: 2}
ipv6BaseRuntimeCost := uint64(2)
type testCase struct {
ops []string
expectEsimatedCost func(checker.CostEstimate) checker.CostEstimate
expectRuntimeCost func(uint64) uint64
}
cases := []testCase{
{
// For just parsing the IP, the cost is expected to be the base.
ops: []string{""},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { return c },
expectRuntimeCost: func(c uint64) uint64 { return c },
},
{
ops: []string{".ip()", ".prefixLength()", ".masked()"},
// For most other operations, the cost is expected to be the base + 1.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 1, Max: c.Max + 1}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 1 },
},
}
//nolint:gocritic
ipv4Cases := append(cases, []testCase{
{
ops: []string{".containsCIDR(cidr('192.0.0.0/30'))"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsCIDR(cidr('192.168.0.0/16'))"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsCIDR('192.0.0.0/30')"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsCIDR('192.168.0.0/16')"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsIP(ip('192.0.0.1'))"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 2, Max: c.Max + 5}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 2 },
},
{
ops: []string{".containsIP(ip('192.169.0.1'))"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 3 },
},
{
ops: []string{".containsIP(ip('192.169.169.250'))"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 3 },
},
{
ops: []string{".containsIP('192.0.0.1')"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 2, Max: c.Max + 5}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 2 },
},
{
ops: []string{".containsIP('192.169.0.1')"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 3 },
},
}...)
//nolint:gocritic
ipv6Cases := append(cases, []testCase{
{
ops: []string{".containsCIDR(cidr('2001:db8::/126'))"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsCIDR(cidr('2001:db8::/32'))"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsCIDR('2001:db8::/126')"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsCIDR('2001:db8::/32')"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsIP(ip('2001:db8:3333:4444:5555:6666:7777:8888'))"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 8}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsIP(ip('2001:db8::1'))"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 3 },
},
{
ops: []string{".containsIP('2001:db8:3333:4444:5555:6666:7777:8888')"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 8}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsIP('2001:db8::1')"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 3 },
},
}...)
for _, tc := range ipv4Cases {
for _, op := range tc.ops {
t.Run(ipv4+op, func(t *testing.T) {
testCost(t, ipv4+op, tc.expectEsimatedCost(ipv4BaseEstimatedCost), tc.expectRuntimeCost(ipv4BaseRuntimeCost))
})
}
}
for _, tc := range ipv6Cases {
for _, op := range tc.ops {
t.Run(ipv6+op, func(t *testing.T) {
testCost(t, ipv6+op, tc.expectEsimatedCost(ipv6BaseEstimatedCost), tc.expectRuntimeCost(ipv6BaseRuntimeCost))
})
}
}
}
func TestStringLibrary(t *testing.T) {
cases := []struct {
name string
@ -767,6 +1024,8 @@ func testCost(t *testing.T, expr string, expectEsimatedCost checker.CostEstimate
Authz(),
Quantity(),
ext.Sets(),
IP(),
CIDR(),
// cel-go v0.17.7 introduced CostEstimatorOptions.
// Previous the presence has a cost of 0 but cel fixed it to 1. We still set to 0 here to avoid breaking changes.
cel.CostEstimatorOptions(checker.PresenceTestHasCost(false)),

View File

@ -0,0 +1,329 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library
import (
"fmt"
"net/netip"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
apiservercel "k8s.io/apiserver/pkg/cel"
)
// IP provides a CEL function library extension of IP address parsing functions.
//
// ip
//
// Converts a string to an IP address or results in an error if the string is not a valid IP address.
// The IP address must be an IPv4 or IPv6 address.
// IPv4-mapped IPv6 addresses (e.g. ::ffff:1.2.3.4) are not allowed.
// IP addresses with zones (e.g. fe80::1%eth0) are not allowed.
// Leading zeros in IPv4 address octets are not allowed.
//
// ip(<string>) <IPAddr>
//
// Examples:
//
// ip('127.0.0.1') // returns an IPv4 address
// ip('::1') // returns an IPv6 address
// ip('127.0.0.256') // error
// ip(':::1') // error
//
// isIP
//
// Returns true if a string is a valid IP address.
// The IP address must be an IPv4 or IPv6 address.
// IPv4-mapped IPv6 addresses (e.g. ::ffff:1.2.3.4) are not allowed.
// IP addresses with zones (e.g. fe80::1%eth0) are not allowed.
// Leading zeros in IPv4 address octets are not allowed.
//
// isIP(<string>) <bool>
//
// Examples:
//
// isIP('127.0.0.1') // returns true
// isIP('::1') // returns true
// isIP('127.0.0.256') // returns false
// isIP(':::1') // returns false
//
// ip.isCanonical
//
// Returns true if the IP address is in its canonical form.
// There is exactly one canonical form for every IP address, so fields containing
// IPs in canonical form can just be treated as strings when checking for equality or uniqueness.
//
// ip.isCanonical(<string>) <bool>
//
// Examples:
//
// ip.isCanonical('127.0.0.1') // returns true; all valid IPv4 addresses are canonical
// ip.isCanonical('2001:db8::abcd') // returns true
// ip.isCanonical('2001:DB8::ABCD') // returns false
// ip.isCanonical('2001:db8::0:0:0:abcd') // returns false
//
// family / isUnspecified / isLoopback / isLinkLocalMulticast / isLinkLocalUnicast / isGlobalUnicast
//
// - family: returns the IP addresses' family (IPv4 or IPv6) as an integer, either '4' or '6'.
//
// - isUnspecified: returns true if the IP address is the unspecified address.
// Either the IPv4 address "0.0.0.0" or the IPv6 address "::".
//
// - isLoopback: returns true if the IP address is the loopback address.
// Either an IPv4 address with a value of 127.x.x.x or an IPv6 address with a value of ::1.
//
// - isLinkLocalMulticast: returns true if the IP address is a link-local multicast address.
// Either an IPv4 address with a value of 224.0.0.x or an IPv6 address in the network ff00::/8.
//
// - isLinkLocalUnicast: returns true if the IP address is a link-local unicast address.
// Either an IPv4 address with a value of 169.254.x.x or an IPv6 address in the network fe80::/10.
//
// - isGlobalUnicast: returns true if the IP address is a global unicast address.
// Either an IPv4 address that is not zero or 255.255.255.255 or an IPv6 address that is not a link-local unicast, loopback or multicast address.
//
// Examples:
//
// ip('127.0.0.1').family() // returns '4”
// ip('::1').family() // returns '6'
// ip('127.0.0.1').family() == 4 // returns true
// ip('::1').family() == 6 // returns true
// ip('0.0.0.0').isUnspecified() // returns true
// ip('127.0.0.1').isUnspecified() // returns false
// ip('::').isUnspecified() // returns true
// ip('::1').isUnspecified() // returns false
// ip('127.0.0.1').isLoopback() // returns true
// ip('192.168.0.1').isLoopback() // returns false
// ip('::1').isLoopback() // returns true
// ip('2001:db8::abcd').isLoopback() // returns false
// ip('224.0.0.1').isLinkLocalMulticast() // returns true
// ip('224.0.1.1').isLinkLocalMulticast() // returns false
// ip('ff02::1').isLinkLocalMulticast() // returns true
// ip('fd00::1').isLinkLocalMulticast() // returns false
// ip('169.254.169.254').isLinkLocalUnicast() // returns true
// ip('192.168.0.1').isLinkLocalUnicast() // returns false
// ip('fe80::1').isLinkLocalUnicast() // returns true
// ip('fd80::1').isLinkLocalUnicast() // returns false
// ip('192.168.0.1').isGlobalUnicast() // returns true
// ip('255.255.255.255').isGlobalUnicast() // returns false
// ip('2001:db8::abcd').isGlobalUnicast() // returns true
// ip('ff00::1').isGlobalUnicast() // returns false
func IP() cel.EnvOption {
return cel.Lib(ipLib)
}
var ipLib = &ip{}
type ip struct{}
func (*ip) LibraryName() string {
return "net.ip"
}
var ipLibraryDecls = map[string][]cel.FunctionOpt{
"ip": {
cel.Overload("string_to_ip", []*cel.Type{cel.StringType}, apiservercel.IPType,
cel.UnaryBinding(stringToIP)),
},
"family": {
cel.MemberOverload("ip_family", []*cel.Type{apiservercel.IPType}, cel.IntType,
cel.UnaryBinding(family)),
},
"ip.isCanonical": {
cel.Overload("ip_is_canonical", []*cel.Type{cel.StringType}, cel.BoolType,
cel.UnaryBinding(ipIsCanonical)),
},
"isUnspecified": {
cel.MemberOverload("ip_is_unspecified", []*cel.Type{apiservercel.IPType}, cel.BoolType,
cel.UnaryBinding(isUnspecified)),
},
"isLoopback": {
cel.MemberOverload("ip_is_loopback", []*cel.Type{apiservercel.IPType}, cel.BoolType,
cel.UnaryBinding(isLoopback)),
},
"isLinkLocalMulticast": {
cel.MemberOverload("ip_is_link_local_multicast", []*cel.Type{apiservercel.IPType}, cel.BoolType,
cel.UnaryBinding(isLinkLocalMulticast)),
},
"isLinkLocalUnicast": {
cel.MemberOverload("ip_is_link_local_unicast", []*cel.Type{apiservercel.IPType}, cel.BoolType,
cel.UnaryBinding(isLinkLocalUnicast)),
},
"isGlobalUnicast": {
cel.MemberOverload("ip_is_global_unicast", []*cel.Type{apiservercel.IPType}, cel.BoolType,
cel.UnaryBinding(isGlobalUnicast)),
},
"isIP": {
cel.Overload("is_ip", []*cel.Type{cel.StringType}, cel.BoolType,
cel.UnaryBinding(isIP)),
},
"string": {
cel.Overload("ip_to_string", []*cel.Type{apiservercel.IPType}, cel.StringType,
cel.UnaryBinding(ipToString)),
},
}
func (*ip) CompileOptions() []cel.EnvOption {
options := []cel.EnvOption{cel.Types(apiservercel.IPType),
cel.Variable(apiservercel.IPType.TypeName(), types.NewTypeTypeWithParam(apiservercel.IPType)),
}
for name, overloads := range ipLibraryDecls {
options = append(options, cel.Function(name, overloads...))
}
return options
}
func (*ip) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}
func stringToIP(arg ref.Val) ref.Val {
s, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
addr, err := parseIPAddr(s)
if err != nil {
// Don't add context, we control the error message already.
return types.NewErr("%v", err)
}
return apiservercel.IP{
Addr: addr,
}
}
func ipToString(arg ref.Val) ref.Val {
ip, ok := arg.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.String(ip.Addr.String())
}
func family(arg ref.Val) ref.Val {
ip, ok := arg.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
switch {
case ip.Addr.Is4():
return types.Int(4)
case ip.Addr.Is6():
return types.Int(6)
default:
return types.NewErr("IP address %q is not an IPv4 or IPv6 address", ip.Addr.String())
}
}
func ipIsCanonical(arg ref.Val) ref.Val {
s, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
addr, err := parseIPAddr(s)
if err != nil {
// Don't add context, we control the error message already.
return types.NewErr("%v", err)
}
// Addr.String() always returns the canonical form of the IP address.
// Therefore comparing this with the original string representation
// will tell us if the IP address is in its canonical form.
return types.Bool(addr.String() == s)
}
func isIP(arg ref.Val) ref.Val {
s, ok := arg.Value().(string)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
_, err := parseIPAddr(s)
return types.Bool(err == nil)
}
func isUnspecified(arg ref.Val) ref.Val {
ip, ok := arg.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.Bool(ip.Addr.IsUnspecified())
}
func isLoopback(arg ref.Val) ref.Val {
ip, ok := arg.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.Bool(ip.Addr.IsLoopback())
}
func isLinkLocalMulticast(arg ref.Val) ref.Val {
ip, ok := arg.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.Bool(ip.Addr.IsLinkLocalMulticast())
}
func isLinkLocalUnicast(arg ref.Val) ref.Val {
ip, ok := arg.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.Bool(ip.Addr.IsLinkLocalUnicast())
}
func isGlobalUnicast(arg ref.Val) ref.Val {
ip, ok := arg.(apiservercel.IP)
if !ok {
return types.MaybeNoSuchOverloadErr(arg)
}
return types.Bool(ip.Addr.IsGlobalUnicast())
}
// parseIPAddr parses a string into an IP address.
// We use this function to parse IP addresses in the CEL library
// so that we can share the common logic of rejecting IP addresses
// that contain zones or are IPv4-mapped IPv6 addresses.
func parseIPAddr(raw string) (netip.Addr, error) {
addr, err := netip.ParseAddr(raw)
if err != nil {
return netip.Addr{}, fmt.Errorf("IP Address %q parse error during conversion from string: %v", raw, err)
}
if addr.Zone() != "" {
return netip.Addr{}, fmt.Errorf("IP address %q with zone value is not allowed", raw)
}
if addr.Is4In6() {
return netip.Addr{}, fmt.Errorf("IPv4-mapped IPv6 address %q is not allowed", raw)
}
return addr, nil
}

View File

@ -0,0 +1,316 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package library_test
import (
"net/netip"
"regexp"
"testing"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/sets"
apiservercel "k8s.io/apiserver/pkg/cel"
"k8s.io/apiserver/pkg/cel/library"
)
func testIP(t *testing.T, expr string, expectResult ref.Val, expectRuntimeErr string, expectCompileErrs []string) {
env, err := cel.NewEnv(
library.IP(),
library.CIDR(),
)
if err != nil {
t.Fatalf("%v", err)
}
compiled, issues := env.Compile(expr)
if len(expectCompileErrs) > 0 {
missingCompileErrs := []string{}
matchedCompileErrs := sets.New[int]()
for _, expectedCompileErr := range expectCompileErrs {
compiledPattern, err := regexp.Compile(expectedCompileErr)
if err != nil {
t.Fatalf("failed to compile expected err regex: %v", err)
}
didMatch := false
for i, compileError := range issues.Errors() {
if compiledPattern.Match([]byte(compileError.Message)) {
didMatch = true
matchedCompileErrs.Insert(i)
}
}
if !didMatch {
missingCompileErrs = append(missingCompileErrs, expectedCompileErr)
} else if len(matchedCompileErrs) != len(issues.Errors()) {
unmatchedErrs := []cel.Error{}
for i, issue := range issues.Errors() {
if !matchedCompileErrs.Has(i) {
unmatchedErrs = append(unmatchedErrs, *issue)
}
}
require.Empty(t, unmatchedErrs, "unexpected compilation errors")
}
}
require.Empty(t, missingCompileErrs, "expected compilation errors")
return
} else if len(issues.Errors()) > 0 {
t.Fatalf("%v", issues.Errors())
}
prog, err := env.Program(compiled)
if err != nil {
t.Fatalf("%v", err)
}
res, _, err := prog.Eval(map[string]interface{}{})
if len(expectRuntimeErr) > 0 {
if err == nil {
t.Fatalf("no runtime error thrown. Expected: %v", expectRuntimeErr)
} else if expectRuntimeErr != err.Error() {
t.Fatalf("unexpected err: %v", err)
}
} else if err != nil {
t.Fatalf("%v", err)
} else if expectResult != nil {
converted := res.Equal(expectResult).Value().(bool)
require.True(t, converted, "expectation not equal to output")
} else {
t.Fatal("expected result must not be nil")
}
}
func TestIP(t *testing.T) {
ipv4Addr, _ := netip.ParseAddr("192.168.0.1")
int4 := types.Int(4)
ipv6Addr, _ := netip.ParseAddr("2001:db8::68")
int6 := types.Int(6)
trueVal := types.Bool(true)
falseVal := types.Bool(false)
cases := []struct {
name string
expr string
expectResult ref.Val
expectRuntimeErr string
expectCompileErrs []string
}{
{
name: "parse ipv4",
expr: `ip("192.168.0.1")`,
expectResult: apiservercel.IP{Addr: ipv4Addr},
},
{
name: "parse invalid ipv4",
expr: `ip("192.168.0.1.0")`,
expectRuntimeErr: "IP Address \"192.168.0.1.0\" parse error during conversion from string: ParseAddr(\"192.168.0.1.0\"): IPv4 address too long",
},
{
name: "isIP valid ipv4",
expr: `isIP("192.168.0.1")`,
expectResult: trueVal,
},
{
name: "isIP invalid ipv4",
expr: `isIP("192.168.0.1.0")`,
expectResult: falseVal,
},
{
name: "ip.isCanonical valid ipv4",
expr: `ip.isCanonical("127.0.0.1")`,
expectResult: trueVal,
},
{
name: "ip.isCanonical invalid ipv4",
expr: `ip.isCanonical("127.0.0.1.0")`,
expectRuntimeErr: "IP Address \"127.0.0.1.0\" parse error during conversion from string: ParseAddr(\"127.0.0.1.0\"): IPv4 address too long",
},
{
name: "ipv4 family",
expr: `ip("192.168.0.1").family()`,
expectResult: int4,
},
{
name: "ipv4 isUnspecified true",
expr: `ip("0.0.0.0").isUnspecified()`,
expectResult: trueVal,
},
{
name: "ipv4 isUnspecified false",
expr: `ip("127.0.0.1").isUnspecified()`,
expectResult: falseVal,
},
{
name: "ipv4 isLoopback true",
expr: `ip("127.0.0.1").isLoopback()`,
expectResult: trueVal,
},
{
name: "ipv4 isLoopback false",
expr: `ip("1.2.3.4").isLoopback()`,
expectResult: falseVal,
},
{
name: "ipv4 isLinkLocalMulticast true",
expr: `ip("224.0.0.1").isLinkLocalMulticast()`,
expectResult: trueVal,
},
{
name: "ipv4 isLinkLocalMulticast false",
expr: `ip("224.0.1.1").isLinkLocalMulticast()`,
expectResult: falseVal,
},
{
name: "ipv4 isLinkLocalUnicast true",
expr: `ip("169.254.169.254").isLinkLocalUnicast()`,
expectResult: trueVal,
},
{
name: "ipv4 isLinkLocalUnicast false",
expr: `ip("192.168.0.1").isLinkLocalUnicast()`,
expectResult: falseVal,
},
{
name: "ipv4 isGlobalUnicast true",
expr: `ip("192.168.0.1").isGlobalUnicast()`,
expectResult: trueVal,
},
{
name: "ipv4 isGlobalUnicast false",
expr: `ip("255.255.255.255").isGlobalUnicast()`,
expectResult: falseVal,
},
{
name: "parse ipv6",
expr: `ip("2001:db8::68")`,
expectResult: apiservercel.IP{Addr: ipv6Addr},
},
{
name: "parse invalid ipv6",
expr: `ip("2001:db8:::68")`,
expectRuntimeErr: "IP Address \"2001:db8:::68\" parse error during conversion from string: ParseAddr(\"2001:db8:::68\"): each colon-separated field must have at least one digit (at \":68\")",
},
{
name: "isIP valid ipv6",
expr: `isIP("2001:db8::68")`,
expectResult: trueVal,
},
{
name: "isIP invalid ipv4",
expr: `isIP("2001:db8:::68")`,
expectResult: falseVal,
},
{
name: "ip.isCanonical valid ipv6",
expr: `ip.isCanonical("2001:db8::68")`,
expectResult: trueVal,
},
{
name: "ip.isCanonical non-canonical ipv6",
expr: `ip.isCanonical("2001:DB8::68")`,
expectResult: falseVal,
},
{
name: "ip.isCanonical invalid ipv6",
expr: `ip.isCanonical("2001:db8:::68")`,
expectRuntimeErr: "IP Address \"2001:db8:::68\" parse error during conversion from string: ParseAddr(\"2001:db8:::68\"): each colon-separated field must have at least one digit (at \":68\")",
},
{
name: "ipv6 family",
expr: `ip("2001:db8::68").family()`,
expectResult: int6,
},
{
name: "ipv6 isUnspecified true",
expr: `ip("::").isUnspecified()`,
expectResult: trueVal,
},
{
name: "ipv6 isUnspecified false",
expr: `ip("::1").isUnspecified()`,
expectResult: falseVal,
},
{
name: "ipv6 isLoopback true",
expr: `ip("::1").isLoopback()`,
expectResult: trueVal,
},
{
name: "ipv6 isLoopback false",
expr: `ip("2001:db8::abcd").isLoopback()`,
expectResult: falseVal,
},
{
name: "ipv6 isLinkLocalMulticast true",
expr: `ip("ff02::1").isLinkLocalMulticast()`,
expectResult: trueVal,
},
{
name: "ipv6 isLinkLocalMulticast false",
expr: `ip("fd00::1").isLinkLocalMulticast()`,
expectResult: falseVal,
},
{
name: "ipv6 isLinkLocalUnicast true",
expr: `ip("fe80::1").isLinkLocalUnicast()`,
expectResult: trueVal,
},
{
name: "ipv6 isLinkLocalUnicast false",
expr: `ip("fd80::1").isLinkLocalUnicast()`,
expectResult: falseVal,
},
{
name: "ipv6 isGlobalUnicast true",
expr: `ip("2001:db8::abcd").isGlobalUnicast()`,
expectResult: trueVal,
},
{
name: "ipv6 isGlobalUnicast false",
expr: `ip("ff00::1").isGlobalUnicast()`,
expectResult: falseVal,
},
{
name: "passing cidr into isIP returns compile error",
expr: `isIP(cidr("192.168.0.0/24"))`,
expectCompileErrs: []string{"found no matching overload for 'isIP' applied to '\\(net.CIDR\\)'"},
},
{
name: "converting an IP address to a string",
expr: `string(ip("192.168.0.1"))`,
expectResult: types.String("192.168.0.1"),
},
{
name: "type of IP is net.IP",
expr: `type(ip("192.168.0.1")) == net.IP`,
expectResult: trueVal,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
testIP(t, tc.expr, tc.expectResult, tc.expectRuntimeErr, tc.expectCompileErrs)
})
}
}

View File

@ -26,7 +26,7 @@ import (
func TestLibraryCompatibility(t *testing.T) {
var libs []map[string][]cel.FunctionOpt
libs = append(libs, authzLibraryDecls, listsLibraryDecls, regexLibraryDecls, urlLibraryDecls, quantityLibraryDecls)
libs = append(libs, authzLibraryDecls, listsLibraryDecls, regexLibraryDecls, urlLibraryDecls, quantityLibraryDecls, ipLibraryDecls, cidrLibraryDecls)
functionNames := sets.New[string]()
for _, lib := range libs {
for name := range lib {
@ -47,6 +47,8 @@ func TestLibraryCompatibility(t *testing.T) {
"errored", "error",
// Kubernetes <1.29>:
"add", "asApproximateFloat", "asInteger", "compareTo", "isGreaterThan", "isInteger", "isLessThan", "isQuantity", "quantity", "sign", "sub",
// Kubernetes <1.30>:
"ip", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast", "isGlobalUnicast", "ip.isCanonical", "isIP", "cidr", "containsIP", "containsCIDR", "masked", "prefixLength", "isCIDR", "string",
// Kubernetes <1.??>:
)