diff --git a/src/agent/src/rpc.rs b/src/agent/src/rpc.rs index 1f4729b7d0..2325347160 100644 --- a/src/agent/src/rpc.rs +++ b/src/agent/src/rpc.rs @@ -85,6 +85,8 @@ use std::path::PathBuf; const CONTAINER_BASE: &str = "/run/kata-containers"; const MODPROBE_PATH: &str = "/sbin/modprobe"; +const ERR_INVALID_BLOCK_SIZE: &str = "Invalid block size"; + // Convenience macro to obtain the scope logger macro_rules! sl { () => { @@ -1154,7 +1156,12 @@ impl protocols::agent_ttrpc::AgentService for AgentService { info!(sl!(), "get guest details!"); let mut resp = GuestDetailsResponse::new(); // to get memory block size - match get_memory_info(req.mem_block_size, req.mem_hotplug_probe) { + match get_memory_info( + req.mem_block_size, + req.mem_hotplug_probe, + SYSFS_MEMORY_BLOCK_SIZE_PATH, + SYSFS_MEMORY_HOTPLUG_PROBE_PATH, + ) { Ok((u, v)) => { resp.mem_block_size_bytes = u; resp.support_mem_hotplug_probe = v; @@ -1343,24 +1350,29 @@ impl protocols::health_ttrpc::Health for HealthService { } } -fn get_memory_info(block_size: bool, hotplug: bool) -> Result<(u64, bool)> { +fn get_memory_info( + block_size: bool, + hotplug: bool, + block_size_path: &str, + hotplug_probe_path: &str, +) -> Result<(u64, bool)> { let mut size: u64 = 0; let mut plug: bool = false; if block_size { - match fs::read_to_string(SYSFS_MEMORY_BLOCK_SIZE_PATH) { + match fs::read_to_string(block_size_path) { Ok(v) => { if v.is_empty() { - info!(sl!(), "string in empty???"); - return Err(anyhow!("Invalid block size")); + warn!(sl!(), "file {} is empty", block_size_path); + return Err(anyhow!(ERR_INVALID_BLOCK_SIZE)); } size = u64::from_str_radix(v.trim(), 16).map_err(|_| { warn!(sl!(), "failed to parse the str {} to hex", size); - anyhow!("Invalid block size") + anyhow!(ERR_INVALID_BLOCK_SIZE) })?; } Err(e) => { - info!(sl!(), "memory block size error: {:?}", e.kind()); + warn!(sl!(), "memory block size error: {:?}", e.kind()); if e.kind() != std::io::ErrorKind::NotFound { return Err(anyhow!(e)); } @@ -1369,10 +1381,10 @@ fn get_memory_info(block_size: bool, hotplug: bool) -> Result<(u64, bool)> { } if hotplug { - match stat::stat(SYSFS_MEMORY_HOTPLUG_PROBE_PATH) { + match stat::stat(hotplug_probe_path) { Ok(_) => plug = true, Err(e) => { - info!(sl!(), "hotplug memory error: {:?}", e); + warn!(sl!(), "hotplug memory error: {:?}", e); match e { nix::Error::ENOENT => plug = false, _ => return Err(anyhow!(e)), @@ -1803,8 +1815,35 @@ mod tests { use super::*; use crate::protocols::agent_ttrpc::AgentService as _; use oci::{Hook, Hooks}; + use tempfile::tempdir; use ttrpc::{r#async::TtrpcContext, MessageHeader}; + // Parameters: + // + // 1: expected Result + // 2: actual Result + // 3: string used to identify the test on error + macro_rules! assert_result { + ($expected_result:expr, $actual_result:expr, $msg:expr) => { + if $expected_result.is_ok() { + let expected_level = $expected_result.as_ref().unwrap(); + let actual_level = $actual_result.unwrap(); + assert!(*expected_level == actual_level, "{}", $msg); + } else { + let expected_error = $expected_result.as_ref().unwrap_err(); + let expected_error_msg = format!("{:?}", expected_error); + + if let Err(actual_error) = $actual_result { + let actual_error_msg = format!("{:?}", actual_error); + + assert!(expected_error_msg == actual_error_msg, "{}", $msg); + } else { + assert!(expected_error_msg == "expected error, got OK", "{}", $msg); + } + } + }; + } + fn mk_ttrpc_context() -> TtrpcContext { TtrpcContext { fd: -1, @@ -1906,6 +1945,119 @@ mod tests { assert!(result.is_err(), "expected add arp neighbors to fail"); } + #[tokio::test] + async fn test_get_memory_info() { + #[derive(Debug)] + struct TestData<'a> { + // if None is provided, no file will be generated, else the data in the Option will populate the file + block_size_data: Option<&'a str>, + + hotplug_probe_data: bool, + get_block_size: bool, + get_hotplug: bool, + result: Result<(u64, bool)>, + } + + let tests = &[ + TestData { + block_size_data: Some("10000000"), + hotplug_probe_data: true, + get_block_size: true, + get_hotplug: true, + result: Ok((268435456, true)), + }, + TestData { + block_size_data: Some("100"), + hotplug_probe_data: false, + get_block_size: true, + get_hotplug: true, + result: Ok((256, false)), + }, + TestData { + block_size_data: None, + hotplug_probe_data: false, + get_block_size: true, + get_hotplug: true, + result: Ok((0, false)), + }, + TestData { + block_size_data: Some(""), + hotplug_probe_data: false, + get_block_size: true, + get_hotplug: false, + result: Err(anyhow!(ERR_INVALID_BLOCK_SIZE)), + }, + TestData { + block_size_data: Some("-1"), + hotplug_probe_data: false, + get_block_size: true, + get_hotplug: false, + result: Err(anyhow!(ERR_INVALID_BLOCK_SIZE)), + }, + TestData { + block_size_data: Some(" "), + hotplug_probe_data: false, + get_block_size: true, + get_hotplug: false, + result: Err(anyhow!(ERR_INVALID_BLOCK_SIZE)), + }, + TestData { + block_size_data: Some("some data"), + hotplug_probe_data: false, + get_block_size: true, + get_hotplug: false, + result: Err(anyhow!(ERR_INVALID_BLOCK_SIZE)), + }, + TestData { + block_size_data: Some("some data"), + hotplug_probe_data: true, + get_block_size: false, + get_hotplug: false, + result: Ok((0, false)), + }, + TestData { + block_size_data: Some("100"), + hotplug_probe_data: true, + get_block_size: false, + get_hotplug: false, + result: Ok((0, false)), + }, + TestData { + block_size_data: Some("100"), + hotplug_probe_data: true, + get_block_size: false, + get_hotplug: true, + result: Ok((0, true)), + }, + ]; + + for (i, d) in tests.iter().enumerate() { + let msg = format!("test[{}]: {:?}", i, d); + + let dir = tempdir().expect("failed to make tempdir"); + let block_size_path = dir.path().join("block_size_bytes"); + let hotplug_probe_path = dir.path().join("probe"); + + if let Some(block_size_data) = d.block_size_data { + fs::write(&block_size_path, block_size_data).unwrap(); + } + if d.hotplug_probe_data { + fs::write(&hotplug_probe_path, []).unwrap(); + } + + let result = get_memory_info( + d.get_block_size, + d.get_hotplug, + block_size_path.to_str().unwrap(), + hotplug_probe_path.to_str().unwrap(), + ); + + let msg = format!("{}, result: {:?}", msg, result); + + assert_result!(d.result, result, msg); + } + } + #[tokio::test] async fn test_verify_cid() { #[derive(Debug)]