diff --git a/pkg/cloudprovider/providers/azure/azure_loadbalancer.go b/pkg/cloudprovider/providers/azure/azure_loadbalancer.go index abca3a7f423..940723d2a46 100644 --- a/pkg/cloudprovider/providers/azure/azure_loadbalancer.go +++ b/pkg/cloudprovider/providers/azure/azure_loadbalancer.go @@ -503,11 +503,11 @@ func (az *Cloud) reconcileLoadBalancer(lb network.LoadBalancer, fipConfiguration } else { ports = []v1.ServicePort{} } - lbRuleName := getRuleName(service, port) var expectedProbes []network.Probe var expectedRules []network.LoadBalancingRule for _, port := range ports { + lbRuleName := getLoadBalancerRuleName(service, port) transportProto, _, probeProto, err := getProtocolsFromKubernetesProtocol(port.Protocol) if err != nil { @@ -690,13 +690,13 @@ func (az *Cloud) reconcileSecurityGroup(sg network.SecurityGroup, clusterName st expectedSecurityRules := make([]network.SecurityRule, len(ports)*len(sourceAddressPrefixes)) for i, port := range ports { - securityRuleName := getRuleName(service, port) _, securityProto, _, err := getProtocolsFromKubernetesProtocol(port.Protocol) if err != nil { return sg, false, err } for j := range sourceAddressPrefixes { ix := i*len(sourceAddressPrefixes) + j + securityRuleName := getSecurityRuleName(service, port, sourceAddressPrefixes[j]) expectedSecurityRules[ix] = network.SecurityRule{ Name: to.StringPtr(securityRuleName), SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ diff --git a/pkg/cloudprovider/providers/azure/azure_test.go b/pkg/cloudprovider/providers/azure/azure_test.go index e38fe8979ac..a0718068aaa 100644 --- a/pkg/cloudprovider/providers/azure/azure_test.go +++ b/pkg/cloudprovider/providers/azure/azure_test.go @@ -260,8 +260,8 @@ func TestReconcileSecurityWithSourceRanges(t *testing.T) { az := getTestCloud() svc := getTestService("servicea", v1.ProtocolTCP, 80, 443) svc.Spec.LoadBalancerSourceRanges = []string{ - "192.168.0.1/24", - "10.0.0.1/32", + "192.168.0.0/24", + "10.0.0.0/32", } sg := getTestSecurityGroup(svc) @@ -336,7 +336,7 @@ func getTestLoadBalancer(services ...v1.Service) network.LoadBalancer { for _, service := range services { for _, port := range service.Spec.Ports { - ruleName := getRuleName(&service, port) + ruleName := getLoadBalancerRuleName(&service, port) rules = append(rules, network.LoadBalancingRule{ Name: to.StringPtr(ruleName), LoadBalancingRulePropertiesFormat: &network.LoadBalancingRulePropertiesFormat{ @@ -378,10 +378,9 @@ func getTestSecurityGroup(services ...v1.Service) network.SecurityGroup { for _, service := range services { for _, port := range service.Spec.Ports { - ruleName := getRuleName(&service, port) - sources := getServiceSourceRanges(&service) for _, src := range sources { + ruleName := getSecurityRuleName(&service, port, src) rules = append(rules, network.SecurityRule{ Name: to.StringPtr(ruleName), SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ @@ -412,7 +411,7 @@ func validateLoadBalancer(t *testing.T, loadBalancer network.LoadBalancer, servi } for _, wantedRule := range svc.Spec.Ports { expectedRuleCount++ - wantedRuleName := getRuleName(&svc, wantedRule) + wantedRuleName := getLoadBalancerRuleName(&svc, wantedRule) foundRule := false for _, actualRule := range *loadBalancer.LoadBalancingRules { if strings.EqualFold(*actualRule.Name, wantedRuleName) && @@ -483,8 +482,8 @@ func validateSecurityGroup(t *testing.T, securityGroup network.SecurityGroup, se for _, svc := range services { for _, wantedRule := range svc.Spec.Ports { sources := getServiceSourceRanges(&svc) - wantedRuleName := getRuleName(&svc, wantedRule) for _, source := range sources { + wantedRuleName := getSecurityRuleName(&svc, wantedRule, source) expectedRuleCount++ foundRule := false for _, actualRule := range *securityGroup.SecurityRules { @@ -557,22 +556,28 @@ func TestProtocolTranslationTCP(t *testing.T) { t.Error(err) } - if transportProto != network.TransportProtocolTCP { + if *transportProto != network.TransportProtocolTCP { t.Errorf("Expected TCP LoadBalancer Rule Protocol. Got %v", transportProto) } - if securityGroupProto != network.TCP { + if *securityGroupProto != network.TCP { t.Errorf("Expected TCP SecurityGroup Protocol. Got %v", transportProto) } - if probeProto != network.ProbeProtocolTCP { + if *probeProto != network.ProbeProtocolTCP { t.Errorf("Expected TCP LoadBalancer Probe Protocol. Got %v", transportProto) } } func TestProtocolTranslationUDP(t *testing.T) { proto := v1.ProtocolUDP - _, _, _, err := getProtocolsFromKubernetesProtocol(proto) - if err == nil { - t.Error("Expected an error. UDP is unsupported.") + transportProto, securityGroupProto, probeProto, _ := getProtocolsFromKubernetesProtocol(proto) + if *transportProto != network.TransportProtocolUDP { + t.Errorf("Expected UDP LoadBalancer Rule Protocol. Got %v", transportProto) + } + if *securityGroupProto != network.UDP { + t.Errorf("Expected UDP SecurityGroup Protocol. Got %v", transportProto) + } + if probeProto != nil { + t.Errorf("Expected UDP LoadBalancer Probe Protocol. Got %v", transportProto) } } diff --git a/pkg/cloudprovider/providers/azure/azure_util.go b/pkg/cloudprovider/providers/azure/azure_util.go index 660a7db8b2e..094f3aaf903 100644 --- a/pkg/cloudprovider/providers/azure/azure_util.go +++ b/pkg/cloudprovider/providers/azure/azure_util.go @@ -132,14 +132,15 @@ func getProtocolsFromKubernetesProtocol(protocol v1.Protocol) (*network.Transpor transportProto = network.TransportProtocolTCP securityProto = network.TCP probeProto = network.ProbeProtocolTCP + return &transportProto, &securityProto, &probeProto, nil case v1.ProtocolUDP: transportProto = network.TransportProtocolUDP securityProto = network.UDP + return &transportProto, &securityProto, nil, nil default: return &transportProto, &securityProto, &probeProto, fmt.Errorf("Only TCP and UDP are supported for Azure LoadBalancers") } - return &transportProto, &securityProto, &probeProto, nil } // This returns the full identifier of the primary NIC for the given VM. @@ -186,8 +187,13 @@ func getBackendPoolName(clusterName string) string { return clusterName } -func getRuleName(service *v1.Service, port v1.ServicePort) string { - return fmt.Sprintf("%s-%s-%d-%d", getRulePrefix(service), port.Protocol, port.Port, port.NodePort) +func getLoadBalancerRuleName(service *v1.Service, port v1.ServicePort) string { + return fmt.Sprintf("%s-%s-%d", getRulePrefix(service), port.Protocol, port.Port) +} + +func getSecurityRuleName(service *v1.Service, port v1.ServicePort, sourceAddrPrefix string) string { + safePrefix := strings.Replace(sourceAddrPrefix, "/", "_", -1) + return fmt.Sprintf("%s-%s-%d-%s", getRulePrefix(service), port.Protocol, port.Port, safePrefix) } // This returns a human-readable version of the Service used to tag some resources.