add unit tests for -src-type=LOCAL from LB chain

Signed-off-by: Andrew Sy Kim <kiman@vmware.com>
This commit is contained in:
Andrew Sy Kim 2019-05-07 15:13:37 -04:00
parent b926fb9d2b
commit 8dfd4def99
2 changed files with 16 additions and 11 deletions

View File

@ -424,6 +424,18 @@ func hasJump(rules []iptablestest.Rule, destChain, destIP string, destPort int)
return match
}
func hasSrcType(rules []iptablestest.Rule, srcType string) bool {
for _, r := range rules {
if r[iptablestest.SrcType] != srcType {
continue
}
return true
}
return false
}
func TestHasJump(t *testing.T) {
testCases := map[string]struct {
rules []iptablestest.Rule
@ -942,10 +954,6 @@ func TestOnlyLocalNodePorts(t *testing.T) {
}
func onlyLocalNodePorts(t *testing.T, fp *Proxier, ipt *iptablestest.FakeIPTables) {
// LB to SVC rule should always exist for local only since
// any traffic with `--src-type LOCAL` now routes to service chain
shouldLBTOSVCRuleExist := true
svcIP := "10.20.30.41"
svcPort := 80
svcNodePort := 3001
@ -1021,12 +1029,8 @@ func onlyLocalNodePorts(t *testing.T, fp *Proxier, ipt *iptablestest.FakeIPTable
if hasJump(lbRules, nonLocalEpChain, "", 0) {
errorf(fmt.Sprintf("Found jump from lb chain %v to non-local ep %v", lbChain, epStrLocal), lbRules, t)
}
if hasJump(lbRules, svcChain, "", 0) != shouldLBTOSVCRuleExist {
prefix := "Did not find "
if !shouldLBTOSVCRuleExist {
prefix = "Found "
}
errorf(fmt.Sprintf("%s jump from lb chain %v to svc %v", prefix, lbChain, svcChain), lbRules, t)
if !hasJump(lbRules, svcChain, "", 0) || !hasSrcType(lbRules, "LOCAL") {
errorf(fmt.Sprintf("Did not find jump from lb chain %v to svc %v with src-type LOCAL", lbChain, svcChain), lbRules, t)
}
if !hasJump(lbRules, localEpChain, "", 0) {
errorf(fmt.Sprintf("Didn't find jump from lb chain %v to local ep %v", lbChain, epStrLocal), lbRules, t)

View File

@ -34,6 +34,7 @@ const (
ToDest = "--to-destination "
Recent = "recent "
MatchSet = "--match-set "
SrcType = "--src-type "
)
type Rule map[string]string
@ -113,7 +114,7 @@ func (f *FakeIPTables) GetRules(chainName string) (rules []Rule) {
for _, l := range strings.Split(string(f.Lines), "\n") {
if strings.Contains(l, fmt.Sprintf("-A %v", chainName)) {
newRule := Rule(map[string]string{})
for _, arg := range []string{Destination, Source, DPort, Protocol, Jump, ToDest, Recent, MatchSet} {
for _, arg := range []string{Destination, Source, DPort, Protocol, Jump, ToDest, Recent, MatchSet, SrcType} {
tok := getToken(l, arg)
if tok != "" {
newRule[arg] = tok