agent: refine device.rs for better maintenance

1) pass reference instead of value when possible.
2) simplify code.
3) rename get_device_pci_address() as get_pci_device_address() to keep
   consistency get_pci_device_name().
4) refine get_device_name() for maintenance.

Signed-off-by: Liu Jiang <gerry@linux.alibaba.com>
This commit is contained in:
Liu Jiang 2019-12-02 12:13:58 +08:00
parent 94311e4997
commit 000bb8592d
3 changed files with 81 additions and 120 deletions

View File

@ -3,21 +3,19 @@
// SPDX-License-Identifier: Apache-2.0
//
use rustjail::errors::*;
use std::fs;
// use std::io::Write;
use libc::{c_uint, major, minor};
use std::collections::HashMap;
use std::fs;
use std::os::unix::fs::MetadataExt;
use std::path::Path;
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::sync::{mpsc, Arc, Mutex};
use crate::mount::{DRIVERBLKTYPE, DRIVERMMIOBLKTYPE, DRIVERNVDIMMTYPE, DRIVERSCSITYPE};
use crate::sandbox::Sandbox;
use crate::{AGENT_CONFIG, GLOBAL_DEVICE_WATCHER};
use protocols::agent::Device;
use protocols::oci::Spec;
use rustjail::errors::*;
// Convenience macro to obtain the scope logger
macro_rules! sl {
@ -54,23 +52,18 @@ pub const SCSI_BLOCK_SUFFIX: &str = "block";
const SCSI_DISK_SUFFIX: &str = "/device/block";
const SCSI_HOST_PATH: &str = "/sys/class/scsi_host";
// DeviceHandler is the type of callback to be defined to handle every
// type of device driver.
type DeviceHandler = fn(&Device, &mut Spec, Arc<Mutex<Sandbox>>) -> Result<()>;
// DeviceHandler is the type of callback to be defined to handle every type of device driver.
type DeviceHandler = fn(&Device, &mut Spec, &Arc<Mutex<Sandbox>>) -> Result<()>;
// DeviceHandlerList lists the supported drivers.
#[cfg_attr(rustfmt, rustfmt_skip)]
lazy_static! {
pub static ref DEVICEHANDLERLIST: HashMap<&'static str, DeviceHandler> = {
let mut m = HashMap::new();
let blk: DeviceHandler = virtio_blk_device_handler;
m.insert(DRIVERBLKTYPE, blk);
let virtiommio: DeviceHandler = virtiommio_blk_device_handler;
m.insert(DRIVERMMIOBLKTYPE, virtiommio);
let local: DeviceHandler = virtio_nvdimm_device_handler;
m.insert(DRIVERNVDIMMTYPE, local);
let scsi: DeviceHandler = virtio_scsi_device_handler;
m.insert(DRIVERSCSITYPE, scsi);
static ref DEVICEHANDLERLIST: HashMap<&'static str, DeviceHandler> = {
let mut m: HashMap<&'static str, DeviceHandler> = HashMap::new();
m.insert(DRIVERBLKTYPE, virtio_blk_device_handler);
m.insert(DRIVERMMIOBLKTYPE, virtiommio_blk_device_handler);
m.insert(DRIVERNVDIMMTYPE, virtio_nvdimm_device_handler);
m.insert(DRIVERSCSITYPE, virtio_scsi_device_handler);
m
};
}
@ -84,11 +77,11 @@ pub fn online_device(path: &str) -> Result<()> {
Ok(())
}
// get_device_pci_address fetches the complete PCI address in sysfs, based on the PCI
// get_pci_device_address fetches the complete PCI address in sysfs, based on the PCI
// identifier provided. This should be in the format: "bridgeAddr/deviceAddr".
// Here, bridgeAddr is the address at which the brige is attached on the root bus,
// Here, bridgeAddr is the address at which the bridge is attached on the root bus,
// while deviceAddr is the address at which the device is attached on the bridge.
pub fn get_device_pci_address(pci_id: &str) -> Result<String> {
fn get_pci_device_address(pci_id: &str) -> Result<String> {
let tokens: Vec<&str> = pci_id.split("/").collect();
if tokens.len() != 2 {
@ -139,78 +132,60 @@ pub fn get_device_pci_address(pci_id: &str) -> Result<String> {
Ok(bridge_device_pci_addr)
}
pub fn get_device_name(sandbox: Arc<Mutex<Sandbox>>, dev_addr: &str) -> Result<String> {
let mut dev_name: String = String::default();
let (tx, rx) = mpsc::channel::<String>();
{
let watcher = GLOBAL_DEVICE_WATCHER.clone();
let mut w = watcher.lock().unwrap();
let s = sandbox.clone();
let sb = s.lock().unwrap();
for (key, value) in &(sb.pci_device_map) {
fn get_device_name(sandbox: &Arc<Mutex<Sandbox>>, dev_addr: &str) -> Result<String> {
// Keep the same lock order as uevent::handle_block_add_event(), otherwise it may cause deadlock.
let mut w = GLOBAL_DEVICE_WATCHER.lock().unwrap();
let sb = sandbox.lock().unwrap();
for (key, value) in sb.pci_device_map.iter() {
if key.contains(dev_addr) {
dev_name = value.to_string();
info!(sl!(), "Device {} found in pci device map", dev_addr);
break;
return Ok(format!("{}/{}", SYSTEM_DEV_PATH, value));
}
}
drop(sb);
// If device is not found in the device map, hotplug event has not
// been received yet, create and add channel to the watchers map.
// The key of the watchers map is the device we are interested in.
// Note this is done inside the lock, not to miss any events from the
// global udev listener.
if dev_name == "" {
let (tx, rx) = mpsc::channel::<String>();
w.insert(dev_addr.to_string(), tx);
}
}
drop(w);
if dev_name == "" {
info!(sl!(), "Waiting on channel for device notification\n");
let agent_config = AGENT_CONFIG.clone();
let config = agent_config.read().unwrap();
match rx.recv_timeout(config.hotplug_timeout) {
Ok(name) => dev_name = name,
let hotplug_timeout = AGENT_CONFIG.read().unwrap().hotplug_timeout;
let dev_name = match rx.recv_timeout(hotplug_timeout) {
Ok(name) => name,
Err(_) => {
let watcher = GLOBAL_DEVICE_WATCHER.clone();
let mut w = watcher.lock().unwrap();
w.remove_entry(dev_addr);
GLOBAL_DEVICE_WATCHER.lock().unwrap().remove_entry(dev_addr);
return Err(ErrorKind::ErrorCode(format!(
"Timeout reached after {:?} waiting for device {}",
config.hotplug_timeout, dev_addr
hotplug_timeout, dev_addr
))
.into());
}
}
}
};
Ok(format!("{}/{}", SYSTEM_DEV_PATH, &dev_name))
}
pub fn get_scsi_device_name(sandbox: Arc<Mutex<Sandbox>>, scsi_addr: &str) -> Result<String> {
scan_scsi_bus(scsi_addr)?;
pub fn get_scsi_device_name(sandbox: &Arc<Mutex<Sandbox>>, scsi_addr: &str) -> Result<String> {
let dev_sub_path = format!("{}{}/{}", SCSI_HOST_CHANNEL, scsi_addr, SCSI_BLOCK_SUFFIX);
get_device_name(sandbox, dev_sub_path.as_str())
scan_scsi_bus(scsi_addr)?;
get_device_name(sandbox, &dev_sub_path)
}
pub fn get_pci_device_name(sandbox: Arc<Mutex<Sandbox>>, pci_id: &str) -> Result<String> {
let pci_addr = get_device_pci_address(pci_id)?;
pub fn get_pci_device_name(sandbox: &Arc<Mutex<Sandbox>>, pci_id: &str) -> Result<String> {
let pci_addr = get_pci_device_address(pci_id)?;
rescan_pci_bus()?;
get_device_name(sandbox, pci_addr.as_str())
get_device_name(sandbox, &pci_addr)
}
// scan_scsi_bus scans SCSI bus for the given SCSI address(SCSI-Id and LUN)
pub fn scan_scsi_bus(scsi_addr: &str) -> Result<()> {
/// Scan SCSI bus for the given SCSI address(SCSI-Id and LUN)
fn scan_scsi_bus(scsi_addr: &str) -> Result<()> {
let tokens: Vec<&str> = scsi_addr.split(":").collect();
if tokens.len() != 2 {
return Err(ErrorKind::Msg(format!(
@ -220,14 +195,12 @@ pub fn scan_scsi_bus(scsi_addr: &str) -> Result<()> {
.into());
}
// Scan scsi host passing in the channel, SCSI id and LUN. Channel
// is always 0 because we have only one SCSI controller.
// Scan scsi host passing in the channel, SCSI id and LUN.
// Channel is always 0 because we have only one SCSI controller.
let scan_data = format!("0 {} {}", tokens[0], tokens[1]);
for entry in fs::read_dir(SCSI_HOST_PATH)? {
let entry = entry?;
let host = entry.file_name();
let host = entry?.file_name();
let scan_path = format!("{}/{}/{}", SCSI_HOST_PATH, host.to_str().unwrap(), "scan");
fs::write(scan_path, &scan_data)?;
@ -243,9 +216,6 @@ pub fn scan_scsi_bus(scsi_addr: &str) -> Result<()> {
// This is needed to update information about minor/major numbers that cannot
// be predicted from the caller.
fn update_spec_device_list(device: &Device, spec: &mut Spec) -> Result<()> {
// If no container_path is provided, we won't be able to match and
// update the device in the OCI spec device list. This is an error.
let major_id: c_uint;
let minor_id: c_uint;
@ -304,9 +274,7 @@ fn update_spec_device_list(device: &Device, spec: &mut Spec) -> Result<()> {
// Resources must be updated since they are used to identify the
// device in the devices cgroup.
let resource = linux.Resources.as_mut();
if resource.is_some() {
let res = resource.unwrap();
if let Some(res) = linux.Resources.as_mut() {
let ds = res.Devices.as_mut_slice();
for d in ds.iter_mut() {
if d.Major == host_major && d.Minor == host_minor {
@ -331,7 +299,7 @@ fn update_spec_device_list(device: &Device, spec: &mut Spec) -> Result<()> {
fn virtiommio_blk_device_handler(
device: &Device,
spec: &mut Spec,
_sandbox: Arc<Mutex<Sandbox>>,
_sandbox: &Arc<Mutex<Sandbox>>,
) -> Result<()> {
if device.vm_path == "" {
return Err(ErrorKind::Msg("Invalid path for virtio mmio blk device".to_string()).into());
@ -346,13 +314,10 @@ fn virtiommio_blk_device_handler(
fn virtio_blk_device_handler(
device: &Device,
spec: &mut Spec,
sandbox: Arc<Mutex<Sandbox>>,
sandbox: &Arc<Mutex<Sandbox>>,
) -> Result<()> {
let dev_path = get_pci_device_name(sandbox, device.id.as_str())?;
let mut dev = device.clone();
dev.vm_path = dev_path;
dev.vm_path = get_pci_device_name(sandbox, &device.id)?;
update_spec_device_list(&dev, spec)
}
@ -360,39 +325,39 @@ fn virtio_blk_device_handler(
fn virtio_scsi_device_handler(
device: &Device,
spec: &mut Spec,
sandbox: Arc<Mutex<Sandbox>>,
sandbox: &Arc<Mutex<Sandbox>>,
) -> Result<()> {
let dev_path = get_scsi_device_name(sandbox, device.id.as_str())?;
let mut dev = device.clone();
dev.vm_path = dev_path;
dev.vm_path = get_scsi_device_name(sandbox, &device.id)?;
update_spec_device_list(&dev, spec)
}
fn virtio_nvdimm_device_handler(
device: &Device,
spec: &mut Spec,
_sandbox: Arc<Mutex<Sandbox>>,
_sandbox: &Arc<Mutex<Sandbox>>,
) -> Result<()> {
if device.vm_path == "" {
return Err(ErrorKind::Msg("Invalid path for nvdimm device".to_string()).into());
}
update_spec_device_list(device, spec)
}
pub fn add_devices(
devices: Vec<Device>,
devices: &[Device],
spec: &mut Spec,
sandbox: Arc<Mutex<Sandbox>>,
sandbox: &Arc<Mutex<Sandbox>>,
) -> Result<()> {
for device in devices.iter() {
add_device(device, spec, sandbox.clone())?;
add_device(device, spec, sandbox)?;
}
Ok(())
}
fn add_device(device: &Device, spec: &mut Spec, sandbox: Arc<Mutex<Sandbox>>) -> Result<()> {
// log before validation to help with debugging gRPC protocol
// version differences.
fn add_device(device: &Device, spec: &mut Spec, sandbox: &Arc<Mutex<Sandbox>>) -> Result<()> {
// log before validation to help with debugging gRPC protocol version differences.
info!(sl!(), "device-id: {}, device-type: {}, device-vm-path: {}, device-container-path: {}, device-options: {:?}",
device.id, device.field_type, device.vm_path, device.container_path, device.options);
@ -412,12 +377,8 @@ fn add_device(device: &Device, spec: &mut Spec, sandbox: Arc<Mutex<Sandbox>>) ->
);
}
let dev_handler = match DEVICEHANDLERLIST.get(device.field_type.as_str()) {
None => {
return Err(ErrorKind::Msg(format!("Unknown device type {}", device.field_type)).into())
match DEVICEHANDLERLIST.get(device.field_type.as_str()) {
None => Err(ErrorKind::Msg(format!("Unknown device type {}", device.field_type)).into()),
Some(dev_handler) => dev_handler(device, spec, sandbox),
}
Some(t) => t,
};
dev_handler(device, spec, sandbox)
}

View File

@ -95,7 +95,7 @@ impl agentService {
// updates the devices listed in the OCI spec, so that they actually
// match real devices inside the VM. This step is necessary since we
// cannot predict everything from the caller.
add_devices(req.devices.to_vec(), oci, self.sandbox.clone())?;
add_devices(&req.devices.to_vec(), oci, &self.sandbox)?;
// Both rootfs and volumes (invoked with --volume for instance) will
// be processed the same way. The idea is to always mount any provided

View File

@ -341,7 +341,7 @@ fn virtio_blk_storage_handler(
return Err(ErrorKind::ErrorCode(format!("Invalid device {}", &storage.source)).into());
}
} else {
let dev_path = get_pci_device_name(sandbox, &storage.source)?;
let dev_path = get_pci_device_name(&sandbox, &storage.source)?;
storage.source = dev_path;
}
@ -357,7 +357,7 @@ fn virtio_scsi_storage_handler(
let mut storage = storage.clone();
// Retrieve the device path from SCSI address.
let dev_path = get_scsi_device_name(sandbox, &storage.source)?;
let dev_path = get_scsi_device_name(&sandbox, &storage.source)?;
storage.source = dev_path;
common_storage_handler(logger, &storage)