agent: Store uevent watchers in Vec rather than HashMap

Sandbox:dev_watcher is a HashMap from a "device address" to a channel used
to notify get_device_name() that a suitable uevent has been found.
However, "device address" isn't well defined, having somewhat different
meanings for different device/event types.  We never actually look up this
HashMap by key, except to remove entries.

Not looking up by key suggests that a map is not the appropriate data
structure here.  Furthermore, HashMap imposes limitations on the types
which will prevent some future extensions we want.

So, replace the HashMap with a Vec<Option<>>.  We need the Option<> so that
we can remove entries by index (removing them from the Vec completely would
hange the indices of other entries, possibly breaking concurrent work.

This does mean that the vector will keep growing as we watch for different
events during startup.  However, we don't expect the number of device
events we watch for during a run to be very large, so that shouldn't be
a problem.  We can optimize this later if it becomes a problem.

Signed-off-by: David Gibson <david@gibson.dropbear.id.au>
This commit is contained in:
David Gibson 2021-03-04 14:07:53 +11:00
parent 91e0ef5c90
commit 55ed2ddd07
3 changed files with 31 additions and 44 deletions

View File

@ -103,7 +103,8 @@ async fn get_device_name(sandbox: &Arc<Mutex<Sandbox>>, dev_addr: &str) -> Resul
// Note this is done inside the lock, not to miss any events from the // Note this is done inside the lock, not to miss any events from the
// global udev listener. // global udev listener.
let (tx, rx) = tokio::sync::oneshot::channel::<Uevent>(); let (tx, rx) = tokio::sync::oneshot::channel::<Uevent>();
sb.dev_watcher.insert(dev_addr.to_string(), tx); let idx = sb.uevent_watchers.len();
sb.uevent_watchers.push(Some((dev_addr.to_string(), tx)));
drop(sb); // unlock drop(sb); // unlock
info!(sl!(), "Waiting on channel for device notification\n"); info!(sl!(), "Waiting on channel for device notification\n");
@ -113,7 +114,7 @@ async fn get_device_name(sandbox: &Arc<Mutex<Sandbox>>, dev_addr: &str) -> Resul
Ok(v) => v?, Ok(v) => v?,
Err(_) => { Err(_) => {
let mut sb = sandbox.lock().await; let mut sb = sandbox.lock().await;
sb.dev_watcher.remove_entry(dev_addr); sb.uevent_watchers[idx].take();
return Err(anyhow!( return Err(anyhow!(
"Timeout reached after {:?} waiting for device {}", "Timeout reached after {:?} waiting for device {}",
@ -804,18 +805,15 @@ mod tests {
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
let mut sb = watcher_sandbox.lock().await; let mut sb = watcher_sandbox.lock().await;
let matched_key = sb for w in &mut sb.uevent_watchers {
.dev_watcher if let Some((dev_addr, _)) = w {
.keys() if devpath.contains(dev_addr.as_str()) {
.filter(|dev_addr| devpath.contains(*dev_addr)) let (_, sender) = w.take().unwrap();
.cloned()
.next();
if let Some(k) = matched_key {
let sender = sb.dev_watcher.remove(&k).unwrap();
let _ = sender.send(uev); let _ = sender.send(uev);
return; return;
} }
}
}
drop(sb); // unlock drop(sb); // unlock
} }
}); });

View File

@ -26,6 +26,7 @@ use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use std::{thread, time}; use std::{thread, time};
use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::oneshot;
use tokio::sync::Mutex; use tokio::sync::Mutex;
#[derive(Debug)] #[derive(Debug)]
@ -38,7 +39,7 @@ pub struct Sandbox {
pub mounts: Vec<String>, pub mounts: Vec<String>,
pub container_mounts: HashMap<String, Vec<String>>, pub container_mounts: HashMap<String, Vec<String>>,
pub uevent_map: HashMap<String, Uevent>, pub uevent_map: HashMap<String, Uevent>,
pub dev_watcher: HashMap<String, tokio::sync::oneshot::Sender<Uevent>>, pub uevent_watchers: Vec<Option<(String, oneshot::Sender<Uevent>)>>,
pub shared_utsns: Namespace, pub shared_utsns: Namespace,
pub shared_ipcns: Namespace, pub shared_ipcns: Namespace,
pub sandbox_pidns: Option<Namespace>, pub sandbox_pidns: Option<Namespace>,
@ -68,7 +69,7 @@ impl Sandbox {
mounts: Vec::new(), mounts: Vec::new(),
container_mounts: HashMap::new(), container_mounts: HashMap::new(),
uevent_map: HashMap::new(), uevent_map: HashMap::new(),
dev_watcher: HashMap::new(), uevent_watchers: Vec::new(),
shared_utsns: Namespace::new(&logger), shared_utsns: Namespace::new(&logger),
shared_ipcns: Namespace::new(&logger), shared_ipcns: Namespace::new(&logger),
sandbox_pidns: None, sandbox_pidns: None,

View File

@ -71,40 +71,28 @@ impl Uevent {
sb.uevent_map.insert(self.devpath.clone(), self.clone()); sb.uevent_map.insert(self.devpath.clone(), self.clone());
// Notify watchers that are interested in the udev event. // Notify watchers that are interested in the udev event.
// Close the channel after watcher has been notified. for watch in &mut sb.uevent_watchers {
let devpath = self.devpath.clone(); if let Some((dev_addr, _)) = watch {
let keys: Vec<_> = sb let pci_p = format!("{}{}", pci_root_bus_path, dev_addr);
.dev_watcher
.keys()
.filter(|dev_addr| {
let pci_p = format!("{}{}", pci_root_bus_path, *dev_addr);
// blk block device
devpath.starts_with(pci_p.as_str()) ||
// scsi block device
{
(*dev_addr).ends_with(SCSI_BLOCK_SUFFIX) &&
devpath.contains(*dev_addr)
} ||
// nvdimm/pmem device
{
let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, self.devname); let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, self.devname);
devpath.starts_with(ACPI_DEV_PATH) &&
devpath.ends_with(pmem_suffix.as_str()) &&
dev_addr.ends_with(pmem_suffix.as_str())
}
})
.cloned()
.collect();
for k in keys { if self.devpath.starts_with(pci_p.as_str()) || // blk block device
// unwrap() is safe because logic above ensures k exists ( // scsi block device
// in the map, and it's locked so no-one else can change dev_addr.ends_with(SCSI_BLOCK_SUFFIX) &&
// that self.devpath.contains(dev_addr.as_str())
let sender = sb.dev_watcher.remove(&k).unwrap(); ) ||
( // nvdimm/pmem device
self.devpath.starts_with(ACPI_DEV_PATH) &&
self.devpath.ends_with(pmem_suffix.as_str()) &&
dev_addr.ends_with(pmem_suffix.as_str())
)
{
let (_, sender) = watch.take().unwrap();
let _ = sender.send(self.clone()); let _ = sender.send(self.clone());
} }
} }
}
}
async fn process(&self, logger: &Logger, sandbox: &Arc<Mutex<Sandbox>>) { async fn process(&self, logger: &Logger, sandbox: &Arc<Mutex<Sandbox>>) {
if self.is_block_add_event() { if self.is_block_add_event() {