diff --git a/src/agent/src/config.rs b/src/agent/src/config.rs index 7533a7c03b..9bd191113c 100644 --- a/src/agent/src/config.rs +++ b/src/agent/src/config.rs @@ -6,17 +6,17 @@ use rustjail::errors::*; use std::fs; use std::time; -const DEBUG_CONSOLE_FLAG: &'static str = "agent.debug_console"; -const DEV_MODE_FLAG: &'static str = "agent.devmode"; -const LOG_LEVEL_OPTION: &'static str = "agent.log"; -const HOTPLUG_TIMOUT_OPTION: &'static str = "agent.hotplug_timeout"; +const DEBUG_CONSOLE_FLAG: &str = "agent.debug_console"; +const DEV_MODE_FLAG: &str = "agent.devmode"; +const LOG_LEVEL_OPTION: &str = "agent.log"; +const HOTPLUG_TIMOUT_OPTION: &str = "agent.hotplug_timeout"; const DEFAULT_LOG_LEVEL: slog::Level = slog::Level::Info; const DEFAULT_HOTPLUG_TIMEOUT: time::Duration = time::Duration::from_secs(3); // FIXME: unused -const TRACE_MODE_FLAG: &'static str = "agent.trace"; -const USE_VSOCK_FLAG: &'static str = "agent.use_vsock"; +const TRACE_MODE_FLAG: &str = "agent.trace"; +const USE_VSOCK_FLAG: &str = "agent.use_vsock"; #[derive(Debug)] pub struct agentConfig { @@ -100,17 +100,11 @@ fn get_log_level(param: &str) -> Result { return Err(ErrorKind::ErrorCode(String::from("invalid log level parameter")).into()); } - let key = fields[0]; - - if key != LOG_LEVEL_OPTION { - return Err(ErrorKind::ErrorCode(String::from("invalid log level key name").into()).into()); + if fields[0] != LOG_LEVEL_OPTION { + Err(ErrorKind::ErrorCode(String::from("invalid log level key name")).into()) + } else { + Ok(logrus_to_slog_level(fields[1])?) } - - let value = fields[1]; - - let level = logrus_to_slog_level(value)?; - - Ok(level) } fn get_hotplug_timeout(param: &str) -> Result { diff --git a/src/agent/src/device.rs b/src/agent/src/device.rs index 9ef45d3b21..bbd97cc6b6 100644 --- a/src/agent/src/device.rs +++ b/src/agent/src/device.rs @@ -3,21 +3,20 @@ // SPDX-License-Identifier: Apache-2.0 // -use rustjail::errors::*; -use std::fs; -// use std::io::Write; use libc::{c_uint, major, minor}; use std::collections::HashMap; +use std::fs; use std::os::unix::fs::MetadataExt; use std::path::Path; -use std::sync::mpsc; -use std::sync::{Arc, Mutex}; +use std::sync::{mpsc, Arc, Mutex}; +use crate::linux_abi::*; use crate::mount::{DRIVERBLKTYPE, DRIVERMMIOBLKTYPE, DRIVERNVDIMMTYPE, DRIVERSCSITYPE}; use crate::sandbox::Sandbox; use crate::{AGENT_CONFIG, GLOBAL_DEVICE_WATCHER}; use protocols::agent::Device; use protocols::oci::Spec; +use rustjail::errors::*; // Convenience macro to obtain the scope logger macro_rules! sl { @@ -26,57 +25,24 @@ macro_rules! sl { }; } -#[cfg(any( - target_arch = "x86_64", - target_arch = "x86", - target_arch = "powerpc64le", - target_arch = "s390x" -))] -pub const ROOT_BUS_PATH: &'static str = "/devices/pci0000:00"; -#[cfg(target_arch = "arm")] -pub const ROOT_BUS_PATH: &'static str = "/devices/platform/4010000000.pcie/pci0000:00"; - -pub const SYSFS_DIR: &'static str = "/sys"; - -const SYS_BUS_PREFIX: &'static str = "/sys/bus/pci/devices"; -const PCI_BUS_RESCAN_FILE: &'static str = "/sys/bus/pci/rescan"; -const SYSTEM_DEV_PATH: &'static str = "/dev"; - -// SCSI const - -// Here in "0:0", the first number is the SCSI host number because -// only one SCSI controller has been plugged, while the second number -// is always 0. -pub const SCSI_HOST_CHANNEL: &'static str = "0:0:"; -const SYS_CLASS_PREFIX: &'static str = "/sys/class"; -const SCSI_DISK_PREFIX: &'static str = "/sys/class/scsi_disk/0:0:"; -pub const SCSI_BLOCK_SUFFIX: &'static str = "block"; -const SCSI_DISK_SUFFIX: &'static str = "/device/block"; -const SCSI_HOST_PATH: &'static str = "/sys/class/scsi_host"; - -// DeviceHandler is the type of callback to be defined to handle every -// type of device driver. -type DeviceHandler = fn(&Device, &mut Spec, Arc>) -> Result<()>; +// DeviceHandler is the type of callback to be defined to handle every type of device driver. +type DeviceHandler = fn(&Device, &mut Spec, &Arc>) -> Result<()>; // DeviceHandlerList lists the supported drivers. #[cfg_attr(rustfmt, rustfmt_skip)] lazy_static! { - pub static ref DEVICEHANDLERLIST: HashMap<&'static str, DeviceHandler> = { - let mut m = HashMap::new(); - let blk: DeviceHandler = virtio_blk_device_handler; - m.insert(DRIVERBLKTYPE, blk); - let virtiommio: DeviceHandler = virtiommio_blk_device_handler; - m.insert(DRIVERMMIOBLKTYPE, virtiommio); - let local: DeviceHandler = virtio_nvdimm_device_handler; - m.insert(DRIVERNVDIMMTYPE, local); - let scsi: DeviceHandler = virtio_scsi_device_handler; - m.insert(DRIVERSCSITYPE, scsi); + static ref DEVICEHANDLERLIST: HashMap<&'static str, DeviceHandler> = { + let mut m: HashMap<&'static str, DeviceHandler> = HashMap::new(); + m.insert(DRIVERBLKTYPE, virtio_blk_device_handler); + m.insert(DRIVERMMIOBLKTYPE, virtiommio_blk_device_handler); + m.insert(DRIVERNVDIMMTYPE, virtio_nvdimm_device_handler); + m.insert(DRIVERSCSITYPE, virtio_scsi_device_handler); m }; } pub fn rescan_pci_bus() -> Result<()> { - online_device(PCI_BUS_RESCAN_FILE) + online_device(SYSFS_PCI_BUS_RESCAN_FILE) } pub fn online_device(path: &str) -> Result<()> { @@ -84,11 +50,11 @@ pub fn online_device(path: &str) -> Result<()> { Ok(()) } -// get_device_pci_address fetches the complete PCI address in sysfs, based on the PCI +// get_pci_device_address fetches the complete PCI address in sysfs, based on the PCI // identifier provided. This should be in the format: "bridgeAddr/deviceAddr". -// Here, bridgeAddr is the address at which the brige is attached on the root bus, +// Here, bridgeAddr is the address at which the bridge is attached on the root bus, // while deviceAddr is the address at which the device is attached on the bridge. -pub fn get_device_pci_address(pci_id: &str) -> Result { +fn get_pci_device_address(pci_id: &str) -> Result { let tokens: Vec<&str> = pci_id.split("/").collect(); if tokens.len() != 2 { @@ -107,7 +73,7 @@ pub fn get_device_pci_address(pci_id: &str) -> Result { let pci_bridge_addr = format!("0000:00:{}.0", bridge_id); // Find out the bus exposed by bridge - let bridge_bus_path = format!("{}/{}/pci_bus/", SYS_BUS_PREFIX, pci_bridge_addr); + let bridge_bus_path = format!("{}/{}/pci_bus/", SYSFS_PCI_BUS_PREFIX, pci_bridge_addr); let files_slice: Vec<_> = fs::read_dir(&bridge_bus_path) .unwrap() @@ -139,78 +105,60 @@ pub fn get_device_pci_address(pci_id: &str) -> Result { Ok(bridge_device_pci_addr) } -pub fn get_device_name(sandbox: Arc>, dev_addr: &str) -> Result { - let mut dev_name: String = String::default(); +fn get_device_name(sandbox: &Arc>, dev_addr: &str) -> Result { + // Keep the same lock order as uevent::handle_block_add_event(), otherwise it may cause deadlock. + let mut w = GLOBAL_DEVICE_WATCHER.lock().unwrap(); + let sb = sandbox.lock().unwrap(); + for (key, value) in sb.pci_device_map.iter() { + if key.contains(dev_addr) { + info!(sl!(), "Device {} found in pci device map", dev_addr); + return Ok(format!("{}/{}", SYSTEM_DEV_PATH, value)); + } + } + drop(sb); + + // If device is not found in the device map, hotplug event has not + // been received yet, create and add channel to the watchers map. + // The key of the watchers map is the device we are interested in. + // Note this is done inside the lock, not to miss any events from the + // global udev listener. let (tx, rx) = mpsc::channel::(); + w.insert(dev_addr.to_string(), tx); + drop(w); - { - let watcher = GLOBAL_DEVICE_WATCHER.clone(); - let mut w = watcher.lock().unwrap(); - - let s = sandbox.clone(); - let sb = s.lock().unwrap(); - - for (key, value) in &(sb.pci_device_map) { - if key.contains(dev_addr) { - dev_name = value.to_string(); - info!(sl!(), "Device {} found in pci device map", dev_addr); - break; - } + info!(sl!(), "Waiting on channel for device notification\n"); + let hotplug_timeout = AGENT_CONFIG.read().unwrap().hotplug_timeout; + let dev_name = match rx.recv_timeout(hotplug_timeout) { + Ok(name) => name, + Err(_) => { + GLOBAL_DEVICE_WATCHER.lock().unwrap().remove_entry(dev_addr); + return Err(ErrorKind::ErrorCode(format!( + "Timeout reached after {:?} waiting for device {}", + hotplug_timeout, dev_addr + )) + .into()); } - - // If device is not found in the device map, hotplug event has not - // been received yet, create and add channel to the watchers map. - // The key of the watchers map is the device we are interested in. - // Note this is done inside the lock, not to miss any events from the - // global udev listener. - if dev_name == "" { - w.insert(dev_addr.to_string(), tx); - } - } - - if dev_name == "" { - info!(sl!(), "Waiting on channel for device notification\n"); - - let agent_config = AGENT_CONFIG.clone(); - let config = agent_config.read().unwrap(); - - match rx.recv_timeout(config.hotplug_timeout) { - Ok(name) => dev_name = name, - Err(_) => { - let watcher = GLOBAL_DEVICE_WATCHER.clone(); - let mut w = watcher.lock().unwrap(); - w.remove_entry(dev_addr); - - return Err(ErrorKind::ErrorCode(format!( - "Timeout reached after {:?} waiting for device {}", - config.hotplug_timeout, dev_addr - )) - .into()); - } - } - } + }; Ok(format!("{}/{}", SYSTEM_DEV_PATH, &dev_name)) } -pub fn get_scsi_device_name(sandbox: Arc>, scsi_addr: &str) -> Result { - scan_scsi_bus(scsi_addr)?; - +pub fn get_scsi_device_name(sandbox: &Arc>, scsi_addr: &str) -> Result { let dev_sub_path = format!("{}{}/{}", SCSI_HOST_CHANNEL, scsi_addr, SCSI_BLOCK_SUFFIX); - get_device_name(sandbox, dev_sub_path.as_str()) + scan_scsi_bus(scsi_addr)?; + get_device_name(sandbox, &dev_sub_path) } -pub fn get_pci_device_name(sandbox: Arc>, pci_id: &str) -> Result { - let pci_addr = get_device_pci_address(pci_id)?; +pub fn get_pci_device_name(sandbox: &Arc>, pci_id: &str) -> Result { + let pci_addr = get_pci_device_address(pci_id)?; rescan_pci_bus()?; - - get_device_name(sandbox, pci_addr.as_str()) + get_device_name(sandbox, &pci_addr) } -// scan_scsi_bus scans SCSI bus for the given SCSI address(SCSI-Id and LUN) -pub fn scan_scsi_bus(scsi_addr: &str) -> Result<()> { +/// Scan SCSI bus for the given SCSI address(SCSI-Id and LUN) +fn scan_scsi_bus(scsi_addr: &str) -> Result<()> { let tokens: Vec<&str> = scsi_addr.split(":").collect(); if tokens.len() != 2 { return Err(ErrorKind::Msg(format!( @@ -220,15 +168,18 @@ pub fn scan_scsi_bus(scsi_addr: &str) -> Result<()> { .into()); } - // Scan scsi host passing in the channel, SCSI id and LUN. Channel - // is always 0 because we have only one SCSI controller. + // Scan scsi host passing in the channel, SCSI id and LUN. + // Channel is always 0 because we have only one SCSI controller. let scan_data = format!("0 {} {}", tokens[0], tokens[1]); - for entry in fs::read_dir(SCSI_HOST_PATH)? { - let entry = entry?; - - let host = entry.file_name(); - let scan_path = format!("{}/{}/{}", SCSI_HOST_PATH, host.to_str().unwrap(), "scan"); + for entry in fs::read_dir(SYSFS_SCSI_HOST_PATH)? { + let host = entry?.file_name(); + let scan_path = format!( + "{}/{}/{}", + SYSFS_SCSI_HOST_PATH, + host.to_str().unwrap(), + "scan" + ); fs::write(scan_path, &scan_data)?; } @@ -243,9 +194,6 @@ pub fn scan_scsi_bus(scsi_addr: &str) -> Result<()> { // This is needed to update information about minor/major numbers that cannot // be predicted from the caller. fn update_spec_device_list(device: &Device, spec: &mut Spec) -> Result<()> { - // If no container_path is provided, we won't be able to match and - // update the device in the OCI spec device list. This is an error. - let major_id: c_uint; let minor_id: c_uint; @@ -253,7 +201,7 @@ fn update_spec_device_list(device: &Device, spec: &mut Spec) -> Result<()> { // update the device in the OCI spec device list. This is an error. if device.container_path == "" { return Err(ErrorKind::Msg(format!( - "container_path cannot empty for device {:?}", + "container_path cannot empty for device {:?}", device )) .into()); @@ -304,9 +252,7 @@ fn update_spec_device_list(device: &Device, spec: &mut Spec) -> Result<()> { // Resources must be updated since they are used to identify the // device in the devices cgroup. - let resource = linux.Resources.as_mut(); - if resource.is_some() { - let res = resource.unwrap(); + if let Some(res) = linux.Resources.as_mut() { let ds = res.Devices.as_mut_slice(); for d in ds.iter_mut() { if d.Major == host_major && d.Minor == host_minor { @@ -331,10 +277,10 @@ fn update_spec_device_list(device: &Device, spec: &mut Spec) -> Result<()> { fn virtiommio_blk_device_handler( device: &Device, spec: &mut Spec, - _sandbox: Arc>, + _sandbox: &Arc>, ) -> Result<()> { if device.vm_path == "" { - return Err(ErrorKind::Msg("Invalid path for virtiommioblkdevice".to_string()).into()); + return Err(ErrorKind::Msg("Invalid path for virtio mmio blk device".to_string()).into()); } update_spec_device_list(device, spec) @@ -346,13 +292,10 @@ fn virtiommio_blk_device_handler( fn virtio_blk_device_handler( device: &Device, spec: &mut Spec, - sandbox: Arc>, + sandbox: &Arc>, ) -> Result<()> { - let dev_path = get_pci_device_name(sandbox, device.id.as_str())?; - let mut dev = device.clone(); - dev.vm_path = dev_path; - + dev.vm_path = get_pci_device_name(sandbox, &device.id)?; update_spec_device_list(&dev, spec) } @@ -360,39 +303,39 @@ fn virtio_blk_device_handler( fn virtio_scsi_device_handler( device: &Device, spec: &mut Spec, - sandbox: Arc>, + sandbox: &Arc>, ) -> Result<()> { - let dev_path = get_scsi_device_name(sandbox, device.id.as_str())?; - let mut dev = device.clone(); - dev.vm_path = dev_path; - + dev.vm_path = get_scsi_device_name(sandbox, &device.id)?; update_spec_device_list(&dev, spec) } fn virtio_nvdimm_device_handler( device: &Device, spec: &mut Spec, - _sandbox: Arc>, + _sandbox: &Arc>, ) -> Result<()> { + if device.vm_path == "" { + return Err(ErrorKind::Msg("Invalid path for nvdimm device".to_string()).into()); + } + update_spec_device_list(device, spec) } pub fn add_devices( - devices: Vec, + devices: &[Device], spec: &mut Spec, - sandbox: Arc>, + sandbox: &Arc>, ) -> Result<()> { for device in devices.iter() { - add_device(device, spec, sandbox.clone())?; + add_device(device, spec, sandbox)?; } Ok(()) } -fn add_device(device: &Device, spec: &mut Spec, sandbox: Arc>) -> Result<()> { - // log before validation to help with debugging gRPC protocol - // version differences. +fn add_device(device: &Device, spec: &mut Spec, sandbox: &Arc>) -> Result<()> { + // log before validation to help with debugging gRPC protocol version differences. info!(sl!(), "device-id: {}, device-type: {}, device-vm-path: {}, device-container-path: {}, device-options: {:?}", device.id, device.field_type, device.vm_path, device.container_path, device.options); @@ -412,12 +355,8 @@ fn add_device(device: &Device, spec: &mut Spec, sandbox: Arc>) -> ); } - let dev_handler = match DEVICEHANDLERLIST.get(device.field_type.as_str()) { - None => { - return Err(ErrorKind::Msg(format!("Unknown device type {}", device.field_type)).into()) - } - Some(t) => t, - }; - - dev_handler(device, spec, sandbox) + match DEVICEHANDLERLIST.get(device.field_type.as_str()) { + None => Err(ErrorKind::Msg(format!("Unknown device type {}", device.field_type)).into()), + Some(dev_handler) => dev_handler(device, spec, sandbox), + } } diff --git a/src/agent/src/grpc.rs b/src/agent/src/grpc.rs index 2e3147f23a..afe716c75d 100644 --- a/src/agent/src/grpc.rs +++ b/src/agent/src/grpc.rs @@ -30,6 +30,7 @@ use nix::unistd::{self, Pid}; use rustjail::process::ProcessOperations; use crate::device::{add_devices, rescan_pci_bus}; +use crate::linux_abi::*; use crate::mount::{add_storages, remove_mounts, STORAGEHANDLERLIST}; use crate::namespace::{NSTYPEIPC, NSTYPEPID, NSTYPEUTS}; use crate::netlink::{RtnlHandle, NETLINK_ROUTE}; @@ -53,10 +54,7 @@ use std::io::{BufRead, BufReader}; use std::os::unix::fs::FileExt; use std::path::PathBuf; -const SYSFS_MEMORY_BLOCK_SIZE_PATH: &'static str = "/sys/devices/system/memory/block_size_bytes"; -const SYSFS_MEMORY_HOTPLUG_PROBE_PATH: &'static str = "/sys/devices/system/memory/probe"; -pub const SYSFS_MEMORY_ONLINE_PATH: &'static str = "/sys/devices/system/memory"; -const CONTAINER_BASE: &'static str = "/run/kata-containers"; +const CONTAINER_BASE: &str = "/run/kata-containers"; // Convenience macro to obtain the scope logger macro_rules! sl { @@ -95,7 +93,7 @@ impl agentService { // updates the devices listed in the OCI spec, so that they actually // match real devices inside the VM. This step is necessary since we // cannot predict everything from the caller. - add_devices(req.devices.to_vec(), oci, self.sandbox.clone())?; + add_devices(&req.devices.to_vec(), oci, &self.sandbox)?; // Both rootfs and volumes (invoked with --volume for instance) will // be processed the same way. The idea is to always mount any provided diff --git a/src/agent/src/linux_abi.rs b/src/agent/src/linux_abi.rs new file mode 100644 index 0000000000..69edb20d9d --- /dev/null +++ b/src/agent/src/linux_abi.rs @@ -0,0 +1,50 @@ +// Copyright (c) 2019 Ant Financial +// +// SPDX-License-Identifier: Apache-2.0 +// + +/// Linux ABI related constants. + +pub const SYSFS_DIR: &str = "/sys"; + +pub const SYSFS_PCI_BUS_PREFIX: &str = "/sys/bus/pci/devices"; +pub const SYSFS_PCI_BUS_RESCAN_FILE: &str = "/sys/bus/pci/rescan"; +#[cfg(any( + target_arch = "powerpc64le", + target_arch = "s390x", + target_arch = "x86_64", + target_arch = "x86" +))] +pub const PCI_ROOT_BUS_PATH: &str = "/devices/pci0000:00"; +#[cfg(target_arch = "arm")] +pub const PCI_ROOT_BUS_PATH: &str = "/devices/platform/4010000000.pcie/pci0000:00"; + +pub const SYSFS_CPU_ONLINE_PATH: &str = "/sys/devices/system/cpu"; + +pub const SYSFS_MEMORY_BLOCK_SIZE_PATH: &str = "/sys/devices/system/memory/block_size_bytes"; +pub const SYSFS_MEMORY_HOTPLUG_PROBE_PATH: &str = "/sys/devices/system/memory/probe"; +pub const SYSFS_MEMORY_ONLINE_PATH: &str = "/sys/devices/system/memory"; + +// Here in "0:0", the first number is the SCSI host number because +// only one SCSI controller has been plugged, while the second number +// is always 0. +pub const SCSI_HOST_CHANNEL: &str = "0:0:"; +pub const SCSI_BLOCK_SUFFIX: &str = "block"; +pub const SYSFS_SCSI_HOST_PATH: &str = "/sys/class/scsi_host"; + +pub const SYSFS_CGROUPPATH: &str = "/sys/fs/cgroup"; +pub const SYSFS_ONLINE_FILE: &str = "online"; + +pub const PROC_MOUNTSTATS: &str = "/proc/self/mountstats"; +pub const PROC_CGROUPS: &str = "/proc/cgroups"; + +pub const SYSTEM_DEV_PATH: &str = "/dev"; + +// Linux UEvent related consts. +pub const U_EVENT_ACTION: &str = "ACTION"; +pub const U_EVENT_ACTION_ADD: &str = "add"; +pub const U_EVENT_DEV_PATH: &str = "DEVPATH"; +pub const U_EVENT_SUB_SYSTEM: &str = "SUBSYSTEM"; +pub const U_EVENT_SEQ_NUM: &str = "SEQNUM"; +pub const U_EVENT_DEV_NAME: &str = "DEVNAME"; +pub const U_EVENT_INTERFACE: &str = "INTERFACE"; diff --git a/src/agent/src/logging.rs b/src/agent/src/logging.rs index e6a93ff0fd..a5d4a94eed 100644 --- a/src/agent/src/logging.rs +++ b/src/agent/src/logging.rs @@ -9,7 +9,7 @@ use std::io; use std::io::Write; use std::process; use std::result; -use std::sync::{Arc, Mutex}; +use std::sync::Mutex; // XXX: 'writer' param used to make testing possible. pub fn create_logger(name: &str, source: &str, level: slog::Level, writer: W) -> slog::Logger @@ -41,16 +41,6 @@ where ) } -impl KV for HashSerializer { - fn serialize(&self, _record: &Record, serializer: &mut dyn slog::Serializer) -> slog::Result { - for (key, value) in self.fields.clone().into_iter() { - serializer.emit_str(Key::from(key), &value)?; - } - - Ok(()) - } -} - // Used to convert an slog::OwnedKVList into a hash map. struct HashSerializer { fields: HashMap, @@ -77,6 +67,16 @@ impl HashSerializer { } } +impl KV for HashSerializer { + fn serialize(&self, _record: &Record, serializer: &mut dyn slog::Serializer) -> slog::Result { + for (key, value) in self.fields.iter() { + serializer.emit_str(Key::from(key.to_string()), value)?; + } + + Ok(()) + } +} + impl slog::Serializer for HashSerializer { fn emit_arguments(&mut self, key: Key, value: &std::fmt::Arguments) -> slog::Result { self.add_field(format!("{}", key), format!("{}", value)); @@ -90,13 +90,13 @@ struct UniqueDrain { impl UniqueDrain { fn new(drain: D) -> Self { - UniqueDrain { drain: drain } + UniqueDrain { drain } } } impl Drain for UniqueDrain where - D: slog::Drain, + D: Drain, { type Ok = (); type Err = io::Error; @@ -136,21 +136,19 @@ where // specified in the struct. struct RuntimeLevelFilter { drain: D, - level: Arc>, + level: Mutex, } impl RuntimeLevelFilter { fn new(drain: D, level: slog::Level) -> Self { RuntimeLevelFilter { - drain: drain, - level: Arc::new(Mutex::new(level)), + drain, + level: Mutex::new(level), } } fn set_level(&self, level: slog::Level) { - let level_ref = self.level.clone(); - - let mut log_level = level_ref.lock().unwrap(); + let mut log_level = self.level.lock().unwrap(); *log_level = level; } @@ -168,9 +166,7 @@ where record: &slog::Record, values: &slog::OwnedKVList, ) -> result::Result { - let level_ref = self.level.clone(); - - let log_level = level_ref.lock().unwrap(); + let log_level = self.level.lock().unwrap(); if record.level().is_at_least(*log_level) { self.drain.log(record, values)?; diff --git a/src/agent/src/main.rs b/src/agent/src/main.rs index a964911782..6f34bb0119 100644 --- a/src/agent/src/main.rs +++ b/src/agent/src/main.rs @@ -34,7 +34,7 @@ use signal_hook::{iterator::Signals, SIGCHLD}; use std::collections::HashMap; use std::env; use std::fs; -use std::os::unix::fs::{self as unixfs}; +use std::os::unix::fs as unixfs; use std::os::unix::io::AsRawFd; use std::path::Path; use std::sync::mpsc::{self, Sender}; @@ -44,6 +44,7 @@ use unistd::Pid; mod config; mod device; +mod linux_abi; mod logging; mod mount; mod namespace; @@ -63,11 +64,11 @@ use uevent::watch_uevents; mod grpc; -const NAME: &'static str = "kata-agent"; -const VSOCK_ADDR: &'static str = "vsock://-1"; +const NAME: &str = "kata-agent"; +const VSOCK_ADDR: &str = "vsock://-1"; const VSOCK_PORT: u16 = 1024; -const KERNEL_CMDLINE_FILE: &'static str = "/proc/cmdline"; -const CONSOLE_PATH: &'static str = "/dev/console"; +const KERNEL_CMDLINE_FILE: &str = "/proc/cmdline"; +const CONSOLE_PATH: &str = "/dev/console"; lazy_static! { static ref GLOBAL_DEVICE_WATCHER: Arc>>> = diff --git a/src/agent/src/mount.rs b/src/agent/src/mount.rs index c3212096a6..2fccdca8cd 100644 --- a/src/agent/src/mount.rs +++ b/src/agent/src/mount.rs @@ -23,27 +23,21 @@ use std::fs::File; use std::io::{BufRead, BufReader}; use crate::device::{get_pci_device_name, get_scsi_device_name, online_device}; +use crate::linux_abi::*; use crate::protocols::agent::Storage; use crate::Sandbox; use slog::Logger; -const DRIVER9PTYPE: &'static str = "9p"; -const DRIVERVIRTIOFSTYPE: &'static str = "virtio-fs"; -pub const DRIVERBLKTYPE: &'static str = "blk"; -pub const DRIVERMMIOBLKTYPE: &'static str = "mmioblk"; -pub const DRIVERSCSITYPE: &'static str = "scsi"; -pub const DRIVERNVDIMMTYPE: &'static str = "nvdimm"; -const DRIVEREPHEMERALTYPE: &'static str = "ephemeral"; -const DRIVERLOCALTYPE: &'static str = "local"; +pub const DRIVER9PTYPE: &str = "9p"; +pub const DRIVERVIRTIOFSTYPE: &str = "virtio-fs"; +pub const DRIVERBLKTYPE: &str = "blk"; +pub const DRIVERMMIOBLKTYPE: &str = "mmioblk"; +pub const DRIVERSCSITYPE: &str = "scsi"; +pub const DRIVERNVDIMMTYPE: &str = "nvdimm"; +pub const DRIVEREPHEMERALTYPE: &str = "ephemeral"; +pub const DRIVERLOCALTYPE: &str = "local"; -pub const TYPEROOTFS: &'static str = "rootfs"; - -pub const PROCMOUNTSTATS: &'static str = "/proc/self/mountstats"; - -const ROOTBUSPATH: &'static str = "/devices/pci0000:00"; - -const CGROUPPATH: &'static str = "/sys/fs/cgroup"; -const PROCCGROUPS: &'static str = "/proc/cgroups"; +pub const TYPEROOTFS: &str = "rootfs"; #[cfg_attr(rustfmt, rustfmt_skip)] lazy_static! { @@ -341,7 +335,7 @@ fn virtio_blk_storage_handler( return Err(ErrorKind::ErrorCode(format!("Invalid device {}", &storage.source)).into()); } } else { - let dev_path = get_pci_device_name(sandbox, &storage.source)?; + let dev_path = get_pci_device_name(&sandbox, &storage.source)?; storage.source = dev_path; } @@ -357,7 +351,7 @@ fn virtio_scsi_storage_handler( let mut storage = storage.clone(); // Retrieve the device path from SCSI address. - let dev_path = get_scsi_device_name(sandbox, &storage.source)?; + let dev_path = get_scsi_device_name(&sandbox, &storage.source)?; storage.source = dev_path; common_storage_handler(logger, &storage) @@ -509,7 +503,7 @@ pub fn general_mount(logger: &Logger) -> Result<()> { #[inline] pub fn get_mount_fs_type(mount_point: &str) -> Result { - get_mount_fs_type_from_file(PROCMOUNTSTATS, mount_point) + get_mount_fs_type_from_file(PROC_MOUNTSTATS, mount_point) } // get_mount_fs_type returns the FS type corresponding to the passed mount point and @@ -553,7 +547,7 @@ pub fn get_cgroup_mounts(logger: &Logger, cg_path: &str) -> Result = vec![INIT_MOUNT { fstype: "tmpfs", src: "tmpfs", - dest: CGROUPPATH, + dest: SYSFS_CGROUPPATH, options: vec!["nosuid", "nodev", "noexec", "mode=755"], }]; @@ -613,7 +607,7 @@ pub fn get_cgroup_mounts(logger: &Logger, cg_path: &str) -> Result Result Result<()> { let logger = logger.new(o!("subsystem" => "mount")); - let cgroups = get_cgroup_mounts(&logger, PROCCGROUPS)?; + let cgroups = get_cgroup_mounts(&logger, PROC_CGROUPS)?; for cg in cgroups.iter() { mount_to_rootfs(&logger, cg)?; @@ -1103,14 +1097,14 @@ mod tests { let first_mount = INIT_MOUNT { fstype: "tmpfs", src: "tmpfs", - dest: CGROUPPATH, + dest: SYSFS_CGROUPPATH, options: vec!["nosuid", "nodev", "noexec", "mode=755"], }; let last_mount = INIT_MOUNT { fstype: "tmpfs", src: "tmpfs", - dest: CGROUPPATH, + dest: SYSFS_CGROUPPATH, options: vec!["remount", "ro", "nosuid", "nodev", "noexec", "mode=755"], }; diff --git a/src/agent/src/namespace.rs b/src/agent/src/namespace.rs index 81db13fd11..e4d2ee1b35 100644 --- a/src/agent/src/namespace.rs +++ b/src/agent/src/namespace.rs @@ -17,10 +17,10 @@ use crate::mount::{BareMount, FLAGS}; use slog::Logger; //use container::Process; -const PERSISTENT_NS_DIR: &'static str = "/var/run/sandbox-ns"; -pub const NSTYPEIPC: &'static str = "ipc"; -pub const NSTYPEUTS: &'static str = "uts"; -pub const NSTYPEPID: &'static str = "pid"; +const PERSISTENT_NS_DIR: &str = "/var/run/sandbox-ns"; +pub const NSTYPEIPC: &str = "ipc"; +pub const NSTYPEUTS: &str = "uts"; +pub const NSTYPEPID: &str = "pid"; pub fn get_current_thread_ns_path(ns_type: &str) -> String { format!( @@ -64,7 +64,7 @@ impl Namespace { self } - // setup_persistent_ns creates persistent namespace without switchin to it. + // setup_persistent_ns creates persistent namespace without switching to it. // Note, pid namespaces cannot be persisted. pub fn setup(mut self) -> Result { if let Err(err) = fs::create_dir_all(&self.persistent_ns_dir) { @@ -81,14 +81,9 @@ impl Namespace { return Err(err.to_string()); } - self.path = new_ns_path.into_os_string().into_string().unwrap(); + self.path = new_ns_path.clone().into_os_string().into_string().unwrap(); let new_thread = thread::spawn(move || { - let ns_path = ns_path.clone(); - let ns_type = ns_type.clone(); - let logger = logger; - let new_ns_path = ns_path.join(&ns_type.get()); - let origin_ns_path = get_current_thread_ns_path(&ns_type.get()); let _origin_ns_fd = match File::open(Path::new(&origin_ns_path)) { @@ -107,8 +102,6 @@ impl Namespace { let source: &str = origin_ns_path.as_str(); let destination: &str = new_ns_path.as_path().to_str().unwrap_or("none"); - let _recursive = true; - let _readonly = true; let mut flags = MsFlags::empty(); match FLAGS.get("rbind") { @@ -120,13 +113,13 @@ impl Namespace { }; let bare_mount = BareMount::new(source, destination, "none", flags, "", &logger); - if let Err(err) = bare_mount.mount() { return Err(format!( "Failed to mount {} to {} with err:{:?}", source, destination, err )); } + Ok(()) }); @@ -152,11 +145,11 @@ enum NamespaceType { impl NamespaceType { /// Get the string representation of the namespace type. - pub fn get(&self) -> String { + pub fn get(&self) -> &str { match *self { - Self::IPC => String::from("ipc"), - Self::UTS => String::from("uts"), - Self::PID => String::from("pid"), + Self::IPC => "ipc", + Self::UTS => "uts", + Self::PID => "pid", } } diff --git a/src/agent/src/random.rs b/src/agent/src/random.rs index f05c4345b5..23446d3212 100644 --- a/src/agent/src/random.rs +++ b/src/agent/src/random.rs @@ -10,7 +10,7 @@ use nix::sys::stat::Mode; use rustjail::errors::*; use std::fs; -pub const RNGDEV: &'static str = "/dev/random"; +pub const RNGDEV: &str = "/dev/random"; pub const RNDADDTOENTCNT: libc::c_int = 0x40045201; pub const RNDRESEEDRNG: libc::c_int = 0x5207; diff --git a/src/agent/src/sandbox.rs b/src/agent/src/sandbox.rs index a60aecc356..13f365bc74 100644 --- a/src/agent/src/sandbox.rs +++ b/src/agent/src/sandbox.rs @@ -4,6 +4,7 @@ // //use crate::container::Container; +use crate::linux_abi::*; use crate::mount::{get_mount_fs_type, remove_mounts, TYPEROOTFS}; use crate::namespace::Namespace; use crate::netlink::{RtnlHandle, NETLINK_ROUTE}; @@ -36,7 +37,6 @@ pub struct Sandbox { pub storages: HashMap, pub running: bool, pub no_pivot_root: bool, - enable_grpc_trace: bool, pub sandbox_pid_ns: bool, pub sender: Option>, pub rtnl: Option, @@ -49,8 +49,8 @@ impl Sandbox { Ok(Sandbox { logger: logger.clone(), - id: "".to_string(), - hostname: "".to_string(), + id: String::new(), + hostname: String::new(), network: Network::new(), containers: HashMap::new(), mounts: Vec::new(), @@ -61,17 +61,37 @@ impl Sandbox { storages: HashMap::new(), running: false, no_pivot_root: fs_type.eq(TYPEROOTFS), - enable_grpc_trace: false, sandbox_pid_ns: false, sender: None, rtnl: Some(RtnlHandle::new(NETLINK_ROUTE, 0).unwrap()), }) } + // set_sandbox_storage sets the sandbox level reference + // counter for the sandbox storage. + // This method also returns a boolean to let + // callers know if the storage already existed or not. + // It will return true if storage is new. + // + // It's assumed that caller is calling this method after + // acquiring a lock on sandbox. + pub fn set_sandbox_storage(&mut self, path: &str) -> bool { + match self.storages.get_mut(path) { + None => { + self.storages.insert(path.to_string(), 1); + true + } + Some(count) => { + *count += 1; + false + } + } + } + // unset_sandbox_storage will decrement the sandbox storage // reference counter. If there aren't any containers using // that sandbox storage, this method will remove the - // storage reference from the sandbox and return 'true, nil' to + // storage reference from the sandbox and return 'true' to // let the caller know that they can clean up the storage // related directories by calling remove_sandbox_storage // @@ -84,8 +104,9 @@ impl Sandbox { *count -= 1; if *count < 1 { self.storages.remove(path); + return true; } - return true; + false } } } @@ -158,7 +179,7 @@ impl Sandbox { self.containers.get_mut(id) } - pub fn find_process<'a>(&'a mut self, pid: pid_t) -> Option<&'a mut Process> { + pub fn find_process(&mut self, pid: pid_t) -> Option<&mut Process> { for (_, c) in self.containers.iter_mut() { if c.processes.get(&pid).is_some() { return c.processes.get_mut(&pid); @@ -168,27 +189,6 @@ impl Sandbox { None } - // set_sandbox_storage sets the sandbox level reference - // counter for the sandbox storage. - // This method also returns a boolean to let - // callers know if the storage already existed or not. - // It will return true if storage is new. - // - // It's assumed that caller is calling this method after - // acquiring a lock on sandbox. - pub fn set_sandbox_storage(&mut self, path: &str) -> bool { - match self.storages.get_mut(path) { - None => { - self.storages.insert(path.to_string(), 1); - true - } - Some(count) => { - *count += 1; - false - } - } - } - pub fn destroy(&mut self) -> Result<()> { for (_, ctr) in &mut self.containers { ctr.destroy()?; @@ -221,10 +221,6 @@ impl Sandbox { } } -pub const CPU_ONLINE_PATH: &'static str = "/sys/devices/system/cpu"; -pub const MEMORY_ONLINE_PATH: &'static str = "/sys/devices/system/memory"; -pub const ONLINE_FILE: &'static str = "online"; - fn online_resources(logger: &Logger, path: &str, pattern: &str, num: i32) -> Result { let mut count = 0; let re = Regex::new(pattern)?; @@ -236,7 +232,7 @@ fn online_resources(logger: &Logger, path: &str, pattern: &str, num: i32) -> Res let p = entry.path(); if re.is_match(name) { - let file = format!("{}/{}", p.to_str().unwrap(), ONLINE_FILE); + let file = format!("{}/{}", p.to_str().unwrap(), SYSFS_ONLINE_FILE); info!(logger, "{}", file.as_str()); let c = fs::read_to_string(file.as_str())?; @@ -259,10 +255,10 @@ fn online_resources(logger: &Logger, path: &str, pattern: &str, num: i32) -> Res } fn online_cpus(logger: &Logger, num: i32) -> Result { - online_resources(logger, CPU_ONLINE_PATH, r"cpu[0-9]+", num) + online_resources(logger, SYSFS_CPU_ONLINE_PATH, r"cpu[0-9]+", num) } fn online_memory(logger: &Logger) -> Result<()> { - online_resources(logger, MEMORY_ONLINE_PATH, r"memory[0-9]+", -1)?; + online_resources(logger, SYSFS_MEMORY_ONLINE_PATH, r"memory[0-9]+", -1)?; Ok(()) } diff --git a/src/agent/src/uevent.rs b/src/agent/src/uevent.rs index 11ca068c32..9165a155d3 100644 --- a/src/agent/src/uevent.rs +++ b/src/agent/src/uevent.rs @@ -3,23 +3,17 @@ // SPDX-License-Identifier: Apache-2.0 // -use crate::device::{online_device, ROOT_BUS_PATH, SCSI_BLOCK_SUFFIX, SYSFS_DIR}; -use crate::grpc::SYSFS_MEMORY_ONLINE_PATH; +use crate::device::online_device; +use crate::linux_abi::*; use crate::netlink::{RtnlHandle, NETLINK_UEVENT}; use crate::sandbox::Sandbox; use crate::GLOBAL_DEVICE_WATCHER; +use slog::Logger; use std::sync::{Arc, Mutex}; use std::thread; -pub const U_EVENT_ACTION: &'static str = "ACTION"; -pub const U_EVENT_DEV_PATH: &'static str = "DEVPATH"; -pub const U_EVENT_SUB_SYSTEM: &'static str = "SUBSYSTEM"; -pub const U_EVENT_SEQ_NUM: &'static str = "SEQNUM"; -pub const U_EVENT_DEV_NAME: &'static str = "DEVNAME"; -pub const U_EVENT_INTERFACE: &'static str = "INTERFACE"; - #[derive(Debug, Default)] -pub struct Uevent { +struct Uevent { action: String, devpath: String, devname: String, @@ -28,36 +22,107 @@ pub struct Uevent { interface: String, } -fn parse_uevent(message: &str) -> Uevent { - let mut msg_iter = message.split('\0'); - let mut event = Uevent::default(); +impl Uevent { + fn new(message: &str) -> Self { + let mut msg_iter = message.split('\0'); + let mut event = Uevent::default(); - msg_iter.next(); // skip the first value - for arg in msg_iter { - let key_val: Vec<&str> = arg.splitn(2, '=').collect(); - if key_val.len() == 2 { - match key_val[0] { - U_EVENT_ACTION => event.action = String::from(key_val[1]), - U_EVENT_DEV_NAME => event.devname = String::from(key_val[1]), - U_EVENT_SUB_SYSTEM => event.subsystem = String::from(key_val[1]), - U_EVENT_DEV_PATH => event.devpath = String::from(key_val[1]), - U_EVENT_SEQ_NUM => event.seqnum = String::from(key_val[1]), - U_EVENT_INTERFACE => event.interface = String::from(key_val[1]), - _ => (), + msg_iter.next(); // skip the first value + for arg in msg_iter { + let key_val: Vec<&str> = arg.splitn(2, '=').collect(); + if key_val.len() == 2 { + match key_val[0] { + U_EVENT_ACTION => event.action = String::from(key_val[1]), + U_EVENT_DEV_NAME => event.devname = String::from(key_val[1]), + U_EVENT_SUB_SYSTEM => event.subsystem = String::from(key_val[1]), + U_EVENT_DEV_PATH => event.devpath = String::from(key_val[1]), + U_EVENT_SEQ_NUM => event.seqnum = String::from(key_val[1]), + U_EVENT_INTERFACE => event.interface = String::from(key_val[1]), + _ => (), + } } } + + event + } + + // Check whether this is a block device hot-add event. + fn is_block_add_event(&self) -> bool { + self.action == U_EVENT_ACTION_ADD + && self.subsystem == "block" + && self.devpath.starts_with(PCI_ROOT_BUS_PATH) + && self.devname != "" + } + + fn handle_block_add_event(&self, sandbox: &Arc>) { + // Keep the same lock order as device::get_device_name(), otherwise it may cause deadlock. + let mut w = GLOBAL_DEVICE_WATCHER.lock().unwrap(); + let mut sb = sandbox.lock().unwrap(); + + // Add the device node name to the pci device map. + sb.pci_device_map + .insert(self.devpath.clone(), self.devname.clone()); + + // Notify watchers that are interested in the udev event. + // Close the channel after watcher has been notified. + let devpath = self.devpath.clone(); + let empties: Vec<_> = w + .iter() + .filter(|(dev_addr, _)| { + let pci_p = format!("{}/{}", PCI_ROOT_BUS_PATH, *dev_addr); + + // blk block device + devpath.starts_with(pci_p.as_str()) || + // scsi block device + { + (*dev_addr).ends_with(SCSI_BLOCK_SUFFIX) && + devpath.contains(*dev_addr) + } + }) + .map(|(k, sender)| { + let devname = self.devname.clone(); + let _ = sender.send(devname); + k.clone() + }) + .collect(); + + // Remove notified nodes from the watcher map. + for empty in empties { + w.remove(&empty); + } } - event + fn process(&self, logger: &Logger, sandbox: &Arc>) { + if self.is_block_add_event() { + return self.handle_block_add_event(sandbox); + } else if self.action == U_EVENT_ACTION_ADD { + let online_path = format!("{}/{}/online", SYSFS_DIR, &self.devpath); + // It's a memory hot-add event. + if online_path.starts_with(SYSFS_MEMORY_ONLINE_PATH) { + if let Err(e) = online_device(online_path.as_ref()) { + error!( + *logger, + "failed to online device"; + "device" => &self.devpath, + "error" => format!("{}", e), + ); + } + return; + } + } + debug!(*logger, "ignoring event"; "uevent" => format!("{:?}", self)); + } } pub fn watch_uevents(sandbox: Arc>) { - let sref = sandbox.clone(); - let s = sref.lock().unwrap(); - let logger = s.logger.new(o!("subsystem" => "uevent")); - thread::spawn(move || { let rtnl = RtnlHandle::new(NETLINK_UEVENT, 1).unwrap(); + let logger = sandbox + .lock() + .unwrap() + .logger + .new(o!("subsystem" => "uevent")); + loop { match rtnl.recv_message() { Err(e) => { @@ -70,68 +135,9 @@ pub fn watch_uevents(sandbox: Arc>) { error!(logger, "failed to convert bytes to text"; "error" => format!("{}", e)) } Ok(text) => { - let event = parse_uevent(&text); + let event = Uevent::new(&text); info!(logger, "got uevent message"; "event" => format!("{:?}", event)); - - // Check if device hotplug event results in a device node being created. - if event.devname != "" - && event.devpath.starts_with(ROOT_BUS_PATH) - && event.subsystem == "block" - { - let watcher = GLOBAL_DEVICE_WATCHER.clone(); - let mut w = watcher.lock().unwrap(); - - let s = sandbox.clone(); - let mut sb = s.lock().unwrap(); - - // Add the device node name to the pci device map. - sb.pci_device_map - .insert(event.devpath.clone(), event.devname.clone()); - - // Notify watchers that are interested in the udev event. - // Close the channel after watcher has been notified. - - let devpath = event.devpath.clone(); - - let empties: Vec<_> = w - .iter() - .filter(|(dev_addr, _)| { - let pci_p = format!("{}/{}", ROOT_BUS_PATH, *dev_addr); - - // blk block device - devpath.starts_with(pci_p.as_str()) || - // scsi block device - { - (*dev_addr).ends_with(SCSI_BLOCK_SUFFIX) && - devpath.contains(*dev_addr) - } - }) - .map(|(k, sender)| { - let devname = event.devname.clone(); - let _ = sender.send(devname); - k.clone() - }) - .collect(); - - for empty in empties { - w.remove(&empty); - } - } else { - let online_path = - format!("{}/{}/online", SYSFS_DIR, &event.devpath); - if online_path.starts_with(SYSFS_MEMORY_ONLINE_PATH) { - // Check memory hotplug and online if possible - match online_device(online_path.as_ref()) { - Ok(_) => (), - Err(e) => error!( - logger, - "failed to online device"; - "device" => &event.devpath, - "error" => format!("{}", e), - ), - } - } - } + event.process(&logger, &sandbox); } } } diff --git a/src/agent/src/version.rs b/src/agent/src/version.rs index 9c9f6e21c1..4902438661 100644 --- a/src/agent/src/version.rs +++ b/src/agent/src/version.rs @@ -3,5 +3,5 @@ // SPDX-License-Identifier: Apache-2.0 // -pub const AGENT_VERSION: &'static str = "1.4.5"; -pub const API_VERSION: &'static str = "0.0.1"; +pub const AGENT_VERSION: &str = "1.4.5"; +pub const API_VERSION: &str = "0.0.1";