diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer.go index 6f15fb5fd05..ab333899633 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer.go @@ -74,6 +74,7 @@ const ( // ServiceAnnotationAllowedServiceTag is the annotation used on the service // to specify a list of allowed service tags separated by comma + // Refer https://docs.microsoft.com/en-us/azure/virtual-network/security-overview#service-tags for all supported service tags. ServiceAnnotationAllowedServiceTag = "service.beta.kubernetes.io/azure-allowed-service-tags" // ServiceAnnotationLoadBalancerIdleTimeout is the annotation used on the service @@ -90,13 +91,6 @@ const ( clusterNameKey = "kubernetes-cluster-name" ) -var ( - // supportedServiceTags holds a list of supported service tags on Azure. - // Refer https://docs.microsoft.com/en-us/azure/virtual-network/security-overview#service-tags for more information. - supportedServiceTags = sets.NewString("VirtualNetwork", "VIRTUAL_NETWORK", "AzureLoadBalancer", "AZURE_LOADBALANCER", - "Internet", "INTERNET", "AzureTrafficManager", "Storage", "Sql") -) - // GetLoadBalancer returns whether the specified load balancer exists, and // if so, what its status is. func (az *Cloud) GetLoadBalancer(ctx context.Context, clusterName string, service *v1.Service) (status *v1.LoadBalancerStatus, exists bool, err error) { @@ -1028,10 +1022,7 @@ func (az *Cloud) reconcileSecurityGroup(clusterName string, service *v1.Service, if err != nil { return nil, err } - serviceTags, err := getServiceTags(service) - if err != nil { - return nil, err - } + serviceTags := getServiceTags(service) var sourceAddressPrefixes []string if (sourceRanges == nil || servicehelpers.IsAllowAll(sourceRanges)) && len(serviceTags) == 0 { if !requiresInternalLoadBalancer(service) { @@ -1609,24 +1600,25 @@ func useSharedSecurityRule(service *v1.Service) bool { return false } -func getServiceTags(service *v1.Service) ([]string, error) { +func getServiceTags(service *v1.Service) []string { + if service == nil { + return nil + } + if serviceTags, found := service.Annotations[ServiceAnnotationAllowedServiceTag]; found { + result := []string{} tags := strings.Split(strings.TrimSpace(serviceTags), ",") for _, tag := range tags { - // Storage and Sql service tags support setting regions with suffix ".Region" - if strings.HasPrefix(tag, "Storage.") || strings.HasPrefix(tag, "Sql.") { - continue - } - - if !supportedServiceTags.Has(tag) { - return nil, fmt.Errorf("only %q are allowed in service tags", supportedServiceTags.List()) + serviceTag := strings.TrimSpace(tag) + if serviceTag != "" { + result = append(result, serviceTag) } } - return tags, nil + return result } - return nil, nil + return nil } func serviceOwnsPublicIP(pip *network.PublicIPAddress, clusterName, serviceName string) bool { diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer_test.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer_test.go index 0efb7ee2974..86fb679f4cc 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer_test.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_loadbalancer_test.go @@ -447,3 +447,60 @@ func TestServiceOwnsPublicIP(t *testing.T) { assert.Equal(t, owns, c.expected, "TestCase[%d]: %s", i, c.desc) } } + +func TestGetServiceTags(t *testing.T) { + tests := []struct { + desc string + service *v1.Service + expected []string + }{ + { + desc: "nil should be returned when service is nil", + service: nil, + expected: nil, + }, + { + desc: "nil should be returned when service has no annotations", + service: &v1.Service{}, + expected: nil, + }, + { + desc: "single tag should be returned when service has set one annotations", + service: &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + ServiceAnnotationAllowedServiceTag: "tag1", + }, + }, + }, + expected: []string{"tag1"}, + }, + { + desc: "multiple tags should be returned when service has set multi-annotations", + service: &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + ServiceAnnotationAllowedServiceTag: "tag1, tag2", + }, + }, + }, + expected: []string{"tag1", "tag2"}, + }, + { + desc: "correct tags should be returned when comma or spaces are included in the annotations", + service: &v1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Annotations: map[string]string{ + ServiceAnnotationAllowedServiceTag: ", tag1, ", + }, + }, + }, + expected: []string{"tag1"}, + }, + } + + for i, c := range tests { + tags := getServiceTags(c.service) + assert.Equal(t, tags, c.expected, "TestCase[%d]: %s", i, c.desc) + } +}