diff --git a/pkg/scheduler/algorithm/predicates/max_attachable_volume_predicate_test.go b/pkg/scheduler/algorithm/predicates/max_attachable_volume_predicate_test.go index af658268a79..1ad456ac317 100644 --- a/pkg/scheduler/algorithm/predicates/max_attachable_volume_predicate_test.go +++ b/pkg/scheduler/algorithm/predicates/max_attachable_volume_predicate_test.go @@ -1067,6 +1067,41 @@ func TestMaxVolumeFuncM4(t *testing.T) { } } +func TestMaxVolumeFuncM4WithOnlyStableLabels(t *testing.T) { + node := &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "node-for-m4-instance", + Labels: map[string]string{ + v1.LabelInstanceTypeStable: "m4.2xlarge", + }, + }, + } + os.Unsetenv(KubeMaxPDVols) + maxVolumeFunc := getMaxVolumeFunc(EBSVolumeFilterType) + maxVolume := maxVolumeFunc(node) + if maxVolume != volumeutil.DefaultMaxEBSVolumes { + t.Errorf("Expected max volume to be %d got %d", volumeutil.DefaultMaxEBSVolumes, maxVolume) + } +} + +func TestMaxVolumeFuncM4WithBothBetaAndStableLabels(t *testing.T) { + node := &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: "node-for-m4-instance", + Labels: map[string]string{ + v1.LabelInstanceType: "m4.2xlarge", + v1.LabelInstanceTypeStable: "m4.2xlarge", + }, + }, + } + os.Unsetenv(KubeMaxPDVols) + maxVolumeFunc := getMaxVolumeFunc(EBSVolumeFilterType) + maxVolume := maxVolumeFunc(node) + if maxVolume != volumeutil.DefaultMaxEBSVolumes { + t.Errorf("Expected max volume to be %d got %d", volumeutil.DefaultMaxEBSVolumes, maxVolume) + } +} + func getNodeWithPodAndVolumeLimits(limitSource string, pods []*v1.Pod, limit int64, driverNames ...string) (*schedulernodeinfo.NodeInfo, *v1beta1.CSINode) { nodeInfo := schedulernodeinfo.NewNodeInfo(pods...) node := &v1.Node{ diff --git a/pkg/scheduler/algorithm/predicates/predicates.go b/pkg/scheduler/algorithm/predicates/predicates.go index d4ff8cade1b..1005c5ac8f9 100644 --- a/pkg/scheduler/algorithm/predicates/predicates.go +++ b/pkg/scheduler/algorithm/predicates/predicates.go @@ -305,8 +305,9 @@ func getMaxVolumeFunc(filterName string) func(node *v1.Node) int { var nodeInstanceType string for k, v := range node.ObjectMeta.Labels { - if k == v1.LabelInstanceType { + if k == v1.LabelInstanceType || k == v1.LabelInstanceTypeStable { nodeInstanceType = v + break } } switch filterName {