agent: Enable clean shutdown

The agent doesn't normally shutdown: it doesn't need to be as it is
killed *after* the workload has finished. However, a clean and ordered
shutdown sequence is required to support agent tracing, since all trace
spans need to be completed to ensure a valid trace transaction.

Enable a controlled shutdown by allowing the main threads (tasks) to be
stopped.

To allow this to happen, each thread is now passed a shutdown channel
which it must listen to asynchronously, and shut down the thread if
activity is detected on that channel.

Since some threads are created for I/O and since the standard `io::copy`
cannot be stopped, added a new `interruptable_io_copier()` function
which shares the same semantics as `io::copy()`, but which is also
passed a shutdown channel to allow asynchronous I/O operations to be
stopped cleanly.

Fixes: #1531.

Signed-off-by: James O. D. Hunt <james.o.hunt@intel.com>
This commit is contained in:
James O. D. Hunt 2021-02-28 16:49:28 +00:00
parent dcb39c61f1
commit 7d5f88c0ad
4 changed files with 384 additions and 6 deletions

20
src/agent/Cargo.lock generated
View File

@ -387,6 +387,15 @@ dependencies = [
"unicode-segmentation", "unicode-segmentation",
] ]
[[package]]
name = "hermit-abi"
version = "0.1.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "322f4de77956e22ed0e5032c359a0f1273f1f7f0d79bfa3b8ffbc730d7fbcc5c"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "hex" name = "hex"
version = "0.4.2" version = "0.4.2"
@ -753,6 +762,16 @@ dependencies = [
"autocfg", "autocfg",
] ]
[[package]]
name = "num_cpus"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
dependencies = [
"hermit-abi",
"libc",
]
[[package]] [[package]]
name = "object" name = "object"
version = "0.22.0" version = "0.22.0"
@ -1483,6 +1502,7 @@ dependencies = [
"libc", "libc",
"memchr", "memchr",
"mio", "mio",
"num_cpus",
"once_cell", "once_cell",
"pin-project-lite 0.2.4", "pin-project-lite 0.2.4",
"signal-hook-registry", "signal-hook-registry",

View File

@ -21,7 +21,7 @@ scopeguard = "1.0.0"
regex = "1" regex = "1"
async-trait = "0.1.42" async-trait = "0.1.42"
tokio = { version = "1.2.0", features = ["rt", "sync", "macros", "io-util", "time", "signal", "io-std", "process"] } tokio = { version = "1.2.0", features = ["rt", "rt-multi-thread", "sync", "macros", "io-util", "time", "signal", "io-std", "process", "fs"] }
futures = "0.3.12" futures = "0.3.12"
netlink-sys = { version = "0.6.0", features = ["tokio_socket",]} netlink-sys = { version = "0.6.0", features = ["tokio_socket",]}
tokio-vsock = "0.3.0" tokio-vsock = "0.3.0"

View File

@ -55,6 +55,7 @@ mod sandbox;
#[cfg(test)] #[cfg(test)]
mod test_utils; mod test_utils;
mod uevent; mod uevent;
mod util;
mod version; mod version;
use mount::{cgroups_mount, general_mount}; use mount::{cgroups_mount, general_mount};
@ -70,7 +71,11 @@ use rustjail::pipestream::PipeStream;
use tokio::{ use tokio::{
io::AsyncWrite, io::AsyncWrite,
signal::unix::{signal, SignalKind}, signal::unix::{signal, SignalKind},
sync::{oneshot::Sender, Mutex, RwLock}, sync::{
oneshot::Sender,
watch::{channel, Receiver},
Mutex, RwLock,
},
task::JoinHandle, task::JoinHandle,
}; };
use tokio_vsock::{Incoming, VsockListener, VsockStream}; use tokio_vsock::{Incoming, VsockListener, VsockStream};
@ -126,7 +131,7 @@ async fn get_vsock_stream(fd: RawFd) -> Result<VsockStream> {
// Create a thread to handle reading from the logger pipe. The thread will // Create a thread to handle reading from the logger pipe. The thread will
// output to the vsock port specified, or stdout. // output to the vsock port specified, or stdout.
async fn create_logger_task(rfd: RawFd, vsock_port: u32) -> Result<()> { async fn create_logger_task(rfd: RawFd, vsock_port: u32, shutdown: Receiver<bool>) -> Result<()> {
let mut reader = PipeStream::from_fd(rfd); let mut reader = PipeStream::from_fd(rfd);
let mut writer: Box<dyn AsyncWrite + Unpin + Send>; let mut writer: Box<dyn AsyncWrite + Unpin + Send>;
@ -147,7 +152,7 @@ async fn create_logger_task(rfd: RawFd, vsock_port: u32) -> Result<()> {
writer = Box::new(tokio::io::stdout()); writer = Box::new(tokio::io::stdout());
} }
let _ = tokio::io::copy(&mut reader, &mut writer).await; let _ = util::interruptable_io_copier(&mut reader, &mut writer, shutdown).await;
Ok(()) Ok(())
} }
@ -165,6 +170,8 @@ async fn real_main() -> std::result::Result<(), Box<dyn std::error::Error>> {
// support vsock log // support vsock log
let (rfd, wfd) = unistd::pipe2(OFlag::O_CLOEXEC)?; let (rfd, wfd) = unistd::pipe2(OFlag::O_CLOEXEC)?;
let (shutdown_tx, shutdown_rx) = channel(true);
let agent_config = AGENT_CONFIG.clone(); let agent_config = AGENT_CONFIG.clone();
let init_mode = unistd::getpid() == Pid::from_raw(1); let init_mode = unistd::getpid() == Pid::from_raw(1);
@ -203,7 +210,7 @@ async fn real_main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let log_vport = config.log_vport as u32; let log_vport = config.log_vport as u32;
let log_handle = tokio::spawn(create_logger_task(rfd, log_vport)); let log_handle = tokio::spawn(create_logger_task(rfd, log_vport, shutdown_rx.clone()));
tasks.push(log_handle); tasks.push(log_handle);
@ -226,8 +233,15 @@ async fn real_main() -> std::result::Result<(), Box<dyn std::error::Error>> {
_log_guard = Ok(slog_stdlog::init().map_err(|e| e)?); _log_guard = Ok(slog_stdlog::init().map_err(|e| e)?);
} }
// Start the sandbox and wait for its ttRPC server to end
start_sandbox(&logger, &config, init_mode).await?; start_sandbox(&logger, &config, init_mode).await?;
// Trigger a controlled shutdown
shutdown_tx
.send(true)
.map_err(|e| anyhow!(e).context("failed to request shutdown"))?;
// Wait for all threads to finish
let results = join_all(tasks).await; let results = join_all(tasks).await;
for result in results { for result in results {
@ -236,6 +250,8 @@ async fn real_main() -> std::result::Result<(), Box<dyn std::error::Error>> {
} }
} }
eprintln!("{} shutdown complete", NAME);
Ok(()) Ok(())
} }
@ -259,7 +275,7 @@ fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
exit(0); exit(0);
} }
let rt = tokio::runtime::Builder::new_current_thread() let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all() .enable_all()
.build()?; .build()?;

342
src/agent/src/util.rs Normal file
View File

@ -0,0 +1,342 @@
// Copyright (c) 2021 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0
//
use std::io;
use std::io::ErrorKind;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::watch::Receiver;
// Size of I/O read buffer
const BUF_SIZE: usize = 8192;
// Interruptable I/O copy using readers and writers
// (an interruptable version of "io::copy()").
pub async fn interruptable_io_copier<R: Sized, W: Sized>(
mut reader: R,
mut writer: W,
mut shutdown: Receiver<bool>,
) -> io::Result<u64>
where
R: tokio::io::AsyncRead + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
let mut total_bytes: u64 = 0;
let mut buf: [u8; BUF_SIZE] = [0; BUF_SIZE];
loop {
tokio::select! {
_ = shutdown.changed() => {
eprintln!("INFO: interruptable_io_copier: got shutdown request");
break;
},
result = reader.read(&mut buf) => {
let bytes = match result {
Ok(0) => return Ok(total_bytes),
Ok(len) => len,
Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
};
total_bytes += bytes as u64;
// Actually copy the data ;)
writer.write_all(&buf[..bytes]).await?;
},
};
}
Ok(total_bytes)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
use std::io::Cursor;
use std::io::Write;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Poll::Ready};
use tokio::pin;
use tokio::select;
use tokio::sync::watch::channel;
use tokio::task::JoinError;
use tokio::time::Duration;
#[derive(Debug, Default, Clone)]
struct BufWriter {
data: Arc<Mutex<Vec<u8>>>,
slow_write: bool,
write_delay: Duration,
}
impl BufWriter {
fn new() -> Self {
BufWriter {
data: Arc::new(Mutex::new(Vec::<u8>::new())),
slow_write: false,
write_delay: Duration::new(0, 0),
}
}
fn write_vec(&mut self, buf: &[u8]) -> io::Result<usize> {
let vec_ref = self.data.clone();
let mut vec_locked = vec_ref.lock();
let mut v = vec_locked.as_deref_mut().unwrap();
if self.write_delay.as_nanos() > 0 {
std::thread::sleep(self.write_delay);
}
std::io::Write::write(&mut v, buf)
}
}
impl Write for BufWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_vec(buf)
}
fn flush(&mut self) -> io::Result<()> {
let vec_ref = self.data.clone();
let mut vec_locked = vec_ref.lock();
let v = vec_locked.as_deref_mut().unwrap();
std::io::Write::flush(v)
}
}
impl tokio::io::AsyncWrite for BufWriter {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let result = self.write_vec(buf);
Ready(result)
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
// NOP
Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
// NOP
Ready(Ok(()))
}
}
impl ToString for BufWriter {
fn to_string(&self) -> String {
let data_ref = self.data.clone();
let output = data_ref.lock().unwrap();
let s = (*output).clone();
String::from_utf8(s).unwrap()
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_interruptable_io_copier_reader() {
#[derive(Debug)]
struct TestData {
reader_value: String,
result: io::Result<u64>,
}
let tests = &[
TestData {
reader_value: "".into(),
result: Ok(0),
},
TestData {
reader_value: "a".into(),
result: Ok(1),
},
TestData {
reader_value: "foo".into(),
result: Ok(3),
},
TestData {
reader_value: "b".repeat(BUF_SIZE - 1),
result: Ok((BUF_SIZE - 1) as u64),
},
TestData {
reader_value: "c".repeat(BUF_SIZE),
result: Ok((BUF_SIZE) as u64),
},
TestData {
reader_value: "d".repeat(BUF_SIZE + 1),
result: Ok((BUF_SIZE + 1) as u64),
},
TestData {
reader_value: "e".repeat((2 * BUF_SIZE) - 1),
result: Ok(((2 * BUF_SIZE) - 1) as u64),
},
TestData {
reader_value: "f".repeat(2 * BUF_SIZE),
result: Ok((2 * BUF_SIZE) as u64),
},
TestData {
reader_value: "g".repeat((2 * BUF_SIZE) + 1),
result: Ok(((2 * BUF_SIZE) + 1) as u64),
},
];
for (i, d) in tests.iter().enumerate() {
// Create a string containing details of the test
let msg = format!("test[{}]: {:?}", i, d);
let (tx, rx) = channel(true);
let reader = Cursor::new(d.reader_value.clone());
let writer = BufWriter::new();
// XXX: Pass a copy of the writer to the copier to allow the
// result of the write operation to be checked below.
let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx));
// Allow time for the thread to be spawned.
tokio::time::sleep(Duration::from_secs(1)).await;
let timeout = tokio::time::sleep(Duration::from_secs(1));
pin!(timeout);
// Since the readers only specify a small number of bytes, the
// copier will quickly read zero and kill the task, closing the
// Receiver.
assert!(tx.is_closed(), "{}", msg);
let spawn_result: std::result::Result<
std::result::Result<u64, std::io::Error>,
JoinError,
>;
let result: std::result::Result<u64, std::io::Error>;
select! {
res = handle => spawn_result = res,
_ = &mut timeout => panic!("timed out"),
}
assert!(spawn_result.is_ok());
result = spawn_result.unwrap();
assert!(result.is_ok());
let byte_count = result.unwrap() as usize;
assert_eq!(byte_count, d.reader_value.len(), "{}", msg);
let value = writer.to_string();
assert_eq!(value, d.reader_value, "{}", msg);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_interruptable_io_copier_eof() {
// Create an async reader that always returns EOF
let reader = tokio::io::empty();
let (tx, rx) = channel(true);
let writer = BufWriter::new();
let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx));
// Allow time for the thread to be spawned.
tokio::time::sleep(Duration::from_secs(1)).await;
let timeout = tokio::time::sleep(Duration::from_secs(1));
pin!(timeout);
assert!(tx.is_closed());
let spawn_result: std::result::Result<std::result::Result<u64, std::io::Error>, JoinError>;
let result: std::result::Result<u64, std::io::Error>;
select! {
res = handle => spawn_result = res,
_ = &mut timeout => panic!("timed out"),
}
assert!(spawn_result.is_ok());
result = spawn_result.unwrap();
assert!(result.is_ok());
let byte_count = result.unwrap();
assert_eq!(byte_count, 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_interruptable_io_copier_shutdown() {
// Create an async reader that creates an infinite stream of bytes
// (which allows us to interrupt it, since we know it is always busy ;)
const REPEAT_CHAR: u8 = b'r';
let reader = tokio::io::repeat(REPEAT_CHAR);
let (tx, rx) = channel(true);
let writer = BufWriter::new();
let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx));
// Allow time for the thread to be spawned.
tokio::time::sleep(Duration::from_secs(1)).await;
let timeout = tokio::time::sleep(Duration::from_secs(1));
pin!(timeout);
assert!(!tx.is_closed());
tx.send(true).expect("failed to request shutdown");
let spawn_result: std::result::Result<std::result::Result<u64, std::io::Error>, JoinError>;
let result: std::result::Result<u64, std::io::Error>;
select! {
res = handle => spawn_result = res,
_ = &mut timeout => panic!("timed out"),
}
assert!(spawn_result.is_ok());
result = spawn_result.unwrap();
assert!(result.is_ok());
let byte_count = result.unwrap();
let value = writer.to_string();
let writer_byte_count = value.len() as u64;
assert_eq!(byte_count, writer_byte_count);
// Remove the char used as a payload. If anything else remins,
// something went wrong.
let mut remainder = value;
remainder.retain(|c| c != REPEAT_CHAR as char);
assert_eq!(remainder.len(), 0);
}
}