diff --git a/pkg/cloudprovider/providers/aws/BUILD b/pkg/cloudprovider/providers/aws/BUILD index 22c40fd0ce4..0afec086431 100644 --- a/pkg/cloudprovider/providers/aws/BUILD +++ b/pkg/cloudprovider/providers/aws/BUILD @@ -78,6 +78,7 @@ go_test( embed = [":go_default_library"], deps = [ "//pkg/kubelet/apis:go_default_library", + "//pkg/volume:go_default_library", "//staging/src/k8s.io/api/core/v1:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/types:go_default_library", diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index a6c4365fceb..b67a09d6f17 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -2323,6 +2323,11 @@ func (c *Cloud) checkIfAvailable(disk *awsDisk, opName string, instance string) // GetLabelsForVolume gets the volume labels for a volume func (c *Cloud) GetLabelsForVolume(ctx context.Context, pv *v1.PersistentVolume) (map[string]string, error) { + // Ignore if not AWSElasticBlockStore. + if pv.Spec.AWSElasticBlockStore == nil { + return nil, nil + } + // Ignore any volumes that are being provisioned if pv.Spec.AWSElasticBlockStore.VolumeID == volume.ProvisionedVolumeName { return nil, nil diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index 72a799357e6..b3d5c774a12 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -36,6 +36,7 @@ import ( "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/types" kubeletapis "k8s.io/kubernetes/pkg/kubelet/apis" + "k8s.io/kubernetes/pkg/volume" ) const TestClusterID = "clusterid.test" @@ -907,6 +908,98 @@ func TestGetVolumeLabels(t *testing.T) { awsServices.ec2.(*MockedFakeEC2).AssertExpectations(t) } +func TestGetLabelsForVolume(t *testing.T) { + defaultVolume := EBSVolumeID("vol-VolumeId").awsString() + tests := []struct { + name string + pv *v1.PersistentVolume + expectedVolumeID *string + expectedEC2Volumes []*ec2.Volume + expectedLabels map[string]string + expectedError error + }{ + { + "not an EBS volume", + &v1.PersistentVolume{ + Spec: v1.PersistentVolumeSpec{}, + }, + nil, + nil, + nil, + nil, + }, + { + "volume which is being provisioned", + &v1.PersistentVolume{ + Spec: v1.PersistentVolumeSpec{ + PersistentVolumeSource: v1.PersistentVolumeSource{ + AWSElasticBlockStore: &v1.AWSElasticBlockStoreVolumeSource{ + VolumeID: volume.ProvisionedVolumeName, + }, + }, + }, + }, + nil, + nil, + nil, + nil, + }, + { + "no volumes found", + &v1.PersistentVolume{ + Spec: v1.PersistentVolumeSpec{ + PersistentVolumeSource: v1.PersistentVolumeSource{ + AWSElasticBlockStore: &v1.AWSElasticBlockStoreVolumeSource{ + VolumeID: "vol-VolumeId", + }, + }, + }, + }, + defaultVolume, + nil, + nil, + fmt.Errorf("no volumes found"), + }, + { + "correct labels for volume", + &v1.PersistentVolume{ + Spec: v1.PersistentVolumeSpec{ + PersistentVolumeSource: v1.PersistentVolumeSource{ + AWSElasticBlockStore: &v1.AWSElasticBlockStoreVolumeSource{ + VolumeID: "vol-VolumeId", + }, + }, + }, + }, + defaultVolume, + []*ec2.Volume{{ + VolumeId: defaultVolume, + AvailabilityZone: aws.String("us-east-1a"), + }}, + map[string]string{ + kubeletapis.LabelZoneFailureDomain: "us-east-1a", + kubeletapis.LabelZoneRegion: "us-east-1", + }, + nil, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + awsServices := newMockedFakeAWSServices(TestClusterID) + expectedVolumeRequest := &ec2.DescribeVolumesInput{VolumeIds: []*string{test.expectedVolumeID}} + awsServices.ec2.(*MockedFakeEC2).On("DescribeVolumes", expectedVolumeRequest).Return(test.expectedEC2Volumes) + + c, err := newAWSCloud(CloudConfig{}, awsServices) + assert.Nil(t, err, "Error building aws cloud: %v", err) + + l, err := c.GetLabelsForVolume(context.TODO(), test.pv) + assert.Equal(t, test.expectedLabels, l) + assert.Equal(t, test.expectedError, err) + }) + + } +} + func TestDescribeLoadBalancerOnDelete(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterID) c, _ := newAWSCloud(CloudConfig{}, awsServices)