diff --git a/src/runtime-rs/crates/hypervisor/src/utils.rs b/src/runtime-rs/crates/hypervisor/src/utils.rs index 6a0e41bc0..d1fe47a5f 100644 --- a/src/runtime-rs/crates/hypervisor/src/utils.rs +++ b/src/runtime-rs/crates/hypervisor/src/utils.rs @@ -16,6 +16,8 @@ use nix::{ fcntl, sched::{setns, CloneFlags}, }; +use serde::{Deserialize, Serialize}; +use serde_json; use crate::device::Tap; @@ -144,9 +146,56 @@ fn create_fds(device: &str, num_fds: usize) -> Result> { Ok(fds) } +// QGS_SOCKET_PATH: the Unix Domain Socket Path served by Intel TDX Quote Generation Service +const QGS_SOCKET_PATH: &str = "/var/run/tdx-qgs/qgs.socket"; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SocketAddress { + #[serde(rename = "type")] + pub typ: String, + + #[serde(rename = "cid", skip_serializing_if = "String::is_empty")] + pub cid: String, + + #[serde(rename = "port", skip_serializing_if = "String::is_empty")] + pub port: String, + + #[serde(rename = "path", skip_serializing_if = "String::is_empty")] + pub path: String, +} + +impl SocketAddress { + pub fn new(port: u32) -> Self { + if port == 0 { + Self { + typ: "unix".to_string(), + cid: "".to_string(), + port: "".to_string(), + path: QGS_SOCKET_PATH.to_string(), + } + } else { + Self { + typ: "vsock".to_string(), + cid: format!("{}", 2), + port: port.to_string(), + path: "".to_string(), + } + } + } +} + +impl std::fmt::Display for SocketAddress { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + serde_json::to_string(self) + .map_err(|_| std::fmt::Error) + .and_then(|s| write!(f, "{}", s)) + } +} + #[cfg(test)] mod tests { use super::create_fds; + use super::SocketAddress; #[test] fn test_ctreate_fds() { @@ -156,4 +205,43 @@ mod tests { assert!(fds.is_ok()); assert_eq!(fds.unwrap().len(), num_fds); } + + #[test] + fn test_vsocket_address_new() { + let socket = SocketAddress::new(8866); + assert_eq!(socket.typ, "vsock"); + assert_eq!(socket.cid, "2"); + assert_eq!(socket.port, "8866"); + } + + #[test] + fn test_unix_address_new() { + let socket = SocketAddress::new(0); + assert_eq!(socket.typ, "unix"); + assert_eq!(socket.path, "/var/run/tdx-qgs/qgs.socket"); + } + + #[test] + fn test_socket_address_display() { + let socket = SocketAddress::new(6688); + let expected_json = r#"{"type":"vsock","cid":"2","port":"6688"}"#; + assert_eq!(format!("{}", socket), expected_json); + } + + #[test] + fn test_socket_address_serialize_deserialize() { + let socket = SocketAddress::new(0); + let serialized = serde_json::to_string(&socket).unwrap(); + let expected_json = r#"{"type":"unix","path":"/var/run/tdx-qgs/qgs.socket"}"#; + assert_eq!(expected_json, serialized); + } + + #[test] + fn test_socket_address_kebab_case() { + let socket = SocketAddress::new(6868); + let serialized = serde_json::to_string(&socket).unwrap(); + assert!(serialized.contains(r#""type":"#)); + assert!(serialized.contains(r#""cid":"#)); + assert!(serialized.contains(r#""port":"#)); + } }