diff --git a/pkg/proxy/iptables/proxier_test.go b/pkg/proxy/iptables/proxier_test.go index bbf55363d8c..a2248f23f77 100644 --- a/pkg/proxy/iptables/proxier_test.go +++ b/pkg/proxy/iptables/proxier_test.go @@ -5014,3 +5014,123 @@ COMMIT assertIPTablesRulesEqual(t, expected, fp.iptablesData.String()) } + +func countEndpointsAndComments(iptablesData string, matchEndpoint string) (string, int, int) { + var numEndpoints, numComments int + var matched string + for _, line := range strings.Split(iptablesData, "\n") { + if strings.HasPrefix(line, "-A KUBE-SEP-") && strings.Contains(line, "-j DNAT") { + numEndpoints++ + if strings.Contains(line, "--comment") { + numComments++ + } + if strings.Contains(line, matchEndpoint) { + matched = line + } + } + } + return matched, numEndpoints, numComments +} + +func TestEndpointCommentElision(t *testing.T) { + ipt := iptablestest.NewFake() + fp := NewFakeProxier(ipt) + fp.masqueradeAll = true + + makeServiceMap(fp, + makeTestService("ns1", "svc1", func(svc *v1.Service) { + svc.Spec.Type = v1.ServiceTypeClusterIP + svc.Spec.ClusterIP = "10.20.30.41" + svc.Spec.Ports = []v1.ServicePort{{ + Name: "p80", + Port: 80, + Protocol: v1.ProtocolTCP, + }} + }), + makeTestService("ns2", "svc2", func(svc *v1.Service) { + svc.Spec.Type = v1.ServiceTypeClusterIP + svc.Spec.ClusterIP = "10.20.30.42" + svc.Spec.Ports = []v1.ServicePort{{ + Name: "p8080", + Port: 8080, + Protocol: v1.ProtocolTCP, + }} + }), + makeTestService("ns3", "svc3", func(svc *v1.Service) { + svc.Spec.Type = v1.ServiceTypeClusterIP + svc.Spec.ClusterIP = "10.20.30.43" + svc.Spec.Ports = []v1.ServicePort{{ + Name: "p8081", + Port: 8081, + Protocol: v1.ProtocolTCP, + }} + }), + ) + + tcpProtocol := v1.ProtocolTCP + populateEndpointSlices(fp, + makeTestEndpointSlice("ns1", "svc1", 1, func(eps *discovery.EndpointSlice) { + eps.AddressType = discovery.AddressTypeIPv4 + eps.Endpoints = make([]discovery.Endpoint, endpointChainsNumberThreshold/2-1) + for i := range eps.Endpoints { + eps.Endpoints[i].Addresses = []string{fmt.Sprintf("10.0.%d.%d", i%256, i/256)} + } + eps.Ports = []discovery.EndpointPort{{ + Name: utilpointer.StringPtr("p80"), + Port: utilpointer.Int32(80), + Protocol: &tcpProtocol, + }} + }), + makeTestEndpointSlice("ns2", "svc2", 1, func(eps *discovery.EndpointSlice) { + eps.AddressType = discovery.AddressTypeIPv4 + eps.Endpoints = make([]discovery.Endpoint, endpointChainsNumberThreshold/2-1) + for i := range eps.Endpoints { + eps.Endpoints[i].Addresses = []string{fmt.Sprintf("10.1.%d.%d", i%256, i/256)} + } + eps.Ports = []discovery.EndpointPort{{ + Name: utilpointer.StringPtr("p8080"), + Port: utilpointer.Int32(8080), + Protocol: &tcpProtocol, + }} + }), + ) + + fp.syncProxyRules() + + expectedEndpoints := 2 * (endpointChainsNumberThreshold/2 - 1) + firstEndpoint, numEndpoints, numComments := countEndpointsAndComments(fp.iptablesData.String(), "10.0.0.0") + assert.Equal(t, "-A KUBE-SEP-DKGQUZGBKLTPAR56 -m comment --comment ns1/svc1:p80 -m tcp -p tcp -j DNAT --to-destination 10.0.0.0:80", firstEndpoint) + if numEndpoints != expectedEndpoints { + t.Errorf("Found wrong number of endpoints: expected %d, got %d", expectedEndpoints, numEndpoints) + } + if numComments != numEndpoints { + t.Errorf("numComments (%d) != numEndpoints (%d) when numEndpoints < threshold (%d)", numComments, numEndpoints, endpointChainsNumberThreshold) + } + + fp.OnEndpointSliceAdd(makeTestEndpointSlice("ns3", "svc3", 1, func(eps *discovery.EndpointSlice) { + eps.AddressType = discovery.AddressTypeIPv4 + eps.Endpoints = []discovery.Endpoint{{ + Addresses: []string{"1.2.3.4"}, + }, { + Addresses: []string{"5.6.7.8"}, + }, { + Addresses: []string{"9.10.11.12"}, + }} + eps.Ports = []discovery.EndpointPort{{ + Name: utilpointer.StringPtr("p8081"), + Port: utilpointer.Int32(8081), + Protocol: &tcpProtocol, + }} + })) + fp.syncProxyRules() + + expectedEndpoints += 3 + firstEndpoint, numEndpoints, numComments = countEndpointsAndComments(fp.iptablesData.String(), "10.0.0.0") + assert.Equal(t, "-A KUBE-SEP-DKGQUZGBKLTPAR56 -m tcp -p tcp -j DNAT --to-destination 10.0.0.0:80", firstEndpoint) + if numEndpoints != expectedEndpoints { + t.Errorf("Found wrong number of endpoints: expected %d, got %d", expectedEndpoints, numEndpoints) + } + if numComments != 0 { + t.Errorf("numComments (%d) != 0 when numEndpoints (%d) > threshold (%d)", numComments, numEndpoints, endpointChainsNumberThreshold) + } +}