diff --git a/pkg/volume/csi/csi_client.go b/pkg/volume/csi/csi_client.go index c99996cabfe..4fb34ff0743 100644 --- a/pkg/volume/csi/csi_client.go +++ b/pkg/volume/csi/csi_client.go @@ -22,7 +22,6 @@ import ( "fmt" "io" "net" - "strings" "sync" "time" @@ -161,9 +160,8 @@ func (c *csiDriverClient) NodeGetInfo(ctx context.Context) ( if nodeID != "" { return true, nil } - // kubelet plugin registration service not implemented is a terminal error, no need to retry - if strings.Contains(getNodeInfoError.Error(), "no handler registered for plugin type") { - return false, getNodeInfoError + if getNodeInfoError != nil { + klog.Warningf("Error calling CSI NodeGetInfo(): %v", getNodeInfoError.Error()) } // Continue with exponential backoff return false, nil diff --git a/pkg/volume/csi/csi_client_test.go b/pkg/volume/csi/csi_client_test.go index 83080da2575..149319e8bbd 100644 --- a/pkg/volume/csi/csi_client_test.go +++ b/pkg/volume/csi/csi_client_test.go @@ -26,6 +26,7 @@ import ( csipbv1 "github.com/container-storage-interface/spec/lib/go/csi" api "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" + "k8s.io/apimachinery/pkg/util/wait" "k8s.io/kubernetes/pkg/volume" "k8s.io/kubernetes/pkg/volume/csi/fake" volumetypes "k8s.io/kubernetes/pkg/volume/util/types" @@ -313,6 +314,7 @@ func TestClientNodeGetInfo(t *testing.T) { expectedMaxVolumePerNode int64 expectedAccessibleTopology map[string]string mustFail bool + mustTimeout bool err error }{ { @@ -326,6 +328,13 @@ func TestClientNodeGetInfo(t *testing.T) { mustFail: true, err: errors.New("grpc error"), }, + { + name: "test empty nodeId", + mustTimeout: true, + expectedNodeID: "", + expectedMaxVolumePerNode: 16, + expectedAccessibleTopology: map[string]string{"com.example.csi-topology/zone": "zone1"}, + }, } for _, tc := range testCases { @@ -349,7 +358,13 @@ func TestClientNodeGetInfo(t *testing.T) { } nodeID, maxVolumePerNode, accessibleTopology, err := client.NodeGetInfo(context.Background()) - checkErr(t, tc.mustFail, err) + if tc.mustTimeout { + if wait.ErrWaitTimeout.Error() != err.Error() { + t.Errorf("should have timed out : %s", tc.name) + } + } else { + checkErr(t, tc.mustFail, err) + } if nodeID != tc.expectedNodeID { t.Errorf("expected nodeID: %v; got: %v", tc.expectedNodeID, nodeID)