diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 7e8b221ca1..971b31d135 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -17,7 +17,7 @@ use crate::linux_abi::*; use crate::mount::{DRIVER_BLK_TYPE, DRIVER_MMIO_BLK_TYPE, DRIVER_NVDIMM_TYPE, DRIVER_SCSI_TYPE}; use crate::pci; use crate::sandbox::Sandbox; -use crate::uevent::Uevent; +use crate::uevent::{Uevent, UeventMatcher}; use crate::AGENT_CONFIG; use anyhow::{anyhow, Result}; use oci::{LinuxDeviceCgroup, LinuxResources, Spec}; @@ -88,7 +88,52 @@ fn pcipath_to_sysfs(root_bus_sysfs: &str, pcipath: &pci::Path) -> Result Ok(relpath) } +#[derive(Debug)] +struct DevAddrMatcher { + dev_addr: String, +} + +impl DevAddrMatcher { + fn new(dev_addr: &str) -> DevAddrMatcher { + DevAddrMatcher { + dev_addr: dev_addr.to_string(), + } + } +} + +impl UeventMatcher for DevAddrMatcher { + fn is_match(&self, uev: &Uevent) -> bool { + let pci_root_bus_path = create_pci_root_bus_path(); + let pci_p = format!("{}{}", pci_root_bus_path, self.dev_addr); + let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, uev.devname); + + uev.subsystem == "block" + && { + uev.devpath.starts_with(pci_root_bus_path.as_str()) + || uev.devpath.starts_with(ACPI_DEV_PATH) // NVDIMM/PMEM devices + } + && !uev.devname.is_empty() + && { + // blk block device + uev.devpath.starts_with(pci_p.as_str()) + // scsi block device + || ( + self.dev_addr.ends_with(SCSI_BLOCK_SUFFIX) && + uev.devpath.contains(self.dev_addr.as_str()) + ) + // nvdimm/pmem device + || ( + uev.devpath.starts_with(ACPI_DEV_PATH) && + uev.devpath.ends_with(pmem_suffix.as_str()) && + self.dev_addr.ends_with(pmem_suffix.as_str()) + ) + } + } +} + async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Result { + let matcher = DevAddrMatcher::new(dev_addr); + let mut sb = sandbox.lock().await; for (key, uev) in sb.uevent_map.iter() { if key.contains(dev_addr) { @@ -104,7 +149,7 @@ async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Resul // global udev listener. let (tx, rx) = tokio::sync::oneshot::channel::(); let idx = sb.uevent_watchers.len(); - sb.uevent_watchers.push(Some((dev_addr.to_string(), tx))); + sb.uevent_watchers.push(Some((Box::new(matcher), tx))); drop(sb); // unlock info!(sl!(), "Waiting on channel for device notification\n"); @@ -783,6 +828,8 @@ mod tests { let devpath = format!("{}{}/virtio4/block/{}", root_bus, relpath, devname); let mut uev = crate::uevent::Uevent::default(); + uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string(); + uev.subsystem = "block".to_string(); uev.devpath = devpath.clone(); uev.devname = devname.to_string(); @@ -806,8 +853,8 @@ mod tests { loop { let mut sb = watcher_sandbox.lock().await; for w in &mut sb.uevent_watchers { - if let Some((dev_addr, _)) = w { - if devpath.contains(dev_addr.as_str()) { + if let Some((matcher, _)) = w { + if matcher.is_match(&uev) { let (_, sender) = w.take().unwrap(); let _ = sender.send(uev); return; diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index 02a60465bd..b71fa5534a 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -8,7 +8,7 @@ use crate::mount::{get_mount_fs_type, remove_mounts, TYPE_ROOTFS}; use crate::namespace::Namespace; use crate::netlink::Handle; use crate::network::Network; -use crate::uevent::Uevent; +use crate::uevent::{Uevent, UeventMatcher}; use anyhow::{anyhow, Context, Result}; use libc::pid_t; use oci::{Hook, Hooks}; @@ -29,6 +29,8 @@ use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::oneshot; use tokio::sync::Mutex; +type UeventWatcher = (Box, oneshot::Sender); + #[derive(Debug)] pub struct Sandbox { pub logger: Logger, @@ -39,7 +41,7 @@ pub struct Sandbox { pub mounts: Vec, pub container_mounts: HashMap>, pub uevent_map: HashMap, - pub uevent_watchers: Vec)>>, + pub uevent_watchers: Vec>, pub shared_utsns: Namespace, pub shared_ipcns: Namespace, pub sandbox_pidns: Option, diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index ecec3b46e5..4d1fb8db70 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -11,6 +11,7 @@ use slog::Logger; use anyhow::Result; use netlink_sys::{protocols, SocketAddr, TokioSocket}; use nix::errno::Errno; +use std::fmt::Debug; use std::os::unix::io::FromRawFd; use std::sync::Arc; use tokio::select; @@ -27,6 +28,10 @@ pub struct Uevent { pub interface: String, } +pub trait UeventMatcher: Sync + Send + Debug + 'static { + fn is_match(&self, uev: &Uevent) -> bool; +} + impl Uevent { fn new(message: &str) -> Self { let mut msg_iter = message.split('\0'); @@ -73,37 +78,8 @@ impl Uevent { // Notify watchers that are interested in the udev event. for watch in &mut sb.uevent_watchers { - if let Some((dev_addr, _)) = watch { - let pci_root_bus_path = create_pci_root_bus_path(); - let pci_p = format!("{}{}", pci_root_bus_path, dev_addr); - - let is_match = |uev: &Uevent| { - let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, uev.devname); - - uev.subsystem == "block" - && { - uev.devpath.starts_with(pci_root_bus_path.as_str()) - || uev.devpath.starts_with(ACPI_DEV_PATH) // NVDIMM/PMEM devices - } - && !uev.devname.is_empty() - && { - // blk block device - uev.devpath.starts_with(pci_p.as_str()) - // scsi block device - || ( - dev_addr.ends_with(SCSI_BLOCK_SUFFIX) && - uev.devpath.contains(dev_addr.as_str()) - ) - // nvdimm/pmem device - || ( - uev.devpath.starts_with(ACPI_DEV_PATH) && - uev.devpath.ends_with(pmem_suffix.as_str()) && - dev_addr.ends_with(pmem_suffix.as_str()) - ) - } - }; - - if is_match(&self) { + if let Some((matcher, _)) = watch { + if matcher.is_match(&self) { let (_, sender) = watch.take().unwrap(); let _ = sender.send(self.clone()); }