diff --git a/pkg/volume/csi/csi_attacher.go b/pkg/volume/csi/csi_attacher.go index 3460ce0d7e8..720cbce04da 100644 --- a/pkg/volume/csi/csi_attacher.go +++ b/pkg/volume/csi/csi_attacher.go @@ -20,11 +20,16 @@ import ( "crypto/sha256" "errors" "fmt" + "os" + "path" + "path/filepath" "strings" "time" "github.com/golang/glog" + grpctx "golang.org/x/net/context" + csipb "github.com/container-storage-interface/spec/lib/go/csi/v0" "k8s.io/api/core/v1" storage "k8s.io/api/storage/v1beta1" apierrs "k8s.io/apimachinery/pkg/api/errors" @@ -35,10 +40,17 @@ import ( "k8s.io/kubernetes/pkg/volume" ) +const ( + persistentVolumeInGlobalPath = "pv" + globalMountInGlobalPath = "globalmount" +) + type csiAttacher struct { plugin *csiPlugin k8s kubernetes.Interface waitSleepTime time.Duration + + csiClient csiClient } // volume.Attacher methods @@ -229,12 +241,125 @@ func (c *csiAttacher) VolumesAreAttached(specs []*volume.Spec, nodeName types.No } func (c *csiAttacher) GetDeviceMountPath(spec *volume.Spec) (string, error) { - glog.V(4).Info(log("attacher.GetDeviceMountPath is not implemented")) - return "", nil + glog.V(4).Info(log("attacher.GetDeviceMountPath(%v)", spec)) + deviceMountPath, err := makeDeviceMountPath(c.plugin, spec) + if err != nil { + glog.Error(log("attacher.GetDeviceMountPath failed to make device mount path: %v", err)) + return "", err + } + glog.V(4).Infof("attacher.GetDeviceMountPath succeeded, deviceMountPath: %s", deviceMountPath) + return deviceMountPath, nil } func (c *csiAttacher) MountDevice(spec *volume.Spec, devicePath string, deviceMountPath string) error { - glog.V(4).Info(log("attacher.MountDevice is not implemented")) + glog.V(4).Infof(log("attacher.MountDevice(%s, %s)", devicePath, deviceMountPath)) + + mounted, err := isDirMounted(c.plugin, deviceMountPath) + if err != nil { + glog.Error(log("attacher.MountDevice failed while checking mount status for dir [%s]", deviceMountPath)) + return err + } + + if mounted { + glog.V(4).Info(log("attacher.MountDevice skipping mount, dir already mounted [%s]", deviceMountPath)) + return nil + } + + // Setup + if spec == nil { + return fmt.Errorf("attacher.MountDevice failed, spec is nil") + } + csiSource, err := getCSISourceFromSpec(spec) + if err != nil { + glog.Error(log("attacher.MountDevice failed to get CSI persistent source: %v", err)) + return err + } + + if c.csiClient == nil { + if csiSource.Driver == "" { + return fmt.Errorf("attacher.MountDevice failed, driver name is empty") + } + addr := fmt.Sprintf(csiAddrTemplate, csiSource.Driver) + c.csiClient = newCsiDriverClient("unix", addr) + } + csi := c.csiClient + + ctx, cancel := grpctx.WithTimeout(grpctx.Background(), csiTimeout) + defer cancel() + // Check whether "STAGE_UNSTAGE_VOLUME" is set + stageUnstageSet, err := hasStageUnstageCapability(ctx, csi) + if err != nil { + glog.Error(log("attacher.MountDevice failed to check STAGE_UNSTAGE_VOLUME: %v", err)) + return err + } + if !stageUnstageSet { + glog.Infof(log("attacher.MountDevice STAGE_UNSTAGE_VOLUME capability not set. Skipping MountDevice...")) + return nil + } + + // Start MountDevice + if deviceMountPath == "" { + return fmt.Errorf("attacher.MountDevice failed, deviceMountPath is empty") + } + + nodeName := string(c.plugin.host.GetNodeName()) + attachID := getAttachmentName(csiSource.VolumeHandle, csiSource.Driver, nodeName) + + // search for attachment by VolumeAttachment.Spec.Source.PersistentVolumeName + attachment, err := c.k8s.StorageV1beta1().VolumeAttachments().Get(attachID, meta.GetOptions{}) + if err != nil { + glog.Error(log("attacher.MountDevice failed while getting volume attachment [id=%v]: %v", attachID, err)) + return err + } + + if attachment == nil { + glog.Error(log("unable to find VolumeAttachment [id=%s]", attachID)) + return errors.New("no existing VolumeAttachment found") + } + publishVolumeInfo := attachment.Status.AttachmentMetadata + + // create target_dir before call to NodeStageVolume + if err := os.MkdirAll(deviceMountPath, 0750); err != nil { + glog.Error(log("attacher.MountDevice failed to create dir %#v: %v", deviceMountPath, err)) + return err + } + glog.V(4).Info(log("created target path successfully [%s]", deviceMountPath)) + + //TODO (vladimirvivien) implement better AccessModes mapping between k8s and CSI + accessMode := v1.ReadWriteOnce + if spec.PersistentVolume.Spec.AccessModes != nil { + accessMode = spec.PersistentVolume.Spec.AccessModes[0] + } + + fsType := csiSource.FSType + if len(fsType) == 0 { + fsType = defaultFSType + } + + nodeStageSecrets := map[string]string{} + if csiSource.NodeStageSecretRef != nil { + nodeStageSecrets = getCredentialsFromSecret(c.k8s, csiSource.NodeStageSecretRef) + } + + err = csi.NodeStageVolume(ctx, + csiSource.VolumeHandle, + publishVolumeInfo, + deviceMountPath, + fsType, + accessMode, + nodeStageSecrets, + csiSource.VolumeAttributes) + + if err != nil { + glog.Errorf(log("attacher.MountDevice failed: %v", err)) + if err := removeMountDir(c.plugin, deviceMountPath); err != nil { + glog.Error(log("attacher.MountDevice failed to remove mount dir after a NodeStageVolume() error [%s]: %v", deviceMountPath, err)) + return err + } + return err + } + + glog.V(4).Infof(log("attacher.MountDevice successfully requested NodeStageVolume [%s]", deviceMountPath)) return nil } @@ -335,12 +460,111 @@ func (c *csiAttacher) waitForVolumeDetachmentInternal(volumeHandle, attachID str } func (c *csiAttacher) UnmountDevice(deviceMountPath string) error { - glog.V(4).Info(log("detacher.UnmountDevice is not implemented")) + glog.V(4).Info(log("attacher.UnmountDevice(%s)", deviceMountPath)) + + // Setup + driverName, volID, err := getDriverAndVolNameFromDeviceMountPath(c.k8s, deviceMountPath) + if err != nil { + glog.Errorf(log("attacher.UnmountDevice failed to get driver and volume name from device mount path: %v", err)) + return err + } + + if c.csiClient == nil { + addr := fmt.Sprintf(csiAddrTemplate, driverName) + c.csiClient = newCsiDriverClient("unix", addr) + } + csi := c.csiClient + + ctx, cancel := grpctx.WithTimeout(grpctx.Background(), csiTimeout) + defer cancel() + // Check whether "STAGE_UNSTAGE_VOLUME" is set + stageUnstageSet, err := hasStageUnstageCapability(ctx, csi) + if err != nil { + glog.Errorf(log("attacher.UnmountDevice failed to check whether STAGE_UNSTAGE_VOLUME set: %v", err)) + return err + } + if !stageUnstageSet { + glog.Infof(log("attacher.UnmountDevice STAGE_UNSTAGE_VOLUME capability not set. Skipping UnmountDevice...")) + return nil + } + + // Start UnmountDevice + err = csi.NodeUnstageVolume(ctx, + volID, + deviceMountPath) + + if err != nil { + glog.Errorf(log("attacher.UnmountDevice failed: %v", err)) + return err + } + + glog.V(4).Infof(log("attacher.UnmountDevice successfully requested NodeStageVolume [%s]", deviceMountPath)) return nil } +func hasStageUnstageCapability(ctx grpctx.Context, csi csiClient) (bool, error) { + capabilities, err := csi.NodeGetCapabilities(ctx) + if err != nil { + return false, err + } + + stageUnstageSet := false + if capabilities == nil { + return false, nil + } + for _, capability := range capabilities { + if capability.GetRpc().GetType() == csipb.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME { + stageUnstageSet = true + } + } + return stageUnstageSet, nil +} + // getAttachmentName returns csi- func getAttachmentName(volName, csiDriverName, nodeName string) string { result := sha256.Sum256([]byte(fmt.Sprintf("%s%s%s", volName, csiDriverName, nodeName))) return fmt.Sprintf("csi-%x", result) } + +func makeDeviceMountPath(plugin *csiPlugin, spec *volume.Spec) (string, error) { + if spec == nil { + return "", fmt.Errorf("makeDeviceMountPath failed, spec is nil") + } + + pvName := spec.PersistentVolume.Name + if pvName == "" { + return "", fmt.Errorf("makeDeviceMountPath failed, pv name empty") + } + + return path.Join(plugin.host.GetPluginDir(plugin.GetPluginName()), persistentVolumeInGlobalPath, pvName, globalMountInGlobalPath), nil +} + +func getDriverAndVolNameFromDeviceMountPath(k8s kubernetes.Interface, deviceMountPath string) (string, string, error) { + // deviceMountPath structure: /var/lib/kubelet/plugins/kubernetes.io/csi/pv/{pvname}/globalmount + dir := filepath.Dir(deviceMountPath) + if file := filepath.Base(deviceMountPath); file != globalMountInGlobalPath { + return "", "", fmt.Errorf("getDriverAndVolNameFromDeviceMountPath failed, path did not end in %s", globalMountInGlobalPath) + } + // dir is now /var/lib/kubelet/plugins/kubernetes.io/csi/pv/{pvname} + pvName := filepath.Base(dir) + + // Get PV and check for errors + pv, err := k8s.CoreV1().PersistentVolumes().Get(pvName, meta.GetOptions{}) + if err != nil { + return "", "", err + } + if pv == nil || pv.Spec.CSI == nil { + return "", "", fmt.Errorf("getDriverAndVolNameFromDeviceMountPath could not find CSI Persistent Volume Source for pv: %s", pvName) + } + + // Get VolumeHandle and PluginName from pv + csiSource := pv.Spec.CSI + if csiSource.Driver == "" { + return "", "", fmt.Errorf("getDriverAndVolNameFromDeviceMountPath failed, driver name empty") + } + if csiSource.VolumeHandle == "" { + return "", "", fmt.Errorf("getDriverAndVolNameFromDeviceMountPath failed, VolumeHandle empty") + } + + return csiSource.Driver, csiSource.VolumeHandle, nil +} diff --git a/pkg/volume/csi/csi_attacher_test.go b/pkg/volume/csi/csi_attacher_test.go index 4b0d67baa2b..9266acbb145 100644 --- a/pkg/volume/csi/csi_attacher_test.go +++ b/pkg/volume/csi/csi_attacher_test.go @@ -19,6 +19,7 @@ package csi import ( "fmt" "os" + "path/filepath" "testing" "time" @@ -31,6 +32,7 @@ import ( core "k8s.io/client-go/testing" utiltesting "k8s.io/client-go/util/testing" "k8s.io/kubernetes/pkg/volume" + "k8s.io/kubernetes/pkg/volume/csi/fake" volumetest "k8s.io/kubernetes/pkg/volume/testing" ) @@ -386,6 +388,294 @@ func TestAttacherDetach(t *testing.T) { } } +func TestAttacherGetDeviceMountPath(t *testing.T) { + // Setup + // Create a new attacher + plug, _, tmpDir := newTestWatchPlugin(t) + defer os.RemoveAll(tmpDir) + attacher, err0 := plug.NewAttacher() + if err0 != nil { + t.Fatalf("failed to create new attacher: %v", err0) + } + csiAttacher := attacher.(*csiAttacher) + + pluginDir := csiAttacher.plugin.host.GetPluginDir(plug.GetPluginName()) + + testCases := []struct { + testName string + pvName string + expectedMountPath string + shouldFail bool + }{ + { + testName: "normal test", + pvName: "test-pv1", + expectedMountPath: pluginDir + "/pv/test-pv1/globalmount", + }, + { + testName: "no pv name", + pvName: "", + expectedMountPath: pluginDir + "/pv/test-pv1/globalmount", + shouldFail: true, + }, + } + + for _, tc := range testCases { + t.Logf("Running test case: %s", tc.testName) + var spec *volume.Spec + + // Create spec + pv := makeTestPV(tc.pvName, 10, testDriver, "testvol") + spec = volume.NewSpecFromPersistentVolume(pv, pv.Spec.PersistentVolumeSource.CSI.ReadOnly) + + // Run + mountPath, err := csiAttacher.GetDeviceMountPath(spec) + + // Verify + if err != nil && !tc.shouldFail { + t.Errorf("test should not fail, but error occurred: %v", err) + } else if err == nil { + if tc.shouldFail { + t.Errorf("test should fail, but no error occurred") + } else if mountPath != tc.expectedMountPath { + t.Errorf("mountPath does not equal expectedMountPath. Got: %s. Expected: %s", mountPath, tc.expectedMountPath) + } + } + } +} + +func TestAttacherMountDevice(t *testing.T) { + testCases := []struct { + testName string + volName string + devicePath string + deviceMountPath string + stageUnstageSet bool + shouldFail bool + }{ + { + testName: "normal", + volName: "test-vol1", + devicePath: "path1", + deviceMountPath: "path2", + stageUnstageSet: true, + }, + { + testName: "no vol name", + volName: "", + devicePath: "path1", + deviceMountPath: "path2", + stageUnstageSet: true, + shouldFail: true, + }, + { + testName: "no device path", + volName: "test-vol1", + devicePath: "", + deviceMountPath: "path2", + stageUnstageSet: true, + shouldFail: true, + }, + { + testName: "no device mount path", + volName: "test-vol1", + devicePath: "path1", + deviceMountPath: "", + stageUnstageSet: true, + shouldFail: true, + }, + { + testName: "stage_unstage cap not set", + volName: "test-vol1", + devicePath: "path1", + deviceMountPath: "path2", + stageUnstageSet: false, + }, + { + testName: "stage_unstage not set no vars should not fail", + stageUnstageSet: false, + }, + } + + for _, tc := range testCases { + t.Logf("Running test case: %s", tc.testName) + var spec *volume.Spec + pvName := "test-pv" + + // Setup + // Create a new attacher + plug, fakeWatcher, tmpDir := newTestWatchPlugin(t) + defer os.RemoveAll(tmpDir) + attacher, err0 := plug.NewAttacher() + if err0 != nil { + t.Fatalf("failed to create new attacher: %v", err0) + } + csiAttacher := attacher.(*csiAttacher) + csiAttacher.csiClient = setupClient(t, tc.stageUnstageSet) + + nodeName := string(csiAttacher.plugin.host.GetNodeName()) + + // Create spec + pv := makeTestPV(pvName, 10, testDriver, tc.volName) + spec = volume.NewSpecFromPersistentVolume(pv, pv.Spec.PersistentVolumeSource.CSI.ReadOnly) + + attachID := getAttachmentName(tc.volName, testDriver, nodeName) + + // Set up volume attachment + attachment := makeTestAttachment(attachID, nodeName, pvName) + _, err := csiAttacher.k8s.StorageV1beta1().VolumeAttachments().Create(attachment) + if err != nil { + t.Fatalf("failed to attach: %v", err) + } + go func() { + fakeWatcher.Delete(attachment) + }() + + // Run + err = csiAttacher.MountDevice(spec, tc.devicePath, tc.deviceMountPath) + + // Verify + if err != nil { + if !tc.shouldFail { + t.Errorf("test should not fail, but error occurred: %v", err) + } + return + } + if err == nil && tc.shouldFail { + t.Errorf("test should fail, but no error occurred") + } + + // Verify call goes through all the way + numStaged := 1 + if !tc.stageUnstageSet { + numStaged = 0 + } + + cdc := csiAttacher.csiClient.(*csiDriverClient) + staged := cdc.nodeClient.(*fake.NodeClient).GetNodeStagedVolumes() + if len(staged) != numStaged { + t.Errorf("got wrong number of staged volumes, expecting %v got: %v", numStaged, len(staged)) + } + if tc.stageUnstageSet { + gotPath, ok := staged[tc.volName] + if !ok { + t.Errorf("could not find staged volume: %s", tc.volName) + } + if gotPath != tc.deviceMountPath { + t.Errorf("expected mount path: %s. got: %s", tc.deviceMountPath, gotPath) + } + } + } +} + +func TestAttacherUnmountDevice(t *testing.T) { + testCases := []struct { + testName string + volID string + deviceMountPath string + stageUnstageSet bool + shouldFail bool + }{ + { + testName: "normal", + volID: "project/zone/test-vol1", + deviceMountPath: "/tmp/csi-test049507108/plugins/csi/pv/test-pv-name/globalmount", + stageUnstageSet: true, + }, + { + testName: "no device mount path", + volID: "project/zone/test-vol1", + deviceMountPath: "", + stageUnstageSet: true, + shouldFail: true, + }, + { + testName: "missing part of device mount path", + volID: "project/zone/test-vol1", + deviceMountPath: "/tmp/csi-test049507108/plugins/csi/pv/test-pv-name/globalmount", + stageUnstageSet: true, + shouldFail: true, + }, + { + testName: "test volume name mismatch", + volID: "project/zone/test-vol1", + deviceMountPath: "/tmp/csi-test049507108/plugins/csi/pv/test-pv-name/globalmount", + stageUnstageSet: true, + shouldFail: true, + }, + { + testName: "stage_unstage not set", + volID: "project/zone/test-vol1", + deviceMountPath: "/tmp/csi-test049507108/plugins/csi/pv/test-pv-name/globalmount", + stageUnstageSet: false, + }, + { + testName: "stage_unstage not set no vars should not fail", + stageUnstageSet: false, + }, + } + + for _, tc := range testCases { + t.Logf("Running test case: %s", tc.testName) + // Setup + // Create a new attacher + plug, _, tmpDir := newTestWatchPlugin(t) + defer os.RemoveAll(tmpDir) + attacher, err0 := plug.NewAttacher() + if err0 != nil { + t.Fatalf("failed to create new attacher: %v", err0) + } + csiAttacher := attacher.(*csiAttacher) + csiAttacher.csiClient = setupClient(t, tc.stageUnstageSet) + + // Add the volume to NodeStagedVolumes + cdc := csiAttacher.csiClient.(*csiDriverClient) + cdc.nodeClient.(*fake.NodeClient).AddNodeStagedVolume(tc.volID, tc.deviceMountPath) + + // Make the PV for this object + dir := filepath.Dir(tc.deviceMountPath) + // dir is now /var/lib/kubelet/plugins/kubernetes.io/csi/pv/{pvname} + pvName := filepath.Base(dir) + pv := makeTestPV(pvName, 5, "csi", tc.volID) + _, err := csiAttacher.k8s.CoreV1().PersistentVolumes().Create(pv) + if err != nil && !tc.shouldFail { + t.Fatalf("Failed to create PV: %v", err) + } + + // Run + err = csiAttacher.UnmountDevice(tc.deviceMountPath) + + // Verify + if err != nil { + if !tc.shouldFail { + t.Errorf("test should not fail, but error occurred: %v", err) + } + return + } + if err == nil && tc.shouldFail { + t.Errorf("test should fail, but no error occurred") + } + + // Verify call goes through all the way + expectedSet := 0 + if !tc.stageUnstageSet { + expectedSet = 1 + } + staged := cdc.nodeClient.(*fake.NodeClient).GetNodeStagedVolumes() + if len(staged) != expectedSet { + t.Errorf("got wrong number of staged volumes, expecting %v got: %v", expectedSet, len(staged)) + } + + _, ok := staged[tc.volID] + if ok && tc.stageUnstageSet { + t.Errorf("found unexpected staged volume: %s", tc.volID) + } else if !ok && !tc.stageUnstageSet { + t.Errorf("could not find expected staged volume: %s", tc.volID) + } + + } +} + // create a plugin mgr to load plugins and setup a fake client func newTestWatchPlugin(t *testing.T) (*csiPlugin, *watch.FakeWatcher, string) { tmpDir, err := utiltesting.MkTmpdir("csi-test") diff --git a/pkg/volume/csi/csi_client.go b/pkg/volume/csi/csi_client.go index 5dfc83482ca..4ec6f9575b7 100644 --- a/pkg/volume/csi/csi_client.go +++ b/pkg/volume/csi/csi_client.go @@ -33,6 +33,7 @@ type csiClient interface { ctx grpctx.Context, volumeid string, readOnly bool, + stagingTargetPath string, targetPath string, accessMode api.PersistentVolumeAccessMode, volumeInfo map[string]string, @@ -45,6 +46,17 @@ type csiClient interface { volID string, targetPath string, ) error + NodeStageVolume(ctx grpctx.Context, + volID string, + publishVolumeInfo map[string]string, + stagingTargetPath string, + fsType string, + accessMode api.PersistentVolumeAccessMode, + nodeStageSecrets map[string]string, + volumeAttribs map[string]string, + ) error + NodeUnstageVolume(ctx grpctx.Context, volID, stagingTargetPath string) error + NodeGetCapabilities(ctx grpctx.Context) ([]*csipb.NodeServiceCapability, error) } // csiClient encapsulates all csi-plugin methods @@ -94,6 +106,7 @@ func (c *csiDriverClient) NodePublishVolume( ctx grpctx.Context, volID string, readOnly bool, + stagingTargetPath string, targetPath string, accessMode api.PersistentVolumeAccessMode, volumeInfo map[string]string, @@ -131,6 +144,9 @@ func (c *csiDriverClient) NodePublishVolume( }, }, } + if stagingTargetPath != "" { + req.StagingTargetPath = stagingTargetPath + } _, err := c.nodeClient.NodePublishVolume(ctx, req) return err @@ -158,6 +174,84 @@ func (c *csiDriverClient) NodeUnpublishVolume(ctx grpctx.Context, volID string, return err } +func (c *csiDriverClient) NodeStageVolume(ctx grpctx.Context, + volID string, + publishInfo map[string]string, + stagingTargetPath string, + fsType string, + accessMode api.PersistentVolumeAccessMode, + nodeStageSecrets map[string]string, + volumeAttribs map[string]string, +) error { + glog.V(4).Info(log("calling NodeStageVolume rpc [volid=%s,staging_target_path=%s]", volID, stagingTargetPath)) + if volID == "" { + return errors.New("missing volume id") + } + if stagingTargetPath == "" { + return errors.New("missing staging target path") + } + if err := c.assertConnection(); err != nil { + glog.Errorf("%v: failed to assert a connection: %v", csiPluginName, err) + return err + } + + req := &csipb.NodeStageVolumeRequest{ + VolumeId: volID, + PublishInfo: publishInfo, + StagingTargetPath: stagingTargetPath, + VolumeCapability: &csipb.VolumeCapability{ + AccessMode: &csipb.VolumeCapability_AccessMode{ + Mode: asCSIAccessMode(accessMode), + }, + AccessType: &csipb.VolumeCapability_Mount{ + Mount: &csipb.VolumeCapability_MountVolume{ + FsType: fsType, + }, + }, + }, + NodeStageSecrets: nodeStageSecrets, + VolumeAttributes: volumeAttribs, + } + + _, err := c.nodeClient.NodeStageVolume(ctx, req) + return err +} + +func (c *csiDriverClient) NodeUnstageVolume(ctx grpctx.Context, volID, stagingTargetPath string) error { + glog.V(4).Info(log("calling NodeUnstageVolume rpc [volid=%s,staging_target_path=%s]", volID, stagingTargetPath)) + if volID == "" { + return errors.New("missing volume id") + } + if stagingTargetPath == "" { + return errors.New("missing staging target path") + } + if err := c.assertConnection(); err != nil { + glog.Errorf("%v: failed to assert a connection: %v", csiPluginName, err) + return err + } + + req := &csipb.NodeUnstageVolumeRequest{ + VolumeId: volID, + StagingTargetPath: stagingTargetPath, + } + _, err := c.nodeClient.NodeUnstageVolume(ctx, req) + return err +} + +func (c *csiDriverClient) NodeGetCapabilities(ctx grpctx.Context) ([]*csipb.NodeServiceCapability, error) { + glog.V(4).Info(log("calling NodeGetCapabilities rpc")) + if err := c.assertConnection(); err != nil { + glog.Errorf("%v: failed to assert a connection: %v", csiPluginName, err) + return nil, err + } + req := &csipb.NodeGetCapabilitiesRequest{} + resp, err := c.nodeClient.NodeGetCapabilities(ctx, req) + if err != nil { + return nil, err + } + return resp.GetCapabilities(), nil +} + func asCSIAccessMode(am api.PersistentVolumeAccessMode) csipb.VolumeCapability_AccessMode_Mode { switch am { case api.ReadWriteOnce: diff --git a/pkg/volume/csi/csi_client_test.go b/pkg/volume/csi/csi_client_test.go index 4a53ebc5290..753396add36 100644 --- a/pkg/volume/csi/csi_client_test.go +++ b/pkg/volume/csi/csi_client_test.go @@ -26,13 +26,13 @@ import ( "k8s.io/kubernetes/pkg/volume/csi/fake" ) -func setupClient(t *testing.T) *csiDriverClient { +func setupClient(t *testing.T, stageUnstageSet bool) *csiDriverClient { client := newCsiDriverClient("unix", "/tmp/test.sock") client.conn = new(grpc.ClientConn) //avoids creating conn object // setup mock grpc clients client.idClient = fake.NewIdentityClient() - client.nodeClient = fake.NewNodeClient() + client.nodeClient = fake.NewNodeClient(stageUnstageSet) client.ctrlClient = fake.NewControllerClient() return client @@ -54,7 +54,7 @@ func TestClientNodePublishVolume(t *testing.T) { {name: "grpc error", volID: "vol-test", targetPath: "/test/path", mustFail: true, err: errors.New("grpc error")}, } - client := setupClient(t) + client := setupClient(t, false) for _, tc := range testCases { t.Logf("test case: %s", tc.name) @@ -63,6 +63,7 @@ func TestClientNodePublishVolume(t *testing.T) { grpctx.Background(), tc.volID, false, + "", tc.targetPath, api.ReadWriteOnce, map[string]string{"device": "/dev/null"}, @@ -91,7 +92,7 @@ func TestClientNodeUnpublishVolume(t *testing.T) { {name: "grpc error", volID: "vol-test", targetPath: "/test/path", mustFail: true, err: errors.New("grpc error")}, } - client := setupClient(t) + client := setupClient(t, false) for _, tc := range testCases { t.Logf("test case: %s", tc.name) @@ -102,3 +103,71 @@ func TestClientNodeUnpublishVolume(t *testing.T) { } } } + +func TestClientNodeStageVolume(t *testing.T) { + testCases := []struct { + name string + volID string + stagingTargetPath string + fsType string + secret map[string]string + mustFail bool + err error + }{ + {name: "test ok", volID: "vol-test", stagingTargetPath: "/test/path", fsType: "ext4"}, + {name: "missing volID", stagingTargetPath: "/test/path", mustFail: true}, + {name: "missing target path", volID: "vol-test", mustFail: true}, + {name: "bad fs", volID: "vol-test", stagingTargetPath: "/test/path", fsType: "badfs", mustFail: true}, + {name: "grpc error", volID: "vol-test", stagingTargetPath: "/test/path", mustFail: true, err: errors.New("grpc error")}, + } + + client := setupClient(t, false) + + for _, tc := range testCases { + t.Logf("Running test case: %s", tc.name) + client.nodeClient.(*fake.NodeClient).SetNextError(tc.err) + err := client.NodeStageVolume( + grpctx.Background(), + tc.volID, + map[string]string{"device": "/dev/null"}, + tc.stagingTargetPath, + tc.fsType, + api.ReadWriteOnce, + tc.secret, + map[string]string{"attr0": "val0"}, + ) + + if tc.mustFail && err == nil { + t.Error("test must fail, but err is nil") + } + } +} + +func TestClientNodeUnstageVolume(t *testing.T) { + testCases := []struct { + name string + volID string + stagingTargetPath string + mustFail bool + err error + }{ + {name: "test ok", volID: "vol-test", stagingTargetPath: "/test/path"}, + {name: "missing volID", stagingTargetPath: "/test/path", mustFail: true}, + {name: "missing target path", volID: "vol-test", mustFail: true}, + {name: "grpc error", volID: "vol-test", stagingTargetPath: "/test/path", mustFail: true, err: errors.New("grpc error")}, + } + + client := setupClient(t, false) + + for _, tc := range testCases { + t.Logf("Running test case: %s", tc.name) + client.nodeClient.(*fake.NodeClient).SetNextError(tc.err) + err := client.NodeUnstageVolume( + grpctx.Background(), + tc.volID, tc.stagingTargetPath, + ) + if tc.mustFail && err == nil { + t.Error("test must fail, but err is nil") + } + } +} diff --git a/pkg/volume/csi/csi_mounter.go b/pkg/volume/csi/csi_mounter.go index 576ef09e787..213d0b58200 100644 --- a/pkg/volume/csi/csi_mounter.go +++ b/pkg/volume/csi/csi_mounter.go @@ -114,13 +114,28 @@ func (c *csiMountMgr) SetUpAt(dir string, fsGroup *int64) error { return err } - ctx, cancel := grpctx.WithTimeout(grpctx.Background(), csiTimeout) - defer cancel() - csi := c.csiClient nodeName := string(c.plugin.host.GetNodeName()) attachID := getAttachmentName(csiSource.VolumeHandle, csiSource.Driver, nodeName) + ctx, cancel := grpctx.WithTimeout(grpctx.Background(), csiTimeout) + defer cancel() + // Check for STAGE_UNSTAGE_VOLUME set and populate deviceMountPath if so + deviceMountPath := "" + stageUnstageSet, err := hasStageUnstageCapability(ctx, csi) + if err != nil { + glog.Error(log("mounter.SetUpAt failed to check for STAGE_UNSTAGE_VOLUME capabilty: %v", err)) + return err + } + + if stageUnstageSet { + deviceMountPath, err = makeDeviceMountPath(c.plugin, c.spec) + if err != nil { + glog.Error(log("mounter.SetUpAt failed to make device mount path: %v", err)) + return err + } + } + // search for attachment by VolumeAttachment.Spec.Source.PersistentVolumeName if c.volumeInfo == nil { attachment, err := c.k8s.StorageV1beta1().VolumeAttachments().Get(attachID, meta.GetOptions{}) @@ -181,6 +196,7 @@ func (c *csiMountMgr) SetUpAt(dir string, fsGroup *int64) error { ctx, c.volumeID, c.readOnly, + deviceMountPath, dir, accessMode, c.volumeInfo, diff --git a/pkg/volume/csi/csi_mounter_test.go b/pkg/volume/csi/csi_mounter_test.go index 90555063dd6..53ad8ba8ecc 100644 --- a/pkg/volume/csi/csi_mounter_test.go +++ b/pkg/volume/csi/csi_mounter_test.go @@ -114,7 +114,7 @@ func TestMounterSetUp(t *testing.T) { } csiMounter := mounter.(*csiMountMgr) - csiMounter.csiClient = setupClient(t) + csiMounter.csiClient = setupClient(t, false) attachID := getAttachmentName(csiMounter.volumeID, csiMounter.driverName, string(plug.host.GetNodeName())) @@ -172,7 +172,7 @@ func TestUnmounterTeardown(t *testing.T) { } csiUnmounter := unmounter.(*csiMountMgr) - csiUnmounter.csiClient = setupClient(t) + csiUnmounter.csiClient = setupClient(t, false) dir := csiUnmounter.GetPath() diff --git a/pkg/volume/csi/csi_plugin_test.go b/pkg/volume/csi/csi_plugin_test.go index 32a0df32912..66151e0ac7d 100644 --- a/pkg/volume/csi/csi_plugin_test.go +++ b/pkg/volume/csi/csi_plugin_test.go @@ -64,8 +64,7 @@ func newTestPlugin(t *testing.T) (*csiPlugin, string) { func makeTestPV(name string, sizeGig int, driverName, volID string) *api.PersistentVolume { return &api.PersistentVolume{ ObjectMeta: meta.ObjectMeta{ - Name: name, - Namespace: testns, + Name: name, }, Spec: api.PersistentVolumeSpec{ AccessModes: []api.PersistentVolumeAccessMode{api.ReadWriteOnce}, diff --git a/pkg/volume/csi/fake/fake_client.go b/pkg/volume/csi/fake/fake_client.go index 9fe28926e02..ce95ff94736 100644 --- a/pkg/volume/csi/fake/fake_client.go +++ b/pkg/volume/csi/fake/fake_client.go @@ -60,12 +60,18 @@ func (f *IdentityClient) Probe(ctx context.Context, in *csipb.ProbeRequest, opts // NodeClient returns CSI node client type NodeClient struct { nodePublishedVolumes map[string]string + nodeStagedVolumes map[string]string + stageUnstageSet bool nextErr error } // NewNodeClient returns fake node client -func NewNodeClient() *NodeClient { - return &NodeClient{nodePublishedVolumes: make(map[string]string)} +func NewNodeClient(stageUnstageSet bool) *NodeClient { + return &NodeClient{ + nodePublishedVolumes: make(map[string]string), + nodeStagedVolumes: make(map[string]string), + stageUnstageSet: stageUnstageSet, + } } // SetNextError injects next expected error @@ -78,6 +84,15 @@ func (f *NodeClient) GetNodePublishedVolumes() map[string]string { return f.nodePublishedVolumes } +// GetNodeStagedVolumes returns node staged volumes +func (f *NodeClient) GetNodeStagedVolumes() map[string]string { + return f.nodeStagedVolumes +} + +func (f *NodeClient) AddNodeStagedVolume(volID, deviceMountPath string) { + f.nodeStagedVolumes[volID] = deviceMountPath +} + // NodePublishVolume implements CSI NodePublishVolume func (f *NodeClient) NodePublishVolume(ctx grpctx.Context, req *csipb.NodePublishVolumeRequest, opts ...grpc.CallOption) (*csipb.NodePublishVolumeResponse, error) { @@ -116,6 +131,50 @@ func (f *NodeClient) NodeUnpublishVolume(ctx context.Context, req *csipb.NodeUnp return &csipb.NodeUnpublishVolumeResponse{}, nil } +// NodeStagevolume implements csi method +func (f *NodeClient) NodeStageVolume(ctx context.Context, req *csipb.NodeStageVolumeRequest, opts ...grpc.CallOption) (*csipb.NodeStageVolumeResponse, error) { + if f.nextErr != nil { + return nil, f.nextErr + } + + if req.GetVolumeId() == "" { + return nil, errors.New("missing volume id") + } + if req.GetStagingTargetPath() == "" { + return nil, errors.New("missing staging target path") + } + + fsType := "" + fsTypes := "ext4|xfs|zfs" + mounted := req.GetVolumeCapability().GetMount() + if mounted != nil { + fsType = mounted.GetFsType() + } + if !strings.Contains(fsTypes, fsType) { + return nil, errors.New("invalid fstype") + } + + f.nodeStagedVolumes[req.GetVolumeId()] = req.GetStagingTargetPath() + return &csipb.NodeStageVolumeResponse{}, nil +} + +// NodeUnstageVolume implements csi method +func (f *NodeClient) NodeUnstageVolume(ctx context.Context, req *csipb.NodeUnstageVolumeRequest, opts ...grpc.CallOption) (*csipb.NodeUnstageVolumeResponse, error) { + if f.nextErr != nil { + return nil, f.nextErr + } + + if req.GetVolumeId() == "" { + return nil, errors.New("missing volume id") + } + if req.GetStagingTargetPath() == "" { + return nil, errors.New("missing staging target path") + } + + delete(f.nodeStagedVolumes, req.GetVolumeId()) + return &csipb.NodeUnstageVolumeResponse{}, nil +} + // NodeGetId implements method func (f *NodeClient) NodeGetId(ctx context.Context, in *csipb.NodeGetIdRequest, opts ...grpc.CallOption) (*csipb.NodeGetIdResponse, error) { return nil, nil @@ -123,16 +182,20 @@ func (f *NodeClient) NodeGetId(ctx context.Context, in *csipb.NodeGetIdRequest, // NodeGetCapabilities implements csi method func (f *NodeClient) NodeGetCapabilities(ctx context.Context, in *csipb.NodeGetCapabilitiesRequest, opts ...grpc.CallOption) (*csipb.NodeGetCapabilitiesResponse, error) { - return nil, nil -} - -// NodeStageVolume implements csi method -func (f *NodeClient) NodeStageVolume(ctx context.Context, in *csipb.NodeStageVolumeRequest, opts ...grpc.CallOption) (*csipb.NodeStageVolumeResponse, error) { - return nil, nil -} - -// NodeUnstageVolume implements csi method -func (f *NodeClient) NodeUnstageVolume(ctx context.Context, in *csipb.NodeUnstageVolumeRequest, opts ...grpc.CallOption) (*csipb.NodeUnstageVolumeResponse, error) { + resp := &csipb.NodeGetCapabilitiesResponse{ + Capabilities: []*csipb.NodeServiceCapability{ + { + Type: &csipb.NodeServiceCapability_Rpc{ + Rpc: &csipb.NodeServiceCapability_RPC{ + Type: csipb.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME, + }, + }, + }, + }, + } + if f.stageUnstageSet { + return resp, nil + } return nil, nil }