From 55ed2ddd071162d41bc052baf84d5ceec38e5b5f Mon Sep 17 00:00:00 2001 From: David Gibson Date: Thu, 4 Mar 2021 14:07:53 +1100 Subject: [PATCH] agent: Store uevent watchers in Vec rather than HashMap Sandbox:dev_watcher is a HashMap from a "device address" to a channel used to notify get_device_name() that a suitable uevent has been found. However, "device address" isn't well defined, having somewhat different meanings for different device/event types. We never actually look up this HashMap by key, except to remove entries. Not looking up by key suggests that a map is not the appropriate data structure here. Furthermore, HashMap imposes limitations on the types which will prevent some future extensions we want. So, replace the HashMap with a Vec>. We need the Option<> so that we can remove entries by index (removing them from the Vec completely would hange the indices of other entries, possibly breaking concurrent work. This does mean that the vector will keep growing as we watch for different events during startup. However, we don't expect the number of device events we watch for during a run to be very large, so that shouldn't be a problem. We can optimize this later if it becomes a problem. Signed-off-by: David Gibson --- src/agent/src/device.rs | 24 ++++++++++----------- src/agent/src/sandbox.rs | 5 +++-- src/agent/src/uevent.rs | 46 +++++++++++++++------------------------- 3 files changed, 31 insertions(+), 44 deletions(-) diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 8d756a9d0b..7e8b221ca1 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -103,7 +103,8 @@ async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Resul // Note this is done inside the lock, not to miss any events from the // global udev listener. let (tx, rx) = tokio::sync::oneshot::channel::(); - sb.dev_watcher.insert(dev_addr.to_string(), tx); + let idx = sb.uevent_watchers.len(); + sb.uevent_watchers.push(Some((dev_addr.to_string(), tx))); drop(sb); // unlock info!(sl!(), "Waiting on channel for device notification\n"); @@ -113,7 +114,7 @@ async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Resul Ok(v) => v?, Err(_) => { let mut sb = sandbox.lock().await; - sb.dev_watcher.remove_entry(dev_addr); + sb.uevent_watchers[idx].take(); return Err(anyhow!( "Timeout reached after {:?} waiting for device {}", @@ -804,17 +805,14 @@ mod tests { tokio::spawn(async move { loop { let mut sb = watcher_sandbox.lock().await; - let matched_key = sb - .dev_watcher - .keys() - .filter(|dev_addr| devpath.contains(*dev_addr)) - .cloned() - .next(); - - if let Some(k) = matched_key { - let sender = sb.dev_watcher.remove(&k).unwrap(); - let _ = sender.send(uev); - return; + for w in &mut sb.uevent_watchers { + if let Some((dev_addr, _)) = w { + if devpath.contains(dev_addr.as_str()) { + let (_, sender) = w.take().unwrap(); + let _ = sender.send(uev); + return; + } + } } drop(sb); // unlock } diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index fd9eab0fbc..02a60465bd 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -26,6 +26,7 @@ use std::path::Path; use std::sync::Arc; use std::{thread, time}; use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::oneshot; use tokio::sync::Mutex; #[derive(Debug)] @@ -38,7 +39,7 @@ pub struct Sandbox { pub mounts: Vec, pub container_mounts: HashMap>, pub uevent_map: HashMap, - pub dev_watcher: HashMap>, + pub uevent_watchers: Vec)>>, pub shared_utsns: Namespace, pub shared_ipcns: Namespace, pub sandbox_pidns: Option, @@ -68,7 +69,7 @@ impl Sandbox { mounts: Vec::new(), container_mounts: HashMap::new(), uevent_map: HashMap::new(), - dev_watcher: HashMap::new(), + uevent_watchers: Vec::new(), shared_utsns: Namespace::new(&logger), shared_ipcns: Namespace::new(&logger), sandbox_pidns: None, diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index 5de4cecb47..9d7a39fbf3 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -71,38 +71,26 @@ impl Uevent { sb.uevent_map.insert(self.devpath.clone(), self.clone()); // Notify watchers that are interested in the udev event. - // Close the channel after watcher has been notified. - let devpath = self.devpath.clone(); - let keys: Vec<_> = sb - .dev_watcher - .keys() - .filter(|dev_addr| { - let pci_p = format!("{}{}", pci_root_bus_path, *dev_addr); + for watch in &mut sb.uevent_watchers { + if let Some((dev_addr, _)) = watch { + let pci_p = format!("{}{}", pci_root_bus_path, dev_addr); + let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, self.devname); - // blk block device - devpath.starts_with(pci_p.as_str()) || - // scsi block device + if self.devpath.starts_with(pci_p.as_str()) || // blk block device + ( // scsi block device + dev_addr.ends_with(SCSI_BLOCK_SUFFIX) && + self.devpath.contains(dev_addr.as_str()) + ) || + ( // nvdimm/pmem device + self.devpath.starts_with(ACPI_DEV_PATH) && + self.devpath.ends_with(pmem_suffix.as_str()) && + dev_addr.ends_with(pmem_suffix.as_str()) + ) { - (*dev_addr).ends_with(SCSI_BLOCK_SUFFIX) && - devpath.contains(*dev_addr) - } || - // nvdimm/pmem device - { - let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, self.devname); - devpath.starts_with(ACPI_DEV_PATH) && - devpath.ends_with(pmem_suffix.as_str()) && - dev_addr.ends_with(pmem_suffix.as_str()) + let (_, sender) = watch.take().unwrap(); + let _ = sender.send(self.clone()); } - }) - .cloned() - .collect(); - - for k in keys { - // unwrap() is safe because logic above ensures k exists - // in the map, and it's locked so no-one else can change - // that - let sender = sb.dev_watcher.remove(&k).unwrap(); - let _ = sender.send(self.clone()); + } } }