agent: switch to async runtime

Fixes: #1209

Signed-off-by: Tim Zhang <tim@hyper.sh>
This commit is contained in:
Tim Zhang 2020-12-23 15:11:40 +08:00
parent 5561755e3c
commit 332fa4c65f
18 changed files with 1225 additions and 634 deletions

143
src/agent/Cargo.lock generated
View File

@ -161,9 +161,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cgroups-rs"
version = "0.2.0"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02274214de2526e48355facdd16c9d774bba2cf74d135ffb9876a60b4d613464"
checksum = "348eb6d8e20a9f5247209686b7d0ffc2f4df40ddcb95f9940de55a94a655b3f5"
dependencies = [
"libc",
"log",
@ -486,6 +486,28 @@ version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "644f9158b2f133fd50f5fb3242878846d9eb792e445c893805ff0e3824006e35"
[[package]]
name = "inotify"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04c6848dfb1580647ab039713282cdd1ab2bfb47b60ecfb598e22e60e3baf3f8"
dependencies = [
"bitflags",
"futures-core",
"inotify-sys",
"libc",
"tokio 0.3.6",
]
[[package]]
name = "inotify-sys"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4563555856585ab3180a5bf0b2f9f8d301a728462afffc8195b3f5394229c55"
dependencies = [
"libc",
]
[[package]]
name = "iovec"
version = "0.1.4"
@ -517,11 +539,13 @@ dependencies = [
"anyhow",
"async-trait",
"cgroups-rs",
"futures",
"lazy_static",
"libc",
"log",
"logging",
"netlink",
"netlink-sys",
"nix 0.17.0",
"oci",
"prctl",
@ -539,7 +563,8 @@ dependencies = [
"slog-scope",
"slog-stdlog",
"tempfile",
"tokio",
"tokio 0.2.24",
"tokio-vsock",
"ttrpc",
]
@ -647,12 +672,37 @@ dependencies = [
"kernel32-sys",
"libc",
"log",
"miow",
"miow 0.2.2",
"net2",
"slab",
"winapi 0.2.8",
]
[[package]]
name = "mio"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f33bc887064ef1fd66020c9adfc45bb9f33d75a42096c81e7c56c65b75dd1a8b"
dependencies = [
"libc",
"log",
"miow 0.3.6",
"ntapi",
"winapi 0.3.9",
]
[[package]]
name = "mio-named-pipes"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0840c1c50fd55e521b247f949c241c9997709f23bd7f023b9762cd561e935656"
dependencies = [
"log",
"mio 0.6.23",
"miow 0.3.6",
"winapi 0.3.9",
]
[[package]]
name = "mio-uds"
version = "0.6.8"
@ -661,7 +711,7 @@ checksum = "afcb699eb26d4332647cc848492bbc15eafb26f08d0304550d5aa1f612e066f0"
dependencies = [
"iovec",
"libc",
"mio",
"mio 0.6.23",
]
[[package]]
@ -676,6 +726,16 @@ dependencies = [
"ws2_32-sys",
]
[[package]]
name = "miow"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a33c1b55807fbed163481b5ba66db4b2fa6cde694a5027be10fb724206c5897"
dependencies = [
"socket2",
"winapi 0.3.9",
]
[[package]]
name = "multimap"
version = "0.4.0"
@ -705,6 +765,19 @@ dependencies = [
"slog-scope",
]
[[package]]
name = "netlink-sys"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc9e9df13fd91bdd4b92bea93d5d2848c8035677c60fc3fee5dabddc02c3012e"
dependencies = [
"futures",
"libc",
"log",
"mio 0.6.23",
"tokio 0.2.24",
]
[[package]]
name = "nix"
version = "0.16.1"
@ -761,6 +834,15 @@ version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf51a729ecf40266a2368ad335a5fdde43471f545a967109cd62146ecf8b66ff"
[[package]]
name = "ntapi"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44"
dependencies = [
"winapi 0.3.9",
]
[[package]]
name = "num-integer"
version = "0.1.43"
@ -890,6 +972,12 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c917123afa01924fc84bb20c4c03f004d9c38e5127e3c039bbf7f4b9c76a2f6b"
[[package]]
name = "pin-project-lite"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b063f57ec186e6140e2b8b6921e5f1bd89c7356dda5b33acc5401203ca6131c"
[[package]]
name = "pin-utils"
version = "0.1.0"
@ -1213,12 +1301,16 @@ name = "rustjail"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"caps",
"cgroups-rs",
"dirs",
"epoll",
"futures",
"inotify",
"lazy_static",
"libc",
"mio 0.6.23",
"nix 0.17.0",
"oci",
"path-absolutize",
@ -1235,6 +1327,7 @@ dependencies = [
"slog",
"slog-scope",
"tempfile",
"tokio 0.2.24",
]
[[package]]
@ -1413,6 +1506,17 @@ version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbee7696b84bbf3d89a1c2eccff0850e3047ed46bfcd2e92c29a2d074d57e252"
[[package]]
name = "socket2"
version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "122e570113d28d773067fab24266b66753f6ea915758651696b6e35e49f88d6e"
dependencies = [
"cfg-if 1.0.0",
"libc",
"winapi 0.3.9",
]
[[package]]
name = "spin"
version = "0.5.2"
@ -1513,12 +1617,27 @@ dependencies = [
"lazy_static",
"libc",
"memchr",
"mio",
"mio 0.6.23",
"mio-named-pipes",
"mio-uds",
"num_cpus",
"pin-project-lite",
"pin-project-lite 0.1.11",
"signal-hook-registry",
"slab",
"tokio-macros",
"winapi 0.3.9",
]
[[package]]
name = "tokio"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "720ba21c25078711bf456d607987d95bce90f7c3bea5abe1db587862e7a1e87c"
dependencies = [
"autocfg",
"libc",
"mio 0.7.6",
"pin-project-lite 0.2.0",
]
[[package]]
@ -1542,17 +1661,17 @@ dependencies = [
"futures",
"iovec",
"libc",
"mio",
"mio 0.6.23",
"nix 0.17.0",
"tokio",
"tokio 0.2.24",
"vsock",
]
[[package]]
name = "ttrpc"
version = "0.4.13"
version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6e99ffa09e7fbe514b58b01bd17d71e3ed4dd27c588afa43d41ec0b7fc90b0a"
checksum = "fc512242eee1f113eadd48087dd97cbf807ccae4820006e7a890044044399c51"
dependencies = [
"async-trait",
"byteorder",
@ -1563,7 +1682,7 @@ dependencies = [
"protobuf",
"protobuf-codegen-pure",
"thiserror",
"tokio",
"tokio 0.2.24",
"tokio-vsock",
]

View File

@ -11,7 +11,7 @@ rustjail = { path = "rustjail" }
protocols = { path = "protocols" }
netlink = { path = "netlink", features = ["with-log", "with-agent-handler"] }
lazy_static = "1.3.0"
ttrpc = { version="0.4.13", features=["async"] }
ttrpc = { version = "0.4.14", features = ["async", "protobuf-codec"], default-features = false }
protobuf = "=2.14.0"
libc = "0.2.58"
nix = "0.17.0"
@ -21,8 +21,12 @@ signal-hook = "0.1.9"
scan_fmt = "0.2.3"
scopeguard = "1.0.0"
regex = "1"
tokio = { version="0.2", features = ["macros", "rt-threaded"] }
async-trait = "0.1.42"
tokio = { version = "0.2", features = ["rt-core", "sync", "uds", "stream", "macros", "io-util", "time", "signal", "io-std", "process",] }
futures = "0.3"
netlink-sys = { version = "0.4.0", features = ["tokio_socket",]}
tokio-vsock = "0.2.2"
# slog:
# - Dynamic keys required to allow HashMap keys to be slog::Serialized.
@ -40,7 +44,7 @@ tempfile = "3.1.0"
prometheus = { version = "0.9.0", features = ["process"] }
procfs = "0.7.9"
anyhow = "1.0.32"
cgroups = { package = "cgroups-rs", version = "0.2.0" }
cgroups = { package = "cgroups-rs", version = "0.2.1" }
[workspace]
members = [

View File

@ -5,7 +5,7 @@ authors = ["The Kata Containers community <kata-dev@lists.katacontainers.io>"]
edition = "2018"
[dependencies]
ttrpc = { version="0.4.13", features=["async"] }
ttrpc = { version = "0.4.14", features = ["async"] }
async-trait = "0.1.42"
protobuf = "=2.14.0"

View File

@ -16,7 +16,7 @@ scopeguard = "1.0.0"
prctl = "1.0.0"
lazy_static = "1.3.0"
libc = "0.2.58"
protobuf = "2.8.1"
protobuf = "=2.14.0"
slog = "2.5.2"
slog-scope = "4.1.2"
scan_fmt = "0.2"
@ -24,9 +24,15 @@ regex = "1.1"
path-absolutize = "1.2.0"
dirs = "3.0.1"
anyhow = "1.0.32"
cgroups = { package = "cgroups-rs", version = "0.2.0" }
cgroups = { package = "cgroups-rs", version = "0.2.1" }
tempfile = "3.1.0"
epoll = "4.3.1"
tokio = { version = "0.2", features = ["sync", "io-util", "process", "time", "macros"] }
futures = "0.3"
async-trait = "0.1.31"
mio = "0.6"
inotify = "0.9"
[dev-dependencies]
serial_test = "0.5.0"

View File

@ -3,16 +3,18 @@
// SPDX-License-Identifier: Apache-2.0
//
use anyhow::{anyhow, Result};
use anyhow::{anyhow, Context, Result};
use eventfd::{eventfd, EfdFlags};
use nix::sys::eventfd;
use nix::sys::inotify::{AddWatchFlags, InitFlags, Inotify};
use std::fs::{self, File};
use std::io::Read;
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::path::{Path, PathBuf};
use std::sync::mpsc::{self, Receiver};
use std::thread;
use crate::pipestream::PipeStream;
use futures::StreamExt as _;
use inotify::{Inotify, WatchMask};
use tokio::io::AsyncReadExt;
use tokio::sync::mpsc::{channel, Receiver};
// Convenience macro to obtain the scope logger
macro_rules! sl {
@ -21,11 +23,11 @@ macro_rules! sl {
};
}
pub fn notify_oom(cid: &str, cg_dir: String) -> Result<Receiver<String>> {
pub async fn notify_oom(cid: &str, cg_dir: String) -> Result<Receiver<String>> {
if cgroups::hierarchies::is_cgroup2_unified_mode() {
return notify_on_oom_v2(cid, cg_dir);
return notify_on_oom_v2(cid, cg_dir).await;
}
notify_on_oom(cid, cg_dir)
notify_on_oom(cid, cg_dir).await
}
// get_value_from_cgroup parse cgroup file with `Flat keyed`
@ -52,11 +54,11 @@ fn get_value_from_cgroup(path: &PathBuf, key: &str) -> Result<i64> {
// notify_on_oom returns channel on which you can expect event about OOM,
// if process died without OOM this channel will be closed.
pub fn notify_on_oom_v2(containere_id: &str, cg_dir: String) -> Result<Receiver<String>> {
register_memory_event_v2(containere_id, cg_dir, "memory.events", "cgroup.events")
pub async fn notify_on_oom_v2(containere_id: &str, cg_dir: String) -> Result<Receiver<String>> {
register_memory_event_v2(containere_id, cg_dir, "memory.events", "cgroup.events").await
}
fn register_memory_event_v2(
async fn register_memory_event_v2(
containere_id: &str,
cg_dir: String,
memory_event_name: &str,
@ -73,55 +75,55 @@ fn register_memory_event_v2(
"register_memory_event_v2 cgroup_event_control_path: {:?}", &cgroup_event_control_path
);
let fd = Inotify::init(InitFlags::empty()).unwrap();
let mut inotify = Inotify::init().context("Failed to initialize inotify")?;
// watching oom kill
let ev_fd = fd
.add_watch(&event_control_path, AddWatchFlags::IN_MODIFY)
.unwrap();
let ev_wd = inotify.add_watch(&event_control_path, WatchMask::MODIFY)?;
// Because no `unix.IN_DELETE|unix.IN_DELETE_SELF` event for cgroup file system, so watching all process exited
let cg_fd = fd
.add_watch(&cgroup_event_control_path, AddWatchFlags::IN_MODIFY)
.unwrap();
info!(sl!(), "ev_fd: {:?}", ev_fd);
info!(sl!(), "cg_fd: {:?}", cg_fd);
let cg_wd = inotify.add_watch(&cgroup_event_control_path, WatchMask::MODIFY)?;
let (sender, receiver) = mpsc::channel();
info!(sl!(), "ev_wd: {:?}", ev_wd);
info!(sl!(), "cg_wd: {:?}", cg_wd);
let (mut sender, receiver) = channel(100);
let containere_id = containere_id.to_string();
thread::spawn(move || {
loop {
let events = fd.read_events().unwrap();
tokio::spawn(async move {
let mut buffer = [0; 32];
let mut stream = inotify
.event_stream(&mut buffer)
.expect("create inotify event stream failed");
while let Some(event_or_error) = stream.next().await {
let event = event_or_error.unwrap();
info!(
sl!(),
"container[{}] get events for container: {:?}", &containere_id, &events
"container[{}] get event for container: {:?}", &containere_id, &event
);
for event in events {
if event.mask & AddWatchFlags::IN_MODIFY != AddWatchFlags::IN_MODIFY {
continue;
}
// info!("is1: {}", event.wd == wd1);
info!(sl!(), "event.wd: {:?}", event.wd);
if event.wd == ev_fd {
if event.wd == ev_wd {
let oom = get_value_from_cgroup(&event_control_path, "oom_kill");
if oom.unwrap_or(0) > 0 {
sender.send(containere_id.clone()).unwrap();
let _ = sender.send(containere_id.clone()).await.map_err(|e| {
error!(sl!(), "send containere_id failed, error: {:?}", e);
});
return;
}
} else if event.wd == cg_fd {
} else if event.wd == cg_wd {
let pids = get_value_from_cgroup(&cgroup_event_control_path, "populated");
if pids.unwrap_or(-1) == 0 {
return;
}
}
}
// When a cgroup is destroyed, an event is sent to eventfd.
// So if the control path is gone, return instead of notifying.
if !Path::new(&event_control_path).exists() {
return;
}
}
});
Ok(receiver)
@ -129,16 +131,16 @@ fn register_memory_event_v2(
// notify_on_oom returns channel on which you can expect event about OOM,
// if process died without OOM this channel will be closed.
fn notify_on_oom(cid: &str, dir: String) -> Result<Receiver<String>> {
async fn notify_on_oom(cid: &str, dir: String) -> Result<Receiver<String>> {
if dir == "" {
return Err(anyhow!("memory controller missing"));
}
register_memory_event(cid, dir, "memory.oom_control", "")
register_memory_event(cid, dir, "memory.oom_control", "").await
}
// level is one of "low", "medium", or "critical"
fn notify_memory_pressure(cid: &str, dir: String, level: &str) -> Result<Receiver<String>> {
async fn notify_memory_pressure(cid: &str, dir: String, level: &str) -> Result<Receiver<String>> {
if dir == "" {
return Err(anyhow!("memory controller missing"));
}
@ -147,10 +149,10 @@ fn notify_memory_pressure(cid: &str, dir: String, level: &str) -> Result<Receive
return Err(anyhow!("invalid pressure level {}", level));
}
register_memory_event(cid, dir, "memory.pressure_level", level)
register_memory_event(cid, dir, "memory.pressure_level", level).await
}
fn register_memory_event(
async fn register_memory_event(
cid: &str,
cg_dir: String,
event_name: &str,
@ -171,15 +173,16 @@ fn register_memory_event(
fs::write(&event_control_path, data)?;
let mut eventfd_file = unsafe { File::from_raw_fd(eventfd) };
let mut eventfd_stream = unsafe { PipeStream::from_raw_fd(eventfd) };
let (sender, receiver) = mpsc::channel();
let (sender, receiver) = tokio::sync::mpsc::channel(100);
let containere_id = cid.to_string();
thread::spawn(move || {
tokio::spawn(async move {
loop {
let mut buf = [0; 8];
match eventfd_file.read(&mut buf) {
let mut sender = sender.clone();
let mut buf = [0u8; 8];
match eventfd_stream.read(&mut buf).await {
Err(err) => {
warn!(sl!(), "failed to read from eventfd: {:?}", err);
return;
@ -198,7 +201,10 @@ fn register_memory_event(
if !Path::new(&event_control_path).exists() {
return;
}
sender.send(containere_id.clone()).unwrap();
let _ = sender.send(containere_id.clone()).await.map_err(|e| {
error!(sl!(), "send containere_id failed, error: {:?}", e);
});
}
});

View File

@ -14,7 +14,6 @@ use std::fmt::Display;
use std::fs;
use std::os::unix::io::RawFd;
use std::path::{Path, PathBuf};
use std::process::Command;
use std::time::SystemTime;
use cgroups::freezer::FreezerState;
@ -28,7 +27,6 @@ use crate::cgroups::Manager;
use crate::log_child;
use crate::process::Process;
use crate::specconv::CreateOpts;
use crate::sync::*;
use crate::{mount, validator};
use protocols::agent::StatsContainerResponse;
@ -49,12 +47,16 @@ use protobuf::SingularPtrField;
use oci::State as OCIState;
use std::collections::HashMap;
use std::io::BufRead;
use std::io::BufReader;
use std::os::unix::io::FromRawFd;
use slog::{info, o, Logger};
use crate::pipestream::PipeStream;
use crate::sync::{read_sync, write_count, write_sync, SYNC_DATA, SYNC_FAILED, SYNC_SUCCESS};
use crate::sync_with_async::{read_async, write_async};
use async_trait::async_trait;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
const STATE_FILENAME: &str = "state.json";
const EXEC_FIFO_FILENAME: &str = "exec.fifo";
const VER_MARKER: &str = "1.2.5";
@ -215,6 +217,7 @@ pub struct BaseState {
init_process_start: u64,
}
#[async_trait]
pub trait BaseContainer {
fn id(&self) -> String;
fn status(&self) -> Status;
@ -225,9 +228,9 @@ pub trait BaseContainer {
fn get_process(&mut self, eid: &str) -> Result<&mut Process>;
fn stats(&self) -> Result<StatsContainerResponse>;
fn set(&mut self, config: LinuxResources) -> Result<()>;
fn start(&mut self, p: Process) -> Result<()>;
fn run(&mut self, p: Process) -> Result<()>;
fn destroy(&mut self) -> Result<()>;
async fn start(&mut self, p: Process) -> Result<()>;
async fn run(&mut self, p: Process) -> Result<()>;
async fn destroy(&mut self) -> Result<()>;
fn signal(&self, sig: Signal, all: bool) -> Result<()>;
fn exec(&mut self) -> Result<()>;
}
@ -273,6 +276,7 @@ pub struct SyncPC {
pid: pid_t,
}
#[async_trait]
pub trait Container: BaseContainer {
fn pause(&mut self) -> Result<()>;
fn resume(&mut self) -> Result<()>;
@ -723,6 +727,7 @@ fn set_stdio_permissions(uid: libc::uid_t) -> Result<()> {
Ok(())
}
#[async_trait]
impl BaseContainer for LinuxContainer {
fn id(&self) -> String {
self.id.clone()
@ -816,7 +821,7 @@ impl BaseContainer for LinuxContainer {
Ok(())
}
fn start(&mut self, mut p: Process) -> Result<()> {
async fn start(&mut self, mut p: Process) -> Result<()> {
let logger = self.logger.new(o!("eid" => p.exec_id.clone()));
let tty = p.tty;
let fifo_file = format!("{}/{}", &self.root, EXEC_FIFO_FILENAME);
@ -854,7 +859,7 @@ impl BaseContainer for LinuxContainer {
.map_err(|e| warn!(logger, "fcntl pfd log FD_CLOEXEC {:?}", e));
let child_logger = logger.new(o!("action" => "child process log"));
let log_handler = setup_child_logger(pfd_log, child_logger)?;
let log_handler = setup_child_logger(pfd_log, child_logger);
let (prfd, cwfd) = unistd::pipe().context("failed to create pipe")?;
let (crfd, pwfd) = unistd::pipe().context("failed to create pipe")?;
@ -865,10 +870,8 @@ impl BaseContainer for LinuxContainer {
let _ = fcntl::fcntl(pwfd, FcntlArg::F_SETFD(FdFlag::FD_CLOEXEC))
.map_err(|e| warn!(logger, "fcntl pwfd FD_COLEXEC {:?}", e));
defer!({
let _ = unistd::close(prfd).map_err(|e| warn!(logger, "close prfd {:?}", e));
let _ = unistd::close(pwfd).map_err(|e| warn!(logger, "close pwfd {:?}", e));
});
let mut pipe_r = PipeStream::from_fd(prfd);
let mut pipe_w = PipeStream::from_fd(pwfd);
let child_stdin: std::process::Stdio;
let child_stdout: std::process::Stdio;
@ -928,7 +931,7 @@ impl BaseContainer for LinuxContainer {
unistd::close(cfd_log)?;
// get container process's pid
let pid_buf = read_sync(prfd)?;
let pid_buf = read_async(&mut pipe_r).await?;
let pid_str = std::str::from_utf8(&pid_buf).context("get pid string")?;
let pid = match pid_str.parse::<i32>() {
Ok(i) => i,
@ -958,9 +961,10 @@ impl BaseContainer for LinuxContainer {
&p,
self.cgroup_manager.as_ref().unwrap(),
&st,
pwfd,
prfd,
&mut pipe_w,
&mut pipe_r,
)
.await
.map_err(|e| {
error!(logger, "create container process error {:?}", e);
// kill the child process.
@ -995,15 +999,15 @@ impl BaseContainer for LinuxContainer {
info!(logger, "wait on child log handler");
let _ = log_handler
.join()
.await
.map_err(|e| warn!(logger, "joining log handler {:?}", e));
info!(logger, "create process completed");
Ok(())
}
fn run(&mut self, p: Process) -> Result<()> {
async fn run(&mut self, p: Process) -> Result<()> {
let init = p.init;
self.start(p)?;
self.start(p).await?;
if init {
self.exec()?;
@ -1013,7 +1017,7 @@ impl BaseContainer for LinuxContainer {
Ok(())
}
fn destroy(&mut self) -> Result<()> {
async fn destroy(&mut self) -> Result<()> {
let spec = self.config.spec.as_ref().unwrap();
let st = self.oci_state()?;
@ -1025,7 +1029,7 @@ impl BaseContainer for LinuxContainer {
info!(self.logger, "poststop");
let hooks = spec.hooks.as_ref().unwrap();
for h in hooks.poststop.iter() {
execute_hook(&self.logger, h, &st)?;
execute_hook(&self.logger, h, &st).await?;
}
}
@ -1177,42 +1181,38 @@ fn get_namespaces(linux: &Linux) -> Vec<LinuxNamespace> {
.collect()
}
pub fn setup_child_logger(fd: RawFd, child_logger: Logger) -> Result<std::thread::JoinHandle<()>> {
let builder = thread::Builder::new();
builder
.spawn(move || {
let log_file = unsafe { std::fs::File::from_raw_fd(fd) };
let mut reader = BufReader::new(log_file);
pub fn setup_child_logger(fd: RawFd, child_logger: Logger) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let log_file_stream = PipeStream::from_fd(fd);
let buf_reader_stream = tokio::io::BufReader::new(log_file_stream);
let mut lines = buf_reader_stream.lines();
loop {
let mut line = String::new();
match reader.read_line(&mut line) {
match lines.next_line().await {
Err(e) => {
info!(child_logger, "read child process log error: {:?}", e);
break;
}
Ok(count) => {
if count == 0 {
Ok(Some(line)) => {
info!(child_logger, "{}", line);
}
Ok(None) => {
info!(child_logger, "read child process log end",);
break;
}
info!(child_logger, "{}", line);
}
}
}
})
.map_err(|e| anyhow!(e).context("failed to create thread"))
}
fn join_namespaces(
async fn join_namespaces(
logger: &Logger,
spec: &Spec,
p: &Process,
cm: &FsManager,
st: &OCIState,
pwfd: RawFd,
prfd: RawFd,
pipe_w: &mut PipeStream,
pipe_r: &mut PipeStream,
) -> Result<()> {
let logger = logger.new(o!("action" => "join-namespaces"));
@ -1223,25 +1223,25 @@ fn join_namespaces(
info!(logger, "try to send spec from parent to child");
let spec_str = serde_json::to_string(spec)?;
write_sync(pwfd, SYNC_DATA, spec_str.as_str())?;
write_async(pipe_w, SYNC_DATA, spec_str.as_str()).await?;
info!(logger, "wait child received oci spec");
read_sync(prfd)?;
read_async(pipe_r).await?;
info!(logger, "send oci process from parent to child");
let process_str = serde_json::to_string(&p.oci)?;
write_sync(pwfd, SYNC_DATA, process_str.as_str())?;
write_async(pipe_w, SYNC_DATA, process_str.as_str()).await?;
info!(logger, "wait child received oci process");
read_sync(prfd)?;
read_async(pipe_r).await?;
let cm_str = serde_json::to_string(cm)?;
write_sync(pwfd, SYNC_DATA, cm_str.as_str())?;
write_async(pipe_w, SYNC_DATA, cm_str.as_str()).await?;
// wait child setup user namespace
info!(logger, "wait child setup user namespace");
read_sync(prfd)?;
read_async(pipe_r).await?;
if userns {
info!(logger, "setup uid/gid mappings");
@ -1270,11 +1270,11 @@ fn join_namespaces(
info!(logger, "notify child to continue");
// notify child to continue
write_sync(pwfd, SYNC_SUCCESS, "")?;
write_async(pipe_w, SYNC_SUCCESS, "").await?;
if p.init {
info!(logger, "notify child parent ready to run prestart hook!");
let _ = read_sync(prfd)?;
let _ = read_async(pipe_r).await?;
info!(logger, "get ready to run prestart hook!");
@ -1283,17 +1283,17 @@ fn join_namespaces(
info!(logger, "prestart hook");
let hooks = spec.hooks.as_ref().unwrap();
for h in hooks.prestart.iter() {
execute_hook(&logger, h, st)?;
execute_hook(&logger, h, st).await?;
}
}
// notify child run prestart hooks completed
info!(logger, "notify child run prestart hook completed!");
write_sync(pwfd, SYNC_SUCCESS, "")?;
write_async(pipe_w, SYNC_SUCCESS, "").await?;
info!(logger, "notify child parent ready to run poststart hook!");
// wait to run poststart hook
read_sync(prfd)?;
read_async(pipe_r).await?;
info!(logger, "get ready to run poststart hook!");
// run poststart hook
@ -1301,13 +1301,13 @@ fn join_namespaces(
info!(logger, "poststart hook");
let hooks = spec.hooks.as_ref().unwrap();
for h in hooks.poststart.iter() {
execute_hook(&logger, h, st)?;
execute_hook(&logger, h, st).await?;
}
}
}
info!(logger, "wait for child process ready to run exec");
read_sync(prfd)?;
read_async(pipe_r).await?;
Ok(())
}
@ -1509,14 +1509,11 @@ fn set_sysctls(sysctls: &HashMap<String, String>) -> Result<()> {
Ok(())
}
use std::io::Read;
use std::os::unix::process::ExitStatusExt;
use std::process::Stdio;
use std::sync::mpsc::{self, RecvTimeoutError};
use std::thread;
use std::time::Duration;
fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
async fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
let logger = logger.new(o!("action" => "execute-hook"));
let binary = PathBuf::from(h.path.as_str());
@ -1535,9 +1532,12 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
let _ = unistd::close(wfd);
});
let mut pipe_r = PipeStream::from_fd(rfd);
let mut pipe_w = PipeStream::from_fd(wfd);
match unistd::fork()? {
ForkResult::Parent { child } => {
let buf = read_sync(rfd)?;
let buf = read_async(&mut pipe_r).await?;
let status = if buf.len() == 4 {
let buf_array: [u8; 4] = [buf[0], buf[1], buf[2], buf[3]];
i32::from_be_bytes(buf_array)
@ -1561,13 +1561,13 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
}
ForkResult::Child => {
let (tx, rx) = mpsc::channel();
let (tx_logger, rx_logger) = mpsc::channel();
let (mut tx, mut rx) = tokio::sync::mpsc::channel(100);
let (tx_logger, rx_logger) = tokio::sync::oneshot::channel();
tx_logger.send(logger.clone()).unwrap();
let handle = thread::spawn(move || {
let logger = rx_logger.recv().unwrap();
let handle = tokio::spawn(async move {
let logger = rx_logger.await.unwrap();
// write oci state to child
let env: HashMap<String, String> = envs
@ -1578,7 +1578,7 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
})
.collect();
let mut child = Command::new(path.to_str().unwrap())
let mut child = tokio::process::Command::new(path.to_str().unwrap())
.args(args.iter())
.envs(env.iter())
.stdin(Stdio::piped())
@ -1588,7 +1588,7 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
.unwrap();
// send out our pid
tx.send(child.id() as libc::pid_t).unwrap();
tx.send(child.id() as libc::pid_t).await.unwrap();
info!(logger, "hook grand: {}", child.id());
child
@ -1596,6 +1596,7 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
.as_mut()
.unwrap()
.write_all(state.as_bytes())
.await
.unwrap();
// read something from stdout for debug
@ -1605,9 +1606,10 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
.as_mut()
.unwrap()
.read_to_string(&mut out)
.await
.unwrap();
info!(logger, "child stdout: {}", out.as_str());
match child.wait() {
match child.await {
Ok(exit) => {
let code: i32 = if exit.success() {
0
@ -1618,7 +1620,7 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
}
};
tx.send(code).unwrap();
tx.send(code).await.unwrap();
}
Err(e) => {
@ -1638,29 +1640,33 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
// -- FIXME
// just in case. Should not happen any more
tx.send(0).unwrap();
tx.send(0).await.unwrap();
}
}
});
let pid = rx.recv().unwrap();
let pid = rx.recv().await.unwrap();
info!(logger, "hook grand: {}", pid);
let status = {
if let Some(timeout) = h.timeout {
match rx.recv_timeout(Duration::from_secs(timeout as u64)) {
Ok(s) => s,
Err(e) => {
let error = if e == RecvTimeoutError::Timeout {
-libc::ETIMEDOUT
} else {
-libc::EPIPE
};
let timeout = tokio::time::delay_for(Duration::from_secs(timeout as u64));
tokio::select! {
v = rx.recv() => {
match v {
Some(s) => s,
None => {
let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL));
error
-libc::EPIPE
}
}
} else if let Ok(s) = rx.recv() {
}
_ = timeout => {
let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL));
-libc::ETIMEDOUT
}
}
} else if let Some(s) = rx.recv().await {
s
} else {
let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL));
@ -1668,12 +1674,13 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
}
};
handle.join().unwrap();
let _ = write_sync(
wfd,
handle.await.unwrap();
let _ = write_async(
&mut pipe_w,
SYNC_DATA,
std::str::from_utf8(&status.to_be_bytes()).unwrap_or_default(),
);
)
.await;
std::process::exit(0);
}
}

View File

@ -40,10 +40,12 @@ pub mod capabilities;
pub mod cgroups;
pub mod container;
pub mod mount;
pub mod pipestream;
pub mod process;
pub mod reaper;
pub mod specconv;
pub mod sync;
pub mod sync_with_async;
pub mod validator;
// pub mod factory;

View File

@ -0,0 +1,159 @@
// Copyright (c) 2020 Ant Group
//
// SPDX-License-Identifier: Apache-2.0
//
//! Async support for pipe or something has file descriptor
use std::{
fmt, io, mem,
os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
pin::Pin,
task::{Context, Poll},
};
use mio::event::Evented;
use mio::unix::EventedFd;
use mio::{Poll as MioPoll, PollOpt, Ready, Token};
use nix::unistd;
use tokio::io::{AsyncRead, AsyncWrite, PollEvented};
unsafe fn set_nonblocking(fd: RawFd) {
libc::fcntl(fd, libc::F_SETFL, libc::O_NONBLOCK);
}
struct StreamFd(RawFd);
impl Evented for StreamFd {
fn register(
&self,
poll: &MioPoll,
token: Token,
interest: Ready,
opts: PollOpt,
) -> io::Result<()> {
EventedFd(&self.0).register(poll, token, interest, opts)
}
fn reregister(
&self,
poll: &MioPoll,
token: Token,
interest: Ready,
opts: PollOpt,
) -> io::Result<()> {
EventedFd(&self.0).reregister(poll, token, interest, opts)
}
fn deregister(&self, poll: &MioPoll) -> io::Result<()> {
EventedFd(&self.0).deregister(poll)
}
}
impl io::Read for StreamFd {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match unistd::read(self.0, buf) {
Ok(l) => Ok(l),
Err(e) => Err(e.as_errno().unwrap().into()),
}
}
}
impl io::Write for StreamFd {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match unistd::write(self.0, buf) {
Ok(l) => Ok(l),
Err(e) => Err(e.as_errno().unwrap().into()),
}
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl StreamFd {
fn close(&mut self) -> io::Result<()> {
match unistd::close(self.0) {
Ok(()) => Ok(()),
Err(e) => Err(e.as_errno().unwrap().into()),
}
}
}
impl Drop for StreamFd {
fn drop(&mut self) {
self.close().ok();
}
}
/// Pipe read
pub struct PipeStream(PollEvented<StreamFd>);
impl PipeStream {
pub fn shutdown(&mut self) -> io::Result<()> {
self.0.get_mut().close()
}
pub fn from_fd(fd: RawFd) -> Self {
unsafe { Self::from_raw_fd(fd) }
}
}
impl AsyncRead for PipeStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
impl AsRawFd for PipeStream {
fn as_raw_fd(&self) -> RawFd {
self.0.get_ref().0
}
}
impl IntoRawFd for PipeStream {
fn into_raw_fd(self) -> RawFd {
let fd = self.0.get_ref().0;
mem::forget(self);
fd
}
}
impl FromRawFd for PipeStream {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
set_nonblocking(fd);
PipeStream(PollEvented::new(StreamFd(fd)).unwrap())
}
}
impl fmt::Debug for PipeStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PipeStream({})", self.as_raw_fd())
}
}
impl AsyncWrite for PipeStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}

View File

@ -6,7 +6,7 @@
use libc::pid_t;
use std::fs::File;
use std::os::unix::io::RawFd;
use std::sync::mpsc::Sender;
use tokio::sync::mpsc::Sender;
use nix::fcntl::{fcntl, FcntlArg, OFlag};
use nix::sys::signal::{self, Signal};
@ -18,6 +18,27 @@ use crate::reaper::Epoller;
use oci::Process as OCIProcess;
use slog::Logger;
use crate::pipestream::PipeStream;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::{split, ReadHalf, WriteHalf};
use tokio::sync::Mutex;
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub enum StreamType {
Stdin,
Stdout,
Stderr,
ExitPipeR,
TermMaster,
ParentStdin,
ParentStdout,
ParentStderr,
}
type Reader = Arc<Mutex<ReadHalf<PipeStream>>>;
type Writer = Arc<Mutex<WriteHalf<PipeStream>>>;
#[derive(Debug)]
pub struct Process {
pub exec_id: String,
@ -42,6 +63,9 @@ pub struct Process {
pub oci: OCIProcess,
pub logger: Logger,
pub epoller: Option<Epoller>,
readers: HashMap<StreamType, Reader>,
writers: HashMap<StreamType, Writer>,
}
pub trait ProcessOperations {
@ -94,6 +118,8 @@ impl Process {
oci: ocip.clone(),
logger: logger.clone(),
epoller: None,
readers: HashMap::new(),
writers: HashMap::new(),
};
info!(logger, "before create console socket!");
@ -138,7 +164,58 @@ impl Process {
}
Ok(())
}
fn get_fd(&self, stream_type: &StreamType) -> Option<RawFd> {
match stream_type {
StreamType::Stdin => self.stdin,
StreamType::Stdout => self.stdout,
StreamType::Stderr => self.stderr,
StreamType::ExitPipeR => self.exit_pipe_r,
StreamType::TermMaster => self.term_master,
StreamType::ParentStdin => self.parent_stdin,
StreamType::ParentStdout => self.parent_stdout,
StreamType::ParentStderr => self.parent_stderr,
}
}
fn get_stream_and_store(&mut self, stream_type: StreamType) -> Option<(Reader, Writer)> {
let fd = self.get_fd(&stream_type)?;
let stream = PipeStream::from_fd(fd);
let (reader, writer) = split(stream);
let reader = Arc::new(Mutex::new(reader));
let writer = Arc::new(Mutex::new(writer));
self.readers.insert(stream_type.clone(), reader.clone());
self.writers.insert(stream_type, writer.clone());
Some((reader, writer))
}
pub fn get_reader(&mut self, stream_type: StreamType) -> Option<Reader> {
if let Some(reader) = self.readers.get(&stream_type) {
return Some(reader.clone());
}
let (reader, _) = self.get_stream_and_store(stream_type)?;
Some(reader)
}
pub fn get_writer(&mut self, stream_type: StreamType) -> Option<Writer> {
if let Some(writer) = self.writers.get(&stream_type) {
return Some(writer.clone());
}
let (_, writer) = self.get_stream_and_store(stream_type)?;
Some(writer)
}
pub fn close_stream(&mut self, stream_type: StreamType) {
let _ = self.readers.remove(&stream_type);
let _ = self.writers.remove(&stream_type);
}
}
fn create_extended_pipe(flags: OFlag, pipe_size: i32) -> Result<(RawFd, RawFd)> {
let (r, w) = unistd::pipe2(flags)?;

View File

@ -14,8 +14,8 @@ pub const SYNC_SUCCESS: i32 = 1;
pub const SYNC_FAILED: i32 = 2;
pub const SYNC_DATA: i32 = 3;
const DATA_SIZE: usize = 100;
const MSG_SIZE: usize = mem::size_of::<i32>();
pub const DATA_SIZE: usize = 100;
pub const MSG_SIZE: usize = mem::size_of::<i32>();
#[macro_export]
macro_rules! log_child {

View File

@ -0,0 +1,148 @@
// Copyright (c) 2020 Ant Group
//
// SPDX-License-Identifier: Apache-2.0
//
//! The async version of sync module used for IPC
use crate::pipestream::PipeStream;
use anyhow::{anyhow, Result};
use nix::errno::Errno;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::sync::{DATA_SIZE, MSG_SIZE, SYNC_DATA, SYNC_FAILED, SYNC_SUCCESS};
async fn write_count(pipe_w: &mut PipeStream, buf: &[u8], count: usize) -> Result<usize> {
let mut len = 0;
loop {
match pipe_w.write(&buf[len..]).await {
Ok(l) => {
len += l;
if len == count {
break;
}
}
Err(e) => {
if e.raw_os_error().unwrap() != Errno::EINTR as i32 {
return Err(e.into());
}
}
}
}
Ok(len)
}
async fn read_count(pipe_r: &mut PipeStream, count: usize) -> Result<Vec<u8>> {
let mut v: Vec<u8> = vec![0; count];
let mut len = 0;
loop {
match pipe_r.read(&mut v[len..]).await {
Ok(l) => {
len += l;
if len == count || l == 0 {
break;
}
}
Err(e) => {
if e.raw_os_error().unwrap() != Errno::EINTR as i32 {
return Err(e.into());
}
}
}
}
Ok(v[0..len].to_vec())
}
pub async fn read_async(pipe_r: &mut PipeStream) -> Result<Vec<u8>> {
let buf = read_count(pipe_r, MSG_SIZE).await?;
if buf.len() != MSG_SIZE {
return Err(anyhow!(
"process: {} failed to receive async message from peer: got msg length: {}, expected: {}",
std::process::id(),
buf.len(),
MSG_SIZE
));
}
let buf_array: [u8; MSG_SIZE] = [buf[0], buf[1], buf[2], buf[3]];
let msg: i32 = i32::from_be_bytes(buf_array);
match msg {
SYNC_SUCCESS => Ok(Vec::new()),
SYNC_DATA => {
let buf = read_count(pipe_r, MSG_SIZE).await?;
let buf_array: [u8; MSG_SIZE] = [buf[0], buf[1], buf[2], buf[3]];
let msg_length: i32 = i32::from_be_bytes(buf_array);
let data_buf = read_count(pipe_r, msg_length as usize).await?;
Ok(data_buf)
}
SYNC_FAILED => {
let mut error_buf = vec![];
loop {
let buf = read_count(pipe_r, DATA_SIZE).await?;
error_buf.extend(&buf);
if DATA_SIZE == buf.len() {
continue;
} else {
break;
}
}
let error_str = match std::str::from_utf8(&error_buf) {
Ok(v) => String::from(v),
Err(e) => {
return Err(
anyhow!(e).context("receive error message from child process failed")
);
}
};
Err(anyhow!(error_str))
}
_ => Err(anyhow!("error in receive sync message")),
}
}
pub async fn write_async(pipe_w: &mut PipeStream, msg_type: i32, data_str: &str) -> Result<()> {
let buf = msg_type.to_be_bytes();
let count = write_count(pipe_w, &buf, MSG_SIZE).await?;
if count != MSG_SIZE {
return Err(anyhow!("error in send sync message"));
}
match msg_type {
SYNC_FAILED => match write_count(pipe_w, data_str.as_bytes(), data_str.len()).await {
Ok(_) => pipe_w.shutdown()?,
Err(e) => {
pipe_w.shutdown()?;
return Err(anyhow!(e).context("error in send message to process"));
}
},
SYNC_DATA => {
let length: i32 = data_str.len() as i32;
write_count(pipe_w, &length.to_be_bytes(), MSG_SIZE)
.await
.or_else(|e| {
pipe_w.shutdown()?;
Err(anyhow!(e).context("error in send message to process"))
})?;
write_count(pipe_w, data_str.as_bytes(), data_str.len())
.await
.or_else(|e| {
pipe_w.shutdown()?;
Err(anyhow!(e).context("error in send message to process"))
})?;
}
_ => (),
};
Ok(())
}

View File

@ -9,7 +9,8 @@ use std::collections::HashMap;
use std::fs;
use std::os::unix::fs::MetadataExt;
use std::path::Path;
use std::sync::{mpsc, Arc, Mutex};
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::linux_abi::*;
use crate::mount::{DRIVERBLKTYPE, DRIVERMMIOBLKTYPE, DRIVERNVDIMMTYPE, DRIVERSCSITYPE};
@ -35,22 +36,6 @@ struct DevIndexEntry {
struct DevIndex(HashMap<String, DevIndexEntry>);
// DeviceHandler is the type of callback to be defined to handle every type of device driver.
type DeviceHandler = fn(&Device, &mut Spec, &Arc<Mutex<Sandbox>>, &DevIndex) -> Result<()>;
// DEVICEHANDLERLIST lists the supported drivers.
#[rustfmt::skip]
lazy_static! {
static ref DEVICEHANDLERLIST: HashMap<&'static str, DeviceHandler> = {
let mut m: HashMap<&'static str, DeviceHandler> = HashMap::new();
m.insert(DRIVERBLKTYPE, virtio_blk_device_handler);
m.insert(DRIVERMMIOBLKTYPE, virtiommio_blk_device_handler);
m.insert(DRIVERNVDIMMTYPE, virtio_nvdimm_device_handler);
m.insert(DRIVERSCSITYPE, virtio_scsi_device_handler);
m
};
}
pub fn rescan_pci_bus() -> Result<()> {
online_device(SYSFS_PCI_BUS_RESCAN_FILE)
}
@ -114,10 +99,10 @@ fn get_pci_device_address(pci_id: &str) -> Result<String> {
Ok(bridge_device_pci_addr)
}
fn get_device_name(sandbox: &Arc<Mutex<Sandbox>>, dev_addr: &str) -> Result<String> {
async fn get_device_name(sandbox: &Arc<Mutex<Sandbox>>, dev_addr: &str) -> Result<String> {
// Keep the same lock order as uevent::handle_block_add_event(), otherwise it may cause deadlock.
let mut w = GLOBAL_DEVICE_WATCHER.lock().unwrap();
let sb = sandbox.lock().unwrap();
let mut w = GLOBAL_DEVICE_WATCHER.lock().await;
let sb = sandbox.lock().await;
for (key, value) in sb.pci_device_map.iter() {
if key.contains(dev_addr) {
info!(sl!(), "Device {} found in pci device map", dev_addr);
@ -131,36 +116,50 @@ fn get_device_name(sandbox: &Arc<Mutex<Sandbox>>, dev_addr: &str) -> Result<Stri
// The key of the watchers map is the device we are interested in.
// Note this is done inside the lock, not to miss any events from the
// global udev listener.
let (tx, rx) = mpsc::channel::<String>();
w.insert(dev_addr.to_string(), tx);
let (tx, rx) = tokio::sync::oneshot::channel::<String>();
w.insert(dev_addr.to_string(), Some(tx));
drop(w);
info!(sl!(), "Waiting on channel for device notification\n");
let hotplug_timeout = AGENT_CONFIG.read().unwrap().hotplug_timeout;
let dev_name = rx.recv_timeout(hotplug_timeout).map_err(|_| {
GLOBAL_DEVICE_WATCHER.lock().unwrap().remove_entry(dev_addr);
anyhow!(
let hotplug_timeout = AGENT_CONFIG.read().await.hotplug_timeout;
let timeout = tokio::time::delay_for(hotplug_timeout);
let dev_name;
tokio::select! {
v = rx => {
dev_name = v?;
}
_ = timeout => {
let watcher = GLOBAL_DEVICE_WATCHER.clone();
let mut w = watcher.lock().await;
w.remove_entry(dev_addr);
return Err(anyhow!(
"Timeout reached after {:?} waiting for device {}",
hotplug_timeout,
dev_addr
)
})?;
));
}
};
Ok(format!("{}/{}", SYSTEM_DEV_PATH, &dev_name))
}
pub fn get_scsi_device_name(sandbox: &Arc<Mutex<Sandbox>>, scsi_addr: &str) -> Result<String> {
pub async fn get_scsi_device_name(
sandbox: &Arc<Mutex<Sandbox>>,
scsi_addr: &str,
) -> Result<String> {
let dev_sub_path = format!("{}{}/{}", SCSI_HOST_CHANNEL, scsi_addr, SCSI_BLOCK_SUFFIX);
scan_scsi_bus(scsi_addr)?;
get_device_name(sandbox, &dev_sub_path)
get_device_name(sandbox, &dev_sub_path).await
}
pub fn get_pci_device_name(sandbox: &Arc<Mutex<Sandbox>>, pci_id: &str) -> Result<String> {
pub async fn get_pci_device_name(sandbox: &Arc<Mutex<Sandbox>>, pci_id: &str) -> Result<String> {
let pci_addr = get_pci_device_address(pci_id)?;
rescan_pci_bus()?;
get_device_name(sandbox, &pci_addr)
get_device_name(sandbox, &pci_addr).await
}
/// Scan SCSI bus for the given SCSI address(SCSI-Id and LUN)
@ -274,7 +273,7 @@ fn update_spec_device_list(device: &Device, spec: &mut Spec, devidx: &DevIndex)
// device.Id should be the predicted device name (vda, vdb, ...)
// device.VmPath already provides a way to send it in
fn virtiommio_blk_device_handler(
async fn virtiommio_blk_device_handler(
device: &Device,
spec: &mut Spec,
_sandbox: &Arc<Mutex<Sandbox>>,
@ -290,7 +289,7 @@ fn virtiommio_blk_device_handler(
// device.Id should be the PCI address in the format "bridgeAddr/deviceAddr".
// Here, bridgeAddr is the address at which the brige is attached on the root bus,
// while deviceAddr is the address at which the device is attached on the bridge.
fn virtio_blk_device_handler(
async fn virtio_blk_device_handler(
device: &Device,
spec: &mut Spec,
sandbox: &Arc<Mutex<Sandbox>>,
@ -301,25 +300,25 @@ fn virtio_blk_device_handler(
// When "Id (PCIAddr)" is not set, we allow to use the predicted "VmPath" passed from kata-runtime
// Note this is a special code path for cloud-hypervisor when BDF information is not available
if device.id != "" {
dev.vm_path = get_pci_device_name(sandbox, &device.id)?;
dev.vm_path = get_pci_device_name(sandbox, &device.id).await?;
}
update_spec_device_list(&dev, spec, devidx)
}
// device.Id should be the SCSI address of the disk in the format "scsiID:lunID"
fn virtio_scsi_device_handler(
async fn virtio_scsi_device_handler(
device: &Device,
spec: &mut Spec,
sandbox: &Arc<Mutex<Sandbox>>,
devidx: &DevIndex,
) -> Result<()> {
let mut dev = device.clone();
dev.vm_path = get_scsi_device_name(sandbox, &device.id)?;
dev.vm_path = get_scsi_device_name(sandbox, &device.id).await?;
update_spec_device_list(&dev, spec, devidx)
}
fn virtio_nvdimm_device_handler(
async fn virtio_nvdimm_device_handler(
device: &Device,
spec: &mut Spec,
_sandbox: &Arc<Mutex<Sandbox>>,
@ -357,7 +356,7 @@ impl DevIndex {
}
}
pub fn add_devices(
pub async fn add_devices(
devices: &[Device],
spec: &mut Spec,
sandbox: &Arc<Mutex<Sandbox>>,
@ -365,13 +364,13 @@ pub fn add_devices(
let devidx = DevIndex::new(spec);
for device in devices.iter() {
add_device(device, spec, sandbox, &devidx)?;
add_device(device, spec, sandbox, &devidx).await?;
}
Ok(())
}
fn add_device(
async fn add_device(
device: &Device,
spec: &mut Spec,
sandbox: &Arc<Mutex<Sandbox>>,
@ -393,9 +392,12 @@ fn add_device(
return Err(anyhow!("invalid container path for device {:?}", device));
}
match DEVICEHANDLERLIST.get(device.field_type.as_str()) {
None => Err(anyhow!("Unknown device type {}", device.field_type)),
Some(dev_handler) => dev_handler(device, spec, sandbox, devidx),
match device.field_type.as_str() {
DRIVERBLKTYPE => virtio_blk_device_handler(device, spec, sandbox, devidx).await,
DRIVERMMIOBLKTYPE => virtiommio_blk_device_handler(device, spec, sandbox, devidx).await,
DRIVERNVDIMMTYPE => virtio_nvdimm_device_handler(device, spec, sandbox, devidx).await,
DRIVERSCSITYPE => virtio_scsi_device_handler(device, spec, sandbox, devidx).await,
_ => Err(anyhow!("Unknown device type {}", device.field_type)),
}
}

View File

@ -15,7 +15,6 @@ extern crate prctl;
extern crate prometheus;
extern crate protocols;
extern crate regex;
extern crate rustjail;
extern crate scan_fmt;
extern crate serde_json;
extern crate signal_hook;
@ -38,7 +37,6 @@ use nix::sys::socket::{self, AddressFamily, SockAddr, SockFlag, SockType};
use nix::sys::wait::{self, WaitStatus};
use nix::unistd::{self, close, dup, dup2, fork, setsid, ForkResult};
use prctl::set_child_subreaper;
use signal_hook::{iterator::Signals, SIGCHLD};
use std::collections::HashMap;
use std::env;
use std::ffi::{CStr, CString, OsStr};
@ -48,9 +46,7 @@ use std::os::unix::ffi::OsStrExt;
use std::os::unix::fs as unixfs;
use std::os::unix::io::AsRawFd;
use std::path::Path;
use std::sync::mpsc::{self, Sender};
use std::sync::{Arc, Mutex, RwLock};
use std::{io, thread, thread::JoinHandle};
use std::sync::Arc;
use unistd::Pid;
mod config;
@ -72,6 +68,16 @@ use sandbox::Sandbox;
use slog::Logger;
use uevent::watch_uevents;
use std::sync::Mutex as SyncMutex;
use futures::StreamExt as _;
use rustjail::pipestream::PipeStream;
use tokio::{
signal::unix::{signal, SignalKind},
sync::{oneshot::Sender, Mutex, RwLock},
};
use tokio_vsock::{Incoming, VsockListener, VsockStream};
mod rpc;
const NAME: &str = "kata-agent";
@ -81,7 +87,7 @@ const CONSOLE_PATH: &str = "/dev/console";
const DEFAULT_BUF_SIZE: usize = 8 * 1024;
lazy_static! {
static ref GLOBAL_DEVICE_WATCHER: Arc<Mutex<HashMap<String, Sender<String>>>> =
static ref GLOBAL_DEVICE_WATCHER: Arc<Mutex<HashMap<String, Option<Sender<String>>>>> =
Arc::new(Mutex::new(HashMap::new()));
static ref AGENT_CONFIG: Arc<RwLock<agentConfig>> =
Arc::new(RwLock::new(config::agentConfig::new()));
@ -100,8 +106,28 @@ fn announce(logger: &Logger, config: &agentConfig) {
);
}
#[tokio::main]
async fn main() -> Result<()> {
fn set_fd_close_exec(fd: RawFd) -> Result<RawFd> {
if let Err(e) = fcntl::fcntl(fd, FcntlArg::F_SETFD(FdFlag::FD_CLOEXEC)) {
return Err(anyhow!("failed to set fd: {} as close-on-exec: {}", fd, e));
}
Ok(fd)
}
fn get_vsock_incoming(fd: RawFd) -> Incoming {
let incoming;
unsafe {
incoming = VsockListener::from_raw_fd(fd).incoming();
}
incoming
}
async fn get_vsock_stream(fd: RawFd) -> Result<VsockStream> {
let stream = get_vsock_incoming(fd).next().await.unwrap().unwrap();
set_fd_close_exec(stream.as_raw_fd())?;
Ok(stream)
}
fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let args: Vec<String> = env::args().collect();
if args.len() == 2 && args[1] == "--version" {
@ -121,6 +147,13 @@ async fn main() -> Result<()> {
exit(0);
}
let mut rt = tokio::runtime::Builder::new()
.basic_scheduler()
.max_threads(1)
.enable_all()
.build()?;
rt.block_on(async {
env::set_var("RUST_BACKTRACE", "full");
lazy_static::initialize(&SHELLS);
@ -151,7 +184,7 @@ async fn main() -> Result<()> {
e
})?;
let mut config = agentConfig.write().unwrap();
let mut config = agentConfig.write().await;
config.parse_cmdline(KERNEL_CMDLINE_FILE)?;
init_agent_as_init(&logger, config.unified_cgroup_hierarchy)?;
@ -159,37 +192,43 @@ async fn main() -> Result<()> {
// once parsed cmdline and set the config, release the write lock
// as soon as possible in case other thread would get read lock on
// it.
let mut config = agentConfig.write().unwrap();
let mut config = agentConfig.write().await;
config.parse_cmdline(KERNEL_CMDLINE_FILE)?;
}
let config = agentConfig.read().unwrap();
let config = agentConfig.read().await;
let log_vport = config.log_vport as u32;
let log_handle = thread::spawn(move || -> Result<()> {
let mut reader = unsafe { File::from_raw_fd(rfd) };
let log_handle = tokio::spawn(async move {
let mut reader = PipeStream::from_fd(rfd);
if log_vport > 0 {
let listenfd = socket::socket(
AddressFamily::Vsock,
SockType::Stream,
SockFlag::SOCK_CLOEXEC,
None,
)?;
)
.unwrap();
let addr = SockAddr::new_vsock(libc::VMADDR_CID_ANY, log_vport);
socket::bind(listenfd, &addr)?;
socket::listen(listenfd, 1)?;
let datafd = socket::accept4(listenfd, SockFlag::SOCK_CLOEXEC)?;
let mut log_writer = unsafe { File::from_raw_fd(datafd) };
let _ = io::copy(&mut reader, &mut log_writer)?;
let _ = unistd::close(listenfd);
let _ = unistd::close(datafd);
}
socket::bind(listenfd, &addr).unwrap();
socket::listen(listenfd, 1).unwrap();
let mut vsock_stream = get_vsock_stream(listenfd).await.unwrap();
// copy log to stdout
let mut stdout_writer = io::stdout();
let _ = io::copy(&mut reader, &mut stdout_writer)?;
Ok(())
tokio::io::copy(&mut reader, &mut vsock_stream)
.await
.unwrap();
}
// copy log to stdout
let mut stdout_writer = tokio::io::stdout();
let _ = tokio::io::copy(&mut reader, &mut stdout_writer).await;
});
let writer = unsafe { File::from_raw_fd(wfd) };
// Recreate a logger with the log level get from "/proc/cmdline".
let logger = logging::create_logger(NAME, "agent", config.log_level, writer);
@ -208,20 +247,22 @@ async fn main() -> Result<()> {
start_sandbox(&logger, &config, init_mode).await?;
let _ = log_handle.join();
let _ = log_handle.await.unwrap();
Ok(())
})
}
async fn start_sandbox(logger: &Logger, config: &agentConfig, init_mode: bool) -> Result<()> {
let shells = SHELLS.clone();
let debug_console_vport = config.debug_console_vport as u32;
let mut shell_handle: Option<JoinHandle<()>> = None;
// TODO: async the debug console
let mut _shell_handle: Option<std::thread::JoinHandle<()>> = None;
if config.debug_console {
let thread_logger = logger.clone();
let builder = thread::Builder::new();
let builder = std::thread::Builder::new();
let handle = builder.spawn(move || {
let shells = shells.lock().unwrap();
@ -233,7 +274,7 @@ async fn start_sandbox(logger: &Logger, config: &agentConfig, init_mode: bool) -
}
})?;
shell_handle = Some(handle);
_shell_handle = Some(handle);
}
// Initialize unique sandbox structure.
@ -248,41 +289,38 @@ async fn start_sandbox(logger: &Logger, config: &agentConfig, init_mode: bool) -
let sandbox = Arc::new(Mutex::new(s));
setup_signal_handler(&logger, sandbox.clone()).unwrap();
watch_uevents(sandbox.clone());
setup_signal_handler(&logger, sandbox.clone())
.await
.unwrap();
watch_uevents(sandbox.clone()).await;
let (tx, rx) = mpsc::channel::<i32>();
sandbox.lock().unwrap().sender = Some(tx);
let (tx, rx) = tokio::sync::oneshot::channel();
sandbox.lock().await.sender = Some(tx);
// vsock:///dev/vsock, port
let mut server = rpc::start(sandbox, config.server_addr.as_str()).await?;
let mut server = rpc::start(sandbox.clone(), config.server_addr.as_str());
server.start().await?;
let _ = rx.recv()?;
let _ = rx.await?;
server.shutdown().await?;
if let Some(handle) = shell_handle {
handle.join().map_err(|e| anyhow!("{:?}", e))?;
}
Ok(())
}
use nix::sys::wait::WaitPidFlag;
fn setup_signal_handler(logger: &Logger, sandbox: Arc<Mutex<Sandbox>>) -> Result<()> {
async fn setup_signal_handler(logger: &Logger, sandbox: Arc<Mutex<Sandbox>>) -> Result<()> {
let logger = logger.new(o!("subsystem" => "signals"));
set_child_subreaper(true)
.map_err(|err| anyhow!(err).context("failed to setup agent as a child subreaper"))?;
let signals = Signals::new(&[SIGCHLD])?;
let mut signal_stream = signal(SignalKind::child())?;
thread::spawn(move || {
'outer: for sig in signals.forever() {
info!(logger, "received signal"; "signal" => sig);
tokio::spawn(async move {
'outer: loop {
signal_stream.recv().await;
info!(logger, "received signal"; "signal" => "SIGCHLD");
// sevral signals can be combined together
// as one. So loop around to reap all
@ -316,7 +354,7 @@ fn setup_signal_handler(logger: &Logger, sandbox: Arc<Mutex<Sandbox>>) -> Result
let logger = logger.new(o!("child-pid" => child_pid));
let mut sandbox = sandbox.lock().unwrap();
let mut sandbox = sandbox.lock().await;
let process = sandbox.find_process(raw_pid);
if process.is_none() {
info!(logger, "child exited unexpectedly");
@ -375,7 +413,7 @@ fn init_agent_as_init(logger: &Logger, unified_cgroup_hierarchy: bool) -> Result
unistd::setsid()?;
unsafe {
libc::ioctl(io::stdin().as_raw_fd(), libc::TIOCSCTTY, 1);
libc::ioctl(std::io::stdin().as_raw_fd(), libc::TIOCSCTTY, 1);
}
env::set_var("PATH", "/bin:/sbin/:/usr/bin/:/usr/sbin/");
@ -406,7 +444,7 @@ fn sethostname(hostname: &OsStr) -> Result<()> {
}
lazy_static! {
static ref SHELLS: Arc<Mutex<Vec<String>>> = {
static ref SHELLS: Arc<SyncMutex<Vec<String>>> = {
let mut v = Vec::new();
if !cfg!(test) {
@ -414,7 +452,7 @@ lazy_static! {
v.push("/bin/sh".to_string());
}
Arc::new(Mutex::new(v))
Arc::new(SyncMutex::new(v))
};
}
@ -480,7 +518,7 @@ fn setup_debug_console(logger: &Logger, shells: Vec<String>, port: u32) -> Resul
};
}
fn io_copy<R: ?Sized, W: ?Sized>(reader: &mut R, writer: &mut W) -> io::Result<u64>
fn io_copy<R: ?Sized, W: ?Sized>(reader: &mut R, writer: &mut W) -> std::io::Result<u64>
where
R: Read,
W: Write,
@ -543,10 +581,10 @@ fn run_debug_console_shell(logger: &Logger, shell: &str, socket_fd: RawFd) -> Re
let debug_shell_logger = logger.clone();
// channel that used to sync between thread and main process
let (tx, rx) = mpsc::channel::<i32>();
let (tx, rx) = std::sync::mpsc::channel::<i32>();
// start a thread to do IO copy between socket and pseduo.master
thread::spawn(move || {
std::thread::spawn(move || {
let mut master_reader = unsafe { File::from_raw_fd(master_fd) };
let mut master_writer = unsafe { File::from_raw_fd(master_fd) };
let mut socket_reader = unsafe { File::from_raw_fd(socket_fd) };

View File

@ -12,7 +12,8 @@ use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::ptr::null;
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use tokio::sync::Mutex;
use libc::{c_void, mount};
use nix::mount::{self, MsFlags};
@ -121,32 +122,15 @@ lazy_static! {
];
}
// StorageHandler is the type of callback to be defined to handle every
// type of storage driver.
type StorageHandler = fn(&Logger, &Storage, Arc<Mutex<Sandbox>>) -> Result<String>;
// STORAGEHANDLERLIST lists the supported drivers.
#[rustfmt::skip]
lazy_static! {
pub static ref STORAGEHANDLERLIST: HashMap<&'static str, StorageHandler> = {
let mut m = HashMap::new();
let blk: StorageHandler = virtio_blk_storage_handler;
m.insert(DRIVERBLKTYPE, blk);
let p9: StorageHandler= virtio9p_storage_handler;
m.insert(DRIVER9PTYPE, p9);
let virtiofs: StorageHandler = virtiofs_storage_handler;
m.insert(DRIVERVIRTIOFSTYPE, virtiofs);
let ephemeral: StorageHandler = ephemeral_storage_handler;
m.insert(DRIVEREPHEMERALTYPE, ephemeral);
let virtiommio: StorageHandler = virtiommio_blk_storage_handler;
m.insert(DRIVERMMIOBLKTYPE, virtiommio);
let local: StorageHandler = local_storage_handler;
m.insert(DRIVERLOCALTYPE, local);
let scsi: StorageHandler = virtio_scsi_storage_handler;
m.insert(DRIVERSCSITYPE, scsi);
m
};
}
pub const STORAGE_HANDLER_LIST: [&str; 7] = [
DRIVERBLKTYPE,
DRIVER9PTYPE,
DRIVERVIRTIOFSTYPE,
DRIVEREPHEMERALTYPE,
DRIVERMMIOBLKTYPE,
DRIVERLOCALTYPE,
DRIVERSCSITYPE,
];
#[derive(Debug, Clone)]
pub struct BareMount<'a> {
@ -238,12 +222,12 @@ impl<'a> BareMount<'a> {
}
}
fn ephemeral_storage_handler(
async fn ephemeral_storage_handler(
logger: &Logger,
storage: &Storage,
sandbox: Arc<Mutex<Sandbox>>,
) -> Result<String> {
let mut sb = sandbox.lock().unwrap();
let mut sb = sandbox.lock().await;
let new_storage = sb.set_sandbox_storage(&storage.mount_point);
if !new_storage {
@ -256,12 +240,12 @@ fn ephemeral_storage_handler(
Ok("".to_string())
}
fn local_storage_handler(
async fn local_storage_handler(
_logger: &Logger,
storage: &Storage,
sandbox: Arc<Mutex<Sandbox>>,
) -> Result<String> {
let mut sb = sandbox.lock().unwrap();
let mut sb = sandbox.lock().await;
let new_storage = sb.set_sandbox_storage(&storage.mount_point);
if !new_storage {
@ -289,7 +273,7 @@ fn local_storage_handler(
Ok("".to_string())
}
fn virtio9p_storage_handler(
async fn virtio9p_storage_handler(
logger: &Logger,
storage: &Storage,
_sandbox: Arc<Mutex<Sandbox>>,
@ -298,7 +282,7 @@ fn virtio9p_storage_handler(
}
// virtiommio_blk_storage_handler handles the storage for mmio blk driver.
fn virtiommio_blk_storage_handler(
async fn virtiommio_blk_storage_handler(
logger: &Logger,
storage: &Storage,
_sandbox: Arc<Mutex<Sandbox>>,
@ -308,7 +292,7 @@ fn virtiommio_blk_storage_handler(
}
// virtiofs_storage_handler handles the storage for virtio-fs.
fn virtiofs_storage_handler(
async fn virtiofs_storage_handler(
logger: &Logger,
storage: &Storage,
_sandbox: Arc<Mutex<Sandbox>>,
@ -317,7 +301,7 @@ fn virtiofs_storage_handler(
}
// virtio_blk_storage_handler handles the storage for blk driver.
fn virtio_blk_storage_handler(
async fn virtio_blk_storage_handler(
logger: &Logger,
storage: &Storage,
sandbox: Arc<Mutex<Sandbox>>,
@ -334,7 +318,7 @@ fn virtio_blk_storage_handler(
return Err(anyhow!("Invalid device {}", &storage.source));
}
} else {
let dev_path = get_pci_device_name(&sandbox, &storage.source)?;
let dev_path = get_pci_device_name(&sandbox, &storage.source).await?;
storage.source = dev_path;
}
@ -342,7 +326,7 @@ fn virtio_blk_storage_handler(
}
// virtio_scsi_storage_handler handles the storage for scsi driver.
fn virtio_scsi_storage_handler(
async fn virtio_scsi_storage_handler(
logger: &Logger,
storage: &Storage,
sandbox: Arc<Mutex<Sandbox>>,
@ -350,7 +334,7 @@ fn virtio_scsi_storage_handler(
let mut storage = storage.clone();
// Retrieve the device path from SCSI address.
let dev_path = get_scsi_device_name(&sandbox, &storage.source)?;
let dev_path = get_scsi_device_name(&sandbox, &storage.source).await?;
storage.source = dev_path;
common_storage_handler(logger, &storage)
@ -430,7 +414,7 @@ fn parse_mount_flags_and_options(options_vec: Vec<&str>) -> (MsFlags, String) {
// associated operations such as waiting for the device to show up, and mount
// it to a specific location, according to the type of handler chosen, and for
// each storage.
pub fn add_storages(
pub async fn add_storages(
logger: Logger,
storages: Vec<Storage>,
sandbox: Arc<Mutex<Sandbox>>,
@ -443,17 +427,30 @@ pub fn add_storages(
"subsystem" => "storage",
"storage-type" => handler_name.to_owned()));
let handler = STORAGEHANDLERLIST
.get(&handler_name.as_str())
.ok_or_else(|| {
anyhow!(
let res = match handler_name.as_str() {
DRIVERBLKTYPE => virtio_blk_storage_handler(&logger, &storage, sandbox.clone()).await,
DRIVER9PTYPE => virtio9p_storage_handler(&logger, &storage, sandbox.clone()).await,
DRIVERVIRTIOFSTYPE => {
virtiofs_storage_handler(&logger, &storage, sandbox.clone()).await
}
DRIVEREPHEMERALTYPE => {
ephemeral_storage_handler(&logger, &storage, sandbox.clone()).await
}
DRIVERMMIOBLKTYPE => {
virtiommio_blk_storage_handler(&logger, &storage, sandbox.clone()).await
}
DRIVERLOCALTYPE => local_storage_handler(&logger, &storage, sandbox.clone()).await,
DRIVERSCSITYPE => virtio_scsi_storage_handler(&logger, &storage, sandbox.clone()).await,
_ => {
return Err(anyhow!(
"Failed to find the storage handler {}",
storage.driver.to_owned()
)
})?;
));
}
};
// Todo need to rollback the mounted storage if err met.
let mount_point = handler(&logger, &storage, sandbox.clone())?;
let mount_point = res?;
if !mount_point.is_empty() {
mount_list.push(mount_point);

View File

@ -11,7 +11,6 @@ use std::fmt;
use std::fs;
use std::fs::File;
use std::path::{Path, PathBuf};
use std::thread::{self};
use crate::mount::{BareMount, FLAGS};
use slog::Logger;
@ -76,7 +75,7 @@ impl Namespace {
// setup creates persistent namespace without switching to it.
// Note, pid namespaces cannot be persisted.
pub fn setup(mut self) -> Result<Self> {
pub async fn setup(mut self) -> Result<Self> {
fs::create_dir_all(&self.persistent_ns_dir)?;
let ns_path = PathBuf::from(&self.persistent_ns_dir);
@ -93,7 +92,8 @@ impl Namespace {
self.path = new_ns_path.clone().into_os_string().into_string().unwrap();
let hostname = self.hostname.clone();
let new_thread = thread::spawn(move || -> Result<()> {
let new_thread = tokio::spawn(async move {
if let Err(err) = || -> Result<()> {
let origin_ns_path = get_current_thread_ns_path(&ns_type.get());
File::open(Path::new(&origin_ns_path))?;
@ -127,11 +127,16 @@ impl Namespace {
)
})?;
Ok(())
}() {
return Err(err);
}
Ok(())
});
new_thread
.join()
.await
.map_err(|e| anyhow!("Failed to join thread {:?}!", e))??;
Ok(self)

View File

@ -4,9 +4,12 @@
//
use async_trait::async_trait;
use rustjail::{pipestream::PipeStream, process::StreamType};
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf};
use tokio::sync::Mutex;
use std::path::Path;
use std::sync::mpsc::channel;
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use ttrpc::{
self,
error::get_rpc_status as ttrpc_error,
@ -29,7 +32,6 @@ use protocols::types::Interface;
use rustjail::cgroups::notifier;
use rustjail::container::{BaseContainer, Container, LinuxContainer};
use rustjail::process::Process;
use rustjail::reaper;
use rustjail::specconv::CreateOpts;
use nix::errno::Errno;
@ -42,7 +44,7 @@ use rustjail::process::ProcessOperations;
use crate::device::{add_devices, rescan_pci_bus, update_device_cgroup};
use crate::linux_abi::*;
use crate::metrics::get_metrics;
use crate::mount::{add_storages, remove_mounts, BareMount, STORAGEHANDLERLIST};
use crate::mount::{add_storages, remove_mounts, BareMount, STORAGE_HANDLER_LIST};
use crate::namespace::{NSTYPEIPC, NSTYPEPID, NSTYPEUTS};
use crate::network::setup_guest_dns;
use crate::random;
@ -54,11 +56,8 @@ use netlink::{RtnlHandle, NETLINK_ROUTE};
use libc::{self, c_ushort, pid_t, winsize, TIOCSWINSZ};
use std::convert::TryFrom;
use std::fs;
use std::os::unix::io::RawFd;
use std::os::unix::prelude::PermissionsExt;
use std::process::{Command, Stdio};
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
use nix::unistd::{Gid, Uid};
@ -83,7 +82,10 @@ pub struct agentService {
}
impl agentService {
fn do_create_container(&self, req: protocols::agent::CreateContainerRequest) -> Result<()> {
async fn do_create_container(
&self,
req: protocols::agent::CreateContainerRequest,
) -> Result<()> {
let cid = req.container_id.clone();
let mut oci_spec = req.OCI.clone();
@ -111,7 +113,7 @@ impl agentService {
// updates the devices listed in the OCI spec, so that they actually
// match real devices inside the VM. This step is necessary since we
// cannot predict everything from the caller.
add_devices(&req.devices.to_vec(), &mut oci, &self.sandbox)?;
add_devices(&req.devices.to_vec(), &mut oci, &self.sandbox).await?;
// Both rootfs and volumes (invoked with --volume for instance) will
// be processed the same way. The idea is to always mount any provided
@ -120,10 +122,10 @@ impl agentService {
// After all those storages have been processed, no matter the order
// here, the agent will rely on rustjail (using the oci.Mounts
// list) to bind mount all of them inside the container.
let m = add_storages(sl!(), req.storages.to_vec(), self.sandbox.clone())?;
let m = add_storages(sl!(), req.storages.to_vec(), self.sandbox.clone()).await?;
{
sandbox = self.sandbox.clone();
s = sandbox.lock().unwrap();
s = sandbox.lock().await;
s.container_mounts.insert(cid.clone(), m);
}
@ -154,7 +156,7 @@ impl agentService {
let mut ctr: LinuxContainer =
LinuxContainer::new(cid.as_str(), CONTAINER_BASE, opts, &sl!())?;
let pipe_size = AGENT_CONFIG.read().unwrap().container_pipe_size;
let pipe_size = AGENT_CONFIG.read().await.container_pipe_size;
let p = if oci.process.is_some() {
Process::new(
&sl!(),
@ -168,7 +170,7 @@ impl agentService {
return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL)));
};
ctr.start(p)?;
ctr.start(p).await?;
s.update_shared_pidns(&ctr)?;
s.add_container(ctr);
@ -177,11 +179,11 @@ impl agentService {
Ok(())
}
fn do_start_container(&self, req: protocols::agent::StartContainerRequest) -> Result<()> {
async fn do_start_container(&self, req: protocols::agent::StartContainerRequest) -> Result<()> {
let cid = req.container_id;
let sandbox = self.sandbox.clone();
let mut s = sandbox.lock().unwrap();
let mut s = sandbox.lock().await;
let sid = s.id.clone();
let ctr = s
@ -194,8 +196,8 @@ impl agentService {
if sid != cid && ctr.cgroup_manager.is_some() {
let cg_path = ctr.cgroup_manager.as_ref().unwrap().get_cg_path("memory");
if cg_path.is_some() {
let rx = notifier::notify_oom(cid.as_str(), cg_path.unwrap())?;
s.run_oom_event_monitor(rx, cid.clone());
let rx = notifier::notify_oom(cid.as_str(), cg_path.unwrap()).await?;
s.run_oom_event_monitor(rx, cid.clone()).await;
}
}
@ -206,7 +208,10 @@ impl agentService {
Ok(())
}
fn do_remove_container(&self, req: protocols::agent::RemoveContainerRequest) -> Result<()> {
async fn do_remove_container(
&self,
req: protocols::agent::RemoveContainerRequest,
) -> Result<()> {
let cid = req.container_id.clone();
let mut cmounts: Vec<String> = vec![];
@ -234,12 +239,12 @@ impl agentService {
if req.timeout == 0 {
let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
let ctr = sandbox
.get_container(&cid)
.ok_or_else(|| anyhow!("Invalid container id"))?;
ctr.destroy()?;
ctr.destroy().await?;
remove_container_resources(&mut sandbox)?;
@ -249,43 +254,47 @@ impl agentService {
// timeout != 0
let s = self.sandbox.clone();
let cid2 = cid.clone();
let (tx, rx) = mpsc::channel();
let (tx, rx) = tokio::sync::oneshot::channel::<i32>();
let handle = thread::spawn(move || {
let mut sandbox = s.lock().unwrap();
let _ctr = sandbox
.get_container(&cid2)
.ok_or_else(|| anyhow!("Invalid container id"))
.map(|ctr| {
ctr.destroy().unwrap();
let handle = tokio::spawn(async move {
let mut sandbox = s.lock().await;
if let Some(ctr) = sandbox.get_container(&cid2) {
ctr.destroy().await.unwrap();
tx.send(1).unwrap();
ctr
});
};
});
rx.recv_timeout(Duration::from_secs(req.timeout as u64))
.map_err(|_| anyhow!(nix::Error::from_errno(nix::errno::Errno::ETIME)))?;
let timeout = tokio::time::delay_for(Duration::from_secs(req.timeout.into()));
handle
.join()
.map_err(|_| anyhow!(nix::Error::from_errno(nix::errno::Errno::UnknownErrno)))?;
tokio::select! {
_ = rx => {}
_ = timeout => {
return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::ETIME)));
}
};
if handle.await.is_err() {
return Err(anyhow!(nix::Error::from_errno(
nix::errno::Errno::UnknownErrno
)));
}
let s = self.sandbox.clone();
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
remove_container_resources(&mut sandbox)?;
Ok(())
}
fn do_exec_process(&self, req: protocols::agent::ExecProcessRequest) -> Result<()> {
async fn do_exec_process(&self, req: protocols::agent::ExecProcessRequest) -> Result<()> {
let cid = req.container_id.clone();
let exec_id = req.exec_id.clone();
info!(sl!(), "do_exec_process cid: {} eid: {}", cid, exec_id);
let s = self.sandbox.clone();
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
let process = if req.process.is_some() {
req.process.as_ref().unwrap()
@ -293,7 +302,7 @@ impl agentService {
return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL)));
};
let pipe_size = AGENT_CONFIG.read().unwrap().container_pipe_size;
let pipe_size = AGENT_CONFIG.read().await.container_pipe_size;
let ocip = rustjail::process_grpc_to_oci(process);
let p = Process::new(&sl!(), &ocip, exec_id.as_str(), false, pipe_size)?;
@ -301,7 +310,7 @@ impl agentService {
.get_container(&cid)
.ok_or_else(|| anyhow!("Invalid container id"))?;
ctr.run(p)?;
ctr.run(p).await?;
// set epoller
let p = find_process(&mut sandbox, cid.as_str(), exec_id.as_str(), false)?;
@ -310,11 +319,11 @@ impl agentService {
Ok(())
}
fn do_signal_process(&self, req: protocols::agent::SignalProcessRequest) -> Result<()> {
async fn do_signal_process(&self, req: protocols::agent::SignalProcessRequest) -> Result<()> {
let cid = req.container_id.clone();
let eid = req.exec_id.clone();
let s = self.sandbox.clone();
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
let mut init = false;
info!(
@ -344,7 +353,7 @@ impl agentService {
Ok(())
}
fn do_wait_process(
async fn do_wait_process(
&self,
req: protocols::agent::WaitProcessRequest,
) -> Result<protocols::agent::WaitProcessResponse> {
@ -353,9 +362,9 @@ impl agentService {
let s = self.sandbox.clone();
let mut resp = WaitProcessResponse::new();
let pid: pid_t;
let mut exit_pipe_r: RawFd = -1;
let mut buf: Vec<u8> = vec![0, 1];
let (exit_send, exit_recv) = channel();
let stream;
let (exit_send, mut exit_recv) = tokio::sync::mpsc::channel(100);
info!(
sl!(),
@ -365,24 +374,24 @@ impl agentService {
);
{
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false)?;
if p.exit_pipe_r.is_some() {
exit_pipe_r = p.exit_pipe_r.unwrap();
}
stream = p.get_reader(StreamType::ExitPipeR);
p.exit_watchers.push(exit_send);
pid = p.pid;
}
if exit_pipe_r != -1 {
if stream.is_some() {
info!(sl!(), "reading exit pipe");
let _ = unistd::read(exit_pipe_r, buf.as_mut_slice());
let reader = stream.unwrap();
let mut content: Vec<u8> = vec![0, 1];
let _ = reader.lock().await.read(&mut content).await;
}
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
let ctr = sandbox
.get_container(&cid)
.ok_or_else(|| anyhow!("Invalid container id"))?;
@ -391,44 +400,20 @@ impl agentService {
Some(p) => p,
None => {
// Lost race, pick up exit code from channel
resp.status = exit_recv.recv().unwrap();
resp.status = exit_recv.recv().await.unwrap();
return Ok(resp);
}
};
// need to close all fds
if p.parent_stdin.is_some() {
let _ = unistd::close(p.parent_stdin.unwrap());
}
if p.parent_stdout.is_some() {
let _ = unistd::close(p.parent_stdout.unwrap());
}
if p.parent_stderr.is_some() {
let _ = unistd::close(p.parent_stderr.unwrap());
}
if p.term_master.is_some() {
let _ = unistd::close(p.term_master.unwrap());
}
if p.exit_pipe_r.is_some() {
let _ = unistd::close(p.exit_pipe_r.unwrap());
}
p.close_epoller();
p.parent_stdin = None;
p.parent_stdout = None;
p.parent_stderr = None;
p.term_master = None;
// need to close all fd
// ignore errors for some fd might be closed by stream
let _ = cleanup_process(&mut p);
resp.status = p.exit_code;
// broadcast exit code to all parallel watchers
for s in p.exit_watchers.iter() {
for s in p.exit_watchers.iter_mut() {
// Just ignore errors in case any watcher quits unexpectedly
let _ = s.send(p.exit_code);
let _ = s.send(p.exit_code).await;
}
ctr.processes.remove(&pid);
@ -436,48 +421,37 @@ impl agentService {
Ok(resp)
}
fn do_write_stream(
async fn do_write_stream(
&self,
req: protocols::agent::WriteStreamRequest,
) -> Result<protocols::agent::WriteStreamResponse> {
let cid = req.container_id.clone();
let eid = req.exec_id.clone();
let writer = {
let s = self.sandbox.clone();
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false)?;
// use ptmx io
let fd = if p.term_master.is_some() {
p.term_master.unwrap()
if p.term_master.is_some() {
p.get_writer(StreamType::TermMaster)
} else {
// use piped io
p.parent_stdin.unwrap()
p.get_writer(StreamType::ParentStdin)
}
};
let mut l = req.data.len();
match unistd::write(fd, req.data.as_slice()) {
Ok(v) => {
if v < l {
info!(sl!(), "write {} bytes", v);
l = v;
}
}
Err(e) => match e {
nix::Error::Sys(nix::errno::Errno::EAGAIN) => l = 0,
_ => {
return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EIO)));
}
},
}
let writer = writer.unwrap();
writer.lock().await.write_all(req.data.as_slice()).await?;
let mut resp = WriteStreamResponse::new();
resp.set_len(l as u32);
resp.set_len(req.data.len() as u32);
Ok(resp)
}
fn do_read_stream(
async fn do_read_stream(
&self,
req: protocols::agent::ReadStreamRequest,
stdout: bool,
@ -485,42 +459,35 @@ impl agentService {
let cid = req.container_id;
let eid = req.exec_id;
let mut fd: RawFd = -1;
let mut epoller: Option<reaper::Epoller> = None;
{
// let mut fd: RawFd = -1;
// let mut epoller: Option<reaper::Epoller> = None;
let reader = {
let s = self.sandbox.clone();
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false)?;
if p.term_master.is_some() {
fd = p.term_master.unwrap();
epoller = p.epoller.clone();
// epoller = p.epoller.clone();
p.get_reader(StreamType::TermMaster)
} else if stdout {
if p.parent_stdout.is_some() {
fd = p.parent_stdout.unwrap();
p.get_reader(StreamType::ParentStdout)
} else {
None
}
} else {
fd = p.parent_stderr.unwrap();
}
p.get_reader(StreamType::ParentStderr)
}
};
if let Some(epoller) = epoller {
// The process's epoller's poll() will return a file descriptor of the process's
// terminal or one end of its exited pipe. If it returns its terminal, it means
// there is data needed to be read out or it has been closed; if it returns the
// process's exited pipe, it means the process has exited and there is no data
// needed to be read out in its terminal, thus following read on it will read out
// "EOF" to terminate this process's io since the other end of this pipe has been
// closed in reap().
fd = epoller.poll()?;
}
if fd == -1 {
if reader.is_none() {
return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL)));
}
let vector = read_stream(fd, req.len as usize)?;
let reader = reader.unwrap();
let vector = read_stream(reader, req.len as usize).await?;
let mut resp = ReadStreamResponse::new();
resp.set_data(vector);
@ -536,7 +503,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
_ctx: &TtrpcContext,
req: protocols::agent::CreateContainerRequest,
) -> ttrpc::Result<Empty> {
match self.do_create_container(req) {
match self.do_create_container(req).await {
Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())),
Ok(_) => Ok(Empty::new()),
}
@ -547,7 +514,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
_ctx: &TtrpcContext,
req: protocols::agent::StartContainerRequest,
) -> ttrpc::Result<Empty> {
match self.do_start_container(req) {
match self.do_start_container(req).await {
Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())),
Ok(_) => Ok(Empty::new()),
}
@ -558,7 +525,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
_ctx: &TtrpcContext,
req: protocols::agent::RemoveContainerRequest,
) -> ttrpc::Result<Empty> {
match self.do_remove_container(req) {
match self.do_remove_container(req).await {
Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())),
Ok(_) => Ok(Empty::new()),
}
@ -569,7 +536,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
_ctx: &TtrpcContext,
req: protocols::agent::ExecProcessRequest,
) -> ttrpc::Result<Empty> {
match self.do_exec_process(req) {
match self.do_exec_process(req).await {
Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())),
Ok(_) => Ok(Empty::new()),
}
@ -580,7 +547,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
_ctx: &TtrpcContext,
req: protocols::agent::SignalProcessRequest,
) -> ttrpc::Result<Empty> {
match self.do_signal_process(req) {
match self.do_signal_process(req).await {
Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())),
Ok(_) => Ok(Empty::new()),
}
@ -592,6 +559,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
req: protocols::agent::WaitProcessRequest,
) -> ttrpc::Result<WaitProcessResponse> {
self.do_wait_process(req)
.await
.map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))
}
@ -606,7 +574,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
let mut resp = ListProcessesResponse::new();
let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
let ctr = sandbox.get_container(&cid).ok_or_else(|| {
ttrpc_error(
@ -637,10 +605,11 @@ impl protocols::agent_ttrpc::AgentService for agentService {
args = vec!["-ef".to_string()];
}
let output = Command::new("ps")
let output = tokio::process::Command::new("ps")
.args(args.as_slice())
.stdout(Stdio::piped())
.output()
.await
.expect("ps failed");
let out: String = String::from_utf8(output.stdout).unwrap();
@ -688,7 +657,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
let res = req.resources;
let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
let ctr = sandbox.get_container(&cid).ok_or_else(|| {
ttrpc_error(
@ -720,7 +689,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
) -> ttrpc::Result<StatsContainerResponse> {
let cid = req.container_id;
let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
let ctr = sandbox.get_container(&cid).ok_or_else(|| {
ttrpc_error(
@ -740,7 +709,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
) -> ttrpc::Result<protocols::empty::Empty> {
let cid = req.get_container_id();
let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
let ctr = sandbox.get_container(&cid).ok_or_else(|| {
ttrpc_error(
@ -762,7 +731,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
) -> ttrpc::Result<protocols::empty::Empty> {
let cid = req.get_container_id();
let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
let ctr = sandbox.get_container(&cid).ok_or_else(|| {
ttrpc_error(
@ -783,6 +752,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
req: protocols::agent::WriteStreamRequest,
) -> ttrpc::Result<WriteStreamResponse> {
self.do_write_stream(req)
.await
.map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))
}
@ -792,6 +762,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
req: protocols::agent::ReadStreamRequest,
) -> ttrpc::Result<ReadStreamResponse> {
self.do_read_stream(req, true)
.await
.map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))
}
@ -801,6 +772,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
req: protocols::agent::ReadStreamRequest,
) -> ttrpc::Result<ReadStreamResponse> {
self.do_read_stream(req, false)
.await
.map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))
}
@ -812,7 +784,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
let cid = req.container_id.clone();
let eid = req.exec_id;
let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false).map_err(|e| {
ttrpc_error(
@ -822,11 +794,13 @@ impl protocols::agent_ttrpc::AgentService for agentService {
})?;
if p.term_master.is_some() {
p.close_stream(StreamType::TermMaster);
let _ = unistd::close(p.term_master.unwrap());
p.term_master = None;
}
if p.parent_stdin.is_some() {
p.close_stream(StreamType::ParentStdin);
let _ = unistd::close(p.parent_stdin.unwrap());
p.parent_stdin = None;
}
@ -844,7 +818,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
let cid = req.container_id.clone();
let eid = req.exec_id.clone();
let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false).map_err(|e| {
ttrpc_error(
ttrpc::Code::UNAVAILABLE,
@ -888,7 +862,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
let interface = req.interface;
let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
if sandbox.rtnl.is_none() {
sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap());
@ -921,7 +895,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
let rs = req.routes.unwrap().Routes.into_vec();
let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
if sandbox.rtnl.is_none() {
sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap());
@ -951,7 +925,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
) -> ttrpc::Result<Interfaces> {
let mut interface = protocols::agent::Interfaces::new();
let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
if sandbox.rtnl.is_none() {
sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap());
@ -974,7 +948,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
) -> ttrpc::Result<Routes> {
let mut routes = protocols::agent::Routes::new();
let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
if sandbox.rtnl.is_none() {
sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap());
@ -1015,7 +989,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
) -> ttrpc::Result<Empty> {
{
let sandbox = self.sandbox.clone();
let mut s = sandbox.lock().unwrap();
let mut s = sandbox.lock().await;
let _ = fs::remove_dir_all(CONTAINER_BASE);
let _ = fs::create_dir_all(CONTAINER_BASE);
@ -1042,13 +1016,14 @@ impl protocols::agent_ttrpc::AgentService for agentService {
}
s.setup_shared_namespaces()
.await
.map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))?;
}
match add_storages(sl!(), req.storages.to_vec(), self.sandbox.clone()) {
match add_storages(sl!(), req.storages.to_vec(), self.sandbox.clone()).await {
Ok(m) => {
let sandbox = self.sandbox.clone();
let mut s = sandbox.lock().unwrap();
let mut s = sandbox.lock().await;
s.mounts = m
}
Err(e) => return Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())),
@ -1057,7 +1032,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
match setup_guest_dns(sl!(), req.dns.to_vec()) {
Ok(_) => {
let sandbox = self.sandbox.clone();
let mut s = sandbox.lock().unwrap();
let mut s = sandbox.lock().await;
let _dns = req
.dns
.to_vec()
@ -1076,13 +1051,11 @@ impl protocols::agent_ttrpc::AgentService for agentService {
_req: protocols::agent::DestroySandboxRequest,
) -> ttrpc::Result<Empty> {
let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
// destroy all containers, clean up, notify agent to exit
// etc.
sandbox.destroy().unwrap();
sandbox.sender.as_ref().unwrap().send(1).unwrap();
sandbox.sender = None;
sandbox.destroy().await.unwrap();
sandbox.sender.take().unwrap().send(1).unwrap();
Ok(Empty::new())
}
@ -1102,7 +1075,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
let neighs = req.neighbors.unwrap().ARPNeighbors.into_vec();
let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap();
let mut sandbox = s.lock().await;
if sandbox.rtnl.is_none() {
sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap());
@ -1122,7 +1095,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
req: protocols::agent::OnlineCPUMemRequest,
) -> ttrpc::Result<Empty> {
let s = Arc::clone(&self.sandbox);
let sandbox = s.lock().unwrap();
let sandbox = s.lock().await;
sandbox
.online_cpu_memory(&req)
@ -1221,15 +1194,15 @@ impl protocols::agent_ttrpc::AgentService for agentService {
_req: protocols::agent::GetOOMEventRequest,
) -> ttrpc::Result<OOMEvent> {
let sandbox = self.sandbox.clone();
let s = sandbox.lock().unwrap();
let s = sandbox.lock().await;
let event_rx = &s.event_rx.clone();
let event_rx = event_rx.lock().unwrap();
let mut event_rx = event_rx.lock().await;
drop(s);
drop(sandbox);
match event_rx.recv() {
Err(err) => Err(ttrpc_error(ttrpc::Code::INTERNAL, err.to_string())),
Ok(container_id) => {
match event_rx.recv().await {
None => Err(ttrpc_error(ttrpc::Code::INTERNAL, "")),
Some(container_id) => {
info!(sl!(), "get_oom_event return {}", &container_id);
let mut resp = OOMEvent::new();
resp.container_id = container_id;
@ -1326,42 +1299,28 @@ fn get_agent_details() -> AgentDetails {
detail.device_handlers = RepeatedField::new();
detail.storage_handlers = RepeatedField::from_vec(
STORAGEHANDLERLIST
.keys()
.cloned()
.map(|x| x.into())
STORAGE_HANDLER_LIST
.to_vec()
.iter()
.map(|x| x.to_string())
.collect(),
);
detail
}
fn read_stream(fd: RawFd, l: usize) -> Result<Vec<u8>> {
let mut v: Vec<u8> = Vec::with_capacity(l);
unsafe {
v.set_len(l);
}
async fn read_stream(reader: Arc<Mutex<ReadHalf<PipeStream>>>, l: usize) -> Result<Vec<u8>> {
let mut content = vec![0u8; l];
let mut reader = reader.lock().await;
let len = reader.read(&mut content).await?;
content.resize(len, 0);
match unistd::read(fd, v.as_mut_slice()) {
Ok(len) => {
v.resize(len, 0);
// Rust didn't return an EOF error when the reading peer point
// was closed, instead it would return a 0 reading length, please
// see https://github.com/rust-lang/rfcs/blob/master/text/0517-io-os-reform.md#errors
if len == 0 {
return Err(anyhow!("read meet eof"));
}
}
Err(e) => match e {
nix::Error::Sys(errno) => match errno {
Errno::EAGAIN => v.clear(),
_ => return Err(anyhow!(nix::Error::Sys(errno))),
},
_ => return Err(anyhow!("read error")),
},
}
Ok(v)
Ok(content)
}
fn find_process<'a>(
@ -1384,7 +1343,7 @@ fn find_process<'a>(
ctr.get_process(eid).map_err(|_| anyhow!("Invalid exec id"))
}
pub async fn start(s: Arc<Mutex<Sandbox>>, server_address: &str) -> Result<TtrpcServer> {
pub fn start(s: Arc<Mutex<Sandbox>>, server_address: &str) -> TtrpcServer {
let agent_service = Box::new(agentService { sandbox: s })
as Box<dyn protocols::agent_ttrpc::AgentService + Send + Sync>;
@ -1399,13 +1358,14 @@ pub async fn start(s: Arc<Mutex<Sandbox>>, server_address: &str) -> Result<Ttrpc
let hservice = protocols::health_ttrpc::create_health(health_worker);
let server = TtrpcServer::new()
.bind(server_address)?
.bind(server_address)
.unwrap()
.register_service(aservice)
.register_service(hservice);
info!(sl!(), "ttRPC server started"; "address" => server_address);
Ok(server)
server
}
// This function updates the container namespaces configuration based on the
@ -1641,6 +1601,42 @@ fn setup_bundle(cid: &str, spec: &mut Spec) -> Result<PathBuf> {
Ok(olddir)
}
fn cleanup_process(p: &mut Process) -> Result<()> {
if p.parent_stdin.is_some() {
p.close_stream(StreamType::ParentStdin);
let _ = unistd::close(p.parent_stdin.unwrap())?;
}
if p.parent_stdout.is_some() {
p.close_stream(StreamType::ParentStdout);
let _ = unistd::close(p.parent_stdout.unwrap())?;
}
if p.parent_stderr.is_some() {
p.close_stream(StreamType::ParentStderr);
let _ = unistd::close(p.parent_stderr.unwrap())?;
}
if p.term_master.is_some() {
p.close_stream(StreamType::TermMaster);
let _ = unistd::close(p.term_master.unwrap())?;
}
if p.exit_pipe_r.is_some() {
p.close_stream(StreamType::ExitPipeR);
let _ = unistd::close(p.exit_pipe_r.unwrap())?;
}
p.close_epoller();
p.parent_stdin = None;
p.parent_stdout = None;
p.parent_stderr = None;
p.term_master = None;
Ok(())
}
fn load_kernel_module(module: &protocols::agent::KernelModule) -> Result<()> {
if module.name == "" {
return Err(anyhow!("Kernel module name is empty"));

View File

@ -22,9 +22,10 @@ use std::collections::HashMap;
use std::fs;
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::sync::mpsc::{self, Receiver, Sender};
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use std::{thread, time};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::Mutex;
#[derive(Debug)]
pub struct Sandbox {
@ -42,7 +43,7 @@ pub struct Sandbox {
pub storages: HashMap<String, u32>,
pub running: bool,
pub no_pivot_root: bool,
pub sender: Option<Sender<i32>>,
pub sender: Option<tokio::sync::oneshot::Sender<i32>>,
pub rtnl: Option<RtnlHandle>,
pub hooks: Option<Hooks>,
pub event_rx: Arc<Mutex<Receiver<String>>>,
@ -53,7 +54,7 @@ impl Sandbox {
pub fn new(logger: &Logger) -> Result<Self> {
let fs_type = get_mount_fs_type("/")?;
let logger = logger.new(o!("subsystem" => "sandbox"));
let (tx, rx) = mpsc::channel::<String>();
let (tx, rx) = channel::<String>(100);
let event_rx = Arc::new(Mutex::new(rx));
Ok(Sandbox {
@ -157,17 +158,19 @@ impl Sandbox {
self.hostname = hostname;
}
pub fn setup_shared_namespaces(&mut self) -> Result<bool> {
pub async fn setup_shared_namespaces(&mut self) -> Result<bool> {
// Set up shared IPC namespace
self.shared_ipcns = Namespace::new(&self.logger)
.get_ipc()
.setup()
.await
.context("Failed to setup persistent IPC namespace")?;
// // Set up shared UTS namespace
self.shared_utsns = Namespace::new(&self.logger)
.get_uts(self.hostname.as_str())
.setup()
.await
.context("Failed to setup persistent UTS namespace")?;
Ok(true)
@ -214,9 +217,9 @@ impl Sandbox {
None
}
pub fn destroy(&mut self) -> Result<()> {
pub async fn destroy(&mut self) -> Result<()> {
for ctr in self.containers.values_mut() {
ctr.destroy()?;
ctr.destroy().await?;
}
Ok(())
}
@ -315,15 +318,17 @@ impl Sandbox {
Ok(hooks)
}
pub fn run_oom_event_monitor(&self, rx: Receiver<String>, container_id: String) {
let tx = self.event_tx.clone();
pub async fn run_oom_event_monitor(&self, mut rx: Receiver<String>, container_id: String) {
let mut tx = self.event_tx.clone();
let logger = self.logger.clone();
thread::spawn(move || {
for event in rx {
tokio::spawn(async move {
loop {
let event = rx.recv().await;
info!(logger, "got an OOM event {:?}", event);
let _ = tx
.send(container_id.clone())
.await
.map_err(|e| error!(logger, "failed to send message: {:?}", e));
}
});

View File

@ -7,10 +7,13 @@ use crate::device::online_device;
use crate::linux_abi::*;
use crate::sandbox::Sandbox;
use crate::GLOBAL_DEVICE_WATCHER;
use netlink::{RtnlHandle, NETLINK_UEVENT};
use slog::Logger;
use std::sync::{Arc, Mutex};
use std::thread;
use netlink_sys::{Protocol, Socket, SocketAddr};
use nix::errno::Errno;
use std::os::unix::io::FromRawFd;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Default)]
struct Uevent {
@ -55,12 +58,13 @@ impl Uevent {
&& self.devname != ""
}
fn handle_block_add_event(&self, sandbox: &Arc<Mutex<Sandbox>>) {
async fn handle_block_add_event(&self, sandbox: &Arc<Mutex<Sandbox>>) {
let pci_root_bus_path = create_pci_root_bus_path();
// Keep the same lock order as device::get_device_name(), otherwise it may cause deadlock.
let mut w = GLOBAL_DEVICE_WATCHER.lock().unwrap();
let mut sb = sandbox.lock().unwrap();
let watcher = GLOBAL_DEVICE_WATCHER.clone();
let mut w = watcher.lock().await;
let mut sb = sandbox.lock().await;
// Add the device node name to the pci device map.
sb.pci_device_map
@ -70,7 +74,7 @@ impl Uevent {
// Close the channel after watcher has been notified.
let devpath = self.devpath.clone();
let empties: Vec<_> = w
.iter()
.iter_mut()
.filter(|(dev_addr, _)| {
let pci_p = format!("{}/{}", pci_root_bus_path, *dev_addr);
@ -84,6 +88,7 @@ impl Uevent {
})
.map(|(k, sender)| {
let devname = self.devname.clone();
let sender = sender.take().unwrap();
let _ = sender.send(devname);
k.clone()
})
@ -95,9 +100,9 @@ impl Uevent {
}
}
fn process(&self, logger: &Logger, sandbox: &Arc<Mutex<Sandbox>>) {
async fn process(&self, logger: &Logger, sandbox: &Arc<Mutex<Sandbox>>) {
if self.is_block_add_event() {
return self.handle_block_add_event(sandbox);
return self.handle_block_add_event(sandbox).await;
} else if self.action == U_EVENT_ACTION_ADD {
let online_path = format!("{}/{}/online", SYSFS_DIR, &self.devpath);
// It's a memory hot-add event.
@ -117,22 +122,37 @@ impl Uevent {
}
}
pub fn watch_uevents(sandbox: Arc<Mutex<Sandbox>>) {
thread::spawn(move || {
let rtnl = RtnlHandle::new(NETLINK_UEVENT, 1).unwrap();
let logger = sandbox
.lock()
.unwrap()
.logger
.new(o!("subsystem" => "uevent"));
pub async fn watch_uevents(sandbox: Arc<Mutex<Sandbox>>) {
let sref = sandbox.clone();
let s = sref.lock().await;
let logger = s.logger.new(o!("subsystem" => "uevent"));
tokio::spawn(async move {
let mut socket;
unsafe {
let fd = libc::socket(
libc::AF_NETLINK,
libc::SOCK_DGRAM | libc::SOCK_CLOEXEC,
Protocol::KObjectUevent as libc::c_int,
);
socket = Socket::from_raw_fd(fd);
}
socket.bind(&SocketAddr::new(0, 1)).unwrap();
loop {
match rtnl.recv_message() {
match socket.recv_from_full().await {
Err(e) => {
error!(logger, "receive uevent message failed"; "error" => format!("{}", e))
}
Ok(data) => {
let text = String::from_utf8(data);
Ok((buf, addr)) => {
if addr.port_number() != 0 {
// not our netlink message
let err_msg = format!("{:?}", nix::Error::Sys(Errno::EBADMSG));
error!(logger, "receive uevent message failed"; "error" => err_msg);
return;
}
let text = String::from_utf8(buf);
match text {
Err(e) => {
error!(logger, "failed to convert bytes to text"; "error" => format!("{}", e))
@ -140,7 +160,7 @@ pub fn watch_uevents(sandbox: Arc<Mutex<Sandbox>>) {
Ok(text) => {
let event = Uevent::new(&text);
info!(logger, "got uevent message"; "event" => format!("{:?}", event));
event.process(&logger, &sandbox);
event.process(&logger, &sandbox).await;
}
}
}