diff --git a/src/runtime-rs/crates/runtimes/virt_container/src/container_manager/io/container_io.rs b/src/runtime-rs/crates/runtimes/virt_container/src/container_manager/io/container_io.rs index c211e8bca4..2c243dece7 100644 --- a/src/runtime-rs/crates/runtimes/virt_container/src/container_manager/io/container_io.rs +++ b/src/runtime-rs/crates/runtimes/virt_container/src/container_manager/io/container_io.rs @@ -44,6 +44,8 @@ struct ContainerIoWrite<'inner> { pub info: Arc, write_future: Option> + Send + 'inner>>>, + shutdown_future: + Option> + Send + 'inner>>>, } impl<'inner> ContainerIoWrite<'inner> { @@ -51,6 +53,7 @@ impl<'inner> ContainerIoWrite<'inner> { Self { info, write_future: Default::default(), + shutdown_future: Default::default(), } } @@ -80,6 +83,30 @@ impl<'inner> ContainerIoWrite<'inner> { } } } + + // Call rpc agent.write_stdin() with empty data to tell agent to close stdin of the process + fn poll_shutdown_inner(&'inner mut self, cx: &mut Context<'_>) -> Poll> { + let mut shutdown_future = self.shutdown_future.take(); + if shutdown_future.is_none() { + let req = agent::WriteStreamRequest { + process_id: self.info.process.clone().into(), + data: Vec::with_capacity(0), + }; + shutdown_future = Some(Box::pin(self.info.agent.write_stdin(req))); + } + + let mut shutdown_future = shutdown_future.unwrap(); + match shutdown_future.as_mut().poll(cx) { + Poll::Ready(v) => match v { + Ok(_) => Poll::Ready(Ok(())), + Err(err) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, err))), + }, + Poll::Pending => { + self.shutdown_future = Some(shutdown_future); + Poll::Pending + } + } + } } impl<'inner> AsyncWrite for ContainerIoWrite<'inner> { @@ -100,8 +127,13 @@ impl<'inner> AsyncWrite for ContainerIoWrite<'inner> { Poll::Ready(Ok(())) } - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let me = unsafe { + std::mem::transmute::<&mut ContainerIoWrite<'_>, &mut ContainerIoWrite<'inner>>( + &mut *self, + ) + }; + me.poll_shutdown_inner(cx) } } diff --git a/src/runtime-rs/crates/runtimes/virt_container/src/container_manager/process.rs b/src/runtime-rs/crates/runtimes/virt_container/src/container_manager/process.rs index 34e9a9a60b..17b2fde464 100644 --- a/src/runtime-rs/crates/runtimes/virt_container/src/container_manager/process.rs +++ b/src/runtime-rs/crates/runtimes/virt_container/src/container_manager/process.rs @@ -11,7 +11,7 @@ use agent::Agent; use anyhow::{Context, Result}; use awaitgroup::{WaitGroup, Worker as WaitGroupWorker}; use common::types::{ContainerProcess, ProcessExitStatus, ProcessStateInfo, ProcessStatus, PID}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::sync::{watch, RwLock}; use super::container::Container; @@ -23,6 +23,13 @@ pub type ProcessWatcher = ( Arc>, ); +#[derive(Debug, PartialEq)] +enum StdIoType { + Stdin, + Stdout, + Stderr, +} + #[derive(Debug)] pub struct Process { pub process: ContainerProcess, @@ -62,10 +69,6 @@ pub struct Process { pub exit_status: Arc>, pub exit_watcher_rx: Option>, pub exit_watcher_tx: Option>, - // used to sync between stdin io copy thread(tokio) and the close it call. - // close io call should wait until the stdin io copy finished to - // prevent stdin data lost. - pub wg_stdin: WaitGroup, // io streams using vsock fd passthrough feature pub passfd_io: Option, @@ -119,7 +122,6 @@ impl Process { exit_status: Arc::new(RwLock::new(ProcessExitStatus::new())), exit_watcher_rx: Some(receiver), exit_watcher_tx: Some(sender), - wg_stdin: WaitGroup::new(), passfd_io: None, } } @@ -246,9 +248,8 @@ impl Process { self.post_fifos_open()?; // start io copy for stdin - let wgw_stdin = self.wg_stdin.worker(); if let Some(stdin) = shim_io.stdin { - self.run_io_copy("stdin", wgw_stdin, stdin, container_io.stdin) + self.run_io_copy(StdIoType::Stdin, None, stdin, container_io.stdin) .await?; } @@ -258,14 +259,19 @@ impl Process { // start io copy for stdout if let Some(stdout) = shim_io.stdout { - self.run_io_copy("stdout", wgw.clone(), container_io.stdout, stdout) - .await?; + self.run_io_copy( + StdIoType::Stdout, + Some(wgw.clone()), + container_io.stdout, + stdout, + ) + .await?; } // start io copy for stderr if !self.terminal { if let Some(stderr) = shim_io.stderr { - self.run_io_copy("stderr", wgw, container_io.stderr, stderr) + self.run_io_copy(StdIoType::Stderr, Some(wgw), container_io.stderr, stderr) .await?; } } @@ -276,27 +282,51 @@ impl Process { Ok(()) } - async fn run_io_copy<'a>( - &'a self, - io_name: &'a str, - wgw: WaitGroupWorker, + async fn run_io_copy( + &self, + io_type: StdIoType, + wgw: Option, mut reader: Box, mut writer: Box, ) -> Result<()> { - info!(self.logger, "run io copy for {}", io_name); - let io_name = io_name.to_string(); - let logger = self.logger.new(o!("io_name" => io_name)); + let io_name = format!("{:?}", io_type); + + info!(self.logger, "run_io_copy[{}] starts", io_name); + let logger = self.logger.new(o!("io_name" => io_name.clone())); + tokio::spawn(async move { match tokio::io::copy(&mut reader, &mut writer).await { Err(e) => { - warn!(logger, "run_io_copy: failed to copy stream: {}", e); + warn!( + logger, + "run_io_copy[{}]: failed to copy stream: {}", io_name, e + ); } Ok(length) => { - info!(logger, "run_io_copy: stop to copy stream length {}", length) + info!( + logger, + "run_io_copy[{}]: stop to copy stream length {}", io_name, length + ); + // Send EOF to agent by calling rpc write_stdin with 0 length data + if io_type == StdIoType::Stdin { + writer + .shutdown() + .await + .map_err(|e| { + error!( + logger, + "run_io_copy[{}]: failed to shutdown: {:?}", io_name, e + ); + e + }) + .ok(); + } } }; - wgw.done(); + if let Some(w) = wgw { + w.done() + } }); Ok(()) @@ -400,24 +430,13 @@ impl Process { } /// Close the stdin of the process in container. - pub async fn close_io(&mut self, agent: Arc) { + pub async fn close_io(&mut self, _agent: Arc) { // Close the stdin writer keeper so that // the end signal could be received in the read side self.stdin_w.take(); - // In passfd io mode, the stdin close and sync logic is handled - // in the agent side. - if self.passfd_io.is_none() { - self.wg_stdin.wait().await; - } - - let req = agent::CloseStdinRequest { - process_id: self.process.clone().into(), - }; - - if let Err(e) = agent.close_stdin(req).await { - warn!(self.logger, "failed close process io: {:?}", e); - } + // The stdin will be closed when EOF is got in rpc `read_stdout` of agent + // so we will not call agent.close_stdin anymore. } pub async fn get_status(&self) -> ProcessStatus {