Combine capability check implementations

This commit is contained in:
Cheng Xing 2021-06-23 16:29:46 -07:00
parent 99700f7faf
commit 65db13a3a5
2 changed files with 66 additions and 5 deletions

View File

@ -454,12 +454,10 @@ func (c *csiDriverClient) NodeUnstageVolume(ctx context.Context, volID, stagingT
}
func (c *csiDriverClient) NodeSupportsNodeExpand(ctx context.Context) (bool, error) {
klog.V(4).Info(log("calling NodeGetCapabilities rpc to determine if Node has EXPAND_VOLUME capability"))
return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_EXPAND_VOLUME)
}
func (c *csiDriverClient) NodeSupportsStageUnstage(ctx context.Context) (bool, error) {
klog.V(4).Info(log("calling NodeGetCapabilities rpc to determine if NodeSupportsStageUnstage"))
return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME)
}
@ -553,12 +551,10 @@ func (c *csiClientGetter) Get() (csiClient, error) {
}
func (c *csiDriverClient) NodeSupportsVolumeStats(ctx context.Context) (bool, error) {
klog.V(5).Info(log("calling NodeGetCapabilities rpc to determine if NodeSupportsVolumeStats"))
return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_GET_VOLUME_STATS)
}
func (c *csiDriverClient) NodeSupportsSingleNodeMultiWriterAccessMode(ctx context.Context) (bool, error) {
klog.V(4).Info(log("calling NodeGetCapabilities rpc to determine if NodeSupportsSingleNodeMultiWriterAccessMode"))
return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_SINGLE_NODE_MULTI_WRITER)
}
@ -637,11 +633,11 @@ func (c *csiDriverClient) NodeGetVolumeStats(ctx context.Context, volID string,
}
func (c *csiDriverClient) nodeSupportsVolumeCondition(ctx context.Context) (bool, error) {
klog.V(5).Info(log("calling NodeGetCapabilities rpc to determine if nodeSupportsVolumeCondition"))
return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_VOLUME_CONDITION)
}
func (c *csiDriverClient) nodeSupportsCapability(ctx context.Context, capabilityType csipbv1.NodeServiceCapability_RPC_Type) (bool, error) {
klog.V(4).Info(log("calling NodeGetCapabilities rpc to determine if the node service has %s capability", capabilityType))
capabilities, err := c.nodeGetCapabilities(ctx)
if err != nil {
return false, err

View File

@ -621,6 +621,71 @@ func TestClientNodeUnstageVolume(t *testing.T) {
}
}
func TestClientNodeSupportsStageUnstage(t *testing.T) {
testClientNodeSupportsCapabilities(t,
func(client *csiDriverClient) (bool, error) {
return client.NodeSupportsStageUnstage(context.Background())
},
func(stagingCapable bool) *fake.NodeClient {
// Creates a staging-capable client
return fake.NewNodeClient(stagingCapable)
})
}
func TestClientNodeSupportsNodeExpand(t *testing.T) {
testClientNodeSupportsCapabilities(t,
func(client *csiDriverClient) (bool, error) {
return client.NodeSupportsNodeExpand(context.Background())
},
func(expansionCapable bool) *fake.NodeClient {
return fake.NewNodeClientWithExpansion(false /* stageCapable */, expansionCapable)
})
}
func TestClientNodeSupportsVolumeStats(t *testing.T) {
testClientNodeSupportsCapabilities(t,
func(client *csiDriverClient) (bool, error) {
return client.NodeSupportsVolumeStats(context.Background())
},
func(volumeStatsCapable bool) *fake.NodeClient {
return fake.NewNodeClientWithVolumeStats(volumeStatsCapable)
})
}
func testClientNodeSupportsCapabilities(
t *testing.T,
capabilityMethodToTest func(*csiDriverClient) (bool, error),
nodeClientGenerator func(bool) *fake.NodeClient) {
testCases := []struct {
name string
capable bool
}{
{name: "positive", capable: true},
{name: "negative", capable: false},
}
for _, tc := range testCases {
t.Logf("Running test case: %s", tc.name)
fakeCloser := fake.NewCloser(t)
client := &csiDriverClient{
driverName: "Fake Driver Name",
nodeV1ClientCreator: func(addr csiAddr, m *MetricsManager) (csipbv1.NodeClient, io.Closer, error) {
nodeClient := nodeClientGenerator(tc.capable)
return nodeClient, fakeCloser, nil
},
}
got, _ := capabilityMethodToTest(client)
if got != tc.capable {
t.Errorf("Expected capability support to be %v, got: %v", tc.capable, got)
}
fakeCloser.Check()
}
}
func TestNodeExpandVolume(t *testing.T) {
testCases := []struct {
name string