Add costing estimations for IP and CIDR

This commit is contained in:
Joel Speed 2023-11-17 17:34:46 +00:00
parent 4710f085b3
commit e1f9aa450b
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
2 changed files with 394 additions and 0 deletions

View File

@ -77,6 +77,74 @@ func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, re
// in length.
regexCost := uint64(math.Ceil(float64(actualSize(args[1])) * common.RegexStringLengthCostFactor))
cost := strCost * regexCost
return &cost
}
case "cidr", "isIP", "isCIDR":
// IP and CIDR parsing is a string traversal.
if len(args) >= 1 {
cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor))
return &cost
}
case "ip":
// IP and CIDR parsing is a string traversal.
if len(args) >= 1 {
if overloadId == "cidr_ip" {
// The IP member of the CIDR object is just accessing a field.
// Nominal cost.
cost := uint64(1)
return &cost
}
cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor))
return &cost
}
case "ip.isCanonical":
if len(args) >= 1 {
// We have to parse the string and then compare the parsed string to the original string.
// So we double the cost of parsing the string.
cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor))
return &cost
}
case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast":
// IP and CIDR accessors are nominal cost.
cost := uint64(1)
return &cost
case "containsIP":
if len(args) >= 2 {
cidrSize := actualSize(args[0])
otherSize := actualSize(args[1])
// This is the base cost of comparing two byte lists.
// We will compare only up to the length of the CIDR prefix in bytes, so use the cidrSize twice.
cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * common.StringTraversalCostFactor))
if overloadId == "cidr_contains_ip_string" {
// If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again.
cost += uint64(math.Ceil(float64(otherSize) * common.StringTraversalCostFactor))
}
return &cost
}
case "containsCIDR":
if len(args) >= 2 {
cidrSize := actualSize(args[0])
otherSize := actualSize(args[1])
// This is the base cost of comparing two byte lists.
// We will compare only up to the length of the CIDR prefix in bytes, so use the cidrSize twice.
cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * common.StringTraversalCostFactor))
// As we are comparing if a CIDR is within another CIDR, we first mask the base CIDR and
// also compare the CIDR bits.
// This has an additional cost of the length of the IP being traversed again, plus 1.
cost += uint64(math.Ceil(float64(cidrSize)*common.StringTraversalCostFactor)) + 1
if overloadId == "cidr_contains_cidr_string" {
// If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again.
cost += uint64(math.Ceil(float64(otherSize) * common.StringTraversalCostFactor))
}
return &cost
}
}
@ -225,6 +293,73 @@ func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *ch
// worst case size of result is that every char is returned as separate find result.
return &checker.CallEstimate{CostEstimate: strCost.Multiply(regexCost), ResultSize: &checker.SizeEstimate{Min: 0, Max: sz.Max}}
}
case "cidr", "isIP", "isCIDR":
if target != nil {
sz := l.sizeEstimate(args[0])
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)}
}
case "ip":
if target != nil && len(args) >= 1 {
if overloadId == "cidr_ip" {
// The IP member of the CIDR object is just accessing a field.
// Nominal cost.
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
}
sz := l.sizeEstimate(args[0])
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)}
} else if target != nil {
// The IP member of a CIDR is a just accessing a field, nominal cost.
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
}
case "ip.isCanonical":
if target != nil && len(args) >= 1 {
sz := l.sizeEstimate(args[0])
// We have to parse the string and then compare the parsed string to the original string.
// So we double the cost of parsing the string.
return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor)}
}
case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast":
// IP and CIDR accessors are nominal cost.
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}}
case "containsIP":
if target != nil && len(args) >= 1 {
// The base cost of the function is the cost of comparing two byte lists.
// The byte lists will be either ipv4 or ipv6 so will have a length of 4, or 16 bytes.
sz := checker.SizeEstimate{Min: 4, Max: 16}
// We have to compare the two strings to determine if the CIDR/IP is in the other CIDR.
ipCompCost := sz.Add(sz).MultiplyByCostFactor(common.StringTraversalCostFactor)
if overloadId == "cidr_contains_ip_string" {
// If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again.
ipCompCost = ipCompCost.Add(checker.CostEstimate(l.sizeEstimate(args[0])).MultiplyByCostFactor(common.StringTraversalCostFactor))
}
return &checker.CallEstimate{CostEstimate: ipCompCost}
}
case "containsCIDR":
if target != nil && len(args) >= 1 {
// The base cost of the function is the cost of comparing two byte lists.
// The byte lists will be either ipv4 or ipv6 so will have a length of 4, or 16 bytes.
sz := checker.SizeEstimate{Min: 4, Max: 16}
// We have to compare the two strings to determine if the CIDR/IP is in the other CIDR.
ipCompCost := sz.Add(sz).MultiplyByCostFactor(common.StringTraversalCostFactor)
// As we are comparing if a CIDR is within another CIDR, we first mask the base CIDR and
// also compare the CIDR bits.
// This has an additional cost of the length of the IP being traversed again, plus 1.
ipCompCost = ipCompCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor))
ipCompCost = ipCompCost.Add(checker.CostEstimate{Min: 1, Max: 1})
if overloadId == "cidr_contains_cidr_string" {
// If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again.
ipCompCost = ipCompCost.Add(checker.CostEstimate(l.sizeEstimate(args[0])).MultiplyByCostFactor(common.StringTraversalCostFactor))
}
return &checker.CallEstimate{CostEstimate: ipCompCost}
}
}
return nil
}

View File

@ -215,6 +215,263 @@ func TestURLsCost(t *testing.T) {
}
}
func TestIPCost(t *testing.T) {
ipv4 := "ip('192.168.0.1')"
ipv4BaseEstimatedCost := checker.CostEstimate{Min: 2, Max: 2}
ipv4BaseRuntimeCost := uint64(2)
ipv6 := "ip('2001:db8:3333:4444:5555:6666:7777:8888')"
ipv6BaseEstimatedCost := checker.CostEstimate{Min: 4, Max: 4}
ipv6BaseRuntimeCost := uint64(4)
testCases := []struct {
ops []string
expectEsimatedCost func(checker.CostEstimate) checker.CostEstimate
expectRuntimeCost func(uint64) uint64
}{
{
// For just parsing the IP, the cost is expected to be the base.
ops: []string{""},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { return c },
expectRuntimeCost: func(c uint64) uint64 { return c },
},
{
ops: []string{".family()", ".isUnspecified()", ".isLoopback()", ".isLinkLocalMulticast()", ".isLinkLocalUnicast()", ".isGlobalUnicast()"},
// For most other operations, the cost is expected to be the base + 1.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 1, Max: c.Max + 1}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 1 },
},
}
for _, tc := range testCases {
for _, op := range tc.ops {
t.Run(ipv4+op, func(t *testing.T) {
testCost(t, ipv4+op, tc.expectEsimatedCost(ipv4BaseEstimatedCost), tc.expectRuntimeCost(ipv4BaseRuntimeCost))
})
t.Run(ipv6+op, func(t *testing.T) {
testCost(t, ipv6+op, tc.expectEsimatedCost(ipv6BaseEstimatedCost), tc.expectRuntimeCost(ipv6BaseRuntimeCost))
})
}
}
}
func TestIPIsCanonicalCost(t *testing.T) {
testCases := []struct {
op string
expectEsimatedCost checker.CostEstimate
expectRuntimeCost uint64
}{
{
op: "ip.isCanonical('192.168.0.1')",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
{
op: "ip.isCanonical('2001:db8:3333:4444:5555:6666:7777:8888')",
expectEsimatedCost: checker.CostEstimate{Min: 8, Max: 8},
expectRuntimeCost: 8,
},
{
op: "ip.isCanonical('2001:db8::abcd')",
expectEsimatedCost: checker.CostEstimate{Min: 3, Max: 3},
expectRuntimeCost: 3,
},
}
for _, tc := range testCases {
t.Run(tc.op, func(t *testing.T) {
testCost(t, tc.op, tc.expectEsimatedCost, tc.expectRuntimeCost)
})
}
}
func TestCIDRCost(t *testing.T) {
ipv4 := "cidr('192.168.0.0/16')"
ipv4BaseEstimatedCost := checker.CostEstimate{Min: 2, Max: 2}
ipv4BaseRuntimeCost := uint64(2)
ipv6 := "cidr('2001:db8::/32')"
ipv6BaseEstimatedCost := checker.CostEstimate{Min: 2, Max: 2}
ipv6BaseRuntimeCost := uint64(2)
type testCase struct {
ops []string
expectEsimatedCost func(checker.CostEstimate) checker.CostEstimate
expectRuntimeCost func(uint64) uint64
}
cases := []testCase{
{
// For just parsing the IP, the cost is expected to be the base.
ops: []string{""},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate { return c },
expectRuntimeCost: func(c uint64) uint64 { return c },
},
{
ops: []string{".ip()", ".prefixLength()", ".masked()"},
// For most other operations, the cost is expected to be the base + 1.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 1, Max: c.Max + 1}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 1 },
},
}
//nolint:gocritic
ipv4Cases := append(cases, []testCase{
{
ops: []string{".containsCIDR(cidr('192.0.0.0/30'))"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsCIDR(cidr('192.168.0.0/16'))"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsCIDR('192.0.0.0/30')"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsCIDR('192.168.0.0/16')"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsIP(ip('192.0.0.1'))"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 2, Max: c.Max + 5}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 2 },
},
{
ops: []string{".containsIP(ip('192.169.0.1'))"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 3 },
},
{
ops: []string{".containsIP(ip('192.169.169.250'))"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 3 },
},
{
ops: []string{".containsIP('192.0.0.1')"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 2, Max: c.Max + 5}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 2 },
},
{
ops: []string{".containsIP('192.169.0.1')"},
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 3 },
},
}...)
//nolint:gocritic
ipv6Cases := append(cases, []testCase{
{
ops: []string{".containsCIDR(cidr('2001:db8::/126'))"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsCIDR(cidr('2001:db8::/32'))"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsCIDR('2001:db8::/126')"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsCIDR('2001:db8::/32')"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 9}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsIP(ip('2001:db8:3333:4444:5555:6666:7777:8888'))"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 8}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsIP(ip('2001:db8::1'))"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 3 },
},
{
ops: []string{".containsIP('2001:db8:3333:4444:5555:6666:7777:8888')"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 5, Max: c.Max + 8}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 5 },
},
{
ops: []string{".containsIP('2001:db8::1')"},
// For operations like checking if an IP is in a CIDR, the cost is expected to higher.
expectEsimatedCost: func(c checker.CostEstimate) checker.CostEstimate {
return checker.CostEstimate{Min: c.Min + 3, Max: c.Max + 6}
},
expectRuntimeCost: func(c uint64) uint64 { return c + 3 },
},
}...)
for _, tc := range ipv4Cases {
for _, op := range tc.ops {
t.Run(ipv4+op, func(t *testing.T) {
testCost(t, ipv4+op, tc.expectEsimatedCost(ipv4BaseEstimatedCost), tc.expectRuntimeCost(ipv4BaseRuntimeCost))
})
}
}
for _, tc := range ipv6Cases {
for _, op := range tc.ops {
t.Run(ipv6+op, func(t *testing.T) {
testCost(t, ipv6+op, tc.expectEsimatedCost(ipv6BaseEstimatedCost), tc.expectRuntimeCost(ipv6BaseRuntimeCost))
})
}
}
}
func TestStringLibrary(t *testing.T) {
cases := []struct {
name string
@ -767,6 +1024,8 @@ func testCost(t *testing.T, expr string, expectEsimatedCost checker.CostEstimate
Authz(),
Quantity(),
ext.Sets(),
IP(),
CIDR(),
// cel-go v0.17.7 introduced CostEstimatorOptions.
// Previous the presence has a cost of 0 but cel fixed it to 1. We still set to 0 here to avoid breaking changes.
cel.CostEstimatorOptions(checker.PresenceTestHasCost(false)),