From 31f463721701a35c1061935d554a81141b0b68bb Mon Sep 17 00:00:00 2001 From: Joe Betz Date: Mon, 3 Mar 2025 09:49:51 -0500 Subject: [PATCH] Add validators: eachkey, eachval, subfield Introduce a composable set of tags for validating child data. This allows for point-of-use validation of shared types. Co-authored-by: Tim Hockin Co-authored-by: Aaron Prindle Co-authored-by: Yongrui Lin --- .../apimachinery/pkg/api/validate/each.go | 119 +++++ .../pkg/api/validate/each_test.go | 316 +++++++++++++ .../apimachinery/pkg/api/validate/subfield.go | 41 ++ .../cmd/validation-gen/validators/each.go | 430 ++++++++++++++++++ .../cmd/validation-gen/validators/subfield.go | 126 +++++ 5 files changed, 1032 insertions(+) create mode 100644 staging/src/k8s.io/apimachinery/pkg/api/validate/each.go create mode 100644 staging/src/k8s.io/apimachinery/pkg/api/validate/each_test.go create mode 100644 staging/src/k8s.io/apimachinery/pkg/api/validate/subfield.go create mode 100644 staging/src/k8s.io/code-generator/cmd/validation-gen/validators/each.go create mode 100644 staging/src/k8s.io/code-generator/cmd/validation-gen/validators/subfield.go diff --git a/staging/src/k8s.io/apimachinery/pkg/api/validate/each.go b/staging/src/k8s.io/apimachinery/pkg/api/validate/each.go new file mode 100644 index 00000000000..ccff8d19b1f --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/api/validate/each.go @@ -0,0 +1,119 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package validate + +import ( + "context" + + "k8s.io/apimachinery/pkg/api/operation" + "k8s.io/apimachinery/pkg/util/validation/field" +) + +// CompareFunc is a function that compares two values of the same type. +type CompareFunc[T any] func(T, T) bool + +// EachSliceVal validates each element of newSlice with the specified +// validation function. The comparison function is used to find the +// corresponding value in oldSlice. The value-type of the slices is assumed to +// not be nilable. +func EachSliceVal[T any](ctx context.Context, op operation.Operation, fldPath *field.Path, newSlice, oldSlice []T, + cmp CompareFunc[T], validator ValidateFunc[*T]) field.ErrorList { + var errs field.ErrorList + for i, val := range newSlice { + var old *T + if cmp != nil && len(oldSlice) > 0 { + old = lookup(oldSlice, val, cmp) + } + errs = append(errs, validator(ctx, op, fldPath.Index(i), &val, old)...) + } + return errs +} + +// EachSliceValNilable validates each element of newSlice with the specified +// validation function. The comparison function is used to find the +// corresponding value in oldSlice. The value-type of the slices is assumed to +// be nilable. +func EachSliceValNilable[T any](ctx context.Context, op operation.Operation, fldPath *field.Path, newSlice, oldSlice []T, + cmp CompareFunc[T], validator ValidateFunc[T]) field.ErrorList { + var errs field.ErrorList + for i, val := range newSlice { + var old T + if cmp != nil && len(oldSlice) > 0 { + p := lookup(oldSlice, val, cmp) + if p != nil { + old = *p + } + } + errs = append(errs, validator(ctx, op, fldPath.Index(i), val, old)...) + } + return errs +} + +// lookup returns a pointer to the first element in the list that matches the +// target, according to the provided comparison function, or else nil. +func lookup[T any](list []T, target T, cmp func(T, T) bool) *T { + for i := range list { + if cmp(list[i], target) { + return &list[i] + } + } + return nil +} + +// EachMapVal validates each element of newMap with the specified validation +// function and, if the corresponding key is found in oldMap, the old value. +// The value-type of the slices is assumed to not be nilable. +func EachMapVal[K ~string, V any](ctx context.Context, op operation.Operation, fldPath *field.Path, newMap, oldMap map[K]V, + validator ValidateFunc[*V]) field.ErrorList { + var errs field.ErrorList + for key, val := range newMap { + var old *V + if o, found := oldMap[key]; found { + old = &o + } + errs = append(errs, validator(ctx, op, fldPath.Key(string(key)), &val, old)...) + } + return errs +} + +// EachMapValNilable validates each element of newMap with the specified +// validation function and, if the corresponding key is found in oldMap, the +// old value. The value-type of the slices is assumed to be nilable. +func EachMapValNilable[K ~string, V any](ctx context.Context, op operation.Operation, fldPath *field.Path, newMap, oldMap map[K]V, + validator ValidateFunc[V]) field.ErrorList { + var errs field.ErrorList + for key, val := range newMap { + var old V + if o, found := oldMap[key]; found { + old = o + } + errs = append(errs, validator(ctx, op, fldPath.Key(string(key)), val, old)...) + } + return errs +} + +// EachMapKey validates each element of newMap with the specified +// validation function. The oldMap argument is not used. +func EachMapKey[K ~string, T any](ctx context.Context, op operation.Operation, fldPath *field.Path, newMap, oldMap map[K]T, + validator ValidateFunc[*K]) field.ErrorList { + var errs field.ErrorList + for key := range newMap { + // Note: the field path is the field, not the key. + errs = append(errs, validator(ctx, op, fldPath, &key, nil)...) + } + return errs +} diff --git a/staging/src/k8s.io/apimachinery/pkg/api/validate/each_test.go b/staging/src/k8s.io/apimachinery/pkg/api/validate/each_test.go new file mode 100644 index 00000000000..312d8545839 --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/api/validate/each_test.go @@ -0,0 +1,316 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package validate + +import ( + "context" + "fmt" + "reflect" + "slices" + "testing" + + "k8s.io/apimachinery/pkg/api/operation" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/utils/ptr" +) + +type TestStruct struct { + I int + D string +} + +func TestEachSliceVal(t *testing.T) { + testEachSliceVal(t, "valid", []int{11, 12, 13}) + testEachSliceVal(t, "valid", []string{"a", "b", "c"}) + testEachSliceVal(t, "valid", []TestStruct{{11, "a"}, {12, "b"}, {13, "c"}}) + + testEachSliceVal(t, "empty", []int{}) + testEachSliceVal(t, "empty", []string{}) + testEachSliceVal(t, "empty", []TestStruct{}) + + testEachSliceVal[int](t, "nil", nil) + testEachSliceVal[string](t, "nil", nil) + testEachSliceVal[TestStruct](t, "nil", nil) + + testEachSliceValUpdate(t, "valid", []int{11, 12, 13}) + testEachSliceValUpdate(t, "valid", []string{"a", "b", "c"}) + testEachSliceValUpdate(t, "valid", []TestStruct{{11, "a"}, {12, "b"}, {13, "c"}}) + + testEachSliceValUpdate(t, "empty", []int{}) + testEachSliceValUpdate(t, "empty", []string{}) + testEachSliceValUpdate(t, "empty", []TestStruct{}) + + testEachSliceValUpdate[int](t, "nil", nil) + testEachSliceValUpdate[string](t, "nil", nil) + testEachSliceValUpdate[TestStruct](t, "nil", nil) +} + +func testEachSliceVal[T any](t *testing.T, name string, input []T) { + var zero T + t.Run(fmt.Sprintf("%s(%T)", name, zero), func(t *testing.T) { + calls := 0 + vfn := func(ctx context.Context, op operation.Operation, fldPath *field.Path, newVal, oldVal *T) field.ErrorList { + if oldVal != nil { + t.Errorf("expected nil oldVal, got %v", *oldVal) + } + calls++ + return nil + } + _ = EachSliceVal(context.Background(), operation.Operation{}, field.NewPath("test"), input, nil, nil, vfn) + if calls != len(input) { + t.Errorf("expected %d calls, got %d", len(input), calls) + } + }) +} + +func testEachSliceValUpdate[T any](t *testing.T, name string, input []T) { + var zero T + t.Run(fmt.Sprintf("%s(%T)", name, zero), func(t *testing.T) { + calls := 0 + vfn := func(ctx context.Context, op operation.Operation, fldPath *field.Path, newVal, oldVal *T) field.ErrorList { + if oldVal == nil { + t.Fatalf("expected non-nil oldVal") + } + if !reflect.DeepEqual(*newVal, *oldVal) { + t.Errorf("expected oldVal == newVal, got %v, %v", *oldVal, *newVal) + } + calls++ + return nil + } + old := make([]T, len(input)) + copy(old, input) + slices.Reverse(old) + cmp := func(a, b T) bool { return reflect.DeepEqual(a, b) } + _ = EachSliceVal(context.Background(), operation.Operation{}, field.NewPath("test"), input, old, cmp, vfn) + if calls != len(input) { + t.Errorf("expected %d calls, got %d", len(input), calls) + } + }) +} + +func TestEachSliceValNilablePointer(t *testing.T) { + testEachSliceValNilable(t, "valid", []*int{ptr.To(11), ptr.To(12), ptr.To(13)}) + testEachSliceValNilable(t, "valid", []*string{ptr.To("a"), ptr.To("b"), ptr.To("c")}) + testEachSliceValNilable(t, "valid", []*TestStruct{{11, "a"}, {12, "b"}, {13, "c"}}) + + testEachSliceValNilable(t, "empty", []*int{}) + testEachSliceValNilable(t, "empty", []*string{}) + testEachSliceValNilable(t, "empty", []*TestStruct{}) + + testEachSliceValNilable[int](t, "nil", nil) + testEachSliceValNilable[string](t, "nil", nil) + testEachSliceValNilable[TestStruct](t, "nil", nil) + + testEachSliceValNilableUpdate(t, "valid", []*int{ptr.To(11), ptr.To(12), ptr.To(13)}) + testEachSliceValNilableUpdate(t, "valid", []*string{ptr.To("a"), ptr.To("b"), ptr.To("c")}) + testEachSliceValNilableUpdate(t, "valid", []*TestStruct{{11, "a"}, {12, "b"}, {13, "c"}}) + + testEachSliceValNilableUpdate(t, "empty", []*int{}) + testEachSliceValNilableUpdate(t, "empty", []*string{}) + testEachSliceValNilableUpdate(t, "empty", []*TestStruct{}) + + testEachSliceValNilableUpdate[int](t, "nil", nil) + testEachSliceValNilableUpdate[string](t, "nil", nil) + testEachSliceValNilableUpdate[TestStruct](t, "nil", nil) +} + +func TestEachSliceValNilableSlice(t *testing.T) { + testEachSliceValNilable(t, "valid", [][]int{{11, 12, 13}, {12, 13, 14}, {13, 14, 15}}) + testEachSliceValNilable(t, "valid", [][]string{{"a", "b", "c"}, {"b", "c", "d"}, {"c", "d", "e"}}) + testEachSliceValNilable(t, "valid", [][]TestStruct{ + {{11, "a"}, {12, "b"}, {13, "c"}}, + {{12, "a"}, {13, "b"}, {14, "c"}}, + {{13, "a"}, {14, "b"}, {15, "c"}}}) + + testEachSliceValNilable(t, "empty", [][]int{{}, {}, {}}) + testEachSliceValNilable(t, "empty", [][]string{{}, {}, {}}) + testEachSliceValNilable(t, "empty", [][]TestStruct{{}, {}, {}}) + + testEachSliceValNilable[int](t, "nil", nil) + testEachSliceValNilable[string](t, "nil", nil) + testEachSliceValNilable[TestStruct](t, "nil", nil) + + testEachSliceValNilableUpdate(t, "valid", [][]int{{11, 12, 13}, {12, 13, 14}, {13, 14, 15}}) + testEachSliceValNilableUpdate(t, "valid", [][]string{{"a", "b", "c"}, {"b", "c", "d"}, {"c", "d", "e"}}) + testEachSliceValNilableUpdate(t, "valid", [][]TestStruct{ + {{11, "a"}, {12, "b"}, {13, "c"}}, + {{12, "a"}, {13, "b"}, {14, "c"}}, + {{13, "a"}, {14, "b"}, {15, "c"}}}) + + testEachSliceValNilableUpdate(t, "empty", [][]int{{}, {}, {}}) + testEachSliceValNilableUpdate(t, "empty", [][]string{{}, {}, {}}) + testEachSliceValNilableUpdate(t, "empty", [][]TestStruct{{}, {}, {}}) + + testEachSliceValNilableUpdate[int](t, "nil", nil) + testEachSliceValNilableUpdate[string](t, "nil", nil) + testEachSliceValNilableUpdate[TestStruct](t, "nil", nil) +} + +func testEachSliceValNilable[T any](t *testing.T, name string, input []T) { + var zero T + t.Run(fmt.Sprintf("%s(%T)", name, zero), func(t *testing.T) { + calls := 0 + vfn := func(ctx context.Context, op operation.Operation, fldPath *field.Path, newVal, oldVal T) field.ErrorList { + if !reflect.DeepEqual(oldVal, zero) { + t.Errorf("expected nil oldVal, got %v", oldVal) + } + calls++ + return nil + } + _ = EachSliceValNilable(context.Background(), operation.Operation{}, field.NewPath("test"), input, nil, nil, vfn) + if calls != len(input) { + t.Errorf("expected %d calls, got %d", len(input), calls) + } + }) +} + +func testEachSliceValNilableUpdate[T any](t *testing.T, name string, input []T) { + var zero T + t.Run(fmt.Sprintf("%s(%T)", name, zero), func(t *testing.T) { + calls := 0 + vfn := func(ctx context.Context, op operation.Operation, fldPath *field.Path, newVal, oldVal T) field.ErrorList { + if reflect.DeepEqual(oldVal, zero) { + t.Fatalf("expected non-nil oldVal") + } + if !reflect.DeepEqual(newVal, oldVal) { + t.Errorf("expected oldVal == newVal, got %v, %v", oldVal, newVal) + } + calls++ + return nil + } + old := make([]T, len(input)) + copy(old, input) + slices.Reverse(old) + cmp := func(a, b T) bool { return reflect.DeepEqual(a, b) } + _ = EachSliceValNilable(context.Background(), operation.Operation{}, field.NewPath("test"), input, old, cmp, vfn) + if calls != len(input) { + t.Errorf("expected %d calls, got %d", len(input), calls) + } + }) +} + +func TestEachMapVal(t *testing.T) { + testEachMapVal(t, "valid", map[string]int{"one": 11, "two": 12, "three": 13}) + testEachMapVal(t, "valid", map[string]string{"A": "a", "B": "b", "C": "c"}) + testEachMapVal(t, "valid", map[string]TestStruct{"one": {11, "a"}, "two": {12, "b"}, "three": {13, "c"}}) + + testEachMapVal(t, "empty", map[string]int{}) + testEachMapVal(t, "empty", map[string]string{}) + testEachMapVal(t, "empty", map[string]TestStruct{}) + + testEachMapVal[int](t, "nil", nil) + testEachMapVal[string](t, "nil", nil) + testEachMapVal[TestStruct](t, "nil", nil) +} + +func testEachMapVal[T any](t *testing.T, name string, input map[string]T) { + var zero T + t.Run(fmt.Sprintf("%s(%T)", name, zero), func(t *testing.T) { + calls := 0 + vfn := func(ctx context.Context, op operation.Operation, fldPath *field.Path, newVal, oldVal *T) field.ErrorList { + if oldVal != nil { + t.Errorf("expected nil oldVal, got %v", *oldVal) + } + calls++ + return nil + } + _ = EachMapVal(context.Background(), operation.Operation{}, field.NewPath("test"), input, nil, vfn) + if calls != len(input) { + t.Errorf("expected %d calls, got %d", len(input), calls) + } + }) +} + +func TestEachMapValNilablePointer(t *testing.T) { + testEachMapValNilable(t, "valid", map[string]*int{"one": ptr.To(11), "two": ptr.To(12), "three": ptr.To(13)}) + testEachMapValNilable(t, "valid", map[string]*string{"A": ptr.To("a"), "B": ptr.To("b"), "C": ptr.To("c")}) + testEachMapValNilable(t, "valid", map[string]*TestStruct{"one": {11, "a"}, "two": {12, "b"}, "three": {13, "c"}}) + + testEachMapValNilable(t, "empty", map[string]*int{}) + testEachMapValNilable(t, "empty", map[string]*string{}) + testEachMapValNilable(t, "empty", map[string]*TestStruct{}) + + testEachMapValNilable[int](t, "nil", nil) + testEachMapValNilable[string](t, "nil", nil) + testEachMapValNilable[TestStruct](t, "nil", nil) +} + +func TestEachMapValNilableSlice(t *testing.T) { + testEachMapValNilable(t, "valid", map[string][]int{ + "one": {11, 12, 13}, + "two": {12, 13, 14}, + "three": {13, 14, 15}}) + testEachMapValNilable(t, "valid", map[string][]string{ + "A": {"a", "b", "c"}, + "B": {"b", "c", "d"}, + "C": {"c", "d", "e"}}) + testEachMapValNilable(t, "valid", map[string][]TestStruct{ + "one": {{11, "a"}, {12, "b"}, {13, "c"}}, + "two": {{12, "a"}, {13, "b"}, {14, "c"}}, + "three": {{13, "a"}, {14, "b"}, {15, "c"}}}) + + testEachMapValNilable(t, "empty", map[string][]int{"a": {}, "b": {}, "c": {}}) + testEachMapValNilable(t, "empty", map[string][]string{"a": {}, "b": {}, "c": {}}) + testEachMapValNilable(t, "empty", map[string][]TestStruct{"a": {}, "b": {}, "c": {}}) + + testEachMapValNilable[int](t, "nil", nil) + testEachMapValNilable[string](t, "nil", nil) + testEachMapValNilable[TestStruct](t, "nil", nil) +} + +func testEachMapValNilable[T any](t *testing.T, name string, input map[string]T) { + var zero T + t.Run(fmt.Sprintf("%s(%T)", name, zero), func(t *testing.T) { + calls := 0 + vfn := func(ctx context.Context, op operation.Operation, fldPath *field.Path, newVal, oldVal T) field.ErrorList { + if !reflect.DeepEqual(oldVal, zero) { + t.Errorf("expected nil oldVal, got %v", oldVal) + } + calls++ + return nil + } + _ = EachMapValNilable(context.Background(), operation.Operation{}, field.NewPath("test"), input, nil, vfn) + if calls != len(input) { + t.Errorf("expected %d calls, got %d", len(input), calls) + } + }) +} + +type StringType string + +func TestEachMapKey(t *testing.T) { + testEachMapKey(t, "valid", map[string]int{"one": 11, "two": 12, "three": 13}) + testEachMapKey(t, "valid", map[StringType]string{"A": "a", "B": "b", "C": "c"}) +} + +func testEachMapKey[K ~string, V any](t *testing.T, name string, input map[K]V) { + var zero K + t.Run(fmt.Sprintf("%s(%T)", name, zero), func(t *testing.T) { + calls := 0 + vfn := func(ctx context.Context, op operation.Operation, fldPath *field.Path, newVal, oldVal *K) field.ErrorList { + if oldVal != nil { + t.Errorf("expected nil oldVal, got %v", *oldVal) + } + calls++ + return nil + } + _ = EachMapKey(context.Background(), operation.Operation{}, field.NewPath("test"), input, nil, vfn) + if calls != len(input) { + t.Errorf("expected %d calls, got %d", len(input), calls) + } + }) +} diff --git a/staging/src/k8s.io/apimachinery/pkg/api/validate/subfield.go b/staging/src/k8s.io/apimachinery/pkg/api/validate/subfield.go new file mode 100644 index 00000000000..3dcd28f26ec --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/api/validate/subfield.go @@ -0,0 +1,41 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package validate + +import ( + "context" + + "k8s.io/apimachinery/pkg/api/operation" + "k8s.io/apimachinery/pkg/util/validation/field" +) + +// GetFieldFunc is a function that extracts a field from a type and returns a +// nilable value. +type GetFieldFunc[Tstruct any, Tfield any] func(*Tstruct) Tfield + +// Subfield validates a subfield of a struct against a validator function. +func Subfield[Tstruct any, Tfield any](ctx context.Context, op operation.Operation, fldPath *field.Path, newStruct, oldStruct *Tstruct, + fldName string, getField GetFieldFunc[Tstruct, Tfield], validator ValidateFunc[Tfield]) field.ErrorList { + var errs field.ErrorList + newVal := getField(newStruct) + var oldVal Tfield + if oldStruct != nil { + oldVal = getField(oldStruct) + } + errs = append(errs, validator(ctx, op, fldPath.Child(fldName), newVal, oldVal)...) + return errs +} diff --git a/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/each.go b/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/each.go new file mode 100644 index 00000000000..43aa2c227e1 --- /dev/null +++ b/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/each.go @@ -0,0 +1,430 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package validators + +import ( + "fmt" + "strings" + + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/gengo/v2/types" +) + +const ( + listTypeTagName = "k8s:listType" + ListMapKeyTagName = "k8s:listMapKey" + eachValTagName = "k8s:eachVal" + eachKeyTagName = "k8s:eachKey" +) + +// We keep the eachVal and eachKey validators around because the main +// code-generation logic calls them directly. We could move them into the main +// pkg, but it's easier and cleaner to leave them here. +var globalEachVal *eachValTagValidator +var globalEachKey *eachKeyTagValidator + +func init() { + // Lists with list-map semantics are comprised of multiple tags, which need + // to share information between them. + shared := map[string]*listMap{} // keyed by the fieldpath + RegisterTagValidator(listTypeTagValidator{shared}) + RegisterTagValidator(listMapKeyTagValidator{shared}) + + globalEachVal = &eachValTagValidator{shared, nil} + RegisterTagValidator(globalEachVal) + + globalEachKey = &eachKeyTagValidator{nil} + RegisterTagValidator(globalEachKey) +} + +// This applies to all tags in this file. +var listTagsValidScopes = sets.New(ScopeAny) + +// listMap collects information about a single list with map semantics. +type listMap struct { + declaredAsMap bool + keyFields []string +} + +type listTypeTagValidator struct { + byFieldPath map[string]*listMap +} + +func (listTypeTagValidator) Init(Config) {} + +func (listTypeTagValidator) TagName() string { + return listTypeTagName +} + +func (listTypeTagValidator) ValidScopes() sets.Set[Scope] { + return listTagsValidScopes +} + +func (lttv listTypeTagValidator) GetValidations(context Context, _ []string, payload string) (Validations, error) { + t := context.Type + if t.Kind == types.Alias { + t = t.Underlying + } + if t.Kind != types.Slice && t.Kind != types.Array { + return Validations{}, fmt.Errorf("can only be used on list types") + } + + switch payload { + case "atomic", "set": + // Allowed but no special handling. + case "map": + if realType(t.Elem).Kind != types.Struct { + return Validations{}, fmt.Errorf("only lists of structs can be list-maps") + } + + // Save the fact that this list is a map. + if lttv.byFieldPath[context.Path.String()] == nil { + lttv.byFieldPath[context.Path.String()] = &listMap{} + } + lm := lttv.byFieldPath[context.Path.String()] + lm.declaredAsMap = true + default: + return Validations{}, fmt.Errorf("unknown list type %q", payload) + } + + // This tag doesn't generate any validations. It just accumulates + // information for other tags to use. + return Validations{}, nil +} + +func realType(t *types.Type) *types.Type { + for { + if t.Kind == types.Alias { + t = t.Underlying + } else if t.Kind == types.Pointer { + t = t.Elem + } else { + break + } + } + return t +} + +func (lttv listTypeTagValidator) Docs() TagDoc { + doc := TagDoc{ + Tag: lttv.TagName(), + Scopes: lttv.ValidScopes().UnsortedList(), + Description: "Declares a list field's semantic type.", + Payloads: []TagPayloadDoc{{ + Description: "", + Docs: "map | atomic", + }}, + } + return doc +} + +type listMapKeyTagValidator struct { + byFieldPath map[string]*listMap +} + +func (listMapKeyTagValidator) Init(Config) {} + +func (listMapKeyTagValidator) TagName() string { + return ListMapKeyTagName +} + +func (listMapKeyTagValidator) ValidScopes() sets.Set[Scope] { + return listTagsValidScopes +} + +func (lmktv listMapKeyTagValidator) GetValidations(context Context, _ []string, payload string) (Validations, error) { + t := context.Type + if t.Kind == types.Alias { + t = t.Underlying + } + if t.Kind != types.Slice && t.Kind != types.Array { + return Validations{}, fmt.Errorf("can only be used on list types") + } + if realType(t.Elem).Kind != types.Struct { + return Validations{}, fmt.Errorf("only lists of structs can be list-maps") + } + + var fieldName string + if memb := getMemberByJSON(realType(t.Elem), payload); memb == nil { + return Validations{}, fmt.Errorf("no field for JSON name %q", payload) + } else if k := realType(memb.Type).Kind; k != types.Builtin { + return Validations{}, fmt.Errorf("only primitive types can be list-map keys, not %s", k) + } else { + fieldName = memb.Name + } + + if lmktv.byFieldPath[context.Path.String()] == nil { + lmktv.byFieldPath[context.Path.String()] = &listMap{} + } + lm := lmktv.byFieldPath[context.Path.String()] + lm.keyFields = append(lm.keyFields, fieldName) + + // This tag doesn't generate any validations. It just accumulates + // information for other tags to use. + return Validations{}, nil +} + +func (lmktv listMapKeyTagValidator) Docs() TagDoc { + doc := TagDoc{ + Tag: lmktv.TagName(), + Scopes: lmktv.ValidScopes().UnsortedList(), + Description: "Declares a named sub-field of a list's value-type to be part of the list-map key.", + Payloads: []TagPayloadDoc{{ + Description: "", + Docs: "The name of the field.", + }}, + } + return doc +} + +type eachValTagValidator struct { + byFieldPath map[string]*listMap + validator Validator +} + +func (evtv *eachValTagValidator) Init(cfg Config) { + evtv.validator = cfg.Validator +} + +func (eachValTagValidator) TagName() string { + return eachValTagName +} + +func (eachValTagValidator) ValidScopes() sets.Set[Scope] { + return listTagsValidScopes +} + +// LateTagValidator indicatesa that validator has to run after the listType and +// listMapKey tags. +func (eachValTagValidator) LateTagValidator() {} + +var ( + validateEachSliceVal = types.Name{Package: libValidationPkg, Name: "EachSliceVal"} + validateEachSliceValNilable = types.Name{Package: libValidationPkg, Name: "EachSliceValNilable"} + validateEachMapVal = types.Name{Package: libValidationPkg, Name: "EachMapVal"} + validateEachMapValNilable = types.Name{Package: libValidationPkg, Name: "EachMapValNilable"} +) + +func (evtv eachValTagValidator) GetValidations(context Context, _ []string, payload string) (Validations, error) { + t := context.Type + if t.Kind == types.Alias { + t = t.Underlying + } + switch t.Kind { + case types.Slice, types.Array, types.Map: + default: + return Validations{}, fmt.Errorf("can only be used on list or map types") + } + + fakeComments := []string{payload} + elemContext := Context{ + Type: t.Elem, + Parent: t, + Path: context.Path.Key("*"), + } + switch t.Kind { + case types.Slice, types.Array: + elemContext.Scope = ScopeListVal + case types.Map: + elemContext.Scope = ScopeMapVal + } + if validations, err := evtv.validator.ExtractValidations(elemContext, fakeComments); err != nil { + return Validations{}, err + } else { + if len(validations.Variables) > 0 { + return Validations{}, fmt.Errorf("variable generation is not supported") + } + return evtv.getValidations(context.Path, t, validations) + } +} + +func (evtv eachValTagValidator) getValidations(fldPath *field.Path, t *types.Type, validations Validations) (Validations, error) { + switch t.Kind { + case types.Slice, types.Array: + return evtv.getListValidations(fldPath, t, validations) + case types.Map: + return evtv.getMapValidations(t, validations) + } + return Validations{}, fmt.Errorf("non-iterable type: %v", t) +} + +// ForEachVal returns a validation that applies a function to each element of +// a list or map. +func ForEachVal(fldPath *field.Path, t *types.Type, fn FunctionGen) (Validations, error) { + return globalEachVal.getValidations(fldPath, t, Validations{Functions: []FunctionGen{fn}}) +} + +func (evtv eachValTagValidator) getListValidations(fldPath *field.Path, t *types.Type, validations Validations) (Validations, error) { + result := Validations{} + result.OpaqueValType = validations.OpaqueType + + var listMap *listMap + if lm, found := evtv.byFieldPath[fldPath.String()]; found { + if !lm.declaredAsMap { + return Validations{}, fmt.Errorf("found listMapKey without listType=map") + } + if len(lm.keyFields) == 0 { + return Validations{}, fmt.Errorf("found listType=map without listMapKey") + } + listMap = lm + } + for _, vfn := range validations.Functions { + // Which function we call depends on whether the value-type is + // already nilable or not. + var validateEach types.Name + + if isNilableType(t.Elem) { + validateEach = validateEachSliceValNilable + } else { + validateEach = validateEachSliceVal + } + + var cmpArg any = Literal("nil") + if listMap != nil { + cmpFn := FunctionLiteral{ + Parameters: []ParamResult{{"a", t.Elem}, {"b", t.Elem}}, + Results: []ParamResult{{"", types.Bool}}, + } + buf := strings.Builder{} + buf.WriteString("return ") + for i, fld := range listMap.keyFields { + if i > 0 { + buf.WriteString(" && ") + } + buf.WriteString(fmt.Sprintf("a.%s == b.%s", fld, fld)) + } + cmpFn.Body = buf.String() + cmpArg = cmpFn + } + f := Function(eachValTagName, vfn.Flags(), validateEach, cmpArg, WrapperFunction{vfn, t.Elem}) + result.Functions = append(result.Functions, f) + } + + return result, nil +} + +func (evtv eachValTagValidator) getMapValidations(t *types.Type, validations Validations) (Validations, error) { + result := Validations{} + result.OpaqueValType = validations.OpaqueType + + for _, vfn := range validations.Functions { + // Which function we call depends on whether the value-type is + // already nilable or not. + var validateEach types.Name + + if isNilableType(t.Elem) { + validateEach = validateEachMapValNilable + } else { + validateEach = validateEachMapVal + } + + f := Function(eachValTagName, vfn.Flags(), validateEach, WrapperFunction{vfn, t.Elem}) + result.Functions = append(result.Functions, f) + } + + return result, nil +} + +func (evtv eachValTagValidator) Docs() TagDoc { + doc := TagDoc{ + Tag: evtv.TagName(), + Scopes: evtv.ValidScopes().UnsortedList(), + Description: "Declares a validation for each value in a map or list.", + Payloads: []TagPayloadDoc{{ + Description: "", + Docs: "The tag to evaluate for each value.", + }}, + } + return doc +} + +type eachKeyTagValidator struct { + validator Validator +} + +func (ektv *eachKeyTagValidator) Init(cfg Config) { + ektv.validator = cfg.Validator +} + +func (eachKeyTagValidator) TagName() string { + return eachKeyTagName +} + +func (eachKeyTagValidator) ValidScopes() sets.Set[Scope] { + return listTagsValidScopes +} + +var ( + validateEachMapKey = types.Name{Package: libValidationPkg, Name: "EachMapKey"} +) + +func (ektv eachKeyTagValidator) GetValidations(context Context, _ []string, payload string) (Validations, error) { + t := context.Type + if t.Kind == types.Alias { + t = t.Underlying + } + if t.Kind != types.Map { + return Validations{}, fmt.Errorf("can only be used on map types") + } + + fakeComments := []string{payload} + elemContext := Context{ + Scope: ScopeMapKey, + Type: t.Elem, + Parent: t, + Path: context.Path.Child("(keys)"), + } + if validations, err := ektv.validator.ExtractValidations(elemContext, fakeComments); err != nil { + return Validations{}, err + } else { + if len(validations.Variables) > 0 { + return Validations{}, fmt.Errorf("variable generation is not supported") + } + + return ektv.getValidations(t, validations) + } +} + +func (ektv eachKeyTagValidator) getValidations(t *types.Type, validations Validations) (Validations, error) { + result := Validations{} + result.OpaqueKeyType = validations.OpaqueType + for _, vfn := range validations.Functions { + f := Function(eachKeyTagName, vfn.Flags(), validateEachMapKey, WrapperFunction{vfn, t.Key}) + result.Functions = append(result.Functions, f) + } + return result, nil +} + +// ForEachKey returns a validation that applies a function to each key of +// a map. +func ForEachKey(_ *field.Path, t *types.Type, fn FunctionGen) (Validations, error) { + return globalEachKey.getValidations(t, Validations{Functions: []FunctionGen{fn}}) +} + +func (ektv eachKeyTagValidator) Docs() TagDoc { + doc := TagDoc{ + Tag: ektv.TagName(), + Scopes: ektv.ValidScopes().UnsortedList(), + Description: "Declares a validation for each value in a map or list.", + Payloads: []TagPayloadDoc{{ + Description: "", + Docs: "The tag to evaluate for each value.", + }}, + } + return doc +} diff --git a/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/subfield.go b/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/subfield.go new file mode 100644 index 00000000000..f6913b19c20 --- /dev/null +++ b/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/subfield.go @@ -0,0 +1,126 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package validators + +import ( + "fmt" + + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/gengo/v2/types" +) + +const ( + subfieldTagName = "k8s:subfield" +) + +func init() { + RegisterTagValidator(&subfieldTagValidator{}) +} + +type subfieldTagValidator struct { + validator Validator +} + +func (stv *subfieldTagValidator) Init(cfg Config) { + stv.validator = cfg.Validator +} + +func (subfieldTagValidator) TagName() string { + return subfieldTagName +} + +var subfieldTagValidScopes = sets.New(ScopeAny) + +func (subfieldTagValidator) ValidScopes() sets.Set[Scope] { + return subfieldTagValidScopes +} + +var ( + validateSubfield = types.Name{Package: libValidationPkg, Name: "Subfield"} +) + +func (stv subfieldTagValidator) GetValidations(context Context, args []string, payload string) (Validations, error) { + t := realType(context.Type) + if t.Kind != types.Struct { + return Validations{}, fmt.Errorf("can only be used on struct types") + } + if len(args) != 1 { + return Validations{}, fmt.Errorf("requires exactly one arg") + } + subname := args[0] + submemb := getMemberByJSON(t, subname) + if submemb == nil { + return Validations{}, fmt.Errorf("no field for json name %q", subname) + } + + result := Validations{} + + fakeComments := []string{payload} + subContext := Context{ + Scope: ScopeField, + Type: submemb.Type, + Parent: t, + Path: context.Path.Child(subname), + } + if validations, err := stv.validator.ExtractValidations(subContext, fakeComments); err != nil { + return Validations{}, err + } else { + if len(validations.Variables) > 0 { + return Validations{}, fmt.Errorf("variable generation is not supported") + } + + for _, vfn := range validations.Functions { + nilableStructType := context.Type + if !isNilableType(nilableStructType) { + nilableStructType = types.PointerTo(nilableStructType) + } + nilableFieldType := submemb.Type + fieldExprPrefix := "" + if !isNilableType(nilableFieldType) { + nilableFieldType = types.PointerTo(nilableFieldType) + fieldExprPrefix = "&" + } + + getFn := FunctionLiteral{ + Parameters: []ParamResult{{"o", nilableStructType}}, + Results: []ParamResult{{"", nilableFieldType}}, + } + getFn.Body = fmt.Sprintf("return %so.%s", fieldExprPrefix, submemb.Name) + f := Function(subfieldTagName, vfn.Flags(), validateSubfield, subname, getFn, WrapperFunction{vfn, submemb.Type}) + result.Functions = append(result.Functions, f) + result.Variables = append(result.Variables, validations.Variables...) + } + } + return result, nil +} + +func (stv subfieldTagValidator) Docs() TagDoc { + doc := TagDoc{ + Tag: stv.TagName(), + Scopes: stv.ValidScopes().UnsortedList(), + Description: "Declares a validation for a subfield of a struct.", + Args: []TagArgDoc{{ + Description: "", + }}, + Docs: "The named subfield must be a direct field of the struct, or of an embedded struct.", + Payloads: []TagPayloadDoc{{ + Description: "", + Docs: "The tag to evaluate for the subfield.", + }}, + } + return doc +}