diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index d68f533270..66ee9d9fa7 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -585,39 +585,25 @@ fn update_spec_device( #[instrument] async fn virtiommio_blk_device_handler( device: &Device, - spec: &mut Spec, _sandbox: &Arc>, - devidx: &DevIndex, -) -> Result<()> { +) -> Result { if device.vm_path.is_empty() { return Err(anyhow!("Invalid path for virtio mmio blk device")); } - update_spec_device( - spec, - devidx, - &device.container_path, - DevNumUpdate::from_vm_path(&device.vm_path)?.into(), - ) + Ok(DevNumUpdate::from_vm_path(&device.vm_path)?.into()) } // device.Id should be a PCI path string #[instrument] async fn virtio_blk_device_handler( device: &Device, - spec: &mut Spec, sandbox: &Arc>, - devidx: &DevIndex, -) -> Result<()> { +) -> Result { let pcipath = pci::Path::from_str(&device.id)?; let vm_path = get_virtio_blk_pci_device_name(sandbox, &pcipath).await?; - update_spec_device( - spec, - devidx, - &device.container_path, - DevNumUpdate::from_vm_path(vm_path)?.into(), - ) + Ok(DevNumUpdate::from_vm_path(vm_path)?.into()) } // device.id should be a CCW path string @@ -625,29 +611,17 @@ async fn virtio_blk_device_handler( #[instrument] async fn virtio_blk_ccw_device_handler( device: &Device, - spec: &mut Spec, sandbox: &Arc>, - devidx: &DevIndex, -) -> Result<()> { +) -> Result { let ccw_device = ccw::Device::from_str(&device.id)?; let vm_path = get_virtio_blk_ccw_device_name(sandbox, &ccw_device).await?; - update_spec_device( - spec, - devidx, - &device.container_path, - DevNumUpdate::from_vm_path(vm_path)?.into(), - ) + Ok(DevNumUpdate::from_vm_path(vm_path)?.into()) } #[cfg(not(target_arch = "s390x"))] #[instrument] -async fn virtio_blk_ccw_device_handler( - _: &Device, - _: &mut Spec, - _: &Arc>, - _: &DevIndex, -) -> Result<()> { +async fn virtio_blk_ccw_device_handler(_: &Device, _: &Arc>) -> Result { Err(anyhow!("CCW is only supported on s390x")) } @@ -655,37 +629,23 @@ async fn virtio_blk_ccw_device_handler( #[instrument] async fn virtio_scsi_device_handler( device: &Device, - spec: &mut Spec, sandbox: &Arc>, - devidx: &DevIndex, -) -> Result<()> { +) -> Result { let vm_path = get_scsi_device_name(sandbox, &device.id).await?; - update_spec_device( - spec, - devidx, - &device.container_path, - DevNumUpdate::from_vm_path(vm_path)?.into(), - ) + Ok(DevNumUpdate::from_vm_path(vm_path)?.into()) } #[instrument] async fn virtio_nvdimm_device_handler( device: &Device, - spec: &mut Spec, _sandbox: &Arc>, - devidx: &DevIndex, -) -> Result<()> { +) -> Result { if device.vm_path.is_empty() { return Err(anyhow!("Invalid path for nvdimm device")); } - update_spec_device( - spec, - devidx, - &device.container_path, - DevNumUpdate::from_vm_path(&device.vm_path)?.into(), - ) + Ok(DevNumUpdate::from_vm_path(&device.vm_path)?.into()) } fn split_vfio_option(opt: &str) -> Option<(&str, &str)> { @@ -703,12 +663,7 @@ fn split_vfio_option(opt: &str) -> Option<(&str, &str)> { // Each option should have the form "DDDD:BB:DD.F=" // DDDD:BB:DD.F is the device's PCI address in the host // is a PCI path to the device in the guest (see pci.rs) -async fn vfio_device_handler( - device: &Device, - spec: &mut Spec, - sandbox: &Arc>, - devidx: &DevIndex, -) -> Result<()> { +async fn vfio_device_handler(device: &Device, sandbox: &Arc>) -> Result { let vfio_in_guest = device.field_type != DRIVER_VFIO_GK_TYPE; let mut group = None; @@ -743,20 +698,15 @@ async fn vfio_device_handler( } } - if vfio_in_guest { + Ok(if vfio_in_guest { // If there are any devices at all, logic above ensures that group is not None let group = group.unwrap(); let vm_path = get_vfio_device_name(sandbox, group).await?; - update_spec_device( - spec, - devidx, - &device.container_path, - DevUpdate::from_vm_path(&vm_path, vm_path.clone())?, - )?; - } - - Ok(()) + DevUpdate::from_vm_path(&vm_path, vm_path.clone())?.into() + } else { + SpecUpdate::default() + }) } impl DevIndex { @@ -790,22 +740,25 @@ pub async fn add_devices( spec: &mut Spec, sandbox: &Arc>, ) -> Result<()> { - let devidx = DevIndex::new(spec); + let mut dev_updates = HashMap::<&str, DevUpdate>::with_capacity(devices.len()); for device in devices.iter() { - add_device(device, spec, sandbox, &devidx).await?; + let update = add_device(device, sandbox).await?; + if let Some(dev_update) = update.dev { + dev_updates.insert(&device.container_path, dev_update); + } + } + + let devidx = DevIndex::new(spec); + for (container_path, update) in dev_updates.drain() { + update_spec_device(spec, &devidx, container_path, update)?; } Ok(()) } #[instrument] -async fn add_device( - device: &Device, - spec: &mut Spec, - sandbox: &Arc>, - devidx: &DevIndex, -) -> Result<()> { +async fn add_device(device: &Device, sandbox: &Arc>) -> Result { // log before validation to help with debugging gRPC protocol version differences. info!(sl!(), "device-id: {}, device-type: {}, device-vm-path: {}, device-container-path: {}, device-options: {:?}", device.id, device.field_type, device.vm_path, device.container_path, device.options); @@ -823,14 +776,12 @@ async fn add_device( } match device.field_type.as_str() { - DRIVER_BLK_TYPE => virtio_blk_device_handler(device, spec, sandbox, devidx).await, - DRIVER_BLK_CCW_TYPE => virtio_blk_ccw_device_handler(device, spec, sandbox, devidx).await, - DRIVER_MMIO_BLK_TYPE => virtiommio_blk_device_handler(device, spec, sandbox, devidx).await, - DRIVER_NVDIMM_TYPE => virtio_nvdimm_device_handler(device, spec, sandbox, devidx).await, - DRIVER_SCSI_TYPE => virtio_scsi_device_handler(device, spec, sandbox, devidx).await, - DRIVER_VFIO_GK_TYPE | DRIVER_VFIO_TYPE => { - vfio_device_handler(device, spec, sandbox, devidx).await - } + DRIVER_BLK_TYPE => virtio_blk_device_handler(device, sandbox).await, + DRIVER_BLK_CCW_TYPE => virtio_blk_ccw_device_handler(device, sandbox).await, + DRIVER_MMIO_BLK_TYPE => virtiommio_blk_device_handler(device, sandbox).await, + DRIVER_NVDIMM_TYPE => virtio_nvdimm_device_handler(device, sandbox).await, + DRIVER_SCSI_TYPE => virtio_scsi_device_handler(device, sandbox).await, + DRIVER_VFIO_GK_TYPE | DRIVER_VFIO_TYPE => vfio_device_handler(device, sandbox).await, _ => Err(anyhow!("Unknown device type {}", device.field_type)), } }