diff --git a/src/agent/Cargo.lock b/src/agent/Cargo.lock index 21600eac9e..7eb2f0bdc5 100644 --- a/src/agent/Cargo.lock +++ b/src/agent/Cargo.lock @@ -161,9 +161,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "cgroups-rs" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02274214de2526e48355facdd16c9d774bba2cf74d135ffb9876a60b4d613464" +checksum = "348eb6d8e20a9f5247209686b7d0ffc2f4df40ddcb95f9940de55a94a655b3f5" dependencies = [ "libc", "log", @@ -486,6 +486,28 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "644f9158b2f133fd50f5fb3242878846d9eb792e445c893805ff0e3824006e35" +[[package]] +name = "inotify" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04c6848dfb1580647ab039713282cdd1ab2bfb47b60ecfb598e22e60e3baf3f8" +dependencies = [ + "bitflags", + "futures-core", + "inotify-sys", + "libc", + "tokio 0.3.6", +] + +[[package]] +name = "inotify-sys" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4563555856585ab3180a5bf0b2f9f8d301a728462afffc8195b3f5394229c55" +dependencies = [ + "libc", +] + [[package]] name = "iovec" version = "0.1.4" @@ -517,11 +539,13 @@ dependencies = [ "anyhow", "async-trait", "cgroups-rs", + "futures", "lazy_static", "libc", "log", "logging", "netlink", + "netlink-sys", "nix 0.17.0", "oci", "prctl", @@ -539,7 +563,8 @@ dependencies = [ "slog-scope", "slog-stdlog", "tempfile", - "tokio", + "tokio 0.2.24", + "tokio-vsock", "ttrpc", ] @@ -647,12 +672,37 @@ dependencies = [ "kernel32-sys", "libc", "log", - "miow", + "miow 0.2.2", "net2", "slab", "winapi 0.2.8", ] +[[package]] +name = "mio" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f33bc887064ef1fd66020c9adfc45bb9f33d75a42096c81e7c56c65b75dd1a8b" +dependencies = [ + "libc", + "log", + "miow 0.3.6", + "ntapi", + "winapi 0.3.9", +] + +[[package]] +name = "mio-named-pipes" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0840c1c50fd55e521b247f949c241c9997709f23bd7f023b9762cd561e935656" +dependencies = [ + "log", + "mio 0.6.23", + "miow 0.3.6", + "winapi 0.3.9", +] + [[package]] name = "mio-uds" version = "0.6.8" @@ -661,7 +711,7 @@ checksum = "afcb699eb26d4332647cc848492bbc15eafb26f08d0304550d5aa1f612e066f0" dependencies = [ "iovec", "libc", - "mio", + "mio 0.6.23", ] [[package]] @@ -676,6 +726,16 @@ dependencies = [ "ws2_32-sys", ] +[[package]] +name = "miow" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a33c1b55807fbed163481b5ba66db4b2fa6cde694a5027be10fb724206c5897" +dependencies = [ + "socket2", + "winapi 0.3.9", +] + [[package]] name = "multimap" version = "0.4.0" @@ -705,6 +765,19 @@ dependencies = [ "slog-scope", ] +[[package]] +name = "netlink-sys" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc9e9df13fd91bdd4b92bea93d5d2848c8035677c60fc3fee5dabddc02c3012e" +dependencies = [ + "futures", + "libc", + "log", + "mio 0.6.23", + "tokio 0.2.24", +] + [[package]] name = "nix" version = "0.16.1" @@ -761,6 +834,15 @@ version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf51a729ecf40266a2368ad335a5fdde43471f545a967109cd62146ecf8b66ff" +[[package]] +name = "ntapi" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44" +dependencies = [ + "winapi 0.3.9", +] + [[package]] name = "num-integer" version = "0.1.43" @@ -890,6 +972,12 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c917123afa01924fc84bb20c4c03f004d9c38e5127e3c039bbf7f4b9c76a2f6b" +[[package]] +name = "pin-project-lite" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b063f57ec186e6140e2b8b6921e5f1bd89c7356dda5b33acc5401203ca6131c" + [[package]] name = "pin-utils" version = "0.1.0" @@ -1213,12 +1301,16 @@ name = "rustjail" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "caps", "cgroups-rs", "dirs", "epoll", + "futures", + "inotify", "lazy_static", "libc", + "mio 0.6.23", "nix 0.17.0", "oci", "path-absolutize", @@ -1235,6 +1327,7 @@ dependencies = [ "slog", "slog-scope", "tempfile", + "tokio 0.2.24", ] [[package]] @@ -1413,6 +1506,17 @@ version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbee7696b84bbf3d89a1c2eccff0850e3047ed46bfcd2e92c29a2d074d57e252" +[[package]] +name = "socket2" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "122e570113d28d773067fab24266b66753f6ea915758651696b6e35e49f88d6e" +dependencies = [ + "cfg-if 1.0.0", + "libc", + "winapi 0.3.9", +] + [[package]] name = "spin" version = "0.5.2" @@ -1513,12 +1617,27 @@ dependencies = [ "lazy_static", "libc", "memchr", - "mio", + "mio 0.6.23", + "mio-named-pipes", "mio-uds", "num_cpus", - "pin-project-lite", + "pin-project-lite 0.1.11", + "signal-hook-registry", "slab", "tokio-macros", + "winapi 0.3.9", +] + +[[package]] +name = "tokio" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "720ba21c25078711bf456d607987d95bce90f7c3bea5abe1db587862e7a1e87c" +dependencies = [ + "autocfg", + "libc", + "mio 0.7.6", + "pin-project-lite 0.2.0", ] [[package]] @@ -1542,17 +1661,17 @@ dependencies = [ "futures", "iovec", "libc", - "mio", + "mio 0.6.23", "nix 0.17.0", - "tokio", + "tokio 0.2.24", "vsock", ] [[package]] name = "ttrpc" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6e99ffa09e7fbe514b58b01bd17d71e3ed4dd27c588afa43d41ec0b7fc90b0a" +checksum = "fc512242eee1f113eadd48087dd97cbf807ccae4820006e7a890044044399c51" dependencies = [ "async-trait", "byteorder", @@ -1563,7 +1682,7 @@ dependencies = [ "protobuf", "protobuf-codegen-pure", "thiserror", - "tokio", + "tokio 0.2.24", "tokio-vsock", ] diff --git a/src/agent/Cargo.toml b/src/agent/Cargo.toml index 1fc3e287e7..1d75a2d8cb 100644 --- a/src/agent/Cargo.toml +++ b/src/agent/Cargo.toml @@ -11,7 +11,7 @@ rustjail = { path = "rustjail" } protocols = { path = "protocols" } netlink = { path = "netlink", features = ["with-log", "with-agent-handler"] } lazy_static = "1.3.0" -ttrpc = { version="0.4.13", features=["async"] } +ttrpc = { version = "0.4.14", features = ["async", "protobuf-codec"], default-features = false } protobuf = "=2.14.0" libc = "0.2.58" nix = "0.17.0" @@ -21,8 +21,12 @@ signal-hook = "0.1.9" scan_fmt = "0.2.3" scopeguard = "1.0.0" regex = "1" -tokio = { version="0.2", features = ["macros", "rt-threaded"] } + async-trait = "0.1.42" +tokio = { version = "0.2", features = ["rt-core", "sync", "uds", "stream", "macros", "io-util", "time", "signal", "io-std", "process",] } +futures = "0.3" +netlink-sys = { version = "0.4.0", features = ["tokio_socket",]} +tokio-vsock = "0.2.2" # slog: # - Dynamic keys required to allow HashMap keys to be slog::Serialized. @@ -40,7 +44,7 @@ tempfile = "3.1.0" prometheus = { version = "0.9.0", features = ["process"] } procfs = "0.7.9" anyhow = "1.0.32" -cgroups = { package = "cgroups-rs", version = "0.2.0" } +cgroups = { package = "cgroups-rs", version = "0.2.1" } [workspace] members = [ diff --git a/src/agent/protocols/Cargo.toml b/src/agent/protocols/Cargo.toml index f44c4bc501..8d9eca6c2b 100644 --- a/src/agent/protocols/Cargo.toml +++ b/src/agent/protocols/Cargo.toml @@ -5,7 +5,7 @@ authors = ["The Kata Containers community "] edition = "2018" [dependencies] -ttrpc = { version="0.4.13", features=["async"] } +ttrpc = { version = "0.4.14", features = ["async"] } async-trait = "0.1.42" protobuf = "=2.14.0" diff --git a/src/agent/rustjail/Cargo.toml b/src/agent/rustjail/Cargo.toml index fef0af84f9..af8623b044 100644 --- a/src/agent/rustjail/Cargo.toml +++ b/src/agent/rustjail/Cargo.toml @@ -16,7 +16,7 @@ scopeguard = "1.0.0" prctl = "1.0.0" lazy_static = "1.3.0" libc = "0.2.58" -protobuf = "2.8.1" +protobuf = "=2.14.0" slog = "2.5.2" slog-scope = "4.1.2" scan_fmt = "0.2" @@ -24,9 +24,15 @@ regex = "1.1" path-absolutize = "1.2.0" dirs = "3.0.1" anyhow = "1.0.32" -cgroups = { package = "cgroups-rs", version = "0.2.0" } +cgroups = { package = "cgroups-rs", version = "0.2.1" } tempfile = "3.1.0" epoll = "4.3.1" +tokio = { version = "0.2", features = ["sync", "io-util", "process", "time", "macros"] } +futures = "0.3" +async-trait = "0.1.31" +mio = "0.6" +inotify = "0.9" + [dev-dependencies] serial_test = "0.5.0" diff --git a/src/agent/rustjail/src/cgroups/notifier.rs b/src/agent/rustjail/src/cgroups/notifier.rs index cbe03c9806..c81da53b4a 100644 --- a/src/agent/rustjail/src/cgroups/notifier.rs +++ b/src/agent/rustjail/src/cgroups/notifier.rs @@ -3,16 +3,18 @@ // SPDX-License-Identifier: Apache-2.0 // -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use eventfd::{eventfd, EfdFlags}; use nix::sys::eventfd; -use nix::sys::inotify::{AddWatchFlags, InitFlags, Inotify}; use std::fs::{self, File}; -use std::io::Read; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::path::{Path, PathBuf}; -use std::sync::mpsc::{self, Receiver}; -use std::thread; + +use crate::pipestream::PipeStream; +use futures::StreamExt as _; +use inotify::{Inotify, WatchMask}; +use tokio::io::AsyncReadExt; +use tokio::sync::mpsc::{channel, Receiver}; // Convenience macro to obtain the scope logger macro_rules! sl { @@ -21,11 +23,11 @@ macro_rules! sl { }; } -pub fn notify_oom(cid: &str, cg_dir: String) -> Result> { +pub async fn notify_oom(cid: &str, cg_dir: String) -> Result> { if cgroups::hierarchies::is_cgroup2_unified_mode() { - return notify_on_oom_v2(cid, cg_dir); + return notify_on_oom_v2(cid, cg_dir).await; } - notify_on_oom(cid, cg_dir) + notify_on_oom(cid, cg_dir).await } // get_value_from_cgroup parse cgroup file with `Flat keyed` @@ -52,11 +54,11 @@ fn get_value_from_cgroup(path: &PathBuf, key: &str) -> Result { // notify_on_oom returns channel on which you can expect event about OOM, // if process died without OOM this channel will be closed. -pub fn notify_on_oom_v2(containere_id: &str, cg_dir: String) -> Result> { - register_memory_event_v2(containere_id, cg_dir, "memory.events", "cgroup.events") +pub async fn notify_on_oom_v2(containere_id: &str, cg_dir: String) -> Result> { + register_memory_event_v2(containere_id, cg_dir, "memory.events", "cgroup.events").await } -fn register_memory_event_v2( +async fn register_memory_event_v2( containere_id: &str, cg_dir: String, memory_event_name: &str, @@ -73,54 +75,54 @@ fn register_memory_event_v2( "register_memory_event_v2 cgroup_event_control_path: {:?}", &cgroup_event_control_path ); - let fd = Inotify::init(InitFlags::empty()).unwrap(); + let mut inotify = Inotify::init().context("Failed to initialize inotify")?; // watching oom kill - let ev_fd = fd - .add_watch(&event_control_path, AddWatchFlags::IN_MODIFY) - .unwrap(); + let ev_wd = inotify.add_watch(&event_control_path, WatchMask::MODIFY)?; // Because no `unix.IN_DELETE|unix.IN_DELETE_SELF` event for cgroup file system, so watching all process exited - let cg_fd = fd - .add_watch(&cgroup_event_control_path, AddWatchFlags::IN_MODIFY) - .unwrap(); - info!(sl!(), "ev_fd: {:?}", ev_fd); - info!(sl!(), "cg_fd: {:?}", cg_fd); + let cg_wd = inotify.add_watch(&cgroup_event_control_path, WatchMask::MODIFY)?; - let (sender, receiver) = mpsc::channel(); + info!(sl!(), "ev_wd: {:?}", ev_wd); + info!(sl!(), "cg_wd: {:?}", cg_wd); + + let (mut sender, receiver) = channel(100); let containere_id = containere_id.to_string(); - thread::spawn(move || { - loop { - let events = fd.read_events().unwrap(); + tokio::spawn(async move { + let mut buffer = [0; 32]; + let mut stream = inotify + .event_stream(&mut buffer) + .expect("create inotify event stream failed"); + + while let Some(event_or_error) = stream.next().await { + let event = event_or_error.unwrap(); info!( sl!(), - "container[{}] get events for container: {:?}", &containere_id, &events + "container[{}] get event for container: {:?}", &containere_id, &event ); + // info!("is1: {}", event.wd == wd1); + info!(sl!(), "event.wd: {:?}", event.wd); - for event in events { - if event.mask & AddWatchFlags::IN_MODIFY != AddWatchFlags::IN_MODIFY { - continue; + if event.wd == ev_wd { + let oom = get_value_from_cgroup(&event_control_path, "oom_kill"); + if oom.unwrap_or(0) > 0 { + let _ = sender.send(containere_id.clone()).await.map_err(|e| { + error!(sl!(), "send containere_id failed, error: {:?}", e); + }); + return; } - info!(sl!(), "event.wd: {:?}", event.wd); - - if event.wd == ev_fd { - let oom = get_value_from_cgroup(&event_control_path, "oom_kill"); - if oom.unwrap_or(0) > 0 { - sender.send(containere_id.clone()).unwrap(); - return; - } - } else if event.wd == cg_fd { - let pids = get_value_from_cgroup(&cgroup_event_control_path, "populated"); - if pids.unwrap_or(-1) == 0 { - return; - } + } else if event.wd == cg_wd { + let pids = get_value_from_cgroup(&cgroup_event_control_path, "populated"); + if pids.unwrap_or(-1) == 0 { + return; } } - // When a cgroup is destroyed, an event is sent to eventfd. - // So if the control path is gone, return instead of notifying. - if !Path::new(&event_control_path).exists() { - return; - } + } + + // When a cgroup is destroyed, an event is sent to eventfd. + // So if the control path is gone, return instead of notifying. + if !Path::new(&event_control_path).exists() { + return; } }); @@ -129,16 +131,16 @@ fn register_memory_event_v2( // notify_on_oom returns channel on which you can expect event about OOM, // if process died without OOM this channel will be closed. -fn notify_on_oom(cid: &str, dir: String) -> Result> { +async fn notify_on_oom(cid: &str, dir: String) -> Result> { if dir == "" { return Err(anyhow!("memory controller missing")); } - register_memory_event(cid, dir, "memory.oom_control", "") + register_memory_event(cid, dir, "memory.oom_control", "").await } // level is one of "low", "medium", or "critical" -fn notify_memory_pressure(cid: &str, dir: String, level: &str) -> Result> { +async fn notify_memory_pressure(cid: &str, dir: String, level: &str) -> Result> { if dir == "" { return Err(anyhow!("memory controller missing")); } @@ -147,10 +149,10 @@ fn notify_memory_pressure(cid: &str, dir: String, level: &str) -> Result { warn!(sl!(), "failed to read from eventfd: {:?}", err); return; @@ -198,7 +201,10 @@ fn register_memory_event( if !Path::new(&event_control_path).exists() { return; } - sender.send(containere_id.clone()).unwrap(); + + let _ = sender.send(containere_id.clone()).await.map_err(|e| { + error!(sl!(), "send containere_id failed, error: {:?}", e); + }); } }); diff --git a/src/agent/rustjail/src/container.rs b/src/agent/rustjail/src/container.rs index 14531757c1..94a284a8d1 100644 --- a/src/agent/rustjail/src/container.rs +++ b/src/agent/rustjail/src/container.rs @@ -14,7 +14,6 @@ use std::fmt::Display; use std::fs; use std::os::unix::io::RawFd; use std::path::{Path, PathBuf}; -use std::process::Command; use std::time::SystemTime; use cgroups::freezer::FreezerState; @@ -28,7 +27,6 @@ use crate::cgroups::Manager; use crate::log_child; use crate::process::Process; use crate::specconv::CreateOpts; -use crate::sync::*; use crate::{mount, validator}; use protocols::agent::StatsContainerResponse; @@ -49,12 +47,16 @@ use protobuf::SingularPtrField; use oci::State as OCIState; use std::collections::HashMap; -use std::io::BufRead; -use std::io::BufReader; use std::os::unix::io::FromRawFd; use slog::{info, o, Logger}; +use crate::pipestream::PipeStream; +use crate::sync::{read_sync, write_count, write_sync, SYNC_DATA, SYNC_FAILED, SYNC_SUCCESS}; +use crate::sync_with_async::{read_async, write_async}; +use async_trait::async_trait; +use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}; + const STATE_FILENAME: &str = "state.json"; const EXEC_FIFO_FILENAME: &str = "exec.fifo"; const VER_MARKER: &str = "1.2.5"; @@ -215,6 +217,7 @@ pub struct BaseState { init_process_start: u64, } +#[async_trait] pub trait BaseContainer { fn id(&self) -> String; fn status(&self) -> Status; @@ -225,9 +228,9 @@ pub trait BaseContainer { fn get_process(&mut self, eid: &str) -> Result<&mut Process>; fn stats(&self) -> Result; fn set(&mut self, config: LinuxResources) -> Result<()>; - fn start(&mut self, p: Process) -> Result<()>; - fn run(&mut self, p: Process) -> Result<()>; - fn destroy(&mut self) -> Result<()>; + async fn start(&mut self, p: Process) -> Result<()>; + async fn run(&mut self, p: Process) -> Result<()>; + async fn destroy(&mut self) -> Result<()>; fn signal(&self, sig: Signal, all: bool) -> Result<()>; fn exec(&mut self) -> Result<()>; } @@ -273,6 +276,7 @@ pub struct SyncPC { pid: pid_t, } +#[async_trait] pub trait Container: BaseContainer { fn pause(&mut self) -> Result<()>; fn resume(&mut self) -> Result<()>; @@ -723,6 +727,7 @@ fn set_stdio_permissions(uid: libc::uid_t) -> Result<()> { Ok(()) } +#[async_trait] impl BaseContainer for LinuxContainer { fn id(&self) -> String { self.id.clone() @@ -816,7 +821,7 @@ impl BaseContainer for LinuxContainer { Ok(()) } - fn start(&mut self, mut p: Process) -> Result<()> { + async fn start(&mut self, mut p: Process) -> Result<()> { let logger = self.logger.new(o!("eid" => p.exec_id.clone())); let tty = p.tty; let fifo_file = format!("{}/{}", &self.root, EXEC_FIFO_FILENAME); @@ -854,7 +859,7 @@ impl BaseContainer for LinuxContainer { .map_err(|e| warn!(logger, "fcntl pfd log FD_CLOEXEC {:?}", e)); let child_logger = logger.new(o!("action" => "child process log")); - let log_handler = setup_child_logger(pfd_log, child_logger)?; + let log_handler = setup_child_logger(pfd_log, child_logger); let (prfd, cwfd) = unistd::pipe().context("failed to create pipe")?; let (crfd, pwfd) = unistd::pipe().context("failed to create pipe")?; @@ -865,10 +870,8 @@ impl BaseContainer for LinuxContainer { let _ = fcntl::fcntl(pwfd, FcntlArg::F_SETFD(FdFlag::FD_CLOEXEC)) .map_err(|e| warn!(logger, "fcntl pwfd FD_COLEXEC {:?}", e)); - defer!({ - let _ = unistd::close(prfd).map_err(|e| warn!(logger, "close prfd {:?}", e)); - let _ = unistd::close(pwfd).map_err(|e| warn!(logger, "close pwfd {:?}", e)); - }); + let mut pipe_r = PipeStream::from_fd(prfd); + let mut pipe_w = PipeStream::from_fd(pwfd); let child_stdin: std::process::Stdio; let child_stdout: std::process::Stdio; @@ -928,7 +931,7 @@ impl BaseContainer for LinuxContainer { unistd::close(cfd_log)?; // get container process's pid - let pid_buf = read_sync(prfd)?; + let pid_buf = read_async(&mut pipe_r).await?; let pid_str = std::str::from_utf8(&pid_buf).context("get pid string")?; let pid = match pid_str.parse::() { Ok(i) => i, @@ -958,9 +961,10 @@ impl BaseContainer for LinuxContainer { &p, self.cgroup_manager.as_ref().unwrap(), &st, - pwfd, - prfd, + &mut pipe_w, + &mut pipe_r, ) + .await .map_err(|e| { error!(logger, "create container process error {:?}", e); // kill the child process. @@ -995,15 +999,15 @@ impl BaseContainer for LinuxContainer { info!(logger, "wait on child log handler"); let _ = log_handler - .join() + .await .map_err(|e| warn!(logger, "joining log handler {:?}", e)); info!(logger, "create process completed"); Ok(()) } - fn run(&mut self, p: Process) -> Result<()> { + async fn run(&mut self, p: Process) -> Result<()> { let init = p.init; - self.start(p)?; + self.start(p).await?; if init { self.exec()?; @@ -1013,7 +1017,7 @@ impl BaseContainer for LinuxContainer { Ok(()) } - fn destroy(&mut self) -> Result<()> { + async fn destroy(&mut self) -> Result<()> { let spec = self.config.spec.as_ref().unwrap(); let st = self.oci_state()?; @@ -1025,7 +1029,7 @@ impl BaseContainer for LinuxContainer { info!(self.logger, "poststop"); let hooks = spec.hooks.as_ref().unwrap(); for h in hooks.poststop.iter() { - execute_hook(&self.logger, h, &st)?; + execute_hook(&self.logger, h, &st).await?; } } @@ -1177,42 +1181,38 @@ fn get_namespaces(linux: &Linux) -> Vec { .collect() } -pub fn setup_child_logger(fd: RawFd, child_logger: Logger) -> Result> { - let builder = thread::Builder::new(); - builder - .spawn(move || { - let log_file = unsafe { std::fs::File::from_raw_fd(fd) }; - let mut reader = BufReader::new(log_file); +pub fn setup_child_logger(fd: RawFd, child_logger: Logger) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let log_file_stream = PipeStream::from_fd(fd); + let buf_reader_stream = tokio::io::BufReader::new(log_file_stream); + let mut lines = buf_reader_stream.lines(); - loop { - let mut line = String::new(); - match reader.read_line(&mut line) { - Err(e) => { - info!(child_logger, "read child process log error: {:?}", e); - break; - } - Ok(count) => { - if count == 0 { - info!(child_logger, "read child process log end",); - break; - } - - info!(child_logger, "{}", line); - } + loop { + match lines.next_line().await { + Err(e) => { + info!(child_logger, "read child process log error: {:?}", e); + break; + } + Ok(Some(line)) => { + info!(child_logger, "{}", line); + } + Ok(None) => { + info!(child_logger, "read child process log end",); + break; } } - }) - .map_err(|e| anyhow!(e).context("failed to create thread")) + } + }) } -fn join_namespaces( +async fn join_namespaces( logger: &Logger, spec: &Spec, p: &Process, cm: &FsManager, st: &OCIState, - pwfd: RawFd, - prfd: RawFd, + pipe_w: &mut PipeStream, + pipe_r: &mut PipeStream, ) -> Result<()> { let logger = logger.new(o!("action" => "join-namespaces")); @@ -1223,25 +1223,25 @@ fn join_namespaces( info!(logger, "try to send spec from parent to child"); let spec_str = serde_json::to_string(spec)?; - write_sync(pwfd, SYNC_DATA, spec_str.as_str())?; + write_async(pipe_w, SYNC_DATA, spec_str.as_str()).await?; info!(logger, "wait child received oci spec"); - read_sync(prfd)?; + read_async(pipe_r).await?; info!(logger, "send oci process from parent to child"); let process_str = serde_json::to_string(&p.oci)?; - write_sync(pwfd, SYNC_DATA, process_str.as_str())?; + write_async(pipe_w, SYNC_DATA, process_str.as_str()).await?; info!(logger, "wait child received oci process"); - read_sync(prfd)?; + read_async(pipe_r).await?; let cm_str = serde_json::to_string(cm)?; - write_sync(pwfd, SYNC_DATA, cm_str.as_str())?; + write_async(pipe_w, SYNC_DATA, cm_str.as_str()).await?; // wait child setup user namespace info!(logger, "wait child setup user namespace"); - read_sync(prfd)?; + read_async(pipe_r).await?; if userns { info!(logger, "setup uid/gid mappings"); @@ -1270,11 +1270,11 @@ fn join_namespaces( info!(logger, "notify child to continue"); // notify child to continue - write_sync(pwfd, SYNC_SUCCESS, "")?; + write_async(pipe_w, SYNC_SUCCESS, "").await?; if p.init { info!(logger, "notify child parent ready to run prestart hook!"); - let _ = read_sync(prfd)?; + let _ = read_async(pipe_r).await?; info!(logger, "get ready to run prestart hook!"); @@ -1283,17 +1283,17 @@ fn join_namespaces( info!(logger, "prestart hook"); let hooks = spec.hooks.as_ref().unwrap(); for h in hooks.prestart.iter() { - execute_hook(&logger, h, st)?; + execute_hook(&logger, h, st).await?; } } // notify child run prestart hooks completed info!(logger, "notify child run prestart hook completed!"); - write_sync(pwfd, SYNC_SUCCESS, "")?; + write_async(pipe_w, SYNC_SUCCESS, "").await?; info!(logger, "notify child parent ready to run poststart hook!"); // wait to run poststart hook - read_sync(prfd)?; + read_async(pipe_r).await?; info!(logger, "get ready to run poststart hook!"); // run poststart hook @@ -1301,13 +1301,13 @@ fn join_namespaces( info!(logger, "poststart hook"); let hooks = spec.hooks.as_ref().unwrap(); for h in hooks.poststart.iter() { - execute_hook(&logger, h, st)?; + execute_hook(&logger, h, st).await?; } } } info!(logger, "wait for child process ready to run exec"); - read_sync(prfd)?; + read_async(pipe_r).await?; Ok(()) } @@ -1509,14 +1509,11 @@ fn set_sysctls(sysctls: &HashMap) -> Result<()> { Ok(()) } -use std::io::Read; use std::os::unix::process::ExitStatusExt; use std::process::Stdio; -use std::sync::mpsc::{self, RecvTimeoutError}; -use std::thread; use std::time::Duration; -fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { +async fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { let logger = logger.new(o!("action" => "execute-hook")); let binary = PathBuf::from(h.path.as_str()); @@ -1535,9 +1532,12 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { let _ = unistd::close(wfd); }); + let mut pipe_r = PipeStream::from_fd(rfd); + let mut pipe_w = PipeStream::from_fd(wfd); + match unistd::fork()? { ForkResult::Parent { child } => { - let buf = read_sync(rfd)?; + let buf = read_async(&mut pipe_r).await?; let status = if buf.len() == 4 { let buf_array: [u8; 4] = [buf[0], buf[1], buf[2], buf[3]]; i32::from_be_bytes(buf_array) @@ -1561,13 +1561,13 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { } ForkResult::Child => { - let (tx, rx) = mpsc::channel(); - let (tx_logger, rx_logger) = mpsc::channel(); + let (mut tx, mut rx) = tokio::sync::mpsc::channel(100); + let (tx_logger, rx_logger) = tokio::sync::oneshot::channel(); tx_logger.send(logger.clone()).unwrap(); - let handle = thread::spawn(move || { - let logger = rx_logger.recv().unwrap(); + let handle = tokio::spawn(async move { + let logger = rx_logger.await.unwrap(); // write oci state to child let env: HashMap = envs @@ -1578,7 +1578,7 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { }) .collect(); - let mut child = Command::new(path.to_str().unwrap()) + let mut child = tokio::process::Command::new(path.to_str().unwrap()) .args(args.iter()) .envs(env.iter()) .stdin(Stdio::piped()) @@ -1588,7 +1588,7 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { .unwrap(); // send out our pid - tx.send(child.id() as libc::pid_t).unwrap(); + tx.send(child.id() as libc::pid_t).await.unwrap(); info!(logger, "hook grand: {}", child.id()); child @@ -1596,6 +1596,7 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { .as_mut() .unwrap() .write_all(state.as_bytes()) + .await .unwrap(); // read something from stdout for debug @@ -1605,9 +1606,10 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { .as_mut() .unwrap() .read_to_string(&mut out) + .await .unwrap(); info!(logger, "child stdout: {}", out.as_str()); - match child.wait() { + match child.await { Ok(exit) => { let code: i32 = if exit.success() { 0 @@ -1618,7 +1620,7 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { } }; - tx.send(code).unwrap(); + tx.send(code).await.unwrap(); } Err(e) => { @@ -1638,29 +1640,33 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { // -- FIXME // just in case. Should not happen any more - tx.send(0).unwrap(); + tx.send(0).await.unwrap(); } } }); - let pid = rx.recv().unwrap(); + let pid = rx.recv().await.unwrap(); info!(logger, "hook grand: {}", pid); let status = { if let Some(timeout) = h.timeout { - match rx.recv_timeout(Duration::from_secs(timeout as u64)) { - Ok(s) => s, - Err(e) => { - let error = if e == RecvTimeoutError::Timeout { - -libc::ETIMEDOUT - } else { - -libc::EPIPE - }; + let timeout = tokio::time::delay_for(Duration::from_secs(timeout as u64)); + tokio::select! { + v = rx.recv() => { + match v { + Some(s) => s, + None => { + let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL)); + -libc::EPIPE + } + } + } + _ = timeout => { let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL)); - error + -libc::ETIMEDOUT } } - } else if let Ok(s) = rx.recv() { + } else if let Some(s) = rx.recv().await { s } else { let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL)); @@ -1668,12 +1674,13 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { } }; - handle.join().unwrap(); - let _ = write_sync( - wfd, + handle.await.unwrap(); + let _ = write_async( + &mut pipe_w, SYNC_DATA, std::str::from_utf8(&status.to_be_bytes()).unwrap_or_default(), - ); + ) + .await; std::process::exit(0); } } diff --git a/src/agent/rustjail/src/lib.rs b/src/agent/rustjail/src/lib.rs index e88dde40c4..59b5b26721 100644 --- a/src/agent/rustjail/src/lib.rs +++ b/src/agent/rustjail/src/lib.rs @@ -40,10 +40,12 @@ pub mod capabilities; pub mod cgroups; pub mod container; pub mod mount; +pub mod pipestream; pub mod process; pub mod reaper; pub mod specconv; pub mod sync; +pub mod sync_with_async; pub mod validator; // pub mod factory; diff --git a/src/agent/rustjail/src/pipestream.rs b/src/agent/rustjail/src/pipestream.rs new file mode 100644 index 0000000000..5d08cdf893 --- /dev/null +++ b/src/agent/rustjail/src/pipestream.rs @@ -0,0 +1,159 @@ +// Copyright (c) 2020 Ant Group +// +// SPDX-License-Identifier: Apache-2.0 +// + +//! Async support for pipe or something has file descriptor + +use std::{ + fmt, io, mem, + os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, + pin::Pin, + task::{Context, Poll}, +}; + +use mio::event::Evented; +use mio::unix::EventedFd; +use mio::{Poll as MioPoll, PollOpt, Ready, Token}; +use nix::unistd; +use tokio::io::{AsyncRead, AsyncWrite, PollEvented}; + +unsafe fn set_nonblocking(fd: RawFd) { + libc::fcntl(fd, libc::F_SETFL, libc::O_NONBLOCK); +} + +struct StreamFd(RawFd); + +impl Evented for StreamFd { + fn register( + &self, + poll: &MioPoll, + token: Token, + interest: Ready, + opts: PollOpt, + ) -> io::Result<()> { + EventedFd(&self.0).register(poll, token, interest, opts) + } + + fn reregister( + &self, + poll: &MioPoll, + token: Token, + interest: Ready, + opts: PollOpt, + ) -> io::Result<()> { + EventedFd(&self.0).reregister(poll, token, interest, opts) + } + + fn deregister(&self, poll: &MioPoll) -> io::Result<()> { + EventedFd(&self.0).deregister(poll) + } +} + +impl io::Read for StreamFd { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match unistd::read(self.0, buf) { + Ok(l) => Ok(l), + Err(e) => Err(e.as_errno().unwrap().into()), + } + } +} + +impl io::Write for StreamFd { + fn write(&mut self, buf: &[u8]) -> io::Result { + match unistd::write(self.0, buf) { + Ok(l) => Ok(l), + Err(e) => Err(e.as_errno().unwrap().into()), + } + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +impl StreamFd { + fn close(&mut self) -> io::Result<()> { + match unistd::close(self.0) { + Ok(()) => Ok(()), + Err(e) => Err(e.as_errno().unwrap().into()), + } + } +} + +impl Drop for StreamFd { + fn drop(&mut self) { + self.close().ok(); + } +} + +/// Pipe read +pub struct PipeStream(PollEvented); + +impl PipeStream { + pub fn shutdown(&mut self) -> io::Result<()> { + self.0.get_mut().close() + } + + pub fn from_fd(fd: RawFd) -> Self { + unsafe { Self::from_raw_fd(fd) } + } +} + +impl AsyncRead for PipeStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl AsRawFd for PipeStream { + fn as_raw_fd(&self) -> RawFd { + self.0.get_ref().0 + } +} + +impl IntoRawFd for PipeStream { + fn into_raw_fd(self) -> RawFd { + let fd = self.0.get_ref().0; + mem::forget(self); + fd + } +} + +impl FromRawFd for PipeStream { + unsafe fn from_raw_fd(fd: RawFd) -> Self { + set_nonblocking(fd); + PipeStream(PollEvented::new(StreamFd(fd)).unwrap()) + } +} + +impl fmt::Debug for PipeStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "PipeStream({})", self.as_raw_fd()) + } +} + +impl AsyncWrite for PipeStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } +} diff --git a/src/agent/rustjail/src/process.rs b/src/agent/rustjail/src/process.rs index 8f2dd9c131..2e1d6ebbd2 100644 --- a/src/agent/rustjail/src/process.rs +++ b/src/agent/rustjail/src/process.rs @@ -6,7 +6,7 @@ use libc::pid_t; use std::fs::File; use std::os::unix::io::RawFd; -use std::sync::mpsc::Sender; +use tokio::sync::mpsc::Sender; use nix::fcntl::{fcntl, FcntlArg, OFlag}; use nix::sys::signal::{self, Signal}; @@ -18,6 +18,27 @@ use crate::reaper::Epoller; use oci::Process as OCIProcess; use slog::Logger; +use crate::pipestream::PipeStream; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::io::{split, ReadHalf, WriteHalf}; +use tokio::sync::Mutex; + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub enum StreamType { + Stdin, + Stdout, + Stderr, + ExitPipeR, + TermMaster, + ParentStdin, + ParentStdout, + ParentStderr, +} + +type Reader = Arc>>; +type Writer = Arc>>; + #[derive(Debug)] pub struct Process { pub exec_id: String, @@ -42,6 +63,9 @@ pub struct Process { pub oci: OCIProcess, pub logger: Logger, pub epoller: Option, + + readers: HashMap, + writers: HashMap, } pub trait ProcessOperations { @@ -94,6 +118,8 @@ impl Process { oci: ocip.clone(), logger: logger.clone(), epoller: None, + readers: HashMap::new(), + writers: HashMap::new(), }; info!(logger, "before create console socket!"); @@ -138,8 +164,59 @@ impl Process { } Ok(()) } + + fn get_fd(&self, stream_type: &StreamType) -> Option { + match stream_type { + StreamType::Stdin => self.stdin, + StreamType::Stdout => self.stdout, + StreamType::Stderr => self.stderr, + StreamType::ExitPipeR => self.exit_pipe_r, + StreamType::TermMaster => self.term_master, + StreamType::ParentStdin => self.parent_stdin, + StreamType::ParentStdout => self.parent_stdout, + StreamType::ParentStderr => self.parent_stderr, + } + } + + fn get_stream_and_store(&mut self, stream_type: StreamType) -> Option<(Reader, Writer)> { + let fd = self.get_fd(&stream_type)?; + let stream = PipeStream::from_fd(fd); + + let (reader, writer) = split(stream); + let reader = Arc::new(Mutex::new(reader)); + let writer = Arc::new(Mutex::new(writer)); + + self.readers.insert(stream_type.clone(), reader.clone()); + self.writers.insert(stream_type, writer.clone()); + + Some((reader, writer)) + } + + pub fn get_reader(&mut self, stream_type: StreamType) -> Option { + if let Some(reader) = self.readers.get(&stream_type) { + return Some(reader.clone()); + } + + let (reader, _) = self.get_stream_and_store(stream_type)?; + Some(reader) + } + + pub fn get_writer(&mut self, stream_type: StreamType) -> Option { + if let Some(writer) = self.writers.get(&stream_type) { + return Some(writer.clone()); + } + + let (_, writer) = self.get_stream_and_store(stream_type)?; + Some(writer) + } + + pub fn close_stream(&mut self, stream_type: StreamType) { + let _ = self.readers.remove(&stream_type); + let _ = self.writers.remove(&stream_type); + } } + fn create_extended_pipe(flags: OFlag, pipe_size: i32) -> Result<(RawFd, RawFd)> { let (r, w) = unistd::pipe2(flags)?; if pipe_size > 0 { diff --git a/src/agent/rustjail/src/sync.rs b/src/agent/rustjail/src/sync.rs index aee0c4cad0..d4dac2c63c 100644 --- a/src/agent/rustjail/src/sync.rs +++ b/src/agent/rustjail/src/sync.rs @@ -14,8 +14,8 @@ pub const SYNC_SUCCESS: i32 = 1; pub const SYNC_FAILED: i32 = 2; pub const SYNC_DATA: i32 = 3; -const DATA_SIZE: usize = 100; -const MSG_SIZE: usize = mem::size_of::(); +pub const DATA_SIZE: usize = 100; +pub const MSG_SIZE: usize = mem::size_of::(); #[macro_export] macro_rules! log_child { diff --git a/src/agent/rustjail/src/sync_with_async.rs b/src/agent/rustjail/src/sync_with_async.rs new file mode 100644 index 0000000000..7f44ab2819 --- /dev/null +++ b/src/agent/rustjail/src/sync_with_async.rs @@ -0,0 +1,148 @@ +// Copyright (c) 2020 Ant Group +// +// SPDX-License-Identifier: Apache-2.0 +// + +//! The async version of sync module used for IPC + +use crate::pipestream::PipeStream; +use anyhow::{anyhow, Result}; +use nix::errno::Errno; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +use crate::sync::{DATA_SIZE, MSG_SIZE, SYNC_DATA, SYNC_FAILED, SYNC_SUCCESS}; + +async fn write_count(pipe_w: &mut PipeStream, buf: &[u8], count: usize) -> Result { + let mut len = 0; + + loop { + match pipe_w.write(&buf[len..]).await { + Ok(l) => { + len += l; + if len == count { + break; + } + } + + Err(e) => { + if e.raw_os_error().unwrap() != Errno::EINTR as i32 { + return Err(e.into()); + } + } + } + } + + Ok(len) +} + +async fn read_count(pipe_r: &mut PipeStream, count: usize) -> Result> { + let mut v: Vec = vec![0; count]; + let mut len = 0; + + loop { + match pipe_r.read(&mut v[len..]).await { + Ok(l) => { + len += l; + if len == count || l == 0 { + break; + } + } + + Err(e) => { + if e.raw_os_error().unwrap() != Errno::EINTR as i32 { + return Err(e.into()); + } + } + } + } + + Ok(v[0..len].to_vec()) +} + +pub async fn read_async(pipe_r: &mut PipeStream) -> Result> { + let buf = read_count(pipe_r, MSG_SIZE).await?; + if buf.len() != MSG_SIZE { + return Err(anyhow!( + "process: {} failed to receive async message from peer: got msg length: {}, expected: {}", + std::process::id(), + buf.len(), + MSG_SIZE + )); + } + let buf_array: [u8; MSG_SIZE] = [buf[0], buf[1], buf[2], buf[3]]; + let msg: i32 = i32::from_be_bytes(buf_array); + match msg { + SYNC_SUCCESS => Ok(Vec::new()), + SYNC_DATA => { + let buf = read_count(pipe_r, MSG_SIZE).await?; + let buf_array: [u8; MSG_SIZE] = [buf[0], buf[1], buf[2], buf[3]]; + let msg_length: i32 = i32::from_be_bytes(buf_array); + let data_buf = read_count(pipe_r, msg_length as usize).await?; + + Ok(data_buf) + } + SYNC_FAILED => { + let mut error_buf = vec![]; + loop { + let buf = read_count(pipe_r, DATA_SIZE).await?; + + error_buf.extend(&buf); + if DATA_SIZE == buf.len() { + continue; + } else { + break; + } + } + + let error_str = match std::str::from_utf8(&error_buf) { + Ok(v) => String::from(v), + Err(e) => { + return Err( + anyhow!(e).context("receive error message from child process failed") + ); + } + }; + + Err(anyhow!(error_str)) + } + _ => Err(anyhow!("error in receive sync message")), + } +} + +pub async fn write_async(pipe_w: &mut PipeStream, msg_type: i32, data_str: &str) -> Result<()> { + let buf = msg_type.to_be_bytes(); + let count = write_count(pipe_w, &buf, MSG_SIZE).await?; + if count != MSG_SIZE { + return Err(anyhow!("error in send sync message")); + } + + match msg_type { + SYNC_FAILED => match write_count(pipe_w, data_str.as_bytes(), data_str.len()).await { + Ok(_) => pipe_w.shutdown()?, + Err(e) => { + pipe_w.shutdown()?; + return Err(anyhow!(e).context("error in send message to process")); + } + }, + SYNC_DATA => { + let length: i32 = data_str.len() as i32; + write_count(pipe_w, &length.to_be_bytes(), MSG_SIZE) + .await + .or_else(|e| { + pipe_w.shutdown()?; + Err(anyhow!(e).context("error in send message to process")) + })?; + + write_count(pipe_w, data_str.as_bytes(), data_str.len()) + .await + .or_else(|e| { + pipe_w.shutdown()?; + Err(anyhow!(e).context("error in send message to process")) + })?; + } + + _ => (), + }; + + Ok(()) +} diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 6a1265d947..dc1a47aef6 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -9,7 +9,8 @@ use std::collections::HashMap; use std::fs; use std::os::unix::fs::MetadataExt; use std::path::Path; -use std::sync::{mpsc, Arc, Mutex}; +use std::sync::Arc; +use tokio::sync::Mutex; use crate::linux_abi::*; use crate::mount::{DRIVERBLKTYPE, DRIVERMMIOBLKTYPE, DRIVERNVDIMMTYPE, DRIVERSCSITYPE}; @@ -35,22 +36,6 @@ struct DevIndexEntry { struct DevIndex(HashMap); -// DeviceHandler is the type of callback to be defined to handle every type of device driver. -type DeviceHandler = fn(&Device, &mut Spec, &Arc>, &DevIndex) -> Result<()>; - -// DEVICEHANDLERLIST lists the supported drivers. -#[rustfmt::skip] -lazy_static! { - static ref DEVICEHANDLERLIST: HashMap<&'static str, DeviceHandler> = { - let mut m: HashMap<&'static str, DeviceHandler> = HashMap::new(); - m.insert(DRIVERBLKTYPE, virtio_blk_device_handler); - m.insert(DRIVERMMIOBLKTYPE, virtiommio_blk_device_handler); - m.insert(DRIVERNVDIMMTYPE, virtio_nvdimm_device_handler); - m.insert(DRIVERSCSITYPE, virtio_scsi_device_handler); - m - }; -} - pub fn rescan_pci_bus() -> Result<()> { online_device(SYSFS_PCI_BUS_RESCAN_FILE) } @@ -114,10 +99,10 @@ fn get_pci_device_address(pci_id: &str) -> Result { Ok(bridge_device_pci_addr) } -fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Result { +async fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Result { // Keep the same lock order as uevent::handle_block_add_event(), otherwise it may cause deadlock. - let mut w = GLOBAL_DEVICE_WATCHER.lock().unwrap(); - let sb = sandbox.lock().unwrap(); + let mut w = GLOBAL_DEVICE_WATCHER.lock().await; + let sb = sandbox.lock().await; for (key, value) in sb.pci_device_map.iter() { if key.contains(dev_addr) { info!(sl!(), "Device {} found in pci device map", dev_addr); @@ -131,36 +116,50 @@ fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Result(); - w.insert(dev_addr.to_string(), tx); + let (tx, rx) = tokio::sync::oneshot::channel::(); + w.insert(dev_addr.to_string(), Some(tx)); drop(w); info!(sl!(), "Waiting on channel for device notification\n"); - let hotplug_timeout = AGENT_CONFIG.read().unwrap().hotplug_timeout; - let dev_name = rx.recv_timeout(hotplug_timeout).map_err(|_| { - GLOBAL_DEVICE_WATCHER.lock().unwrap().remove_entry(dev_addr); - anyhow!( - "Timeout reached after {:?} waiting for device {}", - hotplug_timeout, - dev_addr - ) - })?; + let hotplug_timeout = AGENT_CONFIG.read().await.hotplug_timeout; + let timeout = tokio::time::delay_for(hotplug_timeout); + + let dev_name; + tokio::select! { + v = rx => { + dev_name = v?; + } + _ = timeout => { + let watcher = GLOBAL_DEVICE_WATCHER.clone(); + let mut w = watcher.lock().await; + w.remove_entry(dev_addr); + + return Err(anyhow!( + "Timeout reached after {:?} waiting for device {}", + hotplug_timeout, + dev_addr + )); + } + }; Ok(format!("{}/{}", SYSTEM_DEV_PATH, &dev_name)) } -pub fn get_scsi_device_name(sandbox: &Arc>, scsi_addr: &str) -> Result { +pub async fn get_scsi_device_name( + sandbox: &Arc>, + scsi_addr: &str, +) -> Result { let dev_sub_path = format!("{}{}/{}", SCSI_HOST_CHANNEL, scsi_addr, SCSI_BLOCK_SUFFIX); scan_scsi_bus(scsi_addr)?; - get_device_name(sandbox, &dev_sub_path) + get_device_name(sandbox, &dev_sub_path).await } -pub fn get_pci_device_name(sandbox: &Arc>, pci_id: &str) -> Result { +pub async fn get_pci_device_name(sandbox: &Arc>, pci_id: &str) -> Result { let pci_addr = get_pci_device_address(pci_id)?; rescan_pci_bus()?; - get_device_name(sandbox, &pci_addr) + get_device_name(sandbox, &pci_addr).await } /// Scan SCSI bus for the given SCSI address(SCSI-Id and LUN) @@ -274,7 +273,7 @@ fn update_spec_device_list(device: &Device, spec: &mut Spec, devidx: &DevIndex) // device.Id should be the predicted device name (vda, vdb, ...) // device.VmPath already provides a way to send it in -fn virtiommio_blk_device_handler( +async fn virtiommio_blk_device_handler( device: &Device, spec: &mut Spec, _sandbox: &Arc>, @@ -290,7 +289,7 @@ fn virtiommio_blk_device_handler( // device.Id should be the PCI address in the format "bridgeAddr/deviceAddr". // Here, bridgeAddr is the address at which the brige is attached on the root bus, // while deviceAddr is the address at which the device is attached on the bridge. -fn virtio_blk_device_handler( +async fn virtio_blk_device_handler( device: &Device, spec: &mut Spec, sandbox: &Arc>, @@ -301,25 +300,25 @@ fn virtio_blk_device_handler( // When "Id (PCIAddr)" is not set, we allow to use the predicted "VmPath" passed from kata-runtime // Note this is a special code path for cloud-hypervisor when BDF information is not available if device.id != "" { - dev.vm_path = get_pci_device_name(sandbox, &device.id)?; + dev.vm_path = get_pci_device_name(sandbox, &device.id).await?; } update_spec_device_list(&dev, spec, devidx) } // device.Id should be the SCSI address of the disk in the format "scsiID:lunID" -fn virtio_scsi_device_handler( +async fn virtio_scsi_device_handler( device: &Device, spec: &mut Spec, sandbox: &Arc>, devidx: &DevIndex, ) -> Result<()> { let mut dev = device.clone(); - dev.vm_path = get_scsi_device_name(sandbox, &device.id)?; + dev.vm_path = get_scsi_device_name(sandbox, &device.id).await?; update_spec_device_list(&dev, spec, devidx) } -fn virtio_nvdimm_device_handler( +async fn virtio_nvdimm_device_handler( device: &Device, spec: &mut Spec, _sandbox: &Arc>, @@ -357,7 +356,7 @@ impl DevIndex { } } -pub fn add_devices( +pub async fn add_devices( devices: &[Device], spec: &mut Spec, sandbox: &Arc>, @@ -365,13 +364,13 @@ pub fn add_devices( let devidx = DevIndex::new(spec); for device in devices.iter() { - add_device(device, spec, sandbox, &devidx)?; + add_device(device, spec, sandbox, &devidx).await?; } Ok(()) } -fn add_device( +async fn add_device( device: &Device, spec: &mut Spec, sandbox: &Arc>, @@ -393,9 +392,12 @@ fn add_device( return Err(anyhow!("invalid container path for device {:?}", device)); } - match DEVICEHANDLERLIST.get(device.field_type.as_str()) { - None => Err(anyhow!("Unknown device type {}", device.field_type)), - Some(dev_handler) => dev_handler(device, spec, sandbox, devidx), + match device.field_type.as_str() { + DRIVERBLKTYPE => virtio_blk_device_handler(device, spec, sandbox, devidx).await, + DRIVERMMIOBLKTYPE => virtiommio_blk_device_handler(device, spec, sandbox, devidx).await, + DRIVERNVDIMMTYPE => virtio_nvdimm_device_handler(device, spec, sandbox, devidx).await, + DRIVERSCSITYPE => virtio_scsi_device_handler(device, spec, sandbox, devidx).await, + _ => Err(anyhow!("Unknown device type {}", device.field_type)), } } diff --git a/src/agent/src/main.rs b/src/agent/src/main.rs index aef948769c..b75f90e07f 100644 --- a/src/agent/src/main.rs +++ b/src/agent/src/main.rs @@ -15,7 +15,6 @@ extern crate prctl; extern crate prometheus; extern crate protocols; extern crate regex; -extern crate rustjail; extern crate scan_fmt; extern crate serde_json; extern crate signal_hook; @@ -38,7 +37,6 @@ use nix::sys::socket::{self, AddressFamily, SockAddr, SockFlag, SockType}; use nix::sys::wait::{self, WaitStatus}; use nix::unistd::{self, close, dup, dup2, fork, setsid, ForkResult}; use prctl::set_child_subreaper; -use signal_hook::{iterator::Signals, SIGCHLD}; use std::collections::HashMap; use std::env; use std::ffi::{CStr, CString, OsStr}; @@ -48,9 +46,7 @@ use std::os::unix::ffi::OsStrExt; use std::os::unix::fs as unixfs; use std::os::unix::io::AsRawFd; use std::path::Path; -use std::sync::mpsc::{self, Sender}; -use std::sync::{Arc, Mutex, RwLock}; -use std::{io, thread, thread::JoinHandle}; +use std::sync::Arc; use unistd::Pid; mod config; @@ -72,6 +68,16 @@ use sandbox::Sandbox; use slog::Logger; use uevent::watch_uevents; +use std::sync::Mutex as SyncMutex; + +use futures::StreamExt as _; +use rustjail::pipestream::PipeStream; +use tokio::{ + signal::unix::{signal, SignalKind}, + sync::{oneshot::Sender, Mutex, RwLock}, +}; +use tokio_vsock::{Incoming, VsockListener, VsockStream}; + mod rpc; const NAME: &str = "kata-agent"; @@ -81,7 +87,7 @@ const CONSOLE_PATH: &str = "/dev/console"; const DEFAULT_BUF_SIZE: usize = 8 * 1024; lazy_static! { - static ref GLOBAL_DEVICE_WATCHER: Arc>>> = + static ref GLOBAL_DEVICE_WATCHER: Arc>>>> = Arc::new(Mutex::new(HashMap::new())); static ref AGENT_CONFIG: Arc> = Arc::new(RwLock::new(config::agentConfig::new())); @@ -100,8 +106,28 @@ fn announce(logger: &Logger, config: &agentConfig) { ); } -#[tokio::main] -async fn main() -> Result<()> { +fn set_fd_close_exec(fd: RawFd) -> Result { + if let Err(e) = fcntl::fcntl(fd, FcntlArg::F_SETFD(FdFlag::FD_CLOEXEC)) { + return Err(anyhow!("failed to set fd: {} as close-on-exec: {}", fd, e)); + } + Ok(fd) +} + +fn get_vsock_incoming(fd: RawFd) -> Incoming { + let incoming; + unsafe { + incoming = VsockListener::from_raw_fd(fd).incoming(); + } + incoming +} + +async fn get_vsock_stream(fd: RawFd) -> Result { + let stream = get_vsock_incoming(fd).next().await.unwrap().unwrap(); + set_fd_close_exec(stream.as_raw_fd())?; + Ok(stream) +} + +fn main() -> std::result::Result<(), Box> { let args: Vec = env::args().collect(); if args.len() == 2 && args[1] == "--version" { @@ -121,107 +147,122 @@ async fn main() -> Result<()> { exit(0); } - env::set_var("RUST_BACKTRACE", "full"); + let mut rt = tokio::runtime::Builder::new() + .basic_scheduler() + .max_threads(1) + .enable_all() + .build()?; - lazy_static::initialize(&SHELLS); + rt.block_on(async { + env::set_var("RUST_BACKTRACE", "full"); - lazy_static::initialize(&AGENT_CONFIG); + lazy_static::initialize(&SHELLS); - // support vsock log - let (rfd, wfd) = unistd::pipe2(OFlag::O_CLOEXEC)?; + lazy_static::initialize(&AGENT_CONFIG); - let agentConfig = AGENT_CONFIG.clone(); + // support vsock log + let (rfd, wfd) = unistd::pipe2(OFlag::O_CLOEXEC)?; - let init_mode = unistd::getpid() == Pid::from_raw(1); - if init_mode { - // dup a new file descriptor for this temporary logger writer, - // since this logger would be dropped and it's writer would - // be closed out of this code block. - let newwfd = dup(wfd)?; - let writer = unsafe { File::from_raw_fd(newwfd) }; + let agentConfig = AGENT_CONFIG.clone(); - // Init a temporary logger used by init agent as init process - // since before do the base mount, it wouldn't access "/proc/cmdline" - // to get the customzied debug level. - let logger = logging::create_logger(NAME, "agent", slog::Level::Debug, writer); + let init_mode = unistd::getpid() == Pid::from_raw(1); + if init_mode { + // dup a new file descriptor for this temporary logger writer, + // since this logger would be dropped and it's writer would + // be closed out of this code block. + let newwfd = dup(wfd)?; + let writer = unsafe { File::from_raw_fd(newwfd) }; - // Must mount proc fs before parsing kernel command line - general_mount(&logger).map_err(|e| { - error!(logger, "fail general mount: {}", e); - e - })?; + // Init a temporary logger used by init agent as init process + // since before do the base mount, it wouldn't access "/proc/cmdline" + // to get the customzied debug level. + let logger = logging::create_logger(NAME, "agent", slog::Level::Debug, writer); - let mut config = agentConfig.write().unwrap(); - config.parse_cmdline(KERNEL_CMDLINE_FILE)?; + // Must mount proc fs before parsing kernel command line + general_mount(&logger).map_err(|e| { + error!(logger, "fail general mount: {}", e); + e + })?; - init_agent_as_init(&logger, config.unified_cgroup_hierarchy)?; - } else { - // once parsed cmdline and set the config, release the write lock - // as soon as possible in case other thread would get read lock on - // it. - let mut config = agentConfig.write().unwrap(); - config.parse_cmdline(KERNEL_CMDLINE_FILE)?; - } - let config = agentConfig.read().unwrap(); + let mut config = agentConfig.write().await; + config.parse_cmdline(KERNEL_CMDLINE_FILE)?; - let log_vport = config.log_vport as u32; - let log_handle = thread::spawn(move || -> Result<()> { - let mut reader = unsafe { File::from_raw_fd(rfd) }; - if log_vport > 0 { - let listenfd = socket::socket( - AddressFamily::Vsock, - SockType::Stream, - SockFlag::SOCK_CLOEXEC, - None, - )?; - let addr = SockAddr::new_vsock(libc::VMADDR_CID_ANY, log_vport); - socket::bind(listenfd, &addr)?; - socket::listen(listenfd, 1)?; - let datafd = socket::accept4(listenfd, SockFlag::SOCK_CLOEXEC)?; - let mut log_writer = unsafe { File::from_raw_fd(datafd) }; - let _ = io::copy(&mut reader, &mut log_writer)?; - let _ = unistd::close(listenfd); - let _ = unistd::close(datafd); + init_agent_as_init(&logger, config.unified_cgroup_hierarchy)?; + } else { + // once parsed cmdline and set the config, release the write lock + // as soon as possible in case other thread would get read lock on + // it. + let mut config = agentConfig.write().await; + config.parse_cmdline(KERNEL_CMDLINE_FILE)?; } - // copy log to stdout - let mut stdout_writer = io::stdout(); - let _ = io::copy(&mut reader, &mut stdout_writer)?; + let config = agentConfig.read().await; + + let log_vport = config.log_vport as u32; + let log_handle = tokio::spawn(async move { + let mut reader = PipeStream::from_fd(rfd); + + if log_vport > 0 { + let listenfd = socket::socket( + AddressFamily::Vsock, + SockType::Stream, + SockFlag::SOCK_CLOEXEC, + None, + ) + .unwrap(); + + let addr = SockAddr::new_vsock(libc::VMADDR_CID_ANY, log_vport); + socket::bind(listenfd, &addr).unwrap(); + socket::listen(listenfd, 1).unwrap(); + + let mut vsock_stream = get_vsock_stream(listenfd).await.unwrap(); + + // copy log to stdout + tokio::io::copy(&mut reader, &mut vsock_stream) + .await + .unwrap(); + } + + // copy log to stdout + let mut stdout_writer = tokio::io::stdout(); + let _ = tokio::io::copy(&mut reader, &mut stdout_writer).await; + }); + + let writer = unsafe { File::from_raw_fd(wfd) }; + + // Recreate a logger with the log level get from "/proc/cmdline". + let logger = logging::create_logger(NAME, "agent", config.log_level, writer); + + announce(&logger, &config); + + // This "unused" variable is required as it enables the global (and crucially static) logger, + // which is required to satisfy the the lifetime constraints of the auto-generated gRPC code. + let _guard = slog_scope::set_global_logger(logger.new(o!("subsystem" => "rpc"))); + + let mut _log_guard: Result<(), log::SetLoggerError> = Ok(()); + + if config.log_level == slog::Level::Trace { + // Redirect ttrpc log calls to slog iff full debug requested + _log_guard = Ok(slog_stdlog::init().map_err(|e| e)?); + } + + start_sandbox(&logger, &config, init_mode).await?; + + let _ = log_handle.await.unwrap(); + Ok(()) - }); - - let writer = unsafe { File::from_raw_fd(wfd) }; - // Recreate a logger with the log level get from "/proc/cmdline". - let logger = logging::create_logger(NAME, "agent", config.log_level, writer); - - announce(&logger, &config); - - // This "unused" variable is required as it enables the global (and crucially static) logger, - // which is required to satisfy the the lifetime constraints of the auto-generated gRPC code. - let _guard = slog_scope::set_global_logger(logger.new(o!("subsystem" => "rpc"))); - - let mut _log_guard: Result<(), log::SetLoggerError> = Ok(()); - - if config.log_level == slog::Level::Trace { - // Redirect ttrpc log calls to slog iff full debug requested - _log_guard = Ok(slog_stdlog::init().map_err(|e| e)?); - } - - start_sandbox(&logger, &config, init_mode).await?; - - let _ = log_handle.join(); - - Ok(()) + }) } async fn start_sandbox(logger: &Logger, config: &agentConfig, init_mode: bool) -> Result<()> { let shells = SHELLS.clone(); let debug_console_vport = config.debug_console_vport as u32; - let mut shell_handle: Option> = None; + // TODO: async the debug console + let mut _shell_handle: Option> = None; if config.debug_console { let thread_logger = logger.clone(); - let builder = thread::Builder::new(); + let builder = std::thread::Builder::new(); let handle = builder.spawn(move || { let shells = shells.lock().unwrap(); @@ -233,7 +274,7 @@ async fn start_sandbox(logger: &Logger, config: &agentConfig, init_mode: bool) - } })?; - shell_handle = Some(handle); + _shell_handle = Some(handle); } // Initialize unique sandbox structure. @@ -248,41 +289,38 @@ async fn start_sandbox(logger: &Logger, config: &agentConfig, init_mode: bool) - let sandbox = Arc::new(Mutex::new(s)); - setup_signal_handler(&logger, sandbox.clone()).unwrap(); - watch_uevents(sandbox.clone()); + setup_signal_handler(&logger, sandbox.clone()) + .await + .unwrap(); + watch_uevents(sandbox.clone()).await; - let (tx, rx) = mpsc::channel::(); - sandbox.lock().unwrap().sender = Some(tx); + let (tx, rx) = tokio::sync::oneshot::channel(); + sandbox.lock().await.sender = Some(tx); // vsock:///dev/vsock, port - let mut server = rpc::start(sandbox, config.server_addr.as_str()).await?; - + let mut server = rpc::start(sandbox.clone(), config.server_addr.as_str()); server.start().await?; - let _ = rx.recv()?; - + let _ = rx.await?; server.shutdown().await?; - if let Some(handle) = shell_handle { - handle.join().map_err(|e| anyhow!("{:?}", e))?; - } - Ok(()) } use nix::sys::wait::WaitPidFlag; -fn setup_signal_handler(logger: &Logger, sandbox: Arc>) -> Result<()> { +async fn setup_signal_handler(logger: &Logger, sandbox: Arc>) -> Result<()> { let logger = logger.new(o!("subsystem" => "signals")); set_child_subreaper(true) .map_err(|err| anyhow!(err).context("failed to setup agent as a child subreaper"))?; - let signals = Signals::new(&[SIGCHLD])?; + let mut signal_stream = signal(SignalKind::child())?; - thread::spawn(move || { - 'outer: for sig in signals.forever() { - info!(logger, "received signal"; "signal" => sig); + tokio::spawn(async move { + 'outer: loop { + signal_stream.recv().await; + info!(logger, "received signal"; "signal" => "SIGCHLD"); // sevral signals can be combined together // as one. So loop around to reap all @@ -316,7 +354,7 @@ fn setup_signal_handler(logger: &Logger, sandbox: Arc>) -> Result let logger = logger.new(o!("child-pid" => child_pid)); - let mut sandbox = sandbox.lock().unwrap(); + let mut sandbox = sandbox.lock().await; let process = sandbox.find_process(raw_pid); if process.is_none() { info!(logger, "child exited unexpectedly"); @@ -375,7 +413,7 @@ fn init_agent_as_init(logger: &Logger, unified_cgroup_hierarchy: bool) -> Result unistd::setsid()?; unsafe { - libc::ioctl(io::stdin().as_raw_fd(), libc::TIOCSCTTY, 1); + libc::ioctl(std::io::stdin().as_raw_fd(), libc::TIOCSCTTY, 1); } env::set_var("PATH", "/bin:/sbin/:/usr/bin/:/usr/sbin/"); @@ -406,7 +444,7 @@ fn sethostname(hostname: &OsStr) -> Result<()> { } lazy_static! { - static ref SHELLS: Arc>> = { + static ref SHELLS: Arc>> = { let mut v = Vec::new(); if !cfg!(test) { @@ -414,7 +452,7 @@ lazy_static! { v.push("/bin/sh".to_string()); } - Arc::new(Mutex::new(v)) + Arc::new(SyncMutex::new(v)) }; } @@ -480,7 +518,7 @@ fn setup_debug_console(logger: &Logger, shells: Vec, port: u32) -> Resul }; } -fn io_copy(reader: &mut R, writer: &mut W) -> io::Result +fn io_copy(reader: &mut R, writer: &mut W) -> std::io::Result where R: Read, W: Write, @@ -543,10 +581,10 @@ fn run_debug_console_shell(logger: &Logger, shell: &str, socket_fd: RawFd) -> Re let debug_shell_logger = logger.clone(); // channel that used to sync between thread and main process - let (tx, rx) = mpsc::channel::(); + let (tx, rx) = std::sync::mpsc::channel::(); // start a thread to do IO copy between socket and pseduo.master - thread::spawn(move || { + std::thread::spawn(move || { let mut master_reader = unsafe { File::from_raw_fd(master_fd) }; let mut master_writer = unsafe { File::from_raw_fd(master_fd) }; let mut socket_reader = unsafe { File::from_raw_fd(socket_fd) }; diff --git a/src/agent/src/mount.rs b/src/agent/src/mount.rs index d4180ce4d2..1c2b3ea368 100644 --- a/src/agent/src/mount.rs +++ b/src/agent/src/mount.rs @@ -12,7 +12,8 @@ use std::os::unix::fs::PermissionsExt; use std::path::Path; use std::ptr::null; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +use tokio::sync::Mutex; use libc::{c_void, mount}; use nix::mount::{self, MsFlags}; @@ -121,32 +122,15 @@ lazy_static! { ]; } -// StorageHandler is the type of callback to be defined to handle every -// type of storage driver. -type StorageHandler = fn(&Logger, &Storage, Arc>) -> Result; - -// STORAGEHANDLERLIST lists the supported drivers. -#[rustfmt::skip] -lazy_static! { - pub static ref STORAGEHANDLERLIST: HashMap<&'static str, StorageHandler> = { - let mut m = HashMap::new(); - let blk: StorageHandler = virtio_blk_storage_handler; - m.insert(DRIVERBLKTYPE, blk); - let p9: StorageHandler= virtio9p_storage_handler; - m.insert(DRIVER9PTYPE, p9); - let virtiofs: StorageHandler = virtiofs_storage_handler; - m.insert(DRIVERVIRTIOFSTYPE, virtiofs); - let ephemeral: StorageHandler = ephemeral_storage_handler; - m.insert(DRIVEREPHEMERALTYPE, ephemeral); - let virtiommio: StorageHandler = virtiommio_blk_storage_handler; - m.insert(DRIVERMMIOBLKTYPE, virtiommio); - let local: StorageHandler = local_storage_handler; - m.insert(DRIVERLOCALTYPE, local); - let scsi: StorageHandler = virtio_scsi_storage_handler; - m.insert(DRIVERSCSITYPE, scsi); - m - }; -} +pub const STORAGE_HANDLER_LIST: [&str; 7] = [ + DRIVERBLKTYPE, + DRIVER9PTYPE, + DRIVERVIRTIOFSTYPE, + DRIVEREPHEMERALTYPE, + DRIVERMMIOBLKTYPE, + DRIVERLOCALTYPE, + DRIVERSCSITYPE, +]; #[derive(Debug, Clone)] pub struct BareMount<'a> { @@ -238,12 +222,12 @@ impl<'a> BareMount<'a> { } } -fn ephemeral_storage_handler( +async fn ephemeral_storage_handler( logger: &Logger, storage: &Storage, sandbox: Arc>, ) -> Result { - let mut sb = sandbox.lock().unwrap(); + let mut sb = sandbox.lock().await; let new_storage = sb.set_sandbox_storage(&storage.mount_point); if !new_storage { @@ -256,12 +240,12 @@ fn ephemeral_storage_handler( Ok("".to_string()) } -fn local_storage_handler( +async fn local_storage_handler( _logger: &Logger, storage: &Storage, sandbox: Arc>, ) -> Result { - let mut sb = sandbox.lock().unwrap(); + let mut sb = sandbox.lock().await; let new_storage = sb.set_sandbox_storage(&storage.mount_point); if !new_storage { @@ -289,7 +273,7 @@ fn local_storage_handler( Ok("".to_string()) } -fn virtio9p_storage_handler( +async fn virtio9p_storage_handler( logger: &Logger, storage: &Storage, _sandbox: Arc>, @@ -298,7 +282,7 @@ fn virtio9p_storage_handler( } // virtiommio_blk_storage_handler handles the storage for mmio blk driver. -fn virtiommio_blk_storage_handler( +async fn virtiommio_blk_storage_handler( logger: &Logger, storage: &Storage, _sandbox: Arc>, @@ -308,7 +292,7 @@ fn virtiommio_blk_storage_handler( } // virtiofs_storage_handler handles the storage for virtio-fs. -fn virtiofs_storage_handler( +async fn virtiofs_storage_handler( logger: &Logger, storage: &Storage, _sandbox: Arc>, @@ -317,7 +301,7 @@ fn virtiofs_storage_handler( } // virtio_blk_storage_handler handles the storage for blk driver. -fn virtio_blk_storage_handler( +async fn virtio_blk_storage_handler( logger: &Logger, storage: &Storage, sandbox: Arc>, @@ -334,7 +318,7 @@ fn virtio_blk_storage_handler( return Err(anyhow!("Invalid device {}", &storage.source)); } } else { - let dev_path = get_pci_device_name(&sandbox, &storage.source)?; + let dev_path = get_pci_device_name(&sandbox, &storage.source).await?; storage.source = dev_path; } @@ -342,7 +326,7 @@ fn virtio_blk_storage_handler( } // virtio_scsi_storage_handler handles the storage for scsi driver. -fn virtio_scsi_storage_handler( +async fn virtio_scsi_storage_handler( logger: &Logger, storage: &Storage, sandbox: Arc>, @@ -350,7 +334,7 @@ fn virtio_scsi_storage_handler( let mut storage = storage.clone(); // Retrieve the device path from SCSI address. - let dev_path = get_scsi_device_name(&sandbox, &storage.source)?; + let dev_path = get_scsi_device_name(&sandbox, &storage.source).await?; storage.source = dev_path; common_storage_handler(logger, &storage) @@ -430,7 +414,7 @@ fn parse_mount_flags_and_options(options_vec: Vec<&str>) -> (MsFlags, String) { // associated operations such as waiting for the device to show up, and mount // it to a specific location, according to the type of handler chosen, and for // each storage. -pub fn add_storages( +pub async fn add_storages( logger: Logger, storages: Vec, sandbox: Arc>, @@ -443,17 +427,30 @@ pub fn add_storages( "subsystem" => "storage", "storage-type" => handler_name.to_owned())); - let handler = STORAGEHANDLERLIST - .get(&handler_name.as_str()) - .ok_or_else(|| { - anyhow!( + let res = match handler_name.as_str() { + DRIVERBLKTYPE => virtio_blk_storage_handler(&logger, &storage, sandbox.clone()).await, + DRIVER9PTYPE => virtio9p_storage_handler(&logger, &storage, sandbox.clone()).await, + DRIVERVIRTIOFSTYPE => { + virtiofs_storage_handler(&logger, &storage, sandbox.clone()).await + } + DRIVEREPHEMERALTYPE => { + ephemeral_storage_handler(&logger, &storage, sandbox.clone()).await + } + DRIVERMMIOBLKTYPE => { + virtiommio_blk_storage_handler(&logger, &storage, sandbox.clone()).await + } + DRIVERLOCALTYPE => local_storage_handler(&logger, &storage, sandbox.clone()).await, + DRIVERSCSITYPE => virtio_scsi_storage_handler(&logger, &storage, sandbox.clone()).await, + _ => { + return Err(anyhow!( "Failed to find the storage handler {}", storage.driver.to_owned() - ) - })?; + )); + } + }; // Todo need to rollback the mounted storage if err met. - let mount_point = handler(&logger, &storage, sandbox.clone())?; + let mount_point = res?; if !mount_point.is_empty() { mount_list.push(mount_point); diff --git a/src/agent/src/namespace.rs b/src/agent/src/namespace.rs index 4d0d70d8d7..8053df5af3 100644 --- a/src/agent/src/namespace.rs +++ b/src/agent/src/namespace.rs @@ -11,7 +11,6 @@ use std::fmt; use std::fs; use std::fs::File; use std::path::{Path, PathBuf}; -use std::thread::{self}; use crate::mount::{BareMount, FLAGS}; use slog::Logger; @@ -76,7 +75,7 @@ impl Namespace { // setup creates persistent namespace without switching to it. // Note, pid namespaces cannot be persisted. - pub fn setup(mut self) -> Result { + pub async fn setup(mut self) -> Result { fs::create_dir_all(&self.persistent_ns_dir)?; let ns_path = PathBuf::from(&self.persistent_ns_dir); @@ -93,45 +92,51 @@ impl Namespace { self.path = new_ns_path.clone().into_os_string().into_string().unwrap(); let hostname = self.hostname.clone(); - let new_thread = thread::spawn(move || -> Result<()> { - let origin_ns_path = get_current_thread_ns_path(&ns_type.get()); + let new_thread = tokio::spawn(async move { + if let Err(err) = || -> Result<()> { + let origin_ns_path = get_current_thread_ns_path(&ns_type.get()); - File::open(Path::new(&origin_ns_path))?; + File::open(Path::new(&origin_ns_path))?; - // Create a new netns on the current thread. - let cf = ns_type.get_flags(); + // Create a new netns on the current thread. + let cf = ns_type.get_flags(); - unshare(cf)?; + unshare(cf)?; - if ns_type == NamespaceType::UTS && hostname.is_some() { - nix::unistd::sethostname(hostname.unwrap())?; + if ns_type == NamespaceType::UTS && hostname.is_some() { + nix::unistd::sethostname(hostname.unwrap())?; + } + // Bind mount the new namespace from the current thread onto the mount point to persist it. + let source: &str = origin_ns_path.as_str(); + let destination: &str = new_ns_path.as_path().to_str().unwrap_or("none"); + + let mut flags = MsFlags::empty(); + + if let Some(x) = FLAGS.get("rbind") { + let (_, f) = *x; + flags |= f; + }; + + let bare_mount = BareMount::new(source, destination, "none", flags, "", &logger); + bare_mount.mount().map_err(|e| { + anyhow!( + "Failed to mount {} to {} with err:{:?}", + source, + destination, + e + ) + })?; + + Ok(()) + }() { + return Err(err); } - // Bind mount the new namespace from the current thread onto the mount point to persist it. - let source: &str = origin_ns_path.as_str(); - let destination: &str = new_ns_path.as_path().to_str().unwrap_or("none"); - - let mut flags = MsFlags::empty(); - - if let Some(x) = FLAGS.get("rbind") { - let (_, f) = *x; - flags |= f; - }; - - let bare_mount = BareMount::new(source, destination, "none", flags, "", &logger); - bare_mount.mount().map_err(|e| { - anyhow!( - "Failed to mount {} to {} with err:{:?}", - source, - destination, - e - ) - })?; Ok(()) }); new_thread - .join() + .await .map_err(|e| anyhow!("Failed to join thread {:?}!", e))??; Ok(self) diff --git a/src/agent/src/rpc.rs b/src/agent/src/rpc.rs index 50d7ac4a8a..2efd77f9f7 100644 --- a/src/agent/src/rpc.rs +++ b/src/agent/src/rpc.rs @@ -4,9 +4,12 @@ // use async_trait::async_trait; +use rustjail::{pipestream::PipeStream, process::StreamType}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf}; +use tokio::sync::Mutex; + use std::path::Path; -use std::sync::mpsc::channel; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use ttrpc::{ self, error::get_rpc_status as ttrpc_error, @@ -29,7 +32,6 @@ use protocols::types::Interface; use rustjail::cgroups::notifier; use rustjail::container::{BaseContainer, Container, LinuxContainer}; use rustjail::process::Process; -use rustjail::reaper; use rustjail::specconv::CreateOpts; use nix::errno::Errno; @@ -42,7 +44,7 @@ use rustjail::process::ProcessOperations; use crate::device::{add_devices, rescan_pci_bus, update_device_cgroup}; use crate::linux_abi::*; use crate::metrics::get_metrics; -use crate::mount::{add_storages, remove_mounts, BareMount, STORAGEHANDLERLIST}; +use crate::mount::{add_storages, remove_mounts, BareMount, STORAGE_HANDLER_LIST}; use crate::namespace::{NSTYPEIPC, NSTYPEPID, NSTYPEUTS}; use crate::network::setup_guest_dns; use crate::random; @@ -54,11 +56,8 @@ use netlink::{RtnlHandle, NETLINK_ROUTE}; use libc::{self, c_ushort, pid_t, winsize, TIOCSWINSZ}; use std::convert::TryFrom; use std::fs; -use std::os::unix::io::RawFd; use std::os::unix::prelude::PermissionsExt; use std::process::{Command, Stdio}; -use std::sync::mpsc; -use std::thread; use std::time::Duration; use nix::unistd::{Gid, Uid}; @@ -83,7 +82,10 @@ pub struct agentService { } impl agentService { - fn do_create_container(&self, req: protocols::agent::CreateContainerRequest) -> Result<()> { + async fn do_create_container( + &self, + req: protocols::agent::CreateContainerRequest, + ) -> Result<()> { let cid = req.container_id.clone(); let mut oci_spec = req.OCI.clone(); @@ -111,7 +113,7 @@ impl agentService { // updates the devices listed in the OCI spec, so that they actually // match real devices inside the VM. This step is necessary since we // cannot predict everything from the caller. - add_devices(&req.devices.to_vec(), &mut oci, &self.sandbox)?; + add_devices(&req.devices.to_vec(), &mut oci, &self.sandbox).await?; // Both rootfs and volumes (invoked with --volume for instance) will // be processed the same way. The idea is to always mount any provided @@ -120,10 +122,10 @@ impl agentService { // After all those storages have been processed, no matter the order // here, the agent will rely on rustjail (using the oci.Mounts // list) to bind mount all of them inside the container. - let m = add_storages(sl!(), req.storages.to_vec(), self.sandbox.clone())?; + let m = add_storages(sl!(), req.storages.to_vec(), self.sandbox.clone()).await?; { sandbox = self.sandbox.clone(); - s = sandbox.lock().unwrap(); + s = sandbox.lock().await; s.container_mounts.insert(cid.clone(), m); } @@ -154,7 +156,7 @@ impl agentService { let mut ctr: LinuxContainer = LinuxContainer::new(cid.as_str(), CONTAINER_BASE, opts, &sl!())?; - let pipe_size = AGENT_CONFIG.read().unwrap().container_pipe_size; + let pipe_size = AGENT_CONFIG.read().await.container_pipe_size; let p = if oci.process.is_some() { Process::new( &sl!(), @@ -168,7 +170,7 @@ impl agentService { return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL))); }; - ctr.start(p)?; + ctr.start(p).await?; s.update_shared_pidns(&ctr)?; s.add_container(ctr); @@ -177,11 +179,11 @@ impl agentService { Ok(()) } - fn do_start_container(&self, req: protocols::agent::StartContainerRequest) -> Result<()> { + async fn do_start_container(&self, req: protocols::agent::StartContainerRequest) -> Result<()> { let cid = req.container_id; let sandbox = self.sandbox.clone(); - let mut s = sandbox.lock().unwrap(); + let mut s = sandbox.lock().await; let sid = s.id.clone(); let ctr = s @@ -194,8 +196,8 @@ impl agentService { if sid != cid && ctr.cgroup_manager.is_some() { let cg_path = ctr.cgroup_manager.as_ref().unwrap().get_cg_path("memory"); if cg_path.is_some() { - let rx = notifier::notify_oom(cid.as_str(), cg_path.unwrap())?; - s.run_oom_event_monitor(rx, cid.clone()); + let rx = notifier::notify_oom(cid.as_str(), cg_path.unwrap()).await?; + s.run_oom_event_monitor(rx, cid.clone()).await; } } @@ -206,7 +208,10 @@ impl agentService { Ok(()) } - fn do_remove_container(&self, req: protocols::agent::RemoveContainerRequest) -> Result<()> { + async fn do_remove_container( + &self, + req: protocols::agent::RemoveContainerRequest, + ) -> Result<()> { let cid = req.container_id.clone(); let mut cmounts: Vec = vec![]; @@ -234,12 +239,12 @@ impl agentService { if req.timeout == 0 { let s = Arc::clone(&self.sandbox); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; let ctr = sandbox .get_container(&cid) .ok_or_else(|| anyhow!("Invalid container id"))?; - ctr.destroy()?; + ctr.destroy().await?; remove_container_resources(&mut sandbox)?; @@ -249,43 +254,47 @@ impl agentService { // timeout != 0 let s = self.sandbox.clone(); let cid2 = cid.clone(); - let (tx, rx) = mpsc::channel(); + let (tx, rx) = tokio::sync::oneshot::channel::(); - let handle = thread::spawn(move || { - let mut sandbox = s.lock().unwrap(); - let _ctr = sandbox - .get_container(&cid2) - .ok_or_else(|| anyhow!("Invalid container id")) - .map(|ctr| { - ctr.destroy().unwrap(); - tx.send(1).unwrap(); - ctr - }); + let handle = tokio::spawn(async move { + let mut sandbox = s.lock().await; + if let Some(ctr) = sandbox.get_container(&cid2) { + ctr.destroy().await.unwrap(); + tx.send(1).unwrap(); + }; }); - rx.recv_timeout(Duration::from_secs(req.timeout as u64)) - .map_err(|_| anyhow!(nix::Error::from_errno(nix::errno::Errno::ETIME)))?; + let timeout = tokio::time::delay_for(Duration::from_secs(req.timeout.into())); - handle - .join() - .map_err(|_| anyhow!(nix::Error::from_errno(nix::errno::Errno::UnknownErrno)))?; + tokio::select! { + _ = rx => {} + _ = timeout => { + return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::ETIME))); + } + }; + + if handle.await.is_err() { + return Err(anyhow!(nix::Error::from_errno( + nix::errno::Errno::UnknownErrno + ))); + } let s = self.sandbox.clone(); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; remove_container_resources(&mut sandbox)?; Ok(()) } - fn do_exec_process(&self, req: protocols::agent::ExecProcessRequest) -> Result<()> { + async fn do_exec_process(&self, req: protocols::agent::ExecProcessRequest) -> Result<()> { let cid = req.container_id.clone(); let exec_id = req.exec_id.clone(); info!(sl!(), "do_exec_process cid: {} eid: {}", cid, exec_id); let s = self.sandbox.clone(); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; let process = if req.process.is_some() { req.process.as_ref().unwrap() @@ -293,7 +302,7 @@ impl agentService { return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL))); }; - let pipe_size = AGENT_CONFIG.read().unwrap().container_pipe_size; + let pipe_size = AGENT_CONFIG.read().await.container_pipe_size; let ocip = rustjail::process_grpc_to_oci(process); let p = Process::new(&sl!(), &ocip, exec_id.as_str(), false, pipe_size)?; @@ -301,7 +310,7 @@ impl agentService { .get_container(&cid) .ok_or_else(|| anyhow!("Invalid container id"))?; - ctr.run(p)?; + ctr.run(p).await?; // set epoller let p = find_process(&mut sandbox, cid.as_str(), exec_id.as_str(), false)?; @@ -310,11 +319,11 @@ impl agentService { Ok(()) } - fn do_signal_process(&self, req: protocols::agent::SignalProcessRequest) -> Result<()> { + async fn do_signal_process(&self, req: protocols::agent::SignalProcessRequest) -> Result<()> { let cid = req.container_id.clone(); let eid = req.exec_id.clone(); let s = self.sandbox.clone(); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; let mut init = false; info!( @@ -344,7 +353,7 @@ impl agentService { Ok(()) } - fn do_wait_process( + async fn do_wait_process( &self, req: protocols::agent::WaitProcessRequest, ) -> Result { @@ -353,9 +362,9 @@ impl agentService { let s = self.sandbox.clone(); let mut resp = WaitProcessResponse::new(); let pid: pid_t; - let mut exit_pipe_r: RawFd = -1; - let mut buf: Vec = vec![0, 1]; - let (exit_send, exit_recv) = channel(); + let stream; + + let (exit_send, mut exit_recv) = tokio::sync::mpsc::channel(100); info!( sl!(), @@ -365,24 +374,24 @@ impl agentService { ); { - let mut sandbox = s.lock().unwrap(); - + let mut sandbox = s.lock().await; let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false)?; - if p.exit_pipe_r.is_some() { - exit_pipe_r = p.exit_pipe_r.unwrap(); - } + stream = p.get_reader(StreamType::ExitPipeR); p.exit_watchers.push(exit_send); pid = p.pid; } - if exit_pipe_r != -1 { + if stream.is_some() { info!(sl!(), "reading exit pipe"); - let _ = unistd::read(exit_pipe_r, buf.as_mut_slice()); + + let reader = stream.unwrap(); + let mut content: Vec = vec![0, 1]; + let _ = reader.lock().await.read(&mut content).await; } - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; let ctr = sandbox .get_container(&cid) .ok_or_else(|| anyhow!("Invalid container id"))?; @@ -391,44 +400,20 @@ impl agentService { Some(p) => p, None => { // Lost race, pick up exit code from channel - resp.status = exit_recv.recv().unwrap(); + resp.status = exit_recv.recv().await.unwrap(); return Ok(resp); } }; - // need to close all fds - if p.parent_stdin.is_some() { - let _ = unistd::close(p.parent_stdin.unwrap()); - } - - if p.parent_stdout.is_some() { - let _ = unistd::close(p.parent_stdout.unwrap()); - } - - if p.parent_stderr.is_some() { - let _ = unistd::close(p.parent_stderr.unwrap()); - } - - if p.term_master.is_some() { - let _ = unistd::close(p.term_master.unwrap()); - } - - if p.exit_pipe_r.is_some() { - let _ = unistd::close(p.exit_pipe_r.unwrap()); - } - - p.close_epoller(); - - p.parent_stdin = None; - p.parent_stdout = None; - p.parent_stderr = None; - p.term_master = None; + // need to close all fd + // ignore errors for some fd might be closed by stream + let _ = cleanup_process(&mut p); resp.status = p.exit_code; // broadcast exit code to all parallel watchers - for s in p.exit_watchers.iter() { + for s in p.exit_watchers.iter_mut() { // Just ignore errors in case any watcher quits unexpectedly - let _ = s.send(p.exit_code); + let _ = s.send(p.exit_code).await; } ctr.processes.remove(&pid); @@ -436,48 +421,37 @@ impl agentService { Ok(resp) } - fn do_write_stream( + async fn do_write_stream( &self, req: protocols::agent::WriteStreamRequest, ) -> Result { let cid = req.container_id.clone(); let eid = req.exec_id.clone(); - let s = self.sandbox.clone(); - let mut sandbox = s.lock().unwrap(); - let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false)?; + let writer = { + let s = self.sandbox.clone(); + let mut sandbox = s.lock().await; + let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false)?; - // use ptmx io - let fd = if p.term_master.is_some() { - p.term_master.unwrap() - } else { - // use piped io - p.parent_stdin.unwrap() + // use ptmx io + if p.term_master.is_some() { + p.get_writer(StreamType::TermMaster) + } else { + // use piped io + p.get_writer(StreamType::ParentStdin) + } }; - let mut l = req.data.len(); - match unistd::write(fd, req.data.as_slice()) { - Ok(v) => { - if v < l { - info!(sl!(), "write {} bytes", v); - l = v; - } - } - Err(e) => match e { - nix::Error::Sys(nix::errno::Errno::EAGAIN) => l = 0, - _ => { - return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EIO))); - } - }, - } + let writer = writer.unwrap(); + writer.lock().await.write_all(req.data.as_slice()).await?; let mut resp = WriteStreamResponse::new(); - resp.set_len(l as u32); + resp.set_len(req.data.len() as u32); Ok(resp) } - fn do_read_stream( + async fn do_read_stream( &self, req: protocols::agent::ReadStreamRequest, stdout: bool, @@ -485,42 +459,35 @@ impl agentService { let cid = req.container_id; let eid = req.exec_id; - let mut fd: RawFd = -1; - let mut epoller: Option = None; - { + // let mut fd: RawFd = -1; + // let mut epoller: Option = None; + + let reader = { let s = self.sandbox.clone(); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false)?; if p.term_master.is_some() { - fd = p.term_master.unwrap(); - epoller = p.epoller.clone(); + // epoller = p.epoller.clone(); + p.get_reader(StreamType::TermMaster) } else if stdout { if p.parent_stdout.is_some() { - fd = p.parent_stdout.unwrap(); + p.get_reader(StreamType::ParentStdout) + } else { + None } } else { - fd = p.parent_stderr.unwrap(); + p.get_reader(StreamType::ParentStderr) } - } + }; - if let Some(epoller) = epoller { - // The process's epoller's poll() will return a file descriptor of the process's - // terminal or one end of its exited pipe. If it returns its terminal, it means - // there is data needed to be read out or it has been closed; if it returns the - // process's exited pipe, it means the process has exited and there is no data - // needed to be read out in its terminal, thus following read on it will read out - // "EOF" to terminate this process's io since the other end of this pipe has been - // closed in reap(). - fd = epoller.poll()?; - } - - if fd == -1 { + if reader.is_none() { return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL))); } - let vector = read_stream(fd, req.len as usize)?; + let reader = reader.unwrap(); + let vector = read_stream(reader, req.len as usize).await?; let mut resp = ReadStreamResponse::new(); resp.set_data(vector); @@ -536,7 +503,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { _ctx: &TtrpcContext, req: protocols::agent::CreateContainerRequest, ) -> ttrpc::Result { - match self.do_create_container(req) { + match self.do_create_container(req).await { Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Ok(_) => Ok(Empty::new()), } @@ -547,7 +514,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { _ctx: &TtrpcContext, req: protocols::agent::StartContainerRequest, ) -> ttrpc::Result { - match self.do_start_container(req) { + match self.do_start_container(req).await { Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Ok(_) => Ok(Empty::new()), } @@ -558,7 +525,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { _ctx: &TtrpcContext, req: protocols::agent::RemoveContainerRequest, ) -> ttrpc::Result { - match self.do_remove_container(req) { + match self.do_remove_container(req).await { Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Ok(_) => Ok(Empty::new()), } @@ -569,7 +536,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { _ctx: &TtrpcContext, req: protocols::agent::ExecProcessRequest, ) -> ttrpc::Result { - match self.do_exec_process(req) { + match self.do_exec_process(req).await { Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Ok(_) => Ok(Empty::new()), } @@ -580,7 +547,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { _ctx: &TtrpcContext, req: protocols::agent::SignalProcessRequest, ) -> ttrpc::Result { - match self.do_signal_process(req) { + match self.do_signal_process(req).await { Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Ok(_) => Ok(Empty::new()), } @@ -592,6 +559,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { req: protocols::agent::WaitProcessRequest, ) -> ttrpc::Result { self.do_wait_process(req) + .await .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())) } @@ -606,7 +574,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { let mut resp = ListProcessesResponse::new(); let s = Arc::clone(&self.sandbox); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; let ctr = sandbox.get_container(&cid).ok_or_else(|| { ttrpc_error( @@ -637,10 +605,11 @@ impl protocols::agent_ttrpc::AgentService for agentService { args = vec!["-ef".to_string()]; } - let output = Command::new("ps") + let output = tokio::process::Command::new("ps") .args(args.as_slice()) .stdout(Stdio::piped()) .output() + .await .expect("ps failed"); let out: String = String::from_utf8(output.stdout).unwrap(); @@ -688,7 +657,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { let res = req.resources; let s = Arc::clone(&self.sandbox); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; let ctr = sandbox.get_container(&cid).ok_or_else(|| { ttrpc_error( @@ -720,7 +689,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { ) -> ttrpc::Result { let cid = req.container_id; let s = Arc::clone(&self.sandbox); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; let ctr = sandbox.get_container(&cid).ok_or_else(|| { ttrpc_error( @@ -740,7 +709,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { ) -> ttrpc::Result { let cid = req.get_container_id(); let s = Arc::clone(&self.sandbox); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; let ctr = sandbox.get_container(&cid).ok_or_else(|| { ttrpc_error( @@ -762,7 +731,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { ) -> ttrpc::Result { let cid = req.get_container_id(); let s = Arc::clone(&self.sandbox); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; let ctr = sandbox.get_container(&cid).ok_or_else(|| { ttrpc_error( @@ -783,6 +752,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { req: protocols::agent::WriteStreamRequest, ) -> ttrpc::Result { self.do_write_stream(req) + .await .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())) } @@ -792,6 +762,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { req: protocols::agent::ReadStreamRequest, ) -> ttrpc::Result { self.do_read_stream(req, true) + .await .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())) } @@ -801,6 +772,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { req: protocols::agent::ReadStreamRequest, ) -> ttrpc::Result { self.do_read_stream(req, false) + .await .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())) } @@ -812,7 +784,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { let cid = req.container_id.clone(); let eid = req.exec_id; let s = Arc::clone(&self.sandbox); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false).map_err(|e| { ttrpc_error( @@ -822,11 +794,13 @@ impl protocols::agent_ttrpc::AgentService for agentService { })?; if p.term_master.is_some() { + p.close_stream(StreamType::TermMaster); let _ = unistd::close(p.term_master.unwrap()); p.term_master = None; } if p.parent_stdin.is_some() { + p.close_stream(StreamType::ParentStdin); let _ = unistd::close(p.parent_stdin.unwrap()); p.parent_stdin = None; } @@ -844,7 +818,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { let cid = req.container_id.clone(); let eid = req.exec_id.clone(); let s = Arc::clone(&self.sandbox); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false).map_err(|e| { ttrpc_error( ttrpc::Code::UNAVAILABLE, @@ -888,7 +862,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { let interface = req.interface; let s = Arc::clone(&self.sandbox); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; if sandbox.rtnl.is_none() { sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap()); @@ -921,7 +895,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { let rs = req.routes.unwrap().Routes.into_vec(); let s = Arc::clone(&self.sandbox); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; if sandbox.rtnl.is_none() { sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap()); @@ -951,7 +925,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { ) -> ttrpc::Result { let mut interface = protocols::agent::Interfaces::new(); let s = Arc::clone(&self.sandbox); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; if sandbox.rtnl.is_none() { sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap()); @@ -974,7 +948,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { ) -> ttrpc::Result { let mut routes = protocols::agent::Routes::new(); let s = Arc::clone(&self.sandbox); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; if sandbox.rtnl.is_none() { sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap()); @@ -1015,7 +989,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { ) -> ttrpc::Result { { let sandbox = self.sandbox.clone(); - let mut s = sandbox.lock().unwrap(); + let mut s = sandbox.lock().await; let _ = fs::remove_dir_all(CONTAINER_BASE); let _ = fs::create_dir_all(CONTAINER_BASE); @@ -1042,13 +1016,14 @@ impl protocols::agent_ttrpc::AgentService for agentService { } s.setup_shared_namespaces() + .await .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))?; } - match add_storages(sl!(), req.storages.to_vec(), self.sandbox.clone()) { + match add_storages(sl!(), req.storages.to_vec(), self.sandbox.clone()).await { Ok(m) => { let sandbox = self.sandbox.clone(); - let mut s = sandbox.lock().unwrap(); + let mut s = sandbox.lock().await; s.mounts = m } Err(e) => return Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), @@ -1057,7 +1032,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { match setup_guest_dns(sl!(), req.dns.to_vec()) { Ok(_) => { let sandbox = self.sandbox.clone(); - let mut s = sandbox.lock().unwrap(); + let mut s = sandbox.lock().await; let _dns = req .dns .to_vec() @@ -1076,13 +1051,11 @@ impl protocols::agent_ttrpc::AgentService for agentService { _req: protocols::agent::DestroySandboxRequest, ) -> ttrpc::Result { let s = Arc::clone(&self.sandbox); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; // destroy all containers, clean up, notify agent to exit // etc. - sandbox.destroy().unwrap(); - - sandbox.sender.as_ref().unwrap().send(1).unwrap(); - sandbox.sender = None; + sandbox.destroy().await.unwrap(); + sandbox.sender.take().unwrap().send(1).unwrap(); Ok(Empty::new()) } @@ -1102,7 +1075,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { let neighs = req.neighbors.unwrap().ARPNeighbors.into_vec(); let s = Arc::clone(&self.sandbox); - let mut sandbox = s.lock().unwrap(); + let mut sandbox = s.lock().await; if sandbox.rtnl.is_none() { sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap()); @@ -1122,7 +1095,7 @@ impl protocols::agent_ttrpc::AgentService for agentService { req: protocols::agent::OnlineCPUMemRequest, ) -> ttrpc::Result { let s = Arc::clone(&self.sandbox); - let sandbox = s.lock().unwrap(); + let sandbox = s.lock().await; sandbox .online_cpu_memory(&req) @@ -1221,15 +1194,15 @@ impl protocols::agent_ttrpc::AgentService for agentService { _req: protocols::agent::GetOOMEventRequest, ) -> ttrpc::Result { let sandbox = self.sandbox.clone(); - let s = sandbox.lock().unwrap(); + let s = sandbox.lock().await; let event_rx = &s.event_rx.clone(); - let event_rx = event_rx.lock().unwrap(); + let mut event_rx = event_rx.lock().await; drop(s); drop(sandbox); - match event_rx.recv() { - Err(err) => Err(ttrpc_error(ttrpc::Code::INTERNAL, err.to_string())), - Ok(container_id) => { + match event_rx.recv().await { + None => Err(ttrpc_error(ttrpc::Code::INTERNAL, "")), + Some(container_id) => { info!(sl!(), "get_oom_event return {}", &container_id); let mut resp = OOMEvent::new(); resp.container_id = container_id; @@ -1326,42 +1299,28 @@ fn get_agent_details() -> AgentDetails { detail.device_handlers = RepeatedField::new(); detail.storage_handlers = RepeatedField::from_vec( - STORAGEHANDLERLIST - .keys() - .cloned() - .map(|x| x.into()) + STORAGE_HANDLER_LIST + .to_vec() + .iter() + .map(|x| x.to_string()) .collect(), ); detail } -fn read_stream(fd: RawFd, l: usize) -> Result> { - let mut v: Vec = Vec::with_capacity(l); - unsafe { - v.set_len(l); +async fn read_stream(reader: Arc>>, l: usize) -> Result> { + let mut content = vec![0u8; l]; + + let mut reader = reader.lock().await; + let len = reader.read(&mut content).await?; + content.resize(len, 0); + + if len == 0 { + return Err(anyhow!("read meet eof")); } - match unistd::read(fd, v.as_mut_slice()) { - Ok(len) => { - v.resize(len, 0); - // Rust didn't return an EOF error when the reading peer point - // was closed, instead it would return a 0 reading length, please - // see https://github.com/rust-lang/rfcs/blob/master/text/0517-io-os-reform.md#errors - if len == 0 { - return Err(anyhow!("read meet eof")); - } - } - Err(e) => match e { - nix::Error::Sys(errno) => match errno { - Errno::EAGAIN => v.clear(), - _ => return Err(anyhow!(nix::Error::Sys(errno))), - }, - _ => return Err(anyhow!("read error")), - }, - } - - Ok(v) + Ok(content) } fn find_process<'a>( @@ -1384,7 +1343,7 @@ fn find_process<'a>( ctr.get_process(eid).map_err(|_| anyhow!("Invalid exec id")) } -pub async fn start(s: Arc>, server_address: &str) -> Result { +pub fn start(s: Arc>, server_address: &str) -> TtrpcServer { let agent_service = Box::new(agentService { sandbox: s }) as Box; @@ -1399,13 +1358,14 @@ pub async fn start(s: Arc>, server_address: &str) -> Result server_address); - Ok(server) + server } // This function updates the container namespaces configuration based on the @@ -1641,6 +1601,42 @@ fn setup_bundle(cid: &str, spec: &mut Spec) -> Result { Ok(olddir) } +fn cleanup_process(p: &mut Process) -> Result<()> { + if p.parent_stdin.is_some() { + p.close_stream(StreamType::ParentStdin); + let _ = unistd::close(p.parent_stdin.unwrap())?; + } + + if p.parent_stdout.is_some() { + p.close_stream(StreamType::ParentStdout); + let _ = unistd::close(p.parent_stdout.unwrap())?; + } + + if p.parent_stderr.is_some() { + p.close_stream(StreamType::ParentStderr); + let _ = unistd::close(p.parent_stderr.unwrap())?; + } + + if p.term_master.is_some() { + p.close_stream(StreamType::TermMaster); + let _ = unistd::close(p.term_master.unwrap())?; + } + + if p.exit_pipe_r.is_some() { + p.close_stream(StreamType::ExitPipeR); + let _ = unistd::close(p.exit_pipe_r.unwrap())?; + } + + p.close_epoller(); + + p.parent_stdin = None; + p.parent_stdout = None; + p.parent_stderr = None; + p.term_master = None; + + Ok(()) +} + fn load_kernel_module(module: &protocols::agent::KernelModule) -> Result<()> { if module.name == "" { return Err(anyhow!("Kernel module name is empty")); diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index 8d9f7a1b66..f041509ec4 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -22,9 +22,10 @@ use std::collections::HashMap; use std::fs; use std::os::unix::fs::PermissionsExt; use std::path::Path; -use std::sync::mpsc::{self, Receiver, Sender}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::{thread, time}; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::Mutex; #[derive(Debug)] pub struct Sandbox { @@ -42,7 +43,7 @@ pub struct Sandbox { pub storages: HashMap, pub running: bool, pub no_pivot_root: bool, - pub sender: Option>, + pub sender: Option>, pub rtnl: Option, pub hooks: Option, pub event_rx: Arc>>, @@ -53,7 +54,7 @@ impl Sandbox { pub fn new(logger: &Logger) -> Result { let fs_type = get_mount_fs_type("/")?; let logger = logger.new(o!("subsystem" => "sandbox")); - let (tx, rx) = mpsc::channel::(); + let (tx, rx) = channel::(100); let event_rx = Arc::new(Mutex::new(rx)); Ok(Sandbox { @@ -157,17 +158,19 @@ impl Sandbox { self.hostname = hostname; } - pub fn setup_shared_namespaces(&mut self) -> Result { + pub async fn setup_shared_namespaces(&mut self) -> Result { // Set up shared IPC namespace self.shared_ipcns = Namespace::new(&self.logger) .get_ipc() .setup() + .await .context("Failed to setup persistent IPC namespace")?; // // Set up shared UTS namespace self.shared_utsns = Namespace::new(&self.logger) .get_uts(self.hostname.as_str()) .setup() + .await .context("Failed to setup persistent UTS namespace")?; Ok(true) @@ -214,9 +217,9 @@ impl Sandbox { None } - pub fn destroy(&mut self) -> Result<()> { + pub async fn destroy(&mut self) -> Result<()> { for ctr in self.containers.values_mut() { - ctr.destroy()?; + ctr.destroy().await?; } Ok(()) } @@ -315,15 +318,17 @@ impl Sandbox { Ok(hooks) } - pub fn run_oom_event_monitor(&self, rx: Receiver, container_id: String) { - let tx = self.event_tx.clone(); + pub async fn run_oom_event_monitor(&self, mut rx: Receiver, container_id: String) { + let mut tx = self.event_tx.clone(); let logger = self.logger.clone(); - thread::spawn(move || { - for event in rx { + tokio::spawn(async move { + loop { + let event = rx.recv().await; info!(logger, "got an OOM event {:?}", event); let _ = tx .send(container_id.clone()) + .await .map_err(|e| error!(logger, "failed to send message: {:?}", e)); } }); diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index 42f1590ab1..d3620fe1db 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -7,10 +7,13 @@ use crate::device::online_device; use crate::linux_abi::*; use crate::sandbox::Sandbox; use crate::GLOBAL_DEVICE_WATCHER; -use netlink::{RtnlHandle, NETLINK_UEVENT}; use slog::Logger; -use std::sync::{Arc, Mutex}; -use std::thread; + +use netlink_sys::{Protocol, Socket, SocketAddr}; +use nix::errno::Errno; +use std::os::unix::io::FromRawFd; +use std::sync::Arc; +use tokio::sync::Mutex; #[derive(Debug, Default)] struct Uevent { @@ -55,12 +58,13 @@ impl Uevent { && self.devname != "" } - fn handle_block_add_event(&self, sandbox: &Arc>) { + async fn handle_block_add_event(&self, sandbox: &Arc>) { let pci_root_bus_path = create_pci_root_bus_path(); // Keep the same lock order as device::get_device_name(), otherwise it may cause deadlock. - let mut w = GLOBAL_DEVICE_WATCHER.lock().unwrap(); - let mut sb = sandbox.lock().unwrap(); + let watcher = GLOBAL_DEVICE_WATCHER.clone(); + let mut w = watcher.lock().await; + let mut sb = sandbox.lock().await; // Add the device node name to the pci device map. sb.pci_device_map @@ -70,7 +74,7 @@ impl Uevent { // Close the channel after watcher has been notified. let devpath = self.devpath.clone(); let empties: Vec<_> = w - .iter() + .iter_mut() .filter(|(dev_addr, _)| { let pci_p = format!("{}/{}", pci_root_bus_path, *dev_addr); @@ -84,6 +88,7 @@ impl Uevent { }) .map(|(k, sender)| { let devname = self.devname.clone(); + let sender = sender.take().unwrap(); let _ = sender.send(devname); k.clone() }) @@ -95,9 +100,9 @@ impl Uevent { } } - fn process(&self, logger: &Logger, sandbox: &Arc>) { + async fn process(&self, logger: &Logger, sandbox: &Arc>) { if self.is_block_add_event() { - return self.handle_block_add_event(sandbox); + return self.handle_block_add_event(sandbox).await; } else if self.action == U_EVENT_ACTION_ADD { let online_path = format!("{}/{}/online", SYSFS_DIR, &self.devpath); // It's a memory hot-add event. @@ -117,22 +122,37 @@ impl Uevent { } } -pub fn watch_uevents(sandbox: Arc>) { - thread::spawn(move || { - let rtnl = RtnlHandle::new(NETLINK_UEVENT, 1).unwrap(); - let logger = sandbox - .lock() - .unwrap() - .logger - .new(o!("subsystem" => "uevent")); +pub async fn watch_uevents(sandbox: Arc>) { + let sref = sandbox.clone(); + let s = sref.lock().await; + let logger = s.logger.new(o!("subsystem" => "uevent")); + + tokio::spawn(async move { + let mut socket; + unsafe { + let fd = libc::socket( + libc::AF_NETLINK, + libc::SOCK_DGRAM | libc::SOCK_CLOEXEC, + Protocol::KObjectUevent as libc::c_int, + ); + socket = Socket::from_raw_fd(fd); + } + socket.bind(&SocketAddr::new(0, 1)).unwrap(); loop { - match rtnl.recv_message() { + match socket.recv_from_full().await { Err(e) => { error!(logger, "receive uevent message failed"; "error" => format!("{}", e)) } - Ok(data) => { - let text = String::from_utf8(data); + Ok((buf, addr)) => { + if addr.port_number() != 0 { + // not our netlink message + let err_msg = format!("{:?}", nix::Error::Sys(Errno::EBADMSG)); + error!(logger, "receive uevent message failed"; "error" => err_msg); + return; + } + + let text = String::from_utf8(buf); match text { Err(e) => { error!(logger, "failed to convert bytes to text"; "error" => format!("{}", e)) @@ -140,7 +160,7 @@ pub fn watch_uevents(sandbox: Arc>) { Ok(text) => { let event = Uevent::new(&text); info!(logger, "got uevent message"; "event" => format!("{:?}", event)); - event.process(&logger, &sandbox); + event.process(&logger, &sandbox).await; } } }