diff --git a/pkg/proxy/iptables/proxier.go b/pkg/proxy/iptables/proxier.go index 9ecc3ac5090..1109e096324 100644 --- a/pkg/proxy/iptables/proxier.go +++ b/pkg/proxy/iptables/proxier.go @@ -421,7 +421,7 @@ func CleanupLeftovers(ipt utiliptables.Interface) (encounteredError bool) { natChains.Write("*nat") // Start with chains we know we need to remove. for _, chain := range []utiliptables.Chain{kubeServicesChain, kubeNodePortsChain, kubePostroutingChain} { - if _, found := existingNATChains[chain]; found { + if existingNATChains.Has(chain) { chainString := string(chain) natChains.Write(utiliptables.MakeChainLine(chain)) // flush natRules.Write("-X", chainString) // delete @@ -457,7 +457,7 @@ func CleanupLeftovers(ipt utiliptables.Interface) (encounteredError bool) { filterRules := proxyutil.NewLineBuffer() filterChains.Write("*filter") for _, chain := range []utiliptables.Chain{kubeServicesChain, kubeExternalServicesChain, kubeForwardChain, kubeNodePortsChain} { - if _, found := existingFilterChains[chain]; found { + if existingFilterChains.Has(chain) { chainString := string(chain) filterChains.Write(utiliptables.MakeChainLine(chain)) filterRules.Write("-X", chainString) @@ -926,7 +926,7 @@ func (proxier *Proxier) syncProxyRules() { } // Accumulate NAT chains to keep. - activeNATChains := map[utiliptables.Chain]bool{} // use a map as a set + activeNATChains := sets.New[utiliptables.Chain]() // To avoid growing this slice, we arbitrarily set its size to 64, // there is never more than that many arguments for a single line. @@ -964,26 +964,13 @@ func (proxier *Proxier) syncProxyRules() { allEndpoints := proxier.endpointsMap[svcName] clusterEndpoints, localEndpoints, allLocallyReachableEndpoints, hasEndpoints := proxy.CategorizeEndpoints(allEndpoints, svcInfo, proxier.nodeLabels) - // Note the endpoint chains that will be used - for _, ep := range allLocallyReachableEndpoints { - if epInfo, ok := ep.(*endpointInfo); ok { - activeNATChains[epInfo.ChainName] = true - } - } - // clusterPolicyChain contains the endpoints used with "Cluster" traffic policy clusterPolicyChain := svcInfo.clusterPolicyChainName usesClusterPolicyChain := len(clusterEndpoints) > 0 && svcInfo.UsesClusterEndpoints() - if usesClusterPolicyChain { - activeNATChains[clusterPolicyChain] = true - } // localPolicyChain contains the endpoints used with "Local" traffic policy localPolicyChain := svcInfo.localPolicyChainName usesLocalPolicyChain := len(localEndpoints) > 0 && svcInfo.UsesLocalEndpoints() - if usesLocalPolicyChain { - activeNATChains[localPolicyChain] = true - } // internalPolicyChain is the chain containing the endpoints for // "internal" (ClusterIP) traffic. internalTrafficChain is the chain that @@ -1023,9 +1010,6 @@ func (proxier *Proxier) syncProxyRules() { // because we need the local-traffic-short-circuiting rules even when there // are no externally-usable endpoints. usesExternalTrafficChain := hasEndpoints && svcInfo.ExternallyAccessible() - if usesExternalTrafficChain { - activeNATChains[externalTrafficChain] = true - } // Traffic to LoadBalancer IPs can go directly to externalTrafficChain // unless LoadBalancerSourceRanges is in use in which case we will @@ -1034,7 +1018,6 @@ func (proxier *Proxier) syncProxyRules() { fwChain := svcInfo.firewallChainName usesFWChain := hasEndpoints && len(svcInfo.LoadBalancerVIPStrings()) > 0 && len(svcInfo.LoadBalancerSourceRanges()) > 0 if usesFWChain { - activeNATChains[fwChain] = true loadBalancerTrafficChain = fwChain } @@ -1203,10 +1186,9 @@ func (proxier *Proxier) syncProxyRules() { } // If the SVC/SVL/EXT/FW/SEP chains have not changed since the last sync - // then we can omit them from the restore input. (We have already marked - // them in activeNATChains, so they won't get deleted.) However, we have - // to still figure out how many chains we _would_ have written to make the - // metrics come out right, so we just compute them and throw them away. + // then we can omit them from the restore input. However, we have to still + // figure out how many chains we _would_ have written, to make the metrics + // come out right, so we just compute them and throw them away. if tryPartialSync && !serviceChanged.Has(svcName.NamespacedName.String()) && !endpointsChanged.Has(svcName.NamespacedName.String()) { natChains = skippedNatChains natRules = skippedNatRules @@ -1245,6 +1227,7 @@ func (proxier *Proxier) syncProxyRules() { // then jump to externalPolicyChain. if usesExternalTrafficChain { natChains.Write(utiliptables.MakeChainLine(externalTrafficChain)) + activeNATChains.Insert(externalTrafficChain) if !svcInfo.ExternalPolicyLocal() { // If we are using non-local endpoints we need to masquerade, @@ -1299,6 +1282,7 @@ func (proxier *Proxier) syncProxyRules() { // Set up firewall chain, if needed if usesFWChain { natChains.Write(utiliptables.MakeChainLine(fwChain)) + activeNATChains.Insert(fwChain) // The service firewall rules are created based on the // loadBalancerSourceRanges field. This only works for VIP-like @@ -1347,6 +1331,7 @@ func (proxier *Proxier) syncProxyRules() { // from clusterPolicyChain to the clusterEndpoints if usesClusterPolicyChain { natChains.Write(utiliptables.MakeChainLine(clusterPolicyChain)) + activeNATChains.Insert(clusterPolicyChain) proxier.writeServiceToEndpointRules(natRules, svcPortNameString, svcInfo, clusterPolicyChain, clusterEndpoints, args) } @@ -1354,6 +1339,7 @@ func (proxier *Proxier) syncProxyRules() { // from localPolicyChain to the localEndpoints if usesLocalPolicyChain { natChains.Write(utiliptables.MakeChainLine(localPolicyChain)) + activeNATChains.Insert(localPolicyChain) proxier.writeServiceToEndpointRules(natRules, svcPortNameString, svcInfo, localPolicyChain, localEndpoints, args) } @@ -1369,7 +1355,7 @@ func (proxier *Proxier) syncProxyRules() { // Create the endpoint chain natChains.Write(utiliptables.MakeChainLine(endpointChain)) - activeNATChains[endpointChain] = true + activeNATChains.Insert(endpointChain) args = append(args[:0], "-A", string(endpointChain)) args = proxier.appendServiceCommentLocked(args, svcPortNameString) @@ -1394,26 +1380,21 @@ func (proxier *Proxier) syncProxyRules() { // active rules, so they're harmless other than taking up memory.) deletedChains := 0 if !proxier.largeClusterMode || time.Since(proxier.lastIPTablesCleanup) > proxier.syncPeriod { - var existingNATChains map[utiliptables.Chain]struct{} - proxier.iptablesData.Reset() if err := proxier.iptables.SaveInto(utiliptables.TableNAT, proxier.iptablesData); err == nil { - existingNATChains = utiliptables.GetChainsFromTable(proxier.iptablesData.Bytes()) - - for chain := range existingNATChains { - if !activeNATChains[chain] { - chainString := string(chain) - if !isServiceChainName(chainString) { - // Ignore chains that aren't ours. - continue - } - // We must (as per iptables) write a chain-line - // for it, which has the nice effect of flushing - // the chain. Then we can remove the chain. - proxier.natChains.Write(utiliptables.MakeChainLine(chain)) - proxier.natRules.Write("-X", chainString) - deletedChains++ + existingNATChains := utiliptables.GetChainsFromTable(proxier.iptablesData.Bytes()) + for chain := range existingNATChains.Difference(activeNATChains) { + chainString := string(chain) + if !isServiceChainName(chainString) { + // Ignore chains that aren't ours. + continue } + // We must (as per iptables) write a chain-line + // for it, which has the nice effect of flushing + // the chain. Then we can remove the chain. + proxier.natChains.Write(utiliptables.MakeChainLine(chain)) + proxier.natRules.Write("-X", chainString) + deletedChains++ } proxier.lastIPTablesCleanup = time.Now() } else { diff --git a/pkg/util/iptables/save_restore.go b/pkg/util/iptables/save_restore.go index 38cc8c6c76f..d61de34c16a 100644 --- a/pkg/util/iptables/save_restore.go +++ b/pkg/util/iptables/save_restore.go @@ -19,6 +19,8 @@ package iptables import ( "bytes" "fmt" + + "k8s.io/apimachinery/pkg/util/sets" ) // MakeChainLine return an iptables-save/restore formatted chain line given a Chain @@ -27,10 +29,10 @@ func MakeChainLine(chain Chain) string { } // GetChainsFromTable parses iptables-save data to find the chains that are defined. It -// assumes that save contains a single table's data, and returns a map with keys for every +// assumes that save contains a single table's data, and returns a set with keys for every // chain defined in that table. -func GetChainsFromTable(save []byte) map[Chain]struct{} { - chainsMap := make(map[Chain]struct{}) +func GetChainsFromTable(save []byte) sets.Set[Chain] { + chainsSet := sets.New[Chain]() for { i := bytes.Index(save, []byte("\n:")) @@ -45,8 +47,8 @@ func GetChainsFromTable(save []byte) map[Chain]struct{} { break } chain := Chain(save[:end]) - chainsMap[chain] = struct{}{} + chainsSet.Insert(chain) save = save[end:] } - return chainsMap + return chainsSet } diff --git a/pkg/util/iptables/save_restore_test.go b/pkg/util/iptables/save_restore_test.go index 2a4c383e12f..1030f6194c5 100644 --- a/pkg/util/iptables/save_restore_test.go +++ b/pkg/util/iptables/save_restore_test.go @@ -20,19 +20,19 @@ import ( "testing" "github.com/lithammer/dedent" + + "k8s.io/apimachinery/pkg/util/sets" ) -func checkChains(t *testing.T, save []byte, expected map[Chain]struct{}) { +func checkChains(t *testing.T, save []byte, expected sets.Set[Chain]) { chains := GetChainsFromTable(save) - for chain := range expected { - if _, exists := chains[chain]; !exists { - t.Errorf("GetChainsFromTable expected chain not present: %s", chain) - } + missing := expected.Difference(chains) + if len(missing) != 0 { + t.Errorf("GetChainsFromTable expected chains not present: %v", missing.UnsortedList()) } - for chain := range chains { - if _, exists := expected[chain]; !exists { - t.Errorf("GetChainsFromTable chain unexpectedly present: %s", chain) - } + extra := chains.Difference(expected) + if len(extra) != 0 { + t.Errorf("GetChainsFromTable expected chains unexpectedly present: %v", extra.UnsortedList()) } } @@ -77,22 +77,23 @@ func TestGetChainsFromTable(t *testing.T) { -A KUBE-SVC-6666666666666666 -m comment --comment "kube-system/kube-dns:dns" -j KUBE-SVC-1111111111111111 COMMIT `) - expected := map[Chain]struct{}{ - ChainPrerouting: {}, - Chain("INPUT"): {}, - Chain("OUTPUT"): {}, - ChainPostrouting: {}, - Chain("DOCKER"): {}, - Chain("KUBE-NODEPORT-CONTAINER"): {}, - Chain("KUBE-NODEPORT-HOST"): {}, - Chain("KUBE-PORTALS-CONTAINER"): {}, - Chain("KUBE-PORTALS-HOST"): {}, - Chain("KUBE-SVC-1111111111111111"): {}, - Chain("KUBE-SVC-2222222222222222"): {}, - Chain("KUBE-SVC-3333333333333333"): {}, - Chain("KUBE-SVC-4444444444444444"): {}, - Chain("KUBE-SVC-5555555555555555"): {}, - Chain("KUBE-SVC-6666666666666666"): {}, - } + + expected := sets.New( + ChainPrerouting, + Chain("INPUT"), + Chain("OUTPUT"), + ChainPostrouting, + Chain("DOCKER"), + Chain("KUBE-NODEPORT-CONTAINER"), + Chain("KUBE-NODEPORT-HOST"), + Chain("KUBE-PORTALS-CONTAINER"), + Chain("KUBE-PORTALS-HOST"), + Chain("KUBE-SVC-1111111111111111"), + Chain("KUBE-SVC-2222222222222222"), + Chain("KUBE-SVC-3333333333333333"), + Chain("KUBE-SVC-4444444444444444"), + Chain("KUBE-SVC-5555555555555555"), + Chain("KUBE-SVC-6666666666666666"), + ) checkChains(t, []byte(iptablesSave), expected) }