mirror of
				https://github.com/kata-containers/kata-containers.git
				synced 2025-10-31 17:37:20 +00:00 
			
		
		
		
	rustjail: rework execute_hook
Fixes: #1532 Signed-off-by: Tim Zhang <tim@hyper.sh>
This commit is contained in:
		| @@ -1457,11 +1457,9 @@ fn set_sysctls(sysctls: &HashMap<String, String>) -> Result<()> { | |||||||
|     Ok(()) |     Ok(()) | ||||||
| } | } | ||||||
|  |  | ||||||
| use std::io::Read; |  | ||||||
| use std::os::unix::process::ExitStatusExt; |  | ||||||
| use std::process::Stdio; | use std::process::Stdio; | ||||||
| use std::sync::mpsc::{self, RecvTimeoutError}; |  | ||||||
| use std::time::Duration; | use std::time::Duration; | ||||||
|  | use tokio::io::{AsyncReadExt, AsyncWriteExt}; | ||||||
|  |  | ||||||
| async fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { | async fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { | ||||||
|     let logger = logger.new(o!("action" => "execute-hook")); |     let logger = logger.new(o!("action" => "execute-hook")); | ||||||
| @@ -1473,160 +1471,91 @@ async fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     let args = h.args.clone(); |     let args = h.args.clone(); | ||||||
|     let envs = h.env.clone(); |     let env: HashMap<String, String> = h | ||||||
|  |         .env | ||||||
|  |         .iter() | ||||||
|  |         .map(|e| { | ||||||
|  |             let v: Vec<&str> = e.split('=').collect(); | ||||||
|  |             (v[0].to_string(), v[1].to_string()) | ||||||
|  |         }) | ||||||
|  |         .collect(); | ||||||
|  |  | ||||||
|  |     let mut child = tokio::process::Command::new(path) | ||||||
|  |         .args(args.iter()) | ||||||
|  |         .envs(env.iter()) | ||||||
|  |         .kill_on_drop(true) | ||||||
|  |         .stdin(Stdio::piped()) | ||||||
|  |         .stdout(Stdio::piped()) | ||||||
|  |         .stderr(Stdio::piped()) | ||||||
|  |         .spawn() | ||||||
|  |         .unwrap(); | ||||||
|  |  | ||||||
|  |     // default timeout 10s | ||||||
|  |     let mut timeout: u64 = 10; | ||||||
|  |  | ||||||
|  |     // if timeout is set if hook, then use the specified value | ||||||
|  |     if let Some(t) = h.timeout { | ||||||
|  |         if t > 0 { | ||||||
|  |             timeout = t as u64; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     let state = serde_json::to_string(st)?; |     let state = serde_json::to_string(st)?; | ||||||
|  |     let path = h.path.clone(); | ||||||
|  |  | ||||||
|     let (rfd, wfd) = unistd::pipe2(OFlag::O_CLOEXEC)?; |     let join_handle = tokio::spawn(async move { | ||||||
|     defer!({ |         child | ||||||
|         let _ = unistd::close(rfd); |             .stdin | ||||||
|         let _ = unistd::close(wfd); |             .as_mut() | ||||||
|     }); |             .unwrap() | ||||||
|  |             .write_all(state.as_bytes()) | ||||||
|  |             .await | ||||||
|  |             .unwrap(); | ||||||
|  |  | ||||||
|     match unistd::fork()? { |         // Close stdin so that hook program could receive EOF | ||||||
|         ForkResult::Parent { child } => { |         child.stdin.take(); | ||||||
|             let mut pipe_r = PipeStream::from_fd(rfd); |  | ||||||
|             let buf = read_async(&mut pipe_r).await?; |  | ||||||
|             let status = if buf.len() == 4 { |  | ||||||
|                 let buf_array: [u8; 4] = [buf[0], buf[1], buf[2], buf[3]]; |  | ||||||
|                 i32::from_be_bytes(buf_array) |  | ||||||
|             } else { |  | ||||||
|                 -libc::EPIPE |  | ||||||
|             }; |  | ||||||
|  |  | ||||||
|             info!(logger, "hook child: {} status: {}", child, status); |         // read something from stdout for debug | ||||||
|  |         let mut out = String::new(); | ||||||
|  |         child | ||||||
|  |             .stdout | ||||||
|  |             .as_mut() | ||||||
|  |             .unwrap() | ||||||
|  |             .read_to_string(&mut out) | ||||||
|  |             .await | ||||||
|  |             .unwrap(); | ||||||
|  |         info!(logger, "child stdout: {}", out.as_str()); | ||||||
|  |  | ||||||
|             if status != 0 { |         match child.wait().await { | ||||||
|                 if status == -libc::ETIMEDOUT { |             Ok(exit) => { | ||||||
|                     return Err(anyhow!(nix::Error::from_errno(Errno::ETIMEDOUT))); |                 let code = match exit.code() { | ||||||
|                 } else if status == -libc::EPIPE { |                     Some(c) => c, | ||||||
|                     return Err(anyhow!(nix::Error::from_errno(Errno::EPIPE))); |                     None => { | ||||||
|  |                         return Err(anyhow!("hook exit status has no status code")); | ||||||
|  |                     } | ||||||
|  |                 }; | ||||||
|  |  | ||||||
|  |                 if code == 0 { | ||||||
|  |                     debug!(logger, "hook {} exit status is 0", &path); | ||||||
|  |                     return Ok(()); | ||||||
|                 } else { |                 } else { | ||||||
|  |                     error!(logger, "hook {} exit status is {}", &path, code); | ||||||
|                     return Err(anyhow!(nix::Error::from_errno(Errno::UnknownErrno))); |                     return Err(anyhow!(nix::Error::from_errno(Errno::UnknownErrno))); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |             Err(e) => { | ||||||
|             Ok(()) |                 return Err(anyhow!( | ||||||
|  |                     "wait child error: {} {}", | ||||||
|  |                     e, | ||||||
|  |                     e.raw_os_error().unwrap() | ||||||
|  |                 )); | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|  |     }); | ||||||
|  |  | ||||||
|         ForkResult::Child => { |     match tokio::time::timeout(Duration::new(timeout, 0), join_handle).await { | ||||||
|             let (tx, rx) = mpsc::channel(); |         Ok(r) => r.unwrap(), | ||||||
|             let (tx_logger, rx_logger) = mpsc::channel(); |         Err(_) => Err(anyhow!(nix::Error::from_errno(Errno::ETIMEDOUT))), | ||||||
|  |  | ||||||
|             tx_logger.send(logger.clone()).unwrap(); |  | ||||||
|  |  | ||||||
|             let handle = std::thread::spawn(move || { |  | ||||||
|                 let logger = rx_logger.recv().unwrap(); |  | ||||||
|  |  | ||||||
|                 // write oci state to child |  | ||||||
|                 let env: HashMap<String, String> = envs |  | ||||||
|                     .iter() |  | ||||||
|                     .map(|e| { |  | ||||||
|                         let v: Vec<&str> = e.split('=').collect(); |  | ||||||
|                         (v[0].to_string(), v[1].to_string()) |  | ||||||
|                     }) |  | ||||||
|                     .collect(); |  | ||||||
|  |  | ||||||
|                 let mut child = std::process::Command::new(path.to_str().unwrap()) |  | ||||||
|                     .args(args.iter()) |  | ||||||
|                     .envs(env.iter()) |  | ||||||
|                     .stdin(Stdio::piped()) |  | ||||||
|                     .stdout(Stdio::piped()) |  | ||||||
|                     .stderr(Stdio::piped()) |  | ||||||
|                     .spawn() |  | ||||||
|                     .unwrap(); |  | ||||||
|  |  | ||||||
|                 // send out our pid |  | ||||||
|                 tx.send(child.id() as libc::pid_t).unwrap(); |  | ||||||
|                 info!(logger, "hook grand: {}", child.id()); |  | ||||||
|  |  | ||||||
|                 child |  | ||||||
|                     .stdin |  | ||||||
|                     .as_mut() |  | ||||||
|                     .unwrap() |  | ||||||
|                     .write_all(state.as_bytes()) |  | ||||||
|                     .unwrap(); |  | ||||||
|  |  | ||||||
|                 // Close stdin so that hook program could receive EOF. |  | ||||||
|                 child.stdin.take(); |  | ||||||
|  |  | ||||||
|                 // read something from stdout for debug |  | ||||||
|                 let mut out = String::new(); |  | ||||||
|                 child |  | ||||||
|                     .stdout |  | ||||||
|                     .as_mut() |  | ||||||
|                     .unwrap() |  | ||||||
|                     .read_to_string(&mut out) |  | ||||||
|                     .unwrap(); |  | ||||||
|                 info!(logger, "child stdout: {}", out.as_str()); |  | ||||||
|                 match child.wait() { |  | ||||||
|                     Ok(exit) => { |  | ||||||
|                         let code: i32 = if exit.success() { |  | ||||||
|                             0 |  | ||||||
|                         } else { |  | ||||||
|                             match exit.code() { |  | ||||||
|                                 Some(c) => (c as u32 | 0x80000000) as i32, |  | ||||||
|                                 None => exit.signal().unwrap(), |  | ||||||
|                             } |  | ||||||
|                         }; |  | ||||||
|  |  | ||||||
|                         tx.send(code).unwrap(); |  | ||||||
|                     } |  | ||||||
|  |  | ||||||
|                     Err(e) => { |  | ||||||
|                         info!( |  | ||||||
|                             logger, |  | ||||||
|                             "wait child error: {} {}", |  | ||||||
|                             e, |  | ||||||
|                             e.raw_os_error().unwrap() |  | ||||||
|                         ); |  | ||||||
|  |  | ||||||
|                         // There is apparently race between this wait and |  | ||||||
|                         // child reaper. Ie, the child can already |  | ||||||
|                         // be reaped by subreaper, child.wait returns |  | ||||||
|                         // ECHILD. I have no idea how to get the |  | ||||||
|                         // correct exit status here at present, |  | ||||||
|                         // just pretend it exits successfully. |  | ||||||
|                         // -- FIXME |  | ||||||
|                         // just in case. Should not happen any more |  | ||||||
|  |  | ||||||
|                         tx.send(0).unwrap(); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             }); |  | ||||||
|  |  | ||||||
|             let pid = rx.recv().unwrap(); |  | ||||||
|             info!(logger, "hook grand: {}", pid); |  | ||||||
|  |  | ||||||
|             let status = { |  | ||||||
|                 if let Some(timeout) = h.timeout { |  | ||||||
|                     match rx.recv_timeout(Duration::from_secs(timeout as u64)) { |  | ||||||
|                         Ok(s) => s, |  | ||||||
|                         Err(e) => { |  | ||||||
|                             let error = if e == RecvTimeoutError::Timeout { |  | ||||||
|                                 -libc::ETIMEDOUT |  | ||||||
|                             } else { |  | ||||||
|                                 -libc::EPIPE |  | ||||||
|                             }; |  | ||||||
|                             let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL)); |  | ||||||
|                             error |  | ||||||
|                         } |  | ||||||
|                     } |  | ||||||
|                 } else if let Ok(s) = rx.recv() { |  | ||||||
|                     s |  | ||||||
|                 } else { |  | ||||||
|                     let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL)); |  | ||||||
|                     -libc::EPIPE |  | ||||||
|                 } |  | ||||||
|             }; |  | ||||||
|  |  | ||||||
|             handle.join().unwrap(); |  | ||||||
|             let _ = write_sync( |  | ||||||
|                 wfd, |  | ||||||
|                 SYNC_DATA, |  | ||||||
|                 std::str::from_utf8(&status.to_be_bytes()).unwrap_or_default(), |  | ||||||
|             ); |  | ||||||
|             std::process::exit(0); |  | ||||||
|         } |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -1669,6 +1598,34 @@ mod tests { | |||||||
|         .unwrap() |         .unwrap() | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     #[tokio::test] | ||||||
|  |     async fn test_execute_hook_with_timeout() { | ||||||
|  |         let res = execute_hook( | ||||||
|  |             &slog_scope::logger(), | ||||||
|  |             &Hook { | ||||||
|  |                 path: "/usr/bin/sleep".to_string(), | ||||||
|  |                 args: vec!["2".to_string()], | ||||||
|  |                 env: vec![], | ||||||
|  |                 timeout: Some(1), | ||||||
|  |             }, | ||||||
|  |             &OCIState { | ||||||
|  |                 version: "1.2.3".to_string(), | ||||||
|  |                 id: "321".to_string(), | ||||||
|  |                 status: ContainerState::RUNNING, | ||||||
|  |                 pid: 2, | ||||||
|  |                 bundle: "".to_string(), | ||||||
|  |                 annotations: Default::default(), | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|  |         .await; | ||||||
|  |  | ||||||
|  |         let expected_err = nix::Error::from_errno(Errno::ETIMEDOUT); | ||||||
|  |         assert_eq!( | ||||||
|  |             res.unwrap_err().downcast::<nix::Error>().unwrap(), | ||||||
|  |             expected_err | ||||||
|  |         ); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     #[test] |     #[test] | ||||||
|     fn test_status_transtition() { |     fn test_status_transtition() { | ||||||
|         let mut status = ContainerStatus::new(); |         let mut status = ContainerStatus::new(); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user