From b5db644422d649adffce3a30e9361ad4a5206186 Mon Sep 17 00:00:00 2001 From: Joe Betz Date: Tue, 6 Jun 2023 20:26:43 -0400 Subject: [PATCH] Add merge map key validation to StorageVersions --- .../validation/validation.go | 7 +++ .../validation/validation_test.go | 54 +++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/pkg/apis/apiserverinternal/validation/validation.go b/pkg/apis/apiserverinternal/validation/validation.go index 4493907d1cc..7b922d90ccb 100644 --- a/pkg/apis/apiserverinternal/validation/validation.go +++ b/pkg/apis/apiserverinternal/validation/validation.go @@ -21,6 +21,7 @@ import ( "strings" apimachineryvalidation "k8s.io/apimachinery/pkg/api/validation" + "k8s.io/apimachinery/pkg/util/sets" utilvalidation "k8s.io/apimachinery/pkg/util/validation" "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/kubernetes/pkg/apis/apiserverinternal" @@ -67,7 +68,13 @@ func ValidateStorageVersionStatusUpdate(sv, oldSV *apiserverinternal.StorageVers func validateStorageVersionStatus(ss apiserverinternal.StorageVersionStatus, fldPath *field.Path) field.ErrorList { var allErrs field.ErrorList + allAPIServerIDs := sets.New[string]() for i, ssv := range ss.StorageVersions { + if allAPIServerIDs.Has(ssv.APIServerID) { + allErrs = append(allErrs, field.Duplicate(fldPath.Child("storageVersions").Index(i).Child("apiServerID"), ssv.APIServerID)) + } else { + allAPIServerIDs.Insert(ssv.APIServerID) + } allErrs = append(allErrs, validateServerStorageVersion(ssv, fldPath.Child("storageVersions").Index(i))...) } if err := validateCommonVersion(ss, fldPath); err != nil { diff --git a/pkg/apis/apiserverinternal/validation/validation_test.go b/pkg/apis/apiserverinternal/validation/validation_test.go index a2bbae659d2..887cd5e7b30 100644 --- a/pkg/apis/apiserverinternal/validation/validation_test.go +++ b/pkg/apis/apiserverinternal/validation/validation_test.go @@ -22,6 +22,7 @@ import ( "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/kubernetes/pkg/apis/apiserverinternal" + "k8s.io/utils/pointer" ) func TestValidateServerStorageVersion(t *testing.T) { @@ -126,6 +127,59 @@ func TestValidateServerStorageVersion(t *testing.T) { } } +func TestValidateStorageVersionStatus(t *testing.T) { + cases := []struct { + svs apiserverinternal.StorageVersionStatus + expectedErr string + }{{ + svs: apiserverinternal.StorageVersionStatus{ + StorageVersions: []apiserverinternal.ServerStorageVersion{{ + APIServerID: "1", + EncodingVersion: "v1alpha1", + DecodableVersions: []string{"v1alpha1", "v1"}, + }, { + APIServerID: "2", + EncodingVersion: "v1alpha1", + DecodableVersions: []string{"v1alpha1", "v1"}, + }}, + CommonEncodingVersion: pointer.String("v1alpha1"), + }, + expectedErr: "", + }, { + svs: apiserverinternal.StorageVersionStatus{ + StorageVersions: []apiserverinternal.ServerStorageVersion{{ + APIServerID: "1", + EncodingVersion: "v1alpha1", + DecodableVersions: []string{"v1alpha1", "v1"}, + }, { + APIServerID: "1", + EncodingVersion: "v1beta1", + DecodableVersions: []string{"v1alpha1", "v1"}, + }}, + CommonEncodingVersion: pointer.String("v1alpha1"), + }, + expectedErr: "storageVersions[1].apiServerID: Duplicate value: \"1\"", + }} + + for _, tc := range cases { + err := validateStorageVersionStatus(tc.svs, field.NewPath("")).ToAggregate() + if err == nil && len(tc.expectedErr) == 0 { + continue + } + if err != nil && len(tc.expectedErr) == 0 { + t.Errorf("unexpected error %v", err) + continue + } + if err == nil && len(tc.expectedErr) != 0 { + t.Errorf("unexpected empty error") + continue + } + if !strings.Contains(err.Error(), tc.expectedErr) { + t.Errorf("expected error to contain %s, got %s", tc.expectedErr, err) + } + } +} + func TestValidateCommonVersion(t *testing.T) { cases := []struct { status apiserverinternal.StorageVersionStatus