diff --git a/pkg/volume/csi/csi_client.go b/pkg/volume/csi/csi_client.go index ce75b0bd589..6ca19a32d22 100644 --- a/pkg/volume/csi/csi_client.go +++ b/pkg/volume/csi/csi_client.go @@ -438,66 +438,12 @@ func (c *csiDriverClient) NodeUnstageVolume(ctx context.Context, volID, stagingT func (c *csiDriverClient) NodeSupportsNodeExpand(ctx context.Context) (bool, error) { klog.V(4).Info(log("calling NodeGetCapabilities rpc to determine if Node has EXPAND_VOLUME capability")) - if c.nodeV1ClientCreator == nil { - return false, errors.New("nodeV1ClientCreate is nil") - } - - nodeClient, closer, err := c.nodeV1ClientCreator(c.addr, c.metricsManager) - if err != nil { - return false, err - } - defer closer.Close() - - req := &csipbv1.NodeGetCapabilitiesRequest{} - resp, err := nodeClient.NodeGetCapabilities(ctx, req) - if err != nil { - return false, err - } - - capabilities := resp.GetCapabilities() - - if capabilities == nil { - return false, nil - } - for _, capability := range capabilities { - if capability.GetRpc().GetType() == csipbv1.NodeServiceCapability_RPC_EXPAND_VOLUME { - return true, nil - } - } - return false, nil + return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_EXPAND_VOLUME) } func (c *csiDriverClient) NodeSupportsStageUnstage(ctx context.Context) (bool, error) { klog.V(4).Info(log("calling NodeGetCapabilities rpc to determine if NodeSupportsStageUnstage")) - if c.nodeV1ClientCreator == nil { - return false, errors.New("nodeV1ClientCreate is nil") - } - - nodeClient, closer, err := c.nodeV1ClientCreator(c.addr, c.metricsManager) - if err != nil { - return false, err - } - defer closer.Close() - - req := &csipbv1.NodeGetCapabilitiesRequest{} - resp, err := nodeClient.NodeGetCapabilities(ctx, req) - if err != nil { - return false, err - } - - capabilities := resp.GetCapabilities() - - stageUnstageSet := false - if capabilities == nil { - return false, nil - } - for _, capability := range capabilities { - if capability.GetRpc().GetType() == csipbv1.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME { - stageUnstageSet = true - break - } - } - return stageUnstageSet, nil + return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME) } func asCSIAccessModeV1(am api.PersistentVolumeAccessMode) csipbv1.VolumeCapability_AccessMode_Mode { @@ -561,30 +507,7 @@ func (c *csiClientGetter) Get() (csiClient, error) { func (c *csiDriverClient) NodeSupportsVolumeStats(ctx context.Context) (bool, error) { klog.V(5).Info(log("calling NodeGetCapabilities rpc to determine if NodeSupportsVolumeStats")) - if c.nodeV1ClientCreator == nil { - return false, errors.New("nodeV1ClientCreate is nil") - } - - nodeClient, closer, err := c.nodeV1ClientCreator(c.addr, c.metricsManager) - if err != nil { - return false, err - } - defer closer.Close() - req := &csipbv1.NodeGetCapabilitiesRequest{} - resp, err := nodeClient.NodeGetCapabilities(ctx, req) - if err != nil { - return false, err - } - capabilities := resp.GetCapabilities() - if capabilities == nil { - return false, nil - } - for _, capability := range capabilities { - if capability.GetRpc().GetType() == csipbv1.NodeServiceCapability_RPC_GET_VOLUME_STATS { - return true, nil - } - } - return false, nil + return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_GET_VOLUME_STATS) } func (c *csiDriverClient) NodeGetVolumeStats(ctx context.Context, volID string, targetPath string) (*volume.Metrics, error) { @@ -628,7 +551,7 @@ func (c *csiDriverClient) NodeGetVolumeStats(ctx context.Context, volID string, } if utilfeature.DefaultFeatureGate.Enabled(features.CSIVolumeHealth) { - isSupportNodeVolumeCondition, err := supportNodeGetVolumeCondition(ctx, nodeClient) + isSupportNodeVolumeCondition, err := c.nodeSupportsVolumeCondition(ctx) if err != nil { return nil, err } @@ -661,30 +584,47 @@ func (c *csiDriverClient) NodeGetVolumeStats(ctx context.Context, volID string, return metrics, nil } -func supportNodeGetVolumeCondition(ctx context.Context, nodeClient csipbv1.NodeClient) (supportNodeGetVolumeCondition bool, err error) { - req := csipbv1.NodeGetCapabilitiesRequest{} - rsp, err := nodeClient.NodeGetCapabilities(ctx, &req) +func (c *csiDriverClient) nodeSupportsVolumeCondition(ctx context.Context) (bool, error) { + klog.V(5).Info(log("calling NodeGetCapabilities rpc to determine if nodeSupportsVolumeCondition")) + return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_VOLUME_CONDITION) +} + +func (c *csiDriverClient) nodeSupportsCapability(ctx context.Context, capabilityType csipbv1.NodeServiceCapability_RPC_Type) (bool, error) { + capabilities, err := c.nodeGetCapabilities(ctx) if err != nil { return false, err } - for _, cap := range rsp.GetCapabilities() { - if cap == nil { + for _, capability := range capabilities { + if capability == nil || capability.GetRpc() == nil { continue } - rpc := cap.GetRpc() - if rpc == nil { - continue - } - t := rpc.GetType() - if t == csipbv1.NodeServiceCapability_RPC_VOLUME_CONDITION { + if capability.GetRpc().GetType() == capabilityType { return true, nil } } - return false, nil } +func (c *csiDriverClient) nodeGetCapabilities(ctx context.Context) ([]*csipbv1.NodeServiceCapability, error) { + if c.nodeV1ClientCreator == nil { + return []*csipbv1.NodeServiceCapability{}, errors.New("nodeV1ClientCreate is nil") + } + + nodeClient, closer, err := c.nodeV1ClientCreator(c.addr, c.metricsManager) + if err != nil { + return []*csipbv1.NodeServiceCapability{}, err + } + defer closer.Close() + + req := &csipbv1.NodeGetCapabilitiesRequest{} + resp, err := nodeClient.NodeGetCapabilities(ctx, req) + if err != nil { + return []*csipbv1.NodeServiceCapability{}, err + } + return resp.GetCapabilities(), nil +} + func isFinalError(err error) bool { // Sources: // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md diff --git a/pkg/volume/csi/csi_client_test.go b/pkg/volume/csi/csi_client_test.go index 53238bedacf..c7ca8fa05c4 100644 --- a/pkg/volume/csi/csi_client_test.go +++ b/pkg/volume/csi/csi_client_test.go @@ -106,7 +106,7 @@ func (c *fakeCsiDriverClient) NodeGetVolumeStats(ctx context.Context, volID stri metrics := &volume.Metrics{} - isSupportNodeVolumeCondition, err := supportNodeGetVolumeCondition(ctx, c.nodeClient) + isSupportNodeVolumeCondition, err := c.nodeSupportsVolumeCondition(ctx) if err != nil { return nil, err } @@ -137,21 +137,7 @@ func (c *fakeCsiDriverClient) NodeGetVolumeStats(ctx context.Context, volID stri func (c *fakeCsiDriverClient) NodeSupportsVolumeStats(ctx context.Context) (bool, error) { c.t.Log("calling fake.NodeSupportsVolumeStats...") - req := &csipbv1.NodeGetCapabilitiesRequest{} - resp, err := c.nodeClient.NodeGetCapabilities(ctx, req) - if err != nil { - return false, err - } - capabilities := resp.GetCapabilities() - if capabilities == nil { - return false, nil - } - for _, capability := range capabilities { - if capability.GetRpc().GetType() == csipbv1.NodeServiceCapability_RPC_GET_VOLUME_STATS { - return true, nil - } - } - return false, nil + return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_GET_VOLUME_STATS) } func (c *fakeCsiDriverClient) NodePublishVolume( @@ -269,46 +255,12 @@ func (c *fakeCsiDriverClient) NodeUnstageVolume(ctx context.Context, volID, stag func (c *fakeCsiDriverClient) NodeSupportsNodeExpand(ctx context.Context) (bool, error) { c.t.Log("calling fake.NodeSupportsNodeExpand...") - req := &csipbv1.NodeGetCapabilitiesRequest{} - - resp, err := c.nodeClient.NodeGetCapabilities(ctx, req) - if err != nil { - return false, err - } - - capabilities := resp.GetCapabilities() - - if capabilities == nil { - return false, nil - } - for _, capability := range capabilities { - if capability.GetRpc().GetType() == csipbv1.NodeServiceCapability_RPC_EXPAND_VOLUME { - return true, nil - } - } - return false, nil + return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_EXPAND_VOLUME) } func (c *fakeCsiDriverClient) NodeSupportsStageUnstage(ctx context.Context) (bool, error) { c.t.Log("calling fake.NodeGetCapabilities for NodeSupportsStageUnstage...") - req := &csipbv1.NodeGetCapabilitiesRequest{} - resp, err := c.nodeClient.NodeGetCapabilities(ctx, req) - if err != nil { - return false, err - } - - capabilities := resp.GetCapabilities() - - stageUnstageSet := false - if capabilities == nil { - return false, nil - } - for _, capability := range capabilities { - if capability.GetRpc().GetType() == csipbv1.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME { - stageUnstageSet = true - } - } - return stageUnstageSet, nil + return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_STAGE_UNSTAGE_VOLUME) } func (c *fakeCsiDriverClient) NodeExpandVolume(ctx context.Context, opts csiResizeOptions) (resource.Quantity, error) { @@ -344,6 +296,34 @@ func (c *fakeCsiDriverClient) NodeExpandVolume(ctx context.Context, opts csiResi return *updatedQuantity, nil } +func (c *fakeCsiDriverClient) nodeSupportsVolumeCondition(ctx context.Context) (bool, error) { + c.t.Log("calling fake.nodeSupportsVolumeCondition...") + return c.nodeSupportsCapability(ctx, csipbv1.NodeServiceCapability_RPC_VOLUME_CONDITION) +} + +func (c *fakeCsiDriverClient) nodeSupportsCapability(ctx context.Context, capabilityType csipbv1.NodeServiceCapability_RPC_Type) (bool, error) { + capabilities, err := c.nodeGetCapabilities(ctx) + if err != nil { + return false, err + } + + for _, capability := range capabilities { + if capability.GetRpc().GetType() == capabilityType { + return true, nil + } + } + return false, nil +} + +func (c *fakeCsiDriverClient) nodeGetCapabilities(ctx context.Context) ([]*csipbv1.NodeServiceCapability, error) { + req := &csipbv1.NodeGetCapabilitiesRequest{} + resp, err := c.nodeClient.NodeGetCapabilities(ctx, req) + if err != nil { + return []*csipbv1.NodeServiceCapability{}, err + } + return resp.GetCapabilities(), nil +} + func setupClient(t *testing.T, stageUnstageSet bool) csiClient { return newFakeCsiDriverClient(t, stageUnstageSet) }