proxy/conntrack: consolidate flow cleanup

Signed-off-by: Daman Arora <aroradaman@gmail.com>
This commit is contained in:
Daman Arora 2024-08-30 18:56:21 +05:30
parent b0f823e6cc
commit a6b4aa7005
5 changed files with 342 additions and 355 deletions

View File

@ -20,6 +20,9 @@ limitations under the License.
package conntrack
import (
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/klog/v2"
@ -40,6 +43,7 @@ func CleanStaleEntries(ct Interface, svcPortMap proxy.ServicePortMap,
// may create "black hole" entries for that IP+port. When the service gets endpoints we
// need to delete those entries so further traffic doesn't get dropped.
func deleteStaleServiceConntrackEntries(ct Interface, svcPortMap proxy.ServicePortMap, serviceUpdateResult proxy.UpdateServiceMapResult, endpointsUpdateResult proxy.UpdateEndpointsMapResult) {
var filters []netlink.CustomConntrackFilter
conntrackCleanupServiceIPs := serviceUpdateResult.DeletedUDPClusterIPs
conntrackCleanupServiceNodePorts := sets.New[int]()
isIPv6 := false
@ -48,6 +52,7 @@ func deleteStaleServiceConntrackEntries(ct Interface, svcPortMap proxy.ServicePo
// a UDP service that changes from 0 to non-0 endpoints is newly active.
for _, svcPortName := range endpointsUpdateResult.NewlyActiveUDPServices {
if svcInfo, ok := svcPortMap[svcPortName]; ok {
isIPv6 = netutils.IsIPv6(svcInfo.ClusterIP())
klog.V(4).InfoS("Newly-active UDP service may have stale conntrack entries", "servicePortName", svcPortName)
conntrackCleanupServiceIPs.Insert(svcInfo.ClusterIP().String())
for _, extIP := range svcInfo.ExternalIPs() {
@ -59,23 +64,21 @@ func deleteStaleServiceConntrackEntries(ct Interface, svcPortMap proxy.ServicePo
nodePort := svcInfo.NodePort()
if svcInfo.Protocol() == v1.ProtocolUDP && nodePort != 0 {
conntrackCleanupServiceNodePorts.Insert(nodePort)
isIPv6 = netutils.IsIPv6(svcInfo.ClusterIP())
}
}
}
klog.V(4).InfoS("Deleting conntrack stale entries for services", "IPs", conntrackCleanupServiceIPs.UnsortedList())
for _, svcIP := range conntrackCleanupServiceIPs.UnsortedList() {
if err := ct.ClearEntriesForIP(svcIP, v1.ProtocolUDP); err != nil {
klog.ErrorS(err, "Failed to delete stale service connections", "IP", svcIP)
}
filters = append(filters, filterForIP(svcIP, v1.ProtocolUDP))
}
klog.V(4).InfoS("Deleting conntrack stale entries for services", "nodePorts", conntrackCleanupServiceNodePorts.UnsortedList())
for _, nodePort := range conntrackCleanupServiceNodePorts.UnsortedList() {
err := ct.ClearEntriesForPort(nodePort, isIPv6, v1.ProtocolUDP)
if err != nil {
klog.ErrorS(err, "Failed to clear udp conntrack", "nodePort", nodePort)
}
filters = append(filters, filterForPort(nodePort, v1.ProtocolUDP))
}
if err := ct.ClearEntries(getUnixIPFamily(isIPv6), filters...); err != nil {
klog.ErrorS(err, "Failed to delete stale service connections")
}
}
@ -83,33 +86,98 @@ func deleteStaleServiceConntrackEntries(ct Interface, svcPortMap proxy.ServicePo
// to UDP endpoints. After a UDP endpoint is removed we must flush any conntrack entries
// for it so that if the same client keeps sending, the packets will get routed to a new endpoint.
func deleteStaleEndpointConntrackEntries(ct Interface, svcPortMap proxy.ServicePortMap, endpointsUpdateResult proxy.UpdateEndpointsMapResult) {
var filters []netlink.CustomConntrackFilter
isIPv6 := false
for _, epSvcPair := range endpointsUpdateResult.DeletedUDPEndpoints {
if svcInfo, ok := svcPortMap[epSvcPair.ServicePortName]; ok {
isIPv6 = netutils.IsIPv6(svcInfo.ClusterIP())
endpointIP := proxyutil.IPPart(epSvcPair.Endpoint)
nodePort := svcInfo.NodePort()
var err error
if nodePort != 0 {
err = ct.ClearEntriesForPortNAT(endpointIP, nodePort, v1.ProtocolUDP)
if err != nil {
klog.ErrorS(err, "Failed to delete nodeport-related endpoint connections", "servicePortName", epSvcPair.ServicePortName)
}
}
err = ct.ClearEntriesForNAT(svcInfo.ClusterIP().String(), endpointIP, v1.ProtocolUDP)
if err != nil {
klog.ErrorS(err, "Failed to delete endpoint connections", "servicePortName", epSvcPair.ServicePortName)
filters = append(filters, filterForPortNAT(endpointIP, nodePort, v1.ProtocolUDP))
}
filters = append(filters, filterForNAT(svcInfo.ClusterIP().String(), endpointIP, v1.ProtocolUDP))
for _, extIP := range svcInfo.ExternalIPs() {
err := ct.ClearEntriesForNAT(extIP.String(), endpointIP, v1.ProtocolUDP)
if err != nil {
klog.ErrorS(err, "Failed to delete endpoint connections for externalIP", "servicePortName", epSvcPair.ServicePortName, "externalIP", extIP)
}
filters = append(filters, filterForNAT(extIP.String(), endpointIP, v1.ProtocolUDP))
}
for _, lbIP := range svcInfo.LoadBalancerVIPs() {
err := ct.ClearEntriesForNAT(lbIP.String(), endpointIP, v1.ProtocolUDP)
if err != nil {
klog.ErrorS(err, "Failed to delete endpoint connections for LoadBalancerIP", "servicePortName", epSvcPair.ServicePortName, "loadBalancerIP", lbIP)
}
filters = append(filters, filterForNAT(lbIP.String(), endpointIP, v1.ProtocolUDP))
}
}
}
if err := ct.ClearEntries(getUnixIPFamily(isIPv6), filters...); err != nil {
klog.ErrorS(err, "Failed to delete stale endpoint connections")
}
}
// getUnixIPFamily returns the unix IPFamily constant.
func getUnixIPFamily(isIPv6 bool) uint8 {
if isIPv6 {
return unix.AF_INET6
}
return unix.AF_INET
}
// protocolMap maps v1.Protocol to the Assigned Internet Protocol Number.
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
var protocolMap = map[v1.Protocol]uint8{
v1.ProtocolTCP: unix.IPPROTO_TCP,
v1.ProtocolUDP: unix.IPPROTO_UDP,
v1.ProtocolSCTP: unix.IPPROTO_SCTP,
}
// filterForIP returns *conntrackFilter to delete the conntrack entries for connections
// specified by the destination IP (original direction).
func filterForIP(ip string, protocol v1.Protocol) *conntrackFilter {
klog.V(4).InfoS("Adding conntrack filter for cleanup", "org-dst", ip, "protocol", protocol)
return &conntrackFilter{
protocol: protocolMap[protocol],
original: &connectionTuple{
dstIP: netutils.ParseIPSloppy(ip),
},
}
}
// filterForPort returns *conntrackFilter to delete the conntrack entries for connections
// specified by the destination Port (original direction).
func filterForPort(port int, protocol v1.Protocol) *conntrackFilter {
klog.V(4).InfoS("Adding conntrack filter for cleanup", "org-port-dst", port, "protocol", protocol)
return &conntrackFilter{
protocol: protocolMap[protocol],
original: &connectionTuple{
dstPort: uint16(port),
},
}
}
// filterForNAT returns *conntrackFilter to delete the conntrack entries for connections
// specified by the destination IP (original direction) and source IP (reply direction).
func filterForNAT(origin, dest string, protocol v1.Protocol) *conntrackFilter {
klog.V(4).InfoS("Adding conntrack filter for cleanup", "org-dst", origin, "reply-src", dest, "protocol", protocol)
return &conntrackFilter{
protocol: protocolMap[protocol],
original: &connectionTuple{
dstIP: netutils.ParseIPSloppy(origin),
},
reply: &connectionTuple{
srcIP: netutils.ParseIPSloppy(dest),
},
}
}
// filterForPortNAT returns *conntrackFilter to delete the conntrack entries for connections
// specified by the destination Port (original direction) and source IP (reply direction).
func filterForPortNAT(dest string, port int, protocol v1.Protocol) *conntrackFilter {
klog.V(4).InfoS("Adding conntrack filter for cleanup", "org-port-dst", port, "reply-src", dest, "protocol", protocol)
return &conntrackFilter{
protocol: protocolMap[protocol],
original: &connectionTuple{
dstPort: uint16(port),
},
reply: &connectionTuple{
srcIP: netutils.ParseIPSloppy(dest),
},
}
}

View File

@ -24,11 +24,15 @@ import (
"reflect"
"testing"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/kubernetes/pkg/proxy"
netutils "k8s.io/utils/net"
)
const (
@ -261,3 +265,152 @@ func TestCleanStaleEntries(t *testing.T) {
})
}
}
func TestFilterForIP(t *testing.T) {
testCases := []struct {
name string
ip string
protocol v1.Protocol
expectedFamily netlink.InetFamily
expectedFilter *conntrackFilter
}{
{
name: "ipv4 + UDP",
ip: "10.96.0.10",
protocol: v1.ProtocolUDP,
expectedFilter: &conntrackFilter{
protocol: 17,
original: &connectionTuple{dstIP: netutils.ParseIPSloppy("10.96.0.10")},
},
},
{
name: "ipv6 + TCP",
ip: "2001:db8:1::2",
protocol: v1.ProtocolTCP,
expectedFilter: &conntrackFilter{
protocol: 6,
original: &connectionTuple{dstIP: netutils.ParseIPSloppy("2001:db8:1::2")},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.expectedFilter, filterForIP(tc.ip, tc.protocol))
})
}
}
func TestFilterForPort(t *testing.T) {
testCases := []struct {
name string
port int
protocol v1.Protocol
expectedFilter *conntrackFilter
}{
{
name: "UDP",
port: 5000,
protocol: v1.ProtocolUDP,
expectedFilter: &conntrackFilter{
protocol: 17,
original: &connectionTuple{dstPort: 5000},
},
},
{
name: "SCTP",
port: 3000,
protocol: v1.ProtocolSCTP,
expectedFilter: &conntrackFilter{
protocol: 132,
original: &connectionTuple{dstPort: 3000},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.expectedFilter, filterForPort(tc.port, tc.protocol))
})
}
}
func TestFilterForNAT(t *testing.T) {
testCases := []struct {
name string
orig string
dest string
protocol v1.Protocol
expectedFilter *conntrackFilter
}{
{
name: "ipv4 + SCTP",
orig: "10.96.0.10",
dest: "10.244.0.3",
protocol: v1.ProtocolSCTP,
expectedFilter: &conntrackFilter{
protocol: 132,
original: &connectionTuple{dstIP: netutils.ParseIPSloppy("10.96.0.10")},
reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.244.0.3")},
},
},
{
name: "ipv6 + UDP",
orig: "2001:db8:1::2",
dest: "4001:ab8::2",
protocol: v1.ProtocolUDP,
expectedFilter: &conntrackFilter{
protocol: 17,
original: &connectionTuple{dstIP: netutils.ParseIPSloppy("2001:db8:1::2")},
reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("4001:ab8::2")},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.expectedFilter, filterForNAT(tc.orig, tc.dest, tc.protocol))
})
}
}
func TestFilterForPortNAT(t *testing.T) {
testCases := []struct {
name string
dest string
port int
protocol v1.Protocol
expectedFamily netlink.InetFamily
expectedFilter *conntrackFilter
}{
{
name: "ipv4 + TCP",
dest: "10.96.0.10",
port: 80,
protocol: v1.ProtocolTCP,
expectedFilter: &conntrackFilter{
protocol: 6,
original: &connectionTuple{dstPort: 80},
reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.96.0.10")},
},
},
{
name: "ipv6 + UDP",
dest: "2001:db8:1::2",
port: 8000,
protocol: v1.ProtocolUDP,
expectedFilter: &conntrackFilter{
protocol: 17,
original: &connectionTuple{dstPort: 8000},
reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("2001:db8:1::2")},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(t, tc.expectedFilter, filterForPortNAT(tc.dest, tc.port, tc.protocol))
})
}
}

View File

@ -23,30 +23,15 @@ import (
"fmt"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
v1 "k8s.io/api/core/v1"
"k8s.io/klog/v2"
netutils "k8s.io/utils/net"
)
// Interface for dealing with conntrack
type Interface interface {
// ClearEntriesForIP deletes conntrack entries for connections of the given
// protocol, to the given IP.
ClearEntriesForIP(ip string, protocol v1.Protocol) error
// ClearEntriesForPort deletes conntrack entries for connections of the given
// protocol and IP family, to the given port.
ClearEntriesForPort(port int, isIPv6 bool, protocol v1.Protocol) error
// ClearEntriesForNAT deletes conntrack entries for connections of the given
// protocol, which had been DNATted from origin to dest.
ClearEntriesForNAT(origin, dest string, protocol v1.Protocol) error
// ClearEntriesForPortNAT deletes conntrack entries for connections of the given
// protocol, which had been DNATted from the given port (on any IP) to dest.
ClearEntriesForPortNAT(dest string, port int, protocol v1.Protocol) error
// ClearEntries deletes conntrack entries for connections of the given IP family,
// filtered by the given filters.
ClearEntries(ipFamily uint8, filters ...netlink.CustomConntrackFilter) error
}
// netlinkHandler allows consuming real and mockable implementation for testing.
@ -69,110 +54,17 @@ func newConntracker(handler netlinkHandler) Interface {
return &conntracker{handler: handler}
}
// getNetlinkFamily returns the Netlink IP family constant
func getNetlinkFamily(isIPv6 bool) netlink.InetFamily {
if isIPv6 {
return unix.AF_INET6
// ClearEntries deletes conntrack entries for connections of the given IP family,
// filtered by the given filters.
func (ct *conntracker) ClearEntries(ipFamily uint8, filters ...netlink.CustomConntrackFilter) error {
if len(filters) == 0 {
klog.V(7).InfoS("no conntrack filters provided")
return nil
}
return unix.AF_INET
}
// protocolMap maps v1.Protocol to the Assigned Internet Protocol Number.
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
var protocolMap = map[v1.Protocol]uint8{
v1.ProtocolTCP: unix.IPPROTO_TCP,
v1.ProtocolUDP: unix.IPPROTO_UDP,
v1.ProtocolSCTP: unix.IPPROTO_SCTP,
}
// ClearEntriesForIP delete the conntrack entries for connections specified by the
// destination IP(original direction).
func (ct *conntracker) ClearEntriesForIP(ip string, protocol v1.Protocol) error {
filter := &conntrackFilter{
protocol: protocolMap[protocol],
original: &connectionTuple{
dstIP: netutils.ParseIPSloppy(ip),
},
}
klog.V(4).InfoS("Clearing conntrack entries", "org-dst", ip, "protocol", protocol)
n, err := ct.handler.ConntrackDeleteFilters(netlink.ConntrackTable, getNetlinkFamily(netutils.IsIPv6String(ip)), filter)
n, err := ct.handler.ConntrackDeleteFilters(netlink.ConntrackTable, netlink.InetFamily(ipFamily), filters...)
if err != nil {
// TODO: Better handling for deletion failure. When failure occur, stale udp connection may not get flushed.
// These stale udp connection will keep black hole traffic. Making this a best effort operation for now, since it
// is expensive to baby-sit all udp connections to kubernetes services.
return fmt.Errorf("error deleting connection tracking state for %s service IP: %s, error: %w", protocol, ip, err)
}
klog.V(4).InfoS("Cleared conntrack entries", "count", n)
return nil
}
// ClearEntriesForPort delete the conntrack entries for connections specified by the
// destination Port(original direction) and IPFamily.
func (ct *conntracker) ClearEntriesForPort(port int, isIPv6 bool, protocol v1.Protocol) error {
filter := &conntrackFilter{
protocol: protocolMap[protocol],
original: &connectionTuple{
dstPort: uint16(port),
},
}
if port <= 0 {
return fmt.Errorf("wrong port number. The port number must be greater than zero")
}
klog.V(4).InfoS("Clearing conntrack entries", "org-port-dst", port, "protocol", protocol)
n, err := ct.handler.ConntrackDeleteFilters(netlink.ConntrackTable, getNetlinkFamily(isIPv6), filter)
if err != nil {
return fmt.Errorf("error deleting connection tracking state for %s port: %d, error: %w", protocol, port, err)
}
klog.V(4).InfoS("Cleared conntrack entries", "count", n)
return nil
}
// ClearEntriesForNAT delete the conntrack entries for connections specified by the
// destination IP(original direction) and source IP(reply direction).
func (ct *conntracker) ClearEntriesForNAT(origin, dest string, protocol v1.Protocol) error {
filter := &conntrackFilter{
protocol: protocolMap[protocol],
original: &connectionTuple{
dstIP: netutils.ParseIPSloppy(origin),
},
reply: &connectionTuple{
srcIP: netutils.ParseIPSloppy(dest),
},
}
klog.V(4).InfoS("Clearing conntrack entries", "org-dst", origin, "reply-src", dest, "protocol", protocol)
n, err := ct.handler.ConntrackDeleteFilters(netlink.ConntrackTable, getNetlinkFamily(netutils.IsIPv6String(origin)), filter)
if err != nil {
// TODO: Better handling for deletion failure. When failure occur, stale udp connection may not get flushed.
// These stale udp connection will keep black hole traffic. Making this a best effort operation for now, since it
// is expensive to baby sit all udp connections to kubernetes services.
return fmt.Errorf("error deleting conntrack entries for %s peer {%s, %s}, error: %w", protocol, origin, dest, err)
}
klog.V(4).InfoS("Cleared conntrack entries", "count", n)
return nil
}
// ClearEntriesForPortNAT delete the conntrack entries for connections specified by the
// destination Port(original direction) and source IP(reply direction).
func (ct *conntracker) ClearEntriesForPortNAT(dest string, port int, protocol v1.Protocol) error {
if port <= 0 {
return fmt.Errorf("wrong port number. The port number must be greater than zero")
}
filter := &conntrackFilter{
protocol: protocolMap[protocol],
original: &connectionTuple{
dstPort: uint16(port),
},
reply: &connectionTuple{
srcIP: netutils.ParseIPSloppy(dest),
},
}
klog.V(4).InfoS("Clearing conntrack entries", "reply-src", dest, "org-port-dst", port, "protocol", protocol)
n, err := ct.handler.ConntrackDeleteFilters(netlink.ConntrackTable, getNetlinkFamily(netutils.IsIPv6String(dest)), filter)
if err != nil {
return fmt.Errorf("error deleting conntrack entries for %s port: %d, error: %w", protocol, port, err)
return fmt.Errorf("error deleting conntrack entries, error: %w", err)
}
klog.V(4).InfoS("Cleared conntrack entries", "count", n)
return nil

View File

@ -26,51 +26,63 @@ import (
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
v1 "k8s.io/api/core/v1"
netutils "k8s.io/utils/net"
)
type fakeHandler struct {
tableType netlink.ConntrackTableType
family netlink.InetFamily
filter *conntrackFilter
ipFamily netlink.InetFamily
filters []*conntrackFilter
}
func (f *fakeHandler) ConntrackDeleteFilters(tableType netlink.ConntrackTableType, family netlink.InetFamily, filters ...netlink.CustomConntrackFilter) (uint, error) {
func (f *fakeHandler) ConntrackDeleteFilters(tableType netlink.ConntrackTableType, family netlink.InetFamily, netlinkFilters ...netlink.CustomConntrackFilter) (uint, error) {
f.tableType = tableType
f.family = family
f.filter = filters[0].(*conntrackFilter)
return 1, nil
f.ipFamily = family
f.filters = make([]*conntrackFilter, 0, len(netlinkFilters))
for _, netlinkFilter := range netlinkFilters {
f.filters = append(f.filters, netlinkFilter.(*conntrackFilter))
}
return uint(len(f.filters)), nil
}
var _ netlinkHandler = (*fakeHandler)(nil)
func TestConntracker_ClearEntriesForIP(t *testing.T) {
func TestConntracker_ClearEntries(t *testing.T) {
testCases := []struct {
name string
ip string
protocol v1.Protocol
expectedFamily netlink.InetFamily
expectedFilter *conntrackFilter
name string
ipFamily uint8
filters []netlink.CustomConntrackFilter
}{
{
name: "ipv4 + UDP",
ip: "10.96.0.10",
protocol: v1.ProtocolUDP,
expectedFamily: unix.AF_INET,
expectedFilter: &conntrackFilter{
protocol: 17,
original: &connectionTuple{dstIP: netutils.ParseIPSloppy("10.96.0.10")},
name: "single IPv6 filter",
ipFamily: unix.AF_INET6,
filters: []netlink.CustomConntrackFilter{
&conntrackFilter{
protocol: 17,
original: &connectionTuple{dstPort: 8000},
reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("2001:db8:1::2")},
},
},
},
{
name: "ipv6 + TCP",
ip: "2001:db8:1::2",
protocol: v1.ProtocolTCP,
expectedFamily: unix.AF_INET6,
expectedFilter: &conntrackFilter{
protocol: 6,
original: &connectionTuple{dstIP: netutils.ParseIPSloppy("2001:db8:1::2")},
name: "multiple IPv4 filters",
ipFamily: unix.AF_INET,
filters: []netlink.CustomConntrackFilter{
&conntrackFilter{
protocol: 6,
original: &connectionTuple{dstPort: 3000},
},
&conntrackFilter{
protocol: 17,
original: &connectionTuple{dstPort: 5000},
reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.244.0.3")},
},
&conntrackFilter{
protocol: 132,
original: &connectionTuple{dstIP: netutils.ParseIPSloppy("10.96.0.10")},
reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.244.0.3")},
},
},
},
}
@ -79,149 +91,13 @@ func TestConntracker_ClearEntriesForIP(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
handler := &fakeHandler{}
ct := newConntracker(handler)
require.NoError(t, ct.ClearEntriesForIP(tc.ip, tc.protocol))
require.NoError(t, ct.ClearEntries(tc.ipFamily, tc.filters...))
require.Equal(t, netlink.ConntrackTableType(netlink.ConntrackTable), handler.tableType)
require.Equal(t, tc.expectedFamily, handler.family)
require.Equal(t, tc.expectedFilter, handler.filter)
})
}
}
func TestConntracker_ClearEntriesForPort(t *testing.T) {
testCases := []struct {
name string
port int
isIPv6 bool
protocol v1.Protocol
expectedFamily netlink.InetFamily
expectedFilter *conntrackFilter
}{
{
name: "ipv4 + UDP",
port: 5000,
isIPv6: false,
protocol: v1.ProtocolUDP,
expectedFamily: unix.AF_INET,
expectedFilter: &conntrackFilter{
protocol: 17,
original: &connectionTuple{dstPort: 5000},
},
},
{
name: "ipv6 + SCTP",
port: 3000,
isIPv6: true,
protocol: v1.ProtocolSCTP,
expectedFamily: unix.AF_INET6,
expectedFilter: &conntrackFilter{
protocol: 132,
original: &connectionTuple{dstPort: 3000},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
handler := &fakeHandler{}
ct := newConntracker(handler)
require.NoError(t, ct.ClearEntriesForPort(tc.port, tc.isIPv6, tc.protocol))
require.Equal(t, netlink.ConntrackTableType(netlink.ConntrackTable), handler.tableType)
require.Equal(t, tc.expectedFamily, handler.family)
require.Equal(t, tc.expectedFilter, handler.filter)
})
}
}
func TestConntracker_ClearEntriesForNAT(t *testing.T) {
testCases := []struct {
name string
src string
dest string
protocol v1.Protocol
expectedFamily netlink.InetFamily
expectedFilter *conntrackFilter
}{
{
name: "ipv4 + SCTP",
src: "10.96.0.10",
dest: "10.244.0.3",
protocol: v1.ProtocolSCTP,
expectedFamily: unix.AF_INET,
expectedFilter: &conntrackFilter{
protocol: 132,
original: &connectionTuple{dstIP: netutils.ParseIPSloppy("10.96.0.10")},
reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.244.0.3")},
},
},
{
name: "ipv6 + UDP",
src: "2001:db8:1::2",
dest: "4001:ab8::2",
protocol: v1.ProtocolUDP,
expectedFamily: unix.AF_INET6,
expectedFilter: &conntrackFilter{
protocol: 17,
original: &connectionTuple{dstIP: netutils.ParseIPSloppy("2001:db8:1::2")},
reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("4001:ab8::2")},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
handler := &fakeHandler{}
ct := newConntracker(handler)
require.NoError(t, ct.ClearEntriesForNAT(tc.src, tc.dest, tc.protocol))
require.Equal(t, netlink.ConntrackTableType(netlink.ConntrackTable), handler.tableType)
require.Equal(t, tc.expectedFamily, handler.family)
require.Equal(t, tc.expectedFilter, handler.filter)
})
}
}
func TestConntracker_ClearEntriesForPortNAT(t *testing.T) {
testCases := []struct {
name string
ip string
port int
protocol v1.Protocol
expectedFamily netlink.InetFamily
expectedFilter *conntrackFilter
}{
{
name: "ipv4 + TCP",
ip: "10.96.0.10",
port: 80,
protocol: v1.ProtocolTCP,
expectedFamily: unix.AF_INET,
expectedFilter: &conntrackFilter{
protocol: 6,
original: &connectionTuple{dstPort: 80},
reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.96.0.10")},
},
},
{
name: "ipv6 + UDP",
ip: "2001:db8:1::2",
port: 8000,
protocol: v1.ProtocolUDP,
expectedFamily: unix.AF_INET6,
expectedFilter: &conntrackFilter{
protocol: 17,
original: &connectionTuple{dstPort: 8000},
reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("2001:db8:1::2")},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
handler := &fakeHandler{}
ct := newConntracker(handler)
require.NoError(t, ct.ClearEntriesForPortNAT(tc.ip, tc.port, tc.protocol))
require.Equal(t, netlink.ConntrackTableType(netlink.ConntrackTable), handler.tableType)
require.Equal(t, tc.expectedFamily, handler.family)
require.Equal(t, tc.expectedFilter, handler.filter)
require.Equal(t, netlink.InetFamily(tc.ipFamily), handler.ipFamily)
require.Equal(t, len(tc.filters), len(handler.filters))
for i := 0; i < len(tc.filters); i++ {
require.Equal(t, tc.filters[i], handler.filters[i])
}
})
}
}

View File

@ -22,6 +22,8 @@ package conntrack
import (
"fmt"
"github.com/vishvananda/netlink"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/sets"
)
@ -51,48 +53,44 @@ func (fake *FakeInterface) Reset() {
fake.ClearedPortNATs = make(map[int]string)
}
// ClearEntriesForIP is part of Interface
func (fake *FakeInterface) ClearEntriesForIP(ip string, protocol v1.Protocol) error {
if protocol != v1.ProtocolUDP {
return fmt.Errorf("FakeInterface currently only supports UDP")
}
// ClearEntries is part of Interface
func (fake *FakeInterface) ClearEntries(_ uint8, filters ...netlink.CustomConntrackFilter) error {
for _, anyFilter := range filters {
filter := anyFilter.(*conntrackFilter)
if filter.protocol != protocolMap[v1.ProtocolUDP] {
return fmt.Errorf("FakeInterface currently only supports UDP")
}
fake.ClearedIPs.Insert(ip)
return nil
}
// ClearEntriesForPort is part of Interface
func (fake *FakeInterface) ClearEntriesForPort(port int, isIPv6 bool, protocol v1.Protocol) error {
if protocol != v1.ProtocolUDP {
return fmt.Errorf("FakeInterface currently only supports UDP")
}
fake.ClearedPorts.Insert(port)
return nil
}
// ClearEntriesForNAT is part of Interface
func (fake *FakeInterface) ClearEntriesForNAT(origin, dest string, protocol v1.Protocol) error {
if protocol != v1.ProtocolUDP {
return fmt.Errorf("FakeInterface currently only supports UDP")
}
if previous, exists := fake.ClearedNATs[origin]; exists && previous != dest {
return fmt.Errorf("ClearEntriesForNAT called with same origin (%s), different destination (%s / %s)", origin, previous, dest)
}
fake.ClearedNATs[origin] = dest
return nil
}
// ClearEntriesForPortNAT is part of Interface
func (fake *FakeInterface) ClearEntriesForPortNAT(dest string, port int, protocol v1.Protocol) error {
if protocol != v1.ProtocolUDP {
return fmt.Errorf("FakeInterface currently only supports UDP")
}
if previous, exists := fake.ClearedPortNATs[port]; exists && previous != dest {
return fmt.Errorf("ClearEntriesForPortNAT called with same port (%d), different destination (%s / %s)", port, previous, dest)
}
fake.ClearedPortNATs[port] = dest
// record IP and Port entries
if filter.original != nil && filter.reply == nil {
if filter.original.dstIP != nil {
fake.ClearedIPs.Insert(filter.original.dstIP.String())
}
if filter.original.dstPort != 0 {
fake.ClearedPorts.Insert(int(filter.original.dstPort))
}
}
// record NAT and NATPort entries
if filter.original != nil && filter.reply != nil {
if filter.original.dstIP != nil && filter.reply.srcIP != nil {
origin := filter.original.dstIP.String()
dest := filter.reply.srcIP.String()
if previous, exists := fake.ClearedNATs[origin]; exists && previous != dest {
return fmt.Errorf("filter for NAT passed with same origin (%s), different destination (%s / %s)", origin, previous, dest)
}
fake.ClearedNATs[filter.original.dstIP.String()] = filter.reply.srcIP.String()
}
if filter.original.dstPort != 0 && filter.reply.srcIP != nil {
dest := filter.reply.srcIP.String()
port := int(filter.original.dstPort)
if previous, exists := fake.ClearedPortNATs[port]; exists && previous != dest {
return fmt.Errorf("filter for PortNAT passed with same port (%d), different destination (%s / %s)", port, previous, dest)
}
fake.ClearedPortNATs[port] = dest
}
}
}
return nil
}