diff --git a/pkg/controller/volume/scheduling/BUILD b/pkg/controller/volume/scheduling/BUILD index c9e2acec895..32affc54ccc 100644 --- a/pkg/controller/volume/scheduling/BUILD +++ b/pkg/controller/volume/scheduling/BUILD @@ -14,18 +14,24 @@ go_library( "//pkg/apis/core/v1/helper:go_default_library", "//pkg/controller/volume/persistentvolume/util:go_default_library", "//pkg/controller/volume/scheduling/metrics:go_default_library", + "//pkg/features:go_default_library", "//pkg/volume/util:go_default_library", "//staging/src/k8s.io/api/core/v1:go_default_library", + "//staging/src/k8s.io/api/storage/v1:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/api/meta:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/labels:go_default_library", + "//staging/src/k8s.io/apimachinery/pkg/util/sets:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/util/wait:go_default_library", "//staging/src/k8s.io/apiserver/pkg/storage/etcd3:go_default_library", + "//staging/src/k8s.io/apiserver/pkg/util/feature:go_default_library", "//staging/src/k8s.io/client-go/informers/core/v1:go_default_library", "//staging/src/k8s.io/client-go/informers/storage/v1:go_default_library", "//staging/src/k8s.io/client-go/kubernetes:go_default_library", "//staging/src/k8s.io/client-go/listers/storage/v1:go_default_library", "//staging/src/k8s.io/client-go/tools/cache:go_default_library", + "//staging/src/k8s.io/csi-translation-lib:go_default_library", + "//staging/src/k8s.io/csi-translation-lib/plugins:go_default_library", "//vendor/k8s.io/klog:go_default_library", ], ) @@ -43,6 +49,7 @@ go_test( "//pkg/controller:go_default_library", "//pkg/controller/volume/persistentvolume/testing:go_default_library", "//pkg/controller/volume/persistentvolume/util:go_default_library", + "//pkg/features:go_default_library", "//staging/src/k8s.io/api/core/v1:go_default_library", "//staging/src/k8s.io/api/storage/v1:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/api/resource:go_default_library", @@ -51,11 +58,14 @@ go_test( "//staging/src/k8s.io/apimachinery/pkg/util/diff:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/util/wait:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/watch:go_default_library", + "//staging/src/k8s.io/apiserver/pkg/util/feature:go_default_library", "//staging/src/k8s.io/client-go/informers:go_default_library", "//staging/src/k8s.io/client-go/informers/core/v1:go_default_library", + "//staging/src/k8s.io/client-go/informers/storage/v1:go_default_library", "//staging/src/k8s.io/client-go/kubernetes:go_default_library", "//staging/src/k8s.io/client-go/kubernetes/fake:go_default_library", "//staging/src/k8s.io/client-go/testing:go_default_library", + "//staging/src/k8s.io/component-base/featuregate/testing:go_default_library", "//vendor/k8s.io/klog:go_default_library", ], ) diff --git a/pkg/controller/volume/scheduling/scheduler_binder.go b/pkg/controller/volume/scheduling/scheduler_binder.go index 298ae211958..21bd947e134 100644 --- a/pkg/controller/volume/scheduling/scheduler_binder.go +++ b/pkg/controller/volume/scheduling/scheduler_binder.go @@ -19,24 +19,39 @@ package scheduling import ( "fmt" "sort" + "strings" "time" v1 "k8s.io/api/core/v1" + storagev1 "k8s.io/api/storage/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apiserver/pkg/storage/etcd3" + utilfeature "k8s.io/apiserver/pkg/util/feature" coreinformers "k8s.io/client-go/informers/core/v1" storageinformers "k8s.io/client-go/informers/storage/v1" clientset "k8s.io/client-go/kubernetes" storagelisters "k8s.io/client-go/listers/storage/v1" + csitrans "k8s.io/csi-translation-lib" + csiplugins "k8s.io/csi-translation-lib/plugins" "k8s.io/klog" v1helper "k8s.io/kubernetes/pkg/apis/core/v1/helper" pvutil "k8s.io/kubernetes/pkg/controller/volume/persistentvolume/util" "k8s.io/kubernetes/pkg/controller/volume/scheduling/metrics" + "k8s.io/kubernetes/pkg/features" volumeutil "k8s.io/kubernetes/pkg/volume/util" ) +// InTreeToCSITranslator contains methods required to check migratable status +// and perform translations from InTree PV's to CSI +type InTreeToCSITranslator interface { + IsPVMigratable(pv *v1.PersistentVolume) bool + GetInTreePluginNameFromSpec(pv *v1.PersistentVolume, vol *v1.Volume) (string, error) + TranslateInTreePVToCSI(pv *v1.PersistentVolume) (*v1.PersistentVolume, error) +} + // SchedulerVolumeBinder is used by the scheduler to handle PVC/PV binding // and dynamic provisioning. The binding decisions are integrated into the pod scheduling // workflow so that the PV NodeAffinity is also considered along with the pod's other @@ -103,9 +118,10 @@ type volumeBinder struct { kubeClient clientset.Interface classLister storagelisters.StorageClassLister - nodeInformer coreinformers.NodeInformer - pvcCache PVCAssumeCache - pvCache PVAssumeCache + nodeInformer coreinformers.NodeInformer + csiNodeInformer storageinformers.CSINodeInformer + pvcCache PVCAssumeCache + pvCache PVAssumeCache // Stores binding decisions that were made in FindPodVolumes for use in AssumePodVolumes. // AssumePodVolumes modifies the bindings again for use in BindPodVolumes. @@ -113,12 +129,15 @@ type volumeBinder struct { // Amount of time to wait for the bind operation to succeed bindTimeout time.Duration + + translator InTreeToCSITranslator } // NewVolumeBinder sets up all the caches needed for the scheduler to make volume binding decisions. func NewVolumeBinder( kubeClient clientset.Interface, nodeInformer coreinformers.NodeInformer, + csiNodeInformer storageinformers.CSINodeInformer, pvcInformer coreinformers.PersistentVolumeClaimInformer, pvInformer coreinformers.PersistentVolumeInformer, storageClassInformer storageinformers.StorageClassInformer, @@ -128,10 +147,12 @@ func NewVolumeBinder( kubeClient: kubeClient, classLister: storageClassInformer.Lister(), nodeInformer: nodeInformer, + csiNodeInformer: csiNodeInformer, pvcCache: NewPVCAssumeCache(pvcInformer.Informer()), pvCache: NewPVAssumeCache(pvInformer.Informer()), podBindingCache: NewPodBindingCache(), bindTimeout: bindTimeout, + translator: csitrans.New(), } return b @@ -457,6 +478,12 @@ func (b *volumeBinder) checkBindings(pod *v1.Pod, bindings []*bindingInfo, claim return false, fmt.Errorf("failed to get node %q: %v", pod.Spec.NodeName, err) } + csiNode, err := b.csiNodeInformer.Lister().Get(node.Name) + if err != nil { + // TODO: return the error once CSINode is created by default + klog.V(4).Infof("Could not get a CSINode object for the node %q: %v", node.Name, err) + } + // Check for any conditions that might require scheduling retry // When pod is removed from scheduling queue because of deletion or any @@ -485,6 +512,11 @@ func (b *volumeBinder) checkBindings(pod *v1.Pod, bindings []*bindingInfo, claim return false, nil } + pv, err = b.tryTranslatePVToCSI(pv, csiNode) + if err != nil { + return false, fmt.Errorf("failed to translate pv to csi: %v", err) + } + // Check PV's node affinity (the node might not have the proper label) if err := volumeutil.CheckNodeAffinity(pv, node.Labels); err != nil { return false, fmt.Errorf("pv %q node affinity doesn't match node %q: %v", pv.Name, node.Name, err) @@ -538,6 +570,12 @@ func (b *volumeBinder) checkBindings(pod *v1.Pod, bindings []*bindingInfo, claim } return false, fmt.Errorf("failed to get pv %q from cache: %v", pvc.Spec.VolumeName, err) } + + pv, err = b.tryTranslatePVToCSI(pv, csiNode) + if err != nil { + return false, err + } + if err := volumeutil.CheckNodeAffinity(pv, node.Labels); err != nil { return false, fmt.Errorf("pv %q node affinity doesn't match node %q: %v", pv.Name, node.Name, err) } @@ -641,6 +679,12 @@ func (b *volumeBinder) getPodVolumes(pod *v1.Pod) (boundClaims []*v1.PersistentV } func (b *volumeBinder) checkBoundClaims(claims []*v1.PersistentVolumeClaim, node *v1.Node, podName string) (bool, error) { + csiNode, err := b.csiNodeInformer.Lister().Get(node.Name) + if err != nil { + // TODO: return the error once CSINode is created by default + klog.V(4).Infof("Could not get a CSINode object for the node %q: %v", node.Name, err) + } + for _, pvc := range claims { pvName := pvc.Spec.VolumeName pv, err := b.pvCache.GetPV(pvName) @@ -648,6 +692,11 @@ func (b *volumeBinder) checkBoundClaims(claims []*v1.PersistentVolumeClaim, node return false, err } + pv, err = b.tryTranslatePVToCSI(pv, csiNode) + if err != nil { + return false, err + } + err = volumeutil.CheckNodeAffinity(pv, node.Labels) if err != nil { klog.V(4).Infof("PersistentVolume %q, Node %q mismatch for Pod %q: %v", pvName, node.Name, podName, err) @@ -783,3 +832,72 @@ func (a byPVCSize) Less(i, j int) bool { func claimToClaimKey(claim *v1.PersistentVolumeClaim) string { return fmt.Sprintf("%s/%s", claim.Namespace, claim.Name) } + +// isCSIMigrationOnForPlugin checks if CSI migrartion is enabled for a given plugin. +func isCSIMigrationOnForPlugin(pluginName string) bool { + switch pluginName { + case csiplugins.AWSEBSInTreePluginName: + return utilfeature.DefaultFeatureGate.Enabled(features.CSIMigrationAWS) + case csiplugins.GCEPDInTreePluginName: + return utilfeature.DefaultFeatureGate.Enabled(features.CSIMigrationGCE) + case csiplugins.AzureDiskInTreePluginName: + return utilfeature.DefaultFeatureGate.Enabled(features.CSIMigrationAzureDisk) + case csiplugins.CinderInTreePluginName: + return utilfeature.DefaultFeatureGate.Enabled(features.CSIMigrationOpenStack) + } + return false +} + +// isPluginMigratedToCSIOnNode checks if an in-tree plugin has been migrated to a CSI driver on the node. +func isPluginMigratedToCSIOnNode(pluginName string, csiNode *storagev1.CSINode) bool { + if csiNode == nil { + return false + } + + csiNodeAnn := csiNode.GetAnnotations() + if csiNodeAnn == nil { + return false + } + + var mpaSet sets.String + mpa := csiNodeAnn[v1.MigratedPluginsAnnotationKey] + if len(mpa) == 0 { + mpaSet = sets.NewString() + } else { + tok := strings.Split(mpa, ",") + mpaSet = sets.NewString(tok...) + } + + return mpaSet.Has(pluginName) +} + +// tryTranslatePVToCSI will translate the in-tree PV to CSI if it meets the criteria. If not, it returns the unmodified in-tree PV. +func (b *volumeBinder) tryTranslatePVToCSI(pv *v1.PersistentVolume, csiNode *storagev1.CSINode) (*v1.PersistentVolume, error) { + if !b.translator.IsPVMigratable(pv) { + return pv, nil + } + + if !utilfeature.DefaultFeatureGate.Enabled(features.CSIMigration) { + return pv, nil + } + + pluginName, err := b.translator.GetInTreePluginNameFromSpec(pv, nil) + if err != nil { + return nil, fmt.Errorf("could not get plugin name from pv: %v", err) + } + + if !isCSIMigrationOnForPlugin(pluginName) { + return pv, nil + } + + if !isPluginMigratedToCSIOnNode(pluginName, csiNode) { + return pv, nil + } + + transPV, err := b.translator.TranslateInTreePVToCSI(pv) + if err != nil { + return nil, fmt.Errorf("could not translate pv: %v", err) + } + + return transPV, nil +} diff --git a/pkg/controller/volume/scheduling/scheduler_binder_test.go b/pkg/controller/volume/scheduling/scheduler_binder_test.go index 82b4d435556..f47f47232cc 100644 --- a/pkg/controller/volume/scheduling/scheduler_binder_test.go +++ b/pkg/controller/volume/scheduling/scheduler_binder_test.go @@ -31,16 +31,20 @@ import ( "k8s.io/apimachinery/pkg/util/diff" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/watch" + utilfeature "k8s.io/apiserver/pkg/util/feature" "k8s.io/client-go/informers" coreinformers "k8s.io/client-go/informers/core/v1" + storageinformers "k8s.io/client-go/informers/storage/v1" clientset "k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes/fake" k8stesting "k8s.io/client-go/testing" + featuregatetesting "k8s.io/component-base/featuregate/testing" "k8s.io/klog" "k8s.io/kubernetes/pkg/api/testapi" "k8s.io/kubernetes/pkg/controller" pvtesting "k8s.io/kubernetes/pkg/controller/volume/persistentvolume/testing" pvutil "k8s.io/kubernetes/pkg/controller/volume/persistentvolume/util" + "k8s.io/kubernetes/pkg/features" ) var ( @@ -65,6 +69,10 @@ var ( selectedNodePVC = makeTestPVC("provisioned-pvc", "1Gi", nodeLabelValue, pvcSelectedNode, "", "1", &waitClassWithProvisioner) + // PVCs for CSI migration + boundMigrationPVC = makeTestPVC("pvc-migration-bound", "1G", "", pvcBound, "pv-migration-bound", "1", &waitClass) + provMigrationPVCBound = makeTestPVC("pvc-migration-provisioned", "1Gi", "", pvcBound, "pv-migration-bound", "1", &waitClassWithProvisioner) + // PVs for manual binding pvNode1a = makeTestPV("pv-node1a", "node1", "5G", "1", nil, waitClass) pvNode1b = makeTestPV("pv-node1b", "node1", "10G", "1", nil, waitClass) @@ -77,6 +85,10 @@ var ( pvBoundImmediate = makeTestPV("pv-bound-immediate", "node1", "1G", "1", immediateBoundPVC, immediateClass) pvBoundImmediateNode2 = makeTestPV("pv-bound-immediate", "node2", "1G", "1", immediateBoundPVC, immediateClass) + // PVs for CSI migration + migrationPVBound = makeTestPVForCSIMigration(zone1Labels, boundMigrationPVC) + migrationPVBoundToUnbound = makeTestPVForCSIMigration(zone1Labels, unboundPVC) + // storage class names waitClass = "waitClass" immediateClass = "immediateClass" @@ -87,20 +99,30 @@ var ( node1 = makeNode("node1", map[string]string{nodeLabelKey: "node1"}) node2 = makeNode("node2", map[string]string{nodeLabelKey: "node2"}) node1NoLabels = makeNode("node1", nil) + node1Zone1 = makeNode("node1", map[string]string{"topology.gke.io/zone": "us-east-1"}) + node1Zone2 = makeNode("node1", map[string]string{"topology.gke.io/zone": "us-east-2"}) + + // csiNode objects + csiNode1Migrated = makeCSINode("node1", "kubernetes.io/gce-pd") + csiNode1NotMigrated = makeCSINode("node1", "") // node topology nodeLabelKey = "nodeKey" nodeLabelValue = "node1" + + // node topology for CSI migration + zone1Labels = map[string]string{v1.LabelZoneFailureDomain: "us-east-1", v1.LabelZoneRegion: "us-east-1a"} ) type testEnv struct { - client clientset.Interface - reactor *pvtesting.VolumeReactor - binder SchedulerVolumeBinder - internalBinder *volumeBinder - internalNodeInformer coreinformers.NodeInformer - internalPVCache *assumeCache - internalPVCCache *assumeCache + client clientset.Interface + reactor *pvtesting.VolumeReactor + binder SchedulerVolumeBinder + internalBinder *volumeBinder + internalNodeInformer coreinformers.NodeInformer + internalCSINodeInformer storageinformers.CSINodeInformer + internalPVCache *assumeCache + internalPVCCache *assumeCache } func newTestBinder(t *testing.T, stopCh <-chan struct{}) *testEnv { @@ -119,11 +141,13 @@ func newTestBinder(t *testing.T, stopCh <-chan struct{}) *testEnv { informerFactory := informers.NewSharedInformerFactory(client, controller.NoResyncPeriodFunc()) nodeInformer := informerFactory.Core().V1().Nodes() + csiNodeInformer := informerFactory.Storage().V1().CSINodes() pvcInformer := informerFactory.Core().V1().PersistentVolumeClaims() classInformer := informerFactory.Storage().V1().StorageClasses() binder := NewVolumeBinder( client, nodeInformer, + csiNodeInformer, pvcInformer, informerFactory.Core().V1().PersistentVolumes(), classInformer, @@ -214,13 +238,14 @@ func newTestBinder(t *testing.T, stopCh <-chan struct{}) *testEnv { } return &testEnv{ - client: client, - reactor: reactor, - binder: binder, - internalBinder: internalBinder, - internalNodeInformer: nodeInformer, - internalPVCache: internalPVCache, - internalPVCCache: internalPVCCache, + client: client, + reactor: reactor, + binder: binder, + internalBinder: internalBinder, + internalNodeInformer: nodeInformer, + internalCSINodeInformer: csiNodeInformer, + internalPVCache: internalPVCache, + internalPVCCache: internalPVCCache, } } @@ -231,6 +256,13 @@ func (env *testEnv) initNodes(cachedNodes []*v1.Node) { } } +func (env *testEnv) initCSINodes(cachedCSINodes []*storagev1.CSINode) { + csiNodeInformer := env.internalCSINodeInformer.Informer() + for _, csiNode := range cachedCSINodes { + csiNodeInformer.GetIndexer().Add(csiNode) + } +} + func (env *testEnv) initClaims(cachedPVCs []*v1.PersistentVolumeClaim, apiPVCs []*v1.PersistentVolumeClaim) { internalPVCCache := env.internalPVCCache for _, pvc := range cachedPVCs { @@ -593,6 +625,21 @@ func makeTestPV(name, node, capacity, version string, boundToPVC *v1.PersistentV return pv } +func makeTestPVForCSIMigration(labels map[string]string, pvc *v1.PersistentVolumeClaim) *v1.PersistentVolume { + pv := makeTestPV("pv-migration-bound", "node1", "1G", "1", pvc, waitClass) + pv.Spec.NodeAffinity = nil // Will be written by the CSI translation lib + pv.ObjectMeta.Labels = labels + pv.Spec.PersistentVolumeSource = v1.PersistentVolumeSource{ + GCEPersistentDisk: &v1.GCEPersistentDiskVolumeSource{ + PDName: "test-disk", + FSType: "ext4", + Partition: 0, + ReadOnly: false, + }, + } + return pv +} + func pvcSetSelectedNode(pvc *v1.PersistentVolumeClaim, node string) *v1.PersistentVolumeClaim { newPVC := pvc.DeepCopy() metav1.SetMetaDataAnnotation(&newPVC.ObjectMeta, pvutil.AnnSelectedNode, node) @@ -620,6 +667,17 @@ func makeNode(name string, labels map[string]string) *v1.Node { } } +func makeCSINode(name, migratedPlugin string) *storagev1.CSINode { + return &storagev1.CSINode{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Annotations: map[string]string{ + v1.MigratedPluginsAnnotationKey: migratedPlugin, + }, + }, + } +} + func makePod(pvcs []*v1.PersistentVolumeClaim) *v1.Pod { pod := &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ @@ -983,6 +1041,115 @@ func TestFindPodVolumesWithProvisioning(t *testing.T) { } } +// TestFindPodVolumesWithCSIMigration aims to test the node affinity check procedure that's +// done in FindPodVolumes. In order to reach this code path, the given PVCs must be bound to a PV. +func TestFindPodVolumesWithCSIMigration(t *testing.T) { + scenarios := map[string]struct { + // Inputs + pvs []*v1.PersistentVolume + podPVCs []*v1.PersistentVolumeClaim + // If nil, use pod PVCs + cachePVCs []*v1.PersistentVolumeClaim + // If nil, makePod with podPVCs + pod *v1.Pod + + // Setup + initNodes []*v1.Node + initCSINodes []*storagev1.CSINode + + // Expected return values + expectedUnbound bool + expectedBound bool + shouldFail bool + }{ + "pvc-bound": { + podPVCs: []*v1.PersistentVolumeClaim{boundMigrationPVC}, + pvs: []*v1.PersistentVolume{migrationPVBound}, + initNodes: []*v1.Node{node1Zone1}, + initCSINodes: []*storagev1.CSINode{csiNode1Migrated}, + expectedBound: true, + expectedUnbound: true, + }, + "pvc-bound,csinode-not-migrated": { + podPVCs: []*v1.PersistentVolumeClaim{boundMigrationPVC}, + pvs: []*v1.PersistentVolume{migrationPVBound}, + initNodes: []*v1.Node{node1Zone1}, + initCSINodes: []*storagev1.CSINode{csiNode1NotMigrated}, + expectedBound: true, + expectedUnbound: true, + }, + "pvc-bound,missing-csinode": { + podPVCs: []*v1.PersistentVolumeClaim{boundMigrationPVC}, + pvs: []*v1.PersistentVolume{migrationPVBound}, + initNodes: []*v1.Node{node1Zone1}, + expectedBound: true, + expectedUnbound: true, + }, + "pvc-bound,node-different-zone": { + podPVCs: []*v1.PersistentVolumeClaim{boundMigrationPVC}, + pvs: []*v1.PersistentVolume{migrationPVBound}, + initNodes: []*v1.Node{node1Zone2}, + initCSINodes: []*storagev1.CSINode{csiNode1Migrated}, + expectedBound: false, + expectedUnbound: true, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + defer featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features.CSIMigration, true)() + defer featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features.CSIMigrationGCE, true)() + + for name, scenario := range scenarios { + klog.V(5).Infof("Running test case %q", name) + + // Setup + testEnv := newTestBinder(t, ctx.Done()) + testEnv.initVolumes(scenario.pvs, scenario.pvs) + + var node *v1.Node + if len(scenario.initNodes) > 0 { + testEnv.initNodes(scenario.initNodes) + node = scenario.initNodes[0] + } else { + node = node1 + } + + if len(scenario.initCSINodes) > 0 { + testEnv.initCSINodes(scenario.initCSINodes) + } + + // a. Init pvc cache + if scenario.cachePVCs == nil { + scenario.cachePVCs = scenario.podPVCs + } + testEnv.initClaims(scenario.cachePVCs, scenario.cachePVCs) + + // b. Generate pod with given claims + if scenario.pod == nil { + scenario.pod = makePod(scenario.podPVCs) + } + + // Execute + unboundSatisfied, boundSatisfied, err := testEnv.binder.FindPodVolumes(scenario.pod, node) + + // Validate + if !scenario.shouldFail && err != nil { + t.Errorf("Test %q failed: returned error: %v", name, err) + } + if scenario.shouldFail && err == nil { + t.Errorf("Test %q failed: returned success but expected error", name) + } + if boundSatisfied != scenario.expectedBound { + t.Errorf("Test %q failed: expected boundSatsified %v, got %v", name, scenario.expectedBound, boundSatisfied) + } + if unboundSatisfied != scenario.expectedUnbound { + t.Errorf("Test %q failed: expected unboundSatsified %v, got %v", name, scenario.expectedUnbound, unboundSatisfied) + } + } +} + func TestAssumePodVolumes(t *testing.T) { scenarios := map[string]struct { // Inputs @@ -1414,6 +1581,122 @@ func TestCheckBindings(t *testing.T) { } } +func TestCheckBindingsWithCSIMigration(t *testing.T) { + scenarios := map[string]struct { + // Inputs + initPVs []*v1.PersistentVolume + initPVCs []*v1.PersistentVolumeClaim + initNodes []*v1.Node + initCSINodes []*storagev1.CSINode + + bindings []*bindingInfo + provisionedPVCs []*v1.PersistentVolumeClaim + + // API updates before checking + apiPVs []*v1.PersistentVolume + apiPVCs []*v1.PersistentVolumeClaim + + // Expected return values + shouldFail bool + expectedBound bool + migrationEnabled bool + }{ + "provisioning-pvc-bound": { + bindings: []*bindingInfo{}, + provisionedPVCs: []*v1.PersistentVolumeClaim{addProvisionAnn(provMigrationPVCBound)}, + initPVs: []*v1.PersistentVolume{migrationPVBound}, + initPVCs: []*v1.PersistentVolumeClaim{provMigrationPVCBound}, + initNodes: []*v1.Node{node1Zone1}, + initCSINodes: []*storagev1.CSINode{csiNode1Migrated}, + apiPVCs: []*v1.PersistentVolumeClaim{addProvisionAnn(provMigrationPVCBound)}, + expectedBound: true, + }, + "binding-node-pv-same-zone": { + bindings: []*bindingInfo{makeBinding(unboundPVC, migrationPVBoundToUnbound)}, + provisionedPVCs: []*v1.PersistentVolumeClaim{}, + initPVs: []*v1.PersistentVolume{migrationPVBoundToUnbound}, + initPVCs: []*v1.PersistentVolumeClaim{unboundPVC}, + initNodes: []*v1.Node{node1Zone1}, + initCSINodes: []*storagev1.CSINode{csiNode1Migrated}, + migrationEnabled: true, + }, + "binding-without-csinode": { + bindings: []*bindingInfo{makeBinding(unboundPVC, migrationPVBoundToUnbound)}, + provisionedPVCs: []*v1.PersistentVolumeClaim{}, + initPVs: []*v1.PersistentVolume{migrationPVBoundToUnbound}, + initPVCs: []*v1.PersistentVolumeClaim{unboundPVC}, + initNodes: []*v1.Node{node1Zone1}, + initCSINodes: []*storagev1.CSINode{}, + migrationEnabled: true, + }, + "binding-non-migrated-plugin": { + bindings: []*bindingInfo{makeBinding(unboundPVC, migrationPVBoundToUnbound)}, + provisionedPVCs: []*v1.PersistentVolumeClaim{}, + initPVs: []*v1.PersistentVolume{migrationPVBoundToUnbound}, + initPVCs: []*v1.PersistentVolumeClaim{unboundPVC}, + initNodes: []*v1.Node{node1Zone1}, + initCSINodes: []*storagev1.CSINode{csiNode1NotMigrated}, + migrationEnabled: true, + }, + "binding-node-pv-in-different-zones": { + bindings: []*bindingInfo{makeBinding(unboundPVC, migrationPVBoundToUnbound)}, + provisionedPVCs: []*v1.PersistentVolumeClaim{}, + initPVs: []*v1.PersistentVolume{migrationPVBoundToUnbound}, + initPVCs: []*v1.PersistentVolumeClaim{unboundPVC}, + initNodes: []*v1.Node{node1Zone2}, + initCSINodes: []*storagev1.CSINode{csiNode1Migrated}, + migrationEnabled: true, + shouldFail: true, + }, + "binding-node-pv-different-zones-migration-off": { + bindings: []*bindingInfo{makeBinding(unboundPVC, migrationPVBoundToUnbound)}, + provisionedPVCs: []*v1.PersistentVolumeClaim{}, + initPVs: []*v1.PersistentVolume{migrationPVBoundToUnbound}, + initPVCs: []*v1.PersistentVolumeClaim{unboundPVC}, + initNodes: []*v1.Node{node1Zone2}, + initCSINodes: []*storagev1.CSINode{csiNode1Migrated}, + migrationEnabled: false, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for name, scenario := range scenarios { + t.Run(name, func(t *testing.T) { + defer featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features.CSIMigration, scenario.migrationEnabled)() + defer featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features.CSIMigrationGCE, scenario.migrationEnabled)() + + // Setup + pod := makePod(nil) + testEnv := newTestBinder(t, ctx.Done()) + testEnv.initNodes(scenario.initNodes) + testEnv.initCSINodes(scenario.initCSINodes) + testEnv.initVolumes(scenario.initPVs, nil) + testEnv.initClaims(scenario.initPVCs, nil) + testEnv.assumeVolumes(t, name, "node1", pod, scenario.bindings, scenario.provisionedPVCs) + + // Before execute + testEnv.updateVolumes(t, scenario.apiPVs, true) + testEnv.updateClaims(t, scenario.apiPVCs, true) + + // Execute + allBound, err := testEnv.internalBinder.checkBindings(pod, scenario.bindings, scenario.provisionedPVCs) + + // Validate + if !scenario.shouldFail && err != nil { + t.Errorf("Test %q failed: returned error: %v", name, err) + } + if scenario.shouldFail && err == nil { + t.Errorf("Test %q failed: returned success but expected error", name) + } + if scenario.expectedBound != allBound { + t.Errorf("Test %q failed: returned bound %v", name, allBound) + } + }) + } +} + func TestBindPodVolumes(t *testing.T) { type scenarioType struct { // Inputs diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index 7f652116e1d..0b237119944 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -290,6 +290,7 @@ func New(client clientset.Interface, volumeBinder := volumebinder.NewVolumeBinder( client, informerFactory.Core().V1().Nodes(), + informerFactory.Storage().V1().CSINodes(), informerFactory.Core().V1().PersistentVolumeClaims(), informerFactory.Core().V1().PersistentVolumes(), informerFactory.Storage().V1().StorageClasses(), diff --git a/pkg/scheduler/volumebinder/volume_binder.go b/pkg/scheduler/volumebinder/volume_binder.go index fdc3b3e32d1..32d14059ce1 100644 --- a/pkg/scheduler/volumebinder/volume_binder.go +++ b/pkg/scheduler/volumebinder/volume_binder.go @@ -19,7 +19,7 @@ package volumebinder import ( "time" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" coreinformers "k8s.io/client-go/informers/core/v1" storageinformers "k8s.io/client-go/informers/storage/v1" clientset "k8s.io/client-go/kubernetes" @@ -35,13 +35,14 @@ type VolumeBinder struct { func NewVolumeBinder( client clientset.Interface, nodeInformer coreinformers.NodeInformer, + csiNodeInformer storageinformers.CSINodeInformer, pvcInformer coreinformers.PersistentVolumeClaimInformer, pvInformer coreinformers.PersistentVolumeInformer, storageClassInformer storageinformers.StorageClassInformer, bindTimeout time.Duration) *VolumeBinder { return &VolumeBinder{ - Binder: volumescheduling.NewVolumeBinder(client, nodeInformer, pvcInformer, pvInformer, storageClassInformer, bindTimeout), + Binder: volumescheduling.NewVolumeBinder(client, nodeInformer, csiNodeInformer, pvcInformer, pvInformer, storageClassInformer, bindTimeout), } } diff --git a/staging/src/k8s.io/csi-translation-lib/BUILD b/staging/src/k8s.io/csi-translation-lib/BUILD index 64333e4fba2..cf7e98cc4c5 100644 --- a/staging/src/k8s.io/csi-translation-lib/BUILD +++ b/staging/src/k8s.io/csi-translation-lib/BUILD @@ -17,7 +17,11 @@ go_test( name = "go_default_test", srcs = ["translate_test.go"], embed = [":go_default_library"], - deps = ["//staging/src/k8s.io/api/core/v1:go_default_library"], + deps = [ + "//staging/src/k8s.io/api/core/v1:go_default_library", + "//staging/src/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library", + "//staging/src/k8s.io/csi-translation-lib/plugins:go_default_library", + ], ) filegroup( diff --git a/staging/src/k8s.io/csi-translation-lib/plugins/aws_ebs.go b/staging/src/k8s.io/csi-translation-lib/plugins/aws_ebs.go index 02a8ca1c42a..5390cbedd5d 100644 --- a/staging/src/k8s.io/csi-translation-lib/plugins/aws_ebs.go +++ b/staging/src/k8s.io/csi-translation-lib/plugins/aws_ebs.go @@ -31,6 +31,8 @@ import ( const ( // AWSEBSDriverName is the name of the CSI driver for EBS AWSEBSDriverName = "ebs.csi.aws.com" + // AWSEBSTopologyKey is the zonal topology key for AWS EBS CSI Driver + AWSEBSTopologyKey = "topology.ebs.csi.aws.com/zone" // AWSEBSInTreePluginName is the name of the intree plugin for EBS AWSEBSInTreePluginName = "kubernetes.io/aws-ebs" ) @@ -121,6 +123,10 @@ func (t *awsElasticBlockStoreCSITranslator) TranslateInTreePVToCSI(pv *v1.Persis }, } + if err := translateTopology(pv, AWSEBSTopologyKey); err != nil { + return nil, fmt.Errorf("failed to translate topology: %v", err) + } + pv.Spec.AWSElasticBlockStore = nil pv.Spec.CSI = csiSource return pv, nil diff --git a/staging/src/k8s.io/csi-translation-lib/plugins/gce_pd.go b/staging/src/k8s.io/csi-translation-lib/plugins/gce_pd.go index 7f07c814bb4..db93112f7ca 100644 --- a/staging/src/k8s.io/csi-translation-lib/plugins/gce_pd.go +++ b/staging/src/k8s.io/csi-translation-lib/plugins/gce_pd.go @@ -274,6 +274,10 @@ func (g *gcePersistentDiskCSITranslator) TranslateInTreePVToCSI(pv *v1.Persisten }, } + if err := translateTopology(pv, GCEPDTopologyKey); err != nil { + return nil, fmt.Errorf("failed to translate topology: %v", err) + } + pv.Spec.PersistentVolumeSource.GCEPersistentDisk = nil pv.Spec.PersistentVolumeSource.CSI = csiSource pv.Spec.AccessModes = backwardCompatibleAccessModes(pv.Spec.AccessModes) diff --git a/staging/src/k8s.io/csi-translation-lib/plugins/in_tree_volume.go b/staging/src/k8s.io/csi-translation-lib/plugins/in_tree_volume.go index 1415cbc7036..a93d82cfb3f 100644 --- a/staging/src/k8s.io/csi-translation-lib/plugins/in_tree_volume.go +++ b/staging/src/k8s.io/csi-translation-lib/plugins/in_tree_volume.go @@ -17,8 +17,13 @@ limitations under the License. package plugins import ( + "errors" + "strings" + v1 "k8s.io/api/core/v1" storage "k8s.io/api/storage/v1" + "k8s.io/apimachinery/pkg/util/sets" + cloudvolume "k8s.io/cloud-provider/volume" ) // InTreePlugin handles translations between CSI and in-tree sources in a PV @@ -59,3 +64,92 @@ type InTreePlugin interface { // RepairVolumeHandle generates a correct volume handle based on node ID information. RepairVolumeHandle(volumeHandle, nodeID string) (string, error) } + +// replaceTopology overwrites an existing topology key by a new one. +func replaceTopology(pv *v1.PersistentVolume, oldKey, newKey string) error { + for i := range pv.Spec.NodeAffinity.Required.NodeSelectorTerms { + for j, r := range pv.Spec.NodeAffinity.Required.NodeSelectorTerms[i].MatchExpressions { + if r.Key == oldKey { + pv.Spec.NodeAffinity.Required.NodeSelectorTerms[i].MatchExpressions[j].Key = newKey + } + } + } + return nil +} + +// getTopologyZones returns all topology zones with the given key found in the PV. +func getTopologyZones(pv *v1.PersistentVolume, key string) []string { + if pv.Spec.NodeAffinity == nil || + pv.Spec.NodeAffinity.Required == nil || + len(pv.Spec.NodeAffinity.Required.NodeSelectorTerms) < 1 { + return nil + } + + var values []string + for i := range pv.Spec.NodeAffinity.Required.NodeSelectorTerms { + for _, r := range pv.Spec.NodeAffinity.Required.NodeSelectorTerms[i].MatchExpressions { + if r.Key == key { + values = append(values, r.Values...) + } + } + } + return values +} + +// addTopology appends the topology to the given PV. +func addTopology(pv *v1.PersistentVolume, topologyKey string, zones []string) error { + // Make sure there are no duplicate or empty strings + filteredZones := sets.String{} + for i := range zones { + zone := strings.TrimSpace(zones[i]) + if len(zone) > 0 { + filteredZones.Insert(zone) + } + } + + zones = filteredZones.UnsortedList() + if len(zones) < 1 { + return errors.New("there are no valid zones to add to pv") + } + + // Make sure the necessary fields exist + pv.Spec.NodeAffinity = new(v1.VolumeNodeAffinity) + pv.Spec.NodeAffinity.Required = new(v1.NodeSelector) + pv.Spec.NodeAffinity.Required.NodeSelectorTerms = make([]v1.NodeSelectorTerm, 1) + + topology := v1.NodeSelectorRequirement{ + Key: topologyKey, + Operator: v1.NodeSelectorOpIn, + Values: zones, + } + + pv.Spec.NodeAffinity.Required.NodeSelectorTerms[0].MatchExpressions = append( + pv.Spec.NodeAffinity.Required.NodeSelectorTerms[0].MatchExpressions, + topology, + ) + + return nil +} + +// translateTopology converts existing zone labels or in-tree topology to CSI topology. +// In-tree topology has precedence over zone labels. +func translateTopology(pv *v1.PersistentVolume, topologyKey string) error { + // If topology is already set, assume the content is accurate + if len(getTopologyZones(pv, topologyKey)) > 0 { + return nil + } + + zones := getTopologyZones(pv, v1.LabelZoneFailureDomain) + if len(zones) > 0 { + return replaceTopology(pv, v1.LabelZoneFailureDomain, topologyKey) + } + + if label, ok := pv.Labels[v1.LabelZoneFailureDomain]; ok { + zones = strings.Split(label, cloudvolume.LabelMultiZoneDelimiter) + if len(zones) > 0 { + return addTopology(pv, topologyKey, zones) + } + } + + return nil +} diff --git a/staging/src/k8s.io/csi-translation-lib/plugins/openstack_cinder.go b/staging/src/k8s.io/csi-translation-lib/plugins/openstack_cinder.go index abf578fb4a5..33252bad340 100644 --- a/staging/src/k8s.io/csi-translation-lib/plugins/openstack_cinder.go +++ b/staging/src/k8s.io/csi-translation-lib/plugins/openstack_cinder.go @@ -27,6 +27,8 @@ import ( const ( // CinderDriverName is the name of the CSI driver for Cinder CinderDriverName = "cinder.csi.openstack.org" + // CinderTopologyKey is the zonal topology key for Cinder CSI Driver + CinderTopologyKey = "topology.cinder.csi.openstack.org/zone" // CinderInTreePluginName is the name of the intree plugin for Cinder CinderInTreePluginName = "kubernetes.io/cinder" ) @@ -92,6 +94,10 @@ func (t *osCinderCSITranslator) TranslateInTreePVToCSI(pv *v1.PersistentVolume) VolumeAttributes: map[string]string{}, } + if err := translateTopology(pv, CinderTopologyKey); err != nil { + return nil, fmt.Errorf("failed to translate topology: %v", err) + } + pv.Spec.Cinder = nil pv.Spec.CSI = csiSource return pv, nil diff --git a/staging/src/k8s.io/csi-translation-lib/translate_test.go b/staging/src/k8s.io/csi-translation-lib/translate_test.go index a78f46ae923..83e6a125cf4 100644 --- a/staging/src/k8s.io/csi-translation-lib/translate_test.go +++ b/staging/src/k8s.io/csi-translation-lib/translate_test.go @@ -20,7 +20,19 @@ import ( "reflect" "testing" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/csi-translation-lib/plugins" +) + +var ( + defaultZoneLabels = map[string]string{ + v1.LabelZoneFailureDomain: "us-east-1a", + v1.LabelZoneRegion: "us-east-1", + } + regionalPDLabels = map[string]string{ + v1.LabelZoneFailureDomain: "europe-west1-b__europe-west1-c", + } ) func TestTranslationStability(t *testing.T) { @@ -77,6 +89,229 @@ func TestTranslationStability(t *testing.T) { } } +func TestTopologyTranslation(t *testing.T) { + testCases := []struct { + name string + pv *v1.PersistentVolume + expectedNodeAffinity *v1.VolumeNodeAffinity + }{ + { + name: "GCE PD with zone labels", + pv: makeGCEPDPV(defaultZoneLabels, nil /*topology*/), + expectedNodeAffinity: makeNodeAffinity(false /*multiTerms*/, plugins.GCEPDTopologyKey, "us-east-1a"), + }, + { + name: "GCE PD with existing topology (beta keys)", + pv: makeGCEPDPV(nil /*labels*/, makeTopology(v1.LabelZoneFailureDomain, "us-east-2a")), + expectedNodeAffinity: makeNodeAffinity(false /*multiTerms*/, plugins.GCEPDTopologyKey, "us-east-2a"), + }, + { + name: "GCE PD with existing topology (CSI keys)", + pv: makeGCEPDPV(nil /*labels*/, makeTopology(plugins.GCEPDTopologyKey, "us-east-2a")), + expectedNodeAffinity: makeNodeAffinity(false /*multiTerms*/, plugins.GCEPDTopologyKey, "us-east-2a"), + }, + { + name: "GCE PD with zone labels and topology", + pv: makeGCEPDPV(defaultZoneLabels, makeTopology(v1.LabelZoneFailureDomain, "us-east-2a")), + expectedNodeAffinity: makeNodeAffinity(false /*multiTerms*/, plugins.GCEPDTopologyKey, "us-east-2a"), + }, + { + name: "GCE PD with regional zones", + pv: makeGCEPDPV(regionalPDLabels, nil /*topology*/), + expectedNodeAffinity: makeNodeAffinity(false /*multiTerms*/, plugins.GCEPDTopologyKey, "europe-west1-b", "europe-west1-c"), + }, + { + name: "GCE PD with regional topology", + pv: makeGCEPDPV(nil /*labels*/, makeTopology(v1.LabelZoneFailureDomain, "europe-west1-b", "europe-west1-c")), + expectedNodeAffinity: makeNodeAffinity(false /*multiTerms*/, plugins.GCEPDTopologyKey, "europe-west1-b", "europe-west1-c"), + }, + { + name: "GCE PD with regional zone and topology", + pv: makeGCEPDPV(regionalPDLabels, makeTopology(v1.LabelZoneFailureDomain, "europe-west1-f", "europe-west1-g")), + expectedNodeAffinity: makeNodeAffinity(false /*multiTerms*/, plugins.GCEPDTopologyKey, "europe-west1-f", "europe-west1-g"), + }, + { + name: "GCE PD with multiple node selector terms", + pv: makeGCEPDPVMultTerms( + nil, /*labels*/ + makeTopology(v1.LabelZoneFailureDomain, "europe-west1-f"), + makeTopology(v1.LabelZoneFailureDomain, "europe-west1-g")), + expectedNodeAffinity: makeNodeAffinity( + true, /*multiTerms*/ + plugins.GCEPDTopologyKey, "europe-west1-f", "europe-west1-g"), + }, + // EBS test cases: test mostly topology key, i.e., don't repeat testing done with GCE + { + name: "AWS EBS with zone labels", + pv: makeAWSEBSPV(defaultZoneLabels, nil /*topology*/), + expectedNodeAffinity: makeNodeAffinity(false /*multiTerms*/, plugins.AWSEBSTopologyKey, "us-east-1a"), + }, + { + name: "AWS EBS with zone labels and topology", + pv: makeAWSEBSPV(defaultZoneLabels, makeTopology(v1.LabelZoneFailureDomain, "us-east-2a")), + expectedNodeAffinity: makeNodeAffinity(false /*multiTerms*/, plugins.AWSEBSTopologyKey, "us-east-2a"), + }, + // Cinder test cases: test mosty topology key, i.e., don't repeat testing done with GCE + { + name: "OpenStack Cinder with zone labels", + pv: makeCinderPV(defaultZoneLabels, nil /*topology*/), + expectedNodeAffinity: makeNodeAffinity(false /*multiTerms*/, plugins.CinderTopologyKey, "us-east-1a"), + }, + { + name: "OpenStack Cinder with zone labels and topology", + pv: makeCinderPV(defaultZoneLabels, makeTopology(v1.LabelZoneFailureDomain, "us-east-2a")), + expectedNodeAffinity: makeNodeAffinity(false /*multiTerms*/, plugins.CinderTopologyKey, "us-east-2a"), + }, + } + + for _, test := range testCases { + ctl := New() + t.Logf("Testing %v", test.name) + + // Translate to CSI PV and check translated node affinity + newCSIPV, err := ctl.TranslateInTreePVToCSI(test.pv) + if err != nil { + t.Errorf("Error when translating to CSI: %v", err) + } + + nodeAffinity := newCSIPV.Spec.NodeAffinity + if !reflect.DeepEqual(nodeAffinity, test.expectedNodeAffinity) { + t.Errorf("Expected node affinity %v, got %v", *test.expectedNodeAffinity, *nodeAffinity) + } + + // Translate back to in-tree and make sure node affinity is still set + newInTreePV, err := ctl.TranslateCSIPVToInTree(newCSIPV) + if err != nil { + t.Errorf("Error when translating to in-tree: %v", err) + } + + nodeAffinity = newInTreePV.Spec.NodeAffinity + if !reflect.DeepEqual(nodeAffinity, test.expectedNodeAffinity) { + t.Errorf("Expected node affinity %v, got %v", *test.expectedNodeAffinity, *nodeAffinity) + } + } +} + +func makePV(labels map[string]string, topology *v1.NodeSelectorRequirement) *v1.PersistentVolume { + pv := &v1.PersistentVolume{ + ObjectMeta: metav1.ObjectMeta{ + Labels: labels, + }, + Spec: v1.PersistentVolumeSpec{}, + } + + if topology != nil { + pv.Spec.NodeAffinity = &v1.VolumeNodeAffinity{ + Required: &v1.NodeSelector{ + NodeSelectorTerms: []v1.NodeSelectorTerm{ + {MatchExpressions: []v1.NodeSelectorRequirement{*topology}}, + }, + }, + } + } + + return pv +} + +func makeGCEPDPV(labels map[string]string, topology *v1.NodeSelectorRequirement) *v1.PersistentVolume { + pv := makePV(labels, topology) + pv.Spec.PersistentVolumeSource = v1.PersistentVolumeSource{ + GCEPersistentDisk: &v1.GCEPersistentDiskVolumeSource{ + PDName: "test-disk", + FSType: "ext4", + Partition: 0, + ReadOnly: false, + }, + } + return pv +} + +func makeGCEPDPVMultTerms(labels map[string]string, topologies ...*v1.NodeSelectorRequirement) *v1.PersistentVolume { + pv := makeGCEPDPV(labels, topologies[0]) + for _, topology := range topologies[1:] { + pv.Spec.NodeAffinity.Required.NodeSelectorTerms = append( + pv.Spec.NodeAffinity.Required.NodeSelectorTerms, + v1.NodeSelectorTerm{ + MatchExpressions: []v1.NodeSelectorRequirement{*topology}, + }, + ) + } + return pv +} + +func makeAWSEBSPV(labels map[string]string, topology *v1.NodeSelectorRequirement) *v1.PersistentVolume { + pv := makePV(labels, topology) + pv.Spec.PersistentVolumeSource = v1.PersistentVolumeSource{ + AWSElasticBlockStore: &v1.AWSElasticBlockStoreVolumeSource{ + VolumeID: "vol01", + FSType: "ext3", + Partition: 1, + ReadOnly: true, + }, + } + return pv +} + +func makeCinderPV(labels map[string]string, topology *v1.NodeSelectorRequirement) *v1.PersistentVolume { + pv := makePV(labels, topology) + pv.Spec.PersistentVolumeSource = v1.PersistentVolumeSource{ + Cinder: &v1.CinderPersistentVolumeSource{ + VolumeID: "vol1", + FSType: "ext4", + ReadOnly: false, + }, + } + return pv +} + +func makeNodeAffinity(multiTerms bool, key string, values ...string) *v1.VolumeNodeAffinity { + nodeAffinity := &v1.VolumeNodeAffinity{ + Required: &v1.NodeSelector{ + NodeSelectorTerms: []v1.NodeSelectorTerm{ + { + MatchExpressions: []v1.NodeSelectorRequirement{ + { + Key: key, + Operator: v1.NodeSelectorOpIn, + Values: values, + }, + }, + }, + }, + }, + } + + // If multiple terms is NOT requested, return a single term with all values + if !multiTerms { + return nodeAffinity + } + + // Otherwise return multiple terms, each one with a single value + nodeAffinity.Required.NodeSelectorTerms[0].MatchExpressions[0].Values = values[:1] // If values=[1,2,3], overwrite with [1] + for _, value := range values[1:] { + term := v1.NodeSelectorTerm{ + MatchExpressions: []v1.NodeSelectorRequirement{ + { + Key: key, + Operator: v1.NodeSelectorOpIn, + Values: []string{value}, + }, + }, + } + nodeAffinity.Required.NodeSelectorTerms = append(nodeAffinity.Required.NodeSelectorTerms, term) + } + + return nodeAffinity +} + +func makeTopology(key string, values ...string) *v1.NodeSelectorRequirement { + return &v1.NodeSelectorRequirement{ + Key: key, + Operator: v1.NodeSelectorOpIn, + Values: values, + } +} + func TestPluginNameMappings(t *testing.T) { testCases := []struct { name string