diff --git a/src/runtime-rs/crates/hypervisor/src/device/driver/vfio.rs b/src/runtime-rs/crates/hypervisor/src/device/driver/vfio.rs index 7d8abd2e09..1113d2b879 100644 --- a/src/runtime-rs/crates/hypervisor/src/device/driver/vfio.rs +++ b/src/runtime-rs/crates/hypervisor/src/device/driver/vfio.rs @@ -20,6 +20,7 @@ use crate::{ device::{ pci_path::PciPath, topology::{do_add_pcie_endpoint, PCIeTopology}, + util::{do_decrease_count, do_increase_count}, Device, DeviceType, PCIeDevice, }, register_pcie_device, unregister_pcie_device, update_pcie_device, Hypervisor as hypervisor, @@ -456,7 +457,13 @@ impl Device for VfioDevice { .await .context("failed to increase attach count")? { - return Err(anyhow!("attach count increased failed as some reason.")); + warn!( + sl!(), + "The device {:?} is not allowed to be attached more than one times.", + self.device_id + ); + + return Ok(()); } // do add device for vfio deivce @@ -516,33 +523,11 @@ impl Device for VfioDevice { } async fn increase_attach_count(&mut self) -> Result { - match self.attach_count { - 0 => { - // do real attach - self.attach_count += 1; - Ok(false) - } - std::u64::MAX => Err(anyhow!("device was attached too many times")), - _ => { - self.attach_count += 1; - Ok(true) - } - } + do_increase_count(&mut self.attach_count) } async fn decrease_attach_count(&mut self) -> Result { - match self.attach_count { - 0 => Err(anyhow!("detaching a device that wasn't attached")), - 1 => { - // do real wrok - self.attach_count -= 1; - Ok(false) - } - _ => { - self.attach_count -= 1; - Ok(true) - } - } + do_decrease_count(&mut self.attach_count) } async fn get_device_info(&self) -> DeviceType { diff --git a/src/runtime-rs/crates/hypervisor/src/device/driver/vhost_user_blk.rs b/src/runtime-rs/crates/hypervisor/src/device/driver/vhost_user_blk.rs index b2a1d90f92..a3335ae32b 100644 --- a/src/runtime-rs/crates/hypervisor/src/device/driver/vhost_user_blk.rs +++ b/src/runtime-rs/crates/hypervisor/src/device/driver/vhost_user_blk.rs @@ -4,12 +4,16 @@ // SPDX-License-Identifier: Apache-2.0 // -use anyhow::{anyhow, Context, Result}; +use anyhow::{Context, Result}; use async_trait::async_trait; use super::VhostUserConfig; use crate::{ - device::{topology::PCIeTopology, Device, DeviceType}, + device::{ + topology::PCIeTopology, + util::{do_decrease_count, do_increase_count}, + Device, DeviceType, + }, Hypervisor as hypervisor, }; @@ -104,32 +108,10 @@ impl Device for VhostUserBlkDevice { } async fn increase_attach_count(&mut self) -> Result { - match self.attach_count { - 0 => { - // do real attach - self.attach_count += 1; - Ok(false) - } - std::u64::MAX => Err(anyhow!("device was attached too many times")), - _ => { - self.attach_count += 1; - Ok(true) - } - } + do_increase_count(&mut self.attach_count) } async fn decrease_attach_count(&mut self) -> Result { - match self.attach_count { - 0 => Err(anyhow!("detaching a device that wasn't attached")), - 1 => { - // do real wrok - self.attach_count -= 1; - Ok(false) - } - _ => { - self.attach_count -= 1; - Ok(true) - } - } + do_decrease_count(&mut self.attach_count) } } diff --git a/src/runtime-rs/crates/hypervisor/src/device/driver/virtio_blk.rs b/src/runtime-rs/crates/hypervisor/src/device/driver/virtio_blk.rs index fdf0f7ea2d..56d3fbd639 100644 --- a/src/runtime-rs/crates/hypervisor/src/device/driver/virtio_blk.rs +++ b/src/runtime-rs/crates/hypervisor/src/device/driver/virtio_blk.rs @@ -6,10 +6,12 @@ use crate::device::pci_path::PciPath; use crate::device::topology::PCIeTopology; +use crate::device::util::do_decrease_count; +use crate::device::util::do_increase_count; use crate::device::Device; use crate::device::DeviceType; use crate::Hypervisor as hypervisor; -use anyhow::{anyhow, Context, Result}; +use anyhow::{Context, Result}; use async_trait::async_trait; /// VIRTIO_BLOCK_PCI indicates block driver is virtio-pci based @@ -135,32 +137,10 @@ impl Device for BlockDevice { } async fn increase_attach_count(&mut self) -> Result { - match self.attach_count { - 0 => { - // do real attach - self.attach_count += 1; - Ok(false) - } - std::u64::MAX => Err(anyhow!("device was attached too many times")), - _ => { - self.attach_count += 1; - Ok(true) - } - } + do_increase_count(&mut self.attach_count) } async fn decrease_attach_count(&mut self) -> Result { - match self.attach_count { - 0 => Err(anyhow!("detaching a device that wasn't attached")), - 1 => { - // do real wrok - self.attach_count -= 1; - Ok(false) - } - _ => { - self.attach_count -= 1; - Ok(true) - } - } + do_decrease_count(&mut self.attach_count) } } diff --git a/src/runtime-rs/crates/hypervisor/src/device/util.rs b/src/runtime-rs/crates/hypervisor/src/device/util.rs index 5d999d8f6c..d719fb9e31 100644 --- a/src/runtime-rs/crates/hypervisor/src/device/util.rs +++ b/src/runtime-rs/crates/hypervisor/src/device/util.rs @@ -69,9 +69,43 @@ pub(crate) fn get_virt_drive_name(mut index: i32) -> Result { Ok(String::from(PREFIX) + std::str::from_utf8(&disk_letters)?) } +// Using the return value of do_increase_count to indicate whether a device has been inserted into the guest. +// Specially, Increment the reference count by 1, then check the incremented ref_count: +// If the incremented reference count is not equal to 1, the device has been inserted into the guest. Return true. +// If the reference count is equal to 1, the device has not been inserted into the guest. Return false. +pub fn do_increase_count(ref_count: &mut u64) -> Result { + // ref_count = 0: Device is new and not attached. + // ref_count > 0: Device has been attempted to be attached many times. + *ref_count = (*ref_count) + .checked_add(1) + .ok_or("device reference count overflow") + .map_err(|e| anyhow!(e))?; + + Ok((*ref_count) != 1) +} + +// The return value of do_decrease_count can be used to indicate whether the device is still in use. +// Specifically, the reference count can be decremented by 1 first, then check the decremented ref_count: +// If the decremented reference count is not equal to 0, it indicates that the device is still in use by +// the guest and cannot be detached. Return true. +// If the reference count is equal to 0, it indicates that the device will not be used and can be unplugged +// from the guest. Return false. +pub fn do_decrease_count(ref_count: &mut u64) -> Result { + // ref_count = 0: Device not inserted (cannot decrease further). + // ref_count = 1: Device is attached to the Guest. Decrement ref_count and notify Device Manager of detachment. + // ref_count > 1: Device remains attached to the Guest. Simply decrement ref_count and notify Device Manager. + *ref_count = (*ref_count) + .checked_sub(1) + .ok_or("The device is not attached") + .map_err(|e| anyhow!(e))?; + + Ok((*ref_count) != 0) +} + #[cfg(test)] mod tests { use crate::device::util::get_virt_drive_name; + use crate::device::util::{do_decrease_count, do_increase_count}; #[actix_rt::test] async fn test_get_virt_drive_name() { @@ -88,4 +122,61 @@ mod tests { assert_eq!(&out, output); } } + + #[test] + fn test_do_increase_count() { + // First, ref_count is 0 + let ref_count_0 = &mut 0_u64; + assert!(do_decrease_count(ref_count_0).is_err()); + + assert!(!do_increase_count(ref_count_0).unwrap()); + assert!(!do_decrease_count(ref_count_0).unwrap()); + + // Second, ref_count > 0 + let ref_count_3 = &mut 3_u64; + assert!(do_increase_count(ref_count_3).unwrap()); + assert!(do_decrease_count(ref_count_3).unwrap()); + + // Third, ref_count is MAX + let mut max_count = std::u64::MAX; + let ref_count_max: &mut u64 = &mut max_count; + assert!(do_increase_count(ref_count_max).is_err()); + } + + #[test] + fn test_data_reference_count() { + #[derive(Default)] + struct TestData { + ref_cnt: u64, + } + + impl TestData { + fn attach(&mut self) -> bool { + do_increase_count(&mut self.ref_cnt).unwrap() + } + + fn detach(&mut self) -> bool { + do_decrease_count(&mut self.ref_cnt).unwrap() + } + } + + let testd = &mut TestData { ref_cnt: 0_u64 }; + + // First, ref_cnt is 0 + assert!(!testd.attach()); + assert_eq!(testd.ref_cnt, 1_u64); + // Second, ref_cnt > 0 + assert!(testd.attach()); + assert_eq!(testd.ref_cnt, 2_u64); + assert!(testd.attach()); + assert_eq!(testd.ref_cnt, 3_u64); + + let testd2 = &mut TestData { ref_cnt: 2_u64 }; + + assert!(testd2.detach()); + assert_eq!(testd2.ref_cnt, 1_u64); + + assert!(!testd2.detach()); + assert_eq!(testd2.ref_cnt, 0_u64); + } }