diff --git a/src/agent/src/image_rpc.rs b/src/agent/src/image_rpc.rs index bcdb0cf189..ec88aea88c 100644 --- a/src/agent/src/image_rpc.rs +++ b/src/agent/src/image_rpc.rs @@ -9,7 +9,7 @@ use std::env; use std::fs; use std::path::Path; use std::process::Command; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU16, Ordering}; use std::sync::Arc; use anyhow::{anyhow, Result}; @@ -45,6 +45,7 @@ pub struct ImageService { sandbox: Arc>, attestation_agent_started: AtomicBool, image_client: Arc>, + container_count: Arc, } impl ImageService { @@ -54,6 +55,7 @@ impl ImageService { sandbox, attestation_agent_started: AtomicBool::new(false), image_client: Arc::new(Mutex::new(ImageClient::default())), + container_count: Arc::new(AtomicU16::new(0)), } } @@ -115,13 +117,16 @@ impl ImageService { /// /// If the request specifies a non-empty id, use it; otherwise derive it from the image path. /// In either case, verify that the chosen id is valid. - fn cid_from_request(req: &image::PullImageRequest) -> Result { + fn cid_from_request(&self, req: &image::PullImageRequest) -> Result { let req_cid = req.get_container_id(); let cid = if !req_cid.is_empty() { req_cid.to_string() } else if let Some(last) = req.get_image().rsplit('/').next() { + // Support multiple containers with same image + let index = self.container_count.fetch_add(1, Ordering::Relaxed); + // ':' not valid for container id - last.replace(':', "_") + format!("{}_{}", last.replace(':', "_"), index) } else { return Err(anyhow!("Invalid image name. {}", req.get_image())); }; @@ -142,7 +147,7 @@ impl ImageService { env::set_var("NO_PROXY", no_proxy); } - let cid = Self::cid_from_request(req)?; + let cid = self.cid_from_request(req)?; let image = req.get_image(); if cid.starts_with("pause") { Self::unpack_pause_image(&cid)?; @@ -246,10 +251,13 @@ impl protocols::image_ttrpc_async::Image for ImageService { #[cfg(test)] mod tests { use super::ImageService; + use crate::sandbox::Sandbox; use protocols::image; + use std::sync::Arc; + use tokio::sync::Mutex; - #[test] - fn test_cid_from_request() { + #[tokio::test] + async fn test_cid_from_request() { struct Case { cid: &'static str, image: &'static str, @@ -285,12 +293,12 @@ mod tests { Case { cid: "", image: "../abc", - result: Some("abc"), + result: Some("abc_4"), }, Case { cid: "", image: "../9abc", - result: Some("9abc"), + result: Some("9abc_5"), }, Case { cid: "some-string.1_2", @@ -310,20 +318,23 @@ mod tests { Case { cid: "", image: "prefix/a:b", - result: Some("a_b"), + result: Some("a_b_6"), }, Case { cid: "", image: "/a/b/c/d:e", - result: Some("d_e"), + result: Some("d_e_7"), }, ]; + let logger = slog::Logger::root(slog::Discard, o!()); + let s = Sandbox::new(&logger).unwrap(); + let image_service = ImageService::new(Arc::new(Mutex::new(s))); for case in &cases { let mut req = image::PullImageRequest::new(); req.set_image(case.image.to_string()); req.set_container_id(case.cid.to_string()); - let ret = ImageService::cid_from_request(&req); + let ret = image_service.cid_from_request(&req); match (case.result, ret) { (Some(expected), Ok(actual)) => assert_eq!(expected, actual), (None, Err(_)) => (),