diff --git a/src/agent/src/util.rs b/src/agent/src/util.rs index c4645f9753..e2587477e7 100644 --- a/src/agent/src/util.rs +++ b/src/agent/src/util.rs @@ -75,6 +75,7 @@ pub async fn get_vsock_stream(fd: RawFd) -> Result { #[cfg(test)] mod tests { use super::*; + use rstest::rstest; use std::io; use std::io::Cursor; use std::io::Write; @@ -172,88 +173,55 @@ mod tests { } } + #[rstest] + #[case("".into())] + #[case("a".into())] + #[case("foo".into())] + #[case("b".repeat(BUF_SIZE - 1))] + #[case("c".repeat(BUF_SIZE))] + #[case("d".repeat(BUF_SIZE + 1))] + #[case("e".repeat((2 * BUF_SIZE) - 1))] + #[case("f".repeat(2 * BUF_SIZE))] + #[case("g".repeat((2 * BUF_SIZE) + 1))] #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_interruptable_io_copier_reader() { - #[derive(Debug)] - struct TestData { - reader_value: String, + async fn test_interruptable_io_copier_reader(#[case] reader_value: String) { + let (tx, rx) = channel(true); + let reader = Cursor::new(reader_value.clone()); + let writer = BufWriter::new(); + + // XXX: Pass a copy of the writer to the copier to allow the + // result of the write operation to be checked below. + let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx)); + + // Allow time for the thread to be spawned. + tokio::time::sleep(Duration::from_secs(1)).await; + + let timeout = tokio::time::sleep(Duration::from_secs(1)); + pin!(timeout); + + // Since the readers only specify a small number of bytes, the + // copier will quickly read zero and kill the task, closing the + // Receiver. + assert!(tx.is_closed()); + + let spawn_result: std::result::Result, JoinError>; + + select! { + res = handle => spawn_result = res, + _ = &mut timeout => panic!("timed out"), } - let tests = &[ - TestData { - reader_value: "".into(), - }, - TestData { - reader_value: "a".into(), - }, - TestData { - reader_value: "foo".into(), - }, - TestData { - reader_value: "b".repeat(BUF_SIZE - 1), - }, - TestData { - reader_value: "c".repeat(BUF_SIZE), - }, - TestData { - reader_value: "d".repeat(BUF_SIZE + 1), - }, - TestData { - reader_value: "e".repeat((2 * BUF_SIZE) - 1), - }, - TestData { - reader_value: "f".repeat(2 * BUF_SIZE), - }, - TestData { - reader_value: "g".repeat((2 * BUF_SIZE) + 1), - }, - ]; + assert!(spawn_result.is_ok()); - for (i, d) in tests.iter().enumerate() { - // Create a string containing details of the test - let msg = format!("test[{}]: {:?}", i, d); + let result: std::result::Result = spawn_result.unwrap(); - let (tx, rx) = channel(true); - let reader = Cursor::new(d.reader_value.clone()); - let writer = BufWriter::new(); + assert!(result.is_ok()); - // XXX: Pass a copy of the writer to the copier to allow the - // result of the write operation to be checked below. - let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx)); + let byte_count = result.unwrap() as usize; + assert_eq!(byte_count, reader_value.len()); - // Allow time for the thread to be spawned. - tokio::time::sleep(Duration::from_secs(1)).await; - - let timeout = tokio::time::sleep(Duration::from_secs(1)); - pin!(timeout); - - // Since the readers only specify a small number of bytes, the - // copier will quickly read zero and kill the task, closing the - // Receiver. - assert!(tx.is_closed(), "{}", msg); - - let spawn_result: std::result::Result< - std::result::Result, - JoinError, - >; - - select! { - res = handle => spawn_result = res, - _ = &mut timeout => panic!("timed out"), - } - - assert!(spawn_result.is_ok()); - - let result: std::result::Result = spawn_result.unwrap(); - - assert!(result.is_ok()); - - let byte_count = result.unwrap() as usize; - assert_eq!(byte_count, d.reader_value.len(), "{}", msg); - - let value = writer.to_string(); - assert_eq!(value, d.reader_value, "{}", msg); - } + let value = writer.to_string(); + assert_eq!(value, reader_value); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)]