diff --git a/pkg/util/iptree/iptree.go b/pkg/util/iptree/iptree.go new file mode 100644 index 00000000000..441062c58d2 --- /dev/null +++ b/pkg/util/iptree/iptree.go @@ -0,0 +1,679 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package iptree + +import ( + "fmt" + "math/bits" + "net/netip" +) + +// iptree implement a radix tree that uses IP prefixes as nodes and allows to store values in each node. +// Example: +// +// r := New[int]() +// +// prefixes := []string{ +// "0.0.0.0/0", +// "10.0.0.0/8", +// "10.0.0.0/16", +// "10.1.0.0/16", +// "10.1.1.0/24", +// "10.1.244.0/24", +// "10.0.0.0/24", +// "10.0.0.3/32", +// "192.168.0.0/24", +// "192.168.0.0/28", +// "192.168.129.0/28", +// } +// for _, k := range prefixes { +// r.InsertPrefix(netip.MustParsePrefix(k), 0) +// } +// +// (*) means the node is not public, is not storing any value +// +// 0.0.0.0/0 --- 10.0.0.0/8 --- *10.0.0.0/15 --- 10.0.0.0/16 --- 10.0.0.0/24 --- 10.0.0.3/32 +// | | +// | \ -------- 10.1.0.0/16 --- 10.1.1.0/24 +// | | +// | \ ------- 10.1.244.0/24 +// | +// \------ *192.168.0.0/16 --- 192.168.0.0/24 --- 192.168.0.0/28 +// | +// \ -------- 192.168.129.0/28 + +// node is an element of radix tree with a netip.Prefix optimized to store IP prefixes. +type node[T any] struct { + // prefix network CIDR + prefix netip.Prefix + // public nodes are used to store values + public bool + val T + + child [2]*node[T] // binary tree +} + +// mergeChild allow to compress the tree +// when n has exactly one child and no value +// p -> n -> b -> c ==> p -> b -> c +func (n *node[T]) mergeChild() { + // public nodes can not be merged + if n.public { + return + } + // can not merge if there are two children + if n.child[0] != nil && + n.child[1] != nil { + return + } + // can not merge if there are no children + if n.child[0] == nil && + n.child[1] == nil { + return + } + // find the child and merge it + var child *node[T] + if n.child[0] != nil { + child = n.child[0] + } else if n.child[1] != nil { + child = n.child[1] + } + n.prefix = child.prefix + n.public = child.public + n.val = child.val + n.child = child.child + // remove any references from the deleted node + // to avoid memory leak + child.child[0] = nil + child.child[1] = nil +} + +// Tree is a radix tree for IPv4 and IPv6 networks. +type Tree[T any] struct { + rootV4 *node[T] + rootV6 *node[T] +} + +// New creates a new Radix Tree for IP addresses. +func New[T any]() *Tree[T] { + return &Tree[T]{ + rootV4: &node[T]{ + prefix: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + }, + rootV6: &node[T]{ + prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0), + }, + } +} + +// GetPrefix returns the stored value and true if the exact prefix exists in the tree. +func (t *Tree[T]) GetPrefix(prefix netip.Prefix) (T, bool) { + var zeroT T + + n := t.rootV4 + if prefix.Addr().Is6() { + n = t.rootV6 + } + bitPosition := 0 + // mask the address for sanity + address := prefix.Masked().Addr() + // we can't check longer than the request mask + mask := prefix.Bits() + // walk the network bits of the prefix + for bitPosition < mask { + // Look for a child checking the bit position after the mask + n = n.child[getBitFromAddr(address, bitPosition+1)] + if n == nil { + return zeroT, false + } + // check we are in the right branch comparing the suffixes + if !n.prefix.Contains(address) { + return zeroT, false + } + // update the new bit position with the new node mask + bitPosition = n.prefix.Bits() + } + // check if this node is a public node and contains a prefix + if n != nil && n.public && n.prefix == prefix { + return n.val, true + } + + return zeroT, false +} + +// LongestPrefixMatch returns the longest prefix match, the stored value and true if exist. +// For example, considering the following prefixes 192.168.20.16/28 and 192.168.0.0/16, +// when the address 192.168.20.19/32 is looked up it will return 192.168.20.16/28. +func (t *Tree[T]) LongestPrefixMatch(prefix netip.Prefix) (netip.Prefix, T, bool) { + n := t.rootV4 + if prefix.Addr().Is6() { + n = t.rootV6 + } + + var last *node[T] + // bit position is given by the mask bits + bitPosition := 0 + // mask the address + address := prefix.Masked().Addr() + mask := prefix.Bits() + // walk the network bits of the prefix + for bitPosition < mask { + if n.public { + last = n + } + // Look for a child checking the bit position after the mask + n = n.child[getBitFromAddr(address, bitPosition+1)] + if n == nil { + break + } + // check we are in the right branch comparing the suffixes + if !n.prefix.Contains(address) { + break + } + // update the new bit position with the new node mask + bitPosition = n.prefix.Bits() + } + + if n != nil && n.public && n.prefix == prefix { + last = n + } + + if last != nil { + return last.prefix, last.val, true + } + var zeroT T + return netip.Prefix{}, zeroT, false +} + +// ShortestPrefixMatch returns the shortest prefix match, the stored value and true if exist. +// For example, considering the following prefixes 192.168.20.16/28 and 192.168.0.0/16, +// when the address 192.168.20.19/32 is looked up it will return 192.168.0.0/16. +func (t *Tree[T]) ShortestPrefixMatch(prefix netip.Prefix) (netip.Prefix, T, bool) { + var zeroT T + + n := t.rootV4 + if prefix.Addr().Is6() { + n = t.rootV6 + } + // bit position is given by the mask bits + bitPosition := 0 + // mask the address + address := prefix.Masked().Addr() + mask := prefix.Bits() + for bitPosition < mask { + if n.public { + return n.prefix, n.val, true + } + // Look for a child checking the bit position after the mask + n = n.child[getBitFromAddr(address, bitPosition+1)] + if n == nil { + return netip.Prefix{}, zeroT, false + } + // check we are in the right branch comparing the suffixes + if !n.prefix.Contains(address) { + return netip.Prefix{}, zeroT, false + } + // update the new bit position with the new node mask + bitPosition = n.prefix.Bits() + } + + if n != nil && n.public && n.prefix == prefix { + return n.prefix, n.val, true + } + return netip.Prefix{}, zeroT, false +} + +// InsertPrefix is used to add a new entry or update +// an existing entry. Returns true if updated. +func (t *Tree[T]) InsertPrefix(prefix netip.Prefix, v T) bool { + n := t.rootV4 + if prefix.Addr().Is6() { + n = t.rootV6 + } + var parent *node[T] + // bit position is given by the mask bits + bitPosition := 0 + // mask the address + address := prefix.Masked().Addr() + mask := prefix.Bits() + for bitPosition < mask { + // Look for a child checking the bit position after the mask + childIndex := getBitFromAddr(address, bitPosition+1) + parent = n + n = n.child[childIndex] + // if no child create a new one with + if n == nil { + parent.child[childIndex] = &node[T]{ + public: true, + val: v, + prefix: prefix, + } + return false + } + + // update the new bit position with the new node mask + bitPosition = n.prefix.Bits() + + // continue if we are in the right branch and current + // node is our parent + if n.prefix.Contains(address) && bitPosition <= mask { + continue + } + + // Split the node and add a new child: + // - Case 1: parent -> child -> n + // - Case 2: parent -> newnode |--> child + // |--> n + child := &node[T]{ + prefix: prefix, + public: true, + val: v, + } + // Case 1: existing node is a sibling + if prefix.Contains(n.prefix.Addr()) && bitPosition > mask { + // parent to child + parent.child[childIndex] = child + pos := prefix.Bits() + 1 + // calculate if the sibling is at the left or right + child.child[getBitFromAddr(n.prefix.Addr(), pos)] = n + return false + } + + // Case 2: existing node has the same mask but different base address + // add common ancestor and branch on it + ancestor := findAncestor(prefix, n.prefix) + link := &node[T]{ + prefix: ancestor, + } + pos := parent.prefix.Bits() + 1 + parent.child[getBitFromAddr(ancestor.Addr(), pos)] = link + // ancestor -> children + pos = ancestor.Bits() + 1 + idxChild := getBitFromAddr(prefix.Addr(), pos) + idxN := getBitFromAddr(n.prefix.Addr(), pos) + if idxChild == idxN { + panic(fmt.Sprintf("wrong ancestor %s: child %s N %s", ancestor.String(), prefix.String(), n.prefix.String())) + } + link.child[idxChild] = child + link.child[idxN] = n + return false + } + + // if already exist update it and make it public + if n != nil && n.prefix == prefix { + if n.public { + n.val = v + n.public = true + return true + } + n.val = v + n.public = true + return false + } + + return false +} + +// DeletePrefix delete the exact prefix and return true if it existed. +func (t *Tree[T]) DeletePrefix(prefix netip.Prefix) bool { + root := t.rootV4 + if prefix.Addr().Is6() { + root = t.rootV6 + } + var parent *node[T] + n := root + // bit position is given by the mask bits + bitPosition := 0 + // mask the address + address := prefix.Masked().Addr() + mask := prefix.Bits() + for bitPosition < mask { + // Look for a child checking the bit position after the mask + parent = n + n = n.child[getBitFromAddr(address, bitPosition+1)] + if n == nil { + return false + } + // check we are in the right branch comparing the suffixes + if !n.prefix.Contains(address) { + return false + } + // update the new bit position with the new node mask + bitPosition = n.prefix.Bits() + } + // check if the node contains the prefix we want to delete + if n.prefix != prefix { + return false + } + // Delete the value + n.public = false + var zeroT T + n.val = zeroT + + nodeChildren := 0 + if n.child[0] != nil { + nodeChildren++ + } + if n.child[1] != nil { + nodeChildren++ + } + // If there is a parent and this node does not have any children + // this is a leaf so we can delete this node. + // - parent -> child(to be deleted) + if parent != nil && nodeChildren == 0 { + if parent.child[0] != nil && parent.child[0] == n { + parent.child[0] = nil + } else if parent.child[1] != nil && parent.child[1] == n { + parent.child[1] = nil + } else { + panic("wrong parent") + } + n = nil + } + // Check if we should merge this node + // The root node can not be merged + if n != root && nodeChildren == 1 { + n.mergeChild() + } + // Check if we should merge the parent's other child + // parent -> deletedNode + // |--> child + parentChildren := 0 + if parent != nil { + if parent.child[0] != nil { + parentChildren++ + } + if parent.child[1] != nil { + parentChildren++ + } + if parent != root && parentChildren == 1 && !parent.public { + parent.mergeChild() + } + } + return true +} + +// for testing, returns the number of public nodes in the tree. +func (t *Tree[T]) Len(isV6 bool) int { + count := 0 + t.DepthFirstWalk(isV6, func(k netip.Prefix, v T) bool { + count++ + return false + }) + return count +} + +// WalkFn is used when walking the tree. Takes a +// key and value, returning if iteration should +// be terminated. +type WalkFn[T any] func(s netip.Prefix, v T) bool + +// DepthFirstWalk is used to walk the tree of the corresponding IP family +func (t *Tree[T]) DepthFirstWalk(isIPv6 bool, fn WalkFn[T]) { + if isIPv6 { + recursiveWalk(t.rootV6, fn) + } + recursiveWalk(t.rootV4, fn) +} + +// recursiveWalk is used to do a pre-order walk of a node +// recursively. Returns true if the walk should be aborted +func recursiveWalk[T any](n *node[T], fn WalkFn[T]) bool { + if n == nil { + return true + } + // Visit the public values if any + if n.public && fn(n.prefix, n.val) { + return true + } + + // Recurse on the children + if n.child[0] != nil { + if recursiveWalk(n.child[0], fn) { + return true + } + } + if n.child[1] != nil { + if recursiveWalk(n.child[1], fn) { + return true + } + } + return false +} + +// WalkPrefix is used to walk the tree under a prefix +func (t *Tree[T]) WalkPrefix(prefix netip.Prefix, fn WalkFn[T]) { + n := t.rootV4 + if prefix.Addr().Is6() { + n = t.rootV6 + } + bitPosition := 0 + // mask the address for sanity + address := prefix.Masked().Addr() + // we can't check longer than the request mask + mask := prefix.Bits() + // walk the network bits of the prefix + for bitPosition < mask { + // Look for a child checking the bit position after the mask + n = n.child[getBitFromAddr(address, bitPosition+1)] + if n == nil { + return + } + // check we are in the right branch comparing the suffixes + if !n.prefix.Contains(address) { + break + } + // update the new bit position with the new node mask + bitPosition = n.prefix.Bits() + } + recursiveWalk[T](n, fn) + +} + +// WalkPath is used to walk the tree, but only visiting nodes +// from the root down to a given IP prefix. Where WalkPrefix walks +// all the entries *under* the given prefix, this walks the +// entries *above* the given prefix. +func (t *Tree[T]) WalkPath(path netip.Prefix, fn WalkFn[T]) { + n := t.rootV4 + if path.Addr().Is6() { + n = t.rootV6 + } + bitPosition := 0 + // mask the address for sanity + address := path.Masked().Addr() + // we can't check longer than the request mask + mask := path.Bits() + // walk the network bits of the prefix + for bitPosition < mask { + // Visit the public values if any + if n.public && fn(n.prefix, n.val) { + return + } + // Look for a child checking the bit position after the mask + n = n.child[getBitFromAddr(address, bitPosition+1)] + if n == nil { + return + } + // check we are in the right branch comparing the suffixes + if !n.prefix.Contains(address) { + return + } + // update the new bit position with the new node mask + bitPosition = n.prefix.Bits() + } + // check if this node is a public node and contains a prefix + if n != nil && n.public && n.prefix == path { + fn(n.prefix, n.val) + } +} + +// TopLevelPrefixes is used to return a map with all the Top Level prefixes +// from the corresponding IP family and its values. +// For example, if the tree contains entries for 10.0.0.0/8, 10.1.0.0/16, and 192.168.0.0/16, +// this will return 10.0.0.0/8 and 192.168.0.0/16. +func (t *Tree[T]) TopLevelPrefixes(isIPv6 bool) map[string]T { + if isIPv6 { + return t.topLevelPrefixes(t.rootV6) + } + return t.topLevelPrefixes(t.rootV4) +} + +// topLevelPrefixes is used to return a map with all the Top Level prefixes and its values +func (t *Tree[T]) topLevelPrefixes(root *node[T]) map[string]T { + result := map[string]T{} + queue := []*node[T]{root} + + for len(queue) > 0 { + n := queue[0] + queue = queue[1:] + // store and continue, only interested on the top level prefixes + if n.public { + result[n.prefix.String()] = n.val + continue + } + if n.child[0] != nil { + queue = append(queue, n.child[0]) + } + if n.child[1] != nil { + queue = append(queue, n.child[1]) + } + } + return result +} + +// GetHostIPPrefixMatches returns the list of prefixes that contain the specified Host IP. +// An IP is considered a Host IP if is within the subnet range and is not the network address +// or, if IPv4, the broadcast address (RFC 1878). +func (t *Tree[T]) GetHostIPPrefixMatches(ip netip.Addr) map[netip.Prefix]T { + // walk the tree to find all the prefixes containing this IP + ipPrefix := netip.PrefixFrom(ip, ip.BitLen()) + prefixes := map[netip.Prefix]T{} + t.WalkPath(ipPrefix, func(k netip.Prefix, v T) bool { + if prefixContainIP(k, ipPrefix.Addr()) { + prefixes[k] = v + } + return false + }) + return prefixes +} + +// assume starts at 0 from the MSB: 0.1.2......31 +// return 0 or 1 +func getBitFromAddr(ip netip.Addr, pos int) int { + bytes := ip.AsSlice() + // get the byte in the slice + index := (pos - 1) / 8 + if index >= len(bytes) { + panic(fmt.Sprintf("ip %s pos %d index %d bytes %v", ip, pos, index, bytes)) + } + // get the offset inside the byte + offset := (pos - 1) % 8 + // check if the bit is set + if bytes[index]&(uint8(0x80)>>offset) > 0 { + return 1 + } + return 0 +} + +// find the common subnet, aka the one with the common prefix +func findAncestor(a, b netip.Prefix) netip.Prefix { + bytesA := a.Addr().AsSlice() + bytesB := b.Addr().AsSlice() + bytes := make([]byte, len(bytesA)) + + max := a.Bits() + if l := b.Bits(); l < max { + max = l + } + + mask := 0 + for i := range bytesA { + xor := bytesA[i] ^ bytesB[i] + if xor == 0 { + bytes[i] = bytesA[i] + mask += 8 + + } else { + pos := bits.LeadingZeros8(xor) + mask += pos + // mask off the non leading zeros + bytes[i] = bytesA[i] & (^uint8(0) << (8 - pos)) + break + } + } + if mask > max { + mask = max + } + + addr, ok := netip.AddrFromSlice(bytes) + if !ok { + panic(bytes) + } + ancestor := netip.PrefixFrom(addr, mask) + return ancestor.Masked() +} + +// prefixContainIP returns true if the given IP is contained with the prefix, +// is not the network address and also, if IPv4, is not the broadcast address. +// This is required because the Kubernetes allocators reserve these addresses +// so IPAddresses can not block deletion of this ranges. +func prefixContainIP(prefix netip.Prefix, ip netip.Addr) bool { + // if the IP is the network address is not contained + if prefix.Masked().Addr() == ip { + return false + } + // the broadcast address is not considered contained for IPv4 + if !ip.Is6() { + ipLast, err := broadcastAddress(prefix) + if err != nil || ipLast == ip { + return false + } + } + return prefix.Contains(ip) +} + +// TODO(aojea) consolidate all these IPs utils +// pkg/registry/core/service/ipallocator/ipallocator.go +// broadcastAddress returns the broadcast address of the subnet +// The broadcast address is obtained by setting all the host bits +// in a subnet to 1. +// network 192.168.0.0/24 : subnet bits 24 host bits 32 - 24 = 8 +// broadcast address 192.168.0.255 +func broadcastAddress(subnet netip.Prefix) (netip.Addr, error) { + base := subnet.Masked().Addr() + bytes := base.AsSlice() + // get all the host bits from the subnet + n := 8*len(bytes) - subnet.Bits() + // set all the host bits to 1 + for i := len(bytes) - 1; i >= 0 && n > 0; i-- { + if n >= 8 { + bytes[i] = 0xff + n -= 8 + } else { + mask := ^uint8(0) >> (8 - n) + bytes[i] |= mask + break + } + } + + addr, ok := netip.AddrFromSlice(bytes) + if !ok { + return netip.Addr{}, fmt.Errorf("invalid address %v", bytes) + } + return addr, nil +} diff --git a/pkg/util/iptree/iptree_test.go b/pkg/util/iptree/iptree_test.go new file mode 100644 index 00000000000..f25ff5800f9 --- /dev/null +++ b/pkg/util/iptree/iptree_test.go @@ -0,0 +1,781 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package iptree + +import ( + "math/rand" + "net/netip" + "reflect" + "sort" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "k8s.io/apimachinery/pkg/util/sets" +) + +func Test_InsertGetDelete(t *testing.T) { + testCases := []struct { + name string + prefix netip.Prefix + }{ + { + name: "ipv4", + prefix: netip.MustParsePrefix("192.168.0.0/24"), + }, + { + name: "ipv6", + prefix: netip.MustParsePrefix("fd00:1:2:3::/124"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tree := New[int]() + ok := tree.InsertPrefix(tc.prefix, 1) + if ok { + t.Fatal("should not exist") + } + if _, ok := tree.GetPrefix(tc.prefix); !ok { + t.Errorf("CIDR %s not found", tc.prefix) + } + if ok := tree.DeletePrefix(tc.prefix); !ok { + t.Errorf("CIDR %s not deleted", tc.prefix) + } + if _, ok := tree.GetPrefix(tc.prefix); ok { + t.Errorf("CIDR %s found", tc.prefix) + } + }) + } + +} + +func TestBasicIPv4(t *testing.T) { + tree := New[int]() + // insert + ipnet := netip.MustParsePrefix("192.168.0.0/24") + ok := tree.InsertPrefix(ipnet, 1) + if ok { + t.Fatal("should not exist") + } + // check exist + if _, ok := tree.GetPrefix(ipnet); !ok { + t.Errorf("CIDR %s not found", ipnet) + } + + // check does not exist + ipnet2 := netip.MustParsePrefix("12.1.0.0/16") + if _, ok := tree.GetPrefix(ipnet2); ok { + t.Errorf("CIDR %s not expected", ipnet2) + } + + // check insert existing prefix updates the value + ok = tree.InsertPrefix(ipnet2, 2) + if ok { + t.Errorf("should not exist: %s", ipnet2) + } + + ok = tree.InsertPrefix(ipnet2, 3) + if !ok { + t.Errorf("should be updated: %s", ipnet2) + } + + if v, ok := tree.GetPrefix(ipnet2); !ok || v != 3 { + t.Errorf("CIDR %s not expected", ipnet2) + } + + // check longer prefix matching + ipnet3 := netip.MustParsePrefix("12.1.0.2/32") + lpm, _, ok := tree.LongestPrefixMatch(ipnet3) + if !ok || lpm != ipnet2 { + t.Errorf("expected %s got %s", ipnet2, lpm) + } +} + +func TestBasicIPv6(t *testing.T) { + tree := New[int]() + // insert + ipnet := netip.MustParsePrefix("2001:db8::/64") + ok := tree.InsertPrefix(ipnet, 1) + if ok { + t.Fatal("should not exist") + } + // check exist + if _, ok := tree.GetPrefix(ipnet); !ok { + t.Errorf("CIDR %s not found", ipnet) + } + + // check does not exist + ipnet2 := netip.MustParsePrefix("2001:db8:1:3:4::/64") + if _, ok := tree.GetPrefix(ipnet2); ok { + t.Errorf("CIDR %s not expected", ipnet2) + } + + // check insert existing prefix updates the value + ok = tree.InsertPrefix(ipnet2, 2) + if ok { + t.Errorf("should not exist: %s", ipnet2) + } + + ok = tree.InsertPrefix(ipnet2, 3) + if !ok { + t.Errorf("should be updated: %s", ipnet2) + } + + if v, ok := tree.GetPrefix(ipnet2); !ok || v != 3 { + t.Errorf("CIDR %s not expected", ipnet2) + } + + // check longer prefix matching + ipnet3 := netip.MustParsePrefix("2001:db8:1:3:4::/96") + lpm, _, ok := tree.LongestPrefixMatch(ipnet3) + if !ok || lpm != ipnet2 { + t.Errorf("expected %s got %s", ipnet2, lpm) + } +} + +func TestInsertGetDelete100K(t *testing.T) { + testCases := []struct { + name string + is6 bool + }{ + { + name: "ipv4", + }, + { + name: "ipv6", + is6: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cidrs := generateRandomCIDRs(tc.is6, 100*1000) + tree := New[string]() + + for k := range cidrs { + ok := tree.InsertPrefix(k, k.String()) + if ok { + t.Errorf("error inserting: %v", k) + } + } + + if tree.Len(tc.is6) != len(cidrs) { + t.Errorf("expected %d nodes on the tree, got %d", len(cidrs), tree.Len(tc.is6)) + } + + list := cidrs.UnsortedList() + for _, k := range list { + if v, ok := tree.GetPrefix(k); !ok { + t.Errorf("CIDR %s not found", k) + return + } else if v != k.String() { + t.Errorf("CIDR value %s not found", k) + return + } + ok := tree.DeletePrefix(k) + if !ok { + t.Errorf("CIDR delete %s error", k) + } + } + + if tree.Len(tc.is6) != 0 { + t.Errorf("No node expected on the tree, got: %d %v", tree.Len(tc.is6), cidrs) + } + }) + } +} + +func Test_findAncestor(t *testing.T) { + tests := []struct { + name string + a netip.Prefix + b netip.Prefix + want netip.Prefix + }{ + { + name: "ipv4 direct parent", + a: netip.MustParsePrefix("192.168.0.0/24"), + b: netip.MustParsePrefix("192.168.1.0/24"), + want: netip.MustParsePrefix("192.168.0.0/23"), + }, + { + name: "ipv4 root parent ", + a: netip.MustParsePrefix("192.168.0.0/24"), + b: netip.MustParsePrefix("1.168.1.0/24"), + want: netip.MustParsePrefix("0.0.0.0/0"), + }, + { + name: "ipv4 parent /1", + a: netip.MustParsePrefix("192.168.0.0/24"), + b: netip.MustParsePrefix("184.168.1.0/24"), + want: netip.MustParsePrefix("128.0.0.0/1"), + }, + { + name: "ipv6 direct parent", + a: netip.MustParsePrefix("fd00:1:1:1::/64"), + b: netip.MustParsePrefix("fd00:1:1:2::/64"), + want: netip.MustParsePrefix("fd00:1:1::/62"), + }, + { + name: "ipv6 root parent ", + a: netip.MustParsePrefix("fd00:1:1:1::/64"), + b: netip.MustParsePrefix("1:1:1:1::/64"), + want: netip.MustParsePrefix("::/0"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := findAncestor(tt.a, tt.b); !reflect.DeepEqual(got, tt.want) { + t.Errorf("findAncestor() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_getBitFromAddr(t *testing.T) { + tests := []struct { + name string + ip netip.Addr + pos int + want int + }{ + // 192.168.0.0 + // 11000000.10101000.00000000.00000001 + { + name: "ipv4 first is a one", + ip: netip.MustParseAddr("192.168.0.0"), + pos: 1, + want: 1, + }, + { + name: "ipv4 middle is a zero", + ip: netip.MustParseAddr("192.168.0.0"), + pos: 16, + want: 0, + }, + { + name: "ipv4 middle is a one", + ip: netip.MustParseAddr("192.168.0.0"), + pos: 13, + want: 1, + }, + { + name: "ipv4 last is a zero", + ip: netip.MustParseAddr("192.168.0.0"), + pos: 32, + want: 0, + }, + // 2001:db8::ff00:42:8329 + // 0010000000000001:0000110110111000:0000000000000000:0000000000000000:0000000000000000:1111111100000000:0000000001000010:1000001100101001 + { + name: "ipv6 first is a zero", + ip: netip.MustParseAddr("2001:db8::ff00:42:8329"), + pos: 1, + want: 0, + }, + { + name: "ipv6 middle is a zero", + ip: netip.MustParseAddr("2001:db8::ff00:42:8329"), + pos: 56, + want: 0, + }, + { + name: "ipv6 middle is a one", + ip: netip.MustParseAddr("2001:db8::ff00:42:8329"), + pos: 81, + want: 1, + }, + { + name: "ipv6 last is a one", + ip: netip.MustParseAddr("2001:db8::ff00:42:8329"), + pos: 128, + want: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getBitFromAddr(tt.ip, tt.pos); got != tt.want { + t.Errorf("getBitFromAddr() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestShortestPrefix(t *testing.T) { + r := New[int]() + + keys := []string{ + "10.0.0.0/8", + "10.21.0.0/16", + "10.221.0.0/16", + "10.1.2.3/32", + "10.1.2.0/24", + "192.168.0.0/24", + "192.168.0.0/16", + } + for _, k := range keys { + ok := r.InsertPrefix(netip.MustParsePrefix(k), 0) + if ok { + t.Errorf("unexpected update on insert %s", k) + } + } + if r.Len(false) != len(keys) { + t.Fatalf("bad len: %v %v", r.Len(false), len(keys)) + } + + type exp struct { + inp string + out string + } + cases := []exp{ + {"192.168.0.3/32", "192.168.0.0/16"}, + {"10.1.2.4/21", "10.0.0.0/8"}, + {"192.168.0.0/16", "192.168.0.0/16"}, + {"192.168.0.0/32", "192.168.0.0/16"}, + {"10.1.2.3/32", "10.0.0.0/8"}, + } + for _, test := range cases { + m, _, ok := r.ShortestPrefixMatch(netip.MustParsePrefix(test.inp)) + if !ok { + t.Fatalf("no match: %v", test) + } + if m != netip.MustParsePrefix(test.out) { + t.Fatalf("mis-match: %v %v", m, test) + } + } + + // not match + _, _, ok := r.ShortestPrefixMatch(netip.MustParsePrefix("0.0.0.0/0")) + if ok { + t.Fatalf("match unexpected for 0.0.0.0/0") + } +} + +func TestLongestPrefixMatch(t *testing.T) { + r := New[int]() + + keys := []string{ + "10.0.0.0/8", + "10.21.0.0/16", + "10.221.0.0/16", + "10.1.2.3/32", + "10.1.2.0/24", + "192.168.0.0/24", + "192.168.0.0/16", + } + for _, k := range keys { + ok := r.InsertPrefix(netip.MustParsePrefix(k), 0) + if ok { + t.Errorf("unexpected update on insert %s", k) + } + } + if r.Len(false) != len(keys) { + t.Fatalf("bad len: %v %v", r.Len(false), len(keys)) + } + + type exp struct { + inp string + out string + } + cases := []exp{ + {"192.168.0.3/32", "192.168.0.0/24"}, + {"10.1.2.4/21", "10.0.0.0/8"}, + {"10.21.2.0/24", "10.21.0.0/16"}, + {"10.1.2.3/32", "10.1.2.3/32"}, + } + for _, test := range cases { + m, _, ok := r.LongestPrefixMatch(netip.MustParsePrefix(test.inp)) + if !ok { + t.Fatalf("no match: %v", test) + } + if m != netip.MustParsePrefix(test.out) { + t.Fatalf("mis-match: %v %v", m, test) + } + } + // not match + _, _, ok := r.LongestPrefixMatch(netip.MustParsePrefix("0.0.0.0/0")) + if ok { + t.Fatalf("match unexpected for 0.0.0.0/0") + } +} + +func TestTopLevelPrefixesV4(t *testing.T) { + r := New[string]() + + keys := []string{ + "10.0.0.0/8", + "10.21.0.0/16", + "10.221.0.0/16", + "10.1.2.3/32", + "10.1.2.0/24", + "192.168.0.0/20", + "192.168.1.0/24", + "172.16.0.0/12", + "172.21.23.0/24", + } + for _, k := range keys { + ok := r.InsertPrefix(netip.MustParsePrefix(k), k) + if ok { + t.Errorf("unexpected update on insert %s", k) + } + } + if r.Len(false) != len(keys) { + t.Fatalf("bad len: %v %v", r.Len(false), len(keys)) + } + + expected := []string{ + "10.0.0.0/8", + "192.168.0.0/20", + "172.16.0.0/12", + } + parents := r.TopLevelPrefixes(false) + if len(parents) != len(expected) { + t.Fatalf("bad len: %v %v", len(parents), len(expected)) + } + + for _, k := range expected { + v, ok := parents[k] + if !ok { + t.Errorf("key %s not found", k) + } + if v != k { + t.Errorf("value expected %s got %s", k, v) + } + } +} + +func TestTopLevelPrefixesV6(t *testing.T) { + r := New[string]() + + keys := []string{ + "2001:db8:1:2:3::/64", + "2001:db8::/64", + "2001:db8:1:1:1::/64", + "2001:db8:1:1:1::/112", + } + for _, k := range keys { + ok := r.InsertPrefix(netip.MustParsePrefix(k), k) + if ok { + t.Errorf("unexpected update on insert %s", k) + } + } + + if r.Len(true) != len(keys) { + t.Fatalf("bad len: %v %v", r.Len(true), len(keys)) + } + + expected := []string{ + "2001:db8::/64", + "2001:db8:1:2:3::/64", + "2001:db8:1:1:1::/64", + } + parents := r.TopLevelPrefixes(true) + if len(parents) != len(expected) { + t.Fatalf("bad len: %v %v", len(parents), len(expected)) + } + + for _, k := range expected { + v, ok := parents[k] + if !ok { + t.Errorf("key %s not found", k) + } + if v != k { + t.Errorf("value expected %s got %s", k, v) + } + } +} + +func TestWalkV4(t *testing.T) { + r := New[int]() + + keys := []string{ + "10.0.0.0/8", + "10.1.0.0/16", + "10.1.1.0/24", + "10.1.1.32/26", + "10.1.1.33/32", + } + for _, k := range keys { + ok := r.InsertPrefix(netip.MustParsePrefix(k), 0) + if ok { + t.Errorf("unexpected update on insert %s", k) + } + } + if r.Len(false) != len(keys) { + t.Fatalf("bad len: %v %v", r.Len(false), len(keys)) + } + + // match exact prefix + path := []string{} + r.WalkPath(netip.MustParsePrefix("10.1.1.32/26"), func(k netip.Prefix, v int) bool { + path = append(path, k.String()) + return false + }) + if !cmp.Equal(path, keys[:4]) { + t.Errorf("Walkpath expected %v got %v", keys[:4], path) + } + // not match on prefix + path = []string{} + r.WalkPath(netip.MustParsePrefix("10.1.1.33/26"), func(k netip.Prefix, v int) bool { + path = append(path, k.String()) + return false + }) + if !cmp.Equal(path, keys[:3]) { + t.Errorf("Walkpath expected %v got %v", keys[:3], path) + } + // match exact prefix + path = []string{} + r.WalkPrefix(netip.MustParsePrefix("10.0.0.0/8"), func(k netip.Prefix, v int) bool { + path = append(path, k.String()) + return false + }) + if !cmp.Equal(path, keys) { + t.Errorf("WalkPrefix expected %v got %v", keys, path) + } + // not match on prefix + path = []string{} + r.WalkPrefix(netip.MustParsePrefix("10.0.0.0/9"), func(k netip.Prefix, v int) bool { + path = append(path, k.String()) + return false + }) + if !cmp.Equal(path, keys[1:]) { + t.Errorf("WalkPrefix expected %v got %v", keys[1:], path) + } +} + +func TestWalkV6(t *testing.T) { + r := New[int]() + + keys := []string{ + "2001:db8::/48", + "2001:db8::/64", + "2001:db8::/96", + "2001:db8::/112", + "2001:db8::/128", + } + for _, k := range keys { + ok := r.InsertPrefix(netip.MustParsePrefix(k), 0) + if ok { + t.Errorf("unexpected update on insert %s", k) + } + } + if r.Len(true) != len(keys) { + t.Fatalf("bad len: %v %v", r.Len(false), len(keys)) + } + + // match exact prefix + path := []string{} + r.WalkPath(netip.MustParsePrefix("2001:db8::/112"), func(k netip.Prefix, v int) bool { + path = append(path, k.String()) + return false + }) + if !cmp.Equal(path, keys[:4]) { + t.Errorf("Walkpath expected %v got %v", keys[:4], path) + } + // not match on prefix + path = []string{} + r.WalkPath(netip.MustParsePrefix("2001:db8::1/112"), func(k netip.Prefix, v int) bool { + path = append(path, k.String()) + return false + }) + if !cmp.Equal(path, keys[:3]) { + t.Errorf("Walkpath expected %v got %v", keys[:3], path) + } + // match exact prefix + path = []string{} + r.WalkPrefix(netip.MustParsePrefix("2001:db8::/48"), func(k netip.Prefix, v int) bool { + path = append(path, k.String()) + return false + }) + if !cmp.Equal(path, keys) { + t.Errorf("WalkPrefix expected %v got %v", keys, path) + } + // not match on prefix + path = []string{} + r.WalkPrefix(netip.MustParsePrefix("2001:db8::/49"), func(k netip.Prefix, v int) bool { + path = append(path, k.String()) + return false + }) + if !cmp.Equal(path, keys[1:]) { + t.Errorf("WalkPrefix expected %v got %v", keys[1:], path) + } +} + +func TestGetHostIPPrefixMatches(t *testing.T) { + r := New[int]() + + keys := []string{ + "10.0.0.0/8", + "10.21.0.0/16", + "10.221.0.0/16", + "10.1.2.3/32", + "10.1.2.0/24", + "192.168.0.0/24", + "192.168.0.0/16", + "2001:db8::/48", + "2001:db8::/64", + "2001:db8::/96", + } + for _, k := range keys { + ok := r.InsertPrefix(netip.MustParsePrefix(k), 0) + if ok { + t.Errorf("unexpected update on insert %s", k) + } + } + + type exp struct { + inp string + out []string + } + cases := []exp{ + {"192.168.0.3", []string{"192.168.0.0/24", "192.168.0.0/16"}}, + {"10.1.2.4", []string{"10.1.2.0/24", "10.0.0.0/8"}}, + {"10.1.2.0", []string{"10.0.0.0/8"}}, + {"10.1.2.255", []string{"10.0.0.0/8"}}, + {"192.168.0.0", []string{}}, + {"192.168.1.0", []string{"192.168.0.0/16"}}, + {"10.1.2.255", []string{"10.0.0.0/8"}}, + {"2001:db8::1", []string{"2001:db8::/96", "2001:db8::/64", "2001:db8::/48"}}, + {"2001:db8::", []string{}}, + {"2001:db8::ffff:ffff:ffff:ffff", []string{"2001:db8::/64", "2001:db8::/48"}}, + } + for _, test := range cases { + m := r.GetHostIPPrefixMatches(netip.MustParseAddr(test.inp)) + in := []netip.Prefix{} + for k := range m { + in = append(in, k) + } + out := []netip.Prefix{} + for _, s := range test.out { + out = append(out, netip.MustParsePrefix(s)) + } + + // sort by prefix bits to avoid flakes + sort.Slice(in, func(i, j int) bool { return in[i].Bits() < in[j].Bits() }) + sort.Slice(out, func(i, j int) bool { return out[i].Bits() < out[j].Bits() }) + if !reflect.DeepEqual(in, out) { + t.Fatalf("mis-match: %v %v", in, out) + } + } + + // not match + _, _, ok := r.ShortestPrefixMatch(netip.MustParsePrefix("0.0.0.0/0")) + if ok { + t.Fatalf("match unexpected for 0.0.0.0/0") + } +} + +func Test_prefixContainIP(t *testing.T) { + tests := []struct { + name string + prefix netip.Prefix + ip netip.Addr + want bool + }{ + { + name: "IPv4 contains", + prefix: netip.MustParsePrefix("192.168.0.0/24"), + ip: netip.MustParseAddr("192.168.0.1"), + want: true, + }, + { + name: "IPv4 network address", + prefix: netip.MustParsePrefix("192.168.0.0/24"), + ip: netip.MustParseAddr("192.168.0.0"), + }, + { + name: "IPv4 broadcast address", + prefix: netip.MustParsePrefix("192.168.0.0/24"), + ip: netip.MustParseAddr("192.168.0.255"), + }, + { + name: "IPv4 does not contain", + prefix: netip.MustParsePrefix("192.168.0.0/24"), + ip: netip.MustParseAddr("192.168.1.2"), + }, + { + name: "IPv6 contains", + prefix: netip.MustParsePrefix("2001:db2::/96"), + ip: netip.MustParseAddr("2001:db2::1"), + want: true, + }, + { + name: "IPv6 network address", + prefix: netip.MustParsePrefix("2001:db2::/96"), + ip: netip.MustParseAddr("2001:db2::"), + }, + { + name: "IPv6 broadcast address", + prefix: netip.MustParsePrefix("2001:db2::/96"), + ip: netip.MustParseAddr("2001:db2::ffff:ffff"), + want: true, + }, + { + name: "IPv6 does not contain", + prefix: netip.MustParsePrefix("2001:db2::/96"), + ip: netip.MustParseAddr("2001:db2:1:2:3::1"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := prefixContainIP(tt.prefix, tt.ip); got != tt.want { + t.Errorf("prefixContainIP() = %v, want %v", got, tt.want) + } + }) + } +} + +func BenchmarkInsertUpdate(b *testing.B) { + r := New[bool]() + ipList := generateRandomCIDRs(true, 20000).UnsortedList() + for _, ip := range ipList { + r.InsertPrefix(ip, true) + } + + b.ResetTimer() + for n := 0; n < b.N; n++ { + r.InsertPrefix(ipList[n%len(ipList)], true) + } +} + +func generateRandomCIDRs(is6 bool, number int) sets.Set[netip.Prefix] { + n := 4 + if is6 { + n = 16 + } + cidrs := sets.Set[netip.Prefix]{} + rand.New(rand.NewSource(time.Now().UnixNano())) + for i := 0; i < number; i++ { + bytes := make([]byte, n) + for i := 0; i < n; i++ { + bytes[i] = uint8(rand.Intn(255)) + } + + ip, ok := netip.AddrFromSlice(bytes) + if !ok { + continue + } + + bits := rand.Intn(n * 8) + prefix := netip.PrefixFrom(ip, bits).Masked() + if prefix.IsValid() { + cidrs.Insert(prefix) + } + } + return cidrs +}