diff --git a/src/agent/src/config.rs b/src/agent/src/config.rs index 759b696c4e..0bdf20c2dc 100644 --- a/src/agent/src/config.rs +++ b/src/agent/src/config.rs @@ -432,6 +432,8 @@ fn get_container_pipe_size(param: &str) -> Result { #[cfg(test)] mod tests { + use crate::assert_result; + use super::*; use anyhow::anyhow; use std::fs::File; @@ -439,32 +441,6 @@ mod tests { use std::time; use tempfile::tempdir; - // Parameters: - // - // 1: expected Result - // 2: actual Result - // 3: string used to identify the test on error - macro_rules! assert_result { - ($expected_result:expr, $actual_result:expr, $msg:expr) => { - if $expected_result.is_ok() { - let expected_level = $expected_result.as_ref().unwrap(); - let actual_level = $actual_result.unwrap(); - assert!(*expected_level == actual_level, "{}", $msg); - } else { - let expected_error = $expected_result.as_ref().unwrap_err(); - let expected_error_msg = format!("{:?}", expected_error); - - if let Err(actual_error) = $actual_result { - let actual_error_msg = format!("{:?}", actual_error); - - assert!(expected_error_msg == actual_error_msg, "{}", $msg); - } else { - assert!(expected_error_msg == "expected error, got OK", "{}", $msg); - } - } - }; - } - #[test] fn test_new() { let config: AgentConfig = Default::default(); diff --git a/src/agent/src/rpc.rs b/src/agent/src/rpc.rs index 765a3a6bf1..2d36413d29 100644 --- a/src/agent/src/rpc.rs +++ b/src/agent/src/rpc.rs @@ -82,6 +82,11 @@ use std::path::PathBuf; const CONTAINER_BASE: &str = "/run/kata-containers"; const MODPROBE_PATH: &str = "/sbin/modprobe"; +const ERR_CANNOT_GET_WRITER: &str = "Cannot get writer"; +const ERR_INVALID_BLOCK_SIZE: &str = "Invalid block size"; +const ERR_NO_LINUX_FIELD: &str = "Spec does not contain linux field"; +const ERR_NO_SANDBOX_PIDNS: &str = "Sandbox does not have sandbox_pidns"; + // Convenience macro to obtain the scope logger macro_rules! sl { () => { @@ -401,7 +406,8 @@ impl AgentService { // For container initProcess, if it hasn't installed handler for "SIGTERM" signal, // it will ignore the "SIGTERM" signal sent to it, thus send it "SIGKILL" signal // instead of "SIGTERM" to terminate it. - if p.init && sig == libc::SIGTERM && !is_signal_handled(p.pid, sig as u32) { + let proc_status_file = format!("/proc/{}/status", p.pid); + if p.init && sig == libc::SIGTERM && !is_signal_handled(&proc_status_file, sig as u32) { sig = libc::SIGKILL; } p.signal(sig)?; @@ -572,7 +578,7 @@ impl AgentService { } }; - let writer = writer.ok_or_else(|| anyhow!("cannot get writer"))?; + let writer = writer.ok_or_else(|| anyhow!(ERR_CANNOT_GET_WRITER))?; writer.lock().await.write_all(req.data.as_slice()).await?; let mut resp = WriteStreamResponse::new(); @@ -1216,7 +1222,12 @@ impl protocols::agent_ttrpc::AgentService for AgentService { info!(sl!(), "get guest details!"); let mut resp = GuestDetailsResponse::new(); // to get memory block size - match get_memory_info(req.mem_block_size, req.mem_hotplug_probe) { + match get_memory_info( + req.mem_block_size, + req.mem_hotplug_probe, + SYSFS_MEMORY_BLOCK_SIZE_PATH, + SYSFS_MEMORY_HOTPLUG_PROBE_PATH, + ) { Ok((u, v)) => { resp.mem_block_size_bytes = u; resp.support_mem_hotplug_probe = v; @@ -1405,24 +1416,29 @@ impl protocols::health_ttrpc::Health for HealthService { } } -fn get_memory_info(block_size: bool, hotplug: bool) -> Result<(u64, bool)> { +fn get_memory_info( + block_size: bool, + hotplug: bool, + block_size_path: &str, + hotplug_probe_path: &str, +) -> Result<(u64, bool)> { let mut size: u64 = 0; let mut plug: bool = false; if block_size { - match fs::read_to_string(SYSFS_MEMORY_BLOCK_SIZE_PATH) { + match fs::read_to_string(block_size_path) { Ok(v) => { if v.is_empty() { - info!(sl!(), "string in empty???"); - return Err(anyhow!("Invalid block size")); + warn!(sl!(), "file {} is empty", block_size_path); + return Err(anyhow!(ERR_INVALID_BLOCK_SIZE)); } size = u64::from_str_radix(v.trim(), 16).map_err(|_| { warn!(sl!(), "failed to parse the str {} to hex", size); - anyhow!("Invalid block size") + anyhow!(ERR_INVALID_BLOCK_SIZE) })?; } Err(e) => { - info!(sl!(), "memory block size error: {:?}", e.kind()); + warn!(sl!(), "memory block size error: {:?}", e.kind()); if e.kind() != std::io::ErrorKind::NotFound { return Err(anyhow!(e)); } @@ -1431,10 +1447,10 @@ fn get_memory_info(block_size: bool, hotplug: bool) -> Result<(u64, bool)> { } if hotplug { - match stat::stat(SYSFS_MEMORY_HOTPLUG_PROBE_PATH) { + match stat::stat(hotplug_probe_path) { Ok(_) => plug = true, Err(e) => { - info!(sl!(), "hotplug memory error: {:?}", e); + warn!(sl!(), "hotplug memory error: {:?}", e); match e { nix::Error::ENOENT => plug = false, _ => return Err(anyhow!(e)), @@ -1555,7 +1571,7 @@ fn update_container_namespaces( let linux = spec .linux .as_mut() - .ok_or_else(|| anyhow!("Spec didn't container linux field"))?; + .ok_or_else(|| anyhow!(ERR_NO_LINUX_FIELD))?; let namespaces = linux.namespaces.as_mut_slice(); for namespace in namespaces.iter_mut() { @@ -1582,7 +1598,7 @@ fn update_container_namespaces( if let Some(ref pidns) = &sandbox.sandbox_pidns { pid_ns.path = String::from(pidns.path.as_str()); } else { - return Err(anyhow!("failed to get sandbox pidns")); + return Err(anyhow!(ERR_NO_SANDBOX_PIDNS)); } } @@ -1602,21 +1618,33 @@ fn append_guest_hooks(s: &Sandbox, oci: &mut Spec) -> Result<()> { Ok(()) } -// Check is the container process installed the +// Check if the container process installed the // handler for specific signal. -fn is_signal_handled(pid: pid_t, signum: u32) -> bool { - let sig_mask: u64 = 1u64 << (signum - 1); - let file_name = format!("/proc/{}/status", pid); +fn is_signal_handled(proc_status_file: &str, signum: u32) -> bool { + let shift_count: u64 = if signum == 0 { + // signum 0 is used to check for process liveness. + // Since that signal is not part of the mask in the file, we only need + // to know if the file (and therefore) process exists to handle + // that signal. + return fs::metadata(proc_status_file).is_ok(); + } else if signum > 64 { + // Ensure invalid signum won't break bit shift logic + warn!(sl!(), "received invalid signum {}", signum); + return false; + } else { + (signum - 1).into() + }; // Open the file in read-only mode (ignoring errors). - let file = match File::open(&file_name) { + let file = match File::open(proc_status_file) { Ok(f) => f, Err(_) => { - warn!(sl!(), "failed to open file {}\n", file_name); + warn!(sl!(), "failed to open file {}", proc_status_file); return false; } }; + let sig_mask: u64 = 1 << shift_count; let reader = BufReader::new(file); // Read the file line by line using the lines() iterator from std::io::BufRead. @@ -1624,21 +1652,21 @@ fn is_signal_handled(pid: pid_t, signum: u32) -> bool { let line = match line { Ok(l) => l, Err(_) => { - warn!(sl!(), "failed to read file {}\n", file_name); + warn!(sl!(), "failed to read file {}", proc_status_file); return false; } }; if line.starts_with("SigCgt:") { let mask_vec: Vec<&str> = line.split(':').collect(); if mask_vec.len() != 2 { - warn!(sl!(), "parse the SigCgt field failed\n"); + warn!(sl!(), "parse the SigCgt field failed"); return false; } - let sig_cgt_str = mask_vec[1]; + let sig_cgt_str = mask_vec[1].trim(); let sig_cgt_mask = match u64::from_str_radix(sig_cgt_str, 16) { Ok(h) => h, Err(_) => { - warn!(sl!(), "failed to parse the str {} to hex\n", sig_cgt_str); + warn!(sl!(), "failed to parse the str {} to hex", sig_cgt_str); return false; } }; @@ -1846,9 +1874,13 @@ fn load_kernel_module(module: &protocols::agent::KernelModule) -> Result<()> { #[cfg(test)] mod tests { use super::*; - use crate::{protocols::agent_ttrpc::AgentService as _, skip_if_not_root}; + use crate::{ + assert_result, namespace::Namespace, protocols::agent_ttrpc::AgentService as _, + skip_if_not_root, + }; use nix::mount; - use oci::{Hook, Hooks}; + use oci::{Hook, Hooks, Linux, LinuxNamespace}; + use tempfile::{tempdir, TempDir}; use ttrpc::{r#async::TtrpcContext, MessageHeader}; fn mk_ttrpc_context() -> TtrpcContext { @@ -1860,6 +1892,44 @@ mod tests { } } + fn create_dummy_opts() -> CreateOpts { + let root = Root { + path: String::from("/"), + ..Default::default() + }; + + let spec = Spec { + linux: Some(oci::Linux::default()), + root: Some(root), + ..Default::default() + }; + + CreateOpts { + cgroup_name: "".to_string(), + use_systemd_cgroup: false, + no_pivot_root: false, + no_new_keyring: false, + spec: Some(spec), + rootless_euid: false, + rootless_cgroup: false, + } + } + + fn create_linuxcontainer() -> (LinuxContainer, TempDir) { + let dir = tempdir().expect("failed to make tempdir"); + + ( + LinuxContainer::new( + "some_id", + dir.path().join("rootfs").to_str().unwrap(), + create_dummy_opts(), + &slog_scope::logger(), + ) + .unwrap(), + dir, + ) + } + #[test] fn test_load_kernel_module() { let mut m = protocols::agent::KernelModule { @@ -1952,6 +2022,511 @@ mod tests { assert!(result.is_err(), "expected add arp neighbors to fail"); } + #[tokio::test] + async fn test_do_write_stream() { + #[derive(Debug)] + struct TestData<'a> { + create_container: bool, + has_fd: bool, + has_tty: bool, + break_pipe: bool, + + container_id: &'a str, + exec_id: &'a str, + data: Vec, + result: Result, + } + + impl Default for TestData<'_> { + fn default() -> Self { + TestData { + create_container: true, + has_fd: true, + has_tty: true, + break_pipe: false, + + container_id: "1", + exec_id: "2", + data: vec![1, 2, 3], + result: Ok(WriteStreamResponse { + len: 3, + ..WriteStreamResponse::default() + }), + } + } + } + + let tests = &[ + TestData { + ..Default::default() + }, + TestData { + has_tty: false, + ..Default::default() + }, + TestData { + break_pipe: true, + result: Err(anyhow!(std::io::Error::from_raw_os_error(libc::EPIPE))), + ..Default::default() + }, + TestData { + create_container: false, + result: Err(anyhow!(crate::sandbox::ERR_INVALID_CONTAINER_ID)), + ..Default::default() + }, + TestData { + container_id: "8181", + result: Err(anyhow!(crate::sandbox::ERR_INVALID_CONTAINER_ID)), + ..Default::default() + }, + TestData { + data: vec![], + result: Ok(WriteStreamResponse { + len: 0, + ..WriteStreamResponse::default() + }), + ..Default::default() + }, + TestData { + has_fd: false, + result: Err(anyhow!(ERR_CANNOT_GET_WRITER)), + ..Default::default() + }, + ]; + + for (i, d) in tests.iter().enumerate() { + let msg = format!("test[{}]: {:?}", i, d); + + let logger = slog::Logger::root(slog::Discard, o!()); + let mut sandbox = Sandbox::new(&logger).unwrap(); + + let (rfd, wfd) = unistd::pipe().unwrap(); + if d.break_pipe { + unistd::close(rfd).unwrap(); + } + + if d.create_container { + let (mut linux_container, _root) = create_linuxcontainer(); + let exec_process_id = 2; + + linux_container.id = "1".to_string(); + + let mut exec_process = Process::new( + &logger, + &oci::Process::default(), + &exec_process_id.to_string(), + false, + 1, + ) + .unwrap(); + + let fd = { + if d.has_fd { + Some(wfd) + } else { + None + } + }; + + if d.has_tty { + exec_process.parent_stdin = None; + exec_process.term_master = fd; + } else { + exec_process.parent_stdin = fd; + exec_process.term_master = None; + } + linux_container + .processes + .insert(exec_process_id, exec_process); + + sandbox.add_container(linux_container); + } + + let agent_service = Box::new(AgentService { + sandbox: Arc::new(Mutex::new(sandbox)), + }); + + let result = agent_service + .do_write_stream(protocols::agent::WriteStreamRequest { + container_id: d.container_id.to_string(), + exec_id: d.exec_id.to_string(), + data: d.data.clone(), + ..Default::default() + }) + .await; + + if !d.break_pipe { + unistd::close(rfd).unwrap(); + } + unistd::close(wfd).unwrap(); + + let msg = format!("{}, result: {:?}", msg, result); + assert_result!(d.result, result, msg); + } + } + + #[tokio::test] + async fn test_update_container_namespaces() { + #[derive(Debug)] + struct TestData<'a> { + has_linux_in_spec: bool, + sandbox_pidns_path: Option<&'a str>, + + namespaces: Vec, + use_sandbox_pidns: bool, + result: Result<()>, + expected_namespaces: Vec, + } + + impl Default for TestData<'_> { + fn default() -> Self { + TestData { + has_linux_in_spec: true, + sandbox_pidns_path: Some("sharedpidns"), + namespaces: vec![ + LinuxNamespace { + r#type: NSTYPEIPC.to_string(), + path: "ipcpath".to_string(), + }, + LinuxNamespace { + r#type: NSTYPEUTS.to_string(), + path: "utspath".to_string(), + }, + ], + use_sandbox_pidns: false, + result: Ok(()), + expected_namespaces: vec![ + LinuxNamespace { + r#type: NSTYPEIPC.to_string(), + path: "".to_string(), + }, + LinuxNamespace { + r#type: NSTYPEUTS.to_string(), + path: "".to_string(), + }, + LinuxNamespace { + r#type: NSTYPEPID.to_string(), + path: "".to_string(), + }, + ], + } + } + } + + let tests = &[ + TestData { + ..Default::default() + }, + TestData { + use_sandbox_pidns: true, + expected_namespaces: vec![ + LinuxNamespace { + r#type: NSTYPEIPC.to_string(), + path: "".to_string(), + }, + LinuxNamespace { + r#type: NSTYPEUTS.to_string(), + path: "".to_string(), + }, + LinuxNamespace { + r#type: NSTYPEPID.to_string(), + path: "sharedpidns".to_string(), + }, + ], + ..Default::default() + }, + TestData { + namespaces: vec![], + use_sandbox_pidns: true, + expected_namespaces: vec![LinuxNamespace { + r#type: NSTYPEPID.to_string(), + path: "sharedpidns".to_string(), + }], + ..Default::default() + }, + TestData { + namespaces: vec![], + use_sandbox_pidns: false, + expected_namespaces: vec![LinuxNamespace { + r#type: NSTYPEPID.to_string(), + path: "".to_string(), + }], + ..Default::default() + }, + TestData { + namespaces: vec![], + sandbox_pidns_path: None, + use_sandbox_pidns: true, + result: Err(anyhow!(ERR_NO_SANDBOX_PIDNS)), + expected_namespaces: vec![], + ..Default::default() + }, + TestData { + has_linux_in_spec: false, + result: Err(anyhow!(ERR_NO_LINUX_FIELD)), + ..Default::default() + }, + ]; + + for (i, d) in tests.iter().enumerate() { + let msg = format!("test[{}]: {:?}", i, d); + + let logger = slog::Logger::root(slog::Discard, o!()); + let mut sandbox = Sandbox::new(&logger).unwrap(); + if let Some(pidns_path) = d.sandbox_pidns_path { + let mut sandbox_pidns = Namespace::new(&logger); + sandbox_pidns.path = pidns_path.to_string(); + sandbox.sandbox_pidns = Some(sandbox_pidns); + } + + let mut oci = Spec::default(); + if d.has_linux_in_spec { + oci.linux = Some(Linux { + namespaces: d.namespaces.clone(), + ..Default::default() + }); + } + + let result = update_container_namespaces(&sandbox, &mut oci, d.use_sandbox_pidns); + + let msg = format!("{}, result: {:?}", msg, result); + + assert_result!(d.result, result, msg); + if let Some(linux) = oci.linux { + assert_eq!(d.expected_namespaces, linux.namespaces, "{}", msg); + } + } + } + + #[tokio::test] + async fn test_get_memory_info() { + #[derive(Debug)] + struct TestData<'a> { + // if None is provided, no file will be generated, else the data in the Option will populate the file + block_size_data: Option<&'a str>, + + hotplug_probe_data: bool, + get_block_size: bool, + get_hotplug: bool, + result: Result<(u64, bool)>, + } + + let tests = &[ + TestData { + block_size_data: Some("10000000"), + hotplug_probe_data: true, + get_block_size: true, + get_hotplug: true, + result: Ok((268435456, true)), + }, + TestData { + block_size_data: Some("100"), + hotplug_probe_data: false, + get_block_size: true, + get_hotplug: true, + result: Ok((256, false)), + }, + TestData { + block_size_data: None, + hotplug_probe_data: false, + get_block_size: true, + get_hotplug: true, + result: Ok((0, false)), + }, + TestData { + block_size_data: Some(""), + hotplug_probe_data: false, + get_block_size: true, + get_hotplug: false, + result: Err(anyhow!(ERR_INVALID_BLOCK_SIZE)), + }, + TestData { + block_size_data: Some("-1"), + hotplug_probe_data: false, + get_block_size: true, + get_hotplug: false, + result: Err(anyhow!(ERR_INVALID_BLOCK_SIZE)), + }, + TestData { + block_size_data: Some(" "), + hotplug_probe_data: false, + get_block_size: true, + get_hotplug: false, + result: Err(anyhow!(ERR_INVALID_BLOCK_SIZE)), + }, + TestData { + block_size_data: Some("some data"), + hotplug_probe_data: false, + get_block_size: true, + get_hotplug: false, + result: Err(anyhow!(ERR_INVALID_BLOCK_SIZE)), + }, + TestData { + block_size_data: Some("some data"), + hotplug_probe_data: true, + get_block_size: false, + get_hotplug: false, + result: Ok((0, false)), + }, + TestData { + block_size_data: Some("100"), + hotplug_probe_data: true, + get_block_size: false, + get_hotplug: false, + result: Ok((0, false)), + }, + TestData { + block_size_data: Some("100"), + hotplug_probe_data: true, + get_block_size: false, + get_hotplug: true, + result: Ok((0, true)), + }, + ]; + + for (i, d) in tests.iter().enumerate() { + let msg = format!("test[{}]: {:?}", i, d); + + let dir = tempdir().expect("failed to make tempdir"); + let block_size_path = dir.path().join("block_size_bytes"); + let hotplug_probe_path = dir.path().join("probe"); + + if let Some(block_size_data) = d.block_size_data { + fs::write(&block_size_path, block_size_data).unwrap(); + } + if d.hotplug_probe_data { + fs::write(&hotplug_probe_path, []).unwrap(); + } + + let result = get_memory_info( + d.get_block_size, + d.get_hotplug, + block_size_path.to_str().unwrap(), + hotplug_probe_path.to_str().unwrap(), + ); + + let msg = format!("{}, result: {:?}", msg, result); + + assert_result!(d.result, result, msg); + } + } + + #[tokio::test] + async fn test_is_signal_handled() { + #[derive(Debug)] + struct TestData<'a> { + status_file_data: Option<&'a str>, + signum: u32, + result: bool, + } + + let tests = &[ + TestData { + status_file_data: Some( + r#" +SigBlk:0000000000010000 +SigCgt:0000000000000001 +OtherField:other + "#, + ), + signum: 1, + result: true, + }, + TestData { + status_file_data: Some("SigCgt:000000004b813efb"), + signum: 4, + result: true, + }, + TestData { + status_file_data: Some("SigCgt:\t000000004b813efb"), + signum: 4, + result: true, + }, + TestData { + status_file_data: Some("SigCgt: 000000004b813efb"), + signum: 4, + result: true, + }, + TestData { + status_file_data: Some("SigCgt:000000004b813efb "), + signum: 4, + result: true, + }, + TestData { + status_file_data: Some("SigCgt:\t000000004b813efb "), + signum: 4, + result: true, + }, + TestData { + status_file_data: Some("SigCgt:000000004b813efb"), + signum: 3, + result: false, + }, + TestData { + status_file_data: Some("SigCgt:000000004b813efb"), + signum: 65, + result: false, + }, + TestData { + status_file_data: Some("SigCgt:000000004b813efb"), + signum: 0, + result: true, + }, + TestData { + status_file_data: Some("SigCgt:ZZZZZZZZ"), + signum: 1, + result: false, + }, + TestData { + status_file_data: Some("SigCgt:-1"), + signum: 1, + result: false, + }, + TestData { + status_file_data: Some("SigCgt"), + signum: 1, + result: false, + }, + TestData { + status_file_data: Some("any data"), + signum: 0, + result: true, + }, + TestData { + status_file_data: Some("SigBlk:0000000000000001"), + signum: 1, + result: false, + }, + TestData { + status_file_data: None, + signum: 1, + result: false, + }, + TestData { + status_file_data: None, + signum: 0, + result: false, + }, + ]; + + for (i, d) in tests.iter().enumerate() { + let msg = format!("test[{}]: {:?}", i, d); + + let dir = tempdir().expect("failed to make tempdir"); + let proc_status_file_path = dir.path().join("status"); + + if let Some(file_data) = d.status_file_data { + fs::write(&proc_status_file_path, file_data).unwrap(); + } + + let result = is_signal_handled(proc_status_file_path.to_str().unwrap(), d.signum); + + let msg = format!("{}, result: {:?}", msg, result); + + assert_eq!(d.result, result, "{}", msg); + } + } + #[tokio::test] async fn test_verify_cid() { #[derive(Debug)] diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index 78e2305f21..84cc659298 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -32,6 +32,8 @@ use tokio::sync::oneshot; use tokio::sync::Mutex; use tracing::instrument; +pub const ERR_INVALID_CONTAINER_ID: &str = "Invalid container id"; + type UeventWatcher = (Box, oneshot::Sender); #[derive(Debug)] @@ -237,7 +239,7 @@ impl Sandbox { pub fn find_container_process(&mut self, cid: &str, eid: &str) -> Result<&mut Process> { let ctr = self .get_container(cid) - .ok_or_else(|| anyhow!("Invalid container id"))?; + .ok_or_else(|| anyhow!(ERR_INVALID_CONTAINER_ID))?; if eid.is_empty() { return ctr diff --git a/src/agent/src/test_utils.rs b/src/agent/src/test_utils.rs index f8f2d7ab31..becb845fa7 100644 --- a/src/agent/src/test_utils.rs +++ b/src/agent/src/test_utils.rs @@ -53,4 +53,29 @@ mod test_utils { } }; } + + // Parameters: + // + // 1: expected Result + // 2: actual Result + // 3: string used to identify the test on error + #[macro_export] + macro_rules! assert_result { + ($expected_result:expr, $actual_result:expr, $msg:expr) => { + if $expected_result.is_ok() { + let expected_value = $expected_result.as_ref().unwrap(); + let actual_value = $actual_result.unwrap(); + assert!(*expected_value == actual_value, "{}", $msg); + } else { + assert!($actual_result.is_err(), "{}", $msg); + + let expected_error = $expected_result.as_ref().unwrap_err(); + let expected_error_msg = format!("{:?}", expected_error); + + let actual_error_msg = format!("{:?}", $actual_result.unwrap_err()); + + assert!(expected_error_msg == actual_error_msg, "{}", $msg); + } + }; + } }