diff --git a/src/agent/src/device/block_device_handler.rs b/src/agent/src/device/block_device_handler.rs index 5e22ac28d4..f944e182ba 100644 --- a/src/agent/src/device/block_device_handler.rs +++ b/src/agent/src/device/block_device_handler.rs @@ -234,48 +234,89 @@ impl UeventMatcher for VirtioBlkCCWMatcher { #[cfg(test)] mod tests { use super::*; + use crate::device::test_helpers; + use rstest::rstest; #[cfg(target_arch = "s390x")] use std::str::FromStr; + // Test constants + const TEST_DEVNAME: &str = "vda"; + const TEST_PCI_RELPATH: &str = "/0000:00:0a.0"; + const TEST_ROOT_COMPLEX: &str = "00"; + + // Helper to create a standard PCI uevent for testing + fn create_pci_uevent( + devname: &str, + relpath: &str, + root_complex: &str, + virtio_id: u32, + ) -> crate::uevent::Uevent { + let root_bus = create_pci_root_bus_path(root_complex); + let mut uev = crate::uevent::Uevent::default(); + uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); + uev.subsystem = BLOCK.to_string(); + uev.devname = devname.to_string(); + uev.devpath = format!("{root_bus}{relpath}/virtio{virtio_id}/block/{devname}"); + uev + } + + #[rstest] + #[case::matcher_a_matches_uev_a("/0000:00:0a.0", "/0000:00:0a.0", 4, true)] + #[case::matcher_b_matches_uev_b( + "/0000:00:0a.0/0000:00:0b.0", + "/0000:00:0a.0/0000:00:0b.0", + 0, + true + )] + #[case::matcher_a_rejects_uev_b("/0000:00:0a.0", "/0000:00:0a.0/0000:00:0b.0", 0, false)] + #[case::matcher_b_rejects_uev_a("/0000:00:0a.0/0000:00:0b.0", "/0000:00:0a.0", 4, false)] #[tokio::test] - #[allow(clippy::redundant_clone)] - async fn test_virtio_blk_matcher() { - let root_bus = create_pci_root_bus_path("00"); - let devname = "vda"; + async fn test_virtio_blk_pci_matcher_basic_matching( + #[case] matcher_relpath: &str, + #[case] uevent_relpath: &str, + #[case] virtio_id: u32, + #[case] should_match: bool, + ) { + let matcher = VirtioBlkPciMatcher::new(matcher_relpath, TEST_ROOT_COMPLEX); + let uev = create_pci_uevent(TEST_DEVNAME, uevent_relpath, TEST_ROOT_COMPLEX, virtio_id); - let mut uev_a = crate::uevent::Uevent::default(); - let relpath_a = "/0000:00:0a.0"; - uev_a.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); - uev_a.subsystem = BLOCK.to_string(); - uev_a.devname = devname.to_string(); - uev_a.devpath = format!("{root_bus}{relpath_a}/virtio4/block/{devname}"); - let matcher_a = VirtioBlkPciMatcher::new(relpath_a, "00"); + assert_eq!( + matcher.is_match(&uev), + should_match, + "Matcher for '{}' should {} uevent for '{}'", + matcher_relpath, + if should_match { "match" } else { "reject" }, + uevent_relpath + ); + } - let mut uev_b = uev_a.clone(); - let relpath_b = "/0000:00:0a.0/0000:00:0b.0"; - uev_b.devpath = format!("{root_bus}{relpath_b}/virtio0/block/{devname}"); - let matcher_b = VirtioBlkPciMatcher::new(relpath_b, "00"); + #[rstest] + #[case::partition_vda1("vda1", "vda1")] + #[case::partition_vda91("vda91", "vda91")] + #[tokio::test] + async fn test_virtio_blk_pci_matcher_rejects_partitions( + #[case] partition_devname: &str, + #[case] partition_suffix: &str, + ) { + let root_bus = create_pci_root_bus_path(TEST_ROOT_COMPLEX); + let matcher = VirtioBlkPciMatcher::new(TEST_PCI_RELPATH, TEST_ROOT_COMPLEX); + let mut uev = create_pci_uevent(TEST_DEVNAME, TEST_PCI_RELPATH, TEST_ROOT_COMPLEX, 4); + uev.devname = partition_devname.to_string(); + uev.devpath = format!( + "{root_bus}{}/virtio4/block/{}/{}", + TEST_PCI_RELPATH, TEST_DEVNAME, partition_suffix + ); - assert!(matcher_a.is_match(&uev_a)); - assert!(matcher_b.is_match(&uev_b)); - assert!(!matcher_b.is_match(&uev_a)); - assert!(!matcher_a.is_match(&uev_b)); - - // Partition uevents must NOT match (only the whole-disk uevent should match) - let mut uev_part = uev_a.clone(); - uev_part.devname = "vda1".to_string(); - uev_part.devpath = format!("{root_bus}{relpath_a}/virtio4/block/{devname}/vda1"); - assert!(!matcher_a.is_match(&uev_part)); - - let mut uev_part91 = uev_a.clone(); - uev_part91.devname = "vda91".to_string(); - uev_part91.devpath = format!("{root_bus}{relpath_a}/virtio4/block/{devname}/vda91"); - assert!(!matcher_a.is_match(&uev_part91)); + assert!( + !matcher.is_match(&uev), + "Matcher should reject partition uevent for '{}'", + partition_devname + ); } #[cfg(target_arch = "s390x")] #[tokio::test] - async fn test_virtio_blk_ccw_matcher() { + async fn test_virtio_blk_ccw_matcher_valid_path() { let root_bus = CCW_ROOT_BUS_PATH; let subsystem = "block"; let devname = "vda"; @@ -287,55 +328,209 @@ mod tests { uev.devname = devname.to_string(); uev.devpath = format!("{root_bus}/0.0.0001/{relpath}/virtio1/{subsystem}/{devname}"); - // Valid path let device = ccw::Device::from_str(relpath).unwrap(); let matcher = VirtioBlkCCWMatcher::new(root_bus, &device); - assert!(matcher.is_match(&uev)); - // Invalid paths - uev.devpath = format!("{root_bus}/0.0.0001/0.0.0003/virtio1/{subsystem}/{devname}"); - assert!(!matcher.is_match(&uev)); + assert!( + matcher.is_match(&uev), + "Matcher should match valid CCW device path" + ); + } - uev.devpath = format!("0.0.0001/{relpath}/virtio1/{subsystem}/{devname}"); - assert!(!matcher.is_match(&uev)); + #[cfg(target_arch = "s390x")] + #[rstest] + #[case::wrong_device_id( + "/devices/css0/0.0.0001/0.0.0003/virtio1/block/vda", + "Wrong device ID should be rejected" + )] + #[case::missing_root_bus( + "0.0.0001/0.0.0002/virtio1/block/vda", + "Missing root bus path should be rejected" + )] + #[case::missing_virtio_number( + "/devices/css0/0.0.0001/0.0.0002/virtio/block/vda", + "Missing virtio number should be rejected" + )] + #[case::incomplete_path( + "/devices/css0/0.0.0001/0.0.0002/virtio1", + "Incomplete path should be rejected" + )] + #[case::invalid_subchannel_set_high( + "/devices/css0/1.0.0001/0.0.0002/virtio1/block/vda", + "Invalid subchannel set (>0) should be rejected" + )] + #[case::invalid_subchannel_set_range( + "/devices/css0/0.4.0001/0.0.0002/virtio1/block/vda", + "Invalid subchannel set (>3) should be rejected" + )] + #[case::invalid_devno_range( + "/devices/css0/0.0.10000/0.0.0002/virtio1/block/vda", + "Invalid devno (>0xffff) should be rejected" + )] + #[tokio::test] + async fn test_virtio_blk_ccw_matcher_invalid_paths( + #[case] devpath: &str, + #[case] description: &str, + ) { + let root_bus = CCW_ROOT_BUS_PATH; + let subsystem = "block"; + let devname = "vda"; + let relpath = "0.0.0002"; - uev.devpath = format!("{root_bus}/0.0.0001/{relpath}/virtio/{subsystem}/{devname}"); - assert!(!matcher.is_match(&uev)); + let mut uev = crate::uevent::Uevent::default(); + uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); + uev.subsystem = subsystem.to_string(); + uev.devname = devname.to_string(); + uev.devpath = devpath.to_string(); - uev.devpath = format!("{root_bus}/0.0.0001/{relpath}/virtio1"); - assert!(!matcher.is_match(&uev)); + let device = ccw::Device::from_str(relpath).unwrap(); + let matcher = VirtioBlkCCWMatcher::new(root_bus, &device); - uev.devpath = format!("{root_bus}/1.0.0001/{relpath}/virtio1/{subsystem}/{devname}"); - assert!(!matcher.is_match(&uev)); + assert!(!matcher.is_match(&uev), "{}", description); + } - uev.devpath = format!("{root_bus}/0.4.0001/{relpath}/virtio1/{subsystem}/{devname}"); - assert!(!matcher.is_match(&uev)); + // Helper to create a standard uevent for testing + fn create_mmio_uevent(devname: &str, mmio_id: u32, virtio_id: u32) -> crate::uevent::Uevent { + let mut uev = crate::uevent::Uevent::default(); + uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); + uev.subsystem = BLOCK.to_string(); + uev.devname = devname.to_string(); + uev.devpath = format!( + "/sys/devices/virtio-mmio-cmdline/virtio-mmio.{}/virtio{}/block/{}", + mmio_id, virtio_id, devname + ); + uev + } - uev.devpath = format!("{root_bus}/0.0.10000/{relpath}/virtio1/{subsystem}/{devname}"); - assert!(!matcher.is_match(&uev)); + #[rstest] + #[case::vda_matches_vda("vda", "vda", 0, 0, true)] + #[case::vdb_matches_vdb("vdb", "vdb", 4, 4, true)] + #[case::vda_rejects_vdb("vda", "vdb", 0, 0, false)] + #[case::vdb_rejects_vda("vdb", "vda", 4, 4, false)] + #[tokio::test] + async fn test_virtio_blk_mmio_matcher_basic_matching( + #[case] matcher_devname: &str, + #[case] uevent_devname: &str, + #[case] mmio_id: u32, + #[case] virtio_id: u32, + #[case] should_match: bool, + ) { + let matcher = VirtioBlkMmioMatcher::new(matcher_devname); + let uev = create_mmio_uevent(uevent_devname, mmio_id, virtio_id); + + assert_eq!( + matcher.is_match(&uev), + should_match, + "Matcher for '{}' should {} uevent for '{}'", + matcher_devname, + if should_match { "match" } else { "reject" }, + uevent_devname + ); + } + + #[rstest] + #[case::wrong_subsystem(test_helpers::SUBSYSTEM_NET, "Wrong subsystem should be rejected")] + #[tokio::test] + async fn test_virtio_blk_mmio_matcher_wrong_subsystem( + #[case] wrong_subsystem: &str, + #[case] description: &str, + ) { + let matcher = VirtioBlkMmioMatcher::new(TEST_DEVNAME); + let mut uev = create_mmio_uevent(TEST_DEVNAME, 0, 0); + uev.subsystem = wrong_subsystem.to_string(); + + assert!(!matcher.is_match(&uev), "{}", description); } #[tokio::test] - #[allow(clippy::redundant_clone)] - async fn test_virtio_blk_mmio_matcher() { - let devname_a = "vda"; - let devname_b = "vdb"; - let mut uev_a = crate::uevent::Uevent::default(); - uev_a.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); - uev_a.subsystem = BLOCK.to_string(); - uev_a.devname = devname_a.to_string(); - uev_a.devpath = - format!("/sys/devices/virtio-mmio-cmdline/virtio-mmio.0/virtio0/block/{devname_a}"); - let matcher_a = VirtioBlkMmioMatcher::new(devname_a); + async fn test_virtio_blk_mmio_matcher_empty_devname() { + let matcher = VirtioBlkMmioMatcher::new(TEST_DEVNAME); + let mut uev = create_mmio_uevent(TEST_DEVNAME, 0, 0); + uev.devname = String::new(); - let mut uev_b = uev_a.clone(); - uev_b.devpath = - format!("/sys/devices/virtio-mmio-cmdline/virtio-mmio.4/virtio4/block/{devname_b}"); - let matcher_b = VirtioBlkMmioMatcher::new(devname_b); + assert!( + !matcher.is_match(&uev), + "Matcher should reject uevent with empty devname" + ); + } - assert!(matcher_a.is_match(&uev_a)); - assert!(matcher_b.is_match(&uev_b)); - assert!(!matcher_b.is_match(&uev_a)); - assert!(!matcher_a.is_match(&uev_b)); + #[tokio::test] + async fn test_virtio_blk_mmio_matcher_wrong_device_suffix() { + let matcher = VirtioBlkMmioMatcher::new(TEST_DEVNAME); + let mut uev = create_mmio_uevent(TEST_DEVNAME, 0, 0); + // Modify to create the wrong suffix scenario + uev.devpath = + "/sys/devices/virtio-mmio-cmdline/virtio-mmio.0/virtio0/block/vdc".to_string(); + + assert!( + !matcher.is_match(&uev), + "Matcher should reject uevent with wrong device suffix" + ); + } + + #[tokio::test] + async fn test_virtio_blk_pci_matcher_correct_match() { + let matcher = VirtioBlkPciMatcher::new(TEST_PCI_RELPATH, TEST_ROOT_COMPLEX); + let uev = create_pci_uevent(TEST_DEVNAME, TEST_PCI_RELPATH, TEST_ROOT_COMPLEX, 4); + + assert!( + matcher.is_match(&uev), + "Matcher should match correctly formatted uevent" + ); + } + + #[rstest] + #[case::wrong_subsystem(test_helpers::SUBSYSTEM_NET, "Wrong subsystem should be rejected")] + #[tokio::test] + async fn test_virtio_blk_pci_matcher_wrong_subsystem( + #[case] wrong_subsystem: &str, + #[case] description: &str, + ) { + let matcher = VirtioBlkPciMatcher::new(TEST_PCI_RELPATH, TEST_ROOT_COMPLEX); + let mut uev = create_pci_uevent(TEST_DEVNAME, TEST_PCI_RELPATH, TEST_ROOT_COMPLEX, 4); + uev.subsystem = wrong_subsystem.to_string(); + + assert!(!matcher.is_match(&uev), "{}", description); + } + + #[tokio::test] + async fn test_virtio_blk_pci_matcher_empty_devname() { + let matcher = VirtioBlkPciMatcher::new(TEST_PCI_RELPATH, TEST_ROOT_COMPLEX); + let mut uev = create_pci_uevent(TEST_DEVNAME, TEST_PCI_RELPATH, TEST_ROOT_COMPLEX, 4); + uev.devname = String::new(); + + assert!( + !matcher.is_match(&uev), + "Matcher should reject uevent with empty devname" + ); + } + + #[tokio::test] + async fn test_virtio_blk_pci_matcher_missing_virtio_component() { + let root_bus = create_pci_root_bus_path(TEST_ROOT_COMPLEX); + let matcher = VirtioBlkPciMatcher::new(TEST_PCI_RELPATH, TEST_ROOT_COMPLEX); + let mut uev = create_pci_uevent(TEST_DEVNAME, TEST_PCI_RELPATH, TEST_ROOT_COMPLEX, 4); + uev.devpath = format!("{root_bus}{}/block/{}", TEST_PCI_RELPATH, TEST_DEVNAME); + + assert!( + !matcher.is_match(&uev), + "Matcher should reject path without virtio component" + ); + } + + #[rstest] + #[case::virtio4(4)] + #[case::virtio99(99)] + #[case::virtio0(0)] + #[tokio::test] + async fn test_virtio_blk_pci_matcher_accepts_any_virtio_number(#[case] virtio_id: u32) { + let matcher = VirtioBlkPciMatcher::new(TEST_PCI_RELPATH, TEST_ROOT_COMPLEX); + let uev = create_pci_uevent(TEST_DEVNAME, TEST_PCI_RELPATH, TEST_ROOT_COMPLEX, virtio_id); + + assert!( + matcher.is_match(&uev), + "Matcher should accept virtio{} number", + virtio_id + ); } } diff --git a/src/agent/src/device/mod.rs b/src/agent/src/device/mod.rs index dbe11e4a5e..6a036c9a11 100644 --- a/src/agent/src/device/mod.rs +++ b/src/agent/src/device/mod.rs @@ -1084,6 +1084,16 @@ fn expose_guest_infiniband_devices(logger: &Logger, spec: &mut Spec) -> Result<( Ok(()) } +// Test helper constants for common edge case testing +#[cfg(test)] +pub(crate) mod test_helpers { + #[cfg(not(target_arch = "s390x"))] + pub const SUBSYSTEM_BLOCK: &str = "block"; + pub const SUBSYSTEM_NET: &str = "net"; + #[cfg(not(target_arch = "s390x"))] + pub const ACTION_REMOVE: &str = "remove"; +} + #[cfg(test)] mod tests { use super::*; @@ -1094,10 +1104,225 @@ mod tests { LinuxResources, LinuxResourcesBuilder, SpecBuilder, }; use oci_spec::runtime as oci; + use rstest::rstest; use std::iter::FromIterator; use tempfile::tempdir; const VM_ROOTFS: &str = "/"; + const TEST_CONTAINER_PATH: &str = "/dev/null"; + const TEST_VM_PATH: &str = "/dev/null"; + const TEST_MAJOR: i64 = 7; + const TEST_MINOR: i64 = 2; + + // Helper function to create a test logger + fn create_test_logger() -> slog::Logger { + slog::Logger::root(slog::Discard, o!()) + } + + // Helper function to create a device update map + fn create_device_update<'a>( + container_path: &'a str, + vm_path: &str, + ) -> HashMap<&'a str, DevUpdate> { + HashMap::from_iter(vec![( + container_path, + DevUpdate::new(container_path, vm_path).unwrap(), + )]) + } + + #[rstest] + #[case::valid_zeros("0000:00:00.0", true)] + #[case::valid_normal("0000:01:02.3", true)] + #[case::valid_max("ffff:ff:1f.7", true)] + #[case::invalid_text("invalid", false)] + #[case::empty_string("", false)] + #[case::invalid_format("not_a_pci_address", false)] + #[case::random_string("random_string", false)] + #[test] + fn test_parse_pci_bdf_name(#[case] input: &str, #[case] should_parse: bool) { + let result = parse_pci_bdf_name(input); + assert_eq!( + result.is_some(), + should_parse, + "parse_pci_bdf_name('{}') should {} parse", + input, + if should_parse { + "successfully" + } else { + "fail to" + } + ); + } + + #[rstest] + #[case::normal(0, 1, 2, 3, "0000:01")] + #[case::max_values(0xffff, 0xff, 0x1f, 7, "ffff:ff")] + #[case::all_zeros(0, 0, 0, 0, "0000:00")] + #[test] + fn test_bus_of_addr( + #[case] domain: u16, + #[case] bus: u8, + #[case] slot: u8, + #[case] func: u8, + #[case] expected: &str, + ) { + let addr = pci::Address::new(domain, bus, pci::SlotFn::new(slot, func).unwrap()); + assert_eq!(bus_of_addr(&addr).unwrap(), expected); + } + + #[rstest] + #[case::single_bus( + vec![(0, 1, 0, 0), (0, 1, 1, 0), (0, 1, 2, 0)], + Some("0000:01") + )] + #[case::multiple_buses( + vec![(0, 1, 0, 0), (0, 2, 0, 0)], + None + )] + #[case::empty_list(vec![], None)] + #[test] + fn test_unique_bus_from_pci_addresses( + #[case] addr_tuples: Vec<(u16, u8, u8, u8)>, + #[case] expected: Option<&str>, + ) { + let addrs: Vec = addr_tuples + .into_iter() + .map(|(d, b, s, f)| pci::Address::new(d, b, pci::SlotFn::new(s, f).unwrap())) + .collect(); + + match expected { + Some(bus) => assert_eq!(unique_bus_from_pci_addresses(&addrs).unwrap(), bus), + None => assert!(unique_bus_from_pci_addresses(&addrs).is_err()), + } + } + + #[test] + fn test_read_single_bus_from_pci_bus_dir() { + let testdir = tempdir().expect("failed to create tmpdir"); + let bridgebuspath = testdir.path().join("pci_bus"); + fs::create_dir_all(&bridgebuspath).unwrap(); + + let bus_dir = bridgebuspath.join("0000:01"); + fs::create_dir(&bus_dir).unwrap(); + assert_eq!( + read_single_bus_from_pci_bus_dir(&bridgebuspath).unwrap(), + "0000:01" + ); + + let bus_dir2 = bridgebuspath.join("0000:02"); + fs::create_dir(&bus_dir2).unwrap(); + assert!(read_single_bus_from_pci_bus_dir(&bridgebuspath).is_err()); + + let empty_dir = testdir.path().join("empty_pci_bus"); + fs::create_dir_all(&empty_dir).unwrap(); + assert!(read_single_bus_from_pci_bus_dir(&empty_dir).is_err()); + } + + #[test] + fn test_infer_bus_from_child_devices() { + let testdir = tempdir().expect("failed to create tmpdir"); + let devpath = testdir.path(); + + let dev1 = devpath.join("0000:01:00.0"); + let dev2 = devpath.join("0000:01:01.0"); + let dev3 = devpath.join("0000:01:02.0"); + fs::create_dir(&dev1).unwrap(); + fs::create_dir(&dev2).unwrap(); + fs::create_dir(&dev3).unwrap(); + + assert_eq!( + infer_bus_from_child_devices(&devpath.to_path_buf()).unwrap(), + "0000:01" + ); + + let dev4 = devpath.join("0000:02:00.0"); + fs::create_dir(&dev4).unwrap(); + assert!(infer_bus_from_child_devices(&devpath.to_path_buf()).is_err()); + + let empty_dir = testdir.path().join("no_devices"); + fs::create_dir_all(&empty_dir).unwrap(); + assert!(infer_bus_from_child_devices(&empty_dir).is_err()); + + let non_pci_dir = testdir.path().join("with_non_pci"); + fs::create_dir_all(&non_pci_dir).unwrap(); + let pci_dev = non_pci_dir.join("0000:03:00.0"); + let non_pci = non_pci_dir.join("not_a_pci_device"); + fs::create_dir(&pci_dev).unwrap(); + fs::create_dir(&non_pci).unwrap(); + assert_eq!( + infer_bus_from_child_devices(&non_pci_dir).unwrap(), + "0000:03" + ); + } + + #[test] + fn test_online_device() { + let testdir = tempdir().expect("failed to create tmpdir"); + let device_path = testdir.path().join("online"); + + online_device(device_path.to_str().unwrap()).unwrap(); + assert_eq!(fs::read_to_string(&device_path).unwrap(), "1"); + + assert!(online_device("/nonexistent/path/to/device").is_err()); + } + + #[test] + fn test_dev_update_new() { + let result = DevUpdate::new("/dev/null", "/dev/null"); + assert!(result.is_ok()); + + let update = result.unwrap(); + assert_eq!(update.final_path, Some("/dev/null".to_string())); + + let result2 = DevUpdate::new("/dev/null", "/dev/custom"); + assert!(result2.is_ok()); + let update2 = result2.unwrap(); + assert_eq!(update2.final_path, Some("/dev/custom".to_string())); + + let result_invalid = DevUpdate::new("/nonexistent/device", "/dev/null"); + assert!(result_invalid.is_err()); + } + + #[rstest] + #[case::char_device("/dev/null", true, "c", true)] + #[case::block_device("/", false, "b", true)] + #[case::nonexistent("/nonexistent/path", true, "", false)] + #[case::empty_path("", true, "", false)] + #[test] + fn test_device_info_new( + #[case] path: &str, + #[case] is_char: bool, + #[case] expected_type: &str, + #[case] should_succeed: bool, + ) { + let result = DeviceInfo::new(path, is_char); + + if should_succeed { + let info = result.unwrap(); + assert_eq!(info.cgroup_type, expected_type); + assert!(info.guest_major >= 0); + assert!(info.guest_minor >= 0); + } else { + assert!(result.is_err()); + } + } + + #[test] + fn test_spec_update_conversions() { + let info = DeviceInfo::new("/dev/null", true).unwrap(); + let spec_update: SpecUpdate = info.into(); + assert!(spec_update.dev.is_some()); + assert_eq!(spec_update.pci.len(), 0); + + let dev_update = DevUpdate::new("/dev/null", "/dev/null").unwrap(); + let spec_update2: SpecUpdate = dev_update.into(); + assert!(spec_update2.dev.is_some()); + assert_eq!(spec_update2.pci.len(), 0); + + let spec_update3 = SpecUpdate::default(); + assert!(spec_update3.dev.is_none()); + assert_eq!(spec_update3.pci.len(), 0); + } #[test] fn test_cdi_devices_from_visible_devices() { @@ -1173,7 +1398,7 @@ mod tests { #[test] fn test_update_device_cgroup() { - let logger = slog::Logger::root(slog::Discard, o!()); + let logger = create_test_logger(); let mut linux = Linux::default(); linux.set_resources(Some(LinuxResources::default())); let mut spec = SpecBuilder::default().linux(linux).build().unwrap(); @@ -1204,8 +1429,7 @@ mod tests { #[test] fn test_update_spec_devices() { - let logger = slog::Logger::root(slog::Discard, o!()); - let (major, minor) = (7, 2); + let logger = create_test_logger(); let mut spec = Spec::default(); // vm_path empty @@ -1213,15 +1437,10 @@ mod tests { assert!(update.is_err()); // linux is empty - let container_path = "/dev/null"; - let vm_path = "/dev/null"; let res = update_spec_devices( &logger, &mut spec, - HashMap::from_iter(vec![( - container_path, - DeviceInfo::new(vm_path, true).unwrap().into(), - )]), + create_device_update(TEST_CONTAINER_PATH, TEST_VM_PATH), ); assert!(res.is_err()); @@ -1231,10 +1450,7 @@ mod tests { let res = update_spec_devices( &logger, &mut spec, - HashMap::from_iter(vec![( - container_path, - DeviceInfo::new(vm_path, true).unwrap().into(), - )]), + create_device_update(TEST_CONTAINER_PATH, TEST_VM_PATH), ); assert!(res.is_err()); @@ -1243,8 +1459,8 @@ mod tests { .unwrap() .set_devices(Some(vec![LinuxDeviceBuilder::default() .path(PathBuf::from("/dev/null2")) - .major(major) - .minor(minor) + .major(TEST_MAJOR) + .minor(TEST_MINOR) .build() .unwrap()])); @@ -1252,16 +1468,13 @@ mod tests { let res = update_spec_devices( &logger, &mut spec, - HashMap::from_iter(vec![( - container_path, - DeviceInfo::new(vm_path, true).unwrap().into(), - )]), + create_device_update(TEST_CONTAINER_PATH, TEST_VM_PATH), ); assert!( res.is_err(), "container_path={:?} vm_path={:?} spec={:?}", - container_path, - vm_path, + TEST_CONTAINER_PATH, + TEST_VM_PATH, spec ); @@ -1271,16 +1484,13 @@ mod tests { .devices_mut() .as_mut() .unwrap()[0] - .set_path(PathBuf::from(container_path)); + .set_path(PathBuf::from(TEST_CONTAINER_PATH)); // spec.linux.resources is empty let res = update_spec_devices( &logger, &mut spec, - HashMap::from_iter(vec![( - container_path, - DeviceInfo::new(vm_path, true).unwrap().into(), - )]), + create_device_update(TEST_CONTAINER_PATH, TEST_VM_PATH), ); assert!(res.is_ok()); @@ -1289,17 +1499,17 @@ mod tests { .as_mut() .unwrap() .set_devices(Some(vec![LinuxDeviceBuilder::default() - .path(PathBuf::from(container_path)) - .major(major) - .minor(minor) + .path(PathBuf::from(TEST_CONTAINER_PATH)) + .major(TEST_MAJOR) + .minor(TEST_MINOR) .build() .unwrap()])); spec.linux_mut().as_mut().unwrap().set_resources(Some( oci::LinuxResourcesBuilder::default() .devices(vec![LinuxDeviceCgroupBuilder::default() - .major(major) - .minor(minor) + .major(TEST_MAJOR) + .minor(TEST_MINOR) .build() .unwrap()]) .build() @@ -1309,17 +1519,14 @@ mod tests { let res = update_spec_devices( &logger, &mut spec, - HashMap::from_iter(vec![( - container_path, - DeviceInfo::new(vm_path, true).unwrap().into(), - )]), + create_device_update(TEST_CONTAINER_PATH, TEST_VM_PATH), ); assert!(res.is_ok()); } #[test] fn test_update_spec_devices_guest_host_conflict() { - let logger = slog::Logger::root(slog::Discard, o!()); + let logger = create_test_logger(); let null_rdev = fs::metadata("/dev/null").unwrap().rdev(); let zero_rdev = fs::metadata("/dev/zero").unwrap().rdev(); @@ -1444,7 +1651,7 @@ mod tests { #[test] fn test_update_spec_devices_char_block_conflict() { - let logger = slog::Logger::root(slog::Discard, o!()); + let logger = create_test_logger(); let null_rdev = fs::metadata("/dev/null").unwrap().rdev(); @@ -1544,7 +1751,7 @@ mod tests { #[test] fn test_update_spec_devices_final_path() { - let logger = slog::Logger::root(slog::Discard, o!()); + let logger = create_test_logger(); let null_rdev = fs::metadata("/dev/null").unwrap().rdev(); let guest_major = stat::major(null_rdev) as i64; @@ -1831,7 +2038,7 @@ mod tests { uev.devpath = devpath.clone(); uev.devname = devname.to_string(); - let logger = slog::Logger::root(slog::Discard, o!()); + let logger = create_test_logger(); let sandbox = Arc::new(Mutex::new(Sandbox::new(&logger).unwrap())); let mut sb = sandbox.lock().await; @@ -1855,7 +2062,7 @@ mod tests { #[tokio::test] async fn test_handle_cdi_devices() { - let logger = slog::Logger::root(slog::Discard, o!()); + let logger = create_test_logger(); let mut spec = Spec::default(); let mut annotations = HashMap::new(); diff --git a/src/agent/src/device/network_device_handler.rs b/src/agent/src/device/network_device_handler.rs index 54504d4877..83d2151a5d 100644 --- a/src/agent/src/device/network_device_handler.rs +++ b/src/agent/src/device/network_device_handler.rs @@ -127,70 +127,186 @@ impl UeventMatcher for NetCcwMatcher { #[cfg(test)] mod tests { use super::*; + #[cfg(not(target_arch = "s390x"))] + use crate::device::test_helpers; + use rstest::rstest; + + #[cfg(not(target_arch = "s390x"))] + // Helper to create a network PCI uevent + fn create_net_pci_uevent( + relpath: &str, + root_complex: &str, + interface: &str, + ) -> crate::uevent::Uevent { + let root_bus = create_pci_root_bus_path(root_complex); + let mut uev = crate::uevent::Uevent::default(); + uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); + uev.devpath = format!("{root_bus}{relpath}"); + uev.subsystem = String::from("net"); + uev.interface = String::from(interface); + uev + } + + #[cfg(not(target_arch = "s390x"))] + #[rstest] + #[case::matcher_a_matches_uev_a( + "/0000:00:02.0/0000:01:01.0", + "/0000:00:02.0/0000:01:01.0", + true + )] + #[case::matcher_b_matches_uev_b( + "/0000:00:02.0/0000:01:02.0", + "/0000:00:02.0/0000:01:02.0", + true + )] + #[case::matcher_a_rejects_uev_b( + "/0000:00:02.0/0000:01:01.0", + "/0000:00:02.0/0000:01:02.0", + false + )] + #[case::matcher_b_rejects_uev_a( + "/0000:00:02.0/0000:01:02.0", + "/0000:00:02.0/0000:01:01.0", + false + )] + #[tokio::test] + async fn test_net_pci_matcher_basic_matching( + #[case] matcher_relpath: &str, + #[case] uevent_relpath: &str, + #[case] should_match: bool, + ) { + let matcher = NetPciMatcher::new(matcher_relpath, "00"); + let uev = create_net_pci_uevent(uevent_relpath, "00", "eth0"); + + assert_eq!( + matcher.is_match(&uev), + should_match, + "Matcher for '{}' should {} uevent for '{}'", + matcher_relpath, + if should_match { "match" } else { "reject" }, + uevent_relpath + ); + } #[cfg(not(target_arch = "s390x"))] #[tokio::test] - #[allow(clippy::redundant_clone)] - async fn test_net_pci_matcher() { + async fn test_net_pci_matcher_with_net_substring() { + let relpath = "/0000:00:02.0/0000:01:03.0"; let root_bus = create_pci_root_bus_path("00"); - let relpath_a = "/0000:00:02.0/0000:01:01.0"; - - let mut uev_a = crate::uevent::Uevent::default(); - uev_a.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); - uev_a.devpath = format!("{root_bus}{relpath_a}"); - uev_a.subsystem = String::from("net"); - uev_a.interface = String::from("eth0"); - let matcher_a = NetPciMatcher::new(relpath_a, "00"); - println!("Matcher a : {}", matcher_a.devpath); - - let relpath_b = "/0000:00:02.0/0000:01:02.0"; - let mut uev_b = uev_a.clone(); - uev_b.devpath = format!("{root_bus}{relpath_b}"); - let matcher_b = NetPciMatcher::new(relpath_b, "00"); - - assert!(matcher_a.is_match(&uev_a)); - assert!(matcher_b.is_match(&uev_b)); - assert!(!matcher_b.is_match(&uev_a)); - assert!(!matcher_a.is_match(&uev_b)); - - let relpath_c = "/0000:00:02.0/0000:01:03.0"; let net_substr = "/net/eth0"; - let mut uev_c = uev_a.clone(); - uev_c.devpath = format!("{root_bus}{relpath_c}{net_substr}"); - let matcher_c = NetPciMatcher::new(relpath_c, "00"); - assert!(matcher_c.is_match(&uev_c)); - assert!(!matcher_a.is_match(&uev_c)); - assert!(!matcher_b.is_match(&uev_c)); + let matcher = NetPciMatcher::new(relpath, "00"); + let mut uev = create_net_pci_uevent(relpath, "00", "eth0"); + uev.devpath = format!("{root_bus}{relpath}{net_substr}"); + + assert!( + matcher.is_match(&uev), + "Matcher should match uevent with /net/ substring in devpath" + ); + } + + #[cfg(not(target_arch = "s390x"))] + #[rstest] + #[case::wrong_subsystem(test_helpers::SUBSYSTEM_BLOCK, "Wrong subsystem should be rejected")] + #[tokio::test] + async fn test_net_pci_matcher_wrong_subsystem( + #[case] wrong_subsystem: &str, + #[case] description: &str, + ) { + let relpath = "/0000:00:02.0/0000:01:01.0"; + let matcher = NetPciMatcher::new(relpath, "00"); + let mut uev = create_net_pci_uevent(relpath, "00", "eth0"); + uev.subsystem = wrong_subsystem.to_string(); + + assert!(!matcher.is_match(&uev), "{}", description); + } + + #[cfg(not(target_arch = "s390x"))] + #[rstest] + #[case::wrong_action(test_helpers::ACTION_REMOVE, "Wrong action should be rejected")] + #[tokio::test] + async fn test_net_pci_matcher_wrong_action( + #[case] wrong_action: &str, + #[case] description: &str, + ) { + let relpath = "/0000:00:02.0/0000:01:01.0"; + let matcher = NetPciMatcher::new(relpath, "00"); + let mut uev = create_net_pci_uevent(relpath, "00", "eth0"); + uev.action = wrong_action.to_string(); + + assert!(!matcher.is_match(&uev), "{}", description); + } + + #[cfg(not(target_arch = "s390x"))] + #[tokio::test] + async fn test_net_pci_matcher_empty_interface() { + let relpath = "/0000:00:02.0/0000:01:01.0"; + let matcher = NetPciMatcher::new(relpath, "00"); + let mut uev = create_net_pci_uevent(relpath, "00", "eth0"); + uev.interface = String::new(); + + assert!( + !matcher.is_match(&uev), + "Matcher should reject uevent with empty interface" + ); + } + + #[cfg(not(target_arch = "s390x"))] + #[tokio::test] + async fn test_net_pci_matcher_wrong_devpath() { + let relpath = "/0000:00:02.0/0000:01:01.0"; + let root_bus = create_pci_root_bus_path("00"); + let matcher = NetPciMatcher::new(relpath, "00"); + let mut uev = create_net_pci_uevent(relpath, "00", "eth0"); + uev.devpath = format!("{}/0000:00:03.0", root_bus); + + assert!( + !matcher.is_match(&uev), + "Matcher should reject uevent with wrong devpath" + ); } #[cfg(target_arch = "s390x")] + // Helper to create a network CCW uevent + fn create_net_ccw_uevent(device: &ccw::Device, interface: &str) -> crate::uevent::Uevent { + let mut uev = crate::uevent::Uevent::default(); + uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); + uev.subsystem = String::from("net"); + uev.interface = String::from(interface); + uev.devpath = format!( + "{}/0.0.0001/{}/virtio1/{}/{}", + CCW_ROOT_BUS_PATH, device, uev.subsystem, uev.interface + ); + uev + } + + #[cfg(target_arch = "s390x")] + #[rstest] + #[case::dev_a_matches_uev_a(0, 1, 0, 1, true)] + #[case::dev_b_matches_uev_b(1, 2, 1, 2, true)] + #[case::dev_a_rejects_uev_b(0, 1, 1, 2, false)] + #[case::dev_b_rejects_uev_a(1, 2, 0, 1, false)] #[tokio::test] - async fn test_net_ccw_matcher() { - let dev_a = ccw::Device::new(0, 1).unwrap(); - let dev_b = ccw::Device::new(1, 2).unwrap(); + async fn test_net_ccw_matcher_basic_matching( + #[case] matcher_ssid: u8, + #[case] matcher_devno: u16, + #[case] uevent_ssid: u8, + #[case] uevent_devno: u16, + #[case] should_match: bool, + ) { + let matcher_dev = ccw::Device::new(matcher_ssid, matcher_devno).unwrap(); + let uevent_dev = ccw::Device::new(uevent_ssid, uevent_devno).unwrap(); - let mut uev_a = crate::uevent::Uevent::default(); - uev_a.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); - uev_a.subsystem = String::from("net"); - uev_a.interface = String::from("eth0"); - uev_a.devpath = format!( - "{}/0.0.0001/{}/virtio1/{}/{}", - CCW_ROOT_BUS_PATH, dev_a, uev_a.subsystem, uev_a.interface + let matcher = NetCcwMatcher::new(CCW_ROOT_BUS_PATH, &matcher_dev); + let uev = create_net_ccw_uevent(&uevent_dev, "eth0"); + + assert_eq!( + matcher.is_match(&uev), + should_match, + "Matcher for device {} should {} uevent for device {}", + matcher_dev, + if should_match { "match" } else { "reject" }, + uevent_dev ); - - let mut uev_b = uev_a.clone(); - uev_b.devpath = format!( - "{}/0.0.0001/{}/virtio1/{}/{}", - CCW_ROOT_BUS_PATH, dev_b, uev_b.subsystem, uev_b.interface - ); - - let matcher_a = NetCcwMatcher::new(CCW_ROOT_BUS_PATH, &dev_a); - let matcher_b = NetCcwMatcher::new(CCW_ROOT_BUS_PATH, &dev_b); - - assert!(matcher_a.is_match(&uev_a)); - assert!(matcher_b.is_match(&uev_b)); - assert!(!matcher_b.is_match(&uev_a)); - assert!(!matcher_a.is_match(&uev_b)); } } diff --git a/src/agent/src/device/nvdimm_device_handler.rs b/src/agent/src/device/nvdimm_device_handler.rs index d5a6cdca87..008dc74eef 100644 --- a/src/agent/src/device/nvdimm_device_handler.rs +++ b/src/agent/src/device/nvdimm_device_handler.rs @@ -81,3 +81,106 @@ impl UeventMatcher for PmemBlockMatcher { && !uev.devname.is_empty() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::device::test_helpers; + use rstest::rstest; + + // Helper to create a PMEM uevent + fn create_pmem_uevent(devname: &str, region: u32) -> crate::uevent::Uevent { + let mut uev = crate::uevent::Uevent::default(); + uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); + uev.subsystem = BLOCK.to_string(); + uev.devname = devname.to_string(); + uev.devpath = format!( + "{}/LNXSYSTM:00/LNXSYBUS:00/ACPI0012:00/ndbus0/region{}/btt{}.0/block/{}", + ACPI_DEV_PATH, region, region, devname + ); + uev + } + + #[rstest] + #[case::pmem0_matches_pmem0("pmem0", "pmem0", 0, 0, true)] + #[case::pmem1_matches_pmem1("pmem1", "pmem1", 1, 1, true)] + #[case::pmem0_rejects_pmem1("pmem0", "pmem1", 0, 1, false)] + #[case::pmem1_rejects_pmem0("pmem1", "pmem0", 1, 0, false)] + #[tokio::test] + async fn test_pmem_block_matcher_basic_matching( + #[case] matcher_devname: &str, + #[case] uevent_devname: &str, + #[case] _matcher_region: u32, + #[case] uevent_region: u32, + #[case] should_match: bool, + ) { + let matcher = PmemBlockMatcher::new(matcher_devname); + let uev = create_pmem_uevent(uevent_devname, uevent_region); + + assert_eq!( + matcher.is_match(&uev), + should_match, + "Matcher for '{}' should {} uevent for '{}'", + matcher_devname, + if should_match { "match" } else { "reject" }, + uevent_devname + ); + } + + #[rstest] + #[case::wrong_subsystem(test_helpers::SUBSYSTEM_NET, "Wrong subsystem should be rejected")] + #[tokio::test] + async fn test_pmem_block_matcher_wrong_subsystem( + #[case] wrong_subsystem: &str, + #[case] description: &str, + ) { + let devname = "pmem0"; + let matcher = PmemBlockMatcher::new(devname); + let mut uev = create_pmem_uevent(devname, 0); + uev.subsystem = wrong_subsystem.to_string(); + + assert!(!matcher.is_match(&uev), "{}", description); + } + + #[tokio::test] + async fn test_pmem_block_matcher_empty_devname() { + let devname = "pmem0"; + let matcher = PmemBlockMatcher::new(devname); + let mut uev = create_pmem_uevent(devname, 0); + uev.devname = String::new(); + + assert!( + !matcher.is_match(&uev), + "Matcher should reject uevent with empty devname" + ); + } + + #[tokio::test] + async fn test_pmem_block_matcher_wrong_prefix() { + let devname = "pmem0"; + let matcher = PmemBlockMatcher::new(devname); + let mut uev = create_pmem_uevent(devname, 0); + uev.devpath = format!("/devices/pci0000:00/block/{}", devname); + + assert!( + !matcher.is_match(&uev), + "Matcher should reject devpath not starting with ACPI_DEV_PATH" + ); + } + + #[tokio::test] + async fn test_pmem_block_matcher_wrong_suffix() { + let devname = "pmem0"; + let matcher = PmemBlockMatcher::new(devname); + let mut uev = create_pmem_uevent(devname, 0); + uev.devpath = format!( + "{}/LNXSYSTM:00/LNXSYBUS:00/ACPI0012:00/ndbus0/region0/btt0.0/block/pmem2", + ACPI_DEV_PATH + ); + + assert!( + !matcher.is_match(&uev), + "Matcher should reject devpath with wrong device suffix" + ); + } +} diff --git a/src/agent/src/device/scsi_device_handler.rs b/src/agent/src/device/scsi_device_handler.rs index 7a9e68acae..941bdd15aa 100644 --- a/src/agent/src/device/scsi_device_handler.rs +++ b/src/agent/src/device/scsi_device_handler.rs @@ -130,7 +130,9 @@ fn scan_scsi_bus(scsi_addr: &str) -> Result<()> { #[cfg(test)] mod tests { use super::*; + use crate::device::test_helpers; use crate::linux_abi::U_EVENT_ACTION_ADD; + use rstest::rstest; fn make_scsi_block_uevent(addr: &str, devname: &str, devpath_suffix: &str) -> Uevent { let root_bus = create_pci_root_bus_path("00"); @@ -148,48 +150,125 @@ mod tests { uev } + #[rstest] + #[case::addr_a_matches_uev_a("0:0", "sda", "0:0", "sda", true)] + #[case::addr_b_matches_uev_b("2:0", "sdb", "2:0", "sdb", true)] + #[case::addr_a_rejects_uev_b("0:0", "sda", "2:0", "sdb", false)] + #[case::addr_b_rejects_uev_a("2:0", "sdb", "0:0", "sda", false)] #[tokio::test] - #[allow(clippy::redundant_clone)] - async fn test_scsi_block_matcher() { - let root_bus = create_pci_root_bus_path("00"); - let devname = "sda"; + async fn test_scsi_block_matcher_basic_matching( + #[case] matcher_addr: &str, + #[case] _matcher_devname: &str, + #[case] uevent_addr: &str, + #[case] uevent_devname: &str, + #[case] should_match: bool, + ) { + let matcher = ScsiBlockMatcher::new(matcher_addr); + let uev = make_scsi_block_uevent(uevent_addr, uevent_devname, uevent_devname); - let mut uev_a = crate::uevent::Uevent::default(); - let addr_a = "0:0"; - uev_a.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); - uev_a.subsystem = BLOCK.to_string(); - uev_a.devname = devname.to_string(); - uev_a.devpath = - format!("{root_bus}/0000:00:00.0/virtio0/host0/target0:0:0/0:0:{addr_a}/block/sda"); - let matcher_a = ScsiBlockMatcher::new(addr_a); + assert_eq!( + matcher.is_match(&uev), + should_match, + "Matcher for SCSI addr '{}' should {} uevent for addr '{}'", + matcher_addr, + if should_match { "match" } else { "reject" }, + uevent_addr + ); + } - let mut uev_b = uev_a.clone(); - let addr_b = "2:0"; - uev_b.devname = "sdb".to_string(); - uev_b.devpath = - format!("{root_bus}/0000:00:00.0/virtio0/host0/target0:0:2/0:0:{addr_b}/block/sdb"); - let matcher_b = ScsiBlockMatcher::new(addr_b); + #[rstest] + #[case::wrong_subsystem(test_helpers::SUBSYSTEM_NET, "Wrong subsystem should be rejected")] + #[tokio::test] + async fn test_scsi_block_matcher_wrong_subsystem( + #[case] wrong_subsystem: &str, + #[case] description: &str, + ) { + let addr = "0:0"; + let matcher = ScsiBlockMatcher::new(addr); + let mut uev = make_scsi_block_uevent(addr, "sda", "sda"); + uev.subsystem = wrong_subsystem.to_string(); - assert!(matcher_a.is_match(&uev_a)); - assert!(matcher_b.is_match(&uev_b)); - assert!(!matcher_b.is_match(&uev_a)); - assert!(!matcher_a.is_match(&uev_b)); + assert!(!matcher.is_match(&uev), "{}", description); } #[tokio::test] - async fn test_scsi_block_matcher_rejects_partitions() { - let uev_whole = make_scsi_block_uevent("0:0", "sda", "sda"); - let uev_part = make_scsi_block_uevent("0:0", "sda1", "sda/sda1"); - - let matcher = ScsiBlockMatcher::new("0:0"); + async fn test_scsi_block_matcher_empty_devname() { + let addr = "0:0"; + let matcher = ScsiBlockMatcher::new(addr); + let mut uev = make_scsi_block_uevent(addr, "sda", "sda"); + uev.devname = String::new(); assert!( - matcher.is_match(&uev_whole), - "whole disk uevent should match" + !matcher.is_match(&uev), + "Matcher should reject uevent with empty devname" ); + } + + #[tokio::test] + async fn test_scsi_block_matcher_wrong_path() { + let root_bus = create_pci_root_bus_path("00"); + let addr = "0:0"; + let matcher = ScsiBlockMatcher::new(addr); + let mut uev = make_scsi_block_uevent(addr, "sda", "sda"); + uev.devpath = + format!("{root_bus}/0000:00:00.0/virtio0/host0/target0:0:1/0:0:1:0/block/sdc"); + assert!( - !matcher.is_match(&uev_part), - "partition uevent should not match" + !matcher.is_match(&uev), + "Matcher should reject devpath not containing the SCSI address search string" + ); + } + + #[rstest] + #[case::addr_1_1_matches("1:1", "sdc", true)] + #[case::addr_0_0_rejects("0:0", "sda", false)] + #[tokio::test] + async fn test_scsi_block_matcher_different_addresses( + #[case] test_addr: &str, + #[case] test_devname: &str, + #[case] should_match_1_1: bool, + ) { + let root_bus = create_pci_root_bus_path("00"); + let matcher = ScsiBlockMatcher::new("1:1"); + let mut uev = make_scsi_block_uevent(test_addr, test_devname, test_devname); + + // Adjust devpath for addr 1:1 + if test_addr == "1:1" { + uev.devpath = format!("{root_bus}/0000:00:00.0/virtio0/host0/target0:0:1/0:0:{test_addr}/block/{test_devname}"); + } + + assert_eq!( + matcher.is_match(&uev), + should_match_1_1, + "Matcher for '1:1' should {} uevent for addr '{}'", + if should_match_1_1 { "match" } else { "reject" }, + test_addr + ); + } + + #[rstest] + #[case::whole_disk("0:0", "sda", "sda", true)] + #[case::partition("0:0", "sda1", "sda/sda1", false)] + #[tokio::test] + async fn test_scsi_block_matcher_rejects_partitions( + #[case] addr: &str, + #[case] devname: &str, + #[case] devpath_suffix: &str, + #[case] should_match: bool, + ) { + let matcher = ScsiBlockMatcher::new(addr); + let uev = make_scsi_block_uevent(addr, devname, devpath_suffix); + + assert_eq!( + matcher.is_match(&uev), + should_match, + "{} uevent should {} match", + if devpath_suffix.contains('/') { + "partition" + } else { + "whole disk" + }, + if should_match { "" } else { "not" } ); } } diff --git a/src/agent/src/device/vfio_device_handler.rs b/src/agent/src/device/vfio_device_handler.rs index 8873dbd468..f6d852f047 100644 --- a/src/agent/src/device/vfio_device_handler.rs +++ b/src/agent/src/device/vfio_device_handler.rs @@ -374,7 +374,7 @@ pub async fn wait_for_pci_device( } // Represents an IOMMU group -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct IommuGroup(u32); impl fmt::Display for IommuGroup { @@ -469,37 +469,75 @@ where #[cfg(test)] mod tests { use super::*; + use rstest::rstest; use tempfile::tempdir; - #[tokio::test] - #[allow(clippy::redundant_clone)] - async fn test_vfio_matcher() { - let grpa = IommuGroup(1); - let grpb = IommuGroup(22); - - let mut uev_a = crate::uevent::Uevent::default(); - uev_a.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); - uev_a.devname = format!("vfio/{grpa}"); - uev_a.devpath = format!("/devices/virtual/vfio/{grpa}"); - let matcher_a = VfioMatcher::new(grpa); - - let mut uev_b = uev_a.clone(); - uev_b.devpath = format!("/devices/virtual/vfio/{grpb}"); - let matcher_b = VfioMatcher::new(grpb); - - assert!(matcher_a.is_match(&uev_a)); - assert!(matcher_b.is_match(&uev_b)); - assert!(!matcher_b.is_match(&uev_a)); - assert!(!matcher_a.is_match(&uev_b)); + // Helper to create a VFIO uevent for testing + fn create_vfio_uevent(group: IommuGroup) -> Uevent { + let mut uev = Uevent::default(); + uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); + uev.devname = format!("vfio/{group}"); + uev.devpath = format!("/devices/virtual/vfio/{group}"); + uev } - #[test] - fn test_split_vfio_pci_option() { + + #[rstest] + #[case::group_1_matches(IommuGroup(1), IommuGroup(1), true)] + #[case::group_22_matches(IommuGroup(22), IommuGroup(22), true)] + #[case::group_1_rejects_22(IommuGroup(1), IommuGroup(22), false)] + #[case::group_22_rejects_1(IommuGroup(22), IommuGroup(1), false)] + #[tokio::test] + async fn test_vfio_matcher_basic_matching( + #[case] matcher_group: IommuGroup, + #[case] uevent_group: IommuGroup, + #[case] should_match: bool, + ) { + let matcher = VfioMatcher::new(matcher_group); + let uev = create_vfio_uevent(uevent_group); + assert_eq!( - split_vfio_pci_option("0000:01:00.0=02/01"), - Some(("0000:01:00.0", "02/01")) + matcher.is_match(&uev), + should_match, + "Matcher for group {} should {} uevent for group {}", + matcher_group, + if should_match { "match" } else { "reject" }, + uevent_group ); - assert_eq!(split_vfio_pci_option("0000:01:00.0=02/01=rubbish"), None); - assert_eq!(split_vfio_pci_option("0000:01:00.0"), None); + } + + #[tokio::test] + async fn test_vfio_matcher_wrong_devpath() { + let group = IommuGroup(1); + let matcher = VfioMatcher::new(group); + let mut uev = create_vfio_uevent(group); + uev.devpath = "/devices/virtual/vfio/wrong".to_string(); + + assert!( + !matcher.is_match(&uev), + "Matcher should reject devpath with wrong IOMMU group" + ); + } + + #[tokio::test] + async fn test_vfio_matcher_partial_match() { + let group = IommuGroup(1); + let matcher = VfioMatcher::new(group); + let mut uev = create_vfio_uevent(group); + uev.devpath = "/devices/virtual/vfio/1extra".to_string(); + + assert!( + !matcher.is_match(&uev), + "Matcher should reject devpath with extra characters after group number" + ); + } + + #[rstest] + #[case::valid_option("0000:01:00.0=02/01", Some(("0000:01:00.0", "02/01")))] + #[case::too_many_equals("0000:01:00.0=02/01=rubbish", None)] + #[case::missing_equals("0000:01:00.0", None)] + #[test] + fn test_split_vfio_pci_option(#[case] input: &str, #[case] expected: Option<(&str, &str)>) { + assert_eq!(split_vfio_pci_option(input), expected); } #[test] @@ -580,29 +618,124 @@ mod tests { assert!(pci_iommu_group(&syspci, dev2).is_err()); } + // Helper to create a PCI uevent for testing + fn create_pci_uevent(relpath: &str, root_complex: &str) -> Uevent { + let root_bus = create_pci_root_bus_path(root_complex); + let mut uev = Uevent::default(); + uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); + uev.devpath = format!("{root_bus}{relpath}"); + uev + } + + #[rstest] + #[case::relpath_a_matches("/0000:00:06.0", "/0000:00:06.0", "00", true)] + #[case::relpath_b_matches( + "/0000:00:06.0/0000:02:00.0", + "/0000:00:06.0/0000:02:00.0", + "00", + true + )] + #[case::relpath_a_rejects_b("/0000:00:06.0", "/0000:00:06.0/0000:02:00.0", "00", false)] + #[case::relpath_b_rejects_a("/0000:00:06.0/0000:02:00.0", "/0000:00:06.0", "00", false)] + #[test] + fn test_pci_matcher_basic_matching( + #[case] matcher_relpath: &str, + #[case] uevent_relpath: &str, + #[case] root_complex: &str, + #[case] should_match: bool, + ) { + let matcher = PciMatcher::new(matcher_relpath, root_complex).unwrap(); + let uev = create_pci_uevent(uevent_relpath, root_complex); + + assert_eq!( + matcher.is_match(&uev), + should_match, + "Matcher for '{}' should {} uevent for '{}'", + matcher_relpath, + if should_match { "match" } else { "reject" }, + uevent_relpath + ); + } + + #[test] + fn test_pci_matcher_different_root_complex() { + let relpath = "/0000:00:06.0"; + let matcher = PciMatcher::new(relpath, "00").unwrap(); + let uev = create_pci_uevent(relpath, "01"); + + assert!( + !matcher.is_match(&uev), + "Matcher should reject uevent from different root complex" + ); + } + + #[test] + fn test_pci_matcher_partial_path() { + let root_bus = create_pci_root_bus_path("00"); + let relpath = "/0000:00:06.0"; + let matcher = PciMatcher::new(relpath, "00").unwrap(); + let mut uev = create_pci_uevent(relpath, "00"); + uev.devpath = format!("{root_bus}/0000:00:06"); + + assert!( + !matcher.is_match(&uev), + "Matcher should reject partial PCI path match" + ); + } + + #[cfg(target_arch = "s390x")] + // Helper to create an AP uevent for testing + fn create_ap_uevent(card: &str, relpath: &str, action: &str) -> Uevent { + let mut uev = Uevent::default(); + uev.action = action.to_string(); + uev.subsystem = "ap".to_string(); + uev.devpath = format!("{AP_ROOT_BUS_PATH}/card{card}/{relpath}"); + uev + } + #[cfg(target_arch = "s390x")] #[tokio::test] - async fn test_vfio_ap_matcher() { - let subsystem = "ap"; + async fn test_vfio_ap_matcher_add_action() { let card = "0a"; let relpath = format!("{card}.0001"); - - let mut uev = Uevent::default(); - uev.action = U_EVENT_ACTION_ADD.to_string(); - uev.subsystem = subsystem.to_string(); - uev.devpath = format!("{AP_ROOT_BUS_PATH}/card{card}/{relpath}"); - let ap_address = ap::Address::from_str(&relpath).unwrap(); let matcher = ApMatcher::new(ap_address); + let uev = create_ap_uevent(card, &relpath, U_EVENT_ACTION_ADD); - assert!(matcher.is_match(&uev)); + assert!( + matcher.is_match(&uev), + "Matcher should match uevent with add action" + ); + } - let mut uev_remove = uev.clone(); - uev_remove.action = U_EVENT_ACTION_REMOVE.to_string(); - assert!(!matcher.is_match(&uev_remove)); + #[cfg(target_arch = "s390x")] + #[tokio::test] + async fn test_vfio_ap_matcher_remove_action() { + let card = "0a"; + let relpath = format!("{card}.0001"); + let ap_address = ap::Address::from_str(&relpath).unwrap(); + let matcher = ApMatcher::new(ap_address); + let uev = create_ap_uevent(card, &relpath, U_EVENT_ACTION_REMOVE); - let mut uev_other_device = uev.clone(); - uev_other_device.devpath = format!("{AP_ROOT_BUS_PATH}/card{card}/{card}.0002"); - assert!(!matcher.is_match(&uev_other_device)); + assert!( + !matcher.is_match(&uev), + "Matcher should reject uevent with remove action" + ); + } + + #[cfg(target_arch = "s390x")] + #[tokio::test] + async fn test_vfio_ap_matcher_different_device() { + let card = "0a"; + let relpath = format!("{card}.0001"); + let ap_address = ap::Address::from_str(&relpath).unwrap(); + let matcher = ApMatcher::new(ap_address); + let other_relpath = format!("{card}.0002"); + let uev = create_ap_uevent(card, &other_relpath, U_EVENT_ACTION_ADD); + + assert!( + !matcher.is_match(&uev), + "Matcher should reject uevent for different AP device" + ); } }