From e1f9aa450b7ecd62ce7284486a159d14f66c1761 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Fri, 17 Nov 2023 17:34:46 +0000 Subject: [PATCH] Add costing estimations for IP and CIDR --- .../k8s.io/apiserver/pkg/cel/library/cost.go | 135 +++++++++ .../apiserver/pkg/cel/library/cost_test.go | 259 ++++++++++++++++++ 2 files changed, 394 insertions(+) diff --git a/staging/src/k8s.io/apiserver/pkg/cel/library/cost.go b/staging/src/k8s.io/apiserver/pkg/cel/library/cost.go index d18c138ec8f..a723937f19f 100644 --- a/staging/src/k8s.io/apiserver/pkg/cel/library/cost.go +++ b/staging/src/k8s.io/apiserver/pkg/cel/library/cost.go @@ -77,6 +77,74 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re // in length. regexCost := uint64(math.Ceil(float64(actualSize(args[1])) * common.RegexStringLengthCostFactor)) cost := strCost * regexCost + return &cost + } + case "cidr", "isIP", "isCIDR": + // IP and CIDR parsing is a string traversal. + if len(args) >= 1 { + cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor)) + return &cost + } + case "ip": + // IP and CIDR parsing is a string traversal. + if len(args) >= 1 { + if overloadId == "cidr_ip" { + // The IP member of the CIDR object is just accessing a field. + // Nominal cost. + cost := uint64(1) + return &cost + } + + cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor)) + return &cost + } + case "ip.isCanonical": + if len(args) >= 1 { + // We have to parse the string and then compare the parsed string to the original string. + // So we double the cost of parsing the string. + cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor)) + return &cost + } + case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast": + // IP and CIDR accessors are nominal cost. + cost := uint64(1) + return &cost + case "containsIP": + if len(args) >= 2 { + cidrSize := actualSize(args[0]) + otherSize := actualSize(args[1]) + + // This is the base cost of comparing two byte lists. + // We will compare only up to the length of the CIDR prefix in bytes, so use the cidrSize twice. + cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * common.StringTraversalCostFactor)) + + if overloadId == "cidr_contains_ip_string" { + // If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again. + cost += uint64(math.Ceil(float64(otherSize) * common.StringTraversalCostFactor)) + + } + + return &cost + } + case "containsCIDR": + if len(args) >= 2 { + cidrSize := actualSize(args[0]) + otherSize := actualSize(args[1]) + + // This is the base cost of comparing two byte lists. + // We will compare only up to the length of the CIDR prefix in bytes, so use the cidrSize twice. + cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * common.StringTraversalCostFactor)) + + // As we are comparing if a CIDR is within another CIDR, we first mask the base CIDR and + // also compare the CIDR bits. + // This has an additional cost of the length of the IP being traversed again, plus 1. + cost += uint64(math.Ceil(float64(cidrSize)*common.StringTraversalCostFactor)) + 1 + + if overloadId == "cidr_contains_cidr_string" { + // If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again. + cost += uint64(math.Ceil(float64(otherSize) * common.StringTraversalCostFactor)) + } + return &cost } } @@ -225,6 +293,73 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch // worst case size of result is that every char is returned as separate find result. return &checker.CallEstimate{CostEstimate: strCost.Multiply(regexCost), ResultSize: &checker.SizeEstimate{Min: 0, Max: sz.Max}} } + case "cidr", "isIP", "isCIDR": + if target != nil { + sz := l.sizeEstimate(args[0]) + return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)} + } + case "ip": + if target != nil && len(args) >= 1 { + if overloadId == "cidr_ip" { + // The IP member of the CIDR object is just accessing a field. + // Nominal cost. + return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} + } + + sz := l.sizeEstimate(args[0]) + return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)} + } else if target != nil { + // The IP member of a CIDR is a just accessing a field, nominal cost. + return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} + } + case "ip.isCanonical": + if target != nil && len(args) >= 1 { + sz := l.sizeEstimate(args[0]) + // We have to parse the string and then compare the parsed string to the original string. + // So we double the cost of parsing the string. + return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor)} + } + case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast": + // IP and CIDR accessors are nominal cost. + return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} + case "containsIP": + if target != nil && len(args) >= 1 { + // The base cost of the function is the cost of comparing two byte lists. + // The byte lists will be either ipv4 or ipv6 so will have a length of 4, or 16 bytes. + sz := checker.SizeEstimate{Min: 4, Max: 16} + + // We have to compare the two strings to determine if the CIDR/IP is in the other CIDR. + ipCompCost := sz.Add(sz).MultiplyByCostFactor(common.StringTraversalCostFactor) + + if overloadId == "cidr_contains_ip_string" { + // If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again. + ipCompCost = ipCompCost.Add(checker.CostEstimate(l.sizeEstimate(args[0])).MultiplyByCostFactor(common.StringTraversalCostFactor)) + } + + return &checker.CallEstimate{CostEstimate: ipCompCost} + } + case "containsCIDR": + if target != nil && len(args) >= 1 { + // The base cost of the function is the cost of comparing two byte lists. + // The byte lists will be either ipv4 or ipv6 so will have a length of 4, or 16 bytes. + sz := checker.SizeEstimate{Min: 4, Max: 16} + + // We have to compare the two strings to determine if the CIDR/IP is in the other CIDR. + ipCompCost := sz.Add(sz).MultiplyByCostFactor(common.StringTraversalCostFactor) + + // As we are comparing if a CIDR is within another CIDR, we first mask the base CIDR and + // also compare the CIDR bits. + // This has an additional cost of the length of the IP being traversed again, plus 1. + ipCompCost = ipCompCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor)) + ipCompCost = ipCompCost.Add(checker.CostEstimate{Min: 1, Max: 1}) + + if overloadId == "cidr_contains_cidr_string" { + // If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again. + ipCompCost = ipCompCost.Add(checker.CostEstimate(l.sizeEstimate(args[0])).MultiplyByCostFactor(common.StringTraversalCostFactor)) + } + + return &checker.CallEstimate{CostEstimate: ipCompCost} + } } return nil } diff --git a/staging/src/k8s.io/apiserver/pkg/cel/library/cost_test.go b/staging/src/k8s.io/apiserver/pkg/cel/library/cost_test.go index efe8b42d425..03ba0c504b5 100644 --- a/staging/src/k8s.io/apiserver/pkg/cel/library/cost_test.go +++ b/staging/src/k8s.io/apiserver/pkg/cel/library/cost_test.go @@ -215,6 +215,263 @@ func TestURLsCost(t *testing.T) { } } +func TestIPCost(t *testing.T) { + ipv4 := "ip('192.168.0.1')" + ipv4BaseEstimatedCost := checker.CostEstimate{Min: 2, Max: 2} + ipv4BaseRuntimeCost := uint64(2) + + ipv6 := "ip('2001:db8:3333:4444:5555:6666:7777:8888')" + ipv6BaseEstimatedCost := checker.CostEstimate{Min: 4, Max: 4} + ipv6BaseRuntimeCost := uint64(4) + + testCases := []struct { + ops []string + expectEsimatedCost func(checker.CostEstimate) checker.CostEstimate + expectRuntimeCost func(uint64) uint64 + }{ + { + // For just parsing the IP, the cost is expected to be the base. + ops: []string{""}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { return c }, + expectRuntimeCost: func(c uint64) uint64 { return c }, + }, + { + ops: []string{".family()", ".isUnspecified()", ".isLoopback()", ".isLinkLocalMulticast()", ".isLinkLocalUnicast()", ".isGlobalUnicast()"}, + // For most other operations, the cost is expected to be the base + 1. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 1, Max: c.Max + 1} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 1 }, + }, + } + + for _, tc := range testCases { + for _, op := range tc.ops { + t.Run(ipv4+op, func(t *testing.T) { + testCost(t, ipv4+op, tc.expectEsimatedCost(ipv4BaseEstimatedCost), tc.expectRuntimeCost(ipv4BaseRuntimeCost)) + }) + + t.Run(ipv6+op, func(t *testing.T) { + testCost(t, ipv6+op, tc.expectEsimatedCost(ipv6BaseEstimatedCost), tc.expectRuntimeCost(ipv6BaseRuntimeCost)) + }) + } + } +} + +func TestIPIsCanonicalCost(t *testing.T) { + testCases := []struct { + op string + expectEsimatedCost checker.CostEstimate + expectRuntimeCost uint64 + }{ + { + op: "ip.isCanonical('192.168.0.1')", + expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectRuntimeCost: 3, + }, + { + op: "ip.isCanonical('2001:db8:3333:4444:5555:6666:7777:8888')", + expectEsimatedCost: checker.CostEstimate{Min: 8, Max: 8}, + expectRuntimeCost: 8, + }, + { + op: "ip.isCanonical('2001:db8::abcd')", + expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3}, + expectRuntimeCost: 3, + }, + } + + for _, tc := range testCases { + t.Run(tc.op, func(t *testing.T) { + testCost(t, tc.op, tc.expectEsimatedCost, tc.expectRuntimeCost) + }) + } +} + +func TestCIDRCost(t *testing.T) { + ipv4 := "cidr('192.168.0.0/16')" + ipv4BaseEstimatedCost := checker.CostEstimate{Min: 2, Max: 2} + ipv4BaseRuntimeCost := uint64(2) + + ipv6 := "cidr('2001:db8::/32')" + ipv6BaseEstimatedCost := checker.CostEstimate{Min: 2, Max: 2} + ipv6BaseRuntimeCost := uint64(2) + + type testCase struct { + ops []string + expectEsimatedCost func(checker.CostEstimate) checker.CostEstimate + expectRuntimeCost func(uint64) uint64 + } + + cases := []testCase{ + { + // For just parsing the IP, the cost is expected to be the base. + ops: []string{""}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { return c }, + expectRuntimeCost: func(c uint64) uint64 { return c }, + }, + { + ops: []string{".ip()", ".prefixLength()", ".masked()"}, + // For most other operations, the cost is expected to be the base + 1. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 1, Max: c.Max + 1} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 1 }, + }, + } + + //nolint:gocritic + ipv4Cases := append(cases, []testCase{ + { + ops: []string{".containsCIDR(cidr('192.0.0.0/30'))"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsCIDR(cidr('192.168.0.0/16'))"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsCIDR('192.0.0.0/30')"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsCIDR('192.168.0.0/16')"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsIP(ip('192.0.0.1'))"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 2, Max: c.Max + 5} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 2 }, + }, + { + ops: []string{".containsIP(ip('192.169.0.1'))"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 3 }, + }, + { + ops: []string{".containsIP(ip('192.169.169.250'))"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 3 }, + }, + { + ops: []string{".containsIP('192.0.0.1')"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 2, Max: c.Max + 5} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 2 }, + }, + { + ops: []string{".containsIP('192.169.0.1')"}, + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 3 }, + }, + }...) + + //nolint:gocritic + ipv6Cases := append(cases, []testCase{ + { + ops: []string{".containsCIDR(cidr('2001:db8::/126'))"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsCIDR(cidr('2001:db8::/32'))"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsCIDR('2001:db8::/126')"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsCIDR('2001:db8::/32')"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsIP(ip('2001:db8:3333:4444:5555:6666:7777:8888'))"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 8} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsIP(ip('2001:db8::1'))"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 3 }, + }, + { + ops: []string{".containsIP('2001:db8:3333:4444:5555:6666:7777:8888')"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 8} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 5 }, + }, + { + ops: []string{".containsIP('2001:db8::1')"}, + // For operations like checking if an IP is in a CIDR, the cost is expected to higher. + expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { + return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6} + }, + expectRuntimeCost: func(c uint64) uint64 { return c + 3 }, + }, + }...) + + for _, tc := range ipv4Cases { + for _, op := range tc.ops { + t.Run(ipv4+op, func(t *testing.T) { + testCost(t, ipv4+op, tc.expectEsimatedCost(ipv4BaseEstimatedCost), tc.expectRuntimeCost(ipv4BaseRuntimeCost)) + }) + } + } + + for _, tc := range ipv6Cases { + for _, op := range tc.ops { + t.Run(ipv6+op, func(t *testing.T) { + testCost(t, ipv6+op, tc.expectEsimatedCost(ipv6BaseEstimatedCost), tc.expectRuntimeCost(ipv6BaseRuntimeCost)) + }) + } + } +} + func TestStringLibrary(t *testing.T) { cases := []struct { name string @@ -767,6 +1024,8 @@ func testCost(t *testing.T, expr string, expectEsimatedCost checker.CostEstimate Authz(), Quantity(), ext.Sets(), + IP(), + CIDR(), // cel-go v0.17.7 introduced CostEstimatorOptions. // Previous the presence has a cost of 0 but cel fixed it to 1. We still set to 0 here to avoid breaking changes. cel.CostEstimatorOptions(checker.PresenceTestHasCost(false)),