diff --git a/src/runtime-rs/crates/agent/src/sock/vsock.rs b/src/runtime-rs/crates/agent/src/sock/vsock.rs index dc4a41b8bd..a24e82a046 100644 --- a/src/runtime-rs/crates/agent/src/sock/vsock.rs +++ b/src/runtime-rs/crates/agent/src/sock/vsock.rs @@ -3,11 +3,8 @@ // // SPDX-License-Identifier: Apache-2.0 // - -use std::{ - os::unix::prelude::{AsRawFd, FromRawFd}, - time::Duration, -}; +use std::os::unix::prelude::{AsRawFd, FromRawFd}; +use std::time::{Duration, Instant}; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; @@ -31,70 +28,95 @@ impl Vsock { #[async_trait] impl Sock for Vsock { async fn connect(&self, config: &ConnectConfig) -> Result { - let mut last_err = None; - let retry_times = config.reconnect_timeout_ms / config.dial_timeout_ms; let sock_addr = VsockAddr::new(self.vsock_cid, self.port); - let connect_once = || { - // Create socket fd - let socket = socket( - AddressFamily::Vsock, - SockType::Stream, - SockFlag::empty(), - None, - ) - .context("failed to create vsock socket")?; + let deadline = Instant::now() + Duration::from_millis(config.reconnect_timeout_ms); - // Wrap the socket fd in a UnixStream, so that it is closed when - // anything fails. - // We MUST NOT reuse a vsock socket which has failed a connection - // attempt before, since a ECONNRESET error marks the whole socket as - // broken and non-reusable. - let socket = unsafe { std::os::unix::net::UnixStream::from_raw_fd(socket) }; + let mut backoff = Duration::from_millis(config.dial_timeout_ms); - // Connect the socket to vsock server. - connect(socket.as_raw_fd(), &sock_addr) - .with_context(|| format!("failed to connect to {sock_addr}"))?; + let min_backoff = Duration::from_millis(10); + let max_backoff = Duration::from_millis(500); + if backoff < min_backoff { + backoff = min_backoff; + } else if backoff > max_backoff { + backoff = max_backoff; + } - // Started from tokio v1.44.0+, it would panic when giving - // `from_std()` a blocking socket. A workaround is to set the - // socket to non-blocking, see [1]. - // - // https://github.com/tokio-rs/tokio/issues/7172 - socket - .set_nonblocking(true) - .context("failed to set non-blocking")?; + let mut last_err: Option = None; + let mut attempts: u64 = 0; - // Finally, convert the std UnixSocket to tokio's UnixSocket. - UnixStream::from_std(socket).context("from_std") - }; + while Instant::now() < deadline { + attempts += 1; - for i in 0..retry_times { - match connect_once() { + let sa = sock_addr; + let res: Result = + tokio::task::spawn_blocking(move || -> Result { + // Create socket fd + let fd = socket( + AddressFamily::Vsock, + SockType::Stream, + SockFlag::empty(), + None, + ) + .context("failed to create vsock socket")?; + + // Wrap fd so it closes on error + let socket = unsafe { std::os::unix::net::UnixStream::from_raw_fd(fd) }; + + // Blocking connect (usually returns quickly for vsock) + connect(socket.as_raw_fd(), &sa) + .with_context(|| format!("failed to connect to {sa}"))?; + + // Tokio requires non-blocking std socket before from_std() + socket + .set_nonblocking(true) + .context("failed to set non-blocking")?; + + UnixStream::from_std(socket).context("from_std") + }) + .await + .context("vsock: connect task join failed")?; + + match res { Ok(stream) => { - info!(sl!(), "vsock: connected to {:?}", self); + info!( + sl!(), + "vsock: connected to {:?} after {} attempts", self, attempts + ); return Ok(Stream::Vsock(stream)); } Err(e) => { + last_err = Some(e); + + let now = Instant::now(); + if now >= deadline { + break; + } + + let remaining = deadline.saturating_duration_since(now); + let sleep_dur = std::cmp::min(backoff, remaining); + trace!( sl!(), - "vsock: failed to connect to {:?}, err {:?}, attempts {}, will retry after {} ms", + "vsock: failed to connect to {:?}, attempts {}, retry after {:?}, err {:?}", self, - e, - i, - config.dial_timeout_ms, + attempts, + sleep_dur, + last_err.as_ref().unwrap(), ); - last_err = Some(e); - tokio::time::sleep(Duration::from_millis(config.dial_timeout_ms)).await; + + tokio::time::sleep(sleep_dur).await; + + backoff = std::cmp::min(backoff.saturating_mul(2), max_backoff); } } } - // Safe to unwrap the last_err, as this line will be unreachable if - // no errors occurred. Err(anyhow!( - "vsock: failed to connect to {:?}, err {:?}", + "vsock: failed to connect to {:?} within {:?} (attempts={}), last_err={:?}", self, - last_err.unwrap() + Duration::from_millis(config.reconnect_timeout_ms), + attempts, + last_err )) } }