agent:util: Refactor the unit tests to leverage rstest

Refactor the unit tests in util.rs to leverage rstest for parameterization.

Fixes: #9314

Signed-off-by: ChengyuZhu6 <chengyu.zhu@intel.com>
This commit is contained in:
ChengyuZhu6
2024-03-20 14:44:33 +08:00
parent 2df2b4d30d
commit 7a49ec1c80

View File

@@ -75,6 +75,7 @@ pub async fn get_vsock_stream(fd: RawFd) -> Result<VsockStream> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use rstest::rstest;
use std::io; use std::io;
use std::io::Cursor; use std::io::Cursor;
use std::io::Write; use std::io::Write;
@@ -172,88 +173,55 @@ mod tests {
} }
} }
#[rstest]
#[case("".into())]
#[case("a".into())]
#[case("foo".into())]
#[case("b".repeat(BUF_SIZE - 1))]
#[case("c".repeat(BUF_SIZE))]
#[case("d".repeat(BUF_SIZE + 1))]
#[case("e".repeat((2 * BUF_SIZE) - 1))]
#[case("f".repeat(2 * BUF_SIZE))]
#[case("g".repeat((2 * BUF_SIZE) + 1))]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_interruptable_io_copier_reader() { async fn test_interruptable_io_copier_reader(#[case] reader_value: String) {
#[derive(Debug)] let (tx, rx) = channel(true);
struct TestData { let reader = Cursor::new(reader_value.clone());
reader_value: String, let writer = BufWriter::new();
// XXX: Pass a copy of the writer to the copier to allow the
// result of the write operation to be checked below.
let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx));
// Allow time for the thread to be spawned.
tokio::time::sleep(Duration::from_secs(1)).await;
let timeout = tokio::time::sleep(Duration::from_secs(1));
pin!(timeout);
// Since the readers only specify a small number of bytes, the
// copier will quickly read zero and kill the task, closing the
// Receiver.
assert!(tx.is_closed());
let spawn_result: std::result::Result<std::result::Result<u64, std::io::Error>, JoinError>;
select! {
res = handle => spawn_result = res,
_ = &mut timeout => panic!("timed out"),
} }
let tests = &[ assert!(spawn_result.is_ok());
TestData {
reader_value: "".into(),
},
TestData {
reader_value: "a".into(),
},
TestData {
reader_value: "foo".into(),
},
TestData {
reader_value: "b".repeat(BUF_SIZE - 1),
},
TestData {
reader_value: "c".repeat(BUF_SIZE),
},
TestData {
reader_value: "d".repeat(BUF_SIZE + 1),
},
TestData {
reader_value: "e".repeat((2 * BUF_SIZE) - 1),
},
TestData {
reader_value: "f".repeat(2 * BUF_SIZE),
},
TestData {
reader_value: "g".repeat((2 * BUF_SIZE) + 1),
},
];
for (i, d) in tests.iter().enumerate() { let result: std::result::Result<u64, std::io::Error> = spawn_result.unwrap();
// Create a string containing details of the test
let msg = format!("test[{}]: {:?}", i, d);
let (tx, rx) = channel(true); assert!(result.is_ok());
let reader = Cursor::new(d.reader_value.clone());
let writer = BufWriter::new();
// XXX: Pass a copy of the writer to the copier to allow the let byte_count = result.unwrap() as usize;
// result of the write operation to be checked below. assert_eq!(byte_count, reader_value.len());
let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx));
// Allow time for the thread to be spawned. let value = writer.to_string();
tokio::time::sleep(Duration::from_secs(1)).await; assert_eq!(value, reader_value);
let timeout = tokio::time::sleep(Duration::from_secs(1));
pin!(timeout);
// Since the readers only specify a small number of bytes, the
// copier will quickly read zero and kill the task, closing the
// Receiver.
assert!(tx.is_closed(), "{}", msg);
let spawn_result: std::result::Result<
std::result::Result<u64, std::io::Error>,
JoinError,
>;
select! {
res = handle => spawn_result = res,
_ = &mut timeout => panic!("timed out"),
}
assert!(spawn_result.is_ok());
let result: std::result::Result<u64, std::io::Error> = spawn_result.unwrap();
assert!(result.is_ok());
let byte_count = result.unwrap() as usize;
assert_eq!(byte_count, d.reader_value.len(), "{}", msg);
let value = writer.to_string();
assert_eq!(value, d.reader_value, "{}", msg);
}
} }
#[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[tokio::test(flavor = "multi_thread", worker_threads = 2)]