diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 3258be1165..60f661c120 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -7,7 +7,9 @@ use libc::{c_uint, major, minor}; use nix::sys::stat; use regex::Regex; use std::collections::HashMap; +use std::ffi::OsStr; use std::fs; +use std::os::unix::ffi::OsStrExt; use std::os::unix::fs::MetadataExt; use std::path::Path; use std::str::FromStr; @@ -62,6 +64,40 @@ pub fn online_device(path: &str) -> Result<()> { Ok(()) } +// Force a given PCI device to bind to the given driver, does +// basically the same thing as +// driverctl set-override +#[instrument] +pub fn pci_driver_override(syspci: T, dev: pci::Address, drv: U) -> Result<()> +where + T: AsRef + std::fmt::Debug, + U: AsRef + std::fmt::Debug, +{ + let syspci = Path::new(&syspci); + let drv = drv.as_ref(); + info!(sl!(), "rebind_pci_driver: {} => {:?}", dev, drv); + + let devpath = syspci.join("devices").join(dev.to_string()); + let overridepath = &devpath.join("driver_override"); + + fs::write(overridepath, drv.as_bytes())?; + + let drvpath = &devpath.join("driver"); + let need_unbind = match fs::read_link(drvpath) { + Ok(d) if d.file_name() == Some(drv) => return Ok(()), // Nothing to do + Err(e) if e.kind() == std::io::ErrorKind::NotFound => false, // No current driver + Err(e) => return Err(anyhow!("Error checking driver on {}: {}", dev, e)), + Ok(_) => true, // Current driver needs unbinding + }; + if need_unbind { + let unbindpath = &drvpath.join("unbind"); + fs::write(unbindpath, dev.to_string())?; + } + let probepath = syspci.join("drivers_probe"); + fs::write(probepath, dev.to_string())?; + Ok(()) +} + // pcipath_to_sysfs fetches the sysfs path for a PCI path, relative to // the sysfs path for the PCI host bridge, based on the PCI path // provided. @@ -1151,4 +1187,42 @@ mod tests { assert_eq!(split_vfio_option("0000:01:00.0=02/01=rubbish"), None); assert_eq!(split_vfio_option("0000:01:00.0"), None); } + + #[test] + fn test_pci_driver_override() { + let testdir = tempdir().expect("failed to create tmpdir"); + let syspci = testdir.path(); // Path to mock /sys/bus/pci + + let dev0 = pci::Address::new(0, 0, pci::SlotFn::new(0, 0).unwrap()); + let dev0path = syspci.join("devices").join(dev0.to_string()); + let dev0drv = dev0path.join("driver"); + let dev0override = dev0path.join("driver_override"); + + let drvapath = syspci.join("drivers").join("drv_a"); + let drvaunbind = drvapath.join("unbind"); + + let probepath = syspci.join("drivers_probe"); + + // Start mocking dev0 as being unbound + fs::create_dir_all(&dev0path).unwrap(); + + pci_driver_override(syspci, dev0, "drv_a").unwrap(); + assert_eq!(fs::read_to_string(&dev0override).unwrap(), "drv_a"); + assert_eq!(fs::read_to_string(&probepath).unwrap(), dev0.to_string()); + + // Now mock dev0 already being attached to drv_a + fs::create_dir_all(&drvapath).unwrap(); + std::os::unix::fs::symlink(&drvapath, dev0drv).unwrap(); + std::fs::remove_file(&probepath).unwrap(); + + pci_driver_override(syspci, dev0, "drv_a").unwrap(); // no-op + assert_eq!(fs::read_to_string(&dev0override).unwrap(), "drv_a"); + assert!(!probepath.exists()); + + // Now try binding to a different driver + pci_driver_override(syspci, dev0, "drv_b").unwrap(); + assert_eq!(fs::read_to_string(&dev0override).unwrap(), "drv_b"); + assert_eq!(fs::read_to_string(&probepath).unwrap(), dev0.to_string()); + assert_eq!(fs::read_to_string(&drvaunbind).unwrap(), dev0.to_string()); + } }