Put service/endpoint sync into syncProxyRules

After this, syncProxyRules() can reliably be called in any context to do
the right thing.  Now it cn be made async.
This commit is contained in:
Tim Hockin 2017-04-02 23:21:03 -07:00
parent 5e442a3f61
commit c716886215
2 changed files with 256 additions and 162 deletions

View File

@ -519,22 +519,7 @@ func (proxier *Proxier) OnServiceUpdate(allServices []*api.Service) {
glog.V(2).Info("Received first Services update") glog.V(2).Info("Received first Services update")
} }
proxier.allServices = allServices proxier.allServices = allServices
proxier.syncProxyRules(syncReasonServices)
newServiceMap, hcPorts, staleUDPServices := buildNewServiceMap(allServices, proxier.serviceMap)
// update healthcheck ports
if err := proxier.healthChecker.SyncServices(hcPorts); err != nil {
glog.Errorf("Error syncing healtcheck ports: %v", err)
}
if len(newServiceMap) != len(proxier.serviceMap) || !reflect.DeepEqual(newServiceMap, proxier.serviceMap) {
proxier.serviceMap = newServiceMap
proxier.syncProxyRules(syncReasonServices)
} else {
glog.V(4).Infof("Skipping proxy iptables rule sync on service update because nothing changed")
}
utilproxy.DeleteServiceConnections(proxier.exec, staleUDPServices.List())
} }
// OnEndpointsUpdate takes in a slice of updated endpoints. // OnEndpointsUpdate takes in a slice of updated endpoints.
@ -545,23 +530,7 @@ func (proxier *Proxier) OnEndpointsUpdate(allEndpoints []*api.Endpoints) {
glog.V(2).Info("Received first Endpoints update") glog.V(2).Info("Received first Endpoints update")
} }
proxier.allEndpoints = allEndpoints proxier.allEndpoints = allEndpoints
proxier.syncProxyRules(syncReasonEndpoints)
// TODO: once service has made this same transform, move this into proxier.syncProxyRules()
newMap, hcEndpoints, staleConnections := buildNewEndpointsMap(proxier.allEndpoints, proxier.endpointsMap, proxier.hostname)
// update healthcheck endpoints
if err := proxier.healthChecker.SyncEndpoints(hcEndpoints); err != nil {
glog.Errorf("Error syncing healthcheck endoints: %v", err)
}
if len(newMap) != len(proxier.endpointsMap) || !reflect.DeepEqual(newMap, proxier.endpointsMap) {
proxier.endpointsMap = newMap
proxier.syncProxyRules(syncReasonEndpoints)
} else {
glog.V(4).Infof("Skipping proxy iptables rule sync on endpoint update because nothing changed")
}
proxier.deleteEndpointConnections(staleConnections)
} }
// Convert a slice of api.Endpoints objects into a map of service-port -> endpoints. // Convert a slice of api.Endpoints objects into a map of service-port -> endpoints.
@ -761,6 +730,25 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) {
glog.V(2).Info("Not syncing iptables until Services and Endpoints have been received from master") glog.V(2).Info("Not syncing iptables until Services and Endpoints have been received from master")
return return
} }
// Figure out the new services we need to activate.
newServices, hcServices, staleServices := buildNewServiceMap(proxier.allServices, proxier.serviceMap)
// If this was called because of a services update, but nothing actionable has changed, skip it.
if reason == syncReasonServices && reflect.DeepEqual(newServices, proxier.serviceMap) {
glog.V(3).Infof("Skipping iptables sync because nothing changed")
return
}
// Figure out the new endpoints we need to activate.
newEndpoints, hcEndpoints, staleEndpoints := buildNewEndpointsMap(proxier.allEndpoints, proxier.endpointsMap, proxier.hostname)
// If this was called because of an endpoints update, but nothing actionable has changed, skip it.
if reason == syncReasonEndpoints && reflect.DeepEqual(newEndpoints, proxier.endpointsMap) {
glog.V(3).Infof("Skipping iptables sync because nothing changed")
return
}
glog.V(3).Infof("Syncing iptables rules") glog.V(3).Infof("Syncing iptables rules")
// Create and link the kube services chain. // Create and link the kube services chain.
@ -891,7 +879,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) {
replacementPortsMap := map[localPort]closeable{} replacementPortsMap := map[localPort]closeable{}
// Build rules for each service. // Build rules for each service.
for svcName, svcInfo := range proxier.serviceMap { for svcName, svcInfo := range newServices {
protocol := strings.ToLower(string(svcInfo.protocol)) protocol := strings.ToLower(string(svcInfo.protocol))
// Create the per-service chain, retaining counters if possible. // Create the per-service chain, retaining counters if possible.
@ -1082,7 +1070,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) {
continue continue
} }
if lp.protocol == "udp" { if lp.protocol == "udp" {
proxier.clearUdpConntrackForPort(lp.port) proxier.clearUDPConntrackForPort(lp.port)
} }
replacementPortsMap[lp] = socket replacementPortsMap[lp] = socket
} // We're holding the port, so it's OK to install iptables rules. } // We're holding the port, so it's OK to install iptables rules.
@ -1108,7 +1096,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) {
// table doesn't currently have the same per-service structure that // table doesn't currently have the same per-service structure that
// the nat table does, so we just stick this into the kube-services // the nat table does, so we just stick this into the kube-services
// chain. // chain.
if len(proxier.endpointsMap[svcName]) == 0 { if len(newEndpoints[svcName]) == 0 {
writeLine(filterRules, writeLine(filterRules,
"-A", string(kubeServicesChain), "-A", string(kubeServicesChain),
"-m", "comment", "--comment", fmt.Sprintf(`"%s has no endpoints"`, svcName.String()), "-m", "comment", "--comment", fmt.Sprintf(`"%s has no endpoints"`, svcName.String()),
@ -1121,7 +1109,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) {
} }
// If the service has no endpoints then reject packets. // If the service has no endpoints then reject packets.
if len(proxier.endpointsMap[svcName]) == 0 { if len(newEndpoints[svcName]) == 0 {
writeLine(filterRules, writeLine(filterRules,
"-A", string(kubeServicesChain), "-A", string(kubeServicesChain),
"-m", "comment", "--comment", fmt.Sprintf(`"%s has no endpoints"`, svcName.String()), "-m", "comment", "--comment", fmt.Sprintf(`"%s has no endpoints"`, svcName.String()),
@ -1140,7 +1128,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) {
// These two slices parallel each other - keep in sync // These two slices parallel each other - keep in sync
endpoints := make([]*endpointsInfo, 0) endpoints := make([]*endpointsInfo, 0)
endpointChains := make([]utiliptables.Chain, 0) endpointChains := make([]utiliptables.Chain, 0)
for _, ep := range proxier.endpointsMap[svcName] { for _, ep := range newEndpoints[svcName] {
endpoints = append(endpoints, ep) endpoints = append(endpoints, ep)
endpointChain := servicePortEndpointChainName(svcName, protocol, ep.endpoint) endpointChain := servicePortEndpointChainName(svcName, protocol, ep.endpoint)
endpointChains = append(endpointChains, endpointChain) endpointChains = append(endpointChains, endpointChain)
@ -1317,6 +1305,22 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) {
} }
} }
proxier.portsMap = replacementPortsMap proxier.portsMap = replacementPortsMap
// Update healthchecks.
if err := proxier.healthChecker.SyncServices(hcServices); err != nil {
glog.Errorf("Error syncing healtcheck services: %v", err)
}
if err := proxier.healthChecker.SyncEndpoints(hcEndpoints); err != nil {
glog.Errorf("Error syncing healthcheck endoints: %v", err)
}
// Finish housekeeping.
proxier.serviceMap = newServices
proxier.endpointsMap = newEndpoints
// TODO: these and clearUDPConntrackForPort() could be made more consistent.
utilproxy.DeleteServiceConnections(proxier.exec, staleServices.List())
proxier.deleteEndpointConnections(staleEndpoints)
} }
// Clear UDP conntrack for port or all conntrack entries when port equal zero. // Clear UDP conntrack for port or all conntrack entries when port equal zero.
@ -1324,7 +1328,7 @@ func (proxier *Proxier) syncProxyRules(reason syncReason) {
// The solution is clearing the conntrack. Known issus: // The solution is clearing the conntrack. Known issus:
// https://github.com/docker/docker/issues/8795 // https://github.com/docker/docker/issues/8795
// https://github.com/kubernetes/kubernetes/issues/31983 // https://github.com/kubernetes/kubernetes/issues/31983
func (proxier *Proxier) clearUdpConntrackForPort(port int) { func (proxier *Proxier) clearUDPConntrackForPort(port int) {
glog.V(2).Infof("Deleting conntrack entries for udp connections") glog.V(2).Infof("Deleting conntrack entries for udp connections")
if port > 0 { if port > 0 {
err := utilproxy.ExecConntrackTool(proxier.exec, "-D", "-p", "udp", "--dport", strconv.Itoa(port)) err := utilproxy.ExecConntrackTool(proxier.exec, "-D", "-p", "udp", "--dport", strconv.Itoa(port))

View File

@ -551,7 +551,7 @@ func hasDNAT(rules []iptablestest.Rule, endpoint string) bool {
func errorf(msg string, rules []iptablestest.Rule, t *testing.T) { func errorf(msg string, rules []iptablestest.Rule, t *testing.T) {
for _, r := range rules { for _, r := range rules {
t.Logf("%v", r) t.Logf("%q", r)
} }
t.Errorf("%v", msg) t.Errorf("%v", msg)
} }
@ -559,56 +559,80 @@ func errorf(msg string, rules []iptablestest.Rule, t *testing.T) {
func TestClusterIPReject(t *testing.T) { func TestClusterIPReject(t *testing.T) {
ipt := iptablestest.NewFake() ipt := iptablestest.NewFake()
fp := NewFakeProxier(ipt) fp := NewFakeProxier(ipt)
svcName := "svc1" svcIP := "10.20.30.41"
svcIP := net.IPv4(10, 20, 30, 41) svcPort := 80
svcPortName := proxy.ServicePortName{
NamespacedName: makeNSN("ns1", "svc1"),
Port: "p80",
}
svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "p80"} fp.allServices = []*api.Service{
fp.serviceMap[svc] = newFakeServiceInfo(svc, svcIP, 80, api.ProtocolTCP, false) makeTestService(svcPortName.Namespace, svcPortName.Namespace, func(svc *api.Service) {
svc.Spec.ClusterIP = svcIP
svc.Spec.Ports = []api.ServicePort{{
Name: svcPortName.Port,
Port: int32(svcPort),
Protocol: api.ProtocolTCP,
}}
}),
}
fp.syncProxyRules(syncReasonForce) fp.syncProxyRules(syncReasonForce)
svcChain := string(servicePortChainName(svc, strings.ToLower(string(api.ProtocolTCP)))) svcChain := string(servicePortChainName(svcPortName, strings.ToLower(string(api.ProtocolTCP))))
svcRules := ipt.GetRules(svcChain) svcRules := ipt.GetRules(svcChain)
if len(svcRules) != 0 { if len(svcRules) != 0 {
errorf(fmt.Sprintf("Unexpected rule for chain %v service %v without endpoints", svcChain, svcName), svcRules, t) errorf(fmt.Sprintf("Unexpected rule for chain %v service %v without endpoints", svcChain, svcPortName), svcRules, t)
} }
kubeSvcRules := ipt.GetRules(string(kubeServicesChain)) kubeSvcRules := ipt.GetRules(string(kubeServicesChain))
if !hasJump(kubeSvcRules, iptablestest.Reject, svcIP.String(), 80) { if !hasJump(kubeSvcRules, iptablestest.Reject, svcIP, svcPort) {
errorf(fmt.Sprintf("Failed to find a %v rule for service %v with no endpoints", iptablestest.Reject, svcName), kubeSvcRules, t) errorf(fmt.Sprintf("Failed to find a %v rule for service %v with no endpoints", iptablestest.Reject, svcPortName), kubeSvcRules, t)
} }
} }
func TestClusterIPEndpointsJump(t *testing.T) { func TestClusterIPEndpointsJump(t *testing.T) {
ipt := iptablestest.NewFake() ipt := iptablestest.NewFake()
fp := NewFakeProxier(ipt) fp := NewFakeProxier(ipt)
svcName := "svc1" svcIP := "10.20.30.41"
svcIP := net.IPv4(10, 20, 30, 41) svcPort := 80
svcPortName := proxy.ServicePortName{
NamespacedName: makeNSN("ns1", "svc1"),
Port: "p80",
}
svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "p80"} fp.allServices = []*api.Service{
fp.serviceMap[svc] = newFakeServiceInfo(svc, svcIP, 80, api.ProtocolTCP, true) makeTestService(svcPortName.Namespace, svcPortName.Name, func(svc *api.Service) {
ip := "10.180.0.1" svc.Spec.ClusterIP = svcIP
port := 80 svc.Spec.Ports = []api.ServicePort{{
ep := fmt.Sprintf("%s:%d", ip, port) Name: svcPortName.Port,
allEndpoints := []*api.Endpoints{ Port: int32(svcPort),
makeTestEndpoints("ns1", svcName, func(ept *api.Endpoints) { Protocol: api.ProtocolTCP,
}}
}),
}
epIP := "10.180.0.1"
fp.allEndpoints = []*api.Endpoints{
makeTestEndpoints(svcPortName.Namespace, svcPortName.Name, func(ept *api.Endpoints) {
ept.Subsets = []api.EndpointSubset{{ ept.Subsets = []api.EndpointSubset{{
Addresses: []api.EndpointAddress{{ Addresses: []api.EndpointAddress{{
IP: ip, IP: epIP,
}}, }},
Ports: []api.EndpointPort{{ Ports: []api.EndpointPort{{
Name: "p80", Name: svcPortName.Port,
Port: int32(port), Port: int32(svcPort),
}}, }},
}} }}
}), }),
} }
fp.OnEndpointsUpdate(allEndpoints) fp.syncProxyRules(syncReasonForce)
svcChain := string(servicePortChainName(svc, strings.ToLower(string(api.ProtocolTCP)))) epStr := fmt.Sprintf("%s:%d", epIP, svcPort)
epChain := string(servicePortEndpointChainName(svc, strings.ToLower(string(api.ProtocolTCP)), ep)) svcChain := string(servicePortChainName(svcPortName, strings.ToLower(string(api.ProtocolTCP))))
epChain := string(servicePortEndpointChainName(svcPortName, strings.ToLower(string(api.ProtocolTCP)), epStr))
kubeSvcRules := ipt.GetRules(string(kubeServicesChain)) kubeSvcRules := ipt.GetRules(string(kubeServicesChain))
if !hasJump(kubeSvcRules, svcChain, svcIP.String(), 80) { if !hasJump(kubeSvcRules, svcChain, svcIP, svcPort) {
errorf(fmt.Sprintf("Failed to find jump from KUBE-SERVICES to %v chain", svcChain), kubeSvcRules, t) errorf(fmt.Sprintf("Failed to find jump from KUBE-SERVICES to %v chain", svcChain), kubeSvcRules, t)
} }
@ -617,40 +641,49 @@ func TestClusterIPEndpointsJump(t *testing.T) {
errorf(fmt.Sprintf("Failed to jump to ep chain %v", epChain), svcRules, t) errorf(fmt.Sprintf("Failed to jump to ep chain %v", epChain), svcRules, t)
} }
epRules := ipt.GetRules(epChain) epRules := ipt.GetRules(epChain)
if !hasDNAT(epRules, ep) { if !hasDNAT(epRules, epStr) {
errorf(fmt.Sprintf("Endpoint chain %v lacks DNAT to %v", epChain, ep), epRules, t) errorf(fmt.Sprintf("Endpoint chain %v lacks DNAT to %v", epChain, epStr), epRules, t)
} }
} }
func typeLoadBalancer(svcInfo *serviceInfo) *serviceInfo {
svcInfo.nodePort = 3001
svcInfo.loadBalancerStatus = api.LoadBalancerStatus{
Ingress: []api.LoadBalancerIngress{{IP: "1.2.3.4"}},
}
return svcInfo
}
func TestLoadBalancer(t *testing.T) { func TestLoadBalancer(t *testing.T) {
ipt := iptablestest.NewFake() ipt := iptablestest.NewFake()
fp := NewFakeProxier(ipt) fp := NewFakeProxier(ipt)
svcName := "svc1" svcIP := "10.20.30.41"
svcIP := net.IPv4(10, 20, 30, 41) svcPort := 80
svcNodePort := 3001
svcLBIP := "1.2.3.4"
svcPortName := proxy.ServicePortName{
NamespacedName: makeNSN("ns1", "svc1"),
Port: "p80",
}
svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "p80"} fp.allServices = []*api.Service{
svcInfo := newFakeServiceInfo(svc, svcIP, 80, api.ProtocolTCP, false) makeTestService(svcPortName.Namespace, svcPortName.Name, func(svc *api.Service) {
fp.serviceMap[svc] = typeLoadBalancer(svcInfo) svc.Spec.Type = "LoadBalancer"
svc.Spec.ClusterIP = svcIP
svc.Spec.Ports = []api.ServicePort{{
Name: svcPortName.Port,
Port: int32(svcPort),
Protocol: api.ProtocolTCP,
NodePort: int32(svcNodePort),
}}
svc.Status.LoadBalancer.Ingress = []api.LoadBalancerIngress{{
IP: svcLBIP,
}}
}),
}
ip := "10.180.0.1" epIP := "10.180.0.1"
port := 80
fp.allEndpoints = []*api.Endpoints{ fp.allEndpoints = []*api.Endpoints{
makeTestEndpoints("ns1", svcName, func(ept *api.Endpoints) { makeTestEndpoints(svcPortName.Namespace, svcPortName.Name, func(ept *api.Endpoints) {
ept.Subsets = []api.EndpointSubset{{ ept.Subsets = []api.EndpointSubset{{
Addresses: []api.EndpointAddress{{ Addresses: []api.EndpointAddress{{
IP: ip, IP: epIP,
}}, }},
Ports: []api.EndpointPort{{ Ports: []api.EndpointPort{{
Name: "p80", Name: svcPortName.Port,
Port: int32(port), Port: int32(svcPort),
}}, }},
}} }}
}), }),
@ -659,12 +692,12 @@ func TestLoadBalancer(t *testing.T) {
fp.syncProxyRules(syncReasonForce) fp.syncProxyRules(syncReasonForce)
proto := strings.ToLower(string(api.ProtocolTCP)) proto := strings.ToLower(string(api.ProtocolTCP))
fwChain := string(serviceFirewallChainName(svc, proto)) fwChain := string(serviceFirewallChainName(svcPortName, proto))
svcChain := string(servicePortChainName(svc, strings.ToLower(string(api.ProtocolTCP)))) svcChain := string(servicePortChainName(svcPortName, proto))
//lbChain := string(serviceLBChainName(svc, proto)) //lbChain := string(serviceLBChainName(svcPortName, proto))
kubeSvcRules := ipt.GetRules(string(kubeServicesChain)) kubeSvcRules := ipt.GetRules(string(kubeServicesChain))
if !hasJump(kubeSvcRules, fwChain, svcInfo.loadBalancerStatus.Ingress[0].IP, 80) { if !hasJump(kubeSvcRules, fwChain, svcLBIP, svcPort) {
errorf(fmt.Sprintf("Failed to find jump to firewall chain %v", fwChain), kubeSvcRules, t) errorf(fmt.Sprintf("Failed to find jump to firewall chain %v", fwChain), kubeSvcRules, t)
} }
@ -677,25 +710,37 @@ func TestLoadBalancer(t *testing.T) {
func TestNodePort(t *testing.T) { func TestNodePort(t *testing.T) {
ipt := iptablestest.NewFake() ipt := iptablestest.NewFake()
fp := NewFakeProxier(ipt) fp := NewFakeProxier(ipt)
svcName := "svc1" svcIP := "10.20.30.41"
svcIP := net.IPv4(10, 20, 30, 41) svcPort := 80
svcNodePort := 3001
svcPortName := proxy.ServicePortName{
NamespacedName: makeNSN("ns1", "svc1"),
Port: "p80",
}
svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "p80"} fp.allServices = []*api.Service{
svcInfo := newFakeServiceInfo(svc, svcIP, 80, api.ProtocolTCP, false) makeTestService(svcPortName.Namespace, svcPortName.Name, func(svc *api.Service) {
svcInfo.nodePort = 3001 svc.Spec.Type = "NodePort"
fp.serviceMap[svc] = svcInfo svc.Spec.ClusterIP = svcIP
svc.Spec.Ports = []api.ServicePort{{
Name: svcPortName.Port,
Port: int32(svcPort),
Protocol: api.ProtocolTCP,
NodePort: int32(svcNodePort),
}}
}),
}
ip := "10.180.0.1" epIP := "10.180.0.1"
port := 80
fp.allEndpoints = []*api.Endpoints{ fp.allEndpoints = []*api.Endpoints{
makeTestEndpoints("ns1", svcName, func(ept *api.Endpoints) { makeTestEndpoints(svcPortName.Namespace, svcPortName.Name, func(ept *api.Endpoints) {
ept.Subsets = []api.EndpointSubset{{ ept.Subsets = []api.EndpointSubset{{
Addresses: []api.EndpointAddress{{ Addresses: []api.EndpointAddress{{
IP: ip, IP: epIP,
}}, }},
Ports: []api.EndpointPort{{ Ports: []api.EndpointPort{{
Name: "p80", Name: svcPortName.Port,
Port: int32(port), Port: int32(svcPort),
}}, }},
}} }}
}), }),
@ -704,10 +749,10 @@ func TestNodePort(t *testing.T) {
fp.syncProxyRules(syncReasonForce) fp.syncProxyRules(syncReasonForce)
proto := strings.ToLower(string(api.ProtocolTCP)) proto := strings.ToLower(string(api.ProtocolTCP))
svcChain := string(servicePortChainName(svc, strings.ToLower(proto))) svcChain := string(servicePortChainName(svcPortName, proto))
kubeNodePortRules := ipt.GetRules(string(kubeNodePortsChain)) kubeNodePortRules := ipt.GetRules(string(kubeNodePortsChain))
if !hasJump(kubeNodePortRules, svcChain, "", svcInfo.nodePort) { if !hasJump(kubeNodePortRules, svcChain, "", svcNodePort) {
errorf(fmt.Sprintf("Failed to find jump to svc chain %v", svcChain), kubeNodePortRules, t) errorf(fmt.Sprintf("Failed to find jump to svc chain %v", svcChain), kubeNodePortRules, t)
} }
} }
@ -715,19 +760,32 @@ func TestNodePort(t *testing.T) {
func TestNodePortReject(t *testing.T) { func TestNodePortReject(t *testing.T) {
ipt := iptablestest.NewFake() ipt := iptablestest.NewFake()
fp := NewFakeProxier(ipt) fp := NewFakeProxier(ipt)
svcName := "svc1" svcIP := "10.20.30.41"
svcIP := net.IPv4(10, 20, 30, 41) svcPort := 80
svcNodePort := 3001
svcPortName := proxy.ServicePortName{
NamespacedName: makeNSN("ns1", "svc1"),
Port: "p80",
}
svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "p80"} fp.allServices = []*api.Service{
svcInfo := newFakeServiceInfo(svc, svcIP, 80, api.ProtocolTCP, false) makeTestService(svcPortName.Namespace, svcPortName.Name, func(svc *api.Service) {
svcInfo.nodePort = 3001 svc.Spec.Type = "NodePort"
fp.serviceMap[svc] = svcInfo svc.Spec.ClusterIP = svcIP
svc.Spec.Ports = []api.ServicePort{{
Name: svcPortName.Port,
Port: int32(svcPort),
Protocol: api.ProtocolTCP,
NodePort: int32(svcNodePort),
}}
}),
}
fp.syncProxyRules(syncReasonForce) fp.syncProxyRules(syncReasonForce)
kubeSvcRules := ipt.GetRules(string(kubeServicesChain)) kubeSvcRules := ipt.GetRules(string(kubeServicesChain))
if !hasJump(kubeSvcRules, iptablestest.Reject, svcIP.String(), 3001) { if !hasJump(kubeSvcRules, iptablestest.Reject, svcIP, svcNodePort) {
errorf(fmt.Sprintf("Failed to find a %v rule for service %v with no endpoints", iptablestest.Reject, svcName), kubeSvcRules, t) errorf(fmt.Sprintf("Failed to find a %v rule for service %v with no endpoints", iptablestest.Reject, svcPortName), kubeSvcRules, t)
} }
} }
@ -738,47 +796,65 @@ func strPtr(s string) *string {
func TestOnlyLocalLoadBalancing(t *testing.T) { func TestOnlyLocalLoadBalancing(t *testing.T) {
ipt := iptablestest.NewFake() ipt := iptablestest.NewFake()
fp := NewFakeProxier(ipt) fp := NewFakeProxier(ipt)
svcName := "svc1" svcIP := "10.20.30.41"
svcIP := net.IPv4(10, 20, 30, 41) svcPort := 80
svcNodePort := 3001
svcLBIP := "1.2.3.4"
svcPortName := proxy.ServicePortName{
NamespacedName: makeNSN("ns1", "svc1"),
Port: "p80",
}
svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "p80"} fp.allServices = []*api.Service{
svcInfo := newFakeServiceInfo(svc, svcIP, 80, api.ProtocolTCP, true) makeTestService(svcPortName.Namespace, svcPortName.Name, func(svc *api.Service) {
fp.serviceMap[svc] = typeLoadBalancer(svcInfo) svc.Spec.Type = "LoadBalancer"
svc.Spec.ClusterIP = svcIP
svc.Spec.Ports = []api.ServicePort{{
Name: svcPortName.Port,
Port: int32(svcPort),
Protocol: api.ProtocolTCP,
NodePort: int32(svcNodePort),
}}
svc.Status.LoadBalancer.Ingress = []api.LoadBalancerIngress{{
IP: svcLBIP,
}}
svc.Annotations[service.BetaAnnotationExternalTraffic] = service.AnnotationValueExternalTrafficLocal
}),
}
ip1 := "10.180.0.1" epIP1 := "10.180.0.1"
ip2 := "10.180.2.1" epIP2 := "10.180.2.1"
port := 80 epStrLocal := fmt.Sprintf("%s:%d", epIP1, svcPort)
nonLocalEp := fmt.Sprintf("%s:%d", ip1, port) epStrNonLocal := fmt.Sprintf("%s:%d", epIP2, svcPort)
localEp := fmt.Sprintf("%s:%d", ip2, port) fp.allEndpoints = []*api.Endpoints{
allEndpoints := []*api.Endpoints{ makeTestEndpoints(svcPortName.Namespace, svcPortName.Name, func(ept *api.Endpoints) {
makeTestEndpoints("ns1", svcName, func(ept *api.Endpoints) {
ept.Subsets = []api.EndpointSubset{{ ept.Subsets = []api.EndpointSubset{{
Addresses: []api.EndpointAddress{{ Addresses: []api.EndpointAddress{{
IP: ip1, IP: epIP1,
NodeName: nil, NodeName: nil,
}, { }, {
IP: ip2, IP: epIP2,
NodeName: strPtr(testHostname), NodeName: strPtr(testHostname),
}}, }},
Ports: []api.EndpointPort{{ Ports: []api.EndpointPort{{
Name: "p80", Name: svcPortName.Port,
Port: int32(port), Port: int32(svcPort),
}}, }},
}} }}
}), }),
} }
fp.OnEndpointsUpdate(allEndpoints) fp.syncProxyRules(syncReasonForce)
proto := strings.ToLower(string(api.ProtocolTCP)) proto := strings.ToLower(string(api.ProtocolTCP))
fwChain := string(serviceFirewallChainName(svc, proto)) fwChain := string(serviceFirewallChainName(svcPortName, proto))
lbChain := string(serviceLBChainName(svc, proto)) lbChain := string(serviceLBChainName(svcPortName, proto))
nonLocalEpChain := string(servicePortEndpointChainName(svc, strings.ToLower(string(api.ProtocolTCP)), nonLocalEp)) nonLocalEpChain := string(servicePortEndpointChainName(svcPortName, strings.ToLower(string(api.ProtocolTCP)), epStrLocal))
localEpChain := string(servicePortEndpointChainName(svc, strings.ToLower(string(api.ProtocolTCP)), localEp)) localEpChain := string(servicePortEndpointChainName(svcPortName, strings.ToLower(string(api.ProtocolTCP)), epStrNonLocal))
kubeSvcRules := ipt.GetRules(string(kubeServicesChain)) kubeSvcRules := ipt.GetRules(string(kubeServicesChain))
if !hasJump(kubeSvcRules, fwChain, svcInfo.loadBalancerStatus.Ingress[0].IP, 0) { if !hasJump(kubeSvcRules, fwChain, svcLBIP, svcPort) {
errorf(fmt.Sprintf("Failed to find jump to firewall chain %v", fwChain), kubeSvcRules, t) errorf(fmt.Sprintf("Failed to find jump to firewall chain %v", fwChain), kubeSvcRules, t)
} }
@ -792,10 +868,10 @@ func TestOnlyLocalLoadBalancing(t *testing.T) {
lbRules := ipt.GetRules(lbChain) lbRules := ipt.GetRules(lbChain)
if hasJump(lbRules, nonLocalEpChain, "", 0) { if hasJump(lbRules, nonLocalEpChain, "", 0) {
errorf(fmt.Sprintf("Found jump from lb chain %v to non-local ep %v", lbChain, nonLocalEp), lbRules, t) errorf(fmt.Sprintf("Found jump from lb chain %v to non-local ep %v", lbChain, epStrLocal), lbRules, t)
} }
if !hasJump(lbRules, localEpChain, "", 0) { if !hasJump(lbRules, localEpChain, "", 0) {
errorf(fmt.Sprintf("Didn't find jump from lb chain %v to local ep %v", lbChain, localEp), lbRules, t) errorf(fmt.Sprintf("Didn't find jump from lb chain %v to local ep %v", lbChain, epStrNonLocal), lbRules, t)
} }
} }
@ -815,54 +891,67 @@ func TestOnlyLocalNodePorts(t *testing.T) {
func onlyLocalNodePorts(t *testing.T, fp *Proxier, ipt *iptablestest.FakeIPTables) { func onlyLocalNodePorts(t *testing.T, fp *Proxier, ipt *iptablestest.FakeIPTables) {
shouldLBTOSVCRuleExist := len(fp.clusterCIDR) > 0 shouldLBTOSVCRuleExist := len(fp.clusterCIDR) > 0
svcName := "svc1" svcIP := "10.20.30.41"
svcIP := net.IPv4(10, 20, 30, 41) svcPort := 80
svcNodePort := 3001
svcPortName := proxy.ServicePortName{
NamespacedName: makeNSN("ns1", "svc1"),
Port: "p80",
}
svc := proxy.ServicePortName{NamespacedName: types.NamespacedName{Namespace: "ns1", Name: svcName}, Port: "p80"} fp.allServices = []*api.Service{
svcInfo := newFakeServiceInfo(svc, svcIP, 80, api.ProtocolTCP, true) makeTestService(svcPortName.Namespace, svcPortName.Name, func(svc *api.Service) {
svcInfo.nodePort = 3001 svc.Spec.Type = "NodePort"
fp.serviceMap[svc] = svcInfo svc.Spec.ClusterIP = svcIP
svc.Spec.Ports = []api.ServicePort{{
Name: svcPortName.Port,
Port: int32(svcPort),
Protocol: api.ProtocolTCP,
NodePort: int32(svcNodePort),
}}
svc.Annotations[service.BetaAnnotationExternalTraffic] = service.AnnotationValueExternalTrafficLocal
}),
}
ip1 := "10.180.0.1" epIP1 := "10.180.0.1"
ip2 := "10.180.2.1" epIP2 := "10.180.2.1"
port := 80 epStrLocal := fmt.Sprintf("%s:%d", epIP1, svcPort)
nonLocalEp := fmt.Sprintf("%s:%d", ip1, port) epStrNonLocal := fmt.Sprintf("%s:%d", epIP2, svcPort)
localEp := fmt.Sprintf("%s:%d", ip2, port) fp.allEndpoints = []*api.Endpoints{
allEndpoints := []*api.Endpoints{ makeTestEndpoints(svcPortName.Namespace, svcPortName.Name, func(ept *api.Endpoints) {
makeTestEndpoints("ns1", svcName, func(ept *api.Endpoints) {
ept.Subsets = []api.EndpointSubset{{ ept.Subsets = []api.EndpointSubset{{
Addresses: []api.EndpointAddress{{ Addresses: []api.EndpointAddress{{
IP: ip1, IP: epIP1,
NodeName: nil, NodeName: nil,
}, { }, {
IP: ip2, IP: epIP2,
NodeName: strPtr(testHostname), NodeName: strPtr(testHostname),
}}, }},
Ports: []api.EndpointPort{{ Ports: []api.EndpointPort{{
Name: "p80", Name: svcPortName.Port,
Port: int32(port), Port: int32(svcPort),
}}, }},
}} }}
}), }),
} }
fp.OnEndpointsUpdate(allEndpoints) fp.syncProxyRules(syncReasonForce)
proto := strings.ToLower(string(api.ProtocolTCP)) proto := strings.ToLower(string(api.ProtocolTCP))
lbChain := string(serviceLBChainName(svc, proto)) lbChain := string(serviceLBChainName(svcPortName, proto))
nonLocalEpChain := string(servicePortEndpointChainName(svc, strings.ToLower(string(api.ProtocolTCP)), nonLocalEp)) nonLocalEpChain := string(servicePortEndpointChainName(svcPortName, proto, epStrLocal))
localEpChain := string(servicePortEndpointChainName(svc, strings.ToLower(string(api.ProtocolTCP)), localEp)) localEpChain := string(servicePortEndpointChainName(svcPortName, proto, epStrNonLocal))
kubeNodePortRules := ipt.GetRules(string(kubeNodePortsChain)) kubeNodePortRules := ipt.GetRules(string(kubeNodePortsChain))
if !hasJump(kubeNodePortRules, lbChain, "", svcInfo.nodePort) { if !hasJump(kubeNodePortRules, lbChain, "", svcNodePort) {
errorf(fmt.Sprintf("Failed to find jump to lb chain %v", lbChain), kubeNodePortRules, t) errorf(fmt.Sprintf("Failed to find jump to lb chain %v", lbChain), kubeNodePortRules, t)
} }
svcChain := string(servicePortChainName(svc, strings.ToLower(string(api.ProtocolTCP)))) svcChain := string(servicePortChainName(svcPortName, proto))
lbRules := ipt.GetRules(lbChain) lbRules := ipt.GetRules(lbChain)
if hasJump(lbRules, nonLocalEpChain, "", 0) { if hasJump(lbRules, nonLocalEpChain, "", 0) {
errorf(fmt.Sprintf("Found jump from lb chain %v to non-local ep %v", lbChain, nonLocalEp), lbRules, t) errorf(fmt.Sprintf("Found jump from lb chain %v to non-local ep %v", lbChain, epStrLocal), lbRules, t)
} }
if hasJump(lbRules, svcChain, "", 0) != shouldLBTOSVCRuleExist { if hasJump(lbRules, svcChain, "", 0) != shouldLBTOSVCRuleExist {
prefix := "Did not find " prefix := "Did not find "
@ -872,15 +961,16 @@ func onlyLocalNodePorts(t *testing.T, fp *Proxier, ipt *iptablestest.FakeIPTable
errorf(fmt.Sprintf("%s jump from lb chain %v to svc %v", prefix, lbChain, svcChain), lbRules, t) errorf(fmt.Sprintf("%s jump from lb chain %v to svc %v", prefix, lbChain, svcChain), lbRules, t)
} }
if !hasJump(lbRules, localEpChain, "", 0) { if !hasJump(lbRules, localEpChain, "", 0) {
errorf(fmt.Sprintf("Didn't find jump from lb chain %v to local ep %v", lbChain, nonLocalEp), lbRules, t) errorf(fmt.Sprintf("Didn't find jump from lb chain %v to local ep %v", lbChain, epStrLocal), lbRules, t)
} }
} }
func makeTestService(namespace, name string, svcFunc func(*api.Service)) *api.Service { func makeTestService(namespace, name string, svcFunc func(*api.Service)) *api.Service {
svc := &api.Service{ svc := &api.Service{
ObjectMeta: metav1.ObjectMeta{ ObjectMeta: metav1.ObjectMeta{
Name: name, Name: name,
Namespace: namespace, Namespace: namespace,
Annotations: map[string]string{},
}, },
Spec: api.ServiceSpec{}, Spec: api.ServiceSpec{},
Status: api.ServiceStatus{}, Status: api.ServiceStatus{},