diff --git a/src/runtime-rs/crates/hypervisor/src/device/driver/vfio.rs b/src/runtime-rs/crates/hypervisor/src/device/driver/vfio.rs index 9e329186a7..27e69d9dde 100644 --- a/src/runtime-rs/crates/hypervisor/src/device/driver/vfio.rs +++ b/src/runtime-rs/crates/hypervisor/src/device/driver/vfio.rs @@ -36,6 +36,10 @@ pub const DRIVER_VFIO_PCI_TYPE: &str = "vfio-pci"; pub const DRIVER_VFIO_AP_TYPE: &str = "vfio-ap"; pub const MAX_DEV_ID_SIZE: usize = 31; +/// PCI class bitmasks for devices that must be ignored when enumerating an IOMMU group. +/// Host Bridge: 0x0600, Audio device: 0x0403. +const IOMMU_IGNORE: &[u64] = &[0x0600, 0x403]; + const VFIO_PCI_DRIVER_NEW_ID: &str = "/sys/bus/pci/drivers/vfio-pci/new_id"; const VFIO_PCI_DRIVER_UNBIND: &str = "/sys/bus/pci/drivers/vfio-pci/unbind"; const SYS_CLASS_IOMMU: &str = "/sys/class/iommu"; @@ -437,10 +441,9 @@ impl VfioDevice { } } - // filter Host or PCI Bridges that are in the same IOMMU group as the - // passed-through devices. One CANNOT pass-through a PCI bridge or Host - // bridge. Class 0x0604 is PCI bridge, 0x0600 is Host bridge - fn filter_bridge_device(&self, bdf: &str, bitmask: u64) -> Option { + // filter Host or PCI Bridges and audio devices that are in the same IOMMU + // group as the passed-through devices. + fn filter_bridge_device(&self, bdf: &str, bitmasks: &[u64]) -> Option { let device_class = match get_device_property(bdf, "class") { Ok(dev_class) => dev_class, Err(_) => "".to_string(), @@ -454,11 +457,12 @@ impl VfioDevice { Ok(cid_u32) => { // class code is 16 bits, remove the two trailing zeros let class_code = u64::from(cid_u32) >> 8; - if class_code & bitmask == bitmask { - Some(class_code) - } else { - None + for &bitmask in bitmasks { + if class_code & bitmask == bitmask { + return Some(class_code); + } } + None } _ => None, } @@ -491,8 +495,8 @@ impl VfioDevice { // pass all devices in iommu group, and use index to identify device. for (index, device) in iommu_devices.iter().enumerate() { - // filter host or PCI bridge - if self.filter_bridge_device(device, 0x0600).is_some() { + // filter host/PCI bridge, audio, etc. + if self.filter_bridge_device(device, IOMMU_IGNORE).is_some() { continue; } diff --git a/src/runtime-rs/crates/hypervisor/src/device/driver/vfio_device/core.rs b/src/runtime-rs/crates/hypervisor/src/device/driver/vfio_device/core.rs index 6bfa41a1ba..88aa8b6720 100644 --- a/src/runtime-rs/crates/hypervisor/src/device/driver/vfio_device/core.rs +++ b/src/runtime-rs/crates/hypervisor/src/device/driver/vfio_device/core.rs @@ -340,8 +340,8 @@ fn validate_group_basic(devices: &[DeviceInfo]) -> bool { if let DeviceAddress::Pci(bdf) = &device.addr { // filter host or PCI bridge let bdf_str = bdf.to_string(); - // Filter out host or PCI bridges (cannot be passed through) - if filter_bridge_device(&bdf_str, 0x0600).is_some() { + // Filter out devices that cannot be passed through (bridges, audio, etc.) + if filter_bridge_device(&bdf_str, IOMMU_IGNORE).is_some() { continue; } } @@ -362,9 +362,13 @@ fn get_device_property(device_bdf: &str, property: &str) -> Result { Ok(cfg_path.trim().to_string()) } -/// Filters for Host or PCI bridges within an IOMMU group. -/// PCI Bridge: Class 0x0604, Host Bridge: Class 0x0600. -fn filter_bridge_device(bdf: &str, bitmask: u64) -> Option { +/// PCI class bitmasks for devices that must be ignored when enumerating an IOMMU group. +/// Host Bridge: 0x0600, Audio device: 0x0403. +const IOMMU_IGNORE: &[u64] = &[0x0600, 0x403]; + +/// Filters for devices that cannot or should not be passed through within an IOMMU group +/// (Host/PCI bridges, audio controllers that share the GPU's IOMMU group, etc.). +fn filter_bridge_device(bdf: &str, bitmasks: &[u64]) -> Option { let device_class = get_device_property(bdf, "class").unwrap_or_default(); if device_class.is_empty() { @@ -375,11 +379,12 @@ fn filter_bridge_device(bdf: &str, bitmask: u64) -> Option { Ok(cid_u32) => { // PCI class code is 24 bits, shift right 8 to get base+sub class let class_code = u64::from(cid_u32) >> 8; - if class_code & bitmask == bitmask { - Some(class_code) - } else { - None + for &bitmask in bitmasks { + if class_code & bitmask == bitmask { + return Some(class_code); + } } + None } _ => None, }