diff --git a/pkg/apis/core/validation/validation.go b/pkg/apis/core/validation/validation.go index d7260c655fc..bfda6a76f0a 100644 --- a/pkg/apis/core/validation/validation.go +++ b/pkg/apis/core/validation/validation.go @@ -1530,6 +1530,23 @@ func validateStorageOSPersistentVolumeSource(storageos *core.StorageOSPersistent return allErrs } +// validatePVSecretReference check whether provided SecretReference object is valid in terms of secret name and namespace. + +func validatePVSecretReference(secretRef *core.SecretReference, fldPath *field.Path) field.ErrorList { + var allErrs field.ErrorList + if len(secretRef.Name) == 0 { + allErrs = append(allErrs, field.Required(fldPath.Child("name"), "")) + } else { + allErrs = append(allErrs, ValidateDNS1123Label(secretRef.Name, fldPath.Child("name"))...) + } + if len(secretRef.Namespace) == 0 { + allErrs = append(allErrs, field.Required(fldPath.Child("namespace"), "")) + } else { + allErrs = append(allErrs, ValidateDNS1123Label(secretRef.Namespace, fldPath.Child("namespace"))...) + } + return allErrs +} + func ValidateCSIDriverName(driverName string, fldPath *field.Path) field.ErrorList { allErrs := field.ErrorList{} @@ -1556,58 +1573,18 @@ func validateCSIPersistentVolumeSource(csi *core.CSIPersistentVolumeSource, fldP if len(csi.VolumeHandle) == 0 { allErrs = append(allErrs, field.Required(fldPath.Child("volumeHandle"), "")) } - if csi.ControllerPublishSecretRef != nil { - if len(csi.ControllerPublishSecretRef.Name) == 0 { - allErrs = append(allErrs, field.Required(fldPath.Child("controllerPublishSecretRef", "name"), "")) - } else { - allErrs = append(allErrs, ValidateDNS1123Label(csi.ControllerPublishSecretRef.Name, fldPath.Child("name"))...) - } - if len(csi.ControllerPublishSecretRef.Namespace) == 0 { - allErrs = append(allErrs, field.Required(fldPath.Child("controllerPublishSecretRef", "namespace"), "")) - } else { - allErrs = append(allErrs, ValidateDNS1123Label(csi.ControllerPublishSecretRef.Namespace, fldPath.Child("namespace"))...) - } + allErrs = append(allErrs, validatePVSecretReference(csi.ControllerPublishSecretRef, fldPath.Child("controllerPublishSecretRef"))...) } - if csi.ControllerExpandSecretRef != nil { - if len(csi.ControllerExpandSecretRef.Name) == 0 { - allErrs = append(allErrs, field.Required(fldPath.Child("controllerExpandSecretRef", "name"), "")) - } else { - allErrs = append(allErrs, ValidateDNS1123Label(csi.ControllerExpandSecretRef.Name, fldPath.Child("name"))...) - } - if len(csi.ControllerExpandSecretRef.Namespace) == 0 { - allErrs = append(allErrs, field.Required(fldPath.Child("controllerExpandSecretRef", "namespace"), "")) - } else { - allErrs = append(allErrs, ValidateDNS1123Label(csi.ControllerExpandSecretRef.Namespace, fldPath.Child("namespace"))...) - } + allErrs = append(allErrs, validatePVSecretReference(csi.ControllerExpandSecretRef, fldPath.Child("controllerExpandSecretRef"))...) } - if csi.NodePublishSecretRef != nil { - if len(csi.NodePublishSecretRef.Name) == 0 { - allErrs = append(allErrs, field.Required(fldPath.Child("nodePublishSecretRef", "name"), "")) - } else { - allErrs = append(allErrs, ValidateDNS1123Label(csi.NodePublishSecretRef.Name, fldPath.Child("name"))...) - } - if len(csi.NodePublishSecretRef.Namespace) == 0 { - allErrs = append(allErrs, field.Required(fldPath.Child("nodePublishSecretRef", "namespace"), "")) - } else { - allErrs = append(allErrs, ValidateDNS1123Label(csi.NodePublishSecretRef.Namespace, fldPath.Child("namespace"))...) - } + allErrs = append(allErrs, validatePVSecretReference(csi.NodePublishSecretRef, fldPath.Child("nodePublishSecretRef"))...) } if csi.NodeExpandSecretRef != nil { - if len(csi.NodeExpandSecretRef.Name) == 0 { - allErrs = append(allErrs, field.Required(fldPath.Child("nodeExpandSecretRef", "name"), "")) - } else { - allErrs = append(allErrs, ValidateDNS1123Subdomain(csi.NodeExpandSecretRef.Name, fldPath.Child("name"))...) - } - if len(csi.NodeExpandSecretRef.Namespace) == 0 { - allErrs = append(allErrs, field.Required(fldPath.Child("nodeExpandSecretRef", "namespace"), "")) - } else { - allErrs = append(allErrs, ValidateDNS1123Label(csi.NodeExpandSecretRef.Namespace, fldPath.Child("namespace"))...) - } + allErrs = append(allErrs, validatePVSecretReference(csi.NodeExpandSecretRef, fldPath.Child("nodeExpandSecretRef"))...) } - return allErrs } diff --git a/pkg/apis/core/validation/validation_test.go b/pkg/apis/core/validation/validation_test.go index a4383152440..ea61e5c5998 100644 --- a/pkg/apis/core/validation/validation_test.go +++ b/pkg/apis/core/validation/validation_test.go @@ -20471,3 +20471,65 @@ func TestValidateAppArmorProfileFormat(t *testing.T) { } } } + +func TestValidatePVSecretReference(t *testing.T) { + rootFld := field.NewPath("name") + type args struct { + secretRef *core.SecretReference + fldPath *field.Path + } + tests := []struct { + name string + args args + expectError bool + expectedError string + }{ + { + name: "invalid secret ref name", + args: args{&core.SecretReference{Name: "$%^&*#", Namespace: "default"}, rootFld}, + expectError: true, + expectedError: "name.name: Invalid value: \"$%^&*#\": " + dnsLabelErrMsg, + }, + { + name: "invalid secret ref namespace", + args: args{&core.SecretReference{Name: "valid", Namespace: "$%^&*#"}, rootFld}, + expectError: true, + expectedError: "name.namespace: Invalid value: \"$%^&*#\": " + dnsLabelErrMsg, + }, + { + name: "invalid secret: missing namespace", + args: args{&core.SecretReference{Name: "valid"}, rootFld}, + expectError: true, + expectedError: "name.namespace: Required value", + }, + { + name: "invalid secret : missing name", + args: args{&core.SecretReference{Namespace: "default"}, rootFld}, + expectError: true, + expectedError: "name.name: Required value", + }, + { + name: "valid secret", + args: args{&core.SecretReference{Name: "valid", Namespace: "default"}, rootFld}, + expectError: false, + expectedError: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := validatePVSecretReference(tt.args.secretRef, tt.args.fldPath) + if tt.expectError && len(errs) == 0 { + t.Errorf("Unexpected success") + } + if tt.expectError && len(errs) != 0 { + str := errs[0].Error() + if str != "" && !strings.Contains(str, tt.expectedError) { + t.Errorf("%s: expected error detail either empty or %q, got %q", tt.name, tt.expectedError, str) + } + } + if !tt.expectError && len(errs) != 0 { + t.Errorf("Unexpected error(s): %v", errs) + } + }) + } +}