diff --git a/src/agent/src/rpc.rs b/src/agent/src/rpc.rs index 27a698bb5f..550784dfe7 100644 --- a/src/agent/src/rpc.rs +++ b/src/agent/src/rpc.rs @@ -406,7 +406,8 @@ impl AgentService { // For container initProcess, if it hasn't installed handler for "SIGTERM" signal, // it will ignore the "SIGTERM" signal sent to it, thus send it "SIGKILL" signal // instead of "SIGTERM" to terminate it. - if p.init && sig == libc::SIGTERM && !is_signal_handled(p.pid, sig as u32) { + let proc_status_file = format!("/proc/{}/status", p.pid); + if p.init && sig == libc::SIGTERM && !is_signal_handled(&proc_status_file, sig as u32) { sig = libc::SIGKILL; } p.signal(sig)?; @@ -1617,21 +1618,33 @@ fn append_guest_hooks(s: &Sandbox, oci: &mut Spec) -> Result<()> { Ok(()) } -// Check is the container process installed the +// Check if the container process installed the // handler for specific signal. -fn is_signal_handled(pid: pid_t, signum: u32) -> bool { - let sig_mask: u64 = 1u64 << (signum - 1); - let file_name = format!("/proc/{}/status", pid); +fn is_signal_handled(proc_status_file: &str, signum: u32) -> bool { + let shift_count: u64 = if signum == 0 { + // signum 0 is used to check for process liveness. + // Since that signal is not part of the mask in the file, we only need + // to know if the file (and therefore) process exists to handle + // that signal. + return fs::metadata(proc_status_file).is_ok(); + } else if signum > 64 { + // Ensure invalid signum won't break bit shift logic + warn!(sl!(), "received invalid signum {}", signum); + return false; + } else { + (signum - 1).into() + }; // Open the file in read-only mode (ignoring errors). - let file = match File::open(&file_name) { + let file = match File::open(proc_status_file) { Ok(f) => f, Err(_) => { - warn!(sl!(), "failed to open file {}\n", file_name); + warn!(sl!(), "failed to open file {}", proc_status_file); return false; } }; + let sig_mask: u64 = 1 << shift_count; let reader = BufReader::new(file); // Read the file line by line using the lines() iterator from std::io::BufRead. @@ -1639,21 +1652,21 @@ fn is_signal_handled(pid: pid_t, signum: u32) -> bool { let line = match line { Ok(l) => l, Err(_) => { - warn!(sl!(), "failed to read file {}\n", file_name); + warn!(sl!(), "failed to read file {}", proc_status_file); return false; } }; if line.starts_with("SigCgt:") { let mask_vec: Vec<&str> = line.split(':').collect(); if mask_vec.len() != 2 { - warn!(sl!(), "parse the SigCgt field failed\n"); + warn!(sl!(), "parse the SigCgt field failed"); return false; } let sig_cgt_str = mask_vec[1]; let sig_cgt_mask = match u64::from_str_radix(sig_cgt_str, 16) { Ok(h) => h, Err(_) => { - warn!(sl!(), "failed to parse the str {} to hex\n", sig_cgt_str); + warn!(sl!(), "failed to parse the str {} to hex", sig_cgt_str); return false; } }; @@ -2423,6 +2436,102 @@ mod tests { } } + #[tokio::test] + async fn test_is_signal_handled() { + #[derive(Debug)] + struct TestData<'a> { + status_file_data: Option<&'a str>, + signum: u32, + result: bool, + } + + let tests = &[ + TestData { + status_file_data: Some( + r#" +SigBlk:0000000000010000 +SigCgt:0000000000000001 +OtherField:other + "#, + ), + signum: 1, + result: true, + }, + TestData { + status_file_data: Some("SigCgt:000000004b813efb"), + signum: 4, + result: true, + }, + TestData { + status_file_data: Some("SigCgt:000000004b813efb"), + signum: 3, + result: false, + }, + TestData { + status_file_data: Some("SigCgt:000000004b813efb"), + signum: 65, + result: false, + }, + TestData { + status_file_data: Some("SigCgt:000000004b813efb"), + signum: 0, + result: true, + }, + TestData { + status_file_data: Some("SigCgt:ZZZZZZZZ"), + signum: 1, + result: false, + }, + TestData { + status_file_data: Some("SigCgt:-1"), + signum: 1, + result: false, + }, + TestData { + status_file_data: Some("SigCgt"), + signum: 1, + result: false, + }, + TestData { + status_file_data: Some("any data"), + signum: 0, + result: true, + }, + TestData { + status_file_data: Some("SigBlk:0000000000000001"), + signum: 1, + result: false, + }, + TestData { + status_file_data: None, + signum: 1, + result: false, + }, + TestData { + status_file_data: None, + signum: 0, + result: false, + }, + ]; + + for (i, d) in tests.iter().enumerate() { + let msg = format!("test[{}]: {:?}", i, d); + + let dir = tempdir().expect("failed to make tempdir"); + let proc_status_file_path = dir.path().join("status"); + + if let Some(file_data) = d.status_file_data { + fs::write(&proc_status_file_path, file_data).unwrap(); + } + + let result = is_signal_handled(proc_status_file_path.to_str().unwrap(), d.signum); + + let msg = format!("{}, result: {:?}", msg, result); + + assert_eq!(d.result, result, "{}", msg); + } + } + #[tokio::test] async fn test_verify_cid() { #[derive(Debug)]