agent: add tests for do_write_stream function

Add test coverage for do_write_stream function of AgentService
in src/rpc.rs. Includes minor refactoring to make function more
easily testable.

Fixes #3984

Signed-off-by: Braden Rayhorn <bradenrayhorn@fastmail.com>
This commit is contained in:
Braden Rayhorn 2022-03-28 17:58:28 -05:00
parent 8980d04e25
commit 485aeabb6b
No known key found for this signature in database
GPG Key ID: 90BA250DB2C61574
2 changed files with 187 additions and 3 deletions

View File

@ -85,6 +85,7 @@ use std::path::PathBuf;
const CONTAINER_BASE: &str = "/run/kata-containers"; const CONTAINER_BASE: &str = "/run/kata-containers";
const MODPROBE_PATH: &str = "/sbin/modprobe"; const MODPROBE_PATH: &str = "/sbin/modprobe";
const ERR_CANNOT_GET_WRITER: &str = "Cannot get writer";
const ERR_INVALID_BLOCK_SIZE: &str = "Invalid block size"; const ERR_INVALID_BLOCK_SIZE: &str = "Invalid block size";
// Convenience macro to obtain the scope logger // Convenience macro to obtain the scope logger
@ -577,7 +578,7 @@ impl AgentService {
} }
}; };
let writer = writer.ok_or_else(|| anyhow!("cannot get writer"))?; let writer = writer.ok_or_else(|| anyhow!(ERR_CANNOT_GET_WRITER))?;
writer.lock().await.write_all(req.data.as_slice()).await?; writer.lock().await.write_all(req.data.as_slice()).await?;
let mut resp = WriteStreamResponse::new(); let mut resp = WriteStreamResponse::new();
@ -1880,7 +1881,7 @@ mod tests {
use super::*; use super::*;
use crate::protocols::agent_ttrpc::AgentService as _; use crate::protocols::agent_ttrpc::AgentService as _;
use oci::{Hook, Hooks}; use oci::{Hook, Hooks};
use tempfile::tempdir; use tempfile::{tempdir, TempDir};
use ttrpc::{r#async::TtrpcContext, MessageHeader}; use ttrpc::{r#async::TtrpcContext, MessageHeader};
// Parameters: // Parameters:
@ -1918,6 +1919,44 @@ mod tests {
} }
} }
fn create_dummy_opts() -> CreateOpts {
let root = Root {
path: String::from("/"),
..Default::default()
};
let spec = Spec {
linux: Some(oci::Linux::default()),
root: Some(root),
..Default::default()
};
CreateOpts {
cgroup_name: "".to_string(),
use_systemd_cgroup: false,
no_pivot_root: false,
no_new_keyring: false,
spec: Some(spec),
rootless_euid: false,
rootless_cgroup: false,
}
}
fn create_linuxcontainer() -> (LinuxContainer, TempDir) {
let dir = tempdir().expect("failed to make tempdir");
(
LinuxContainer::new(
"some_id",
dir.path().join("rootfs").to_str().unwrap(),
create_dummy_opts(),
&slog_scope::logger(),
)
.unwrap(),
dir,
)
}
#[test] #[test]
fn test_load_kernel_module() { fn test_load_kernel_module() {
let mut m = protocols::agent::KernelModule { let mut m = protocols::agent::KernelModule {
@ -2010,6 +2049,149 @@ mod tests {
assert!(result.is_err(), "expected add arp neighbors to fail"); assert!(result.is_err(), "expected add arp neighbors to fail");
} }
#[tokio::test]
async fn test_do_write_stream() {
#[derive(Debug)]
struct TestData<'a> {
create_container: bool,
has_fd: bool,
has_tty: bool,
break_pipe: bool,
container_id: &'a str,
exec_id: &'a str,
data: Vec<u8>,
result: Result<protocols::agent::WriteStreamResponse>,
}
impl Default for TestData<'_> {
fn default() -> Self {
TestData {
create_container: true,
has_fd: true,
has_tty: true,
break_pipe: false,
container_id: "1",
exec_id: "2",
data: vec![1, 2, 3],
result: Ok(WriteStreamResponse {
len: 3,
..WriteStreamResponse::default()
}),
}
}
}
let tests = &[
TestData {
..Default::default()
},
TestData {
has_tty: false,
..Default::default()
},
TestData {
break_pipe: true,
result: Err(anyhow!(std::io::Error::from_raw_os_error(libc::EPIPE))),
..Default::default()
},
TestData {
create_container: false,
result: Err(anyhow!(crate::sandbox::ERR_INVALID_CONTAINER_ID)),
..Default::default()
},
TestData {
container_id: "8181",
result: Err(anyhow!(crate::sandbox::ERR_INVALID_CONTAINER_ID)),
..Default::default()
},
TestData {
data: vec![],
result: Ok(WriteStreamResponse {
len: 0,
..WriteStreamResponse::default()
}),
..Default::default()
},
TestData {
has_fd: false,
result: Err(anyhow!(ERR_CANNOT_GET_WRITER)),
..Default::default()
},
];
for (i, d) in tests.iter().enumerate() {
let msg = format!("test[{}]: {:?}", i, d);
let logger = slog::Logger::root(slog::Discard, o!());
let mut sandbox = Sandbox::new(&logger).unwrap();
let (rfd, wfd) = unistd::pipe().unwrap();
if d.break_pipe {
unistd::close(rfd).unwrap();
}
if d.create_container {
let (mut linux_container, _root) = create_linuxcontainer();
let exec_process_id = 2;
linux_container.id = "1".to_string();
let mut exec_process = Process::new(
&logger,
&oci::Process::default(),
&exec_process_id.to_string(),
false,
1,
)
.unwrap();
let fd = {
if d.has_fd {
Some(wfd)
} else {
None
}
};
if d.has_tty {
exec_process.parent_stdin = None;
exec_process.term_master = fd;
} else {
exec_process.parent_stdin = fd;
exec_process.term_master = None;
}
linux_container
.processes
.insert(exec_process_id, exec_process);
sandbox.add_container(linux_container);
}
let agent_service = Box::new(AgentService {
sandbox: Arc::new(Mutex::new(sandbox)),
});
let result = agent_service
.do_write_stream(protocols::agent::WriteStreamRequest {
container_id: d.container_id.to_string(),
exec_id: d.exec_id.to_string(),
data: d.data.clone(),
..Default::default()
})
.await;
if !d.break_pipe {
unistd::close(rfd).unwrap();
}
unistd::close(wfd).unwrap();
let msg = format!("{}, result: {:?}", msg, result);
assert_result!(d.result, result, msg);
}
}
#[tokio::test] #[tokio::test]
async fn test_get_memory_info() { async fn test_get_memory_info() {
#[derive(Debug)] #[derive(Debug)]

View File

@ -32,6 +32,8 @@ use tokio::sync::oneshot;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::instrument; use tracing::instrument;
pub const ERR_INVALID_CONTAINER_ID: &str = "Invalid container id";
type UeventWatcher = (Box<dyn UeventMatcher>, oneshot::Sender<Uevent>); type UeventWatcher = (Box<dyn UeventMatcher>, oneshot::Sender<Uevent>);
#[derive(Debug)] #[derive(Debug)]
@ -232,7 +234,7 @@ impl Sandbox {
pub fn find_container_process(&mut self, cid: &str, eid: &str) -> Result<&mut Process> { pub fn find_container_process(&mut self, cid: &str, eid: &str) -> Result<&mut Process> {
let ctr = self let ctr = self
.get_container(cid) .get_container(cid)
.ok_or_else(|| anyhow!("Invalid container id"))?; .ok_or_else(|| anyhow!(ERR_INVALID_CONTAINER_ID))?;
if eid.is_empty() { if eid.is_empty() {
return ctr return ctr