From a6b4aa70051fceb7e4d2d9c487cefcc4b657350a Mon Sep 17 00:00:00 2001 From: Daman Arora Date: Fri, 30 Aug 2024 18:56:21 +0530 Subject: [PATCH] proxy/conntrack: consolidate flow cleanup Signed-off-by: Daman Arora --- pkg/proxy/conntrack/cleanup.go | 118 +++++++++++--- pkg/proxy/conntrack/cleanup_test.go | 153 ++++++++++++++++++ pkg/proxy/conntrack/conntrack.go | 130 ++-------------- pkg/proxy/conntrack/conntrack_test.go | 214 ++++++-------------------- pkg/proxy/conntrack/fake.go | 82 +++++----- 5 files changed, 342 insertions(+), 355 deletions(-) diff --git a/pkg/proxy/conntrack/cleanup.go b/pkg/proxy/conntrack/cleanup.go index 27b38685e5e..6debead9a38 100644 --- a/pkg/proxy/conntrack/cleanup.go +++ b/pkg/proxy/conntrack/cleanup.go @@ -20,6 +20,9 @@ limitations under the License. package conntrack import ( + "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" + v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/klog/v2" @@ -40,6 +43,7 @@ func CleanStaleEntries(ct Interface, svcPortMap proxy.ServicePortMap, // may create "black hole" entries for that IP+port. When the service gets endpoints we // need to delete those entries so further traffic doesn't get dropped. func deleteStaleServiceConntrackEntries(ct Interface, svcPortMap proxy.ServicePortMap, serviceUpdateResult proxy.UpdateServiceMapResult, endpointsUpdateResult proxy.UpdateEndpointsMapResult) { + var filters []netlink.CustomConntrackFilter conntrackCleanupServiceIPs := serviceUpdateResult.DeletedUDPClusterIPs conntrackCleanupServiceNodePorts := sets.New[int]() isIPv6 := false @@ -48,6 +52,7 @@ func deleteStaleServiceConntrackEntries(ct Interface, svcPortMap proxy.ServicePo // a UDP service that changes from 0 to non-0 endpoints is newly active. for _, svcPortName := range endpointsUpdateResult.NewlyActiveUDPServices { if svcInfo, ok := svcPortMap[svcPortName]; ok { + isIPv6 = netutils.IsIPv6(svcInfo.ClusterIP()) klog.V(4).InfoS("Newly-active UDP service may have stale conntrack entries", "servicePortName", svcPortName) conntrackCleanupServiceIPs.Insert(svcInfo.ClusterIP().String()) for _, extIP := range svcInfo.ExternalIPs() { @@ -59,23 +64,21 @@ func deleteStaleServiceConntrackEntries(ct Interface, svcPortMap proxy.ServicePo nodePort := svcInfo.NodePort() if svcInfo.Protocol() == v1.ProtocolUDP && nodePort != 0 { conntrackCleanupServiceNodePorts.Insert(nodePort) - isIPv6 = netutils.IsIPv6(svcInfo.ClusterIP()) } } } klog.V(4).InfoS("Deleting conntrack stale entries for services", "IPs", conntrackCleanupServiceIPs.UnsortedList()) for _, svcIP := range conntrackCleanupServiceIPs.UnsortedList() { - if err := ct.ClearEntriesForIP(svcIP, v1.ProtocolUDP); err != nil { - klog.ErrorS(err, "Failed to delete stale service connections", "IP", svcIP) - } + filters = append(filters, filterForIP(svcIP, v1.ProtocolUDP)) } klog.V(4).InfoS("Deleting conntrack stale entries for services", "nodePorts", conntrackCleanupServiceNodePorts.UnsortedList()) for _, nodePort := range conntrackCleanupServiceNodePorts.UnsortedList() { - err := ct.ClearEntriesForPort(nodePort, isIPv6, v1.ProtocolUDP) - if err != nil { - klog.ErrorS(err, "Failed to clear udp conntrack", "nodePort", nodePort) - } + filters = append(filters, filterForPort(nodePort, v1.ProtocolUDP)) + } + + if err := ct.ClearEntries(getUnixIPFamily(isIPv6), filters...); err != nil { + klog.ErrorS(err, "Failed to delete stale service connections") } } @@ -83,33 +86,98 @@ func deleteStaleServiceConntrackEntries(ct Interface, svcPortMap proxy.ServicePo // to UDP endpoints. After a UDP endpoint is removed we must flush any conntrack entries // for it so that if the same client keeps sending, the packets will get routed to a new endpoint. func deleteStaleEndpointConntrackEntries(ct Interface, svcPortMap proxy.ServicePortMap, endpointsUpdateResult proxy.UpdateEndpointsMapResult) { + var filters []netlink.CustomConntrackFilter + isIPv6 := false for _, epSvcPair := range endpointsUpdateResult.DeletedUDPEndpoints { if svcInfo, ok := svcPortMap[epSvcPair.ServicePortName]; ok { + isIPv6 = netutils.IsIPv6(svcInfo.ClusterIP()) endpointIP := proxyutil.IPPart(epSvcPair.Endpoint) nodePort := svcInfo.NodePort() - var err error if nodePort != 0 { - err = ct.ClearEntriesForPortNAT(endpointIP, nodePort, v1.ProtocolUDP) - if err != nil { - klog.ErrorS(err, "Failed to delete nodeport-related endpoint connections", "servicePortName", epSvcPair.ServicePortName) - } - } - err = ct.ClearEntriesForNAT(svcInfo.ClusterIP().String(), endpointIP, v1.ProtocolUDP) - if err != nil { - klog.ErrorS(err, "Failed to delete endpoint connections", "servicePortName", epSvcPair.ServicePortName) + filters = append(filters, filterForPortNAT(endpointIP, nodePort, v1.ProtocolUDP)) + } + filters = append(filters, filterForNAT(svcInfo.ClusterIP().String(), endpointIP, v1.ProtocolUDP)) for _, extIP := range svcInfo.ExternalIPs() { - err := ct.ClearEntriesForNAT(extIP.String(), endpointIP, v1.ProtocolUDP) - if err != nil { - klog.ErrorS(err, "Failed to delete endpoint connections for externalIP", "servicePortName", epSvcPair.ServicePortName, "externalIP", extIP) - } + filters = append(filters, filterForNAT(extIP.String(), endpointIP, v1.ProtocolUDP)) } for _, lbIP := range svcInfo.LoadBalancerVIPs() { - err := ct.ClearEntriesForNAT(lbIP.String(), endpointIP, v1.ProtocolUDP) - if err != nil { - klog.ErrorS(err, "Failed to delete endpoint connections for LoadBalancerIP", "servicePortName", epSvcPair.ServicePortName, "loadBalancerIP", lbIP) - } + filters = append(filters, filterForNAT(lbIP.String(), endpointIP, v1.ProtocolUDP)) } } } + + if err := ct.ClearEntries(getUnixIPFamily(isIPv6), filters...); err != nil { + klog.ErrorS(err, "Failed to delete stale endpoint connections") + } +} + +// getUnixIPFamily returns the unix IPFamily constant. +func getUnixIPFamily(isIPv6 bool) uint8 { + if isIPv6 { + return unix.AF_INET6 + } + return unix.AF_INET +} + +// protocolMap maps v1.Protocol to the Assigned Internet Protocol Number. +// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml +var protocolMap = map[v1.Protocol]uint8{ + v1.ProtocolTCP: unix.IPPROTO_TCP, + v1.ProtocolUDP: unix.IPPROTO_UDP, + v1.ProtocolSCTP: unix.IPPROTO_SCTP, +} + +// filterForIP returns *conntrackFilter to delete the conntrack entries for connections +// specified by the destination IP (original direction). +func filterForIP(ip string, protocol v1.Protocol) *conntrackFilter { + klog.V(4).InfoS("Adding conntrack filter for cleanup", "org-dst", ip, "protocol", protocol) + return &conntrackFilter{ + protocol: protocolMap[protocol], + original: &connectionTuple{ + dstIP: netutils.ParseIPSloppy(ip), + }, + } +} + +// filterForPort returns *conntrackFilter to delete the conntrack entries for connections +// specified by the destination Port (original direction). +func filterForPort(port int, protocol v1.Protocol) *conntrackFilter { + klog.V(4).InfoS("Adding conntrack filter for cleanup", "org-port-dst", port, "protocol", protocol) + return &conntrackFilter{ + protocol: protocolMap[protocol], + original: &connectionTuple{ + dstPort: uint16(port), + }, + } +} + +// filterForNAT returns *conntrackFilter to delete the conntrack entries for connections +// specified by the destination IP (original direction) and source IP (reply direction). +func filterForNAT(origin, dest string, protocol v1.Protocol) *conntrackFilter { + klog.V(4).InfoS("Adding conntrack filter for cleanup", "org-dst", origin, "reply-src", dest, "protocol", protocol) + return &conntrackFilter{ + protocol: protocolMap[protocol], + original: &connectionTuple{ + dstIP: netutils.ParseIPSloppy(origin), + }, + reply: &connectionTuple{ + srcIP: netutils.ParseIPSloppy(dest), + }, + } +} + +// filterForPortNAT returns *conntrackFilter to delete the conntrack entries for connections +// specified by the destination Port (original direction) and source IP (reply direction). +func filterForPortNAT(dest string, port int, protocol v1.Protocol) *conntrackFilter { + klog.V(4).InfoS("Adding conntrack filter for cleanup", "org-port-dst", port, "reply-src", dest, "protocol", protocol) + return &conntrackFilter{ + protocol: protocolMap[protocol], + original: &connectionTuple{ + dstPort: uint16(port), + }, + reply: &connectionTuple{ + srcIP: netutils.ParseIPSloppy(dest), + }, + } } diff --git a/pkg/proxy/conntrack/cleanup_test.go b/pkg/proxy/conntrack/cleanup_test.go index 781800280fb..e17961050dd 100644 --- a/pkg/proxy/conntrack/cleanup_test.go +++ b/pkg/proxy/conntrack/cleanup_test.go @@ -24,11 +24,15 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" + v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/kubernetes/pkg/proxy" + netutils "k8s.io/utils/net" ) const ( @@ -261,3 +265,152 @@ func TestCleanStaleEntries(t *testing.T) { }) } } + +func TestFilterForIP(t *testing.T) { + testCases := []struct { + name string + ip string + protocol v1.Protocol + expectedFamily netlink.InetFamily + expectedFilter *conntrackFilter + }{ + { + name: "ipv4 + UDP", + ip: "10.96.0.10", + protocol: v1.ProtocolUDP, + expectedFilter: &conntrackFilter{ + protocol: 17, + original: &connectionTuple{dstIP: netutils.ParseIPSloppy("10.96.0.10")}, + }, + }, + { + name: "ipv6 + TCP", + ip: "2001:db8:1::2", + protocol: v1.ProtocolTCP, + expectedFilter: &conntrackFilter{ + protocol: 6, + original: &connectionTuple{dstIP: netutils.ParseIPSloppy("2001:db8:1::2")}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.expectedFilter, filterForIP(tc.ip, tc.protocol)) + }) + } +} + +func TestFilterForPort(t *testing.T) { + testCases := []struct { + name string + port int + protocol v1.Protocol + expectedFilter *conntrackFilter + }{ + { + name: "UDP", + port: 5000, + protocol: v1.ProtocolUDP, + + expectedFilter: &conntrackFilter{ + protocol: 17, + original: &connectionTuple{dstPort: 5000}, + }, + }, + { + name: "SCTP", + port: 3000, + protocol: v1.ProtocolSCTP, + expectedFilter: &conntrackFilter{ + protocol: 132, + original: &connectionTuple{dstPort: 3000}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.expectedFilter, filterForPort(tc.port, tc.protocol)) + }) + } +} + +func TestFilterForNAT(t *testing.T) { + testCases := []struct { + name string + orig string + dest string + protocol v1.Protocol + expectedFilter *conntrackFilter + }{ + { + name: "ipv4 + SCTP", + orig: "10.96.0.10", + dest: "10.244.0.3", + protocol: v1.ProtocolSCTP, + expectedFilter: &conntrackFilter{ + protocol: 132, + original: &connectionTuple{dstIP: netutils.ParseIPSloppy("10.96.0.10")}, + reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.244.0.3")}, + }, + }, + { + name: "ipv6 + UDP", + orig: "2001:db8:1::2", + dest: "4001:ab8::2", + protocol: v1.ProtocolUDP, + expectedFilter: &conntrackFilter{ + protocol: 17, + original: &connectionTuple{dstIP: netutils.ParseIPSloppy("2001:db8:1::2")}, + reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("4001:ab8::2")}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.expectedFilter, filterForNAT(tc.orig, tc.dest, tc.protocol)) + }) + } +} + +func TestFilterForPortNAT(t *testing.T) { + testCases := []struct { + name string + dest string + port int + protocol v1.Protocol + expectedFamily netlink.InetFamily + expectedFilter *conntrackFilter + }{ + { + name: "ipv4 + TCP", + dest: "10.96.0.10", + port: 80, + protocol: v1.ProtocolTCP, + expectedFilter: &conntrackFilter{ + protocol: 6, + original: &connectionTuple{dstPort: 80}, + reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.96.0.10")}, + }, + }, + { + name: "ipv6 + UDP", + dest: "2001:db8:1::2", + port: 8000, + protocol: v1.ProtocolUDP, + expectedFilter: &conntrackFilter{ + protocol: 17, + original: &connectionTuple{dstPort: 8000}, + reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("2001:db8:1::2")}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.expectedFilter, filterForPortNAT(tc.dest, tc.port, tc.protocol)) + }) + } +} diff --git a/pkg/proxy/conntrack/conntrack.go b/pkg/proxy/conntrack/conntrack.go index f22dd7c38f4..c5ec762f3cb 100644 --- a/pkg/proxy/conntrack/conntrack.go +++ b/pkg/proxy/conntrack/conntrack.go @@ -23,30 +23,15 @@ import ( "fmt" "github.com/vishvananda/netlink" - "golang.org/x/sys/unix" - v1 "k8s.io/api/core/v1" "k8s.io/klog/v2" - netutils "k8s.io/utils/net" ) // Interface for dealing with conntrack type Interface interface { - // ClearEntriesForIP deletes conntrack entries for connections of the given - // protocol, to the given IP. - ClearEntriesForIP(ip string, protocol v1.Protocol) error - - // ClearEntriesForPort deletes conntrack entries for connections of the given - // protocol and IP family, to the given port. - ClearEntriesForPort(port int, isIPv6 bool, protocol v1.Protocol) error - - // ClearEntriesForNAT deletes conntrack entries for connections of the given - // protocol, which had been DNATted from origin to dest. - ClearEntriesForNAT(origin, dest string, protocol v1.Protocol) error - - // ClearEntriesForPortNAT deletes conntrack entries for connections of the given - // protocol, which had been DNATted from the given port (on any IP) to dest. - ClearEntriesForPortNAT(dest string, port int, protocol v1.Protocol) error + // ClearEntries deletes conntrack entries for connections of the given IP family, + // filtered by the given filters. + ClearEntries(ipFamily uint8, filters ...netlink.CustomConntrackFilter) error } // netlinkHandler allows consuming real and mockable implementation for testing. @@ -69,110 +54,17 @@ func newConntracker(handler netlinkHandler) Interface { return &conntracker{handler: handler} } -// getNetlinkFamily returns the Netlink IP family constant -func getNetlinkFamily(isIPv6 bool) netlink.InetFamily { - if isIPv6 { - return unix.AF_INET6 +// ClearEntries deletes conntrack entries for connections of the given IP family, +// filtered by the given filters. +func (ct *conntracker) ClearEntries(ipFamily uint8, filters ...netlink.CustomConntrackFilter) error { + if len(filters) == 0 { + klog.V(7).InfoS("no conntrack filters provided") + return nil } - return unix.AF_INET -} -// protocolMap maps v1.Protocol to the Assigned Internet Protocol Number. -// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml -var protocolMap = map[v1.Protocol]uint8{ - v1.ProtocolTCP: unix.IPPROTO_TCP, - v1.ProtocolUDP: unix.IPPROTO_UDP, - v1.ProtocolSCTP: unix.IPPROTO_SCTP, -} - -// ClearEntriesForIP delete the conntrack entries for connections specified by the -// destination IP(original direction). -func (ct *conntracker) ClearEntriesForIP(ip string, protocol v1.Protocol) error { - filter := &conntrackFilter{ - protocol: protocolMap[protocol], - original: &connectionTuple{ - dstIP: netutils.ParseIPSloppy(ip), - }, - } - klog.V(4).InfoS("Clearing conntrack entries", "org-dst", ip, "protocol", protocol) - - n, err := ct.handler.ConntrackDeleteFilters(netlink.ConntrackTable, getNetlinkFamily(netutils.IsIPv6String(ip)), filter) + n, err := ct.handler.ConntrackDeleteFilters(netlink.ConntrackTable, netlink.InetFamily(ipFamily), filters...) if err != nil { - // TODO: Better handling for deletion failure. When failure occur, stale udp connection may not get flushed. - // These stale udp connection will keep black hole traffic. Making this a best effort operation for now, since it - // is expensive to baby-sit all udp connections to kubernetes services. - return fmt.Errorf("error deleting connection tracking state for %s service IP: %s, error: %w", protocol, ip, err) - } - klog.V(4).InfoS("Cleared conntrack entries", "count", n) - return nil -} - -// ClearEntriesForPort delete the conntrack entries for connections specified by the -// destination Port(original direction) and IPFamily. -func (ct *conntracker) ClearEntriesForPort(port int, isIPv6 bool, protocol v1.Protocol) error { - filter := &conntrackFilter{ - protocol: protocolMap[protocol], - original: &connectionTuple{ - dstPort: uint16(port), - }, - } - if port <= 0 { - return fmt.Errorf("wrong port number. The port number must be greater than zero") - } - - klog.V(4).InfoS("Clearing conntrack entries", "org-port-dst", port, "protocol", protocol) - n, err := ct.handler.ConntrackDeleteFilters(netlink.ConntrackTable, getNetlinkFamily(isIPv6), filter) - if err != nil { - return fmt.Errorf("error deleting connection tracking state for %s port: %d, error: %w", protocol, port, err) - } - klog.V(4).InfoS("Cleared conntrack entries", "count", n) - return nil -} - -// ClearEntriesForNAT delete the conntrack entries for connections specified by the -// destination IP(original direction) and source IP(reply direction). -func (ct *conntracker) ClearEntriesForNAT(origin, dest string, protocol v1.Protocol) error { - filter := &conntrackFilter{ - protocol: protocolMap[protocol], - original: &connectionTuple{ - dstIP: netutils.ParseIPSloppy(origin), - }, - reply: &connectionTuple{ - srcIP: netutils.ParseIPSloppy(dest), - }, - } - - klog.V(4).InfoS("Clearing conntrack entries", "org-dst", origin, "reply-src", dest, "protocol", protocol) - n, err := ct.handler.ConntrackDeleteFilters(netlink.ConntrackTable, getNetlinkFamily(netutils.IsIPv6String(origin)), filter) - if err != nil { - // TODO: Better handling for deletion failure. When failure occur, stale udp connection may not get flushed. - // These stale udp connection will keep black hole traffic. Making this a best effort operation for now, since it - // is expensive to baby sit all udp connections to kubernetes services. - return fmt.Errorf("error deleting conntrack entries for %s peer {%s, %s}, error: %w", protocol, origin, dest, err) - } - klog.V(4).InfoS("Cleared conntrack entries", "count", n) - return nil -} - -// ClearEntriesForPortNAT delete the conntrack entries for connections specified by the -// destination Port(original direction) and source IP(reply direction). -func (ct *conntracker) ClearEntriesForPortNAT(dest string, port int, protocol v1.Protocol) error { - if port <= 0 { - return fmt.Errorf("wrong port number. The port number must be greater than zero") - } - filter := &conntrackFilter{ - protocol: protocolMap[protocol], - original: &connectionTuple{ - dstPort: uint16(port), - }, - reply: &connectionTuple{ - srcIP: netutils.ParseIPSloppy(dest), - }, - } - klog.V(4).InfoS("Clearing conntrack entries", "reply-src", dest, "org-port-dst", port, "protocol", protocol) - n, err := ct.handler.ConntrackDeleteFilters(netlink.ConntrackTable, getNetlinkFamily(netutils.IsIPv6String(dest)), filter) - if err != nil { - return fmt.Errorf("error deleting conntrack entries for %s port: %d, error: %w", protocol, port, err) + return fmt.Errorf("error deleting conntrack entries, error: %w", err) } klog.V(4).InfoS("Cleared conntrack entries", "count", n) return nil diff --git a/pkg/proxy/conntrack/conntrack_test.go b/pkg/proxy/conntrack/conntrack_test.go index 1721537fac4..729ac39b443 100644 --- a/pkg/proxy/conntrack/conntrack_test.go +++ b/pkg/proxy/conntrack/conntrack_test.go @@ -26,51 +26,63 @@ import ( "github.com/vishvananda/netlink" "golang.org/x/sys/unix" - v1 "k8s.io/api/core/v1" netutils "k8s.io/utils/net" ) type fakeHandler struct { tableType netlink.ConntrackTableType - family netlink.InetFamily - filter *conntrackFilter + ipFamily netlink.InetFamily + filters []*conntrackFilter } -func (f *fakeHandler) ConntrackDeleteFilters(tableType netlink.ConntrackTableType, family netlink.InetFamily, filters ...netlink.CustomConntrackFilter) (uint, error) { +func (f *fakeHandler) ConntrackDeleteFilters(tableType netlink.ConntrackTableType, family netlink.InetFamily, netlinkFilters ...netlink.CustomConntrackFilter) (uint, error) { f.tableType = tableType - f.family = family - f.filter = filters[0].(*conntrackFilter) - return 1, nil + f.ipFamily = family + f.filters = make([]*conntrackFilter, 0, len(netlinkFilters)) + for _, netlinkFilter := range netlinkFilters { + f.filters = append(f.filters, netlinkFilter.(*conntrackFilter)) + } + return uint(len(f.filters)), nil } var _ netlinkHandler = (*fakeHandler)(nil) -func TestConntracker_ClearEntriesForIP(t *testing.T) { +func TestConntracker_ClearEntries(t *testing.T) { + testCases := []struct { - name string - ip string - protocol v1.Protocol - expectedFamily netlink.InetFamily - expectedFilter *conntrackFilter + name string + ipFamily uint8 + filters []netlink.CustomConntrackFilter }{ { - name: "ipv4 + UDP", - ip: "10.96.0.10", - protocol: v1.ProtocolUDP, - expectedFamily: unix.AF_INET, - expectedFilter: &conntrackFilter{ - protocol: 17, - original: &connectionTuple{dstIP: netutils.ParseIPSloppy("10.96.0.10")}, + name: "single IPv6 filter", + ipFamily: unix.AF_INET6, + filters: []netlink.CustomConntrackFilter{ + &conntrackFilter{ + protocol: 17, + original: &connectionTuple{dstPort: 8000}, + reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("2001:db8:1::2")}, + }, }, }, { - name: "ipv6 + TCP", - ip: "2001:db8:1::2", - protocol: v1.ProtocolTCP, - expectedFamily: unix.AF_INET6, - expectedFilter: &conntrackFilter{ - protocol: 6, - original: &connectionTuple{dstIP: netutils.ParseIPSloppy("2001:db8:1::2")}, + name: "multiple IPv4 filters", + ipFamily: unix.AF_INET, + filters: []netlink.CustomConntrackFilter{ + &conntrackFilter{ + protocol: 6, + original: &connectionTuple{dstPort: 3000}, + }, + &conntrackFilter{ + protocol: 17, + original: &connectionTuple{dstPort: 5000}, + reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.244.0.3")}, + }, + &conntrackFilter{ + protocol: 132, + original: &connectionTuple{dstIP: netutils.ParseIPSloppy("10.96.0.10")}, + reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.244.0.3")}, + }, }, }, } @@ -79,149 +91,13 @@ func TestConntracker_ClearEntriesForIP(t *testing.T) { t.Run(tc.name, func(t *testing.T) { handler := &fakeHandler{} ct := newConntracker(handler) - require.NoError(t, ct.ClearEntriesForIP(tc.ip, tc.protocol)) + require.NoError(t, ct.ClearEntries(tc.ipFamily, tc.filters...)) require.Equal(t, netlink.ConntrackTableType(netlink.ConntrackTable), handler.tableType) - require.Equal(t, tc.expectedFamily, handler.family) - require.Equal(t, tc.expectedFilter, handler.filter) - }) - } -} - -func TestConntracker_ClearEntriesForPort(t *testing.T) { - testCases := []struct { - name string - port int - isIPv6 bool - protocol v1.Protocol - expectedFamily netlink.InetFamily - expectedFilter *conntrackFilter - }{ - { - name: "ipv4 + UDP", - port: 5000, - isIPv6: false, - protocol: v1.ProtocolUDP, - expectedFamily: unix.AF_INET, - expectedFilter: &conntrackFilter{ - protocol: 17, - original: &connectionTuple{dstPort: 5000}, - }, - }, - { - name: "ipv6 + SCTP", - port: 3000, - isIPv6: true, - protocol: v1.ProtocolSCTP, - expectedFamily: unix.AF_INET6, - expectedFilter: &conntrackFilter{ - protocol: 132, - original: &connectionTuple{dstPort: 3000}, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - handler := &fakeHandler{} - ct := newConntracker(handler) - require.NoError(t, ct.ClearEntriesForPort(tc.port, tc.isIPv6, tc.protocol)) - require.Equal(t, netlink.ConntrackTableType(netlink.ConntrackTable), handler.tableType) - require.Equal(t, tc.expectedFamily, handler.family) - require.Equal(t, tc.expectedFilter, handler.filter) - }) - } -} - -func TestConntracker_ClearEntriesForNAT(t *testing.T) { - testCases := []struct { - name string - src string - dest string - protocol v1.Protocol - expectedFamily netlink.InetFamily - expectedFilter *conntrackFilter - }{ - { - name: "ipv4 + SCTP", - src: "10.96.0.10", - dest: "10.244.0.3", - protocol: v1.ProtocolSCTP, - expectedFamily: unix.AF_INET, - expectedFilter: &conntrackFilter{ - protocol: 132, - original: &connectionTuple{dstIP: netutils.ParseIPSloppy("10.96.0.10")}, - reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.244.0.3")}, - }, - }, - { - name: "ipv6 + UDP", - src: "2001:db8:1::2", - dest: "4001:ab8::2", - protocol: v1.ProtocolUDP, - expectedFamily: unix.AF_INET6, - expectedFilter: &conntrackFilter{ - protocol: 17, - original: &connectionTuple{dstIP: netutils.ParseIPSloppy("2001:db8:1::2")}, - reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("4001:ab8::2")}, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - handler := &fakeHandler{} - ct := newConntracker(handler) - require.NoError(t, ct.ClearEntriesForNAT(tc.src, tc.dest, tc.protocol)) - require.Equal(t, netlink.ConntrackTableType(netlink.ConntrackTable), handler.tableType) - require.Equal(t, tc.expectedFamily, handler.family) - require.Equal(t, tc.expectedFilter, handler.filter) - }) - } -} - -func TestConntracker_ClearEntriesForPortNAT(t *testing.T) { - testCases := []struct { - name string - ip string - port int - protocol v1.Protocol - expectedFamily netlink.InetFamily - expectedFilter *conntrackFilter - }{ - { - name: "ipv4 + TCP", - ip: "10.96.0.10", - port: 80, - protocol: v1.ProtocolTCP, - expectedFamily: unix.AF_INET, - expectedFilter: &conntrackFilter{ - protocol: 6, - original: &connectionTuple{dstPort: 80}, - reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.96.0.10")}, - }, - }, - { - name: "ipv6 + UDP", - ip: "2001:db8:1::2", - port: 8000, - protocol: v1.ProtocolUDP, - expectedFamily: unix.AF_INET6, - expectedFilter: &conntrackFilter{ - protocol: 17, - original: &connectionTuple{dstPort: 8000}, - reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("2001:db8:1::2")}, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - handler := &fakeHandler{} - ct := newConntracker(handler) - require.NoError(t, ct.ClearEntriesForPortNAT(tc.ip, tc.port, tc.protocol)) - require.Equal(t, netlink.ConntrackTableType(netlink.ConntrackTable), handler.tableType) - require.Equal(t, tc.expectedFamily, handler.family) - require.Equal(t, tc.expectedFilter, handler.filter) + require.Equal(t, netlink.InetFamily(tc.ipFamily), handler.ipFamily) + require.Equal(t, len(tc.filters), len(handler.filters)) + for i := 0; i < len(tc.filters); i++ { + require.Equal(t, tc.filters[i], handler.filters[i]) + } }) } } diff --git a/pkg/proxy/conntrack/fake.go b/pkg/proxy/conntrack/fake.go index 2bc1f2ce518..a90aa3f2cfd 100644 --- a/pkg/proxy/conntrack/fake.go +++ b/pkg/proxy/conntrack/fake.go @@ -22,6 +22,8 @@ package conntrack import ( "fmt" + "github.com/vishvananda/netlink" + v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/sets" ) @@ -51,48 +53,44 @@ func (fake *FakeInterface) Reset() { fake.ClearedPortNATs = make(map[int]string) } -// ClearEntriesForIP is part of Interface -func (fake *FakeInterface) ClearEntriesForIP(ip string, protocol v1.Protocol) error { - if protocol != v1.ProtocolUDP { - return fmt.Errorf("FakeInterface currently only supports UDP") - } +// ClearEntries is part of Interface +func (fake *FakeInterface) ClearEntries(_ uint8, filters ...netlink.CustomConntrackFilter) error { + for _, anyFilter := range filters { + filter := anyFilter.(*conntrackFilter) + if filter.protocol != protocolMap[v1.ProtocolUDP] { + return fmt.Errorf("FakeInterface currently only supports UDP") + } - fake.ClearedIPs.Insert(ip) - return nil -} - -// ClearEntriesForPort is part of Interface -func (fake *FakeInterface) ClearEntriesForPort(port int, isIPv6 bool, protocol v1.Protocol) error { - if protocol != v1.ProtocolUDP { - return fmt.Errorf("FakeInterface currently only supports UDP") - } - - fake.ClearedPorts.Insert(port) - return nil -} - -// ClearEntriesForNAT is part of Interface -func (fake *FakeInterface) ClearEntriesForNAT(origin, dest string, protocol v1.Protocol) error { - if protocol != v1.ProtocolUDP { - return fmt.Errorf("FakeInterface currently only supports UDP") - } - if previous, exists := fake.ClearedNATs[origin]; exists && previous != dest { - return fmt.Errorf("ClearEntriesForNAT called with same origin (%s), different destination (%s / %s)", origin, previous, dest) - } - - fake.ClearedNATs[origin] = dest - return nil -} - -// ClearEntriesForPortNAT is part of Interface -func (fake *FakeInterface) ClearEntriesForPortNAT(dest string, port int, protocol v1.Protocol) error { - if protocol != v1.ProtocolUDP { - return fmt.Errorf("FakeInterface currently only supports UDP") - } - if previous, exists := fake.ClearedPortNATs[port]; exists && previous != dest { - return fmt.Errorf("ClearEntriesForPortNAT called with same port (%d), different destination (%s / %s)", port, previous, dest) - } - - fake.ClearedPortNATs[port] = dest + // record IP and Port entries + if filter.original != nil && filter.reply == nil { + if filter.original.dstIP != nil { + fake.ClearedIPs.Insert(filter.original.dstIP.String()) + } + if filter.original.dstPort != 0 { + fake.ClearedPorts.Insert(int(filter.original.dstPort)) + } + } + + // record NAT and NATPort entries + if filter.original != nil && filter.reply != nil { + if filter.original.dstIP != nil && filter.reply.srcIP != nil { + origin := filter.original.dstIP.String() + dest := filter.reply.srcIP.String() + if previous, exists := fake.ClearedNATs[origin]; exists && previous != dest { + return fmt.Errorf("filter for NAT passed with same origin (%s), different destination (%s / %s)", origin, previous, dest) + } + fake.ClearedNATs[filter.original.dstIP.String()] = filter.reply.srcIP.String() + } + + if filter.original.dstPort != 0 && filter.reply.srcIP != nil { + dest := filter.reply.srcIP.String() + port := int(filter.original.dstPort) + if previous, exists := fake.ClearedPortNATs[port]; exists && previous != dest { + return fmt.Errorf("filter for PortNAT passed with same port (%d), different destination (%s / %s)", port, previous, dest) + } + fake.ClearedPortNATs[port] = dest + } + } + } return nil }