mirror of
https://github.com/kata-containers/kata-containers.git
synced 2025-09-05 10:50:18 +00:00
Merge pull request #1535 from jodh-intel/agent-shutdown
agent: Enable clean shutdown
This commit is contained in:
@@ -21,7 +21,12 @@ const LOG_LEVELS: &[(&str, slog::Level)] = &[
|
|||||||
];
|
];
|
||||||
|
|
||||||
// XXX: 'writer' param used to make testing possible.
|
// XXX: 'writer' param used to make testing possible.
|
||||||
pub fn create_logger<W>(name: &str, source: &str, level: slog::Level, writer: W) -> slog::Logger
|
pub fn create_logger<W>(
|
||||||
|
name: &str,
|
||||||
|
source: &str,
|
||||||
|
level: slog::Level,
|
||||||
|
writer: W,
|
||||||
|
) -> (slog::Logger, slog_async::AsyncGuard)
|
||||||
where
|
where
|
||||||
W: Write + Send + Sync + 'static,
|
W: Write + Send + Sync + 'static,
|
||||||
{
|
{
|
||||||
@@ -37,17 +42,21 @@ where
|
|||||||
let filter_drain = RuntimeLevelFilter::new(unique_drain, level).fuse();
|
let filter_drain = RuntimeLevelFilter::new(unique_drain, level).fuse();
|
||||||
|
|
||||||
// Ensure the logger is thread-safe
|
// Ensure the logger is thread-safe
|
||||||
let async_drain = slog_async::Async::new(filter_drain).build().fuse();
|
let (async_drain, guard) = slog_async::Async::new(filter_drain)
|
||||||
|
.thread_name("slog-async-logger".into())
|
||||||
|
.build_with_guard();
|
||||||
|
|
||||||
// Add some "standard" fields
|
// Add some "standard" fields
|
||||||
slog::Logger::root(
|
let logger = slog::Logger::root(
|
||||||
async_drain.fuse(),
|
async_drain.fuse(),
|
||||||
o!("version" => env!("CARGO_PKG_VERSION"),
|
o!("version" => env!("CARGO_PKG_VERSION"),
|
||||||
"subsystem" => "root",
|
"subsystem" => "root",
|
||||||
"pid" => process::id().to_string(),
|
"pid" => process::id().to_string(),
|
||||||
"name" => name.to_string(),
|
"name" => name.to_string(),
|
||||||
"source" => source.to_string()),
|
"source" => source.to_string()),
|
||||||
)
|
);
|
||||||
|
|
||||||
|
(logger, guard)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_log_levels() -> Vec<&'static str> {
|
pub fn get_log_levels() -> Vec<&'static str> {
|
||||||
|
20
src/agent/Cargo.lock
generated
20
src/agent/Cargo.lock
generated
@@ -387,6 +387,15 @@ dependencies = [
|
|||||||
"unicode-segmentation",
|
"unicode-segmentation",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hermit-abi"
|
||||||
|
version = "0.1.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "322f4de77956e22ed0e5032c359a0f1273f1f7f0d79bfa3b8ffbc730d7fbcc5c"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hex"
|
name = "hex"
|
||||||
version = "0.4.2"
|
version = "0.4.2"
|
||||||
@@ -753,6 +762,16 @@ dependencies = [
|
|||||||
"autocfg",
|
"autocfg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num_cpus"
|
||||||
|
version = "1.13.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
|
||||||
|
dependencies = [
|
||||||
|
"hermit-abi",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "object"
|
name = "object"
|
||||||
version = "0.22.0"
|
version = "0.22.0"
|
||||||
@@ -1483,6 +1502,7 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
"memchr",
|
"memchr",
|
||||||
"mio",
|
"mio",
|
||||||
|
"num_cpus",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"pin-project-lite 0.2.4",
|
"pin-project-lite 0.2.4",
|
||||||
"signal-hook-registry",
|
"signal-hook-registry",
|
||||||
|
@@ -21,7 +21,7 @@ scopeguard = "1.0.0"
|
|||||||
regex = "1"
|
regex = "1"
|
||||||
|
|
||||||
async-trait = "0.1.42"
|
async-trait = "0.1.42"
|
||||||
tokio = { version = "1.2.0", features = ["rt", "sync", "macros", "io-util", "time", "signal", "io-std", "process"] }
|
tokio = { version = "1.2.0", features = ["rt", "rt-multi-thread", "sync", "macros", "io-util", "time", "signal", "io-std", "process", "fs"] }
|
||||||
futures = "0.3.12"
|
futures = "0.3.12"
|
||||||
netlink-sys = { version = "0.6.0", features = ["tokio_socket",]}
|
netlink-sys = { version = "0.6.0", features = ["tokio_socket",]}
|
||||||
tokio-vsock = "0.3.0"
|
tokio-vsock = "0.3.0"
|
||||||
|
@@ -337,12 +337,15 @@ mod tests {
|
|||||||
assert!(*expected_level == actual_level, $msg);
|
assert!(*expected_level == actual_level, $msg);
|
||||||
} else {
|
} else {
|
||||||
let expected_error = $expected_result.as_ref().unwrap_err();
|
let expected_error = $expected_result.as_ref().unwrap_err();
|
||||||
let actual_error = $actual_result.unwrap_err();
|
|
||||||
|
|
||||||
let expected_error_msg = format!("{:?}", expected_error);
|
let expected_error_msg = format!("{:?}", expected_error);
|
||||||
|
|
||||||
|
if let Err(actual_error) = $actual_result {
|
||||||
let actual_error_msg = format!("{:?}", actual_error);
|
let actual_error_msg = format!("{:?}", actual_error);
|
||||||
|
|
||||||
assert!(expected_error_msg == actual_error_msg, $msg);
|
assert!(expected_error_msg == actual_error_msg, $msg);
|
||||||
|
} else {
|
||||||
|
assert!(expected_error_msg == "expected error, got OK", $msg);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@@ -26,9 +26,8 @@ use nix::libc::{STDERR_FILENO, STDIN_FILENO, STDOUT_FILENO};
|
|||||||
use nix::pty;
|
use nix::pty;
|
||||||
use nix::sys::select::{select, FdSet};
|
use nix::sys::select::{select, FdSet};
|
||||||
use nix::sys::socket::{self, AddressFamily, SockAddr, SockFlag, SockType};
|
use nix::sys::socket::{self, AddressFamily, SockAddr, SockFlag, SockType};
|
||||||
use nix::sys::wait::{self, WaitStatus};
|
use nix::sys::wait;
|
||||||
use nix::unistd::{self, close, dup, dup2, fork, setsid, ForkResult};
|
use nix::unistd::{self, close, dup, dup2, fork, setsid, ForkResult};
|
||||||
use prctl::set_child_subreaper;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::ffi::{CStr, CString, OsStr};
|
use std::ffi::{CStr, CString, OsStr};
|
||||||
@@ -52,23 +51,32 @@ mod network;
|
|||||||
mod pci;
|
mod pci;
|
||||||
pub mod random;
|
pub mod random;
|
||||||
mod sandbox;
|
mod sandbox;
|
||||||
|
mod signal;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test_utils;
|
mod test_utils;
|
||||||
mod uevent;
|
mod uevent;
|
||||||
|
mod util;
|
||||||
mod version;
|
mod version;
|
||||||
|
|
||||||
use mount::{cgroups_mount, general_mount};
|
use mount::{cgroups_mount, general_mount};
|
||||||
use sandbox::Sandbox;
|
use sandbox::Sandbox;
|
||||||
|
use signal::setup_signal_handler;
|
||||||
use slog::Logger;
|
use slog::Logger;
|
||||||
use uevent::watch_uevents;
|
use uevent::watch_uevents;
|
||||||
|
|
||||||
use std::sync::Mutex as SyncMutex;
|
use std::sync::Mutex as SyncMutex;
|
||||||
|
|
||||||
|
use futures::future::join_all;
|
||||||
use futures::StreamExt as _;
|
use futures::StreamExt as _;
|
||||||
use rustjail::pipestream::PipeStream;
|
use rustjail::pipestream::PipeStream;
|
||||||
use tokio::{
|
use tokio::{
|
||||||
signal::unix::{signal, SignalKind},
|
io::AsyncWrite,
|
||||||
sync::{oneshot::Sender, Mutex, RwLock},
|
sync::{
|
||||||
|
oneshot::Sender,
|
||||||
|
watch::{channel, Receiver},
|
||||||
|
Mutex, RwLock,
|
||||||
|
},
|
||||||
|
task::JoinHandle,
|
||||||
};
|
};
|
||||||
use tokio_vsock::{Incoming, VsockListener, VsockStream};
|
use tokio_vsock::{Incoming, VsockListener, VsockStream};
|
||||||
|
|
||||||
@@ -121,6 +129,146 @@ async fn get_vsock_stream(fd: RawFd) -> Result<VsockStream> {
|
|||||||
Ok(stream)
|
Ok(stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a thread to handle reading from the logger pipe. The thread will
|
||||||
|
// output to the vsock port specified, or stdout.
|
||||||
|
async fn create_logger_task(rfd: RawFd, vsock_port: u32, shutdown: Receiver<bool>) -> Result<()> {
|
||||||
|
let mut reader = PipeStream::from_fd(rfd);
|
||||||
|
let mut writer: Box<dyn AsyncWrite + Unpin + Send>;
|
||||||
|
|
||||||
|
if vsock_port > 0 {
|
||||||
|
let listenfd = socket::socket(
|
||||||
|
AddressFamily::Vsock,
|
||||||
|
SockType::Stream,
|
||||||
|
SockFlag::SOCK_CLOEXEC,
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let addr = SockAddr::new_vsock(libc::VMADDR_CID_ANY, vsock_port);
|
||||||
|
socket::bind(listenfd, &addr).unwrap();
|
||||||
|
socket::listen(listenfd, 1).unwrap();
|
||||||
|
|
||||||
|
writer = Box::new(get_vsock_stream(listenfd).await.unwrap());
|
||||||
|
} else {
|
||||||
|
writer = Box::new(tokio::io::stdout());
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = util::interruptable_io_copier(&mut reader, &mut writer, shutdown).await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn real_main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||||
|
env::set_var("RUST_BACKTRACE", "full");
|
||||||
|
|
||||||
|
// List of tasks that need to be stopped for a clean shutdown
|
||||||
|
let mut tasks: Vec<JoinHandle<Result<()>>> = vec![];
|
||||||
|
|
||||||
|
lazy_static::initialize(&SHELLS);
|
||||||
|
|
||||||
|
lazy_static::initialize(&AGENT_CONFIG);
|
||||||
|
|
||||||
|
// support vsock log
|
||||||
|
let (rfd, wfd) = unistd::pipe2(OFlag::O_CLOEXEC)?;
|
||||||
|
|
||||||
|
let (shutdown_tx, shutdown_rx) = channel(true);
|
||||||
|
|
||||||
|
let agent_config = AGENT_CONFIG.clone();
|
||||||
|
|
||||||
|
let init_mode = unistd::getpid() == Pid::from_raw(1);
|
||||||
|
if init_mode {
|
||||||
|
// dup a new file descriptor for this temporary logger writer,
|
||||||
|
// since this logger would be dropped and it's writer would
|
||||||
|
// be closed out of this code block.
|
||||||
|
let newwfd = dup(wfd)?;
|
||||||
|
let writer = unsafe { File::from_raw_fd(newwfd) };
|
||||||
|
|
||||||
|
// Init a temporary logger used by init agent as init process
|
||||||
|
// since before do the base mount, it wouldn't access "/proc/cmdline"
|
||||||
|
// to get the customzied debug level.
|
||||||
|
let (logger, logger_async_guard) =
|
||||||
|
logging::create_logger(NAME, "agent", slog::Level::Debug, writer);
|
||||||
|
|
||||||
|
// Must mount proc fs before parsing kernel command line
|
||||||
|
general_mount(&logger).map_err(|e| {
|
||||||
|
error!(logger, "fail general mount: {}", e);
|
||||||
|
e
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut config = agent_config.write().await;
|
||||||
|
config.parse_cmdline(KERNEL_CMDLINE_FILE)?;
|
||||||
|
|
||||||
|
init_agent_as_init(&logger, config.unified_cgroup_hierarchy)?;
|
||||||
|
drop(logger_async_guard);
|
||||||
|
} else {
|
||||||
|
// once parsed cmdline and set the config, release the write lock
|
||||||
|
// as soon as possible in case other thread would get read lock on
|
||||||
|
// it.
|
||||||
|
let mut config = agent_config.write().await;
|
||||||
|
config.parse_cmdline(KERNEL_CMDLINE_FILE)?;
|
||||||
|
}
|
||||||
|
let config = agent_config.read().await;
|
||||||
|
|
||||||
|
let log_vport = config.log_vport as u32;
|
||||||
|
|
||||||
|
let log_handle = tokio::spawn(create_logger_task(rfd, log_vport, shutdown_rx.clone()));
|
||||||
|
|
||||||
|
tasks.push(log_handle);
|
||||||
|
|
||||||
|
let writer = unsafe { File::from_raw_fd(wfd) };
|
||||||
|
|
||||||
|
// Recreate a logger with the log level get from "/proc/cmdline".
|
||||||
|
let (logger, logger_async_guard) =
|
||||||
|
logging::create_logger(NAME, "agent", config.log_level, writer);
|
||||||
|
|
||||||
|
announce(&logger, &config);
|
||||||
|
|
||||||
|
// This variable is required as it enables the global (and crucially static) logger,
|
||||||
|
// which is required to satisfy the the lifetime constraints of the auto-generated gRPC code.
|
||||||
|
let global_logger = slog_scope::set_global_logger(logger.new(o!("subsystem" => "rpc")));
|
||||||
|
|
||||||
|
// Allow the global logger to be modified later (for shutdown)
|
||||||
|
global_logger.cancel_reset();
|
||||||
|
|
||||||
|
let mut ttrpc_log_guard: Result<(), log::SetLoggerError> = Ok(());
|
||||||
|
|
||||||
|
if config.log_level == slog::Level::Trace {
|
||||||
|
// Redirect ttrpc log calls to slog iff full debug requested
|
||||||
|
ttrpc_log_guard = Ok(slog_stdlog::init().map_err(|e| e)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the sandbox and wait for its ttRPC server to end
|
||||||
|
start_sandbox(&logger, &config, init_mode, &mut tasks, shutdown_rx.clone()).await?;
|
||||||
|
|
||||||
|
// Install a NOP logger for the remainder of the shutdown sequence
|
||||||
|
// to ensure any log calls made by local crates using the scope logger
|
||||||
|
// don't fail.
|
||||||
|
let global_logger_guard2 =
|
||||||
|
slog_scope::set_global_logger(slog::Logger::root(slog::Discard, o!()));
|
||||||
|
global_logger_guard2.cancel_reset();
|
||||||
|
|
||||||
|
drop(logger_async_guard);
|
||||||
|
|
||||||
|
drop(ttrpc_log_guard);
|
||||||
|
|
||||||
|
// Trigger a controlled shutdown
|
||||||
|
shutdown_tx
|
||||||
|
.send(true)
|
||||||
|
.map_err(|e| anyhow!(e).context("failed to request shutdown"))?;
|
||||||
|
|
||||||
|
// Wait for all threads to finish
|
||||||
|
let results = join_all(tasks).await;
|
||||||
|
|
||||||
|
for result in results {
|
||||||
|
if let Err(e) = result {
|
||||||
|
return Err(anyhow!(e).into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
eprintln!("{} shutdown complete", NAME);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||||
let args: Vec<String> = env::args().collect();
|
let args: Vec<String> = env::args().collect();
|
||||||
|
|
||||||
@@ -141,111 +289,20 @@ fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
|||||||
exit(0);
|
exit(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
let rt = tokio::runtime::Builder::new_current_thread()
|
let rt = tokio::runtime::Builder::new_multi_thread()
|
||||||
.enable_all()
|
.enable_all()
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
||||||
rt.block_on(async {
|
rt.block_on(real_main())
|
||||||
env::set_var("RUST_BACKTRACE", "full");
|
|
||||||
|
|
||||||
lazy_static::initialize(&SHELLS);
|
|
||||||
|
|
||||||
lazy_static::initialize(&AGENT_CONFIG);
|
|
||||||
|
|
||||||
// support vsock log
|
|
||||||
let (rfd, wfd) = unistd::pipe2(OFlag::O_CLOEXEC)?;
|
|
||||||
|
|
||||||
let agent_config = AGENT_CONFIG.clone();
|
|
||||||
|
|
||||||
let init_mode = unistd::getpid() == Pid::from_raw(1);
|
|
||||||
if init_mode {
|
|
||||||
// dup a new file descriptor for this temporary logger writer,
|
|
||||||
// since this logger would be dropped and it's writer would
|
|
||||||
// be closed out of this code block.
|
|
||||||
let newwfd = dup(wfd)?;
|
|
||||||
let writer = unsafe { File::from_raw_fd(newwfd) };
|
|
||||||
|
|
||||||
// Init a temporary logger used by init agent as init process
|
|
||||||
// since before do the base mount, it wouldn't access "/proc/cmdline"
|
|
||||||
// to get the customzied debug level.
|
|
||||||
let logger = logging::create_logger(NAME, "agent", slog::Level::Debug, writer);
|
|
||||||
|
|
||||||
// Must mount proc fs before parsing kernel command line
|
|
||||||
general_mount(&logger).map_err(|e| {
|
|
||||||
error!(logger, "fail general mount: {}", e);
|
|
||||||
e
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let mut config = agent_config.write().await;
|
|
||||||
config.parse_cmdline(KERNEL_CMDLINE_FILE)?;
|
|
||||||
|
|
||||||
init_agent_as_init(&logger, config.unified_cgroup_hierarchy)?;
|
|
||||||
} else {
|
|
||||||
// once parsed cmdline and set the config, release the write lock
|
|
||||||
// as soon as possible in case other thread would get read lock on
|
|
||||||
// it.
|
|
||||||
let mut config = agent_config.write().await;
|
|
||||||
config.parse_cmdline(KERNEL_CMDLINE_FILE)?;
|
|
||||||
}
|
|
||||||
let config = agent_config.read().await;
|
|
||||||
|
|
||||||
let log_vport = config.log_vport as u32;
|
|
||||||
let log_handle = tokio::spawn(async move {
|
|
||||||
let mut reader = PipeStream::from_fd(rfd);
|
|
||||||
|
|
||||||
if log_vport > 0 {
|
|
||||||
let listenfd = socket::socket(
|
|
||||||
AddressFamily::Vsock,
|
|
||||||
SockType::Stream,
|
|
||||||
SockFlag::SOCK_CLOEXEC,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let addr = SockAddr::new_vsock(libc::VMADDR_CID_ANY, log_vport);
|
|
||||||
socket::bind(listenfd, &addr).unwrap();
|
|
||||||
socket::listen(listenfd, 1).unwrap();
|
|
||||||
|
|
||||||
let mut vsock_stream = get_vsock_stream(listenfd).await.unwrap();
|
|
||||||
|
|
||||||
// copy log to stdout
|
|
||||||
tokio::io::copy(&mut reader, &mut vsock_stream)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy log to stdout
|
async fn start_sandbox(
|
||||||
let mut stdout_writer = tokio::io::stdout();
|
logger: &Logger,
|
||||||
let _ = tokio::io::copy(&mut reader, &mut stdout_writer).await;
|
config: &AgentConfig,
|
||||||
});
|
init_mode: bool,
|
||||||
|
tasks: &mut Vec<JoinHandle<Result<()>>>,
|
||||||
let writer = unsafe { File::from_raw_fd(wfd) };
|
shutdown: Receiver<bool>,
|
||||||
|
) -> Result<()> {
|
||||||
// Recreate a logger with the log level get from "/proc/cmdline".
|
|
||||||
let logger = logging::create_logger(NAME, "agent", config.log_level, writer);
|
|
||||||
|
|
||||||
announce(&logger, &config);
|
|
||||||
|
|
||||||
// This "unused" variable is required as it enables the global (and crucially static) logger,
|
|
||||||
// which is required to satisfy the the lifetime constraints of the auto-generated gRPC code.
|
|
||||||
let _guard = slog_scope::set_global_logger(logger.new(o!("subsystem" => "rpc")));
|
|
||||||
|
|
||||||
let mut _log_guard: Result<(), log::SetLoggerError> = Ok(());
|
|
||||||
|
|
||||||
if config.log_level == slog::Level::Trace {
|
|
||||||
// Redirect ttrpc log calls to slog iff full debug requested
|
|
||||||
_log_guard = Ok(slog_stdlog::init().map_err(|e| e)?);
|
|
||||||
}
|
|
||||||
|
|
||||||
start_sandbox(&logger, &config, init_mode).await?;
|
|
||||||
|
|
||||||
let _ = log_handle.await.unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn start_sandbox(logger: &Logger, config: &AgentConfig, init_mode: bool) -> Result<()> {
|
|
||||||
let shells = SHELLS.clone();
|
let shells = SHELLS.clone();
|
||||||
let debug_console_vport = config.debug_console_vport as u32;
|
let debug_console_vport = config.debug_console_vport as u32;
|
||||||
|
|
||||||
@@ -275,10 +332,17 @@ async fn start_sandbox(logger: &Logger, config: &AgentConfig, init_mode: bool) -
|
|||||||
|
|
||||||
let sandbox = Arc::new(Mutex::new(s));
|
let sandbox = Arc::new(Mutex::new(s));
|
||||||
|
|
||||||
setup_signal_handler(&logger, sandbox.clone())
|
let signal_handler_task = tokio::spawn(setup_signal_handler(
|
||||||
.await
|
logger.clone(),
|
||||||
.unwrap();
|
sandbox.clone(),
|
||||||
watch_uevents(sandbox.clone()).await;
|
shutdown.clone(),
|
||||||
|
));
|
||||||
|
|
||||||
|
tasks.push(signal_handler_task);
|
||||||
|
|
||||||
|
let uevents_handler_task = tokio::spawn(watch_uevents(sandbox.clone(), shutdown.clone()));
|
||||||
|
|
||||||
|
tasks.push(uevents_handler_task);
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||||
sandbox.lock().await.sender = Some(tx);
|
sandbox.lock().await.sender = Some(tx);
|
||||||
@@ -297,93 +361,6 @@ async fn start_sandbox(logger: &Logger, config: &AgentConfig, init_mode: bool) -
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
use nix::sys::wait::WaitPidFlag;
|
|
||||||
|
|
||||||
async fn setup_signal_handler(logger: &Logger, sandbox: Arc<Mutex<Sandbox>>) -> Result<()> {
|
|
||||||
let logger = logger.new(o!("subsystem" => "signals"));
|
|
||||||
|
|
||||||
set_child_subreaper(true)
|
|
||||||
.map_err(|err| anyhow!(err).context("failed to setup agent as a child subreaper"))?;
|
|
||||||
|
|
||||||
let mut signal_stream = signal(SignalKind::child())?;
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
'outer: loop {
|
|
||||||
signal_stream.recv().await;
|
|
||||||
info!(logger, "received signal"; "signal" => "SIGCHLD");
|
|
||||||
|
|
||||||
// sevral signals can be combined together
|
|
||||||
// as one. So loop around to reap all
|
|
||||||
// exited children
|
|
||||||
'inner: loop {
|
|
||||||
let wait_status = match wait::waitpid(
|
|
||||||
Some(Pid::from_raw(-1)),
|
|
||||||
Some(WaitPidFlag::WNOHANG | WaitPidFlag::__WALL),
|
|
||||||
) {
|
|
||||||
Ok(s) => {
|
|
||||||
if s == WaitStatus::StillAlive {
|
|
||||||
continue 'outer;
|
|
||||||
}
|
|
||||||
s
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
info!(
|
|
||||||
logger,
|
|
||||||
"waitpid reaper failed";
|
|
||||||
"error" => e.as_errno().unwrap().desc()
|
|
||||||
);
|
|
||||||
continue 'outer;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
info!(logger, "wait_status"; "wait_status result" => format!("{:?}", wait_status));
|
|
||||||
|
|
||||||
let pid = wait_status.pid();
|
|
||||||
if let Some(pid) = pid {
|
|
||||||
let raw_pid = pid.as_raw();
|
|
||||||
let child_pid = format!("{}", raw_pid);
|
|
||||||
|
|
||||||
let logger = logger.new(o!("child-pid" => child_pid));
|
|
||||||
|
|
||||||
let mut sandbox = sandbox.lock().await;
|
|
||||||
let process = sandbox.find_process(raw_pid);
|
|
||||||
if process.is_none() {
|
|
||||||
info!(logger, "child exited unexpectedly");
|
|
||||||
continue 'inner;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut p = process.unwrap();
|
|
||||||
|
|
||||||
if p.exit_pipe_w.is_none() {
|
|
||||||
error!(logger, "the process's exit_pipe_w isn't set");
|
|
||||||
continue 'inner;
|
|
||||||
}
|
|
||||||
let pipe_write = p.exit_pipe_w.unwrap();
|
|
||||||
let ret: i32;
|
|
||||||
|
|
||||||
match wait_status {
|
|
||||||
WaitStatus::Exited(_, c) => ret = c,
|
|
||||||
WaitStatus::Signaled(_, sig, _) => ret = sig as i32,
|
|
||||||
_ => {
|
|
||||||
info!(logger, "got wrong status for process";
|
|
||||||
"child-status" => format!("{:?}", wait_status));
|
|
||||||
continue 'inner;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
p.exit_code = ret;
|
|
||||||
let _ = unistd::close(pipe_write);
|
|
||||||
|
|
||||||
info!(logger, "notify term to close");
|
|
||||||
// close the socket file to notify readStdio to close terminal specifically
|
|
||||||
// in case this process's terminal has been inherited by its children.
|
|
||||||
p.notify_term_close();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// init_agent_as_init will do the initializations such as setting up the rootfs
|
// init_agent_as_init will do the initializations such as setting up the rootfs
|
||||||
// when this agent has been run as the init process.
|
// when this agent has been run as the init process.
|
||||||
fn init_agent_as_init(logger: &Logger, unified_cgroup_hierarchy: bool) -> Result<()> {
|
fn init_agent_as_init(logger: &Logger, unified_cgroup_hierarchy: bool) -> Result<()> {
|
||||||
|
159
src/agent/src/signal.rs
Normal file
159
src/agent/src/signal.rs
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
// Copyright (c) 2019-2020 Ant Financial
|
||||||
|
// Copyright (c) 2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
use crate::sandbox::Sandbox;
|
||||||
|
use anyhow::{anyhow, Result};
|
||||||
|
use nix::sys::wait::WaitPidFlag;
|
||||||
|
use nix::sys::wait::{self, WaitStatus};
|
||||||
|
use nix::unistd;
|
||||||
|
use prctl::set_child_subreaper;
|
||||||
|
use slog::{error, info, o, Logger};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::select;
|
||||||
|
use tokio::signal::unix::{signal, SignalKind};
|
||||||
|
use tokio::sync::watch::Receiver;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use unistd::Pid;
|
||||||
|
|
||||||
|
async fn handle_sigchild(logger: Logger, sandbox: Arc<Mutex<Sandbox>>) -> Result<()> {
|
||||||
|
info!(logger, "handling signal"; "signal" => "SIGCHLD");
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let result = wait::waitpid(
|
||||||
|
Some(Pid::from_raw(-1)),
|
||||||
|
Some(WaitPidFlag::WNOHANG | WaitPidFlag::__WALL),
|
||||||
|
);
|
||||||
|
|
||||||
|
let wait_status = match result {
|
||||||
|
Ok(s) => {
|
||||||
|
if s == WaitStatus::StillAlive {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
s
|
||||||
|
}
|
||||||
|
Err(e) => return Err(anyhow!(e).context("waitpid reaper failed")),
|
||||||
|
};
|
||||||
|
|
||||||
|
info!(logger, "wait_status"; "wait_status result" => format!("{:?}", wait_status));
|
||||||
|
|
||||||
|
if let Some(pid) = wait_status.pid() {
|
||||||
|
let raw_pid = pid.as_raw();
|
||||||
|
let child_pid = format!("{}", raw_pid);
|
||||||
|
|
||||||
|
let logger = logger.new(o!("child-pid" => child_pid));
|
||||||
|
|
||||||
|
let sandbox_ref = sandbox.clone();
|
||||||
|
let mut sandbox = sandbox_ref.lock().await;
|
||||||
|
|
||||||
|
let process = sandbox.find_process(raw_pid);
|
||||||
|
if process.is_none() {
|
||||||
|
info!(logger, "child exited unexpectedly");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut p = process.unwrap();
|
||||||
|
|
||||||
|
if p.exit_pipe_w.is_none() {
|
||||||
|
info!(logger, "process exit pipe not set");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let pipe_write = p.exit_pipe_w.unwrap();
|
||||||
|
let ret: i32;
|
||||||
|
|
||||||
|
match wait_status {
|
||||||
|
WaitStatus::Exited(_, c) => ret = c,
|
||||||
|
WaitStatus::Signaled(_, sig, _) => ret = sig as i32,
|
||||||
|
_ => {
|
||||||
|
info!(logger, "got wrong status for process";
|
||||||
|
"child-status" => format!("{:?}", wait_status));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.exit_code = ret;
|
||||||
|
let _ = unistd::close(pipe_write);
|
||||||
|
|
||||||
|
info!(logger, "notify term to close");
|
||||||
|
// close the socket file to notify readStdio to close terminal specifically
|
||||||
|
// in case this process's terminal has been inherited by its children.
|
||||||
|
p.notify_term_close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn setup_signal_handler(
|
||||||
|
logger: Logger,
|
||||||
|
sandbox: Arc<Mutex<Sandbox>>,
|
||||||
|
mut shutdown: Receiver<bool>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let logger = logger.new(o!("subsystem" => "signals"));
|
||||||
|
|
||||||
|
set_child_subreaper(true)
|
||||||
|
.map_err(|err| anyhow!(err).context("failed to setup agent as a child subreaper"))?;
|
||||||
|
|
||||||
|
let mut sigchild_stream = signal(SignalKind::child())?;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
select! {
|
||||||
|
_ = shutdown.changed() => {
|
||||||
|
info!(logger, "got shutdown request");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = sigchild_stream.recv() => {
|
||||||
|
let result = handle_sigchild(logger.clone(), sandbox.clone()).await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(()) => (),
|
||||||
|
Err(e) => {
|
||||||
|
// Log errors, but don't abort - just wait for more signals!
|
||||||
|
error!(logger, "failed to handle signal"; "error" => format!("{:?}", e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio::pin;
|
||||||
|
use tokio::sync::watch::channel;
|
||||||
|
use tokio::time::Duration;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_setup_signal_handler() {
|
||||||
|
let logger = slog::Logger::root(slog::Discard, o!());
|
||||||
|
let s = Sandbox::new(&logger).unwrap();
|
||||||
|
|
||||||
|
let sandbox = Arc::new(Mutex::new(s));
|
||||||
|
|
||||||
|
let (tx, rx) = channel(true);
|
||||||
|
|
||||||
|
let handle = tokio::spawn(setup_signal_handler(logger, sandbox, rx));
|
||||||
|
|
||||||
|
let timeout = tokio::time::sleep(Duration::from_secs(1));
|
||||||
|
pin!(timeout);
|
||||||
|
|
||||||
|
tx.send(true).expect("failed to request shutdown");
|
||||||
|
|
||||||
|
loop {
|
||||||
|
select! {
|
||||||
|
_ = handle => {
|
||||||
|
println!("INFO: task completed");
|
||||||
|
break;
|
||||||
|
},
|
||||||
|
_ = &mut timeout => {
|
||||||
|
panic!("signal thread failed to stop");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@@ -9,10 +9,13 @@ use crate::sandbox::Sandbox;
|
|||||||
use crate::GLOBAL_DEVICE_WATCHER;
|
use crate::GLOBAL_DEVICE_WATCHER;
|
||||||
use slog::Logger;
|
use slog::Logger;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
use netlink_sys::{protocols, SocketAddr, TokioSocket};
|
use netlink_sys::{protocols, SocketAddr, TokioSocket};
|
||||||
use nix::errno::Errno;
|
use nix::errno::Errno;
|
||||||
use std::os::unix::io::FromRawFd;
|
use std::os::unix::io::FromRawFd;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use tokio::select;
|
||||||
|
use tokio::sync::watch::Receiver;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
@@ -132,13 +135,21 @@ impl Uevent {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn watch_uevents(sandbox: Arc<Mutex<Sandbox>>) {
|
pub async fn watch_uevents(
|
||||||
|
sandbox: Arc<Mutex<Sandbox>>,
|
||||||
|
mut shutdown: Receiver<bool>,
|
||||||
|
) -> Result<()> {
|
||||||
let sref = sandbox.clone();
|
let sref = sandbox.clone();
|
||||||
let s = sref.lock().await;
|
let s = sref.lock().await;
|
||||||
let logger = s.logger.new(o!("subsystem" => "uevent"));
|
let logger = s.logger.new(o!("subsystem" => "uevent"));
|
||||||
|
|
||||||
tokio::spawn(async move {
|
// Unlock the sandbox to allow a successful shutdown
|
||||||
|
drop(s);
|
||||||
|
|
||||||
|
info!(logger, "starting uevents handler");
|
||||||
|
|
||||||
let mut socket;
|
let mut socket;
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let fd = libc::socket(
|
let fd = libc::socket(
|
||||||
libc::AF_NETLINK,
|
libc::AF_NETLINK,
|
||||||
@@ -147,19 +158,26 @@ pub async fn watch_uevents(sandbox: Arc<Mutex<Sandbox>>) {
|
|||||||
);
|
);
|
||||||
socket = TokioSocket::from_raw_fd(fd);
|
socket = TokioSocket::from_raw_fd(fd);
|
||||||
}
|
}
|
||||||
socket.bind(&SocketAddr::new(0, 1)).unwrap();
|
|
||||||
|
socket.bind(&SocketAddr::new(0, 1))?;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match socket.recv_from_full().await {
|
select! {
|
||||||
|
_ = shutdown.changed() => {
|
||||||
|
info!(logger, "got shutdown request");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
result = socket.recv_from_full() => {
|
||||||
|
match result {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(logger, "receive uevent message failed"; "error" => format!("{}", e))
|
error!(logger, "failed to receive uevent"; "error" => format!("{}", e))
|
||||||
}
|
}
|
||||||
Ok((buf, addr)) => {
|
Ok((buf, addr)) => {
|
||||||
if addr.port_number() != 0 {
|
if addr.port_number() != 0 {
|
||||||
// not our netlink message
|
// not our netlink message
|
||||||
let err_msg = format!("{:?}", nix::Error::Sys(Errno::EBADMSG));
|
let err_msg = format!("{:?}", nix::Error::Sys(Errno::EBADMSG));
|
||||||
error!(logger, "receive uevent message failed"; "error" => err_msg);
|
error!(logger, "receive uevent message failed"; "error" => err_msg);
|
||||||
return;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let text = String::from_utf8(buf);
|
let text = String::from_utf8(buf);
|
||||||
@@ -176,5 +194,8 @@ pub async fn watch_uevents(sandbox: Arc<Mutex<Sandbox>>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
342
src/agent/src/util.rs
Normal file
342
src/agent/src/util.rs
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
// Copyright (c) 2021 Intel Corporation
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
use std::io;
|
||||||
|
use std::io::ErrorKind;
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
use tokio::sync::watch::Receiver;
|
||||||
|
|
||||||
|
// Size of I/O read buffer
|
||||||
|
const BUF_SIZE: usize = 8192;
|
||||||
|
|
||||||
|
// Interruptable I/O copy using readers and writers
|
||||||
|
// (an interruptable version of "io::copy()").
|
||||||
|
pub async fn interruptable_io_copier<R: Sized, W: Sized>(
|
||||||
|
mut reader: R,
|
||||||
|
mut writer: W,
|
||||||
|
mut shutdown: Receiver<bool>,
|
||||||
|
) -> io::Result<u64>
|
||||||
|
where
|
||||||
|
R: tokio::io::AsyncRead + Unpin,
|
||||||
|
W: tokio::io::AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
let mut total_bytes: u64 = 0;
|
||||||
|
|
||||||
|
let mut buf: [u8; BUF_SIZE] = [0; BUF_SIZE];
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
_ = shutdown.changed() => {
|
||||||
|
eprintln!("INFO: interruptable_io_copier: got shutdown request");
|
||||||
|
break;
|
||||||
|
},
|
||||||
|
|
||||||
|
result = reader.read(&mut buf) => {
|
||||||
|
let bytes = match result {
|
||||||
|
Ok(0) => return Ok(total_bytes),
|
||||||
|
Ok(len) => len,
|
||||||
|
Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
};
|
||||||
|
|
||||||
|
total_bytes += bytes as u64;
|
||||||
|
|
||||||
|
// Actually copy the data ;)
|
||||||
|
writer.write_all(&buf[..bytes]).await?;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(total_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::io;
|
||||||
|
use std::io::Cursor;
|
||||||
|
use std::io::Write;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use std::task::{Context, Poll, Poll::Ready};
|
||||||
|
use tokio::pin;
|
||||||
|
use tokio::select;
|
||||||
|
use tokio::sync::watch::channel;
|
||||||
|
use tokio::task::JoinError;
|
||||||
|
use tokio::time::Duration;
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Clone)]
|
||||||
|
struct BufWriter {
|
||||||
|
data: Arc<Mutex<Vec<u8>>>,
|
||||||
|
slow_write: bool,
|
||||||
|
write_delay: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BufWriter {
|
||||||
|
fn new() -> Self {
|
||||||
|
BufWriter {
|
||||||
|
data: Arc::new(Mutex::new(Vec::<u8>::new())),
|
||||||
|
slow_write: false,
|
||||||
|
write_delay: Duration::new(0, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_vec(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||||
|
let vec_ref = self.data.clone();
|
||||||
|
|
||||||
|
let mut vec_locked = vec_ref.lock();
|
||||||
|
|
||||||
|
let mut v = vec_locked.as_deref_mut().unwrap();
|
||||||
|
|
||||||
|
if self.write_delay.as_nanos() > 0 {
|
||||||
|
std::thread::sleep(self.write_delay);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::io::Write::write(&mut v, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Write for BufWriter {
|
||||||
|
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||||
|
self.write_vec(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush(&mut self) -> io::Result<()> {
|
||||||
|
let vec_ref = self.data.clone();
|
||||||
|
|
||||||
|
let mut vec_locked = vec_ref.lock();
|
||||||
|
|
||||||
|
let v = vec_locked.as_deref_mut().unwrap();
|
||||||
|
|
||||||
|
std::io::Write::flush(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tokio::io::AsyncWrite for BufWriter {
|
||||||
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<Result<usize, io::Error>> {
|
||||||
|
let result = self.write_vec(buf);
|
||||||
|
|
||||||
|
Ready(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), std::io::Error>> {
|
||||||
|
// NOP
|
||||||
|
Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), std::io::Error>> {
|
||||||
|
// NOP
|
||||||
|
Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToString for BufWriter {
|
||||||
|
fn to_string(&self) -> String {
|
||||||
|
let data_ref = self.data.clone();
|
||||||
|
let output = data_ref.lock().unwrap();
|
||||||
|
let s = (*output).clone();
|
||||||
|
|
||||||
|
String::from_utf8(s).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn test_interruptable_io_copier_reader() {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct TestData {
|
||||||
|
reader_value: String,
|
||||||
|
result: io::Result<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
let tests = &[
|
||||||
|
TestData {
|
||||||
|
reader_value: "".into(),
|
||||||
|
result: Ok(0),
|
||||||
|
},
|
||||||
|
TestData {
|
||||||
|
reader_value: "a".into(),
|
||||||
|
result: Ok(1),
|
||||||
|
},
|
||||||
|
TestData {
|
||||||
|
reader_value: "foo".into(),
|
||||||
|
result: Ok(3),
|
||||||
|
},
|
||||||
|
TestData {
|
||||||
|
reader_value: "b".repeat(BUF_SIZE - 1),
|
||||||
|
result: Ok((BUF_SIZE - 1) as u64),
|
||||||
|
},
|
||||||
|
TestData {
|
||||||
|
reader_value: "c".repeat(BUF_SIZE),
|
||||||
|
result: Ok((BUF_SIZE) as u64),
|
||||||
|
},
|
||||||
|
TestData {
|
||||||
|
reader_value: "d".repeat(BUF_SIZE + 1),
|
||||||
|
result: Ok((BUF_SIZE + 1) as u64),
|
||||||
|
},
|
||||||
|
TestData {
|
||||||
|
reader_value: "e".repeat((2 * BUF_SIZE) - 1),
|
||||||
|
result: Ok(((2 * BUF_SIZE) - 1) as u64),
|
||||||
|
},
|
||||||
|
TestData {
|
||||||
|
reader_value: "f".repeat(2 * BUF_SIZE),
|
||||||
|
result: Ok((2 * BUF_SIZE) as u64),
|
||||||
|
},
|
||||||
|
TestData {
|
||||||
|
reader_value: "g".repeat((2 * BUF_SIZE) + 1),
|
||||||
|
result: Ok(((2 * BUF_SIZE) + 1) as u64),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
for (i, d) in tests.iter().enumerate() {
|
||||||
|
// Create a string containing details of the test
|
||||||
|
let msg = format!("test[{}]: {:?}", i, d);
|
||||||
|
|
||||||
|
let (tx, rx) = channel(true);
|
||||||
|
let reader = Cursor::new(d.reader_value.clone());
|
||||||
|
let writer = BufWriter::new();
|
||||||
|
|
||||||
|
// XXX: Pass a copy of the writer to the copier to allow the
|
||||||
|
// result of the write operation to be checked below.
|
||||||
|
let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx));
|
||||||
|
|
||||||
|
// Allow time for the thread to be spawned.
|
||||||
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||||
|
|
||||||
|
let timeout = tokio::time::sleep(Duration::from_secs(1));
|
||||||
|
pin!(timeout);
|
||||||
|
|
||||||
|
// Since the readers only specify a small number of bytes, the
|
||||||
|
// copier will quickly read zero and kill the task, closing the
|
||||||
|
// Receiver.
|
||||||
|
assert!(tx.is_closed(), "{}", msg);
|
||||||
|
|
||||||
|
let spawn_result: std::result::Result<
|
||||||
|
std::result::Result<u64, std::io::Error>,
|
||||||
|
JoinError,
|
||||||
|
>;
|
||||||
|
|
||||||
|
let result: std::result::Result<u64, std::io::Error>;
|
||||||
|
|
||||||
|
select! {
|
||||||
|
res = handle => spawn_result = res,
|
||||||
|
_ = &mut timeout => panic!("timed out"),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(spawn_result.is_ok());
|
||||||
|
|
||||||
|
result = spawn_result.unwrap();
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let byte_count = result.unwrap() as usize;
|
||||||
|
assert_eq!(byte_count, d.reader_value.len(), "{}", msg);
|
||||||
|
|
||||||
|
let value = writer.to_string();
|
||||||
|
assert_eq!(value, d.reader_value, "{}", msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn test_interruptable_io_copier_eof() {
|
||||||
|
// Create an async reader that always returns EOF
|
||||||
|
let reader = tokio::io::empty();
|
||||||
|
|
||||||
|
let (tx, rx) = channel(true);
|
||||||
|
let writer = BufWriter::new();
|
||||||
|
|
||||||
|
let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx));
|
||||||
|
|
||||||
|
// Allow time for the thread to be spawned.
|
||||||
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||||
|
|
||||||
|
let timeout = tokio::time::sleep(Duration::from_secs(1));
|
||||||
|
pin!(timeout);
|
||||||
|
|
||||||
|
assert!(tx.is_closed());
|
||||||
|
|
||||||
|
let spawn_result: std::result::Result<std::result::Result<u64, std::io::Error>, JoinError>;
|
||||||
|
|
||||||
|
let result: std::result::Result<u64, std::io::Error>;
|
||||||
|
|
||||||
|
select! {
|
||||||
|
res = handle => spawn_result = res,
|
||||||
|
_ = &mut timeout => panic!("timed out"),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(spawn_result.is_ok());
|
||||||
|
|
||||||
|
result = spawn_result.unwrap();
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let byte_count = result.unwrap();
|
||||||
|
assert_eq!(byte_count, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||||
|
async fn test_interruptable_io_copier_shutdown() {
|
||||||
|
// Create an async reader that creates an infinite stream of bytes
|
||||||
|
// (which allows us to interrupt it, since we know it is always busy ;)
|
||||||
|
const REPEAT_CHAR: u8 = b'r';
|
||||||
|
|
||||||
|
let reader = tokio::io::repeat(REPEAT_CHAR);
|
||||||
|
|
||||||
|
let (tx, rx) = channel(true);
|
||||||
|
let writer = BufWriter::new();
|
||||||
|
|
||||||
|
let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx));
|
||||||
|
|
||||||
|
// Allow time for the thread to be spawned.
|
||||||
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||||
|
|
||||||
|
let timeout = tokio::time::sleep(Duration::from_secs(1));
|
||||||
|
pin!(timeout);
|
||||||
|
|
||||||
|
assert!(!tx.is_closed());
|
||||||
|
|
||||||
|
tx.send(true).expect("failed to request shutdown");
|
||||||
|
|
||||||
|
let spawn_result: std::result::Result<std::result::Result<u64, std::io::Error>, JoinError>;
|
||||||
|
|
||||||
|
let result: std::result::Result<u64, std::io::Error>;
|
||||||
|
|
||||||
|
select! {
|
||||||
|
res = handle => spawn_result = res,
|
||||||
|
_ = &mut timeout => panic!("timed out"),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(spawn_result.is_ok());
|
||||||
|
|
||||||
|
result = spawn_result.unwrap();
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let byte_count = result.unwrap();
|
||||||
|
|
||||||
|
let value = writer.to_string();
|
||||||
|
|
||||||
|
let writer_byte_count = value.len() as u64;
|
||||||
|
|
||||||
|
assert_eq!(byte_count, writer_byte_count);
|
||||||
|
|
||||||
|
// Remove the char used as a payload. If anything else remins,
|
||||||
|
// something went wrong.
|
||||||
|
let mut remainder = value;
|
||||||
|
|
||||||
|
remainder.retain(|c| c != REPEAT_CHAR as char);
|
||||||
|
|
||||||
|
assert_eq!(remainder.len(), 0);
|
||||||
|
}
|
||||||
|
}
|
@@ -180,7 +180,7 @@ fn real_main() -> Result<()> {
|
|||||||
|
|
||||||
// Setup logger
|
// Setup logger
|
||||||
let writer = io::stdout();
|
let writer = io::stdout();
|
||||||
let logger = logging::create_logger(name, name, log_level, writer);
|
let (logger, _logger_guard) = logging::create_logger(name, name, log_level, writer);
|
||||||
|
|
||||||
announce(&logger, version);
|
announce(&logger, version);
|
||||||
|
|
||||||
|
@@ -142,7 +142,7 @@ fn connect(name: &str, global_args: clap::ArgMatches) -> Result<()> {
|
|||||||
let log_level = logging::level_name_to_slog_level(log_level_name).map_err(|e| anyhow!(e))?;
|
let log_level = logging::level_name_to_slog_level(log_level_name).map_err(|e| anyhow!(e))?;
|
||||||
|
|
||||||
let writer = io::stdout();
|
let writer = io::stdout();
|
||||||
let logger = logging::create_logger(name, crate_name!(), log_level, writer);
|
let (logger, _guard) = logging::create_logger(name, crate_name!(), log_level, writer);
|
||||||
|
|
||||||
let timeout_nano: i64 = match args.value_of("timeout") {
|
let timeout_nano: i64 = match args.value_of("timeout") {
|
||||||
Some(t) => utils::human_time_to_ns(t).map_err(|e| e)?,
|
Some(t) => utils::human_time_to_ns(t).map_err(|e| e)?,
|
||||||
|
Reference in New Issue
Block a user