From 2df2b4d30ddd977ae20dbd4e50b8d7c59b1dab31 Mon Sep 17 00:00:00 2001 From: ChengyuZhu6 Date: Wed, 20 Mar 2024 14:15:45 +0800 Subject: [PATCH 1/2] agent:namespace: Refactor unit tests to leverage rstest Refactor the unit tests in `namespace.rs` to leverage rstest for parameterization. Signed-off-by: ChengyuZhu6 --- src/agent/src/namespace.rs | 95 ++++++++++---------------------------- 1 file changed, 24 insertions(+), 71 deletions(-) diff --git a/src/agent/src/namespace.rs b/src/agent/src/namespace.rs index bf24cd1048..62529885f1 100644 --- a/src/agent/src/namespace.rs +++ b/src/agent/src/namespace.rs @@ -182,6 +182,7 @@ mod tests { use super::{Namespace, NamespaceType}; use crate::mount::remove_mounts; use nix::sched::CloneFlags; + use rstest::rstest; use tempfile::Builder; use test_utils::skip_if_not_root; @@ -226,26 +227,23 @@ mod tests { assert!(ns_pid.is_err()); } - #[test] - fn test_namespace_type() { - let ipc = NamespaceType::Ipc; - assert_eq!("ipc", ipc.get()); - assert_eq!(CloneFlags::CLONE_NEWIPC, ipc.get_flags()); - - let uts = NamespaceType::Uts; - assert_eq!("uts", uts.get()); - assert_eq!(CloneFlags::CLONE_NEWUTS, uts.get_flags()); - - let pid = NamespaceType::Pid; - assert_eq!("pid", pid.get()); - assert_eq!(CloneFlags::CLONE_NEWPID, pid.get_flags()); + #[rstest] + #[case::ipc(NamespaceType::Ipc, "ipc", CloneFlags::CLONE_NEWIPC)] + #[case::uts(NamespaceType::Uts, "uts", CloneFlags::CLONE_NEWUTS)] + #[case::pid(NamespaceType::Pid, "pid", CloneFlags::CLONE_NEWPID)] + fn test_namespace_type( + #[case] ns_type: NamespaceType, + #[case] ns_name: &str, + #[case] ns_flag: CloneFlags, + ) { + assert_eq!(ns_name, ns_type.get()); + assert_eq!(ns_flag, ns_type.get_flags()); } #[test] fn test_new() { // Create dummy logger and temp folder. let logger = slog::Logger::root(slog::Discard, o!()); - let ns_ipc = Namespace::new(&logger); assert_eq!(NamespaceType::Ipc, ns_ipc.ns_type); } @@ -301,65 +299,20 @@ mod tests { assert_eq!(ns_root.persistent_ns_dir, tmpdir.path().to_str().unwrap()); } - #[test] - fn test_namespace_type_get() { - #[derive(Debug)] - struct TestData<'a> { - ns_type: NamespaceType, - str: &'a str, - } - - let tests = &[ - TestData { - ns_type: NamespaceType::Ipc, - str: "ipc", - }, - TestData { - ns_type: NamespaceType::Uts, - str: "uts", - }, - TestData { - ns_type: NamespaceType::Pid, - str: "pid", - }, - ]; - - // Run the tests - for (i, d) in tests.iter().enumerate() { - // Create a string containing details of the test - let msg = format!("test[{}]: {:?}", i, d); - assert_eq!(d.str, d.ns_type.get(), "{}", msg) - } + #[rstest] + #[case::namespace_type_get_ipc(NamespaceType::Ipc, "ipc")] + #[case::namespace_type_get_uts(NamespaceType::Uts, "uts")] + #[case::namespace_type_get_pid(NamespaceType::Pid, "pid")] + fn test_namespace_type_get(#[case] ns_type: NamespaceType, #[case] ns_name: &str) { + assert_eq!(ns_name, ns_type.get()) } - #[test] - fn test_namespace_type_get_flags() { - #[derive(Debug)] - struct TestData { - ns_type: NamespaceType, - ns_flag: CloneFlags, - } - - let tests = &[ - TestData { - ns_type: NamespaceType::Ipc, - ns_flag: CloneFlags::CLONE_NEWIPC, - }, - TestData { - ns_type: NamespaceType::Uts, - ns_flag: CloneFlags::CLONE_NEWUTS, - }, - TestData { - ns_type: NamespaceType::Pid, - ns_flag: CloneFlags::CLONE_NEWPID, - }, - ]; - + #[rstest] + #[case::namespace_type_get_flags_ipc(NamespaceType::Ipc, CloneFlags::CLONE_NEWIPC)] + #[case::namespace_type_get_flags_uts(NamespaceType::Uts, CloneFlags::CLONE_NEWUTS)] + #[case::namespace_type_get_flags_pid(NamespaceType::Pid, CloneFlags::CLONE_NEWPID)] + fn test_namespace_type_get_flags(#[case] ns_type: NamespaceType, #[case] ns_flag: CloneFlags) { // Run the tests - for (i, d) in tests.iter().enumerate() { - // Create a string containing details of the test - let msg = format!("test[{}]: {:?}", i, d); - assert_eq!(d.ns_flag, d.ns_type.get_flags(), "{}", msg) - } + assert_eq!(ns_flag, ns_type.get_flags()) } } From 7a49ec1c8069ddd01e68c762703f0a07b4d4ecac Mon Sep 17 00:00:00 2001 From: ChengyuZhu6 Date: Wed, 20 Mar 2024 14:44:33 +0800 Subject: [PATCH 2/2] agent:util: Refactor the unit tests to leverage rstest Refactor the unit tests in util.rs to leverage rstest for parameterization. Fixes: #9314 Signed-off-by: ChengyuZhu6 --- src/agent/src/util.rs | 118 +++++++++++++++--------------------------- 1 file changed, 43 insertions(+), 75 deletions(-) diff --git a/src/agent/src/util.rs b/src/agent/src/util.rs index c4645f9753..e2587477e7 100644 --- a/src/agent/src/util.rs +++ b/src/agent/src/util.rs @@ -75,6 +75,7 @@ pub async fn get_vsock_stream(fd: RawFd) -> Result { #[cfg(test)] mod tests { use super::*; + use rstest::rstest; use std::io; use std::io::Cursor; use std::io::Write; @@ -172,88 +173,55 @@ mod tests { } } + #[rstest] + #[case("".into())] + #[case("a".into())] + #[case("foo".into())] + #[case("b".repeat(BUF_SIZE - 1))] + #[case("c".repeat(BUF_SIZE))] + #[case("d".repeat(BUF_SIZE + 1))] + #[case("e".repeat((2 * BUF_SIZE) - 1))] + #[case("f".repeat(2 * BUF_SIZE))] + #[case("g".repeat((2 * BUF_SIZE) + 1))] #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_interruptable_io_copier_reader() { - #[derive(Debug)] - struct TestData { - reader_value: String, + async fn test_interruptable_io_copier_reader(#[case] reader_value: String) { + let (tx, rx) = channel(true); + let reader = Cursor::new(reader_value.clone()); + let writer = BufWriter::new(); + + // XXX: Pass a copy of the writer to the copier to allow the + // result of the write operation to be checked below. + let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx)); + + // Allow time for the thread to be spawned. + tokio::time::sleep(Duration::from_secs(1)).await; + + let timeout = tokio::time::sleep(Duration::from_secs(1)); + pin!(timeout); + + // Since the readers only specify a small number of bytes, the + // copier will quickly read zero and kill the task, closing the + // Receiver. + assert!(tx.is_closed()); + + let spawn_result: std::result::Result, JoinError>; + + select! { + res = handle => spawn_result = res, + _ = &mut timeout => panic!("timed out"), } - let tests = &[ - TestData { - reader_value: "".into(), - }, - TestData { - reader_value: "a".into(), - }, - TestData { - reader_value: "foo".into(), - }, - TestData { - reader_value: "b".repeat(BUF_SIZE - 1), - }, - TestData { - reader_value: "c".repeat(BUF_SIZE), - }, - TestData { - reader_value: "d".repeat(BUF_SIZE + 1), - }, - TestData { - reader_value: "e".repeat((2 * BUF_SIZE) - 1), - }, - TestData { - reader_value: "f".repeat(2 * BUF_SIZE), - }, - TestData { - reader_value: "g".repeat((2 * BUF_SIZE) + 1), - }, - ]; + assert!(spawn_result.is_ok()); - for (i, d) in tests.iter().enumerate() { - // Create a string containing details of the test - let msg = format!("test[{}]: {:?}", i, d); + let result: std::result::Result = spawn_result.unwrap(); - let (tx, rx) = channel(true); - let reader = Cursor::new(d.reader_value.clone()); - let writer = BufWriter::new(); + assert!(result.is_ok()); - // XXX: Pass a copy of the writer to the copier to allow the - // result of the write operation to be checked below. - let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx)); + let byte_count = result.unwrap() as usize; + assert_eq!(byte_count, reader_value.len()); - // Allow time for the thread to be spawned. - tokio::time::sleep(Duration::from_secs(1)).await; - - let timeout = tokio::time::sleep(Duration::from_secs(1)); - pin!(timeout); - - // Since the readers only specify a small number of bytes, the - // copier will quickly read zero and kill the task, closing the - // Receiver. - assert!(tx.is_closed(), "{}", msg); - - let spawn_result: std::result::Result< - std::result::Result, - JoinError, - >; - - select! { - res = handle => spawn_result = res, - _ = &mut timeout => panic!("timed out"), - } - - assert!(spawn_result.is_ok()); - - let result: std::result::Result = spawn_result.unwrap(); - - assert!(result.is_ok()); - - let byte_count = result.unwrap() as usize; - assert_eq!(byte_count, d.reader_value.len(), "{}", msg); - - let value = writer.to_string(); - assert_eq!(value, d.reader_value, "{}", msg); - } + let value = writer.to_string(); + assert_eq!(value, reader_value); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)]