diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 629cc3fc23..31c5c120bc 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -592,38 +592,38 @@ fn update_spec_devices(spec: &mut Spec, mut updates: HashMap<&str, DevUpdate>) - Ok(()) } -// update_spec_pci PCI addresses in the OCI spec to be guest addresses -// instead of host addresses. It is given a map of (host address => -// guest address) +// update_env_pci alters PCI addresses in a set of environment +// variables to be correct for the VM instead of the host. It is +// given a map of (host address => guest address) #[instrument] -fn update_spec_pci(spec: &mut Spec, updates: HashMap) -> Result<()> { - // Correct PCI addresses in the environment - if let Some(process) = spec.process.as_mut() { - for envvar in process.env.iter_mut() { - let eqpos = envvar - .find('=') - .ok_or_else(|| anyhow!("Malformed OCI env entry {:?}", envvar))?; +pub fn update_env_pci( + env: &mut [String], + pcimap: &HashMap, +) -> Result<()> { + for envvar in env { + let eqpos = envvar + .find('=') + .ok_or_else(|| anyhow!("Malformed OCI env entry {:?}", envvar))?; - let (name, eqval) = envvar.split_at(eqpos); - let val = &eqval[1..]; + let (name, eqval) = envvar.split_at(eqpos); + let val = &eqval[1..]; - if !name.starts_with("PCIDEVICE_") { - continue; - } - - let mut guest_addrs = Vec::::new(); - - for host_addr in val.split(',') { - let host_addr = pci::Address::from_str(host_addr) - .with_context(|| format!("Can't parse {} environment variable", name))?; - let guest_addr = updates - .get(&host_addr) - .ok_or_else(|| anyhow!("Unable to translate host PCI address {}", host_addr))?; - guest_addrs.push(format!("{}", guest_addr)); - } - - envvar.replace_range(eqpos + 1.., guest_addrs.join(",").as_str()); + if !name.starts_with("PCIDEVICE_") { + continue; } + + let mut guest_addrs = Vec::::new(); + + for host_addr in val.split(',') { + let host_addr = pci::Address::from_str(host_addr) + .with_context(|| format!("Can't parse {} environment variable", name))?; + let guest_addr = pcimap + .get(&host_addr) + .ok_or_else(|| anyhow!("Unable to translate host PCI address {}", host_addr))?; + guest_addrs.push(format!("{}", guest_addr)); + } + + envvar.replace_range(eqpos + 1.., guest_addrs.join(",").as_str()); } Ok(()) @@ -768,7 +768,6 @@ pub async fn add_devices( sandbox: &Arc>, ) -> Result<()> { let mut dev_updates = HashMap::<&str, DevUpdate>::with_capacity(devices.len()); - let mut pci_updates = HashMap::::new(); for device in devices.iter() { let update = add_device(device, sandbox).await?; @@ -783,8 +782,9 @@ pub async fn add_devices( )); } + let mut sb = sandbox.lock().await; for (host, guest) in update.pci { - if let Some(other_guest) = pci_updates.insert(host, guest) { + if let Some(other_guest) = sb.pcimap.insert(host, guest) { return Err(anyhow!( "Conflicting guest address for host device {} ({} versus {})", host, @@ -796,6 +796,9 @@ pub async fn add_devices( } } + if let Some(process) = spec.process.as_mut() { + update_env_pci(&mut process.env, &sandbox.lock().await.pcimap)? + } update_spec_devices(spec, dev_updates) } @@ -860,7 +863,7 @@ pub fn update_device_cgroup(spec: &mut Spec) -> Result<()> { mod tests { use super::*; use crate::uevent::spawn_test_watcher; - use oci::{Linux, Process}; + use oci::Linux; use std::iter::FromIterator; use tempfile::tempdir; @@ -1199,7 +1202,7 @@ mod tests { } #[test] - fn test_update_spec_pci() { + fn test_update_env_pci() { let example_map = [ // Each is a host,guest pair of pci addresses ("0000:1a:01.0", "0000:01:01.0"), @@ -1209,17 +1212,11 @@ mod tests { ("0000:01:01.0", "ffff:02:1f.7"), ]; - let mut spec = Spec { - process: Some(Process { - env: vec![ - "PCIDEVICE_x=0000:1a:01.0,0000:1b:02.0".to_string(), - "PCIDEVICE_y=0000:01:01.0".to_string(), - "NOTAPCIDEVICE_blah=abcd:ef:01.0".to_string(), - ], - ..Process::default() - }), - ..Spec::default() - }; + let mut env = vec![ + "PCIDEVICE_x=0000:1a:01.0,0000:1b:02.0".to_string(), + "PCIDEVICE_y=0000:01:01.0".to_string(), + "NOTAPCIDEVICE_blah=abcd:ef:01.0".to_string(), + ]; let pci_fixups = example_map .iter() @@ -1231,10 +1228,9 @@ mod tests { }) .collect(); - let res = update_spec_pci(&mut spec, pci_fixups); + let res = update_env_pci(&mut env, &pci_fixups); assert!(res.is_ok()); - let env = &spec.process.as_ref().unwrap().env; assert_eq!(env[0], "PCIDEVICE_x=0000:01:01.0,0000:01:02.0"); assert_eq!(env[1], "PCIDEVICE_y=ffff:02:1f.7"); assert_eq!(env[2], "NOTAPCIDEVICE_blah=abcd:ef:01.0"); diff --git a/src/agent/src/rpc.rs b/src/agent/src/rpc.rs index 444be723cf..276746b9fe 100644 --- a/src/agent/src/rpc.rs +++ b/src/agent/src/rpc.rs @@ -43,7 +43,9 @@ use nix::sys::stat; use nix::unistd::{self, Pid}; use rustjail::process::ProcessOperations; -use crate::device::{add_devices, get_virtio_blk_pci_device_name, update_device_cgroup}; +use crate::device::{ + add_devices, get_virtio_blk_pci_device_name, update_device_cgroup, update_env_pci, +}; use crate::linux_abi::*; use crate::metrics::get_metrics; use crate::mount::{add_storages, baremount, remove_mounts, STORAGE_HANDLER_LIST}; @@ -359,11 +361,14 @@ impl AgentService { let s = self.sandbox.clone(); let mut sandbox = s.lock().await; - let process = req + let mut process = req .process .into_option() .ok_or_else(|| anyhow!(nix::Error::EINVAL))?; + // Apply any necessary corrections for PCI addresses + update_env_pci(&mut process.Env, &sandbox.pcimap)?; + let pipe_size = AGENT_CONFIG.read().await.container_pipe_size; let ocip = rustjail::process_grpc_to_oci(&process); let p = Process::new(&sl!(), &ocip, exec_id.as_str(), false, pipe_size)?; diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index 4dfb2eda58..3edbe7ea02 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -8,6 +8,7 @@ use crate::mount::{get_mount_fs_type, remove_mounts, TYPE_ROOTFS}; use crate::namespace::Namespace; use crate::netlink::Handle; use crate::network::Network; +use crate::pci; use crate::uevent::{Uevent, UeventMatcher}; use crate::watcher::BindWatcher; use anyhow::{anyhow, Context, Result}; @@ -56,6 +57,7 @@ pub struct Sandbox { pub event_rx: Arc>>, pub event_tx: Option>, pub bind_watcher: BindWatcher, + pub pcimap: HashMap, } impl Sandbox { @@ -88,6 +90,7 @@ impl Sandbox { event_rx, event_tx: Some(tx), bind_watcher: BindWatcher::new(), + pcimap: HashMap::new(), }) }