diff --git a/src/agent/src/config.rs b/src/agent/src/config.rs index 108464a4a5..1089afc780 100644 --- a/src/agent/src/config.rs +++ b/src/agent/src/config.rs @@ -5,6 +5,7 @@ use crate::tracer; use anyhow::{bail, ensure, Context, Result}; use serde::Deserialize; +use std::collections::HashSet; use std::env; use std::fs; use std::str::FromStr; @@ -55,6 +56,12 @@ pub struct EndpointsConfig { pub allowed: Vec, } +#[derive(Debug, Default)] +pub struct AgentEndpoints { + pub allowed: HashSet, + pub all_allowed: bool, +} + #[derive(Debug)] pub struct AgentConfig { pub debug_console: bool, @@ -67,7 +74,7 @@ pub struct AgentConfig { pub server_addr: String, pub unified_cgroup_hierarchy: bool, pub tracing: tracer::TraceType, - pub endpoints: EndpointsConfig, + pub endpoints: AgentEndpoints, } #[derive(Debug, Deserialize)] @@ -171,7 +178,13 @@ impl FromStr for AgentConfig { config_override!(agent_config_builder, agent_config, server_addr); config_override!(agent_config_builder, agent_config, unified_cgroup_hierarchy); config_override!(agent_config_builder, agent_config, tracing); - config_override!(agent_config_builder, agent_config, endpoints); + + // Populate the allowed endpoints hash set, if we got any from the config file. + if let Some(endpoints) = agent_config_builder.endpoints { + for ep in endpoints.allowed { + agent_config.endpoints.allowed.insert(ep); + } + } Ok(agent_config) } @@ -270,6 +283,9 @@ impl AgentConfig { } } + // We did not get a configuration file: allow all endpoints. + config.endpoints.all_allowed = true; + Ok(config) } @@ -278,6 +294,10 @@ impl AgentConfig { let config = fs::read_to_string(file)?; AgentConfig::from_str(&config) } + + pub fn is_allowed_endpoint(&self, ep: &str) -> bool { + self.endpoints.all_allowed || self.endpoints.allowed.contains(ep) + } } #[instrument] @@ -1353,12 +1373,18 @@ Caused by: ) .unwrap(); + // Verify that the all_allowed flag is false + assert!(!config.endpoints.all_allowed); + // Verify that the override worked assert!(config.dev_mode); assert_eq!(config.server_addr, "vsock://8:2048"); assert_eq!( config.endpoints.allowed, vec!["CreateContainer".to_string(), "StartContainer".to_string()] + .iter() + .cloned() + .collect() ); // Verify that the default values are valid diff --git a/src/agent/src/rpc.rs b/src/agent/src/rpc.rs index 8154975b74..d48815db63 100644 --- a/src/agent/src/rpc.rs +++ b/src/agent/src/rpc.rs @@ -20,7 +20,7 @@ use ttrpc::{ use anyhow::{anyhow, Context, Result}; use oci::{LinuxNamespace, Root, Spec}; -use protobuf::{RepeatedField, SingularPtrField}; +use protobuf::{Message, RepeatedField, SingularPtrField}; use protocols::agent::{ AddSwapRequest, AgentDetails, CopyFileRequest, GuestDetailsResponse, Interfaces, Metrics, OOMEvent, ReadStreamResponse, Routes, StatsContainerResponse, WaitProcessResponse, @@ -86,6 +86,21 @@ macro_rules! sl { }; } +macro_rules! is_allowed { + ($req:ident) => { + if !AGENT_CONFIG + .read() + .await + .is_allowed_endpoint($req.descriptor().name()) + { + return Err(ttrpc_error( + ttrpc::Code::UNIMPLEMENTED, + format!("{} is blocked", $req.descriptor().name()), + )); + } + }; +} + #[derive(Clone, Debug)] pub struct AgentService { sandbox: Arc>, @@ -531,6 +546,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::CreateContainerRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "create_container", req); + is_allowed!(req); match self.do_create_container(req).await { Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Ok(_) => Ok(Empty::new()), @@ -543,6 +559,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::StartContainerRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "start_container", req); + is_allowed!(req); match self.do_start_container(req).await { Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Ok(_) => Ok(Empty::new()), @@ -555,6 +572,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::RemoveContainerRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "remove_container", req); + is_allowed!(req); match self.do_remove_container(req).await { Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Ok(_) => Ok(Empty::new()), @@ -567,6 +585,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::ExecProcessRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "exec_process", req); + is_allowed!(req); match self.do_exec_process(req).await { Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Ok(_) => Ok(Empty::new()), @@ -579,6 +598,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::SignalProcessRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "signal_process", req); + is_allowed!(req); match self.do_signal_process(req).await { Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), Ok(_) => Ok(Empty::new()), @@ -591,6 +611,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::WaitProcessRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "wait_process", req); + is_allowed!(req); self.do_wait_process(req) .await .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())) @@ -602,6 +623,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::UpdateContainerRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "update_container", req); + is_allowed!(req); let cid = req.container_id.clone(); let res = req.resources; @@ -637,6 +659,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::StatsContainerRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "stats_container", req); + is_allowed!(req); let cid = req.container_id; let s = Arc::clone(&self.sandbox); let mut sandbox = s.lock().await; @@ -658,6 +681,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::PauseContainerRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "pause_container", req); + is_allowed!(req); let cid = req.get_container_id(); let s = Arc::clone(&self.sandbox); let mut sandbox = s.lock().await; @@ -681,6 +705,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::ResumeContainerRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "resume_container", req); + is_allowed!(req); let cid = req.get_container_id(); let s = Arc::clone(&self.sandbox); let mut sandbox = s.lock().await; @@ -703,6 +728,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { _ctx: &TtrpcContext, req: protocols::agent::WriteStreamRequest, ) -> ttrpc::Result { + is_allowed!(req); self.do_write_stream(req) .await .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())) @@ -713,6 +739,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { _ctx: &TtrpcContext, req: protocols::agent::ReadStreamRequest, ) -> ttrpc::Result { + is_allowed!(req); self.do_read_stream(req, true) .await .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())) @@ -723,6 +750,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { _ctx: &TtrpcContext, req: protocols::agent::ReadStreamRequest, ) -> ttrpc::Result { + is_allowed!(req); self.do_read_stream(req, false) .await .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())) @@ -734,6 +762,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::CloseStdinRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "close_stdin", req); + is_allowed!(req); let cid = req.container_id.clone(); let eid = req.exec_id; @@ -770,6 +799,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::TtyWinResizeRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "tty_win_resize", req); + is_allowed!(req); let cid = req.container_id.clone(); let eid = req.exec_id.clone(); @@ -810,6 +840,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::UpdateInterfaceRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "update_interface", req); + is_allowed!(req); let interface = req.interface.into_option().ok_or_else(|| { ttrpc_error( @@ -837,6 +868,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::UpdateRoutesRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "update_routes", req); + is_allowed!(req); let new_routes = req .routes @@ -877,6 +909,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::ListInterfacesRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "list_interfaces", req); + is_allowed!(req); let list = self .sandbox @@ -904,6 +937,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::ListRoutesRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "list_routes", req); + is_allowed!(req); let list = self .sandbox @@ -926,14 +960,16 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::StartTracingRequest, ) -> ttrpc::Result { info!(sl!(), "start_tracing {:?}", req); + is_allowed!(req); Ok(Empty::new()) } async fn stop_tracing( &self, _ctx: &TtrpcContext, - _req: protocols::agent::StopTracingRequest, + req: protocols::agent::StopTracingRequest, ) -> ttrpc::Result { + is_allowed!(req); Ok(Empty::new()) } @@ -943,6 +979,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::CreateSandboxRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "create_sandbox", req); + is_allowed!(req); { let sandbox = self.sandbox.clone(); @@ -1008,6 +1045,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::DestroySandboxRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "destroy_sandbox", req); + is_allowed!(req); let s = Arc::clone(&self.sandbox); let mut sandbox = s.lock().await; @@ -1029,6 +1067,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::AddARPNeighborsRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "add_arp_neighbors", req); + is_allowed!(req); let neighs = req .neighbors @@ -1062,6 +1101,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { ctx: &TtrpcContext, req: protocols::agent::OnlineCPUMemRequest, ) -> ttrpc::Result { + is_allowed!(req); let s = Arc::clone(&self.sandbox); let sandbox = s.lock().await; trace_rpc_call!(ctx, "online_cpu_mem", req); @@ -1079,6 +1119,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::ReseedRandomDevRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "reseed_random_dev", req); + is_allowed!(req); random::reseed_rng(req.data.as_slice()) .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))?; @@ -1092,6 +1133,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::GuestDetailsRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "get_guest_details", req); + is_allowed!(req); info!(sl!(), "get guest details!"); let mut resp = GuestDetailsResponse::new(); @@ -1120,6 +1162,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::MemHotplugByProbeRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "mem_hotplug_by_probe", req); + is_allowed!(req); do_mem_hotplug_by_probe(&req.memHotplugProbeAddr) .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))?; @@ -1133,6 +1176,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::SetGuestDateTimeRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "set_guest_date_time", req); + is_allowed!(req); do_set_guest_date_time(req.Sec, req.Usec) .map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))?; @@ -1146,6 +1190,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::CopyFileRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "copy_file", req); + is_allowed!(req); do_copy_file(&req).map_err(|e| ttrpc_error(ttrpc::Code::INTERNAL, e.to_string()))?; @@ -1158,6 +1203,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::GetMetricsRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "get_metrics", req); + is_allowed!(req); match get_metrics(&req) { Err(e) => Err(ttrpc_error(ttrpc::Code::INTERNAL, e.to_string())), @@ -1172,8 +1218,9 @@ impl protocols::agent_ttrpc::AgentService for AgentService { async fn get_oom_event( &self, _ctx: &TtrpcContext, - _req: protocols::agent::GetOOMEventRequest, + req: protocols::agent::GetOOMEventRequest, ) -> ttrpc::Result { + is_allowed!(req); let sandbox = self.sandbox.clone(); let s = sandbox.lock().await; let event_rx = &s.event_rx.clone(); @@ -1199,6 +1246,7 @@ impl protocols::agent_ttrpc::AgentService for AgentService { req: protocols::agent::AddSwapRequest, ) -> ttrpc::Result { trace_rpc_call!(ctx, "add_swap", req); + is_allowed!(req); do_add_swap(&self.sandbox, &req) .await