diff --git a/src/agent/src/mount.rs b/src/agent/src/mount.rs index 5b023b3393..1c2b2dab23 100644 --- a/src/agent/src/mount.rs +++ b/src/agent/src/mount.rs @@ -4,22 +4,18 @@ // use std::collections::HashMap; -use std::ffi::CString; use std::fs; use std::fs::File; -use std::io; use std::io::{BufRead, BufReader}; use std::iter; use std::os::unix::fs::{MetadataExt, PermissionsExt}; use std::path::Path; -use std::ptr::null; use std::str::FromStr; use std::sync::Arc; use tokio::sync::Mutex; -use libc::{c_void, mount}; -use nix::mount::{self, MsFlags}; +use nix::mount::MsFlags; use nix::unistd::Gid; use regex::Regex; @@ -149,96 +145,53 @@ pub const STORAGE_HANDLER_LIST: &[&str] = &[ DRIVER_WATCHABLE_BIND_TYPE, ]; -#[derive(Debug, Clone)] -pub struct BareMount<'a> { - source: &'a str, - destination: &'a str, - fs_type: &'a str, +#[instrument] +pub fn baremount( + source: &str, + destination: &str, + fs_type: &str, flags: MsFlags, - options: &'a str, - logger: Logger, -} + options: &str, + logger: &Logger, +) -> Result<()> { + let logger = logger.new(o!("subsystem" => "baremount")); -// mount mounts a source in to a destination. This will do some bookkeeping: -// * evaluate all symlinks -// * ensure the source exists -impl<'a> BareMount<'a> { - #[instrument] - pub fn new( - s: &'a str, - d: &'a str, - fs_type: &'a str, - flags: MsFlags, - options: &'a str, - logger: &Logger, - ) -> Self { - BareMount { - source: s, - destination: d, - fs_type, - flags, - options, - logger: logger.new(o!("subsystem" => "baremount")), - } + if source.is_empty() { + return Err(anyhow!("need mount source")); } - #[instrument] - pub fn mount(&self) -> Result<()> { - let source; - let dest; - let fs_type; - let mut options = null(); - let cstr_options: CString; - let cstr_source: CString; - let cstr_dest: CString; - let cstr_fs_type: CString; - - if self.source.is_empty() { - return Err(anyhow!("need mount source")); - } - - if self.destination.is_empty() { - return Err(anyhow!("need mount destination")); - } - - cstr_source = CString::new(self.source)?; - source = cstr_source.as_ptr(); - - cstr_dest = CString::new(self.destination)?; - dest = cstr_dest.as_ptr(); - - if self.fs_type.is_empty() { - return Err(anyhow!("need mount FS type")); - } - - cstr_fs_type = CString::new(self.fs_type)?; - fs_type = cstr_fs_type.as_ptr(); - - if !self.options.is_empty() { - cstr_options = CString::new(self.options)?; - options = cstr_options.as_ptr() as *const c_void; - } - - info!( - self.logger, - "mount source={:?}, dest={:?}, fs_type={:?}, options={:?}", - self.source, - self.destination, - self.fs_type, - self.options - ); - let rc = unsafe { mount(source, dest, fs_type, self.flags.bits(), options) }; - - if rc < 0 { - return Err(anyhow!( - "failed to mount {:?} to {:?}, with error: {}", - self.source, - self.destination, - io::Error::last_os_error() - )); - } - Ok(()) + if destination.is_empty() { + return Err(anyhow!("need mount destination")); } + + if fs_type.is_empty() { + return Err(anyhow!("need mount FS type")); + } + + info!( + logger, + "mount source={:?}, dest={:?}, fs_type={:?}, options={:?}", + source, + destination, + fs_type, + options + ); + + nix::mount::mount( + Some(source), + destination, + Some(fs_type), + flags, + Some(options), + ) + .map_err(|e| { + anyhow!( + "failed to mount {:?} to {:?}, with error: {}", + source, + destination, + e + ) + }) } #[instrument] @@ -486,17 +439,14 @@ fn mount_storage(logger: &Logger, storage: &Storage) -> Result<()> { return Ok(()); } - match storage.fstype.as_str() { - DRIVER_9P_TYPE | DRIVER_VIRTIOFS_TYPE => { - let dest_path = Path::new(storage.mount_point.as_str()); - if !dest_path.exists() { - fs::create_dir_all(dest_path).context("Create mount destination failed")?; - } - } - _ => { - ensure_destination_exists(storage.mount_point.as_str(), storage.fstype.as_str())?; - } + let mount_path = Path::new(&storage.mount_point); + let src_path = Path::new(&storage.source); + if storage.fstype == "bind" && !src_path.is_dir() { + ensure_destination_file_exists(mount_path) + } else { + fs::create_dir_all(mount_path).map_err(anyhow::Error::from) } + .context("Could not create mountpoint")?; let options_vec = storage.options.to_vec(); let options_vec = options_vec.iter().map(String::as_str).collect(); @@ -509,16 +459,14 @@ fn mount_storage(logger: &Logger, storage: &Storage) -> Result<()> { "mount-options" => options.as_str(), ); - let bare_mount = BareMount::new( + baremount( storage.source.as_str(), storage.mount_point.as_str(), storage.fstype.as_str(), flags, options.as_str(), &logger, - ); - - bare_mount.mount() + ) } /// Looks for `mount_point` entry in the /proc/mounts. @@ -637,11 +585,9 @@ fn mount_to_rootfs(logger: &Logger, m: &InitMount) -> Result<()> { let (flags, options) = parse_mount_flags_and_options(options_vec); - let bare_mount = BareMount::new(m.src, m.dest, m.fstype, flags, options.as_str(), logger); - fs::create_dir_all(Path::new(m.dest)).context("could not create directory")?; - bare_mount.mount().or_else(|e| { + baremount(m.src, m.dest, m.fstype, flags, &options, logger).or_else(|e| { if m.src != "dev" { return Err(e); } @@ -816,32 +762,27 @@ pub fn cgroups_mount(logger: &Logger, unified_cgroup_hierarchy: bool) -> Result< #[instrument] pub fn remove_mounts(mounts: &[String]) -> Result<()> { for m in mounts.iter() { - mount::umount(m.as_str()).context(format!("failed to umount {:?}", m))?; + nix::mount::umount(m.as_str()).context(format!("failed to umount {:?}", m))?; } Ok(()) } -// ensure_destination_exists will recursively create a given mountpoint. If directories -// are created, their permissions are initialized to mountPerm(0755) #[instrument] -fn ensure_destination_exists(destination: &str, fs_type: &str) -> Result<()> { - let d = Path::new(destination); - if d.exists() { +fn ensure_destination_file_exists(path: &Path) -> Result<()> { + if path.is_file() { return Ok(()); - } - let dir = d - .parent() - .ok_or_else(|| anyhow!("mount destination {} doesn't exist", destination))?; - - if !dir.exists() { - fs::create_dir_all(dir).context(format!("create dir all {:?}", dir))?; + } else if path.exists() { + return Err(anyhow!("{:?} exists but is not a regular file", path)); } - if fs_type != "bind" || d.is_dir() { - fs::create_dir_all(d).context(format!("create dir all {:?}", d))?; - } else { - fs::File::create(d).context(format!("create file {:?}", d))?; - } + // The only way parent() can return None is if the path is /, + // which always exists, so the test above will already have caught + // it, thus the unwrap() is safe + let dir = path.parent().unwrap(); + + fs::create_dir_all(dir).context(format!("create_dir_all {:?}", dir))?; + + fs::File::create(path).context(format!("create empty file {:?}", path))?; Ok(()) } @@ -865,8 +806,6 @@ fn parse_options(option_list: Vec) -> HashMap { mod tests { use super::*; use crate::{skip_if_not_root, skip_loop_if_not_root, skip_loop_if_root}; - use libc::umount; - use std::fs::metadata; use std::fs::File; use std::fs::OpenOptions; use std::io::Write; @@ -1006,7 +945,7 @@ mod tests { std::fs::create_dir_all(d).expect("failed to created directory"); } - let bare_mount = BareMount::new( + let result = baremount( &src_filename, &dest_filename, d.fs_type, @@ -1015,25 +954,13 @@ mod tests { &logger, ); - let result = bare_mount.mount(); - let msg = format!("{}: result: {:?}", msg, result); if d.error_contains.is_empty() { assert!(result.is_ok(), "{}", msg); // Cleanup - unsafe { - let cstr_dest = - CString::new(dest_filename).expect("failed to convert dest to cstring"); - let umount_dest = cstr_dest.as_ptr(); - - let ret = umount(umount_dest); - - let msg = format!("{}: umount result: {:?}", msg, result); - - assert!(ret == 0, "{}", msg); - }; + nix::mount::umount(dest_filename.as_str()).unwrap(); continue; } @@ -1103,7 +1030,7 @@ mod tests { } // Create an actual mount - let bare_mount = BareMount::new( + let result = baremount( mnt_src_filename, mnt_dest_filename, "bind", @@ -1111,8 +1038,6 @@ mod tests { "", &logger, ); - - let result = bare_mount.mount(); assert!(result.is_ok(), "mount for test setup failed"); let tests = &[ @@ -1444,37 +1369,20 @@ mod tests { } #[test] - fn test_ensure_destination_exists() { + fn test_ensure_destination_file_exists() { let dir = tempdir().expect("failed to create tmpdir"); let mut testfile = dir.into_path(); testfile.push("testfile"); - let result = ensure_destination_exists(testfile.to_str().unwrap(), "bind"); + let result = ensure_destination_file_exists(&testfile); assert!(result.is_ok()); assert!(testfile.exists()); - let result = ensure_destination_exists(testfile.to_str().unwrap(), "bind"); + let result = ensure_destination_file_exists(&testfile); assert!(result.is_ok()); - let meta = metadata(testfile).unwrap(); - - assert!(meta.is_file()); - - let dir = tempdir().expect("failed to create tmpdir"); - let mut testdir = dir.into_path(); - testdir.push("testdir"); - - let result = ensure_destination_exists(testdir.to_str().unwrap(), "ext4"); - assert!(result.is_ok()); - assert!(testdir.exists()); - - let result = ensure_destination_exists(testdir.to_str().unwrap(), "ext4"); - assert!(result.is_ok()); - - //let meta = metadata(testdir.to_str().unwrap()).unwrap(); - let meta = metadata(testdir).unwrap(); - assert!(meta.is_dir()); + assert!(testfile.is_file()); } } diff --git a/src/agent/src/namespace.rs b/src/agent/src/namespace.rs index 200fc7c09d..8970e5f30c 100644 --- a/src/agent/src/namespace.rs +++ b/src/agent/src/namespace.rs @@ -13,7 +13,7 @@ use std::fs::File; use std::path::{Path, PathBuf}; use tracing::instrument; -use crate::mount::{BareMount, FLAGS}; +use crate::mount::{baremount, FLAGS}; use slog::Logger; const PERSISTENT_NS_DIR: &str = "/var/run/sandbox-ns"; @@ -129,8 +129,7 @@ impl Namespace { } }; - let bare_mount = BareMount::new(source, destination, "none", flags, "", &logger); - bare_mount.mount().map_err(|e| { + baremount(source, destination, "none", flags, "", &logger).map_err(|e| { anyhow!( "Failed to mount {} to {} with err:{:?}", source, diff --git a/src/agent/src/rpc.rs b/src/agent/src/rpc.rs index 4698e23339..775b119dfd 100644 --- a/src/agent/src/rpc.rs +++ b/src/agent/src/rpc.rs @@ -47,7 +47,7 @@ use rustjail::process::ProcessOperations; use crate::device::{add_devices, pcipath_to_sysfs, rescan_pci_bus, update_device_cgroup}; use crate::linux_abi::*; use crate::metrics::get_metrics; -use crate::mount::{add_storages, remove_mounts, BareMount, STORAGE_HANDLER_LIST}; +use crate::mount::{add_storages, baremount, remove_mounts, STORAGE_HANDLER_LIST}; use crate::namespace::{NSTYPEIPC, NSTYPEPID, NSTYPEUTS}; use crate::network::setup_guest_dns; use crate::random; @@ -1624,15 +1624,14 @@ fn setup_bundle(cid: &str, spec: &mut Spec) -> Result { let rootfs_path = bundle_path.join("rootfs"); fs::create_dir_all(&rootfs_path)?; - BareMount::new( + baremount( &spec_root.path, rootfs_path.to_str().unwrap(), "bind", MsFlags::MS_BIND, "", &sl!(), - ) - .mount()?; + )?; spec.root = Some(Root { path: rootfs_path.to_str().unwrap().to_owned(), readonly: spec_root.readonly, diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index 93a4fa0feb..6829eb2cf5 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -449,7 +449,7 @@ fn online_memory(logger: &Logger) -> Result<()> { #[cfg(test)] mod tests { use super::Sandbox; - use crate::{mount::BareMount, skip_if_not_root}; + use crate::{mount::baremount, skip_if_not_root}; use anyhow::Error; use nix::mount::MsFlags; use oci::{Linux, Root, Spec}; @@ -461,8 +461,7 @@ mod tests { use tempfile::Builder; fn bind_mount(src: &str, dst: &str, logger: &Logger) -> Result<(), Error> { - let baremount = BareMount::new(src, dst, "bind", MsFlags::MS_BIND, "", logger); - baremount.mount() + baremount(src, dst, "bind", MsFlags::MS_BIND, "", logger) } #[tokio::test] diff --git a/src/agent/src/watcher.rs b/src/agent/src/watcher.rs index b6ec528988..840c847dec 100644 --- a/src/agent/src/watcher.rs +++ b/src/agent/src/watcher.rs @@ -20,7 +20,7 @@ use tokio::sync::Mutex; use tokio::task; use tokio::time::{self, Duration}; -use crate::mount::BareMount; +use crate::mount::baremount; use crate::protocols::agent as protos; /// The maximum number of file system entries agent will watch for each mount. @@ -314,16 +314,14 @@ impl SandboxStorages { } } - match BareMount::new( + match baremount( entry.source_mount_point.to_str().unwrap(), entry.target_mount_point.to_str().unwrap(), "bind", MsFlags::MS_BIND, "bind", logger, - ) - .mount() - { + ) { Ok(_) => { entry.watch = false; info!(logger, "watchable mount replaced with bind mount") @@ -427,15 +425,14 @@ impl BindWatcher { async fn mount(&self, logger: &Logger) -> Result<()> { fs::create_dir_all(WATCH_MOUNT_POINT_PATH).await?; - BareMount::new( + baremount( "tmpfs", WATCH_MOUNT_POINT_PATH, "tmpfs", MsFlags::empty(), "", logger, - ) - .mount()?; + )?; Ok(()) }