diff --git a/src/agent/rustjail/src/container.rs b/src/agent/rustjail/src/container.rs index d716c5344d..c3df3c1c2c 100644 --- a/src/agent/rustjail/src/container.rs +++ b/src/agent/rustjail/src/container.rs @@ -636,11 +636,10 @@ fn do_init_child(cwfd: RawFd) -> Result<()> { // setup the envs for e in env.iter() { - let v: Vec<&str> = e.splitn(2, '=').collect(); - if v.len() != 2 { - continue; + match valid_env(e) { + Some((key, value)) => env::set_var(key, value), + None => log_child!(cfd_log, "invalid env key-value: {:?}", e), } - env::set_var(v[0], v[1]); } // set the "HOME" env getting from "/etc/passwd", if @@ -1479,15 +1478,15 @@ async fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { return Err(anyhow!(nix::Error::from_errno(Errno::EINVAL))); } - let args = h.args.clone(); - let env: HashMap = h - .env - .iter() - .map(|e| { - let v: Vec<&str> = e.split('=').collect(); - (v[0].to_string(), v[1].to_string()) - }) - .collect(); + let mut args = h.args.clone(); + // the hook.args[0] is the hook binary name which shouldn't be included + // in the Command.args + if args.len() > 1 { + args.remove(0); + } + + // all invalid envs will be omitted, only valid envs will be passed to hook. + let env: HashMap<&str, &str> = h.env.iter().filter_map(|e| valid_env(e)).collect(); // Avoid the exit signal to be reaped by the global reaper. let _wait_locker = WAIT_PID_LOCKER.lock().await; @@ -1498,8 +1497,7 @@ async fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) - .spawn() - .unwrap(); + .spawn()?; // default timeout 10s let mut timeout: u64 = 10; @@ -1515,27 +1513,39 @@ async fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { let path = h.path.clone(); let join_handle = tokio::spawn(async move { - child - .stdin - .as_mut() - .unwrap() - .write_all(state.as_bytes()) - .await - .unwrap(); + if let Some(mut stdin) = child.stdin.take() { + match stdin.write_all(state.as_bytes()).await { + Ok(_) => {} + Err(e) => { + info!(logger, "write to child stdin failed: {:?}", e); + } + } + } - // Close stdin so that hook program could receive EOF - child.stdin.take(); + // read something from stdout and stderr for debug + if let Some(stdout) = child.stdout.as_mut() { + let mut out = String::new(); + match stdout.read_to_string(&mut out).await { + Ok(_) => { + info!(logger, "child stdout: {}", out.as_str()); + } + Err(e) => { + info!(logger, "read from child stdout failed: {:?}", e); + } + } + } - // read something from stdout for debug - let mut out = String::new(); - child - .stdout - .as_mut() - .unwrap() - .read_to_string(&mut out) - .await - .unwrap(); - info!(logger, "child stdout: {}", out.as_str()); + let mut err = String::new(); + if let Some(stderr) = child.stderr.as_mut() { + match stderr.read_to_string(&mut err).await { + Ok(_) => { + info!(logger, "child stderr: {}", err.as_str()); + } + Err(e) => { + info!(logger, "read from child stderr failed: {:?}", e); + } + } + } match child.wait().await { Ok(exit) => { @@ -1544,7 +1554,10 @@ async fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { .ok_or_else(|| anyhow!("hook exit status has no status code"))?; if code != 0 { - error!(logger, "hook {} exit status is {}", &path, code); + error!( + logger, + "hook {} exit status is {}, error message is {}", &path, code, err + ); return Err(anyhow!(nix::Error::from_errno(Errno::UnknownErrno))); } @@ -1565,6 +1578,30 @@ async fn execute_hook(logger: &Logger, h: &Hook, st: &OCIState) -> Result<()> { } } +// valid environment variables according to https://doc.rust-lang.org/std/env/fn.set_var.html#panics +fn valid_env(e: &str) -> Option<(&str, &str)> { + // wherther key or value will contain NULL char. + if e.as_bytes().contains(&b'\0') { + return None; + } + + let v: Vec<&str> = e.splitn(2, '=').collect(); + + // key can't hold an `equal` sign, but value can + if v.len() != 2 { + return None; + } + + let (key, value) = (v[0].trim(), v[1].trim()); + + // key can't be empty + if key.is_empty() { + return None; + } + + Some((key, value)) +} + #[cfg(test)] mod tests { use super::*; @@ -1597,13 +1634,47 @@ mod tests { #[tokio::test] async fn test_execute_hook() { - let xargs = which("xargs").await; + let temp_file = "/tmp/test_execute_hook"; + + let touch = which("touch").await; + + defer!(fs::remove_file(temp_file).unwrap();); + let invalid_str = vec![97, b'\0', 98]; + let invalid_string = std::str::from_utf8(&invalid_str).unwrap(); + let invalid_env = format!("{}=value", invalid_string); execute_hook( &slog_scope::logger(), &Hook { - path: xargs, - args: vec![], + path: touch, + args: vec!["touch".to_string(), temp_file.to_string()], + env: vec![invalid_env], + timeout: Some(10), + }, + &OCIState { + version: "1.2.3".to_string(), + id: "321".to_string(), + status: ContainerState::Running, + pid: 2, + bundle: "".to_string(), + annotations: Default::default(), + }, + ) + .await + .unwrap(); + + assert_eq!(Path::new(&temp_file).exists(), true); + } + + #[tokio::test] + async fn test_execute_hook_with_error() { + let ls = which("ls").await; + + let res = execute_hook( + &slog_scope::logger(), + &Hook { + path: ls, + args: vec!["ls".to_string(), "/tmp/not-exist".to_string()], env: vec![], timeout: None, }, @@ -1616,8 +1687,13 @@ mod tests { annotations: Default::default(), }, ) - .await - .unwrap() + .await; + + let expected_err = nix::Error::from_errno(Errno::UnknownErrno); + assert_eq!( + res.unwrap_err().downcast::().unwrap(), + expected_err + ); } #[tokio::test] @@ -1628,7 +1704,7 @@ mod tests { &slog_scope::logger(), &Hook { path: sleep, - args: vec!["2".to_string()], + args: vec!["sleep".to_string(), "2".to_string()], env: vec![], timeout: Some(1), }, @@ -1988,4 +2064,49 @@ mod tests { let ret = do_init_child(std::io::stdin().as_raw_fd()); assert!(ret.is_err(), "Expecting Err, Got {:?}", ret); } + + #[test] + fn test_valid_env() { + let env = valid_env("a=b=c"); + assert_eq!(Some(("a", "b=c")), env); + + let env = valid_env("a=b"); + assert_eq!(Some(("a", "b")), env); + let env = valid_env("a =b"); + assert_eq!(Some(("a", "b")), env); + + let env = valid_env(" a =b"); + assert_eq!(Some(("a", "b")), env); + + let env = valid_env("a= b"); + assert_eq!(Some(("a", "b")), env); + + let env = valid_env("a=b "); + assert_eq!(Some(("a", "b")), env); + let env = valid_env("a=b c "); + assert_eq!(Some(("a", "b c")), env); + + let env = valid_env("=b"); + assert_eq!(None, env); + + let env = valid_env("a="); + assert_eq!(Some(("a", "")), env); + + let env = valid_env("a=="); + assert_eq!(Some(("a", "=")), env); + + let env = valid_env("a"); + assert_eq!(None, env); + + let invalid_str = vec![97, b'\0', 98]; + let invalid_string = std::str::from_utf8(&invalid_str).unwrap(); + + let invalid_env = format!("{}=value", invalid_string); + let env = valid_env(&invalid_env); + assert_eq!(None, env); + + let invalid_env = format!("key={}", invalid_string); + let env = valid_env(&invalid_env); + assert_eq!(None, env); + } }