diff --git a/src/libs/Cargo.lock b/src/libs/Cargo.lock index 1cdd13212..9f2be8395 100644 --- a/src/libs/Cargo.lock +++ b/src/libs/Cargo.lock @@ -37,6 +37,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + [[package]] name = "anyhow" version = "1.0.57" @@ -472,6 +478,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "either" version = "1.6.1" @@ -510,6 +522,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "fragile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" + [[package]] name = "funty" version = "2.0.0" @@ -826,10 +844,12 @@ dependencies = [ "kata-types", "lazy_static", "libc", + "mockall", "nix 0.24.2", "num_cpus", "oci-spec", "once_cell", + "pci-ids", "rand", "runtime-spec", "safe-path", @@ -941,6 +961,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.7.3" @@ -961,6 +987,32 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mockall" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "multimap" version = "0.8.3" @@ -992,6 +1044,16 @@ dependencies = [ "pin-utils", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "ntapi" version = "0.4.1" @@ -1102,6 +1164,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "pci-ids" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d88ae3281b415d856e9c2ddbcdd5961e71c1a3e90138512c04d720241853a6af" +dependencies = [ + "nom", + "phf", + "phf_codegen", + "proc-macro2", + "quote", +] + [[package]] name = "petgraph" version = "0.5.1" @@ -1112,6 +1187,44 @@ dependencies = [ "indexmap", ] +[[package]] +name = "phf" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_shared" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project" version = "1.0.12" @@ -1156,6 +1269,32 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" +[[package]] +name = "predicates" +version = "3.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e9086cc7640c29a356d1a29fd134380bee9d8f79a17410aa76e7ad295f42c97" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae8177bee8e75d6846599c6b9ff679ed51e882816914eec639944d7c9aa11931" + +[[package]] +name = "predicates-tree" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41b740d195ed3166cd147c8047ec98db0e22ec019eb8eeb76d343b795304fb13" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "proc-macro-crate" version = "0.1.5" @@ -1676,6 +1815,12 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + [[package]] name = "slab" version = "0.4.6" @@ -1871,6 +2016,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "termtree" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" + [[package]] name = "test-utils" version = "0.1.0" diff --git a/src/libs/kata-sys-util/Cargo.toml b/src/libs/kata-sys-util/Cargo.toml index 079339c9c..77574d6b8 100644 --- a/src/libs/kata-sys-util/Cargo.toml +++ b/src/libs/kata-sys-util/Cargo.toml @@ -28,6 +28,8 @@ subprocess = "0.2.8" rand = "0.8.5" thiserror = "1.0.30" hex = "0.4.3" +pci-ids = "0.2.5" +mockall = "0.13.1" kata-types = { path = "../kata-types" } oci-spec = { version = "0.6.8", features = ["runtime"] } diff --git a/src/libs/kata-sys-util/src/lib.rs b/src/libs/kata-sys-util/src/lib.rs index 1e51c6e85..8131835c6 100644 --- a/src/libs/kata-sys-util/src/lib.rs +++ b/src/libs/kata-sys-util/src/lib.rs @@ -14,6 +14,7 @@ pub mod k8s; pub mod mount; pub mod netns; pub mod numa; +pub mod pcilibs; pub mod protection; pub mod rand; pub mod spec; diff --git a/src/libs/kata-sys-util/src/pcilibs/devices.rs b/src/libs/kata-sys-util/src/pcilibs/devices.rs new file mode 100644 index 000000000..c989075f4 --- /dev/null +++ b/src/libs/kata-sys-util/src/pcilibs/devices.rs @@ -0,0 +1,160 @@ +// Copyright (c) 2024 Ant Group +// +// SPDX-License-Identifier: Apache-2.0 +// +#![allow(dead_code)] + +use super::pci_manager::{MemoryResourceTrait, PCIDevice, PCIDeviceManager, PCIDevices}; + +const PCI_DEVICES_ROOT: &str = "/sys/bus/pci/devices"; +const PCI_NVIDIA_VENDOR_ID: u16 = 0x10DE; +const PCI3D_CONTROLLER_CLASS: u32 = 0x030200; + +struct NvidiaPCIDevice { + vendor_id: u16, + class_id: u32, +} + +impl NvidiaPCIDevice { + pub fn new(vendor_id: u16, class_id: u32) -> Self { + Self { + vendor_id, + class_id, + } + } + + pub fn get_bars_max_addressable_memory(&self) -> (u64, u64) { + let mut max_32bit = 2 * 1024 * 1024; + let mut max_64bit = 2 * 1024 * 1024; + + let nvgpu_devices = self.get_pci_devices(Some(self.vendor_id)); + for dev in nvgpu_devices { + let (mem_size_32bit, mem_size_64bit) = dev.resources.get_total_addressable_memory(true); + if max_32bit < mem_size_32bit { + max_32bit = mem_size_32bit; + } + if max_64bit < mem_size_64bit { + max_64bit = mem_size_64bit; + } + } + + (max_32bit * 2, max_64bit) + } + + fn is_vga_controller(&self, device: &PCIDevice) -> bool { + self.class_id == device.class + } + + fn is_3d_controller(&self, device: &PCIDevice) -> bool { + self.class_id == device.class + } + + fn is_gpu(&self, device: &PCIDevice) -> bool { + self.is_vga_controller(device) || self.is_3d_controller(device) + } +} + +impl PCIDevices for NvidiaPCIDevice { + fn get_pci_devices(&self, vendor: Option) -> Vec { + let mut nvidia_devices: Vec = Vec::new(); + let devices = PCIDeviceManager::new(PCI_DEVICES_ROOT) + .get_all_devices(vendor) + .unwrap_or_else(|_| vec![]); + for dev in devices.iter() { + if self.is_gpu(dev) { + nvidia_devices.push(dev.clone()); + } + } + + return nvidia_devices; + } +} + +pub fn get_bars_max_addressable_memory() -> (u64, u64) { + let nvdevice = NvidiaPCIDevice::new(PCI_NVIDIA_VENDOR_ID, PCI3D_CONTROLLER_CLASS); + let (max_32bit, max_64bit) = nvdevice.get_bars_max_addressable_memory(); + + (max_32bit, max_64bit) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::path::PathBuf; + + use super::*; + use crate::pcilibs::pci_manager::{ + MemoryResource, MemoryResources, MockPCIDevices, PCI_BASE_ADDRESS_MEM_TYPE32, + PCI_BASE_ADDRESS_MEM_TYPE64, + }; + use mockall::predicate::*; + + #[test] + fn test_get_bars_max_addressable_memory() { + let pci_device = PCIDevice { + device_path: PathBuf::new(), + address: "0000:00:00.0".to_string(), + vendor: PCI_NVIDIA_VENDOR_ID, + class: PCI3D_CONTROLLER_CLASS, + class_name: "3D Controller".to_string(), + device: 0x1c82, + device_name: "NVIDIA Device".to_string(), + driver: "nvidia".to_string(), + iommu_group: 0, + numa_node: 0, + resources: MemoryResources::default(), + }; + let devices = vec![pci_device.clone()]; + + // Mock PCI device manager and devices + let mut mock_pci_manager = MockPCIDevices::default(); + // Setting up Mock to return a device + mock_pci_manager + .expect_get_pci_devices() + .with(eq(Some(PCI_NVIDIA_VENDOR_ID))) + .returning(move |_| devices.clone()); + + // Create NvidiaPCIDevice + let nvidia_device = NvidiaPCIDevice::new(PCI_NVIDIA_VENDOR_ID, PCI3D_CONTROLLER_CLASS); + + // Prepare memory resources + let mut resources: MemoryResources = HashMap::new(); + // resource0 memsz = end - start => 1024 + resources.insert( + 0, + MemoryResource { + start: 0, + end: 1023, + flags: PCI_BASE_ADDRESS_MEM_TYPE32, + path: PathBuf::from("/fake/path/resource0"), + }, + ); + // resource1 memsz = end - start => 1024 + resources.insert( + 1, + MemoryResource { + start: 1024, + end: 2047, + flags: PCI_BASE_ADDRESS_MEM_TYPE64, + path: PathBuf::from("/fake/path/resource1"), + }, + ); + + let pci_device_with_resources = PCIDevice { + resources: resources.clone(), + ..pci_device + }; + + mock_pci_manager + .expect_get_pci_devices() + .with(eq(Some(PCI_NVIDIA_VENDOR_ID))) + .returning(move |_| vec![pci_device_with_resources.clone()]); + + // Call the function under test + let (max_32bit, max_64bit) = nvidia_device.get_bars_max_addressable_memory(); + + // Assert the results + assert_eq!(max_32bit, 2 * 2 * 1024 * 1024); + assert_eq!(max_64bit, 2 * 1024 * 1024); + } +} diff --git a/src/libs/kata-sys-util/src/pcilibs/mod.rs b/src/libs/kata-sys-util/src/pcilibs/mod.rs new file mode 100644 index 000000000..453f71634 --- /dev/null +++ b/src/libs/kata-sys-util/src/pcilibs/mod.rs @@ -0,0 +1,6 @@ +// Copyright (c) 2024 Ant Group +// +// SPDX-License-Identifier: Apache-2.0 +// +mod devices; +mod pci_manager; diff --git a/src/libs/kata-sys-util/src/pcilibs/pci_manager.rs b/src/libs/kata-sys-util/src/pcilibs/pci_manager.rs new file mode 100644 index 000000000..bffc5fa87 --- /dev/null +++ b/src/libs/kata-sys-util/src/pcilibs/pci_manager.rs @@ -0,0 +1,446 @@ +// Copyright (c) 2024 Ant Group +// +// SPDX-License-Identifier: Apache-2.0 +// +#![allow(dead_code)] + +use std::collections::HashMap; +use std::fs; +use std::io; +use std::path::PathBuf; + +use mockall::automock; +use pci_ids::{Classes, Vendors}; + +const PCI_DEV_DOMAIN: &str = "0000"; +const PCI_CONFIG_SPACE_SZ: u64 = 256; + +const UNKNOWN_DEVICE: &str = "UNKNOWN_DEVICE"; +const UNKNOWN_CLASS: &str = "UNKNOWN_CLASS"; + +const PCI_IOV_NUM_BAR: usize = 6; +const PCI_BASE_ADDRESS_MEM_TYPE_MASK: u64 = 0x06; + +pub(crate) const PCI_BASE_ADDRESS_MEM_TYPE32: u64 = 0x00; // 32 bit address +pub(crate) const PCI_BASE_ADDRESS_MEM_TYPE64: u64 = 0x04; // 64 bit address + +fn address_to_id(address: &str) -> u64 { + let cleaned_address = address.replace(":", "").replace(".", ""); + u64::from_str_radix(&cleaned_address, 16).unwrap_or(0) +} + +// Calculate the next power of 2. +fn calc_next_power_of_2(mut n: u64) -> u64 { + if n < 1 { + return 1_u64; + } + + n -= 1; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + n |= n >> 32; + n + 1 +} + +#[derive(Clone, Debug, Default)] +pub(crate) struct MemoryResource { + pub(crate) start: u64, + pub(crate) end: u64, + pub(crate) flags: u64, + pub(crate) path: PathBuf, +} + +pub(crate) type MemoryResources = HashMap; + +pub(crate) trait MemoryResourceTrait { + fn get_total_addressable_memory(&self, round_up: bool) -> (u64, u64); +} + +impl MemoryResourceTrait for MemoryResources { + fn get_total_addressable_memory(&self, round_up: bool) -> (u64, u64) { + let mut num_bar = 0; + let mut mem_size_32bit = 0u64; + let mut mem_size_64bit = 0u64; + + let mut keys: Vec<_> = self.keys().cloned().collect(); + keys.sort(); + + for key in keys { + if key as usize >= PCI_IOV_NUM_BAR || num_bar == PCI_IOV_NUM_BAR { + break; + } + num_bar += 1; + + if let Some(region) = self.get(&key) { + let flags = region.flags & PCI_BASE_ADDRESS_MEM_TYPE_MASK; + let mem_type_32bit = flags == PCI_BASE_ADDRESS_MEM_TYPE32; + let mem_type_64bit = flags == PCI_BASE_ADDRESS_MEM_TYPE64; + let mem_size = (region.end - region.start + 1) as u64; + + if mem_type_32bit { + mem_size_32bit += mem_size; + } + if mem_type_64bit { + mem_size_64bit += mem_size; + } + } + } + + if round_up { + mem_size_32bit = calc_next_power_of_2(mem_size_32bit); + mem_size_64bit = calc_next_power_of_2(mem_size_64bit); + } + + (mem_size_32bit, mem_size_64bit) + } +} + +#[automock] +pub trait PCIDevices { + fn get_pci_devices(&self, vendor: Option) -> Vec; +} + +#[derive(Clone, Debug, Default)] +pub struct PCIDevice { + pub(crate) device_path: PathBuf, + pub(crate) address: String, + pub(crate) vendor: u16, + pub(crate) class: u32, + pub(crate) class_name: String, + pub(crate) device: u16, + pub(crate) device_name: String, + pub(crate) driver: String, + pub(crate) iommu_group: i64, + pub(crate) numa_node: i64, + pub(crate) resources: MemoryResources, +} + +pub struct PCIDeviceManager { + pci_devices_root: PathBuf, +} + +impl PCIDeviceManager { + pub fn new(pci_devices_root: &str) -> Self { + PCIDeviceManager { + pci_devices_root: PathBuf::from(pci_devices_root), + } + } + + pub fn get_all_devices(&self, vendor: Option) -> io::Result> { + let mut pci_devices = Vec::new(); + let device_dirs = fs::read_dir(&self.pci_devices_root)?; + + let mut cache: HashMap = HashMap::new(); + + for entry in device_dirs { + let device_dir = entry?; + let device_address = device_dir.file_name().to_string_lossy().to_string(); + if let Ok(device) = self.get_device_by_pci_bus_id(&device_address, vendor, &mut cache) { + if let Some(dev) = device { + pci_devices.push(dev); + } + } + } + + pci_devices.sort_by_key(|dev| address_to_id(&dev.address)); + + Ok(pci_devices) + } + + fn get_device_by_pci_bus_id( + &self, + address: &str, + vendor: Option, + cache: &mut HashMap, + ) -> io::Result> { + if let Some(device) = cache.get(address) { + return Ok(Some(device.clone())); + } + + let device_path = self.pci_devices_root.join(address); + + // read vendor ID + let vendor_str = fs::read_to_string(device_path.join("vendor"))?; + let vendor_id = u16::from_str_radix(vendor_str.trim().trim_start_matches("0x"), 16) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + if let Some(vend_id) = vendor { + if vendor_id != vend_id { + return Ok(None); + } + } + + let class_str = fs::read_to_string(device_path.join("class"))?; + let class_id = u32::from_str_radix(class_str.trim().trim_start_matches("0x"), 16).unwrap(); + + let device_str = fs::read_to_string(device_path.join("device"))?; + let device_id = + u16::from_str_radix(device_str.trim().trim_start_matches("0x"), 16).unwrap(); + + let driver = match fs::read_link(device_path.join("driver")) { + Ok(path) => path.file_name().unwrap().to_string_lossy().to_string(), + Err(_) => String::new(), + }; + + let iommu_group = match fs::read_link(device_path.join("iommu_group")) { + Ok(path) => path + .file_name() + .unwrap() + .to_string_lossy() + .into_owned() + .parse::() + .unwrap_or(-1), + Err(_) => -1, + }; + + let numa_node = fs::read_to_string(device_path.join("numa_node")) + .map(|numa| numa.trim().parse::().unwrap_or(-1)) + .unwrap_or(-1); + + let resources = self.parse_resources(&device_path)?; + + let mut device_name = UNKNOWN_DEVICE.to_string(); + for vendor in Vendors::iter() { + for device in vendor.devices() { + if vendor.id() == vendor_id && device.id() == device_id { + device_name = device.name().to_owned(); + break; + } + } + } + + let mut class_name = UNKNOWN_CLASS.to_string(); + for class in Classes::iter() { + if u32::from(class.id()) == class_id { + class_name = class.name().to_owned(); + break; + } + } + + let pci_device = PCIDevice { + device_path, + address: address.to_string(), + vendor: vendor_id, + class: class_id, + device: device_id, + driver, + iommu_group, + numa_node, + resources, + device_name, + class_name, + }; + + cache.insert(address.to_string(), pci_device.clone()); + + Ok(Some(pci_device)) + } + + fn parse_resources(&self, device_path: &PathBuf) -> io::Result { + let content = fs::read_to_string(device_path.join("resource"))?; + let mut resources: MemoryResources = MemoryResources::new(); + for (i, line) in content.lines().enumerate() { + let values: Vec<&str> = line.split_whitespace().collect(); + if values.len() != 3 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("there's more than 3 entries in line '{}'", i), + )); + } + + let mem_start = u64::from_str_radix(values[0].trim_start_matches("0x"), 16) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let mem_end = u64::from_str_radix(values[1].trim_start_matches("0x"), 16) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let mem_flags = u64::from_str_radix(values[2].trim_start_matches("0x"), 16) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + if mem_end > mem_start { + resources.insert( + i, + MemoryResource { + start: mem_start, + end: mem_end, + flags: mem_flags, + path: device_path.join(format!("resource{}", i)), + }, + ); + } + } + + Ok(resources) + } +} + +/// Checks if the given BDF corresponds to a PCIe device. +/// The sysbus_pci_root is the path "/sys/bus/pci/devices" +pub fn is_pcie_device(bdf: &str, sysbus_pci_root: &str) -> bool { + let bdf_with_domain = if bdf.split(':').count() == 2 { + format!("{}:{}", PCI_DEV_DOMAIN, bdf) + } else { + bdf.to_string() + }; + + let config_path = PathBuf::from(sysbus_pci_root) + .join(bdf_with_domain) + .join("config"); + + match fs::metadata(config_path) { + Ok(metadata) => metadata.len() > PCI_CONFIG_SPACE_SZ, + // Error reading the file, assume it's not a PCIe device + Err(_) => false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::io::Write; + use std::path::{Path, PathBuf}; + + const MOCK_PCI_DEVICES_ROOT: &str = "tests/mock_devices"; + // domain number + const TEST_PCI_DEV_DOMAIN: &str = "0000"; + // sysfs path + const MOCK_SYS_BUS_PCI_DEVICES: &str = "/tmp/bus/pci/devices"; + + // Mock data + fn setup_mock_device_files() { + // Create mock path and files for PCI devices + let device_path = PathBuf::from(MOCK_PCI_DEVICES_ROOT).join("0000:ff:1f.0"); + fs::create_dir_all(&device_path).unwrap(); + fs::write(device_path.join("vendor"), "0x8086").unwrap(); + fs::write(device_path.join("device"), "0x1234").unwrap(); + fs::write(device_path.join("class"), "0x060100").unwrap(); + fs::write(device_path.join("numa_node"), "0").unwrap(); + fs::write( + device_path.join("resource"), + "0x00000000 0x0000ffff 0x00000404\n", + ) + .unwrap(); + } + // Mock data + fn cleanup_mock_device_files() { + // Create mock path and files for PCI devices + let device_path = PathBuf::from(MOCK_PCI_DEVICES_ROOT).join("0000:ff:1f.0"); + // Clean up + let _ = fs::remove_file(device_path); + } + + #[test] + fn test_calc_next_power_of_2() { + assert_eq!(calc_next_power_of_2(0), 1); + assert_eq!(calc_next_power_of_2(1), 1); + assert_eq!(calc_next_power_of_2(6), 8); + assert_eq!(calc_next_power_of_2(9), 16); + assert_eq!(calc_next_power_of_2(15), 16); + assert_eq!(calc_next_power_of_2(16), 16); + assert_eq!(calc_next_power_of_2(17), 32); + } + + #[test] + fn test_get_total_addressable_memory() { + let mut resources: MemoryResources = HashMap::new(); + + // Adding a 32b memory region + resources.insert( + 0, + MemoryResource { + start: 0, + end: 1023, + flags: PCI_BASE_ADDRESS_MEM_TYPE32, + path: PathBuf::from("/path/resource0"), + }, + ); + + // Adding a 64b memory region + resources.insert( + 1, + MemoryResource { + start: 1024, + end: 2047, + flags: PCI_BASE_ADDRESS_MEM_TYPE64, + path: PathBuf::from("/path/resource1"), + }, + ); + + let (mem32, mem64) = resources.get_total_addressable_memory(false); + assert_eq!(mem32, 1024); + assert_eq!(mem64, 1024); + + // Test with rounding up + let (mem32, mem64) = resources.get_total_addressable_memory(true); + + // Nearest power of 2 is the number itself + assert_eq!(mem32, 1024); + assert_eq!(mem64, 1024); + } + + #[test] + fn test_get_all_devices() { + // Setup mock data + setup_mock_device_files(); + + // Initialize PCI device manager with the mock path + let manager = PCIDeviceManager::new(MOCK_PCI_DEVICES_ROOT); + + // Get all devices + let devices_result = manager.get_all_devices(None); + + assert!(devices_result.is_ok()); + let devices = devices_result.unwrap(); + assert_eq!(devices.len(), 1); + + let device = &devices[0]; + assert_eq!(device.vendor, 0x8086); + assert_eq!(device.device, 0x1234); + assert_eq!(device.class, 0x060100); + + // Cleanup mock data + cleanup_mock_device_files() + } + + #[test] + fn test_parse_resources() { + let manager = PCIDeviceManager::new(MOCK_PCI_DEVICES_ROOT); + let device_path = PathBuf::from(MOCK_PCI_DEVICES_ROOT).join("0000:ff:1f.0"); + + let resources_result = manager.parse_resources(&device_path); + assert!(resources_result.is_ok()); + + let resources = resources_result.unwrap(); + assert_eq!(resources.len(), 1); + + let resource = resources.get(&0).unwrap(); + assert_eq!(resource.start, 0x00000000); + assert_eq!(resource.end, 0x0000ffff); + assert_eq!(resource.flags, 0x00000404); + } + + #[test] + fn test_is_pcie_device() { + // Create a mock PCI device config file + let bdf = format!("{}:ff:00.0", TEST_PCI_DEV_DOMAIN); + let config_path = Path::new(MOCK_SYS_BUS_PCI_DEVICES) + .join(&bdf) + .join("config"); + let _ = fs::create_dir_all(config_path.parent().unwrap()); + + // Write a file with a size larger than PCI_CONFIG_SPACE_SZ + let mut file = fs::File::create(&config_path).unwrap(); + // Test size greater than PCI_CONFIG_SPACE_SZ + file.write_all(&vec![0; 512]).unwrap(); + + // It should be true + assert!(is_pcie_device( + &format!("ff:00.0"), + MOCK_SYS_BUS_PCI_DEVICES + )); + + // Clean up + let _ = fs::remove_file(config_path); + } +}