diff --git a/pkg/volume/csi/nodeinfomanager/nodeinfomanager.go b/pkg/volume/csi/nodeinfomanager/nodeinfomanager.go index 37f24151b7e..a27e2966919 100644 --- a/pkg/volume/csi/nodeinfomanager/nodeinfomanager.go +++ b/pkg/volume/csi/nodeinfomanager/nodeinfomanager.go @@ -21,6 +21,7 @@ package nodeinfomanager // import "k8s.io/kubernetes/pkg/volume/csi/nodeinfomana import ( "encoding/json" "fmt" + "strings" csipb "github.com/container-storage-interface/spec/lib/go/csi/v0" "k8s.io/api/core/v1" @@ -386,6 +387,10 @@ func (nim *nodeInfoManager) CreateCSINodeInfo() (*csiv1alpha1.CSINodeInfo, error }, } + err = validateCSINodeInfo(nodeInfo) + if err != nil { + return err + } return csiKubeClient.CsiV1alpha1().CSINodeInfos().Create(nodeInfo) } @@ -458,7 +463,11 @@ func (nim *nodeInfoManager) installDriverToCSINodeInfo( nodeInfo.Spec.Drivers = newDriverSpecs nodeInfo.Status.Drivers = newDriverStatuses - _, err := csiKubeClient.CsiV1alpha1().CSINodeInfos().Update(nodeInfo) + err := validateCSINodeInfo(nodeInfo) + if err != nil { + return err + } + _, err = csiKubeClient.CsiV1alpha1().CSINodeInfos().Update(nodeInfo) return err // do not wrap error } @@ -557,3 +566,33 @@ func removeMaxAttachLimit(driverName string) nodeUpdateFunc { return node, true, nil } } + +// validateCSINodeInfo ensures members of CSINodeInfo object satisfies map and set semantics. +// Before calling CSINodeInfoInterface.Create() or CSINodeInfoInterface.Update() +// validateCSINodeInfo() should be invoked to make sure the CSINodeInfo is compliant +// TODO: move this logic to an external webhook +func validateCSINodeInfo(nodeInfo *csiv1alpha1.CSINodeInfo) error { + if len(nodeInfo.CSIDrivers) < 1 { + return fmt.Errorf("at least one CSI Driver entry is required") + } + // check for duplicate entries for the same driver + var errors []string + driverNames := make(sets.String) + for _, driverInfo := range nodeInfo.CSIDrivers { + if driverNames.Has(driverInfo.Driver) { + errors = append(errors, fmt.Sprintf("duplicate entries found for driver %s", driverNames)) + } + driverNames.Insert(driverInfo.Driver) + topoKeys := make(sets.String) + for _, key := range driverInfo.TopologyKeys { + if topoKeys.Has(key) { + errors = append(errors, fmt.Sprintf("duplicate topology keys %s found for driver %s", key, driverInfo.Driver)) + } + topoKeys.Insert(key) + } + } + if len(errors) == 0 { + return nil + } + return fmt.Errorf(strings.Join(errors, ", ")) +} diff --git a/pkg/volume/csi/nodeinfomanager/nodeinfomanager_test.go b/pkg/volume/csi/nodeinfomanager/nodeinfomanager_test.go index b58ab813dcd..f9803d745b8 100644 --- a/pkg/volume/csi/nodeinfomanager/nodeinfomanager_test.go +++ b/pkg/volume/csi/nodeinfomanager/nodeinfomanager_test.go @@ -637,6 +637,110 @@ func TestInstallCSIDriverExistingAnnotation(t *testing.T) { } } +func TestValidateCSINodeInfo(t *testing.T) { + testcases := []struct { + name string + nodeInfo *csiv1alpha1.CSINodeInfo + expectErr bool + }{ + { + name: "multiple drivers with same ids and different topology keys", + nodeInfo: &csiv1alpha1.CSINodeInfo{ + CSIDrivers: []csiv1alpha1.CSIDriverInfo{ + { + Driver: "driver1", + NodeID: "node1", + TopologyKeys: []string{"key1, key2"}, + }, + { + Driver: "driverB", + NodeID: "nodeA", + TopologyKeys: []string{"keyA", "keyB"}, + }, + }, + }, + expectErr: false, + }, + { + name: "multiple drivers with same ids and similar topology keys", + nodeInfo: &csiv1alpha1.CSINodeInfo{ + CSIDrivers: []csiv1alpha1.CSIDriverInfo{ + { + Driver: "driver1", + NodeID: "node1", + TopologyKeys: []string{"key1"}, + }, + { + Driver: "driver2", + NodeID: "node1", + TopologyKeys: []string{"key1"}, + }, + }, + }, + expectErr: false, + }, + { + name: "duplicate drivers", + nodeInfo: &csiv1alpha1.CSINodeInfo{ + CSIDrivers: []csiv1alpha1.CSIDriverInfo{ + { + Driver: "driver1", + NodeID: "node1", + TopologyKeys: []string{"key1", "key2"}, + }, + { + Driver: "driver1", + NodeID: "nodeX", + TopologyKeys: []string{"keyA", "keyB"}, + }, + }, + }, + expectErr: true, + }, + { + name: "single driver with duplicate topology keys", + nodeInfo: &csiv1alpha1.CSINodeInfo{ + CSIDrivers: []csiv1alpha1.CSIDriverInfo{ + { + Driver: "driver1", + NodeID: "node1", + TopologyKeys: []string{"key1", "key1"}, + }, + }, + }, + expectErr: true, + }, + { + name: "multiple drivers with one set of duplicate topology keys ", + nodeInfo: &csiv1alpha1.CSINodeInfo{ + CSIDrivers: []csiv1alpha1.CSIDriverInfo{ + { + Driver: "driver1", + NodeID: "node1", + TopologyKeys: []string{"key1"}, + }, + { + Driver: "driver2", + NodeID: "nodeX", + TopologyKeys: []string{"keyA", "keyA"}, + }, + }, + }, + expectErr: true, + }, + } + for _, tc := range testcases { + t.Logf("test case: %q", tc.name) + err := validateCSINodeInfo(tc.nodeInfo) + if err != nil && !tc.expectErr { + t.Errorf("expected no errors from validateCSINodeInfo but got error %v", err) + } + if err == nil && tc.expectErr { + t.Errorf("expected error from validateCSINodeInfo but got no errors") + } + } +} + func test(t *testing.T, addNodeInfo bool, csiNodeInfoEnabled bool, testcases []testcase) { defer utilfeaturetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features.CSINodeInfo, csiNodeInfoEnabled)() defer utilfeaturetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features.AttachVolumeLimit, true)()