diff --git a/src/agent/src/image_rpc.rs b/src/agent/src/image_rpc.rs index 298b79b7bf..1176735351 100644 --- a/src/agent/src/image_rpc.rs +++ b/src/agent/src/image_rpc.rs @@ -203,6 +203,24 @@ impl ImageService { Ok(()) } + /// Determines the container id (cid) to use for a given request. + /// + /// If the request specifies a non-empty id, use it; otherwise derive it from the image path. + /// In either case, verify that the chosen id is valid. + fn cid_from_request(req: &image::PullImageRequest) -> Result { + let req_cid = req.get_container_id(); + let cid = if !req_cid.is_empty() { + req_cid.to_string() + } else if let Some(last) = req.get_image().rsplit('/').next() { + // ':' have special meaning for umoci during upack + last.replace(':', "_") + } else { + return Err(anyhow!("Invalid image name. {}", req.get_image())); + }; + verify_cid(&cid)?; + Ok(cid) + } + async fn pull_image(&self, req: &image::PullImageRequest) -> Result { env::set_var("OCICRYPT_KEYPROVIDER_CONFIG", OCICRYPT_CONFIG_PATH); @@ -216,33 +234,19 @@ impl ImageService { env::set_var("NO_PROXY", no_proxy); } + let cid = Self::cid_from_request(req)?; let image = req.get_image(); - let mut cid = req.get_container_id().to_string(); - - let aa_kbc_params = &AGENT_CONFIG.read().await.aa_kbc_params; - - if cid.is_empty() { - let v: Vec<&str> = image.rsplit('/').collect(); - if !v[0].is_empty() { - // ':' have special meaning for umoci during upack - cid = v[0].replace(':', "_"); - } else { - return Err(anyhow!("Invalid image name. {}", image)); - } - } else { - verify_cid(&cid)?; - } - // Can switch to use cid directly when we remove umoci let v: Vec<&str> = image.rsplit('/').collect(); if !v[0].is_empty() && v[0].starts_with("pause:") { Self::unpack_pause_image(&cid)?; let mut sandbox = self.sandbox.lock().await; - sandbox.images.insert(String::from(image), cid.to_string()); + sandbox.images.insert(String::from(image), cid); return Ok(image.to_owned()); } + let aa_kbc_params = &AGENT_CONFIG.read().await.aa_kbc_params; if !aa_kbc_params.is_empty() { match self.attestation_agent_started.compare_exchange_weak( false, @@ -288,7 +292,7 @@ impl ImageService { } let mut sandbox = self.sandbox.lock().await; - sandbox.images.insert(String::from(image), cid.to_string()); + sandbox.images.insert(String::from(image), cid); Ok(image.to_owned()) } } @@ -312,3 +316,96 @@ impl protocols::image_ttrpc_async::Image for ImageService { } } } + +#[cfg(test)] +mod tests { + use super::ImageService; + use protocols::image; + + #[test] + fn test_cid_from_request() { + struct Case { + cid: &'static str, + image: &'static str, + result: Option<&'static str>, + } + + let cases = [ + Case { + cid: "", + image: "", + result: None, + }, + Case { + cid: "..", + image: "", + result: None, + }, + Case { + cid: "", + image: "..", + result: None, + }, + Case { + cid: "", + image: "abc/..", + result: None, + }, + Case { + cid: "", + image: "abc/", + result: None, + }, + Case { + cid: "", + image: "../abc", + result: Some("abc"), + }, + Case { + cid: "", + image: "../9abc", + result: Some("9abc"), + }, + Case { + cid: "some-string.1_2", + image: "", + result: Some("some-string.1_2"), + }, + Case { + cid: "0some-string.1_2", + image: "", + result: Some("0some-string.1_2"), + }, + Case { + cid: "a:b", + image: "", + result: None, + }, + Case { + cid: "", + image: "prefix/a:b", + result: Some("a_b"), + }, + Case { + cid: "", + image: "/a/b/c/d:e", + result: Some("d_e"), + }, + ]; + + for case in &cases { + let mut req = image::PullImageRequest::new(); + req.set_image(case.image.to_string()); + req.set_container_id(case.cid.to_string()); + let ret = ImageService::cid_from_request(&req); + match (case.result, ret) { + (Some(expected), Ok(actual)) => assert_eq!(expected, actual), + (None, Err(_)) => (), + (None, Ok(r)) => panic!("Expected an error, got {}", r), + (Some(expected), Err(e)) => { + panic!("Expected {} but got an error ({})", expected, e) + } + } + } + } +}