diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index c22a65795..7c73dd2ca 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -11,7 +11,7 @@ use std::fmt; use std::fs; use std::os::unix::ffi::OsStrExt; use std::os::unix::fs::MetadataExt; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::str::FromStr; use std::sync::Arc; use tokio::sync::Mutex; @@ -157,20 +157,22 @@ pub fn pcipath_to_sysfs(root_bus_sysfs: &str, pcipath: &pci::Path) -> Result = fs::read_dir(&bridgebuspath)?.collect(); - if files.len() != 1 { - return Err(anyhow!( - "Expected exactly one PCI bus in {}, got {} instead", - bridgebuspath, - files.len() - )); - } - - // unwrap is safe, because of the length test above - let busfile = files.pop().unwrap()?; - bus = busfile - .file_name() - .into_string() - .map_err(|e| anyhow!("Bad filename under {}: {:?}", &bridgebuspath, e))?; + match files.pop() { + Some(busfile) if files.is_empty() => { + bus = busfile? + .file_name() + .into_string() + .map_err(|e| anyhow!("Bad filename under {}: {:?}", &bridgebuspath, e))?; + } + _ => { + return Err(anyhow!( + "Expected exactly one PCI bus in {}, got {} instead", + bridgebuspath, + // Adjust to original value as we've already popped + files.len() + 1 + )); + } + }; } Ok(relpath) @@ -218,8 +220,9 @@ impl VirtioBlkPciMatcher { fn new(relpath: &str) -> VirtioBlkPciMatcher { let root_bus = create_pci_root_bus_path(); let re = format!(r"^{}{}/virtio[0-9]+/block/", root_bus, relpath); + VirtioBlkPciMatcher { - rex: Regex::new(&re).unwrap(), + rex: Regex::new(&re).expect("BUG: failed to compile VirtioBlkPciMatcher regex"), } } } @@ -257,7 +260,7 @@ impl VirtioBlkCCWMatcher { root_bus_path, device ); VirtioBlkCCWMatcher { - rex: Regex::new(&re).unwrap(), + rex: Regex::new(&re).expect("BUG: failed to compile VirtioBlkCCWMatcher regex"), } } } @@ -413,12 +416,15 @@ fn scan_scsi_bus(scsi_addr: &str) -> Result<()> { for entry in fs::read_dir(SYSFS_SCSI_HOST_PATH)? { let host = entry?.file_name(); - let scan_path = format!( - "{}/{}/{}", - SYSFS_SCSI_HOST_PATH, - host.to_str().unwrap(), - "scan" - ); + + let host_str = host.to_str().ok_or_else(|| { + anyhow!( + "failed to convert directory entry to unicode for file {:?}", + host + ) + })?; + + let scan_path = PathBuf::from(&format!("{}/{}/{}", SYSFS_SCSI_HOST_PATH, host_str, "scan")); fs::write(scan_path, &scan_data)?; } @@ -722,25 +728,17 @@ async fn vfio_device_handler(device: &Device, sandbox: &Arc>) -> if vfio_in_guest { pci_driver_override(SYSFS_BUS_PCI_PATH, guestdev, "vfio-pci")?; - let devgroup = pci_iommu_group(SYSFS_BUS_PCI_PATH, guestdev)?; - if devgroup.is_none() { - // Devices must have an IOMMU group to be usable via VFIO - return Err(anyhow!("{} has no IOMMU group", guestdev)); + // Devices must have an IOMMU group to be usable via VFIO + let devgroup = pci_iommu_group(SYSFS_BUS_PCI_PATH, guestdev)? + .ok_or_else(|| anyhow!("{} has no IOMMU group", guestdev))?; + + if let Some(g) = group { + if g != devgroup { + return Err(anyhow!("{} is not in guest IOMMU group {}", guestdev, g)); + } } - if group.is_some() && group != devgroup { - // If PCI devices associated with the same VFIO device - // (and therefore group) in the host don't end up in - // the same group in the guest, something has gone - // horribly wrong - return Err(anyhow!( - "{} is not in guest IOMMU group {}", - guestdev, - group.unwrap() - )); - } - - group = devgroup; + group = Some(devgroup); pci_fixups.push((host, guestdev)); } @@ -748,7 +746,8 @@ async fn vfio_device_handler(device: &Device, sandbox: &Arc>) -> let dev_update = if vfio_in_guest { // If there are any devices at all, logic above ensures that group is not None - let group = group.unwrap(); + let group = group.ok_or_else(|| anyhow!("failed to get VFIO group: {:?}"))?; + let vm_path = get_vfio_device_name(sandbox, group).await?; Some(DevUpdate::from_vm_path(&vm_path, vm_path.clone())?) @@ -844,11 +843,8 @@ pub fn update_device_cgroup(spec: &mut Spec) -> Result<()> { .as_mut() .ok_or_else(|| anyhow!("Spec didn't container linux field"))?; - if linux.resources.is_none() { - linux.resources = Some(LinuxResources::default()); - } + let resources = linux.resources.get_or_insert(LinuxResources::default()); - let resources = linux.resources.as_mut().unwrap(); resources.devices.push(LinuxDeviceCgroup { allow: false, major: Some(major), diff --git a/src/agent/src/main.rs b/src/agent/src/main.rs index c2ae8c769..d1745fb01 100644 --- a/src/agent/src/main.rs +++ b/src/agent/src/main.rs @@ -113,10 +113,10 @@ async fn create_logger_task(rfd: RawFd, vsock_port: u32, shutdown: Receiver Result { AGENT_SCRAPE_COUNT.inc(); // update agent process metrics - update_agent_metrics(); + update_agent_metrics()?; // update guest os metrics update_guest_metrics(); @@ -84,23 +85,26 @@ pub fn get_metrics(_: &protocols::agent::GetMetricsRequest) -> Result { let mut buffer = Vec::new(); let encoder = TextEncoder::new(); - encoder.encode(&metric_families, &mut buffer).unwrap(); + encoder.encode(&metric_families, &mut buffer)?; - Ok(String::from_utf8(buffer).unwrap()) + Ok(String::from_utf8(buffer)?) } #[instrument] -fn update_agent_metrics() { +fn update_agent_metrics() -> Result<()> { let me = procfs::process::Process::myself(); - if let Err(err) = me { - error!(sl!(), "failed to create process instance: {:?}", err); - return; - } + let me = match me { + Ok(p) => p, + Err(e) => { + // FIXME: return Ok for all errors? + warn!(sl!(), "failed to create process instance: {:?}", e); - let me = me.unwrap(); + return Ok(()); + } + }; - let tps = procfs::ticks_per_second().unwrap(); + let tps = procfs::ticks_per_second()?; // process total time AGENT_TOTAL_TIME.set((me.stat.utime + me.stat.stime) as f64 / (tps as f64)); @@ -109,7 +113,7 @@ fn update_agent_metrics() { AGENT_TOTAL_VM.set(me.stat.vsize as f64); // Total resident set - let page_size = procfs::page_size().unwrap() as f64; + let page_size = procfs::page_size()? as f64; AGENT_TOTAL_RSS.set(me.stat.rss as f64 * page_size); // io @@ -132,11 +136,11 @@ fn update_agent_metrics() { } match me.status() { - Err(err) => { - info!(sl!(), "failed to get process status: {:?}", err); - } + Err(err) => error!(sl!(), "failed to get process status: {:?}", err), Ok(status) => set_gauge_vec_proc_status(&AGENT_PROC_STATUS, &status), } + + Ok(()) } #[instrument] diff --git a/src/agent/src/mount.rs b/src/agent/src/mount.rs index 3d55f874f..20fd0494c 100644 --- a/src/agent/src/mount.rs +++ b/src/agent/src/mount.rs @@ -139,8 +139,8 @@ pub const STORAGE_HANDLER_LIST: &[&str] = &[ #[instrument] pub fn baremount( - source: &str, - destination: &str, + source: &Path, + destination: &Path, fs_type: &str, flags: MsFlags, options: &str, @@ -148,11 +148,11 @@ pub fn baremount( ) -> Result<()> { let logger = logger.new(o!("subsystem" => "baremount")); - if source.is_empty() { + if source.as_os_str().is_empty() { return Err(anyhow!("need mount source")); } - if destination.is_empty() { + if destination.as_os_str().is_empty() { return Err(anyhow!("need mount destination")); } @@ -444,16 +444,19 @@ fn mount_storage(logger: &Logger, storage: &Storage) -> Result<()> { let options_vec = options_vec.iter().map(String::as_str).collect(); let (flags, options) = parse_mount_flags_and_options(options_vec); + let source = Path::new(&storage.source); + let mount_point = Path::new(&storage.mount_point); + info!(logger, "mounting storage"; - "mount-source:" => storage.source.as_str(), - "mount-destination" => storage.mount_point.as_str(), + "mount-source" => source.display(), + "mount-destination" => mount_point.display(), "mount-fstype" => storage.fstype.as_str(), "mount-options" => options.as_str(), ); baremount( - storage.source.as_str(), - storage.mount_point.as_str(), + source, + mount_point, storage.fstype.as_str(), flags, options.as_str(), @@ -579,7 +582,10 @@ fn mount_to_rootfs(logger: &Logger, m: &InitMount) -> Result<()> { fs::create_dir_all(Path::new(m.dest)).context("could not create directory")?; - baremount(m.src, m.dest, m.fstype, flags, &options, logger).or_else(|e| { + let source = Path::new(m.src); + let dest = Path::new(m.dest); + + baremount(source, dest, m.fstype, flags, &options, logger).or_else(|e| { if m.src != "dev" { return Err(e); } @@ -622,8 +628,7 @@ pub fn get_mount_fs_type_from_file(mount_file: &str, mount_point: &str) -> Resul let file = File::open(mount_file)?; let reader = BufReader::new(file); - let re = Regex::new(format!("device .+ mounted on {} with fstype (.+)", mount_point).as_str()) - .unwrap(); + let re = Regex::new(format!("device .+ mounted on {} with fstype (.+)", mount_point).as_str())?; // Read the file line by line using the lines() iterator from std::io::BufRead. for (_index, line) in reader.lines().enumerate() { @@ -701,20 +706,21 @@ pub fn get_cgroup_mounts( } } - if fields[0].is_empty() { + let subsystem_name = fields[0]; + + if subsystem_name.is_empty() { continue; } - if fields[0] == "devices" { + if subsystem_name == "devices" { has_device_cgroup = true; } - if let Some(value) = CGROUPS.get(&fields[0]) { - let key = CGROUPS.keys().find(|&&f| f == fields[0]).unwrap(); + if let Some((key, value)) = CGROUPS.get_key_value(subsystem_name) { cg_mounts.push(InitMount { fstype: "cgroup", src: "cgroup", - dest: *value, + dest: value, options: vec!["nosuid", "nodev", "noexec", "relatime", key], }); } @@ -767,10 +773,9 @@ fn ensure_destination_file_exists(path: &Path) -> Result<()> { return Err(anyhow!("{:?} exists but is not a regular file", path)); } - // The only way parent() can return None is if the path is /, - // which always exists, so the test above will already have caught - // it, thus the unwrap() is safe - let dir = path.parent().unwrap(); + let dir = path + .parent() + .ok_or_else(|| anyhow!("failed to find parent path for {:?}", path))?; fs::create_dir_all(dir).context(format!("create_dir_all {:?}", dir))?; @@ -937,14 +942,10 @@ mod tests { std::fs::create_dir_all(d).expect("failed to created directory"); } - let result = baremount( - &src_filename, - &dest_filename, - d.fs_type, - d.flags, - d.options, - &logger, - ); + let src = Path::new(&src_filename); + let dest = Path::new(&dest_filename); + + let result = baremount(src, dest, d.fs_type, d.flags, d.options, &logger); let msg = format!("{}: result: {:?}", msg, result); @@ -1021,15 +1022,11 @@ mod tests { .unwrap_or_else(|_| panic!("failed to create directory {}", d)); } + let src = Path::new(mnt_src_filename); + let dest = Path::new(mnt_dest_filename); + // Create an actual mount - let result = baremount( - mnt_src_filename, - mnt_dest_filename, - "bind", - MsFlags::MS_BIND, - "", - &logger, - ); + let result = baremount(src, dest, "bind", MsFlags::MS_BIND, "", &logger); assert!(result.is_ok(), "mount for test setup failed"); let tests = &[ diff --git a/src/agent/src/namespace.rs b/src/agent/src/namespace.rs index 061370a46..c821a0acb 100644 --- a/src/agent/src/namespace.rs +++ b/src/agent/src/namespace.rs @@ -104,7 +104,10 @@ impl Namespace { if let Err(err) = || -> Result<()> { let origin_ns_path = get_current_thread_ns_path(ns_type.get()); - File::open(Path::new(&origin_ns_path))?; + let source = Path::new(&origin_ns_path); + let destination = new_ns_path.as_path(); + + File::open(&source)?; // Create a new netns on the current thread. let cf = ns_type.get_flags(); @@ -115,8 +118,6 @@ impl Namespace { nix::unistd::sethostname(hostname.unwrap())?; } // Bind mount the new namespace from the current thread onto the mount point to persist it. - let source: &str = origin_ns_path.as_str(); - let destination: &str = new_ns_path.as_path().to_str().unwrap_or("none"); let mut flags = MsFlags::empty(); @@ -131,7 +132,7 @@ impl Namespace { baremount(source, destination, "none", flags, "", &logger).map_err(|e| { anyhow!( - "Failed to mount {} to {} with err:{:?}", + "Failed to mount {:?} to {:?} with err:{:?}", source, destination, e diff --git a/src/agent/src/rpc.rs b/src/agent/src/rpc.rs index 778da07fe..87cdc4f84 100644 --- a/src/agent/src/rpc.rs +++ b/src/agent/src/rpc.rs @@ -111,11 +111,18 @@ pub struct AgentService { // ^[a-zA-Z0-9][a-zA-Z0-9_.-]+$ // fn verify_cid(id: &str) -> Result<()> { - let valid = id.len() > 1 - && id.chars().next().unwrap().is_alphanumeric() - && id - .chars() - .all(|c| (c.is_alphanumeric() || ['.', '-', '_'].contains(&c))); + let mut chars = id.chars(); + + let valid = match chars.next() { + Some(first) + if first.is_alphanumeric() + && id.len() > 1 + && chars.all(|c| c.is_alphanumeric() || ['.', '-', '_'].contains(&c)) => + { + true + } + _ => false, + }; match valid { true => Ok(()), @@ -176,7 +183,7 @@ impl AgentService { update_device_cgroup(&mut oci)?; // Append guest hooks - append_guest_hooks(&s, &mut oci); + append_guest_hooks(&s, &mut oci)?; // write spec to bundle path, hooks might // read ocispec @@ -198,21 +205,14 @@ impl AgentService { LinuxContainer::new(cid.as_str(), CONTAINER_BASE, opts, &sl!())?; let pipe_size = AGENT_CONFIG.read().await.container_pipe_size; - let p = if oci.process.is_some() { - Process::new( - &sl!(), - oci.process.as_ref().unwrap(), - cid.as_str(), - true, - pipe_size, - )? + + let p = if let Some(p) = oci.process { + Process::new(&sl!(), &p, cid.as_str(), true, pipe_size)? } else { info!(sl!(), "no process configurations!"); return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL))); }; - ctr.start(p).await?; - s.update_shared_pidns(&ctr)?; s.add_container(ctr); info!(sl!(), "created container!"); @@ -234,11 +234,17 @@ impl AgentService { ctr.exec()?; + if sid == cid { + return Ok(()); + } + // start oom event loop - if sid != cid && ctr.cgroup_manager.is_some() { - let cg_path = ctr.cgroup_manager.as_ref().unwrap().get_cg_path("memory"); - if cg_path.is_some() { - let rx = notifier::notify_oom(cid.as_str(), cg_path.unwrap()).await?; + if let Some(ref ctr) = ctr.cgroup_manager { + let cg_path = ctr.get_cg_path("memory"); + + if let Some(cg_path) = cg_path { + let rx = notifier::notify_oom(cid.as_str(), cg_path.to_string()).await?; + s.run_oom_event_monitor(rx, cid.clone()).await; } } @@ -338,14 +344,13 @@ impl AgentService { let s = self.sandbox.clone(); let mut sandbox = s.lock().await; - let process = if req.process.is_some() { - req.process.as_ref().unwrap() - } else { - return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL))); - }; + let process = req + .process + .into_option() + .ok_or_else(|| anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL)))?; let pipe_size = AGENT_CONFIG.read().await.container_pipe_size; - let ocip = rustjail::process_grpc_to_oci(process); + let ocip = rustjail::process_grpc_to_oci(&process); let p = Process::new(&sl!(), &ocip, exec_id.as_str(), false, pipe_size)?; let ctr = sandbox @@ -378,7 +383,12 @@ impl AgentService { let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), init)?; - let mut signal = Signal::try_from(req.signal as i32).unwrap(); + let mut signal = Signal::try_from(req.signal as i32).map_err(|e| { + anyhow!(e).context(format!( + "failed to convert {:?} to signal (container-id: {}, exec-id: {})", + req.signal, cid, eid + )) + })?; // For container initProcess, if it hasn't installed handler for "SIGTERM" signal, // it will ignore the "SIGTERM" signal sent to it, thus send it "SIGKILL" signal @@ -437,7 +447,11 @@ impl AgentService { Some(p) => p, None => { // Lost race, pick up exit code from channel - resp.status = exit_recv.recv().await.unwrap(); + resp.status = exit_recv + .recv() + .await + .ok_or_else(|| anyhow!("Failed to receive exit code"))?; + return Ok(resp); } }; @@ -479,7 +493,7 @@ impl AgentService { } }; - let writer = writer.unwrap(); + let writer = writer.ok_or_else(|| anyhow!("cannot get writer"))?; writer.lock().await.write_all(req.data.as_slice()).await?; let mut resp = WriteStreamResponse::new(); @@ -521,7 +535,7 @@ impl AgentService { return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL))); } - let reader = reader.unwrap(); + let reader = reader.ok_or_else(|| anyhow!("cannot get stream reader"))?; tokio::select! { _ = term_exit_notifier.notified() => { @@ -639,8 +653,8 @@ impl protocols::agent_ttrpc::AgentService for AgentService { let resp = Empty::new(); - if res.is_some() { - let oci_res = rustjail::resources_grpc_to_oci(&res.unwrap()); + if let Some(res) = res.as_ref() { + let oci_res = rustjail::resources_grpc_to_oci(res); match ctr.set(oci_res) { Err(e) => { return Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())); @@ -800,25 +814,24 @@ impl protocols::agent_ttrpc::AgentService for AgentService { ) })?; - if p.term_master.is_none() { + if let Some(fd) = p.term_master { + unsafe { + let win = winsize { + ws_row: req.row as c_ushort, + ws_col: req.column as c_ushort, + ws_xpixel: 0, + ws_ypixel: 0, + }; + + let err = libc::ioctl(fd, TIOCSWINSZ, &win); + Errno::result(err).map(drop).map_err(|e| { + ttrpc_error(ttrpc::Code::INTERNAL, format!("ioctl error: {:?}", e)) + })?; + } + } else { return Err(ttrpc_error(ttrpc::Code::UNAVAILABLE, "no tty".to_string())); } - let fd = p.term_master.unwrap(); - unsafe { - let win = winsize { - ws_row: req.row as c_ushort, - ws_col: req.column as c_ushort, - ws_xpixel: 0, - ws_ypixel: 0, - }; - - let err = libc::ioctl(fd, TIOCSWINSZ, &win); - Errno::result(err) - .map(drop) - .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, format!("ioctl error: {:?}", e)))?; - } - Ok(Empty::new()) } @@ -1020,12 +1033,25 @@ impl protocols::agent_ttrpc::AgentService for AgentService { let mut sandbox = s.lock().await; // destroy all containers, clean up, notify agent to exit // etc. - sandbox.destroy().await.unwrap(); + sandbox + .destroy() + .await + .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))?; // Close get_oom_event connection, // otherwise it will block the shutdown of ttrpc. sandbox.event_tx.take(); - sandbox.sender.take().unwrap().send(1).unwrap(); + sandbox + .sender + .take() + .ok_or_else(|| { + ttrpc_error( + ttrpc::Code::INTERNAL, + "failed to get sandbox sender channel".to_string(), + ) + })? + .send(1) + .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))?; Ok(Empty::new()) } @@ -1284,11 +1310,7 @@ fn get_memory_info(block_size: bool, hotplug: bool) -> Result<(u64, bool)> { match stat::stat(SYSFS_MEMORY_HOTPLUG_PROBE_PATH) { Ok(_) => plug = true, Err(e) => { - info!( - sl!(), - "hotplug memory error: {}", - e.as_errno().unwrap().desc() - ); + info!(sl!(), "hotplug memory error: {:?}", e); match e { nix::Error::Sys(errno) => match errno { Errno::ENOENT => plug = false, @@ -1364,7 +1386,7 @@ fn find_process<'a>( ctr.get_process(eid).map_err(|_| anyhow!("Invalid exec id")) } -pub fn start(s: Arc>, server_address: &str) -> TtrpcServer { +pub fn start(s: Arc>, server_address: &str) -> Result { let agent_service = Box::new(AgentService { sandbox: s }) as Box; @@ -1379,14 +1401,13 @@ pub fn start(s: Arc>, server_address: &str) -> TtrpcServer { let hservice = protocols::health_ttrpc::create_health(health_worker); let server = TtrpcServer::new() - .bind(server_address) - .unwrap() + .bind(server_address)? .register_service(aservice) .register_service(hservice); info!(sl!(), "ttRPC server started"; "address" => server_address); - server + Ok(server) } // This function updates the container namespaces configuration based on the @@ -1431,24 +1452,28 @@ fn update_container_namespaces( // the create_sandbox request or create_container request. // Else set this to empty string so that a new pid namespace is // created for the container. - if sandbox_pidns && sandbox.sandbox_pidns.is_some() { - pid_ns.path = String::from(sandbox.sandbox_pidns.as_ref().unwrap().path.as_str()); + if sandbox_pidns { + if let Some(ref pidns) = &sandbox.sandbox_pidns { + pid_ns.path = String::from(pidns.path.as_str()); + } else { + return Err(anyhow!("failed to get sandbox pidns")); + } } linux.namespaces.push(pid_ns); Ok(()) } -fn append_guest_hooks(s: &Sandbox, oci: &mut Spec) { - if s.hooks.is_none() { - return; +fn append_guest_hooks(s: &Sandbox, oci: &mut Spec) -> Result<()> { + if let Some(ref guest_hooks) = s.hooks { + let mut hooks = oci.hooks.take().unwrap_or_default(); + hooks.prestart.append(&mut guest_hooks.prestart.clone()); + hooks.poststart.append(&mut guest_hooks.poststart.clone()); + hooks.poststop.append(&mut guest_hooks.poststop.clone()); + oci.hooks = Some(hooks); } - let guest_hooks = s.hooks.as_ref().unwrap(); - let mut hooks = oci.hooks.take().unwrap_or_default(); - hooks.prestart.append(&mut guest_hooks.prestart.clone()); - hooks.poststart.append(&mut guest_hooks.poststart.clone()); - hooks.poststop.append(&mut guest_hooks.poststop.clone()); - oci.hooks = Some(hooks); + + Ok(()) } // Check is the container process installed the @@ -1538,7 +1563,7 @@ fn do_copy_file(req: &CopyFileRequest) -> Result<()> { PathBuf::from("/") }; - fs::create_dir_all(dir.to_str().unwrap()).or_else(|e| { + fs::create_dir_all(&dir).or_else(|e| { if e.kind() != std::io::ErrorKind::AlreadyExists { return Err(e); } @@ -1546,10 +1571,7 @@ fn do_copy_file(req: &CopyFileRequest) -> Result<()> { Ok(()) })?; - std::fs::set_permissions( - dir.to_str().unwrap(), - std::fs::Permissions::from_mode(req.dir_mode), - )?; + std::fs::set_permissions(&dir, std::fs::Permissions::from_mode(req.dir_mode))?; let mut tmpfile = path.clone(); tmpfile.set_extension("tmp"); @@ -1558,10 +1580,10 @@ fn do_copy_file(req: &CopyFileRequest) -> Result<()> { .write(true) .create(true) .truncate(false) - .open(tmpfile.to_str().unwrap())?; + .open(&tmpfile)?; file.write_all_at(req.data.as_slice(), req.offset as u64)?; - let st = stat::stat(tmpfile.to_str().unwrap())?; + let st = stat::stat(&tmpfile)?; if st.st_size != req.file_size { return Ok(()); @@ -1570,7 +1592,7 @@ fn do_copy_file(req: &CopyFileRequest) -> Result<()> { file.set_permissions(std::fs::Permissions::from_mode(req.file_mode))?; unistd::chown( - tmpfile.to_str().unwrap(), + &tmpfile, Some(Uid::from_raw(req.uid as u32)), Some(Gid::from_raw(req.gid as u32)), )?; @@ -1607,10 +1629,13 @@ async fn do_add_swap(sandbox: &Arc>, req: &AddSwapRequest) -> Res // - container rootfs bind mounted at ///rootfs // - modify container spec root to point to ///rootfs fn setup_bundle(cid: &str, spec: &mut Spec) -> Result { - if spec.root.is_none() { + let spec_root = if let Some(sr) = &spec.root { + sr + } else { return Err(nix::Error::Sys(Errno::EINVAL).into()); - } - let spec_root = spec.root.as_ref().unwrap(); + }; + + let spec_root_path = Path::new(&spec_root.path); let bundle_path = Path::new(CONTAINER_BASE).join(cid); let config_path = bundle_path.join("config.json"); @@ -1618,22 +1643,36 @@ fn setup_bundle(cid: &str, spec: &mut Spec) -> Result { fs::create_dir_all(&rootfs_path)?; baremount( - &spec_root.path, - rootfs_path.to_str().unwrap(), + spec_root_path, + &rootfs_path, "bind", MsFlags::MS_BIND, "", &sl!(), )?; + + let rootfs_path_name = rootfs_path + .to_str() + .ok_or_else(|| anyhow!("failed to convert rootfs to unicode"))? + .to_string(); + spec.root = Some(Root { - path: rootfs_path.to_str().unwrap().to_owned(), + path: rootfs_path_name, readonly: spec_root.readonly, }); - let _ = spec.save(config_path.to_str().unwrap()); + let _ = spec.save( + config_path + .to_str() + .ok_or_else(|| anyhow!("cannot convert path to unicode"))?, + ); let olddir = unistd::getcwd().context("cannot getcwd")?; - unistd::chdir(bundle_path.to_str().unwrap())?; + unistd::chdir( + bundle_path + .to_str() + .ok_or_else(|| anyhow!("cannot convert bundle path to unicode"))?, + )?; Ok(olddir) } @@ -1666,8 +1705,8 @@ fn load_kernel_module(module: &protocols::agent::KernelModule) -> Result<()> { match status.code() { Some(code) => { - let std_out: String = String::from_utf8(output.stdout).unwrap(); - let std_err: String = String::from_utf8(output.stderr).unwrap(); + let std_out = String::from_utf8_lossy(&output.stdout); + let std_err = String::from_utf8_lossy(&output.stderr); let msg = format!( "load_kernel_module return code: {} stdout:{} stderr:{}", code, std_out, std_err @@ -1730,7 +1769,7 @@ mod tests { let mut oci = Spec { ..Default::default() }; - append_guest_hooks(&s, &mut oci); + append_guest_hooks(&s, &mut oci).unwrap(); assert_eq!(s.hooks, oci.hooks); } diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index ddc18c5c9..526464014 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -458,10 +458,14 @@ mod tests { use slog::Logger; use std::fs::{self, File}; use std::os::unix::fs::PermissionsExt; + use std::path::Path; use tempfile::Builder; fn bind_mount(src: &str, dst: &str, logger: &Logger) -> Result<(), Error> { - baremount(src, dst, "bind", MsFlags::MS_BIND, "", logger) + let src_path = Path::new(src); + let dst_path = Path::new(dst); + + baremount(src_path, dst_path, "bind", MsFlags::MS_BIND, "", logger) } use serial_test::serial; diff --git a/src/agent/src/util.rs b/src/agent/src/util.rs index 0e262e7ee..1be52f730 100644 --- a/src/agent/src/util.rs +++ b/src/agent/src/util.rs @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 // -use anyhow::Result; +use anyhow::{anyhow, Result}; use futures::StreamExt; use std::io; use std::io::ErrorKind; @@ -64,8 +64,12 @@ pub fn get_vsock_incoming(fd: RawFd) -> Incoming { #[instrument] pub async fn get_vsock_stream(fd: RawFd) -> Result { - let stream = get_vsock_incoming(fd).next().await.unwrap()?; - Ok(stream) + let stream = get_vsock_incoming(fd) + .next() + .await + .ok_or_else(|| anyhow!("cannot handle incoming vsock connection"))?; + + Ok(stream?) } #[cfg(test)] @@ -124,7 +128,9 @@ mod tests { let mut vec_locked = vec_ref.lock(); - let v = vec_locked.as_deref_mut().unwrap(); + let v = vec_locked + .as_deref_mut() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; std::io::Write::flush(v) } diff --git a/src/agent/src/watcher.rs b/src/agent/src/watcher.rs index bb3fb1641..b3cd3f832 100644 --- a/src/agent/src/watcher.rs +++ b/src/agent/src/watcher.rs @@ -366,8 +366,8 @@ impl SandboxStorages { } match baremount( - entry.source_mount_point.to_str().unwrap(), - entry.target_mount_point.to_str().unwrap(), + entry.source_mount_point.as_path(), + entry.target_mount_point.as_path(), "bind", MsFlags::MS_BIND, "bind", @@ -477,8 +477,8 @@ impl BindWatcher { fs::create_dir_all(WATCH_MOUNT_POINT_PATH).await?; baremount( - "tmpfs", - WATCH_MOUNT_POINT_PATH, + Path::new("tmpfs"), + Path::new(WATCH_MOUNT_POINT_PATH), "tmpfs", MsFlags::empty(), "", diff --git a/tools/agent-ctl/src/main.rs b/tools/agent-ctl/src/main.rs index 00519c932..88c12e984 100644 --- a/tools/agent-ctl/src/main.rs +++ b/tools/agent-ctl/src/main.rs @@ -134,16 +134,14 @@ fn make_examples_text(program_name: &str) -> String { fn connect(name: &str, global_args: clap::ArgMatches) -> Result<()> { let args = global_args .subcommand_matches("connect") - .ok_or("BUG: missing sub-command arguments".to_string()) - .map_err(|e| anyhow!(e))?; + .ok_or_else(|| anyhow!("BUG: missing sub-command arguments"))?; let interactive = args.is_present("interactive"); let ignore_errors = args.is_present("ignore-errors"); let server_address = args .value_of("server-address") - .ok_or("need server adddress".to_string()) - .map_err(|e| anyhow!(e))? + .ok_or_else(|| anyhow!("need server adddress"))? .to_string(); let mut commands: Vec<&str> = Vec::new(); @@ -151,13 +149,13 @@ fn connect(name: &str, global_args: clap::ArgMatches) -> Result<()> { if !interactive { commands = args .values_of("cmd") - .ok_or("need commands to send to the server".to_string()) - .map_err(|e| anyhow!(e))? + .ok_or_else(|| anyhow!("need commands to send to the server"))? .collect(); } - // Cannot fail as a default has been specified - let log_level_name = global_args.value_of("log-level").unwrap(); + let log_level_name = global_args + .value_of("log-level") + .ok_or_else(|| anyhow!("cannot get log level"))?; let log_level = logging::level_name_to_slog_level(log_level_name).map_err(|e| anyhow!(e))?; @@ -169,10 +167,10 @@ fn connect(name: &str, global_args: clap::ArgMatches) -> Result<()> { None => 0, }; - let hybrid_vsock_port: u64 = args + let hybrid_vsock_port = args .value_of("hybrid-vsock-port") - .ok_or("Need Hybrid VSOCK port number") - .map(|p| p.parse::().unwrap()) + .ok_or_else(|| anyhow!("Need Hybrid VSOCK port number"))? + .parse::() .map_err(|e| anyhow!("VSOCK port number must be an integer: {:?}", e))?; let bundle_dir = args.value_of("bundle-dir").unwrap_or("").to_string(); @@ -218,7 +216,7 @@ fn real_main() -> Result<()> { .long("log-level") .short("l") .help("specific log level") - .default_value(logging::slog_level_to_level_name(DEFAULT_LOG_LEVEL).unwrap()) + .default_value(logging::slog_level_to_level_name(DEFAULT_LOG_LEVEL).map_err(|e| anyhow!(e))?) .possible_values(&logging::get_log_levels()) .takes_value(true) .required(false), @@ -304,35 +302,29 @@ fn real_main() -> Result<()> { let subcmd = args .subcommand_name() - .ok_or("need sub-command".to_string()) - .map_err(|e| anyhow!(e))?; + .ok_or_else(|| anyhow!("need sub-command"))?; match subcmd { "generate-cid" => { println!("{}", utils::random_container_id()); - return Ok(()); + Ok(()) } "generate-sid" => { println!("{}", utils::random_sandbox_id()); - return Ok(()); + Ok(()) } "examples" => { println!("{}", make_examples_text(name)); - return Ok(()); - } - "connect" => { - return connect(name, args); + Ok(()) } + "connect" => connect(name, args), _ => return Err(anyhow!(format!("invalid sub-command: {:?}", subcmd))), } } fn main() { - match real_main() { - Err(e) => { - eprintln!("ERROR: {}", e); - exit(1); - } - _ => (), - }; + if let Err(e) = real_main() { + eprintln!("ERROR: {}", e); + exit(1); + } }