diff --git a/src/agent/src/policy.rs b/src/agent/src/policy.rs index 020251024..6809ae51a 100644 --- a/src/agent/src/policy.rs +++ b/src/agent/src/policy.rs @@ -4,11 +4,15 @@ // use anyhow::{bail, Result}; +use protobuf::MessageDyn; use serde::{Deserialize, Serialize}; use slog::Drain; use tokio::io::AsyncWriteExt; use tokio::time::{sleep, Duration}; +use crate::rpc::ttrpc_error; +use crate::AGENT_POLICY; + static EMPTY_JSON_INPUT: &str = "{\"input\":{}}"; static OPA_DATA_PATH: &str = "/data"; @@ -23,6 +27,34 @@ macro_rules! sl { }; } +async fn allow_request(policy: &mut AgentPolicy, ep: &str, request: &str) -> ttrpc::Result<()> { + if !policy.allow_request(ep, request).await { + warn!(sl!(), "{ep} is blocked by policy"); + Err(ttrpc_error( + ttrpc::Code::PERMISSION_DENIED, + format!("{ep} is blocked by policy"), + )) + } else { + Ok(()) + } +} + +pub async fn is_allowed(req: &(impl MessageDyn + serde::Serialize)) -> ttrpc::Result<()> { + let request = serde_json::to_string(req).unwrap(); + let mut policy = AGENT_POLICY.lock().await; + allow_request(&mut policy, req.descriptor_dyn().name(), &request).await +} + +pub async fn do_set_policy(req: &protocols::agent::SetPolicyRequest) -> ttrpc::Result<()> { + let request = serde_json::to_string(req).unwrap(); + let mut policy = AGENT_POLICY.lock().await; + allow_request(&mut policy, "SetPolicyRequest", &request).await?; + policy + .set_policy(&req.policy) + .await + .map_err(|e| ttrpc_error(ttrpc::Code::INVALID_ARGUMENT, e)) +} + /// Example of HTTP response from OPA: {"result":true} #[derive(Debug, Serialize, Deserialize)] struct AllowResponse { @@ -127,7 +159,7 @@ impl AgentPolicy { } /// Ask OPA to check if an API call should be allowed or not. - pub async fn is_allowed_endpoint(&mut self, ep: &str, request: &str) -> bool { + pub async fn allow_request(&mut self, ep: &str, request: &str) -> bool { let post_input = format!("{{\"input\":{request}}}"); self.log_opa_input(ep, &post_input).await; match self.post_query(ep, &post_input).await { diff --git a/src/agent/src/rpc.rs b/src/agent/src/rpc.rs index 955588625..57690cdfb 100644 --- a/src/agent/src/rpc.rs +++ b/src/agent/src/rpc.rs @@ -23,7 +23,7 @@ use ttrpc::{ use anyhow::{anyhow, Context, Result}; use cgroups::freezer::FreezerState; use oci::{LinuxNamespace, Root, Spec}; -use protobuf::{MessageDyn, MessageField}; +use protobuf::MessageField; use protocols::agent::{ AddSwapRequest, AgentDetails, CopyFileRequest, GetIPTablesRequest, GetIPTablesResponse, GuestDetailsResponse, Interfaces, Metrics, OOMEvent, ReadStreamResponse, Routes, @@ -69,7 +69,7 @@ use crate::trace_rpc_call; use crate::tracer::extract_carrier_from_ttrpc; #[cfg(feature = "agent-policy")] -use crate::AGENT_POLICY; +use crate::policy::{do_set_policy, is_allowed}; use opentelemetry::global; use tracing::span; @@ -123,33 +123,15 @@ fn sl() -> slog::Logger { } // Convenience function to wrap an error and response to ttrpc client -fn ttrpc_error(code: ttrpc::Code, err: impl Debug) -> ttrpc::Error { +pub fn ttrpc_error(code: ttrpc::Code, err: impl Debug) -> ttrpc::Error { get_rpc_status(code, format!("{:?}", err)) } #[cfg(not(feature = "agent-policy"))] -async fn is_allowed(_req: &(impl MessageDyn + serde::Serialize)) -> ttrpc::Result<()> { +async fn is_allowed(_req: &impl serde::Serialize) -> ttrpc::Result<()> { Ok(()) } -#[cfg(feature = "agent-policy")] -async fn is_allowed(req: &(impl MessageDyn + serde::Serialize)) -> ttrpc::Result<()> { - let request = serde_json::to_string(req).unwrap(); - let mut policy = AGENT_POLICY.lock().await; - if !policy - .is_allowed_endpoint(req.descriptor_dyn().name(), &request) - .await - { - warn!(sl(), "{} is blocked by policy", req.descriptor_dyn().name()); - Err(ttrpc_error( - ttrpc::Code::PERMISSION_DENIED, - format!("{} is blocked by policy", req.descriptor_dyn().name()), - )) - } else { - Ok(()) - } -} - fn same(e: E) -> E { e } @@ -1439,14 +1421,8 @@ impl agent_ttrpc::AgentService for AgentService { req: protocols::agent::SetPolicyRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "set_policy", req); - is_allowed(&req).await?; - AGENT_POLICY - .lock() - .await - .set_policy(&req.policy) - .await - .map_ttrpc_err(same)?; + do_set_policy(&req).await?; Ok(Empty::new()) }