From 9973b1cb5e3a8ac201d4c2efbe677d2720ab73b7 Mon Sep 17 00:00:00 2001 From: andyzhangx Date: Sun, 21 Jun 2020 12:57:41 +0000 Subject: [PATCH] feat: add tags support for azure disk driver merge tags map --- pkg/volume/azure_dd/azure_provision.go | 42 ++++++++++++- pkg/volume/azure_dd/azure_provision_test.go | 69 ++++++++++++++++++++- 2 files changed, 108 insertions(+), 3 deletions(-) diff --git a/pkg/volume/azure_dd/azure_provision.go b/pkg/volume/azure_dd/azure_provision.go index 3ce4074b0d3..af7d0957670 100644 --- a/pkg/volume/azure_dd/azure_provision.go +++ b/pkg/volume/azure_dd/azure_provision.go @@ -35,6 +35,11 @@ import ( "k8s.io/legacy-cloud-providers/azure" ) +const ( + TagsDelimiter = "," + TagKeyValueDelimiter = "=" +) + type azureDiskProvisioner struct { plugin *azureDataDiskPlugin options volume.VolumeOptions @@ -118,6 +123,7 @@ func (p *azureDiskProvisioner) Provision(selectedNode *v1.Node, allowedTopologie diskIopsReadWrite string diskMbpsReadWrite string diskEncryptionSetID string + customTags string maxShares int ) @@ -164,6 +170,8 @@ func (p *azureDiskProvisioner) Provision(selectedNode *v1.Node, allowedTopologie diskMbpsReadWrite = v case "diskencryptionsetid": diskEncryptionSetID = v + case "tags": + customTags = v case azure.WriteAcceleratorEnabled: writeAcceleratorEnabled = v case "maxshares": @@ -261,9 +269,14 @@ func (p *azureDiskProvisioner) Provision(selectedNode *v1.Node, allowedTopologie diskURI := "" labels := map[string]string{} if kind == v1.AzureManagedDisk { - tags := make(map[string]string) + tags, err := ConvertTagsToMap(customTags) + if err != nil { + return nil, err + } if p.options.CloudTags != nil { - tags = *(p.options.CloudTags) + for k, v := range *(p.options.CloudTags) { + tags[k] = v + } } if strings.EqualFold(writeAcceleratorEnabled, "true") { tags[azure.WriteAcceleratorEnabled] = "true" @@ -386,3 +399,28 @@ func (p *azureDiskProvisioner) Provision(selectedNode *v1.Node, allowedTopologie return pv, nil } + +// ConvertTagsToMap convert the tags from string to map +// the valid tags fomat is "key1=value1,key2=value2", which could be converted to +// {"key1": "value1", "key2": "value2"} +func ConvertTagsToMap(tags string) (map[string]string, error) { + m := make(map[string]string) + if tags == "" { + return m, nil + } + s := strings.Split(tags, TagsDelimiter) + for _, tag := range s { + kv := strings.Split(tag, TagKeyValueDelimiter) + if len(kv) != 2 { + return nil, fmt.Errorf("Tags '%s' are invalid, the format should like: 'key1=value1,key2=value2'", tags) + } + key := strings.TrimSpace(kv[0]) + if key == "" { + return nil, fmt.Errorf("Tags '%s' are invalid, the format should like: 'key1=value1,key2=value2'", tags) + } + value := strings.TrimSpace(kv[1]) + m[key] = value + } + + return m, nil +} diff --git a/pkg/volume/azure_dd/azure_provision_test.go b/pkg/volume/azure_dd/azure_provision_test.go index dfc8fc80573..49320b2b9bd 100644 --- a/pkg/volume/azure_dd/azure_provision_test.go +++ b/pkg/volume/azure_dd/azure_provision_test.go @@ -20,10 +20,11 @@ package azure_dd import ( "fmt" + "reflect" "testing" "github.com/stretchr/testify/assert" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" ) func TestParseZoned(t *testing.T) { @@ -96,3 +97,69 @@ func TestParseZoned(t *testing.T) { } } } + +func TestConvertTagsToMap(t *testing.T) { + testCases := []struct { + desc string + tags string + expectedOutput map[string]string + expectedError bool + }{ + { + desc: "should return empty map when tag is empty", + tags: "", + expectedOutput: map[string]string{}, + expectedError: false, + }, + { + desc: "sing valid tag should be converted", + tags: "key=value", + expectedOutput: map[string]string{ + "key": "value", + }, + expectedError: false, + }, + { + desc: "multiple valid tags should be converted", + tags: "key1=value1,key2=value2", + expectedOutput: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + expectedError: false, + }, + { + desc: "whitespaces should be trimmed", + tags: "key1=value1, key2=value2", + expectedOutput: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + expectedError: false, + }, + { + desc: "should return error for invalid format", + tags: "foo,bar", + expectedOutput: nil, + expectedError: true, + }, + { + desc: "should return error for when key is missed", + tags: "key1=value1,=bar", + expectedOutput: nil, + expectedError: true, + }, + } + + for i, c := range testCases { + m, err := ConvertTagsToMap(c.tags) + if c.expectedError { + assert.NotNil(t, err, "TestCase[%d]: %s", i, c.desc) + } else { + assert.Nil(t, err, "TestCase[%d]: %s", i, c.desc) + if !reflect.DeepEqual(m, c.expectedOutput) { + t.Errorf("got: %v, expected: %v, desc: %v", m, c.expectedOutput, c.desc) + } + } + } +}