From 8acf185791706142c630ed989d0adde4d770c636 Mon Sep 17 00:00:00 2001 From: Dan Winship Date: Wed, 29 Nov 2023 11:09:24 -0500 Subject: [PATCH] Use a generic Set for utiliptables.GetChainsFromTable --- pkg/proxy/iptables/proxier.go | 33 +++++++--------- pkg/util/iptables/save_restore.go | 12 +++--- pkg/util/iptables/save_restore_test.go | 53 +++++++++++++------------- 3 files changed, 48 insertions(+), 50 deletions(-) diff --git a/pkg/proxy/iptables/proxier.go b/pkg/proxy/iptables/proxier.go index ff51d5aacbc..1109e096324 100644 --- a/pkg/proxy/iptables/proxier.go +++ b/pkg/proxy/iptables/proxier.go @@ -421,7 +421,7 @@ func CleanupLeftovers(ipt utiliptables.Interface) (encounteredError bool) { natChains.Write("*nat") // Start with chains we know we need to remove. for _, chain := range []utiliptables.Chain{kubeServicesChain, kubeNodePortsChain, kubePostroutingChain} { - if _, found := existingNATChains[chain]; found { + if existingNATChains.Has(chain) { chainString := string(chain) natChains.Write(utiliptables.MakeChainLine(chain)) // flush natRules.Write("-X", chainString) // delete @@ -457,7 +457,7 @@ func CleanupLeftovers(ipt utiliptables.Interface) (encounteredError bool) { filterRules := proxyutil.NewLineBuffer() filterChains.Write("*filter") for _, chain := range []utiliptables.Chain{kubeServicesChain, kubeExternalServicesChain, kubeForwardChain, kubeNodePortsChain} { - if _, found := existingFilterChains[chain]; found { + if existingFilterChains.Has(chain) { chainString := string(chain) filterChains.Write(utiliptables.MakeChainLine(chain)) filterRules.Write("-X", chainString) @@ -1380,26 +1380,21 @@ func (proxier *Proxier) syncProxyRules() { // active rules, so they're harmless other than taking up memory.) deletedChains := 0 if !proxier.largeClusterMode || time.Since(proxier.lastIPTablesCleanup) > proxier.syncPeriod { - var existingNATChains map[utiliptables.Chain]struct{} - proxier.iptablesData.Reset() if err := proxier.iptables.SaveInto(utiliptables.TableNAT, proxier.iptablesData); err == nil { - existingNATChains = utiliptables.GetChainsFromTable(proxier.iptablesData.Bytes()) - - for chain := range existingNATChains { - if !activeNATChains.Has(chain) { - chainString := string(chain) - if !isServiceChainName(chainString) { - // Ignore chains that aren't ours. - continue - } - // We must (as per iptables) write a chain-line - // for it, which has the nice effect of flushing - // the chain. Then we can remove the chain. - proxier.natChains.Write(utiliptables.MakeChainLine(chain)) - proxier.natRules.Write("-X", chainString) - deletedChains++ + existingNATChains := utiliptables.GetChainsFromTable(proxier.iptablesData.Bytes()) + for chain := range existingNATChains.Difference(activeNATChains) { + chainString := string(chain) + if !isServiceChainName(chainString) { + // Ignore chains that aren't ours. + continue } + // We must (as per iptables) write a chain-line + // for it, which has the nice effect of flushing + // the chain. Then we can remove the chain. + proxier.natChains.Write(utiliptables.MakeChainLine(chain)) + proxier.natRules.Write("-X", chainString) + deletedChains++ } proxier.lastIPTablesCleanup = time.Now() } else { diff --git a/pkg/util/iptables/save_restore.go b/pkg/util/iptables/save_restore.go index b788beb9113..5f8b78e94ae 100644 --- a/pkg/util/iptables/save_restore.go +++ b/pkg/util/iptables/save_restore.go @@ -19,6 +19,8 @@ package iptables import ( "bytes" "fmt" + + "k8s.io/apimachinery/pkg/util/sets" ) // MakeChainLine return an iptables-save/restore formatted chain line given a Chain @@ -27,10 +29,10 @@ func MakeChainLine(chain Chain) string { } // GetChainsFromTable parses iptables-save data to find the chains that are defined. It -// assumes that save contains a single table's data, and returns a map with keys for every +// assumes that save contains a single table's data, and returns a set with keys for every // chain defined in that table. -func GetChainsFromTable(save []byte) map[Chain]struct{} { - chainsMap := make(map[Chain]struct{}) +func GetChainsFromTable(save []byte) sets.Set[Chain] { + chainsSet := sets.New[Chain]() for { i := bytes.Index(save, []byte("\n:")) @@ -45,8 +47,8 @@ func GetChainsFromTable(save []byte) map[Chain]struct{} { break } chain := Chain(save[:end]) - chainsMap[chain] = struct{}{} + chainsSet.Insert(chain) save = save[end:] } - return chainsMap + return chainsSet } diff --git a/pkg/util/iptables/save_restore_test.go b/pkg/util/iptables/save_restore_test.go index 2a4c383e12f..1030f6194c5 100644 --- a/pkg/util/iptables/save_restore_test.go +++ b/pkg/util/iptables/save_restore_test.go @@ -20,19 +20,19 @@ import ( "testing" "github.com/lithammer/dedent" + + "k8s.io/apimachinery/pkg/util/sets" ) -func checkChains(t *testing.T, save []byte, expected map[Chain]struct{}) { +func checkChains(t *testing.T, save []byte, expected sets.Set[Chain]) { chains := GetChainsFromTable(save) - for chain := range expected { - if _, exists := chains[chain]; !exists { - t.Errorf("GetChainsFromTable expected chain not present: %s", chain) - } + missing := expected.Difference(chains) + if len(missing) != 0 { + t.Errorf("GetChainsFromTable expected chains not present: %v", missing.UnsortedList()) } - for chain := range chains { - if _, exists := expected[chain]; !exists { - t.Errorf("GetChainsFromTable chain unexpectedly present: %s", chain) - } + extra := chains.Difference(expected) + if len(extra) != 0 { + t.Errorf("GetChainsFromTable expected chains unexpectedly present: %v", extra.UnsortedList()) } } @@ -77,22 +77,23 @@ func TestGetChainsFromTable(t *testing.T) { -A KUBE-SVC-6666666666666666 -m comment --comment "kube-system/kube-dns:dns" -j KUBE-SVC-1111111111111111 COMMIT `) - expected := map[Chain]struct{}{ - ChainPrerouting: {}, - Chain("INPUT"): {}, - Chain("OUTPUT"): {}, - ChainPostrouting: {}, - Chain("DOCKER"): {}, - Chain("KUBE-NODEPORT-CONTAINER"): {}, - Chain("KUBE-NODEPORT-HOST"): {}, - Chain("KUBE-PORTALS-CONTAINER"): {}, - Chain("KUBE-PORTALS-HOST"): {}, - Chain("KUBE-SVC-1111111111111111"): {}, - Chain("KUBE-SVC-2222222222222222"): {}, - Chain("KUBE-SVC-3333333333333333"): {}, - Chain("KUBE-SVC-4444444444444444"): {}, - Chain("KUBE-SVC-5555555555555555"): {}, - Chain("KUBE-SVC-6666666666666666"): {}, - } + + expected := sets.New( + ChainPrerouting, + Chain("INPUT"), + Chain("OUTPUT"), + ChainPostrouting, + Chain("DOCKER"), + Chain("KUBE-NODEPORT-CONTAINER"), + Chain("KUBE-NODEPORT-HOST"), + Chain("KUBE-PORTALS-CONTAINER"), + Chain("KUBE-PORTALS-HOST"), + Chain("KUBE-SVC-1111111111111111"), + Chain("KUBE-SVC-2222222222222222"), + Chain("KUBE-SVC-3333333333333333"), + Chain("KUBE-SVC-4444444444444444"), + Chain("KUBE-SVC-5555555555555555"), + Chain("KUBE-SVC-6666666666666666"), + ) checkChains(t, []byte(iptablesSave), expected) }