diff --git a/src/runtime/virtcontainers/network.go b/src/runtime/virtcontainers/network.go index 0960d47d37..9d6a8faede 100644 --- a/src/runtime/virtcontainers/network.go +++ b/src/runtime/virtcontainers/network.go @@ -315,13 +315,16 @@ func generateVCNetworkStructures(ctx context.Context, endpoints []Endpoint) ([]* routes = append(routes, &r) } - for _, neigh := range endpoint.Properties().Neighbors { - var n pbTypes.ARPNeighbor + gatewaySet := gatewaySetFromRoutes(endpoint.Properties().Routes) - if !validGuestNeighbor(neigh) { + for _, neigh := range endpoint.Properties().Neighbors { + + if !validGuestNeighbor(neigh, gatewaySet) { continue } + var n pbTypes.ARPNeighbor + n.Device = endpoint.Name() n.State = int32(neigh.State) n.Flags = int32(neigh.Flags) diff --git a/src/runtime/virtcontainers/network_darwin.go b/src/runtime/virtcontainers/network_darwin.go index 922da24c63..7a50df8ca0 100644 --- a/src/runtime/virtcontainers/network_darwin.go +++ b/src/runtime/virtcontainers/network_darwin.go @@ -100,6 +100,10 @@ func validGuestRoute(route netlink.Route) bool { return true } -func validGuestNeighbor(route netlink.Neigh) bool { +func validGuestNeighbor(neigh netlink.Neigh, gatewaySet map[string]struct{}) bool { return true } + +func gatewaySetFromRoutes(routes []netlink.Route) map[string]struct{} { + return make(map[string]struct{}) +} diff --git a/src/runtime/virtcontainers/network_linux.go b/src/runtime/virtcontainers/network_linux.go index a9034b20eb..8977938a92 100644 --- a/src/runtime/virtcontainers/network_linux.go +++ b/src/runtime/virtcontainers/network_linux.go @@ -1735,7 +1735,39 @@ func validGuestRoute(route netlink.Route) bool { return route.Protocol != unix.RTPROT_KERNEL } -func validGuestNeighbor(neigh netlink.Neigh) bool { - // We add only static ARP entries - return neigh.State == netlink.NUD_PERMANENT +// neighbor is valid if it is static or a default-gateway +func validGuestNeighbor(neigh netlink.Neigh, gatewaySet map[string]struct{}) bool { + // need a MAC for the guest + if neigh.HardwareAddr == nil { + return false + } + // Keep all static entries + if neigh.State == netlink.NUD_PERMANENT { + return true + } + // Gateway-only exception: allow the default-gateway IP: + // On some setups, the pod subnet gateway does not appear in the host ARP cache as a static entry. + // On these setups an ARP request storm happens when many Kata PODs are started at the same time and they all look for the gateway MAC address. + // This forces the gateway to churn a lot of ARP requests and render the ARP request full, hence dropping some ARP requests. + // Manually pre-populating the ARP entry in the UVM guest ARP cache for that gateway solves that problem. + _, isGw := gatewaySet[neigh.IP.String()] + return isGw && neigh.State == netlink.NUD_REACHABLE +} + +// helper: default routes => set of gateway IP strings +func gatewaySetFromRoutes(routes []netlink.Route) map[string]struct{} { + gatewaySet := make(map[string]struct{}) + for _, route := range routes { + if route.Gw == nil { + continue + } + if route.Dst == nil { + gatewaySet[route.Gw.String()] = struct{}{} + continue + } + if ones, _ := route.Dst.Mask.Size(); ones == 0 { // 0.0.0.0/0 or ::/0 + gatewaySet[route.Gw.String()] = struct{}{} + } + } + return gatewaySet } diff --git a/src/runtime/virtcontainers/network_linux_test.go b/src/runtime/virtcontainers/network_linux_test.go index 3e93c2356c..3f98c47c60 100644 --- a/src/runtime/virtcontainers/network_linux_test.go +++ b/src/runtime/virtcontainers/network_linux_test.go @@ -381,3 +381,236 @@ func TestAddEndpoints_Dan(t *testing.T) { assert.Equal(t, ep.Type(), VfioEndpointType) assert.Equal(t, ep.PciPath().String(), "") } + +func TestValidGuestNeighbor(t *testing.T) { + assert := assert.New(t) + + // Setup gateway set with a known gateway IP + gatewaySet := map[string]struct{}{ + "10.0.0.1": {}, + } + + neighborMAC, _ := net.ParseMAC("aa:bb:cc:dd:ee:ff") + + tests := []struct { + name string + neighbor netlink.Neigh + gateways map[string]struct{} + expected bool + desc string + }{ + { + name: "PermanentNeighborAlwaysValid", + neighbor: netlink.Neigh{ + IP: net.ParseIP("192.168.1.1"), + HardwareAddr: neighborMAC, + State: netlink.NUD_PERMANENT, + }, + gateways: gatewaySet, + expected: true, + desc: "PERMANENT neighbors should always be included regardless of gateway status", + }, + { + name: "NonPermanentGatewayReachable", + neighbor: netlink.Neigh{ + IP: net.ParseIP("10.0.0.1"), + HardwareAddr: neighborMAC, + State: netlink.NUD_REACHABLE, + }, + gateways: gatewaySet, + expected: true, + desc: "Non-PERMANENT gateway neighbor in REACHABLE state should be included", + }, + { + name: "NonPermanentGatewayNotReachable", + neighbor: netlink.Neigh{ + IP: net.ParseIP("10.0.0.1"), + HardwareAddr: neighborMAC, + State: netlink.NUD_STALE, + }, + gateways: gatewaySet, + expected: false, + desc: "Non-PERMANENT gateway neighbor in STALE state should be filtered out", + }, + { + name: "NonPermanentNonGateway", + neighbor: netlink.Neigh{ + IP: net.ParseIP("192.168.1.1"), + HardwareAddr: neighborMAC, + State: netlink.NUD_REACHABLE, + }, + gateways: gatewaySet, + expected: false, + desc: "Non-PERMANENT non-gateway neighbor should be filtered out even if REACHABLE", + }, + { + name: "MissingHardwareAddr", + neighbor: netlink.Neigh{ + IP: net.ParseIP("10.0.0.1"), + State: netlink.NUD_REACHABLE, + }, + gateways: gatewaySet, + expected: false, + desc: "Any neighbor without a MAC address should be filtered out", + }, + { + name: "PermanentNeighborWithoutMAC", + neighbor: netlink.Neigh{ + IP: net.ParseIP("10.0.0.1"), + State: netlink.NUD_PERMANENT, + }, + gateways: gatewaySet, + expected: false, + desc: "Even PERMANENT neighbors need a MAC address", + }, + { + name: "OtherTransientStateAsGateway", + neighbor: netlink.Neigh{ + IP: net.ParseIP("10.0.0.1"), + HardwareAddr: neighborMAC, + State: netlink.NUD_DELAY, + }, + gateways: gatewaySet, + expected: false, + desc: "Transient states like DELAY should be filtered even for gateways", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := validGuestNeighbor(tc.neighbor, tc.gateways) + assert.Equal(tc.expected, result, tc.desc) + }) + } +} + +func TestGatewaySetFromRoutes(t *testing.T) { + assert := assert.New(t) + + gwIPv4 := net.ParseIP("10.0.0.1") + gwIPv6 := net.ParseIP("fe80::1") + nonGwIP := net.ParseIP("192.168.1.1") + + tests := []struct { + name string + routes []netlink.Route + expectedGateways map[string]struct{} + desc string + }{ + { + name: "DefaultRouteIPv4", + routes: []netlink.Route{ + { + Dst: nil, // nil destination means default route + Gw: gwIPv4, + }, + }, + expectedGateways: map[string]struct{}{ + "10.0.0.1": {}, + }, + desc: "IPv4 default route (nil destination) should add gateway to set", + }, + { + name: "DefaultRouteIPv6", + routes: []netlink.Route{ + { + Dst: nil, + Gw: gwIPv6, + }, + }, + expectedGateways: map[string]struct{}{ + "fe80::1": {}, + }, + desc: "IPv6 default route (nil destination) should add gateway to set", + }, + { + name: "ExplicitDefaultRouteIPv4", + routes: []netlink.Route{ + { + Dst: &net.IPNet{ + IP: net.IPv4(0, 0, 0, 0), + Mask: net.CIDRMask(0, 32), + }, + Gw: gwIPv4, + }, + }, + expectedGateways: map[string]struct{}{ + "10.0.0.1": {}, + }, + desc: "Explicit IPv4 default route (0.0.0.0/0) should add gateway to set", + }, + { + name: "ExplicitDefaultRouteIPv6", + routes: []netlink.Route{ + { + Dst: &net.IPNet{ + IP: net.ParseIP("::"), + Mask: net.CIDRMask(0, 128), + }, + Gw: gwIPv6, + }, + }, + expectedGateways: map[string]struct{}{ + "fe80::1": {}, + }, + desc: "Explicit IPv6 default route (::/0) should add gateway to set", + }, + { + name: "NonDefaultRouteFiltered", + routes: []netlink.Route{ + { + Dst: &net.IPNet{ + IP: net.IPv4(172, 16, 0, 0), + Mask: net.CIDRMask(16, 32), + }, + Gw: nonGwIP, + }, + }, + expectedGateways: map[string]struct{}{}, + desc: "Non-default routes should not add gateway to set", + }, + { + name: "RouteWithoutGateway", + routes: []netlink.Route{ + { + Dst: nil, + Gw: nil, + }, + }, + expectedGateways: map[string]struct{}{}, + desc: "Routes without a gateway should be skipped", + }, + { + name: "MultipleDefaultRoutes", + routes: []netlink.Route{ + { + Dst: nil, + Gw: gwIPv4, + }, + { + Dst: nil, + Gw: gwIPv6, + }, + { + Dst: &net.IPNet{ + IP: net.IPv4(172, 16, 0, 0), + Mask: net.CIDRMask(16, 32), + }, + Gw: nonGwIP, + }, + }, + expectedGateways: map[string]struct{}{ + "10.0.0.1": {}, + "fe80::1": {}, + }, + desc: "Multiple default routes should populate the set; non-default routes should be ignored", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := gatewaySetFromRoutes(tc.routes) + assert.Equal(tc.expectedGateways, result, tc.desc) + }) + } +}