agent/device: Add function to get IOMMU group for a PCI device

For upcoming VFIO extensions we'll need to work with the IOMMU groups of
VFIO devices.  This helps us towards that by adding pci_iommu_group() to
retrieve the IOMMU group (if any) of a given PCI device.

Signed-off-by: David Gibson <david@gibson.dropbear.id.au>
This commit is contained in:
David Gibson 2021-10-06 15:55:49 +11:00
parent 13b06a35d5
commit ff59db7534

View File

@ -8,6 +8,7 @@ use nix::sys::stat;
use regex::Regex;
use std::collections::HashMap;
use std::ffi::OsStr;
use std::fmt;
use std::fs;
use std::os::unix::ffi::OsStrExt;
use std::os::unix::fs::MetadataExt;
@ -101,6 +102,49 @@ where
Ok(())
}
// Represents an IOMMU group
#[derive(Clone, Debug, PartialEq, Eq)]
struct IommuGroup(u32);
impl fmt::Display for IommuGroup {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(f, "{}", self.0)
}
}
// Determine the IOMMU group of a PCI device
#[instrument]
fn pci_iommu_group<T>(syspci: T, dev: pci::Address) -> Result<Option<IommuGroup>>
where
T: AsRef<OsStr> + std::fmt::Debug,
{
let syspci = Path::new(&syspci);
let grouppath = syspci
.join("devices")
.join(dev.to_string())
.join("iommu_group");
match fs::read_link(&grouppath) {
// Device has no group
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(e) => Err(anyhow!("Error reading link {:?}: {}", &grouppath, e)),
Ok(group) => {
if let Some(group) = group.file_name() {
if let Some(group) = group.to_str() {
if let Ok(group) = group.parse::<u32>() {
return Ok(Some(IommuGroup(group)));
}
}
}
Err(anyhow!(
"Unexpected IOMMU group link {:?} => {:?}",
grouppath,
group
))
}
}
}
// pcipath_to_sysfs fetches the sysfs path for a PCI path, relative to
// the sysfs path for the PCI host bridge, based on the PCI path
// provided.
@ -1235,4 +1279,43 @@ mod tests {
assert_eq!(fs::read_to_string(&probepath).unwrap(), dev0.to_string());
assert_eq!(fs::read_to_string(&drvaunbind).unwrap(), dev0.to_string());
}
#[test]
fn test_pci_iommu_group() {
let testdir = tempdir().expect("failed to create tmpdir"); // mock /sys
let syspci = testdir.path().join("bus").join("pci");
// Mock dev0, which has no group
let dev0 = pci::Address::new(0, 0, pci::SlotFn::new(0, 0).unwrap());
let dev0path = syspci.join("devices").join(dev0.to_string());
fs::create_dir_all(&dev0path).unwrap();
// Test dev0
assert!(pci_iommu_group(&syspci, dev0).unwrap().is_none());
// Mock dev1, which is in group 12
let dev1 = pci::Address::new(0, 1, pci::SlotFn::new(0, 0).unwrap());
let dev1path = syspci.join("devices").join(dev1.to_string());
let dev1group = dev1path.join("iommu_group");
fs::create_dir_all(&dev1path).unwrap();
std::os::unix::fs::symlink("../../../kernel/iommu_groups/12", &dev1group).unwrap();
// Test dev1
assert_eq!(
pci_iommu_group(&syspci, dev1).unwrap(),
Some(IommuGroup(12))
);
// Mock dev2, which has a bogus group (dir instead of symlink)
let dev2 = pci::Address::new(0, 2, pci::SlotFn::new(0, 0).unwrap());
let dev2path = syspci.join("devices").join(dev2.to_string());
let dev2group = dev2path.join("iommu_group");
fs::create_dir_all(&dev2group).unwrap();
// Test dev2
assert!(pci_iommu_group(&syspci, dev2).is_err());
}
}