From e3e670c56f82412250ef7978b6634f43beb34977 Mon Sep 17 00:00:00 2001 From: David Gibson Date: Fri, 12 Feb 2021 15:13:53 +1100 Subject: [PATCH 01/12] agent/device: Forward port test for get_device_name() from Kata 1.x Kata 1.x had a testcase for the equivalent getDeviceName function in Go, this adapts it to Rust and adds it to Kata 2.x. Signed-off-by: David Gibson --- src/agent/src/device.rs | 45 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 16d1ecc5be..e08f91cc43 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -776,4 +776,49 @@ mod tests { let relpath = pcipath_to_sysfs(rootbuspath, &path234); assert_eq!(relpath.unwrap(), "/0000:00:02.0/0000:01:03.0/0000:02:04.0"); } + + #[tokio::test] + async fn test_get_device_name() { + let devname = "vda"; + let busid = "0.0.0005"; + let devpath = format!("/dev/ces/css0/0.0.0004/{}/virtio4/block/{}", busid, devname); + + let logger = slog::Logger::root(slog::Discard, o!()); + let sandbox = Arc::new(Mutex::new(Sandbox::new(&logger).unwrap())); + + let mut sb = sandbox.lock().await; + sb.pci_device_map + .insert(devpath.clone(), devname.to_string()); + drop(sb); // unlock + + let name = get_device_name(&sandbox, busid).await; + assert!(name.is_ok(), "{}", name.unwrap_err()); + assert_eq!(name.unwrap(), format!("{}/{}", SYSTEM_DEV_PATH, devname)); + + let mut sb = sandbox.lock().await; + sb.pci_device_map.remove(&devpath); + drop(sb); // unlock + + tokio::spawn(async move { + loop { + let mut w = GLOBAL_DEVICE_WATCHER.lock().await; + let matched_key = w + .keys() + .filter(|dev_addr| devpath.contains(*dev_addr)) + .cloned() + .next(); + + if let Some(k) = matched_key { + let sender = w.remove(&k).unwrap().unwrap(); + let _ = sender.send(devname.to_string()); + return; + } + drop(w); // unlock + } + }); + + let name = get_device_name(&sandbox, busid).await; + assert!(name.is_ok(), "{}", name.unwrap_err()); + assert_eq!(name.unwrap(), format!("{}/{}", SYSTEM_DEV_PATH, devname)); + } } From 4f60880414f8d9a8fb682d680bd8a7ba30910bbe Mon Sep 17 00:00:00 2001 From: David Gibson Date: Tue, 6 Apr 2021 20:49:48 +1000 Subject: [PATCH 02/12] agent/device: Update test_get_device_name() The current test_get_device_name(), ported from Kata 1.x doesn't really reflect how the function is used in practice. The example path appears to be for a virtio-blk device, but it's an s390 specific variant, not a PCI device. The s390 form isn't actually supported by any of the existing users of get_device_name(). Change it to a plausible virtio-blk-pci style path to better test how get_device_name() will actually be used in practice. Signed-off-by: David Gibson --- src/agent/src/device.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index e08f91cc43..8ce34a64d6 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -780,8 +780,9 @@ mod tests { #[tokio::test] async fn test_get_device_name() { let devname = "vda"; - let busid = "0.0.0005"; - let devpath = format!("/dev/ces/css0/0.0.0004/{}/virtio4/block/{}", busid, devname); + let root_bus = create_pci_root_bus_path(); + let relpath = "/0000:00:0a.0/0000:03:0b.0"; + let devpath = format!("{}{}/virtio4/block/{}", root_bus, relpath, devname); let logger = slog::Logger::root(slog::Discard, o!()); let sandbox = Arc::new(Mutex::new(Sandbox::new(&logger).unwrap())); @@ -791,7 +792,7 @@ mod tests { .insert(devpath.clone(), devname.to_string()); drop(sb); // unlock - let name = get_device_name(&sandbox, busid).await; + let name = get_device_name(&sandbox, relpath).await; assert!(name.is_ok(), "{}", name.unwrap_err()); assert_eq!(name.unwrap(), format!("{}/{}", SYSTEM_DEV_PATH, devname)); @@ -817,7 +818,7 @@ mod tests { } }); - let name = get_device_name(&sandbox, busid).await; + let name = get_device_name(&sandbox, relpath).await; assert!(name.is_ok(), "{}", name.unwrap_err()); assert_eq!(name.unwrap(), format!("{}/{}", SYSTEM_DEV_PATH, devname)); } From 11ae32e3c04906b38faf9a263a2bec3b30d7c75a Mon Sep 17 00:00:00 2001 From: David Gibson Date: Tue, 6 Apr 2021 20:58:06 +1000 Subject: [PATCH 03/12] agent/device: Fix path matching for PCI devices For the case of virtio-blk PCI devices, when matching uevents we create a pci_p temporary. However, we build it incorrectly: the dev_addr values we use for PCI devices are a relative sysfs paths from the PCI root to the device in question *including an initial /*. But when we construct pci_p we add an extra /, meaning the resulting path will *not* match properly. AFAICT the only reason we got away with this is because in practice the virtio-blk devices where discovered by the kernel before we looked for them meaning the loosed matching in get_device_name() was used, rather than the pci_p logic in handle_block_add_event(). Signed-off-by: David Gibson --- src/agent/src/uevent.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index c4067ed5ab..b534282315 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -82,7 +82,7 @@ impl Uevent { let empties: Vec<_> = w .iter_mut() .filter(|(dev_addr, _)| { - let pci_p = format!("{}/{}", pci_root_bus_path, *dev_addr); + let pci_p = format!("{}{}", pci_root_bus_path, *dev_addr); // blk block device devpath.starts_with(pci_p.as_str()) || From 0616202580f4323bcbc9ecd6d3bceb45191d32b8 Mon Sep 17 00:00:00 2001 From: David Gibson Date: Thu, 25 Feb 2021 11:32:50 +1100 Subject: [PATCH 04/12] agent/device: Move GLOBAL_DEVICE_WATCHER into Sandbox In Kata 1.x, both the sysToDevMap and the deviceWatchers are in the sandbox structure. For some reason in Kata 2.x, the device watchers have moved to a separate global variable, GLOBAL_DEVICE_WATCHER. This is a bad idea: apart from introducing an extra global variable unnecessarily, it means that Sandbox::pci_device_map and GLOBAL_DEVICE_WATCHER are protected by separate mutexes. Since the information in these two structures has to be kept in sync with each other, it makes much more sense to keep them both under the same single Sandbox mutex. Signed-off-by: David Gibson --- src/agent/src/device.rs | 26 ++++++++++++-------------- src/agent/src/main.rs | 4 ---- src/agent/src/sandbox.rs | 2 ++ src/agent/src/uevent.rs | 29 ++++++++++++----------------- 4 files changed, 26 insertions(+), 35 deletions(-) diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 8ce34a64d6..9d3180b0c5 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -17,7 +17,7 @@ use crate::linux_abi::*; use crate::mount::{DRIVER_BLK_TYPE, DRIVER_MMIO_BLK_TYPE, DRIVER_NVDIMM_TYPE, DRIVER_SCSI_TYPE}; use crate::pci; use crate::sandbox::Sandbox; -use crate::{AGENT_CONFIG, GLOBAL_DEVICE_WATCHER}; +use crate::AGENT_CONFIG; use anyhow::{anyhow, Result}; use oci::{LinuxDeviceCgroup, LinuxResources, Spec}; use protocols::agent::Device; @@ -88,16 +88,13 @@ fn pcipath_to_sysfs(root_bus_sysfs: &str, pcipath: &pci::Path) -> Result } async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Result { - // Keep the same lock order as uevent::handle_block_add_event(), otherwise it may cause deadlock. - let mut w = GLOBAL_DEVICE_WATCHER.lock().await; - let sb = sandbox.lock().await; + let mut sb = sandbox.lock().await; for (key, value) in sb.pci_device_map.iter() { if key.contains(dev_addr) { info!(sl!(), "Device {} found in pci device map", dev_addr); return Ok(format!("{}/{}", SYSTEM_DEV_PATH, value)); } } - drop(sb); // If device is not found in the device map, hotplug event has not // been received yet, create and add channel to the watchers map. @@ -105,8 +102,8 @@ async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Resul // Note this is done inside the lock, not to miss any events from the // global udev listener. let (tx, rx) = tokio::sync::oneshot::channel::(); - w.insert(dev_addr.to_string(), Some(tx)); - drop(w); + sb.dev_watcher.insert(dev_addr.to_string(), tx); + drop(sb); // unlock info!(sl!(), "Waiting on channel for device notification\n"); let hotplug_timeout = AGENT_CONFIG.read().await.hotplug_timeout; @@ -114,9 +111,8 @@ async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Resul let dev_name = match tokio::time::timeout(hotplug_timeout, rx).await { Ok(v) => v?, Err(_) => { - let watcher = GLOBAL_DEVICE_WATCHER.clone(); - let mut w = watcher.lock().await; - w.remove_entry(dev_addr); + let mut sb = sandbox.lock().await; + sb.dev_watcher.remove_entry(dev_addr); return Err(anyhow!( "Timeout reached after {:?} waiting for device {}", @@ -800,21 +796,23 @@ mod tests { sb.pci_device_map.remove(&devpath); drop(sb); // unlock + let watcher_sandbox = Arc::clone(&sandbox); tokio::spawn(async move { loop { - let mut w = GLOBAL_DEVICE_WATCHER.lock().await; - let matched_key = w + let mut sb = watcher_sandbox.lock().await; + let matched_key = sb + .dev_watcher .keys() .filter(|dev_addr| devpath.contains(*dev_addr)) .cloned() .next(); if let Some(k) = matched_key { - let sender = w.remove(&k).unwrap().unwrap(); + let sender = sb.dev_watcher.remove(&k).unwrap(); let _ = sender.send(devname.to_string()); return; } - drop(w); // unlock + drop(sb); // unlock } }); diff --git a/src/agent/src/main.rs b/src/agent/src/main.rs index c86a03254a..d546ab8635 100644 --- a/src/agent/src/main.rs +++ b/src/agent/src/main.rs @@ -28,7 +28,6 @@ use nix::sys::select::{select, FdSet}; use nix::sys::socket::{self, AddressFamily, SockAddr, SockFlag, SockType}; use nix::sys::wait; use nix::unistd::{self, close, dup, dup2, fork, setsid, ForkResult}; -use std::collections::HashMap; use std::env; use std::ffi::{CStr, CString, OsStr}; use std::fs::{self, File}; @@ -72,7 +71,6 @@ use rustjail::pipestream::PipeStream; use tokio::{ io::AsyncWrite, sync::{ - oneshot::Sender, watch::{channel, Receiver}, Mutex, RwLock, }, @@ -89,8 +87,6 @@ const CONSOLE_PATH: &str = "/dev/console"; const DEFAULT_BUF_SIZE: usize = 8 * 1024; lazy_static! { - static ref GLOBAL_DEVICE_WATCHER: Arc>>>> = - Arc::new(Mutex::new(HashMap::new())); static ref AGENT_CONFIG: Arc> = Arc::new(RwLock::new(config::AgentConfig::new())); } diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index 9d5244ee39..4ec3b3d6ee 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -37,6 +37,7 @@ pub struct Sandbox { pub mounts: Vec, pub container_mounts: HashMap>, pub pci_device_map: HashMap, + pub dev_watcher: HashMap>, pub shared_utsns: Namespace, pub shared_ipcns: Namespace, pub sandbox_pidns: Option, @@ -66,6 +67,7 @@ impl Sandbox { mounts: Vec::new(), container_mounts: HashMap::new(), pci_device_map: HashMap::new(), + dev_watcher: HashMap::new(), shared_utsns: Namespace::new(&logger), shared_ipcns: Namespace::new(&logger), sandbox_pidns: None, diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index b534282315..1c52863b3a 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -6,7 +6,6 @@ use crate::device::online_device; use crate::linux_abi::*; use crate::sandbox::Sandbox; -use crate::GLOBAL_DEVICE_WATCHER; use slog::Logger; use anyhow::Result; @@ -66,10 +65,6 @@ impl Uevent { async fn handle_block_add_event(&self, sandbox: &Arc>) { let pci_root_bus_path = create_pci_root_bus_path(); - - // Keep the same lock order as device::get_device_name(), otherwise it may cause deadlock. - let watcher = GLOBAL_DEVICE_WATCHER.clone(); - let mut w = watcher.lock().await; let mut sb = sandbox.lock().await; // Add the device node name to the pci device map. @@ -79,9 +74,10 @@ impl Uevent { // Notify watchers that are interested in the udev event. // Close the channel after watcher has been notified. let devpath = self.devpath.clone(); - let empties: Vec<_> = w - .iter_mut() - .filter(|(dev_addr, _)| { + let keys: Vec<_> = sb + .dev_watcher + .keys() + .filter(|dev_addr| { let pci_p = format!("{}{}", pci_root_bus_path, *dev_addr); // blk block device @@ -99,17 +95,16 @@ impl Uevent { dev_addr.ends_with(pmem_suffix.as_str()) } }) - .map(|(k, sender)| { - let devname = self.devname.clone(); - let sender = sender.take().unwrap(); - let _ = sender.send(devname); - k.clone() - }) + .cloned() .collect(); - // Remove notified nodes from the watcher map. - for empty in empties { - w.remove(&empty); + for k in keys { + let devname = self.devname.clone(); + // unwrap() is safe because logic above ensures k exists + // in the map, and it's locked so no-one else can change + // that + let sender = sb.dev_watcher.remove(&k).unwrap(); + let _ = sender.send(devname); } } From 3642005479025b0342f596b784a0843368438497 Mon Sep 17 00:00:00 2001 From: David Gibson Date: Wed, 3 Mar 2021 16:41:35 +1100 Subject: [PATCH 05/12] agent: Store whole Uevent in map, rather than just /dev name Sandbox::pci_device_map contains a mapping from sysfs paths to /dev entries which is used by get_device_name() to look up the right /dev node. But, the map only supplies the answer if the uevent for the device has already been received, otherwise get_device_name() has to wait for it. However the matching for already-received and yet-to-come uevents isn't quite the same which makes the whole system fragile. In order to make sure the matching for both cases is identical, we need the already-received side to store the whole uevent to match against, not just the sysfs path and device name. So, rename pci_device_map to uevent_map and store the whole uevent there verbatim. Signed-off-by: David Gibson --- src/agent/src/device.rs | 13 ++++++++----- src/agent/src/sandbox.rs | 5 +++-- src/agent/src/uevent.rs | 19 +++++++++---------- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 9d3180b0c5..d6117839ce 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -89,10 +89,10 @@ fn pcipath_to_sysfs(root_bus_sysfs: &str, pcipath: &pci::Path) -> Result async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Result { let mut sb = sandbox.lock().await; - for (key, value) in sb.pci_device_map.iter() { + for (key, uev) in sb.uevent_map.iter() { if key.contains(dev_addr) { info!(sl!(), "Device {} found in pci device map", dev_addr); - return Ok(format!("{}/{}", SYSTEM_DEV_PATH, value)); + return Ok(format!("{}/{}", SYSTEM_DEV_PATH, uev.devname)); } } @@ -780,12 +780,15 @@ mod tests { let relpath = "/0000:00:0a.0/0000:03:0b.0"; let devpath = format!("{}{}/virtio4/block/{}", root_bus, relpath, devname); + let mut uev = crate::uevent::Uevent::default(); + uev.devpath = devpath.clone(); + uev.devname = devname.to_string(); + let logger = slog::Logger::root(slog::Discard, o!()); let sandbox = Arc::new(Mutex::new(Sandbox::new(&logger).unwrap())); let mut sb = sandbox.lock().await; - sb.pci_device_map - .insert(devpath.clone(), devname.to_string()); + sb.uevent_map.insert(devpath.clone(), uev); drop(sb); // unlock let name = get_device_name(&sandbox, relpath).await; @@ -793,7 +796,7 @@ mod tests { assert_eq!(name.unwrap(), format!("{}/{}", SYSTEM_DEV_PATH, devname)); let mut sb = sandbox.lock().await; - sb.pci_device_map.remove(&devpath); + sb.uevent_map.remove(&devpath); drop(sb); // unlock let watcher_sandbox = Arc::clone(&sandbox); diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index 4ec3b3d6ee..8b1c4561ca 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -8,6 +8,7 @@ use crate::mount::{get_mount_fs_type, remove_mounts, TYPE_ROOTFS}; use crate::namespace::Namespace; use crate::netlink::Handle; use crate::network::Network; +use crate::uevent::Uevent; use anyhow::{anyhow, Context, Result}; use libc::pid_t; use oci::{Hook, Hooks}; @@ -36,7 +37,7 @@ pub struct Sandbox { pub network: Network, pub mounts: Vec, pub container_mounts: HashMap>, - pub pci_device_map: HashMap, + pub uevent_map: HashMap, pub dev_watcher: HashMap>, pub shared_utsns: Namespace, pub shared_ipcns: Namespace, @@ -66,7 +67,7 @@ impl Sandbox { containers: HashMap::new(), mounts: Vec::new(), container_mounts: HashMap::new(), - pci_device_map: HashMap::new(), + uevent_map: HashMap::new(), dev_watcher: HashMap::new(), shared_utsns: Namespace::new(&logger), shared_ipcns: Namespace::new(&logger), diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index 1c52863b3a..78bc400abd 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -17,14 +17,14 @@ use tokio::select; use tokio::sync::watch::Receiver; use tokio::sync::Mutex; -#[derive(Debug, Default)] -struct Uevent { - action: String, - devpath: String, - devname: String, - subsystem: String, +#[derive(Debug, Default, Clone)] +pub struct Uevent { + pub action: String, + pub devpath: String, + pub devname: String, + pub subsystem: String, seqnum: String, - interface: String, + pub interface: String, } impl Uevent { @@ -67,9 +67,8 @@ impl Uevent { let pci_root_bus_path = create_pci_root_bus_path(); let mut sb = sandbox.lock().await; - // Add the device node name to the pci device map. - sb.pci_device_map - .insert(self.devpath.clone(), self.devname.clone()); + // Record the event by sysfs path + sb.uevent_map.insert(self.devpath.clone(), self.clone()); // Notify watchers that are interested in the udev event. // Close the channel after watcher has been notified. From 91e0ef5c908a9db34df1460bf9814b969c5b0059 Mon Sep 17 00:00:00 2001 From: David Gibson Date: Fri, 5 Mar 2021 16:07:42 +1100 Subject: [PATCH 06/12] agent/uevent: Report whole Uevents to device watchers Currently, when Uevent::handle_block_add_event() receives an event matching a registered watcher, it reports the /dev node name from the event back to the watcher. This changes it to report the entire uevent, not just the /dev node name. This will allow various future extensions. It also makes the client side of the uevent watching - get_device_name() - more consistent between its two paths: finding a past uevent in Sandbox::uevent_map() or waiting for a new uevent via a watcher. Signed-off-by: David Gibson --- src/agent/src/device.rs | 11 ++++++----- src/agent/src/sandbox.rs | 2 +- src/agent/src/uevent.rs | 3 +-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index d6117839ce..8d756a9d0b 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -17,6 +17,7 @@ use crate::linux_abi::*; use crate::mount::{DRIVER_BLK_TYPE, DRIVER_MMIO_BLK_TYPE, DRIVER_NVDIMM_TYPE, DRIVER_SCSI_TYPE}; use crate::pci; use crate::sandbox::Sandbox; +use crate::uevent::Uevent; use crate::AGENT_CONFIG; use anyhow::{anyhow, Result}; use oci::{LinuxDeviceCgroup, LinuxResources, Spec}; @@ -101,14 +102,14 @@ async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Resul // The key of the watchers map is the device we are interested in. // Note this is done inside the lock, not to miss any events from the // global udev listener. - let (tx, rx) = tokio::sync::oneshot::channel::(); + let (tx, rx) = tokio::sync::oneshot::channel::(); sb.dev_watcher.insert(dev_addr.to_string(), tx); drop(sb); // unlock info!(sl!(), "Waiting on channel for device notification\n"); let hotplug_timeout = AGENT_CONFIG.read().await.hotplug_timeout; - let dev_name = match tokio::time::timeout(hotplug_timeout, rx).await { + let uev = match tokio::time::timeout(hotplug_timeout, rx).await { Ok(v) => v?, Err(_) => { let mut sb = sandbox.lock().await; @@ -122,7 +123,7 @@ async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Resul } }; - Ok(format!("{}/{}", SYSTEM_DEV_PATH, &dev_name)) + Ok(format!("{}/{}", SYSTEM_DEV_PATH, &uev.devname)) } pub async fn get_scsi_device_name( @@ -796,7 +797,7 @@ mod tests { assert_eq!(name.unwrap(), format!("{}/{}", SYSTEM_DEV_PATH, devname)); let mut sb = sandbox.lock().await; - sb.uevent_map.remove(&devpath); + let uev = sb.uevent_map.remove(&devpath).unwrap(); drop(sb); // unlock let watcher_sandbox = Arc::clone(&sandbox); @@ -812,7 +813,7 @@ mod tests { if let Some(k) = matched_key { let sender = sb.dev_watcher.remove(&k).unwrap(); - let _ = sender.send(devname.to_string()); + let _ = sender.send(uev); return; } drop(sb); // unlock diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index 8b1c4561ca..fd9eab0fbc 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -38,7 +38,7 @@ pub struct Sandbox { pub mounts: Vec, pub container_mounts: HashMap>, pub uevent_map: HashMap, - pub dev_watcher: HashMap>, + pub dev_watcher: HashMap>, pub shared_utsns: Namespace, pub shared_ipcns: Namespace, pub sandbox_pidns: Option, diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index 78bc400abd..5de4cecb47 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -98,12 +98,11 @@ impl Uevent { .collect(); for k in keys { - let devname = self.devname.clone(); // unwrap() is safe because logic above ensures k exists // in the map, and it's locked so no-one else can change // that let sender = sb.dev_watcher.remove(&k).unwrap(); - let _ = sender.send(devname); + let _ = sender.send(self.clone()); } } From 55ed2ddd071162d41bc052baf84d5ceec38e5b5f Mon Sep 17 00:00:00 2001 From: David Gibson Date: Thu, 4 Mar 2021 14:07:53 +1100 Subject: [PATCH 07/12] agent: Store uevent watchers in Vec rather than HashMap Sandbox:dev_watcher is a HashMap from a "device address" to a channel used to notify get_device_name() that a suitable uevent has been found. However, "device address" isn't well defined, having somewhat different meanings for different device/event types. We never actually look up this HashMap by key, except to remove entries. Not looking up by key suggests that a map is not the appropriate data structure here. Furthermore, HashMap imposes limitations on the types which will prevent some future extensions we want. So, replace the HashMap with a Vec>. We need the Option<> so that we can remove entries by index (removing them from the Vec completely would hange the indices of other entries, possibly breaking concurrent work. This does mean that the vector will keep growing as we watch for different events during startup. However, we don't expect the number of device events we watch for during a run to be very large, so that shouldn't be a problem. We can optimize this later if it becomes a problem. Signed-off-by: David Gibson --- src/agent/src/device.rs | 24 ++++++++++----------- src/agent/src/sandbox.rs | 5 +++-- src/agent/src/uevent.rs | 46 +++++++++++++++------------------------- 3 files changed, 31 insertions(+), 44 deletions(-) diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 8d756a9d0b..7e8b221ca1 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -103,7 +103,8 @@ async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Resul // Note this is done inside the lock, not to miss any events from the // global udev listener. let (tx, rx) = tokio::sync::oneshot::channel::(); - sb.dev_watcher.insert(dev_addr.to_string(), tx); + let idx = sb.uevent_watchers.len(); + sb.uevent_watchers.push(Some((dev_addr.to_string(), tx))); drop(sb); // unlock info!(sl!(), "Waiting on channel for device notification\n"); @@ -113,7 +114,7 @@ async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Resul Ok(v) => v?, Err(_) => { let mut sb = sandbox.lock().await; - sb.dev_watcher.remove_entry(dev_addr); + sb.uevent_watchers[idx].take(); return Err(anyhow!( "Timeout reached after {:?} waiting for device {}", @@ -804,17 +805,14 @@ mod tests { tokio::spawn(async move { loop { let mut sb = watcher_sandbox.lock().await; - let matched_key = sb - .dev_watcher - .keys() - .filter(|dev_addr| devpath.contains(*dev_addr)) - .cloned() - .next(); - - if let Some(k) = matched_key { - let sender = sb.dev_watcher.remove(&k).unwrap(); - let _ = sender.send(uev); - return; + for w in &mut sb.uevent_watchers { + if let Some((dev_addr, _)) = w { + if devpath.contains(dev_addr.as_str()) { + let (_, sender) = w.take().unwrap(); + let _ = sender.send(uev); + return; + } + } } drop(sb); // unlock } diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index fd9eab0fbc..02a60465bd 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -26,6 +26,7 @@ use std::path::Path; use std::sync::Arc; use std::{thread, time}; use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::oneshot; use tokio::sync::Mutex; #[derive(Debug)] @@ -38,7 +39,7 @@ pub struct Sandbox { pub mounts: Vec, pub container_mounts: HashMap>, pub uevent_map: HashMap, - pub dev_watcher: HashMap>, + pub uevent_watchers: Vec)>>, pub shared_utsns: Namespace, pub shared_ipcns: Namespace, pub sandbox_pidns: Option, @@ -68,7 +69,7 @@ impl Sandbox { mounts: Vec::new(), container_mounts: HashMap::new(), uevent_map: HashMap::new(), - dev_watcher: HashMap::new(), + uevent_watchers: Vec::new(), shared_utsns: Namespace::new(&logger), shared_ipcns: Namespace::new(&logger), sandbox_pidns: None, diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index 5de4cecb47..9d7a39fbf3 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -71,38 +71,26 @@ impl Uevent { sb.uevent_map.insert(self.devpath.clone(), self.clone()); // Notify watchers that are interested in the udev event. - // Close the channel after watcher has been notified. - let devpath = self.devpath.clone(); - let keys: Vec<_> = sb - .dev_watcher - .keys() - .filter(|dev_addr| { - let pci_p = format!("{}{}", pci_root_bus_path, *dev_addr); + for watch in &mut sb.uevent_watchers { + if let Some((dev_addr, _)) = watch { + let pci_p = format!("{}{}", pci_root_bus_path, dev_addr); + let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, self.devname); - // blk block device - devpath.starts_with(pci_p.as_str()) || - // scsi block device + if self.devpath.starts_with(pci_p.as_str()) || // blk block device + ( // scsi block device + dev_addr.ends_with(SCSI_BLOCK_SUFFIX) && + self.devpath.contains(dev_addr.as_str()) + ) || + ( // nvdimm/pmem device + self.devpath.starts_with(ACPI_DEV_PATH) && + self.devpath.ends_with(pmem_suffix.as_str()) && + dev_addr.ends_with(pmem_suffix.as_str()) + ) { - (*dev_addr).ends_with(SCSI_BLOCK_SUFFIX) && - devpath.contains(*dev_addr) - } || - // nvdimm/pmem device - { - let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, self.devname); - devpath.starts_with(ACPI_DEV_PATH) && - devpath.ends_with(pmem_suffix.as_str()) && - dev_addr.ends_with(pmem_suffix.as_str()) + let (_, sender) = watch.take().unwrap(); + let _ = sender.send(self.clone()); } - }) - .cloned() - .collect(); - - for k in keys { - // unwrap() is safe because logic above ensures k exists - // in the map, and it's locked so no-one else can change - // that - let sender = sb.dev_watcher.remove(&k).unwrap(); - let _ = sender.send(self.clone()); + } } } From d2caff6c550b53a5a618f35d2f4911e2eeefe68f Mon Sep 17 00:00:00 2001 From: David Gibson Date: Thu, 4 Mar 2021 14:38:33 +1100 Subject: [PATCH 08/12] agent: Re-organize uevent processing Uevent::process() is a bit oddly organized. It treats the onlining of hotplugged memory as the "default" case, although that's quite specific, while treating the handling of hotplugged block devices more like a special case, although that's pretty close to being very general. Furthermore splitting Uevent::is_block_add_event() from Uevent::handle_block_add_event() doesn't make a lot of sense, since their logic is intimately related to each other. Alter the code to be a bit more sensible: first split on the "action" type since that's the most fundamental difference, then handle the memory onlining special case, then the block device add (which will become a lot more general in future changes). Signed-off-by: David Gibson --- src/agent/src/uevent.rs | 48 ++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index 9d7a39fbf3..7f6269ddef 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -51,20 +51,34 @@ impl Uevent { event } - // Check whether this is a block device hot-add event. - fn is_block_add_event(&self) -> bool { + async fn process_add(&self, logger: &Logger, sandbox: &Arc>) { + // Special case for memory hot-adds first + let online_path = format!("{}/{}/online", SYSFS_DIR, &self.devpath); + if online_path.starts_with(SYSFS_MEMORY_ONLINE_PATH) { + let _ = online_device(online_path.as_ref()).map_err(|e| { + error!( + *logger, + "failed to online device"; + "device" => &self.devpath, + "error" => format!("{}", e), + ) + }); + return; + } + let pci_root_bus_path = create_pci_root_bus_path(); - self.action == U_EVENT_ACTION_ADD - && self.subsystem == "block" + + // Check whether this is a block device hot-add event. + if !(self.subsystem == "block" && { self.devpath.starts_with(pci_root_bus_path.as_str()) || self.devpath.starts_with(ACPI_DEV_PATH) // NVDIMM/PMEM devices } - && !self.devname.is_empty() - } + && !self.devname.is_empty()) + { + return; + } - async fn handle_block_add_event(&self, sandbox: &Arc>) { - let pci_root_bus_path = create_pci_root_bus_path(); let mut sb = sandbox.lock().await; // Record the event by sysfs path @@ -95,22 +109,8 @@ impl Uevent { } async fn process(&self, logger: &Logger, sandbox: &Arc>) { - if self.is_block_add_event() { - return self.handle_block_add_event(sandbox).await; - } else if self.action == U_EVENT_ACTION_ADD { - let online_path = format!("{}/{}/online", SYSFS_DIR, &self.devpath); - // It's a memory hot-add event. - if online_path.starts_with(SYSFS_MEMORY_ONLINE_PATH) { - let _ = online_device(online_path.as_ref()).map_err(|e| { - error!( - *logger, - "failed to online device"; - "device" => &self.devpath, - "error" => format!("{}", e), - ) - }); - return; - } + if self.action == U_EVENT_ACTION_ADD { + return self.process_add(logger, sandbox).await; } debug!(*logger, "ignoring event"; "uevent" => format!("{:?}", self)); } From b8b322482cef8ca13ca3862848278b014f1a5394 Mon Sep 17 00:00:00 2001 From: David Gibson Date: Thu, 4 Mar 2021 15:03:49 +1100 Subject: [PATCH 09/12] agent/uevent: Consolidate event matching logic The event matching logic in Uevent::process_add() is split into two parts. The first checks if we care about the event at all, the second checks whether the event is relevant to a particular watcher. However, we're going to be adding more types of watchers in future, which will make the global filter too restrictive. Fold the two bits of logic together into a per-watcher filter function. Signed-off-by: David Gibson --- src/agent/src/uevent.rs | 53 ++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index 7f6269ddef..ecec3b46e5 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -66,19 +66,6 @@ impl Uevent { return; } - let pci_root_bus_path = create_pci_root_bus_path(); - - // Check whether this is a block device hot-add event. - if !(self.subsystem == "block" - && { - self.devpath.starts_with(pci_root_bus_path.as_str()) - || self.devpath.starts_with(ACPI_DEV_PATH) // NVDIMM/PMEM devices - } - && !self.devname.is_empty()) - { - return; - } - let mut sb = sandbox.lock().await; // Record the event by sysfs path @@ -87,20 +74,36 @@ impl Uevent { // Notify watchers that are interested in the udev event. for watch in &mut sb.uevent_watchers { if let Some((dev_addr, _)) = watch { + let pci_root_bus_path = create_pci_root_bus_path(); let pci_p = format!("{}{}", pci_root_bus_path, dev_addr); - let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, self.devname); - if self.devpath.starts_with(pci_p.as_str()) || // blk block device - ( // scsi block device - dev_addr.ends_with(SCSI_BLOCK_SUFFIX) && - self.devpath.contains(dev_addr.as_str()) - ) || - ( // nvdimm/pmem device - self.devpath.starts_with(ACPI_DEV_PATH) && - self.devpath.ends_with(pmem_suffix.as_str()) && - dev_addr.ends_with(pmem_suffix.as_str()) - ) - { + let is_match = |uev: &Uevent| { + let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, uev.devname); + + uev.subsystem == "block" + && { + uev.devpath.starts_with(pci_root_bus_path.as_str()) + || uev.devpath.starts_with(ACPI_DEV_PATH) // NVDIMM/PMEM devices + } + && !uev.devname.is_empty() + && { + // blk block device + uev.devpath.starts_with(pci_p.as_str()) + // scsi block device + || ( + dev_addr.ends_with(SCSI_BLOCK_SUFFIX) && + uev.devpath.contains(dev_addr.as_str()) + ) + // nvdimm/pmem device + || ( + uev.devpath.starts_with(ACPI_DEV_PATH) && + uev.devpath.ends_with(pmem_suffix.as_str()) && + dev_addr.ends_with(pmem_suffix.as_str()) + ) + } + }; + + if is_match(&self) { let (_, sender) = watch.take().unwrap(); let _ = sender.send(self.clone()); } From 4b16681d87ed60a9ddafb0b06213bd01ca27b3b2 Mon Sep 17 00:00:00 2001 From: David Gibson Date: Wed, 10 Mar 2021 16:29:11 +1100 Subject: [PATCH 10/12] agent/uevent: Put matcher object rather than "device address" in watch list Currently, Sandbox::uevent_watchers lists uevents to watch for by a "device address" string. This is not very clearly defined, and is matched against events with a rather complex closure created in Uevent::process_add(). That closure makes a bunch of fragile assumptions about what sort of events we could ever be interested in. In some ways it is too restrictive (requires everything to be a block device), but in others is not restrictive enough (allows things matching NVDIMM paths, even if we're looking for a PCI block device). To allow the clients more precise control over uevent matching, we define a new UeventMatcher trait with a method to match uevents. We then have the atchers list include UeventMatcher trait objects which are used directly by Uevent::process_add(), instead of constructing our match directly from dev_addr. For now we don't actually change the matching function, or even use multiple different trait implementations, but we'll refine that in future. Signed-off-by: David Gibson --- src/agent/src/device.rs | 55 +++++++++++++++++++++++++++++++++++++--- src/agent/src/sandbox.rs | 6 +++-- src/agent/src/uevent.rs | 38 +++++---------------------- 3 files changed, 62 insertions(+), 37 deletions(-) diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 7e8b221ca1..971b31d135 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -17,7 +17,7 @@ use crate::linux_abi::*; use crate::mount::{DRIVER_BLK_TYPE, DRIVER_MMIO_BLK_TYPE, DRIVER_NVDIMM_TYPE, DRIVER_SCSI_TYPE}; use crate::pci; use crate::sandbox::Sandbox; -use crate::uevent::Uevent; +use crate::uevent::{Uevent, UeventMatcher}; use crate::AGENT_CONFIG; use anyhow::{anyhow, Result}; use oci::{LinuxDeviceCgroup, LinuxResources, Spec}; @@ -88,7 +88,52 @@ fn pcipath_to_sysfs(root_bus_sysfs: &str, pcipath: &pci::Path) -> Result Ok(relpath) } +#[derive(Debug)] +struct DevAddrMatcher { + dev_addr: String, +} + +impl DevAddrMatcher { + fn new(dev_addr: &str) -> DevAddrMatcher { + DevAddrMatcher { + dev_addr: dev_addr.to_string(), + } + } +} + +impl UeventMatcher for DevAddrMatcher { + fn is_match(&self, uev: &Uevent) -> bool { + let pci_root_bus_path = create_pci_root_bus_path(); + let pci_p = format!("{}{}", pci_root_bus_path, self.dev_addr); + let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, uev.devname); + + uev.subsystem == "block" + && { + uev.devpath.starts_with(pci_root_bus_path.as_str()) + || uev.devpath.starts_with(ACPI_DEV_PATH) // NVDIMM/PMEM devices + } + && !uev.devname.is_empty() + && { + // blk block device + uev.devpath.starts_with(pci_p.as_str()) + // scsi block device + || ( + self.dev_addr.ends_with(SCSI_BLOCK_SUFFIX) && + uev.devpath.contains(self.dev_addr.as_str()) + ) + // nvdimm/pmem device + || ( + uev.devpath.starts_with(ACPI_DEV_PATH) && + uev.devpath.ends_with(pmem_suffix.as_str()) && + self.dev_addr.ends_with(pmem_suffix.as_str()) + ) + } + } +} + async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Result { + let matcher = DevAddrMatcher::new(dev_addr); + let mut sb = sandbox.lock().await; for (key, uev) in sb.uevent_map.iter() { if key.contains(dev_addr) { @@ -104,7 +149,7 @@ async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Resul // global udev listener. let (tx, rx) = tokio::sync::oneshot::channel::(); let idx = sb.uevent_watchers.len(); - sb.uevent_watchers.push(Some((dev_addr.to_string(), tx))); + sb.uevent_watchers.push(Some((Box::new(matcher), tx))); drop(sb); // unlock info!(sl!(), "Waiting on channel for device notification\n"); @@ -783,6 +828,8 @@ mod tests { let devpath = format!("{}{}/virtio4/block/{}", root_bus, relpath, devname); let mut uev = crate::uevent::Uevent::default(); + uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); + uev.subsystem = "block".to_string(); uev.devpath = devpath.clone(); uev.devname = devname.to_string(); @@ -806,8 +853,8 @@ mod tests { loop { let mut sb = watcher_sandbox.lock().await; for w in &mut sb.uevent_watchers { - if let Some((dev_addr, _)) = w { - if devpath.contains(dev_addr.as_str()) { + if let Some((matcher, _)) = w { + if matcher.is_match(&uev) { let (_, sender) = w.take().unwrap(); let _ = sender.send(uev); return; diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index 02a60465bd..b71fa5534a 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -8,7 +8,7 @@ use crate::mount::{get_mount_fs_type, remove_mounts, TYPE_ROOTFS}; use crate::namespace::Namespace; use crate::netlink::Handle; use crate::network::Network; -use crate::uevent::Uevent; +use crate::uevent::{Uevent, UeventMatcher}; use anyhow::{anyhow, Context, Result}; use libc::pid_t; use oci::{Hook, Hooks}; @@ -29,6 +29,8 @@ use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::oneshot; use tokio::sync::Mutex; +type UeventWatcher = (Box, oneshot::Sender); + #[derive(Debug)] pub struct Sandbox { pub logger: Logger, @@ -39,7 +41,7 @@ pub struct Sandbox { pub mounts: Vec, pub container_mounts: HashMap>, pub uevent_map: HashMap, - pub uevent_watchers: Vec)>>, + pub uevent_watchers: Vec>, pub shared_utsns: Namespace, pub shared_ipcns: Namespace, pub sandbox_pidns: Option, diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index ecec3b46e5..4d1fb8db70 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -11,6 +11,7 @@ use slog::Logger; use anyhow::Result; use netlink_sys::{protocols, SocketAddr, TokioSocket}; use nix::errno::Errno; +use std::fmt::Debug; use std::os::unix::io::FromRawFd; use std::sync::Arc; use tokio::select; @@ -27,6 +28,10 @@ pub struct Uevent { pub interface: String, } +pub trait UeventMatcher: Sync + Send + Debug + 'static { + fn is_match(&self, uev: &Uevent) -> bool; +} + impl Uevent { fn new(message: &str) -> Self { let mut msg_iter = message.split('\0'); @@ -73,37 +78,8 @@ impl Uevent { // Notify watchers that are interested in the udev event. for watch in &mut sb.uevent_watchers { - if let Some((dev_addr, _)) = watch { - let pci_root_bus_path = create_pci_root_bus_path(); - let pci_p = format!("{}{}", pci_root_bus_path, dev_addr); - - let is_match = |uev: &Uevent| { - let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, uev.devname); - - uev.subsystem == "block" - && { - uev.devpath.starts_with(pci_root_bus_path.as_str()) - || uev.devpath.starts_with(ACPI_DEV_PATH) // NVDIMM/PMEM devices - } - && !uev.devname.is_empty() - && { - // blk block device - uev.devpath.starts_with(pci_p.as_str()) - // scsi block device - || ( - dev_addr.ends_with(SCSI_BLOCK_SUFFIX) && - uev.devpath.contains(dev_addr.as_str()) - ) - // nvdimm/pmem device - || ( - uev.devpath.starts_with(ACPI_DEV_PATH) && - uev.devpath.ends_with(pmem_suffix.as_str()) && - dev_addr.ends_with(pmem_suffix.as_str()) - ) - } - }; - - if is_match(&self) { + if let Some((matcher, _)) = watch { + if matcher.is_match(&self) { let (_, sender) = watch.take().unwrap(); let _ = sender.send(self.clone()); } From 16ed55e440afd14437d1455773a00af235cd0dde Mon Sep 17 00:00:00 2001 From: David Gibson Date: Fri, 5 Mar 2021 15:47:45 +1100 Subject: [PATCH 11/12] agent/device: Use consistent matching for past and future uevents get_device_name() looks at kernel uevents to work out the device name for a given PCI (usually) address. However, when we call it we can't know if the uevent we're interested in has already happened (in which case it will have been recorded in Sandbox::uevent_map) or yet to come, in which case we need to register to watch it. However, we currently match differently against past and future events. For past events we simply look for a sysfs path including the address, but for future events we use a complex bit of logic in the is_match() closure. Change it to use the exact same matching logic in both cases. fixes #1397 Signed-off-by: David Gibson --- src/agent/src/device.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 971b31d135..bbb70bc65a 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -135,8 +135,8 @@ async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Resul let matcher = DevAddrMatcher::new(dev_addr); let mut sb = sandbox.lock().await; - for (key, uev) in sb.uevent_map.iter() { - if key.contains(dev_addr) { + for uev in sb.uevent_map.values() { + if matcher.is_match(uev) { info!(sl!(), "Device {} found in pci device map", dev_addr); return Ok(format!("{}/{}", SYSTEM_DEV_PATH, uev.devname)); } From 0828f9ba707b4a2d4a381b657439f0802d2d5cc7 Mon Sep 17 00:00:00 2001 From: David Gibson Date: Wed, 10 Mar 2021 15:18:37 +1100 Subject: [PATCH 12/12] agent/uevent: Introduce wait_for_uevent() helper get_device_name() contains logic to wait for a specific uevent, then extract the /dev node name from it. In future we're going to want similar logic to wait on uevents, but using different match criteria, or getting different information out. To simplify this, add a wait_for_uevent() helper in the uevent module, which takes an explicit UeventMatcher object and returns the whole uevent found. To make testing easier, we also extract the cut down uevent watcher from test_get_device_name() into a new spawn_test_watcher() helper. Its used for both test_get_device_name() and a new test_wait_for_uevent() amd will be useful for more tests in future. fixes #1484 Signed-off-by: David Gibson --- src/agent/src/device.rs | 58 ++----------------- src/agent/src/uevent.rs | 120 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 123 insertions(+), 55 deletions(-) diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index bbb70bc65a..750aa46691 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -17,8 +17,7 @@ use crate::linux_abi::*; use crate::mount::{DRIVER_BLK_TYPE, DRIVER_MMIO_BLK_TYPE, DRIVER_NVDIMM_TYPE, DRIVER_SCSI_TYPE}; use crate::pci; use crate::sandbox::Sandbox; -use crate::uevent::{Uevent, UeventMatcher}; -use crate::AGENT_CONFIG; +use crate::uevent::{wait_for_uevent, Uevent, UeventMatcher}; use anyhow::{anyhow, Result}; use oci::{LinuxDeviceCgroup, LinuxResources, Spec}; use protocols::agent::Device; @@ -88,7 +87,7 @@ fn pcipath_to_sysfs(root_bus_sysfs: &str, pcipath: &pci::Path) -> Result Ok(relpath) } -#[derive(Debug)] +#[derive(Debug, Clone)] struct DevAddrMatcher { dev_addr: String, } @@ -134,40 +133,7 @@ impl UeventMatcher for DevAddrMatcher { async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Result { let matcher = DevAddrMatcher::new(dev_addr); - let mut sb = sandbox.lock().await; - for uev in sb.uevent_map.values() { - if matcher.is_match(uev) { - info!(sl!(), "Device {} found in pci device map", dev_addr); - return Ok(format!("{}/{}", SYSTEM_DEV_PATH, uev.devname)); - } - } - - // If device is not found in the device map, hotplug event has not - // been received yet, create and add channel to the watchers map. - // The key of the watchers map is the device we are interested in. - // Note this is done inside the lock, not to miss any events from the - // global udev listener. - let (tx, rx) = tokio::sync::oneshot::channel::(); - let idx = sb.uevent_watchers.len(); - sb.uevent_watchers.push(Some((Box::new(matcher), tx))); - drop(sb); // unlock - - info!(sl!(), "Waiting on channel for device notification\n"); - let hotplug_timeout = AGENT_CONFIG.read().await.hotplug_timeout; - - let uev = match tokio::time::timeout(hotplug_timeout, rx).await { - Ok(v) => v?, - Err(_) => { - let mut sb = sandbox.lock().await; - sb.uevent_watchers[idx].take(); - - return Err(anyhow!( - "Timeout reached after {:?} waiting for device {}", - hotplug_timeout, - dev_addr - )); - } - }; + let uev = wait_for_uevent(sandbox, matcher).await?; Ok(format!("{}/{}", SYSTEM_DEV_PATH, &uev.devname)) } @@ -473,6 +439,7 @@ pub fn update_device_cgroup(spec: &mut Spec) -> Result<()> { #[cfg(test)] mod tests { use super::*; + use crate::uevent::spawn_test_watcher; use oci::Linux; use tempfile::tempdir; @@ -848,22 +815,7 @@ mod tests { let uev = sb.uevent_map.remove(&devpath).unwrap(); drop(sb); // unlock - let watcher_sandbox = Arc::clone(&sandbox); - tokio::spawn(async move { - loop { - let mut sb = watcher_sandbox.lock().await; - for w in &mut sb.uevent_watchers { - if let Some((matcher, _)) = w { - if matcher.is_match(&uev) { - let (_, sender) = w.take().unwrap(); - let _ = sender.send(uev); - return; - } - } - } - drop(sb); // unlock - } - }); + spawn_test_watcher(sandbox.clone(), uev); let name = get_device_name(&sandbox, relpath).await; assert!(name.is_ok(), "{}", name.unwrap_err()); diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index 4d1fb8db70..3d8e12b0d5 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -6,9 +6,10 @@ use crate::device::online_device; use crate::linux_abi::*; use crate::sandbox::Sandbox; +use crate::AGENT_CONFIG; use slog::Logger; -use anyhow::Result; +use anyhow::{anyhow, Result}; use netlink_sys::{protocols, SocketAddr, TokioSocket}; use nix::errno::Errno; use std::fmt::Debug; @@ -18,7 +19,14 @@ use tokio::select; use tokio::sync::watch::Receiver; use tokio::sync::Mutex; -#[derive(Debug, Default, Clone)] +// Convenience macro to obtain the scope logger +macro_rules! sl { + () => { + slog_scope::logger().new(o!("subsystem" => "uevent")) + }; +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct Uevent { pub action: String, pub devpath: String, @@ -95,6 +103,48 @@ impl Uevent { } } +pub async fn wait_for_uevent( + sandbox: &Arc>, + matcher: impl UeventMatcher, +) -> Result { + let mut sb = sandbox.lock().await; + for uev in sb.uevent_map.values() { + if matcher.is_match(uev) { + info!(sl!(), "Device {:?} found in pci device map", uev); + return Ok(uev.clone()); + } + } + + // If device is not found in the device map, hotplug event has not + // been received yet, create and add channel to the watchers map. + // The key of the watchers map is the device we are interested in. + // Note this is done inside the lock, not to miss any events from the + // global udev listener. + let (tx, rx) = tokio::sync::oneshot::channel::(); + let idx = sb.uevent_watchers.len(); + sb.uevent_watchers.push(Some((Box::new(matcher), tx))); + drop(sb); // unlock + + info!(sl!(), "Waiting on channel for uevent notification\n"); + let hotplug_timeout = AGENT_CONFIG.read().await.hotplug_timeout; + + let uev = match tokio::time::timeout(hotplug_timeout, rx).await { + Ok(v) => v?, + Err(_) => { + let mut sb = sandbox.lock().await; + let matcher = sb.uevent_watchers[idx].take().unwrap().0; + + return Err(anyhow!( + "Timeout after {:?} waiting for uevent {:?}", + hotplug_timeout, + &matcher + )); + } + }; + + Ok(uev) +} + pub async fn watch_uevents( sandbox: Arc>, mut shutdown: Receiver, @@ -159,3 +209,69 @@ pub async fn watch_uevents( Ok(()) } + +// Used in the device module unit tests +#[cfg(test)] +pub(crate) fn spawn_test_watcher(sandbox: Arc>, uev: Uevent) { + tokio::spawn(async move { + loop { + let mut sb = sandbox.lock().await; + for w in &mut sb.uevent_watchers { + if let Some((matcher, _)) = w { + if matcher.is_match(&uev) { + let (_, sender) = w.take().unwrap(); + let _ = sender.send(uev); + return; + } + } + } + drop(sb); // unlock + } + }); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, Clone, Copy)] + struct AlwaysMatch(); + + impl UeventMatcher for AlwaysMatch { + fn is_match(&self, _: &Uevent) -> bool { + true + } + } + + #[tokio::test] + async fn test_wait_for_uevent() { + let mut uev = crate::uevent::Uevent::default(); + uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); + uev.subsystem = "test".to_string(); + uev.devpath = "/test/sysfs/path".to_string(); + uev.devname = "testdevname".to_string(); + + let matcher = AlwaysMatch(); + + let logger = slog::Logger::root(slog::Discard, o!()); + let sandbox = Arc::new(Mutex::new(Sandbox::new(&logger).unwrap())); + + let mut sb = sandbox.lock().await; + sb.uevent_map.insert(uev.devpath.clone(), uev.clone()); + drop(sb); // unlock + + let uev2 = wait_for_uevent(&sandbox, matcher).await; + assert!(uev2.is_ok()); + assert_eq!(uev2.unwrap(), uev); + + let mut sb = sandbox.lock().await; + sb.uevent_map.remove(&uev.devpath).unwrap(); + drop(sb); // unlock + + spawn_test_watcher(sandbox.clone(), uev.clone()); + + let uev2 = wait_for_uevent(&sandbox, matcher).await; + assert!(uev2.is_ok()); + assert_eq!(uev2.unwrap(), uev); + } +}