mirror of
https://github.com/kata-containers/kata-containers.git
synced 2025-09-16 06:18:58 +00:00
Merge pull request #8450 from adamqqqplay/vhost-user-general
dragonball: add vhost-user connection management logic
This commit is contained in:
@@ -53,3 +53,4 @@ virtio-mem = ["virtio-mmio"]
|
||||
virtio-balloon = ["virtio-mmio"]
|
||||
vhost = ["virtio-mmio", "vhost-rs/vhost-user-master", "vhost-rs/vhost-kern"]
|
||||
vhost-net = ["vhost", "vhost-rs/vhost-net"]
|
||||
vhost-user = ["vhost"]
|
@@ -125,6 +125,32 @@ pub enum ActivateError {
|
||||
InvalidQueueConfig,
|
||||
#[error("IO: {0}.")]
|
||||
IOError(#[from] IOError),
|
||||
#[error("Virtio error")]
|
||||
VirtioError(Error),
|
||||
#[error("Epoll manager error")]
|
||||
EpollMgr(dbs_utils::epoll_manager::Error),
|
||||
#[cfg(feature = "vhost")]
|
||||
#[error("Vhost activate error")]
|
||||
VhostActivate(vhost_rs::Error),
|
||||
}
|
||||
|
||||
impl std::convert::From<Error> for ActivateError {
|
||||
fn from(error: Error) -> ActivateError {
|
||||
ActivateError::VirtioError(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::From<dbs_utils::epoll_manager::Error> for ActivateError {
|
||||
fn from(error: dbs_utils::epoll_manager::Error) -> ActivateError {
|
||||
ActivateError::EpollMgr(error)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "vhost")]
|
||||
impl std::convert::From<vhost_rs::Error> for ActivateError {
|
||||
fn from(error: vhost_rs::Error) -> ActivateError {
|
||||
ActivateError::VhostActivate(error)
|
||||
}
|
||||
}
|
||||
|
||||
/// Error code for VirtioDevice::read_config()/write_config().
|
||||
@@ -155,6 +181,9 @@ pub enum Error {
|
||||
/// Guest gave us a descriptor that was too big to use.
|
||||
#[error("descriptor length too big.")]
|
||||
DescriptorLengthTooBig,
|
||||
/// Error from the epoll event manager
|
||||
#[error("dbs_utils error: {0:?}.")]
|
||||
EpollMgr(dbs_utils::epoll_manager::Error),
|
||||
/// Guest gave us a write only descriptor that protocol says to read from.
|
||||
#[error("unexpected write only descriptor.")]
|
||||
UnexpectedWriteOnlyDescriptor,
|
||||
@@ -181,7 +210,7 @@ pub enum Error {
|
||||
VirtioQueueError(#[from] VqError),
|
||||
/// Error from Device activate.
|
||||
#[error("Device activate error: {0}")]
|
||||
ActivateError(#[from] ActivateError),
|
||||
ActivateError(#[from] Box<ActivateError>),
|
||||
/// Error from Interrupt.
|
||||
#[error("Interrupt error: {0}")]
|
||||
InterruptError(IOError),
|
||||
@@ -229,6 +258,15 @@ pub enum Error {
|
||||
#[cfg(feature = "virtio-balloon")]
|
||||
#[error("Virtio-balloon error: {0}")]
|
||||
VirtioBalloonError(#[from] balloon::BalloonError),
|
||||
|
||||
#[cfg(feature = "vhost")]
|
||||
/// Error from the vhost subsystem
|
||||
#[error("Vhost error: {0:?}")]
|
||||
VhostError(vhost_rs::Error),
|
||||
#[cfg(feature = "vhost")]
|
||||
/// Error from the vhost user subsystem
|
||||
#[error("Vhost-user error: {0:?}")]
|
||||
VhostUserError(vhost_rs::vhost_user::Error),
|
||||
}
|
||||
|
||||
// Error for tap devices
|
||||
|
@@ -124,7 +124,9 @@ where
|
||||
// If the driver incorrectly sets up the queues, the following check will fail and take
|
||||
// the device into an unusable state.
|
||||
if !self.check_queues_valid() {
|
||||
return Err(Error::ActivateError(ActivateError::InvalidQueueConfig));
|
||||
return Err(Error::ActivateError(Box::new(
|
||||
ActivateError::InvalidQueueConfig,
|
||||
)));
|
||||
}
|
||||
|
||||
self.register_ioevent()?;
|
||||
@@ -138,7 +140,7 @@ where
|
||||
.map(|_| self.device_activated = true)
|
||||
.map_err(|e| {
|
||||
error!("device activate error: {:?}", e);
|
||||
Error::ActivateError(e)
|
||||
Error::ActivateError(Box::new(e))
|
||||
})
|
||||
}
|
||||
|
||||
|
@@ -6,3 +6,21 @@
|
||||
|
||||
#[cfg(feature = "vhost-net")]
|
||||
pub mod vhost_kern;
|
||||
|
||||
pub use vhost_rs::vhost_user::Error as VhostUserError;
|
||||
pub use vhost_rs::Error as VhostError;
|
||||
|
||||
#[cfg(feature = "vhost-user")]
|
||||
pub mod vhost_user;
|
||||
|
||||
impl std::convert::From<VhostError> for super::Error {
|
||||
fn from(e: VhostError) -> Self {
|
||||
super::Error::VhostError(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::From<VhostUserError> for super::Error {
|
||||
fn from(e: VhostUserError) -> Self {
|
||||
super::Error::VhostUserError(e)
|
||||
}
|
||||
}
|
||||
|
@@ -290,7 +290,7 @@ where
|
||||
"{}: Invalid virtio queue pairs, expected a value greater than 0, but got {}",
|
||||
NET_DRIVER_NAME, self.vq_pairs
|
||||
);
|
||||
return Err(VirtioError::ActivateError(ActivateError::InvalidParam));
|
||||
return Err(VirtioError::ActivateError(Box::new(ActivateError::InvalidParam)));
|
||||
}
|
||||
|
||||
if self.handles.len() != self.vq_pairs || self.taps.len() != self.vq_pairs {
|
||||
@@ -299,7 +299,7 @@ where
|
||||
self.handles.len(),
|
||||
self.taps.len(),
|
||||
self.vq_pairs);
|
||||
return Err(VirtioError::ActivateError(ActivateError::InternalError));
|
||||
return Err(VirtioError::ActivateError(Box::new(ActivateError::InternalError)));
|
||||
}
|
||||
|
||||
for idx in 0..self.vq_pairs {
|
||||
|
@@ -0,0 +1,552 @@
|
||||
// Copyright (C) 2019-2023 Alibaba Cloud. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//! Helper utilities for vhost-user communication channel.
|
||||
|
||||
use std::ops::Deref;
|
||||
use std::os::unix::io::{AsRawFd, RawFd};
|
||||
|
||||
use dbs_utils::epoll_manager::{EventOps, EventSet, Events};
|
||||
use log::*;
|
||||
use vhost_rs::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVringAddrFlags};
|
||||
use vhost_rs::vhost_user::{
|
||||
Error as VhostUserError, Listener as VhostUserListener, Master, VhostUserMaster,
|
||||
};
|
||||
use vhost_rs::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
|
||||
use virtio_queue::QueueT;
|
||||
use vm_memory::{
|
||||
Address, GuestAddress, GuestAddressSpace, GuestMemory, GuestMemoryRegion, MemoryRegionAddress,
|
||||
};
|
||||
use vmm_sys_util::eventfd::EventFd;
|
||||
|
||||
use super::super::super::device::VirtioDeviceConfig;
|
||||
use super::super::super::{Error as VirtioError, Result as VirtioResult};
|
||||
use super::VhostError;
|
||||
|
||||
enum EndpointProtocolFlags {
|
||||
ProtocolMq = 1,
|
||||
}
|
||||
|
||||
pub(super) struct Listener {
|
||||
listener: VhostUserListener,
|
||||
/// Slot to register epoll event for the underlying socket.
|
||||
slot: u32,
|
||||
name: String,
|
||||
path: String,
|
||||
}
|
||||
|
||||
impl Listener {
|
||||
pub fn new(name: String, path: String, force: bool, slot: u32) -> VirtioResult<Self> {
|
||||
info!("vhost-user: create listener at {} for {}", path, name);
|
||||
Ok(Listener {
|
||||
listener: VhostUserListener::new(&path, force)?,
|
||||
slot,
|
||||
name,
|
||||
path,
|
||||
})
|
||||
}
|
||||
|
||||
// Wait for an incoming connection until success.
|
||||
pub fn accept(&self) -> VirtioResult<(Master, u64)> {
|
||||
loop {
|
||||
match self.try_accept() {
|
||||
Ok(Some((master, feature))) => return Ok((master, feature)),
|
||||
Ok(None) => continue,
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_accept(&self) -> VirtioResult<Option<(Master, u64)>> {
|
||||
let sock = match self.listener.accept() {
|
||||
Ok(Some(conn)) => conn,
|
||||
Ok(None) => return Ok(None),
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
let mut master = Master::from_stream(sock, 1);
|
||||
info!("{}: try to get virtio features from slave.", self.name);
|
||||
match Endpoint::initialize(&mut master) {
|
||||
Ok(Some(features)) => Ok(Some((master, features))),
|
||||
// The new connection has been closed, try again.
|
||||
Ok(None) => {
|
||||
warn!(
|
||||
"{}: new connection get closed during initialization, waiting for another one.",
|
||||
self.name
|
||||
);
|
||||
Ok(None)
|
||||
}
|
||||
// Unrecoverable error happened
|
||||
Err(e) => {
|
||||
error!("{}: failed to get virtio features, {}", self.name, e);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Register the underlying listener to be monitored for incoming connection.
|
||||
pub fn register_epoll_event(&self, ops: &mut EventOps) -> VirtioResult<()> {
|
||||
info!("{}: monitor incoming connect at {}", self.name, self.path);
|
||||
// Switch to nonblocking mode.
|
||||
self.listener.set_nonblocking(true)?;
|
||||
let event = Events::with_data(&self.listener, self.slot, EventSet::IN);
|
||||
ops.add(event).map_err(VirtioError::EpollMgr)
|
||||
}
|
||||
}
|
||||
|
||||
/// Struct to pass info to vhost user backend
|
||||
#[derive(Clone)]
|
||||
pub struct BackendInfo {
|
||||
/// -1 means to tell backend to destroy corresponding
|
||||
/// device, while others means construct it
|
||||
fd: i32,
|
||||
/// cluster id of device, must set
|
||||
cluster_id: u32,
|
||||
/// device id of device, must set
|
||||
device_id: u64,
|
||||
/// device config file path
|
||||
filename: [u8; 128],
|
||||
}
|
||||
|
||||
/// Struct to pass function parameters to methods of Endpoint.
|
||||
pub(super) struct EndpointParam<'a, AS: GuestAddressSpace, Q: QueueT, R: GuestMemoryRegion> {
|
||||
pub virtio_config: &'a VirtioDeviceConfig<AS, Q, R>,
|
||||
pub intr_evts: Vec<&'a EventFd>,
|
||||
pub queue_sizes: &'a [u16],
|
||||
pub features: u64,
|
||||
pub protocol_flag: u16,
|
||||
pub dev_protocol_features: VhostUserProtocolFeatures,
|
||||
pub reconnect: bool,
|
||||
pub backend: Option<BackendInfo>,
|
||||
pub init_queues: u32,
|
||||
pub slave_req_fd: Option<RawFd>,
|
||||
}
|
||||
|
||||
impl<'a, AS: GuestAddressSpace, Q: QueueT, R: GuestMemoryRegion> EndpointParam<'a, AS, Q, R> {
|
||||
fn get_host_address(&self, addr: GuestAddress, mem: &AS::M) -> VirtioResult<*mut u8> {
|
||||
mem.get_host_address(addr)
|
||||
.map_err(|_| VirtioError::InvalidGuestAddress(addr))
|
||||
}
|
||||
|
||||
/// set protocol multi-queue bit
|
||||
pub fn set_protocol_mq(&mut self) {
|
||||
self.protocol_flag |= EndpointProtocolFlags::ProtocolMq as u16;
|
||||
}
|
||||
|
||||
/// check if multi-queue bit is set
|
||||
pub fn has_protocol_mq(&self) -> bool {
|
||||
(self.protocol_flag & (EndpointProtocolFlags::ProtocolMq as u16)) != 0
|
||||
}
|
||||
}
|
||||
|
||||
/// Communication channel from the master to the slave.
|
||||
///
|
||||
/// It encapsulates a low-level vhost-user master side communication endpoint, and provides
|
||||
/// connection initialization, monitoring and reconnect functionalities for vhost-user devices.
|
||||
///
|
||||
/// Caller needs to ensure mutual exclusive access to the object.
|
||||
pub(super) struct Endpoint {
|
||||
/// Underlying vhost-user communication endpoint.
|
||||
conn: Option<Master>,
|
||||
old: Option<Master>,
|
||||
/// Token to register epoll event for the underlying socket.
|
||||
slot: u32,
|
||||
/// Identifier string for logs.
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl Endpoint {
|
||||
pub fn new(master: Master, slot: u32, name: String) -> Self {
|
||||
Endpoint {
|
||||
conn: Some(master),
|
||||
old: None,
|
||||
slot,
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
/// First state of the connection negotiation between the master and the slave.
|
||||
///
|
||||
/// If Ok(None) is returned, the underlying communication channel gets broken and the caller may
|
||||
/// try to recreate the communication channel and negotiate again.
|
||||
///
|
||||
/// # Return
|
||||
/// * - Ok(Some(avial_features)): virtio features from the slave
|
||||
/// * - Ok(None): underlying communicaiton channel gets broken during negotiation
|
||||
/// * - Err(e): error conditions
|
||||
fn initialize(master: &mut Master) -> VirtioResult<Option<u64>> {
|
||||
// 1. Seems that some vhost-user slaves depend on the get_features request to driver its
|
||||
// internal state machine.
|
||||
// N.B. it's really TDD, we just found it works in this way. Any spec about this?
|
||||
let features = match master.get_features() {
|
||||
Ok(val) => val,
|
||||
Err(VhostError::VhostUserProtocol(VhostUserError::SocketBroken(_e))) => {
|
||||
return Ok(None)
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
Ok(Some(features))
|
||||
}
|
||||
|
||||
pub fn update_memory<AS: GuestAddressSpace>(&mut self, vm_as: &AS) -> VirtioResult<()> {
|
||||
let master = match self.conn.as_mut() {
|
||||
Some(conn) => conn,
|
||||
None => {
|
||||
error!("vhost user master is None!");
|
||||
return Err(VirtioError::InternalError);
|
||||
}
|
||||
};
|
||||
let guard = vm_as.memory();
|
||||
let mem = guard.deref();
|
||||
let mut regions = Vec::new();
|
||||
for region in mem.iter() {
|
||||
let guest_phys_addr = region.start_addr();
|
||||
let file_offset = region.file_offset().ok_or_else(|| {
|
||||
error!("region file_offset get error!");
|
||||
VirtioError::InvalidGuestAddress(guest_phys_addr)
|
||||
})?;
|
||||
let userspace_addr = region
|
||||
.get_host_address(MemoryRegionAddress(0))
|
||||
.map_err(|e| {
|
||||
error!("get_host_address error! {:?}", e);
|
||||
VirtioError::InvalidGuestAddress(guest_phys_addr)
|
||||
})?;
|
||||
|
||||
regions.push(VhostUserMemoryRegionInfo {
|
||||
guest_phys_addr: guest_phys_addr.raw_value() as u64,
|
||||
memory_size: region.len() as u64,
|
||||
userspace_addr: userspace_addr as *const u8 as u64,
|
||||
mmap_offset: file_offset.start(),
|
||||
mmap_handle: file_offset.file().as_raw_fd(),
|
||||
});
|
||||
}
|
||||
master.set_mem_table(®ions)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Drive the negotiation and initialization process with the vhost-user slave.
|
||||
pub fn negotiate<AS: GuestAddressSpace, Q: QueueT, R: GuestMemoryRegion>(
|
||||
&mut self,
|
||||
config: &EndpointParam<AS, Q, R>,
|
||||
mut old: Option<&mut Master>,
|
||||
) -> VirtioResult<()> {
|
||||
let guard = config.virtio_config.lock_guest_memory();
|
||||
let mem = guard.deref();
|
||||
let queue_num = config.virtio_config.queues.len();
|
||||
assert_eq!(queue_num, config.queue_sizes.len());
|
||||
assert_eq!(queue_num, config.intr_evts.len());
|
||||
|
||||
let master = match self.conn.as_mut() {
|
||||
Some(conn) => conn,
|
||||
None => return Err(VirtioError::InternalError),
|
||||
};
|
||||
|
||||
info!("{}: negotiate()", self.name);
|
||||
master.set_owner()?;
|
||||
info!("{}: set_owner()", self.name);
|
||||
|
||||
// 3. query features again after set owner.
|
||||
let features = master.get_features()?;
|
||||
info!("{}: get_features({:X})", self.name, features);
|
||||
|
||||
// 4. set virtio features.
|
||||
master.set_features(config.features)?;
|
||||
info!("{}: set_features({:X})", self.name, config.features);
|
||||
|
||||
// 5. set vhost-user protocol features
|
||||
// typical protocol features: 0x37
|
||||
let mut protocol_features = master.get_protocol_features()?;
|
||||
info!(
|
||||
"{}: get_protocol_features({:X})",
|
||||
self.name, protocol_features
|
||||
);
|
||||
// There are two virtque for rx/tx.
|
||||
if config.has_protocol_mq() && !protocol_features.contains(VhostUserProtocolFeatures::MQ) {
|
||||
return Err(VhostError::VhostUserProtocol(VhostUserError::FeatureMismatch).into());
|
||||
}
|
||||
protocol_features &= config.dev_protocol_features;
|
||||
master.set_protocol_features(protocol_features)?;
|
||||
info!(
|
||||
"{}: set_protocol_features({:X}), dev_protocol_features({:X})",
|
||||
self.name, protocol_features, config.dev_protocol_features
|
||||
);
|
||||
|
||||
// Setup slave channel if SLAVE_REQ protocol feature is set
|
||||
if protocol_features.contains(VhostUserProtocolFeatures::SLAVE_REQ) {
|
||||
match config.slave_req_fd {
|
||||
Some(fd) => master.set_slave_request_fd(&fd)?,
|
||||
None => {
|
||||
error!(
|
||||
"{}: Protocol feature SLAVE_REQ is set but not slave channel fd",
|
||||
self.name
|
||||
);
|
||||
return Err(VhostError::VhostUserProtocol(VhostUserError::InvalidParam).into());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info!("{}: has no SLAVE_REQ protocol feature set", self.name);
|
||||
}
|
||||
|
||||
// 6. check number of queues supported
|
||||
if config.has_protocol_mq() {
|
||||
let queue_num = master.get_queue_num()?;
|
||||
info!("{}: get_queue_num({:X})", self.name, queue_num);
|
||||
if queue_num < config.queue_sizes.len() as u64 {
|
||||
return Err(VhostError::VhostUserProtocol(VhostUserError::FeatureMismatch).into());
|
||||
}
|
||||
}
|
||||
|
||||
// 7. trigger the backend state machine.
|
||||
for queue_index in 0..queue_num {
|
||||
master.set_vring_call(queue_index, config.intr_evts[queue_index])?;
|
||||
}
|
||||
info!("{}: set_vring_call()", self.name);
|
||||
|
||||
// 8. set mem_table
|
||||
let mut regions = Vec::new();
|
||||
for region in mem.iter() {
|
||||
let guest_phys_addr = region.start_addr();
|
||||
let file_offset = region
|
||||
.file_offset()
|
||||
.ok_or(VirtioError::InvalidGuestAddress(guest_phys_addr))?;
|
||||
let userspace_addr = region
|
||||
.get_host_address(MemoryRegionAddress(0))
|
||||
.map_err(|_| VirtioError::InvalidGuestAddress(guest_phys_addr))?;
|
||||
|
||||
regions.push(VhostUserMemoryRegionInfo {
|
||||
guest_phys_addr: guest_phys_addr.raw_value() as u64,
|
||||
memory_size: region.len() as u64,
|
||||
userspace_addr: userspace_addr as *const u8 as u64,
|
||||
mmap_offset: file_offset.start(),
|
||||
mmap_handle: file_offset.file().as_raw_fd(),
|
||||
});
|
||||
}
|
||||
master.set_mem_table(®ions)?;
|
||||
info!("{}: set_mem_table()", self.name);
|
||||
|
||||
// 9. setup vrings
|
||||
for queue_cfg in config.virtio_config.queues.iter() {
|
||||
master.set_vring_num(queue_cfg.index() as usize, queue_cfg.actual_size() as u16)?;
|
||||
info!(
|
||||
"{}: set_vring_num(idx: {}, size: {})",
|
||||
self.name,
|
||||
queue_cfg.index(),
|
||||
queue_cfg.actual_size(),
|
||||
);
|
||||
}
|
||||
// On reconnection, the slave may have processed some packets in virtque and queue
|
||||
// base is not zero any more. So don't set queue base on reconnection.
|
||||
// N.B. it's really TDD, we just found it works in this way. Any spec about this?
|
||||
for queue_index in 0..queue_num {
|
||||
let base = if old.is_some() {
|
||||
let conn = old.as_mut().unwrap();
|
||||
match conn.get_vring_base(queue_index) {
|
||||
Ok(val) => Some(val),
|
||||
Err(_) => None,
|
||||
}
|
||||
} else if !config.reconnect {
|
||||
Some(0)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if let Some(val) = base {
|
||||
master.set_vring_base(queue_index, val as u16)?;
|
||||
info!(
|
||||
"{}: set_vring_base(idx: {}, base: {})",
|
||||
self.name, queue_index, val
|
||||
);
|
||||
}
|
||||
}
|
||||
for queue_cfg in config.virtio_config.queues.iter() {
|
||||
let queue = &queue_cfg.queue;
|
||||
let queue_index = queue_cfg.index() as usize;
|
||||
let desc_addr =
|
||||
config.get_host_address(vm_memory::GuestAddress(queue.desc_table()), mem)?;
|
||||
let used_addr =
|
||||
config.get_host_address(vm_memory::GuestAddress(queue.used_ring()), mem)?;
|
||||
let avail_addr =
|
||||
config.get_host_address(vm_memory::GuestAddress(queue.avail_ring()), mem)?;
|
||||
master.set_vring_addr(
|
||||
queue_index,
|
||||
&VringConfigData {
|
||||
queue_max_size: queue.max_size(),
|
||||
queue_size: queue_cfg.actual_size(),
|
||||
flags: VhostUserVringAddrFlags::empty().bits(),
|
||||
desc_table_addr: desc_addr as u64,
|
||||
used_ring_addr: used_addr as u64,
|
||||
avail_ring_addr: avail_addr as u64,
|
||||
log_addr: None,
|
||||
},
|
||||
)?;
|
||||
info!(
|
||||
"{}: set_vring_addr(idx: {}, addr: {:p})",
|
||||
self.name, queue_index, desc_addr
|
||||
);
|
||||
}
|
||||
for queue_index in 0..queue_num {
|
||||
master.set_vring_kick(
|
||||
queue_index,
|
||||
&config.virtio_config.queues[queue_index].eventfd,
|
||||
)?;
|
||||
info!(
|
||||
"{}: set_vring_kick(idx: {}, fd: {})",
|
||||
self.name,
|
||||
queue_index,
|
||||
config.virtio_config.queues[queue_index].eventfd.as_raw_fd()
|
||||
);
|
||||
}
|
||||
for queue_index in 0..queue_num {
|
||||
let intr_index = if config.intr_evts.len() == 1 {
|
||||
0
|
||||
} else {
|
||||
queue_index
|
||||
};
|
||||
master.set_vring_call(queue_index, config.intr_evts[intr_index])?;
|
||||
info!(
|
||||
"{}: set_vring_call(idx: {}, fd: {})",
|
||||
self.name,
|
||||
queue_index,
|
||||
config.intr_evts[intr_index].as_raw_fd()
|
||||
);
|
||||
}
|
||||
for queue_index in 0..queue_num {
|
||||
master.set_vring_enable(queue_index, true)?;
|
||||
info!(
|
||||
"{}: set_vring_enable(idx: {}, enable: {})",
|
||||
self.name, queue_index, true
|
||||
);
|
||||
if (queue_index + 1) == config.init_queues as usize {
|
||||
break;
|
||||
}
|
||||
}
|
||||
info!("{}: protocol negotiate completed successfully.", self.name);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn set_queues_attach(&mut self, curr_queues: u32) -> VirtioResult<()> {
|
||||
let master = match self.conn.as_mut() {
|
||||
Some(conn) => conn,
|
||||
None => return Err(VirtioError::InternalError),
|
||||
};
|
||||
|
||||
for index in 0..curr_queues {
|
||||
master.set_vring_enable(index as usize, true)?;
|
||||
info!(
|
||||
"{}: set_vring_enable(idx: {}, enable: {})",
|
||||
self.name, index, true
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Restore communication with the vhost-user slave on reconnect.
|
||||
pub fn reconnect<AS: GuestAddressSpace, Q: QueueT, R: GuestMemoryRegion>(
|
||||
&mut self,
|
||||
master: Master,
|
||||
config: &EndpointParam<AS, Q, R>,
|
||||
ops: &mut EventOps,
|
||||
) -> VirtioResult<()> {
|
||||
let mut old = self.conn.replace(master);
|
||||
if let Err(e) = self.negotiate(config, old.as_mut()) {
|
||||
error!("{}: failed to initialize connection: {}", self.name, e);
|
||||
self.conn = old;
|
||||
return Err(e);
|
||||
}
|
||||
if let Err(e) = self.register_epoll_event(ops) {
|
||||
error!("{}: failed to add fd to epoll: {}", self.name, e);
|
||||
self.conn = old;
|
||||
return Err(e);
|
||||
}
|
||||
self.old = old;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Teardown the communication channel to the vhost-user slave.
|
||||
pub fn disconnect(&mut self, ops: &mut EventOps) -> VirtioResult<()> {
|
||||
info!("vhost-user-net: disconnect communication channel.");
|
||||
match self.old.take() {
|
||||
Some(master) => {
|
||||
info!("close old connection");
|
||||
self.deregister_epoll_event(&master, ops)
|
||||
}
|
||||
None => match self.conn.take() {
|
||||
Some(master) => {
|
||||
info!("disconnect connection.");
|
||||
self.deregister_epoll_event(&master, ops)
|
||||
}
|
||||
None => {
|
||||
info!("get disconnect notification when it's already disconnected.");
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Register the underlying socket to be monitored for socket disconnect events.
|
||||
pub fn register_epoll_event(&self, ops: &mut EventOps) -> VirtioResult<()> {
|
||||
match self.conn.as_ref() {
|
||||
Some(master) => {
|
||||
info!(
|
||||
"{}: monitor disconnect event for fd {}.",
|
||||
self.name,
|
||||
master.as_raw_fd()
|
||||
);
|
||||
ops.add(Events::with_data(
|
||||
master,
|
||||
self.slot,
|
||||
EventSet::HANG_UP | EventSet::EDGE_TRIGGERED,
|
||||
))
|
||||
.map_err(VirtioError::EpollMgr)
|
||||
}
|
||||
None => Err(VirtioError::InternalError),
|
||||
}
|
||||
}
|
||||
|
||||
/// Deregister the underlying socket from the epoll controller.
|
||||
pub fn deregister_epoll_event(&self, master: &Master, ops: &mut EventOps) -> VirtioResult<()> {
|
||||
info!(
|
||||
"{}: unregister epoll event for fd {}.",
|
||||
self.name,
|
||||
master.as_raw_fd()
|
||||
);
|
||||
ops.remove(Events::with_data(
|
||||
master,
|
||||
self.slot,
|
||||
EventSet::HANG_UP | EventSet::EDGE_TRIGGERED,
|
||||
))
|
||||
.map_err(VirtioError::EpollMgr)
|
||||
}
|
||||
|
||||
pub fn set_master(&mut self, master: Master) {
|
||||
self.conn = Some(master);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_endpoint_flags() {
|
||||
assert_eq!(EndpointProtocolFlags::ProtocolMq as u16, 0x1);
|
||||
}
|
||||
|
||||
#[should_panic]
|
||||
#[test]
|
||||
fn test_connect_try_accept() {
|
||||
let listener = Listener::new(
|
||||
"test_listener".to_string(),
|
||||
"/tmp/test_vhost_listener".to_string(),
|
||||
true,
|
||||
1,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
listener.listener.set_nonblocking(true).unwrap();
|
||||
|
||||
assert!(listener.try_accept().is_err());
|
||||
}
|
||||
}
|
@@ -0,0 +1,11 @@
|
||||
// Copyright (C) 2019-2023 Alibaba Cloud. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//! Vhost-based virtio device backend implementations.
|
||||
|
||||
use super::VhostError;
|
||||
|
||||
pub mod connection;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_utils;
|
@@ -0,0 +1,750 @@
|
||||
// Copyright (C) 2021 Alibaba Cloud Computing. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
|
||||
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
use std::os::unix::io::{AsRawFd, RawFd};
|
||||
use std::os::unix::net::UnixStream;
|
||||
use std::{mem, slice};
|
||||
|
||||
use vmm_sys_util::tempfile::TempFile;
|
||||
use libc::{c_void, iovec};
|
||||
use vhost_rs::vhost_user::message::{
|
||||
VhostUserHeaderFlag, VhostUserInflight, VhostUserMemory, VhostUserMemoryRegion,
|
||||
VhostUserMsgValidator, VhostUserProtocolFeatures, VhostUserU64, VhostUserVirtioFeatures,
|
||||
VhostUserVringAddr, VhostUserVringState, MAX_MSG_SIZE,
|
||||
};
|
||||
use vhost_rs::vhost_user::Error;
|
||||
use vmm_sys_util::sock_ctrl_msg::ScmSocket;
|
||||
|
||||
pub const MAX_ATTACHED_FD_ENTRIES: usize = 32;
|
||||
|
||||
pub(crate) trait Req:
|
||||
Clone + Copy + Debug + PartialEq + Eq + PartialOrd + Ord + Into<u32>
|
||||
{
|
||||
fn is_valid(&self) -> bool;
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
/// Type of requests sending from masters to slaves.
|
||||
#[repr(u32)]
|
||||
#[allow(unused, non_camel_case_types)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum MasterReq {
|
||||
/// Null operation.
|
||||
NOOP = 0,
|
||||
/// Get from the underlying vhost implementation the features bit mask.
|
||||
GET_FEATURES = 1,
|
||||
/// Enable features in the underlying vhost implementation using a bit mask.
|
||||
SET_FEATURES = 2,
|
||||
/// Set the current Master as an owner of the session.
|
||||
SET_OWNER = 3,
|
||||
/// No longer used.
|
||||
RESET_OWNER = 4,
|
||||
/// Set the memory map regions on the slave so it can translate the vring addresses.
|
||||
SET_MEM_TABLE = 5,
|
||||
/// Set logging shared memory space.
|
||||
SET_LOG_BASE = 6,
|
||||
/// Set the logging file descriptor, which is passed as ancillary data.
|
||||
SET_LOG_FD = 7,
|
||||
/// Set the size of the queue.
|
||||
SET_VRING_NUM = 8,
|
||||
/// Set the addresses of the different aspects of the vring.
|
||||
SET_VRING_ADDR = 9,
|
||||
/// Set the base offset in the available vring.
|
||||
SET_VRING_BASE = 10,
|
||||
/// Get the available vring base offset.
|
||||
GET_VRING_BASE = 11,
|
||||
/// Set the event file descriptor for adding buffers to the vring.
|
||||
SET_VRING_KICK = 12,
|
||||
/// Set the event file descriptor to signal when buffers are used.
|
||||
SET_VRING_CALL = 13,
|
||||
/// Set the event file descriptor to signal when error occurs.
|
||||
SET_VRING_ERR = 14,
|
||||
/// Get the protocol feature bit mask from the underlying vhost implementation.
|
||||
GET_PROTOCOL_FEATURES = 15,
|
||||
/// Enable protocol features in the underlying vhost implementation.
|
||||
SET_PROTOCOL_FEATURES = 16,
|
||||
/// Query how many queues the backend supports.
|
||||
GET_QUEUE_NUM = 17,
|
||||
/// Signal slave to enable or disable corresponding vring.
|
||||
SET_VRING_ENABLE = 18,
|
||||
/// Ask vhost user backend to broadcast a fake RARP to notify the migration is terminated
|
||||
/// for guest that does not support GUEST_ANNOUNCE.
|
||||
SEND_RARP = 19,
|
||||
/// Set host MTU value exposed to the guest.
|
||||
NET_SET_MTU = 20,
|
||||
/// Set the socket file descriptor for slave initiated requests.
|
||||
SET_SLAVE_REQ_FD = 21,
|
||||
/// Send IOTLB messages with struct vhost_iotlb_msg as payload.
|
||||
IOTLB_MSG = 22,
|
||||
/// Set the endianness of a VQ for legacy devices.
|
||||
SET_VRING_ENDIAN = 23,
|
||||
/// Fetch the contents of the virtio device configuration space.
|
||||
GET_CONFIG = 24,
|
||||
/// Change the contents of the virtio device configuration space.
|
||||
SET_CONFIG = 25,
|
||||
/// Create a session for crypto operation.
|
||||
CREATE_CRYPTO_SESSION = 26,
|
||||
/// Close a session for crypto operation.
|
||||
CLOSE_CRYPTO_SESSION = 27,
|
||||
/// Advise slave that a migration with postcopy enabled is underway.
|
||||
POSTCOPY_ADVISE = 28,
|
||||
/// Advise slave that a transition to postcopy mode has happened.
|
||||
POSTCOPY_LISTEN = 29,
|
||||
/// Advise that postcopy migration has now completed.
|
||||
POSTCOPY_END = 30,
|
||||
/// Get a shared buffer from slave.
|
||||
GET_INFLIGHT_FD = 31,
|
||||
/// Send the shared inflight buffer back to slave
|
||||
SET_INFLIGHT_FD = 32,
|
||||
/// Upper bound of valid commands.
|
||||
MAX_CMD = 33,
|
||||
}
|
||||
|
||||
impl Into<u32> for MasterReq {
|
||||
fn into(self) -> u32 {
|
||||
self as u32
|
||||
}
|
||||
}
|
||||
|
||||
impl Req for MasterReq {
|
||||
fn is_valid(&self) -> bool {
|
||||
(*self > MasterReq::NOOP) && (*self < MasterReq::MAX_CMD)
|
||||
}
|
||||
}
|
||||
|
||||
// Given a slice of sizes and the `skip_size`, return the offset of `skip_size` in the slice.
|
||||
// For example:
|
||||
// let iov_lens = vec![4, 4, 5];
|
||||
// let size = 6;
|
||||
// assert_eq!(get_sub_iovs_offset(&iov_len, size), (1, 2));
|
||||
fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
|
||||
let mut size = skip_size;
|
||||
let mut nr_skip = 0;
|
||||
|
||||
for len in iov_lens {
|
||||
if size >= *len {
|
||||
size -= *len;
|
||||
nr_skip += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
(nr_skip, size)
|
||||
}
|
||||
|
||||
/// Common message header for vhost-user requests and replies.
|
||||
/// A vhost-user message consists of 3 header fields and an optional payload. All numbers are in the
|
||||
/// machine native byte order.
|
||||
#[repr(packed)]
|
||||
#[derive(Copy)]
|
||||
pub(crate) struct VhostUserMsgHeader<R: Req> {
|
||||
request: u32,
|
||||
flags: u32,
|
||||
size: u32,
|
||||
_r: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl<R: Req> Debug for VhostUserMsgHeader<R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Point")
|
||||
.field("request", &{ self.request })
|
||||
.field("flags", &{ self.flags })
|
||||
.field("size", &{ self.size })
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Req> VhostUserMsgValidator for VhostUserMsgHeader<T> {
|
||||
#[allow(clippy::if_same_then_else)]
|
||||
fn is_valid(&self) -> bool {
|
||||
if !self.get_code().is_valid() {
|
||||
return false;
|
||||
} else if self.size as usize > MAX_MSG_SIZE {
|
||||
return false;
|
||||
} else if self.get_version() != 0x1 {
|
||||
return false;
|
||||
} else if (self.flags & VhostUserHeaderFlag::RESERVED_BITS.bits()) != 0 {
|
||||
return false;
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Req> Clone for VhostUserMsgHeader<R> {
|
||||
fn clone(&self) -> VhostUserMsgHeader<R> {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Req> VhostUserMsgHeader<R> {
|
||||
/// Create a new instance of `VhostUserMsgHeader`.
|
||||
pub fn new(request: R, flags: u32, size: u32) -> Self {
|
||||
// Default to protocol version 1
|
||||
let fl = (flags & VhostUserHeaderFlag::ALL_FLAGS.bits()) | 0x1;
|
||||
VhostUserMsgHeader {
|
||||
request: request.into(),
|
||||
flags: fl,
|
||||
size,
|
||||
_r: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get message type.
|
||||
pub fn get_code(&self) -> R {
|
||||
// It's safe because R is marked as repr(u32).
|
||||
unsafe { std::mem::transmute_copy::<u32, R>(&{ self.request }) }
|
||||
}
|
||||
|
||||
/// Get message version number.
|
||||
pub fn get_version(&self) -> u32 {
|
||||
self.flags & 0x3
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Req> Default for VhostUserMsgHeader<R> {
|
||||
fn default() -> Self {
|
||||
VhostUserMsgHeader {
|
||||
request: 0,
|
||||
flags: 0x1,
|
||||
size: 0,
|
||||
_r: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Unix domain socket endpoint for vhost-user connection.
|
||||
pub(crate) struct Endpoint<R: Req> {
|
||||
sock: UnixStream,
|
||||
_r: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl<R: Req> Endpoint<R> {
|
||||
/// Create a new stream by connecting to server at `str`.
|
||||
///
|
||||
/// # Return:
|
||||
/// * - the new Endpoint object on success.
|
||||
/// * - SocketConnect: failed to connect to peer.
|
||||
pub fn connect(path: &str) -> Result<Self> {
|
||||
let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?;
|
||||
Ok(Self::from_stream(sock))
|
||||
}
|
||||
|
||||
/// Create an endpoint from a stream object.
|
||||
pub fn from_stream(sock: UnixStream) -> Self {
|
||||
Endpoint {
|
||||
sock,
|
||||
_r: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends bytes from scatter-gather vectors over the socket with optional attached file
|
||||
/// descriptors.
|
||||
///
|
||||
/// # Return:
|
||||
/// * - number of bytes sent on success
|
||||
/// * - SocketRetry: temporary error caused by signals or short of resources.
|
||||
/// * - SocketBroken: the underline socket is broken.
|
||||
/// * - SocketError: other socket related errors.
|
||||
pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
|
||||
let rfds = match fds {
|
||||
Some(rfds) => rfds,
|
||||
_ => &[],
|
||||
};
|
||||
self.sock.send_with_fds(iovs, rfds).map_err(Into::into)
|
||||
}
|
||||
|
||||
/// Sends all bytes from scatter-gather vectors over the socket with optional attached file
|
||||
/// descriptors. Will loop until all data has been transfered.
|
||||
///
|
||||
/// # Return:
|
||||
/// * - number of bytes sent on success
|
||||
/// * - SocketBroken: the underline socket is broken.
|
||||
/// * - SocketError: other socket related errors.
|
||||
pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
|
||||
let mut data_sent = 0;
|
||||
let mut data_total = 0;
|
||||
let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.len()).collect();
|
||||
for len in &iov_lens {
|
||||
data_total += len;
|
||||
}
|
||||
|
||||
while (data_total - data_sent) > 0 {
|
||||
let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent);
|
||||
let iov = &iovs[nr_skip][offset..];
|
||||
|
||||
let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat();
|
||||
let sfds = if data_sent == 0 { fds } else { None };
|
||||
|
||||
let sent = self.send_iovec(data, sfds);
|
||||
match sent {
|
||||
Ok(0) => return Ok(data_sent),
|
||||
Ok(n) => data_sent += n,
|
||||
Err(e) => match e {
|
||||
Error::SocketRetry(_) => {}
|
||||
_ => return Err(e),
|
||||
},
|
||||
}
|
||||
}
|
||||
Ok(data_sent)
|
||||
}
|
||||
|
||||
/// Sends a header-only message with optional attached file descriptors.
|
||||
///
|
||||
/// # Return:
|
||||
/// * - number of bytes sent on success
|
||||
/// * - SocketRetry: temporary error caused by signals or short of resources.
|
||||
/// * - SocketBroken: the underline socket is broken.
|
||||
/// * - SocketError: other socket related errors.
|
||||
/// * - PartialMessage: received a partial message.
|
||||
pub fn send_header(
|
||||
&mut self,
|
||||
hdr: &VhostUserMsgHeader<R>,
|
||||
fds: Option<&[RawFd]>,
|
||||
) -> Result<()> {
|
||||
// Safe because there can't be other mutable referance to hdr.
|
||||
let iovs = unsafe {
|
||||
[slice::from_raw_parts(
|
||||
hdr as *const VhostUserMsgHeader<R> as *const u8,
|
||||
mem::size_of::<VhostUserMsgHeader<R>>(),
|
||||
)]
|
||||
};
|
||||
let bytes = self.send_iovec_all(&iovs[..], fds)?;
|
||||
if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
|
||||
return Err(Error::PartialMessage);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a message with header and body. Optional file descriptors may be attached to
|
||||
/// the message.
|
||||
///
|
||||
/// # Return:
|
||||
/// * - number of bytes sent on success
|
||||
/// * - SocketRetry: temporary error caused by signals or short of resources.
|
||||
/// * - SocketBroken: the underline socket is broken.
|
||||
/// * - SocketError: other socket related errors.
|
||||
/// * - PartialMessage: received a partial message.
|
||||
pub fn send_message<T: Sized>(
|
||||
&mut self,
|
||||
hdr: &VhostUserMsgHeader<R>,
|
||||
body: &T,
|
||||
fds: Option<&[RawFd]>,
|
||||
) -> Result<()> {
|
||||
// Safe because there can't be other mutable referance to hdr and body.
|
||||
let iovs = unsafe {
|
||||
[
|
||||
slice::from_raw_parts(
|
||||
hdr as *const VhostUserMsgHeader<R> as *const u8,
|
||||
mem::size_of::<VhostUserMsgHeader<R>>(),
|
||||
),
|
||||
slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()),
|
||||
]
|
||||
};
|
||||
|
||||
let bytes = self.send_iovec_all(&iovs[..], fds)?;
|
||||
if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() {
|
||||
return Err(Error::PartialMessage);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Reads bytes from the socket into the given scatter/gather vectors with optional attached
|
||||
/// file descriptors.
|
||||
///
|
||||
/// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
|
||||
/// tricky to pass file descriptors through such a communication channel. Let's assume that a
|
||||
/// sender sending a message with some file descriptors attached. To successfully receive those
|
||||
/// attached file descriptors, the receiver must obey following rules:
|
||||
/// 1) file descriptors are attached to a message.
|
||||
/// 2) message(packet) boundaries must be respected on the receive side.
|
||||
/// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
|
||||
/// attached file descriptors will get lost.
|
||||
///
|
||||
/// # Return:
|
||||
/// * - (number of bytes received, [received fds]) on success
|
||||
/// * - SocketRetry: temporary error caused by signals or short of resources.
|
||||
/// * - SocketBroken: the underline socket is broken.
|
||||
/// * - SocketError: other socket related errors.
|
||||
pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> {
|
||||
let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES];
|
||||
let (bytes, fds) = unsafe { self.sock.recv_with_fds(iovs, &mut fd_array)? };
|
||||
let rfds = match fds {
|
||||
0 => None,
|
||||
n => {
|
||||
let mut fds = Vec::with_capacity(n);
|
||||
fds.extend_from_slice(&fd_array[0..n]);
|
||||
Some(fds)
|
||||
}
|
||||
};
|
||||
|
||||
Ok((bytes, rfds))
|
||||
}
|
||||
|
||||
/// Reads all bytes from the socket into the given scatter/gather vectors with optional
|
||||
/// attached file descriptors. Will loop until all data has been transfered.
|
||||
///
|
||||
/// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
|
||||
/// tricky to pass file descriptors through such a communication channel. Let's assume that a
|
||||
/// sender sending a message with some file descriptors attached. To successfully receive those
|
||||
/// attached file descriptors, the receiver must obey following rules:
|
||||
/// 1) file descriptors are attached to a message.
|
||||
/// 2) message(packet) boundaries must be respected on the receive side.
|
||||
/// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
|
||||
/// attached file descriptors will get lost.
|
||||
///
|
||||
/// # Return:
|
||||
/// * - (number of bytes received, [received fds]) on success
|
||||
/// * - SocketBroken: the underline socket is broken.
|
||||
/// * - SocketError: other socket related errors.
|
||||
pub fn recv_into_iovec_all(
|
||||
&mut self,
|
||||
iovs: &mut [iovec],
|
||||
) -> Result<(usize, Option<Vec<RawFd>>)> {
|
||||
let mut data_read = 0;
|
||||
let mut data_total = 0;
|
||||
let mut rfds = None;
|
||||
let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.iov_len).collect();
|
||||
for len in &iov_lens {
|
||||
data_total += len;
|
||||
}
|
||||
|
||||
while (data_total - data_read) > 0 {
|
||||
let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_read);
|
||||
let iov = &mut iovs[nr_skip];
|
||||
|
||||
let mut data = [
|
||||
&[iovec {
|
||||
iov_base: (iov.iov_base as usize + offset) as *mut c_void,
|
||||
iov_len: iov.iov_len - offset,
|
||||
}],
|
||||
&iovs[(nr_skip + 1)..],
|
||||
]
|
||||
.concat();
|
||||
|
||||
let res = self.recv_into_iovec(&mut data);
|
||||
match res {
|
||||
Ok((0, _)) => return Ok((data_read, rfds)),
|
||||
Ok((n, fds)) => {
|
||||
if data_read == 0 {
|
||||
rfds = fds;
|
||||
}
|
||||
data_read += n;
|
||||
}
|
||||
Err(e) => match e {
|
||||
Error::SocketRetry(_) => {}
|
||||
_ => return Err(e),
|
||||
},
|
||||
}
|
||||
}
|
||||
Ok((data_read, rfds))
|
||||
}
|
||||
|
||||
/// Receive a header-only message with optional attached file descriptors.
|
||||
/// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
|
||||
/// accepted and all other file descriptor will be discard silently.
|
||||
///
|
||||
/// # Return:
|
||||
/// * - (message header, [received fds]) on success.
|
||||
/// * - SocketRetry: temporary error caused by signals or short of resources.
|
||||
/// * - SocketBroken: the underline socket is broken.
|
||||
/// * - SocketError: other socket related errors.
|
||||
/// * - PartialMessage: received a partial message.
|
||||
/// * - InvalidMessage: received a invalid message.
|
||||
pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> {
|
||||
let mut hdr = VhostUserMsgHeader::default();
|
||||
let mut iovs = [iovec {
|
||||
iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
|
||||
iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
|
||||
}];
|
||||
let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
|
||||
|
||||
if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
|
||||
return Err(Error::PartialMessage);
|
||||
} else if !hdr.is_valid() {
|
||||
return Err(Error::InvalidMessage);
|
||||
}
|
||||
|
||||
Ok((hdr, rfds))
|
||||
}
|
||||
|
||||
/// Receive a message with optional attached file descriptors.
|
||||
/// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
|
||||
/// accepted and all other file descriptor will be discard silently.
|
||||
///
|
||||
/// # Return:
|
||||
/// * - (message header, message body, [received fds]) on success.
|
||||
/// * - SocketRetry: temporary error caused by signals or short of resources.
|
||||
/// * - SocketBroken: the underline socket is broken.
|
||||
/// * - SocketError: other socket related errors.
|
||||
/// * - PartialMessage: received a partial message.
|
||||
/// * - InvalidMessage: received a invalid message.
|
||||
pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>(
|
||||
&mut self,
|
||||
) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> {
|
||||
let mut hdr = VhostUserMsgHeader::default();
|
||||
let mut body: T = Default::default();
|
||||
let mut iovs = [
|
||||
iovec {
|
||||
iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
|
||||
iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
|
||||
},
|
||||
iovec {
|
||||
iov_base: (&mut body as *mut T) as *mut c_void,
|
||||
iov_len: mem::size_of::<T>(),
|
||||
},
|
||||
];
|
||||
let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
|
||||
|
||||
let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
|
||||
if bytes != total {
|
||||
return Err(Error::PartialMessage);
|
||||
} else if !hdr.is_valid() || !body.is_valid() {
|
||||
return Err(Error::InvalidMessage);
|
||||
}
|
||||
|
||||
Ok((hdr, body, rfds))
|
||||
}
|
||||
|
||||
/// Send a message with header, body and payload. Optional file descriptors
|
||||
/// may also be attached to the message.
|
||||
///
|
||||
/// # Return:
|
||||
/// * - number of bytes sent on success
|
||||
/// * - SocketRetry: temporary error caused by signals or short of resources.
|
||||
/// * - SocketBroken: the underline socket is broken.
|
||||
/// * - SocketError: other socket related errors.
|
||||
/// * - OversizedMsg: message size is too big.
|
||||
/// * - PartialMessage: received a partial message.
|
||||
/// * - IncorrectFds: wrong number of attached fds.
|
||||
pub fn send_message_with_payload<T: Sized, P: Sized>(
|
||||
&mut self,
|
||||
hdr: &VhostUserMsgHeader<R>,
|
||||
body: &T,
|
||||
payload: &[P],
|
||||
fds: Option<&[RawFd]>,
|
||||
) -> Result<()> {
|
||||
let len = payload.len() * mem::size_of::<P>();
|
||||
if len > MAX_MSG_SIZE - mem::size_of::<T>() {
|
||||
return Err(Error::OversizedMsg);
|
||||
}
|
||||
if let Some(fd_arr) = fds {
|
||||
if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
|
||||
return Err(Error::IncorrectFds);
|
||||
}
|
||||
}
|
||||
|
||||
// Safe because there can't be other mutable reference to hdr, body and payload.
|
||||
let iovs = unsafe {
|
||||
[
|
||||
slice::from_raw_parts(
|
||||
hdr as *const VhostUserMsgHeader<R> as *const u8,
|
||||
mem::size_of::<VhostUserMsgHeader<R>>(),
|
||||
),
|
||||
slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()),
|
||||
slice::from_raw_parts(payload.as_ptr() as *const u8, len),
|
||||
]
|
||||
};
|
||||
let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len;
|
||||
let len = self.send_iovec_all(&iovs, fds)?;
|
||||
if len != total {
|
||||
return Err(Error::PartialMessage);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive a message with optional payload and attached file descriptors.
|
||||
/// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
|
||||
/// accepted and all other file descriptor will be discard silently.
|
||||
///
|
||||
/// # Return:
|
||||
/// * - (message header, message body, size of payload, [received fds]) on success.
|
||||
/// * - SocketRetry: temporary error caused by signals or short of resources.
|
||||
/// * - SocketBroken: the underline socket is broken.
|
||||
/// * - SocketError: other socket related errors.
|
||||
/// * - PartialMessage: received a partial message.
|
||||
/// * - InvalidMessage: received a invalid message.
|
||||
#[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))]
|
||||
pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>(
|
||||
&mut self,
|
||||
buf: &mut [u8],
|
||||
) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> {
|
||||
let mut hdr = VhostUserMsgHeader::default();
|
||||
let mut body: T = Default::default();
|
||||
let mut iovs = [
|
||||
iovec {
|
||||
iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
|
||||
iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
|
||||
},
|
||||
iovec {
|
||||
iov_base: (&mut body as *mut T) as *mut c_void,
|
||||
iov_len: mem::size_of::<T>(),
|
||||
},
|
||||
iovec {
|
||||
iov_base: buf.as_mut_ptr() as *mut c_void,
|
||||
iov_len: buf.len(),
|
||||
},
|
||||
];
|
||||
let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
|
||||
|
||||
let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
|
||||
if bytes < total {
|
||||
return Err(Error::PartialMessage);
|
||||
} else if !hdr.is_valid() || !body.is_valid() {
|
||||
return Err(Error::InvalidMessage);
|
||||
}
|
||||
|
||||
Ok((hdr, body, bytes - total, rfds))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Req> AsRawFd for Endpoint<T> {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.sock.as_raw_fd()
|
||||
}
|
||||
}
|
||||
|
||||
// Negotiate process from slave.
|
||||
pub(crate) fn negotiate_slave(
|
||||
slave: &mut Endpoint<MasterReq>,
|
||||
pfeatures: VhostUserProtocolFeatures,
|
||||
use_ali_feature: bool,
|
||||
has_protocol_mq: bool,
|
||||
queue_num: u64,
|
||||
) {
|
||||
// set owner
|
||||
let (hdr, rfds) = slave.recv_header().unwrap();
|
||||
assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
|
||||
assert!(rfds.is_none());
|
||||
|
||||
// get features
|
||||
let vfeatures = 0x15 | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
|
||||
let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
|
||||
let msg = VhostUserU64::new(vfeatures);
|
||||
slave.send_message(&hdr, &msg, None).unwrap();
|
||||
let (hdr, _rfds) = slave.recv_header().unwrap();
|
||||
assert_eq!(hdr.get_code(), MasterReq::GET_FEATURES);
|
||||
|
||||
// set features
|
||||
let (hdr, _msg, rfds) = slave.recv_body::<VhostUserU64>().unwrap();
|
||||
assert_eq!(hdr.get_code(), MasterReq::SET_FEATURES);
|
||||
assert!(rfds.is_none());
|
||||
|
||||
// get vhost-user protocol features
|
||||
let code = MasterReq::GET_PROTOCOL_FEATURES;
|
||||
let (hdr, rfds) = slave.recv_header().unwrap();
|
||||
assert_eq!(hdr.get_code(), code);
|
||||
assert!(rfds.is_none());
|
||||
let hdr = VhostUserMsgHeader::new(code, 0x4, 8);
|
||||
let msg = VhostUserU64::new(pfeatures.bits());
|
||||
slave.send_message(&hdr, &msg, None).unwrap();
|
||||
|
||||
// set vhost-user protocol features
|
||||
let (hdr, _msg, rfds) = slave.recv_body::<VhostUserU64>().unwrap();
|
||||
assert_eq!(hdr.get_code(), MasterReq::SET_PROTOCOL_FEATURES);
|
||||
assert!(rfds.is_none());
|
||||
|
||||
// set number of queues
|
||||
if has_protocol_mq {
|
||||
let (hdr, rfds) = slave.recv_header().unwrap();
|
||||
assert_eq!(hdr.get_code(), MasterReq::GET_QUEUE_NUM);
|
||||
assert!(rfds.is_none());
|
||||
let hdr = VhostUserMsgHeader::new(MasterReq::GET_QUEUE_NUM, 0x4, 8);
|
||||
let msg = VhostUserU64::new(queue_num);
|
||||
slave.send_message(&hdr, &msg, None).unwrap();
|
||||
}
|
||||
|
||||
// set vring call
|
||||
for _i in 0..queue_num {
|
||||
let (hdr, _msg, rfds) = slave.recv_body::<VhostUserU64>().unwrap();
|
||||
assert_eq!(hdr.get_code(), MasterReq::SET_VRING_CALL);
|
||||
assert!(rfds.is_some());
|
||||
}
|
||||
|
||||
// set mem table
|
||||
let mut region_buf: Vec<u8> = vec![0u8; mem::size_of::<VhostUserMemoryRegion>()];
|
||||
let (hdr, _msg, _payload, rfds) = slave
|
||||
.recv_payload_into_buf::<VhostUserMemory>(&mut region_buf)
|
||||
.unwrap();
|
||||
assert_eq!(hdr.get_code(), MasterReq::SET_MEM_TABLE);
|
||||
assert!(rfds.is_some());
|
||||
|
||||
if pfeatures.contains(VhostUserProtocolFeatures::INFLIGHT_SHMFD) {
|
||||
// get inflight fd
|
||||
let (hdr, _msg, rfds) = slave.recv_body::<VhostUserInflight>().unwrap();
|
||||
assert_eq!(hdr.get_code(), MasterReq::GET_INFLIGHT_FD);
|
||||
assert!(rfds.is_none());
|
||||
let msg = VhostUserInflight {
|
||||
mmap_size: 0x100,
|
||||
mmap_offset: 0x0,
|
||||
..Default::default()
|
||||
};
|
||||
let inflight_file = TempFile::new().unwrap().into_file();
|
||||
inflight_file.set_len(0x100).unwrap();
|
||||
let fds = [inflight_file.as_raw_fd()];
|
||||
let hdr = VhostUserMsgHeader::new(
|
||||
MasterReq::GET_INFLIGHT_FD,
|
||||
VhostUserHeaderFlag::REPLY.bits(),
|
||||
std::mem::size_of::<VhostUserInflight>() as u32,
|
||||
);
|
||||
slave.send_message(&hdr, &msg, Some(&fds)).unwrap();
|
||||
|
||||
// set inflight fd
|
||||
let (hdr, _msg, rfds) = slave.recv_body::<VhostUserInflight>().unwrap();
|
||||
assert_eq!(hdr.get_code(), MasterReq::SET_INFLIGHT_FD);
|
||||
assert!(rfds.is_some());
|
||||
let hdr = VhostUserMsgHeader::new(
|
||||
MasterReq::GET_INFLIGHT_FD,
|
||||
VhostUserHeaderFlag::REPLY.bits(),
|
||||
std::mem::size_of::<VhostUserInflight>() as u32,
|
||||
);
|
||||
slave.send_header(&hdr, None).unwrap();
|
||||
}
|
||||
|
||||
// set vring num
|
||||
for _i in 0..queue_num {
|
||||
let (hdr, _msg, rfds) = slave.recv_body::<VhostUserVringState>().unwrap();
|
||||
slave.send_header(&hdr, None).unwrap();
|
||||
assert!(rfds.is_none());
|
||||
}
|
||||
|
||||
// set vring base
|
||||
for _i in 0..queue_num {
|
||||
let (hdr, _msg, rfds) = slave.recv_body::<VhostUserVringState>().unwrap();
|
||||
assert_eq!(hdr.get_code(), MasterReq::SET_VRING_BASE);
|
||||
assert!(rfds.is_none());
|
||||
slave.send_header(&hdr, None).unwrap();
|
||||
}
|
||||
|
||||
// set vring addr
|
||||
for _i in 0..queue_num {
|
||||
let (hdr, _msg, rfds) = slave.recv_body::<VhostUserVringAddr>().unwrap();
|
||||
assert_eq!(hdr.get_code(), MasterReq::SET_VRING_ADDR);
|
||||
assert!(rfds.is_none());
|
||||
slave.send_header(&hdr, None).unwrap();
|
||||
}
|
||||
|
||||
// set vring kick
|
||||
for _i in 0..queue_num {
|
||||
let (hdr, _msg, rfds) = slave.recv_body::<VhostUserU64>().unwrap();
|
||||
assert_eq!(hdr.get_code(), MasterReq::SET_VRING_KICK);
|
||||
assert!(rfds.is_some());
|
||||
}
|
||||
|
||||
// set vring call
|
||||
for _i in 0..queue_num {
|
||||
let (hdr, _msg, rfds) = slave.recv_body::<VhostUserU64>().unwrap();
|
||||
assert_eq!(hdr.get_code(), MasterReq::SET_VRING_CALL);
|
||||
assert!(rfds.is_some());
|
||||
}
|
||||
|
||||
// set vring enable
|
||||
for _i in 0..queue_num {
|
||||
let (hdr, _msg, rfds) = slave.recv_body::<VhostUserVringState>().unwrap();
|
||||
assert_eq!(hdr.get_code(), MasterReq::SET_VRING_ENABLE);
|
||||
assert!(rfds.is_none());
|
||||
slave.send_header(&hdr, None).unwrap();
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user