From 65db13a3a5f4fa5141839f374785b277f9e98df8 Mon Sep 17 00:00:00 2001 From: Cheng Xing Date: Wed, 23 Jun 2021 16:29:46 -0700 Subject: [PATCH] Combine capability check implementations --- pkg/volume/csi/csi_client.go | 6 +-- pkg/volume/csi/csi_client_test.go | 65 +++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/pkg/volume/csi/csi_client.go b/pkg/volume/csi/csi_client.go index 3df67888013..10c949193a0 100644 --- a/pkg/volume/csi/csi_client.go +++ b/pkg/volume/csi/csi_client.go @@ -454,12 +454,10 @@ func (c *csiDriverClient) NodeUnstageVolume(ctx context.Context, volID, stagingT } func (c *csiDriverClient) NodeSupportsNodeExpand(ctx context.Context) (bool, error) { - klog.V(4).Info(log("calling NodeGetCapabilities rpc to determine if Node has EXPAND_VOLUME capability")) return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_EXPAND_VOLUME) } func (c *csiDriverClient) NodeSupportsStageUnstage(ctx context.Context) (bool, error) { - klog.V(4).Info(log("calling NodeGetCapabilities rpc to determine if NodeSupportsStageUnstage")) return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME) } @@ -553,12 +551,10 @@ func (c *csiClientGetter) Get() (csiClient, error) { } func (c *csiDriverClient) NodeSupportsVolumeStats(ctx context.Context) (bool, error) { - klog.V(5).Info(log("calling NodeGetCapabilities rpc to determine if NodeSupportsVolumeStats")) return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_GET_VOLUME_STATS) } func (c *csiDriverClient) NodeSupportsSingleNodeMultiWriterAccessMode(ctx context.Context) (bool, error) { - klog.V(4).Info(log("calling NodeGetCapabilities rpc to determine if NodeSupportsSingleNodeMultiWriterAccessMode")) return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_SINGLE_NODE_MULTI_WRITER) } @@ -637,11 +633,11 @@ func (c *csiDriverClient) NodeGetVolumeStats(ctx context.Context, volID string, } func (c *csiDriverClient) nodeSupportsVolumeCondition(ctx context.Context) (bool, error) { - klog.V(5).Info(log("calling NodeGetCapabilities rpc to determine if nodeSupportsVolumeCondition")) return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_VOLUME_CONDITION) } func (c *csiDriverClient) nodeSupportsCapability(ctx context.Context, capabilityType csipbv1.NodeServiceCapability_RPC_Type) (bool, error) { + klog.V(4).Info(log("calling NodeGetCapabilities rpc to determine if the node service has %s capability", capabilityType)) capabilities, err := c.nodeGetCapabilities(ctx) if err != nil { return false, err diff --git a/pkg/volume/csi/csi_client_test.go b/pkg/volume/csi/csi_client_test.go index 5954090285d..0b4b4042097 100644 --- a/pkg/volume/csi/csi_client_test.go +++ b/pkg/volume/csi/csi_client_test.go @@ -621,6 +621,71 @@ func TestClientNodeUnstageVolume(t *testing.T) { } } +func TestClientNodeSupportsStageUnstage(t *testing.T) { + testClientNodeSupportsCapabilities(t, + func(client *csiDriverClient) (bool, error) { + return client.NodeSupportsStageUnstage(context.Background()) + }, + func(stagingCapable bool) *fake.NodeClient { + // Creates a staging-capable client + return fake.NewNodeClient(stagingCapable) + }) +} + +func TestClientNodeSupportsNodeExpand(t *testing.T) { + testClientNodeSupportsCapabilities(t, + func(client *csiDriverClient) (bool, error) { + return client.NodeSupportsNodeExpand(context.Background()) + }, + func(expansionCapable bool) *fake.NodeClient { + return fake.NewNodeClientWithExpansion(false /* stageCapable */, expansionCapable) + }) +} + +func TestClientNodeSupportsVolumeStats(t *testing.T) { + testClientNodeSupportsCapabilities(t, + func(client *csiDriverClient) (bool, error) { + return client.NodeSupportsVolumeStats(context.Background()) + }, + func(volumeStatsCapable bool) *fake.NodeClient { + return fake.NewNodeClientWithVolumeStats(volumeStatsCapable) + }) +} + +func testClientNodeSupportsCapabilities( + t *testing.T, + capabilityMethodToTest func(*csiDriverClient) (bool, error), + nodeClientGenerator func(bool) *fake.NodeClient) { + + testCases := []struct { + name string + capable bool + }{ + {name: "positive", capable: true}, + {name: "negative", capable: false}, + } + + for _, tc := range testCases { + t.Logf("Running test case: %s", tc.name) + fakeCloser := fake.NewCloser(t) + client := &csiDriverClient{ + driverName: "Fake Driver Name", + nodeV1ClientCreator: func(addr csiAddr, m *MetricsManager) (csipbv1.NodeClient, io.Closer, error) { + nodeClient := nodeClientGenerator(tc.capable) + return nodeClient, fakeCloser, nil + }, + } + + got, _ := capabilityMethodToTest(client) + + if got != tc.capable { + t.Errorf("Expected capability support to be %v, got: %v", tc.capable, got) + } + + fakeCloser.Check() + } +} + func TestNodeExpandVolume(t *testing.T) { testCases := []struct { name string