From 7cedc3d7417d67e71cbe9671d8d50a0f1c74e8d6 Mon Sep 17 00:00:00 2001 From: Dan Winship Date: Wed, 29 Nov 2023 10:58:48 -0500 Subject: [PATCH 1/2] Simplify creation/tracking of chains In the original version of "MinimizeIPTablesRestore", we skipped the bottom half of the sync loop when we weren't re-syncing a service, so certain things that couldn't be skipped had to be done in the top half. But the code was later changed to always run through the whole loop body (just not necessarily writing out rules in the bottom half), so we can reorganize things now to put some related bits of code back together. (In particular, this also resolves the fact that we were accidentally adding the endpoint chains to activeNATChains twice.) Also change activeNATChains to be a proper generic Set type. --- pkg/proxy/iptables/proxier.go | 34 ++++++++++------------------------ 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/pkg/proxy/iptables/proxier.go b/pkg/proxy/iptables/proxier.go index 9ecc3ac5090..ff51d5aacbc 100644 --- a/pkg/proxy/iptables/proxier.go +++ b/pkg/proxy/iptables/proxier.go @@ -926,7 +926,7 @@ func (proxier *Proxier) syncProxyRules() { } // Accumulate NAT chains to keep. - activeNATChains := map[utiliptables.Chain]bool{} // use a map as a set + activeNATChains := sets.New[utiliptables.Chain]() // To avoid growing this slice, we arbitrarily set its size to 64, // there is never more than that many arguments for a single line. @@ -964,26 +964,13 @@ func (proxier *Proxier) syncProxyRules() { allEndpoints := proxier.endpointsMap[svcName] clusterEndpoints, localEndpoints, allLocallyReachableEndpoints, hasEndpoints := proxy.CategorizeEndpoints(allEndpoints, svcInfo, proxier.nodeLabels) - // Note the endpoint chains that will be used - for _, ep := range allLocallyReachableEndpoints { - if epInfo, ok := ep.(*endpointInfo); ok { - activeNATChains[epInfo.ChainName] = true - } - } - // clusterPolicyChain contains the endpoints used with "Cluster" traffic policy clusterPolicyChain := svcInfo.clusterPolicyChainName usesClusterPolicyChain := len(clusterEndpoints) > 0 && svcInfo.UsesClusterEndpoints() - if usesClusterPolicyChain { - activeNATChains[clusterPolicyChain] = true - } // localPolicyChain contains the endpoints used with "Local" traffic policy localPolicyChain := svcInfo.localPolicyChainName usesLocalPolicyChain := len(localEndpoints) > 0 && svcInfo.UsesLocalEndpoints() - if usesLocalPolicyChain { - activeNATChains[localPolicyChain] = true - } // internalPolicyChain is the chain containing the endpoints for // "internal" (ClusterIP) traffic. internalTrafficChain is the chain that @@ -1023,9 +1010,6 @@ func (proxier *Proxier) syncProxyRules() { // because we need the local-traffic-short-circuiting rules even when there // are no externally-usable endpoints. usesExternalTrafficChain := hasEndpoints && svcInfo.ExternallyAccessible() - if usesExternalTrafficChain { - activeNATChains[externalTrafficChain] = true - } // Traffic to LoadBalancer IPs can go directly to externalTrafficChain // unless LoadBalancerSourceRanges is in use in which case we will @@ -1034,7 +1018,6 @@ func (proxier *Proxier) syncProxyRules() { fwChain := svcInfo.firewallChainName usesFWChain := hasEndpoints && len(svcInfo.LoadBalancerVIPStrings()) > 0 && len(svcInfo.LoadBalancerSourceRanges()) > 0 if usesFWChain { - activeNATChains[fwChain] = true loadBalancerTrafficChain = fwChain } @@ -1203,10 +1186,9 @@ func (proxier *Proxier) syncProxyRules() { } // If the SVC/SVL/EXT/FW/SEP chains have not changed since the last sync - // then we can omit them from the restore input. (We have already marked - // them in activeNATChains, so they won't get deleted.) However, we have - // to still figure out how many chains we _would_ have written to make the - // metrics come out right, so we just compute them and throw them away. + // then we can omit them from the restore input. However, we have to still + // figure out how many chains we _would_ have written, to make the metrics + // come out right, so we just compute them and throw them away. if tryPartialSync && !serviceChanged.Has(svcName.NamespacedName.String()) && !endpointsChanged.Has(svcName.NamespacedName.String()) { natChains = skippedNatChains natRules = skippedNatRules @@ -1245,6 +1227,7 @@ func (proxier *Proxier) syncProxyRules() { // then jump to externalPolicyChain. if usesExternalTrafficChain { natChains.Write(utiliptables.MakeChainLine(externalTrafficChain)) + activeNATChains.Insert(externalTrafficChain) if !svcInfo.ExternalPolicyLocal() { // If we are using non-local endpoints we need to masquerade, @@ -1299,6 +1282,7 @@ func (proxier *Proxier) syncProxyRules() { // Set up firewall chain, if needed if usesFWChain { natChains.Write(utiliptables.MakeChainLine(fwChain)) + activeNATChains.Insert(fwChain) // The service firewall rules are created based on the // loadBalancerSourceRanges field. This only works for VIP-like @@ -1347,6 +1331,7 @@ func (proxier *Proxier) syncProxyRules() { // from clusterPolicyChain to the clusterEndpoints if usesClusterPolicyChain { natChains.Write(utiliptables.MakeChainLine(clusterPolicyChain)) + activeNATChains.Insert(clusterPolicyChain) proxier.writeServiceToEndpointRules(natRules, svcPortNameString, svcInfo, clusterPolicyChain, clusterEndpoints, args) } @@ -1354,6 +1339,7 @@ func (proxier *Proxier) syncProxyRules() { // from localPolicyChain to the localEndpoints if usesLocalPolicyChain { natChains.Write(utiliptables.MakeChainLine(localPolicyChain)) + activeNATChains.Insert(localPolicyChain) proxier.writeServiceToEndpointRules(natRules, svcPortNameString, svcInfo, localPolicyChain, localEndpoints, args) } @@ -1369,7 +1355,7 @@ func (proxier *Proxier) syncProxyRules() { // Create the endpoint chain natChains.Write(utiliptables.MakeChainLine(endpointChain)) - activeNATChains[endpointChain] = true + activeNATChains.Insert(endpointChain) args = append(args[:0], "-A", string(endpointChain)) args = proxier.appendServiceCommentLocked(args, svcPortNameString) @@ -1401,7 +1387,7 @@ func (proxier *Proxier) syncProxyRules() { existingNATChains = utiliptables.GetChainsFromTable(proxier.iptablesData.Bytes()) for chain := range existingNATChains { - if !activeNATChains[chain] { + if !activeNATChains.Has(chain) { chainString := string(chain) if !isServiceChainName(chainString) { // Ignore chains that aren't ours. From 8acf185791706142c630ed989d0adde4d770c636 Mon Sep 17 00:00:00 2001 From: Dan Winship Date: Wed, 29 Nov 2023 11:09:24 -0500 Subject: [PATCH 2/2] Use a generic Set for utiliptables.GetChainsFromTable --- pkg/proxy/iptables/proxier.go | 33 +++++++--------- pkg/util/iptables/save_restore.go | 12 +++--- pkg/util/iptables/save_restore_test.go | 53 +++++++++++++------------- 3 files changed, 48 insertions(+), 50 deletions(-) diff --git a/pkg/proxy/iptables/proxier.go b/pkg/proxy/iptables/proxier.go index ff51d5aacbc..1109e096324 100644 --- a/pkg/proxy/iptables/proxier.go +++ b/pkg/proxy/iptables/proxier.go @@ -421,7 +421,7 @@ func CleanupLeftovers(ipt utiliptables.Interface) (encounteredError bool) { natChains.Write("*nat") // Start with chains we know we need to remove. for _, chain := range []utiliptables.Chain{kubeServicesChain, kubeNodePortsChain, kubePostroutingChain} { - if _, found := existingNATChains[chain]; found { + if existingNATChains.Has(chain) { chainString := string(chain) natChains.Write(utiliptables.MakeChainLine(chain)) // flush natRules.Write("-X", chainString) // delete @@ -457,7 +457,7 @@ func CleanupLeftovers(ipt utiliptables.Interface) (encounteredError bool) { filterRules := proxyutil.NewLineBuffer() filterChains.Write("*filter") for _, chain := range []utiliptables.Chain{kubeServicesChain, kubeExternalServicesChain, kubeForwardChain, kubeNodePortsChain} { - if _, found := existingFilterChains[chain]; found { + if existingFilterChains.Has(chain) { chainString := string(chain) filterChains.Write(utiliptables.MakeChainLine(chain)) filterRules.Write("-X", chainString) @@ -1380,26 +1380,21 @@ func (proxier *Proxier) syncProxyRules() { // active rules, so they're harmless other than taking up memory.) deletedChains := 0 if !proxier.largeClusterMode || time.Since(proxier.lastIPTablesCleanup) > proxier.syncPeriod { - var existingNATChains map[utiliptables.Chain]struct{} - proxier.iptablesData.Reset() if err := proxier.iptables.SaveInto(utiliptables.TableNAT, proxier.iptablesData); err == nil { - existingNATChains = utiliptables.GetChainsFromTable(proxier.iptablesData.Bytes()) - - for chain := range existingNATChains { - if !activeNATChains.Has(chain) { - chainString := string(chain) - if !isServiceChainName(chainString) { - // Ignore chains that aren't ours. - continue - } - // We must (as per iptables) write a chain-line - // for it, which has the nice effect of flushing - // the chain. Then we can remove the chain. - proxier.natChains.Write(utiliptables.MakeChainLine(chain)) - proxier.natRules.Write("-X", chainString) - deletedChains++ + existingNATChains := utiliptables.GetChainsFromTable(proxier.iptablesData.Bytes()) + for chain := range existingNATChains.Difference(activeNATChains) { + chainString := string(chain) + if !isServiceChainName(chainString) { + // Ignore chains that aren't ours. + continue } + // We must (as per iptables) write a chain-line + // for it, which has the nice effect of flushing + // the chain. Then we can remove the chain. + proxier.natChains.Write(utiliptables.MakeChainLine(chain)) + proxier.natRules.Write("-X", chainString) + deletedChains++ } proxier.lastIPTablesCleanup = time.Now() } else { diff --git a/pkg/util/iptables/save_restore.go b/pkg/util/iptables/save_restore.go index b788beb9113..5f8b78e94ae 100644 --- a/pkg/util/iptables/save_restore.go +++ b/pkg/util/iptables/save_restore.go @@ -19,6 +19,8 @@ package iptables import ( "bytes" "fmt" + + "k8s.io/apimachinery/pkg/util/sets" ) // MakeChainLine return an iptables-save/restore formatted chain line given a Chain @@ -27,10 +29,10 @@ func MakeChainLine(chain Chain) string { } // GetChainsFromTable parses iptables-save data to find the chains that are defined. It -// assumes that save contains a single table's data, and returns a map with keys for every +// assumes that save contains a single table's data, and returns a set with keys for every // chain defined in that table. -func GetChainsFromTable(save []byte) map[Chain]struct{} { - chainsMap := make(map[Chain]struct{}) +func GetChainsFromTable(save []byte) sets.Set[Chain] { + chainsSet := sets.New[Chain]() for { i := bytes.Index(save, []byte("\n:")) @@ -45,8 +47,8 @@ func GetChainsFromTable(save []byte) map[Chain]struct{} { break } chain := Chain(save[:end]) - chainsMap[chain] = struct{}{} + chainsSet.Insert(chain) save = save[end:] } - return chainsMap + return chainsSet } diff --git a/pkg/util/iptables/save_restore_test.go b/pkg/util/iptables/save_restore_test.go index 2a4c383e12f..1030f6194c5 100644 --- a/pkg/util/iptables/save_restore_test.go +++ b/pkg/util/iptables/save_restore_test.go @@ -20,19 +20,19 @@ import ( "testing" "github.com/lithammer/dedent" + + "k8s.io/apimachinery/pkg/util/sets" ) -func checkChains(t *testing.T, save []byte, expected map[Chain]struct{}) { +func checkChains(t *testing.T, save []byte, expected sets.Set[Chain]) { chains := GetChainsFromTable(save) - for chain := range expected { - if _, exists := chains[chain]; !exists { - t.Errorf("GetChainsFromTable expected chain not present: %s", chain) - } + missing := expected.Difference(chains) + if len(missing) != 0 { + t.Errorf("GetChainsFromTable expected chains not present: %v", missing.UnsortedList()) } - for chain := range chains { - if _, exists := expected[chain]; !exists { - t.Errorf("GetChainsFromTable chain unexpectedly present: %s", chain) - } + extra := chains.Difference(expected) + if len(extra) != 0 { + t.Errorf("GetChainsFromTable expected chains unexpectedly present: %v", extra.UnsortedList()) } } @@ -77,22 +77,23 @@ func TestGetChainsFromTable(t *testing.T) { -A KUBE-SVC-6666666666666666 -m comment --comment "kube-system/kube-dns:dns" -j KUBE-SVC-1111111111111111 COMMIT `) - expected := map[Chain]struct{}{ - ChainPrerouting: {}, - Chain("INPUT"): {}, - Chain("OUTPUT"): {}, - ChainPostrouting: {}, - Chain("DOCKER"): {}, - Chain("KUBE-NODEPORT-CONTAINER"): {}, - Chain("KUBE-NODEPORT-HOST"): {}, - Chain("KUBE-PORTALS-CONTAINER"): {}, - Chain("KUBE-PORTALS-HOST"): {}, - Chain("KUBE-SVC-1111111111111111"): {}, - Chain("KUBE-SVC-2222222222222222"): {}, - Chain("KUBE-SVC-3333333333333333"): {}, - Chain("KUBE-SVC-4444444444444444"): {}, - Chain("KUBE-SVC-5555555555555555"): {}, - Chain("KUBE-SVC-6666666666666666"): {}, - } + + expected := sets.New( + ChainPrerouting, + Chain("INPUT"), + Chain("OUTPUT"), + ChainPostrouting, + Chain("DOCKER"), + Chain("KUBE-NODEPORT-CONTAINER"), + Chain("KUBE-NODEPORT-HOST"), + Chain("KUBE-PORTALS-CONTAINER"), + Chain("KUBE-PORTALS-HOST"), + Chain("KUBE-SVC-1111111111111111"), + Chain("KUBE-SVC-2222222222222222"), + Chain("KUBE-SVC-3333333333333333"), + Chain("KUBE-SVC-4444444444444444"), + Chain("KUBE-SVC-5555555555555555"), + Chain("KUBE-SVC-6666666666666666"), + ) checkChains(t, []byte(iptablesSave), expected) }