diff --git a/src/agent/Cargo.toml b/src/agent/Cargo.toml index 34a39b561..87c9ae826 100644 --- a/src/agent/Cargo.toml +++ b/src/agent/Cargo.toml @@ -40,8 +40,9 @@ tokio = { version = "1.39.0", features = ["full"] } tokio-vsock = "0.3.4" netlink-sys = { version = "0.7.0", features = ["tokio_socket"] } -rtnetlink = "0.8.0" -netlink-packet-utils = "0.4.1" +rtnetlink = "0.14.0" +netlink-packet-route = "0.19.0" +netlink-packet-core = "0.7.0" ipnetwork = "0.17.0" # Note: this crate sets the slog 'max_*' features which allows the log level diff --git a/src/agent/src/netlink.rs b/src/agent/src/netlink.rs index 0b6767c24..9f46860cb 100644 --- a/src/agent/src/netlink.rs +++ b/src/agent/src/netlink.rs @@ -6,9 +6,18 @@ use anyhow::{anyhow, Context, Result}; use futures::{future, StreamExt, TryStreamExt}; use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network}; +use netlink_packet_route::address::{AddressAttribute, AddressMessage}; +use netlink_packet_route::link::{LinkAttribute, LinkMessage}; +use netlink_packet_route::neighbour::{self, NeighbourFlag}; +use netlink_packet_route::route::{RouteHeader, RouteProtocol, RouteScope, RouteType}; +use netlink_packet_route::{ + neighbour::{NeighbourAddress, NeighbourAttribute, NeighbourState}, + route::{RouteAddress, RouteAttribute, RouteMessage}, + AddressFamily, +}; use nix::errno::Errno; use protocols::types::{ARPNeighbor, IPAddress, IPFamily, Interface, Route}; -use rtnetlink::{new_connection, packet, IpVersion}; +use rtnetlink::{new_connection, IpVersion}; use std::convert::{TryFrom, TryInto}; use std::fmt; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; @@ -35,6 +44,17 @@ impl fmt::Display for LinkFilter<'_> { } } +const ALL_RULE_FLAGS: [NeighbourFlag; 8] = [ + NeighbourFlag::Use, + NeighbourFlag::Own, + NeighbourFlag::Controller, + NeighbourFlag::Proxy, + NeighbourFlag::ExtLearned, + NeighbourFlag::Offloaded, + NeighbourFlag::Sticky, + NeighbourFlag::Router, +]; + /// A filter to query addresses. pub enum AddressFilter { /// Return addresses that belong to the given interface. @@ -125,6 +145,7 @@ impl Handle { } // Update link + let link = self.find_link(LinkFilter::Address(&iface.hwAddr)).await?; let mut request = self.handle.link().set(link.index()); request.message_mut().header = link.header.clone(); @@ -225,7 +246,7 @@ impl Handle { let request = self.handle.link().get(); let filtered = match filter { - LinkFilter::Name(name) => request.set_name_filter(name.to_owned()), + LinkFilter::Name(name) => request.match_name(name.to_owned()), LinkFilter::Index(index) => request.match_index(index), _ => request, // Post filters }; @@ -233,7 +254,7 @@ impl Handle { let mut stream = filtered.execute(); let next = if let LinkFilter::Address(addr) = filter { - use packet::link::nlas::Nla; + use LinkAttribute as Nla; let mac_addr = parse_mac_address(addr) .with_context(|| format!("Failed to parse MAC address: {}", addr))?; @@ -242,7 +263,7 @@ impl Handle { // we may have to dump link list and then find the target link. stream .try_filter(|f| { - let result = f.nlas.iter().any(|n| match n { + let result = f.attributes.iter().any(|n| match n { Nla::Address(data) => data.eq(&mac_addr), _ => false, }); @@ -278,10 +299,7 @@ impl Handle { Ok(()) } - async fn query_routes( - &self, - ip_version: Option, - ) -> Result> { + async fn query_routes(&self, ip_version: Option) -> Result> { let list = if let Some(ip_version) = ip_version { self.handle .route() @@ -321,36 +339,46 @@ impl Handle { for msg in self.query_routes(None).await? { // Ignore non-main tables - if msg.header.table != packet::constants::RT_TABLE_MAIN { + if msg.header.table != RouteHeader::RT_TABLE_MAIN { continue; } let mut route = Route { - scope: msg.header.scope as _, + scope: u8::from(msg.header.scope) as u32, ..Default::default() }; - if let Some((ip, mask)) = msg.destination_prefix() { - route.dest = format!("{}/{}", ip, mask); - } - - if let Some((ip, mask)) = msg.source_prefix() { - route.source = format!("{}/{}", ip, mask); - } - - if let Some(addr) = msg.gateway() { - route.gateway = addr.to_string(); - - // For gateway, destination is 0.0.0.0 - route.dest = if addr.is_ipv4() { - String::from("0.0.0.0") - } else { - String::from("::1") + for attribute in &msg.attributes { + if let RouteAttribute::Destination(dest) = attribute { + if let Ok(dest) = parse_route_addr(dest) { + route.dest = format!("{}/{}", dest, msg.header.destination_prefix_length); + } } - } - if let Some(index) = msg.output_interface() { - route.device = self.find_link(LinkFilter::Index(index)).await?.name(); + if let RouteAttribute::Source(src) = attribute { + if let Ok(src) = parse_route_addr(src) { + route.source = format!("{}/{}", src, msg.header.source_prefix_length) + } + } + + if let RouteAttribute::Gateway(g) = attribute { + if let Ok(addr) = parse_route_addr(g) { + // For gateway, destination is 0.0.0.0 + if addr.is_ipv4() { + route.dest = String::from("0.0.0.0"); + } else { + route.dest = String::from("::1"); + } + } + + route.gateway = parse_route_addr(g) + .map(|g| g.to_string()) + .unwrap_or_default(); + } + + if let RouteAttribute::Oif(index) = attribute { + route.device = self.find_link(LinkFilter::Index(*index)).await?.name(); + } } if !route.dest.is_empty() { @@ -377,22 +405,22 @@ impl Handle { for route in list { let link = self.find_link(LinkFilter::Name(&route.device)).await?; - const MAIN_TABLE: u8 = packet::constants::RT_TABLE_MAIN; - const UNICAST: u8 = packet::constants::RTN_UNICAST; - const BOOT_PROT: u8 = packet::constants::RTPROT_BOOT; + const MAIN_TABLE: u32 = libc::RT_TABLE_MAIN as u32; + let uni_cast: RouteType = RouteType::from(libc::RTN_UNICAST); + let boot_prot: RouteProtocol = RouteProtocol::from(libc::RTPROT_BOOT); - let scope = route.scope as u8; + let scope = RouteScope::from(route.scope as u8); - use packet::nlas::route::Nla; + use RouteAttribute as Nla; // Build a common indeterminate ip request let request = self .handle .route() .add() - .table(MAIN_TABLE) - .kind(UNICAST) - .protocol(BOOT_PROT) + .table_id(MAIN_TABLE) + .kind(uni_cast) + .protocol(boot_prot) .scope(scope); // `rtnetlink` offers a separate request builders for different IP versions (IP v4 and v6). @@ -417,8 +445,8 @@ impl Handle { } else { request .message_mut() - .nlas - .push(Nla::PrefSource(network.ip().octets().to_vec())); + .attributes + .push(Nla::PrefSource(RouteAddress::from(network.ip()))); } } @@ -428,14 +456,16 @@ impl Handle { } if let Err(rtnetlink::Error::NetlinkError(message)) = request.execute().await { - if Errno::from_i32(message.code.abs()) != Errno::EEXIST { - return Err(anyhow!( - "Failed to add IP v6 route (src: {}, dst: {}, gtw: {},Err: {})", - route.source(), - route.dest(), - route.gateway(), - message - )); + if let Some(code) = message.code { + if Errno::from_i32(code.get()) != Errno::EEXIST { + return Err(anyhow!( + "Failed to add IP v6 route (src: {}, dst: {}, gtw: {},Err: {})", + route.source(), + route.dest(), + route.gateway(), + message + )); + } } } } else { @@ -458,8 +488,8 @@ impl Handle { } else { request .message_mut() - .nlas - .push(Nla::PrefSource(network.ip().octets().to_vec())); + .attributes + .push(RouteAttribute::PrefSource(RouteAddress::from(network.ip()))); } } @@ -469,14 +499,16 @@ impl Handle { } if let Err(rtnetlink::Error::NetlinkError(message)) = request.execute().await { - if Errno::from_i32(message.code.abs()) != Errno::EEXIST { - return Err(anyhow!( - "Failed to add IP v4 route (src: {}, dst: {}, gtw: {},Err: {})", - route.source(), - route.dest(), - route.gateway(), - message - )); + if let Some(code) = message.code { + if Errno::from_i32(code.get()) != Errno::EEXIST { + return Err(anyhow!( + "Failed to add IP v4 route (src: {}, dst: {}, gtw: {},Err: {})", + route.source(), + route.dest(), + route.gateway(), + message + )); + } } } } @@ -487,17 +519,24 @@ impl Handle { async fn delete_routes(&mut self, routes: I) -> Result<()> where - I: IntoIterator, + I: IntoIterator, { for route in routes.into_iter() { - if route.header.protocol == packet::constants::RTPROT_KERNEL { + if u8::from(route.header.protocol) == libc::RTPROT_KERNEL { continue; } - let index = if let Some(index) = route.output_interface() { - index - } else { - continue; + let mut link_index = None; + for routeattr in &route.attributes { + if let RouteAttribute::Oif(index) = routeattr { + link_index = Some(*index); + break; + } + } + + let index = match link_index { + None => continue, + Some(index) => index, }; let link = self.find_link(LinkFilter::Index(index)).await?; @@ -592,52 +631,57 @@ impl Handle { .map_err(|e| anyhow!("Failed to parse IP {}: {:?}", ip_address, e))?; // Import rtnetlink objects that make sense only for this function - use packet::constants::{ - NDA_UNSPEC, NLM_F_ACK, NLM_F_CREATE, NLM_F_REPLACE, NLM_F_REQUEST, - }; - use packet::neighbour::{NeighbourHeader, NeighbourMessage}; - use packet::nlas::neighbour::Nla; - use packet::{NetlinkMessage, NetlinkPayload, RtnlMessage}; + use libc::{NDA_UNSPEC, NLM_F_ACK, NLM_F_CREATE, NLM_F_REPLACE, NLM_F_REQUEST}; + use neighbour::{NeighbourHeader, NeighbourMessage}; + use netlink_packet_core::{NetlinkMessage, NetlinkPayload}; + use netlink_packet_route::RouteNetlinkMessage as RtnlMessage; use rtnetlink::Error; const IFA_F_PERMANENT: u16 = 0x80; // See https://github.com/little-dude/netlink/blob/0185b2952505e271805902bf175fee6ea86c42b8/netlink-packet-route/src/rtnl/constants.rs#L770 + let state = if neigh.state != 0 { + neigh.state as u16 + } else { + IFA_F_PERMANENT + }; let link = self.find_link(LinkFilter::Name(&neigh.device)).await?; - let message = NeighbourMessage { - header: NeighbourHeader { - family: match ip { - IpAddr::V4(_) => packet::AF_INET, - IpAddr::V6(_) => packet::AF_INET6, - } as u8, - ifindex: link.index(), - state: if neigh.state != 0 { - neigh.state as u16 - } else { - IFA_F_PERMANENT - }, - flags: neigh.flags as u8, - ntype: NDA_UNSPEC as u8, - }, - nlas: { - let mut nlas = vec![Nla::Destination(match ip { - IpAddr::V4(v4) => v4.octets().to_vec(), - IpAddr::V6(v6) => v6.octets().to_vec(), - })]; + let mut flags = Vec::new(); + for flag in ALL_RULE_FLAGS { + if (neigh.flags as u8 & (u8::from(flag))) > 0 { + flags.push(flag); + } + } - if !neigh.lladdr.is_empty() { - nlas.push(Nla::LinkLocalAddress( - parse_mac_address(&neigh.lladdr)?.to_vec(), - )); - } + let mut message = NeighbourMessage::default(); - nlas + message.header = NeighbourHeader { + family: match ip { + IpAddr::V4(_) => AddressFamily::Inet, + IpAddr::V6(_) => AddressFamily::Inet6, }, + ifindex: link.index(), + state: NeighbourState::from(state), + flags, + kind: RouteType::from(NDA_UNSPEC as u8), }; + let mut nlas = vec![NeighbourAttribute::Destination(match ip { + IpAddr::V4(ipv4_addr) => NeighbourAddress::from(ipv4_addr), + IpAddr::V6(ipv6_addr) => NeighbourAddress::from(ipv6_addr), + })]; + + if !neigh.lladdr.is_empty() { + nlas.push(NeighbourAttribute::LinkLocalAddress( + parse_mac_address(&neigh.lladdr)?.to_vec(), + )); + } + + message.attributes = nlas; + // Send request and ACK let mut req = NetlinkMessage::from(RtnlMessage::NewNeighbour(message)); - req.header.flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE; + req.header.flags = (NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE | NLM_F_REPLACE) as u16; let mut response = self.handle.request(req)?; while let Some(message) = response.next().await { @@ -700,13 +744,13 @@ fn parse_mac_address(addr: &str) -> Result<[u8; 6]> { } /// Wraps external type with the local one, so we can implement various extensions and type conversions. -struct Link(packet::LinkMessage); +struct Link(LinkMessage); impl Link { /// If name. fn name(&self) -> String { - use packet::nlas::link::Nla; - self.nlas + use LinkAttribute as Nla; + self.attributes .iter() .find_map(|n| { if let Nla::IfName(name) = n { @@ -720,8 +764,8 @@ impl Link { /// Extract Mac address. fn address(&self) -> String { - use packet::nlas::link::Nla; - self.nlas + use LinkAttribute as Nla; + self.attributes .iter() .find_map(|n| { if let Nla::Address(data) = n { @@ -735,7 +779,12 @@ impl Link { /// Returns whether the link is UP fn is_up(&self) -> bool { - self.header.flags & packet::rtnl::constants::IFF_UP > 0 + let mut flags: u32 = 0; + for flag in &self.header.flags { + flags += u32::from(*flag); + } + + flags as i32 & libc::IFF_UP > 0 } fn index(&self) -> u32 { @@ -743,8 +792,8 @@ impl Link { } fn mtu(&self) -> Option { - use packet::nlas::link::Nla; - self.nlas.iter().find_map(|n| { + use LinkAttribute as Nla; + self.attributes.iter().find_map(|n| { if let Nla::Mtu(mtu) = n { Some(*mtu as u64) } else { @@ -754,21 +803,21 @@ impl Link { } } -impl From for Link { - fn from(msg: packet::LinkMessage) -> Self { +impl From for Link { + fn from(msg: LinkMessage) -> Self { Link(msg) } } impl Deref for Link { - type Target = packet::LinkMessage; + type Target = LinkMessage; fn deref(&self) -> &Self::Target { &self.0 } } -struct Address(packet::AddressMessage); +struct Address(AddressMessage); impl TryFrom
for IPAddress { type Error = anyhow::Error; @@ -798,7 +847,7 @@ impl TryFrom
for IPAddress { impl Address { fn is_ipv6(&self) -> bool { - self.0.header.family == packet::constants::AF_INET6 as u8 + u8::from(self.0.header.family) == libc::AF_INET6 as u8 } #[allow(dead_code)] @@ -807,13 +856,13 @@ impl Address { } fn address(&self) -> String { - use packet::nlas::address::Nla; + use AddressAttribute as Nla; self.0 - .nlas + .attributes .iter() .find_map(|n| { if let Nla::Address(data) = n { - format_address(data).ok() + Some(data.to_string()) } else { None } @@ -822,13 +871,13 @@ impl Address { } fn local(&self) -> String { - use packet::nlas::address::Nla; + use AddressAttribute as Nla; self.0 - .nlas + .attributes .iter() .find_map(|n| { if let Nla::Local(data) = n { - format_address(data).ok() + Some(data.to_string()) } else { None } @@ -837,10 +886,21 @@ impl Address { } } +fn parse_route_addr(ra: &RouteAddress) -> Result { + let ipaddr = match ra { + RouteAddress::Inet6(ipv6_addr) => ipv6_addr.to_canonical(), + RouteAddress::Inet(ipv4_addr) => IpAddr::from(*ipv4_addr), + _ => return Err(anyhow!("got invalid route address")), + }; + + Ok(ipaddr) +} + #[cfg(test)] mod tests { use super::*; - use rtnetlink::packet; + use netlink_packet_route::address::AddressHeader; + use netlink_packet_route::link::LinkHeader; use std::iter; use std::process::Command; use test_utils::skip_if_not_root; @@ -853,7 +913,7 @@ mod tests { .await .expect("Loopback not found"); - assert_ne!(message.header, packet::LinkHeader::default()); + assert_ne!(message.header, LinkHeader::default()); assert_eq!(message.name(), "lo"); } @@ -928,7 +988,7 @@ mod tests { assert_ne!(list.len(), 0); for addr in &list { - assert_ne!(addr.0.header, packet::AddressHeader::default()); + assert_ne!(addr.0.header, AddressHeader::default()); } }