From a6a399d5bcf985249dbc17b67d4bc186a11dc8c3 Mon Sep 17 00:00:00 2001 From: Qinqi Qu Date: Thu, 20 Apr 2023 17:10:43 +0800 Subject: [PATCH 1/2] dragonball: add vhost-user connection management logic The vhost-user connection management logic will be used by the upcoming features: vhost-user-net, vhost-user-blk and vhost-user-fs. Fixes: #8448 Signed-off-by: Liu Jiang Signed-off-by: Qinqi Qu Signed-off-by: Huang Jianan --- .../src/dbs_virtio_devices/Cargo.toml | 1 + .../src/dbs_virtio_devices/src/lib.rs | 40 +- .../dbs_virtio_devices/src/mmio/mmio_state.rs | 6 +- .../src/dbs_virtio_devices/src/vhost/mod.rs | 18 + .../src/vhost/vhost_kern/net.rs | 4 +- .../src/vhost/vhost_user/connection.rs | 552 ++++++++++++++++++ .../src/vhost/vhost_user/mod.rs | 8 + 7 files changed, 624 insertions(+), 5 deletions(-) create mode 100644 src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/connection.rs create mode 100644 src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/mod.rs diff --git a/src/dragonball/src/dbs_virtio_devices/Cargo.toml b/src/dragonball/src/dbs_virtio_devices/Cargo.toml index 9299915ad9..eb0912f306 100644 --- a/src/dragonball/src/dbs_virtio_devices/Cargo.toml +++ b/src/dragonball/src/dbs_virtio_devices/Cargo.toml @@ -53,3 +53,4 @@ virtio-mem = ["virtio-mmio"] virtio-balloon = ["virtio-mmio"] vhost = ["virtio-mmio", "vhost-rs/vhost-user-master", "vhost-rs/vhost-kern"] vhost-net = ["vhost", "vhost-rs/vhost-net"] +vhost-user = ["vhost"] \ No newline at end of file diff --git a/src/dragonball/src/dbs_virtio_devices/src/lib.rs b/src/dragonball/src/dbs_virtio_devices/src/lib.rs index 59727b0903..93c4525440 100644 --- a/src/dragonball/src/dbs_virtio_devices/src/lib.rs +++ b/src/dragonball/src/dbs_virtio_devices/src/lib.rs @@ -125,6 +125,32 @@ pub enum ActivateError { InvalidQueueConfig, #[error("IO: {0}.")] IOError(#[from] IOError), + #[error("Virtio error")] + VirtioError(Error), + #[error("Epoll manager error")] + EpollMgr(dbs_utils::epoll_manager::Error), + #[cfg(feature = "vhost")] + #[error("Vhost activate error")] + VhostActivate(vhost_rs::Error), +} + +impl std::convert::From for ActivateError { + fn from(error: Error) -> ActivateError { + ActivateError::VirtioError(error) + } +} + +impl std::convert::From for ActivateError { + fn from(error: dbs_utils::epoll_manager::Error) -> ActivateError { + ActivateError::EpollMgr(error) + } +} + +#[cfg(feature = "vhost")] +impl std::convert::From for ActivateError { + fn from(error: vhost_rs::Error) -> ActivateError { + ActivateError::VhostActivate(error) + } } /// Error code for VirtioDevice::read_config()/write_config(). @@ -155,6 +181,9 @@ pub enum Error { /// Guest gave us a descriptor that was too big to use. #[error("descriptor length too big.")] DescriptorLengthTooBig, + /// Error from the epoll event manager + #[error("dbs_utils error: {0:?}.")] + EpollMgr(dbs_utils::epoll_manager::Error), /// Guest gave us a write only descriptor that protocol says to read from. #[error("unexpected write only descriptor.")] UnexpectedWriteOnlyDescriptor, @@ -181,7 +210,7 @@ pub enum Error { VirtioQueueError(#[from] VqError), /// Error from Device activate. #[error("Device activate error: {0}")] - ActivateError(#[from] ActivateError), + ActivateError(#[from] Box), /// Error from Interrupt. #[error("Interrupt error: {0}")] InterruptError(IOError), @@ -229,6 +258,15 @@ pub enum Error { #[cfg(feature = "virtio-balloon")] #[error("Virtio-balloon error: {0}")] VirtioBalloonError(#[from] balloon::BalloonError), + + #[cfg(feature = "vhost")] + /// Error from the vhost subsystem + #[error("Vhost error: {0:?}")] + VhostError(vhost_rs::Error), + #[cfg(feature = "vhost")] + /// Error from the vhost user subsystem + #[error("Vhost-user error: {0:?}")] + VhostUserError(vhost_rs::vhost_user::Error), } // Error for tap devices diff --git a/src/dragonball/src/dbs_virtio_devices/src/mmio/mmio_state.rs b/src/dragonball/src/dbs_virtio_devices/src/mmio/mmio_state.rs index 434be51a91..796024879a 100644 --- a/src/dragonball/src/dbs_virtio_devices/src/mmio/mmio_state.rs +++ b/src/dragonball/src/dbs_virtio_devices/src/mmio/mmio_state.rs @@ -124,7 +124,9 @@ where // If the driver incorrectly sets up the queues, the following check will fail and take // the device into an unusable state. if !self.check_queues_valid() { - return Err(Error::ActivateError(ActivateError::InvalidQueueConfig)); + return Err(Error::ActivateError(Box::new( + ActivateError::InvalidQueueConfig, + ))); } self.register_ioevent()?; @@ -138,7 +140,7 @@ where .map(|_| self.device_activated = true) .map_err(|e| { error!("device activate error: {:?}", e); - Error::ActivateError(e) + Error::ActivateError(Box::new(e)) }) } diff --git a/src/dragonball/src/dbs_virtio_devices/src/vhost/mod.rs b/src/dragonball/src/dbs_virtio_devices/src/vhost/mod.rs index 7c2940e699..d60a281aa5 100644 --- a/src/dragonball/src/dbs_virtio_devices/src/vhost/mod.rs +++ b/src/dragonball/src/dbs_virtio_devices/src/vhost/mod.rs @@ -6,3 +6,21 @@ #[cfg(feature = "vhost-net")] pub mod vhost_kern; + +pub use vhost_rs::vhost_user::Error as VhostUserError; +pub use vhost_rs::Error as VhostError; + +#[cfg(feature = "vhost-user")] +pub mod vhost_user; + +impl std::convert::From for super::Error { + fn from(e: VhostError) -> Self { + super::Error::VhostError(e) + } +} + +impl std::convert::From for super::Error { + fn from(e: VhostUserError) -> Self { + super::Error::VhostUserError(e) + } +} diff --git a/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_kern/net.rs b/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_kern/net.rs index 2b9a379de5..cd65474ece 100644 --- a/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_kern/net.rs +++ b/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_kern/net.rs @@ -290,7 +290,7 @@ where "{}: Invalid virtio queue pairs, expected a value greater than 0, but got {}", NET_DRIVER_NAME, self.vq_pairs ); - return Err(VirtioError::ActivateError(ActivateError::InvalidParam)); + return Err(VirtioError::ActivateError(Box::new(ActivateError::InvalidParam))); } if self.handles.len() != self.vq_pairs || self.taps.len() != self.vq_pairs { @@ -299,7 +299,7 @@ where self.handles.len(), self.taps.len(), self.vq_pairs); - return Err(VirtioError::ActivateError(ActivateError::InternalError)); + return Err(VirtioError::ActivateError(Box::new(ActivateError::InternalError))); } for idx in 0..self.vq_pairs { diff --git a/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/connection.rs b/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/connection.rs new file mode 100644 index 0000000000..7eeeef1baf --- /dev/null +++ b/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/connection.rs @@ -0,0 +1,552 @@ +// Copyright (C) 2019-2023 Alibaba Cloud. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Helper utilities for vhost-user communication channel. + +use std::ops::Deref; +use std::os::unix::io::{AsRawFd, RawFd}; + +use dbs_utils::epoll_manager::{EventOps, EventSet, Events}; +use log::*; +use vhost_rs::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVringAddrFlags}; +use vhost_rs::vhost_user::{ + Error as VhostUserError, Listener as VhostUserListener, Master, VhostUserMaster, +}; +use vhost_rs::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData}; +use virtio_queue::QueueT; +use vm_memory::{ + Address, GuestAddress, GuestAddressSpace, GuestMemory, GuestMemoryRegion, MemoryRegionAddress, +}; +use vmm_sys_util::eventfd::EventFd; + +use super::super::super::device::VirtioDeviceConfig; +use super::super::super::{Error as VirtioError, Result as VirtioResult}; +use super::VhostError; + +enum EndpointProtocolFlags { + ProtocolMq = 1, +} + +pub(super) struct Listener { + listener: VhostUserListener, + /// Slot to register epoll event for the underlying socket. + slot: u32, + name: String, + path: String, +} + +impl Listener { + pub fn new(name: String, path: String, force: bool, slot: u32) -> VirtioResult { + info!("vhost-user: create listener at {} for {}", path, name); + Ok(Listener { + listener: VhostUserListener::new(&path, force)?, + slot, + name, + path, + }) + } + + // Wait for an incoming connection until success. + pub fn accept(&self) -> VirtioResult<(Master, u64)> { + loop { + match self.try_accept() { + Ok(Some((master, feature))) => return Ok((master, feature)), + Ok(None) => continue, + Err(e) => return Err(e), + } + } + } + + pub fn try_accept(&self) -> VirtioResult> { + let sock = match self.listener.accept() { + Ok(Some(conn)) => conn, + Ok(None) => return Ok(None), + Err(e) => return Err(e.into()), + }; + + let mut master = Master::from_stream(sock, 1); + info!("{}: try to get virtio features from slave.", self.name); + match Endpoint::initialize(&mut master) { + Ok(Some(features)) => Ok(Some((master, features))), + // The new connection has been closed, try again. + Ok(None) => { + warn!( + "{}: new connection get closed during initialization, waiting for another one.", + self.name + ); + Ok(None) + } + // Unrecoverable error happened + Err(e) => { + error!("{}: failed to get virtio features, {}", self.name, e); + Err(e) + } + } + } + + /// Register the underlying listener to be monitored for incoming connection. + pub fn register_epoll_event(&self, ops: &mut EventOps) -> VirtioResult<()> { + info!("{}: monitor incoming connect at {}", self.name, self.path); + // Switch to nonblocking mode. + self.listener.set_nonblocking(true)?; + let event = Events::with_data(&self.listener, self.slot, EventSet::IN); + ops.add(event).map_err(VirtioError::EpollMgr) + } +} + +/// Struct to pass info to vhost user backend +#[derive(Clone)] +pub struct BackendInfo { + /// -1 means to tell backend to destroy corresponding + /// device, while others means construct it + fd: i32, + /// cluster id of device, must set + cluster_id: u32, + /// device id of device, must set + device_id: u64, + /// device config file path + filename: [u8; 128], +} + +/// Struct to pass function parameters to methods of Endpoint. +pub(super) struct EndpointParam<'a, AS: GuestAddressSpace, Q: QueueT, R: GuestMemoryRegion> { + pub virtio_config: &'a VirtioDeviceConfig, + pub intr_evts: Vec<&'a EventFd>, + pub queue_sizes: &'a [u16], + pub features: u64, + pub protocol_flag: u16, + pub dev_protocol_features: VhostUserProtocolFeatures, + pub reconnect: bool, + pub backend: Option, + pub init_queues: u32, + pub slave_req_fd: Option, +} + +impl<'a, AS: GuestAddressSpace, Q: QueueT, R: GuestMemoryRegion> EndpointParam<'a, AS, Q, R> { + fn get_host_address(&self, addr: GuestAddress, mem: &AS::M) -> VirtioResult<*mut u8> { + mem.get_host_address(addr) + .map_err(|_| VirtioError::InvalidGuestAddress(addr)) + } + + /// set protocol multi-queue bit + pub fn set_protocol_mq(&mut self) { + self.protocol_flag |= EndpointProtocolFlags::ProtocolMq as u16; + } + + /// check if multi-queue bit is set + pub fn has_protocol_mq(&self) -> bool { + (self.protocol_flag & (EndpointProtocolFlags::ProtocolMq as u16)) != 0 + } +} + +/// Communication channel from the master to the slave. +/// +/// It encapsulates a low-level vhost-user master side communication endpoint, and provides +/// connection initialization, monitoring and reconnect functionalities for vhost-user devices. +/// +/// Caller needs to ensure mutual exclusive access to the object. +pub(super) struct Endpoint { + /// Underlying vhost-user communication endpoint. + conn: Option, + old: Option, + /// Token to register epoll event for the underlying socket. + slot: u32, + /// Identifier string for logs. + name: String, +} + +impl Endpoint { + pub fn new(master: Master, slot: u32, name: String) -> Self { + Endpoint { + conn: Some(master), + old: None, + slot, + name, + } + } + + /// First state of the connection negotiation between the master and the slave. + /// + /// If Ok(None) is returned, the underlying communication channel gets broken and the caller may + /// try to recreate the communication channel and negotiate again. + /// + /// # Return + /// * - Ok(Some(avial_features)): virtio features from the slave + /// * - Ok(None): underlying communicaiton channel gets broken during negotiation + /// * - Err(e): error conditions + fn initialize(master: &mut Master) -> VirtioResult> { + // 1. Seems that some vhost-user slaves depend on the get_features request to driver its + // internal state machine. + // N.B. it's really TDD, we just found it works in this way. Any spec about this? + let features = match master.get_features() { + Ok(val) => val, + Err(VhostError::VhostUserProtocol(VhostUserError::SocketBroken(_e))) => { + return Ok(None) + } + Err(e) => return Err(e.into()), + }; + + Ok(Some(features)) + } + + pub fn update_memory(&mut self, vm_as: &AS) -> VirtioResult<()> { + let master = match self.conn.as_mut() { + Some(conn) => conn, + None => { + error!("vhost user master is None!"); + return Err(VirtioError::InternalError); + } + }; + let guard = vm_as.memory(); + let mem = guard.deref(); + let mut regions = Vec::new(); + for region in mem.iter() { + let guest_phys_addr = region.start_addr(); + let file_offset = region.file_offset().ok_or_else(|| { + error!("region file_offset get error!"); + VirtioError::InvalidGuestAddress(guest_phys_addr) + })?; + let userspace_addr = region + .get_host_address(MemoryRegionAddress(0)) + .map_err(|e| { + error!("get_host_address error! {:?}", e); + VirtioError::InvalidGuestAddress(guest_phys_addr) + })?; + + regions.push(VhostUserMemoryRegionInfo { + guest_phys_addr: guest_phys_addr.raw_value() as u64, + memory_size: region.len() as u64, + userspace_addr: userspace_addr as *const u8 as u64, + mmap_offset: file_offset.start(), + mmap_handle: file_offset.file().as_raw_fd(), + }); + } + master.set_mem_table(®ions)?; + Ok(()) + } + + /// Drive the negotiation and initialization process with the vhost-user slave. + pub fn negotiate( + &mut self, + config: &EndpointParam, + mut old: Option<&mut Master>, + ) -> VirtioResult<()> { + let guard = config.virtio_config.lock_guest_memory(); + let mem = guard.deref(); + let queue_num = config.virtio_config.queues.len(); + assert_eq!(queue_num, config.queue_sizes.len()); + assert_eq!(queue_num, config.intr_evts.len()); + + let master = match self.conn.as_mut() { + Some(conn) => conn, + None => return Err(VirtioError::InternalError), + }; + + info!("{}: negotiate()", self.name); + master.set_owner()?; + info!("{}: set_owner()", self.name); + + // 3. query features again after set owner. + let features = master.get_features()?; + info!("{}: get_features({:X})", self.name, features); + + // 4. set virtio features. + master.set_features(config.features)?; + info!("{}: set_features({:X})", self.name, config.features); + + // 5. set vhost-user protocol features + // typical protocol features: 0x37 + let mut protocol_features = master.get_protocol_features()?; + info!( + "{}: get_protocol_features({:X})", + self.name, protocol_features + ); + // There are two virtque for rx/tx. + if config.has_protocol_mq() && !protocol_features.contains(VhostUserProtocolFeatures::MQ) { + return Err(VhostError::VhostUserProtocol(VhostUserError::FeatureMismatch).into()); + } + protocol_features &= config.dev_protocol_features; + master.set_protocol_features(protocol_features)?; + info!( + "{}: set_protocol_features({:X}), dev_protocol_features({:X})", + self.name, protocol_features, config.dev_protocol_features + ); + + // Setup slave channel if SLAVE_REQ protocol feature is set + if protocol_features.contains(VhostUserProtocolFeatures::SLAVE_REQ) { + match config.slave_req_fd { + Some(fd) => master.set_slave_request_fd(&fd)?, + None => { + error!( + "{}: Protocol feature SLAVE_REQ is set but not slave channel fd", + self.name + ); + return Err(VhostError::VhostUserProtocol(VhostUserError::InvalidParam).into()); + } + } + } else { + info!("{}: has no SLAVE_REQ protocol feature set", self.name); + } + + // 6. check number of queues supported + if config.has_protocol_mq() { + let queue_num = master.get_queue_num()?; + info!("{}: get_queue_num({:X})", self.name, queue_num); + if queue_num < config.queue_sizes.len() as u64 { + return Err(VhostError::VhostUserProtocol(VhostUserError::FeatureMismatch).into()); + } + } + + // 7. trigger the backend state machine. + for queue_index in 0..queue_num { + master.set_vring_call(queue_index, config.intr_evts[queue_index])?; + } + info!("{}: set_vring_call()", self.name); + + // 8. set mem_table + let mut regions = Vec::new(); + for region in mem.iter() { + let guest_phys_addr = region.start_addr(); + let file_offset = region + .file_offset() + .ok_or(VirtioError::InvalidGuestAddress(guest_phys_addr))?; + let userspace_addr = region + .get_host_address(MemoryRegionAddress(0)) + .map_err(|_| VirtioError::InvalidGuestAddress(guest_phys_addr))?; + + regions.push(VhostUserMemoryRegionInfo { + guest_phys_addr: guest_phys_addr.raw_value() as u64, + memory_size: region.len() as u64, + userspace_addr: userspace_addr as *const u8 as u64, + mmap_offset: file_offset.start(), + mmap_handle: file_offset.file().as_raw_fd(), + }); + } + master.set_mem_table(®ions)?; + info!("{}: set_mem_table()", self.name); + + // 9. setup vrings + for queue_cfg in config.virtio_config.queues.iter() { + master.set_vring_num(queue_cfg.index() as usize, queue_cfg.actual_size() as u16)?; + info!( + "{}: set_vring_num(idx: {}, size: {})", + self.name, + queue_cfg.index(), + queue_cfg.actual_size(), + ); + } + // On reconnection, the slave may have processed some packets in virtque and queue + // base is not zero any more. So don't set queue base on reconnection. + // N.B. it's really TDD, we just found it works in this way. Any spec about this? + for queue_index in 0..queue_num { + let base = if old.is_some() { + let conn = old.as_mut().unwrap(); + match conn.get_vring_base(queue_index) { + Ok(val) => Some(val), + Err(_) => None, + } + } else if !config.reconnect { + Some(0) + } else { + None + }; + if let Some(val) = base { + master.set_vring_base(queue_index, val as u16)?; + info!( + "{}: set_vring_base(idx: {}, base: {})", + self.name, queue_index, val + ); + } + } + for queue_cfg in config.virtio_config.queues.iter() { + let queue = &queue_cfg.queue; + let queue_index = queue_cfg.index() as usize; + let desc_addr = + config.get_host_address(vm_memory::GuestAddress(queue.desc_table()), mem)?; + let used_addr = + config.get_host_address(vm_memory::GuestAddress(queue.used_ring()), mem)?; + let avail_addr = + config.get_host_address(vm_memory::GuestAddress(queue.avail_ring()), mem)?; + master.set_vring_addr( + queue_index, + &VringConfigData { + queue_max_size: queue.max_size(), + queue_size: queue_cfg.actual_size(), + flags: VhostUserVringAddrFlags::empty().bits(), + desc_table_addr: desc_addr as u64, + used_ring_addr: used_addr as u64, + avail_ring_addr: avail_addr as u64, + log_addr: None, + }, + )?; + info!( + "{}: set_vring_addr(idx: {}, addr: {:p})", + self.name, queue_index, desc_addr + ); + } + for queue_index in 0..queue_num { + master.set_vring_kick( + queue_index, + &config.virtio_config.queues[queue_index].eventfd, + )?; + info!( + "{}: set_vring_kick(idx: {}, fd: {})", + self.name, + queue_index, + config.virtio_config.queues[queue_index].eventfd.as_raw_fd() + ); + } + for queue_index in 0..queue_num { + let intr_index = if config.intr_evts.len() == 1 { + 0 + } else { + queue_index + }; + master.set_vring_call(queue_index, config.intr_evts[intr_index])?; + info!( + "{}: set_vring_call(idx: {}, fd: {})", + self.name, + queue_index, + config.intr_evts[intr_index].as_raw_fd() + ); + } + for queue_index in 0..queue_num { + master.set_vring_enable(queue_index, true)?; + info!( + "{}: set_vring_enable(idx: {}, enable: {})", + self.name, queue_index, true + ); + if (queue_index + 1) == config.init_queues as usize { + break; + } + } + info!("{}: protocol negotiate completed successfully.", self.name); + + Ok(()) + } + + pub fn set_queues_attach(&mut self, curr_queues: u32) -> VirtioResult<()> { + let master = match self.conn.as_mut() { + Some(conn) => conn, + None => return Err(VirtioError::InternalError), + }; + + for index in 0..curr_queues { + master.set_vring_enable(index as usize, true)?; + info!( + "{}: set_vring_enable(idx: {}, enable: {})", + self.name, index, true + ); + } + + Ok(()) + } + + /// Restore communication with the vhost-user slave on reconnect. + pub fn reconnect( + &mut self, + master: Master, + config: &EndpointParam, + ops: &mut EventOps, + ) -> VirtioResult<()> { + let mut old = self.conn.replace(master); + if let Err(e) = self.negotiate(config, old.as_mut()) { + error!("{}: failed to initialize connection: {}", self.name, e); + self.conn = old; + return Err(e); + } + if let Err(e) = self.register_epoll_event(ops) { + error!("{}: failed to add fd to epoll: {}", self.name, e); + self.conn = old; + return Err(e); + } + self.old = old; + Ok(()) + } + + /// Teardown the communication channel to the vhost-user slave. + pub fn disconnect(&mut self, ops: &mut EventOps) -> VirtioResult<()> { + info!("vhost-user-net: disconnect communication channel."); + match self.old.take() { + Some(master) => { + info!("close old connection"); + self.deregister_epoll_event(&master, ops) + } + None => match self.conn.take() { + Some(master) => { + info!("disconnect connection."); + self.deregister_epoll_event(&master, ops) + } + None => { + info!("get disconnect notification when it's already disconnected."); + Ok(()) + } + }, + } + } + + /// Register the underlying socket to be monitored for socket disconnect events. + pub fn register_epoll_event(&self, ops: &mut EventOps) -> VirtioResult<()> { + match self.conn.as_ref() { + Some(master) => { + info!( + "{}: monitor disconnect event for fd {}.", + self.name, + master.as_raw_fd() + ); + ops.add(Events::with_data( + master, + self.slot, + EventSet::HANG_UP | EventSet::EDGE_TRIGGERED, + )) + .map_err(VirtioError::EpollMgr) + } + None => Err(VirtioError::InternalError), + } + } + + /// Deregister the underlying socket from the epoll controller. + pub fn deregister_epoll_event(&self, master: &Master, ops: &mut EventOps) -> VirtioResult<()> { + info!( + "{}: unregister epoll event for fd {}.", + self.name, + master.as_raw_fd() + ); + ops.remove(Events::with_data( + master, + self.slot, + EventSet::HANG_UP | EventSet::EDGE_TRIGGERED, + )) + .map_err(VirtioError::EpollMgr) + } + + pub fn set_master(&mut self, master: Master) { + self.conn = Some(master); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_endpoint_flags() { + assert_eq!(EndpointProtocolFlags::ProtocolMq as u16, 0x1); + } + + #[should_panic] + #[test] + fn test_connect_try_accept() { + let listener = Listener::new( + "test_listener".to_string(), + "/tmp/test_vhost_listener".to_string(), + true, + 1, + ) + .unwrap(); + + listener.listener.set_nonblocking(true).unwrap(); + + assert!(listener.try_accept().is_err()); + } +} diff --git a/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/mod.rs b/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/mod.rs new file mode 100644 index 0000000000..995c167e90 --- /dev/null +++ b/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/mod.rs @@ -0,0 +1,8 @@ +// Copyright (C) 2019-2023 Alibaba Cloud. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Vhost-based virtio device backend implementations. + +use super::VhostError; + +pub mod connection; From a9571398a6b8d20778d90f137e7a05c5c5b22e7c Mon Sep 17 00:00:00 2001 From: Huang Jianan Date: Mon, 24 Apr 2023 17:29:18 +0800 Subject: [PATCH 2/2] dragonball: add test utils for vhost-user The test utils will be used by the upcoming feature tests: vhost-user-net, vhost-user-blk and vhost-user-fs. Signed-off-by: Beiyue Signed-off-by: Huang Jianan --- .../src/vhost/vhost_user/mod.rs | 3 + .../src/vhost/vhost_user/test_utils.rs | 750 ++++++++++++++++++ 2 files changed, 753 insertions(+) create mode 100644 src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/test_utils.rs diff --git a/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/mod.rs b/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/mod.rs index 995c167e90..ad78fe7635 100644 --- a/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/mod.rs +++ b/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/mod.rs @@ -6,3 +6,6 @@ use super::VhostError; pub mod connection; + +#[cfg(test)] +mod test_utils; diff --git a/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/test_utils.rs b/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/test_utils.rs new file mode 100644 index 0000000000..ac5fb9e1d7 --- /dev/null +++ b/src/dragonball/src/dbs_virtio_devices/src/vhost/vhost_user/test_utils.rs @@ -0,0 +1,750 @@ +// Copyright (C) 2021 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause + +use std::fmt::Debug; +use std::marker::PhantomData; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::{mem, slice}; + +use vmm_sys_util::tempfile::TempFile; +use libc::{c_void, iovec}; +use vhost_rs::vhost_user::message::{ + VhostUserHeaderFlag, VhostUserInflight, VhostUserMemory, VhostUserMemoryRegion, + VhostUserMsgValidator, VhostUserProtocolFeatures, VhostUserU64, VhostUserVirtioFeatures, + VhostUserVringAddr, VhostUserVringState, MAX_MSG_SIZE, +}; +use vhost_rs::vhost_user::Error; +use vmm_sys_util::sock_ctrl_msg::ScmSocket; + +pub const MAX_ATTACHED_FD_ENTRIES: usize = 32; + +pub(crate) trait Req: + Clone + Copy + Debug + PartialEq + Eq + PartialOrd + Ord + Into +{ + fn is_valid(&self) -> bool; +} + +pub type Result = std::result::Result; + +/// Type of requests sending from masters to slaves. +#[repr(u32)] +#[allow(unused, non_camel_case_types)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum MasterReq { + /// Null operation. + NOOP = 0, + /// Get from the underlying vhost implementation the features bit mask. + GET_FEATURES = 1, + /// Enable features in the underlying vhost implementation using a bit mask. + SET_FEATURES = 2, + /// Set the current Master as an owner of the session. + SET_OWNER = 3, + /// No longer used. + RESET_OWNER = 4, + /// Set the memory map regions on the slave so it can translate the vring addresses. + SET_MEM_TABLE = 5, + /// Set logging shared memory space. + SET_LOG_BASE = 6, + /// Set the logging file descriptor, which is passed as ancillary data. + SET_LOG_FD = 7, + /// Set the size of the queue. + SET_VRING_NUM = 8, + /// Set the addresses of the different aspects of the vring. + SET_VRING_ADDR = 9, + /// Set the base offset in the available vring. + SET_VRING_BASE = 10, + /// Get the available vring base offset. + GET_VRING_BASE = 11, + /// Set the event file descriptor for adding buffers to the vring. + SET_VRING_KICK = 12, + /// Set the event file descriptor to signal when buffers are used. + SET_VRING_CALL = 13, + /// Set the event file descriptor to signal when error occurs. + SET_VRING_ERR = 14, + /// Get the protocol feature bit mask from the underlying vhost implementation. + GET_PROTOCOL_FEATURES = 15, + /// Enable protocol features in the underlying vhost implementation. + SET_PROTOCOL_FEATURES = 16, + /// Query how many queues the backend supports. + GET_QUEUE_NUM = 17, + /// Signal slave to enable or disable corresponding vring. + SET_VRING_ENABLE = 18, + /// Ask vhost user backend to broadcast a fake RARP to notify the migration is terminated + /// for guest that does not support GUEST_ANNOUNCE. + SEND_RARP = 19, + /// Set host MTU value exposed to the guest. + NET_SET_MTU = 20, + /// Set the socket file descriptor for slave initiated requests. + SET_SLAVE_REQ_FD = 21, + /// Send IOTLB messages with struct vhost_iotlb_msg as payload. + IOTLB_MSG = 22, + /// Set the endianness of a VQ for legacy devices. + SET_VRING_ENDIAN = 23, + /// Fetch the contents of the virtio device configuration space. + GET_CONFIG = 24, + /// Change the contents of the virtio device configuration space. + SET_CONFIG = 25, + /// Create a session for crypto operation. + CREATE_CRYPTO_SESSION = 26, + /// Close a session for crypto operation. + CLOSE_CRYPTO_SESSION = 27, + /// Advise slave that a migration with postcopy enabled is underway. + POSTCOPY_ADVISE = 28, + /// Advise slave that a transition to postcopy mode has happened. + POSTCOPY_LISTEN = 29, + /// Advise that postcopy migration has now completed. + POSTCOPY_END = 30, + /// Get a shared buffer from slave. + GET_INFLIGHT_FD = 31, + /// Send the shared inflight buffer back to slave + SET_INFLIGHT_FD = 32, + /// Upper bound of valid commands. + MAX_CMD = 33, +} + +impl Into for MasterReq { + fn into(self) -> u32 { + self as u32 + } +} + +impl Req for MasterReq { + fn is_valid(&self) -> bool { + (*self > MasterReq::NOOP) && (*self < MasterReq::MAX_CMD) + } +} + +// Given a slice of sizes and the `skip_size`, return the offset of `skip_size` in the slice. +// For example: +// let iov_lens = vec![4, 4, 5]; +// let size = 6; +// assert_eq!(get_sub_iovs_offset(&iov_len, size), (1, 2)); +fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) { + let mut size = skip_size; + let mut nr_skip = 0; + + for len in iov_lens { + if size >= *len { + size -= *len; + nr_skip += 1; + } else { + break; + } + } + (nr_skip, size) +} + +/// Common message header for vhost-user requests and replies. +/// A vhost-user message consists of 3 header fields and an optional payload. All numbers are in the +/// machine native byte order. +#[repr(packed)] +#[derive(Copy)] +pub(crate) struct VhostUserMsgHeader { + request: u32, + flags: u32, + size: u32, + _r: PhantomData, +} + +impl Debug for VhostUserMsgHeader { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Point") + .field("request", &{ self.request }) + .field("flags", &{ self.flags }) + .field("size", &{ self.size }) + .finish() + } +} + +impl VhostUserMsgValidator for VhostUserMsgHeader { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if !self.get_code().is_valid() { + return false; + } else if self.size as usize > MAX_MSG_SIZE { + return false; + } else if self.get_version() != 0x1 { + return false; + } else if (self.flags & VhostUserHeaderFlag::RESERVED_BITS.bits()) != 0 { + return false; + } + true + } +} + +impl Clone for VhostUserMsgHeader { + fn clone(&self) -> VhostUserMsgHeader { + *self + } +} + +impl VhostUserMsgHeader { + /// Create a new instance of `VhostUserMsgHeader`. + pub fn new(request: R, flags: u32, size: u32) -> Self { + // Default to protocol version 1 + let fl = (flags & VhostUserHeaderFlag::ALL_FLAGS.bits()) | 0x1; + VhostUserMsgHeader { + request: request.into(), + flags: fl, + size, + _r: PhantomData, + } + } + + /// Get message type. + pub fn get_code(&self) -> R { + // It's safe because R is marked as repr(u32). + unsafe { std::mem::transmute_copy::(&{ self.request }) } + } + + /// Get message version number. + pub fn get_version(&self) -> u32 { + self.flags & 0x3 + } +} + +impl Default for VhostUserMsgHeader { + fn default() -> Self { + VhostUserMsgHeader { + request: 0, + flags: 0x1, + size: 0, + _r: PhantomData, + } + } +} + +/// Unix domain socket endpoint for vhost-user connection. +pub(crate) struct Endpoint { + sock: UnixStream, + _r: PhantomData, +} + +impl Endpoint { + /// Create a new stream by connecting to server at `str`. + /// + /// # Return: + /// * - the new Endpoint object on success. + /// * - SocketConnect: failed to connect to peer. + pub fn connect(path: &str) -> Result { + let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?; + Ok(Self::from_stream(sock)) + } + + /// Create an endpoint from a stream object. + pub fn from_stream(sock: UnixStream) -> Self { + Endpoint { + sock, + _r: PhantomData, + } + } + + /// Sends bytes from scatter-gather vectors over the socket with optional attached file + /// descriptors. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result { + let rfds = match fds { + Some(rfds) => rfds, + _ => &[], + }; + self.sock.send_with_fds(iovs, rfds).map_err(Into::into) + } + + /// Sends all bytes from scatter-gather vectors over the socket with optional attached file + /// descriptors. Will loop until all data has been transfered. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result { + let mut data_sent = 0; + let mut data_total = 0; + let iov_lens: Vec = iovs.iter().map(|iov| iov.len()).collect(); + for len in &iov_lens { + data_total += len; + } + + while (data_total - data_sent) > 0 { + let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent); + let iov = &iovs[nr_skip][offset..]; + + let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat(); + let sfds = if data_sent == 0 { fds } else { None }; + + let sent = self.send_iovec(data, sfds); + match sent { + Ok(0) => return Ok(data_sent), + Ok(n) => data_sent += n, + Err(e) => match e { + Error::SocketRetry(_) => {} + _ => return Err(e), + }, + } + } + Ok(data_sent) + } + + /// Sends a header-only message with optional attached file descriptors. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + pub fn send_header( + &mut self, + hdr: &VhostUserMsgHeader, + fds: Option<&[RawFd]>, + ) -> Result<()> { + // Safe because there can't be other mutable referance to hdr. + let iovs = unsafe { + [slice::from_raw_parts( + hdr as *const VhostUserMsgHeader as *const u8, + mem::size_of::>(), + )] + }; + let bytes = self.send_iovec_all(&iovs[..], fds)?; + if bytes != mem::size_of::>() { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Send a message with header and body. Optional file descriptors may be attached to + /// the message. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + pub fn send_message( + &mut self, + hdr: &VhostUserMsgHeader, + body: &T, + fds: Option<&[RawFd]>, + ) -> Result<()> { + // Safe because there can't be other mutable referance to hdr and body. + let iovs = unsafe { + [ + slice::from_raw_parts( + hdr as *const VhostUserMsgHeader as *const u8, + mem::size_of::>(), + ), + slice::from_raw_parts(body as *const T as *const u8, mem::size_of::()), + ] + }; + + let bytes = self.send_iovec_all(&iovs[..], fds)?; + if bytes != mem::size_of::>() + mem::size_of::() { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Reads bytes from the socket into the given scatter/gather vectors with optional attached + /// file descriptors. + /// + /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little + /// tricky to pass file descriptors through such a communication channel. Let's assume that a + /// sender sending a message with some file descriptors attached. To successfully receive those + /// attached file descriptors, the receiver must obey following rules: + /// 1) file descriptors are attached to a message. + /// 2) message(packet) boundaries must be respected on the receive side. + /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the + /// attached file descriptors will get lost. + /// + /// # Return: + /// * - (number of bytes received, [received fds]) on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option>)> { + let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES]; + let (bytes, fds) = unsafe { self.sock.recv_with_fds(iovs, &mut fd_array)? }; + let rfds = match fds { + 0 => None, + n => { + let mut fds = Vec::with_capacity(n); + fds.extend_from_slice(&fd_array[0..n]); + Some(fds) + } + }; + + Ok((bytes, rfds)) + } + + /// Reads all bytes from the socket into the given scatter/gather vectors with optional + /// attached file descriptors. Will loop until all data has been transfered. + /// + /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little + /// tricky to pass file descriptors through such a communication channel. Let's assume that a + /// sender sending a message with some file descriptors attached. To successfully receive those + /// attached file descriptors, the receiver must obey following rules: + /// 1) file descriptors are attached to a message. + /// 2) message(packet) boundaries must be respected on the receive side. + /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the + /// attached file descriptors will get lost. + /// + /// # Return: + /// * - (number of bytes received, [received fds]) on success + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_into_iovec_all( + &mut self, + iovs: &mut [iovec], + ) -> Result<(usize, Option>)> { + let mut data_read = 0; + let mut data_total = 0; + let mut rfds = None; + let iov_lens: Vec = iovs.iter().map(|iov| iov.iov_len).collect(); + for len in &iov_lens { + data_total += len; + } + + while (data_total - data_read) > 0 { + let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_read); + let iov = &mut iovs[nr_skip]; + + let mut data = [ + &[iovec { + iov_base: (iov.iov_base as usize + offset) as *mut c_void, + iov_len: iov.iov_len - offset, + }], + &iovs[(nr_skip + 1)..], + ] + .concat(); + + let res = self.recv_into_iovec(&mut data); + match res { + Ok((0, _)) => return Ok((data_read, rfds)), + Ok((n, fds)) => { + if data_read == 0 { + rfds = fds; + } + data_read += n; + } + Err(e) => match e { + Error::SocketRetry(_) => {} + _ => return Err(e), + }, + } + } + Ok((data_read, rfds)) + } + + /// Receive a header-only message with optional attached file descriptors. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader, Option>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut iovs = [iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader) as *mut c_void, + iov_len: mem::size_of::>(), + }]; + let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + + if bytes != mem::size_of::>() { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, rfds)) + } + + /// Receive a message with optional attached file descriptors. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, message body, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + pub fn recv_body( + &mut self, + ) -> Result<(VhostUserMsgHeader, T, Option>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut body: T = Default::default(); + let mut iovs = [ + iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader) as *mut c_void, + iov_len: mem::size_of::>(), + }, + iovec { + iov_base: (&mut body as *mut T) as *mut c_void, + iov_len: mem::size_of::(), + }, + ]; + let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + + let total = mem::size_of::>() + mem::size_of::(); + if bytes != total { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() || !body.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, body, rfds)) + } + + /// Send a message with header, body and payload. Optional file descriptors + /// may also be attached to the message. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - OversizedMsg: message size is too big. + /// * - PartialMessage: received a partial message. + /// * - IncorrectFds: wrong number of attached fds. + pub fn send_message_with_payload( + &mut self, + hdr: &VhostUserMsgHeader, + body: &T, + payload: &[P], + fds: Option<&[RawFd]>, + ) -> Result<()> { + let len = payload.len() * mem::size_of::

(); + if len > MAX_MSG_SIZE - mem::size_of::() { + return Err(Error::OversizedMsg); + } + if let Some(fd_arr) = fds { + if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES { + return Err(Error::IncorrectFds); + } + } + + // Safe because there can't be other mutable reference to hdr, body and payload. + let iovs = unsafe { + [ + slice::from_raw_parts( + hdr as *const VhostUserMsgHeader as *const u8, + mem::size_of::>(), + ), + slice::from_raw_parts(body as *const T as *const u8, mem::size_of::()), + slice::from_raw_parts(payload.as_ptr() as *const u8, len), + ] + }; + let total = mem::size_of::>() + mem::size_of::() + len; + let len = self.send_iovec_all(&iovs, fds)?; + if len != total { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Receive a message with optional payload and attached file descriptors. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, message body, size of payload, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + #[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))] + pub fn recv_payload_into_buf( + &mut self, + buf: &mut [u8], + ) -> Result<(VhostUserMsgHeader, T, usize, Option>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut body: T = Default::default(); + let mut iovs = [ + iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader) as *mut c_void, + iov_len: mem::size_of::>(), + }, + iovec { + iov_base: (&mut body as *mut T) as *mut c_void, + iov_len: mem::size_of::(), + }, + iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }, + ]; + let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + + let total = mem::size_of::>() + mem::size_of::(); + if bytes < total { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() || !body.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, body, bytes - total, rfds)) + } +} + +impl AsRawFd for Endpoint { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } +} + +// Negotiate process from slave. +pub(crate) fn negotiate_slave( + slave: &mut Endpoint, + pfeatures: VhostUserProtocolFeatures, + use_ali_feature: bool, + has_protocol_mq: bool, + queue_num: u64, +) { + // set owner + let (hdr, rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert!(rfds.is_none()); + + // get features + let vfeatures = 0x15 | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(vfeatures); + slave.send_message(&hdr, &msg, None).unwrap(); + let (hdr, _rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::GET_FEATURES); + + // set features + let (hdr, _msg, rfds) = slave.recv_body::().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_FEATURES); + assert!(rfds.is_none()); + + // get vhost-user protocol features + let code = MasterReq::GET_PROTOCOL_FEATURES; + let (hdr, rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr.get_code(), code); + assert!(rfds.is_none()); + let hdr = VhostUserMsgHeader::new(code, 0x4, 8); + let msg = VhostUserU64::new(pfeatures.bits()); + slave.send_message(&hdr, &msg, None).unwrap(); + + // set vhost-user protocol features + let (hdr, _msg, rfds) = slave.recv_body::().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_PROTOCOL_FEATURES); + assert!(rfds.is_none()); + + // set number of queues + if has_protocol_mq { + let (hdr, rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::GET_QUEUE_NUM); + assert!(rfds.is_none()); + let hdr = VhostUserMsgHeader::new(MasterReq::GET_QUEUE_NUM, 0x4, 8); + let msg = VhostUserU64::new(queue_num); + slave.send_message(&hdr, &msg, None).unwrap(); + } + + // set vring call + for _i in 0..queue_num { + let (hdr, _msg, rfds) = slave.recv_body::().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_VRING_CALL); + assert!(rfds.is_some()); + } + + // set mem table + let mut region_buf: Vec = vec![0u8; mem::size_of::()]; + let (hdr, _msg, _payload, rfds) = slave + .recv_payload_into_buf::(&mut region_buf) + .unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_MEM_TABLE); + assert!(rfds.is_some()); + + if pfeatures.contains(VhostUserProtocolFeatures::INFLIGHT_SHMFD) { + // get inflight fd + let (hdr, _msg, rfds) = slave.recv_body::().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::GET_INFLIGHT_FD); + assert!(rfds.is_none()); + let msg = VhostUserInflight { + mmap_size: 0x100, + mmap_offset: 0x0, + ..Default::default() + }; + let inflight_file = TempFile::new().unwrap().into_file(); + inflight_file.set_len(0x100).unwrap(); + let fds = [inflight_file.as_raw_fd()]; + let hdr = VhostUserMsgHeader::new( + MasterReq::GET_INFLIGHT_FD, + VhostUserHeaderFlag::REPLY.bits(), + std::mem::size_of::() as u32, + ); + slave.send_message(&hdr, &msg, Some(&fds)).unwrap(); + + // set inflight fd + let (hdr, _msg, rfds) = slave.recv_body::().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_INFLIGHT_FD); + assert!(rfds.is_some()); + let hdr = VhostUserMsgHeader::new( + MasterReq::GET_INFLIGHT_FD, + VhostUserHeaderFlag::REPLY.bits(), + std::mem::size_of::() as u32, + ); + slave.send_header(&hdr, None).unwrap(); + } + + // set vring num + for _i in 0..queue_num { + let (hdr, _msg, rfds) = slave.recv_body::().unwrap(); + slave.send_header(&hdr, None).unwrap(); + assert!(rfds.is_none()); + } + + // set vring base + for _i in 0..queue_num { + let (hdr, _msg, rfds) = slave.recv_body::().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_VRING_BASE); + assert!(rfds.is_none()); + slave.send_header(&hdr, None).unwrap(); + } + + // set vring addr + for _i in 0..queue_num { + let (hdr, _msg, rfds) = slave.recv_body::().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_VRING_ADDR); + assert!(rfds.is_none()); + slave.send_header(&hdr, None).unwrap(); + } + + // set vring kick + for _i in 0..queue_num { + let (hdr, _msg, rfds) = slave.recv_body::().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_VRING_KICK); + assert!(rfds.is_some()); + } + + // set vring call + for _i in 0..queue_num { + let (hdr, _msg, rfds) = slave.recv_body::().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_VRING_CALL); + assert!(rfds.is_some()); + } + + // set vring enable + for _i in 0..queue_num { + let (hdr, _msg, rfds) = slave.recv_body::().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_VRING_ENABLE); + assert!(rfds.is_none()); + slave.send_header(&hdr, None).unwrap(); + } +}