diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_storageaccount.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_storageaccount.go index bfc153171d5..a87406bb1ec 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_storageaccount.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_storageaccount.go @@ -23,6 +23,7 @@ import ( "strings" "github.com/Azure/azure-sdk-for-go/services/storage/mgmt/2019-06-01/storage" + "github.com/Azure/go-autorest/autorest/to" "k8s.io/klog/v2" ) @@ -32,17 +33,18 @@ type AccountOptions struct { Name, Type, Kind, ResourceGroup, Location string EnableHTTPSTrafficOnly bool Tags map[string]string + VirtualNetworkResourceIDs []string } type accountWithLocation struct { Name, StorageType, Location string } -// getStorageAccounts gets name, type, location of all storage accounts in a resource group which matches matchingAccountType, matchingLocation -func (az *Cloud) getStorageAccounts(matchingAccountType, matchingAccountKind, resourceGroup, matchingLocation string) ([]accountWithLocation, error) { +// getStorageAccounts get matching storage accounts +func (az *Cloud) getStorageAccounts(accountOptions *AccountOptions) ([]accountWithLocation, error) { ctx, cancel := getContextWithCancel() defer cancel() - result, rerr := az.StorageAccountClient.ListByResourceGroup(ctx, resourceGroup) + result, rerr := az.StorageAccountClient.ListByResourceGroup(ctx, accountOptions.ResourceGroup) if rerr != nil { return nil, rerr.Error() } @@ -51,18 +53,39 @@ func (az *Cloud) getStorageAccounts(matchingAccountType, matchingAccountKind, re for _, acct := range result { if acct.Name != nil && acct.Location != nil && acct.Sku != nil { storageType := string((*acct.Sku).Name) - if matchingAccountType != "" && !strings.EqualFold(matchingAccountType, storageType) { + if accountOptions.Type != "" && !strings.EqualFold(accountOptions.Type, storageType) { continue } - if matchingAccountKind != "" && !strings.EqualFold(matchingAccountKind, string(acct.Kind)) { + if accountOptions.Kind != "" && !strings.EqualFold(accountOptions.Kind, string(acct.Kind)) { continue } location := *acct.Location - if matchingLocation != "" && !strings.EqualFold(matchingLocation, location) { + if accountOptions.Location != "" && !strings.EqualFold(accountOptions.Location, location) { continue } + + if len(accountOptions.VirtualNetworkResourceIDs) > 0 { + if acct.AccountProperties == nil || acct.AccountProperties.NetworkRuleSet == nil || + acct.AccountProperties.NetworkRuleSet.VirtualNetworkRules == nil { + continue + } + + found := false + for _, subnetID := range accountOptions.VirtualNetworkResourceIDs { + for _, rule := range *acct.AccountProperties.NetworkRuleSet.VirtualNetworkRules { + if strings.EqualFold(to.String(rule.VirtualNetworkResourceID), subnetID) && rule.Action == storage.Allow { + found = true + break + } + } + } + if !found { + continue + } + } + accounts = append(accounts, accountWithLocation{Name: *acct.Name, StorageType: storageType, Location: location}) } } @@ -106,9 +129,10 @@ func (az *Cloud) EnsureStorageAccount(accountOptions *AccountOptions, genAccount resourceGroup := accountOptions.ResourceGroup location := accountOptions.Location enableHTTPSTrafficOnly := accountOptions.EnableHTTPSTrafficOnly + if len(accountName) == 0 { // find a storage account that matches accountType - accounts, err := az.getStorageAccounts(accountType, accountKind, resourceGroup, location) + accounts, err := az.getStorageAccounts(accountOptions) if err != nil { return "", "", fmt.Errorf("could not list storage accounts for account type %s: %v", accountType, err) } @@ -119,6 +143,24 @@ func (az *Cloud) EnsureStorageAccount(accountOptions *AccountOptions, genAccount } if len(accountName) == 0 { + // set network rules for storage account + var networkRuleSet *storage.NetworkRuleSet + virtualNetworkRules := []storage.VirtualNetworkRule{} + for _, subnetID := range accountOptions.VirtualNetworkResourceIDs { + vnetRule := storage.VirtualNetworkRule{ + VirtualNetworkResourceID: &subnetID, + Action: storage.Allow, + } + virtualNetworkRules = append(virtualNetworkRules, vnetRule) + klog.V(4).Infof("subnetID(%s) has been set", subnetID) + } + if len(virtualNetworkRules) > 0 { + networkRuleSet = &storage.NetworkRuleSet{ + VirtualNetworkRules: &virtualNetworkRules, + DefaultAction: storage.DefaultActionDeny, + } + } + // not found a matching account, now create a new account in current resource group accountName = generateStorageAccountName(genAccountNamePrefix) if location == "" { @@ -143,11 +185,14 @@ func (az *Cloud) EnsureStorageAccount(accountOptions *AccountOptions, genAccount accountName, resourceGroup, location, accountType, kind, accountOptions.Tags) cp := storage.AccountCreateParameters{ - Sku: &storage.Sku{Name: storage.SkuName(accountType)}, - Kind: kind, - AccountPropertiesCreateParameters: &storage.AccountPropertiesCreateParameters{EnableHTTPSTrafficOnly: &enableHTTPSTrafficOnly}, - Tags: tags, - Location: &location} + Sku: &storage.Sku{Name: storage.SkuName(accountType)}, + Kind: kind, + AccountPropertiesCreateParameters: &storage.AccountPropertiesCreateParameters{ + EnableHTTPSTrafficOnly: &enableHTTPSTrafficOnly, + NetworkRuleSet: networkRuleSet, + }, + Tags: tags, + Location: &location} ctx, cancel := getContextWithCancel() defer cancel()