remove iptree from tree

This commit is contained in:
Antonio Ojea 2024-05-20 19:55:21 +00:00
parent b04ca186d8
commit f36975b193
2 changed files with 0 additions and 1460 deletions

View File

@ -1,679 +0,0 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package iptree
import (
"fmt"
"math/bits"
"net/netip"
)
// iptree implement a radix tree that uses IP prefixes as nodes and allows to store values in each node.
// Example:
//
// r := New[int]()
//
// prefixes := []string{
// "0.0.0.0/0",
// "10.0.0.0/8",
// "10.0.0.0/16",
// "10.1.0.0/16",
// "10.1.1.0/24",
// "10.1.244.0/24",
// "10.0.0.0/24",
// "10.0.0.3/32",
// "192.168.0.0/24",
// "192.168.0.0/28",
// "192.168.129.0/28",
// }
// for _, k := range prefixes {
// r.InsertPrefix(netip.MustParsePrefix(k), 0)
// }
//
// (*) means the node is not public, is not storing any value
//
// 0.0.0.0/0 --- 10.0.0.0/8 --- *10.0.0.0/15 --- 10.0.0.0/16 --- 10.0.0.0/24 --- 10.0.0.3/32
// | |
// | \ -------- 10.1.0.0/16 --- 10.1.1.0/24
// | |
// | \ ------- 10.1.244.0/24
// |
// \------ *192.168.0.0/16 --- 192.168.0.0/24 --- 192.168.0.0/28
// |
// \ -------- 192.168.129.0/28
// node is an element of radix tree with a netip.Prefix optimized to store IP prefixes.
type node[T any] struct {
// prefix network CIDR
prefix netip.Prefix
// public nodes are used to store values
public bool
val T
child [2]*node[T] // binary tree
}
// mergeChild allow to compress the tree
// when n has exactly one child and no value
// p -> n -> b -> c ==> p -> b -> c
func (n *node[T]) mergeChild() {
// public nodes can not be merged
if n.public {
return
}
// can not merge if there are two children
if n.child[0] != nil &&
n.child[1] != nil {
return
}
// can not merge if there are no children
if n.child[0] == nil &&
n.child[1] == nil {
return
}
// find the child and merge it
var child *node[T]
if n.child[0] != nil {
child = n.child[0]
} else if n.child[1] != nil {
child = n.child[1]
}
n.prefix = child.prefix
n.public = child.public
n.val = child.val
n.child = child.child
// remove any references from the deleted node
// to avoid memory leak
child.child[0] = nil
child.child[1] = nil
}
// Tree is a radix tree for IPv4 and IPv6 networks.
type Tree[T any] struct {
rootV4 *node[T]
rootV6 *node[T]
}
// New creates a new Radix Tree for IP addresses.
func New[T any]() *Tree[T] {
return &Tree[T]{
rootV4: &node[T]{
prefix: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
},
rootV6: &node[T]{
prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0),
},
}
}
// GetPrefix returns the stored value and true if the exact prefix exists in the tree.
func (t *Tree[T]) GetPrefix(prefix netip.Prefix) (T, bool) {
var zeroT T
n := t.rootV4
if prefix.Addr().Is6() {
n = t.rootV6
}
bitPosition := 0
// mask the address for sanity
address := prefix.Masked().Addr()
// we can't check longer than the request mask
mask := prefix.Bits()
// walk the network bits of the prefix
for bitPosition < mask {
// Look for a child checking the bit position after the mask
n = n.child[getBitFromAddr(address, bitPosition+1)]
if n == nil {
return zeroT, false
}
// check we are in the right branch comparing the suffixes
if !n.prefix.Contains(address) {
return zeroT, false
}
// update the new bit position with the new node mask
bitPosition = n.prefix.Bits()
}
// check if this node is a public node and contains a prefix
if n != nil && n.public && n.prefix == prefix {
return n.val, true
}
return zeroT, false
}
// LongestPrefixMatch returns the longest prefix match, the stored value and true if exist.
// For example, considering the following prefixes 192.168.20.16/28 and 192.168.0.0/16,
// when the address 192.168.20.19/32 is looked up it will return 192.168.20.16/28.
func (t *Tree[T]) LongestPrefixMatch(prefix netip.Prefix) (netip.Prefix, T, bool) {
n := t.rootV4
if prefix.Addr().Is6() {
n = t.rootV6
}
var last *node[T]
// bit position is given by the mask bits
bitPosition := 0
// mask the address
address := prefix.Masked().Addr()
mask := prefix.Bits()
// walk the network bits of the prefix
for bitPosition < mask {
if n.public {
last = n
}
// Look for a child checking the bit position after the mask
n = n.child[getBitFromAddr(address, bitPosition+1)]
if n == nil {
break
}
// check we are in the right branch comparing the suffixes
if !n.prefix.Contains(address) {
break
}
// update the new bit position with the new node mask
bitPosition = n.prefix.Bits()
}
if n != nil && n.public && n.prefix == prefix {
last = n
}
if last != nil {
return last.prefix, last.val, true
}
var zeroT T
return netip.Prefix{}, zeroT, false
}
// ShortestPrefixMatch returns the shortest prefix match, the stored value and true if exist.
// For example, considering the following prefixes 192.168.20.16/28 and 192.168.0.0/16,
// when the address 192.168.20.19/32 is looked up it will return 192.168.0.0/16.
func (t *Tree[T]) ShortestPrefixMatch(prefix netip.Prefix) (netip.Prefix, T, bool) {
var zeroT T
n := t.rootV4
if prefix.Addr().Is6() {
n = t.rootV6
}
// bit position is given by the mask bits
bitPosition := 0
// mask the address
address := prefix.Masked().Addr()
mask := prefix.Bits()
for bitPosition < mask {
if n.public {
return n.prefix, n.val, true
}
// Look for a child checking the bit position after the mask
n = n.child[getBitFromAddr(address, bitPosition+1)]
if n == nil {
return netip.Prefix{}, zeroT, false
}
// check we are in the right branch comparing the suffixes
if !n.prefix.Contains(address) {
return netip.Prefix{}, zeroT, false
}
// update the new bit position with the new node mask
bitPosition = n.prefix.Bits()
}
if n != nil && n.public && n.prefix == prefix {
return n.prefix, n.val, true
}
return netip.Prefix{}, zeroT, false
}
// InsertPrefix is used to add a new entry or update
// an existing entry. Returns true if updated.
func (t *Tree[T]) InsertPrefix(prefix netip.Prefix, v T) bool {
n := t.rootV4
if prefix.Addr().Is6() {
n = t.rootV6
}
var parent *node[T]
// bit position is given by the mask bits
bitPosition := 0
// mask the address
address := prefix.Masked().Addr()
mask := prefix.Bits()
for bitPosition < mask {
// Look for a child checking the bit position after the mask
childIndex := getBitFromAddr(address, bitPosition+1)
parent = n
n = n.child[childIndex]
// if no child create a new one with
if n == nil {
parent.child[childIndex] = &node[T]{
public: true,
val: v,
prefix: prefix,
}
return false
}
// update the new bit position with the new node mask
bitPosition = n.prefix.Bits()
// continue if we are in the right branch and current
// node is our parent
if n.prefix.Contains(address) && bitPosition <= mask {
continue
}
// Split the node and add a new child:
// - Case 1: parent -> child -> n
// - Case 2: parent -> newnode |--> child
// |--> n
child := &node[T]{
prefix: prefix,
public: true,
val: v,
}
// Case 1: existing node is a sibling
if prefix.Contains(n.prefix.Addr()) && bitPosition > mask {
// parent to child
parent.child[childIndex] = child
pos := prefix.Bits() + 1
// calculate if the sibling is at the left or right
child.child[getBitFromAddr(n.prefix.Addr(), pos)] = n
return false
}
// Case 2: existing node has the same mask but different base address
// add common ancestor and branch on it
ancestor := findAncestor(prefix, n.prefix)
link := &node[T]{
prefix: ancestor,
}
pos := parent.prefix.Bits() + 1
parent.child[getBitFromAddr(ancestor.Addr(), pos)] = link
// ancestor -> children
pos = ancestor.Bits() + 1
idxChild := getBitFromAddr(prefix.Addr(), pos)
idxN := getBitFromAddr(n.prefix.Addr(), pos)
if idxChild == idxN {
panic(fmt.Sprintf("wrong ancestor %s: child %s N %s", ancestor.String(), prefix.String(), n.prefix.String()))
}
link.child[idxChild] = child
link.child[idxN] = n
return false
}
// if already exist update it and make it public
if n != nil && n.prefix == prefix {
if n.public {
n.val = v
n.public = true
return true
}
n.val = v
n.public = true
return false
}
return false
}
// DeletePrefix delete the exact prefix and return true if it existed.
func (t *Tree[T]) DeletePrefix(prefix netip.Prefix) bool {
root := t.rootV4
if prefix.Addr().Is6() {
root = t.rootV6
}
var parent *node[T]
n := root
// bit position is given by the mask bits
bitPosition := 0
// mask the address
address := prefix.Masked().Addr()
mask := prefix.Bits()
for bitPosition < mask {
// Look for a child checking the bit position after the mask
parent = n
n = n.child[getBitFromAddr(address, bitPosition+1)]
if n == nil {
return false
}
// check we are in the right branch comparing the suffixes
if !n.prefix.Contains(address) {
return false
}
// update the new bit position with the new node mask
bitPosition = n.prefix.Bits()
}
// check if the node contains the prefix we want to delete
if n.prefix != prefix {
return false
}
// Delete the value
n.public = false
var zeroT T
n.val = zeroT
nodeChildren := 0
if n.child[0] != nil {
nodeChildren++
}
if n.child[1] != nil {
nodeChildren++
}
// If there is a parent and this node does not have any children
// this is a leaf so we can delete this node.
// - parent -> child(to be deleted)
if parent != nil && nodeChildren == 0 {
if parent.child[0] != nil && parent.child[0] == n {
parent.child[0] = nil
} else if parent.child[1] != nil && parent.child[1] == n {
parent.child[1] = nil
} else {
panic("wrong parent")
}
n = nil
}
// Check if we should merge this node
// The root node can not be merged
if n != root && nodeChildren == 1 {
n.mergeChild()
}
// Check if we should merge the parent's other child
// parent -> deletedNode
// |--> child
parentChildren := 0
if parent != nil {
if parent.child[0] != nil {
parentChildren++
}
if parent.child[1] != nil {
parentChildren++
}
if parent != root && parentChildren == 1 && !parent.public {
parent.mergeChild()
}
}
return true
}
// for testing, returns the number of public nodes in the tree.
func (t *Tree[T]) Len(isV6 bool) int {
count := 0
t.DepthFirstWalk(isV6, func(k netip.Prefix, v T) bool {
count++
return false
})
return count
}
// WalkFn is used when walking the tree. Takes a
// key and value, returning if iteration should
// be terminated.
type WalkFn[T any] func(s netip.Prefix, v T) bool
// DepthFirstWalk is used to walk the tree of the corresponding IP family
func (t *Tree[T]) DepthFirstWalk(isIPv6 bool, fn WalkFn[T]) {
if isIPv6 {
recursiveWalk(t.rootV6, fn)
}
recursiveWalk(t.rootV4, fn)
}
// recursiveWalk is used to do a pre-order walk of a node
// recursively. Returns true if the walk should be aborted
func recursiveWalk[T any](n *node[T], fn WalkFn[T]) bool {
if n == nil {
return true
}
// Visit the public values if any
if n.public && fn(n.prefix, n.val) {
return true
}
// Recurse on the children
if n.child[0] != nil {
if recursiveWalk(n.child[0], fn) {
return true
}
}
if n.child[1] != nil {
if recursiveWalk(n.child[1], fn) {
return true
}
}
return false
}
// WalkPrefix is used to walk the tree under a prefix
func (t *Tree[T]) WalkPrefix(prefix netip.Prefix, fn WalkFn[T]) {
n := t.rootV4
if prefix.Addr().Is6() {
n = t.rootV6
}
bitPosition := 0
// mask the address for sanity
address := prefix.Masked().Addr()
// we can't check longer than the request mask
mask := prefix.Bits()
// walk the network bits of the prefix
for bitPosition < mask {
// Look for a child checking the bit position after the mask
n = n.child[getBitFromAddr(address, bitPosition+1)]
if n == nil {
return
}
// check we are in the right branch comparing the suffixes
if !n.prefix.Contains(address) {
break
}
// update the new bit position with the new node mask
bitPosition = n.prefix.Bits()
}
recursiveWalk[T](n, fn)
}
// WalkPath is used to walk the tree, but only visiting nodes
// from the root down to a given IP prefix. Where WalkPrefix walks
// all the entries *under* the given prefix, this walks the
// entries *above* the given prefix.
func (t *Tree[T]) WalkPath(path netip.Prefix, fn WalkFn[T]) {
n := t.rootV4
if path.Addr().Is6() {
n = t.rootV6
}
bitPosition := 0
// mask the address for sanity
address := path.Masked().Addr()
// we can't check longer than the request mask
mask := path.Bits()
// walk the network bits of the prefix
for bitPosition < mask {
// Visit the public values if any
if n.public && fn(n.prefix, n.val) {
return
}
// Look for a child checking the bit position after the mask
n = n.child[getBitFromAddr(address, bitPosition+1)]
if n == nil {
return
}
// check we are in the right branch comparing the suffixes
if !n.prefix.Contains(address) {
return
}
// update the new bit position with the new node mask
bitPosition = n.prefix.Bits()
}
// check if this node is a public node and contains a prefix
if n != nil && n.public && n.prefix == path {
fn(n.prefix, n.val)
}
}
// TopLevelPrefixes is used to return a map with all the Top Level prefixes
// from the corresponding IP family and its values.
// For example, if the tree contains entries for 10.0.0.0/8, 10.1.0.0/16, and 192.168.0.0/16,
// this will return 10.0.0.0/8 and 192.168.0.0/16.
func (t *Tree[T]) TopLevelPrefixes(isIPv6 bool) map[string]T {
if isIPv6 {
return t.topLevelPrefixes(t.rootV6)
}
return t.topLevelPrefixes(t.rootV4)
}
// topLevelPrefixes is used to return a map with all the Top Level prefixes and its values
func (t *Tree[T]) topLevelPrefixes(root *node[T]) map[string]T {
result := map[string]T{}
queue := []*node[T]{root}
for len(queue) > 0 {
n := queue[0]
queue = queue[1:]
// store and continue, only interested on the top level prefixes
if n.public {
result[n.prefix.String()] = n.val
continue
}
if n.child[0] != nil {
queue = append(queue, n.child[0])
}
if n.child[1] != nil {
queue = append(queue, n.child[1])
}
}
return result
}
// GetHostIPPrefixMatches returns the list of prefixes that contain the specified Host IP.
// An IP is considered a Host IP if is within the subnet range and is not the network address
// or, if IPv4, the broadcast address (RFC 1878).
func (t *Tree[T]) GetHostIPPrefixMatches(ip netip.Addr) map[netip.Prefix]T {
// walk the tree to find all the prefixes containing this IP
ipPrefix := netip.PrefixFrom(ip, ip.BitLen())
prefixes := map[netip.Prefix]T{}
t.WalkPath(ipPrefix, func(k netip.Prefix, v T) bool {
if prefixContainIP(k, ipPrefix.Addr()) {
prefixes[k] = v
}
return false
})
return prefixes
}
// assume starts at 0 from the MSB: 0.1.2......31
// return 0 or 1
func getBitFromAddr(ip netip.Addr, pos int) int {
bytes := ip.AsSlice()
// get the byte in the slice
index := (pos - 1) / 8
if index >= len(bytes) {
panic(fmt.Sprintf("ip %s pos %d index %d bytes %v", ip, pos, index, bytes))
}
// get the offset inside the byte
offset := (pos - 1) % 8
// check if the bit is set
if bytes[index]&(uint8(0x80)>>offset) > 0 {
return 1
}
return 0
}
// find the common subnet, aka the one with the common prefix
func findAncestor(a, b netip.Prefix) netip.Prefix {
bytesA := a.Addr().AsSlice()
bytesB := b.Addr().AsSlice()
bytes := make([]byte, len(bytesA))
max := a.Bits()
if l := b.Bits(); l < max {
max = l
}
mask := 0
for i := range bytesA {
xor := bytesA[i] ^ bytesB[i]
if xor == 0 {
bytes[i] = bytesA[i]
mask += 8
} else {
pos := bits.LeadingZeros8(xor)
mask += pos
// mask off the non leading zeros
bytes[i] = bytesA[i] & (^uint8(0) << (8 - pos))
break
}
}
if mask > max {
mask = max
}
addr, ok := netip.AddrFromSlice(bytes)
if !ok {
panic(bytes)
}
ancestor := netip.PrefixFrom(addr, mask)
return ancestor.Masked()
}
// prefixContainIP returns true if the given IP is contained with the prefix,
// is not the network address and also, if IPv4, is not the broadcast address.
// This is required because the Kubernetes allocators reserve these addresses
// so IPAddresses can not block deletion of this ranges.
func prefixContainIP(prefix netip.Prefix, ip netip.Addr) bool {
// if the IP is the network address is not contained
if prefix.Masked().Addr() == ip {
return false
}
// the broadcast address is not considered contained for IPv4
if !ip.Is6() {
ipLast, err := broadcastAddress(prefix)
if err != nil || ipLast == ip {
return false
}
}
return prefix.Contains(ip)
}
// TODO(aojea) consolidate all these IPs utils
// pkg/registry/core/service/ipallocator/ipallocator.go
// broadcastAddress returns the broadcast address of the subnet
// The broadcast address is obtained by setting all the host bits
// in a subnet to 1.
// network 192.168.0.0/24 : subnet bits 24 host bits 32 - 24 = 8
// broadcast address 192.168.0.255
func broadcastAddress(subnet netip.Prefix) (netip.Addr, error) {
base := subnet.Masked().Addr()
bytes := base.AsSlice()
// get all the host bits from the subnet
n := 8*len(bytes) - subnet.Bits()
// set all the host bits to 1
for i := len(bytes) - 1; i >= 0 && n > 0; i-- {
if n >= 8 {
bytes[i] = 0xff
n -= 8
} else {
mask := ^uint8(0) >> (8 - n)
bytes[i] |= mask
break
}
}
addr, ok := netip.AddrFromSlice(bytes)
if !ok {
return netip.Addr{}, fmt.Errorf("invalid address %v", bytes)
}
return addr, nil
}

View File

@ -1,781 +0,0 @@
/*
Copyright 2023 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package iptree
import (
"math/rand"
"net/netip"
"reflect"
"sort"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"k8s.io/apimachinery/pkg/util/sets"
)
func Test_InsertGetDelete(t *testing.T) {
testCases := []struct {
name string
prefix netip.Prefix
}{
{
name: "ipv4",
prefix: netip.MustParsePrefix("192.168.0.0/24"),
},
{
name: "ipv6",
prefix: netip.MustParsePrefix("fd00:1:2:3::/124"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tree := New[int]()
ok := tree.InsertPrefix(tc.prefix, 1)
if ok {
t.Fatal("should not exist")
}
if _, ok := tree.GetPrefix(tc.prefix); !ok {
t.Errorf("CIDR %s not found", tc.prefix)
}
if ok := tree.DeletePrefix(tc.prefix); !ok {
t.Errorf("CIDR %s not deleted", tc.prefix)
}
if _, ok := tree.GetPrefix(tc.prefix); ok {
t.Errorf("CIDR %s found", tc.prefix)
}
})
}
}
func TestBasicIPv4(t *testing.T) {
tree := New[int]()
// insert
ipnet := netip.MustParsePrefix("192.168.0.0/24")
ok := tree.InsertPrefix(ipnet, 1)
if ok {
t.Fatal("should not exist")
}
// check exist
if _, ok := tree.GetPrefix(ipnet); !ok {
t.Errorf("CIDR %s not found", ipnet)
}
// check does not exist
ipnet2 := netip.MustParsePrefix("12.1.0.0/16")
if _, ok := tree.GetPrefix(ipnet2); ok {
t.Errorf("CIDR %s not expected", ipnet2)
}
// check insert existing prefix updates the value
ok = tree.InsertPrefix(ipnet2, 2)
if ok {
t.Errorf("should not exist: %s", ipnet2)
}
ok = tree.InsertPrefix(ipnet2, 3)
if !ok {
t.Errorf("should be updated: %s", ipnet2)
}
if v, ok := tree.GetPrefix(ipnet2); !ok || v != 3 {
t.Errorf("CIDR %s not expected", ipnet2)
}
// check longer prefix matching
ipnet3 := netip.MustParsePrefix("12.1.0.2/32")
lpm, _, ok := tree.LongestPrefixMatch(ipnet3)
if !ok || lpm != ipnet2 {
t.Errorf("expected %s got %s", ipnet2, lpm)
}
}
func TestBasicIPv6(t *testing.T) {
tree := New[int]()
// insert
ipnet := netip.MustParsePrefix("2001:db8::/64")
ok := tree.InsertPrefix(ipnet, 1)
if ok {
t.Fatal("should not exist")
}
// check exist
if _, ok := tree.GetPrefix(ipnet); !ok {
t.Errorf("CIDR %s not found", ipnet)
}
// check does not exist
ipnet2 := netip.MustParsePrefix("2001:db8:1:3:4::/64")
if _, ok := tree.GetPrefix(ipnet2); ok {
t.Errorf("CIDR %s not expected", ipnet2)
}
// check insert existing prefix updates the value
ok = tree.InsertPrefix(ipnet2, 2)
if ok {
t.Errorf("should not exist: %s", ipnet2)
}
ok = tree.InsertPrefix(ipnet2, 3)
if !ok {
t.Errorf("should be updated: %s", ipnet2)
}
if v, ok := tree.GetPrefix(ipnet2); !ok || v != 3 {
t.Errorf("CIDR %s not expected", ipnet2)
}
// check longer prefix matching
ipnet3 := netip.MustParsePrefix("2001:db8:1:3:4::/96")
lpm, _, ok := tree.LongestPrefixMatch(ipnet3)
if !ok || lpm != ipnet2 {
t.Errorf("expected %s got %s", ipnet2, lpm)
}
}
func TestInsertGetDelete100K(t *testing.T) {
testCases := []struct {
name string
is6 bool
}{
{
name: "ipv4",
},
{
name: "ipv6",
is6: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cidrs := generateRandomCIDRs(tc.is6, 100*1000)
tree := New[string]()
for k := range cidrs {
ok := tree.InsertPrefix(k, k.String())
if ok {
t.Errorf("error inserting: %v", k)
}
}
if tree.Len(tc.is6) != len(cidrs) {
t.Errorf("expected %d nodes on the tree, got %d", len(cidrs), tree.Len(tc.is6))
}
list := cidrs.UnsortedList()
for _, k := range list {
if v, ok := tree.GetPrefix(k); !ok {
t.Errorf("CIDR %s not found", k)
return
} else if v != k.String() {
t.Errorf("CIDR value %s not found", k)
return
}
ok := tree.DeletePrefix(k)
if !ok {
t.Errorf("CIDR delete %s error", k)
}
}
if tree.Len(tc.is6) != 0 {
t.Errorf("No node expected on the tree, got: %d %v", tree.Len(tc.is6), cidrs)
}
})
}
}
func Test_findAncestor(t *testing.T) {
tests := []struct {
name string
a netip.Prefix
b netip.Prefix
want netip.Prefix
}{
{
name: "ipv4 direct parent",
a: netip.MustParsePrefix("192.168.0.0/24"),
b: netip.MustParsePrefix("192.168.1.0/24"),
want: netip.MustParsePrefix("192.168.0.0/23"),
},
{
name: "ipv4 root parent ",
a: netip.MustParsePrefix("192.168.0.0/24"),
b: netip.MustParsePrefix("1.168.1.0/24"),
want: netip.MustParsePrefix("0.0.0.0/0"),
},
{
name: "ipv4 parent /1",
a: netip.MustParsePrefix("192.168.0.0/24"),
b: netip.MustParsePrefix("184.168.1.0/24"),
want: netip.MustParsePrefix("128.0.0.0/1"),
},
{
name: "ipv6 direct parent",
a: netip.MustParsePrefix("fd00:1:1:1::/64"),
b: netip.MustParsePrefix("fd00:1:1:2::/64"),
want: netip.MustParsePrefix("fd00:1:1::/62"),
},
{
name: "ipv6 root parent ",
a: netip.MustParsePrefix("fd00:1:1:1::/64"),
b: netip.MustParsePrefix("1:1:1:1::/64"),
want: netip.MustParsePrefix("::/0"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := findAncestor(tt.a, tt.b); !reflect.DeepEqual(got, tt.want) {
t.Errorf("findAncestor() = %v, want %v", got, tt.want)
}
})
}
}
func Test_getBitFromAddr(t *testing.T) {
tests := []struct {
name string
ip netip.Addr
pos int
want int
}{
// 192.168.0.0
// 11000000.10101000.00000000.00000001
{
name: "ipv4 first is a one",
ip: netip.MustParseAddr("192.168.0.0"),
pos: 1,
want: 1,
},
{
name: "ipv4 middle is a zero",
ip: netip.MustParseAddr("192.168.0.0"),
pos: 16,
want: 0,
},
{
name: "ipv4 middle is a one",
ip: netip.MustParseAddr("192.168.0.0"),
pos: 13,
want: 1,
},
{
name: "ipv4 last is a zero",
ip: netip.MustParseAddr("192.168.0.0"),
pos: 32,
want: 0,
},
// 2001:db8::ff00:42:8329
// 0010000000000001:0000110110111000:0000000000000000:0000000000000000:0000000000000000:1111111100000000:0000000001000010:1000001100101001
{
name: "ipv6 first is a zero",
ip: netip.MustParseAddr("2001:db8::ff00:42:8329"),
pos: 1,
want: 0,
},
{
name: "ipv6 middle is a zero",
ip: netip.MustParseAddr("2001:db8::ff00:42:8329"),
pos: 56,
want: 0,
},
{
name: "ipv6 middle is a one",
ip: netip.MustParseAddr("2001:db8::ff00:42:8329"),
pos: 81,
want: 1,
},
{
name: "ipv6 last is a one",
ip: netip.MustParseAddr("2001:db8::ff00:42:8329"),
pos: 128,
want: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getBitFromAddr(tt.ip, tt.pos); got != tt.want {
t.Errorf("getBitFromAddr() = %v, want %v", got, tt.want)
}
})
}
}
func TestShortestPrefix(t *testing.T) {
r := New[int]()
keys := []string{
"10.0.0.0/8",
"10.21.0.0/16",
"10.221.0.0/16",
"10.1.2.3/32",
"10.1.2.0/24",
"192.168.0.0/24",
"192.168.0.0/16",
}
for _, k := range keys {
ok := r.InsertPrefix(netip.MustParsePrefix(k), 0)
if ok {
t.Errorf("unexpected update on insert %s", k)
}
}
if r.Len(false) != len(keys) {
t.Fatalf("bad len: %v %v", r.Len(false), len(keys))
}
type exp struct {
inp string
out string
}
cases := []exp{
{"192.168.0.3/32", "192.168.0.0/16"},
{"10.1.2.4/21", "10.0.0.0/8"},
{"192.168.0.0/16", "192.168.0.0/16"},
{"192.168.0.0/32", "192.168.0.0/16"},
{"10.1.2.3/32", "10.0.0.0/8"},
}
for _, test := range cases {
m, _, ok := r.ShortestPrefixMatch(netip.MustParsePrefix(test.inp))
if !ok {
t.Fatalf("no match: %v", test)
}
if m != netip.MustParsePrefix(test.out) {
t.Fatalf("mis-match: %v %v", m, test)
}
}
// not match
_, _, ok := r.ShortestPrefixMatch(netip.MustParsePrefix("0.0.0.0/0"))
if ok {
t.Fatalf("match unexpected for 0.0.0.0/0")
}
}
func TestLongestPrefixMatch(t *testing.T) {
r := New[int]()
keys := []string{
"10.0.0.0/8",
"10.21.0.0/16",
"10.221.0.0/16",
"10.1.2.3/32",
"10.1.2.0/24",
"192.168.0.0/24",
"192.168.0.0/16",
}
for _, k := range keys {
ok := r.InsertPrefix(netip.MustParsePrefix(k), 0)
if ok {
t.Errorf("unexpected update on insert %s", k)
}
}
if r.Len(false) != len(keys) {
t.Fatalf("bad len: %v %v", r.Len(false), len(keys))
}
type exp struct {
inp string
out string
}
cases := []exp{
{"192.168.0.3/32", "192.168.0.0/24"},
{"10.1.2.4/21", "10.0.0.0/8"},
{"10.21.2.0/24", "10.21.0.0/16"},
{"10.1.2.3/32", "10.1.2.3/32"},
}
for _, test := range cases {
m, _, ok := r.LongestPrefixMatch(netip.MustParsePrefix(test.inp))
if !ok {
t.Fatalf("no match: %v", test)
}
if m != netip.MustParsePrefix(test.out) {
t.Fatalf("mis-match: %v %v", m, test)
}
}
// not match
_, _, ok := r.LongestPrefixMatch(netip.MustParsePrefix("0.0.0.0/0"))
if ok {
t.Fatalf("match unexpected for 0.0.0.0/0")
}
}
func TestTopLevelPrefixesV4(t *testing.T) {
r := New[string]()
keys := []string{
"10.0.0.0/8",
"10.21.0.0/16",
"10.221.0.0/16",
"10.1.2.3/32",
"10.1.2.0/24",
"192.168.0.0/20",
"192.168.1.0/24",
"172.16.0.0/12",
"172.21.23.0/24",
}
for _, k := range keys {
ok := r.InsertPrefix(netip.MustParsePrefix(k), k)
if ok {
t.Errorf("unexpected update on insert %s", k)
}
}
if r.Len(false) != len(keys) {
t.Fatalf("bad len: %v %v", r.Len(false), len(keys))
}
expected := []string{
"10.0.0.0/8",
"192.168.0.0/20",
"172.16.0.0/12",
}
parents := r.TopLevelPrefixes(false)
if len(parents) != len(expected) {
t.Fatalf("bad len: %v %v", len(parents), len(expected))
}
for _, k := range expected {
v, ok := parents[k]
if !ok {
t.Errorf("key %s not found", k)
}
if v != k {
t.Errorf("value expected %s got %s", k, v)
}
}
}
func TestTopLevelPrefixesV6(t *testing.T) {
r := New[string]()
keys := []string{
"2001:db8:1:2:3::/64",
"2001:db8::/64",
"2001:db8:1:1:1::/64",
"2001:db8:1:1:1::/112",
}
for _, k := range keys {
ok := r.InsertPrefix(netip.MustParsePrefix(k), k)
if ok {
t.Errorf("unexpected update on insert %s", k)
}
}
if r.Len(true) != len(keys) {
t.Fatalf("bad len: %v %v", r.Len(true), len(keys))
}
expected := []string{
"2001:db8::/64",
"2001:db8:1:2:3::/64",
"2001:db8:1:1:1::/64",
}
parents := r.TopLevelPrefixes(true)
if len(parents) != len(expected) {
t.Fatalf("bad len: %v %v", len(parents), len(expected))
}
for _, k := range expected {
v, ok := parents[k]
if !ok {
t.Errorf("key %s not found", k)
}
if v != k {
t.Errorf("value expected %s got %s", k, v)
}
}
}
func TestWalkV4(t *testing.T) {
r := New[int]()
keys := []string{
"10.0.0.0/8",
"10.1.0.0/16",
"10.1.1.0/24",
"10.1.1.32/26",
"10.1.1.33/32",
}
for _, k := range keys {
ok := r.InsertPrefix(netip.MustParsePrefix(k), 0)
if ok {
t.Errorf("unexpected update on insert %s", k)
}
}
if r.Len(false) != len(keys) {
t.Fatalf("bad len: %v %v", r.Len(false), len(keys))
}
// match exact prefix
path := []string{}
r.WalkPath(netip.MustParsePrefix("10.1.1.32/26"), func(k netip.Prefix, v int) bool {
path = append(path, k.String())
return false
})
if !cmp.Equal(path, keys[:4]) {
t.Errorf("Walkpath expected %v got %v", keys[:4], path)
}
// not match on prefix
path = []string{}
r.WalkPath(netip.MustParsePrefix("10.1.1.33/26"), func(k netip.Prefix, v int) bool {
path = append(path, k.String())
return false
})
if !cmp.Equal(path, keys[:3]) {
t.Errorf("Walkpath expected %v got %v", keys[:3], path)
}
// match exact prefix
path = []string{}
r.WalkPrefix(netip.MustParsePrefix("10.0.0.0/8"), func(k netip.Prefix, v int) bool {
path = append(path, k.String())
return false
})
if !cmp.Equal(path, keys) {
t.Errorf("WalkPrefix expected %v got %v", keys, path)
}
// not match on prefix
path = []string{}
r.WalkPrefix(netip.MustParsePrefix("10.0.0.0/9"), func(k netip.Prefix, v int) bool {
path = append(path, k.String())
return false
})
if !cmp.Equal(path, keys[1:]) {
t.Errorf("WalkPrefix expected %v got %v", keys[1:], path)
}
}
func TestWalkV6(t *testing.T) {
r := New[int]()
keys := []string{
"2001:db8::/48",
"2001:db8::/64",
"2001:db8::/96",
"2001:db8::/112",
"2001:db8::/128",
}
for _, k := range keys {
ok := r.InsertPrefix(netip.MustParsePrefix(k), 0)
if ok {
t.Errorf("unexpected update on insert %s", k)
}
}
if r.Len(true) != len(keys) {
t.Fatalf("bad len: %v %v", r.Len(false), len(keys))
}
// match exact prefix
path := []string{}
r.WalkPath(netip.MustParsePrefix("2001:db8::/112"), func(k netip.Prefix, v int) bool {
path = append(path, k.String())
return false
})
if !cmp.Equal(path, keys[:4]) {
t.Errorf("Walkpath expected %v got %v", keys[:4], path)
}
// not match on prefix
path = []string{}
r.WalkPath(netip.MustParsePrefix("2001:db8::1/112"), func(k netip.Prefix, v int) bool {
path = append(path, k.String())
return false
})
if !cmp.Equal(path, keys[:3]) {
t.Errorf("Walkpath expected %v got %v", keys[:3], path)
}
// match exact prefix
path = []string{}
r.WalkPrefix(netip.MustParsePrefix("2001:db8::/48"), func(k netip.Prefix, v int) bool {
path = append(path, k.String())
return false
})
if !cmp.Equal(path, keys) {
t.Errorf("WalkPrefix expected %v got %v", keys, path)
}
// not match on prefix
path = []string{}
r.WalkPrefix(netip.MustParsePrefix("2001:db8::/49"), func(k netip.Prefix, v int) bool {
path = append(path, k.String())
return false
})
if !cmp.Equal(path, keys[1:]) {
t.Errorf("WalkPrefix expected %v got %v", keys[1:], path)
}
}
func TestGetHostIPPrefixMatches(t *testing.T) {
r := New[int]()
keys := []string{
"10.0.0.0/8",
"10.21.0.0/16",
"10.221.0.0/16",
"10.1.2.3/32",
"10.1.2.0/24",
"192.168.0.0/24",
"192.168.0.0/16",
"2001:db8::/48",
"2001:db8::/64",
"2001:db8::/96",
}
for _, k := range keys {
ok := r.InsertPrefix(netip.MustParsePrefix(k), 0)
if ok {
t.Errorf("unexpected update on insert %s", k)
}
}
type exp struct {
inp string
out []string
}
cases := []exp{
{"192.168.0.3", []string{"192.168.0.0/24", "192.168.0.0/16"}},
{"10.1.2.4", []string{"10.1.2.0/24", "10.0.0.0/8"}},
{"10.1.2.0", []string{"10.0.0.0/8"}},
{"10.1.2.255", []string{"10.0.0.0/8"}},
{"192.168.0.0", []string{}},
{"192.168.1.0", []string{"192.168.0.0/16"}},
{"10.1.2.255", []string{"10.0.0.0/8"}},
{"2001:db8::1", []string{"2001:db8::/96", "2001:db8::/64", "2001:db8::/48"}},
{"2001:db8::", []string{}},
{"2001:db8::ffff:ffff:ffff:ffff", []string{"2001:db8::/64", "2001:db8::/48"}},
}
for _, test := range cases {
m := r.GetHostIPPrefixMatches(netip.MustParseAddr(test.inp))
in := []netip.Prefix{}
for k := range m {
in = append(in, k)
}
out := []netip.Prefix{}
for _, s := range test.out {
out = append(out, netip.MustParsePrefix(s))
}
// sort by prefix bits to avoid flakes
sort.Slice(in, func(i, j int) bool { return in[i].Bits() < in[j].Bits() })
sort.Slice(out, func(i, j int) bool { return out[i].Bits() < out[j].Bits() })
if !reflect.DeepEqual(in, out) {
t.Fatalf("mis-match: %v %v", in, out)
}
}
// not match
_, _, ok := r.ShortestPrefixMatch(netip.MustParsePrefix("0.0.0.0/0"))
if ok {
t.Fatalf("match unexpected for 0.0.0.0/0")
}
}
func Test_prefixContainIP(t *testing.T) {
tests := []struct {
name string
prefix netip.Prefix
ip netip.Addr
want bool
}{
{
name: "IPv4 contains",
prefix: netip.MustParsePrefix("192.168.0.0/24"),
ip: netip.MustParseAddr("192.168.0.1"),
want: true,
},
{
name: "IPv4 network address",
prefix: netip.MustParsePrefix("192.168.0.0/24"),
ip: netip.MustParseAddr("192.168.0.0"),
},
{
name: "IPv4 broadcast address",
prefix: netip.MustParsePrefix("192.168.0.0/24"),
ip: netip.MustParseAddr("192.168.0.255"),
},
{
name: "IPv4 does not contain",
prefix: netip.MustParsePrefix("192.168.0.0/24"),
ip: netip.MustParseAddr("192.168.1.2"),
},
{
name: "IPv6 contains",
prefix: netip.MustParsePrefix("2001:db2::/96"),
ip: netip.MustParseAddr("2001:db2::1"),
want: true,
},
{
name: "IPv6 network address",
prefix: netip.MustParsePrefix("2001:db2::/96"),
ip: netip.MustParseAddr("2001:db2::"),
},
{
name: "IPv6 broadcast address",
prefix: netip.MustParsePrefix("2001:db2::/96"),
ip: netip.MustParseAddr("2001:db2::ffff:ffff"),
want: true,
},
{
name: "IPv6 does not contain",
prefix: netip.MustParsePrefix("2001:db2::/96"),
ip: netip.MustParseAddr("2001:db2:1:2:3::1"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := prefixContainIP(tt.prefix, tt.ip); got != tt.want {
t.Errorf("prefixContainIP() = %v, want %v", got, tt.want)
}
})
}
}
func BenchmarkInsertUpdate(b *testing.B) {
r := New[bool]()
ipList := generateRandomCIDRs(true, 20000).UnsortedList()
for _, ip := range ipList {
r.InsertPrefix(ip, true)
}
b.ResetTimer()
for n := 0; n < b.N; n++ {
r.InsertPrefix(ipList[n%len(ipList)], true)
}
}
func generateRandomCIDRs(is6 bool, number int) sets.Set[netip.Prefix] {
n := 4
if is6 {
n = 16
}
cidrs := sets.Set[netip.Prefix]{}
rand.New(rand.NewSource(time.Now().UnixNano()))
for i := 0; i < number; i++ {
bytes := make([]byte, n)
for i := 0; i < n; i++ {
bytes[i] = uint8(rand.Intn(255))
}
ip, ok := netip.AddrFromSlice(bytes)
if !ok {
continue
}
bits := rand.Intn(n * 8)
prefix := netip.PrefixFrom(ip, bits).Masked()
if prefix.IsValid() {
cidrs.Insert(prefix)
}
}
return cidrs
}