diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index bbb70bc65a..750aa46691 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -17,8 +17,7 @@ use crate::linux_abi::*; use crate::mount::{DRIVER_BLK_TYPE, DRIVER_MMIO_BLK_TYPE, DRIVER_NVDIMM_TYPE, DRIVER_SCSI_TYPE}; use crate::pci; use crate::sandbox::Sandbox; -use crate::uevent::{Uevent, UeventMatcher}; -use crate::AGENT_CONFIG; +use crate::uevent::{wait_for_uevent, Uevent, UeventMatcher}; use anyhow::{anyhow, Result}; use oci::{LinuxDeviceCgroup, LinuxResources, Spec}; use protocols::agent::Device; @@ -88,7 +87,7 @@ fn pcipath_to_sysfs(root_bus_sysfs: &str, pcipath: &pci::Path) -> Result Ok(relpath) } -#[derive(Debug)] +#[derive(Debug, Clone)] struct DevAddrMatcher { dev_addr: String, } @@ -134,40 +133,7 @@ impl UeventMatcher for DevAddrMatcher { async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Result { let matcher = DevAddrMatcher::new(dev_addr); - let mut sb = sandbox.lock().await; - for uev in sb.uevent_map.values() { - if matcher.is_match(uev) { - info!(sl!(), "Device {} found in pci device map", dev_addr); - return Ok(format!("{}/{}", SYSTEM_DEV_PATH, uev.devname)); - } - } - - // If device is not found in the device map, hotplug event has not - // been received yet, create and add channel to the watchers map. - // The key of the watchers map is the device we are interested in. - // Note this is done inside the lock, not to miss any events from the - // global udev listener. - let (tx, rx) = tokio::sync::oneshot::channel::(); - let idx = sb.uevent_watchers.len(); - sb.uevent_watchers.push(Some((Box::new(matcher), tx))); - drop(sb); // unlock - - info!(sl!(), "Waiting on channel for device notification\n"); - let hotplug_timeout = AGENT_CONFIG.read().await.hotplug_timeout; - - let uev = match tokio::time::timeout(hotplug_timeout, rx).await { - Ok(v) => v?, - Err(_) => { - let mut sb = sandbox.lock().await; - sb.uevent_watchers[idx].take(); - - return Err(anyhow!( - "Timeout reached after {:?} waiting for device {}", - hotplug_timeout, - dev_addr - )); - } - }; + let uev = wait_for_uevent(sandbox, matcher).await?; Ok(format!("{}/{}", SYSTEM_DEV_PATH, &uev.devname)) } @@ -473,6 +439,7 @@ pub fn update_device_cgroup(spec: &mut Spec) -> Result<()> { #[cfg(test)] mod tests { use super::*; + use crate::uevent::spawn_test_watcher; use oci::Linux; use tempfile::tempdir; @@ -848,22 +815,7 @@ mod tests { let uev = sb.uevent_map.remove(&devpath).unwrap(); drop(sb); // unlock - let watcher_sandbox = Arc::clone(&sandbox); - tokio::spawn(async move { - loop { - let mut sb = watcher_sandbox.lock().await; - for w in &mut sb.uevent_watchers { - if let Some((matcher, _)) = w { - if matcher.is_match(&uev) { - let (_, sender) = w.take().unwrap(); - let _ = sender.send(uev); - return; - } - } - } - drop(sb); // unlock - } - }); + spawn_test_watcher(sandbox.clone(), uev); let name = get_device_name(&sandbox, relpath).await; assert!(name.is_ok(), "{}", name.unwrap_err()); diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index 4d1fb8db70..3d8e12b0d5 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -6,9 +6,10 @@ use crate::device::online_device; use crate::linux_abi::*; use crate::sandbox::Sandbox; +use crate::AGENT_CONFIG; use slog::Logger; -use anyhow::Result; +use anyhow::{anyhow, Result}; use netlink_sys::{protocols, SocketAddr, TokioSocket}; use nix::errno::Errno; use std::fmt::Debug; @@ -18,7 +19,14 @@ use tokio::select; use tokio::sync::watch::Receiver; use tokio::sync::Mutex; -#[derive(Debug, Default, Clone)] +// Convenience macro to obtain the scope logger +macro_rules! sl { + () => { + slog_scope::logger().new(o!("subsystem" => "uevent")) + }; +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct Uevent { pub action: String, pub devpath: String, @@ -95,6 +103,48 @@ impl Uevent { } } +pub async fn wait_for_uevent( + sandbox: &Arc>, + matcher: impl UeventMatcher, +) -> Result { + let mut sb = sandbox.lock().await; + for uev in sb.uevent_map.values() { + if matcher.is_match(uev) { + info!(sl!(), "Device {:?} found in pci device map", uev); + return Ok(uev.clone()); + } + } + + // If device is not found in the device map, hotplug event has not + // been received yet, create and add channel to the watchers map. + // The key of the watchers map is the device we are interested in. + // Note this is done inside the lock, not to miss any events from the + // global udev listener. + let (tx, rx) = tokio::sync::oneshot::channel::(); + let idx = sb.uevent_watchers.len(); + sb.uevent_watchers.push(Some((Box::new(matcher), tx))); + drop(sb); // unlock + + info!(sl!(), "Waiting on channel for uevent notification\n"); + let hotplug_timeout = AGENT_CONFIG.read().await.hotplug_timeout; + + let uev = match tokio::time::timeout(hotplug_timeout, rx).await { + Ok(v) => v?, + Err(_) => { + let mut sb = sandbox.lock().await; + let matcher = sb.uevent_watchers[idx].take().unwrap().0; + + return Err(anyhow!( + "Timeout after {:?} waiting for uevent {:?}", + hotplug_timeout, + &matcher + )); + } + }; + + Ok(uev) +} + pub async fn watch_uevents( sandbox: Arc>, mut shutdown: Receiver, @@ -159,3 +209,69 @@ pub async fn watch_uevents( Ok(()) } + +// Used in the device module unit tests +#[cfg(test)] +pub(crate) fn spawn_test_watcher(sandbox: Arc>, uev: Uevent) { + tokio::spawn(async move { + loop { + let mut sb = sandbox.lock().await; + for w in &mut sb.uevent_watchers { + if let Some((matcher, _)) = w { + if matcher.is_match(&uev) { + let (_, sender) = w.take().unwrap(); + let _ = sender.send(uev); + return; + } + } + } + drop(sb); // unlock + } + }); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, Clone, Copy)] + struct AlwaysMatch(); + + impl UeventMatcher for AlwaysMatch { + fn is_match(&self, _: &Uevent) -> bool { + true + } + } + + #[tokio::test] + async fn test_wait_for_uevent() { + let mut uev = crate::uevent::Uevent::default(); + uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); + uev.subsystem = "test".to_string(); + uev.devpath = "/test/sysfs/path".to_string(); + uev.devname = "testdevname".to_string(); + + let matcher = AlwaysMatch(); + + let logger = slog::Logger::root(slog::Discard, o!()); + let sandbox = Arc::new(Mutex::new(Sandbox::new(&logger).unwrap())); + + let mut sb = sandbox.lock().await; + sb.uevent_map.insert(uev.devpath.clone(), uev.clone()); + drop(sb); // unlock + + let uev2 = wait_for_uevent(&sandbox, matcher).await; + assert!(uev2.is_ok()); + assert_eq!(uev2.unwrap(), uev); + + let mut sb = sandbox.lock().await; + sb.uevent_map.remove(&uev.devpath).unwrap(); + drop(sb); // unlock + + spawn_test_watcher(sandbox.clone(), uev.clone()); + + let uev2 = wait_for_uevent(&sandbox, matcher).await; + assert!(uev2.is_ok()); + assert_eq!(uev2.unwrap(), uev); + } +}