diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 8ce34a64d6..9d3180b0c5 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -17,7 +17,7 @@ use crate::linux_abi::*; use crate::mount::{DRIVER_BLK_TYPE, DRIVER_MMIO_BLK_TYPE, DRIVER_NVDIMM_TYPE, DRIVER_SCSI_TYPE}; use crate::pci; use crate::sandbox::Sandbox; -use crate::{AGENT_CONFIG, GLOBAL_DEVICE_WATCHER}; +use crate::AGENT_CONFIG; use anyhow::{anyhow, Result}; use oci::{LinuxDeviceCgroup, LinuxResources, Spec}; use protocols::agent::Device; @@ -88,16 +88,13 @@ fn pcipath_to_sysfs(root_bus_sysfs: &str, pcipath: &pci::Path) -> Result } async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Result { - // Keep the same lock order as uevent::handle_block_add_event(), otherwise it may cause deadlock. - let mut w = GLOBAL_DEVICE_WATCHER.lock().await; - let sb = sandbox.lock().await; + let mut sb = sandbox.lock().await; for (key, value) in sb.pci_device_map.iter() { if key.contains(dev_addr) { info!(sl!(), "Device {} found in pci device map", dev_addr); return Ok(format!("{}/{}", SYSTEM_DEV_PATH, value)); } } - drop(sb); // If device is not found in the device map, hotplug event has not // been received yet, create and add channel to the watchers map. @@ -105,8 +102,8 @@ async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Resul // Note this is done inside the lock, not to miss any events from the // global udev listener. let (tx, rx) = tokio::sync::oneshot::channel::(); - w.insert(dev_addr.to_string(), Some(tx)); - drop(w); + sb.dev_watcher.insert(dev_addr.to_string(), tx); + drop(sb); // unlock info!(sl!(), "Waiting on channel for device notification\n"); let hotplug_timeout = AGENT_CONFIG.read().await.hotplug_timeout; @@ -114,9 +111,8 @@ async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Resul let dev_name = match tokio::time::timeout(hotplug_timeout, rx).await { Ok(v) => v?, Err(_) => { - let watcher = GLOBAL_DEVICE_WATCHER.clone(); - let mut w = watcher.lock().await; - w.remove_entry(dev_addr); + let mut sb = sandbox.lock().await; + sb.dev_watcher.remove_entry(dev_addr); return Err(anyhow!( "Timeout reached after {:?} waiting for device {}", @@ -800,21 +796,23 @@ mod tests { sb.pci_device_map.remove(&devpath); drop(sb); // unlock + let watcher_sandbox = Arc::clone(&sandbox); tokio::spawn(async move { loop { - let mut w = GLOBAL_DEVICE_WATCHER.lock().await; - let matched_key = w + let mut sb = watcher_sandbox.lock().await; + let matched_key = sb + .dev_watcher .keys() .filter(|dev_addr| devpath.contains(*dev_addr)) .cloned() .next(); if let Some(k) = matched_key { - let sender = w.remove(&k).unwrap().unwrap(); + let sender = sb.dev_watcher.remove(&k).unwrap(); let _ = sender.send(devname.to_string()); return; } - drop(w); // unlock + drop(sb); // unlock } }); diff --git a/src/agent/src/main.rs b/src/agent/src/main.rs index c86a03254a..d546ab8635 100644 --- a/src/agent/src/main.rs +++ b/src/agent/src/main.rs @@ -28,7 +28,6 @@ use nix::sys::select::{select, FdSet}; use nix::sys::socket::{self, AddressFamily, SockAddr, SockFlag, SockType}; use nix::sys::wait; use nix::unistd::{self, close, dup, dup2, fork, setsid, ForkResult}; -use std::collections::HashMap; use std::env; use std::ffi::{CStr, CString, OsStr}; use std::fs::{self, File}; @@ -72,7 +71,6 @@ use rustjail::pipestream::PipeStream; use tokio::{ io::AsyncWrite, sync::{ - oneshot::Sender, watch::{channel, Receiver}, Mutex, RwLock, }, @@ -89,8 +87,6 @@ const CONSOLE_PATH: &str = "/dev/console"; const DEFAULT_BUF_SIZE: usize = 8 * 1024; lazy_static! { - static ref GLOBAL_DEVICE_WATCHER: Arc>>>> = - Arc::new(Mutex::new(HashMap::new())); static ref AGENT_CONFIG: Arc> = Arc::new(RwLock::new(config::AgentConfig::new())); } diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index 9d5244ee39..4ec3b3d6ee 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -37,6 +37,7 @@ pub struct Sandbox { pub mounts: Vec, pub container_mounts: HashMap>, pub pci_device_map: HashMap, + pub dev_watcher: HashMap>, pub shared_utsns: Namespace, pub shared_ipcns: Namespace, pub sandbox_pidns: Option, @@ -66,6 +67,7 @@ impl Sandbox { mounts: Vec::new(), container_mounts: HashMap::new(), pci_device_map: HashMap::new(), + dev_watcher: HashMap::new(), shared_utsns: Namespace::new(&logger), shared_ipcns: Namespace::new(&logger), sandbox_pidns: None, diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index b534282315..1c52863b3a 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -6,7 +6,6 @@ use crate::device::online_device; use crate::linux_abi::*; use crate::sandbox::Sandbox; -use crate::GLOBAL_DEVICE_WATCHER; use slog::Logger; use anyhow::Result; @@ -66,10 +65,6 @@ impl Uevent { async fn handle_block_add_event(&self, sandbox: &Arc>) { let pci_root_bus_path = create_pci_root_bus_path(); - - // Keep the same lock order as device::get_device_name(), otherwise it may cause deadlock. - let watcher = GLOBAL_DEVICE_WATCHER.clone(); - let mut w = watcher.lock().await; let mut sb = sandbox.lock().await; // Add the device node name to the pci device map. @@ -79,9 +74,10 @@ impl Uevent { // Notify watchers that are interested in the udev event. // Close the channel after watcher has been notified. let devpath = self.devpath.clone(); - let empties: Vec<_> = w - .iter_mut() - .filter(|(dev_addr, _)| { + let keys: Vec<_> = sb + .dev_watcher + .keys() + .filter(|dev_addr| { let pci_p = format!("{}{}", pci_root_bus_path, *dev_addr); // blk block device @@ -99,17 +95,16 @@ impl Uevent { dev_addr.ends_with(pmem_suffix.as_str()) } }) - .map(|(k, sender)| { - let devname = self.devname.clone(); - let sender = sender.take().unwrap(); - let _ = sender.send(devname); - k.clone() - }) + .cloned() .collect(); - // Remove notified nodes from the watcher map. - for empty in empties { - w.remove(&empty); + for k in keys { + let devname = self.devname.clone(); + // unwrap() is safe because logic above ensures k exists + // in the map, and it's locked so no-one else can change + // that + let sender = sb.dev_watcher.remove(&k).unwrap(); + let _ = sender.send(devname); } }