diff --git a/pkg/proxy/iptables/proxier.go b/pkg/proxy/iptables/proxier.go index 79802d635f8..2da5eb76895 100644 --- a/pkg/proxy/iptables/proxier.go +++ b/pkg/proxy/iptables/proxier.go @@ -1516,6 +1516,18 @@ func (proxier *Proxier) syncProxyRules() { ) writeLine(proxier.natRules, args...) } else { + // First write session affinity rules only over local endpoints, if applicable. + if svcInfo.sessionAffinityType == api.ServiceAffinityClientIP { + for _, endpointChain := range localEndpointChains { + writeLine(proxier.natRules, + "-A", string(svcXlbChain), + "-m", "comment", "--comment", svcNameString, + "-m", "recent", "--name", string(endpointChain), + "--rcheck", "--seconds", strconv.Itoa(svcInfo.stickyMaxAgeSeconds), "--reap", + "-j", string(endpointChain)) + } + } + // Setup probability filter rules only over local endpoints for i, endpointChain := range localEndpointChains { // Balancing rules in the per-service chain. diff --git a/pkg/proxy/iptables/proxier_test.go b/pkg/proxy/iptables/proxier_test.go index 2e929347246..261bb66f333 100644 --- a/pkg/proxy/iptables/proxier_test.go +++ b/pkg/proxy/iptables/proxier_test.go @@ -325,6 +325,15 @@ func NewFakeProxier(ipt utiliptables.Interface) *Proxier { return p } +func hasSessionAffinityRule(rules []iptablestest.Rule) bool { + for _, r := range rules { + if _, ok := r[iptablestest.Recent]; ok { + return true + } + } + return false +} + func hasJump(rules []iptablestest.Rule, destChain, destIP string, destPort int) bool { destPortStr := strconv.Itoa(destPort) match := false @@ -769,6 +778,7 @@ func TestOnlyLocalLoadBalancing(t *testing.T) { NamespacedName: makeNSN("ns1", "svc1"), Port: "p80", } + svcSessionAffinityTimeout := int32(10800) makeServiceMap(fp, makeTestService(svcPortName.Namespace, svcPortName.Name, func(svc *api.Service) { @@ -784,6 +794,10 @@ func TestOnlyLocalLoadBalancing(t *testing.T) { IP: svcLBIP, }} svc.Spec.ExternalTrafficPolicy = api.ServiceExternalTrafficPolicyTypeLocal + svc.Spec.SessionAffinity = api.ServiceAffinityClientIP + svc.Spec.SessionAffinityConfig = &api.SessionAffinityConfig{ + ClientIP: &api.ClientIPConfig{TimeoutSeconds: &svcSessionAffinityTimeout}, + } }), ) @@ -838,6 +852,9 @@ func TestOnlyLocalLoadBalancing(t *testing.T) { if !hasJump(lbRules, localEpChain, "", 0) { errorf(fmt.Sprintf("Didn't find jump from lb chain %v to local ep %v", lbChain, epStrNonLocal), lbRules, t) } + if !hasSessionAffinityRule(lbRules) { + errorf(fmt.Sprintf("Didn't find session affinity rule from lb chain %v", lbChain), lbRules, t) + } } func TestOnlyLocalNodePortsNoClusterCIDR(t *testing.T) { diff --git a/pkg/util/iptables/testing/fake.go b/pkg/util/iptables/testing/fake.go index 8d9ac7c0708..6f398597f77 100644 --- a/pkg/util/iptables/testing/fake.go +++ b/pkg/util/iptables/testing/fake.go @@ -32,6 +32,7 @@ const ( Jump = "-j " Reject = "REJECT" ToDest = "--to-destination " + Recent = "recent " ) type Rule map[string]string @@ -111,7 +112,7 @@ func (f *FakeIPTables) GetRules(chainName string) (rules []Rule) { for _, l := range strings.Split(string(f.Lines), "\n") { if strings.Contains(l, fmt.Sprintf("-A %v", chainName)) { newRule := Rule(map[string]string{}) - for _, arg := range []string{Destination, Source, DPort, Protocol, Jump, ToDest} { + for _, arg := range []string{Destination, Source, DPort, Protocol, Jump, ToDest, Recent} { tok := getToken(l, arg) if tok != "" { newRule[arg] = tok