Add validators: eachkey, eachval, subfield

Introduce a composable set of tags for validating child data.
This allows for point-of-use validation of shared types.

Co-authored-by: Tim Hockin <thockin@google.com>
Co-authored-by: Aaron Prindle <aprindle@google.com>
Co-authored-by: Yongrui Lin <yongrlin@google.com>
This commit is contained in:
Joe Betz 2025-03-03 09:49:51 -05:00
parent b5f9a00258
commit 31f4637217
5 changed files with 1032 additions and 0 deletions

View File

@ -0,0 +1,119 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package validate
import (
"context"
"k8s.io/apimachinery/pkg/api/operation"
"k8s.io/apimachinery/pkg/util/validation/field"
)
// CompareFunc is a function that compares two values of the same type.
type CompareFunc[T any] func(T, T) bool
// EachSliceVal validates each element of newSlice with the specified
// validation function. The comparison function is used to find the
// corresponding value in oldSlice. The value-type of the slices is assumed to
// not be nilable.
func EachSliceVal[T any](ctx context.Context, op operation.Operation, fldPath *field.Path, newSlice, oldSlice []T,
cmp CompareFunc[T], validator ValidateFunc[*T]) field.ErrorList {
var errs field.ErrorList
for i, val := range newSlice {
var old *T
if cmp != nil && len(oldSlice) > 0 {
old = lookup(oldSlice, val, cmp)
}
errs = append(errs, validator(ctx, op, fldPath.Index(i), &val, old)...)
}
return errs
}
// EachSliceValNilable validates each element of newSlice with the specified
// validation function. The comparison function is used to find the
// corresponding value in oldSlice. The value-type of the slices is assumed to
// be nilable.
func EachSliceValNilable[T any](ctx context.Context, op operation.Operation, fldPath *field.Path, newSlice, oldSlice []T,
cmp CompareFunc[T], validator ValidateFunc[T]) field.ErrorList {
var errs field.ErrorList
for i, val := range newSlice {
var old T
if cmp != nil && len(oldSlice) > 0 {
p := lookup(oldSlice, val, cmp)
if p != nil {
old = *p
}
}
errs = append(errs, validator(ctx, op, fldPath.Index(i), val, old)...)
}
return errs
}
// lookup returns a pointer to the first element in the list that matches the
// target, according to the provided comparison function, or else nil.
func lookup[T any](list []T, target T, cmp func(T, T) bool) *T {
for i := range list {
if cmp(list[i], target) {
return &list[i]
}
}
return nil
}
// EachMapVal validates each element of newMap with the specified validation
// function and, if the corresponding key is found in oldMap, the old value.
// The value-type of the slices is assumed to not be nilable.
func EachMapVal[K ~string, V any](ctx context.Context, op operation.Operation, fldPath *field.Path, newMap, oldMap map[K]V,
validator ValidateFunc[*V]) field.ErrorList {
var errs field.ErrorList
for key, val := range newMap {
var old *V
if o, found := oldMap[key]; found {
old = &o
}
errs = append(errs, validator(ctx, op, fldPath.Key(string(key)), &val, old)...)
}
return errs
}
// EachMapValNilable validates each element of newMap with the specified
// validation function and, if the corresponding key is found in oldMap, the
// old value. The value-type of the slices is assumed to be nilable.
func EachMapValNilable[K ~string, V any](ctx context.Context, op operation.Operation, fldPath *field.Path, newMap, oldMap map[K]V,
validator ValidateFunc[V]) field.ErrorList {
var errs field.ErrorList
for key, val := range newMap {
var old V
if o, found := oldMap[key]; found {
old = o
}
errs = append(errs, validator(ctx, op, fldPath.Key(string(key)), val, old)...)
}
return errs
}
// EachMapKey validates each element of newMap with the specified
// validation function. The oldMap argument is not used.
func EachMapKey[K ~string, T any](ctx context.Context, op operation.Operation, fldPath *field.Path, newMap, oldMap map[K]T,
validator ValidateFunc[*K]) field.ErrorList {
var errs field.ErrorList
for key := range newMap {
// Note: the field path is the field, not the key.
errs = append(errs, validator(ctx, op, fldPath, &key, nil)...)
}
return errs
}

View File

@ -0,0 +1,316 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package validate
import (
"context"
"fmt"
"reflect"
"slices"
"testing"
"k8s.io/apimachinery/pkg/api/operation"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
)
type TestStruct struct {
I int
D string
}
func TestEachSliceVal(t *testing.T) {
testEachSliceVal(t, "valid", []int{11, 12, 13})
testEachSliceVal(t, "valid", []string{"a", "b", "c"})
testEachSliceVal(t, "valid", []TestStruct{{11, "a"}, {12, "b"}, {13, "c"}})
testEachSliceVal(t, "empty", []int{})
testEachSliceVal(t, "empty", []string{})
testEachSliceVal(t, "empty", []TestStruct{})
testEachSliceVal[int](t, "nil", nil)
testEachSliceVal[string](t, "nil", nil)
testEachSliceVal[TestStruct](t, "nil", nil)
testEachSliceValUpdate(t, "valid", []int{11, 12, 13})
testEachSliceValUpdate(t, "valid", []string{"a", "b", "c"})
testEachSliceValUpdate(t, "valid", []TestStruct{{11, "a"}, {12, "b"}, {13, "c"}})
testEachSliceValUpdate(t, "empty", []int{})
testEachSliceValUpdate(t, "empty", []string{})
testEachSliceValUpdate(t, "empty", []TestStruct{})
testEachSliceValUpdate[int](t, "nil", nil)
testEachSliceValUpdate[string](t, "nil", nil)
testEachSliceValUpdate[TestStruct](t, "nil", nil)
}
func testEachSliceVal[T any](t *testing.T, name string, input []T) {
var zero T
t.Run(fmt.Sprintf("%s(%T)", name, zero), func(t *testing.T) {
calls := 0
vfn := func(ctx context.Context, op operation.Operation, fldPath *field.Path, newVal, oldVal *T) field.ErrorList {
if oldVal != nil {
t.Errorf("expected nil oldVal, got %v", *oldVal)
}
calls++
return nil
}
_ = EachSliceVal(context.Background(), operation.Operation{}, field.NewPath("test"), input, nil, nil, vfn)
if calls != len(input) {
t.Errorf("expected %d calls, got %d", len(input), calls)
}
})
}
func testEachSliceValUpdate[T any](t *testing.T, name string, input []T) {
var zero T
t.Run(fmt.Sprintf("%s(%T)", name, zero), func(t *testing.T) {
calls := 0
vfn := func(ctx context.Context, op operation.Operation, fldPath *field.Path, newVal, oldVal *T) field.ErrorList {
if oldVal == nil {
t.Fatalf("expected non-nil oldVal")
}
if !reflect.DeepEqual(*newVal, *oldVal) {
t.Errorf("expected oldVal == newVal, got %v, %v", *oldVal, *newVal)
}
calls++
return nil
}
old := make([]T, len(input))
copy(old, input)
slices.Reverse(old)
cmp := func(a, b T) bool { return reflect.DeepEqual(a, b) }
_ = EachSliceVal(context.Background(), operation.Operation{}, field.NewPath("test"), input, old, cmp, vfn)
if calls != len(input) {
t.Errorf("expected %d calls, got %d", len(input), calls)
}
})
}
func TestEachSliceValNilablePointer(t *testing.T) {
testEachSliceValNilable(t, "valid", []*int{ptr.To(11), ptr.To(12), ptr.To(13)})
testEachSliceValNilable(t, "valid", []*string{ptr.To("a"), ptr.To("b"), ptr.To("c")})
testEachSliceValNilable(t, "valid", []*TestStruct{{11, "a"}, {12, "b"}, {13, "c"}})
testEachSliceValNilable(t, "empty", []*int{})
testEachSliceValNilable(t, "empty", []*string{})
testEachSliceValNilable(t, "empty", []*TestStruct{})
testEachSliceValNilable[int](t, "nil", nil)
testEachSliceValNilable[string](t, "nil", nil)
testEachSliceValNilable[TestStruct](t, "nil", nil)
testEachSliceValNilableUpdate(t, "valid", []*int{ptr.To(11), ptr.To(12), ptr.To(13)})
testEachSliceValNilableUpdate(t, "valid", []*string{ptr.To("a"), ptr.To("b"), ptr.To("c")})
testEachSliceValNilableUpdate(t, "valid", []*TestStruct{{11, "a"}, {12, "b"}, {13, "c"}})
testEachSliceValNilableUpdate(t, "empty", []*int{})
testEachSliceValNilableUpdate(t, "empty", []*string{})
testEachSliceValNilableUpdate(t, "empty", []*TestStruct{})
testEachSliceValNilableUpdate[int](t, "nil", nil)
testEachSliceValNilableUpdate[string](t, "nil", nil)
testEachSliceValNilableUpdate[TestStruct](t, "nil", nil)
}
func TestEachSliceValNilableSlice(t *testing.T) {
testEachSliceValNilable(t, "valid", [][]int{{11, 12, 13}, {12, 13, 14}, {13, 14, 15}})
testEachSliceValNilable(t, "valid", [][]string{{"a", "b", "c"}, {"b", "c", "d"}, {"c", "d", "e"}})
testEachSliceValNilable(t, "valid", [][]TestStruct{
{{11, "a"}, {12, "b"}, {13, "c"}},
{{12, "a"}, {13, "b"}, {14, "c"}},
{{13, "a"}, {14, "b"}, {15, "c"}}})
testEachSliceValNilable(t, "empty", [][]int{{}, {}, {}})
testEachSliceValNilable(t, "empty", [][]string{{}, {}, {}})
testEachSliceValNilable(t, "empty", [][]TestStruct{{}, {}, {}})
testEachSliceValNilable[int](t, "nil", nil)
testEachSliceValNilable[string](t, "nil", nil)
testEachSliceValNilable[TestStruct](t, "nil", nil)
testEachSliceValNilableUpdate(t, "valid", [][]int{{11, 12, 13}, {12, 13, 14}, {13, 14, 15}})
testEachSliceValNilableUpdate(t, "valid", [][]string{{"a", "b", "c"}, {"b", "c", "d"}, {"c", "d", "e"}})
testEachSliceValNilableUpdate(t, "valid", [][]TestStruct{
{{11, "a"}, {12, "b"}, {13, "c"}},
{{12, "a"}, {13, "b"}, {14, "c"}},
{{13, "a"}, {14, "b"}, {15, "c"}}})
testEachSliceValNilableUpdate(t, "empty", [][]int{{}, {}, {}})
testEachSliceValNilableUpdate(t, "empty", [][]string{{}, {}, {}})
testEachSliceValNilableUpdate(t, "empty", [][]TestStruct{{}, {}, {}})
testEachSliceValNilableUpdate[int](t, "nil", nil)
testEachSliceValNilableUpdate[string](t, "nil", nil)
testEachSliceValNilableUpdate[TestStruct](t, "nil", nil)
}
func testEachSliceValNilable[T any](t *testing.T, name string, input []T) {
var zero T
t.Run(fmt.Sprintf("%s(%T)", name, zero), func(t *testing.T) {
calls := 0
vfn := func(ctx context.Context, op operation.Operation, fldPath *field.Path, newVal, oldVal T) field.ErrorList {
if !reflect.DeepEqual(oldVal, zero) {
t.Errorf("expected nil oldVal, got %v", oldVal)
}
calls++
return nil
}
_ = EachSliceValNilable(context.Background(), operation.Operation{}, field.NewPath("test"), input, nil, nil, vfn)
if calls != len(input) {
t.Errorf("expected %d calls, got %d", len(input), calls)
}
})
}
func testEachSliceValNilableUpdate[T any](t *testing.T, name string, input []T) {
var zero T
t.Run(fmt.Sprintf("%s(%T)", name, zero), func(t *testing.T) {
calls := 0
vfn := func(ctx context.Context, op operation.Operation, fldPath *field.Path, newVal, oldVal T) field.ErrorList {
if reflect.DeepEqual(oldVal, zero) {
t.Fatalf("expected non-nil oldVal")
}
if !reflect.DeepEqual(newVal, oldVal) {
t.Errorf("expected oldVal == newVal, got %v, %v", oldVal, newVal)
}
calls++
return nil
}
old := make([]T, len(input))
copy(old, input)
slices.Reverse(old)
cmp := func(a, b T) bool { return reflect.DeepEqual(a, b) }
_ = EachSliceValNilable(context.Background(), operation.Operation{}, field.NewPath("test"), input, old, cmp, vfn)
if calls != len(input) {
t.Errorf("expected %d calls, got %d", len(input), calls)
}
})
}
func TestEachMapVal(t *testing.T) {
testEachMapVal(t, "valid", map[string]int{"one": 11, "two": 12, "three": 13})
testEachMapVal(t, "valid", map[string]string{"A": "a", "B": "b", "C": "c"})
testEachMapVal(t, "valid", map[string]TestStruct{"one": {11, "a"}, "two": {12, "b"}, "three": {13, "c"}})
testEachMapVal(t, "empty", map[string]int{})
testEachMapVal(t, "empty", map[string]string{})
testEachMapVal(t, "empty", map[string]TestStruct{})
testEachMapVal[int](t, "nil", nil)
testEachMapVal[string](t, "nil", nil)
testEachMapVal[TestStruct](t, "nil", nil)
}
func testEachMapVal[T any](t *testing.T, name string, input map[string]T) {
var zero T
t.Run(fmt.Sprintf("%s(%T)", name, zero), func(t *testing.T) {
calls := 0
vfn := func(ctx context.Context, op operation.Operation, fldPath *field.Path, newVal, oldVal *T) field.ErrorList {
if oldVal != nil {
t.Errorf("expected nil oldVal, got %v", *oldVal)
}
calls++
return nil
}
_ = EachMapVal(context.Background(), operation.Operation{}, field.NewPath("test"), input, nil, vfn)
if calls != len(input) {
t.Errorf("expected %d calls, got %d", len(input), calls)
}
})
}
func TestEachMapValNilablePointer(t *testing.T) {
testEachMapValNilable(t, "valid", map[string]*int{"one": ptr.To(11), "two": ptr.To(12), "three": ptr.To(13)})
testEachMapValNilable(t, "valid", map[string]*string{"A": ptr.To("a"), "B": ptr.To("b"), "C": ptr.To("c")})
testEachMapValNilable(t, "valid", map[string]*TestStruct{"one": {11, "a"}, "two": {12, "b"}, "three": {13, "c"}})
testEachMapValNilable(t, "empty", map[string]*int{})
testEachMapValNilable(t, "empty", map[string]*string{})
testEachMapValNilable(t, "empty", map[string]*TestStruct{})
testEachMapValNilable[int](t, "nil", nil)
testEachMapValNilable[string](t, "nil", nil)
testEachMapValNilable[TestStruct](t, "nil", nil)
}
func TestEachMapValNilableSlice(t *testing.T) {
testEachMapValNilable(t, "valid", map[string][]int{
"one": {11, 12, 13},
"two": {12, 13, 14},
"three": {13, 14, 15}})
testEachMapValNilable(t, "valid", map[string][]string{
"A": {"a", "b", "c"},
"B": {"b", "c", "d"},
"C": {"c", "d", "e"}})
testEachMapValNilable(t, "valid", map[string][]TestStruct{
"one": {{11, "a"}, {12, "b"}, {13, "c"}},
"two": {{12, "a"}, {13, "b"}, {14, "c"}},
"three": {{13, "a"}, {14, "b"}, {15, "c"}}})
testEachMapValNilable(t, "empty", map[string][]int{"a": {}, "b": {}, "c": {}})
testEachMapValNilable(t, "empty", map[string][]string{"a": {}, "b": {}, "c": {}})
testEachMapValNilable(t, "empty", map[string][]TestStruct{"a": {}, "b": {}, "c": {}})
testEachMapValNilable[int](t, "nil", nil)
testEachMapValNilable[string](t, "nil", nil)
testEachMapValNilable[TestStruct](t, "nil", nil)
}
func testEachMapValNilable[T any](t *testing.T, name string, input map[string]T) {
var zero T
t.Run(fmt.Sprintf("%s(%T)", name, zero), func(t *testing.T) {
calls := 0
vfn := func(ctx context.Context, op operation.Operation, fldPath *field.Path, newVal, oldVal T) field.ErrorList {
if !reflect.DeepEqual(oldVal, zero) {
t.Errorf("expected nil oldVal, got %v", oldVal)
}
calls++
return nil
}
_ = EachMapValNilable(context.Background(), operation.Operation{}, field.NewPath("test"), input, nil, vfn)
if calls != len(input) {
t.Errorf("expected %d calls, got %d", len(input), calls)
}
})
}
type StringType string
func TestEachMapKey(t *testing.T) {
testEachMapKey(t, "valid", map[string]int{"one": 11, "two": 12, "three": 13})
testEachMapKey(t, "valid", map[StringType]string{"A": "a", "B": "b", "C": "c"})
}
func testEachMapKey[K ~string, V any](t *testing.T, name string, input map[K]V) {
var zero K
t.Run(fmt.Sprintf("%s(%T)", name, zero), func(t *testing.T) {
calls := 0
vfn := func(ctx context.Context, op operation.Operation, fldPath *field.Path, newVal, oldVal *K) field.ErrorList {
if oldVal != nil {
t.Errorf("expected nil oldVal, got %v", *oldVal)
}
calls++
return nil
}
_ = EachMapKey(context.Background(), operation.Operation{}, field.NewPath("test"), input, nil, vfn)
if calls != len(input) {
t.Errorf("expected %d calls, got %d", len(input), calls)
}
})
}

View File

@ -0,0 +1,41 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package validate
import (
"context"
"k8s.io/apimachinery/pkg/api/operation"
"k8s.io/apimachinery/pkg/util/validation/field"
)
// GetFieldFunc is a function that extracts a field from a type and returns a
// nilable value.
type GetFieldFunc[Tstruct any, Tfield any] func(*Tstruct) Tfield
// Subfield validates a subfield of a struct against a validator function.
func Subfield[Tstruct any, Tfield any](ctx context.Context, op operation.Operation, fldPath *field.Path, newStruct, oldStruct *Tstruct,
fldName string, getField GetFieldFunc[Tstruct, Tfield], validator ValidateFunc[Tfield]) field.ErrorList {
var errs field.ErrorList
newVal := getField(newStruct)
var oldVal Tfield
if oldStruct != nil {
oldVal = getField(oldStruct)
}
errs = append(errs, validator(ctx, op, fldPath.Child(fldName), newVal, oldVal)...)
return errs
}

View File

@ -0,0 +1,430 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package validators
import (
"fmt"
"strings"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/gengo/v2/types"
)
const (
listTypeTagName = "k8s:listType"
ListMapKeyTagName = "k8s:listMapKey"
eachValTagName = "k8s:eachVal"
eachKeyTagName = "k8s:eachKey"
)
// We keep the eachVal and eachKey validators around because the main
// code-generation logic calls them directly. We could move them into the main
// pkg, but it's easier and cleaner to leave them here.
var globalEachVal *eachValTagValidator
var globalEachKey *eachKeyTagValidator
func init() {
// Lists with list-map semantics are comprised of multiple tags, which need
// to share information between them.
shared := map[string]*listMap{} // keyed by the fieldpath
RegisterTagValidator(listTypeTagValidator{shared})
RegisterTagValidator(listMapKeyTagValidator{shared})
globalEachVal = &eachValTagValidator{shared, nil}
RegisterTagValidator(globalEachVal)
globalEachKey = &eachKeyTagValidator{nil}
RegisterTagValidator(globalEachKey)
}
// This applies to all tags in this file.
var listTagsValidScopes = sets.New(ScopeAny)
// listMap collects information about a single list with map semantics.
type listMap struct {
declaredAsMap bool
keyFields []string
}
type listTypeTagValidator struct {
byFieldPath map[string]*listMap
}
func (listTypeTagValidator) Init(Config) {}
func (listTypeTagValidator) TagName() string {
return listTypeTagName
}
func (listTypeTagValidator) ValidScopes() sets.Set[Scope] {
return listTagsValidScopes
}
func (lttv listTypeTagValidator) GetValidations(context Context, _ []string, payload string) (Validations, error) {
t := context.Type
if t.Kind == types.Alias {
t = t.Underlying
}
if t.Kind != types.Slice && t.Kind != types.Array {
return Validations{}, fmt.Errorf("can only be used on list types")
}
switch payload {
case "atomic", "set":
// Allowed but no special handling.
case "map":
if realType(t.Elem).Kind != types.Struct {
return Validations{}, fmt.Errorf("only lists of structs can be list-maps")
}
// Save the fact that this list is a map.
if lttv.byFieldPath[context.Path.String()] == nil {
lttv.byFieldPath[context.Path.String()] = &listMap{}
}
lm := lttv.byFieldPath[context.Path.String()]
lm.declaredAsMap = true
default:
return Validations{}, fmt.Errorf("unknown list type %q", payload)
}
// This tag doesn't generate any validations. It just accumulates
// information for other tags to use.
return Validations{}, nil
}
func realType(t *types.Type) *types.Type {
for {
if t.Kind == types.Alias {
t = t.Underlying
} else if t.Kind == types.Pointer {
t = t.Elem
} else {
break
}
}
return t
}
func (lttv listTypeTagValidator) Docs() TagDoc {
doc := TagDoc{
Tag: lttv.TagName(),
Scopes: lttv.ValidScopes().UnsortedList(),
Description: "Declares a list field's semantic type.",
Payloads: []TagPayloadDoc{{
Description: "<type>",
Docs: "map | atomic",
}},
}
return doc
}
type listMapKeyTagValidator struct {
byFieldPath map[string]*listMap
}
func (listMapKeyTagValidator) Init(Config) {}
func (listMapKeyTagValidator) TagName() string {
return ListMapKeyTagName
}
func (listMapKeyTagValidator) ValidScopes() sets.Set[Scope] {
return listTagsValidScopes
}
func (lmktv listMapKeyTagValidator) GetValidations(context Context, _ []string, payload string) (Validations, error) {
t := context.Type
if t.Kind == types.Alias {
t = t.Underlying
}
if t.Kind != types.Slice && t.Kind != types.Array {
return Validations{}, fmt.Errorf("can only be used on list types")
}
if realType(t.Elem).Kind != types.Struct {
return Validations{}, fmt.Errorf("only lists of structs can be list-maps")
}
var fieldName string
if memb := getMemberByJSON(realType(t.Elem), payload); memb == nil {
return Validations{}, fmt.Errorf("no field for JSON name %q", payload)
} else if k := realType(memb.Type).Kind; k != types.Builtin {
return Validations{}, fmt.Errorf("only primitive types can be list-map keys, not %s", k)
} else {
fieldName = memb.Name
}
if lmktv.byFieldPath[context.Path.String()] == nil {
lmktv.byFieldPath[context.Path.String()] = &listMap{}
}
lm := lmktv.byFieldPath[context.Path.String()]
lm.keyFields = append(lm.keyFields, fieldName)
// This tag doesn't generate any validations. It just accumulates
// information for other tags to use.
return Validations{}, nil
}
func (lmktv listMapKeyTagValidator) Docs() TagDoc {
doc := TagDoc{
Tag: lmktv.TagName(),
Scopes: lmktv.ValidScopes().UnsortedList(),
Description: "Declares a named sub-field of a list's value-type to be part of the list-map key.",
Payloads: []TagPayloadDoc{{
Description: "<field-json-name>",
Docs: "The name of the field.",
}},
}
return doc
}
type eachValTagValidator struct {
byFieldPath map[string]*listMap
validator Validator
}
func (evtv *eachValTagValidator) Init(cfg Config) {
evtv.validator = cfg.Validator
}
func (eachValTagValidator) TagName() string {
return eachValTagName
}
func (eachValTagValidator) ValidScopes() sets.Set[Scope] {
return listTagsValidScopes
}
// LateTagValidator indicatesa that validator has to run after the listType and
// listMapKey tags.
func (eachValTagValidator) LateTagValidator() {}
var (
validateEachSliceVal = types.Name{Package: libValidationPkg, Name: "EachSliceVal"}
validateEachSliceValNilable = types.Name{Package: libValidationPkg, Name: "EachSliceValNilable"}
validateEachMapVal = types.Name{Package: libValidationPkg, Name: "EachMapVal"}
validateEachMapValNilable = types.Name{Package: libValidationPkg, Name: "EachMapValNilable"}
)
func (evtv eachValTagValidator) GetValidations(context Context, _ []string, payload string) (Validations, error) {
t := context.Type
if t.Kind == types.Alias {
t = t.Underlying
}
switch t.Kind {
case types.Slice, types.Array, types.Map:
default:
return Validations{}, fmt.Errorf("can only be used on list or map types")
}
fakeComments := []string{payload}
elemContext := Context{
Type: t.Elem,
Parent: t,
Path: context.Path.Key("*"),
}
switch t.Kind {
case types.Slice, types.Array:
elemContext.Scope = ScopeListVal
case types.Map:
elemContext.Scope = ScopeMapVal
}
if validations, err := evtv.validator.ExtractValidations(elemContext, fakeComments); err != nil {
return Validations{}, err
} else {
if len(validations.Variables) > 0 {
return Validations{}, fmt.Errorf("variable generation is not supported")
}
return evtv.getValidations(context.Path, t, validations)
}
}
func (evtv eachValTagValidator) getValidations(fldPath *field.Path, t *types.Type, validations Validations) (Validations, error) {
switch t.Kind {
case types.Slice, types.Array:
return evtv.getListValidations(fldPath, t, validations)
case types.Map:
return evtv.getMapValidations(t, validations)
}
return Validations{}, fmt.Errorf("non-iterable type: %v", t)
}
// ForEachVal returns a validation that applies a function to each element of
// a list or map.
func ForEachVal(fldPath *field.Path, t *types.Type, fn FunctionGen) (Validations, error) {
return globalEachVal.getValidations(fldPath, t, Validations{Functions: []FunctionGen{fn}})
}
func (evtv eachValTagValidator) getListValidations(fldPath *field.Path, t *types.Type, validations Validations) (Validations, error) {
result := Validations{}
result.OpaqueValType = validations.OpaqueType
var listMap *listMap
if lm, found := evtv.byFieldPath[fldPath.String()]; found {
if !lm.declaredAsMap {
return Validations{}, fmt.Errorf("found listMapKey without listType=map")
}
if len(lm.keyFields) == 0 {
return Validations{}, fmt.Errorf("found listType=map without listMapKey")
}
listMap = lm
}
for _, vfn := range validations.Functions {
// Which function we call depends on whether the value-type is
// already nilable or not.
var validateEach types.Name
if isNilableType(t.Elem) {
validateEach = validateEachSliceValNilable
} else {
validateEach = validateEachSliceVal
}
var cmpArg any = Literal("nil")
if listMap != nil {
cmpFn := FunctionLiteral{
Parameters: []ParamResult{{"a", t.Elem}, {"b", t.Elem}},
Results: []ParamResult{{"", types.Bool}},
}
buf := strings.Builder{}
buf.WriteString("return ")
for i, fld := range listMap.keyFields {
if i > 0 {
buf.WriteString(" && ")
}
buf.WriteString(fmt.Sprintf("a.%s == b.%s", fld, fld))
}
cmpFn.Body = buf.String()
cmpArg = cmpFn
}
f := Function(eachValTagName, vfn.Flags(), validateEach, cmpArg, WrapperFunction{vfn, t.Elem})
result.Functions = append(result.Functions, f)
}
return result, nil
}
func (evtv eachValTagValidator) getMapValidations(t *types.Type, validations Validations) (Validations, error) {
result := Validations{}
result.OpaqueValType = validations.OpaqueType
for _, vfn := range validations.Functions {
// Which function we call depends on whether the value-type is
// already nilable or not.
var validateEach types.Name
if isNilableType(t.Elem) {
validateEach = validateEachMapValNilable
} else {
validateEach = validateEachMapVal
}
f := Function(eachValTagName, vfn.Flags(), validateEach, WrapperFunction{vfn, t.Elem})
result.Functions = append(result.Functions, f)
}
return result, nil
}
func (evtv eachValTagValidator) Docs() TagDoc {
doc := TagDoc{
Tag: evtv.TagName(),
Scopes: evtv.ValidScopes().UnsortedList(),
Description: "Declares a validation for each value in a map or list.",
Payloads: []TagPayloadDoc{{
Description: "<validation-tag>",
Docs: "The tag to evaluate for each value.",
}},
}
return doc
}
type eachKeyTagValidator struct {
validator Validator
}
func (ektv *eachKeyTagValidator) Init(cfg Config) {
ektv.validator = cfg.Validator
}
func (eachKeyTagValidator) TagName() string {
return eachKeyTagName
}
func (eachKeyTagValidator) ValidScopes() sets.Set[Scope] {
return listTagsValidScopes
}
var (
validateEachMapKey = types.Name{Package: libValidationPkg, Name: "EachMapKey"}
)
func (ektv eachKeyTagValidator) GetValidations(context Context, _ []string, payload string) (Validations, error) {
t := context.Type
if t.Kind == types.Alias {
t = t.Underlying
}
if t.Kind != types.Map {
return Validations{}, fmt.Errorf("can only be used on map types")
}
fakeComments := []string{payload}
elemContext := Context{
Scope: ScopeMapKey,
Type: t.Elem,
Parent: t,
Path: context.Path.Child("(keys)"),
}
if validations, err := ektv.validator.ExtractValidations(elemContext, fakeComments); err != nil {
return Validations{}, err
} else {
if len(validations.Variables) > 0 {
return Validations{}, fmt.Errorf("variable generation is not supported")
}
return ektv.getValidations(t, validations)
}
}
func (ektv eachKeyTagValidator) getValidations(t *types.Type, validations Validations) (Validations, error) {
result := Validations{}
result.OpaqueKeyType = validations.OpaqueType
for _, vfn := range validations.Functions {
f := Function(eachKeyTagName, vfn.Flags(), validateEachMapKey, WrapperFunction{vfn, t.Key})
result.Functions = append(result.Functions, f)
}
return result, nil
}
// ForEachKey returns a validation that applies a function to each key of
// a map.
func ForEachKey(_ *field.Path, t *types.Type, fn FunctionGen) (Validations, error) {
return globalEachKey.getValidations(t, Validations{Functions: []FunctionGen{fn}})
}
func (ektv eachKeyTagValidator) Docs() TagDoc {
doc := TagDoc{
Tag: ektv.TagName(),
Scopes: ektv.ValidScopes().UnsortedList(),
Description: "Declares a validation for each value in a map or list.",
Payloads: []TagPayloadDoc{{
Description: "<validation-tag>",
Docs: "The tag to evaluate for each value.",
}},
}
return doc
}

View File

@ -0,0 +1,126 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package validators
import (
"fmt"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/gengo/v2/types"
)
const (
subfieldTagName = "k8s:subfield"
)
func init() {
RegisterTagValidator(&subfieldTagValidator{})
}
type subfieldTagValidator struct {
validator Validator
}
func (stv *subfieldTagValidator) Init(cfg Config) {
stv.validator = cfg.Validator
}
func (subfieldTagValidator) TagName() string {
return subfieldTagName
}
var subfieldTagValidScopes = sets.New(ScopeAny)
func (subfieldTagValidator) ValidScopes() sets.Set[Scope] {
return subfieldTagValidScopes
}
var (
validateSubfield = types.Name{Package: libValidationPkg, Name: "Subfield"}
)
func (stv subfieldTagValidator) GetValidations(context Context, args []string, payload string) (Validations, error) {
t := realType(context.Type)
if t.Kind != types.Struct {
return Validations{}, fmt.Errorf("can only be used on struct types")
}
if len(args) != 1 {
return Validations{}, fmt.Errorf("requires exactly one arg")
}
subname := args[0]
submemb := getMemberByJSON(t, subname)
if submemb == nil {
return Validations{}, fmt.Errorf("no field for json name %q", subname)
}
result := Validations{}
fakeComments := []string{payload}
subContext := Context{
Scope: ScopeField,
Type: submemb.Type,
Parent: t,
Path: context.Path.Child(subname),
}
if validations, err := stv.validator.ExtractValidations(subContext, fakeComments); err != nil {
return Validations{}, err
} else {
if len(validations.Variables) > 0 {
return Validations{}, fmt.Errorf("variable generation is not supported")
}
for _, vfn := range validations.Functions {
nilableStructType := context.Type
if !isNilableType(nilableStructType) {
nilableStructType = types.PointerTo(nilableStructType)
}
nilableFieldType := submemb.Type
fieldExprPrefix := ""
if !isNilableType(nilableFieldType) {
nilableFieldType = types.PointerTo(nilableFieldType)
fieldExprPrefix = "&"
}
getFn := FunctionLiteral{
Parameters: []ParamResult{{"o", nilableStructType}},
Results: []ParamResult{{"", nilableFieldType}},
}
getFn.Body = fmt.Sprintf("return %so.%s", fieldExprPrefix, submemb.Name)
f := Function(subfieldTagName, vfn.Flags(), validateSubfield, subname, getFn, WrapperFunction{vfn, submemb.Type})
result.Functions = append(result.Functions, f)
result.Variables = append(result.Variables, validations.Variables...)
}
}
return result, nil
}
func (stv subfieldTagValidator) Docs() TagDoc {
doc := TagDoc{
Tag: stv.TagName(),
Scopes: stv.ValidScopes().UnsortedList(),
Description: "Declares a validation for a subfield of a struct.",
Args: []TagArgDoc{{
Description: "<field-json-name>",
}},
Docs: "The named subfield must be a direct field of the struct, or of an embedded struct.",
Payloads: []TagPayloadDoc{{
Description: "<validation-tag>",
Docs: "The tag to evaluate for the subfield.",
}},
}
return doc
}