diff --git a/src/runtime-rs/crates/hypervisor/src/device/util.rs b/src/runtime-rs/crates/hypervisor/src/device/util.rs index 5d999d8f6..f3c6964f4 100644 --- a/src/runtime-rs/crates/hypervisor/src/device/util.rs +++ b/src/runtime-rs/crates/hypervisor/src/device/util.rs @@ -69,8 +69,24 @@ pub(crate) fn get_virt_drive_name(mut index: i32) -> Result { Ok(String::from(PREFIX) + std::str::from_utf8(&disk_letters)?) } +// Using the return value of do_increase_count to indicate whether a device has been inserted into the guest. +// Specially, Increment the reference count by 1, then check the incremented ref_count: +// If the incremented reference count is not equal to 1, the device has been inserted into the guest. Return true. +// If the reference count is equal to 1, the device has not been inserted into the guest. Return false. +pub fn do_increase_count(ref_count: &mut u64) -> Result { + // ref_count = 0: Device is new and not attached. + // ref_count > 0: Device has been attempted to be attached many times. + *ref_count = (*ref_count) + .checked_add(1) + .ok_or("device reference count overflow") + .map_err(|e| anyhow!(e))?; + + Ok((*ref_count) != 1) +} + #[cfg(test)] mod tests { + use crate::device::util::do_increase_count; use crate::device::util::get_virt_drive_name; #[actix_rt::test] @@ -88,4 +104,45 @@ mod tests { assert_eq!(&out, output); } } + + #[test] + fn test_do_increase_count() { + // First, ref_count is 0 + let ref_count_0 = &mut 0_u64; + assert!(!do_increase_count(ref_count_0).unwrap()); + + // Second, ref_count > 0 + let ref_count_3 = &mut 3_u64; + assert!(do_increase_count(ref_count_3).unwrap()); + + // Third, ref_count is MAX + let mut max_count = std::u64::MAX; + let ref_count_max: &mut u64 = &mut max_count; + assert!(do_increase_count(ref_count_max).is_err()); + } + + #[test] + fn test_data_reference_count() { + #[derive(Default)] + struct TestData { + ref_cnt: u64, + } + + impl TestData { + fn attach(&mut self) -> bool { + do_increase_count(&mut self.ref_cnt).unwrap() + } + } + + let testd = &mut TestData { ref_cnt: 0_u64 }; + + // First, ref_cnt is 0 + assert!(!testd.attach()); + assert_eq!(testd.ref_cnt, 1_u64); + // Second, ref_cnt > 0 + assert!(testd.attach()); + assert_eq!(testd.ref_cnt, 2_u64); + assert!(testd.attach()); + assert_eq!(testd.ref_cnt, 3_u64); + } }