diff --git a/pkg/logging/src/lib.rs b/pkg/logging/src/lib.rs index 65314237cf..af6a1738f3 100644 --- a/pkg/logging/src/lib.rs +++ b/pkg/logging/src/lib.rs @@ -21,7 +21,12 @@ const LOG_LEVELS: &[(&str, slog::Level)] = &[ ]; // XXX: 'writer' param used to make testing possible. -pub fn create_logger(name: &str, source: &str, level: slog::Level, writer: W) -> slog::Logger +pub fn create_logger( + name: &str, + source: &str, + level: slog::Level, + writer: W, +) -> (slog::Logger, slog_async::AsyncGuard) where W: Write + Send + Sync + 'static, { @@ -37,17 +42,21 @@ where let filter_drain = RuntimeLevelFilter::new(unique_drain, level).fuse(); // Ensure the logger is thread-safe - let async_drain = slog_async::Async::new(filter_drain).build().fuse(); + let (async_drain, guard) = slog_async::Async::new(filter_drain) + .thread_name("slog-async-logger".into()) + .build_with_guard(); // Add some "standard" fields - slog::Logger::root( + let logger = slog::Logger::root( async_drain.fuse(), o!("version" => env!("CARGO_PKG_VERSION"), "subsystem" => "root", "pid" => process::id().to_string(), "name" => name.to_string(), "source" => source.to_string()), - ) + ); + + (logger, guard) } pub fn get_log_levels() -> Vec<&'static str> { diff --git a/src/agent/Cargo.lock b/src/agent/Cargo.lock index a7de84906b..f747dc8275 100644 --- a/src/agent/Cargo.lock +++ b/src/agent/Cargo.lock @@ -387,6 +387,15 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "hermit-abi" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "322f4de77956e22ed0e5032c359a0f1273f1f7f0d79bfa3b8ffbc730d7fbcc5c" +dependencies = [ + "libc", +] + [[package]] name = "hex" version = "0.4.2" @@ -753,6 +762,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "object" version = "0.22.0" @@ -1483,6 +1502,7 @@ dependencies = [ "libc", "memchr", "mio", + "num_cpus", "once_cell", "pin-project-lite 0.2.4", "signal-hook-registry", diff --git a/src/agent/Cargo.toml b/src/agent/Cargo.toml index 8f86d2aba5..e632a30c3d 100644 --- a/src/agent/Cargo.toml +++ b/src/agent/Cargo.toml @@ -21,7 +21,7 @@ scopeguard = "1.0.0" regex = "1" async-trait = "0.1.42" -tokio = { version = "1.2.0", features = ["rt", "sync", "macros", "io-util", "time", "signal", "io-std", "process"] } +tokio = { version = "1.2.0", features = ["rt", "rt-multi-thread", "sync", "macros", "io-util", "time", "signal", "io-std", "process", "fs"] } futures = "0.3.12" netlink-sys = { version = "0.6.0", features = ["tokio_socket",]} tokio-vsock = "0.3.0" diff --git a/src/agent/src/config.rs b/src/agent/src/config.rs index c3b311de31..8225eee5fc 100644 --- a/src/agent/src/config.rs +++ b/src/agent/src/config.rs @@ -337,12 +337,15 @@ mod tests { assert!(*expected_level == actual_level, $msg); } else { let expected_error = $expected_result.as_ref().unwrap_err(); - let actual_error = $actual_result.unwrap_err(); - let expected_error_msg = format!("{:?}", expected_error); - let actual_error_msg = format!("{:?}", actual_error); - assert!(expected_error_msg == actual_error_msg, $msg); + if let Err(actual_error) = $actual_result { + let actual_error_msg = format!("{:?}", actual_error); + + assert!(expected_error_msg == actual_error_msg, $msg); + } else { + assert!(expected_error_msg == "expected error, got OK", $msg); + } } }; } diff --git a/src/agent/src/main.rs b/src/agent/src/main.rs index e58ba88f77..c86a03254a 100644 --- a/src/agent/src/main.rs +++ b/src/agent/src/main.rs @@ -26,9 +26,8 @@ use nix::libc::{STDERR_FILENO, STDIN_FILENO, STDOUT_FILENO}; use nix::pty; use nix::sys::select::{select, FdSet}; use nix::sys::socket::{self, AddressFamily, SockAddr, SockFlag, SockType}; -use nix::sys::wait::{self, WaitStatus}; +use nix::sys::wait; use nix::unistd::{self, close, dup, dup2, fork, setsid, ForkResult}; -use prctl::set_child_subreaper; use std::collections::HashMap; use std::env; use std::ffi::{CStr, CString, OsStr}; @@ -52,23 +51,32 @@ mod network; mod pci; pub mod random; mod sandbox; +mod signal; #[cfg(test)] mod test_utils; mod uevent; +mod util; mod version; use mount::{cgroups_mount, general_mount}; use sandbox::Sandbox; +use signal::setup_signal_handler; use slog::Logger; use uevent::watch_uevents; use std::sync::Mutex as SyncMutex; +use futures::future::join_all; use futures::StreamExt as _; use rustjail::pipestream::PipeStream; use tokio::{ - signal::unix::{signal, SignalKind}, - sync::{oneshot::Sender, Mutex, RwLock}, + io::AsyncWrite, + sync::{ + oneshot::Sender, + watch::{channel, Receiver}, + Mutex, RwLock, + }, + task::JoinHandle, }; use tokio_vsock::{Incoming, VsockListener, VsockStream}; @@ -121,6 +129,146 @@ async fn get_vsock_stream(fd: RawFd) -> Result { Ok(stream) } +// Create a thread to handle reading from the logger pipe. The thread will +// output to the vsock port specified, or stdout. +async fn create_logger_task(rfd: RawFd, vsock_port: u32, shutdown: Receiver) -> Result<()> { + let mut reader = PipeStream::from_fd(rfd); + let mut writer: Box; + + if vsock_port > 0 { + let listenfd = socket::socket( + AddressFamily::Vsock, + SockType::Stream, + SockFlag::SOCK_CLOEXEC, + None, + )?; + + let addr = SockAddr::new_vsock(libc::VMADDR_CID_ANY, vsock_port); + socket::bind(listenfd, &addr).unwrap(); + socket::listen(listenfd, 1).unwrap(); + + writer = Box::new(get_vsock_stream(listenfd).await.unwrap()); + } else { + writer = Box::new(tokio::io::stdout()); + } + + let _ = util::interruptable_io_copier(&mut reader, &mut writer, shutdown).await; + + Ok(()) +} + +async fn real_main() -> std::result::Result<(), Box> { + env::set_var("RUST_BACKTRACE", "full"); + + // List of tasks that need to be stopped for a clean shutdown + let mut tasks: Vec>> = vec![]; + + lazy_static::initialize(&SHELLS); + + lazy_static::initialize(&AGENT_CONFIG); + + // support vsock log + let (rfd, wfd) = unistd::pipe2(OFlag::O_CLOEXEC)?; + + let (shutdown_tx, shutdown_rx) = channel(true); + + let agent_config = AGENT_CONFIG.clone(); + + let init_mode = unistd::getpid() == Pid::from_raw(1); + if init_mode { + // dup a new file descriptor for this temporary logger writer, + // since this logger would be dropped and it's writer would + // be closed out of this code block. + let newwfd = dup(wfd)?; + let writer = unsafe { File::from_raw_fd(newwfd) }; + + // Init a temporary logger used by init agent as init process + // since before do the base mount, it wouldn't access "/proc/cmdline" + // to get the customzied debug level. + let (logger, logger_async_guard) = + logging::create_logger(NAME, "agent", slog::Level::Debug, writer); + + // Must mount proc fs before parsing kernel command line + general_mount(&logger).map_err(|e| { + error!(logger, "fail general mount: {}", e); + e + })?; + + let mut config = agent_config.write().await; + config.parse_cmdline(KERNEL_CMDLINE_FILE)?; + + init_agent_as_init(&logger, config.unified_cgroup_hierarchy)?; + drop(logger_async_guard); + } else { + // once parsed cmdline and set the config, release the write lock + // as soon as possible in case other thread would get read lock on + // it. + let mut config = agent_config.write().await; + config.parse_cmdline(KERNEL_CMDLINE_FILE)?; + } + let config = agent_config.read().await; + + let log_vport = config.log_vport as u32; + + let log_handle = tokio::spawn(create_logger_task(rfd, log_vport, shutdown_rx.clone())); + + tasks.push(log_handle); + + let writer = unsafe { File::from_raw_fd(wfd) }; + + // Recreate a logger with the log level get from "/proc/cmdline". + let (logger, logger_async_guard) = + logging::create_logger(NAME, "agent", config.log_level, writer); + + announce(&logger, &config); + + // This variable is required as it enables the global (and crucially static) logger, + // which is required to satisfy the the lifetime constraints of the auto-generated gRPC code. + let global_logger = slog_scope::set_global_logger(logger.new(o!("subsystem" => "rpc"))); + + // Allow the global logger to be modified later (for shutdown) + global_logger.cancel_reset(); + + let mut ttrpc_log_guard: Result<(), log::SetLoggerError> = Ok(()); + + if config.log_level == slog::Level::Trace { + // Redirect ttrpc log calls to slog iff full debug requested + ttrpc_log_guard = Ok(slog_stdlog::init().map_err(|e| e)?); + } + + // Start the sandbox and wait for its ttRPC server to end + start_sandbox(&logger, &config, init_mode, &mut tasks, shutdown_rx.clone()).await?; + + // Install a NOP logger for the remainder of the shutdown sequence + // to ensure any log calls made by local crates using the scope logger + // don't fail. + let global_logger_guard2 = + slog_scope::set_global_logger(slog::Logger::root(slog::Discard, o!())); + global_logger_guard2.cancel_reset(); + + drop(logger_async_guard); + + drop(ttrpc_log_guard); + + // Trigger a controlled shutdown + shutdown_tx + .send(true) + .map_err(|e| anyhow!(e).context("failed to request shutdown"))?; + + // Wait for all threads to finish + let results = join_all(tasks).await; + + for result in results { + if let Err(e) = result { + return Err(anyhow!(e).into()); + } + } + + eprintln!("{} shutdown complete", NAME); + + Ok(()) +} + fn main() -> std::result::Result<(), Box> { let args: Vec = env::args().collect(); @@ -141,111 +289,20 @@ fn main() -> std::result::Result<(), Box> { exit(0); } - let rt = tokio::runtime::Builder::new_current_thread() + let rt = tokio::runtime::Builder::new_multi_thread() .enable_all() .build()?; - rt.block_on(async { - env::set_var("RUST_BACKTRACE", "full"); - - lazy_static::initialize(&SHELLS); - - lazy_static::initialize(&AGENT_CONFIG); - - // support vsock log - let (rfd, wfd) = unistd::pipe2(OFlag::O_CLOEXEC)?; - - let agent_config = AGENT_CONFIG.clone(); - - let init_mode = unistd::getpid() == Pid::from_raw(1); - if init_mode { - // dup a new file descriptor for this temporary logger writer, - // since this logger would be dropped and it's writer would - // be closed out of this code block. - let newwfd = dup(wfd)?; - let writer = unsafe { File::from_raw_fd(newwfd) }; - - // Init a temporary logger used by init agent as init process - // since before do the base mount, it wouldn't access "/proc/cmdline" - // to get the customzied debug level. - let logger = logging::create_logger(NAME, "agent", slog::Level::Debug, writer); - - // Must mount proc fs before parsing kernel command line - general_mount(&logger).map_err(|e| { - error!(logger, "fail general mount: {}", e); - e - })?; - - let mut config = agent_config.write().await; - config.parse_cmdline(KERNEL_CMDLINE_FILE)?; - - init_agent_as_init(&logger, config.unified_cgroup_hierarchy)?; - } else { - // once parsed cmdline and set the config, release the write lock - // as soon as possible in case other thread would get read lock on - // it. - let mut config = agent_config.write().await; - config.parse_cmdline(KERNEL_CMDLINE_FILE)?; - } - let config = agent_config.read().await; - - let log_vport = config.log_vport as u32; - let log_handle = tokio::spawn(async move { - let mut reader = PipeStream::from_fd(rfd); - - if log_vport > 0 { - let listenfd = socket::socket( - AddressFamily::Vsock, - SockType::Stream, - SockFlag::SOCK_CLOEXEC, - None, - ) - .unwrap(); - - let addr = SockAddr::new_vsock(libc::VMADDR_CID_ANY, log_vport); - socket::bind(listenfd, &addr).unwrap(); - socket::listen(listenfd, 1).unwrap(); - - let mut vsock_stream = get_vsock_stream(listenfd).await.unwrap(); - - // copy log to stdout - tokio::io::copy(&mut reader, &mut vsock_stream) - .await - .unwrap(); - } - - // copy log to stdout - let mut stdout_writer = tokio::io::stdout(); - let _ = tokio::io::copy(&mut reader, &mut stdout_writer).await; - }); - - let writer = unsafe { File::from_raw_fd(wfd) }; - - // Recreate a logger with the log level get from "/proc/cmdline". - let logger = logging::create_logger(NAME, "agent", config.log_level, writer); - - announce(&logger, &config); - - // This "unused" variable is required as it enables the global (and crucially static) logger, - // which is required to satisfy the the lifetime constraints of the auto-generated gRPC code. - let _guard = slog_scope::set_global_logger(logger.new(o!("subsystem" => "rpc"))); - - let mut _log_guard: Result<(), log::SetLoggerError> = Ok(()); - - if config.log_level == slog::Level::Trace { - // Redirect ttrpc log calls to slog iff full debug requested - _log_guard = Ok(slog_stdlog::init().map_err(|e| e)?); - } - - start_sandbox(&logger, &config, init_mode).await?; - - let _ = log_handle.await.unwrap(); - - Ok(()) - }) + rt.block_on(real_main()) } -async fn start_sandbox(logger: &Logger, config: &AgentConfig, init_mode: bool) -> Result<()> { +async fn start_sandbox( + logger: &Logger, + config: &AgentConfig, + init_mode: bool, + tasks: &mut Vec>>, + shutdown: Receiver, +) -> Result<()> { let shells = SHELLS.clone(); let debug_console_vport = config.debug_console_vport as u32; @@ -275,10 +332,17 @@ async fn start_sandbox(logger: &Logger, config: &AgentConfig, init_mode: bool) - let sandbox = Arc::new(Mutex::new(s)); - setup_signal_handler(&logger, sandbox.clone()) - .await - .unwrap(); - watch_uevents(sandbox.clone()).await; + let signal_handler_task = tokio::spawn(setup_signal_handler( + logger.clone(), + sandbox.clone(), + shutdown.clone(), + )); + + tasks.push(signal_handler_task); + + let uevents_handler_task = tokio::spawn(watch_uevents(sandbox.clone(), shutdown.clone())); + + tasks.push(uevents_handler_task); let (tx, rx) = tokio::sync::oneshot::channel(); sandbox.lock().await.sender = Some(tx); @@ -297,93 +361,6 @@ async fn start_sandbox(logger: &Logger, config: &AgentConfig, init_mode: bool) - Ok(()) } -use nix::sys::wait::WaitPidFlag; - -async fn setup_signal_handler(logger: &Logger, sandbox: Arc>) -> Result<()> { - let logger = logger.new(o!("subsystem" => "signals")); - - set_child_subreaper(true) - .map_err(|err| anyhow!(err).context("failed to setup agent as a child subreaper"))?; - - let mut signal_stream = signal(SignalKind::child())?; - - tokio::spawn(async move { - 'outer: loop { - signal_stream.recv().await; - info!(logger, "received signal"; "signal" => "SIGCHLD"); - - // sevral signals can be combined together - // as one. So loop around to reap all - // exited children - 'inner: loop { - let wait_status = match wait::waitpid( - Some(Pid::from_raw(-1)), - Some(WaitPidFlag::WNOHANG | WaitPidFlag::__WALL), - ) { - Ok(s) => { - if s == WaitStatus::StillAlive { - continue 'outer; - } - s - } - Err(e) => { - info!( - logger, - "waitpid reaper failed"; - "error" => e.as_errno().unwrap().desc() - ); - continue 'outer; - } - }; - info!(logger, "wait_status"; "wait_status result" => format!("{:?}", wait_status)); - - let pid = wait_status.pid(); - if let Some(pid) = pid { - let raw_pid = pid.as_raw(); - let child_pid = format!("{}", raw_pid); - - let logger = logger.new(o!("child-pid" => child_pid)); - - let mut sandbox = sandbox.lock().await; - let process = sandbox.find_process(raw_pid); - if process.is_none() { - info!(logger, "child exited unexpectedly"); - continue 'inner; - } - - let mut p = process.unwrap(); - - if p.exit_pipe_w.is_none() { - error!(logger, "the process's exit_pipe_w isn't set"); - continue 'inner; - } - let pipe_write = p.exit_pipe_w.unwrap(); - let ret: i32; - - match wait_status { - WaitStatus::Exited(_, c) => ret = c, - WaitStatus::Signaled(_, sig, _) => ret = sig as i32, - _ => { - info!(logger, "got wrong status for process"; - "child-status" => format!("{:?}", wait_status)); - continue 'inner; - } - } - - p.exit_code = ret; - let _ = unistd::close(pipe_write); - - info!(logger, "notify term to close"); - // close the socket file to notify readStdio to close terminal specifically - // in case this process's terminal has been inherited by its children. - p.notify_term_close(); - } - } - } - }); - Ok(()) -} - // init_agent_as_init will do the initializations such as setting up the rootfs // when this agent has been run as the init process. fn init_agent_as_init(logger: &Logger, unified_cgroup_hierarchy: bool) -> Result<()> { diff --git a/src/agent/src/signal.rs b/src/agent/src/signal.rs new file mode 100644 index 0000000000..283951117e --- /dev/null +++ b/src/agent/src/signal.rs @@ -0,0 +1,159 @@ +// Copyright (c) 2019-2020 Ant Financial +// Copyright (c) 2020 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 +// + +use crate::sandbox::Sandbox; +use anyhow::{anyhow, Result}; +use nix::sys::wait::WaitPidFlag; +use nix::sys::wait::{self, WaitStatus}; +use nix::unistd; +use prctl::set_child_subreaper; +use slog::{error, info, o, Logger}; +use std::sync::Arc; +use tokio::select; +use tokio::signal::unix::{signal, SignalKind}; +use tokio::sync::watch::Receiver; +use tokio::sync::Mutex; +use unistd::Pid; + +async fn handle_sigchild(logger: Logger, sandbox: Arc>) -> Result<()> { + info!(logger, "handling signal"; "signal" => "SIGCHLD"); + + loop { + let result = wait::waitpid( + Some(Pid::from_raw(-1)), + Some(WaitPidFlag::WNOHANG | WaitPidFlag::__WALL), + ); + + let wait_status = match result { + Ok(s) => { + if s == WaitStatus::StillAlive { + return Ok(()); + } + s + } + Err(e) => return Err(anyhow!(e).context("waitpid reaper failed")), + }; + + info!(logger, "wait_status"; "wait_status result" => format!("{:?}", wait_status)); + + if let Some(pid) = wait_status.pid() { + let raw_pid = pid.as_raw(); + let child_pid = format!("{}", raw_pid); + + let logger = logger.new(o!("child-pid" => child_pid)); + + let sandbox_ref = sandbox.clone(); + let mut sandbox = sandbox_ref.lock().await; + + let process = sandbox.find_process(raw_pid); + if process.is_none() { + info!(logger, "child exited unexpectedly"); + continue; + } + + let mut p = process.unwrap(); + + if p.exit_pipe_w.is_none() { + info!(logger, "process exit pipe not set"); + continue; + } + + let pipe_write = p.exit_pipe_w.unwrap(); + let ret: i32; + + match wait_status { + WaitStatus::Exited(_, c) => ret = c, + WaitStatus::Signaled(_, sig, _) => ret = sig as i32, + _ => { + info!(logger, "got wrong status for process"; + "child-status" => format!("{:?}", wait_status)); + continue; + } + } + + p.exit_code = ret; + let _ = unistd::close(pipe_write); + + info!(logger, "notify term to close"); + // close the socket file to notify readStdio to close terminal specifically + // in case this process's terminal has been inherited by its children. + p.notify_term_close(); + } + } +} + +pub async fn setup_signal_handler( + logger: Logger, + sandbox: Arc>, + mut shutdown: Receiver, +) -> Result<()> { + let logger = logger.new(o!("subsystem" => "signals")); + + set_child_subreaper(true) + .map_err(|err| anyhow!(err).context("failed to setup agent as a child subreaper"))?; + + let mut sigchild_stream = signal(SignalKind::child())?; + + loop { + select! { + _ = shutdown.changed() => { + info!(logger, "got shutdown request"); + break; + } + + _ = sigchild_stream.recv() => { + let result = handle_sigchild(logger.clone(), sandbox.clone()).await; + + match result { + Ok(()) => (), + Err(e) => { + // Log errors, but don't abort - just wait for more signals! + error!(logger, "failed to handle signal"; "error" => format!("{:?}", e)); + } + } + } + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::pin; + use tokio::sync::watch::channel; + use tokio::time::Duration; + + #[tokio::test] + async fn test_setup_signal_handler() { + let logger = slog::Logger::root(slog::Discard, o!()); + let s = Sandbox::new(&logger).unwrap(); + + let sandbox = Arc::new(Mutex::new(s)); + + let (tx, rx) = channel(true); + + let handle = tokio::spawn(setup_signal_handler(logger, sandbox, rx)); + + let timeout = tokio::time::sleep(Duration::from_secs(1)); + pin!(timeout); + + tx.send(true).expect("failed to request shutdown"); + + loop { + select! { + _ = handle => { + println!("INFO: task completed"); + break; + }, + _ = &mut timeout => { + panic!("signal thread failed to stop"); + } + } + } + } +} diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index 94c4253da8..c4067ed5ab 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -9,10 +9,13 @@ use crate::sandbox::Sandbox; use crate::GLOBAL_DEVICE_WATCHER; use slog::Logger; +use anyhow::Result; use netlink_sys::{protocols, SocketAddr, TokioSocket}; use nix::errno::Errno; use std::os::unix::io::FromRawFd; use std::sync::Arc; +use tokio::select; +use tokio::sync::watch::Receiver; use tokio::sync::Mutex; #[derive(Debug, Default)] @@ -132,49 +135,67 @@ impl Uevent { } } -pub async fn watch_uevents(sandbox: Arc>) { +pub async fn watch_uevents( + sandbox: Arc>, + mut shutdown: Receiver, +) -> Result<()> { let sref = sandbox.clone(); let s = sref.lock().await; let logger = s.logger.new(o!("subsystem" => "uevent")); - tokio::spawn(async move { - let mut socket; - unsafe { - let fd = libc::socket( - libc::AF_NETLINK, - libc::SOCK_DGRAM | libc::SOCK_CLOEXEC, - protocols::NETLINK_KOBJECT_UEVENT as libc::c_int, - ); - socket = TokioSocket::from_raw_fd(fd); - } - socket.bind(&SocketAddr::new(0, 1)).unwrap(); + // Unlock the sandbox to allow a successful shutdown + drop(s); - loop { - match socket.recv_from_full().await { - Err(e) => { - error!(logger, "receive uevent message failed"; "error" => format!("{}", e)) - } - Ok((buf, addr)) => { - if addr.port_number() != 0 { - // not our netlink message - let err_msg = format!("{:?}", nix::Error::Sys(Errno::EBADMSG)); - error!(logger, "receive uevent message failed"; "error" => err_msg); - return; + info!(logger, "starting uevents handler"); + + let mut socket; + + unsafe { + let fd = libc::socket( + libc::AF_NETLINK, + libc::SOCK_DGRAM | libc::SOCK_CLOEXEC, + protocols::NETLINK_KOBJECT_UEVENT as libc::c_int, + ); + socket = TokioSocket::from_raw_fd(fd); + } + + socket.bind(&SocketAddr::new(0, 1))?; + + loop { + select! { + _ = shutdown.changed() => { + info!(logger, "got shutdown request"); + break; + } + result = socket.recv_from_full() => { + match result { + Err(e) => { + error!(logger, "failed to receive uevent"; "error" => format!("{}", e)) } - - let text = String::from_utf8(buf); - match text { - Err(e) => { - error!(logger, "failed to convert bytes to text"; "error" => format!("{}", e)) + Ok((buf, addr)) => { + if addr.port_number() != 0 { + // not our netlink message + let err_msg = format!("{:?}", nix::Error::Sys(Errno::EBADMSG)); + error!(logger, "receive uevent message failed"; "error" => err_msg); + continue; } - Ok(text) => { - let event = Uevent::new(&text); - info!(logger, "got uevent message"; "event" => format!("{:?}", event)); - event.process(&logger, &sandbox).await; + + let text = String::from_utf8(buf); + match text { + Err(e) => { + error!(logger, "failed to convert bytes to text"; "error" => format!("{}", e)) + } + Ok(text) => { + let event = Uevent::new(&text); + info!(logger, "got uevent message"; "event" => format!("{:?}", event)); + event.process(&logger, &sandbox).await; + } } } } } } - }); + } + + Ok(()) } diff --git a/src/agent/src/util.rs b/src/agent/src/util.rs new file mode 100644 index 0000000000..314d05a254 --- /dev/null +++ b/src/agent/src/util.rs @@ -0,0 +1,342 @@ +// Copyright (c) 2021 Intel Corporation +// +// SPDX-License-Identifier: Apache-2.0 +// + +use std::io; +use std::io::ErrorKind; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::sync::watch::Receiver; + +// Size of I/O read buffer +const BUF_SIZE: usize = 8192; + +// Interruptable I/O copy using readers and writers +// (an interruptable version of "io::copy()"). +pub async fn interruptable_io_copier( + mut reader: R, + mut writer: W, + mut shutdown: Receiver, +) -> io::Result +where + R: tokio::io::AsyncRead + Unpin, + W: tokio::io::AsyncWrite + Unpin, +{ + let mut total_bytes: u64 = 0; + + let mut buf: [u8; BUF_SIZE] = [0; BUF_SIZE]; + + loop { + tokio::select! { + _ = shutdown.changed() => { + eprintln!("INFO: interruptable_io_copier: got shutdown request"); + break; + }, + + result = reader.read(&mut buf) => { + let bytes = match result { + Ok(0) => return Ok(total_bytes), + Ok(len) => len, + Err(ref e) if e.kind() == ErrorKind::Interrupted => continue, + Err(e) => return Err(e), + }; + + total_bytes += bytes as u64; + + // Actually copy the data ;) + writer.write_all(&buf[..bytes]).await?; + }, + }; + } + + Ok(total_bytes) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io; + use std::io::Cursor; + use std::io::Write; + use std::pin::Pin; + use std::sync::{Arc, Mutex}; + use std::task::{Context, Poll, Poll::Ready}; + use tokio::pin; + use tokio::select; + use tokio::sync::watch::channel; + use tokio::task::JoinError; + use tokio::time::Duration; + + #[derive(Debug, Default, Clone)] + struct BufWriter { + data: Arc>>, + slow_write: bool, + write_delay: Duration, + } + + impl BufWriter { + fn new() -> Self { + BufWriter { + data: Arc::new(Mutex::new(Vec::::new())), + slow_write: false, + write_delay: Duration::new(0, 0), + } + } + + fn write_vec(&mut self, buf: &[u8]) -> io::Result { + let vec_ref = self.data.clone(); + + let mut vec_locked = vec_ref.lock(); + + let mut v = vec_locked.as_deref_mut().unwrap(); + + if self.write_delay.as_nanos() > 0 { + std::thread::sleep(self.write_delay); + } + + std::io::Write::write(&mut v, buf) + } + } + + impl Write for BufWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.write_vec(buf) + } + + fn flush(&mut self) -> io::Result<()> { + let vec_ref = self.data.clone(); + + let mut vec_locked = vec_ref.lock(); + + let v = vec_locked.as_deref_mut().unwrap(); + + std::io::Write::flush(v) + } + } + + impl tokio::io::AsyncWrite for BufWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let result = self.write_vec(buf); + + Ready(result) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + // NOP + Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + // NOP + Ready(Ok(())) + } + } + + impl ToString for BufWriter { + fn to_string(&self) -> String { + let data_ref = self.data.clone(); + let output = data_ref.lock().unwrap(); + let s = (*output).clone(); + + String::from_utf8(s).unwrap() + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_interruptable_io_copier_reader() { + #[derive(Debug)] + struct TestData { + reader_value: String, + result: io::Result, + } + + let tests = &[ + TestData { + reader_value: "".into(), + result: Ok(0), + }, + TestData { + reader_value: "a".into(), + result: Ok(1), + }, + TestData { + reader_value: "foo".into(), + result: Ok(3), + }, + TestData { + reader_value: "b".repeat(BUF_SIZE - 1), + result: Ok((BUF_SIZE - 1) as u64), + }, + TestData { + reader_value: "c".repeat(BUF_SIZE), + result: Ok((BUF_SIZE) as u64), + }, + TestData { + reader_value: "d".repeat(BUF_SIZE + 1), + result: Ok((BUF_SIZE + 1) as u64), + }, + TestData { + reader_value: "e".repeat((2 * BUF_SIZE) - 1), + result: Ok(((2 * BUF_SIZE) - 1) as u64), + }, + TestData { + reader_value: "f".repeat(2 * BUF_SIZE), + result: Ok((2 * BUF_SIZE) as u64), + }, + TestData { + reader_value: "g".repeat((2 * BUF_SIZE) + 1), + result: Ok(((2 * BUF_SIZE) + 1) as u64), + }, + ]; + + for (i, d) in tests.iter().enumerate() { + // Create a string containing details of the test + let msg = format!("test[{}]: {:?}", i, d); + + let (tx, rx) = channel(true); + let reader = Cursor::new(d.reader_value.clone()); + let writer = BufWriter::new(); + + // XXX: Pass a copy of the writer to the copier to allow the + // result of the write operation to be checked below. + let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx)); + + // Allow time for the thread to be spawned. + tokio::time::sleep(Duration::from_secs(1)).await; + + let timeout = tokio::time::sleep(Duration::from_secs(1)); + pin!(timeout); + + // Since the readers only specify a small number of bytes, the + // copier will quickly read zero and kill the task, closing the + // Receiver. + assert!(tx.is_closed(), "{}", msg); + + let spawn_result: std::result::Result< + std::result::Result, + JoinError, + >; + + let result: std::result::Result; + + select! { + res = handle => spawn_result = res, + _ = &mut timeout => panic!("timed out"), + } + + assert!(spawn_result.is_ok()); + + result = spawn_result.unwrap(); + + assert!(result.is_ok()); + + let byte_count = result.unwrap() as usize; + assert_eq!(byte_count, d.reader_value.len(), "{}", msg); + + let value = writer.to_string(); + assert_eq!(value, d.reader_value, "{}", msg); + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_interruptable_io_copier_eof() { + // Create an async reader that always returns EOF + let reader = tokio::io::empty(); + + let (tx, rx) = channel(true); + let writer = BufWriter::new(); + + let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx)); + + // Allow time for the thread to be spawned. + tokio::time::sleep(Duration::from_secs(1)).await; + + let timeout = tokio::time::sleep(Duration::from_secs(1)); + pin!(timeout); + + assert!(tx.is_closed()); + + let spawn_result: std::result::Result, JoinError>; + + let result: std::result::Result; + + select! { + res = handle => spawn_result = res, + _ = &mut timeout => panic!("timed out"), + } + + assert!(spawn_result.is_ok()); + + result = spawn_result.unwrap(); + + assert!(result.is_ok()); + + let byte_count = result.unwrap(); + assert_eq!(byte_count, 0); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_interruptable_io_copier_shutdown() { + // Create an async reader that creates an infinite stream of bytes + // (which allows us to interrupt it, since we know it is always busy ;) + const REPEAT_CHAR: u8 = b'r'; + + let reader = tokio::io::repeat(REPEAT_CHAR); + + let (tx, rx) = channel(true); + let writer = BufWriter::new(); + + let handle = tokio::spawn(interruptable_io_copier(reader, writer.clone(), rx)); + + // Allow time for the thread to be spawned. + tokio::time::sleep(Duration::from_secs(1)).await; + + let timeout = tokio::time::sleep(Duration::from_secs(1)); + pin!(timeout); + + assert!(!tx.is_closed()); + + tx.send(true).expect("failed to request shutdown"); + + let spawn_result: std::result::Result, JoinError>; + + let result: std::result::Result; + + select! { + res = handle => spawn_result = res, + _ = &mut timeout => panic!("timed out"), + } + + assert!(spawn_result.is_ok()); + + result = spawn_result.unwrap(); + + assert!(result.is_ok()); + + let byte_count = result.unwrap(); + + let value = writer.to_string(); + + let writer_byte_count = value.len() as u64; + + assert_eq!(byte_count, writer_byte_count); + + // Remove the char used as a payload. If anything else remins, + // something went wrong. + let mut remainder = value; + + remainder.retain(|c| c != REPEAT_CHAR as char); + + assert_eq!(remainder.len(), 0); + } +} diff --git a/src/trace-forwarder/src/main.rs b/src/trace-forwarder/src/main.rs index c6925e183a..8e5f646d46 100644 --- a/src/trace-forwarder/src/main.rs +++ b/src/trace-forwarder/src/main.rs @@ -180,7 +180,7 @@ fn real_main() -> Result<()> { // Setup logger let writer = io::stdout(); - let logger = logging::create_logger(name, name, log_level, writer); + let (logger, _logger_guard) = logging::create_logger(name, name, log_level, writer); announce(&logger, version); diff --git a/tools/agent-ctl/src/main.rs b/tools/agent-ctl/src/main.rs index e36ee22783..313068e0c1 100644 --- a/tools/agent-ctl/src/main.rs +++ b/tools/agent-ctl/src/main.rs @@ -142,7 +142,7 @@ fn connect(name: &str, global_args: clap::ArgMatches) -> Result<()> { let log_level = logging::level_name_to_slog_level(log_level_name).map_err(|e| anyhow!(e))?; let writer = io::stdout(); - let logger = logging::create_logger(name, crate_name!(), log_level, writer); + let (logger, _guard) = logging::create_logger(name, crate_name!(), log_level, writer); let timeout_nano: i64 = match args.value_of("timeout") { Some(t) => utils::human_time_to_ns(t).map_err(|e| e)?,