diff --git a/pkg/proxy/conntrack/conntrack.go b/pkg/proxy/conntrack/conntrack.go index 49aae1b94ea..f22dd7c38f4 100644 --- a/pkg/proxy/conntrack/conntrack.go +++ b/pkg/proxy/conntrack/conntrack.go @@ -21,13 +21,13 @@ package conntrack import ( "fmt" - "strconv" - "strings" + + "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" v1 "k8s.io/api/core/v1" "k8s.io/klog/v2" - "k8s.io/utils/exec" - utilnet "k8s.io/utils/net" + netutils "k8s.io/utils/net" ) // Interface for dealing with conntrack @@ -49,95 +49,131 @@ type Interface interface { ClearEntriesForPortNAT(dest string, port int, protocol v1.Protocol) error } -// execCT implements Interface by execing the conntrack tool -type execCT struct { - execer exec.Interface +// netlinkHandler allows consuming real and mockable implementation for testing. +type netlinkHandler interface { + ConntrackDeleteFilters(netlink.ConntrackTableType, netlink.InetFamily, ...netlink.CustomConntrackFilter) (uint, error) } -var _ Interface = &execCT{} - -func NewExec(execer exec.Interface) Interface { - return &execCT{execer: execer} +// conntracker implements Interface by using netlink APIs. +type conntracker struct { + handler netlinkHandler } -// noConnectionToDelete is the error string returned by conntrack when no matching connections are found -const noConnectionToDelete = "0 flow entries have been deleted" +var _ Interface = &conntracker{} -func protoStr(proto v1.Protocol) string { - return strings.ToLower(string(proto)) +func New() Interface { + return newConntracker(&netlink.Handle{}) } -func parametersWithFamily(isIPv6 bool, parameters ...string) []string { +func newConntracker(handler netlinkHandler) Interface { + return &conntracker{handler: handler} +} + +// getNetlinkFamily returns the Netlink IP family constant +func getNetlinkFamily(isIPv6 bool) netlink.InetFamily { if isIPv6 { - parameters = append(parameters, "-f", "ipv6") + return unix.AF_INET6 } - return parameters + return unix.AF_INET } -// exec executes the conntrack tool using the given parameters -func (ct *execCT) exec(parameters ...string) error { - conntrackPath, err := ct.execer.LookPath("conntrack") - if err != nil { - return fmt.Errorf("error looking for path of conntrack: %v", err) - } - klog.V(4).InfoS("Clearing conntrack entries", "parameters", parameters) - output, err := ct.execer.Command(conntrackPath, parameters...).CombinedOutput() - if err != nil { - return fmt.Errorf("conntrack command returned: %q, error message: %s", string(output), err) - } - klog.V(4).InfoS("Conntrack entries deleted", "output", string(output)) - return nil +// protocolMap maps v1.Protocol to the Assigned Internet Protocol Number. +// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml +var protocolMap = map[v1.Protocol]uint8{ + v1.ProtocolTCP: unix.IPPROTO_TCP, + v1.ProtocolUDP: unix.IPPROTO_UDP, + v1.ProtocolSCTP: unix.IPPROTO_SCTP, } -// ClearEntriesForIP is part of Interface -func (ct *execCT) ClearEntriesForIP(ip string, protocol v1.Protocol) error { - parameters := parametersWithFamily(utilnet.IsIPv6String(ip), "-D", "--orig-dst", ip, "-p", protoStr(protocol)) - err := ct.exec(parameters...) - if err != nil && !strings.Contains(err.Error(), noConnectionToDelete) { +// ClearEntriesForIP delete the conntrack entries for connections specified by the +// destination IP(original direction). +func (ct *conntracker) ClearEntriesForIP(ip string, protocol v1.Protocol) error { + filter := &conntrackFilter{ + protocol: protocolMap[protocol], + original: &connectionTuple{ + dstIP: netutils.ParseIPSloppy(ip), + }, + } + klog.V(4).InfoS("Clearing conntrack entries", "org-dst", ip, "protocol", protocol) + + n, err := ct.handler.ConntrackDeleteFilters(netlink.ConntrackTable, getNetlinkFamily(netutils.IsIPv6String(ip)), filter) + if err != nil { // TODO: Better handling for deletion failure. When failure occur, stale udp connection may not get flushed. // These stale udp connection will keep black hole traffic. Making this a best effort operation for now, since it // is expensive to baby-sit all udp connections to kubernetes services. - return fmt.Errorf("error deleting connection tracking state for UDP service IP: %s, error: %v", ip, err) + return fmt.Errorf("error deleting connection tracking state for %s service IP: %s, error: %w", protocol, ip, err) } + klog.V(4).InfoS("Cleared conntrack entries", "count", n) return nil } -// ClearEntriesForPort is part of Interface -func (ct *execCT) ClearEntriesForPort(port int, isIPv6 bool, protocol v1.Protocol) error { +// ClearEntriesForPort delete the conntrack entries for connections specified by the +// destination Port(original direction) and IPFamily. +func (ct *conntracker) ClearEntriesForPort(port int, isIPv6 bool, protocol v1.Protocol) error { + filter := &conntrackFilter{ + protocol: protocolMap[protocol], + original: &connectionTuple{ + dstPort: uint16(port), + }, + } if port <= 0 { return fmt.Errorf("wrong port number. The port number must be greater than zero") } - parameters := parametersWithFamily(isIPv6, "-D", "-p", protoStr(protocol), "--dport", strconv.Itoa(port)) - err := ct.exec(parameters...) - if err != nil && !strings.Contains(err.Error(), noConnectionToDelete) { - return fmt.Errorf("error deleting conntrack entries for UDP port: %d, error: %v", port, err) + + klog.V(4).InfoS("Clearing conntrack entries", "org-port-dst", port, "protocol", protocol) + n, err := ct.handler.ConntrackDeleteFilters(netlink.ConntrackTable, getNetlinkFamily(isIPv6), filter) + if err != nil { + return fmt.Errorf("error deleting connection tracking state for %s port: %d, error: %w", protocol, port, err) } + klog.V(4).InfoS("Cleared conntrack entries", "count", n) return nil } -// ClearEntriesForNAT is part of Interface -func (ct *execCT) ClearEntriesForNAT(origin, dest string, protocol v1.Protocol) error { - parameters := parametersWithFamily(utilnet.IsIPv6String(origin), "-D", "--orig-dst", origin, "--dst-nat", dest, - "-p", protoStr(protocol)) - err := ct.exec(parameters...) - if err != nil && !strings.Contains(err.Error(), noConnectionToDelete) { +// ClearEntriesForNAT delete the conntrack entries for connections specified by the +// destination IP(original direction) and source IP(reply direction). +func (ct *conntracker) ClearEntriesForNAT(origin, dest string, protocol v1.Protocol) error { + filter := &conntrackFilter{ + protocol: protocolMap[protocol], + original: &connectionTuple{ + dstIP: netutils.ParseIPSloppy(origin), + }, + reply: &connectionTuple{ + srcIP: netutils.ParseIPSloppy(dest), + }, + } + + klog.V(4).InfoS("Clearing conntrack entries", "org-dst", origin, "reply-src", dest, "protocol", protocol) + n, err := ct.handler.ConntrackDeleteFilters(netlink.ConntrackTable, getNetlinkFamily(netutils.IsIPv6String(origin)), filter) + if err != nil { // TODO: Better handling for deletion failure. When failure occur, stale udp connection may not get flushed. // These stale udp connection will keep black hole traffic. Making this a best effort operation for now, since it // is expensive to baby sit all udp connections to kubernetes services. - return fmt.Errorf("error deleting conntrack entries for UDP peer {%s, %s}, error: %v", origin, dest, err) + return fmt.Errorf("error deleting conntrack entries for %s peer {%s, %s}, error: %w", protocol, origin, dest, err) } + klog.V(4).InfoS("Cleared conntrack entries", "count", n) return nil } -// ClearEntriesForPortNAT is part of Interface -func (ct *execCT) ClearEntriesForPortNAT(dest string, port int, protocol v1.Protocol) error { +// ClearEntriesForPortNAT delete the conntrack entries for connections specified by the +// destination Port(original direction) and source IP(reply direction). +func (ct *conntracker) ClearEntriesForPortNAT(dest string, port int, protocol v1.Protocol) error { if port <= 0 { return fmt.Errorf("wrong port number. The port number must be greater than zero") } - parameters := parametersWithFamily(utilnet.IsIPv6String(dest), "-D", "-p", protoStr(protocol), "--dport", strconv.Itoa(port), "--dst-nat", dest) - err := ct.exec(parameters...) - if err != nil && !strings.Contains(err.Error(), noConnectionToDelete) { - return fmt.Errorf("error deleting conntrack entries for UDP port: %d, error: %v", port, err) + filter := &conntrackFilter{ + protocol: protocolMap[protocol], + original: &connectionTuple{ + dstPort: uint16(port), + }, + reply: &connectionTuple{ + srcIP: netutils.ParseIPSloppy(dest), + }, } + klog.V(4).InfoS("Clearing conntrack entries", "reply-src", dest, "org-port-dst", port, "protocol", protocol) + n, err := ct.handler.ConntrackDeleteFilters(netlink.ConntrackTable, getNetlinkFamily(netutils.IsIPv6String(dest)), filter) + if err != nil { + return fmt.Errorf("error deleting conntrack entries for %s port: %d, error: %w", protocol, port, err) + } + klog.V(4).InfoS("Cleared conntrack entries", "count", n) return nil } diff --git a/pkg/proxy/conntrack/conntrack_test.go b/pkg/proxy/conntrack/conntrack_test.go index 1d9e3fd9f3f..1721537fac4 100644 --- a/pkg/proxy/conntrack/conntrack_test.go +++ b/pkg/proxy/conntrack/conntrack_test.go @@ -20,253 +20,208 @@ limitations under the License. package conntrack import ( - "fmt" - "strings" "testing" + "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" + v1 "k8s.io/api/core/v1" - "k8s.io/utils/exec" - fakeexec "k8s.io/utils/exec/testing" + netutils "k8s.io/utils/net" ) -var success = func() ([]byte, []byte, error) { return []byte("1 flow entries have been deleted"), nil, nil } -var nothingToDelete = func() ([]byte, []byte, error) { - return []byte(""), nil, fmt.Errorf("conntrack v1.4.2 (conntrack-tools): 0 flow entries have been deleted") +type fakeHandler struct { + tableType netlink.ConntrackTableType + family netlink.InetFamily + filter *conntrackFilter } -type testCT struct { - execCT - - fcmd *fakeexec.FakeCmd +func (f *fakeHandler) ConntrackDeleteFilters(tableType netlink.ConntrackTableType, family netlink.InetFamily, filters ...netlink.CustomConntrackFilter) (uint, error) { + f.tableType = tableType + f.family = family + f.filter = filters[0].(*conntrackFilter) + return 1, nil } -func makeCT(result fakeexec.FakeAction) *testCT { - fcmd := &fakeexec.FakeCmd{ - CombinedOutputScript: []fakeexec.FakeAction{result}, - } - fexec := &fakeexec.FakeExec{ - CommandScript: []fakeexec.FakeCommandAction{ - func(cmd string, args ...string) exec.Cmd { return fakeexec.InitFakeCmd(fcmd, cmd, args...) }, - }, - LookPathFunc: func(cmd string) (string, error) { return cmd, nil }, - } +var _ netlinkHandler = (*fakeHandler)(nil) - return &testCT{execCT{fexec}, fcmd} -} - -// Gets the command that ct executed. (If it didn't execute any commands, this will -// return "".) -func (ct *testCT) getExecutedCommand() string { - // FakeExec panics if you try to run more commands than you set it up for. So the - // only possibilities here are that we ran 1 command or we ran 0. - if ct.execer.(*fakeexec.FakeExec).CommandCalls != 1 { - return "" - } - return strings.Join(ct.fcmd.CombinedOutputLog[0], " ") -} - -func TestExec(t *testing.T) { +func TestConntracker_ClearEntriesForIP(t *testing.T) { testCases := []struct { - args []string - result fakeexec.FakeAction - expectErr bool + name string + ip string + protocol v1.Protocol + expectedFamily netlink.InetFamily + expectedFilter *conntrackFilter }{ { - args: []string{"-D", "-p", "udp", "-d", "10.0.240.1"}, - result: success, - expectErr: false, + name: "ipv4 + UDP", + ip: "10.96.0.10", + protocol: v1.ProtocolUDP, + expectedFamily: unix.AF_INET, + expectedFilter: &conntrackFilter{ + protocol: 17, + original: &connectionTuple{dstIP: netutils.ParseIPSloppy("10.96.0.10")}, + }, }, { - args: []string{"-D", "-p", "udp", "--orig-dst", "10.240.0.2", "--dst-nat", "10.0.10.2"}, - result: nothingToDelete, - expectErr: true, + name: "ipv6 + TCP", + ip: "2001:db8:1::2", + protocol: v1.ProtocolTCP, + expectedFamily: unix.AF_INET6, + expectedFilter: &conntrackFilter{ + protocol: 6, + original: &connectionTuple{dstIP: netutils.ParseIPSloppy("2001:db8:1::2")}, + }, }, } for _, tc := range testCases { - ct := makeCT(tc.result) - err := ct.exec(tc.args...) - if tc.expectErr { - if err == nil { - t.Errorf("expected err, got %v", err) - } - } else { - if err != nil { - t.Errorf("expected success, got %v", err) - } - } - - execCmd := ct.getExecutedCommand() - expectCmd := "conntrack " + strings.Join(tc.args, " ") - if execCmd != expectCmd { - t.Errorf("expect execute command: %s, but got: %s", expectCmd, execCmd) - } + t.Run(tc.name, func(t *testing.T) { + handler := &fakeHandler{} + ct := newConntracker(handler) + require.NoError(t, ct.ClearEntriesForIP(tc.ip, tc.protocol)) + require.Equal(t, netlink.ConntrackTableType(netlink.ConntrackTable), handler.tableType) + require.Equal(t, tc.expectedFamily, handler.family) + require.Equal(t, tc.expectedFilter, handler.filter) + }) } } -func TestClearEntriesForIP(t *testing.T) { +func TestConntracker_ClearEntriesForPort(t *testing.T) { testCases := []struct { - name string - ip string - - expectCommand string + name string + port int + isIPv6 bool + protocol v1.Protocol + expectedFamily netlink.InetFamily + expectedFilter *conntrackFilter }{ { - name: "IPv4", - ip: "10.240.0.3", - - expectCommand: "conntrack -D --orig-dst 10.240.0.3 -p udp", + name: "ipv4 + UDP", + port: 5000, + isIPv6: false, + protocol: v1.ProtocolUDP, + expectedFamily: unix.AF_INET, + expectedFilter: &conntrackFilter{ + protocol: 17, + original: &connectionTuple{dstPort: 5000}, + }, }, { - name: "IPv6", - ip: "2001:db8::10", - - expectCommand: "conntrack -D --orig-dst 2001:db8::10 -p udp -f ipv6", + name: "ipv6 + SCTP", + port: 3000, + isIPv6: true, + protocol: v1.ProtocolSCTP, + expectedFamily: unix.AF_INET6, + expectedFilter: &conntrackFilter{ + protocol: 132, + original: &connectionTuple{dstPort: 3000}, + }, }, } for _, tc := range testCases { - ct := makeCT(success) - if err := ct.ClearEntriesForIP(tc.ip, v1.ProtocolUDP); err != nil { - t.Errorf("%s/success: Unexpected error: %v", tc.name, err) - } - execCommand := ct.getExecutedCommand() - if tc.expectCommand != execCommand { - t.Errorf("%s/success: Expect command: %s, but executed %s", tc.name, tc.expectCommand, execCommand) - } - - ct = makeCT(nothingToDelete) - if err := ct.ClearEntriesForIP(tc.ip, v1.ProtocolUDP); err != nil { - t.Errorf("%s/nothing to delete: Unexpected error: %v", tc.name, err) - } + t.Run(tc.name, func(t *testing.T) { + handler := &fakeHandler{} + ct := newConntracker(handler) + require.NoError(t, ct.ClearEntriesForPort(tc.port, tc.isIPv6, tc.protocol)) + require.Equal(t, netlink.ConntrackTableType(netlink.ConntrackTable), handler.tableType) + require.Equal(t, tc.expectedFamily, handler.family) + require.Equal(t, tc.expectedFilter, handler.filter) + }) } } -func TestClearEntriesForPort(t *testing.T) { +func TestConntracker_ClearEntriesForNAT(t *testing.T) { testCases := []struct { - name string - port int - isIPv6 bool - - expectCommand string + name string + src string + dest string + protocol v1.Protocol + expectedFamily netlink.InetFamily + expectedFilter *conntrackFilter }{ { - name: "IPv4", - port: 8080, - isIPv6: false, - - expectCommand: "conntrack -D -p udp --dport 8080", + name: "ipv4 + SCTP", + src: "10.96.0.10", + dest: "10.244.0.3", + protocol: v1.ProtocolSCTP, + expectedFamily: unix.AF_INET, + expectedFilter: &conntrackFilter{ + protocol: 132, + original: &connectionTuple{dstIP: netutils.ParseIPSloppy("10.96.0.10")}, + reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.244.0.3")}, + }, }, { - name: "IPv6", - port: 6666, - isIPv6: true, - - expectCommand: "conntrack -D -p udp --dport 6666 -f ipv6", + name: "ipv6 + UDP", + src: "2001:db8:1::2", + dest: "4001:ab8::2", + protocol: v1.ProtocolUDP, + expectedFamily: unix.AF_INET6, + expectedFilter: &conntrackFilter{ + protocol: 17, + original: &connectionTuple{dstIP: netutils.ParseIPSloppy("2001:db8:1::2")}, + reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("4001:ab8::2")}, + }, }, } for _, tc := range testCases { - ct := makeCT(success) - err := ct.ClearEntriesForPort(tc.port, tc.isIPv6, v1.ProtocolUDP) - if err != nil { - t.Errorf("%s/success: Unexpected error: %v", tc.name, err) - } - execCommand := ct.getExecutedCommand() - if tc.expectCommand != execCommand { - t.Errorf("%s/success: Expect command: %s, but executed %s", tc.name, tc.expectCommand, execCommand) - } - - ct = makeCT(nothingToDelete) - err = ct.ClearEntriesForPort(tc.port, tc.isIPv6, v1.ProtocolUDP) - if err != nil { - t.Errorf("%s/nothing to delete: Unexpected error: %v", tc.name, err) - } + t.Run(tc.name, func(t *testing.T) { + handler := &fakeHandler{} + ct := newConntracker(handler) + require.NoError(t, ct.ClearEntriesForNAT(tc.src, tc.dest, tc.protocol)) + require.Equal(t, netlink.ConntrackTableType(netlink.ConntrackTable), handler.tableType) + require.Equal(t, tc.expectedFamily, handler.family) + require.Equal(t, tc.expectedFilter, handler.filter) + }) } } -func TestClearEntriesForNAT(t *testing.T) { +func TestConntracker_ClearEntriesForPortNAT(t *testing.T) { testCases := []struct { - name string - origin string - dest string - - expectCommand string + name string + ip string + port int + protocol v1.Protocol + expectedFamily netlink.InetFamily + expectedFilter *conntrackFilter }{ { - name: "IPv4", - origin: "1.2.3.4", - dest: "10.20.30.40", - - expectCommand: "conntrack -D --orig-dst 1.2.3.4 --dst-nat 10.20.30.40 -p udp", + name: "ipv4 + TCP", + ip: "10.96.0.10", + port: 80, + protocol: v1.ProtocolTCP, + expectedFamily: unix.AF_INET, + expectedFilter: &conntrackFilter{ + protocol: 6, + original: &connectionTuple{dstPort: 80}, + reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.96.0.10")}, + }, }, { - name: "IPv6", - origin: "fd00::600d:f00d", - dest: "2001:db8::5", - - expectCommand: "conntrack -D --orig-dst fd00::600d:f00d --dst-nat 2001:db8::5 -p udp -f ipv6", + name: "ipv6 + UDP", + ip: "2001:db8:1::2", + port: 8000, + protocol: v1.ProtocolUDP, + expectedFamily: unix.AF_INET6, + expectedFilter: &conntrackFilter{ + protocol: 17, + original: &connectionTuple{dstPort: 8000}, + reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("2001:db8:1::2")}, + }, }, } for _, tc := range testCases { - ct := makeCT(success) - err := ct.ClearEntriesForNAT(tc.origin, tc.dest, v1.ProtocolUDP) - if err != nil { - t.Errorf("%s/success: unexpected error: %v", tc.name, err) - } - execCommand := ct.getExecutedCommand() - if tc.expectCommand != execCommand { - t.Errorf("%s/success: Expect command: %s, but executed %s", tc.name, tc.expectCommand, execCommand) - } - - ct = makeCT(nothingToDelete) - err = ct.ClearEntriesForNAT(tc.origin, tc.dest, v1.ProtocolUDP) - if err != nil { - t.Errorf("%s/nothing to delete: unexpected error: %v", tc.name, err) - } - } -} - -func TestClearEntriesForPortNAT(t *testing.T) { - testCases := []struct { - name string - port int - dest string - - expectCommand string - }{ - { - name: "IPv4", - port: 30211, - dest: "1.2.3.4", - - expectCommand: "conntrack -D -p udp --dport 30211 --dst-nat 1.2.3.4", - }, - { - name: "IPv6", - port: 30212, - dest: "2600:5200::7800", - - expectCommand: "conntrack -D -p udp --dport 30212 --dst-nat 2600:5200::7800 -f ipv6", - }, - } - - for _, tc := range testCases { - ct := makeCT(success) - err := ct.ClearEntriesForPortNAT(tc.dest, tc.port, v1.ProtocolUDP) - if err != nil { - t.Errorf("%s/success: unexpected error: %v", tc.name, err) - } - execCommand := ct.getExecutedCommand() - if tc.expectCommand != execCommand { - t.Errorf("%s/success: Expect command: %s, but executed %s", tc.name, tc.expectCommand, execCommand) - } - - ct = makeCT(nothingToDelete) - err = ct.ClearEntriesForPortNAT(tc.dest, tc.port, v1.ProtocolUDP) - if err != nil { - t.Errorf("%s/nothing to delete: unexpected error: %v", tc.name, err) - } + t.Run(tc.name, func(t *testing.T) { + handler := &fakeHandler{} + ct := newConntracker(handler) + require.NoError(t, ct.ClearEntriesForPortNAT(tc.ip, tc.port, tc.protocol)) + require.Equal(t, netlink.ConntrackTableType(netlink.ConntrackTable), handler.tableType) + require.Equal(t, tc.expectedFamily, handler.family) + require.Equal(t, tc.expectedFilter, handler.filter) + }) } } diff --git a/pkg/proxy/conntrack/filter.go b/pkg/proxy/conntrack/filter.go new file mode 100644 index 00000000000..75dd3f0c57f --- /dev/null +++ b/pkg/proxy/conntrack/filter.go @@ -0,0 +1,101 @@ +//go:build linux +// +build linux + +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package conntrack + +import ( + "net" + + "github.com/vishvananda/netlink" + + "k8s.io/klog/v2" +) + +type connectionTuple struct { + srcIP net.IP + srcPort uint16 + dstIP net.IP + dstPort uint16 +} + +type conntrackFilter struct { + protocol uint8 + original *connectionTuple + reply *connectionTuple +} + +var _ netlink.CustomConntrackFilter = (*conntrackFilter)(nil) + +// MatchConntrackFlow applies the filter to the flow and returns true if the flow matches the filter +// false otherwise. +func (f *conntrackFilter) MatchConntrackFlow(flow *netlink.ConntrackFlow) bool { + // return false in case of empty filter + if f.protocol == 0 && f.original == nil && f.reply == nil { + return false + } + + // -p, --protonum proto [Layer 4 Protocol, eg. 'tcp'] + if f.protocol != 0 && f.protocol != flow.Forward.Protocol { + return false + } + + // filter on original direction + if f.original != nil { + // --orig-src ip [Source address from original direction] + if f.original.srcIP != nil && !f.original.srcIP.Equal(flow.Forward.SrcIP) { + return false + } + // --orig-dst ip [Destination address from original direction] + if f.original.dstIP != nil && !f.original.dstIP.Equal(flow.Forward.DstIP) { + return false + } + // --orig-port-src port [Source port from original direction] + if f.original.srcPort != 0 && f.original.srcPort != flow.Forward.SrcPort { + return false + } + // --orig-port-dst port [Destination port from original direction] + if f.original.dstPort != 0 && f.original.dstPort != flow.Forward.DstPort { + return false + } + } + + // filter on reply direction + if f.reply != nil { + // --reply-src ip [Source NAT ip] + if f.reply.srcIP != nil && !f.reply.srcIP.Equal(flow.Reverse.SrcIP) { + return false + } + // --reply-dst ip [Destination NAT ip] + if f.reply.dstIP != nil && !f.reply.dstIP.Equal(flow.Reverse.DstIP) { + return false + } + // --reply-port-src port [Source port from reply direction] + if f.reply.srcPort != 0 && f.reply.srcPort != flow.Reverse.SrcPort { + return false + } + // --reply-port-dst port [Destination port from reply direction] + if f.reply.dstPort != 0 && f.reply.dstPort != flow.Reverse.DstPort { + return false + } + } + + // appending a new line to the flow makes klog print multiline log which is easier to debug and understand. + klog.V(4).InfoS("Deleting conntrack entry", "flow", flow.String()+"\n") + return true +} diff --git a/pkg/proxy/conntrack/filter_test.go b/pkg/proxy/conntrack/filter_test.go new file mode 100644 index 00000000000..a8b8a4b3f87 --- /dev/null +++ b/pkg/proxy/conntrack/filter_test.go @@ -0,0 +1,172 @@ +//go:build linux +// +build linux + +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package conntrack + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" + + netutils "k8s.io/utils/net" +) + +func applyFilter(flowList []netlink.ConntrackFlow, ipv4Filter *conntrackFilter, ipv6Filter *conntrackFilter) (ipv4Match, ipv6Match int) { + for _, flow := range flowList { + if ipv4Filter.MatchConntrackFlow(&flow) == true { + ipv4Match++ + } + if ipv6Filter.MatchConntrackFlow(&flow) == true { + ipv6Match++ + } + } + return ipv4Match, ipv6Match +} + +func TestConntrackFilter(t *testing.T) { + var flowList []netlink.ConntrackFlow + flow1 := netlink.ConntrackFlow{} + flow1.FamilyType = unix.AF_INET + flow1.Forward.SrcIP = netutils.ParseIPSloppy("10.0.0.1") + flow1.Forward.DstIP = netutils.ParseIPSloppy("20.0.0.1") + flow1.Forward.SrcPort = 1000 + flow1.Forward.DstPort = 2000 + flow1.Forward.Protocol = 17 + flow1.Reverse.SrcIP = netutils.ParseIPSloppy("20.0.0.1") + flow1.Reverse.DstIP = netutils.ParseIPSloppy("192.168.1.1") + flow1.Reverse.SrcPort = 2000 + flow1.Reverse.DstPort = 1000 + flow1.Reverse.Protocol = 17 + + flow2 := netlink.ConntrackFlow{} + flow2.FamilyType = unix.AF_INET + flow2.Forward.SrcIP = netutils.ParseIPSloppy("10.0.0.2") + flow2.Forward.DstIP = netutils.ParseIPSloppy("20.0.0.2") + flow2.Forward.SrcPort = 5000 + flow2.Forward.DstPort = 6000 + flow2.Forward.Protocol = 6 + flow2.Reverse.SrcIP = netutils.ParseIPSloppy("20.0.0.2") + flow2.Reverse.DstIP = netutils.ParseIPSloppy("192.168.1.1") + flow2.Reverse.SrcPort = 6000 + flow2.Reverse.DstPort = 5000 + flow2.Reverse.Protocol = 6 + + flow3 := netlink.ConntrackFlow{} + flow3.FamilyType = unix.AF_INET6 + flow3.Forward.SrcIP = netutils.ParseIPSloppy("eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee") + flow3.Forward.DstIP = netutils.ParseIPSloppy("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd") + flow3.Forward.SrcPort = 1000 + flow3.Forward.DstPort = 2000 + flow3.Forward.Protocol = 132 + flow3.Reverse.SrcIP = netutils.ParseIPSloppy("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd") + flow3.Reverse.DstIP = netutils.ParseIPSloppy("eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee") + flow3.Reverse.SrcPort = 2000 + flow3.Reverse.DstPort = 1000 + flow3.Reverse.Protocol = 132 + flowList = append(flowList, flow1, flow2, flow3) + + testCases := []struct { + name string + filterV4 *conntrackFilter + filterV6 *conntrackFilter + expectedV4Matches int + expectedV6Matches int + }{ + { + name: "Empty filter", + filterV4: &conntrackFilter{}, + filterV6: &conntrackFilter{}, + expectedV4Matches: 0, + expectedV6Matches: 0, + }, + { + name: "Protocol filter", + filterV4: &conntrackFilter{protocol: 6}, + filterV6: &conntrackFilter{protocol: 17}, + expectedV4Matches: 1, + expectedV6Matches: 1, + }, + { + name: "Original Source IP filter", + filterV4: &conntrackFilter{original: &connectionTuple{srcIP: netutils.ParseIPSloppy("10.0.0.1")}}, + filterV6: &conntrackFilter{original: &connectionTuple{srcIP: netutils.ParseIPSloppy("eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee")}}, + expectedV4Matches: 1, + expectedV6Matches: 1, + }, + { + name: "Original Destination IP filter", + filterV4: &conntrackFilter{original: &connectionTuple{dstIP: netutils.ParseIPSloppy("20.0.0.1")}}, + filterV6: &conntrackFilter{original: &connectionTuple{dstIP: netutils.ParseIPSloppy("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd")}}, + expectedV4Matches: 1, + expectedV6Matches: 1, + }, + { + name: "Original Source Port Filter", + filterV4: &conntrackFilter{protocol: 6, original: &connectionTuple{srcPort: 5000}}, + filterV6: &conntrackFilter{protocol: 132, original: &connectionTuple{srcPort: 1000}}, + expectedV4Matches: 1, + expectedV6Matches: 1, + }, + { + name: "Original Destination Port Filter", + filterV4: &conntrackFilter{protocol: 6, original: &connectionTuple{dstPort: 6000}}, + filterV6: &conntrackFilter{protocol: 132, original: &connectionTuple{dstPort: 2000}}, + expectedV4Matches: 1, + expectedV6Matches: 1, + }, + { + name: "Reply Source IP filter", + filterV4: &conntrackFilter{reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("20.0.0.1")}}, + filterV6: &conntrackFilter{reply: &connectionTuple{srcIP: netutils.ParseIPSloppy("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd")}}, + expectedV4Matches: 1, + expectedV6Matches: 1, + }, + { + name: "Reply Destination IP filter", + filterV4: &conntrackFilter{reply: &connectionTuple{dstIP: netutils.ParseIPSloppy("192.168.1.1")}}, + filterV6: &conntrackFilter{reply: &connectionTuple{dstIP: netutils.ParseIPSloppy("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd")}}, + expectedV4Matches: 2, + expectedV6Matches: 0, + }, + { + name: "Reply Source Port filter", + filterV4: &conntrackFilter{protocol: 17, reply: &connectionTuple{srcPort: 2000}}, + filterV6: &conntrackFilter{protocol: 132, reply: &connectionTuple{srcPort: 2000}}, + expectedV4Matches: 1, + expectedV6Matches: 1, + }, + { + name: "Reply Destination Port filter", + filterV4: &conntrackFilter{protocol: 6, reply: &connectionTuple{dstPort: 5000}}, + filterV6: &conntrackFilter{protocol: 132, reply: &connectionTuple{dstPort: 1000}}, + expectedV4Matches: 1, + expectedV6Matches: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v4Matches, v6Matches := applyFilter(flowList, tc.filterV4, tc.filterV6) + require.Equal(t, tc.expectedV4Matches, v4Matches) + require.Equal(t, tc.expectedV6Matches, v6Matches) + }) + } +} diff --git a/pkg/proxy/iptables/proxier.go b/pkg/proxy/iptables/proxier.go index e5e96eee480..a87c70ec4ae 100644 --- a/pkg/proxy/iptables/proxier.go +++ b/pkg/proxy/iptables/proxier.go @@ -294,7 +294,7 @@ func NewProxier(ctx context.Context, iptables: ipt, masqueradeAll: masqueradeAll, masqueradeMark: masqueradeMark, - conntrack: conntrack.NewExec(exec), + conntrack: conntrack.New(), nfacct: nfacctRunner, localDetector: localDetector, hostname: hostname, diff --git a/pkg/proxy/ipvs/proxier.go b/pkg/proxy/ipvs/proxier.go index eece2f0d2ab..8c9f8ff02da 100644 --- a/pkg/proxy/ipvs/proxier.go +++ b/pkg/proxy/ipvs/proxier.go @@ -385,7 +385,7 @@ func NewProxier( iptables: ipt, masqueradeAll: masqueradeAll, masqueradeMark: masqueradeMark, - conntrack: conntrack.NewExec(exec), + conntrack: conntrack.New(), localDetector: localDetector, hostname: hostname, nodeIP: nodeIP, diff --git a/pkg/proxy/nftables/proxier.go b/pkg/proxy/nftables/proxier.go index c770be081dd..7da83173ecc 100644 --- a/pkg/proxy/nftables/proxier.go +++ b/pkg/proxy/nftables/proxier.go @@ -53,7 +53,6 @@ import ( proxyutil "k8s.io/kubernetes/pkg/proxy/util" "k8s.io/kubernetes/pkg/util/async" utilkernel "k8s.io/kubernetes/pkg/util/kernel" - utilexec "k8s.io/utils/exec" netutils "k8s.io/utils/net" "k8s.io/utils/ptr" "sigs.k8s.io/knftables" @@ -256,7 +255,7 @@ func NewProxier(ctx context.Context, nftables: nft, masqueradeAll: masqueradeAll, masqueradeMark: masqueradeMark, - conntrack: conntrack.NewExec(utilexec.New()), + conntrack: conntrack.New(), localDetector: localDetector, hostname: hostname, nodeIP: nodeIP,