diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 8297ccb4ef..c22a657951 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -22,7 +22,7 @@ use crate::linux_abi::*; use crate::pci; use crate::sandbox::Sandbox; use crate::uevent::{wait_for_uevent, Uevent, UeventMatcher}; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use oci::{LinuxDeviceCgroup, LinuxResources, Spec}; use protocols::agent::Device; use tracing::instrument; @@ -484,12 +484,15 @@ impl From for DevUpdate { #[derive(Debug, Clone, Default)] struct SpecUpdate { dev: Option, + // optional corrections for PCI addresses + pci: Vec<(pci::Address, pci::Address)>, } impl> From for SpecUpdate { fn from(dev: T) -> Self { SpecUpdate { dev: Some(dev.into()), + pci: Vec::new(), } } } @@ -583,6 +586,43 @@ fn update_spec_devices(spec: &mut Spec, mut updates: HashMap<&str, DevUpdate>) - Ok(()) } +// update_spec_pci PCI addresses in the OCI spec to be guest addresses +// instead of host addresses. It is given a map of (host address => +// guest address) +#[instrument] +fn update_spec_pci(spec: &mut Spec, updates: HashMap) -> Result<()> { + // Correct PCI addresses in the environment + if let Some(process) = spec.process.as_mut() { + for envvar in process.env.iter_mut() { + let eqpos = envvar + .find('=') + .ok_or_else(|| anyhow!("Malformed OCI env entry {:?}", envvar))?; + + let (name, eqval) = envvar.split_at(eqpos); + let val = &eqval[1..]; + + if !name.starts_with("PCIDEVICE_") { + continue; + } + + let mut guest_addrs = Vec::::new(); + + for host_addr in val.split(',') { + let host_addr = pci::Address::from_str(host_addr) + .with_context(|| format!("Can't parse {} environment variable", name))?; + let guest_addr = updates + .get(&host_addr) + .ok_or_else(|| anyhow!("Unable to translate host PCI address {}", host_addr))?; + guest_addrs.push(format!("{}", guest_addr)); + } + + envvar.replace_range(eqpos + 1.., guest_addrs.join(",").as_str()); + } + } + + Ok(()) +} + // device.Id should be the predicted device name (vda, vdb, ...) // device.VmPath already provides a way to send it in #[instrument] @@ -668,11 +708,14 @@ fn split_vfio_option(opt: &str) -> Option<(&str, &str)> { // is a PCI path to the device in the guest (see pci.rs) async fn vfio_device_handler(device: &Device, sandbox: &Arc>) -> Result { let vfio_in_guest = device.field_type != DRIVER_VFIO_GK_TYPE; + let mut pci_fixups = Vec::<(pci::Address, pci::Address)>::new(); let mut group = None; for opt in device.options.iter() { - let (_, pcipath) = + let (host, pcipath) = split_vfio_option(opt).ok_or_else(|| anyhow!("Malformed VFIO option {:?}", opt))?; + let host = + pci::Address::from_str(host).context("Bad host PCI address in VFIO option {:?}")?; let pcipath = pci::Path::from_str(pcipath)?; let guestdev = wait_for_pci_device(sandbox, &pcipath).await?; @@ -698,17 +741,24 @@ async fn vfio_device_handler(device: &Device, sandbox: &Arc>) -> } group = devgroup; + + pci_fixups.push((host, guestdev)); } } - Ok(if vfio_in_guest { + let dev_update = if vfio_in_guest { // If there are any devices at all, logic above ensures that group is not None let group = group.unwrap(); let vm_path = get_vfio_device_name(sandbox, group).await?; - DevUpdate::from_vm_path(&vm_path, vm_path.clone())?.into() + Some(DevUpdate::from_vm_path(&vm_path, vm_path.clone())?) } else { - SpecUpdate::default() + None + }; + + Ok(SpecUpdate { + dev: dev_update, + pci: pci_fixups, }) } @@ -719,6 +769,7 @@ pub async fn add_devices( sandbox: &Arc>, ) -> Result<()> { let mut dev_updates = HashMap::<&str, DevUpdate>::with_capacity(devices.len()); + let mut pci_updates = HashMap::::new(); for device in devices.iter() { let update = add_device(device, sandbox).await?; @@ -732,6 +783,17 @@ pub async fn add_devices( &device.container_path )); } + + for (host, guest) in update.pci { + if let Some(other_guest) = pci_updates.insert(host, guest) { + return Err(anyhow!( + "Conflicting guest address for host device {} ({} versus {})", + host, + guest, + other_guest + )); + } + } } } @@ -802,7 +864,7 @@ pub fn update_device_cgroup(spec: &mut Spec) -> Result<()> { mod tests { use super::*; use crate::uevent::spawn_test_watcher; - use oci::Linux; + use oci::{Linux, Process}; use std::iter::FromIterator; use tempfile::tempdir; @@ -1140,6 +1202,48 @@ mod tests { assert_eq!(final_path, specdevices[0].path); } + #[test] + fn test_update_spec_pci() { + let example_map = [ + // Each is a host,guest pair of pci addresses + ("0000:1a:01.0", "0000:01:01.0"), + ("0000:1b:02.0", "0000:01:02.0"), + // This one has the same host address as guest address + // above, to test that we're not double-translating + ("0000:01:01.0", "ffff:02:1f.7"), + ]; + + let mut spec = Spec { + process: Some(Process { + env: vec![ + "PCIDEVICE_x=0000:1a:01.0,0000:1b:02.0".to_string(), + "PCIDEVICE_y=0000:01:01.0".to_string(), + "NOTAPCIDEVICE_blah=abcd:ef:01.0".to_string(), + ], + ..Process::default() + }), + ..Spec::default() + }; + + let pci_fixups = example_map + .iter() + .map(|(h, g)| { + ( + pci::Address::from_str(h).unwrap(), + pci::Address::from_str(g).unwrap(), + ) + }) + .collect(); + + let res = update_spec_pci(&mut spec, pci_fixups); + assert!(res.is_ok()); + + let env = &spec.process.as_ref().unwrap().env; + assert_eq!(env[0], "PCIDEVICE_x=0000:01:01.0,0000:01:02.0"); + assert_eq!(env[1], "PCIDEVICE_y=ffff:02:1f.7"); + assert_eq!(env[2], "NOTAPCIDEVICE_blah=abcd:ef:01.0"); + } + #[test] fn test_pcipath_to_sysfs() { let testdir = tempdir().expect("failed to create tmpdir"); diff --git a/src/agent/src/pci.rs b/src/agent/src/pci.rs index 25ad9a6a09..4cb6b521ae 100644 --- a/src/agent/src/pci.rs +++ b/src/agent/src/pci.rs @@ -20,7 +20,7 @@ const FUNCTION_MAX: u8 = (1 << FUNCTION_BITS) - 1; // Represents a PCI function's slot (a.k.a. device) and function // numbers, giving its location on a single logical bus -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct SlotFn(u8); impl SlotFn { @@ -94,7 +94,7 @@ impl fmt::Display for SlotFn { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub struct Address { domain: u16, bus: u8,