diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 5d90075c8e..7e87f099a9 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -3,21 +3,19 @@ // SPDX-License-Identifier: Apache-2.0 // -use rustjail::errors::*; -use std::fs; -// use std::io::Write; use libc::{c_uint, major, minor}; use std::collections::HashMap; +use std::fs; use std::os::unix::fs::MetadataExt; use std::path::Path; -use std::sync::mpsc; -use std::sync::{Arc, Mutex}; +use std::sync::{mpsc, Arc, Mutex}; use crate::mount::{DRIVERBLKTYPE, DRIVERMMIOBLKTYPE, DRIVERNVDIMMTYPE, DRIVERSCSITYPE}; use crate::sandbox::Sandbox; use crate::{AGENT_CONFIG, GLOBAL_DEVICE_WATCHER}; use protocols::agent::Device; use protocols::oci::Spec; +use rustjail::errors::*; // Convenience macro to obtain the scope logger macro_rules! sl { @@ -54,23 +52,18 @@ pub const SCSI_BLOCK_SUFFIX: &str = "block"; const SCSI_DISK_SUFFIX: &str = "/device/block"; const SCSI_HOST_PATH: &str = "/sys/class/scsi_host"; -// DeviceHandler is the type of callback to be defined to handle every -// type of device driver. -type DeviceHandler = fn(&Device, &mut Spec, Arc>) -> Result<()>; +// DeviceHandler is the type of callback to be defined to handle every type of device driver. +type DeviceHandler = fn(&Device, &mut Spec, &Arc>) -> Result<()>; // DeviceHandlerList lists the supported drivers. #[cfg_attr(rustfmt, rustfmt_skip)] lazy_static! { - pub static ref DEVICEHANDLERLIST: HashMap<&'static str, DeviceHandler> = { - let mut m = HashMap::new(); - let blk: DeviceHandler = virtio_blk_device_handler; - m.insert(DRIVERBLKTYPE, blk); - let virtiommio: DeviceHandler = virtiommio_blk_device_handler; - m.insert(DRIVERMMIOBLKTYPE, virtiommio); - let local: DeviceHandler = virtio_nvdimm_device_handler; - m.insert(DRIVERNVDIMMTYPE, local); - let scsi: DeviceHandler = virtio_scsi_device_handler; - m.insert(DRIVERSCSITYPE, scsi); + static ref DEVICEHANDLERLIST: HashMap<&'static str, DeviceHandler> = { + let mut m: HashMap<&'static str, DeviceHandler> = HashMap::new(); + m.insert(DRIVERBLKTYPE, virtio_blk_device_handler); + m.insert(DRIVERMMIOBLKTYPE, virtiommio_blk_device_handler); + m.insert(DRIVERNVDIMMTYPE, virtio_nvdimm_device_handler); + m.insert(DRIVERSCSITYPE, virtio_scsi_device_handler); m }; } @@ -84,11 +77,11 @@ pub fn online_device(path: &str) -> Result<()> { Ok(()) } -// get_device_pci_address fetches the complete PCI address in sysfs, based on the PCI +// get_pci_device_address fetches the complete PCI address in sysfs, based on the PCI // identifier provided. This should be in the format: "bridgeAddr/deviceAddr". -// Here, bridgeAddr is the address at which the brige is attached on the root bus, +// Here, bridgeAddr is the address at which the bridge is attached on the root bus, // while deviceAddr is the address at which the device is attached on the bridge. -pub fn get_device_pci_address(pci_id: &str) -> Result { +fn get_pci_device_address(pci_id: &str) -> Result { let tokens: Vec<&str> = pci_id.split("/").collect(); if tokens.len() != 2 { @@ -139,78 +132,60 @@ pub fn get_device_pci_address(pci_id: &str) -> Result { Ok(bridge_device_pci_addr) } -pub fn get_device_name(sandbox: Arc>, dev_addr: &str) -> Result { - let mut dev_name: String = String::default(); +fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Result { + // Keep the same lock order as uevent::handle_block_add_event(), otherwise it may cause deadlock. + let mut w = GLOBAL_DEVICE_WATCHER.lock().unwrap(); + let sb = sandbox.lock().unwrap(); + for (key, value) in sb.pci_device_map.iter() { + if key.contains(dev_addr) { + info!(sl!(), "Device {} found in pci device map", dev_addr); + return Ok(format!("{}/{}", SYSTEM_DEV_PATH, value)); + } + } + drop(sb); + + // If device is not found in the device map, hotplug event has not + // been received yet, create and add channel to the watchers map. + // The key of the watchers map is the device we are interested in. + // Note this is done inside the lock, not to miss any events from the + // global udev listener. let (tx, rx) = mpsc::channel::(); + w.insert(dev_addr.to_string(), tx); + drop(w); - { - let watcher = GLOBAL_DEVICE_WATCHER.clone(); - let mut w = watcher.lock().unwrap(); - - let s = sandbox.clone(); - let sb = s.lock().unwrap(); - - for (key, value) in &(sb.pci_device_map) { - if key.contains(dev_addr) { - dev_name = value.to_string(); - info!(sl!(), "Device {} found in pci device map", dev_addr); - break; - } + info!(sl!(), "Waiting on channel for device notification\n"); + let hotplug_timeout = AGENT_CONFIG.read().unwrap().hotplug_timeout; + let dev_name = match rx.recv_timeout(hotplug_timeout) { + Ok(name) => name, + Err(_) => { + GLOBAL_DEVICE_WATCHER.lock().unwrap().remove_entry(dev_addr); + return Err(ErrorKind::ErrorCode(format!( + "Timeout reached after {:?} waiting for device {}", + hotplug_timeout, dev_addr + )) + .into()); } - - // If device is not found in the device map, hotplug event has not - // been received yet, create and add channel to the watchers map. - // The key of the watchers map is the device we are interested in. - // Note this is done inside the lock, not to miss any events from the - // global udev listener. - if dev_name == "" { - w.insert(dev_addr.to_string(), tx); - } - } - - if dev_name == "" { - info!(sl!(), "Waiting on channel for device notification\n"); - - let agent_config = AGENT_CONFIG.clone(); - let config = agent_config.read().unwrap(); - - match rx.recv_timeout(config.hotplug_timeout) { - Ok(name) => dev_name = name, - Err(_) => { - let watcher = GLOBAL_DEVICE_WATCHER.clone(); - let mut w = watcher.lock().unwrap(); - w.remove_entry(dev_addr); - - return Err(ErrorKind::ErrorCode(format!( - "Timeout reached after {:?} waiting for device {}", - config.hotplug_timeout, dev_addr - )) - .into()); - } - } - } + }; Ok(format!("{}/{}", SYSTEM_DEV_PATH, &dev_name)) } -pub fn get_scsi_device_name(sandbox: Arc>, scsi_addr: &str) -> Result { - scan_scsi_bus(scsi_addr)?; - +pub fn get_scsi_device_name(sandbox: &Arc>, scsi_addr: &str) -> Result { let dev_sub_path = format!("{}{}/{}", SCSI_HOST_CHANNEL, scsi_addr, SCSI_BLOCK_SUFFIX); - get_device_name(sandbox, dev_sub_path.as_str()) + scan_scsi_bus(scsi_addr)?; + get_device_name(sandbox, &dev_sub_path) } -pub fn get_pci_device_name(sandbox: Arc>, pci_id: &str) -> Result { - let pci_addr = get_device_pci_address(pci_id)?; +pub fn get_pci_device_name(sandbox: &Arc>, pci_id: &str) -> Result { + let pci_addr = get_pci_device_address(pci_id)?; rescan_pci_bus()?; - - get_device_name(sandbox, pci_addr.as_str()) + get_device_name(sandbox, &pci_addr) } -// scan_scsi_bus scans SCSI bus for the given SCSI address(SCSI-Id and LUN) -pub fn scan_scsi_bus(scsi_addr: &str) -> Result<()> { +/// Scan SCSI bus for the given SCSI address(SCSI-Id and LUN) +fn scan_scsi_bus(scsi_addr: &str) -> Result<()> { let tokens: Vec<&str> = scsi_addr.split(":").collect(); if tokens.len() != 2 { return Err(ErrorKind::Msg(format!( @@ -220,14 +195,12 @@ pub fn scan_scsi_bus(scsi_addr: &str) -> Result<()> { .into()); } - // Scan scsi host passing in the channel, SCSI id and LUN. Channel - // is always 0 because we have only one SCSI controller. + // Scan scsi host passing in the channel, SCSI id and LUN. + // Channel is always 0 because we have only one SCSI controller. let scan_data = format!("0 {} {}", tokens[0], tokens[1]); for entry in fs::read_dir(SCSI_HOST_PATH)? { - let entry = entry?; - - let host = entry.file_name(); + let host = entry?.file_name(); let scan_path = format!("{}/{}/{}", SCSI_HOST_PATH, host.to_str().unwrap(), "scan"); fs::write(scan_path, &scan_data)?; @@ -243,9 +216,6 @@ pub fn scan_scsi_bus(scsi_addr: &str) -> Result<()> { // This is needed to update information about minor/major numbers that cannot // be predicted from the caller. fn update_spec_device_list(device: &Device, spec: &mut Spec) -> Result<()> { - // If no container_path is provided, we won't be able to match and - // update the device in the OCI spec device list. This is an error. - let major_id: c_uint; let minor_id: c_uint; @@ -253,7 +223,7 @@ fn update_spec_device_list(device: &Device, spec: &mut Spec) -> Result<()> { // update the device in the OCI spec device list. This is an error. if device.container_path == "" { return Err(ErrorKind::Msg(format!( - "container_path cannot empty for device {:?}", + "container_path cannot empty for device {:?}", device )) .into()); @@ -304,9 +274,7 @@ fn update_spec_device_list(device: &Device, spec: &mut Spec) -> Result<()> { // Resources must be updated since they are used to identify the // device in the devices cgroup. - let resource = linux.Resources.as_mut(); - if resource.is_some() { - let res = resource.unwrap(); + if let Some(res) = linux.Resources.as_mut() { let ds = res.Devices.as_mut_slice(); for d in ds.iter_mut() { if d.Major == host_major && d.Minor == host_minor { @@ -331,10 +299,10 @@ fn update_spec_device_list(device: &Device, spec: &mut Spec) -> Result<()> { fn virtiommio_blk_device_handler( device: &Device, spec: &mut Spec, - _sandbox: Arc>, + _sandbox: &Arc>, ) -> Result<()> { if device.vm_path == "" { - return Err(ErrorKind::Msg("Invalid path for virtiommioblkdevice".to_string()).into()); + return Err(ErrorKind::Msg("Invalid path for virtio mmio blk device".to_string()).into()); } update_spec_device_list(device, spec) @@ -346,13 +314,10 @@ fn virtiommio_blk_device_handler( fn virtio_blk_device_handler( device: &Device, spec: &mut Spec, - sandbox: Arc>, + sandbox: &Arc>, ) -> Result<()> { - let dev_path = get_pci_device_name(sandbox, device.id.as_str())?; - let mut dev = device.clone(); - dev.vm_path = dev_path; - + dev.vm_path = get_pci_device_name(sandbox, &device.id)?; update_spec_device_list(&dev, spec) } @@ -360,39 +325,39 @@ fn virtio_blk_device_handler( fn virtio_scsi_device_handler( device: &Device, spec: &mut Spec, - sandbox: Arc>, + sandbox: &Arc>, ) -> Result<()> { - let dev_path = get_scsi_device_name(sandbox, device.id.as_str())?; - let mut dev = device.clone(); - dev.vm_path = dev_path; - + dev.vm_path = get_scsi_device_name(sandbox, &device.id)?; update_spec_device_list(&dev, spec) } fn virtio_nvdimm_device_handler( device: &Device, spec: &mut Spec, - _sandbox: Arc>, + _sandbox: &Arc>, ) -> Result<()> { + if device.vm_path == "" { + return Err(ErrorKind::Msg("Invalid path for nvdimm device".to_string()).into()); + } + update_spec_device_list(device, spec) } pub fn add_devices( - devices: Vec, + devices: &[Device], spec: &mut Spec, - sandbox: Arc>, + sandbox: &Arc>, ) -> Result<()> { for device in devices.iter() { - add_device(device, spec, sandbox.clone())?; + add_device(device, spec, sandbox)?; } Ok(()) } -fn add_device(device: &Device, spec: &mut Spec, sandbox: Arc>) -> Result<()> { - // log before validation to help with debugging gRPC protocol - // version differences. +fn add_device(device: &Device, spec: &mut Spec, sandbox: &Arc>) -> Result<()> { + // log before validation to help with debugging gRPC protocol version differences. info!(sl!(), "device-id: {}, device-type: {}, device-vm-path: {}, device-container-path: {}, device-options: {:?}", device.id, device.field_type, device.vm_path, device.container_path, device.options); @@ -412,12 +377,8 @@ fn add_device(device: &Device, spec: &mut Spec, sandbox: Arc>) -> ); } - let dev_handler = match DEVICEHANDLERLIST.get(device.field_type.as_str()) { - None => { - return Err(ErrorKind::Msg(format!("Unknown device type {}", device.field_type)).into()) - } - Some(t) => t, - }; - - dev_handler(device, spec, sandbox) + match DEVICEHANDLERLIST.get(device.field_type.as_str()) { + None => Err(ErrorKind::Msg(format!("Unknown device type {}", device.field_type)).into()), + Some(dev_handler) => dev_handler(device, spec, sandbox), + } } diff --git a/src/agent/src/grpc.rs b/src/agent/src/grpc.rs index ad84c997dc..009d482d1b 100644 --- a/src/agent/src/grpc.rs +++ b/src/agent/src/grpc.rs @@ -95,7 +95,7 @@ impl agentService { // updates the devices listed in the OCI spec, so that they actually // match real devices inside the VM. This step is necessary since we // cannot predict everything from the caller. - add_devices(req.devices.to_vec(), oci, self.sandbox.clone())?; + add_devices(&req.devices.to_vec(), oci, &self.sandbox)?; // Both rootfs and volumes (invoked with --volume for instance) will // be processed the same way. The idea is to always mount any provided diff --git a/src/agent/src/mount.rs b/src/agent/src/mount.rs index 31d4af0d92..aeeac2e185 100644 --- a/src/agent/src/mount.rs +++ b/src/agent/src/mount.rs @@ -341,7 +341,7 @@ fn virtio_blk_storage_handler( return Err(ErrorKind::ErrorCode(format!("Invalid device {}", &storage.source)).into()); } } else { - let dev_path = get_pci_device_name(sandbox, &storage.source)?; + let dev_path = get_pci_device_name(&sandbox, &storage.source)?; storage.source = dev_path; } @@ -357,7 +357,7 @@ fn virtio_scsi_storage_handler( let mut storage = storage.clone(); // Retrieve the device path from SCSI address. - let dev_path = get_scsi_device_name(sandbox, &storage.source)?; + let dev_path = get_scsi_device_name(&sandbox, &storage.source)?; storage.source = dev_path; common_storage_handler(logger, &storage)