Merge pull request #1492 from dgibson/uevent

Make uevent watching mechanism more flexible
This commit is contained in:
GabyCT 2021-04-07 10:15:33 -05:00 committed by GitHub
commit 81bcded9a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 240 additions and 117 deletions

View File

@ -17,7 +17,7 @@ use crate::linux_abi::*;
use crate::mount::{DRIVER_BLK_TYPE, DRIVER_MMIO_BLK_TYPE, DRIVER_NVDIMM_TYPE, DRIVER_SCSI_TYPE};
use crate::pci;
use crate::sandbox::Sandbox;
use crate::{AGENT_CONFIG, GLOBAL_DEVICE_WATCHER};
use crate::uevent::{wait_for_uevent, Uevent, UeventMatcher};
use anyhow::{anyhow, Result};
use oci::{LinuxDeviceCgroup, LinuxResources, Spec};
use protocols::agent::Device;
@ -87,46 +87,55 @@ fn pcipath_to_sysfs(root_bus_sysfs: &str, pcipath: &pci::Path) -> Result<String>
Ok(relpath)
}
async fn get_device_name(sandbox: &Arc<Mutex<Sandbox>>, dev_addr: &str) -> Result<String> {
// Keep the same lock order as uevent::handle_block_add_event(), otherwise it may cause deadlock.
let mut w = GLOBAL_DEVICE_WATCHER.lock().await;
let sb = sandbox.lock().await;
for (key, value) in sb.pci_device_map.iter() {
if key.contains(dev_addr) {
info!(sl!(), "Device {} found in pci device map", dev_addr);
return Ok(format!("{}/{}", SYSTEM_DEV_PATH, value));
#[derive(Debug, Clone)]
struct DevAddrMatcher {
dev_addr: String,
}
impl DevAddrMatcher {
fn new(dev_addr: &str) -> DevAddrMatcher {
DevAddrMatcher {
dev_addr: dev_addr.to_string(),
}
}
drop(sb);
}
// If device is not found in the device map, hotplug event has not
// been received yet, create and add channel to the watchers map.
// The key of the watchers map is the device we are interested in.
// Note this is done inside the lock, not to miss any events from the
// global udev listener.
let (tx, rx) = tokio::sync::oneshot::channel::<String>();
w.insert(dev_addr.to_string(), Some(tx));
drop(w);
impl UeventMatcher for DevAddrMatcher {
fn is_match(&self, uev: &Uevent) -> bool {
let pci_root_bus_path = create_pci_root_bus_path();
let pci_p = format!("{}{}", pci_root_bus_path, self.dev_addr);
let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, uev.devname);
info!(sl!(), "Waiting on channel for device notification\n");
let hotplug_timeout = AGENT_CONFIG.read().await.hotplug_timeout;
uev.subsystem == "block"
&& {
uev.devpath.starts_with(pci_root_bus_path.as_str())
|| uev.devpath.starts_with(ACPI_DEV_PATH) // NVDIMM/PMEM devices
}
&& !uev.devname.is_empty()
&& {
// blk block device
uev.devpath.starts_with(pci_p.as_str())
// scsi block device
|| (
self.dev_addr.ends_with(SCSI_BLOCK_SUFFIX) &&
uev.devpath.contains(self.dev_addr.as_str())
)
// nvdimm/pmem device
|| (
uev.devpath.starts_with(ACPI_DEV_PATH) &&
uev.devpath.ends_with(pmem_suffix.as_str()) &&
self.dev_addr.ends_with(pmem_suffix.as_str())
)
}
}
}
let dev_name = match tokio::time::timeout(hotplug_timeout, rx).await {
Ok(v) => v?,
Err(_) => {
let watcher = GLOBAL_DEVICE_WATCHER.clone();
let mut w = watcher.lock().await;
w.remove_entry(dev_addr);
async fn get_device_name(sandbox: &Arc<Mutex<Sandbox>>, dev_addr: &str) -> Result<String> {
let matcher = DevAddrMatcher::new(dev_addr);
return Err(anyhow!(
"Timeout reached after {:?} waiting for device {}",
hotplug_timeout,
dev_addr
));
}
};
let uev = wait_for_uevent(sandbox, matcher).await?;
Ok(format!("{}/{}", SYSTEM_DEV_PATH, &dev_name))
Ok(format!("{}/{}", SYSTEM_DEV_PATH, &uev.devname))
}
pub async fn get_scsi_device_name(
@ -430,6 +439,7 @@ pub fn update_device_cgroup(spec: &mut Spec) -> Result<()> {
#[cfg(test)]
mod tests {
use super::*;
use crate::uevent::spawn_test_watcher;
use oci::Linux;
use tempfile::tempdir;
@ -776,4 +786,39 @@ mod tests {
let relpath = pcipath_to_sysfs(rootbuspath, &path234);
assert_eq!(relpath.unwrap(), "/0000:00:02.0/0000:01:03.0/0000:02:04.0");
}
#[tokio::test]
async fn test_get_device_name() {
let devname = "vda";
let root_bus = create_pci_root_bus_path();
let relpath = "/0000:00:0a.0/0000:03:0b.0";
let devpath = format!("{}{}/virtio4/block/{}", root_bus, relpath, devname);
let mut uev = crate::uevent::Uevent::default();
uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string();
uev.subsystem = "block".to_string();
uev.devpath = devpath.clone();
uev.devname = devname.to_string();
let logger = slog::Logger::root(slog::Discard, o!());
let sandbox = Arc::new(Mutex::new(Sandbox::new(&logger).unwrap()));
let mut sb = sandbox.lock().await;
sb.uevent_map.insert(devpath.clone(), uev);
drop(sb); // unlock
let name = get_device_name(&sandbox, relpath).await;
assert!(name.is_ok(), "{}", name.unwrap_err());
assert_eq!(name.unwrap(), format!("{}/{}", SYSTEM_DEV_PATH, devname));
let mut sb = sandbox.lock().await;
let uev = sb.uevent_map.remove(&devpath).unwrap();
drop(sb); // unlock
spawn_test_watcher(sandbox.clone(), uev);
let name = get_device_name(&sandbox, relpath).await;
assert!(name.is_ok(), "{}", name.unwrap_err());
assert_eq!(name.unwrap(), format!("{}/{}", SYSTEM_DEV_PATH, devname));
}
}

View File

@ -28,7 +28,6 @@ use nix::sys::select::{select, FdSet};
use nix::sys::socket::{self, AddressFamily, SockAddr, SockFlag, SockType};
use nix::sys::wait;
use nix::unistd::{self, close, dup, dup2, fork, setsid, ForkResult};
use std::collections::HashMap;
use std::env;
use std::ffi::{CStr, CString, OsStr};
use std::fs::{self, File};
@ -72,7 +71,6 @@ use rustjail::pipestream::PipeStream;
use tokio::{
io::AsyncWrite,
sync::{
oneshot::Sender,
watch::{channel, Receiver},
Mutex, RwLock,
},
@ -89,8 +87,6 @@ const CONSOLE_PATH: &str = "/dev/console";
const DEFAULT_BUF_SIZE: usize = 8 * 1024;
lazy_static! {
static ref GLOBAL_DEVICE_WATCHER: Arc<Mutex<HashMap<String, Option<Sender<String>>>>> =
Arc::new(Mutex::new(HashMap::new()));
static ref AGENT_CONFIG: Arc<RwLock<AgentConfig>> =
Arc::new(RwLock::new(config::AgentConfig::new()));
}

View File

@ -8,6 +8,7 @@ use crate::mount::{get_mount_fs_type, remove_mounts, TYPE_ROOTFS};
use crate::namespace::Namespace;
use crate::netlink::Handle;
use crate::network::Network;
use crate::uevent::{Uevent, UeventMatcher};
use anyhow::{anyhow, Context, Result};
use libc::pid_t;
use oci::{Hook, Hooks};
@ -25,8 +26,11 @@ use std::path::Path;
use std::sync::Arc;
use std::{thread, time};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::oneshot;
use tokio::sync::Mutex;
type UeventWatcher = (Box<dyn UeventMatcher>, oneshot::Sender<Uevent>);
#[derive(Debug)]
pub struct Sandbox {
pub logger: Logger,
@ -36,7 +40,8 @@ pub struct Sandbox {
pub network: Network,
pub mounts: Vec<String>,
pub container_mounts: HashMap<String, Vec<String>>,
pub pci_device_map: HashMap<String, String>,
pub uevent_map: HashMap<String, Uevent>,
pub uevent_watchers: Vec<Option<UeventWatcher>>,
pub shared_utsns: Namespace,
pub shared_ipcns: Namespace,
pub sandbox_pidns: Option<Namespace>,
@ -65,7 +70,8 @@ impl Sandbox {
containers: HashMap::new(),
mounts: Vec::new(),
container_mounts: HashMap::new(),
pci_device_map: HashMap::new(),
uevent_map: HashMap::new(),
uevent_watchers: Vec::new(),
shared_utsns: Namespace::new(&logger),
shared_ipcns: Namespace::new(&logger),
sandbox_pidns: None,

View File

@ -6,26 +6,38 @@
use crate::device::online_device;
use crate::linux_abi::*;
use crate::sandbox::Sandbox;
use crate::GLOBAL_DEVICE_WATCHER;
use crate::AGENT_CONFIG;
use slog::Logger;
use anyhow::Result;
use anyhow::{anyhow, Result};
use netlink_sys::{protocols, SocketAddr, TokioSocket};
use nix::errno::Errno;
use std::fmt::Debug;
use std::os::unix::io::FromRawFd;
use std::sync::Arc;
use tokio::select;
use tokio::sync::watch::Receiver;
use tokio::sync::Mutex;
#[derive(Debug, Default)]
struct Uevent {
action: String,
devpath: String,
devname: String,
subsystem: String,
// Convenience macro to obtain the scope logger
macro_rules! sl {
() => {
slog_scope::logger().new(o!("subsystem" => "uevent"))
};
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct Uevent {
pub action: String,
pub devpath: String,
pub devname: String,
pub subsystem: String,
seqnum: String,
interface: String,
pub interface: String,
}
pub trait UeventMatcher: Sync + Send + Debug + 'static {
fn is_match(&self, uev: &Uevent) -> bool;
}
impl Uevent {
@ -52,89 +64,87 @@ impl Uevent {
event
}
// Check whether this is a block device hot-add event.
fn is_block_add_event(&self) -> bool {
let pci_root_bus_path = create_pci_root_bus_path();
self.action == U_EVENT_ACTION_ADD
&& self.subsystem == "block"
&& {
self.devpath.starts_with(pci_root_bus_path.as_str())
|| self.devpath.starts_with(ACPI_DEV_PATH) // NVDIMM/PMEM devices
}
&& !self.devname.is_empty()
}
async fn process_add(&self, logger: &Logger, sandbox: &Arc<Mutex<Sandbox>>) {
// Special case for memory hot-adds first
let online_path = format!("{}/{}/online", SYSFS_DIR, &self.devpath);
if online_path.starts_with(SYSFS_MEMORY_ONLINE_PATH) {
let _ = online_device(online_path.as_ref()).map_err(|e| {
error!(
*logger,
"failed to online device";
"device" => &self.devpath,
"error" => format!("{}", e),
)
});
return;
}
async fn handle_block_add_event(&self, sandbox: &Arc<Mutex<Sandbox>>) {
let pci_root_bus_path = create_pci_root_bus_path();
// Keep the same lock order as device::get_device_name(), otherwise it may cause deadlock.
let watcher = GLOBAL_DEVICE_WATCHER.clone();
let mut w = watcher.lock().await;
let mut sb = sandbox.lock().await;
// Add the device node name to the pci device map.
sb.pci_device_map
.insert(self.devpath.clone(), self.devname.clone());
// Record the event by sysfs path
sb.uevent_map.insert(self.devpath.clone(), self.clone());
// Notify watchers that are interested in the udev event.
// Close the channel after watcher has been notified.
let devpath = self.devpath.clone();
let empties: Vec<_> = w
.iter_mut()
.filter(|(dev_addr, _)| {
let pci_p = format!("{}/{}", pci_root_bus_path, *dev_addr);
// blk block device
devpath.starts_with(pci_p.as_str()) ||
// scsi block device
{
(*dev_addr).ends_with(SCSI_BLOCK_SUFFIX) &&
devpath.contains(*dev_addr)
} ||
// nvdimm/pmem device
{
let pmem_suffix = format!("/{}/{}", SCSI_BLOCK_SUFFIX, self.devname);
devpath.starts_with(ACPI_DEV_PATH) &&
devpath.ends_with(pmem_suffix.as_str()) &&
dev_addr.ends_with(pmem_suffix.as_str())
for watch in &mut sb.uevent_watchers {
if let Some((matcher, _)) = watch {
if matcher.is_match(&self) {
let (_, sender) = watch.take().unwrap();
let _ = sender.send(self.clone());
}
})
.map(|(k, sender)| {
let devname = self.devname.clone();
let sender = sender.take().unwrap();
let _ = sender.send(devname);
k.clone()
})
.collect();
// Remove notified nodes from the watcher map.
for empty in empties {
w.remove(&empty);
}
}
}
async fn process(&self, logger: &Logger, sandbox: &Arc<Mutex<Sandbox>>) {
if self.is_block_add_event() {
return self.handle_block_add_event(sandbox).await;
} else if self.action == U_EVENT_ACTION_ADD {
let online_path = format!("{}/{}/online", SYSFS_DIR, &self.devpath);
// It's a memory hot-add event.
if online_path.starts_with(SYSFS_MEMORY_ONLINE_PATH) {
let _ = online_device(online_path.as_ref()).map_err(|e| {
error!(
*logger,
"failed to online device";
"device" => &self.devpath,
"error" => format!("{}", e),
)
});
return;
}
if self.action == U_EVENT_ACTION_ADD {
return self.process_add(logger, sandbox).await;
}
debug!(*logger, "ignoring event"; "uevent" => format!("{:?}", self));
}
}
pub async fn wait_for_uevent(
sandbox: &Arc<Mutex<Sandbox>>,
matcher: impl UeventMatcher,
) -> Result<Uevent> {
let mut sb = sandbox.lock().await;
for uev in sb.uevent_map.values() {
if matcher.is_match(uev) {
info!(sl!(), "Device {:?} found in pci device map", uev);
return Ok(uev.clone());
}
}
// If device is not found in the device map, hotplug event has not
// been received yet, create and add channel to the watchers map.
// The key of the watchers map is the device we are interested in.
// Note this is done inside the lock, not to miss any events from the
// global udev listener.
let (tx, rx) = tokio::sync::oneshot::channel::<Uevent>();
let idx = sb.uevent_watchers.len();
sb.uevent_watchers.push(Some((Box::new(matcher), tx)));
drop(sb); // unlock
info!(sl!(), "Waiting on channel for uevent notification\n");
let hotplug_timeout = AGENT_CONFIG.read().await.hotplug_timeout;
let uev = match tokio::time::timeout(hotplug_timeout, rx).await {
Ok(v) => v?,
Err(_) => {
let mut sb = sandbox.lock().await;
let matcher = sb.uevent_watchers[idx].take().unwrap().0;
return Err(anyhow!(
"Timeout after {:?} waiting for uevent {:?}",
hotplug_timeout,
&matcher
));
}
};
Ok(uev)
}
pub async fn watch_uevents(
sandbox: Arc<Mutex<Sandbox>>,
mut shutdown: Receiver<bool>,
@ -199,3 +209,69 @@ pub async fn watch_uevents(
Ok(())
}
// Used in the device module unit tests
#[cfg(test)]
pub(crate) fn spawn_test_watcher(sandbox: Arc<Mutex<Sandbox>>, uev: Uevent) {
tokio::spawn(async move {
loop {
let mut sb = sandbox.lock().await;
for w in &mut sb.uevent_watchers {
if let Some((matcher, _)) = w {
if matcher.is_match(&uev) {
let (_, sender) = w.take().unwrap();
let _ = sender.send(uev);
return;
}
}
}
drop(sb); // unlock
}
});
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, Copy)]
struct AlwaysMatch();
impl UeventMatcher for AlwaysMatch {
fn is_match(&self, _: &Uevent) -> bool {
true
}
}
#[tokio::test]
async fn test_wait_for_uevent() {
let mut uev = crate::uevent::Uevent::default();
uev.action = crate::linux_abi::U_EVENT_ACTION_ADD.to_string();
uev.subsystem = "test".to_string();
uev.devpath = "/test/sysfs/path".to_string();
uev.devname = "testdevname".to_string();
let matcher = AlwaysMatch();
let logger = slog::Logger::root(slog::Discard, o!());
let sandbox = Arc::new(Mutex::new(Sandbox::new(&logger).unwrap()));
let mut sb = sandbox.lock().await;
sb.uevent_map.insert(uev.devpath.clone(), uev.clone());
drop(sb); // unlock
let uev2 = wait_for_uevent(&sandbox, matcher).await;
assert!(uev2.is_ok());
assert_eq!(uev2.unwrap(), uev);
let mut sb = sandbox.lock().await;
sb.uevent_map.remove(&uev.devpath).unwrap();
drop(sb); // unlock
spawn_test_watcher(sandbox.clone(), uev.clone());
let uev2 = wait_for_uevent(&sandbox, matcher).await;
assert!(uev2.is_ok());
assert_eq!(uev2.unwrap(), uev);
}
}