diff --git a/pkg/proxy/iptables/proxier_test.go b/pkg/proxy/iptables/proxier_test.go index 1eb672e3756..7baa2ec2e45 100644 --- a/pkg/proxy/iptables/proxier_test.go +++ b/pkg/proxy/iptables/proxier_test.go @@ -424,6 +424,18 @@ func hasJump(rules []iptablestest.Rule, destChain, destIP string, destPort int) return match } +func hasSrcType(rules []iptablestest.Rule, srcType string) bool { + for _, r := range rules { + if r[iptablestest.SrcType] != srcType { + continue + } + + return true + } + + return false +} + func TestHasJump(t *testing.T) { testCases := map[string]struct { rules []iptablestest.Rule @@ -942,10 +954,6 @@ func TestOnlyLocalNodePorts(t *testing.T) { } func onlyLocalNodePorts(t *testing.T, fp *Proxier, ipt *iptablestest.FakeIPTables) { - // LB to SVC rule should always exist for local only since - // any traffic with `--src-type LOCAL` now routes to service chain - shouldLBTOSVCRuleExist := true - svcIP := "10.20.30.41" svcPort := 80 svcNodePort := 3001 @@ -1021,12 +1029,8 @@ func onlyLocalNodePorts(t *testing.T, fp *Proxier, ipt *iptablestest.FakeIPTable if hasJump(lbRules, nonLocalEpChain, "", 0) { errorf(fmt.Sprintf("Found jump from lb chain %v to non-local ep %v", lbChain, epStrLocal), lbRules, t) } - if hasJump(lbRules, svcChain, "", 0) != shouldLBTOSVCRuleExist { - prefix := "Did not find " - if !shouldLBTOSVCRuleExist { - prefix = "Found " - } - errorf(fmt.Sprintf("%s jump from lb chain %v to svc %v", prefix, lbChain, svcChain), lbRules, t) + if !hasJump(lbRules, svcChain, "", 0) || !hasSrcType(lbRules, "LOCAL") { + errorf(fmt.Sprintf("Did not find jump from lb chain %v to svc %v with src-type LOCAL", lbChain, svcChain), lbRules, t) } if !hasJump(lbRules, localEpChain, "", 0) { errorf(fmt.Sprintf("Didn't find jump from lb chain %v to local ep %v", lbChain, epStrLocal), lbRules, t) diff --git a/pkg/util/iptables/testing/fake.go b/pkg/util/iptables/testing/fake.go index cb504f90471..66adc1a275e 100644 --- a/pkg/util/iptables/testing/fake.go +++ b/pkg/util/iptables/testing/fake.go @@ -34,6 +34,7 @@ const ( ToDest = "--to-destination " Recent = "recent " MatchSet = "--match-set " + SrcType = "--src-type " ) type Rule map[string]string @@ -113,7 +114,7 @@ func (f *FakeIPTables) GetRules(chainName string) (rules []Rule) { for _, l := range strings.Split(string(f.Lines), "\n") { if strings.Contains(l, fmt.Sprintf("-A %v", chainName)) { newRule := Rule(map[string]string{}) - for _, arg := range []string{Destination, Source, DPort, Protocol, Jump, ToDest, Recent, MatchSet} { + for _, arg := range []string{Destination, Source, DPort, Protocol, Jump, ToDest, Recent, MatchSet, SrcType} { tok := getToken(l, arg) if tok != "" { newRule[arg] = tok