diff --git a/pkg/util/iptree/iptree.go b/pkg/util/iptree/iptree.go deleted file mode 100644 index 441062c58d2..00000000000 --- a/pkg/util/iptree/iptree.go +++ /dev/null @@ -1,679 +0,0 @@ -/* -Copyright 2023 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package iptree - -import ( - "fmt" - "math/bits" - "net/netip" -) - -// iptree implement a radix tree that uses IP prefixes as nodes and allows to store values in each node. -// Example: -// -// r := New[int]() -// -// prefixes := []string{ -// "0.0.0.0/0", -// "10.0.0.0/8", -// "10.0.0.0/16", -// "10.1.0.0/16", -// "10.1.1.0/24", -// "10.1.244.0/24", -// "10.0.0.0/24", -// "10.0.0.3/32", -// "192.168.0.0/24", -// "192.168.0.0/28", -// "192.168.129.0/28", -// } -// for _, k := range prefixes { -// r.InsertPrefix(netip.MustParsePrefix(k), 0) -// } -// -// (*) means the node is not public, is not storing any value -// -// 0.0.0.0/0 --- 10.0.0.0/8 --- *10.0.0.0/15 --- 10.0.0.0/16 --- 10.0.0.0/24 --- 10.0.0.3/32 -// | | -// | \ -------- 10.1.0.0/16 --- 10.1.1.0/24 -// | | -// | \ ------- 10.1.244.0/24 -// | -// \------ *192.168.0.0/16 --- 192.168.0.0/24 --- 192.168.0.0/28 -// | -// \ -------- 192.168.129.0/28 - -// node is an element of radix tree with a netip.Prefix optimized to store IP prefixes. -type node[T any] struct { - // prefix network CIDR - prefix netip.Prefix - // public nodes are used to store values - public bool - val T - - child [2]*node[T] // binary tree -} - -// mergeChild allow to compress the tree -// when n has exactly one child and no value -// p -> n -> b -> c ==> p -> b -> c -func (n *node[T]) mergeChild() { - // public nodes can not be merged - if n.public { - return - } - // can not merge if there are two children - if n.child[0] != nil && - n.child[1] != nil { - return - } - // can not merge if there are no children - if n.child[0] == nil && - n.child[1] == nil { - return - } - // find the child and merge it - var child *node[T] - if n.child[0] != nil { - child = n.child[0] - } else if n.child[1] != nil { - child = n.child[1] - } - n.prefix = child.prefix - n.public = child.public - n.val = child.val - n.child = child.child - // remove any references from the deleted node - // to avoid memory leak - child.child[0] = nil - child.child[1] = nil -} - -// Tree is a radix tree for IPv4 and IPv6 networks. -type Tree[T any] struct { - rootV4 *node[T] - rootV6 *node[T] -} - -// New creates a new Radix Tree for IP addresses. -func New[T any]() *Tree[T] { - return &Tree[T]{ - rootV4: &node[T]{ - prefix: netip.PrefixFrom(netip.IPv4Unspecified(), 0), - }, - rootV6: &node[T]{ - prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0), - }, - } -} - -// GetPrefix returns the stored value and true if the exact prefix exists in the tree. -func (t *Tree[T]) GetPrefix(prefix netip.Prefix) (T, bool) { - var zeroT T - - n := t.rootV4 - if prefix.Addr().Is6() { - n = t.rootV6 - } - bitPosition := 0 - // mask the address for sanity - address := prefix.Masked().Addr() - // we can't check longer than the request mask - mask := prefix.Bits() - // walk the network bits of the prefix - for bitPosition < mask { - // Look for a child checking the bit position after the mask - n = n.child[getBitFromAddr(address, bitPosition+1)] - if n == nil { - return zeroT, false - } - // check we are in the right branch comparing the suffixes - if !n.prefix.Contains(address) { - return zeroT, false - } - // update the new bit position with the new node mask - bitPosition = n.prefix.Bits() - } - // check if this node is a public node and contains a prefix - if n != nil && n.public && n.prefix == prefix { - return n.val, true - } - - return zeroT, false -} - -// LongestPrefixMatch returns the longest prefix match, the stored value and true if exist. -// For example, considering the following prefixes 192.168.20.16/28 and 192.168.0.0/16, -// when the address 192.168.20.19/32 is looked up it will return 192.168.20.16/28. -func (t *Tree[T]) LongestPrefixMatch(prefix netip.Prefix) (netip.Prefix, T, bool) { - n := t.rootV4 - if prefix.Addr().Is6() { - n = t.rootV6 - } - - var last *node[T] - // bit position is given by the mask bits - bitPosition := 0 - // mask the address - address := prefix.Masked().Addr() - mask := prefix.Bits() - // walk the network bits of the prefix - for bitPosition < mask { - if n.public { - last = n - } - // Look for a child checking the bit position after the mask - n = n.child[getBitFromAddr(address, bitPosition+1)] - if n == nil { - break - } - // check we are in the right branch comparing the suffixes - if !n.prefix.Contains(address) { - break - } - // update the new bit position with the new node mask - bitPosition = n.prefix.Bits() - } - - if n != nil && n.public && n.prefix == prefix { - last = n - } - - if last != nil { - return last.prefix, last.val, true - } - var zeroT T - return netip.Prefix{}, zeroT, false -} - -// ShortestPrefixMatch returns the shortest prefix match, the stored value and true if exist. -// For example, considering the following prefixes 192.168.20.16/28 and 192.168.0.0/16, -// when the address 192.168.20.19/32 is looked up it will return 192.168.0.0/16. -func (t *Tree[T]) ShortestPrefixMatch(prefix netip.Prefix) (netip.Prefix, T, bool) { - var zeroT T - - n := t.rootV4 - if prefix.Addr().Is6() { - n = t.rootV6 - } - // bit position is given by the mask bits - bitPosition := 0 - // mask the address - address := prefix.Masked().Addr() - mask := prefix.Bits() - for bitPosition < mask { - if n.public { - return n.prefix, n.val, true - } - // Look for a child checking the bit position after the mask - n = n.child[getBitFromAddr(address, bitPosition+1)] - if n == nil { - return netip.Prefix{}, zeroT, false - } - // check we are in the right branch comparing the suffixes - if !n.prefix.Contains(address) { - return netip.Prefix{}, zeroT, false - } - // update the new bit position with the new node mask - bitPosition = n.prefix.Bits() - } - - if n != nil && n.public && n.prefix == prefix { - return n.prefix, n.val, true - } - return netip.Prefix{}, zeroT, false -} - -// InsertPrefix is used to add a new entry or update -// an existing entry. Returns true if updated. -func (t *Tree[T]) InsertPrefix(prefix netip.Prefix, v T) bool { - n := t.rootV4 - if prefix.Addr().Is6() { - n = t.rootV6 - } - var parent *node[T] - // bit position is given by the mask bits - bitPosition := 0 - // mask the address - address := prefix.Masked().Addr() - mask := prefix.Bits() - for bitPosition < mask { - // Look for a child checking the bit position after the mask - childIndex := getBitFromAddr(address, bitPosition+1) - parent = n - n = n.child[childIndex] - // if no child create a new one with - if n == nil { - parent.child[childIndex] = &node[T]{ - public: true, - val: v, - prefix: prefix, - } - return false - } - - // update the new bit position with the new node mask - bitPosition = n.prefix.Bits() - - // continue if we are in the right branch and current - // node is our parent - if n.prefix.Contains(address) && bitPosition <= mask { - continue - } - - // Split the node and add a new child: - // - Case 1: parent -> child -> n - // - Case 2: parent -> newnode |--> child - // |--> n - child := &node[T]{ - prefix: prefix, - public: true, - val: v, - } - // Case 1: existing node is a sibling - if prefix.Contains(n.prefix.Addr()) && bitPosition > mask { - // parent to child - parent.child[childIndex] = child - pos := prefix.Bits() + 1 - // calculate if the sibling is at the left or right - child.child[getBitFromAddr(n.prefix.Addr(), pos)] = n - return false - } - - // Case 2: existing node has the same mask but different base address - // add common ancestor and branch on it - ancestor := findAncestor(prefix, n.prefix) - link := &node[T]{ - prefix: ancestor, - } - pos := parent.prefix.Bits() + 1 - parent.child[getBitFromAddr(ancestor.Addr(), pos)] = link - // ancestor -> children - pos = ancestor.Bits() + 1 - idxChild := getBitFromAddr(prefix.Addr(), pos) - idxN := getBitFromAddr(n.prefix.Addr(), pos) - if idxChild == idxN { - panic(fmt.Sprintf("wrong ancestor %s: child %s N %s", ancestor.String(), prefix.String(), n.prefix.String())) - } - link.child[idxChild] = child - link.child[idxN] = n - return false - } - - // if already exist update it and make it public - if n != nil && n.prefix == prefix { - if n.public { - n.val = v - n.public = true - return true - } - n.val = v - n.public = true - return false - } - - return false -} - -// DeletePrefix delete the exact prefix and return true if it existed. -func (t *Tree[T]) DeletePrefix(prefix netip.Prefix) bool { - root := t.rootV4 - if prefix.Addr().Is6() { - root = t.rootV6 - } - var parent *node[T] - n := root - // bit position is given by the mask bits - bitPosition := 0 - // mask the address - address := prefix.Masked().Addr() - mask := prefix.Bits() - for bitPosition < mask { - // Look for a child checking the bit position after the mask - parent = n - n = n.child[getBitFromAddr(address, bitPosition+1)] - if n == nil { - return false - } - // check we are in the right branch comparing the suffixes - if !n.prefix.Contains(address) { - return false - } - // update the new bit position with the new node mask - bitPosition = n.prefix.Bits() - } - // check if the node contains the prefix we want to delete - if n.prefix != prefix { - return false - } - // Delete the value - n.public = false - var zeroT T - n.val = zeroT - - nodeChildren := 0 - if n.child[0] != nil { - nodeChildren++ - } - if n.child[1] != nil { - nodeChildren++ - } - // If there is a parent and this node does not have any children - // this is a leaf so we can delete this node. - // - parent -> child(to be deleted) - if parent != nil && nodeChildren == 0 { - if parent.child[0] != nil && parent.child[0] == n { - parent.child[0] = nil - } else if parent.child[1] != nil && parent.child[1] == n { - parent.child[1] = nil - } else { - panic("wrong parent") - } - n = nil - } - // Check if we should merge this node - // The root node can not be merged - if n != root && nodeChildren == 1 { - n.mergeChild() - } - // Check if we should merge the parent's other child - // parent -> deletedNode - // |--> child - parentChildren := 0 - if parent != nil { - if parent.child[0] != nil { - parentChildren++ - } - if parent.child[1] != nil { - parentChildren++ - } - if parent != root && parentChildren == 1 && !parent.public { - parent.mergeChild() - } - } - return true -} - -// for testing, returns the number of public nodes in the tree. -func (t *Tree[T]) Len(isV6 bool) int { - count := 0 - t.DepthFirstWalk(isV6, func(k netip.Prefix, v T) bool { - count++ - return false - }) - return count -} - -// WalkFn is used when walking the tree. Takes a -// key and value, returning if iteration should -// be terminated. -type WalkFn[T any] func(s netip.Prefix, v T) bool - -// DepthFirstWalk is used to walk the tree of the corresponding IP family -func (t *Tree[T]) DepthFirstWalk(isIPv6 bool, fn WalkFn[T]) { - if isIPv6 { - recursiveWalk(t.rootV6, fn) - } - recursiveWalk(t.rootV4, fn) -} - -// recursiveWalk is used to do a pre-order walk of a node -// recursively. Returns true if the walk should be aborted -func recursiveWalk[T any](n *node[T], fn WalkFn[T]) bool { - if n == nil { - return true - } - // Visit the public values if any - if n.public && fn(n.prefix, n.val) { - return true - } - - // Recurse on the children - if n.child[0] != nil { - if recursiveWalk(n.child[0], fn) { - return true - } - } - if n.child[1] != nil { - if recursiveWalk(n.child[1], fn) { - return true - } - } - return false -} - -// WalkPrefix is used to walk the tree under a prefix -func (t *Tree[T]) WalkPrefix(prefix netip.Prefix, fn WalkFn[T]) { - n := t.rootV4 - if prefix.Addr().Is6() { - n = t.rootV6 - } - bitPosition := 0 - // mask the address for sanity - address := prefix.Masked().Addr() - // we can't check longer than the request mask - mask := prefix.Bits() - // walk the network bits of the prefix - for bitPosition < mask { - // Look for a child checking the bit position after the mask - n = n.child[getBitFromAddr(address, bitPosition+1)] - if n == nil { - return - } - // check we are in the right branch comparing the suffixes - if !n.prefix.Contains(address) { - break - } - // update the new bit position with the new node mask - bitPosition = n.prefix.Bits() - } - recursiveWalk[T](n, fn) - -} - -// WalkPath is used to walk the tree, but only visiting nodes -// from the root down to a given IP prefix. Where WalkPrefix walks -// all the entries *under* the given prefix, this walks the -// entries *above* the given prefix. -func (t *Tree[T]) WalkPath(path netip.Prefix, fn WalkFn[T]) { - n := t.rootV4 - if path.Addr().Is6() { - n = t.rootV6 - } - bitPosition := 0 - // mask the address for sanity - address := path.Masked().Addr() - // we can't check longer than the request mask - mask := path.Bits() - // walk the network bits of the prefix - for bitPosition < mask { - // Visit the public values if any - if n.public && fn(n.prefix, n.val) { - return - } - // Look for a child checking the bit position after the mask - n = n.child[getBitFromAddr(address, bitPosition+1)] - if n == nil { - return - } - // check we are in the right branch comparing the suffixes - if !n.prefix.Contains(address) { - return - } - // update the new bit position with the new node mask - bitPosition = n.prefix.Bits() - } - // check if this node is a public node and contains a prefix - if n != nil && n.public && n.prefix == path { - fn(n.prefix, n.val) - } -} - -// TopLevelPrefixes is used to return a map with all the Top Level prefixes -// from the corresponding IP family and its values. -// For example, if the tree contains entries for 10.0.0.0/8, 10.1.0.0/16, and 192.168.0.0/16, -// this will return 10.0.0.0/8 and 192.168.0.0/16. -func (t *Tree[T]) TopLevelPrefixes(isIPv6 bool) map[string]T { - if isIPv6 { - return t.topLevelPrefixes(t.rootV6) - } - return t.topLevelPrefixes(t.rootV4) -} - -// topLevelPrefixes is used to return a map with all the Top Level prefixes and its values -func (t *Tree[T]) topLevelPrefixes(root *node[T]) map[string]T { - result := map[string]T{} - queue := []*node[T]{root} - - for len(queue) > 0 { - n := queue[0] - queue = queue[1:] - // store and continue, only interested on the top level prefixes - if n.public { - result[n.prefix.String()] = n.val - continue - } - if n.child[0] != nil { - queue = append(queue, n.child[0]) - } - if n.child[1] != nil { - queue = append(queue, n.child[1]) - } - } - return result -} - -// GetHostIPPrefixMatches returns the list of prefixes that contain the specified Host IP. -// An IP is considered a Host IP if is within the subnet range and is not the network address -// or, if IPv4, the broadcast address (RFC 1878). -func (t *Tree[T]) GetHostIPPrefixMatches(ip netip.Addr) map[netip.Prefix]T { - // walk the tree to find all the prefixes containing this IP - ipPrefix := netip.PrefixFrom(ip, ip.BitLen()) - prefixes := map[netip.Prefix]T{} - t.WalkPath(ipPrefix, func(k netip.Prefix, v T) bool { - if prefixContainIP(k, ipPrefix.Addr()) { - prefixes[k] = v - } - return false - }) - return prefixes -} - -// assume starts at 0 from the MSB: 0.1.2......31 -// return 0 or 1 -func getBitFromAddr(ip netip.Addr, pos int) int { - bytes := ip.AsSlice() - // get the byte in the slice - index := (pos - 1) / 8 - if index >= len(bytes) { - panic(fmt.Sprintf("ip %s pos %d index %d bytes %v", ip, pos, index, bytes)) - } - // get the offset inside the byte - offset := (pos - 1) % 8 - // check if the bit is set - if bytes[index]&(uint8(0x80)>>offset) > 0 { - return 1 - } - return 0 -} - -// find the common subnet, aka the one with the common prefix -func findAncestor(a, b netip.Prefix) netip.Prefix { - bytesA := a.Addr().AsSlice() - bytesB := b.Addr().AsSlice() - bytes := make([]byte, len(bytesA)) - - max := a.Bits() - if l := b.Bits(); l < max { - max = l - } - - mask := 0 - for i := range bytesA { - xor := bytesA[i] ^ bytesB[i] - if xor == 0 { - bytes[i] = bytesA[i] - mask += 8 - - } else { - pos := bits.LeadingZeros8(xor) - mask += pos - // mask off the non leading zeros - bytes[i] = bytesA[i] & (^uint8(0) << (8 - pos)) - break - } - } - if mask > max { - mask = max - } - - addr, ok := netip.AddrFromSlice(bytes) - if !ok { - panic(bytes) - } - ancestor := netip.PrefixFrom(addr, mask) - return ancestor.Masked() -} - -// prefixContainIP returns true if the given IP is contained with the prefix, -// is not the network address and also, if IPv4, is not the broadcast address. -// This is required because the Kubernetes allocators reserve these addresses -// so IPAddresses can not block deletion of this ranges. -func prefixContainIP(prefix netip.Prefix, ip netip.Addr) bool { - // if the IP is the network address is not contained - if prefix.Masked().Addr() == ip { - return false - } - // the broadcast address is not considered contained for IPv4 - if !ip.Is6() { - ipLast, err := broadcastAddress(prefix) - if err != nil || ipLast == ip { - return false - } - } - return prefix.Contains(ip) -} - -// TODO(aojea) consolidate all these IPs utils -// pkg/registry/core/service/ipallocator/ipallocator.go -// broadcastAddress returns the broadcast address of the subnet -// The broadcast address is obtained by setting all the host bits -// in a subnet to 1. -// network 192.168.0.0/24 : subnet bits 24 host bits 32 - 24 = 8 -// broadcast address 192.168.0.255 -func broadcastAddress(subnet netip.Prefix) (netip.Addr, error) { - base := subnet.Masked().Addr() - bytes := base.AsSlice() - // get all the host bits from the subnet - n := 8*len(bytes) - subnet.Bits() - // set all the host bits to 1 - for i := len(bytes) - 1; i >= 0 && n > 0; i-- { - if n >= 8 { - bytes[i] = 0xff - n -= 8 - } else { - mask := ^uint8(0) >> (8 - n) - bytes[i] |= mask - break - } - } - - addr, ok := netip.AddrFromSlice(bytes) - if !ok { - return netip.Addr{}, fmt.Errorf("invalid address %v", bytes) - } - return addr, nil -} diff --git a/pkg/util/iptree/iptree_test.go b/pkg/util/iptree/iptree_test.go deleted file mode 100644 index f25ff5800f9..00000000000 --- a/pkg/util/iptree/iptree_test.go +++ /dev/null @@ -1,781 +0,0 @@ -/* -Copyright 2023 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package iptree - -import ( - "math/rand" - "net/netip" - "reflect" - "sort" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "k8s.io/apimachinery/pkg/util/sets" -) - -func Test_InsertGetDelete(t *testing.T) { - testCases := []struct { - name string - prefix netip.Prefix - }{ - { - name: "ipv4", - prefix: netip.MustParsePrefix("192.168.0.0/24"), - }, - { - name: "ipv6", - prefix: netip.MustParsePrefix("fd00:1:2:3::/124"), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - tree := New[int]() - ok := tree.InsertPrefix(tc.prefix, 1) - if ok { - t.Fatal("should not exist") - } - if _, ok := tree.GetPrefix(tc.prefix); !ok { - t.Errorf("CIDR %s not found", tc.prefix) - } - if ok := tree.DeletePrefix(tc.prefix); !ok { - t.Errorf("CIDR %s not deleted", tc.prefix) - } - if _, ok := tree.GetPrefix(tc.prefix); ok { - t.Errorf("CIDR %s found", tc.prefix) - } - }) - } - -} - -func TestBasicIPv4(t *testing.T) { - tree := New[int]() - // insert - ipnet := netip.MustParsePrefix("192.168.0.0/24") - ok := tree.InsertPrefix(ipnet, 1) - if ok { - t.Fatal("should not exist") - } - // check exist - if _, ok := tree.GetPrefix(ipnet); !ok { - t.Errorf("CIDR %s not found", ipnet) - } - - // check does not exist - ipnet2 := netip.MustParsePrefix("12.1.0.0/16") - if _, ok := tree.GetPrefix(ipnet2); ok { - t.Errorf("CIDR %s not expected", ipnet2) - } - - // check insert existing prefix updates the value - ok = tree.InsertPrefix(ipnet2, 2) - if ok { - t.Errorf("should not exist: %s", ipnet2) - } - - ok = tree.InsertPrefix(ipnet2, 3) - if !ok { - t.Errorf("should be updated: %s", ipnet2) - } - - if v, ok := tree.GetPrefix(ipnet2); !ok || v != 3 { - t.Errorf("CIDR %s not expected", ipnet2) - } - - // check longer prefix matching - ipnet3 := netip.MustParsePrefix("12.1.0.2/32") - lpm, _, ok := tree.LongestPrefixMatch(ipnet3) - if !ok || lpm != ipnet2 { - t.Errorf("expected %s got %s", ipnet2, lpm) - } -} - -func TestBasicIPv6(t *testing.T) { - tree := New[int]() - // insert - ipnet := netip.MustParsePrefix("2001:db8::/64") - ok := tree.InsertPrefix(ipnet, 1) - if ok { - t.Fatal("should not exist") - } - // check exist - if _, ok := tree.GetPrefix(ipnet); !ok { - t.Errorf("CIDR %s not found", ipnet) - } - - // check does not exist - ipnet2 := netip.MustParsePrefix("2001:db8:1:3:4::/64") - if _, ok := tree.GetPrefix(ipnet2); ok { - t.Errorf("CIDR %s not expected", ipnet2) - } - - // check insert existing prefix updates the value - ok = tree.InsertPrefix(ipnet2, 2) - if ok { - t.Errorf("should not exist: %s", ipnet2) - } - - ok = tree.InsertPrefix(ipnet2, 3) - if !ok { - t.Errorf("should be updated: %s", ipnet2) - } - - if v, ok := tree.GetPrefix(ipnet2); !ok || v != 3 { - t.Errorf("CIDR %s not expected", ipnet2) - } - - // check longer prefix matching - ipnet3 := netip.MustParsePrefix("2001:db8:1:3:4::/96") - lpm, _, ok := tree.LongestPrefixMatch(ipnet3) - if !ok || lpm != ipnet2 { - t.Errorf("expected %s got %s", ipnet2, lpm) - } -} - -func TestInsertGetDelete100K(t *testing.T) { - testCases := []struct { - name string - is6 bool - }{ - { - name: "ipv4", - }, - { - name: "ipv6", - is6: true, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - cidrs := generateRandomCIDRs(tc.is6, 100*1000) - tree := New[string]() - - for k := range cidrs { - ok := tree.InsertPrefix(k, k.String()) - if ok { - t.Errorf("error inserting: %v", k) - } - } - - if tree.Len(tc.is6) != len(cidrs) { - t.Errorf("expected %d nodes on the tree, got %d", len(cidrs), tree.Len(tc.is6)) - } - - list := cidrs.UnsortedList() - for _, k := range list { - if v, ok := tree.GetPrefix(k); !ok { - t.Errorf("CIDR %s not found", k) - return - } else if v != k.String() { - t.Errorf("CIDR value %s not found", k) - return - } - ok := tree.DeletePrefix(k) - if !ok { - t.Errorf("CIDR delete %s error", k) - } - } - - if tree.Len(tc.is6) != 0 { - t.Errorf("No node expected on the tree, got: %d %v", tree.Len(tc.is6), cidrs) - } - }) - } -} - -func Test_findAncestor(t *testing.T) { - tests := []struct { - name string - a netip.Prefix - b netip.Prefix - want netip.Prefix - }{ - { - name: "ipv4 direct parent", - a: netip.MustParsePrefix("192.168.0.0/24"), - b: netip.MustParsePrefix("192.168.1.0/24"), - want: netip.MustParsePrefix("192.168.0.0/23"), - }, - { - name: "ipv4 root parent ", - a: netip.MustParsePrefix("192.168.0.0/24"), - b: netip.MustParsePrefix("1.168.1.0/24"), - want: netip.MustParsePrefix("0.0.0.0/0"), - }, - { - name: "ipv4 parent /1", - a: netip.MustParsePrefix("192.168.0.0/24"), - b: netip.MustParsePrefix("184.168.1.0/24"), - want: netip.MustParsePrefix("128.0.0.0/1"), - }, - { - name: "ipv6 direct parent", - a: netip.MustParsePrefix("fd00:1:1:1::/64"), - b: netip.MustParsePrefix("fd00:1:1:2::/64"), - want: netip.MustParsePrefix("fd00:1:1::/62"), - }, - { - name: "ipv6 root parent ", - a: netip.MustParsePrefix("fd00:1:1:1::/64"), - b: netip.MustParsePrefix("1:1:1:1::/64"), - want: netip.MustParsePrefix("::/0"), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := findAncestor(tt.a, tt.b); !reflect.DeepEqual(got, tt.want) { - t.Errorf("findAncestor() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_getBitFromAddr(t *testing.T) { - tests := []struct { - name string - ip netip.Addr - pos int - want int - }{ - // 192.168.0.0 - // 11000000.10101000.00000000.00000001 - { - name: "ipv4 first is a one", - ip: netip.MustParseAddr("192.168.0.0"), - pos: 1, - want: 1, - }, - { - name: "ipv4 middle is a zero", - ip: netip.MustParseAddr("192.168.0.0"), - pos: 16, - want: 0, - }, - { - name: "ipv4 middle is a one", - ip: netip.MustParseAddr("192.168.0.0"), - pos: 13, - want: 1, - }, - { - name: "ipv4 last is a zero", - ip: netip.MustParseAddr("192.168.0.0"), - pos: 32, - want: 0, - }, - // 2001:db8::ff00:42:8329 - // 0010000000000001:0000110110111000:0000000000000000:0000000000000000:0000000000000000:1111111100000000:0000000001000010:1000001100101001 - { - name: "ipv6 first is a zero", - ip: netip.MustParseAddr("2001:db8::ff00:42:8329"), - pos: 1, - want: 0, - }, - { - name: "ipv6 middle is a zero", - ip: netip.MustParseAddr("2001:db8::ff00:42:8329"), - pos: 56, - want: 0, - }, - { - name: "ipv6 middle is a one", - ip: netip.MustParseAddr("2001:db8::ff00:42:8329"), - pos: 81, - want: 1, - }, - { - name: "ipv6 last is a one", - ip: netip.MustParseAddr("2001:db8::ff00:42:8329"), - pos: 128, - want: 1, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := getBitFromAddr(tt.ip, tt.pos); got != tt.want { - t.Errorf("getBitFromAddr() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestShortestPrefix(t *testing.T) { - r := New[int]() - - keys := []string{ - "10.0.0.0/8", - "10.21.0.0/16", - "10.221.0.0/16", - "10.1.2.3/32", - "10.1.2.0/24", - "192.168.0.0/24", - "192.168.0.0/16", - } - for _, k := range keys { - ok := r.InsertPrefix(netip.MustParsePrefix(k), 0) - if ok { - t.Errorf("unexpected update on insert %s", k) - } - } - if r.Len(false) != len(keys) { - t.Fatalf("bad len: %v %v", r.Len(false), len(keys)) - } - - type exp struct { - inp string - out string - } - cases := []exp{ - {"192.168.0.3/32", "192.168.0.0/16"}, - {"10.1.2.4/21", "10.0.0.0/8"}, - {"192.168.0.0/16", "192.168.0.0/16"}, - {"192.168.0.0/32", "192.168.0.0/16"}, - {"10.1.2.3/32", "10.0.0.0/8"}, - } - for _, test := range cases { - m, _, ok := r.ShortestPrefixMatch(netip.MustParsePrefix(test.inp)) - if !ok { - t.Fatalf("no match: %v", test) - } - if m != netip.MustParsePrefix(test.out) { - t.Fatalf("mis-match: %v %v", m, test) - } - } - - // not match - _, _, ok := r.ShortestPrefixMatch(netip.MustParsePrefix("0.0.0.0/0")) - if ok { - t.Fatalf("match unexpected for 0.0.0.0/0") - } -} - -func TestLongestPrefixMatch(t *testing.T) { - r := New[int]() - - keys := []string{ - "10.0.0.0/8", - "10.21.0.0/16", - "10.221.0.0/16", - "10.1.2.3/32", - "10.1.2.0/24", - "192.168.0.0/24", - "192.168.0.0/16", - } - for _, k := range keys { - ok := r.InsertPrefix(netip.MustParsePrefix(k), 0) - if ok { - t.Errorf("unexpected update on insert %s", k) - } - } - if r.Len(false) != len(keys) { - t.Fatalf("bad len: %v %v", r.Len(false), len(keys)) - } - - type exp struct { - inp string - out string - } - cases := []exp{ - {"192.168.0.3/32", "192.168.0.0/24"}, - {"10.1.2.4/21", "10.0.0.0/8"}, - {"10.21.2.0/24", "10.21.0.0/16"}, - {"10.1.2.3/32", "10.1.2.3/32"}, - } - for _, test := range cases { - m, _, ok := r.LongestPrefixMatch(netip.MustParsePrefix(test.inp)) - if !ok { - t.Fatalf("no match: %v", test) - } - if m != netip.MustParsePrefix(test.out) { - t.Fatalf("mis-match: %v %v", m, test) - } - } - // not match - _, _, ok := r.LongestPrefixMatch(netip.MustParsePrefix("0.0.0.0/0")) - if ok { - t.Fatalf("match unexpected for 0.0.0.0/0") - } -} - -func TestTopLevelPrefixesV4(t *testing.T) { - r := New[string]() - - keys := []string{ - "10.0.0.0/8", - "10.21.0.0/16", - "10.221.0.0/16", - "10.1.2.3/32", - "10.1.2.0/24", - "192.168.0.0/20", - "192.168.1.0/24", - "172.16.0.0/12", - "172.21.23.0/24", - } - for _, k := range keys { - ok := r.InsertPrefix(netip.MustParsePrefix(k), k) - if ok { - t.Errorf("unexpected update on insert %s", k) - } - } - if r.Len(false) != len(keys) { - t.Fatalf("bad len: %v %v", r.Len(false), len(keys)) - } - - expected := []string{ - "10.0.0.0/8", - "192.168.0.0/20", - "172.16.0.0/12", - } - parents := r.TopLevelPrefixes(false) - if len(parents) != len(expected) { - t.Fatalf("bad len: %v %v", len(parents), len(expected)) - } - - for _, k := range expected { - v, ok := parents[k] - if !ok { - t.Errorf("key %s not found", k) - } - if v != k { - t.Errorf("value expected %s got %s", k, v) - } - } -} - -func TestTopLevelPrefixesV6(t *testing.T) { - r := New[string]() - - keys := []string{ - "2001:db8:1:2:3::/64", - "2001:db8::/64", - "2001:db8:1:1:1::/64", - "2001:db8:1:1:1::/112", - } - for _, k := range keys { - ok := r.InsertPrefix(netip.MustParsePrefix(k), k) - if ok { - t.Errorf("unexpected update on insert %s", k) - } - } - - if r.Len(true) != len(keys) { - t.Fatalf("bad len: %v %v", r.Len(true), len(keys)) - } - - expected := []string{ - "2001:db8::/64", - "2001:db8:1:2:3::/64", - "2001:db8:1:1:1::/64", - } - parents := r.TopLevelPrefixes(true) - if len(parents) != len(expected) { - t.Fatalf("bad len: %v %v", len(parents), len(expected)) - } - - for _, k := range expected { - v, ok := parents[k] - if !ok { - t.Errorf("key %s not found", k) - } - if v != k { - t.Errorf("value expected %s got %s", k, v) - } - } -} - -func TestWalkV4(t *testing.T) { - r := New[int]() - - keys := []string{ - "10.0.0.0/8", - "10.1.0.0/16", - "10.1.1.0/24", - "10.1.1.32/26", - "10.1.1.33/32", - } - for _, k := range keys { - ok := r.InsertPrefix(netip.MustParsePrefix(k), 0) - if ok { - t.Errorf("unexpected update on insert %s", k) - } - } - if r.Len(false) != len(keys) { - t.Fatalf("bad len: %v %v", r.Len(false), len(keys)) - } - - // match exact prefix - path := []string{} - r.WalkPath(netip.MustParsePrefix("10.1.1.32/26"), func(k netip.Prefix, v int) bool { - path = append(path, k.String()) - return false - }) - if !cmp.Equal(path, keys[:4]) { - t.Errorf("Walkpath expected %v got %v", keys[:4], path) - } - // not match on prefix - path = []string{} - r.WalkPath(netip.MustParsePrefix("10.1.1.33/26"), func(k netip.Prefix, v int) bool { - path = append(path, k.String()) - return false - }) - if !cmp.Equal(path, keys[:3]) { - t.Errorf("Walkpath expected %v got %v", keys[:3], path) - } - // match exact prefix - path = []string{} - r.WalkPrefix(netip.MustParsePrefix("10.0.0.0/8"), func(k netip.Prefix, v int) bool { - path = append(path, k.String()) - return false - }) - if !cmp.Equal(path, keys) { - t.Errorf("WalkPrefix expected %v got %v", keys, path) - } - // not match on prefix - path = []string{} - r.WalkPrefix(netip.MustParsePrefix("10.0.0.0/9"), func(k netip.Prefix, v int) bool { - path = append(path, k.String()) - return false - }) - if !cmp.Equal(path, keys[1:]) { - t.Errorf("WalkPrefix expected %v got %v", keys[1:], path) - } -} - -func TestWalkV6(t *testing.T) { - r := New[int]() - - keys := []string{ - "2001:db8::/48", - "2001:db8::/64", - "2001:db8::/96", - "2001:db8::/112", - "2001:db8::/128", - } - for _, k := range keys { - ok := r.InsertPrefix(netip.MustParsePrefix(k), 0) - if ok { - t.Errorf("unexpected update on insert %s", k) - } - } - if r.Len(true) != len(keys) { - t.Fatalf("bad len: %v %v", r.Len(false), len(keys)) - } - - // match exact prefix - path := []string{} - r.WalkPath(netip.MustParsePrefix("2001:db8::/112"), func(k netip.Prefix, v int) bool { - path = append(path, k.String()) - return false - }) - if !cmp.Equal(path, keys[:4]) { - t.Errorf("Walkpath expected %v got %v", keys[:4], path) - } - // not match on prefix - path = []string{} - r.WalkPath(netip.MustParsePrefix("2001:db8::1/112"), func(k netip.Prefix, v int) bool { - path = append(path, k.String()) - return false - }) - if !cmp.Equal(path, keys[:3]) { - t.Errorf("Walkpath expected %v got %v", keys[:3], path) - } - // match exact prefix - path = []string{} - r.WalkPrefix(netip.MustParsePrefix("2001:db8::/48"), func(k netip.Prefix, v int) bool { - path = append(path, k.String()) - return false - }) - if !cmp.Equal(path, keys) { - t.Errorf("WalkPrefix expected %v got %v", keys, path) - } - // not match on prefix - path = []string{} - r.WalkPrefix(netip.MustParsePrefix("2001:db8::/49"), func(k netip.Prefix, v int) bool { - path = append(path, k.String()) - return false - }) - if !cmp.Equal(path, keys[1:]) { - t.Errorf("WalkPrefix expected %v got %v", keys[1:], path) - } -} - -func TestGetHostIPPrefixMatches(t *testing.T) { - r := New[int]() - - keys := []string{ - "10.0.0.0/8", - "10.21.0.0/16", - "10.221.0.0/16", - "10.1.2.3/32", - "10.1.2.0/24", - "192.168.0.0/24", - "192.168.0.0/16", - "2001:db8::/48", - "2001:db8::/64", - "2001:db8::/96", - } - for _, k := range keys { - ok := r.InsertPrefix(netip.MustParsePrefix(k), 0) - if ok { - t.Errorf("unexpected update on insert %s", k) - } - } - - type exp struct { - inp string - out []string - } - cases := []exp{ - {"192.168.0.3", []string{"192.168.0.0/24", "192.168.0.0/16"}}, - {"10.1.2.4", []string{"10.1.2.0/24", "10.0.0.0/8"}}, - {"10.1.2.0", []string{"10.0.0.0/8"}}, - {"10.1.2.255", []string{"10.0.0.0/8"}}, - {"192.168.0.0", []string{}}, - {"192.168.1.0", []string{"192.168.0.0/16"}}, - {"10.1.2.255", []string{"10.0.0.0/8"}}, - {"2001:db8::1", []string{"2001:db8::/96", "2001:db8::/64", "2001:db8::/48"}}, - {"2001:db8::", []string{}}, - {"2001:db8::ffff:ffff:ffff:ffff", []string{"2001:db8::/64", "2001:db8::/48"}}, - } - for _, test := range cases { - m := r.GetHostIPPrefixMatches(netip.MustParseAddr(test.inp)) - in := []netip.Prefix{} - for k := range m { - in = append(in, k) - } - out := []netip.Prefix{} - for _, s := range test.out { - out = append(out, netip.MustParsePrefix(s)) - } - - // sort by prefix bits to avoid flakes - sort.Slice(in, func(i, j int) bool { return in[i].Bits() < in[j].Bits() }) - sort.Slice(out, func(i, j int) bool { return out[i].Bits() < out[j].Bits() }) - if !reflect.DeepEqual(in, out) { - t.Fatalf("mis-match: %v %v", in, out) - } - } - - // not match - _, _, ok := r.ShortestPrefixMatch(netip.MustParsePrefix("0.0.0.0/0")) - if ok { - t.Fatalf("match unexpected for 0.0.0.0/0") - } -} - -func Test_prefixContainIP(t *testing.T) { - tests := []struct { - name string - prefix netip.Prefix - ip netip.Addr - want bool - }{ - { - name: "IPv4 contains", - prefix: netip.MustParsePrefix("192.168.0.0/24"), - ip: netip.MustParseAddr("192.168.0.1"), - want: true, - }, - { - name: "IPv4 network address", - prefix: netip.MustParsePrefix("192.168.0.0/24"), - ip: netip.MustParseAddr("192.168.0.0"), - }, - { - name: "IPv4 broadcast address", - prefix: netip.MustParsePrefix("192.168.0.0/24"), - ip: netip.MustParseAddr("192.168.0.255"), - }, - { - name: "IPv4 does not contain", - prefix: netip.MustParsePrefix("192.168.0.0/24"), - ip: netip.MustParseAddr("192.168.1.2"), - }, - { - name: "IPv6 contains", - prefix: netip.MustParsePrefix("2001:db2::/96"), - ip: netip.MustParseAddr("2001:db2::1"), - want: true, - }, - { - name: "IPv6 network address", - prefix: netip.MustParsePrefix("2001:db2::/96"), - ip: netip.MustParseAddr("2001:db2::"), - }, - { - name: "IPv6 broadcast address", - prefix: netip.MustParsePrefix("2001:db2::/96"), - ip: netip.MustParseAddr("2001:db2::ffff:ffff"), - want: true, - }, - { - name: "IPv6 does not contain", - prefix: netip.MustParsePrefix("2001:db2::/96"), - ip: netip.MustParseAddr("2001:db2:1:2:3::1"), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := prefixContainIP(tt.prefix, tt.ip); got != tt.want { - t.Errorf("prefixContainIP() = %v, want %v", got, tt.want) - } - }) - } -} - -func BenchmarkInsertUpdate(b *testing.B) { - r := New[bool]() - ipList := generateRandomCIDRs(true, 20000).UnsortedList() - for _, ip := range ipList { - r.InsertPrefix(ip, true) - } - - b.ResetTimer() - for n := 0; n < b.N; n++ { - r.InsertPrefix(ipList[n%len(ipList)], true) - } -} - -func generateRandomCIDRs(is6 bool, number int) sets.Set[netip.Prefix] { - n := 4 - if is6 { - n = 16 - } - cidrs := sets.Set[netip.Prefix]{} - rand.New(rand.NewSource(time.Now().UnixNano())) - for i := 0; i < number; i++ { - bytes := make([]byte, n) - for i := 0; i < n; i++ { - bytes[i] = uint8(rand.Intn(255)) - } - - ip, ok := netip.AddrFromSlice(bytes) - if !ok { - continue - } - - bits := rand.Intn(n * 8) - prefix := netip.PrefixFrom(ip, bits).Masked() - if prefix.IsValid() { - cidrs.Insert(prefix) - } - } - return cidrs -}