diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 16d1ecc5be..750aa46691 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -17,7 +17,7 @@ use crate::linux_abi::*; use crate::mount::{DRIVER_BLK_TYPE, DRIVER_MMIO_BLK_TYPE, DRIVER_NVDIMM_TYPE, DRIVER_SCSI_TYPE}; use crate::pci; use crate::sandbox::Sandbox; -use crate::{AGENT_CONFIG, GLOBAL_DEVICE_WATCHER}; +use crate::uevent::{wait_for_uevent, Uevent, UeventMatcher}; use anyhow::{anyhow, Result}; use oci::{LinuxDeviceCgroup, LinuxResources, Spec}; use protocols::agent::Device; @@ -87,46 +87,55 @@ fn pcipath_to_sysfs(root_bus_sysfs: &str, pcipath: &pci::Path) -> Result Ok(relpath) } -async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Result { - // Keep the same lock order as uevent::handle_block_add_event(), otherwise it may cause deadlock. - let mut w = GLOBAL_DEVICE_WATCHER.lock().await; - let sb = sandbox.lock().await; - for (key, value) in sb.pci_device_map.iter() { - if key.contains(dev_addr) { - info!(sl!(), "Device {} found in pci device map", dev_addr); - return Ok(format!("{}/{}", SYSTEM_DEV_PATH, value)); +#[derive(Debug, Clone)] +struct DevAddrMatcher { + dev_addr: String, +} + +impl DevAddrMatcher { + fn new(dev_addr: &str) -> DevAddrMatcher { + DevAddrMatcher { + dev_addr: dev_addr.to_string(), } } - drop(sb); +} - // If device is not found in the device map, hotplug event has not - // been received yet, create and add channel to the watchers map. - // The key of the watchers map is the device we are interested in. - // Note this is done inside the lock, not to miss any events from the - // global udev listener. - let (tx, rx) = tokio::sync::oneshot::channel::(); - w.insert(dev_addr.to_string(), Some(tx)); - drop(w); +impl UeventMatcher for DevAddrMatcher { + fn is_match(&self, uev: &Uevent) -> bool { + let pci_root_bus_path = create_pci_root_bus_path(); + let pci_p = format!("{}{}", pci_root_bus_path, self.dev_addr); + let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, uev.devname); - info!(sl!(), "Waiting on channel for device notification\n"); - let hotplug_timeout = AGENT_CONFIG.read().await.hotplug_timeout; + uev.subsystem == "block" + && { + uev.devpath.starts_with(pci_root_bus_path.as_str()) + || uev.devpath.starts_with(ACPI_DEV_PATH) // NVDIMM/PMEM devices + } + && !uev.devname.is_empty() + && { + // blk block device + uev.devpath.starts_with(pci_p.as_str()) + // scsi block device + || ( + self.dev_addr.ends_with(SCSI_BLOCK_SUFFIX) && + uev.devpath.contains(self.dev_addr.as_str()) + ) + // nvdimm/pmem device + || ( + uev.devpath.starts_with(ACPI_DEV_PATH) && + uev.devpath.ends_with(pmem_suffix.as_str()) && + self.dev_addr.ends_with(pmem_suffix.as_str()) + ) + } + } +} - let dev_name = match tokio::time::timeout(hotplug_timeout, rx).await { - Ok(v) => v?, - Err(_) => { - let watcher = GLOBAL_DEVICE_WATCHER.clone(); - let mut w = watcher.lock().await; - w.remove_entry(dev_addr); +async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Result { + let matcher = DevAddrMatcher::new(dev_addr); - return Err(anyhow!( - "Timeout reached after {:?} waiting for device {}", - hotplug_timeout, - dev_addr - )); - } - }; + let uev = wait_for_uevent(sandbox, matcher).await?; - Ok(format!("{}/{}", SYSTEM_DEV_PATH, &dev_name)) + Ok(format!("{}/{}", SYSTEM_DEV_PATH, &uev.devname)) } pub async fn get_scsi_device_name( @@ -430,6 +439,7 @@ pub fn update_device_cgroup(spec: &mut Spec) -> Result<()> { #[cfg(test)] mod tests { use super::*; + use crate::uevent::spawn_test_watcher; use oci::Linux; use tempfile::tempdir; @@ -776,4 +786,39 @@ mod tests { let relpath = pcipath_to_sysfs(rootbuspath, &path234); assert_eq!(relpath.unwrap(), "/0000:00:02.0/0000:01:03.0/0000:02:04.0"); } + + #[tokio::test] + async fn test_get_device_name() { + let devname = "vda"; + let root_bus = create_pci_root_bus_path(); + let relpath = "/0000:00:0a.0/0000:03:0b.0"; + let devpath = format!("{}{}/virtio4/block/{}", root_bus, relpath, devname); + + let mut uev = crate::uevent::Uevent::default(); + uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); + uev.subsystem = "block".to_string(); + uev.devpath = devpath.clone(); + uev.devname = devname.to_string(); + + let logger = slog::Logger::root(slog::Discard, o!()); + let sandbox = Arc::new(Mutex::new(Sandbox::new(&logger).unwrap())); + + let mut sb = sandbox.lock().await; + sb.uevent_map.insert(devpath.clone(), uev); + drop(sb); // unlock + + let name = get_device_name(&sandbox, relpath).await; + assert!(name.is_ok(), "{}", name.unwrap_err()); + assert_eq!(name.unwrap(), format!("{}/{}", SYSTEM_DEV_PATH, devname)); + + let mut sb = sandbox.lock().await; + let uev = sb.uevent_map.remove(&devpath).unwrap(); + drop(sb); // unlock + + spawn_test_watcher(sandbox.clone(), uev); + + let name = get_device_name(&sandbox, relpath).await; + assert!(name.is_ok(), "{}", name.unwrap_err()); + assert_eq!(name.unwrap(), format!("{}/{}", SYSTEM_DEV_PATH, devname)); + } } diff --git a/src/agent/src/main.rs b/src/agent/src/main.rs index c86a03254a..d546ab8635 100644 --- a/src/agent/src/main.rs +++ b/src/agent/src/main.rs @@ -28,7 +28,6 @@ use nix::sys::select::{select, FdSet}; use nix::sys::socket::{self, AddressFamily, SockAddr, SockFlag, SockType}; use nix::sys::wait; use nix::unistd::{self, close, dup, dup2, fork, setsid, ForkResult}; -use std::collections::HashMap; use std::env; use std::ffi::{CStr, CString, OsStr}; use std::fs::{self, File}; @@ -72,7 +71,6 @@ use rustjail::pipestream::PipeStream; use tokio::{ io::AsyncWrite, sync::{ - oneshot::Sender, watch::{channel, Receiver}, Mutex, RwLock, }, @@ -89,8 +87,6 @@ const CONSOLE_PATH: &str = "/dev/console"; const DEFAULT_BUF_SIZE: usize = 8 * 1024; lazy_static! { - static ref GLOBAL_DEVICE_WATCHER: Arc>>>> = - Arc::new(Mutex::new(HashMap::new())); static ref AGENT_CONFIG: Arc> = Arc::new(RwLock::new(config::AgentConfig::new())); } diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index 9d5244ee39..b71fa5534a 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -8,6 +8,7 @@ use crate::mount::{get_mount_fs_type, remove_mounts, TYPE_ROOTFS}; use crate::namespace::Namespace; use crate::netlink::Handle; use crate::network::Network; +use crate::uevent::{Uevent, UeventMatcher}; use anyhow::{anyhow, Context, Result}; use libc::pid_t; use oci::{Hook, Hooks}; @@ -25,8 +26,11 @@ use std::path::Path; use std::sync::Arc; use std::{thread, time}; use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::oneshot; use tokio::sync::Mutex; +type UeventWatcher = (Box, oneshot::Sender); + #[derive(Debug)] pub struct Sandbox { pub logger: Logger, @@ -36,7 +40,8 @@ pub struct Sandbox { pub network: Network, pub mounts: Vec, pub container_mounts: HashMap>, - pub pci_device_map: HashMap, + pub uevent_map: HashMap, + pub uevent_watchers: Vec>, pub shared_utsns: Namespace, pub shared_ipcns: Namespace, pub sandbox_pidns: Option, @@ -65,7 +70,8 @@ impl Sandbox { containers: HashMap::new(), mounts: Vec::new(), container_mounts: HashMap::new(), - pci_device_map: HashMap::new(), + uevent_map: HashMap::new(), + uevent_watchers: Vec::new(), shared_utsns: Namespace::new(&logger), shared_ipcns: Namespace::new(&logger), sandbox_pidns: None, diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index c4067ed5ab..3d8e12b0d5 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -6,26 +6,38 @@ use crate::device::online_device; use crate::linux_abi::*; use crate::sandbox::Sandbox; -use crate::GLOBAL_DEVICE_WATCHER; +use crate::AGENT_CONFIG; use slog::Logger; -use anyhow::Result; +use anyhow::{anyhow, Result}; use netlink_sys::{protocols, SocketAddr, TokioSocket}; use nix::errno::Errno; +use std::fmt::Debug; use std::os::unix::io::FromRawFd; use std::sync::Arc; use tokio::select; use tokio::sync::watch::Receiver; use tokio::sync::Mutex; -#[derive(Debug, Default)] -struct Uevent { - action: String, - devpath: String, - devname: String, - subsystem: String, +// Convenience macro to obtain the scope logger +macro_rules! sl { + () => { + slog_scope::logger().new(o!("subsystem" => "uevent")) + }; +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct Uevent { + pub action: String, + pub devpath: String, + pub devname: String, + pub subsystem: String, seqnum: String, - interface: String, + pub interface: String, +} + +pub trait UeventMatcher: Sync + Send + Debug + 'static { + fn is_match(&self, uev: &Uevent) -> bool; } impl Uevent { @@ -52,89 +64,87 @@ impl Uevent { event } - // Check whether this is a block device hot-add event. - fn is_block_add_event(&self) -> bool { - let pci_root_bus_path = create_pci_root_bus_path(); - self.action == U_EVENT_ACTION_ADD - && self.subsystem == "block" - && { - self.devpath.starts_with(pci_root_bus_path.as_str()) - || self.devpath.starts_with(ACPI_DEV_PATH) // NVDIMM/PMEM devices - } - && !self.devname.is_empty() - } + async fn process_add(&self, logger: &Logger, sandbox: &Arc>) { + // Special case for memory hot-adds first + let online_path = format!("{}/{}/online", SYSFS_DIR, &self.devpath); + if online_path.starts_with(SYSFS_MEMORY_ONLINE_PATH) { + let _ = online_device(online_path.as_ref()).map_err(|e| { + error!( + *logger, + "failed to online device"; + "device" => &self.devpath, + "error" => format!("{}", e), + ) + }); + return; + } - async fn handle_block_add_event(&self, sandbox: &Arc>) { - let pci_root_bus_path = create_pci_root_bus_path(); - - // Keep the same lock order as device::get_device_name(), otherwise it may cause deadlock. - let watcher = GLOBAL_DEVICE_WATCHER.clone(); - let mut w = watcher.lock().await; let mut sb = sandbox.lock().await; - // Add the device node name to the pci device map. - sb.pci_device_map - .insert(self.devpath.clone(), self.devname.clone()); + // Record the event by sysfs path + sb.uevent_map.insert(self.devpath.clone(), self.clone()); // Notify watchers that are interested in the udev event. - // Close the channel after watcher has been notified. - let devpath = self.devpath.clone(); - let empties: Vec<_> = w - .iter_mut() - .filter(|(dev_addr, _)| { - let pci_p = format!("{}/{}", pci_root_bus_path, *dev_addr); - - // blk block device - devpath.starts_with(pci_p.as_str()) || - // scsi block device - { - (*dev_addr).ends_with(SCSI_BLOCK_SUFFIX) && - devpath.contains(*dev_addr) - } || - // nvdimm/pmem device - { - let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, self.devname); - devpath.starts_with(ACPI_DEV_PATH) && - devpath.ends_with(pmem_suffix.as_str()) && - dev_addr.ends_with(pmem_suffix.as_str()) + for watch in &mut sb.uevent_watchers { + if let Some((matcher, _)) = watch { + if matcher.is_match(&self) { + let (_, sender) = watch.take().unwrap(); + let _ = sender.send(self.clone()); } - }) - .map(|(k, sender)| { - let devname = self.devname.clone(); - let sender = sender.take().unwrap(); - let _ = sender.send(devname); - k.clone() - }) - .collect(); - - // Remove notified nodes from the watcher map. - for empty in empties { - w.remove(&empty); + } } } async fn process(&self, logger: &Logger, sandbox: &Arc>) { - if self.is_block_add_event() { - return self.handle_block_add_event(sandbox).await; - } else if self.action == U_EVENT_ACTION_ADD { - let online_path = format!("{}/{}/online", SYSFS_DIR, &self.devpath); - // It's a memory hot-add event. - if online_path.starts_with(SYSFS_MEMORY_ONLINE_PATH) { - let _ = online_device(online_path.as_ref()).map_err(|e| { - error!( - *logger, - "failed to online device"; - "device" => &self.devpath, - "error" => format!("{}", e), - ) - }); - return; - } + if self.action == U_EVENT_ACTION_ADD { + return self.process_add(logger, sandbox).await; } debug!(*logger, "ignoring event"; "uevent" => format!("{:?}", self)); } } +pub async fn wait_for_uevent( + sandbox: &Arc>, + matcher: impl UeventMatcher, +) -> Result { + let mut sb = sandbox.lock().await; + for uev in sb.uevent_map.values() { + if matcher.is_match(uev) { + info!(sl!(), "Device {:?} found in pci device map", uev); + return Ok(uev.clone()); + } + } + + // If device is not found in the device map, hotplug event has not + // been received yet, create and add channel to the watchers map. + // The key of the watchers map is the device we are interested in. + // Note this is done inside the lock, not to miss any events from the + // global udev listener. + let (tx, rx) = tokio::sync::oneshot::channel::(); + let idx = sb.uevent_watchers.len(); + sb.uevent_watchers.push(Some((Box::new(matcher), tx))); + drop(sb); // unlock + + info!(sl!(), "Waiting on channel for uevent notification\n"); + let hotplug_timeout = AGENT_CONFIG.read().await.hotplug_timeout; + + let uev = match tokio::time::timeout(hotplug_timeout, rx).await { + Ok(v) => v?, + Err(_) => { + let mut sb = sandbox.lock().await; + let matcher = sb.uevent_watchers[idx].take().unwrap().0; + + return Err(anyhow!( + "Timeout after {:?} waiting for uevent {:?}", + hotplug_timeout, + &matcher + )); + } + }; + + Ok(uev) +} + pub async fn watch_uevents( sandbox: Arc>, mut shutdown: Receiver, @@ -199,3 +209,69 @@ pub async fn watch_uevents( Ok(()) } + +// Used in the device module unit tests +#[cfg(test)] +pub(crate) fn spawn_test_watcher(sandbox: Arc>, uev: Uevent) { + tokio::spawn(async move { + loop { + let mut sb = sandbox.lock().await; + for w in &mut sb.uevent_watchers { + if let Some((matcher, _)) = w { + if matcher.is_match(&uev) { + let (_, sender) = w.take().unwrap(); + let _ = sender.send(uev); + return; + } + } + } + drop(sb); // unlock + } + }); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, Clone, Copy)] + struct AlwaysMatch(); + + impl UeventMatcher for AlwaysMatch { + fn is_match(&self, _: &Uevent) -> bool { + true + } + } + + #[tokio::test] + async fn test_wait_for_uevent() { + let mut uev = crate::uevent::Uevent::default(); + uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); + uev.subsystem = "test".to_string(); + uev.devpath = "/test/sysfs/path".to_string(); + uev.devname = "testdevname".to_string(); + + let matcher = AlwaysMatch(); + + let logger = slog::Logger::root(slog::Discard, o!()); + let sandbox = Arc::new(Mutex::new(Sandbox::new(&logger).unwrap())); + + let mut sb = sandbox.lock().await; + sb.uevent_map.insert(uev.devpath.clone(), uev.clone()); + drop(sb); // unlock + + let uev2 = wait_for_uevent(&sandbox, matcher).await; + assert!(uev2.is_ok()); + assert_eq!(uev2.unwrap(), uev); + + let mut sb = sandbox.lock().await; + sb.uevent_map.remove(&uev.devpath).unwrap(); + drop(sb); // unlock + + spawn_test_watcher(sandbox.clone(), uev.clone()); + + let uev2 = wait_for_uevent(&sandbox, matcher).await; + assert!(uev2.is_ok()); + assert_eq!(uev2.unwrap(), uev); + } +}