diff --git a/pkg/proxy/nftables/helpers_test.go b/pkg/proxy/nftables/helpers_test.go index f9e316c10d3..0ca7eca330e 100644 --- a/pkg/proxy/nftables/helpers_test.go +++ b/pkg/proxy/nftables/helpers_test.go @@ -19,6 +19,7 @@ package nftables import ( "context" "fmt" + "regexp" "runtime" "sort" "strings" @@ -29,6 +30,8 @@ import ( "github.com/lithammer/dedent" "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/sets" + netutils "k8s.io/utils/net" ) // getLine returns a string containing the file and line number of the caller, if @@ -55,6 +58,15 @@ var objectOrder = map[string]int{ // anything else: 0 } +// For most chains we leave the rules in order (because the order matters), but for chains +// with per-service rules, we don't know what order syncProxyRules is going to output them +// in, but the order doesn't matter anyway. So we sort the rules in those chains. +var sortedChains = sets.New( + kubeServicesFilterChain, + kubeExternalServicesChain, + kubeFirewallChain, +) + // sortNFTablesTransaction sorts an nftables transaction into a standard order for comparison func sortNFTablesTransaction(tx string) string { lines := strings.Split(tx, "\n") @@ -93,8 +105,13 @@ func sortNFTablesTransaction(tx string) string { return wi[4] < wj[4] } - // Leave rules in the order they were added in if wi[1] == "rule" { + // Sort rules in chains that need to be sorted + if sortedChains.Has(wi[4]) { + return li < lj + } + + // Otherwise leave rules in the order they were originally added. return false } @@ -146,6 +163,224 @@ func assertNFTablesChainEqual(t *testing.T, line string, nft *knftables.Fake, ch } } +// nftablesTracer holds data used while virtually tracing a packet through a set of +// iptables rules +type nftablesTracer struct { + nft *knftables.Fake + nodeIPs sets.Set[string] + t *testing.T + + // matches accumulates the list of rules that were matched, for debugging purposes. + matches []string + + // outputs accumulates the list of matched terminal rule targets (endpoint + // IP:ports, or a special target like "REJECT") and is eventually used to generate + // the return value of tracePacket. + outputs []string + + // markMasq tracks whether the packet has been marked for masquerading + markMasq bool +} + +// newNFTablesTracer creates an nftablesTracer. nodeIPs are the IP to treat as local node +// IPs (for determining whether rules with "fib saddr type local" or "fib daddr type +// local" match). +func newNFTablesTracer(t *testing.T, nft *knftables.Fake, nodeIPs []string) *nftablesTracer { + return &nftablesTracer{ + nft: nft, + nodeIPs: sets.New(nodeIPs...), + t: t, + } +} + +func (tracer *nftablesTracer) addressMatches(ipStr, not, ruleAddress string) bool { + ip := netutils.ParseIPSloppy(ipStr) + if ip == nil { + tracer.t.Fatalf("Bad IP in test case: %s", ipStr) + } + + var match bool + if strings.Contains(ruleAddress, "/") { + _, cidr, err := netutils.ParseCIDRSloppy(ruleAddress) + if err != nil { + tracer.t.Errorf("Bad CIDR in kube-proxy output: %v", err) + } + match = cidr.Contains(ip) + } else { + ip2 := netutils.ParseIPSloppy(ruleAddress) + if ip2 == nil { + tracer.t.Errorf("Bad IP/CIDR in kube-proxy output: %s", ruleAddress) + } + match = ip.Equal(ip2) + } + + if not == "!= " { + return !match + } else { + return match + } +} + +// We intentionally don't try to parse arbitrary nftables rules, as the syntax is quite +// complicated and context sensitive. (E.g., "ip daddr" could be the start of an address +// comparison, or it could be the start of a set/map lookup.) Instead, we just have +// regexps to recognize the specific pieces of rules that we create in proxier.go. +// Anything matching ignoredRegexp gets stripped out of the rule, and then what's left +// *must* match one of the cases in runChain or an error will be logged. In cases where +// the regexp doesn't end with `$`, and the matched rule succeeds against the input data, +// runChain will continue trying to match the rest of the rule. E.g., "ip daddr 10.0.0.1 +// drop" would first match destAddrRegexp, and then (assuming destIP was "10.0.0.1") would +// match verdictRegexp. + +var destAddrRegexp = regexp.MustCompile(`^ip6* daddr (!= )?(\S+)`) +var destAddrLocalRegexp = regexp.MustCompile(`^fib daddr type local`) +var destPortRegexp = regexp.MustCompile(`^(tcp|udp|sctp) dport (\d+)`) + +var jumpRegexp = regexp.MustCompile(`^(jump|goto) (\S+)$`) +var verdictRegexp = regexp.MustCompile(`^(drop|reject)$`) + +var ignoredRegexp = regexp.MustCompile(strings.Join( + []string{ + // Ignore comments (which can only appear at the end of a rule). + ` *comment "[^"]*"$`, + + // The trace tests only check new connections, so for our purposes, this + // check always succeeds (and thus can be ignored). + `^ct state new`, + + // Likewise, this rule never matches and thus never drops anything, and so + // can be ignored. + `^ct state invalid drop$`, + }, + "|", +)) + +// runChain runs the given packet through the rules in the given table and chain, updating +// tracer's internal state accordingly. It returns true if it hits a terminal action. +func (tracer *nftablesTracer) runChain(chname, sourceIP, protocol, destIP, destPort string) bool { + ch := tracer.nft.Table.Chains[chname] + if ch == nil { + tracer.t.Errorf("unknown chain %q", chname) + return true + } + + for _, ruleObj := range ch.Rules { + rule := ignoredRegexp.ReplaceAllLiteralString(ruleObj.Rule, "") + for rule != "" { + rule = strings.TrimLeft(rule, " ") + + switch { + case destAddrRegexp.MatchString(rule): + // `^ip6* daddr (!= )?(\S+)` + // Tests whether destIP does/doesn't match a literal. + match := destAddrRegexp.FindStringSubmatch(rule) + rule = strings.TrimPrefix(rule, match[0]) + not, ip := match[1], match[2] + if !tracer.addressMatches(destIP, not, ip) { + rule = "" + break + } + + case destAddrLocalRegexp.MatchString(rule): + // `^fib daddr type local` + // Tests whether destIP is a local IP. + match := destAddrLocalRegexp.FindStringSubmatch(rule) + rule = strings.TrimPrefix(rule, match[0]) + if !tracer.nodeIPs.Has(destIP) { + rule = "" + break + } + + case destPortRegexp.MatchString(rule): + // `^(tcp|udp|sctp) dport (\d+)` + // Tests whether destPort matches a literal. + match := destPortRegexp.FindStringSubmatch(rule) + rule = strings.TrimPrefix(rule, match[0]) + proto, port := match[1], match[2] + if protocol != proto || destPort != port { + rule = "" + break + } + + case jumpRegexp.MatchString(rule): + // `^(jump|goto) (\S+)$` + // Jumps to another chain. + match := jumpRegexp.FindStringSubmatch(rule) + rule = strings.TrimPrefix(rule, match[0]) + action, destChain := match[1], match[2] + + tracer.matches = append(tracer.matches, ruleObj.Rule) + terminated := tracer.runChain(destChain, sourceIP, protocol, destIP, destPort) + if terminated { + // destChain reached a terminal statement, so we + // terminate too. + return true + } else if action == "goto" { + // After a goto, return to our calling chain + // (without terminating) rather than continuing + // with this chain. + return false + } + + case verdictRegexp.MatchString(rule): + // `^(drop|reject)$` + // Drop/reject the packet and terminate processing. + match := verdictRegexp.FindStringSubmatch(rule) + verdict := match[1] + + tracer.matches = append(tracer.matches, ruleObj.Rule) + tracer.outputs = append(tracer.outputs, strings.ToUpper(verdict)) + return true + + default: + tracer.t.Errorf("unmatched rule: %s", ruleObj.Rule) + rule = "" + } + } + } + + return false +} + +// tracePacket determines what would happen to a packet with the given sourceIP, destIP, +// and destPort, given the indicated iptables ruleData. nodeIPs are the local node IPs (for +// rules matching "local"). (The protocol value should be lowercase as in nftables +// rules, not uppercase as in corev1.) +// +// The return values are: an array of matched rules (for debugging), the final packet +// destinations (a comma-separated list of IPs, or one of the special targets "ACCEPT", +// "DROP", or "REJECT"), and whether the packet would be masqueraded. +func tracePacket(t *testing.T, nft *knftables.Fake, sourceIP, protocol, destIP, destPort string, nodeIPs []string) ([]string, string, bool) { + tracer := newNFTablesTracer(t, nft, nodeIPs) + + // Collect "base chains" (ie, the chains that are run by netfilter directly rather + // than only being run when they are jumped to). Skip postrouting because it only + // does masquerading and we handle that separately. + var baseChains []string + for chname, ch := range nft.Table.Chains { + if ch.Priority != nil && chname != "nat-postrouting" { + baseChains = append(baseChains, chname) + } + } + + // Sort by priority + sort.Slice(baseChains, func(i, j int) bool { + // FIXME: IPv4 vs IPv6 doesn't actually matter here + iprio, _ := knftables.ParsePriority(knftables.IPv4Family, string(*nft.Table.Chains[baseChains[i]].Priority)) + jprio, _ := knftables.ParsePriority(knftables.IPv4Family, string(*nft.Table.Chains[baseChains[j]].Priority)) + return iprio < jprio + }) + + for _, chname := range baseChains { + terminated := tracer.runChain(chname, sourceIP, protocol, destIP, destPort) + if terminated { + break + } + } + + return tracer.matches, strings.Join(tracer.outputs, ", "), tracer.markMasq +} + type packetFlowTest struct { name string sourceIP string @@ -158,7 +393,28 @@ type packetFlowTest struct { func runPacketFlowTests(t *testing.T, line string, nft *knftables.Fake, nodeIPs []string, testCases []packetFlowTest) { for _, tc := range testCases { - t.Logf("Skipping test %s which doesn't work yet", tc.name) + if tc.output != "DROP" && tc.output != "REJECT" && tc.output != "" { + t.Logf("Skipping test %s which doesn't work yet", tc.name) + continue + } + t.Run(tc.name, func(t *testing.T) { + protocol := strings.ToLower(string(tc.protocol)) + if protocol == "" { + protocol = "tcp" + } + matches, output, masq := tracePacket(t, nft, tc.sourceIP, protocol, tc.destIP, fmt.Sprintf("%d", tc.destPort), nodeIPs) + var errors []string + if output != tc.output { + errors = append(errors, fmt.Sprintf("wrong output: expected %q got %q", tc.output, output)) + } + if masq != tc.masq { + errors = append(errors, fmt.Sprintf("wrong masq: expected %v got %v", tc.masq, masq)) + } + if errors != nil { + t.Errorf("Test %q of a packet from %s to %s:%d%s got result:\n%s\n\nBy matching:\n%s\n\n", + tc.name, tc.sourceIP, tc.destIP, tc.destPort, line, strings.Join(errors, "\n"), strings.Join(matches, "\n")) + } + }) } } diff --git a/pkg/proxy/nftables/proxier.go b/pkg/proxy/nftables/proxier.go index 531a88764b4..b95a627dbe6 100644 --- a/pkg/proxy/nftables/proxier.go +++ b/pkg/proxy/nftables/proxier.go @@ -174,8 +174,7 @@ type Proxier struct { // The following buffers are used to reuse memory and avoid allocations // that are significantly impacting performance. - filterRules proxyutil.LineBuffer - natRules proxyutil.LineBuffer + natRules proxyutil.LineBuffer // conntrackTCPLiberal indicates whether the system sets the kernel nf_conntrack_tcp_be_liberal conntrackTCPLiberal bool @@ -263,7 +262,6 @@ func NewProxier(ipFamily v1.IPFamily, serviceHealthServer: serviceHealthServer, healthzServer: healthzServer, precomputedProbabilities: make([]string, 0, 1001), - filterRules: proxyutil.NewLineBuffer(), natRules: proxyutil.NewLineBuffer(), nodePortAddresses: nodePortAddresses, networkInterfacer: proxyutil.RealNetwork{}, @@ -842,9 +840,14 @@ func (proxier *Proxier) syncProxyRules() { tx := proxier.nftables.NewTransaction() proxier.setupNFTables(tx) + // We need to use, eg, "ip daddr" for IPv4 but "ip6 daddr" for IPv6 + ipX := "ip" + if proxier.ipFamily == v1.IPv6Protocol { + ipX = "ip6" + } + // Reset all buffers used later. // This is to avoid memory reallocations and thus improve performance. - proxier.filterRules.Reset() proxier.natRules.Reset() // Accumulate service/endpoint chains to keep. @@ -953,25 +956,25 @@ func (proxier *Proxier) syncProxyRules() { loadBalancerTrafficChain = fwChain } - var internalTrafficFilterTarget, internalTrafficFilterComment string - var externalTrafficFilterTarget, externalTrafficFilterComment string + var internalTrafficFilterVerdict, internalTrafficFilterComment string + var externalTrafficFilterVerdict, externalTrafficFilterComment string if !hasEndpoints { // The service has no endpoints at all; hasInternalEndpoints and // hasExternalEndpoints will also be false, and we will not // generate any chains in the "nat" table for the service; only // rules in the "filter" table rejecting incoming packets for // the service's IPs. - internalTrafficFilterTarget = "REJECT" - internalTrafficFilterComment = fmt.Sprintf(`"%s has no endpoints"`, svcPortNameString) - externalTrafficFilterTarget = "REJECT" + internalTrafficFilterVerdict = "reject" + internalTrafficFilterComment = fmt.Sprintf("%s has no endpoints", svcPortNameString) + externalTrafficFilterVerdict = "reject" externalTrafficFilterComment = internalTrafficFilterComment } else { if !hasInternalEndpoints { // The internalTrafficPolicy is "Local" but there are no local // endpoints. Traffic to the clusterIP will be dropped, but // external traffic may still be accepted. - internalTrafficFilterTarget = "DROP" - internalTrafficFilterComment = fmt.Sprintf(`"%s has no local endpoints"`, svcPortNameString) + internalTrafficFilterVerdict = "drop" + internalTrafficFilterComment = fmt.Sprintf("%s has no local endpoints", svcPortNameString) serviceNoLocalEndpointsTotalInternal++ } if !hasExternalEndpoints { @@ -979,8 +982,8 @@ func (proxier *Proxier) syncProxyRules() { // local endpoints. Traffic to "external" IPs from outside // the cluster will be dropped, but traffic from inside // the cluster may still be accepted. - externalTrafficFilterTarget = "DROP" - externalTrafficFilterComment = fmt.Sprintf(`"%s has no local endpoints"`, svcPortNameString) + externalTrafficFilterVerdict = "drop" + externalTrafficFilterComment = fmt.Sprintf("%s has no local endpoints", svcPortNameString) serviceNoLocalEndpointsTotalExternal++ } } @@ -996,14 +999,15 @@ func (proxier *Proxier) syncProxyRules() { "-j", string(internalTrafficChain)) } else { // No endpoints. - proxier.filterRules.Write( - "-A", string(kubeServicesFilterChain), - "-m", "comment", "--comment", internalTrafficFilterComment, - "-m", protocol, "-p", protocol, - "-d", svcInfo.ClusterIP().String(), - "--dport", strconv.Itoa(svcInfo.Port()), - "-j", internalTrafficFilterTarget, - ) + tx.Add(&knftables.Rule{ + Chain: kubeServicesFilterChain, + Rule: knftables.Concat( + ipX, "daddr", svcInfo.ClusterIP(), + protocol, "dport", svcInfo.Port(), + internalTrafficFilterVerdict, + ), + Comment: &internalTrafficFilterComment, + }) } // Capture externalIPs. @@ -1023,14 +1027,15 @@ func (proxier *Proxier) syncProxyRules() { // Either no endpoints at all (REJECT) or no endpoints for // external traffic (DROP anything that didn't get // short-circuited by the EXT chain.) - proxier.filterRules.Write( - "-A", string(kubeExternalServicesChain), - "-m", "comment", "--comment", externalTrafficFilterComment, - "-m", protocol, "-p", protocol, - "-d", externalIP, - "--dport", strconv.Itoa(svcInfo.Port()), - "-j", externalTrafficFilterTarget, - ) + tx.Add(&knftables.Rule{ + Chain: kubeExternalServicesChain, + Rule: knftables.Concat( + ipX, "daddr", externalIP, + protocol, "dport", svcInfo.Port(), + externalTrafficFilterVerdict, + ), + Comment: &externalTrafficFilterComment, + }) } } @@ -1047,13 +1052,16 @@ func (proxier *Proxier) syncProxyRules() { } if usesFWChain { - proxier.filterRules.Write( - "-A", string(kubeFirewallChain), - "-m", "comment", "--comment", fmt.Sprintf(`"%s traffic not accepted by %s"`, svcPortNameString, svcInfo.firewallChainName), - "-m", protocol, "-p", protocol, - "-d", lbip, - "--dport", strconv.Itoa(svcInfo.Port()), - "-j", "DROP") + comment := fmt.Sprintf("%s traffic not accepted by %s", svcPortNameString, svcInfo.firewallChainName) + tx.Add(&knftables.Rule{ + Chain: kubeFirewallChain, + Rule: knftables.Concat( + ipX, "daddr", lbip, + protocol, "dport", svcInfo.Port(), + "drop", + ), + Comment: &comment, + }) } } if !hasExternalEndpoints { @@ -1061,14 +1069,15 @@ func (proxier *Proxier) syncProxyRules() { // external traffic (DROP anything that didn't get short-circuited // by the EXT chain.) for _, lbip := range svcInfo.LoadBalancerVIPStrings() { - proxier.filterRules.Write( - "-A", string(kubeExternalServicesChain), - "-m", "comment", "--comment", externalTrafficFilterComment, - "-m", protocol, "-p", protocol, - "-d", lbip, - "--dport", strconv.Itoa(svcInfo.Port()), - "-j", externalTrafficFilterTarget, - ) + tx.Add(&knftables.Rule{ + Chain: kubeExternalServicesChain, + Rule: knftables.Concat( + ipX, "daddr", lbip, + protocol, "dport", svcInfo.Port(), + externalTrafficFilterVerdict, + ), + Comment: &externalTrafficFilterComment, + }) } } @@ -1089,14 +1098,15 @@ func (proxier *Proxier) syncProxyRules() { // Either no endpoints at all (REJECT) or no endpoints for // external traffic (DROP anything that didn't get // short-circuited by the EXT chain.) - proxier.filterRules.Write( - "-A", string(kubeExternalServicesChain), - "-m", "comment", "--comment", externalTrafficFilterComment, - "-m", "addrtype", "--dst-type", "LOCAL", - "-m", protocol, "-p", protocol, - "--dport", strconv.Itoa(svcInfo.NodePort()), - "-j", externalTrafficFilterTarget, - ) + tx.Add(&knftables.Rule{ + Chain: kubeExternalServicesChain, + Rule: knftables.Concat( + "fib daddr type local", + protocol, "dport", svcInfo.NodePort(), + externalTrafficFilterVerdict, + ), + Comment: &externalTrafficFilterComment, + }) } } @@ -1331,14 +1341,12 @@ func (proxier *Proxier) syncProxyRules() { klog.ErrorS(err, "Failed to list nftables chains: stale chains will not be deleted") } - metrics.IptablesRulesTotal.WithLabelValues(string(utiliptables.TableFilter)).Set(float64(proxier.filterRules.Lines())) metrics.IptablesRulesTotal.WithLabelValues(string(utiliptables.TableNAT)).Set(float64(proxier.natRules.Lines())) // Sync rules. klog.V(2).InfoS("Reloading service nftables data", "numServices", len(proxier.svcPortMap), "numEndpoints", totalEndpoints, - "numFilterRules", proxier.filterRules.Lines(), "numNATRules", proxier.natRules.Lines(), ) diff --git a/pkg/proxy/nftables/proxier_test.go b/pkg/proxy/nftables/proxier_test.go index d10a81c3aa7..660fea1f99e 100644 --- a/pkg/proxy/nftables/proxier_test.go +++ b/pkg/proxy/nftables/proxier_test.go @@ -321,7 +321,6 @@ func NewFakeProxier(ipFamily v1.IPFamily) (*knftables.Fake, *Proxier) { hostname: testHostname, serviceHealthServer: healthcheck.NewFakeServiceHealthServer(), precomputedProbabilities: make([]string, 0, 1001), - filterRules: proxyutil.NewLineBuffer(), natRules: proxyutil.NewLineBuffer(), nodeIP: netutils.ParseIPSloppy(testNodeIP), nodePortAddresses: proxyutil.NewNodePortAddresses(ipFamily, nil), @@ -540,6 +539,9 @@ func TestOverallNFTablesRules(t *testing.T) { add chain ip kube-proxy service-42NFTM6N-ns2/svc2/tcp/p80 add chain ip kube-proxy external-42NFTM6N-ns2/svc2/tcp/p80 add chain ip kube-proxy endpoint-SGOXE6O3-ns2/svc2/tcp/p80__10.180.0.2/80 + add rule ip kube-proxy external-services ip daddr 192.168.99.22 tcp dport 80 drop comment "ns2/svc2:p80 has no local endpoints" + add rule ip kube-proxy external-services ip daddr 1.2.3.4 tcp dport 80 drop comment "ns2/svc2:p80 has no local endpoints" + add rule ip kube-proxy external-services fib daddr type local tcp dport 3001 drop comment "ns2/svc2:p80 has no local endpoints" # svc3 add chain ip kube-proxy service-4AT6LBPK-ns3/svc3/tcp/p80 @@ -557,6 +559,10 @@ func TestOverallNFTablesRules(t *testing.T) { add chain ip kube-proxy external-HVFWP5L3-ns5/svc5/tcp/p80 add chain ip kube-proxy firewall-HVFWP5L3-ns5/svc5/tcp/p80 add chain ip kube-proxy endpoint-GTK6MW7G-ns5/svc5/tcp/p80__10.180.0.3/80 + add rule ip kube-proxy firewall ip daddr 5.6.7.8 tcp dport 80 drop comment "ns5/svc5:p80 traffic not accepted by firewall-HVFWP5L3-ns5/svc5/tcp/p80" + + # svc6 + add rule ip kube-proxy services-filter ip daddr 172.30.0.46 tcp dport 80 reject comment "ns6/svc6:p80 has no endpoints" `) assertNFTablesTransactionEqual(t, getLine(), expected, nft.Dump()) @@ -4397,6 +4403,8 @@ func TestSyncProxyRulesRepeated(t *testing.T) { add chain ip kube-proxy service-4AT6LBPK-ns3/svc3/tcp/p80 add chain ip kube-proxy endpoint-2OCDJSZQ-ns3/svc3/tcp/p80__10.0.3.1/80 + + add rule ip kube-proxy services-filter ip daddr 172.30.0.44 tcp dport 80 reject comment "ns4/svc4:p80 has no endpoints" `) assertNFTablesTransactionEqual(t, getLine(), expected, nft.Dump())