diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 79afc933387..b80dff30e51 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -2175,7 +2175,6 @@ func (c *Cloud) CreateDisk(volumeOptions *VolumeOptions) (KubernetesVolumeID, er return "", fmt.Errorf("invalid AWS VolumeType %q", volumeOptions.VolumeType) } - // TODO: Should we tag this with the cluster id (so it gets deleted when the cluster does?) request := &ec2.CreateVolumeInput{} request.AvailabilityZone = aws.String(volumeOptions.AvailabilityZone) request.Size = aws.Int64(int64(volumeOptions.CapacityGB)) @@ -2188,6 +2187,21 @@ func (c *Cloud) CreateDisk(volumeOptions *VolumeOptions) (KubernetesVolumeID, er if iops > 0 { request.Iops = aws.Int64(iops) } + + tags := volumeOptions.Tags + tags = c.tagging.buildTags(ResourceLifecycleOwned, tags) + + var tagList []*ec2.Tag + for k, v := range tags { + tagList = append(tagList, &ec2.Tag{ + Key: aws.String(k), Value: aws.String(v), + }) + } + request.TagSpecifications = append(request.TagSpecifications, &ec2.TagSpecification{ + Tags: tagList, + ResourceType: aws.String(ec2.ResourceTypeVolume), + }) + response, err := c.ec2.CreateVolume(request) if err != nil { return "", err @@ -2199,17 +2213,6 @@ func (c *Cloud) CreateDisk(volumeOptions *VolumeOptions) (KubernetesVolumeID, er } volumeName := KubernetesVolumeID("aws://" + aws.StringValue(response.AvailabilityZone) + "/" + string(awsID)) - // apply tags - if err := c.tagging.createTags(c.ec2, string(awsID), ResourceLifecycleOwned, volumeOptions.Tags); err != nil { - // delete the volume and hope it succeeds - _, delerr := c.DeleteDisk(volumeName) - if delerr != nil { - // delete did not succeed, we have a stray volume! - return "", fmt.Errorf("error tagging volume %s, could not delete the volume: %q", volumeName, delerr) - } - return "", fmt.Errorf("error tagging volume %s: %q", volumeName, err) - } - // AWS has a bad habbit of reporting success when creating a volume with // encryption keys that either don't exists or have wrong permissions. // Such volume lives for couple of seconds and then it's silently deleted diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index 8d7531f992c..50e303d43d3 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -68,6 +68,11 @@ func (m *MockedFakeEC2) DescribeSecurityGroups(request *ec2.DescribeSecurityGrou return args.Get(0).([]*ec2.SecurityGroup), nil } +func (m *MockedFakeEC2) CreateVolume(request *ec2.CreateVolumeInput) (*ec2.Volume, error) { + args := m.Called(request) + return args.Get(0).(*ec2.Volume), nil +} + type MockedFakeELB struct { *FakeELB mock.Mock @@ -1393,6 +1398,38 @@ func TestFindSecurityGroupForInstanceMultipleTagged(t *testing.T) { assert.Contains(t, err.Error(), "sg123(another_group)") } +func TestCreateDisk(t *testing.T) { + awsServices := newMockedFakeAWSServices(TestClusterID) + c, _ := newAWSCloud(CloudConfig{}, awsServices) + + volumeOptions := &VolumeOptions{ + AvailabilityZone: "us-east-1a", + CapacityGB: 10, + } + request := &ec2.CreateVolumeInput{ + AvailabilityZone: aws.String("us-east-1a"), + Encrypted: aws.Bool(false), + VolumeType: aws.String(DefaultVolumeType), + Size: aws.Int64(10), + TagSpecifications: []*ec2.TagSpecification{ + {ResourceType: aws.String(ec2.ResourceTypeVolume), Tags: []*ec2.Tag{ + {Key: aws.String(TagNameKubernetesClusterLegacy), Value: aws.String(TestClusterID)}, + {Key: aws.String(fmt.Sprintf("%s%s", TagNameKubernetesClusterPrefix, TestClusterID)), Value: aws.String(ResourceLifecycleOwned)}, + }}, + }, + } + volume := &ec2.Volume{ + AvailabilityZone: aws.String("us-east-1a"), + VolumeId: aws.String("vol-volumeId0"), + } + awsServices.ec2.(*MockedFakeEC2).On("CreateVolume", request).Return(volume, nil) + + volumeID, err := c.CreateDisk(volumeOptions) + assert.Nil(t, err, "Error creating disk: %v", err) + assert.Equal(t, volumeID, KubernetesVolumeID("aws://us-east-1a/vol-volumeId0")) + awsServices.ec2.(*MockedFakeEC2).AssertExpectations(t) +} + func newMockedFakeAWSServices(id string) *FakeAWSServices { s := NewFakeAWSServices(id) s.ec2 = &MockedFakeEC2{FakeEC2Impl: s.ec2.(*FakeEC2Impl)}