diff --git a/src/agent/rustjail/src/container.rs b/src/agent/rustjail/src/container.rs index bdaac295d2..ca77a181f4 100644 --- a/src/agent/rustjail/src/container.rs +++ b/src/agent/rustjail/src/container.rs @@ -1457,11 +1457,9 @@ fn set_sysctls(sysctls: &HashMap) -> Result<()> { Ok(()) } -use std::io::Read; -use std::os::unix::process::ExitStatusExt; use std::process::Stdio; -use std::sync::mpsc::{self, RecvTimeoutError}; use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; async fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { let logger = logger.new(o!("action" => "execute-hook")); @@ -1473,160 +1471,91 @@ async fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { } let args = h.args.clone(); - let envs = h.env.clone(); + let env: HashMap = h + .env + .iter() + .map(|e| { + let v: Vec<&str> = e.split('=').collect(); + (v[0].to_string(), v[1].to_string()) + }) + .collect(); + + let mut child = tokio::process::Command::new(path) + .args(args.iter()) + .envs(env.iter()) + .kill_on_drop(true) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .unwrap(); + + // default timeout 10s + let mut timeout: u64 = 10; + + // if timeout is set if hook, then use the specified value + if let Some(t) = h.timeout { + if t > 0 { + timeout = t as u64; + } + } + let state = serde_json::to_string(st)?; + let path = h.path.clone(); - let (rfd, wfd) = unistd::pipe2(OFlag::O_CLOEXEC)?; - defer!({ - let _ = unistd::close(rfd); - let _ = unistd::close(wfd); - }); + let join_handle = tokio::spawn(async move { + child + .stdin + .as_mut() + .unwrap() + .write_all(state.as_bytes()) + .await + .unwrap(); - match unistd::fork()? { - ForkResult::Parent { child } => { - let mut pipe_r = PipeStream::from_fd(rfd); - let buf = read_async(&mut pipe_r).await?; - let status = if buf.len() == 4 { - let buf_array: [u8; 4] = [buf[0], buf[1], buf[2], buf[3]]; - i32::from_be_bytes(buf_array) - } else { - -libc::EPIPE - }; + // Close stdin so that hook program could receive EOF + child.stdin.take(); - info!(logger, "hook child: {} status: {}", child, status); + // read something from stdout for debug + let mut out = String::new(); + child + .stdout + .as_mut() + .unwrap() + .read_to_string(&mut out) + .await + .unwrap(); + info!(logger, "child stdout: {}", out.as_str()); - if status != 0 { - if status == -libc::ETIMEDOUT { - return Err(anyhow!(nix::Error::from_errno(Errno::ETIMEDOUT))); - } else if status == -libc::EPIPE { - return Err(anyhow!(nix::Error::from_errno(Errno::EPIPE))); + match child.wait().await { + Ok(exit) => { + let code = match exit.code() { + Some(c) => c, + None => { + return Err(anyhow!("hook exit status has no status code")); + } + }; + + if code == 0 { + debug!(logger, "hook {} exit status is 0", &path); + return Ok(()); } else { + error!(logger, "hook {} exit status is {}", &path, code); return Err(anyhow!(nix::Error::from_errno(Errno::UnknownErrno))); } } - - Ok(()) + Err(e) => { + return Err(anyhow!( + "wait child error: {} {}", + e, + e.raw_os_error().unwrap() + )); + } } + }); - ForkResult::Child => { - let (tx, rx) = mpsc::channel(); - let (tx_logger, rx_logger) = mpsc::channel(); - - tx_logger.send(logger.clone()).unwrap(); - - let handle = std::thread::spawn(move || { - let logger = rx_logger.recv().unwrap(); - - // write oci state to child - let env: HashMap = envs - .iter() - .map(|e| { - let v: Vec<&str> = e.split('=').collect(); - (v[0].to_string(), v[1].to_string()) - }) - .collect(); - - let mut child = std::process::Command::new(path.to_str().unwrap()) - .args(args.iter()) - .envs(env.iter()) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - .unwrap(); - - // send out our pid - tx.send(child.id() as libc::pid_t).unwrap(); - info!(logger, "hook grand: {}", child.id()); - - child - .stdin - .as_mut() - .unwrap() - .write_all(state.as_bytes()) - .unwrap(); - - // Close stdin so that hook program could receive EOF. - child.stdin.take(); - - // read something from stdout for debug - let mut out = String::new(); - child - .stdout - .as_mut() - .unwrap() - .read_to_string(&mut out) - .unwrap(); - info!(logger, "child stdout: {}", out.as_str()); - match child.wait() { - Ok(exit) => { - let code: i32 = if exit.success() { - 0 - } else { - match exit.code() { - Some(c) => (c as u32 | 0x80000000) as i32, - None => exit.signal().unwrap(), - } - }; - - tx.send(code).unwrap(); - } - - Err(e) => { - info!( - logger, - "wait child error: {} {}", - e, - e.raw_os_error().unwrap() - ); - - // There is apparently race between this wait and - // child reaper. Ie, the child can already - // be reaped by subreaper, child.wait returns - // ECHILD. I have no idea how to get the - // correct exit status here at present, - // just pretend it exits successfully. - // -- FIXME - // just in case. Should not happen any more - - tx.send(0).unwrap(); - } - } - }); - - let pid = rx.recv().unwrap(); - info!(logger, "hook grand: {}", pid); - - let status = { - if let Some(timeout) = h.timeout { - match rx.recv_timeout(Duration::from_secs(timeout as u64)) { - Ok(s) => s, - Err(e) => { - let error = if e == RecvTimeoutError::Timeout { - -libc::ETIMEDOUT - } else { - -libc::EPIPE - }; - let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL)); - error - } - } - } else if let Ok(s) = rx.recv() { - s - } else { - let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL)); - -libc::EPIPE - } - }; - - handle.join().unwrap(); - let _ = write_sync( - wfd, - SYNC_DATA, - std::str::from_utf8(&status.to_be_bytes()).unwrap_or_default(), - ); - std::process::exit(0); - } + match tokio::time::timeout(Duration::new(timeout, 0), join_handle).await { + Ok(r) => r.unwrap(), + Err(_) => Err(anyhow!(nix::Error::from_errno(Errno::ETIMEDOUT))), } } @@ -1669,6 +1598,34 @@ mod tests { .unwrap() } + #[tokio::test] + async fn test_execute_hook_with_timeout() { + let res = execute_hook( + &slog_scope::logger(), + &Hook { + path: "/usr/bin/sleep".to_string(), + args: vec!["2".to_string()], + env: vec![], + timeout: Some(1), + }, + &OCIState { + version: "1.2.3".to_string(), + id: "321".to_string(), + status: ContainerState::RUNNING, + pid: 2, + bundle: "".to_string(), + annotations: Default::default(), + }, + ) + .await; + + let expected_err = nix::Error::from_errno(Errno::ETIMEDOUT); + assert_eq!( + res.unwrap_err().downcast::().unwrap(), + expected_err + ); + } + #[test] fn test_status_transtition() { let mut status = ContainerStatus::new();