diff --git a/src/runtime-rs/crates/hypervisor/ch-config/src/ch_api.rs b/src/runtime-rs/crates/hypervisor/ch-config/src/ch_api.rs index 5a5ab90cc5..2b5b34c750 100644 --- a/src/runtime-rs/crates/hypervisor/ch-config/src/ch_api.rs +++ b/src/runtime-rs/crates/hypervisor/ch-config/src/ch_api.rs @@ -2,8 +2,10 @@ // // SPDX-License-Identifier: Apache-2.0 -use crate::{DeviceConfig, DiskConfig, FsConfig, NetConfig, VmConfig, VsockConfig}; -use anyhow::{anyhow, Result}; +use crate::{ + DeviceConfig, DiskConfig, FsConfig, NetConfig, VmConfig, VmInfo, VmResize, VsockConfig, +}; +use anyhow::{anyhow, Context, Result}; use api_client::simple_api_full_command_and_response; use serde::{Deserialize, Serialize}; @@ -190,3 +192,34 @@ pub async fn cloud_hypervisor_vm_vsock_add( }) .await? } + +pub async fn cloud_hypervisor_vm_info(mut socket: UnixStream) -> Result { + let vm_info = task::spawn_blocking(move || -> Result> { + let response = simple_api_full_command_and_response(&mut socket, "GET", "vm.info", None) + .map_err(|e| anyhow!(format!("failed to run get vminfo with err: {:?}", e)))?; + + Ok(response) + }) + .await??; + + let vm_info = vm_info.ok_or(anyhow!("failed to get vminfo"))?; + serde_json::from_str(&vm_info).with_context(|| format!("failed to serde {}", vm_info)) +} + +pub async fn cloud_hypervisor_vm_resize( + mut socket: UnixStream, + vmresize: VmResize, +) -> Result> { + task::spawn_blocking(move || -> Result> { + let response = simple_api_full_command_and_response( + &mut socket, + "PUT", + "vm.resize", + Some(&serde_json::to_string(&vmresize)?), + ) + .map_err(|e| anyhow!(e))?; + + Ok(response) + }) + .await? +} diff --git a/src/runtime-rs/crates/hypervisor/ch-config/src/lib.rs b/src/runtime-rs/crates/hypervisor/ch-config/src/lib.rs index 5c90f9b8c3..0e3e4122fd 100644 --- a/src/runtime-rs/crates/hypervisor/ch-config/src/lib.rs +++ b/src/runtime-rs/crates/hypervisor/ch-config/src/lib.rs @@ -500,6 +500,32 @@ pub struct NamedHypervisorConfig { pub guest_protection_to_use: GuestProtection, } +#[derive(Clone, Debug, Deserialize, Serialize, Default)] +pub struct VmResize { + pub desired_vcpus: Option, + pub desired_ram: Option, + pub desired_balloon: Option, +} + +/// VmInfo : Virtual Machine information +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct VmInfo { + pub config: VmConfig, + pub state: State, + #[serde(skip_serializing_if = "Option::is_none")] + pub memory_actual_size: Option, +} + +#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub enum State { + #[default] + Created, + Running, + Shutdown, + Paused, +} + // Returns true if the enabled guest protection is Intel TDX. pub fn guest_protection_is_tdx(guest_protection_to_use: GuestProtection) -> bool { matches!(guest_protection_to_use, GuestProtection::Tdx) diff --git a/src/runtime-rs/crates/hypervisor/ch-config/src/net_util.rs b/src/runtime-rs/crates/hypervisor/ch-config/src/net_util.rs index 00a0794628..723bc576f4 100644 --- a/src/runtime-rs/crates/hypervisor/ch-config/src/net_util.rs +++ b/src/runtime-rs/crates/hypervisor/ch-config/src/net_util.rs @@ -2,16 +2,24 @@ // // SPDX-License-Identifier: Apache-2.0 -use serde::{Deserialize, Serialize, Serializer}; +use anyhow::{anyhow, Result}; +use serde::de::{self, Visitor}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::fmt; pub const MAC_ADDR_LEN: usize = 6; -#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize, Default)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] pub struct MacAddr { pub bytes: [u8; MAC_ADDR_LEN], } +impl MacAddr { + pub fn new(addr: [u8; MAC_ADDR_LEN]) -> MacAddr { + MacAddr { bytes: addr } + } +} + // Note: Implements ToString automatically. impl fmt::Display for MacAddr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -30,3 +38,186 @@ impl Serialize for MacAddr { self.to_string().serialize(serializer) } } + +// Helper function: parse MAC address string to byte array +fn parse_mac_address_str(s: &str) -> Result<[u8; MAC_ADDR_LEN]> { + let parts: Vec<&str> = s.split(':').collect(); + if parts.len() != MAC_ADDR_LEN { + return Err(anyhow!( + "Invalid MAC address format: expected {} parts separated by ':', got {}", + MAC_ADDR_LEN, + parts.len() + )); + } + + let mut bytes = [0u8; MAC_ADDR_LEN]; + for (i, part) in parts.iter().enumerate() { + if part.len() != 2 { + return Err(anyhow!( + "Invalid MAC address part '{}': expected 2 hex digits", + part + )); + } + bytes[i] = u8::from_str_radix(part, 16) + .map_err(|e| anyhow!("Invalid hex digit in '{}': {}", part, e))?; + } + Ok(bytes) +} + +// Customize Deserialize implementation, because the system's own one does not work. +impl<'de> Deserialize<'de> for MacAddr { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // We expect the deserializer to provide a string, so we use deserialize_string + deserializer.deserialize_string(MacAddrVisitor) + } +} + +// MacAddrVisitor will handle the actual conversion from string to MacAddr +struct MacAddrVisitor; + +impl Visitor<'_> for MacAddrVisitor { + type Value = MacAddr; + + // When deserialization fails, Serde will call this method to get a description of the expected format + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a MAC address string in format \"XX:XX:XX:XX:XX:XX\"") + } + + // Called when the deserializer provides a string slice + fn visit_str(self, v: &str) -> Result + where + E: de::Error, + { + // Use our auxiliary function to parse the string and convert it to MacAddr + parse_mac_address_str(v) + .map(MacAddr::new) // If the parsing is successful, create a MacAddr with a byte array + .map_err(de::Error::custom) // If parsing fails, convert the error to Serde's error type + } + + // Called when the deserializer provides a String (usually delegated to visit_str) + fn visit_string(self, v: String) -> Result + where + E: de::Error, + { + self.visit_str(&v) + } +} + +#[cfg(test)] +mod tests { + use super::*; // Import parent module items, including MAC_ADDR_LEN and parse_mac_address_str + + #[test] + fn test_parse_mac_address_str_valid() { + // Test a standard MAC address + let mac_str = "00:11:22:33:44:55"; + let expected_bytes = [0x00, 0x11, 0x22, 0x33, 0x44, 0x55]; + assert_eq!(parse_mac_address_str(mac_str).unwrap(), expected_bytes); + + // Test a MAC address with uppercase letters + let mac_str_upper = "AA:BB:CC:DD:EE:FF"; + let expected_bytes_upper = [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]; + assert_eq!( + parse_mac_address_str(mac_str_upper).unwrap(), + expected_bytes_upper + ); + + // Test a mixed-case MAC address + let mac_str_mixed = "aA:Bb:Cc:Dd:Ee:Ff"; + let expected_bytes_mixed = [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF]; + assert_eq!( + parse_mac_address_str(mac_str_mixed).unwrap(), + expected_bytes_mixed + ); + + // Test an all-zero MAC address + let mac_str_zero = "00:00:00:00:00:00"; + let expected_bytes_zero = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + assert_eq!( + parse_mac_address_str(mac_str_zero).unwrap(), + expected_bytes_zero + ); + } + + #[test] + fn test_parse_mac_address_str_invalid_length() { + // MAC address with too few segments + let mac_str_short = "00:11:22:33:44"; + let err = parse_mac_address_str(mac_str_short).unwrap_err(); + assert!(err + .to_string() + .contains("Invalid MAC address format: expected 6 parts separated by ':', got 5")); + + // MAC address with too many segments + let mac_str_long = "00:11:22:33:44:55:66"; + let err = parse_mac_address_str(mac_str_long).unwrap_err(); + assert!(err + .to_string() + .contains("Invalid MAC address format: expected 6 parts separated by ':', got 7")); + + // Empty string + let mac_str_empty = ""; + let err = parse_mac_address_str(mac_str_empty).unwrap_err(); + // Note: split(':') on an empty string returns a Vec containing [""] if delimiter is not found, + // so its length will be 1. + assert!(err + .to_string() + .contains("Invalid MAC address format: expected 6 parts separated by ':', got 1")); + } + + #[test] + fn test_parse_mac_address_str_invalid_part_length() { + // Part with insufficient length (1 digit) + let mac_str_part_short = "0:11:22:33:44:55"; + let err = parse_mac_address_str(mac_str_part_short).unwrap_err(); + assert!(err + .to_string() + .contains("Invalid MAC address part '0': expected 2 hex digits")); + + // Part with excessive length (3 digits) + let mac_str_part_long = "000:11:22:33:44:55"; + let err = parse_mac_address_str(mac_str_part_long).unwrap_err(); + assert!(err + .to_string() + .contains("Invalid MAC address part '000': expected 2 hex digits")); + } + + #[test] + fn test_parse_mac_address_str_invalid_chars() { + // Contains non-hexadecimal character (letter G) + let mac_str_invalid_char_g = "00:11:22:33:44:GG"; + let err = parse_mac_address_str(mac_str_invalid_char_g).unwrap_err(); + assert!(err.to_string().contains("Invalid hex digit in 'GG'")); + + // Contains non-hexadecimal character (symbol @) + let mac_str_invalid_char_at = "00:11:22:33:44:@5"; + let err = parse_mac_address_str(mac_str_invalid_char_at).unwrap_err(); + assert!(err.to_string().contains("Invalid hex digit in '@5'")); + + // Contains whitespace character + let mac_str_with_space = "00:11:22:33:44: 5"; + let err = parse_mac_address_str(mac_str_with_space).unwrap_err(); + assert!(err.to_string().contains("Invalid hex digit in ' 5'")); + } + + #[test] + fn test_parse_mac_address_str_malformed_string() { + // String with only colons + let mac_str_colon_only = ":::::"; + let err = parse_mac_address_str(mac_str_colon_only).unwrap_err(); + // Each empty part will trigger the "expected 2 hex digits" error + assert!(err + .to_string() + .contains("Invalid MAC address part '': expected 2 hex digits")); + + // String with trailing colon + let mac_str_trailing_colon = "00:11:22:33:44:55:"; + let err = parse_mac_address_str(mac_str_trailing_colon).unwrap_err(); + assert!(err + .to_string() + .contains("Invalid MAC address format: expected 6 parts separated by ':', got 7")); + } +} diff --git a/src/runtime-rs/crates/hypervisor/src/ch/inner.rs b/src/runtime-rs/crates/hypervisor/src/ch/inner.rs index 40f3c5d330..c7f31bd0d8 100644 --- a/src/runtime-rs/crates/hypervisor/src/ch/inner.rs +++ b/src/runtime-rs/crates/hypervisor/src/ch/inner.rs @@ -74,8 +74,8 @@ pub struct CloudHypervisorInner { // None. pub(crate) ch_features: Option>, - /// Size of memory block of guest OS in MB (currently unused) - pub(crate) _guest_memory_block_size_mb: u32, + /// Size of memory block of guest OS in MB + pub(crate) guest_memory_block_size_mb: u32, pub(crate) exit_notify: Option>, } @@ -117,7 +117,7 @@ impl CloudHypervisorInner { tasks: None, guest_protection_to_use: GuestProtection::NoProtection, ch_features: None, - _guest_memory_block_size_mb: 0, + guest_memory_block_size_mb: 0, exit_notify, } diff --git a/src/runtime-rs/crates/hypervisor/src/ch/inner_hypervisor.rs b/src/runtime-rs/crates/hypervisor/src/ch/inner_hypervisor.rs index 353b2f7978..6f114718dc 100644 --- a/src/runtime-rs/crates/hypervisor/src/ch/inner_hypervisor.rs +++ b/src/runtime-rs/crates/hypervisor/src/ch/inner_hypervisor.rs @@ -7,15 +7,18 @@ use super::inner::CloudHypervisorInner; use crate::ch::utils::get_api_socket_path; use crate::ch::utils::get_vsock_path; use crate::kernel_param::KernelParams; -use crate::utils::{get_jailer_root, get_sandbox_path}; +use crate::utils::{bytes_to_megs, get_jailer_root, get_sandbox_path, megs_to_bytes}; use crate::MemoryConfig; use crate::VM_ROOTFS_DRIVER_BLK; use crate::VM_ROOTFS_DRIVER_PMEM; use crate::{VcpuThreadIds, VmmState}; use anyhow::{anyhow, Context, Result}; -use ch_config::ch_api::{ - cloud_hypervisor_vm_create, cloud_hypervisor_vm_start, cloud_hypervisor_vmm_ping, - cloud_hypervisor_vmm_shutdown, +use ch_config::{ + ch_api::{ + cloud_hypervisor_vm_create, cloud_hypervisor_vm_info, cloud_hypervisor_vm_resize, + cloud_hypervisor_vm_start, cloud_hypervisor_vmm_ping, cloud_hypervisor_vmm_shutdown, + }, + VmResize, }; use ch_config::{guest_protection_is_tdx, NamedHypervisorConfig, VmConfig}; use core::future::poll_fn; @@ -678,8 +681,50 @@ impl CloudHypervisorInner { Ok(()) } - pub(crate) async fn resize_vcpu(&self, old_vcpu: u32, new_vcpu: u32) -> Result<(u32, u32)> { - Ok((old_vcpu, new_vcpu)) + pub(crate) async fn resize_vcpu( + &self, + old_vcpus: u32, + mut new_vcpus: u32, + ) -> Result<(u32, u32)> { + info!( + sl!(), + "cloud hypervisor resize_vcpu(): {} -> {}", old_vcpus, new_vcpus + ); + + if new_vcpus == 0 { + return Err(anyhow!("resize to 0 vcpus requested")); + } + + if new_vcpus > self.config.cpu_info.default_maxvcpus { + warn!( + sl!(), + "Cannot allocate more vcpus than the max allowed number of vcpus. The maximum allowed amount of vcpus will be used instead."); + new_vcpus = self.config.cpu_info.default_maxvcpus; + } + + if new_vcpus == old_vcpus { + return Ok((old_vcpus, new_vcpus)); + } + + let socket = self + .api_socket + .as_ref() + .ok_or("missing socket") + .map_err(|e| anyhow!(e))?; + + let vmresize = VmResize { + desired_vcpus: Some(new_vcpus as u8), + ..Default::default() + }; + + cloud_hypervisor_vm_resize( + socket.try_clone().context("failed to clone socket")?, + vmresize, + ) + .await + .context("resize vcpus")?; + + Ok((old_vcpus, new_vcpus)) } pub(crate) async fn get_pids(&self) -> Result> { @@ -748,17 +793,99 @@ impl CloudHypervisorInner { } pub(crate) fn set_guest_memory_block_size(&mut self, size: u32) { - self._guest_memory_block_size_mb = size; + self.guest_memory_block_size_mb = bytes_to_megs(size as u64); } pub(crate) fn guest_memory_block_size_mb(&self) -> u32 { - self._guest_memory_block_size_mb + self.guest_memory_block_size_mb } - pub(crate) fn resize_memory(&self, _new_mem_mb: u32) -> Result<(u32, MemoryConfig)> { - warn!(sl!(), "CH memory resize not implemented - see https://github.com/kata-containers/kata-containers/issues/8801"); + pub(crate) async fn resize_memory(&self, new_mem_mb: u32) -> Result<(u32, MemoryConfig)> { + let socket = self + .api_socket + .as_ref() + .ok_or("missing socket") + .map_err(|e| anyhow!(e))?; - Ok((0, MemoryConfig::default())) + let vminfo = + cloud_hypervisor_vm_info(socket.try_clone().context("failed to clone socket")?) + .await + .context("get vminfo")?; + + let current_mem_size = vminfo.config.memory.size; + let new_total_mem = megs_to_bytes(new_mem_mb); + + info!( + sl!(), + "cloud-hypervisor::resize_memory(): asked to resize memory to {} MB, current memory is {} MB", new_mem_mb, bytes_to_megs(current_mem_size) + ); + + // Early Check to verify if boot memory is the same as requested + if current_mem_size == new_total_mem { + info!(sl!(), "VM alreay has requested memory"); + return Ok((new_mem_mb, MemoryConfig::default())); + } + + if current_mem_size > new_total_mem { + info!(sl!(), "Remove memory is not supported, nothing to do"); + return Ok((new_mem_mb, MemoryConfig::default())); + } + + let guest_mem_block_size = megs_to_bytes(self.guest_memory_block_size_mb); + + let mut new_hotplugged_mem = new_total_mem - current_mem_size; + + info!( + sl!(), + "new hotplugged mem before alignment: {} B ({} MB), guest_mem_block_size: {} MB", + new_hotplugged_mem, + bytes_to_megs(new_hotplugged_mem), + bytes_to_megs(guest_mem_block_size) + ); + + let is_unaligned = new_hotplugged_mem % guest_mem_block_size != 0; + if is_unaligned { + new_hotplugged_mem = ch_config::convert::checked_next_multiple_of( + new_hotplugged_mem, + guest_mem_block_size, + ) + .ok_or(anyhow!(format!( + "alignment of {} B to the block size of {} B failed", + new_hotplugged_mem, guest_mem_block_size + )))? + } + + let new_total_mem_aligned = new_hotplugged_mem + current_mem_size; + + let max_total_mem = megs_to_bytes(self.config.memory_info.default_maxmemory); + if new_total_mem_aligned > max_total_mem { + return Err(anyhow!( + "requested memory ({} MB) is greater than maximum allowed ({} MB)", + bytes_to_megs(new_total_mem_aligned), + self.config.memory_info.default_maxmemory + )); + } + + info!( + sl!(), + "hotplugged mem from {} MB to {} MB)", + bytes_to_megs(current_mem_size), + bytes_to_megs(new_total_mem_aligned) + ); + + let vmresize = VmResize { + desired_ram: Some(new_total_mem_aligned), + ..Default::default() + }; + + cloud_hypervisor_vm_resize( + socket.try_clone().context("failed to clone socket")?, + vmresize, + ) + .await + .context("resize memory")?; + + Ok((new_mem_mb, MemoryConfig::default())) } } diff --git a/src/runtime-rs/crates/hypervisor/src/ch/mod.rs b/src/runtime-rs/crates/hypervisor/src/ch/mod.rs index 31056c9052..9381569af8 100644 --- a/src/runtime-rs/crates/hypervisor/src/ch/mod.rs +++ b/src/runtime-rs/crates/hypervisor/src/ch/mod.rs @@ -206,7 +206,7 @@ impl Hypervisor for CloudHypervisor { async fn resize_memory(&self, new_mem_mb: u32) -> Result<(u32, MemoryConfig)> { let inner = self.inner.read().await; - inner.resize_memory(new_mem_mb) + inner.resize_memory(new_mem_mb).await } async fn get_passfd_listener_addr(&self) -> Result<(String, u32)> { diff --git a/src/runtime-rs/crates/hypervisor/src/qemu/inner.rs b/src/runtime-rs/crates/hypervisor/src/qemu/inner.rs index 55e7b9d7b7..b781de8177 100644 --- a/src/runtime-rs/crates/hypervisor/src/qemu/inner.rs +++ b/src/runtime-rs/crates/hypervisor/src/qemu/inner.rs @@ -7,10 +7,12 @@ use super::cmdline_generator::{get_network_device, QemuCmdLine, QMP_SOCKET_FILE} use super::qmp::Qmp; use crate::device::topology::PCIePort; use crate::{ - device::driver::ProtectionDeviceConfig, hypervisor_persist::HypervisorState, - utils::enter_netns, HypervisorConfig, MemoryConfig, VcpuThreadIds, VsockDevice, - HYPERVISOR_QEMU, + device::driver::ProtectionDeviceConfig, hypervisor_persist::HypervisorState, HypervisorConfig, + MemoryConfig, VcpuThreadIds, VsockDevice, HYPERVISOR_QEMU, }; + +use crate::utils::{bytes_to_megs, enter_netns, megs_to_bytes}; + use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use kata_sys_util::netns::NetnsGuard; @@ -456,15 +458,6 @@ impl QemuInner { "QemuInner::resize_memory(): asked to resize memory to {} MB", new_total_mem_mb ); - // stick to the apparent de facto convention and represent megabytes - // as u32 and bytes as u64 - fn bytes_to_megs(bytes: u64) -> u32 { - (bytes / (1 << 20)) as u32 - } - fn megs_to_bytes(bytes: u32) -> u64 { - bytes as u64 * (1 << 20) - } - let qmp = match self.qmp { Some(ref mut qmp) => qmp, None => { diff --git a/src/runtime-rs/crates/hypervisor/src/utils.rs b/src/runtime-rs/crates/hypervisor/src/utils.rs index d1fe47a5f1..2cdd91d5f9 100644 --- a/src/runtime-rs/crates/hypervisor/src/utils.rs +++ b/src/runtime-rs/crates/hypervisor/src/utils.rs @@ -192,6 +192,14 @@ impl std::fmt::Display for SocketAddress { } } +pub fn bytes_to_megs(bytes: u64) -> u32 { + (bytes / (1 << 20)) as u32 +} + +pub fn megs_to_bytes(bytes: u32) -> u64 { + bytes as u64 * (1 << 20) +} + #[cfg(test)] mod tests { use super::create_fds;