agent: switch to async runtime

Fixes: #1209

Signed-off-by: Tim Zhang <tim@hyper.sh>
This commit is contained in:
Tim Zhang 2020-12-23 15:11:40 +08:00
parent 5561755e3c
commit 332fa4c65f
18 changed files with 1225 additions and 634 deletions

143
src/agent/Cargo.lock generated
View File

@ -161,9 +161,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]] [[package]]
name = "cgroups-rs" name = "cgroups-rs"
version = "0.2.0" version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02274214de2526e48355facdd16c9d774bba2cf74d135ffb9876a60b4d613464" checksum = "348eb6d8e20a9f5247209686b7d0ffc2f4df40ddcb95f9940de55a94a655b3f5"
dependencies = [ dependencies = [
"libc", "libc",
"log", "log",
@ -486,6 +486,28 @@ version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "644f9158b2f133fd50f5fb3242878846d9eb792e445c893805ff0e3824006e35" checksum = "644f9158b2f133fd50f5fb3242878846d9eb792e445c893805ff0e3824006e35"
[[package]]
name = "inotify"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04c6848dfb1580647ab039713282cdd1ab2bfb47b60ecfb598e22e60e3baf3f8"
dependencies = [
"bitflags",
"futures-core",
"inotify-sys",
"libc",
"tokio 0.3.6",
]
[[package]]
name = "inotify-sys"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4563555856585ab3180a5bf0b2f9f8d301a728462afffc8195b3f5394229c55"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "iovec" name = "iovec"
version = "0.1.4" version = "0.1.4"
@ -517,11 +539,13 @@ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
"cgroups-rs", "cgroups-rs",
"futures",
"lazy_static", "lazy_static",
"libc", "libc",
"log", "log",
"logging", "logging",
"netlink", "netlink",
"netlink-sys",
"nix 0.17.0", "nix 0.17.0",
"oci", "oci",
"prctl", "prctl",
@ -539,7 +563,8 @@ dependencies = [
"slog-scope", "slog-scope",
"slog-stdlog", "slog-stdlog",
"tempfile", "tempfile",
"tokio", "tokio 0.2.24",
"tokio-vsock",
"ttrpc", "ttrpc",
] ]
@ -647,12 +672,37 @@ dependencies = [
"kernel32-sys", "kernel32-sys",
"libc", "libc",
"log", "log",
"miow", "miow 0.2.2",
"net2", "net2",
"slab", "slab",
"winapi 0.2.8", "winapi 0.2.8",
] ]
[[package]]
name = "mio"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f33bc887064ef1fd66020c9adfc45bb9f33d75a42096c81e7c56c65b75dd1a8b"
dependencies = [
"libc",
"log",
"miow 0.3.6",
"ntapi",
"winapi 0.3.9",
]
[[package]]
name = "mio-named-pipes"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0840c1c50fd55e521b247f949c241c9997709f23bd7f023b9762cd561e935656"
dependencies = [
"log",
"mio 0.6.23",
"miow 0.3.6",
"winapi 0.3.9",
]
[[package]] [[package]]
name = "mio-uds" name = "mio-uds"
version = "0.6.8" version = "0.6.8"
@ -661,7 +711,7 @@ checksum = "afcb699eb26d4332647cc848492bbc15eafb26f08d0304550d5aa1f612e066f0"
dependencies = [ dependencies = [
"iovec", "iovec",
"libc", "libc",
"mio", "mio 0.6.23",
] ]
[[package]] [[package]]
@ -676,6 +726,16 @@ dependencies = [
"ws2_32-sys", "ws2_32-sys",
] ]
[[package]]
name = "miow"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a33c1b55807fbed163481b5ba66db4b2fa6cde694a5027be10fb724206c5897"
dependencies = [
"socket2",
"winapi 0.3.9",
]
[[package]] [[package]]
name = "multimap" name = "multimap"
version = "0.4.0" version = "0.4.0"
@ -705,6 +765,19 @@ dependencies = [
"slog-scope", "slog-scope",
] ]
[[package]]
name = "netlink-sys"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc9e9df13fd91bdd4b92bea93d5d2848c8035677c60fc3fee5dabddc02c3012e"
dependencies = [
"futures",
"libc",
"log",
"mio 0.6.23",
"tokio 0.2.24",
]
[[package]] [[package]]
name = "nix" name = "nix"
version = "0.16.1" version = "0.16.1"
@ -761,6 +834,15 @@ version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf51a729ecf40266a2368ad335a5fdde43471f545a967109cd62146ecf8b66ff" checksum = "cf51a729ecf40266a2368ad335a5fdde43471f545a967109cd62146ecf8b66ff"
[[package]]
name = "ntapi"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44"
dependencies = [
"winapi 0.3.9",
]
[[package]] [[package]]
name = "num-integer" name = "num-integer"
version = "0.1.43" version = "0.1.43"
@ -890,6 +972,12 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c917123afa01924fc84bb20c4c03f004d9c38e5127e3c039bbf7f4b9c76a2f6b" checksum = "c917123afa01924fc84bb20c4c03f004d9c38e5127e3c039bbf7f4b9c76a2f6b"
[[package]]
name = "pin-project-lite"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b063f57ec186e6140e2b8b6921e5f1bd89c7356dda5b33acc5401203ca6131c"
[[package]] [[package]]
name = "pin-utils" name = "pin-utils"
version = "0.1.0" version = "0.1.0"
@ -1213,12 +1301,16 @@ name = "rustjail"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait",
"caps", "caps",
"cgroups-rs", "cgroups-rs",
"dirs", "dirs",
"epoll", "epoll",
"futures",
"inotify",
"lazy_static", "lazy_static",
"libc", "libc",
"mio 0.6.23",
"nix 0.17.0", "nix 0.17.0",
"oci", "oci",
"path-absolutize", "path-absolutize",
@ -1235,6 +1327,7 @@ dependencies = [
"slog", "slog",
"slog-scope", "slog-scope",
"tempfile", "tempfile",
"tokio 0.2.24",
] ]
[[package]] [[package]]
@ -1413,6 +1506,17 @@ version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbee7696b84bbf3d89a1c2eccff0850e3047ed46bfcd2e92c29a2d074d57e252" checksum = "fbee7696b84bbf3d89a1c2eccff0850e3047ed46bfcd2e92c29a2d074d57e252"
[[package]]
name = "socket2"
version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "122e570113d28d773067fab24266b66753f6ea915758651696b6e35e49f88d6e"
dependencies = [
"cfg-if 1.0.0",
"libc",
"winapi 0.3.9",
]
[[package]] [[package]]
name = "spin" name = "spin"
version = "0.5.2" version = "0.5.2"
@ -1513,12 +1617,27 @@ dependencies = [
"lazy_static", "lazy_static",
"libc", "libc",
"memchr", "memchr",
"mio", "mio 0.6.23",
"mio-named-pipes",
"mio-uds", "mio-uds",
"num_cpus", "num_cpus",
"pin-project-lite", "pin-project-lite 0.1.11",
"signal-hook-registry",
"slab", "slab",
"tokio-macros", "tokio-macros",
"winapi 0.3.9",
]
[[package]]
name = "tokio"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "720ba21c25078711bf456d607987d95bce90f7c3bea5abe1db587862e7a1e87c"
dependencies = [
"autocfg",
"libc",
"mio 0.7.6",
"pin-project-lite 0.2.0",
] ]
[[package]] [[package]]
@ -1542,17 +1661,17 @@ dependencies = [
"futures", "futures",
"iovec", "iovec",
"libc", "libc",
"mio", "mio 0.6.23",
"nix 0.17.0", "nix 0.17.0",
"tokio", "tokio 0.2.24",
"vsock", "vsock",
] ]
[[package]] [[package]]
name = "ttrpc" name = "ttrpc"
version = "0.4.13" version = "0.4.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6e99ffa09e7fbe514b58b01bd17d71e3ed4dd27c588afa43d41ec0b7fc90b0a" checksum = "fc512242eee1f113eadd48087dd97cbf807ccae4820006e7a890044044399c51"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"byteorder", "byteorder",
@ -1563,7 +1682,7 @@ dependencies = [
"protobuf", "protobuf",
"protobuf-codegen-pure", "protobuf-codegen-pure",
"thiserror", "thiserror",
"tokio", "tokio 0.2.24",
"tokio-vsock", "tokio-vsock",
] ]

View File

@ -11,7 +11,7 @@ rustjail = { path = "rustjail" }
protocols = { path = "protocols" } protocols = { path = "protocols" }
netlink = { path = "netlink", features = ["with-log", "with-agent-handler"] } netlink = { path = "netlink", features = ["with-log", "with-agent-handler"] }
lazy_static = "1.3.0" lazy_static = "1.3.0"
ttrpc = { version="0.4.13", features=["async"] } ttrpc = { version = "0.4.14", features = ["async", "protobuf-codec"], default-features = false }
protobuf = "=2.14.0" protobuf = "=2.14.0"
libc = "0.2.58" libc = "0.2.58"
nix = "0.17.0" nix = "0.17.0"
@ -21,8 +21,12 @@ signal-hook = "0.1.9"
scan_fmt = "0.2.3" scan_fmt = "0.2.3"
scopeguard = "1.0.0" scopeguard = "1.0.0"
regex = "1" regex = "1"
tokio = { version="0.2", features = ["macros", "rt-threaded"] }
async-trait = "0.1.42" async-trait = "0.1.42"
tokio = { version = "0.2", features = ["rt-core", "sync", "uds", "stream", "macros", "io-util", "time", "signal", "io-std", "process",] }
futures = "0.3"
netlink-sys = { version = "0.4.0", features = ["tokio_socket",]}
tokio-vsock = "0.2.2"
# slog: # slog:
# - Dynamic keys required to allow HashMap keys to be slog::Serialized. # - Dynamic keys required to allow HashMap keys to be slog::Serialized.
@ -40,7 +44,7 @@ tempfile = "3.1.0"
prometheus = { version = "0.9.0", features = ["process"] } prometheus = { version = "0.9.0", features = ["process"] }
procfs = "0.7.9" procfs = "0.7.9"
anyhow = "1.0.32" anyhow = "1.0.32"
cgroups = { package = "cgroups-rs", version = "0.2.0" } cgroups = { package = "cgroups-rs", version = "0.2.1" }
[workspace] [workspace]
members = [ members = [

View File

@ -5,7 +5,7 @@ authors = ["The Kata Containers community <kata-dev@lists.katacontainers.io>"]
edition = "2018" edition = "2018"
[dependencies] [dependencies]
ttrpc = { version="0.4.13", features=["async"] } ttrpc = { version = "0.4.14", features = ["async"] }
async-trait = "0.1.42" async-trait = "0.1.42"
protobuf = "=2.14.0" protobuf = "=2.14.0"

View File

@ -16,7 +16,7 @@ scopeguard = "1.0.0"
prctl = "1.0.0" prctl = "1.0.0"
lazy_static = "1.3.0" lazy_static = "1.3.0"
libc = "0.2.58" libc = "0.2.58"
protobuf = "2.8.1" protobuf = "=2.14.0"
slog = "2.5.2" slog = "2.5.2"
slog-scope = "4.1.2" slog-scope = "4.1.2"
scan_fmt = "0.2" scan_fmt = "0.2"
@ -24,9 +24,15 @@ regex = "1.1"
path-absolutize = "1.2.0" path-absolutize = "1.2.0"
dirs = "3.0.1" dirs = "3.0.1"
anyhow = "1.0.32" anyhow = "1.0.32"
cgroups = { package = "cgroups-rs", version = "0.2.0" } cgroups = { package = "cgroups-rs", version = "0.2.1" }
tempfile = "3.1.0" tempfile = "3.1.0"
epoll = "4.3.1" epoll = "4.3.1"
tokio = { version = "0.2", features = ["sync", "io-util", "process", "time", "macros"] }
futures = "0.3"
async-trait = "0.1.31"
mio = "0.6"
inotify = "0.9"
[dev-dependencies] [dev-dependencies]
serial_test = "0.5.0" serial_test = "0.5.0"

View File

@ -3,16 +3,18 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Context, Result};
use eventfd::{eventfd, EfdFlags}; use eventfd::{eventfd, EfdFlags};
use nix::sys::eventfd; use nix::sys::eventfd;
use nix::sys::inotify::{AddWatchFlags, InitFlags, Inotify};
use std::fs::{self, File}; use std::fs::{self, File};
use std::io::Read;
use std::os::unix::io::{AsRawFd, FromRawFd}; use std::os::unix::io::{AsRawFd, FromRawFd};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::mpsc::{self, Receiver};
use std::thread; use crate::pipestream::PipeStream;
use futures::StreamExt as _;
use inotify::{Inotify, WatchMask};
use tokio::io::AsyncReadExt;
use tokio::sync::mpsc::{channel, Receiver};
// Convenience macro to obtain the scope logger // Convenience macro to obtain the scope logger
macro_rules! sl { macro_rules! sl {
@ -21,11 +23,11 @@ macro_rules! sl {
}; };
} }
pub fn notify_oom(cid: &str, cg_dir: String) -> Result<Receiver<String>> { pub async fn notify_oom(cid: &str, cg_dir: String) -> Result<Receiver<String>> {
if cgroups::hierarchies::is_cgroup2_unified_mode() { if cgroups::hierarchies::is_cgroup2_unified_mode() {
return notify_on_oom_v2(cid, cg_dir); return notify_on_oom_v2(cid, cg_dir).await;
} }
notify_on_oom(cid, cg_dir) notify_on_oom(cid, cg_dir).await
} }
// get_value_from_cgroup parse cgroup file with `Flat keyed` // get_value_from_cgroup parse cgroup file with `Flat keyed`
@ -52,11 +54,11 @@ fn get_value_from_cgroup(path: &PathBuf, key: &str) -> Result<i64> {
// notify_on_oom returns channel on which you can expect event about OOM, // notify_on_oom returns channel on which you can expect event about OOM,
// if process died without OOM this channel will be closed. // if process died without OOM this channel will be closed.
pub fn notify_on_oom_v2(containere_id: &str, cg_dir: String) -> Result<Receiver<String>> { pub async fn notify_on_oom_v2(containere_id: &str, cg_dir: String) -> Result<Receiver<String>> {
register_memory_event_v2(containere_id, cg_dir, "memory.events", "cgroup.events") register_memory_event_v2(containere_id, cg_dir, "memory.events", "cgroup.events").await
} }
fn register_memory_event_v2( async fn register_memory_event_v2(
containere_id: &str, containere_id: &str,
cg_dir: String, cg_dir: String,
memory_event_name: &str, memory_event_name: &str,
@ -73,54 +75,54 @@ fn register_memory_event_v2(
"register_memory_event_v2 cgroup_event_control_path: {:?}", &cgroup_event_control_path "register_memory_event_v2 cgroup_event_control_path: {:?}", &cgroup_event_control_path
); );
let fd = Inotify::init(InitFlags::empty()).unwrap(); let mut inotify = Inotify::init().context("Failed to initialize inotify")?;
// watching oom kill // watching oom kill
let ev_fd = fd let ev_wd = inotify.add_watch(&event_control_path, WatchMask::MODIFY)?;
.add_watch(&event_control_path, AddWatchFlags::IN_MODIFY)
.unwrap();
// Because no `unix.IN_DELETE|unix.IN_DELETE_SELF` event for cgroup file system, so watching all process exited // Because no `unix.IN_DELETE|unix.IN_DELETE_SELF` event for cgroup file system, so watching all process exited
let cg_fd = fd let cg_wd = inotify.add_watch(&cgroup_event_control_path, WatchMask::MODIFY)?;
.add_watch(&cgroup_event_control_path, AddWatchFlags::IN_MODIFY)
.unwrap();
info!(sl!(), "ev_fd: {:?}", ev_fd);
info!(sl!(), "cg_fd: {:?}", cg_fd);
let (sender, receiver) = mpsc::channel(); info!(sl!(), "ev_wd: {:?}", ev_wd);
info!(sl!(), "cg_wd: {:?}", cg_wd);
let (mut sender, receiver) = channel(100);
let containere_id = containere_id.to_string(); let containere_id = containere_id.to_string();
thread::spawn(move || { tokio::spawn(async move {
loop { let mut buffer = [0; 32];
let events = fd.read_events().unwrap(); let mut stream = inotify
.event_stream(&mut buffer)
.expect("create inotify event stream failed");
while let Some(event_or_error) = stream.next().await {
let event = event_or_error.unwrap();
info!( info!(
sl!(), sl!(),
"container[{}] get events for container: {:?}", &containere_id, &events "container[{}] get event for container: {:?}", &containere_id, &event
); );
// info!("is1: {}", event.wd == wd1);
info!(sl!(), "event.wd: {:?}", event.wd);
for event in events { if event.wd == ev_wd {
if event.mask & AddWatchFlags::IN_MODIFY != AddWatchFlags::IN_MODIFY { let oom = get_value_from_cgroup(&event_control_path, "oom_kill");
continue; if oom.unwrap_or(0) > 0 {
let _ = sender.send(containere_id.clone()).await.map_err(|e| {
error!(sl!(), "send containere_id failed, error: {:?}", e);
});
return;
} }
info!(sl!(), "event.wd: {:?}", event.wd); } else if event.wd == cg_wd {
let pids = get_value_from_cgroup(&cgroup_event_control_path, "populated");
if event.wd == ev_fd { if pids.unwrap_or(-1) == 0 {
let oom = get_value_from_cgroup(&event_control_path, "oom_kill"); return;
if oom.unwrap_or(0) > 0 {
sender.send(containere_id.clone()).unwrap();
return;
}
} else if event.wd == cg_fd {
let pids = get_value_from_cgroup(&cgroup_event_control_path, "populated");
if pids.unwrap_or(-1) == 0 {
return;
}
} }
} }
// When a cgroup is destroyed, an event is sent to eventfd. }
// So if the control path is gone, return instead of notifying.
if !Path::new(&event_control_path).exists() { // When a cgroup is destroyed, an event is sent to eventfd.
return; // So if the control path is gone, return instead of notifying.
} if !Path::new(&event_control_path).exists() {
return;
} }
}); });
@ -129,16 +131,16 @@ fn register_memory_event_v2(
// notify_on_oom returns channel on which you can expect event about OOM, // notify_on_oom returns channel on which you can expect event about OOM,
// if process died without OOM this channel will be closed. // if process died without OOM this channel will be closed.
fn notify_on_oom(cid: &str, dir: String) -> Result<Receiver<String>> { async fn notify_on_oom(cid: &str, dir: String) -> Result<Receiver<String>> {
if dir == "" { if dir == "" {
return Err(anyhow!("memory controller missing")); return Err(anyhow!("memory controller missing"));
} }
register_memory_event(cid, dir, "memory.oom_control", "") register_memory_event(cid, dir, "memory.oom_control", "").await
} }
// level is one of "low", "medium", or "critical" // level is one of "low", "medium", or "critical"
fn notify_memory_pressure(cid: &str, dir: String, level: &str) -> Result<Receiver<String>> { async fn notify_memory_pressure(cid: &str, dir: String, level: &str) -> Result<Receiver<String>> {
if dir == "" { if dir == "" {
return Err(anyhow!("memory controller missing")); return Err(anyhow!("memory controller missing"));
} }
@ -147,10 +149,10 @@ fn notify_memory_pressure(cid: &str, dir: String, level: &str) -> Result<Receive
return Err(anyhow!("invalid pressure level {}", level)); return Err(anyhow!("invalid pressure level {}", level));
} }
register_memory_event(cid, dir, "memory.pressure_level", level) register_memory_event(cid, dir, "memory.pressure_level", level).await
} }
fn register_memory_event( async fn register_memory_event(
cid: &str, cid: &str,
cg_dir: String, cg_dir: String,
event_name: &str, event_name: &str,
@ -171,15 +173,16 @@ fn register_memory_event(
fs::write(&event_control_path, data)?; fs::write(&event_control_path, data)?;
let mut eventfd_file = unsafe { File::from_raw_fd(eventfd) }; let mut eventfd_stream = unsafe { PipeStream::from_raw_fd(eventfd) };
let (sender, receiver) = mpsc::channel(); let (sender, receiver) = tokio::sync::mpsc::channel(100);
let containere_id = cid.to_string(); let containere_id = cid.to_string();
thread::spawn(move || { tokio::spawn(async move {
loop { loop {
let mut buf = [0; 8]; let mut sender = sender.clone();
match eventfd_file.read(&mut buf) { let mut buf = [0u8; 8];
match eventfd_stream.read(&mut buf).await {
Err(err) => { Err(err) => {
warn!(sl!(), "failed to read from eventfd: {:?}", err); warn!(sl!(), "failed to read from eventfd: {:?}", err);
return; return;
@ -198,7 +201,10 @@ fn register_memory_event(
if !Path::new(&event_control_path).exists() { if !Path::new(&event_control_path).exists() {
return; return;
} }
sender.send(containere_id.clone()).unwrap();
let _ = sender.send(containere_id.clone()).await.map_err(|e| {
error!(sl!(), "send containere_id failed, error: {:?}", e);
});
} }
}); });

View File

@ -14,7 +14,6 @@ use std::fmt::Display;
use std::fs; use std::fs;
use std::os::unix::io::RawFd; use std::os::unix::io::RawFd;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::Command;
use std::time::SystemTime; use std::time::SystemTime;
use cgroups::freezer::FreezerState; use cgroups::freezer::FreezerState;
@ -28,7 +27,6 @@ use crate::cgroups::Manager;
use crate::log_child; use crate::log_child;
use crate::process::Process; use crate::process::Process;
use crate::specconv::CreateOpts; use crate::specconv::CreateOpts;
use crate::sync::*;
use crate::{mount, validator}; use crate::{mount, validator};
use protocols::agent::StatsContainerResponse; use protocols::agent::StatsContainerResponse;
@ -49,12 +47,16 @@ use protobuf::SingularPtrField;
use oci::State as OCIState; use oci::State as OCIState;
use std::collections::HashMap; use std::collections::HashMap;
use std::io::BufRead;
use std::io::BufReader;
use std::os::unix::io::FromRawFd; use std::os::unix::io::FromRawFd;
use slog::{info, o, Logger}; use slog::{info, o, Logger};
use crate::pipestream::PipeStream;
use crate::sync::{read_sync, write_count, write_sync, SYNC_DATA, SYNC_FAILED, SYNC_SUCCESS};
use crate::sync_with_async::{read_async, write_async};
use async_trait::async_trait;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
const STATE_FILENAME: &str = "state.json"; const STATE_FILENAME: &str = "state.json";
const EXEC_FIFO_FILENAME: &str = "exec.fifo"; const EXEC_FIFO_FILENAME: &str = "exec.fifo";
const VER_MARKER: &str = "1.2.5"; const VER_MARKER: &str = "1.2.5";
@ -215,6 +217,7 @@ pub struct BaseState {
init_process_start: u64, init_process_start: u64,
} }
#[async_trait]
pub trait BaseContainer { pub trait BaseContainer {
fn id(&self) -> String; fn id(&self) -> String;
fn status(&self) -> Status; fn status(&self) -> Status;
@ -225,9 +228,9 @@ pub trait BaseContainer {
fn get_process(&mut self, eid: &str) -> Result<&mut Process>; fn get_process(&mut self, eid: &str) -> Result<&mut Process>;
fn stats(&self) -> Result<StatsContainerResponse>; fn stats(&self) -> Result<StatsContainerResponse>;
fn set(&mut self, config: LinuxResources) -> Result<()>; fn set(&mut self, config: LinuxResources) -> Result<()>;
fn start(&mut self, p: Process) -> Result<()>; async fn start(&mut self, p: Process) -> Result<()>;
fn run(&mut self, p: Process) -> Result<()>; async fn run(&mut self, p: Process) -> Result<()>;
fn destroy(&mut self) -> Result<()>; async fn destroy(&mut self) -> Result<()>;
fn signal(&self, sig: Signal, all: bool) -> Result<()>; fn signal(&self, sig: Signal, all: bool) -> Result<()>;
fn exec(&mut self) -> Result<()>; fn exec(&mut self) -> Result<()>;
} }
@ -273,6 +276,7 @@ pub struct SyncPC {
pid: pid_t, pid: pid_t,
} }
#[async_trait]
pub trait Container: BaseContainer { pub trait Container: BaseContainer {
fn pause(&mut self) -> Result<()>; fn pause(&mut self) -> Result<()>;
fn resume(&mut self) -> Result<()>; fn resume(&mut self) -> Result<()>;
@ -723,6 +727,7 @@ fn set_stdio_permissions(uid: libc::uid_t) -> Result<()> {
Ok(()) Ok(())
} }
#[async_trait]
impl BaseContainer for LinuxContainer { impl BaseContainer for LinuxContainer {
fn id(&self) -> String { fn id(&self) -> String {
self.id.clone() self.id.clone()
@ -816,7 +821,7 @@ impl BaseContainer for LinuxContainer {
Ok(()) Ok(())
} }
fn start(&mut self, mut p: Process) -> Result<()> { async fn start(&mut self, mut p: Process) -> Result<()> {
let logger = self.logger.new(o!("eid" => p.exec_id.clone())); let logger = self.logger.new(o!("eid" => p.exec_id.clone()));
let tty = p.tty; let tty = p.tty;
let fifo_file = format!("{}/{}", &self.root, EXEC_FIFO_FILENAME); let fifo_file = format!("{}/{}", &self.root, EXEC_FIFO_FILENAME);
@ -854,7 +859,7 @@ impl BaseContainer for LinuxContainer {
.map_err(|e| warn!(logger, "fcntl pfd log FD_CLOEXEC {:?}", e)); .map_err(|e| warn!(logger, "fcntl pfd log FD_CLOEXEC {:?}", e));
let child_logger = logger.new(o!("action" => "child process log")); let child_logger = logger.new(o!("action" => "child process log"));
let log_handler = setup_child_logger(pfd_log, child_logger)?; let log_handler = setup_child_logger(pfd_log, child_logger);
let (prfd, cwfd) = unistd::pipe().context("failed to create pipe")?; let (prfd, cwfd) = unistd::pipe().context("failed to create pipe")?;
let (crfd, pwfd) = unistd::pipe().context("failed to create pipe")?; let (crfd, pwfd) = unistd::pipe().context("failed to create pipe")?;
@ -865,10 +870,8 @@ impl BaseContainer for LinuxContainer {
let _ = fcntl::fcntl(pwfd, FcntlArg::F_SETFD(FdFlag::FD_CLOEXEC)) let _ = fcntl::fcntl(pwfd, FcntlArg::F_SETFD(FdFlag::FD_CLOEXEC))
.map_err(|e| warn!(logger, "fcntl pwfd FD_COLEXEC {:?}", e)); .map_err(|e| warn!(logger, "fcntl pwfd FD_COLEXEC {:?}", e));
defer!({ let mut pipe_r = PipeStream::from_fd(prfd);
let _ = unistd::close(prfd).map_err(|e| warn!(logger, "close prfd {:?}", e)); let mut pipe_w = PipeStream::from_fd(pwfd);
let _ = unistd::close(pwfd).map_err(|e| warn!(logger, "close pwfd {:?}", e));
});
let child_stdin: std::process::Stdio; let child_stdin: std::process::Stdio;
let child_stdout: std::process::Stdio; let child_stdout: std::process::Stdio;
@ -928,7 +931,7 @@ impl BaseContainer for LinuxContainer {
unistd::close(cfd_log)?; unistd::close(cfd_log)?;
// get container process's pid // get container process's pid
let pid_buf = read_sync(prfd)?; let pid_buf = read_async(&mut pipe_r).await?;
let pid_str = std::str::from_utf8(&pid_buf).context("get pid string")?; let pid_str = std::str::from_utf8(&pid_buf).context("get pid string")?;
let pid = match pid_str.parse::<i32>() { let pid = match pid_str.parse::<i32>() {
Ok(i) => i, Ok(i) => i,
@ -958,9 +961,10 @@ impl BaseContainer for LinuxContainer {
&p, &p,
self.cgroup_manager.as_ref().unwrap(), self.cgroup_manager.as_ref().unwrap(),
&st, &st,
pwfd, &mut pipe_w,
prfd, &mut pipe_r,
) )
.await
.map_err(|e| { .map_err(|e| {
error!(logger, "create container process error {:?}", e); error!(logger, "create container process error {:?}", e);
// kill the child process. // kill the child process.
@ -995,15 +999,15 @@ impl BaseContainer for LinuxContainer {
info!(logger, "wait on child log handler"); info!(logger, "wait on child log handler");
let _ = log_handler let _ = log_handler
.join() .await
.map_err(|e| warn!(logger, "joining log handler {:?}", e)); .map_err(|e| warn!(logger, "joining log handler {:?}", e));
info!(logger, "create process completed"); info!(logger, "create process completed");
Ok(()) Ok(())
} }
fn run(&mut self, p: Process) -> Result<()> { async fn run(&mut self, p: Process) -> Result<()> {
let init = p.init; let init = p.init;
self.start(p)?; self.start(p).await?;
if init { if init {
self.exec()?; self.exec()?;
@ -1013,7 +1017,7 @@ impl BaseContainer for LinuxContainer {
Ok(()) Ok(())
} }
fn destroy(&mut self) -> Result<()> { async fn destroy(&mut self) -> Result<()> {
let spec = self.config.spec.as_ref().unwrap(); let spec = self.config.spec.as_ref().unwrap();
let st = self.oci_state()?; let st = self.oci_state()?;
@ -1025,7 +1029,7 @@ impl BaseContainer for LinuxContainer {
info!(self.logger, "poststop"); info!(self.logger, "poststop");
let hooks = spec.hooks.as_ref().unwrap(); let hooks = spec.hooks.as_ref().unwrap();
for h in hooks.poststop.iter() { for h in hooks.poststop.iter() {
execute_hook(&self.logger, h, &st)?; execute_hook(&self.logger, h, &st).await?;
} }
} }
@ -1177,42 +1181,38 @@ fn get_namespaces(linux: &Linux) -> Vec<LinuxNamespace> {
.collect() .collect()
} }
pub fn setup_child_logger(fd: RawFd, child_logger: Logger) -> Result<std::thread::JoinHandle<()>> { pub fn setup_child_logger(fd: RawFd, child_logger: Logger) -> tokio::task::JoinHandle<()> {
let builder = thread::Builder::new(); tokio::spawn(async move {
builder let log_file_stream = PipeStream::from_fd(fd);
.spawn(move || { let buf_reader_stream = tokio::io::BufReader::new(log_file_stream);
let log_file = unsafe { std::fs::File::from_raw_fd(fd) }; let mut lines = buf_reader_stream.lines();
let mut reader = BufReader::new(log_file);
loop { loop {
let mut line = String::new(); match lines.next_line().await {
match reader.read_line(&mut line) { Err(e) => {
Err(e) => { info!(child_logger, "read child process log error: {:?}", e);
info!(child_logger, "read child process log error: {:?}", e); break;
break; }
} Ok(Some(line)) => {
Ok(count) => { info!(child_logger, "{}", line);
if count == 0 { }
info!(child_logger, "read child process log end",); Ok(None) => {
break; info!(child_logger, "read child process log end",);
} break;
info!(child_logger, "{}", line);
}
} }
} }
}) }
.map_err(|e| anyhow!(e).context("failed to create thread")) })
} }
fn join_namespaces( async fn join_namespaces(
logger: &Logger, logger: &Logger,
spec: &Spec, spec: &Spec,
p: &Process, p: &Process,
cm: &FsManager, cm: &FsManager,
st: &OCIState, st: &OCIState,
pwfd: RawFd, pipe_w: &mut PipeStream,
prfd: RawFd, pipe_r: &mut PipeStream,
) -> Result<()> { ) -> Result<()> {
let logger = logger.new(o!("action" => "join-namespaces")); let logger = logger.new(o!("action" => "join-namespaces"));
@ -1223,25 +1223,25 @@ fn join_namespaces(
info!(logger, "try to send spec from parent to child"); info!(logger, "try to send spec from parent to child");
let spec_str = serde_json::to_string(spec)?; let spec_str = serde_json::to_string(spec)?;
write_sync(pwfd, SYNC_DATA, spec_str.as_str())?; write_async(pipe_w, SYNC_DATA, spec_str.as_str()).await?;
info!(logger, "wait child received oci spec"); info!(logger, "wait child received oci spec");
read_sync(prfd)?; read_async(pipe_r).await?;
info!(logger, "send oci process from parent to child"); info!(logger, "send oci process from parent to child");
let process_str = serde_json::to_string(&p.oci)?; let process_str = serde_json::to_string(&p.oci)?;
write_sync(pwfd, SYNC_DATA, process_str.as_str())?; write_async(pipe_w, SYNC_DATA, process_str.as_str()).await?;
info!(logger, "wait child received oci process"); info!(logger, "wait child received oci process");
read_sync(prfd)?; read_async(pipe_r).await?;
let cm_str = serde_json::to_string(cm)?; let cm_str = serde_json::to_string(cm)?;
write_sync(pwfd, SYNC_DATA, cm_str.as_str())?; write_async(pipe_w, SYNC_DATA, cm_str.as_str()).await?;
// wait child setup user namespace // wait child setup user namespace
info!(logger, "wait child setup user namespace"); info!(logger, "wait child setup user namespace");
read_sync(prfd)?; read_async(pipe_r).await?;
if userns { if userns {
info!(logger, "setup uid/gid mappings"); info!(logger, "setup uid/gid mappings");
@ -1270,11 +1270,11 @@ fn join_namespaces(
info!(logger, "notify child to continue"); info!(logger, "notify child to continue");
// notify child to continue // notify child to continue
write_sync(pwfd, SYNC_SUCCESS, "")?; write_async(pipe_w, SYNC_SUCCESS, "").await?;
if p.init { if p.init {
info!(logger, "notify child parent ready to run prestart hook!"); info!(logger, "notify child parent ready to run prestart hook!");
let _ = read_sync(prfd)?; let _ = read_async(pipe_r).await?;
info!(logger, "get ready to run prestart hook!"); info!(logger, "get ready to run prestart hook!");
@ -1283,17 +1283,17 @@ fn join_namespaces(
info!(logger, "prestart hook"); info!(logger, "prestart hook");
let hooks = spec.hooks.as_ref().unwrap(); let hooks = spec.hooks.as_ref().unwrap();
for h in hooks.prestart.iter() { for h in hooks.prestart.iter() {
execute_hook(&logger, h, st)?; execute_hook(&logger, h, st).await?;
} }
} }
// notify child run prestart hooks completed // notify child run prestart hooks completed
info!(logger, "notify child run prestart hook completed!"); info!(logger, "notify child run prestart hook completed!");
write_sync(pwfd, SYNC_SUCCESS, "")?; write_async(pipe_w, SYNC_SUCCESS, "").await?;
info!(logger, "notify child parent ready to run poststart hook!"); info!(logger, "notify child parent ready to run poststart hook!");
// wait to run poststart hook // wait to run poststart hook
read_sync(prfd)?; read_async(pipe_r).await?;
info!(logger, "get ready to run poststart hook!"); info!(logger, "get ready to run poststart hook!");
// run poststart hook // run poststart hook
@ -1301,13 +1301,13 @@ fn join_namespaces(
info!(logger, "poststart hook"); info!(logger, "poststart hook");
let hooks = spec.hooks.as_ref().unwrap(); let hooks = spec.hooks.as_ref().unwrap();
for h in hooks.poststart.iter() { for h in hooks.poststart.iter() {
execute_hook(&logger, h, st)?; execute_hook(&logger, h, st).await?;
} }
} }
} }
info!(logger, "wait for child process ready to run exec"); info!(logger, "wait for child process ready to run exec");
read_sync(prfd)?; read_async(pipe_r).await?;
Ok(()) Ok(())
} }
@ -1509,14 +1509,11 @@ fn set_sysctls(sysctls: &HashMap<String, String>) -> Result<()> {
Ok(()) Ok(())
} }
use std::io::Read;
use std::os::unix::process::ExitStatusExt; use std::os::unix::process::ExitStatusExt;
use std::process::Stdio; use std::process::Stdio;
use std::sync::mpsc::{self, RecvTimeoutError};
use std::thread;
use std::time::Duration; use std::time::Duration;
fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { async fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
let logger = logger.new(o!("action" => "execute-hook")); let logger = logger.new(o!("action" => "execute-hook"));
let binary = PathBuf::from(h.path.as_str()); let binary = PathBuf::from(h.path.as_str());
@ -1535,9 +1532,12 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
let _ = unistd::close(wfd); let _ = unistd::close(wfd);
}); });
let mut pipe_r = PipeStream::from_fd(rfd);
let mut pipe_w = PipeStream::from_fd(wfd);
match unistd::fork()? { match unistd::fork()? {
ForkResult::Parent { child } => { ForkResult::Parent { child } => {
let buf = read_sync(rfd)?; let buf = read_async(&mut pipe_r).await?;
let status = if buf.len() == 4 { let status = if buf.len() == 4 {
let buf_array: [u8; 4] = [buf[0], buf[1], buf[2], buf[3]]; let buf_array: [u8; 4] = [buf[0], buf[1], buf[2], buf[3]];
i32::from_be_bytes(buf_array) i32::from_be_bytes(buf_array)
@ -1561,13 +1561,13 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
} }
ForkResult::Child => { ForkResult::Child => {
let (tx, rx) = mpsc::channel(); let (mut tx, mut rx) = tokio::sync::mpsc::channel(100);
let (tx_logger, rx_logger) = mpsc::channel(); let (tx_logger, rx_logger) = tokio::sync::oneshot::channel();
tx_logger.send(logger.clone()).unwrap(); tx_logger.send(logger.clone()).unwrap();
let handle = thread::spawn(move || { let handle = tokio::spawn(async move {
let logger = rx_logger.recv().unwrap(); let logger = rx_logger.await.unwrap();
// write oci state to child // write oci state to child
let env: HashMap<String, String> = envs let env: HashMap<String, String> = envs
@ -1578,7 +1578,7 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
}) })
.collect(); .collect();
let mut child = Command::new(path.to_str().unwrap()) let mut child = tokio::process::Command::new(path.to_str().unwrap())
.args(args.iter()) .args(args.iter())
.envs(env.iter()) .envs(env.iter())
.stdin(Stdio::piped()) .stdin(Stdio::piped())
@ -1588,7 +1588,7 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
.unwrap(); .unwrap();
// send out our pid // send out our pid
tx.send(child.id() as libc::pid_t).unwrap(); tx.send(child.id() as libc::pid_t).await.unwrap();
info!(logger, "hook grand: {}", child.id()); info!(logger, "hook grand: {}", child.id());
child child
@ -1596,6 +1596,7 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
.as_mut() .as_mut()
.unwrap() .unwrap()
.write_all(state.as_bytes()) .write_all(state.as_bytes())
.await
.unwrap(); .unwrap();
// read something from stdout for debug // read something from stdout for debug
@ -1605,9 +1606,10 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
.as_mut() .as_mut()
.unwrap() .unwrap()
.read_to_string(&mut out) .read_to_string(&mut out)
.await
.unwrap(); .unwrap();
info!(logger, "child stdout: {}", out.as_str()); info!(logger, "child stdout: {}", out.as_str());
match child.wait() { match child.await {
Ok(exit) => { Ok(exit) => {
let code: i32 = if exit.success() { let code: i32 = if exit.success() {
0 0
@ -1618,7 +1620,7 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
} }
}; };
tx.send(code).unwrap(); tx.send(code).await.unwrap();
} }
Err(e) => { Err(e) => {
@ -1638,29 +1640,33 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
// -- FIXME // -- FIXME
// just in case. Should not happen any more // just in case. Should not happen any more
tx.send(0).unwrap(); tx.send(0).await.unwrap();
} }
} }
}); });
let pid = rx.recv().unwrap(); let pid = rx.recv().await.unwrap();
info!(logger, "hook grand: {}", pid); info!(logger, "hook grand: {}", pid);
let status = { let status = {
if let Some(timeout) = h.timeout { if let Some(timeout) = h.timeout {
match rx.recv_timeout(Duration::from_secs(timeout as u64)) { let timeout = tokio::time::delay_for(Duration::from_secs(timeout as u64));
Ok(s) => s, tokio::select! {
Err(e) => { v = rx.recv() => {
let error = if e == RecvTimeoutError::Timeout { match v {
-libc::ETIMEDOUT Some(s) => s,
} else { None => {
-libc::EPIPE let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL));
}; -libc::EPIPE
}
}
}
_ = timeout => {
let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL)); let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL));
error -libc::ETIMEDOUT
} }
} }
} else if let Ok(s) = rx.recv() { } else if let Some(s) = rx.recv().await {
s s
} else { } else {
let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL)); let _ = signal::kill(Pid::from_raw(pid), Some(Signal::SIGKILL));
@ -1668,12 +1674,13 @@ fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> {
} }
}; };
handle.join().unwrap(); handle.await.unwrap();
let _ = write_sync( let _ = write_async(
wfd, &mut pipe_w,
SYNC_DATA, SYNC_DATA,
std::str::from_utf8(&status.to_be_bytes()).unwrap_or_default(), std::str::from_utf8(&status.to_be_bytes()).unwrap_or_default(),
); )
.await;
std::process::exit(0); std::process::exit(0);
} }
} }

View File

@ -40,10 +40,12 @@ pub mod capabilities;
pub mod cgroups; pub mod cgroups;
pub mod container; pub mod container;
pub mod mount; pub mod mount;
pub mod pipestream;
pub mod process; pub mod process;
pub mod reaper; pub mod reaper;
pub mod specconv; pub mod specconv;
pub mod sync; pub mod sync;
pub mod sync_with_async;
pub mod validator; pub mod validator;
// pub mod factory; // pub mod factory;

View File

@ -0,0 +1,159 @@
// Copyright (c) 2020 Ant Group
//
// SPDX-License-Identifier: Apache-2.0
//
//! Async support for pipe or something has file descriptor
use std::{
fmt, io, mem,
os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
pin::Pin,
task::{Context, Poll},
};
use mio::event::Evented;
use mio::unix::EventedFd;
use mio::{Poll as MioPoll, PollOpt, Ready, Token};
use nix::unistd;
use tokio::io::{AsyncRead, AsyncWrite, PollEvented};
unsafe fn set_nonblocking(fd: RawFd) {
libc::fcntl(fd, libc::F_SETFL, libc::O_NONBLOCK);
}
struct StreamFd(RawFd);
impl Evented for StreamFd {
fn register(
&self,
poll: &MioPoll,
token: Token,
interest: Ready,
opts: PollOpt,
) -> io::Result<()> {
EventedFd(&self.0).register(poll, token, interest, opts)
}
fn reregister(
&self,
poll: &MioPoll,
token: Token,
interest: Ready,
opts: PollOpt,
) -> io::Result<()> {
EventedFd(&self.0).reregister(poll, token, interest, opts)
}
fn deregister(&self, poll: &MioPoll) -> io::Result<()> {
EventedFd(&self.0).deregister(poll)
}
}
impl io::Read for StreamFd {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match unistd::read(self.0, buf) {
Ok(l) => Ok(l),
Err(e) => Err(e.as_errno().unwrap().into()),
}
}
}
impl io::Write for StreamFd {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match unistd::write(self.0, buf) {
Ok(l) => Ok(l),
Err(e) => Err(e.as_errno().unwrap().into()),
}
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl StreamFd {
fn close(&mut self) -> io::Result<()> {
match unistd::close(self.0) {
Ok(()) => Ok(()),
Err(e) => Err(e.as_errno().unwrap().into()),
}
}
}
impl Drop for StreamFd {
fn drop(&mut self) {
self.close().ok();
}
}
/// Pipe read
pub struct PipeStream(PollEvented<StreamFd>);
impl PipeStream {
pub fn shutdown(&mut self) -> io::Result<()> {
self.0.get_mut().close()
}
pub fn from_fd(fd: RawFd) -> Self {
unsafe { Self::from_raw_fd(fd) }
}
}
impl AsyncRead for PipeStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
impl AsRawFd for PipeStream {
fn as_raw_fd(&self) -> RawFd {
self.0.get_ref().0
}
}
impl IntoRawFd for PipeStream {
fn into_raw_fd(self) -> RawFd {
let fd = self.0.get_ref().0;
mem::forget(self);
fd
}
}
impl FromRawFd for PipeStream {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
set_nonblocking(fd);
PipeStream(PollEvented::new(StreamFd(fd)).unwrap())
}
}
impl fmt::Debug for PipeStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PipeStream({})", self.as_raw_fd())
}
}
impl AsyncWrite for PipeStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}

View File

@ -6,7 +6,7 @@
use libc::pid_t; use libc::pid_t;
use std::fs::File; use std::fs::File;
use std::os::unix::io::RawFd; use std::os::unix::io::RawFd;
use std::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use nix::fcntl::{fcntl, FcntlArg, OFlag}; use nix::fcntl::{fcntl, FcntlArg, OFlag};
use nix::sys::signal::{self, Signal}; use nix::sys::signal::{self, Signal};
@ -18,6 +18,27 @@ use crate::reaper::Epoller;
use oci::Process as OCIProcess; use oci::Process as OCIProcess;
use slog::Logger; use slog::Logger;
use crate::pipestream::PipeStream;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::{split, ReadHalf, WriteHalf};
use tokio::sync::Mutex;
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub enum StreamType {
Stdin,
Stdout,
Stderr,
ExitPipeR,
TermMaster,
ParentStdin,
ParentStdout,
ParentStderr,
}
type Reader = Arc<Mutex<ReadHalf<PipeStream>>>;
type Writer = Arc<Mutex<WriteHalf<PipeStream>>>;
#[derive(Debug)] #[derive(Debug)]
pub struct Process { pub struct Process {
pub exec_id: String, pub exec_id: String,
@ -42,6 +63,9 @@ pub struct Process {
pub oci: OCIProcess, pub oci: OCIProcess,
pub logger: Logger, pub logger: Logger,
pub epoller: Option<Epoller>, pub epoller: Option<Epoller>,
readers: HashMap<StreamType, Reader>,
writers: HashMap<StreamType, Writer>,
} }
pub trait ProcessOperations { pub trait ProcessOperations {
@ -94,6 +118,8 @@ impl Process {
oci: ocip.clone(), oci: ocip.clone(),
logger: logger.clone(), logger: logger.clone(),
epoller: None, epoller: None,
readers: HashMap::new(),
writers: HashMap::new(),
}; };
info!(logger, "before create console socket!"); info!(logger, "before create console socket!");
@ -138,8 +164,59 @@ impl Process {
} }
Ok(()) Ok(())
} }
fn get_fd(&self, stream_type: &StreamType) -> Option<RawFd> {
match stream_type {
StreamType::Stdin => self.stdin,
StreamType::Stdout => self.stdout,
StreamType::Stderr => self.stderr,
StreamType::ExitPipeR => self.exit_pipe_r,
StreamType::TermMaster => self.term_master,
StreamType::ParentStdin => self.parent_stdin,
StreamType::ParentStdout => self.parent_stdout,
StreamType::ParentStderr => self.parent_stderr,
}
}
fn get_stream_and_store(&mut self, stream_type: StreamType) -> Option<(Reader, Writer)> {
let fd = self.get_fd(&stream_type)?;
let stream = PipeStream::from_fd(fd);
let (reader, writer) = split(stream);
let reader = Arc::new(Mutex::new(reader));
let writer = Arc::new(Mutex::new(writer));
self.readers.insert(stream_type.clone(), reader.clone());
self.writers.insert(stream_type, writer.clone());
Some((reader, writer))
}
pub fn get_reader(&mut self, stream_type: StreamType) -> Option<Reader> {
if let Some(reader) = self.readers.get(&stream_type) {
return Some(reader.clone());
}
let (reader, _) = self.get_stream_and_store(stream_type)?;
Some(reader)
}
pub fn get_writer(&mut self, stream_type: StreamType) -> Option<Writer> {
if let Some(writer) = self.writers.get(&stream_type) {
return Some(writer.clone());
}
let (_, writer) = self.get_stream_and_store(stream_type)?;
Some(writer)
}
pub fn close_stream(&mut self, stream_type: StreamType) {
let _ = self.readers.remove(&stream_type);
let _ = self.writers.remove(&stream_type);
}
} }
fn create_extended_pipe(flags: OFlag, pipe_size: i32) -> Result<(RawFd, RawFd)> { fn create_extended_pipe(flags: OFlag, pipe_size: i32) -> Result<(RawFd, RawFd)> {
let (r, w) = unistd::pipe2(flags)?; let (r, w) = unistd::pipe2(flags)?;
if pipe_size > 0 { if pipe_size > 0 {

View File

@ -14,8 +14,8 @@ pub const SYNC_SUCCESS: i32 = 1;
pub const SYNC_FAILED: i32 = 2; pub const SYNC_FAILED: i32 = 2;
pub const SYNC_DATA: i32 = 3; pub const SYNC_DATA: i32 = 3;
const DATA_SIZE: usize = 100; pub const DATA_SIZE: usize = 100;
const MSG_SIZE: usize = mem::size_of::<i32>(); pub const MSG_SIZE: usize = mem::size_of::<i32>();
#[macro_export] #[macro_export]
macro_rules! log_child { macro_rules! log_child {

View File

@ -0,0 +1,148 @@
// Copyright (c) 2020 Ant Group
//
// SPDX-License-Identifier: Apache-2.0
//
//! The async version of sync module used for IPC
use crate::pipestream::PipeStream;
use anyhow::{anyhow, Result};
use nix::errno::Errno;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::sync::{DATA_SIZE, MSG_SIZE, SYNC_DATA, SYNC_FAILED, SYNC_SUCCESS};
async fn write_count(pipe_w: &mut PipeStream, buf: &[u8], count: usize) -> Result<usize> {
let mut len = 0;
loop {
match pipe_w.write(&buf[len..]).await {
Ok(l) => {
len += l;
if len == count {
break;
}
}
Err(e) => {
if e.raw_os_error().unwrap() != Errno::EINTR as i32 {
return Err(e.into());
}
}
}
}
Ok(len)
}
async fn read_count(pipe_r: &mut PipeStream, count: usize) -> Result<Vec<u8>> {
let mut v: Vec<u8> = vec![0; count];
let mut len = 0;
loop {
match pipe_r.read(&mut v[len..]).await {
Ok(l) => {
len += l;
if len == count || l == 0 {
break;
}
}
Err(e) => {
if e.raw_os_error().unwrap() != Errno::EINTR as i32 {
return Err(e.into());
}
}
}
}
Ok(v[0..len].to_vec())
}
pub async fn read_async(pipe_r: &mut PipeStream) -> Result<Vec<u8>> {
let buf = read_count(pipe_r, MSG_SIZE).await?;
if buf.len() != MSG_SIZE {
return Err(anyhow!(
"process: {} failed to receive async message from peer: got msg length: {}, expected: {}",
std::process::id(),
buf.len(),
MSG_SIZE
));
}
let buf_array: [u8; MSG_SIZE] = [buf[0], buf[1], buf[2], buf[3]];
let msg: i32 = i32::from_be_bytes(buf_array);
match msg {
SYNC_SUCCESS => Ok(Vec::new()),
SYNC_DATA => {
let buf = read_count(pipe_r, MSG_SIZE).await?;
let buf_array: [u8; MSG_SIZE] = [buf[0], buf[1], buf[2], buf[3]];
let msg_length: i32 = i32::from_be_bytes(buf_array);
let data_buf = read_count(pipe_r, msg_length as usize).await?;
Ok(data_buf)
}
SYNC_FAILED => {
let mut error_buf = vec![];
loop {
let buf = read_count(pipe_r, DATA_SIZE).await?;
error_buf.extend(&buf);
if DATA_SIZE == buf.len() {
continue;
} else {
break;
}
}
let error_str = match std::str::from_utf8(&error_buf) {
Ok(v) => String::from(v),
Err(e) => {
return Err(
anyhow!(e).context("receive error message from child process failed")
);
}
};
Err(anyhow!(error_str))
}
_ => Err(anyhow!("error in receive sync message")),
}
}
pub async fn write_async(pipe_w: &mut PipeStream, msg_type: i32, data_str: &str) -> Result<()> {
let buf = msg_type.to_be_bytes();
let count = write_count(pipe_w, &buf, MSG_SIZE).await?;
if count != MSG_SIZE {
return Err(anyhow!("error in send sync message"));
}
match msg_type {
SYNC_FAILED => match write_count(pipe_w, data_str.as_bytes(), data_str.len()).await {
Ok(_) => pipe_w.shutdown()?,
Err(e) => {
pipe_w.shutdown()?;
return Err(anyhow!(e).context("error in send message to process"));
}
},
SYNC_DATA => {
let length: i32 = data_str.len() as i32;
write_count(pipe_w, &length.to_be_bytes(), MSG_SIZE)
.await
.or_else(|e| {
pipe_w.shutdown()?;
Err(anyhow!(e).context("error in send message to process"))
})?;
write_count(pipe_w, data_str.as_bytes(), data_str.len())
.await
.or_else(|e| {
pipe_w.shutdown()?;
Err(anyhow!(e).context("error in send message to process"))
})?;
}
_ => (),
};
Ok(())
}

View File

@ -9,7 +9,8 @@ use std::collections::HashMap;
use std::fs; use std::fs;
use std::os::unix::fs::MetadataExt; use std::os::unix::fs::MetadataExt;
use std::path::Path; use std::path::Path;
use std::sync::{mpsc, Arc, Mutex}; use std::sync::Arc;
use tokio::sync::Mutex;
use crate::linux_abi::*; use crate::linux_abi::*;
use crate::mount::{DRIVERBLKTYPE, DRIVERMMIOBLKTYPE, DRIVERNVDIMMTYPE, DRIVERSCSITYPE}; use crate::mount::{DRIVERBLKTYPE, DRIVERMMIOBLKTYPE, DRIVERNVDIMMTYPE, DRIVERSCSITYPE};
@ -35,22 +36,6 @@ struct DevIndexEntry {
struct DevIndex(HashMap<String, DevIndexEntry>); struct DevIndex(HashMap<String, DevIndexEntry>);
// DeviceHandler is the type of callback to be defined to handle every type of device driver.
type DeviceHandler = fn(&Device, &mut Spec, &Arc<Mutex<Sandbox>>, &DevIndex) -> Result<()>;
// DEVICEHANDLERLIST lists the supported drivers.
#[rustfmt::skip]
lazy_static! {
static ref DEVICEHANDLERLIST: HashMap<&'static str, DeviceHandler> = {
let mut m: HashMap<&'static str, DeviceHandler> = HashMap::new();
m.insert(DRIVERBLKTYPE, virtio_blk_device_handler);
m.insert(DRIVERMMIOBLKTYPE, virtiommio_blk_device_handler);
m.insert(DRIVERNVDIMMTYPE, virtio_nvdimm_device_handler);
m.insert(DRIVERSCSITYPE, virtio_scsi_device_handler);
m
};
}
pub fn rescan_pci_bus() -> Result<()> { pub fn rescan_pci_bus() -> Result<()> {
online_device(SYSFS_PCI_BUS_RESCAN_FILE) online_device(SYSFS_PCI_BUS_RESCAN_FILE)
} }
@ -114,10 +99,10 @@ fn get_pci_device_address(pci_id: &str) -> Result<String> {
Ok(bridge_device_pci_addr) Ok(bridge_device_pci_addr)
} }
fn get_device_name(sandbox: &Arc<Mutex<Sandbox>>, dev_addr: &str) -> Result<String> { async fn get_device_name(sandbox: &Arc<Mutex<Sandbox>>, dev_addr: &str) -> Result<String> {
// Keep the same lock order as uevent::handle_block_add_event(), otherwise it may cause deadlock. // Keep the same lock order as uevent::handle_block_add_event(), otherwise it may cause deadlock.
let mut w = GLOBAL_DEVICE_WATCHER.lock().unwrap(); let mut w = GLOBAL_DEVICE_WATCHER.lock().await;
let sb = sandbox.lock().unwrap(); let sb = sandbox.lock().await;
for (key, value) in sb.pci_device_map.iter() { for (key, value) in sb.pci_device_map.iter() {
if key.contains(dev_addr) { if key.contains(dev_addr) {
info!(sl!(), "Device {} found in pci device map", dev_addr); info!(sl!(), "Device {} found in pci device map", dev_addr);
@ -131,36 +116,50 @@ fn get_device_name(sandbox: &Arc<Mutex<Sandbox>>, dev_addr: &str) -> Result<Stri
// The key of the watchers map is the device we are interested in. // The key of the watchers map is the device we are interested in.
// Note this is done inside the lock, not to miss any events from the // Note this is done inside the lock, not to miss any events from the
// global udev listener. // global udev listener.
let (tx, rx) = mpsc::channel::<String>(); let (tx, rx) = tokio::sync::oneshot::channel::<String>();
w.insert(dev_addr.to_string(), tx); w.insert(dev_addr.to_string(), Some(tx));
drop(w); drop(w);
info!(sl!(), "Waiting on channel for device notification\n"); info!(sl!(), "Waiting on channel for device notification\n");
let hotplug_timeout = AGENT_CONFIG.read().unwrap().hotplug_timeout; let hotplug_timeout = AGENT_CONFIG.read().await.hotplug_timeout;
let dev_name = rx.recv_timeout(hotplug_timeout).map_err(|_| { let timeout = tokio::time::delay_for(hotplug_timeout);
GLOBAL_DEVICE_WATCHER.lock().unwrap().remove_entry(dev_addr);
anyhow!( let dev_name;
"Timeout reached after {:?} waiting for device {}", tokio::select! {
hotplug_timeout, v = rx => {
dev_addr dev_name = v?;
) }
})?; _ = timeout => {
let watcher = GLOBAL_DEVICE_WATCHER.clone();
let mut w = watcher.lock().await;
w.remove_entry(dev_addr);
return Err(anyhow!(
"Timeout reached after {:?} waiting for device {}",
hotplug_timeout,
dev_addr
));
}
};
Ok(format!("{}/{}", SYSTEM_DEV_PATH, &dev_name)) Ok(format!("{}/{}", SYSTEM_DEV_PATH, &dev_name))
} }
pub fn get_scsi_device_name(sandbox: &Arc<Mutex<Sandbox>>, scsi_addr: &str) -> Result<String> { pub async fn get_scsi_device_name(
sandbox: &Arc<Mutex<Sandbox>>,
scsi_addr: &str,
) -> Result<String> {
let dev_sub_path = format!("{}{}/{}", SCSI_HOST_CHANNEL, scsi_addr, SCSI_BLOCK_SUFFIX); let dev_sub_path = format!("{}{}/{}", SCSI_HOST_CHANNEL, scsi_addr, SCSI_BLOCK_SUFFIX);
scan_scsi_bus(scsi_addr)?; scan_scsi_bus(scsi_addr)?;
get_device_name(sandbox, &dev_sub_path) get_device_name(sandbox, &dev_sub_path).await
} }
pub fn get_pci_device_name(sandbox: &Arc<Mutex<Sandbox>>, pci_id: &str) -> Result<String> { pub async fn get_pci_device_name(sandbox: &Arc<Mutex<Sandbox>>, pci_id: &str) -> Result<String> {
let pci_addr = get_pci_device_address(pci_id)?; let pci_addr = get_pci_device_address(pci_id)?;
rescan_pci_bus()?; rescan_pci_bus()?;
get_device_name(sandbox, &pci_addr) get_device_name(sandbox, &pci_addr).await
} }
/// Scan SCSI bus for the given SCSI address(SCSI-Id and LUN) /// Scan SCSI bus for the given SCSI address(SCSI-Id and LUN)
@ -274,7 +273,7 @@ fn update_spec_device_list(device: &Device, spec: &mut Spec, devidx: &DevIndex)
// device.Id should be the predicted device name (vda, vdb, ...) // device.Id should be the predicted device name (vda, vdb, ...)
// device.VmPath already provides a way to send it in // device.VmPath already provides a way to send it in
fn virtiommio_blk_device_handler( async fn virtiommio_blk_device_handler(
device: &Device, device: &Device,
spec: &mut Spec, spec: &mut Spec,
_sandbox: &Arc<Mutex<Sandbox>>, _sandbox: &Arc<Mutex<Sandbox>>,
@ -290,7 +289,7 @@ fn virtiommio_blk_device_handler(
// device.Id should be the PCI address in the format "bridgeAddr/deviceAddr". // device.Id should be the PCI address in the format "bridgeAddr/deviceAddr".
// Here, bridgeAddr is the address at which the brige is attached on the root bus, // Here, bridgeAddr is the address at which the brige is attached on the root bus,
// while deviceAddr is the address at which the device is attached on the bridge. // while deviceAddr is the address at which the device is attached on the bridge.
fn virtio_blk_device_handler( async fn virtio_blk_device_handler(
device: &Device, device: &Device,
spec: &mut Spec, spec: &mut Spec,
sandbox: &Arc<Mutex<Sandbox>>, sandbox: &Arc<Mutex<Sandbox>>,
@ -301,25 +300,25 @@ fn virtio_blk_device_handler(
// When "Id (PCIAddr)" is not set, we allow to use the predicted "VmPath" passed from kata-runtime // When "Id (PCIAddr)" is not set, we allow to use the predicted "VmPath" passed from kata-runtime
// Note this is a special code path for cloud-hypervisor when BDF information is not available // Note this is a special code path for cloud-hypervisor when BDF information is not available
if device.id != "" { if device.id != "" {
dev.vm_path = get_pci_device_name(sandbox, &device.id)?; dev.vm_path = get_pci_device_name(sandbox, &device.id).await?;
} }
update_spec_device_list(&dev, spec, devidx) update_spec_device_list(&dev, spec, devidx)
} }
// device.Id should be the SCSI address of the disk in the format "scsiID:lunID" // device.Id should be the SCSI address of the disk in the format "scsiID:lunID"
fn virtio_scsi_device_handler( async fn virtio_scsi_device_handler(
device: &Device, device: &Device,
spec: &mut Spec, spec: &mut Spec,
sandbox: &Arc<Mutex<Sandbox>>, sandbox: &Arc<Mutex<Sandbox>>,
devidx: &DevIndex, devidx: &DevIndex,
) -> Result<()> { ) -> Result<()> {
let mut dev = device.clone(); let mut dev = device.clone();
dev.vm_path = get_scsi_device_name(sandbox, &device.id)?; dev.vm_path = get_scsi_device_name(sandbox, &device.id).await?;
update_spec_device_list(&dev, spec, devidx) update_spec_device_list(&dev, spec, devidx)
} }
fn virtio_nvdimm_device_handler( async fn virtio_nvdimm_device_handler(
device: &Device, device: &Device,
spec: &mut Spec, spec: &mut Spec,
_sandbox: &Arc<Mutex<Sandbox>>, _sandbox: &Arc<Mutex<Sandbox>>,
@ -357,7 +356,7 @@ impl DevIndex {
} }
} }
pub fn add_devices( pub async fn add_devices(
devices: &[Device], devices: &[Device],
spec: &mut Spec, spec: &mut Spec,
sandbox: &Arc<Mutex<Sandbox>>, sandbox: &Arc<Mutex<Sandbox>>,
@ -365,13 +364,13 @@ pub fn add_devices(
let devidx = DevIndex::new(spec); let devidx = DevIndex::new(spec);
for device in devices.iter() { for device in devices.iter() {
add_device(device, spec, sandbox, &devidx)?; add_device(device, spec, sandbox, &devidx).await?;
} }
Ok(()) Ok(())
} }
fn add_device( async fn add_device(
device: &Device, device: &Device,
spec: &mut Spec, spec: &mut Spec,
sandbox: &Arc<Mutex<Sandbox>>, sandbox: &Arc<Mutex<Sandbox>>,
@ -393,9 +392,12 @@ fn add_device(
return Err(anyhow!("invalid container path for device {:?}", device)); return Err(anyhow!("invalid container path for device {:?}", device));
} }
match DEVICEHANDLERLIST.get(device.field_type.as_str()) { match device.field_type.as_str() {
None => Err(anyhow!("Unknown device type {}", device.field_type)), DRIVERBLKTYPE => virtio_blk_device_handler(device, spec, sandbox, devidx).await,
Some(dev_handler) => dev_handler(device, spec, sandbox, devidx), DRIVERMMIOBLKTYPE => virtiommio_blk_device_handler(device, spec, sandbox, devidx).await,
DRIVERNVDIMMTYPE => virtio_nvdimm_device_handler(device, spec, sandbox, devidx).await,
DRIVERSCSITYPE => virtio_scsi_device_handler(device, spec, sandbox, devidx).await,
_ => Err(anyhow!("Unknown device type {}", device.field_type)),
} }
} }

View File

@ -15,7 +15,6 @@ extern crate prctl;
extern crate prometheus; extern crate prometheus;
extern crate protocols; extern crate protocols;
extern crate regex; extern crate regex;
extern crate rustjail;
extern crate scan_fmt; extern crate scan_fmt;
extern crate serde_json; extern crate serde_json;
extern crate signal_hook; extern crate signal_hook;
@ -38,7 +37,6 @@ use nix::sys::socket::{self, AddressFamily, SockAddr, SockFlag, SockType};
use nix::sys::wait::{self, WaitStatus}; use nix::sys::wait::{self, WaitStatus};
use nix::unistd::{self, close, dup, dup2, fork, setsid, ForkResult}; use nix::unistd::{self, close, dup, dup2, fork, setsid, ForkResult};
use prctl::set_child_subreaper; use prctl::set_child_subreaper;
use signal_hook::{iterator::Signals, SIGCHLD};
use std::collections::HashMap; use std::collections::HashMap;
use std::env; use std::env;
use std::ffi::{CStr, CString, OsStr}; use std::ffi::{CStr, CString, OsStr};
@ -48,9 +46,7 @@ use std::os::unix::ffi::OsStrExt;
use std::os::unix::fs as unixfs; use std::os::unix::fs as unixfs;
use std::os::unix::io::AsRawFd; use std::os::unix::io::AsRawFd;
use std::path::Path; use std::path::Path;
use std::sync::mpsc::{self, Sender}; use std::sync::Arc;
use std::sync::{Arc, Mutex, RwLock};
use std::{io, thread, thread::JoinHandle};
use unistd::Pid; use unistd::Pid;
mod config; mod config;
@ -72,6 +68,16 @@ use sandbox::Sandbox;
use slog::Logger; use slog::Logger;
use uevent::watch_uevents; use uevent::watch_uevents;
use std::sync::Mutex as SyncMutex;
use futures::StreamExt as _;
use rustjail::pipestream::PipeStream;
use tokio::{
signal::unix::{signal, SignalKind},
sync::{oneshot::Sender, Mutex, RwLock},
};
use tokio_vsock::{Incoming, VsockListener, VsockStream};
mod rpc; mod rpc;
const NAME: &str = "kata-agent"; const NAME: &str = "kata-agent";
@ -81,7 +87,7 @@ const CONSOLE_PATH: &str = "/dev/console";
const DEFAULT_BUF_SIZE: usize = 8 * 1024; const DEFAULT_BUF_SIZE: usize = 8 * 1024;
lazy_static! { lazy_static! {
static ref GLOBAL_DEVICE_WATCHER: Arc<Mutex<HashMap<String, Sender<String>>>> = static ref GLOBAL_DEVICE_WATCHER: Arc<Mutex<HashMap<String, Option<Sender<String>>>>> =
Arc::new(Mutex::new(HashMap::new())); Arc::new(Mutex::new(HashMap::new()));
static ref AGENT_CONFIG: Arc<RwLock<agentConfig>> = static ref AGENT_CONFIG: Arc<RwLock<agentConfig>> =
Arc::new(RwLock::new(config::agentConfig::new())); Arc::new(RwLock::new(config::agentConfig::new()));
@ -100,8 +106,28 @@ fn announce(logger: &Logger, config: &agentConfig) {
); );
} }
#[tokio::main] fn set_fd_close_exec(fd: RawFd) -> Result<RawFd> {
async fn main() -> Result<()> { if let Err(e) = fcntl::fcntl(fd, FcntlArg::F_SETFD(FdFlag::FD_CLOEXEC)) {
return Err(anyhow!("failed to set fd: {} as close-on-exec: {}", fd, e));
}
Ok(fd)
}
fn get_vsock_incoming(fd: RawFd) -> Incoming {
let incoming;
unsafe {
incoming = VsockListener::from_raw_fd(fd).incoming();
}
incoming
}
async fn get_vsock_stream(fd: RawFd) -> Result<VsockStream> {
let stream = get_vsock_incoming(fd).next().await.unwrap().unwrap();
set_fd_close_exec(stream.as_raw_fd())?;
Ok(stream)
}
fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let args: Vec<String> = env::args().collect(); let args: Vec<String> = env::args().collect();
if args.len() == 2 && args[1] == "--version" { if args.len() == 2 && args[1] == "--version" {
@ -121,107 +147,122 @@ async fn main() -> Result<()> {
exit(0); exit(0);
} }
env::set_var("RUST_BACKTRACE", "full"); let mut rt = tokio::runtime::Builder::new()
.basic_scheduler()
.max_threads(1)
.enable_all()
.build()?;
lazy_static::initialize(&SHELLS); rt.block_on(async {
env::set_var("RUST_BACKTRACE", "full");
lazy_static::initialize(&AGENT_CONFIG); lazy_static::initialize(&SHELLS);
// support vsock log lazy_static::initialize(&AGENT_CONFIG);
let (rfd, wfd) = unistd::pipe2(OFlag::O_CLOEXEC)?;
let agentConfig = AGENT_CONFIG.clone(); // support vsock log
let (rfd, wfd) = unistd::pipe2(OFlag::O_CLOEXEC)?;
let init_mode = unistd::getpid() == Pid::from_raw(1); let agentConfig = AGENT_CONFIG.clone();
if init_mode {
// dup a new file descriptor for this temporary logger writer,
// since this logger would be dropped and it's writer would
// be closed out of this code block.
let newwfd = dup(wfd)?;
let writer = unsafe { File::from_raw_fd(newwfd) };
// Init a temporary logger used by init agent as init process let init_mode = unistd::getpid() == Pid::from_raw(1);
// since before do the base mount, it wouldn't access "/proc/cmdline" if init_mode {
// to get the customzied debug level. // dup a new file descriptor for this temporary logger writer,
let logger = logging::create_logger(NAME, "agent", slog::Level::Debug, writer); // since this logger would be dropped and it's writer would
// be closed out of this code block.
let newwfd = dup(wfd)?;
let writer = unsafe { File::from_raw_fd(newwfd) };
// Must mount proc fs before parsing kernel command line // Init a temporary logger used by init agent as init process
general_mount(&logger).map_err(|e| { // since before do the base mount, it wouldn't access "/proc/cmdline"
error!(logger, "fail general mount: {}", e); // to get the customzied debug level.
e let logger = logging::create_logger(NAME, "agent", slog::Level::Debug, writer);
})?;
let mut config = agentConfig.write().unwrap(); // Must mount proc fs before parsing kernel command line
config.parse_cmdline(KERNEL_CMDLINE_FILE)?; general_mount(&logger).map_err(|e| {
error!(logger, "fail general mount: {}", e);
e
})?;
init_agent_as_init(&logger, config.unified_cgroup_hierarchy)?; let mut config = agentConfig.write().await;
} else { config.parse_cmdline(KERNEL_CMDLINE_FILE)?;
// once parsed cmdline and set the config, release the write lock
// as soon as possible in case other thread would get read lock on
// it.
let mut config = agentConfig.write().unwrap();
config.parse_cmdline(KERNEL_CMDLINE_FILE)?;
}
let config = agentConfig.read().unwrap();
let log_vport = config.log_vport as u32; init_agent_as_init(&logger, config.unified_cgroup_hierarchy)?;
let log_handle = thread::spawn(move || -> Result<()> { } else {
let mut reader = unsafe { File::from_raw_fd(rfd) }; // once parsed cmdline and set the config, release the write lock
if log_vport > 0 { // as soon as possible in case other thread would get read lock on
let listenfd = socket::socket( // it.
AddressFamily::Vsock, let mut config = agentConfig.write().await;
SockType::Stream, config.parse_cmdline(KERNEL_CMDLINE_FILE)?;
SockFlag::SOCK_CLOEXEC,
None,
)?;
let addr = SockAddr::new_vsock(libc::VMADDR_CID_ANY, log_vport);
socket::bind(listenfd, &addr)?;
socket::listen(listenfd, 1)?;
let datafd = socket::accept4(listenfd, SockFlag::SOCK_CLOEXEC)?;
let mut log_writer = unsafe { File::from_raw_fd(datafd) };
let _ = io::copy(&mut reader, &mut log_writer)?;
let _ = unistd::close(listenfd);
let _ = unistd::close(datafd);
} }
// copy log to stdout let config = agentConfig.read().await;
let mut stdout_writer = io::stdout();
let _ = io::copy(&mut reader, &mut stdout_writer)?; let log_vport = config.log_vport as u32;
let log_handle = tokio::spawn(async move {
let mut reader = PipeStream::from_fd(rfd);
if log_vport > 0 {
let listenfd = socket::socket(
AddressFamily::Vsock,
SockType::Stream,
SockFlag::SOCK_CLOEXEC,
None,
)
.unwrap();
let addr = SockAddr::new_vsock(libc::VMADDR_CID_ANY, log_vport);
socket::bind(listenfd, &addr).unwrap();
socket::listen(listenfd, 1).unwrap();
let mut vsock_stream = get_vsock_stream(listenfd).await.unwrap();
// copy log to stdout
tokio::io::copy(&mut reader, &mut vsock_stream)
.await
.unwrap();
}
// copy log to stdout
let mut stdout_writer = tokio::io::stdout();
let _ = tokio::io::copy(&mut reader, &mut stdout_writer).await;
});
let writer = unsafe { File::from_raw_fd(wfd) };
// Recreate a logger with the log level get from "/proc/cmdline".
let logger = logging::create_logger(NAME, "agent", config.log_level, writer);
announce(&logger, &config);
// This "unused" variable is required as it enables the global (and crucially static) logger,
// which is required to satisfy the the lifetime constraints of the auto-generated gRPC code.
let _guard = slog_scope::set_global_logger(logger.new(o!("subsystem" => "rpc")));
let mut _log_guard: Result<(), log::SetLoggerError> = Ok(());
if config.log_level == slog::Level::Trace {
// Redirect ttrpc log calls to slog iff full debug requested
_log_guard = Ok(slog_stdlog::init().map_err(|e| e)?);
}
start_sandbox(&logger, &config, init_mode).await?;
let _ = log_handle.await.unwrap();
Ok(()) Ok(())
}); })
let writer = unsafe { File::from_raw_fd(wfd) };
// Recreate a logger with the log level get from "/proc/cmdline".
let logger = logging::create_logger(NAME, "agent", config.log_level, writer);
announce(&logger, &config);
// This "unused" variable is required as it enables the global (and crucially static) logger,
// which is required to satisfy the the lifetime constraints of the auto-generated gRPC code.
let _guard = slog_scope::set_global_logger(logger.new(o!("subsystem" => "rpc")));
let mut _log_guard: Result<(), log::SetLoggerError> = Ok(());
if config.log_level == slog::Level::Trace {
// Redirect ttrpc log calls to slog iff full debug requested
_log_guard = Ok(slog_stdlog::init().map_err(|e| e)?);
}
start_sandbox(&logger, &config, init_mode).await?;
let _ = log_handle.join();
Ok(())
} }
async fn start_sandbox(logger: &Logger, config: &agentConfig, init_mode: bool) -> Result<()> { async fn start_sandbox(logger: &Logger, config: &agentConfig, init_mode: bool) -> Result<()> {
let shells = SHELLS.clone(); let shells = SHELLS.clone();
let debug_console_vport = config.debug_console_vport as u32; let debug_console_vport = config.debug_console_vport as u32;
let mut shell_handle: Option<JoinHandle<()>> = None; // TODO: async the debug console
let mut _shell_handle: Option<std::thread::JoinHandle<()>> = None;
if config.debug_console { if config.debug_console {
let thread_logger = logger.clone(); let thread_logger = logger.clone();
let builder = thread::Builder::new(); let builder = std::thread::Builder::new();
let handle = builder.spawn(move || { let handle = builder.spawn(move || {
let shells = shells.lock().unwrap(); let shells = shells.lock().unwrap();
@ -233,7 +274,7 @@ async fn start_sandbox(logger: &Logger, config: &agentConfig, init_mode: bool) -
} }
})?; })?;
shell_handle = Some(handle); _shell_handle = Some(handle);
} }
// Initialize unique sandbox structure. // Initialize unique sandbox structure.
@ -248,41 +289,38 @@ async fn start_sandbox(logger: &Logger, config: &agentConfig, init_mode: bool) -
let sandbox = Arc::new(Mutex::new(s)); let sandbox = Arc::new(Mutex::new(s));
setup_signal_handler(&logger, sandbox.clone()).unwrap(); setup_signal_handler(&logger, sandbox.clone())
watch_uevents(sandbox.clone()); .await
.unwrap();
watch_uevents(sandbox.clone()).await;
let (tx, rx) = mpsc::channel::<i32>(); let (tx, rx) = tokio::sync::oneshot::channel();
sandbox.lock().unwrap().sender = Some(tx); sandbox.lock().await.sender = Some(tx);
// vsock:///dev/vsock, port // vsock:///dev/vsock, port
let mut server = rpc::start(sandbox, config.server_addr.as_str()).await?; let mut server = rpc::start(sandbox.clone(), config.server_addr.as_str());
server.start().await?; server.start().await?;
let _ = rx.recv()?; let _ = rx.await?;
server.shutdown().await?; server.shutdown().await?;
if let Some(handle) = shell_handle {
handle.join().map_err(|e| anyhow!("{:?}", e))?;
}
Ok(()) Ok(())
} }
use nix::sys::wait::WaitPidFlag; use nix::sys::wait::WaitPidFlag;
fn setup_signal_handler(logger: &Logger, sandbox: Arc<Mutex<Sandbox>>) -> Result<()> { async fn setup_signal_handler(logger: &Logger, sandbox: Arc<Mutex<Sandbox>>) -> Result<()> {
let logger = logger.new(o!("subsystem" => "signals")); let logger = logger.new(o!("subsystem" => "signals"));
set_child_subreaper(true) set_child_subreaper(true)
.map_err(|err| anyhow!(err).context("failed to setup agent as a child subreaper"))?; .map_err(|err| anyhow!(err).context("failed to setup agent as a child subreaper"))?;
let signals = Signals::new(&[SIGCHLD])?; let mut signal_stream = signal(SignalKind::child())?;
thread::spawn(move || { tokio::spawn(async move {
'outer: for sig in signals.forever() { 'outer: loop {
info!(logger, "received signal"; "signal" => sig); signal_stream.recv().await;
info!(logger, "received signal"; "signal" => "SIGCHLD");
// sevral signals can be combined together // sevral signals can be combined together
// as one. So loop around to reap all // as one. So loop around to reap all
@ -316,7 +354,7 @@ fn setup_signal_handler(logger: &Logger, sandbox: Arc<Mutex<Sandbox>>) -> Result
let logger = logger.new(o!("child-pid" => child_pid)); let logger = logger.new(o!("child-pid" => child_pid));
let mut sandbox = sandbox.lock().unwrap(); let mut sandbox = sandbox.lock().await;
let process = sandbox.find_process(raw_pid); let process = sandbox.find_process(raw_pid);
if process.is_none() { if process.is_none() {
info!(logger, "child exited unexpectedly"); info!(logger, "child exited unexpectedly");
@ -375,7 +413,7 @@ fn init_agent_as_init(logger: &Logger, unified_cgroup_hierarchy: bool) -> Result
unistd::setsid()?; unistd::setsid()?;
unsafe { unsafe {
libc::ioctl(io::stdin().as_raw_fd(), libc::TIOCSCTTY, 1); libc::ioctl(std::io::stdin().as_raw_fd(), libc::TIOCSCTTY, 1);
} }
env::set_var("PATH", "/bin:/sbin/:/usr/bin/:/usr/sbin/"); env::set_var("PATH", "/bin:/sbin/:/usr/bin/:/usr/sbin/");
@ -406,7 +444,7 @@ fn sethostname(hostname: &OsStr) -> Result<()> {
} }
lazy_static! { lazy_static! {
static ref SHELLS: Arc<Mutex<Vec<String>>> = { static ref SHELLS: Arc<SyncMutex<Vec<String>>> = {
let mut v = Vec::new(); let mut v = Vec::new();
if !cfg!(test) { if !cfg!(test) {
@ -414,7 +452,7 @@ lazy_static! {
v.push("/bin/sh".to_string()); v.push("/bin/sh".to_string());
} }
Arc::new(Mutex::new(v)) Arc::new(SyncMutex::new(v))
}; };
} }
@ -480,7 +518,7 @@ fn setup_debug_console(logger: &Logger, shells: Vec<String>, port: u32) -> Resul
}; };
} }
fn io_copy<R: ?Sized, W: ?Sized>(reader: &mut R, writer: &mut W) -> io::Result<u64> fn io_copy<R: ?Sized, W: ?Sized>(reader: &mut R, writer: &mut W) -> std::io::Result<u64>
where where
R: Read, R: Read,
W: Write, W: Write,
@ -543,10 +581,10 @@ fn run_debug_console_shell(logger: &Logger, shell: &str, socket_fd: RawFd) -> Re
let debug_shell_logger = logger.clone(); let debug_shell_logger = logger.clone();
// channel that used to sync between thread and main process // channel that used to sync between thread and main process
let (tx, rx) = mpsc::channel::<i32>(); let (tx, rx) = std::sync::mpsc::channel::<i32>();
// start a thread to do IO copy between socket and pseduo.master // start a thread to do IO copy between socket and pseduo.master
thread::spawn(move || { std::thread::spawn(move || {
let mut master_reader = unsafe { File::from_raw_fd(master_fd) }; let mut master_reader = unsafe { File::from_raw_fd(master_fd) };
let mut master_writer = unsafe { File::from_raw_fd(master_fd) }; let mut master_writer = unsafe { File::from_raw_fd(master_fd) };
let mut socket_reader = unsafe { File::from_raw_fd(socket_fd) }; let mut socket_reader = unsafe { File::from_raw_fd(socket_fd) };

View File

@ -12,7 +12,8 @@ use std::os::unix::fs::PermissionsExt;
use std::path::Path; use std::path::Path;
use std::ptr::null; use std::ptr::null;
use std::sync::{Arc, Mutex}; use std::sync::Arc;
use tokio::sync::Mutex;
use libc::{c_void, mount}; use libc::{c_void, mount};
use nix::mount::{self, MsFlags}; use nix::mount::{self, MsFlags};
@ -121,32 +122,15 @@ lazy_static! {
]; ];
} }
// StorageHandler is the type of callback to be defined to handle every pub const STORAGE_HANDLER_LIST: [&str; 7] = [
// type of storage driver. DRIVERBLKTYPE,
type StorageHandler = fn(&Logger, &Storage, Arc<Mutex<Sandbox>>) -> Result<String>; DRIVER9PTYPE,
DRIVERVIRTIOFSTYPE,
// STORAGEHANDLERLIST lists the supported drivers. DRIVEREPHEMERALTYPE,
#[rustfmt::skip] DRIVERMMIOBLKTYPE,
lazy_static! { DRIVERLOCALTYPE,
pub static ref STORAGEHANDLERLIST: HashMap<&'static str, StorageHandler> = { DRIVERSCSITYPE,
let mut m = HashMap::new(); ];
let blk: StorageHandler = virtio_blk_storage_handler;
m.insert(DRIVERBLKTYPE, blk);
let p9: StorageHandler= virtio9p_storage_handler;
m.insert(DRIVER9PTYPE, p9);
let virtiofs: StorageHandler = virtiofs_storage_handler;
m.insert(DRIVERVIRTIOFSTYPE, virtiofs);
let ephemeral: StorageHandler = ephemeral_storage_handler;
m.insert(DRIVEREPHEMERALTYPE, ephemeral);
let virtiommio: StorageHandler = virtiommio_blk_storage_handler;
m.insert(DRIVERMMIOBLKTYPE, virtiommio);
let local: StorageHandler = local_storage_handler;
m.insert(DRIVERLOCALTYPE, local);
let scsi: StorageHandler = virtio_scsi_storage_handler;
m.insert(DRIVERSCSITYPE, scsi);
m
};
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct BareMount<'a> { pub struct BareMount<'a> {
@ -238,12 +222,12 @@ impl<'a> BareMount<'a> {
} }
} }
fn ephemeral_storage_handler( async fn ephemeral_storage_handler(
logger: &Logger, logger: &Logger,
storage: &Storage, storage: &Storage,
sandbox: Arc<Mutex<Sandbox>>, sandbox: Arc<Mutex<Sandbox>>,
) -> Result<String> { ) -> Result<String> {
let mut sb = sandbox.lock().unwrap(); let mut sb = sandbox.lock().await;
let new_storage = sb.set_sandbox_storage(&storage.mount_point); let new_storage = sb.set_sandbox_storage(&storage.mount_point);
if !new_storage { if !new_storage {
@ -256,12 +240,12 @@ fn ephemeral_storage_handler(
Ok("".to_string()) Ok("".to_string())
} }
fn local_storage_handler( async fn local_storage_handler(
_logger: &Logger, _logger: &Logger,
storage: &Storage, storage: &Storage,
sandbox: Arc<Mutex<Sandbox>>, sandbox: Arc<Mutex<Sandbox>>,
) -> Result<String> { ) -> Result<String> {
let mut sb = sandbox.lock().unwrap(); let mut sb = sandbox.lock().await;
let new_storage = sb.set_sandbox_storage(&storage.mount_point); let new_storage = sb.set_sandbox_storage(&storage.mount_point);
if !new_storage { if !new_storage {
@ -289,7 +273,7 @@ fn local_storage_handler(
Ok("".to_string()) Ok("".to_string())
} }
fn virtio9p_storage_handler( async fn virtio9p_storage_handler(
logger: &Logger, logger: &Logger,
storage: &Storage, storage: &Storage,
_sandbox: Arc<Mutex<Sandbox>>, _sandbox: Arc<Mutex<Sandbox>>,
@ -298,7 +282,7 @@ fn virtio9p_storage_handler(
} }
// virtiommio_blk_storage_handler handles the storage for mmio blk driver. // virtiommio_blk_storage_handler handles the storage for mmio blk driver.
fn virtiommio_blk_storage_handler( async fn virtiommio_blk_storage_handler(
logger: &Logger, logger: &Logger,
storage: &Storage, storage: &Storage,
_sandbox: Arc<Mutex<Sandbox>>, _sandbox: Arc<Mutex<Sandbox>>,
@ -308,7 +292,7 @@ fn virtiommio_blk_storage_handler(
} }
// virtiofs_storage_handler handles the storage for virtio-fs. // virtiofs_storage_handler handles the storage for virtio-fs.
fn virtiofs_storage_handler( async fn virtiofs_storage_handler(
logger: &Logger, logger: &Logger,
storage: &Storage, storage: &Storage,
_sandbox: Arc<Mutex<Sandbox>>, _sandbox: Arc<Mutex<Sandbox>>,
@ -317,7 +301,7 @@ fn virtiofs_storage_handler(
} }
// virtio_blk_storage_handler handles the storage for blk driver. // virtio_blk_storage_handler handles the storage for blk driver.
fn virtio_blk_storage_handler( async fn virtio_blk_storage_handler(
logger: &Logger, logger: &Logger,
storage: &Storage, storage: &Storage,
sandbox: Arc<Mutex<Sandbox>>, sandbox: Arc<Mutex<Sandbox>>,
@ -334,7 +318,7 @@ fn virtio_blk_storage_handler(
return Err(anyhow!("Invalid device {}", &storage.source)); return Err(anyhow!("Invalid device {}", &storage.source));
} }
} else { } else {
let dev_path = get_pci_device_name(&sandbox, &storage.source)?; let dev_path = get_pci_device_name(&sandbox, &storage.source).await?;
storage.source = dev_path; storage.source = dev_path;
} }
@ -342,7 +326,7 @@ fn virtio_blk_storage_handler(
} }
// virtio_scsi_storage_handler handles the storage for scsi driver. // virtio_scsi_storage_handler handles the storage for scsi driver.
fn virtio_scsi_storage_handler( async fn virtio_scsi_storage_handler(
logger: &Logger, logger: &Logger,
storage: &Storage, storage: &Storage,
sandbox: Arc<Mutex<Sandbox>>, sandbox: Arc<Mutex<Sandbox>>,
@ -350,7 +334,7 @@ fn virtio_scsi_storage_handler(
let mut storage = storage.clone(); let mut storage = storage.clone();
// Retrieve the device path from SCSI address. // Retrieve the device path from SCSI address.
let dev_path = get_scsi_device_name(&sandbox, &storage.source)?; let dev_path = get_scsi_device_name(&sandbox, &storage.source).await?;
storage.source = dev_path; storage.source = dev_path;
common_storage_handler(logger, &storage) common_storage_handler(logger, &storage)
@ -430,7 +414,7 @@ fn parse_mount_flags_and_options(options_vec: Vec<&str>) -> (MsFlags, String) {
// associated operations such as waiting for the device to show up, and mount // associated operations such as waiting for the device to show up, and mount
// it to a specific location, according to the type of handler chosen, and for // it to a specific location, according to the type of handler chosen, and for
// each storage. // each storage.
pub fn add_storages( pub async fn add_storages(
logger: Logger, logger: Logger,
storages: Vec<Storage>, storages: Vec<Storage>,
sandbox: Arc<Mutex<Sandbox>>, sandbox: Arc<Mutex<Sandbox>>,
@ -443,17 +427,30 @@ pub fn add_storages(
"subsystem" => "storage", "subsystem" => "storage",
"storage-type" => handler_name.to_owned())); "storage-type" => handler_name.to_owned()));
let handler = STORAGEHANDLERLIST let res = match handler_name.as_str() {
.get(&handler_name.as_str()) DRIVERBLKTYPE => virtio_blk_storage_handler(&logger, &storage, sandbox.clone()).await,
.ok_or_else(|| { DRIVER9PTYPE => virtio9p_storage_handler(&logger, &storage, sandbox.clone()).await,
anyhow!( DRIVERVIRTIOFSTYPE => {
virtiofs_storage_handler(&logger, &storage, sandbox.clone()).await
}
DRIVEREPHEMERALTYPE => {
ephemeral_storage_handler(&logger, &storage, sandbox.clone()).await
}
DRIVERMMIOBLKTYPE => {
virtiommio_blk_storage_handler(&logger, &storage, sandbox.clone()).await
}
DRIVERLOCALTYPE => local_storage_handler(&logger, &storage, sandbox.clone()).await,
DRIVERSCSITYPE => virtio_scsi_storage_handler(&logger, &storage, sandbox.clone()).await,
_ => {
return Err(anyhow!(
"Failed to find the storage handler {}", "Failed to find the storage handler {}",
storage.driver.to_owned() storage.driver.to_owned()
) ));
})?; }
};
// Todo need to rollback the mounted storage if err met. // Todo need to rollback the mounted storage if err met.
let mount_point = handler(&logger, &storage, sandbox.clone())?; let mount_point = res?;
if !mount_point.is_empty() { if !mount_point.is_empty() {
mount_list.push(mount_point); mount_list.push(mount_point);

View File

@ -11,7 +11,6 @@ use std::fmt;
use std::fs; use std::fs;
use std::fs::File; use std::fs::File;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::thread::{self};
use crate::mount::{BareMount, FLAGS}; use crate::mount::{BareMount, FLAGS};
use slog::Logger; use slog::Logger;
@ -76,7 +75,7 @@ impl Namespace {
// setup creates persistent namespace without switching to it. // setup creates persistent namespace without switching to it.
// Note, pid namespaces cannot be persisted. // Note, pid namespaces cannot be persisted.
pub fn setup(mut self) -> Result<Self> { pub async fn setup(mut self) -> Result<Self> {
fs::create_dir_all(&self.persistent_ns_dir)?; fs::create_dir_all(&self.persistent_ns_dir)?;
let ns_path = PathBuf::from(&self.persistent_ns_dir); let ns_path = PathBuf::from(&self.persistent_ns_dir);
@ -93,45 +92,51 @@ impl Namespace {
self.path = new_ns_path.clone().into_os_string().into_string().unwrap(); self.path = new_ns_path.clone().into_os_string().into_string().unwrap();
let hostname = self.hostname.clone(); let hostname = self.hostname.clone();
let new_thread = thread::spawn(move || -> Result<()> { let new_thread = tokio::spawn(async move {
let origin_ns_path = get_current_thread_ns_path(&ns_type.get()); if let Err(err) = || -> Result<()> {
let origin_ns_path = get_current_thread_ns_path(&ns_type.get());
File::open(Path::new(&origin_ns_path))?; File::open(Path::new(&origin_ns_path))?;
// Create a new netns on the current thread. // Create a new netns on the current thread.
let cf = ns_type.get_flags(); let cf = ns_type.get_flags();
unshare(cf)?; unshare(cf)?;
if ns_type == NamespaceType::UTS && hostname.is_some() { if ns_type == NamespaceType::UTS && hostname.is_some() {
nix::unistd::sethostname(hostname.unwrap())?; nix::unistd::sethostname(hostname.unwrap())?;
}
// Bind mount the new namespace from the current thread onto the mount point to persist it.
let source: &str = origin_ns_path.as_str();
let destination: &str = new_ns_path.as_path().to_str().unwrap_or("none");
let mut flags = MsFlags::empty();
if let Some(x) = FLAGS.get("rbind") {
let (_, f) = *x;
flags |= f;
};
let bare_mount = BareMount::new(source, destination, "none", flags, "", &logger);
bare_mount.mount().map_err(|e| {
anyhow!(
"Failed to mount {} to {} with err:{:?}",
source,
destination,
e
)
})?;
Ok(())
}() {
return Err(err);
} }
// Bind mount the new namespace from the current thread onto the mount point to persist it.
let source: &str = origin_ns_path.as_str();
let destination: &str = new_ns_path.as_path().to_str().unwrap_or("none");
let mut flags = MsFlags::empty();
if let Some(x) = FLAGS.get("rbind") {
let (_, f) = *x;
flags |= f;
};
let bare_mount = BareMount::new(source, destination, "none", flags, "", &logger);
bare_mount.mount().map_err(|e| {
anyhow!(
"Failed to mount {} to {} with err:{:?}",
source,
destination,
e
)
})?;
Ok(()) Ok(())
}); });
new_thread new_thread
.join() .await
.map_err(|e| anyhow!("Failed to join thread {:?}!", e))??; .map_err(|e| anyhow!("Failed to join thread {:?}!", e))??;
Ok(self) Ok(self)

View File

@ -4,9 +4,12 @@
// //
use async_trait::async_trait; use async_trait::async_trait;
use rustjail::{pipestream::PipeStream, process::StreamType};
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf};
use tokio::sync::Mutex;
use std::path::Path; use std::path::Path;
use std::sync::mpsc::channel; use std::sync::Arc;
use std::sync::{Arc, Mutex};
use ttrpc::{ use ttrpc::{
self, self,
error::get_rpc_status as ttrpc_error, error::get_rpc_status as ttrpc_error,
@ -29,7 +32,6 @@ use protocols::types::Interface;
use rustjail::cgroups::notifier; use rustjail::cgroups::notifier;
use rustjail::container::{BaseContainer, Container, LinuxContainer}; use rustjail::container::{BaseContainer, Container, LinuxContainer};
use rustjail::process::Process; use rustjail::process::Process;
use rustjail::reaper;
use rustjail::specconv::CreateOpts; use rustjail::specconv::CreateOpts;
use nix::errno::Errno; use nix::errno::Errno;
@ -42,7 +44,7 @@ use rustjail::process::ProcessOperations;
use crate::device::{add_devices, rescan_pci_bus, update_device_cgroup}; use crate::device::{add_devices, rescan_pci_bus, update_device_cgroup};
use crate::linux_abi::*; use crate::linux_abi::*;
use crate::metrics::get_metrics; use crate::metrics::get_metrics;
use crate::mount::{add_storages, remove_mounts, BareMount, STORAGEHANDLERLIST}; use crate::mount::{add_storages, remove_mounts, BareMount, STORAGE_HANDLER_LIST};
use crate::namespace::{NSTYPEIPC, NSTYPEPID, NSTYPEUTS}; use crate::namespace::{NSTYPEIPC, NSTYPEPID, NSTYPEUTS};
use crate::network::setup_guest_dns; use crate::network::setup_guest_dns;
use crate::random; use crate::random;
@ -54,11 +56,8 @@ use netlink::{RtnlHandle, NETLINK_ROUTE};
use libc::{self, c_ushort, pid_t, winsize, TIOCSWINSZ}; use libc::{self, c_ushort, pid_t, winsize, TIOCSWINSZ};
use std::convert::TryFrom; use std::convert::TryFrom;
use std::fs; use std::fs;
use std::os::unix::io::RawFd;
use std::os::unix::prelude::PermissionsExt; use std::os::unix::prelude::PermissionsExt;
use std::process::{Command, Stdio}; use std::process::{Command, Stdio};
use std::sync::mpsc;
use std::thread;
use std::time::Duration; use std::time::Duration;
use nix::unistd::{Gid, Uid}; use nix::unistd::{Gid, Uid};
@ -83,7 +82,10 @@ pub struct agentService {
} }
impl agentService { impl agentService {
fn do_create_container(&self, req: protocols::agent::CreateContainerRequest) -> Result<()> { async fn do_create_container(
&self,
req: protocols::agent::CreateContainerRequest,
) -> Result<()> {
let cid = req.container_id.clone(); let cid = req.container_id.clone();
let mut oci_spec = req.OCI.clone(); let mut oci_spec = req.OCI.clone();
@ -111,7 +113,7 @@ impl agentService {
// updates the devices listed in the OCI spec, so that they actually // updates the devices listed in the OCI spec, so that they actually
// match real devices inside the VM. This step is necessary since we // match real devices inside the VM. This step is necessary since we
// cannot predict everything from the caller. // cannot predict everything from the caller.
add_devices(&req.devices.to_vec(), &mut oci, &self.sandbox)?; add_devices(&req.devices.to_vec(), &mut oci, &self.sandbox).await?;
// Both rootfs and volumes (invoked with --volume for instance) will // Both rootfs and volumes (invoked with --volume for instance) will
// be processed the same way. The idea is to always mount any provided // be processed the same way. The idea is to always mount any provided
@ -120,10 +122,10 @@ impl agentService {
// After all those storages have been processed, no matter the order // After all those storages have been processed, no matter the order
// here, the agent will rely on rustjail (using the oci.Mounts // here, the agent will rely on rustjail (using the oci.Mounts
// list) to bind mount all of them inside the container. // list) to bind mount all of them inside the container.
let m = add_storages(sl!(), req.storages.to_vec(), self.sandbox.clone())?; let m = add_storages(sl!(), req.storages.to_vec(), self.sandbox.clone()).await?;
{ {
sandbox = self.sandbox.clone(); sandbox = self.sandbox.clone();
s = sandbox.lock().unwrap(); s = sandbox.lock().await;
s.container_mounts.insert(cid.clone(), m); s.container_mounts.insert(cid.clone(), m);
} }
@ -154,7 +156,7 @@ impl agentService {
let mut ctr: LinuxContainer = let mut ctr: LinuxContainer =
LinuxContainer::new(cid.as_str(), CONTAINER_BASE, opts, &sl!())?; LinuxContainer::new(cid.as_str(), CONTAINER_BASE, opts, &sl!())?;
let pipe_size = AGENT_CONFIG.read().unwrap().container_pipe_size; let pipe_size = AGENT_CONFIG.read().await.container_pipe_size;
let p = if oci.process.is_some() { let p = if oci.process.is_some() {
Process::new( Process::new(
&sl!(), &sl!(),
@ -168,7 +170,7 @@ impl agentService {
return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL))); return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL)));
}; };
ctr.start(p)?; ctr.start(p).await?;
s.update_shared_pidns(&ctr)?; s.update_shared_pidns(&ctr)?;
s.add_container(ctr); s.add_container(ctr);
@ -177,11 +179,11 @@ impl agentService {
Ok(()) Ok(())
} }
fn do_start_container(&self, req: protocols::agent::StartContainerRequest) -> Result<()> { async fn do_start_container(&self, req: protocols::agent::StartContainerRequest) -> Result<()> {
let cid = req.container_id; let cid = req.container_id;
let sandbox = self.sandbox.clone(); let sandbox = self.sandbox.clone();
let mut s = sandbox.lock().unwrap(); let mut s = sandbox.lock().await;
let sid = s.id.clone(); let sid = s.id.clone();
let ctr = s let ctr = s
@ -194,8 +196,8 @@ impl agentService {
if sid != cid && ctr.cgroup_manager.is_some() { if sid != cid && ctr.cgroup_manager.is_some() {
let cg_path = ctr.cgroup_manager.as_ref().unwrap().get_cg_path("memory"); let cg_path = ctr.cgroup_manager.as_ref().unwrap().get_cg_path("memory");
if cg_path.is_some() { if cg_path.is_some() {
let rx = notifier::notify_oom(cid.as_str(), cg_path.unwrap())?; let rx = notifier::notify_oom(cid.as_str(), cg_path.unwrap()).await?;
s.run_oom_event_monitor(rx, cid.clone()); s.run_oom_event_monitor(rx, cid.clone()).await;
} }
} }
@ -206,7 +208,10 @@ impl agentService {
Ok(()) Ok(())
} }
fn do_remove_container(&self, req: protocols::agent::RemoveContainerRequest) -> Result<()> { async fn do_remove_container(
&self,
req: protocols::agent::RemoveContainerRequest,
) -> Result<()> {
let cid = req.container_id.clone(); let cid = req.container_id.clone();
let mut cmounts: Vec<String> = vec![]; let mut cmounts: Vec<String> = vec![];
@ -234,12 +239,12 @@ impl agentService {
if req.timeout == 0 { if req.timeout == 0 {
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
let ctr = sandbox let ctr = sandbox
.get_container(&cid) .get_container(&cid)
.ok_or_else(|| anyhow!("Invalid container id"))?; .ok_or_else(|| anyhow!("Invalid container id"))?;
ctr.destroy()?; ctr.destroy().await?;
remove_container_resources(&mut sandbox)?; remove_container_resources(&mut sandbox)?;
@ -249,43 +254,47 @@ impl agentService {
// timeout != 0 // timeout != 0
let s = self.sandbox.clone(); let s = self.sandbox.clone();
let cid2 = cid.clone(); let cid2 = cid.clone();
let (tx, rx) = mpsc::channel(); let (tx, rx) = tokio::sync::oneshot::channel::<i32>();
let handle = thread::spawn(move || { let handle = tokio::spawn(async move {
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
let _ctr = sandbox if let Some(ctr) = sandbox.get_container(&cid2) {
.get_container(&cid2) ctr.destroy().await.unwrap();
.ok_or_else(|| anyhow!("Invalid container id")) tx.send(1).unwrap();
.map(|ctr| { };
ctr.destroy().unwrap();
tx.send(1).unwrap();
ctr
});
}); });
rx.recv_timeout(Duration::from_secs(req.timeout as u64)) let timeout = tokio::time::delay_for(Duration::from_secs(req.timeout.into()));
.map_err(|_| anyhow!(nix::Error::from_errno(nix::errno::Errno::ETIME)))?;
handle tokio::select! {
.join() _ = rx => {}
.map_err(|_| anyhow!(nix::Error::from_errno(nix::errno::Errno::UnknownErrno)))?; _ = timeout => {
return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::ETIME)));
}
};
if handle.await.is_err() {
return Err(anyhow!(nix::Error::from_errno(
nix::errno::Errno::UnknownErrno
)));
}
let s = self.sandbox.clone(); let s = self.sandbox.clone();
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
remove_container_resources(&mut sandbox)?; remove_container_resources(&mut sandbox)?;
Ok(()) Ok(())
} }
fn do_exec_process(&self, req: protocols::agent::ExecProcessRequest) -> Result<()> { async fn do_exec_process(&self, req: protocols::agent::ExecProcessRequest) -> Result<()> {
let cid = req.container_id.clone(); let cid = req.container_id.clone();
let exec_id = req.exec_id.clone(); let exec_id = req.exec_id.clone();
info!(sl!(), "do_exec_process cid: {} eid: {}", cid, exec_id); info!(sl!(), "do_exec_process cid: {} eid: {}", cid, exec_id);
let s = self.sandbox.clone(); let s = self.sandbox.clone();
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
let process = if req.process.is_some() { let process = if req.process.is_some() {
req.process.as_ref().unwrap() req.process.as_ref().unwrap()
@ -293,7 +302,7 @@ impl agentService {
return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL))); return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL)));
}; };
let pipe_size = AGENT_CONFIG.read().unwrap().container_pipe_size; let pipe_size = AGENT_CONFIG.read().await.container_pipe_size;
let ocip = rustjail::process_grpc_to_oci(process); let ocip = rustjail::process_grpc_to_oci(process);
let p = Process::new(&sl!(), &ocip, exec_id.as_str(), false, pipe_size)?; let p = Process::new(&sl!(), &ocip, exec_id.as_str(), false, pipe_size)?;
@ -301,7 +310,7 @@ impl agentService {
.get_container(&cid) .get_container(&cid)
.ok_or_else(|| anyhow!("Invalid container id"))?; .ok_or_else(|| anyhow!("Invalid container id"))?;
ctr.run(p)?; ctr.run(p).await?;
// set epoller // set epoller
let p = find_process(&mut sandbox, cid.as_str(), exec_id.as_str(), false)?; let p = find_process(&mut sandbox, cid.as_str(), exec_id.as_str(), false)?;
@ -310,11 +319,11 @@ impl agentService {
Ok(()) Ok(())
} }
fn do_signal_process(&self, req: protocols::agent::SignalProcessRequest) -> Result<()> { async fn do_signal_process(&self, req: protocols::agent::SignalProcessRequest) -> Result<()> {
let cid = req.container_id.clone(); let cid = req.container_id.clone();
let eid = req.exec_id.clone(); let eid = req.exec_id.clone();
let s = self.sandbox.clone(); let s = self.sandbox.clone();
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
let mut init = false; let mut init = false;
info!( info!(
@ -344,7 +353,7 @@ impl agentService {
Ok(()) Ok(())
} }
fn do_wait_process( async fn do_wait_process(
&self, &self,
req: protocols::agent::WaitProcessRequest, req: protocols::agent::WaitProcessRequest,
) -> Result<protocols::agent::WaitProcessResponse> { ) -> Result<protocols::agent::WaitProcessResponse> {
@ -353,9 +362,9 @@ impl agentService {
let s = self.sandbox.clone(); let s = self.sandbox.clone();
let mut resp = WaitProcessResponse::new(); let mut resp = WaitProcessResponse::new();
let pid: pid_t; let pid: pid_t;
let mut exit_pipe_r: RawFd = -1; let stream;
let mut buf: Vec<u8> = vec![0, 1];
let (exit_send, exit_recv) = channel(); let (exit_send, mut exit_recv) = tokio::sync::mpsc::channel(100);
info!( info!(
sl!(), sl!(),
@ -365,24 +374,24 @@ impl agentService {
); );
{ {
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false)?; let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false)?;
if p.exit_pipe_r.is_some() { stream = p.get_reader(StreamType::ExitPipeR);
exit_pipe_r = p.exit_pipe_r.unwrap();
}
p.exit_watchers.push(exit_send); p.exit_watchers.push(exit_send);
pid = p.pid; pid = p.pid;
} }
if exit_pipe_r != -1 { if stream.is_some() {
info!(sl!(), "reading exit pipe"); info!(sl!(), "reading exit pipe");
let _ = unistd::read(exit_pipe_r, buf.as_mut_slice());
let reader = stream.unwrap();
let mut content: Vec<u8> = vec![0, 1];
let _ = reader.lock().await.read(&mut content).await;
} }
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
let ctr = sandbox let ctr = sandbox
.get_container(&cid) .get_container(&cid)
.ok_or_else(|| anyhow!("Invalid container id"))?; .ok_or_else(|| anyhow!("Invalid container id"))?;
@ -391,44 +400,20 @@ impl agentService {
Some(p) => p, Some(p) => p,
None => { None => {
// Lost race, pick up exit code from channel // Lost race, pick up exit code from channel
resp.status = exit_recv.recv().unwrap(); resp.status = exit_recv.recv().await.unwrap();
return Ok(resp); return Ok(resp);
} }
}; };
// need to close all fds // need to close all fd
if p.parent_stdin.is_some() { // ignore errors for some fd might be closed by stream
let _ = unistd::close(p.parent_stdin.unwrap()); let _ = cleanup_process(&mut p);
}
if p.parent_stdout.is_some() {
let _ = unistd::close(p.parent_stdout.unwrap());
}
if p.parent_stderr.is_some() {
let _ = unistd::close(p.parent_stderr.unwrap());
}
if p.term_master.is_some() {
let _ = unistd::close(p.term_master.unwrap());
}
if p.exit_pipe_r.is_some() {
let _ = unistd::close(p.exit_pipe_r.unwrap());
}
p.close_epoller();
p.parent_stdin = None;
p.parent_stdout = None;
p.parent_stderr = None;
p.term_master = None;
resp.status = p.exit_code; resp.status = p.exit_code;
// broadcast exit code to all parallel watchers // broadcast exit code to all parallel watchers
for s in p.exit_watchers.iter() { for s in p.exit_watchers.iter_mut() {
// Just ignore errors in case any watcher quits unexpectedly // Just ignore errors in case any watcher quits unexpectedly
let _ = s.send(p.exit_code); let _ = s.send(p.exit_code).await;
} }
ctr.processes.remove(&pid); ctr.processes.remove(&pid);
@ -436,48 +421,37 @@ impl agentService {
Ok(resp) Ok(resp)
} }
fn do_write_stream( async fn do_write_stream(
&self, &self,
req: protocols::agent::WriteStreamRequest, req: protocols::agent::WriteStreamRequest,
) -> Result<protocols::agent::WriteStreamResponse> { ) -> Result<protocols::agent::WriteStreamResponse> {
let cid = req.container_id.clone(); let cid = req.container_id.clone();
let eid = req.exec_id.clone(); let eid = req.exec_id.clone();
let s = self.sandbox.clone(); let writer = {
let mut sandbox = s.lock().unwrap(); let s = self.sandbox.clone();
let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false)?; let mut sandbox = s.lock().await;
let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false)?;
// use ptmx io // use ptmx io
let fd = if p.term_master.is_some() { if p.term_master.is_some() {
p.term_master.unwrap() p.get_writer(StreamType::TermMaster)
} else { } else {
// use piped io // use piped io
p.parent_stdin.unwrap() p.get_writer(StreamType::ParentStdin)
}
}; };
let mut l = req.data.len(); let writer = writer.unwrap();
match unistd::write(fd, req.data.as_slice()) { writer.lock().await.write_all(req.data.as_slice()).await?;
Ok(v) => {
if v < l {
info!(sl!(), "write {} bytes", v);
l = v;
}
}
Err(e) => match e {
nix::Error::Sys(nix::errno::Errno::EAGAIN) => l = 0,
_ => {
return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EIO)));
}
},
}
let mut resp = WriteStreamResponse::new(); let mut resp = WriteStreamResponse::new();
resp.set_len(l as u32); resp.set_len(req.data.len() as u32);
Ok(resp) Ok(resp)
} }
fn do_read_stream( async fn do_read_stream(
&self, &self,
req: protocols::agent::ReadStreamRequest, req: protocols::agent::ReadStreamRequest,
stdout: bool, stdout: bool,
@ -485,42 +459,35 @@ impl agentService {
let cid = req.container_id; let cid = req.container_id;
let eid = req.exec_id; let eid = req.exec_id;
let mut fd: RawFd = -1; // let mut fd: RawFd = -1;
let mut epoller: Option<reaper::Epoller> = None; // let mut epoller: Option<reaper::Epoller> = None;
{
let reader = {
let s = self.sandbox.clone(); let s = self.sandbox.clone();
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false)?; let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false)?;
if p.term_master.is_some() { if p.term_master.is_some() {
fd = p.term_master.unwrap(); // epoller = p.epoller.clone();
epoller = p.epoller.clone(); p.get_reader(StreamType::TermMaster)
} else if stdout { } else if stdout {
if p.parent_stdout.is_some() { if p.parent_stdout.is_some() {
fd = p.parent_stdout.unwrap(); p.get_reader(StreamType::ParentStdout)
} else {
None
} }
} else { } else {
fd = p.parent_stderr.unwrap(); p.get_reader(StreamType::ParentStderr)
} }
} };
if let Some(epoller) = epoller { if reader.is_none() {
// The process's epoller's poll() will return a file descriptor of the process's
// terminal or one end of its exited pipe. If it returns its terminal, it means
// there is data needed to be read out or it has been closed; if it returns the
// process's exited pipe, it means the process has exited and there is no data
// needed to be read out in its terminal, thus following read on it will read out
// "EOF" to terminate this process's io since the other end of this pipe has been
// closed in reap().
fd = epoller.poll()?;
}
if fd == -1 {
return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL))); return Err(anyhow!(nix::Error::from_errno(nix::errno::Errno::EINVAL)));
} }
let vector = read_stream(fd, req.len as usize)?; let reader = reader.unwrap();
let vector = read_stream(reader, req.len as usize).await?;
let mut resp = ReadStreamResponse::new(); let mut resp = ReadStreamResponse::new();
resp.set_data(vector); resp.set_data(vector);
@ -536,7 +503,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
_ctx: &TtrpcContext, _ctx: &TtrpcContext,
req: protocols::agent::CreateContainerRequest, req: protocols::agent::CreateContainerRequest,
) -> ttrpc::Result<Empty> { ) -> ttrpc::Result<Empty> {
match self.do_create_container(req) { match self.do_create_container(req).await {
Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())),
Ok(_) => Ok(Empty::new()), Ok(_) => Ok(Empty::new()),
} }
@ -547,7 +514,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
_ctx: &TtrpcContext, _ctx: &TtrpcContext,
req: protocols::agent::StartContainerRequest, req: protocols::agent::StartContainerRequest,
) -> ttrpc::Result<Empty> { ) -> ttrpc::Result<Empty> {
match self.do_start_container(req) { match self.do_start_container(req).await {
Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())),
Ok(_) => Ok(Empty::new()), Ok(_) => Ok(Empty::new()),
} }
@ -558,7 +525,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
_ctx: &TtrpcContext, _ctx: &TtrpcContext,
req: protocols::agent::RemoveContainerRequest, req: protocols::agent::RemoveContainerRequest,
) -> ttrpc::Result<Empty> { ) -> ttrpc::Result<Empty> {
match self.do_remove_container(req) { match self.do_remove_container(req).await {
Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())),
Ok(_) => Ok(Empty::new()), Ok(_) => Ok(Empty::new()),
} }
@ -569,7 +536,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
_ctx: &TtrpcContext, _ctx: &TtrpcContext,
req: protocols::agent::ExecProcessRequest, req: protocols::agent::ExecProcessRequest,
) -> ttrpc::Result<Empty> { ) -> ttrpc::Result<Empty> {
match self.do_exec_process(req) { match self.do_exec_process(req).await {
Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())),
Ok(_) => Ok(Empty::new()), Ok(_) => Ok(Empty::new()),
} }
@ -580,7 +547,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
_ctx: &TtrpcContext, _ctx: &TtrpcContext,
req: protocols::agent::SignalProcessRequest, req: protocols::agent::SignalProcessRequest,
) -> ttrpc::Result<Empty> { ) -> ttrpc::Result<Empty> {
match self.do_signal_process(req) { match self.do_signal_process(req).await {
Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())),
Ok(_) => Ok(Empty::new()), Ok(_) => Ok(Empty::new()),
} }
@ -592,6 +559,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
req: protocols::agent::WaitProcessRequest, req: protocols::agent::WaitProcessRequest,
) -> ttrpc::Result<WaitProcessResponse> { ) -> ttrpc::Result<WaitProcessResponse> {
self.do_wait_process(req) self.do_wait_process(req)
.await
.map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())) .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))
} }
@ -606,7 +574,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
let mut resp = ListProcessesResponse::new(); let mut resp = ListProcessesResponse::new();
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
let ctr = sandbox.get_container(&cid).ok_or_else(|| { let ctr = sandbox.get_container(&cid).ok_or_else(|| {
ttrpc_error( ttrpc_error(
@ -637,10 +605,11 @@ impl protocols::agent_ttrpc::AgentService for agentService {
args = vec!["-ef".to_string()]; args = vec!["-ef".to_string()];
} }
let output = Command::new("ps") let output = tokio::process::Command::new("ps")
.args(args.as_slice()) .args(args.as_slice())
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.output() .output()
.await
.expect("ps failed"); .expect("ps failed");
let out: String = String::from_utf8(output.stdout).unwrap(); let out: String = String::from_utf8(output.stdout).unwrap();
@ -688,7 +657,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
let res = req.resources; let res = req.resources;
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
let ctr = sandbox.get_container(&cid).ok_or_else(|| { let ctr = sandbox.get_container(&cid).ok_or_else(|| {
ttrpc_error( ttrpc_error(
@ -720,7 +689,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
) -> ttrpc::Result<StatsContainerResponse> { ) -> ttrpc::Result<StatsContainerResponse> {
let cid = req.container_id; let cid = req.container_id;
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
let ctr = sandbox.get_container(&cid).ok_or_else(|| { let ctr = sandbox.get_container(&cid).ok_or_else(|| {
ttrpc_error( ttrpc_error(
@ -740,7 +709,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
) -> ttrpc::Result<protocols::empty::Empty> { ) -> ttrpc::Result<protocols::empty::Empty> {
let cid = req.get_container_id(); let cid = req.get_container_id();
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
let ctr = sandbox.get_container(&cid).ok_or_else(|| { let ctr = sandbox.get_container(&cid).ok_or_else(|| {
ttrpc_error( ttrpc_error(
@ -762,7 +731,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
) -> ttrpc::Result<protocols::empty::Empty> { ) -> ttrpc::Result<protocols::empty::Empty> {
let cid = req.get_container_id(); let cid = req.get_container_id();
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
let ctr = sandbox.get_container(&cid).ok_or_else(|| { let ctr = sandbox.get_container(&cid).ok_or_else(|| {
ttrpc_error( ttrpc_error(
@ -783,6 +752,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
req: protocols::agent::WriteStreamRequest, req: protocols::agent::WriteStreamRequest,
) -> ttrpc::Result<WriteStreamResponse> { ) -> ttrpc::Result<WriteStreamResponse> {
self.do_write_stream(req) self.do_write_stream(req)
.await
.map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())) .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))
} }
@ -792,6 +762,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
req: protocols::agent::ReadStreamRequest, req: protocols::agent::ReadStreamRequest,
) -> ttrpc::Result<ReadStreamResponse> { ) -> ttrpc::Result<ReadStreamResponse> {
self.do_read_stream(req, true) self.do_read_stream(req, true)
.await
.map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())) .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))
} }
@ -801,6 +772,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
req: protocols::agent::ReadStreamRequest, req: protocols::agent::ReadStreamRequest,
) -> ttrpc::Result<ReadStreamResponse> { ) -> ttrpc::Result<ReadStreamResponse> {
self.do_read_stream(req, false) self.do_read_stream(req, false)
.await
.map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())) .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))
} }
@ -812,7 +784,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
let cid = req.container_id.clone(); let cid = req.container_id.clone();
let eid = req.exec_id; let eid = req.exec_id;
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false).map_err(|e| { let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false).map_err(|e| {
ttrpc_error( ttrpc_error(
@ -822,11 +794,13 @@ impl protocols::agent_ttrpc::AgentService for agentService {
})?; })?;
if p.term_master.is_some() { if p.term_master.is_some() {
p.close_stream(StreamType::TermMaster);
let _ = unistd::close(p.term_master.unwrap()); let _ = unistd::close(p.term_master.unwrap());
p.term_master = None; p.term_master = None;
} }
if p.parent_stdin.is_some() { if p.parent_stdin.is_some() {
p.close_stream(StreamType::ParentStdin);
let _ = unistd::close(p.parent_stdin.unwrap()); let _ = unistd::close(p.parent_stdin.unwrap());
p.parent_stdin = None; p.parent_stdin = None;
} }
@ -844,7 +818,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
let cid = req.container_id.clone(); let cid = req.container_id.clone();
let eid = req.exec_id.clone(); let eid = req.exec_id.clone();
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false).map_err(|e| { let p = find_process(&mut sandbox, cid.as_str(), eid.as_str(), false).map_err(|e| {
ttrpc_error( ttrpc_error(
ttrpc::Code::UNAVAILABLE, ttrpc::Code::UNAVAILABLE,
@ -888,7 +862,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
let interface = req.interface; let interface = req.interface;
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
if sandbox.rtnl.is_none() { if sandbox.rtnl.is_none() {
sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap()); sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap());
@ -921,7 +895,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
let rs = req.routes.unwrap().Routes.into_vec(); let rs = req.routes.unwrap().Routes.into_vec();
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
if sandbox.rtnl.is_none() { if sandbox.rtnl.is_none() {
sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap()); sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap());
@ -951,7 +925,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
) -> ttrpc::Result<Interfaces> { ) -> ttrpc::Result<Interfaces> {
let mut interface = protocols::agent::Interfaces::new(); let mut interface = protocols::agent::Interfaces::new();
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
if sandbox.rtnl.is_none() { if sandbox.rtnl.is_none() {
sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap()); sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap());
@ -974,7 +948,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
) -> ttrpc::Result<Routes> { ) -> ttrpc::Result<Routes> {
let mut routes = protocols::agent::Routes::new(); let mut routes = protocols::agent::Routes::new();
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
if sandbox.rtnl.is_none() { if sandbox.rtnl.is_none() {
sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap()); sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap());
@ -1015,7 +989,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
) -> ttrpc::Result<Empty> { ) -> ttrpc::Result<Empty> {
{ {
let sandbox = self.sandbox.clone(); let sandbox = self.sandbox.clone();
let mut s = sandbox.lock().unwrap(); let mut s = sandbox.lock().await;
let _ = fs::remove_dir_all(CONTAINER_BASE); let _ = fs::remove_dir_all(CONTAINER_BASE);
let _ = fs::create_dir_all(CONTAINER_BASE); let _ = fs::create_dir_all(CONTAINER_BASE);
@ -1042,13 +1016,14 @@ impl protocols::agent_ttrpc::AgentService for agentService {
} }
s.setup_shared_namespaces() s.setup_shared_namespaces()
.await
.map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))?; .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))?;
} }
match add_storages(sl!(), req.storages.to_vec(), self.sandbox.clone()) { match add_storages(sl!(), req.storages.to_vec(), self.sandbox.clone()).await {
Ok(m) => { Ok(m) => {
let sandbox = self.sandbox.clone(); let sandbox = self.sandbox.clone();
let mut s = sandbox.lock().unwrap(); let mut s = sandbox.lock().await;
s.mounts = m s.mounts = m
} }
Err(e) => return Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Err(e) => return Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())),
@ -1057,7 +1032,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
match setup_guest_dns(sl!(), req.dns.to_vec()) { match setup_guest_dns(sl!(), req.dns.to_vec()) {
Ok(_) => { Ok(_) => {
let sandbox = self.sandbox.clone(); let sandbox = self.sandbox.clone();
let mut s = sandbox.lock().unwrap(); let mut s = sandbox.lock().await;
let _dns = req let _dns = req
.dns .dns
.to_vec() .to_vec()
@ -1076,13 +1051,11 @@ impl protocols::agent_ttrpc::AgentService for agentService {
_req: protocols::agent::DestroySandboxRequest, _req: protocols::agent::DestroySandboxRequest,
) -> ttrpc::Result<Empty> { ) -> ttrpc::Result<Empty> {
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
// destroy all containers, clean up, notify agent to exit // destroy all containers, clean up, notify agent to exit
// etc. // etc.
sandbox.destroy().unwrap(); sandbox.destroy().await.unwrap();
sandbox.sender.take().unwrap().send(1).unwrap();
sandbox.sender.as_ref().unwrap().send(1).unwrap();
sandbox.sender = None;
Ok(Empty::new()) Ok(Empty::new())
} }
@ -1102,7 +1075,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
let neighs = req.neighbors.unwrap().ARPNeighbors.into_vec(); let neighs = req.neighbors.unwrap().ARPNeighbors.into_vec();
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let mut sandbox = s.lock().unwrap(); let mut sandbox = s.lock().await;
if sandbox.rtnl.is_none() { if sandbox.rtnl.is_none() {
sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap()); sandbox.rtnl = Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap());
@ -1122,7 +1095,7 @@ impl protocols::agent_ttrpc::AgentService for agentService {
req: protocols::agent::OnlineCPUMemRequest, req: protocols::agent::OnlineCPUMemRequest,
) -> ttrpc::Result<Empty> { ) -> ttrpc::Result<Empty> {
let s = Arc::clone(&self.sandbox); let s = Arc::clone(&self.sandbox);
let sandbox = s.lock().unwrap(); let sandbox = s.lock().await;
sandbox sandbox
.online_cpu_memory(&req) .online_cpu_memory(&req)
@ -1221,15 +1194,15 @@ impl protocols::agent_ttrpc::AgentService for agentService {
_req: protocols::agent::GetOOMEventRequest, _req: protocols::agent::GetOOMEventRequest,
) -> ttrpc::Result<OOMEvent> { ) -> ttrpc::Result<OOMEvent> {
let sandbox = self.sandbox.clone(); let sandbox = self.sandbox.clone();
let s = sandbox.lock().unwrap(); let s = sandbox.lock().await;
let event_rx = &s.event_rx.clone(); let event_rx = &s.event_rx.clone();
let event_rx = event_rx.lock().unwrap(); let mut event_rx = event_rx.lock().await;
drop(s); drop(s);
drop(sandbox); drop(sandbox);
match event_rx.recv() { match event_rx.recv().await {
Err(err) => Err(ttrpc_error(ttrpc::Code::INTERNAL, err.to_string())), None => Err(ttrpc_error(ttrpc::Code::INTERNAL, "")),
Ok(container_id) => { Some(container_id) => {
info!(sl!(), "get_oom_event return {}", &container_id); info!(sl!(), "get_oom_event return {}", &container_id);
let mut resp = OOMEvent::new(); let mut resp = OOMEvent::new();
resp.container_id = container_id; resp.container_id = container_id;
@ -1326,42 +1299,28 @@ fn get_agent_details() -> AgentDetails {
detail.device_handlers = RepeatedField::new(); detail.device_handlers = RepeatedField::new();
detail.storage_handlers = RepeatedField::from_vec( detail.storage_handlers = RepeatedField::from_vec(
STORAGEHANDLERLIST STORAGE_HANDLER_LIST
.keys() .to_vec()
.cloned() .iter()
.map(|x| x.into()) .map(|x| x.to_string())
.collect(), .collect(),
); );
detail detail
} }
fn read_stream(fd: RawFd, l: usize) -> Result<Vec<u8>> { async fn read_stream(reader: Arc<Mutex<ReadHalf<PipeStream>>>, l: usize) -> Result<Vec<u8>> {
let mut v: Vec<u8> = Vec::with_capacity(l); let mut content = vec![0u8; l];
unsafe {
v.set_len(l); let mut reader = reader.lock().await;
let len = reader.read(&mut content).await?;
content.resize(len, 0);
if len == 0 {
return Err(anyhow!("read meet eof"));
} }
match unistd::read(fd, v.as_mut_slice()) { Ok(content)
Ok(len) => {
v.resize(len, 0);
// Rust didn't return an EOF error when the reading peer point
// was closed, instead it would return a 0 reading length, please
// see https://github.com/rust-lang/rfcs/blob/master/text/0517-io-os-reform.md#errors
if len == 0 {
return Err(anyhow!("read meet eof"));
}
}
Err(e) => match e {
nix::Error::Sys(errno) => match errno {
Errno::EAGAIN => v.clear(),
_ => return Err(anyhow!(nix::Error::Sys(errno))),
},
_ => return Err(anyhow!("read error")),
},
}
Ok(v)
} }
fn find_process<'a>( fn find_process<'a>(
@ -1384,7 +1343,7 @@ fn find_process<'a>(
ctr.get_process(eid).map_err(|_| anyhow!("Invalid exec id")) ctr.get_process(eid).map_err(|_| anyhow!("Invalid exec id"))
} }
pub async fn start(s: Arc<Mutex<Sandbox>>, server_address: &str) -> Result<TtrpcServer> { pub fn start(s: Arc<Mutex<Sandbox>>, server_address: &str) -> TtrpcServer {
let agent_service = Box::new(agentService { sandbox: s }) let agent_service = Box::new(agentService { sandbox: s })
as Box<dyn protocols::agent_ttrpc::AgentService + Send + Sync>; as Box<dyn protocols::agent_ttrpc::AgentService + Send + Sync>;
@ -1399,13 +1358,14 @@ pub async fn start(s: Arc<Mutex<Sandbox>>, server_address: &str) -> Result<Ttrpc
let hservice = protocols::health_ttrpc::create_health(health_worker); let hservice = protocols::health_ttrpc::create_health(health_worker);
let server = TtrpcServer::new() let server = TtrpcServer::new()
.bind(server_address)? .bind(server_address)
.unwrap()
.register_service(aservice) .register_service(aservice)
.register_service(hservice); .register_service(hservice);
info!(sl!(), "ttRPC server started"; "address" => server_address); info!(sl!(), "ttRPC server started"; "address" => server_address);
Ok(server) server
} }
// This function updates the container namespaces configuration based on the // This function updates the container namespaces configuration based on the
@ -1641,6 +1601,42 @@ fn setup_bundle(cid: &str, spec: &mut Spec) -> Result<PathBuf> {
Ok(olddir) Ok(olddir)
} }
fn cleanup_process(p: &mut Process) -> Result<()> {
if p.parent_stdin.is_some() {
p.close_stream(StreamType::ParentStdin);
let _ = unistd::close(p.parent_stdin.unwrap())?;
}
if p.parent_stdout.is_some() {
p.close_stream(StreamType::ParentStdout);
let _ = unistd::close(p.parent_stdout.unwrap())?;
}
if p.parent_stderr.is_some() {
p.close_stream(StreamType::ParentStderr);
let _ = unistd::close(p.parent_stderr.unwrap())?;
}
if p.term_master.is_some() {
p.close_stream(StreamType::TermMaster);
let _ = unistd::close(p.term_master.unwrap())?;
}
if p.exit_pipe_r.is_some() {
p.close_stream(StreamType::ExitPipeR);
let _ = unistd::close(p.exit_pipe_r.unwrap())?;
}
p.close_epoller();
p.parent_stdin = None;
p.parent_stdout = None;
p.parent_stderr = None;
p.term_master = None;
Ok(())
}
fn load_kernel_module(module: &protocols::agent::KernelModule) -> Result<()> { fn load_kernel_module(module: &protocols::agent::KernelModule) -> Result<()> {
if module.name == "" { if module.name == "" {
return Err(anyhow!("Kernel module name is empty")); return Err(anyhow!("Kernel module name is empty"));

View File

@ -22,9 +22,10 @@ use std::collections::HashMap;
use std::fs; use std::fs;
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
use std::path::Path; use std::path::Path;
use std::sync::mpsc::{self, Receiver, Sender}; use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::{thread, time}; use std::{thread, time};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::Mutex;
#[derive(Debug)] #[derive(Debug)]
pub struct Sandbox { pub struct Sandbox {
@ -42,7 +43,7 @@ pub struct Sandbox {
pub storages: HashMap<String, u32>, pub storages: HashMap<String, u32>,
pub running: bool, pub running: bool,
pub no_pivot_root: bool, pub no_pivot_root: bool,
pub sender: Option<Sender<i32>>, pub sender: Option<tokio::sync::oneshot::Sender<i32>>,
pub rtnl: Option<RtnlHandle>, pub rtnl: Option<RtnlHandle>,
pub hooks: Option<Hooks>, pub hooks: Option<Hooks>,
pub event_rx: Arc<Mutex<Receiver<String>>>, pub event_rx: Arc<Mutex<Receiver<String>>>,
@ -53,7 +54,7 @@ impl Sandbox {
pub fn new(logger: &Logger) -> Result<Self> { pub fn new(logger: &Logger) -> Result<Self> {
let fs_type = get_mount_fs_type("/")?; let fs_type = get_mount_fs_type("/")?;
let logger = logger.new(o!("subsystem" => "sandbox")); let logger = logger.new(o!("subsystem" => "sandbox"));
let (tx, rx) = mpsc::channel::<String>(); let (tx, rx) = channel::<String>(100);
let event_rx = Arc::new(Mutex::new(rx)); let event_rx = Arc::new(Mutex::new(rx));
Ok(Sandbox { Ok(Sandbox {
@ -157,17 +158,19 @@ impl Sandbox {
self.hostname = hostname; self.hostname = hostname;
} }
pub fn setup_shared_namespaces(&mut self) -> Result<bool> { pub async fn setup_shared_namespaces(&mut self) -> Result<bool> {
// Set up shared IPC namespace // Set up shared IPC namespace
self.shared_ipcns = Namespace::new(&self.logger) self.shared_ipcns = Namespace::new(&self.logger)
.get_ipc() .get_ipc()
.setup() .setup()
.await
.context("Failed to setup persistent IPC namespace")?; .context("Failed to setup persistent IPC namespace")?;
// // Set up shared UTS namespace // // Set up shared UTS namespace
self.shared_utsns = Namespace::new(&self.logger) self.shared_utsns = Namespace::new(&self.logger)
.get_uts(self.hostname.as_str()) .get_uts(self.hostname.as_str())
.setup() .setup()
.await
.context("Failed to setup persistent UTS namespace")?; .context("Failed to setup persistent UTS namespace")?;
Ok(true) Ok(true)
@ -214,9 +217,9 @@ impl Sandbox {
None None
} }
pub fn destroy(&mut self) -> Result<()> { pub async fn destroy(&mut self) -> Result<()> {
for ctr in self.containers.values_mut() { for ctr in self.containers.values_mut() {
ctr.destroy()?; ctr.destroy().await?;
} }
Ok(()) Ok(())
} }
@ -315,15 +318,17 @@ impl Sandbox {
Ok(hooks) Ok(hooks)
} }
pub fn run_oom_event_monitor(&self, rx: Receiver<String>, container_id: String) { pub async fn run_oom_event_monitor(&self, mut rx: Receiver<String>, container_id: String) {
let tx = self.event_tx.clone(); let mut tx = self.event_tx.clone();
let logger = self.logger.clone(); let logger = self.logger.clone();
thread::spawn(move || { tokio::spawn(async move {
for event in rx { loop {
let event = rx.recv().await;
info!(logger, "got an OOM event {:?}", event); info!(logger, "got an OOM event {:?}", event);
let _ = tx let _ = tx
.send(container_id.clone()) .send(container_id.clone())
.await
.map_err(|e| error!(logger, "failed to send message: {:?}", e)); .map_err(|e| error!(logger, "failed to send message: {:?}", e));
} }
}); });

View File

@ -7,10 +7,13 @@ use crate::device::online_device;
use crate::linux_abi::*; use crate::linux_abi::*;
use crate::sandbox::Sandbox; use crate::sandbox::Sandbox;
use crate::GLOBAL_DEVICE_WATCHER; use crate::GLOBAL_DEVICE_WATCHER;
use netlink::{RtnlHandle, NETLINK_UEVENT};
use slog::Logger; use slog::Logger;
use std::sync::{Arc, Mutex};
use std::thread; use netlink_sys::{Protocol, Socket, SocketAddr};
use nix::errno::Errno;
use std::os::unix::io::FromRawFd;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Default)] #[derive(Debug, Default)]
struct Uevent { struct Uevent {
@ -55,12 +58,13 @@ impl Uevent {
&& self.devname != "" && self.devname != ""
} }
fn handle_block_add_event(&self, sandbox: &Arc<Mutex<Sandbox>>) { async fn handle_block_add_event(&self, sandbox: &Arc<Mutex<Sandbox>>) {
let pci_root_bus_path = create_pci_root_bus_path(); let pci_root_bus_path = create_pci_root_bus_path();
// Keep the same lock order as device::get_device_name(), otherwise it may cause deadlock. // Keep the same lock order as device::get_device_name(), otherwise it may cause deadlock.
let mut w = GLOBAL_DEVICE_WATCHER.lock().unwrap(); let watcher = GLOBAL_DEVICE_WATCHER.clone();
let mut sb = sandbox.lock().unwrap(); let mut w = watcher.lock().await;
let mut sb = sandbox.lock().await;
// Add the device node name to the pci device map. // Add the device node name to the pci device map.
sb.pci_device_map sb.pci_device_map
@ -70,7 +74,7 @@ impl Uevent {
// Close the channel after watcher has been notified. // Close the channel after watcher has been notified.
let devpath = self.devpath.clone(); let devpath = self.devpath.clone();
let empties: Vec<_> = w let empties: Vec<_> = w
.iter() .iter_mut()
.filter(|(dev_addr, _)| { .filter(|(dev_addr, _)| {
let pci_p = format!("{}/{}", pci_root_bus_path, *dev_addr); let pci_p = format!("{}/{}", pci_root_bus_path, *dev_addr);
@ -84,6 +88,7 @@ impl Uevent {
}) })
.map(|(k, sender)| { .map(|(k, sender)| {
let devname = self.devname.clone(); let devname = self.devname.clone();
let sender = sender.take().unwrap();
let _ = sender.send(devname); let _ = sender.send(devname);
k.clone() k.clone()
}) })
@ -95,9 +100,9 @@ impl Uevent {
} }
} }
fn process(&self, logger: &Logger, sandbox: &Arc<Mutex<Sandbox>>) { async fn process(&self, logger: &Logger, sandbox: &Arc<Mutex<Sandbox>>) {
if self.is_block_add_event() { if self.is_block_add_event() {
return self.handle_block_add_event(sandbox); return self.handle_block_add_event(sandbox).await;
} else if self.action == U_EVENT_ACTION_ADD { } else if self.action == U_EVENT_ACTION_ADD {
let online_path = format!("{}/{}/online", SYSFS_DIR, &self.devpath); let online_path = format!("{}/{}/online", SYSFS_DIR, &self.devpath);
// It's a memory hot-add event. // It's a memory hot-add event.
@ -117,22 +122,37 @@ impl Uevent {
} }
} }
pub fn watch_uevents(sandbox: Arc<Mutex<Sandbox>>) { pub async fn watch_uevents(sandbox: Arc<Mutex<Sandbox>>) {
thread::spawn(move || { let sref = sandbox.clone();
let rtnl = RtnlHandle::new(NETLINK_UEVENT, 1).unwrap(); let s = sref.lock().await;
let logger = sandbox let logger = s.logger.new(o!("subsystem" => "uevent"));
.lock()
.unwrap() tokio::spawn(async move {
.logger let mut socket;
.new(o!("subsystem" => "uevent")); unsafe {
let fd = libc::socket(
libc::AF_NETLINK,
libc::SOCK_DGRAM | libc::SOCK_CLOEXEC,
Protocol::KObjectUevent as libc::c_int,
);
socket = Socket::from_raw_fd(fd);
}
socket.bind(&SocketAddr::new(0, 1)).unwrap();
loop { loop {
match rtnl.recv_message() { match socket.recv_from_full().await {
Err(e) => { Err(e) => {
error!(logger, "receive uevent message failed"; "error" => format!("{}", e)) error!(logger, "receive uevent message failed"; "error" => format!("{}", e))
} }
Ok(data) => { Ok((buf, addr)) => {
let text = String::from_utf8(data); if addr.port_number() != 0 {
// not our netlink message
let err_msg = format!("{:?}", nix::Error::Sys(Errno::EBADMSG));
error!(logger, "receive uevent message failed"; "error" => err_msg);
return;
}
let text = String::from_utf8(buf);
match text { match text {
Err(e) => { Err(e) => {
error!(logger, "failed to convert bytes to text"; "error" => format!("{}", e)) error!(logger, "failed to convert bytes to text"; "error" => format!("{}", e))
@ -140,7 +160,7 @@ pub fn watch_uevents(sandbox: Arc<Mutex<Sandbox>>) {
Ok(text) => { Ok(text) => {
let event = Uevent::new(&text); let event = Uevent::new(&text);
info!(logger, "got uevent message"; "event" => format!("{:?}", event)); info!(logger, "got uevent message"; "event" => format!("{:?}", event));
event.process(&logger, &sandbox); event.process(&logger, &sandbox).await;
} }
} }
} }