diff --git a/pkg/cloudprovider/providers/azure/azure_loadbalancer.go b/pkg/cloudprovider/providers/azure/azure_loadbalancer.go index 13577e06581..a6c0911e57b 100644 --- a/pkg/cloudprovider/providers/azure/azure_loadbalancer.go +++ b/pkg/cloudprovider/providers/azure/azure_loadbalancer.go @@ -69,6 +69,17 @@ const ( // ServiceAnnotationLoadBalancerResourceGroup is the annotation used on the service // to specify the resource group of load balancer objects that are not in the same resource group as the cluster. ServiceAnnotationLoadBalancerResourceGroup = "service.beta.kubernetes.io/azure-load-balancer-resource-group" + + // ServiceAnnotationAllowedServiceTag is the annotation used on the service + // to specify a list of allowed service tags separated by comma + ServiceAnnotationAllowedServiceTag = "service.beta.kubernetes.io/azure-allowed-service-tags" +) + +var ( + // supportedServiceTags holds a list of supported service tags on Azure. + // Refer https://docs.microsoft.com/en-us/azure/virtual-network/security-overview#service-tags for more information. + supportedServiceTags = sets.NewString("VirtualNetwork", "VIRTUAL_NETWORK", "AzureLoadBalancer", "AZURE_LOADBALANCER", + "Internet", "INTERNET", "AzureTrafficManager", "Storage", "Sql") ) // GetLoadBalancer returns whether the specified load balancer exists, and @@ -838,8 +849,12 @@ func (az *Cloud) reconcileSecurityGroup(clusterName string, service *v1.Service, if err != nil { return nil, err } + serviceTags, err := getServiceTags(service) + if err != nil { + return nil, err + } var sourceAddressPrefixes []string - if sourceRanges == nil || serviceapi.IsAllowAll(sourceRanges) { + if (sourceRanges == nil || serviceapi.IsAllowAll(sourceRanges)) && len(serviceTags) == 0 { if !requiresInternalLoadBalancer(service) { sourceAddressPrefixes = []string{"Internet"} } @@ -847,6 +862,9 @@ func (az *Cloud) reconcileSecurityGroup(clusterName string, service *v1.Service, for _, ip := range sourceRanges { sourceAddressPrefixes = append(sourceAddressPrefixes, ip.String()) } + for _, serviceTag := range serviceTags { + sourceAddressPrefixes = append(sourceAddressPrefixes, serviceTag) + } } expectedSecurityRules := []network.SecurityRule{} @@ -1319,3 +1337,23 @@ func useSharedSecurityRule(service *v1.Service) bool { return false } + +func getServiceTags(service *v1.Service) ([]string, error) { + if serviceTags, found := service.Annotations[ServiceAnnotationAllowedServiceTag]; found { + tags := strings.Split(strings.TrimSpace(serviceTags), ",") + for _, tag := range tags { + // Storage and Sql service tags support setting regions with suffix ".Region" + if strings.HasPrefix(tag, "Storage.") || strings.HasPrefix(tag, "Sql.") { + continue + } + + if !supportedServiceTags.Has(tag) { + return nil, fmt.Errorf("only %q are allowed in service tags", supportedServiceTags.List()) + } + } + + return tags, nil + } + + return nil, nil +}