diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index abc2cbdea7..a07ca6c7ca 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -8,6 +8,7 @@ use nix::sys::stat; use regex::Regex; use std::collections::HashMap; use std::ffi::OsStr; +use std::fmt; use std::fs; use std::os::unix::ffi::OsStrExt; use std::os::unix::fs::MetadataExt; @@ -101,6 +102,49 @@ where Ok(()) } +// Represents an IOMMU group +#[derive(Clone, Debug, PartialEq, Eq)] +struct IommuGroup(u32); + +impl fmt::Display for IommuGroup { + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + write!(f, "{}", self.0) + } +} + +// Determine the IOMMU group of a PCI device +#[instrument] +fn pci_iommu_group(syspci: T, dev: pci::Address) -> Result> +where + T: AsRef + std::fmt::Debug, +{ + let syspci = Path::new(&syspci); + let grouppath = syspci + .join("devices") + .join(dev.to_string()) + .join("iommu_group"); + + match fs::read_link(&grouppath) { + // Device has no group + Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None), + Err(e) => Err(anyhow!("Error reading link {:?}: {}", &grouppath, e)), + Ok(group) => { + if let Some(group) = group.file_name() { + if let Some(group) = group.to_str() { + if let Ok(group) = group.parse::() { + return Ok(Some(IommuGroup(group))); + } + } + } + Err(anyhow!( + "Unexpected IOMMU group link {:?} => {:?}", + grouppath, + group + )) + } + } +} + // pcipath_to_sysfs fetches the sysfs path for a PCI path, relative to // the sysfs path for the PCI host bridge, based on the PCI path // provided. @@ -1235,4 +1279,43 @@ mod tests { assert_eq!(fs::read_to_string(&probepath).unwrap(), dev0.to_string()); assert_eq!(fs::read_to_string(&drvaunbind).unwrap(), dev0.to_string()); } + + #[test] + fn test_pci_iommu_group() { + let testdir = tempdir().expect("failed to create tmpdir"); // mock /sys + let syspci = testdir.path().join("bus").join("pci"); + + // Mock dev0, which has no group + let dev0 = pci::Address::new(0, 0, pci::SlotFn::new(0, 0).unwrap()); + let dev0path = syspci.join("devices").join(dev0.to_string()); + + fs::create_dir_all(&dev0path).unwrap(); + + // Test dev0 + assert!(pci_iommu_group(&syspci, dev0).unwrap().is_none()); + + // Mock dev1, which is in group 12 + let dev1 = pci::Address::new(0, 1, pci::SlotFn::new(0, 0).unwrap()); + let dev1path = syspci.join("devices").join(dev1.to_string()); + let dev1group = dev1path.join("iommu_group"); + + fs::create_dir_all(&dev1path).unwrap(); + std::os::unix::fs::symlink("../../../kernel/iommu_groups/12", &dev1group).unwrap(); + + // Test dev1 + assert_eq!( + pci_iommu_group(&syspci, dev1).unwrap(), + Some(IommuGroup(12)) + ); + + // Mock dev2, which has a bogus group (dir instead of symlink) + let dev2 = pci::Address::new(0, 2, pci::SlotFn::new(0, 0).unwrap()); + let dev2path = syspci.join("devices").join(dev2.to_string()); + let dev2group = dev2path.join("iommu_group"); + + fs::create_dir_all(&dev2group).unwrap(); + + // Test dev2 + assert!(pci_iommu_group(&syspci, dev2).is_err()); + } }