diff --git a/pkg/proxy/iptables/proxier.go b/pkg/proxy/iptables/proxier.go index 8d1bade3146..23920a55cb4 100644 --- a/pkg/proxy/iptables/proxier.go +++ b/pkg/proxy/iptables/proxier.go @@ -166,26 +166,9 @@ type endpointsInfo struct { chainName utiliptables.Chain } -// Returns just the IP part of an IP or IP:port or endpoint string. If the IP -// part is an IPv6 address enclosed in brackets (e.g. "[fd00:1::5]:9999"), -// then the brackets are stripped as well. -func ipPart(s string) string { - if ip := net.ParseIP(s); ip != nil { - // IP address without port - return s - } - // Must be IP:port - ip, _, err := net.SplitHostPort(s) - if err != nil { - glog.Errorf("Error parsing '%s': %v", s, err) - return "" - } - return ip -} - -// Returns just the IP part of the endpoint. +// IPPart returns just the IP part of the endpoint. func (e *endpointsInfo) IPPart() string { - return ipPart(e.endpoint) + return utilproxy.IPPart(e.endpoint) } // Returns the endpoint chain name for a given endpointsInfo. @@ -944,7 +927,7 @@ type endpointServicePair struct { } func (esp *endpointServicePair) IPPart() string { - return ipPart(esp.endpoint) + return utilproxy.IPPart(esp.endpoint) } // After a UDP endpoint has been removed, we must flush any pending conntrack entries to it, or else we @@ -953,7 +936,7 @@ func (esp *endpointServicePair) IPPart() string { func (proxier *Proxier) deleteEndpointConnections(connectionMap map[endpointServicePair]bool) { for epSvcPair := range connectionMap { if svcInfo, ok := proxier.serviceMap[epSvcPair.servicePortName]; ok && svcInfo.protocol == api.ProtocolUDP { - endpointIP := epSvcPair.endpoint[0:strings.Index(epSvcPair.endpoint, ":")] + endpointIP := utilproxy.IPPart(epSvcPair.endpoint) err := utilproxy.ClearUDPConntrackForPeers(proxier.exec, svcInfo.clusterIP.String(), endpointIP) if err != nil { glog.Errorf("Failed to delete %s endpoint connections, error: %v", epSvcPair.servicePortName.String(), err) @@ -962,16 +945,6 @@ func (proxier *Proxier) deleteEndpointConnections(connectionMap map[endpointServ } } -// hostAddress returns a host address of the form /32 for -// IPv4 and /128 for IPv6 -func hostAddress(ip net.IP) string { - len := 32 - if ip.To4() == nil { - len = 128 - } - return fmt.Sprintf("%s/%d", ip.String(), len) -} - // This is where all of the iptables-save/restore calls happen. // The only other iptables rules are those that are setup in iptablesInit() // This assumes proxier.mu is NOT held @@ -1189,7 +1162,7 @@ func (proxier *Proxier) syncProxyRules() { "-A", string(kubeServicesChain), "-m", "comment", "--comment", fmt.Sprintf(`"%s cluster IP"`, svcNameString), "-m", protocol, "-p", protocol, - "-d", hostAddress(svcInfo.clusterIP), + "-d", utilproxy.ToCIDR(svcInfo.clusterIP), "--dport", strconv.Itoa(svcInfo.port), ) if proxier.masqueradeAll { @@ -1243,7 +1216,7 @@ func (proxier *Proxier) syncProxyRules() { "-A", string(kubeServicesChain), "-m", "comment", "--comment", fmt.Sprintf(`"%s external IP"`, svcNameString), "-m", protocol, "-p", protocol, - "-d", hostAddress(net.ParseIP(externalIP)), + "-d", utilproxy.ToCIDR(net.ParseIP(externalIP)), "--dport", strconv.Itoa(svcInfo.port), ) // We have to SNAT packets to external IPs. @@ -1269,7 +1242,7 @@ func (proxier *Proxier) syncProxyRules() { "-A", string(kubeServicesChain), "-m", "comment", "--comment", fmt.Sprintf(`"%s has no endpoints"`, svcNameString), "-m", protocol, "-p", protocol, - "-d", hostAddress(net.ParseIP(externalIP)), + "-d", utilproxy.ToCIDR(net.ParseIP(externalIP)), "--dport", strconv.Itoa(svcInfo.port), "-j", "REJECT", ) @@ -1295,7 +1268,7 @@ func (proxier *Proxier) syncProxyRules() { "-A", string(kubeServicesChain), "-m", "comment", "--comment", fmt.Sprintf(`"%s loadbalancer IP"`, svcNameString), "-m", protocol, "-p", protocol, - "-d", hostAddress(net.ParseIP(ingress.IP)), + "-d", utilproxy.ToCIDR(net.ParseIP(ingress.IP)), "--dport", strconv.Itoa(svcInfo.port), ) // jump to service firewall chain @@ -1333,7 +1306,7 @@ func (proxier *Proxier) syncProxyRules() { // loadbalancer's backend hosts. In this case, request will not hit the loadbalancer but loop back directly. // Need to add the following rule to allow request on host. if allowFromNode { - writeLine(proxier.natRules, append(args, "-s", hostAddress(net.ParseIP(ingress.IP)), "-j", string(chosenChain))...) + writeLine(proxier.natRules, append(args, "-s", utilproxy.ToCIDR(net.ParseIP(ingress.IP)), "-j", string(chosenChain))...) } } @@ -1417,7 +1390,7 @@ func (proxier *Proxier) syncProxyRules() { "-A", string(kubeServicesChain), "-m", "comment", "--comment", fmt.Sprintf(`"%s has no endpoints"`, svcNameString), "-m", protocol, "-p", protocol, - "-d", hostAddress(svcInfo.clusterIP), + "-d", utilproxy.ToCIDR(svcInfo.clusterIP), "--dport", strconv.Itoa(svcInfo.port), "-j", "REJECT", ) @@ -1489,7 +1462,7 @@ func (proxier *Proxier) syncProxyRules() { ) // Handle traffic that loops back to the originator with SNAT. writeLine(proxier.natRules, append(args, - "-s", hostAddress(net.ParseIP(epIP)), + "-s", utilproxy.ToCIDR(net.ParseIP(epIP)), "-j", string(KubeMarkMasqChain))...) // Update client-affinity lists. if svcInfo.sessionAffinityType == api.ServiceAffinityClientIP { diff --git a/pkg/proxy/iptables/proxier_test.go b/pkg/proxy/iptables/proxier_test.go index 7ba8e13594e..aba1b5d58ee 100644 --- a/pkg/proxy/iptables/proxier_test.go +++ b/pkg/proxy/iptables/proxier_test.go @@ -56,52 +56,6 @@ func checkAllLines(t *testing.T, table utiliptables.Table, save []byte, expected } } -func TestIpPart(t *testing.T) { - const noError = "" - - testCases := []struct { - endpoint string - expectedIP string - expectedError string - }{ - {"1.2.3.4", "1.2.3.4", noError}, - {"1.2.3.4:9999", "1.2.3.4", noError}, - {"2001:db8::1:1", "2001:db8::1:1", noError}, - {"[2001:db8::2:2]:9999", "2001:db8::2:2", noError}, - {"1.2.3.4::9999", "", "too many colons"}, - {"1.2.3.4:[0]", "", "unexpected '[' in address"}, - } - - for _, tc := range testCases { - ip := ipPart(tc.endpoint) - if tc.expectedError == noError { - if ip != tc.expectedIP { - t.Errorf("Unexpected IP for %s: Expected: %s, Got %s", tc.endpoint, tc.expectedIP, ip) - } - } else if ip != "" { - t.Errorf("Error did not occur for %s, expected: '%s' error", tc.endpoint, tc.expectedError) - } - } -} - -func TestHostAddress(t *testing.T) { - testCases := []struct { - ip string - expectedAddr string - }{ - {"1.2.3.4", "1.2.3.4/32"}, - {"2001:db8::1:1", "2001:db8::1:1/128"}, - } - - for _, tc := range testCases { - ip := net.ParseIP(tc.ip) - addr := hostAddress(ip) - if addr != tc.expectedAddr { - t.Errorf("Unexpected host address for %s: Expected: %s, Got %s", tc.ip, tc.expectedAddr, addr) - } - } -} - func TestReadLinesFromByteBuffer(t *testing.T) { testFn := func(byteArray []byte, expected []string) { index := 0 @@ -272,6 +226,10 @@ func TestDeleteEndpointConnections(t *testing.T) { endpoint: "10.240.0.5:80", servicePortName: svc2, }, + { + endpoint: "[fd00:1::5]:8080", + servicePortName: svc2, + }, } expectCommandExecCount := 0 @@ -281,7 +239,7 @@ func TestDeleteEndpointConnections(t *testing.T) { svcInfo := fakeProxier.serviceMap[testCases[i].servicePortName] if svcInfo.protocol == api.ProtocolUDP { svcIp := svcInfo.clusterIP.String() - endpointIp := strings.Split(testCases[i].endpoint, ":")[0] + endpointIp := utilproxy.IPPart(testCases[i].endpoint) expectCommand := fmt.Sprintf("conntrack -D --orig-dst %s --dst-nat %s -p udp", svcIp, endpointIp) execCommand := strings.Join(fcmd.CombinedOutputLog[expectCommandExecCount], " ") if expectCommand != execCommand { diff --git a/pkg/proxy/ipvs/proxier.go b/pkg/proxy/ipvs/proxier.go index 6da4a0a4859..87c7e00ca0f 100644 --- a/pkg/proxy/ipvs/proxier.go +++ b/pkg/proxy/ipvs/proxier.go @@ -478,10 +478,7 @@ func (e *endpointsInfo) String() string { // IPPart returns just the IP part of the endpoint. func (e *endpointsInfo) IPPart() string { - if index := strings.Index(e.endpoint, ":"); index != -1 { - return e.endpoint[0:index] - } - return e.endpoint + return utilproxy.IPPart(e.endpoint) } type endpointServicePair struct { @@ -1262,7 +1259,7 @@ func (proxier *Proxier) syncProxyRules() { func (proxier *Proxier) deleteEndpointConnections(connectionMap map[endpointServicePair]bool) { for epSvcPair := range connectionMap { if svcInfo, ok := proxier.serviceMap[epSvcPair.servicePortName]; ok && svcInfo.protocol == api.ProtocolUDP { - endpointIP := epSvcPair.endpoint[0:strings.Index(epSvcPair.endpoint, ":")] + endpointIP := utilproxy.IPPart(epSvcPair.endpoint) err := utilproxy.ClearUDPConntrackForPeers(proxier.exec, svcInfo.clusterIP.String(), endpointIP) if err != nil { glog.Errorf("Failed to delete %s endpoint connections, error: %v", epSvcPair.servicePortName.String(), err) diff --git a/pkg/proxy/util/BUILD b/pkg/proxy/util/BUILD index 0fb554dc0cf..03faa116ea1 100644 --- a/pkg/proxy/util/BUILD +++ b/pkg/proxy/util/BUILD @@ -4,6 +4,7 @@ go_library( name = "go_default_library", srcs = [ "conntrack.go", + "endpoints.go", "port.go", "utils.go", ], @@ -21,6 +22,7 @@ go_test( name = "go_default_test", srcs = [ "conntrack_test.go", + "endpoints_test.go", "port_test.go", "utils_test.go", ], diff --git a/pkg/proxy/util/endpoints.go b/pkg/proxy/util/endpoints.go new file mode 100644 index 00000000000..32e770d4f94 --- /dev/null +++ b/pkg/proxy/util/endpoints.go @@ -0,0 +1,51 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package util + +import ( + "fmt" + "net" + + "github.com/golang/glog" +) + +// IPPart returns just the IP part of an IP or IP:port or endpoint string. If the IP +// part is an IPv6 address enclosed in brackets (e.g. "[fd00:1::5]:9999"), +// then the brackets are stripped as well. +func IPPart(s string) string { + if ip := net.ParseIP(s); ip != nil { + // IP address without port + return s + } + // Must be IP:port + ip, _, err := net.SplitHostPort(s) + if err != nil { + glog.Errorf("Error parsing '%s': %v", s, err) + return "" + } + return ip +} + +// ToCIDR returns a host address of the form /32 for +// IPv4 and /128 for IPv6 +func ToCIDR(ip net.IP) string { + len := 32 + if ip.To4() == nil { + len = 128 + } + return fmt.Sprintf("%s/%d", ip.String(), len) +} diff --git a/pkg/proxy/util/endpoints_test.go b/pkg/proxy/util/endpoints_test.go new file mode 100644 index 00000000000..618f59e96a8 --- /dev/null +++ b/pkg/proxy/util/endpoints_test.go @@ -0,0 +1,68 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package util + +import ( + "net" + "testing" +) + +func TestIPPart(t *testing.T) { + const noError = "" + + testCases := []struct { + endpoint string + expectedIP string + expectedError string + }{ + {"1.2.3.4", "1.2.3.4", noError}, + {"1.2.3.4:9999", "1.2.3.4", noError}, + {"2001:db8::1:1", "2001:db8::1:1", noError}, + {"[2001:db8::2:2]:9999", "2001:db8::2:2", noError}, + {"1.2.3.4::9999", "", "too many colons"}, + {"1.2.3.4:[0]", "", "unexpected '[' in address"}, + } + + for _, tc := range testCases { + ip := IPPart(tc.endpoint) + if tc.expectedError == noError { + if ip != tc.expectedIP { + t.Errorf("Unexpected IP for %s: Expected: %s, Got %s", tc.endpoint, tc.expectedIP, ip) + } + } else if ip != "" { + t.Errorf("Error did not occur for %s, expected: '%s' error", tc.endpoint, tc.expectedError) + } + } +} + +func TestToCIDR(t *testing.T) { + testCases := []struct { + ip string + expectedAddr string + }{ + {"1.2.3.4", "1.2.3.4/32"}, + {"2001:db8::1:1", "2001:db8::1:1/128"}, + } + + for _, tc := range testCases { + ip := net.ParseIP(tc.ip) + addr := ToCIDR(ip) + if addr != tc.expectedAddr { + t.Errorf("Unexpected host address for %s: Expected: %s, Got %s", tc.ip, tc.expectedAddr, addr) + } + } +}