diff --git a/staging/publishing/import-restrictions.yaml b/staging/publishing/import-restrictions.yaml index 989721473e7..50a077fa2b0 100644 --- a/staging/publishing/import-restrictions.yaml +++ b/staging/publishing/import-restrictions.yaml @@ -61,6 +61,7 @@ - k8s.io/code-generator - k8s.io/kube-openapi - k8s.io/klog + - k8s.io/utils/ptr - baseImportPath: "./staging/src/k8s.io/component-base" allowedImports: diff --git a/staging/src/k8s.io/apimachinery/pkg/api/operation/operation.go b/staging/src/k8s.io/apimachinery/pkg/api/operation/operation.go new file mode 100644 index 00000000000..9f5ae7a9d40 --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/api/operation/operation.go @@ -0,0 +1,56 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operation + +import "k8s.io/apimachinery/pkg/util/sets" + +// Operation provides contextual information about a validation request and the API +// operation being validated. +// This type is intended for use with generate validation code and may be enhanced +// in the future to include other information needed to validate requests. +type Operation struct { + // Type is the category of operation being validated. This does not + // differentiate between HTTP verbs like PUT and PATCH, but rather merges + // those into a single "Update" category. + Type Type + + // Options declare the options enabled for validation. + // + // Options should be set according to a resource validation strategy before validation + // is performed, and must be treated as read-only during validation. + // + // Options are identified by string names. Option string names may match the name of a feature + // gate, in which case the presence of the name in the set indicates that the feature is + // considered enabled for the resource being validated. Note that a resource may have a + // feature enabled even when the feature gate is disabled. This can happen when feature is + // already in-use by a resource, often because the feature gate was enabled when the + // resource first began using the feature. + // + // Unset options are disabled/false. + Options sets.Set[string] +} + +// Code is the request operation to be validated. +type Type uint32 + +const ( + // Create indicates the request being validated is for a resource create operation. + Create Type = iota + + // Update indicates the request being validated is for a resource update operation. + Update +) diff --git a/staging/src/k8s.io/apimachinery/pkg/api/safe/safe.go b/staging/src/k8s.io/apimachinery/pkg/api/safe/safe.go new file mode 100644 index 00000000000..aad8925dbe6 --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/api/safe/safe.go @@ -0,0 +1,37 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package safe + +// Field takes a pointer to any value (which may or may not be nil) and +// a function that traverses to a target type R (a typical use case is to dereference a field), +// and returns the result of the traversal, or the zero value of the target type. +// This is roughly equivalent to "value != nil ? fn(value) : zero-value" in languages that support the ternary operator. +func Field[V any, R any](value *V, fn func(*V) R) R { + if value == nil { + var zero R + return zero + } + o := fn(value) + return o +} + +// Cast takes any value, attempts to cast it to T, and returns the T value if +// the cast is successful, or else the zero value of T. +func Cast[T any](value any) T { + result, _ := value.(T) + return result +} diff --git a/staging/src/k8s.io/apimachinery/pkg/api/validate/common.go b/staging/src/k8s.io/apimachinery/pkg/api/validate/common.go new file mode 100644 index 00000000000..14a6f0da7f0 --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/api/validate/common.go @@ -0,0 +1,28 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package validate + +import ( + "context" + + "k8s.io/apimachinery/pkg/api/operation" + "k8s.io/apimachinery/pkg/util/validation/field" +) + +// ValidateFunc is a function that validates a value, possibly considering the +// old value (if any). +type ValidateFunc[T any] func(ctx context.Context, op operation.Operation, fldPath *field.Path, newValue, oldValue T) field.ErrorList diff --git a/staging/src/k8s.io/code-generator/cmd/validation-gen/lint.go b/staging/src/k8s.io/code-generator/cmd/validation-gen/lint.go new file mode 100644 index 00000000000..625edd121f4 --- /dev/null +++ b/staging/src/k8s.io/code-generator/cmd/validation-gen/lint.go @@ -0,0 +1,160 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "fmt" + "strings" + + "k8s.io/gengo/v2/types" + "k8s.io/klog/v2" +) + +// linter is a struct that holds the state of the linting process. +// It contains a map of types that have been linted, a list of linting rules, +// and a list of errors that occurred during the linting process. +type linter struct { + linted map[*types.Type]bool + rules []lintRule + // lintErrors is a list of errors that occurred during the linting process. + // lintErrors would be in the format: + // field : + // type : + lintErrors []error +} + +var defaultRules = []lintRule{ + ruleOptionalAndRequired, + ruleRequiredAndDefault, +} + +func (l *linter) AddError(field, msg string) { + l.lintErrors = append(l.lintErrors, fmt.Errorf("%s: %s", field, msg)) +} + +func newLinter(rules ...lintRule) *linter { + if len(rules) == 0 { + rules = defaultRules + } + return &linter{ + linted: make(map[*types.Type]bool), + rules: rules, + lintErrors: []error{}, + } +} + +func (l *linter) lintType(t *types.Type) error { + if _, ok := l.linted[t]; ok { + return nil + } + l.linted[t] = true + + if t.CommentLines != nil { + klog.V(5).Infof("linting type %s", t.Name.String()) + lintErrs, err := l.lintComments(t.CommentLines) + if err != nil { + return err + } + for _, lintErr := range lintErrs { + l.AddError("type "+t.Name.String(), lintErr) + } + } + switch t.Kind { + case types.Alias: + // Recursively lint the underlying type of the alias. + if err := l.lintType(t.Underlying); err != nil { + return err + } + case types.Struct: + // Recursively lint each member of the struct. + for _, member := range t.Members { + klog.V(5).Infof("linting comments for field %s of type %s", member.String(), t.Name.String()) + lintErrs, err := l.lintComments(member.CommentLines) + if err != nil { + return err + } + for _, lintErr := range lintErrs { + l.AddError("type "+t.Name.String(), lintErr) + } + if err := l.lintType(member.Type); err != nil { + return err + } + } + case types.Slice, types.Array, types.Pointer: + // Recursively lint the element type of the slice or array. + if err := l.lintType(t.Elem); err != nil { + return err + } + case types.Map: + // Recursively lint the key and element types of the map. + if err := l.lintType(t.Key); err != nil { + return err + } + if err := l.lintType(t.Elem); err != nil { + return err + } + } + return nil +} + +// lintRule is a function that validates a slice of comments. +// It returns a string as an error message if the comments are invalid, +// and an error there is an error happened during the linting process. +type lintRule func(comments []string) (string, error) + +// lintComments runs all registered rules on a slice of comments. +func (l *linter) lintComments(comments []string) ([]string, error) { + var lintErrs []string + for _, rule := range l.rules { + if msg, err := rule(comments); err != nil { + return nil, err + } else if msg != "" { + lintErrs = append(lintErrs, msg) + } + } + + return lintErrs, nil +} + +// conflictingTagsRule checks for conflicting tags in a slice of comments. +func conflictingTagsRule(comments []string, tags ...string) (string, error) { + if len(tags) < 2 { + return "", fmt.Errorf("at least two tags must be provided") + } + tagCount := make(map[string]bool) + for _, comment := range comments { + for _, tag := range tags { + if strings.HasPrefix(comment, tag) { + tagCount[tag] = true + } + } + } + if len(tagCount) > 1 { + return fmt.Sprintf("conflicting tags: {%s}", strings.Join(tags, ", ")), nil + } + return "", nil +} + +// ruleOptionalAndRequired checks for conflicting tags +k8s:optional and +k8s:required in a slice of comments. +func ruleOptionalAndRequired(comments []string) (string, error) { + return conflictingTagsRule(comments, "+k8s:optional", "+k8s:required") +} + +// ruleRequiredAndDefault checks for conflicting tags +k8s:required and +default in a slice of comments. +func ruleRequiredAndDefault(comments []string) (string, error) { + return conflictingTagsRule(comments, "+k8s:required", "+default") +} diff --git a/staging/src/k8s.io/code-generator/cmd/validation-gen/lint_test.go b/staging/src/k8s.io/code-generator/cmd/validation-gen/lint_test.go new file mode 100644 index 00000000000..ea5954389b6 --- /dev/null +++ b/staging/src/k8s.io/code-generator/cmd/validation-gen/lint_test.go @@ -0,0 +1,412 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "errors" + "testing" + + "k8s.io/gengo/v2/types" +) + +func ruleAlwaysPass(comments []string) (string, error) { + return "", nil +} + +func ruleAlwaysFail(comments []string) (string, error) { + return "lintfail", nil +} + +func ruleAlwaysErr(comments []string) (string, error) { + return "", errors.New("linterr") +} + +func mkCountRule(counter *int, realRule lintRule) lintRule { + return func(comments []string) (string, error) { + (*counter)++ + return realRule(comments) + } +} + +func TestLintCommentsRuleInvocation(t *testing.T) { + tests := []struct { + name string + rules []lintRule + commentLineGroups [][]string + wantErr bool + wantCount int + }{ + { + name: "0 rules, 0 comments", + rules: []lintRule{}, + commentLineGroups: [][]string{}, + wantErr: false, + wantCount: 0, + }, + { + name: "1 rule, 1 comment", + rules: []lintRule{ruleAlwaysPass}, + commentLineGroups: [][]string{{"comment"}}, + wantErr: false, + wantCount: 1, + }, + { + name: "3 rules, 3 comments", + rules: []lintRule{ruleAlwaysPass, ruleAlwaysFail, ruleAlwaysErr}, + commentLineGroups: [][]string{{"comment1"}, {"comment2"}, {"comment3"}}, + wantErr: true, + wantCount: 9, + }, + { + name: "1 rule, 1 comment, rule fails", + rules: []lintRule{ruleAlwaysFail}, + commentLineGroups: [][]string{{"comment"}}, + wantErr: false, + wantCount: 1, + }, + { + name: "1 rule, 1 comment, rule errors", + rules: []lintRule{ruleAlwaysErr}, + commentLineGroups: [][]string{{"comment"}}, + wantErr: true, + wantCount: 1, + }, + { + name: "3 rules, 1 comment, rule errors in the middle", + rules: []lintRule{ruleAlwaysPass, ruleAlwaysErr, ruleAlwaysFail}, + commentLineGroups: [][]string{{"comment"}}, + wantErr: true, + wantCount: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + counter := 0 + rules := make([]lintRule, len(tt.rules)) + for i, rule := range tt.rules { + rules[i] = mkCountRule(&counter, rule) + } + l := newLinter(rules...) + for _, commentLines := range tt.commentLineGroups { + _, err := l.lintComments(commentLines) + gotErr := err != nil + if gotErr != tt.wantErr { + t.Errorf("lintComments() error = %v, wantErr %v", err, tt.wantErr) + } + } + if counter != tt.wantCount { + t.Errorf("expected %d rule invocations, got %d", tt.wantCount, counter) + } + }) + } +} + +func TestRuleOptionalAndRequired(t *testing.T) { + tests := []struct { + name string + comments []string + wantMsg string + wantErr bool + }{ + { + name: "no comments", + comments: []string{}, + wantMsg: "", + }, + { + name: "only optional", + comments: []string{"+k8s:optional"}, + wantMsg: "", + }, + { + name: "only required", + comments: []string{"+k8s:required"}, + wantMsg: "", + }, + { + name: "optional and required", + comments: []string{"+k8s:optional", "+k8s:required"}, + wantMsg: "conflicting tags: {+k8s:optional, +k8s:required}", + }, + { + name: "optional, empty, required", + comments: []string{"+k8s:optional", "", "+k8s:required"}, + wantMsg: "conflicting tags: {+k8s:optional, +k8s:required}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg, _ := ruleOptionalAndRequired(tt.comments) + if msg != tt.wantMsg { + t.Errorf("ruleOptionalAndRequired() msg = %v, wantMsg %v", msg, tt.wantMsg) + } + }) + } +} + +func TestRuleRequiredAndDefault(t *testing.T) { + tests := []struct { + name string + comments []string + wantMsg string + }{ + { + name: "no comments", + comments: []string{}, + wantMsg: "", + }, + { + name: "only required", + comments: []string{"+k8s:required"}, + wantMsg: "", + }, + { + name: "only default", + comments: []string{"+default=somevalue"}, + wantMsg: "", + }, + { + name: "required and default", + comments: []string{"+k8s:required", "+default=somevalue"}, + wantMsg: "conflicting tags: {+k8s:required, +default}", + }, + { + name: "required, empty, default", + comments: []string{"+k8s:required", "", "+default=somevalue"}, + wantMsg: "conflicting tags: {+k8s:required, +default}", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg, _ := ruleRequiredAndDefault(tt.comments) + if msg != tt.wantMsg { + t.Errorf("ruleRequiredAndDefault() msg = %v, wantMsg %v", msg, tt.wantMsg) + } + }) + } +} + +func TestConflictingTagsRule(t *testing.T) { + tests := []struct { + name string + comments []string + tags []string + wantMsg string + wantErr bool + }{ + { + name: "no comments", + comments: []string{}, + tags: []string{"+tag1", "+tag2"}, + wantMsg: "", + }, + { + name: "only tag1", + comments: []string{"+tag1"}, + tags: []string{"+tag1", "+tag2"}, + wantMsg: "", + }, + { + name: "tag1, empty, tag2", + comments: []string{"+tag1", "", "+tag2"}, + tags: []string{"+tag1", "+tag2"}, + wantMsg: "conflicting tags: {+tag1, +tag2}", + }, + { + name: "3 tags", + comments: []string{"tag1", "+tag2", "+tag3=value"}, + tags: []string{"+tag1", "+tag2", "+tag3"}, + wantMsg: "conflicting tags: {+tag1, +tag2, +tag3}", + }, + { + name: "less than 2 tags", + comments: []string{"+tag1"}, + tags: []string{"+tag1"}, + wantMsg: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg, err := conflictingTagsRule(tt.comments, tt.tags...) + if (err != nil) != tt.wantErr { + t.Errorf("conflictingTagsRule() error = %v, wantErr %v", err, tt.wantErr) + return + } + if msg != tt.wantMsg { + t.Errorf("conflictingTagsRule() msg = %v, wantMsg %v", msg, tt.wantMsg) + } + }) + } +} + +func TestLintType(t *testing.T) { + tests := []struct { + name string + typeToLint *types.Type + wantCount int + expectError bool + }{ + { + name: "No comments", + typeToLint: &types.Type{ + Name: types.Name{Package: "testpkg", Name: "TestType"}, + CommentLines: nil, + }, + wantCount: 0, + expectError: false, + }, + { + name: "Valid comments", + typeToLint: &types.Type{ + Name: types.Name{Package: "testpkg", Name: "TestType"}, + CommentLines: []string{"+k8s:optional"}, + }, + wantCount: 1, + expectError: false, + }, + { + name: "Pointer type", + typeToLint: &types.Type{ + Name: types.Name{Package: "testpkg", Name: "TestPointer"}, + Kind: types.Pointer, + Elem: &types.Type{Name: types.Name{Package: "testpkg", Name: "ElemType"}, CommentLines: []string{"+k8s:optional"}}, + CommentLines: []string{"+k8s:optional"}, + }, + wantCount: 2, + expectError: false, + }, + { + name: "Slice of pointers", + typeToLint: &types.Type{ + Name: types.Name{Package: "testpkg", Name: "TestSlice"}, + Kind: types.Slice, + Elem: &types.Type{ + Name: types.Name{Package: "testpkg", Name: "PointerElem"}, + Kind: types.Pointer, + Elem: &types.Type{Name: types.Name{Package: "testpkg", Name: "ElemType"}, CommentLines: []string{"+k8s:optional"}}, + CommentLines: []string{"+k8s:optional"}, + }, + CommentLines: []string{"+k8s:optional"}, + }, + wantCount: 3, + expectError: false, + }, + { + name: "Map to pointers", + typeToLint: &types.Type{ + Name: types.Name{Package: "testpkg", Name: "TestMap"}, + Kind: types.Map, + Key: &types.Type{Name: types.Name{Package: "testpkg", Name: "KeyType"}, CommentLines: []string{"+k8s:required"}}, + Elem: &types.Type{ + Name: types.Name{Package: "testpkg", Name: "PointerElem"}, + Kind: types.Pointer, + Elem: &types.Type{Name: types.Name{Package: "testpkg", Name: "ElemType"}, CommentLines: []string{"+k8s:optional"}}, + CommentLines: []string{"+k8s:optional"}, + }, + CommentLines: []string{"+k8s:optional"}, + }, + wantCount: 4, + expectError: false, + }, + { + name: "Alias to pointers", + typeToLint: &types.Type{ + Name: types.Name{Package: "testpkg", Name: "TestAlias"}, + Kind: types.Alias, + Underlying: &types.Type{ + Name: types.Name{Package: "testpkg", Name: "PointerElem"}, + Kind: types.Pointer, + Elem: &types.Type{Name: types.Name{Package: "testpkg", Name: "ElemType"}, CommentLines: []string{"+k8s:optional"}}, + CommentLines: []string{"+k8s:optional"}, + }, + CommentLines: []string{"+k8s:optional"}, + }, + wantCount: 3, + expectError: false, + }, + { + name: "Struct with members", + typeToLint: &types.Type{ + Name: types.Name{Package: "testpkg", Name: "TestStruct"}, + Kind: types.Struct, + Members: []types.Member{ + { + Name: "Field1", + Type: &types.Type{Name: types.Name{Package: "testpkg", Name: "FieldType"}}, + CommentLines: []string{"+k8s:optional"}, + }, + { + Name: "Field2", + Type: &types.Type{Name: types.Name{Package: "testpkg", Name: "FieldType"}}, + CommentLines: []string{"+k8s:required"}, + }, + }, + }, + wantCount: 2, + expectError: false, + }, + { + name: "Nested types", + typeToLint: &types.Type{ + Name: types.Name{Package: "testpkg", Name: "TestStruct"}, + Kind: types.Struct, + Members: []types.Member{ + { + Name: "Field1", + Type: &types.Type{ + Name: types.Name{Package: "testpkg", Name: "NestedStruct"}, + Kind: types.Struct, + CommentLines: []string{"+k8s:optional"}, + Members: []types.Member{ + { + Name: "NestedField1", + Type: &types.Type{Name: types.Name{Package: "testpkg", Name: "NestedFieldType"}}, + CommentLines: []string{"+k8s:required"}, + }, + }, + }, + }, + }, + }, + wantCount: 3, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + counter := 0 + rules := []lintRule{mkCountRule(&counter, ruleAlwaysPass)} + l := newLinter(rules...) + if err := l.lintType(tt.typeToLint); err != nil { + t.Fatal(err) + } + gotErr := len(l.lintErrors) > 0 + if gotErr != tt.expectError { + t.Errorf("LintType() errors = %v, expectError %v", l.lintErrors, tt.expectError) + } + if counter != tt.wantCount { + t.Errorf("expected %d rule invocations, got %d", tt.wantCount, counter) + } + }) + } +} diff --git a/staging/src/k8s.io/code-generator/cmd/validation-gen/main.go b/staging/src/k8s.io/code-generator/cmd/validation-gen/main.go new file mode 100644 index 00000000000..a35dc37bb3e --- /dev/null +++ b/staging/src/k8s.io/code-generator/cmd/validation-gen/main.go @@ -0,0 +1,159 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// validation-gen is a tool for auto-generating Validation functions. +package main + +import ( + "bytes" + "cmp" + "encoding/json" + "flag" + "fmt" + "os" + "slices" + + "github.com/spf13/pflag" + + "k8s.io/code-generator/cmd/validation-gen/validators" + "k8s.io/gengo/v2" + "k8s.io/gengo/v2/generator" + "k8s.io/gengo/v2/namer" + "k8s.io/gengo/v2/types" + "k8s.io/klog/v2" +) + +func main() { + klog.InitFlags(nil) + args := &Args{} + + args.AddFlags(pflag.CommandLine) + if err := flag.Set("logtostderr", "true"); err != nil { + klog.Fatalf("Error: %v", err) + } + pflag.CommandLine.AddGoFlagSet(flag.CommandLine) + pflag.Parse() + + if err := args.Validate(); err != nil { + klog.Fatalf("Error: %v", err) + } + + if args.PrintDocs { + printDocs() + os.Exit(0) + } + + myTargets := func(context *generator.Context) []generator.Target { + return GetTargets(context, args) + } + + // Run it. + if err := gengo.Execute( + NameSystems(), + DefaultNameSystem(), + myTargets, + gengo.StdBuildTag, + pflag.Args(), + ); err != nil { + klog.Fatalf("Error: %v", err) + } + klog.V(2).Info("Completed successfully.") +} + +type Args struct { + OutputFile string + ExtraPkgs []string // Always consider these as last-ditch possibilities for validations. + GoHeaderFile string + PrintDocs bool + Lint bool +} + +// AddFlags add the generator flags to the flag set. +func (args *Args) AddFlags(fs *pflag.FlagSet) { + fs.StringVar(&args.OutputFile, "output-file", "generated.validations.go", + "the name of the file to be generated") + fs.StringSliceVar(&args.ExtraPkgs, "extra-pkg", args.ExtraPkgs, + "the import path of a package whose validation can be used by generated code, but is not being generated for") + fs.StringVar(&args.GoHeaderFile, "go-header-file", "", + "the path to a file containing boilerplate header text; the string \"YEAR\" will be replaced with the current 4-digit year") + fs.BoolVar(&args.PrintDocs, "docs", false, + "print documentation for supported declarative validations, and then exit") + fs.BoolVar(&args.Lint, "lint", false, + "only run linting checks, do not generate code") +} + +// Validate checks the given arguments. +func (args *Args) Validate() error { + if len(args.OutputFile) == 0 { + return fmt.Errorf("--output-file must be specified") + } + + return nil +} + +func printDocs() { + // We need a fake context to init the validator plugins. + c := &generator.Context{ + Namers: namer.NameSystems{}, + Universe: types.Universe{}, + FileTypes: map[string]generator.FileType{}, + } + + // Initialize all registered validators. + validator := validators.InitGlobalValidator(c) + + docs := validator.Docs() + for i := range docs { + d := &docs[i] + slices.Sort(d.Scopes) + if d.Usage == "" { + // Try to generate a usage string if none was provided. + usage := d.Tag + if len(d.Args) > 0 { + usage += "(" + for i := range d.Args { + if i > 0 { + usage += ", " + } + usage += d.Args[i].Description + } + usage += ")" + } + if len(d.Payloads) > 0 { + usage += "=" + if len(d.Payloads) == 1 { + usage += d.Payloads[0].Description + } else { + usage += "" + } + } + d.Usage = usage + } + } + slices.SortFunc(docs, func(a, b validators.TagDoc) int { + return cmp.Compare(a.Tag, b.Tag) + }) + + var buf bytes.Buffer + encoder := json.NewEncoder(&buf) + encoder.SetEscapeHTML(false) + encoder.SetIndent("", " ") + if err := encoder.Encode(docs); err != nil { + klog.Fatalf("failed to marshal docs: %v", err) + } + + fmt.Println(buf.String()) +} diff --git a/staging/src/k8s.io/code-generator/cmd/validation-gen/targets.go b/staging/src/k8s.io/code-generator/cmd/validation-gen/targets.go new file mode 100644 index 00000000000..17c2b2663d8 --- /dev/null +++ b/staging/src/k8s.io/code-generator/cmd/validation-gen/targets.go @@ -0,0 +1,383 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "cmp" + "fmt" + "reflect" + "slices" + "strings" + + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/code-generator/cmd/validation-gen/validators" + "k8s.io/gengo/v2" + "k8s.io/gengo/v2/generator" + "k8s.io/gengo/v2/namer" + "k8s.io/gengo/v2/types" + "k8s.io/klog/v2" +) + +// These are the comment tags that carry parameters for validation generation. +const ( + tagName = "k8s:validation-gen" + inputTagName = "k8s:validation-gen-input" + schemeRegistryTagName = "k8s:validation-gen-scheme-registry" // defaults to k8s.io/apimachinery/pkg.runtime.Scheme + testFixtureTagName = "k8s:validation-gen-test-fixture" // if set, generate go test files for test fixtures. Supported values: "validateFalse". +) + +var ( + runtimePkg = "k8s.io/apimachinery/pkg/runtime" + schemeType = types.Name{Package: runtimePkg, Name: "Scheme"} +) + +func extractTag(comments []string) ([]string, bool) { + tags, err := gengo.ExtractFunctionStyleCommentTags("+", []string{tagName}, comments) + if err != nil { + klog.Fatalf("Failed to extract tags: %v", err) + } + values, found := tags[tagName] + if !found || len(values) == 0 { + return nil, false + } + + result := make([]string, len(values)) + for i, tag := range values { + result[i] = tag.Value + } + return result, true +} + +func extractInputTag(comments []string) []string { + tags, err := gengo.ExtractFunctionStyleCommentTags("+", []string{inputTagName}, comments) + if err != nil { + klog.Fatalf("Failed to extract input tags: %v", err) + } + values, found := tags[inputTagName] + if !found { + return nil + } + + result := make([]string, len(values)) + for i, tag := range values { + result[i] = tag.Value + } + return result +} + +func checkTag(comments []string, require ...string) bool { + tags, err := gengo.ExtractFunctionStyleCommentTags("+", []string{tagName}, comments) + if err != nil { + klog.Fatalf("Failed to extract tags: %v", err) + } + values, found := tags[tagName] + if !found { + return false + } + + if len(require) == 0 { + return len(values) == 1 && values[0].Value == "" + } + + valueStrings := make([]string, len(values)) + for i, tag := range values { + valueStrings[i] = tag.Value + } + + return reflect.DeepEqual(valueStrings, require) +} + +func schemeRegistryTag(pkg *types.Package) types.Name { + tags, err := gengo.ExtractFunctionStyleCommentTags("+", []string{schemeRegistryTagName}, pkg.Comments) + if err != nil { + klog.Fatalf("Failed to extract scheme registry tags: %v", err) + } + values, found := tags[schemeRegistryTagName] + if !found || len(values) == 0 { + return schemeType // default + } + if len(values) > 1 { + panic(fmt.Sprintf("Package %q contains more than one usage of %q", pkg.Path, schemeRegistryTagName)) + } + return types.ParseFullyQualifiedName(values[0].Value) +} + +var testFixtureTagValues = sets.New("validateFalse") + +func testFixtureTag(pkg *types.Package) sets.Set[string] { + result := sets.New[string]() + tags, err := gengo.ExtractFunctionStyleCommentTags("+", []string{testFixtureTagName}, pkg.Comments) + if err != nil { + klog.Fatalf("Failed to extract test fixture tags: %v", err) + } + values, found := tags[testFixtureTagName] + if !found { + return result + } + + for _, tag := range values { + if !testFixtureTagValues.Has(tag.Value) { + panic(fmt.Sprintf("Package %q: %s must be one of '%s', but got: %s", pkg.Path, testFixtureTagName, testFixtureTagValues.UnsortedList(), tag.Value)) + } + result.Insert(tag.Value) + } + return result +} + +// NameSystems returns the name system used by the generators in this package. +func NameSystems() namer.NameSystems { + return namer.NameSystems{ + "public": namer.NewPublicNamer(1), + "raw": namer.NewRawNamer("", nil), + "objectvalidationfn": validationFnNamer(), + "private": namer.NewPrivateNamer(0), + "name": namer.NewPublicNamer(0), + } +} + +func validationFnNamer() *namer.NameStrategy { + return &namer.NameStrategy{ + Prefix: "Validate_", + Join: func(pre string, in []string, post string) string { + return pre + strings.Join(in, "_") + post + }, + } +} + +// DefaultNameSystem returns the default name system for ordering the types to be +// processed by the generators in this package. +func DefaultNameSystem() string { + return "public" +} + +func GetTargets(context *generator.Context, args *Args) []generator.Target { + boilerplate, err := gengo.GoBoilerplate(args.GoHeaderFile, gengo.StdBuildTag, gengo.StdGeneratedBy) + if err != nil { + klog.Fatalf("Failed loading boilerplate: %v", err) + } + + var targets []generator.Target + var lintErrs []error + + // First load other "input" packages. We do this as a single call because + // it is MUCH faster. + inputPkgs := make([]string, 0, len(context.Inputs)) + pkgToInput := map[string]string{} + for _, input := range context.Inputs { + klog.V(5).Infof("considering pkg %q", input) + + pkg := context.Universe[input] + + // if the types are not in the same package where the validation + // functions are to be emitted + inputTags := extractInputTag(pkg.Comments) + if len(inputTags) > 1 { + panic(fmt.Sprintf("there may only be one input tag, got %#v", inputTags)) + } + if len(inputTags) == 1 { + inputPath := inputTags[0] + if strings.HasPrefix(inputPath, "./") || strings.HasPrefix(inputPath, "../") { + // this is a relative dir, which will not work under gomodules. + // join with the local package path, but warn + klog.Fatalf("relative path (%s=%s) is not supported; use full package path (as used by 'import') instead", inputTagName, inputPath) + } + + klog.V(5).Infof(" input pkg %v", inputPath) + inputPkgs = append(inputPkgs, inputPath) + pkgToInput[input] = inputPath + } else { + pkgToInput[input] = input + } + } + + // Make sure explicit extra-packages are added. + var extraPkgs []string + for _, pkg := range args.ExtraPkgs { + // In case someone specifies an extra as a path into vendor, convert + // it to its "real" package path. + if i := strings.Index(pkg, "/vendor/"); i != -1 { + pkg = pkg[i+len("/vendor/"):] + } + extraPkgs = append(extraPkgs, pkg) + } + if expanded, err := context.FindPackages(extraPkgs...); err != nil { + klog.Fatalf("cannot find extra packages: %v", err) + } else { + extraPkgs = expanded // now in fully canonical form + } + for _, extra := range extraPkgs { + inputPkgs = append(inputPkgs, extra) + pkgToInput[extra] = extra + } + + // We also need the to be able to look up the packages of inputs + inputToPkg := make(map[string]string, len(pkgToInput)) + for k, v := range pkgToInput { + inputToPkg[v] = k + } + + if len(inputPkgs) > 0 { + if _, err := context.LoadPackages(inputPkgs...); err != nil { + klog.Fatalf("cannot load packages: %v", err) + } + } + // update context.Order to the latest context.Universe + orderer := namer.Orderer{Namer: namer.NewPublicNamer(1)} + context.Order = orderer.OrderUniverse(context.Universe) + + // Initialize all validator plugins exactly once. + validator := validators.InitGlobalValidator(context) + + // Build a cache of type->callNode for every type we need. + for _, input := range context.Inputs { + klog.V(2).InfoS("processing", "pkg", input) + + pkg := context.Universe[input] + + schemeRegistry := schemeRegistryTag(pkg) + + typesWith, found := extractTag(pkg.Comments) + if !found { + klog.V(2).InfoS(" did not find required tag", "tag", tagName) + continue + } + if len(typesWith) == 1 && typesWith[0] == "" { + klog.Fatalf("found package tag %q with no value", tagName) + } + shouldCreateObjectValidationFn := func(t *types.Type) bool { + // opt-out + if checkTag(t.SecondClosestCommentLines, "false") { + return false + } + // opt-in + if checkTag(t.SecondClosestCommentLines, "true") { + return true + } + // all types + for _, v := range typesWith { + if v == "*" && !namer.IsPrivateGoName(t.Name.Name) { + return true + } + } + // For every k8s:validation-gen tag at the package level, interpret the value as a + // field name (like TypeMeta, ListMeta, ObjectMeta) and trigger validation generation + // for any type with any of the matching field names. Provides a more useful package + // level validation than global (because we only need validations on a subset of objects - + // usually those with TypeMeta). + return isTypeWith(t, typesWith) + } + + // Find the right input pkg, which might not be this one. + inputPath := pkgToInput[input] + // typesPkg is where the types that need validation are defined. + // Sometimes it is different from pkg. For example, kubernetes core/v1 + // types are defined in k8s.io/api/core/v1, while the pkg which holds + // defaulter code is at k/k/pkg/api/v1. + typesPkg := context.Universe[inputPath] + + // Figure out which types we should be considering further. + var rootTypes []*types.Type + for _, t := range typesPkg.Types { + if shouldCreateObjectValidationFn(t) { + rootTypes = append(rootTypes, t) + } else { + klog.V(6).InfoS("skipping type", "type", t) + } + } + // Deterministic ordering helps in logs and debugging. + slices.SortFunc(rootTypes, func(a, b *types.Type) int { + return cmp.Compare(a.Name.String(), b.Name.String()) + }) + + td := NewTypeDiscoverer(validator, inputToPkg) + for _, t := range rootTypes { + klog.V(4).InfoS("pre-processing", "type", t) + if err := td.DiscoverType(t); err != nil { + klog.Fatalf("failed to generate validations: %v", err) + } + } + + l := newLinter() + for _, t := range rootTypes { + klog.V(4).InfoS("linting root-type", "type", t) + if err := l.lintType(t); err != nil { + klog.Fatalf("failed to lint type %q: %v", t.Name, err) + } + if len(l.lintErrors) > 0 { + lintErrs = append(lintErrs, l.lintErrors...) + } + } + if args.Lint { + klog.V(4).Info("Lint is set, skip appending targets") + continue + } + + targets = append(targets, + &generator.SimpleTarget{ + PkgName: pkg.Name, + PkgPath: pkg.Path, + PkgDir: pkg.Dir, // output pkg is the same as the input + HeaderComment: boilerplate, + + FilterFunc: func(c *generator.Context, t *types.Type) bool { + return t.Name.Package == typesPkg.Path + }, + + GeneratorsFunc: func(c *generator.Context) (generators []generator.Generator) { + generators = []generator.Generator{ + NewGenValidations(args.OutputFile, pkg.Path, rootTypes, td, inputToPkg, schemeRegistry), + } + testFixtureTags := testFixtureTag(pkg) + if testFixtureTags.Len() > 0 { + if !strings.HasSuffix(args.OutputFile, ".go") { + panic(fmt.Sprintf("%s requires that output file have .go suffix", testFixtureTagName)) + } + filename := args.OutputFile[0:len(args.OutputFile)-3] + "_test.go" + generators = append(generators, FixtureTests(filename, testFixtureTags)) + } + return generators + }, + }) + } + + if len(lintErrs) > 0 { + var lintErrsStr string + for _, err := range lintErrs { + lintErrsStr += fmt.Sprintf("\n%s", err.Error()) + } + if args.Lint { + klog.Fatalf("failed to lint comments: %s", lintErrsStr) + } else { + klog.Warningf("failed to lint comments: %s", lintErrsStr) + } + + } + return targets +} + +func isTypeWith(t *types.Type, typesWith []string) bool { + if t.Kind == types.Struct && len(typesWith) > 0 { + for _, field := range t.Members { + for _, s := range typesWith { + if field.Name == s { + return true + } + } + } + } + return false +} diff --git a/staging/src/k8s.io/code-generator/cmd/validation-gen/validation.go b/staging/src/k8s.io/code-generator/cmd/validation-gen/validation.go new file mode 100644 index 00000000000..d93b21ef1b6 --- /dev/null +++ b/staging/src/k8s.io/code-generator/cmd/validation-gen/validation.go @@ -0,0 +1,1448 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "bytes" + "cmp" + "fmt" + "io" + "reflect" + "slices" + "strconv" + "strings" + "unicode" + + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/code-generator/cmd/validation-gen/validators" + "k8s.io/gengo/v2/generator" + "k8s.io/gengo/v2/namer" + "k8s.io/gengo/v2/parser/tags" + "k8s.io/gengo/v2/types" + "k8s.io/klog/v2" +) + +func mkPkgNames(pkg string, names ...string) []types.Name { + result := make([]types.Name, 0, len(names)) + for _, name := range names { + result = append(result, types.Name{Package: pkg, Name: name}) + } + return result +} + +var ( + fieldPkg = "k8s.io/apimachinery/pkg/util/validation/field" + fieldPkgSymbols = mkPkgNames(fieldPkg, "ErrorList", "InternalError", "Path") + fmtPkgSymbols = mkPkgNames("fmt", "Errorf") + safePkg = "k8s.io/apimachinery/pkg/api/safe" + safePkgSymbols = mkPkgNames(safePkg, "Field", "Cast") + operationPkg = "k8s.io/apimachinery/pkg/api/operation" + operationPkgSymbols = mkPkgNames(operationPkg, "Operation") + contextPkg = "context" + contextPkgSymbols = mkPkgNames(contextPkg, "Context") +) + +// genValidations produces a file with autogenerated validations. +type genValidations struct { + generator.GoGenerator + outputPackage string + inputToPkg map[string]string // Maps input packages to generated validation packages + rootTypes []*types.Type + discovered *typeDiscoverer + imports namer.ImportTracker + schemeRegistry types.Name +} + +// NewGenValidations cretes a new generator for the specified package. +func NewGenValidations(outputFilename, outputPackage string, rootTypes []*types.Type, discovered *typeDiscoverer, inputToPkg map[string]string, schemeRegistry types.Name) generator.Generator { + return &genValidations{ + GoGenerator: generator.GoGenerator{ + OutputFilename: outputFilename, + }, + outputPackage: outputPackage, + inputToPkg: inputToPkg, + rootTypes: rootTypes, + discovered: discovered, + imports: generator.NewImportTrackerForPackage(outputPackage), + schemeRegistry: schemeRegistry, + } +} + +func (g *genValidations) Namers(_ *generator.Context) namer.NameSystems { + // Have the raw namer for this file track what it imports. + return namer.NameSystems{ + "raw": namer.NewRawNamer(g.outputPackage, g.imports), + } +} + +func (g *genValidations) Filter(_ *generator.Context, t *types.Type) bool { + // We want to emit code for all root types. + for _, rt := range g.rootTypes { + if rt == t { + return true + } + } + // We want to emit for any other type that is transitively part of a root + // type and has validations. + n := g.discovered.typeNodes[t] + return n != nil && hasValidations(n) +} + +func (g *genValidations) Imports(_ *generator.Context) (imports []string) { + var importLines []string + for _, singleImport := range g.imports.ImportLines() { + if g.isOtherPackage(singleImport) { + importLines = append(importLines, singleImport) + } + } + return importLines +} + +func (g *genValidations) isOtherPackage(pkg string) bool { + if pkg == g.outputPackage { + return false + } + if strings.HasSuffix(pkg, `"`+g.outputPackage+`"`) { + return false + } + return true +} + +func (g *genValidations) Init(c *generator.Context, w io.Writer) error { + klog.V(5).Infof("emitting registration code") + sw := generator.NewSnippetWriter(w, c, "$", "$") + g.emitRegisterFunction(c, g.schemeRegistry, sw) + if err := sw.Error(); err != nil { + return err + } + return nil +} + +func (g *genValidations) GenerateType(c *generator.Context, t *types.Type, w io.Writer) error { + klog.V(5).Infof("emitting validation code for type %v", t) + + sw := generator.NewSnippetWriter(w, c, "$", "$") + g.emitValidationVariables(c, t, sw) + g.emitValidationFunction(c, t, sw) + if err := sw.Error(); err != nil { + return err + } + return nil +} + +// This is a global cache of whether a type has validations. +var hasValidationsCache = map[*typeNode]bool{} + +// hasValidations checks and caches whether the given typeNode has any +// validations, transitively. Callers must be SURE that the typeNode has +// already been full discovered, or this may return wrong answers. It should +// be totally safe in the generation phase, but discovery needs to be careful! +func hasValidations(n *typeNode) bool { + seen := map[*typeNode]bool{} + return hasValidationsImpl(n, seen) +} + +// hasValidationsImpl implements hasValidations without risk of infinite +// recursion. +func hasValidationsImpl(n *typeNode, seen map[*typeNode]bool) bool { + if n == nil { + return false + } + + if seen[n] { + return false // prevent infinite recursion + } + seen[n] = true + + if r, found := hasValidationsCache[n]; found { + return r + } + + r := hasValidationsMiss(n, seen) + hasValidationsCache[n] = r + return r +} + +// hasValidationsMiss is called in case of a cache miss. +func hasValidationsMiss(n *typeNode, seen map[*typeNode]bool) bool { + if !n.typeValidations.Empty() { + return true + } + allChildren := n.fields + if n.key != nil { + allChildren = append(allChildren, n.key) + } + if n.elem != nil { + allChildren = append(allChildren, n.elem) + } + if n.underlying != nil { + allChildren = append(allChildren, n.underlying) + } + for _, c := range allChildren { + if !c.fieldValidations.Empty() { + return true + } + if hasValidationsImpl(c.node, seen) { + return true + } + } + return false +} + +// typeDiscoverer contains fields necessary to build graphs of types. +type typeDiscoverer struct { + validator validators.Validator + inputToPkg map[string]string + + // typeNodes holds a map of gengo Type to typeNode for all of the types + // encountered during discovery. + typeNodes map[*types.Type]*typeNode +} + +// NewTypeDiscoverer creates and initializes a NewTypeDiscoverer. +func NewTypeDiscoverer(validator validators.Validator, inputToPkg map[string]string) *typeDiscoverer { + return &typeDiscoverer{ + validator: validator, + inputToPkg: inputToPkg, + typeNodes: map[*types.Type]*typeNode{}, + } +} + +// childNode represents a type which is used in another type (e.g. a struct +// field). +type childNode struct { + name string // the field name in the parent, populated when this node is a struct field + jsonName string // always populated when name is populated + childType *types.Type // the real type of the child (may be a pointer) + node *typeNode // the node of the child's value type, or nil if it is in a foreign package + + fieldValidations validators.Validations // validations on the field +} + +// typeNode represents a node in the type-graph, annotated with information +// about validations. Everything in this type, transitively, is assoctiated +// with the type, and not any specific instance of that type (e.g. when used as +// a field in a struct. +type typeNode struct { + valueType *types.Type // never a pointer, but may be a map, slice, struct, etc. + funcName types.Name // populated when this type is has a validation function + + fields []*childNode // populated when this type is a struct + key *childNode // populated when this type is a map + elem *childNode // populated when this type is a map or slice + underlying *childNode // populated when this type is an alias + + typeValidations validators.Validations // validations on the type +} + +// lookupField returns the childNode with the specified JSON name. +func (n typeNode) lookupField(jsonName string) *childNode { + for _, fld := range n.fields { + if fld.jsonName == jsonName { + return fld + } + } + return nil +} + +// DiscoverType walks the given type recursively, building a type-graph in this +// typeDiscoverer. If this is called multiple times for different types, the +// graphs will be will be merged. +func (td *typeDiscoverer) DiscoverType(t *types.Type) error { + if t.Kind == types.Pointer { + return fmt.Errorf("type %v: pointer root-types are not supported", t) + } + fldPath := field.NewPath(t.Name.String()) + if node, err := td.discover(t, fldPath); err != nil { + return err + } else if node == nil { + panic(fmt.Sprintf("discovered a nil node for type %v", t)) + } + return nil +} + +// discover walks the given type recursively and returns a typeNode +// representing it. This does not distinguish between discovering a type +// definition and discovering a field of a struct. The first time it +// encounters a type it has not seen before, it will discover that type. If it +// finds a type it has already processed, it will return the existing node. +func (td *typeDiscoverer) discover(t *types.Type, fldPath *field.Path) (*typeNode, error) { + // With the exception of builtins (which gengo puts in package ""), we + // can't traverse into packages which are not being processed by this tool. + if t.Name.Package != "" { + _, ok := td.inputToPkg[t.Name.Package] + if !ok { + return nil, nil + } + } + + // Catch some edge cases that we don't want to handle. + if t.Kind == types.Alias && t.Underlying.Kind == types.Pointer { + return nil, fmt.Errorf("field %s (%s): typedefs to pointers are not supported", fldPath.String(), t) + } + if t.Kind == types.Pointer { + pointee := t.Elem + if pointee.Kind == types.Alias { + pointee = pointee.Underlying + } + switch pointee.Kind { + case types.Pointer: + return nil, fmt.Errorf("field %s (%s): pointers to pointers are not supported", fldPath.String(), t) + case types.Slice: + return nil, fmt.Errorf("field %s (%s): pointers to slices are not supported", fldPath.String(), t) + case types.Map: + return nil, fmt.Errorf("field %s (%s): pointers to maps are not supported", fldPath.String(), t) + } + } + + // Discovery applies to values, not pointers. + if t.Kind == types.Pointer { + return td.discover(t.Elem, fldPath) + } + + // If we have done this type already, we can stop here and break any + // recursion. + if node, found := td.typeNodes[t]; found { + return node, nil + } + + // If we are descending into a named type, reboot the field path for better + // logging. Otherwise the field path might come in as something like + // .. which is true, but not super useful. + switch t.Kind { + case types.Alias, types.Struct: + fldPath = field.NewPath(t.Name.String()) + } + + // This is the type-node being assembled in the rest of this function. + thisNode := &typeNode{ + valueType: t, + } + td.typeNodes[t] = thisNode + + // If this is a known, named type, we can call its validation function. + switch t.Kind { + case types.Alias, types.Struct: + if fn, ok := td.getValidationFunctionName(t); ok { + thisNode.funcName = fn + } + } + + // Discover into this type before extracting type validations. + switch t.Kind { + case types.Builtin: + // Nothing more to do. + case types.Alias: + // Discover the underlying type. + // + // Note: By the language definition, what gengo calls "Aliases" (really + // just "type definitions") have underlying types of the type literal. + // In other words, if we define `type T1 string` and `type T2 T1`, the + // underlying type of T2 is string, not T1. This means that: + // 1) We will emit code for both underlying types. If the underlying + // type is a struct with many fields, we will emit two identical + // functions. + // 2) Validating a field of type T2 will NOT call any validation + // defined on the type T1. + // 3) In the case of a type definition whose RHS is a struct which + // has fields with validation tags, the validation for those fields + // WILL be called from the generated for for the new type. + if node, err := td.discover(t.Underlying, fldPath); err != nil { + return nil, err + } else { + thisNode.underlying = &childNode{ + childType: t.Underlying, + node: node, + } + } + case types.Struct: + // Discover into this struct, recursively. + if err := td.discoverStruct(thisNode, fldPath); err != nil { + return nil, err + } + case types.Slice, types.Array: + // Discover the element type. + if node, err := td.discover(t.Elem, fldPath.Key("vals")); err != nil { + return nil, err + } else { + thisNode.elem = &childNode{ + childType: t.Elem, + node: node, + } + } + case types.Map: + // Discover the key type. + if node, err := td.discover(t.Key, fldPath.Key("keys")); err != nil { + return nil, err + } else { + thisNode.key = &childNode{ + childType: t.Key, + node: node, + } + } + + // Discover the element type. + if node, err := td.discover(t.Elem, fldPath.Key("vals")); err != nil { + return nil, err + } else { + thisNode.elem = &childNode{ + childType: t.Elem, + node: node, + } + } + default: + return nil, fmt.Errorf("field %s (%v, kind %v) is not supported", fldPath.String(), t, t.Kind) + } + + // Extract any type-attached validation rules. We do this AFTER descending + // into the type, so that these validators have access to the full type. + // For example, all struct field validators get called before the type + // validators. This does not influence the order in which the validations + // are called in emitted code, just how we evaluate what to emit. + // + // This should only ever hit Struct and Alias types, since that is the only + // opportunity to have type-attached comments to process. + context := validators.Context{ + Scope: validators.ScopeType, + Type: t, + Path: fldPath, + } + if t.Kind == types.Alias { + context.Parent = t + context.Type = t.Underlying + } + if validations, err := td.validator.ExtractValidations(context, t.CommentLines); err != nil { + return nil, fmt.Errorf("%v: %w", fldPath, err) + } else if !validations.Empty() { + klog.V(5).InfoS("found type-attached validations", "n", validations.Len()) + thisNode.typeValidations.Add(validations) + } + + // Handle type definitions whose output depends on the rest of type + // discovery being complete. In particular, aliases to lists and maps need + // iteration, but we don't want to iterate them if the key or value types + // don't actually have validations. We also want to handle non-included + // types and make users tell us what they intended. Lastly, we want to + // handle recursive types, but we need to finish discovering the type + // before we know if there are other validations, again so we don't emit + // empty functions. + if t.Kind == types.Alias { + underlying := thisNode.underlying + + switch t.Underlying.Kind { + case types.Slice, types.Array: + // Validate each value. + if elemNode := underlying.node.elem.node; elemNode == nil { + if !thisNode.typeValidations.OpaqueValType { + return nil, fmt.Errorf("%v: value type %v is in a non-included package; "+ + "either add this package to validation-gen's --extra-pkg flag, "+ + "or add +k8s:eachVal=+k8s:opaqueType to the field to skip validation", + fldPath, underlying.node.elem.childType) + } + } else if thisNode.typeValidations.OpaqueValType { + // If the type is marked as opaque, we can treat it as it is + // were in a non-included package. + } else { + // If the value type is a named type, call the validation + // function for each element. + if funcName := elemNode.funcName; funcName.Name != "" { + // We only need the iteration function if the underlying + // type has validations, otherwise it is noise. + if hasValidations(underlying.node) { + // Note: the first argument to Function() is really + // only for debugging. + v, err := validators.ForEachVal(fldPath, underlying.childType, + validators.Function("iterateListValues", validators.DefaultFlags, funcName)) + if err != nil { + return nil, fmt.Errorf("generating list iteration: %w", err) + } else { + thisNode.typeValidations.Add(v) + } + } + } + } + case types.Map: + // Validate each key. + if keyNode := underlying.node.key.node; keyNode == nil { + if !thisNode.typeValidations.OpaqueKeyType { + return nil, fmt.Errorf("%v: key type %v is in a non-included package; "+ + "either add this package to validation-gen's --extra-pkg flag, "+ + "or add +k8s:eachKey=+k8s:opaqueType to the field to skip validation", + fldPath, underlying.node.elem.childType) + } + } else if thisNode.typeValidations.OpaqueKeyType { + // If the type is marked as opaque, we can treat it as it is + // were in a non-included package. + } else { + // If the key type is a named type, call the validation + // function for each key. + if funcName := keyNode.funcName; funcName.Name != "" { + // We only need the iteration function if the underlying + // type has validations, otherwise it is noise. + if hasValidations(underlying.node) { + // Note: the first argument to Function() is really + // only for debugging. + v, err := validators.ForEachKey(fldPath, underlying.childType, + validators.Function("iterateMapKeys", validators.DefaultFlags, funcName)) + if err != nil { + return nil, fmt.Errorf("generating map key iteration: %w", err) + } else { + thisNode.typeValidations.Add(v) + } + } + } + } + // Validate each value. + if elemNode := underlying.node.elem.node; elemNode == nil { + if !thisNode.typeValidations.OpaqueValType { + return nil, fmt.Errorf("%v: value type %v is in a non-included package; "+ + "either add this package to validation-gen's --extra-pkg flag, "+ + "or add +k8s:eachVal=+k8s:opaqueType to the field to skip validation", + fldPath, underlying.node.elem.childType) + } + } else if thisNode.typeValidations.OpaqueValType { + // If the type is marked as opaque, we can treat it as it is + // were in a non-included package. + } else { + // If the value type is a named type, call the validation + // function for each element. + if funcName := elemNode.funcName; funcName.Name != "" { + // We only need the iteration function if the underlying + // type has validations, otherwise it is noise. + if hasValidations(underlying.node) { + // Note: the first argument to Function() is really + // only for debugging. + v, err := validators.ForEachVal(fldPath, underlying.childType, + validators.Function("iterateMapValues", validators.DefaultFlags, funcName)) + if err != nil { + return nil, fmt.Errorf("generating map value iteration: %w", err) + } else { + thisNode.typeValidations.Add(v) + } + } + } + } + } + } + + return thisNode, nil +} + +// discoverStruct walks a struct type recursively. +func (td *typeDiscoverer) discoverStruct(thisNode *typeNode, fldPath *field.Path) error { + var fields []*childNode + + // Discover into each field of this struct. + for _, memb := range thisNode.valueType.Members { + name := memb.Name + if len(name) == 0 { // embedded fields + if memb.Type.Kind == types.Pointer { + name = memb.Type.Elem.Name.Name + } else { + name = memb.Type.Name.Name + } + } + // Only do exported fields. + if unicode.IsLower([]rune(name)[0]) { + continue + } + + // If we try to emit code for this field and find no JSON name, we + // will abort. + jsonName := "" + if commentTags, ok := tags.LookupJSON(memb); ok { + jsonName = commentTags.Name + } + + klog.V(5).InfoS("field", "name", name, "jsonName", jsonName, "type", memb.Type) + + // Discover the field type. + childPath := fldPath.Child(name) + childType := memb.Type + var child *childNode + if node, err := td.discover(childType, childPath); err != nil { + return err + } else { + child = &childNode{ + name: name, + jsonName: jsonName, + childType: childType, + node: node, + } + } + + // Extract any field-attached validation rules. + context := validators.Context{ + Scope: validators.ScopeField, + Type: childType, + Parent: thisNode.valueType, + Member: &memb, + Path: childPath, + } + if validations, err := td.validator.ExtractValidations(context, memb.CommentLines); err != nil { + return fmt.Errorf("field %s: %w", childPath.String(), err) + } else if !validations.Empty() { + klog.V(5).InfoS("found field-attached validations", "n", validations.Len()) + child.fieldValidations.Add(validations) + if len(validations.Variables) > 0 { + return fmt.Errorf("%v: variable generation is not supported for field validations", childPath) + } + } + + // Handle non-included types. + switch nonPtrType(childType).Kind { + case types.Struct, types.Alias: + if child.node == nil { // a non-included type + if !child.fieldValidations.OpaqueType { + return fmt.Errorf("%v: type %v is in a non-included package; "+ + "either add this package to validation-gen's --extra-pkg flag, "+ + "or add +k8s:opaqueType to the field to skip validation", + childPath, childType.String()) + } + } else if child.fieldValidations.OpaqueType { + // If the field is marked as opaque, we can treat it as it is + // were in a non-included package. + child.node = nil + } + } + + // Add any other field-attached "special" validators. We need to do + // this after all the other field validation has been processed, + // because some of this is conditional on whether other validations + // were emitted (to avoid emitting empty functions). + // + // We do this here, rather than in discover() because we need to know + // information about the field, not just the type. + switch childType.Kind { + case types.Slice, types.Array: + // Validate each value of a list field. + if elemNode := child.node.elem.node; elemNode == nil { + if !child.fieldValidations.OpaqueValType { + return fmt.Errorf("%v: value type %v is in a non-included package; "+ + "either add this package to validation-gen's --extra-pkg flag, "+ + "or add +k8s:eachVal=+k8s:opaqueType to the field to skip validation", + childPath, childType.Elem.String()) + } + } else if child.fieldValidations.OpaqueValType { + // If the field is marked as opaque, we can treat it as it is + // were in a non-included package. + } else { + // If the list's value type is a named type, call the validation + // function for each element. + if funcName := elemNode.funcName; funcName.Name != "" { + // We only emit the iteration function if the field + // has other validations, otherwise it is noise. + if hasValidations(child.node) { + // Note: the first argument to Function() is really + // only for debugging. + v, err := validators.ForEachVal(childPath, childType, + validators.Function("iterateListValues", validators.DefaultFlags, funcName)) + if err != nil { + return fmt.Errorf("generating list iteration: %w", err) + } else { + child.fieldValidations.Add(v) + } + } + } + } + case types.Map: + // Validate each key of a map field. + if keyNode := child.node.key.node; keyNode == nil { + if !child.fieldValidations.OpaqueKeyType { + return fmt.Errorf("%v: key type %v is in a non-included package; "+ + "either add this package to validation-gen's --extra-pkg flag, "+ + "or add +k8s:eachKey=+k8s:opaqueType to the field to skip validation", + childPath, childType.Key.String()) + } + } else if child.fieldValidations.OpaqueKeyType { + // If the field is marked as opaque, we can treat it as it is + // were in a non-included package. + } else { + // If the map's key type is a named type, call the validation + // function for each key. + if funcName := keyNode.funcName; funcName.Name != "" { + // We only emit the iteration function if the field + // has other validations, otherwise it is noise. + if hasValidations(child.node) { + // Note: the first argument to Function() is really + // only for debugging. + v, err := validators.ForEachKey(childPath, childType, + validators.Function("iterateMapKeys", validators.DefaultFlags, funcName)) + if err != nil { + return fmt.Errorf("generating map key iteration: %w", err) + } else { + child.fieldValidations.Add(v) + } + } + } + } + // Validate each value of a map field. + if elemNode := child.node.elem.node; elemNode == nil { + if !child.fieldValidations.OpaqueValType { + return fmt.Errorf("%v: value type %v is in a non-included package; "+ + "either add this package to validation-gen's --extra-pkg flag, "+ + "or add +k8s:eachVal=+k8s:opaqueType to the field to skip validation", + childPath, childType.Elem.String()) + } + } else if child.fieldValidations.OpaqueValType { + // If the field is marked as opaque, we can treat it as it is + // were in a non-included package. + } else { + // If the map's value type is a named type, call the validation + // function for each element. + if funcName := elemNode.funcName; funcName.Name != "" { + // We only emit the iteration function if the field + // has other validations, otherwise it is noise. + if hasValidations(child.node) { + // Note: the first argument to Function() is really + // only for debugging. + v, err := validators.ForEachVal(childPath, childType, + validators.Function("iterateMapValues", validators.DefaultFlags, funcName)) + if err != nil { + return fmt.Errorf("generating map value iteration: %w", err) + } else { + child.fieldValidations.Add(v) + } + } + } + } + } + + fields = append(fields, child) + } + + thisNode.fields = fields + return nil +} + +// nonPtrType removes any pointerness from the type. +func nonPtrType(t *types.Type) *types.Type { + for t.Kind == types.Pointer { + t = t.Elem + } + return t +} + +// getValidationFunctionName looks up the name of the specified type's +// validation function. +// +// TODO: Currently this is a "blind" call - we hope that the expected function +// exists, but we don't verify that, and we only emit calls into packages which +// are being processed by this generator. For cross-package calls we will need +// to verify the target, either by naming convention + fingerprint or by +// explicit comment-tags or something. +func (td *typeDiscoverer) getValidationFunctionName(t *types.Type) (types.Name, bool) { + pkg, ok := td.inputToPkg[t.Name.Package] + if !ok { + return types.Name{}, false + } + return types.Name{Package: pkg, Name: "Validate_" + t.Name.Name}, true +} + +func mkSymbolArgs(c *generator.Context, names []types.Name) generator.Args { + args := generator.Args{} + for _, name := range names { + args[name.Name] = c.Universe.Type(name) + } + return args +} + +// emitRegisterFunction emits the type-registration logic for validation +// functions. +func (g *genValidations) emitRegisterFunction(c *generator.Context, schemeRegistry types.Name, sw *generator.SnippetWriter) { + scheme := c.Universe.Type(schemeRegistry) + schemePtr := &types.Type{ + Kind: types.Pointer, + Elem: scheme, + } + + sw.Do("func init() { localSchemeBuilder.Register(RegisterValidations)}\n\n", nil) + + sw.Do("// RegisterValidations adds validation functions to the given scheme.\n", nil) + sw.Do("// Public to allow building arbitrary schemes.\n", nil) + sw.Do("func RegisterValidations(scheme $.|raw$) error {\n", schemePtr) + for _, rootType := range g.rootTypes { + if !hasValidations(g.discovered.typeNodes[rootType]) { + continue + } + + node := g.discovered.typeNodes[rootType] + if node == nil { + panic(fmt.Sprintf("found nil node for root-type %v", rootType)) + } + + // TODO: It would be nice if these were not hard-coded. + var statusType *types.Type + var statusField string + if status := node.lookupField("status"); status != nil { + statusType = status.node.valueType + statusField = status.name + } + + targs := generator.Args{ + "rootType": rootType, + "typePfx": "", + "statusType": statusType, + "statusField": statusField, + "field": mkSymbolArgs(c, fieldPkgSymbols), + "fmt": mkSymbolArgs(c, fmtPkgSymbols), + "operation": mkSymbolArgs(c, operationPkgSymbols), + "safe": mkSymbolArgs(c, safePkgSymbols), + "context": mkSymbolArgs(c, contextPkgSymbols), + } + if !isNilableType(rootType) { + targs["typePfx"] = "*" + } + + // This uses a typed nil pointer, rather than a real instance because + // we need the type information, but not an instance of the type. + sw.Do("scheme.AddValidationFunc(", targs) + sw.Do(" ($.typePfx$$.rootType|raw$)(nil), ", targs) + sw.Do(" func(ctx $.context.Context$, op $.operation.Operation|raw$, obj, oldObj interface{}, ", targs) + sw.Do(" subresources ...string) $.field.ErrorList|raw$ {\n", targs) + sw.Do(" if len(subresources) == 0 {\n", targs) + sw.Do(" return $.rootType|objectvalidationfn$(", targs) + sw.Do(" ctx, ", targs) + sw.Do(" op, ", targs) + sw.Do(" nil /* fldPath */, ", targs) + sw.Do(" obj.($.typePfx$$.rootType|raw$), ", targs) + sw.Do(" $.safe.Cast|raw$[$.typePfx$$.rootType|raw$](oldObj))\n", targs) + sw.Do(" }\n", targs) + + if statusType != nil { + targs["statusTypePfx"] = "" + targs["statusTypePtrPfx"] = "" + if !isNilableType(statusType) { + targs["statusTypePfx"] = "*" + targs["statusTypePtrPfx"] = "&" + } + sw.Do(" if len(subresources) == 1 && subresources[0] == \"status\" {\n", targs) + if hasValidations(g.discovered.typeNodes[statusType]) { + sw.Do(" root := obj.($.typePfx$$.rootType|raw$)\n", targs) + sw.Do(" return $.statusType|objectvalidationfn$(", targs) + sw.Do(" ctx, ", targs) + sw.Do(" op, ", targs) + sw.Do(" nil /* fldPath */, ", targs) + sw.Do(" &root.$.statusField$, ", targs) + sw.Do(" $.safe.Field|raw$(", targs) + sw.Do(" $.safe.Cast|raw$[$.typePfx$$.rootType|raw$](oldObj), ", targs) + sw.Do(" func(oldObj $.typePfx$$.rootType|raw$) $.statusTypePfx$$.statusType|raw$ { ", targs) + sw.Do(" return $.statusTypePtrPfx$oldObj.$.statusField$ ", targs) + sw.Do(" }))\n", targs) + } else { + sw.Do(" return nil // $.statusType|raw$ has no validation\n", targs) + } + sw.Do(" }\n", targs) + } + sw.Do(" return $.field.ErrorList|raw${", targs) + sw.Do(" $.field.InternalError|raw$(", targs) + sw.Do(" nil, ", targs) + sw.Do(" $.fmt.Errorf|raw$(\"no validation found for %T, subresources: %v\", obj, subresources))", targs) + sw.Do(" }\n", targs) + sw.Do("})\n", targs) + } + sw.Do("return nil\n", nil) + sw.Do("}\n\n", nil) +} + +// emitValidationFunction emits a validation function for the specified type. +func (g *genValidations) emitValidationFunction(c *generator.Context, t *types.Type, sw *generator.SnippetWriter) { + if !hasValidations(g.discovered.typeNodes[t]) { + return + } + + targs := generator.Args{ + "inType": t, + "field": mkSymbolArgs(c, fieldPkgSymbols), + "operation": mkSymbolArgs(c, operationPkgSymbols), + "context": mkSymbolArgs(c, contextPkgSymbols), + "objTypePfx": "*", + } + if isNilableType(t) { + targs["objTypePfx"] = "" + } + + node := g.discovered.typeNodes[t] + if node == nil { + panic(fmt.Sprintf("found nil node for root-type %v", t)) + } + sw.Do("func $.inType|objectvalidationfn$(", targs) + sw.Do(" ctx $.context.Context|raw$, ", targs) + sw.Do(" op $.operation.Operation|raw$, ", targs) + sw.Do(" fldPath *$.field.Path|raw$, ", targs) + sw.Do(" obj, oldObj $.objTypePfx$$.inType|raw$) ", targs) + sw.Do("(errs $.field.ErrorList|raw$) {\n", targs) + fakeChild := &childNode{ + node: node, + childType: t, + } + g.emitValidationForChild(c, fakeChild, sw) + sw.Do("return errs\n", nil) + sw.Do("}\n\n", nil) +} + +// emitValidationForChild emits code for the specified childNode, calling +// type-attached validations and then descending into the type (e.g. struct +// fields). +// +// Emitted code assumes that the value in question is always a pair of nilable +// variables named "obj" and "oldObj", and the field path to this value is +// named "fldPath". +// +// This function assumes that thisChild.node is not nil. +func (g *genValidations) emitValidationForChild(c *generator.Context, thisChild *childNode, sw *generator.SnippetWriter) { + thisNode := thisChild.node + inType := thisNode.valueType + + targs := generator.Args{ + "inType": inType, + "field": mkSymbolArgs(c, fieldPkgSymbols), + "safe": mkSymbolArgs(c, safePkgSymbols), + } + + didSome := false // for prettier output later + + // Emit code for type-attached validations. + if validations := thisNode.typeValidations; !validations.Empty() { + switch thisNode.valueType.Kind { + case types.Struct, types.Alias: // OK + default: + panic(fmt.Sprintf("unexpected type-validations on type %v, kind %s", thisNode.valueType, thisNode.valueType.Kind)) + } + sw.Do("// type $.inType|raw$\n", targs) + emitCallsToValidators(c, validations.Functions, sw) + emitComments(validations.Comments, sw) + sw.Do("\n", nil) + didSome = true + } + + // Descend into the type. + switch inType.Kind { + case types.Builtin: + // Nothing further. + case types.Slice, types.Array: + // Nothing further + case types.Map: + // Nothing further + case types.Alias: + g.emitValidationForChild(c, thisNode.underlying, sw) + case types.Struct: + for _, fld := range thisNode.fields { + if len(fld.name) == 0 { + panic(fmt.Sprintf("missing field name in type %s (field-type %s)", thisNode.valueType, fld.childType)) + } + // Missing JSON name is checked iff we have code to emit. + + // Accumulate into a buffer so we don't emit empty functions. + buf := bytes.NewBuffer(nil) + bufsw := sw.Dup(buf) + + validations := fld.fieldValidations + if !validations.Empty() { + emitCallsToValidators(c, validations.Functions, bufsw) + emitComments(validations.Comments, bufsw) + } + + // If the node is nil, this must be a type in a package we are not + // handling - it's effectively opaque to us. + if fld.node != nil { + // Get to the real type. + switch fld.node.valueType.Kind { + case types.Alias, types.Struct: + // If this field is another type, call its validation function. + g.emitCallToOtherTypeFunc(c, fld.node, bufsw) + default: + // Descend into this field. + g.emitValidationForChild(c, fld, bufsw) + } + } + + if buf.Len() > 0 { + if len(fld.jsonName) == 0 { + continue // TODO: Embedded (inline) types are expected to be unnamed. + } + + leafType, typePfx, exprPfx := getLeafTypeAndPrefixes(fld.childType) + targs := targs.WithArgs(generator.Args{ + "fieldName": fld.name, + "fieldJSON": fld.jsonName, + "fieldType": leafType, + "fieldTypePfx": typePfx, + "fieldExprPfx": exprPfx, + }) + + if didSome { + sw.Do("\n", nil) + } + sw.Do("// field $.inType|raw$.$.fieldName$\n", targs) + sw.Do("errs = append(errs,\n", targs) + sw.Do(" func(fldPath *$.field.Path|raw$, obj, oldObj $.fieldTypePfx$$.fieldType|raw$) (errs $.field.ErrorList|raw$) {\n", targs) + if err := sw.Merge(buf, bufsw); err != nil { + panic(fmt.Sprintf("failed to merge buffer: %v", err)) + } + sw.Do(" return\n", targs) + sw.Do(" }(fldPath.Child(\"$.fieldJSON$\"), ", targs) + sw.Do(" $.fieldExprPfx$obj.$.fieldName$, ", targs) + sw.Do(" $.safe.Field|raw$(oldObj, ", targs) + sw.Do(" func(oldObj *$.inType|raw$) $.fieldTypePfx$$.fieldType|raw$ {", targs) + sw.Do(" return $.fieldExprPfx$oldObj.$.fieldName$", targs) + sw.Do(" }),", targs) + sw.Do(" )...)\n", targs) + sw.Do("\n", nil) + } else { + targs := targs.WithArgs(generator.Args{ + "fieldName": fld.name, + }) + sw.Do("// field $.inType|raw$.$.fieldName$ has no validation\n", targs) + } + didSome = true + } + default: + panic(fmt.Sprintf("unhandled type: %v (kind %s)", inType, inType.Kind)) + } +} + +// emitCallToOtherTypeFunc generates a call to the specified node's generated +// validation function for a field in some parent context. +// +// Emitted code assumes that the value in question is always a pair of nilable +// variables named "obj" and "oldObj", and the field path to this value is +// named "fldPath". +func (g *genValidations) emitCallToOtherTypeFunc(c *generator.Context, node *typeNode, sw *generator.SnippetWriter) { + // If this type has no validations (transitively) then we don't need to do + // anything. + if !hasValidations(node) { + return + } + + targs := generator.Args{ + "funcName": c.Universe.Type(node.funcName), + } + sw.Do("errs = append(errs, $.funcName|raw$(ctx, op, fldPath, obj, oldObj)...)\n", targs) +} + +// emitCallsToValidators emits calls to a list of validation functions for +// a single field or type. validations is a list of functions to call, with +// arguments. +// +// When calling registered validators, we always pass a nilable type. E.g. if +// the field's type is string, we pass *string, and if the field's type is +// *string, we also pass *string. This means that validators need to do +// nil-checks themselves, if they intend to dereference the pointer. This +// makes updates more consistent. +// +// Emitted code assumes that the value in question is always a pair of nilable +// variables named "obj" and "oldObj", and the field path to this value is +// named "fldPath". +func emitCallsToValidators(c *generator.Context, validations []validators.FunctionGen, sw *generator.SnippetWriter) { + // Helper func + sort := func(in []validators.FunctionGen) []validators.FunctionGen { + sooner := make([]validators.FunctionGen, 0, len(in)) + later := make([]validators.FunctionGen, 0, len(in)) + + for _, fg := range in { + isShortCircuit := (fg.Flags().IsSet(validators.ShortCircuit)) + + if isShortCircuit { + sooner = append(sooner, fg) + } else { + later = append(later, fg) + } + } + result := sooner + result = append(result, later...) + return result + } + + validations = sort(validations) + + for _, v := range validations { + isShortCircuit := v.Flags().IsSet(validators.ShortCircuit) + isNonError := v.Flags().IsSet(validators.NonError) + + fn, extraArgs := v.SignatureAndArgs() + targs := generator.Args{ + "funcName": c.Universe.Type(fn), + "field": mkSymbolArgs(c, fieldPkgSymbols), + } + + emitCall := func() { + sw.Do("$.funcName|raw$", targs) + typeArgs := v.TypeArgs() + if len(typeArgs) > 0 { + sw.Do("[", nil) + for i, typeArg := range typeArgs { + sw.Do("$.|raw$", c.Universe.Type(typeArg)) + if i < len(typeArgs)-1 { + sw.Do(",", nil) + } + } + sw.Do("]", nil) + } + sw.Do("(ctx, op, fldPath, obj, oldObj", targs) + for _, arg := range extraArgs { + sw.Do(", ", nil) + toGolangSourceDataLiteral(sw, c, arg) + } + sw.Do(")", targs) + } + + // If validation is conditional, wrap the validation function with a conditions check. + if !v.Conditions().Empty() { + emitBaseFunction := emitCall + emitCall = func() { + sw.Do("func() $.field.ErrorList|raw$ {\n", targs) + sw.Do(" if ", nil) + firstCondition := true + if len(v.Conditions().OptionEnabled) > 0 { + sw.Do("op.Options.Has($.$)", strconv.Quote(v.Conditions().OptionEnabled)) + firstCondition = false + } + if len(v.Conditions().OptionDisabled) > 0 { + if !firstCondition { + sw.Do(" && ", nil) + } + sw.Do("!op.Options.Has($.$)", strconv.Quote(v.Conditions().OptionDisabled)) + } + sw.Do(" {\n", nil) + sw.Do(" return ", nil) + emitBaseFunction() + sw.Do("\n", nil) + sw.Do(" } else {\n", nil) + sw.Do(" return nil // skip validation\n", nil) + sw.Do(" }\n", nil) + sw.Do("}()", nil) + } + } + + if isShortCircuit { + sw.Do("if e := ", nil) + emitCall() + sw.Do("; len(e) != 0 {\n", nil) + if !isNonError { + sw.Do("errs = append(errs, e...)\n", nil) + } + sw.Do(" return // do not proceed\n", nil) + sw.Do("}\n", nil) + } else { + if isNonError { + emitCall() + } else { + sw.Do("errs = append(errs, ", nil) + emitCall() + sw.Do("...)\n", nil) + } + } + } +} + +func emitComments(comments []string, sw *generator.SnippetWriter) { + for _, comment := range comments { + sw.Do("// ", nil) + sw.Do(comment, nil) + sw.Do("\n", nil) + } +} + +// emitValidationVariables emits a list of variable declarations. Each variable declaration has a +// private (unexported) variable name, and a function invocation declaration that is expected +// to initialize the value of the variable. +func (g *genValidations) emitValidationVariables(c *generator.Context, t *types.Type, sw *generator.SnippetWriter) { + tn := g.discovered.typeNodes[t] + + variables := tn.typeValidations.Variables + slices.SortFunc(variables, func(a, b validators.VariableGen) int { + return cmp.Compare(a.Var().Name, b.Var().Name) + }) + for _, variable := range variables { + supportInitFn, supportInitArgs := variable.Init().SignatureAndArgs() + targs := generator.Args{ + "varName": c.Universe.Type(types.Name(variable.Var())), + "initFn": c.Universe.Type(supportInitFn), + } + sw.Do("var $.varName|private$ = $.initFn|raw$", targs) + typeArgs := variable.Init().TypeArgs() + if len(typeArgs) > 0 { + sw.Do("[", nil) + for i, typeArg := range typeArgs { + sw.Do("$.|raw$", c.Universe.Type(typeArg)) + if i < len(typeArgs)-1 { + sw.Do(",", nil) + } + } + sw.Do("]", nil) + } + sw.Do("(", targs) + for i, arg := range supportInitArgs { + toGolangSourceDataLiteral(sw, c, arg) + if i < len(supportInitArgs)-1 { + sw.Do(", ", nil) + } + } + sw.Do(")\n", nil) + + } +} + +func toGolangSourceDataLiteral(sw *generator.SnippetWriter, c *generator.Context, value any) { + // For safety, be strict in what values we output to visited source, and ensure strings + // are quoted. + + switch v := value.(type) { + case uint, uint8, uint16, uint32, uint64, int, int8, int16, int32, int64, float32, float64, bool: + sw.Do(fmt.Sprintf("%v", value), nil) + case string: + // If the incoming string was quoted, we still do it ourselves, JIC. + str := value.(string) + if s, err := strconv.Unquote(str); err == nil { + str = s + } + sw.Do(fmt.Sprintf("%q", str), nil) + case *types.Type: + sw.Do("$.|raw$", v) + case types.Member: + sw.Do("obj."+v.Name, nil) + case validators.Identifier: + sw.Do("$.|raw$", c.Universe.Type(types.Name(v))) + case *validators.Identifier: + sw.Do("$.|raw$", c.Universe.Type(types.Name(*v))) + case validators.PrivateVar: + sw.Do("$.|private$", c.Universe.Type(types.Name(v))) + case *validators.PrivateVar: + sw.Do("$.|private$", c.Universe.Type(types.Name(*v))) + case validators.WrapperFunction: + fn, extraArgs := v.Function.SignatureAndArgs() + if len(extraArgs) == 0 { + // If the function to be wrapped has no additional arguments, we can + // just use it directly. + targs := generator.Args{ + "funcName": c.Universe.Type(fn), + } + sw.Do("$.funcName|raw$", targs) + } else { + // If the function to be wrapped has additional arguments, we need + // a "standard signature" validation function to wrap it. + targs := generator.Args{ + "funcName": c.Universe.Type(fn), + "field": mkSymbolArgs(c, fieldPkgSymbols), + "operation": mkSymbolArgs(c, operationPkgSymbols), + "context": mkSymbolArgs(c, contextPkgSymbols), + "objType": v.ObjType, + "objTypePfx": "*", + } + if isNilableType(v.ObjType) { + targs["objTypePfx"] = "" + } + + emitCall := func() { + sw.Do("return $.funcName|raw$", targs) + typeArgs := v.Function.TypeArgs() + if len(typeArgs) > 0 { + sw.Do("[", nil) + for i, typeArg := range typeArgs { + sw.Do("$.|raw$", c.Universe.Type(typeArg)) + if i < len(typeArgs)-1 { + sw.Do(",", nil) + } + } + sw.Do("]", nil) + } + sw.Do("(ctx, op, fldPath, obj, oldObj", targs) + for _, arg := range extraArgs { + sw.Do(", ", nil) + toGolangSourceDataLiteral(sw, c, arg) + } + sw.Do(")", targs) + } + sw.Do("func(", targs) + sw.Do(" ctx $.context.Context|raw$, ", targs) + sw.Do(" op $.operation.Operation|raw$, ", targs) + sw.Do(" fldPath *$.field.Path|raw$, ", targs) + sw.Do(" obj, oldObj $.objTypePfx$$.objType|raw$ ", targs) + sw.Do(") $.field.ErrorList|raw$ {\n", targs) + emitCall() + sw.Do("\n}", targs) + } + case validators.Literal: + sw.Do("$.$", v) + case validators.FunctionLiteral: + sw.Do("func(", nil) + for i, param := range v.Parameters { + if i > 0 { + sw.Do(", ", nil) + } + targs := generator.Args{ + "name": param.Name, + "type": param.Type, + } + sw.Do("$.name$ $.type|raw$", targs) + } + sw.Do(")", nil) + if len(v.Results) > 1 { + sw.Do(" (", nil) + } + for i, ret := range v.Results { + if i > 0 { + sw.Do(", ", nil) + } + targs := generator.Args{ + "name": ret.Name, + "type": ret.Type, + } + sw.Do("$.name$ $.type|raw$", targs) + } + if len(v.Results) > 1 { + sw.Do(")", nil) + } + sw.Do(" { $.$ }", v.Body) + default: + rv := reflect.ValueOf(value) + switch rv.Kind() { + case reflect.Slice, reflect.Array: + arraySize := "" + if rv.Kind() == reflect.Array { + arraySize = strconv.Itoa(rv.Len()) + } + var itemType string + switch rv.Type().Elem().Kind() { + case reflect.String: // For now, only support lists of strings. + itemType = rv.Type().Elem().Name() + default: + panic(fmt.Sprintf("Unsupported extraArg type: %T", value)) + } + rv.Type().Elem() + sw.Do("[$.arraySize$]$.itemType${", map[string]string{"arraySize": arraySize, "itemType": itemType}) + for i := 0; i < rv.Len(); i++ { + val := rv.Index(i) + toGolangSourceDataLiteral(sw, c, val.Interface()) + if i < rv.Len()-1 { + sw.Do(", ", nil) + } + } + sw.Do("}", nil) + default: + // TODO: check this during discovery and emit an error with more useful information + panic(fmt.Sprintf("Unsupported extraArg type: %T", value)) + } + } +} + +// findMemberByFieldName finds the member which matches the specified name. +// The name is expected to be the "JSON name", rather than the Go name. This +// function will descend into embedded types which would appear in JSON to be +// directly in the parent struct. If t is not a struct this does nothing. +// nolint:unused // FIXME: Remove once all validation-gen PRs are merged +func findMemberByFieldName(t *types.Type, name string) (types.Member, bool) { + for _, m := range t.Members { + if jsonTag, found := tags.LookupJSON(m); found { + // If there is a JSON tag of the exact name, use it. + if jsonTag.Name == name { + return m, true + } + // If there is a (non-standard) "inline" tag, look in the type. + if jsonTag.Inline { + return findMemberByFieldName(m.Type, name) + } + } + // If this field was embedded, look in that type. + if m.Embedded { + return findMemberByFieldName(m.Type, name) + } + } + return types.Member{}, false +} + +// isNilableType returns true if the argument type can be compared to nil. +func isNilableType(t *types.Type) bool { + for t.Kind == types.Alias { + t = t.Underlying + } + switch t.Kind { + case types.Pointer, types.Map, types.Slice, types.Interface: // Note: Arrays are not nilable + return true + } + return false +} + +// getLeafTypeAndPrefixes returns the "leaf value type" for a given type, as +// well as type and expression prefix strings for the input type. The type +// prefix can be prepended to the given type's name to produce the nilable form +// of that type. The expression prefix can be prepended to a variable of the +// given type to produce the nilable form of that value. +// +// Example: Given an input type "string" this should produce (string, "*", "&"). +// That is to say: the value-type is "string", which yields "*string" when the +// type prefix is applied, and a variable "x" becomes "&x" when the expression +// prefix is applied. +// +// Example: Given an input type "*string" this should produce (string, "*", ""). +// That is to say: the value-type is "string", which yields "*string" when the +// type prefix is applied, and a variable "x" remains "x" when the expression +// prefix is applied. +func getLeafTypeAndPrefixes(inType *types.Type) (*types.Type, string, string) { + leafType := inType + typePfx := "" + exprPfx := "" + + nPtrs := 0 + for leafType.Kind == types.Pointer { + nPtrs++ + leafType = leafType.Elem + } + if !isNilableType(leafType) { + typePfx = "*" + if nPtrs == 0 { + exprPfx = "&" + } else { + exprPfx = strings.Repeat("*", nPtrs-1) + } + } else { + exprPfx = strings.Repeat("*", nPtrs) + } + + return leafType, typePfx, exprPfx +} + +// FixtureTests generates a test file that checks all validateFalse validations. +func FixtureTests(outputFilename string, testFixtureTags sets.Set[string]) generator.Generator { + return &fixtureTestGen{ + GoGenerator: generator.GoGenerator{ + OutputFilename: outputFilename, + }, + testFixtureTags: testFixtureTags, + } +} + +type fixtureTestGen struct { + generator.GoGenerator + testFixtureTags sets.Set[string] +} + +func (g *fixtureTestGen) Imports(_ *generator.Context) (imports []string) { + return []string{`"testing"`} +} + +func (g *fixtureTestGen) Init(c *generator.Context, w io.Writer) error { + if g.testFixtureTags.Has("validateFalse") { + sw := generator.NewSnippetWriter(w, c, "$", "$") + sw.Do("func TestValidation(t *testing.T) {\n", nil) + sw.Do(" localSchemeBuilder.Test(t).ValidateFixtures()\n", nil) + sw.Do("}\n", nil) + } + return nil +} diff --git a/staging/src/k8s.io/code-generator/cmd/validation-gen/validation_test.go b/staging/src/k8s.io/code-generator/cmd/validation-gen/validation_test.go new file mode 100644 index 00000000000..0d93ccb4ecf --- /dev/null +++ b/staging/src/k8s.io/code-generator/cmd/validation-gen/validation_test.go @@ -0,0 +1,374 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "testing" + + "k8s.io/gengo/v2/types" +) + +func TestGetLeafTypeAndPrefixes(t *testing.T) { + stringType := &types.Type{ + Name: types.Name{ + Package: "", + Name: "string", + }, + Kind: types.Builtin, + } + + ptrTo := func(t *types.Type) *types.Type { + return &types.Type{ + Name: types.Name{ + Package: "", + Name: "*" + t.Name.String(), + }, + Kind: types.Pointer, + Elem: t, + } + } + + sliceOf := func(t *types.Type) *types.Type { + return &types.Type{ + Name: types.Name{ + Package: "", + Name: "[]" + t.Name.String(), + }, + Kind: types.Slice, + Elem: t, + } + } + + mapOf := func(t *types.Type) *types.Type { + return &types.Type{ + Name: types.Name{ + Package: "", + Name: "map[string]" + t.Name.String(), + }, + Kind: types.Map, + Key: stringType, + Elem: t, + } + } + + aliasOf := func(name string, t *types.Type) *types.Type { + return &types.Type{ + Name: types.Name{ + Package: "", + Name: "Alias_" + name, + }, + Kind: types.Alias, + Underlying: t, + } + } + + cases := []struct { + in *types.Type + expectedType *types.Type + expectedTypePfx string + expectedExprPfx string + }{{ + // string + in: stringType, + expectedType: stringType, + expectedTypePfx: "*", + expectedExprPfx: "&", + }, { + // *string + in: ptrTo(stringType), + expectedType: stringType, + expectedTypePfx: "*", + expectedExprPfx: "", + }, { + // **string + in: ptrTo(ptrTo(stringType)), + expectedType: stringType, + expectedTypePfx: "*", + expectedExprPfx: "*", + }, { + // ***string + in: ptrTo(ptrTo(ptrTo(stringType))), + expectedType: stringType, + expectedTypePfx: "*", + expectedExprPfx: "**", + }, { + // []string + in: sliceOf(stringType), + expectedType: sliceOf(stringType), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // *[]string + in: ptrTo(sliceOf(stringType)), + expectedType: sliceOf(stringType), + expectedTypePfx: "", + expectedExprPfx: "*", + }, { + // **[]string + in: ptrTo(ptrTo(sliceOf(stringType))), + expectedType: sliceOf(stringType), + expectedTypePfx: "", + expectedExprPfx: "**", + }, { + // ***[]string + in: ptrTo(ptrTo(ptrTo(sliceOf(stringType)))), + expectedType: sliceOf(stringType), + expectedTypePfx: "", + expectedExprPfx: "***", + }, { + // map[string]string + in: mapOf(stringType), + expectedType: mapOf(stringType), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // *map[string]string + in: ptrTo(mapOf(stringType)), + expectedType: mapOf(stringType), + expectedTypePfx: "", + expectedExprPfx: "*", + }, { + // **map[string]string + in: ptrTo(ptrTo(mapOf(stringType))), + expectedType: mapOf(stringType), + expectedTypePfx: "", + expectedExprPfx: "**", + }, { + // ***map[string]string + in: ptrTo(ptrTo(ptrTo(mapOf(stringType)))), + expectedType: mapOf(stringType), + expectedTypePfx: "", + expectedExprPfx: "***", + }, { + // alias of string + in: aliasOf("s", stringType), + expectedType: aliasOf("s", stringType), + expectedTypePfx: "*", + expectedExprPfx: "&", + }, { + // alias of *string + in: aliasOf("ps", ptrTo(stringType)), + expectedType: aliasOf("ps", stringType), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // alias of **string + in: aliasOf("pps", ptrTo(ptrTo(stringType))), + expectedType: aliasOf("pps", stringType), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // alias of ***string + in: aliasOf("ppps", ptrTo(ptrTo(ptrTo(stringType)))), + expectedType: aliasOf("ppps", stringType), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // alias of []string + in: aliasOf("ls", sliceOf(stringType)), + expectedType: aliasOf("ls", sliceOf(stringType)), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // alias of *[]string + in: aliasOf("pls", ptrTo(sliceOf(stringType))), + expectedType: aliasOf("pls", sliceOf(stringType)), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // alias of **[]string + in: aliasOf("ppls", ptrTo(ptrTo(sliceOf(stringType)))), + expectedType: aliasOf("ppls", sliceOf(stringType)), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // alias of ***[]string + in: aliasOf("pppls", ptrTo(ptrTo(ptrTo(sliceOf(stringType))))), + expectedType: aliasOf("pppls", sliceOf(stringType)), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // alias of map[string]string + in: aliasOf("ms", mapOf(stringType)), + expectedType: aliasOf("ms", mapOf(stringType)), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // alias of *map[string]string + in: aliasOf("pms", ptrTo(mapOf(stringType))), + expectedType: aliasOf("pms", mapOf(stringType)), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // alias of **map[string]string + in: aliasOf("ppms", ptrTo(ptrTo(mapOf(stringType)))), + expectedType: aliasOf("ppms", mapOf(stringType)), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // alias of ***map[string]string + in: aliasOf("pppms", ptrTo(ptrTo(ptrTo(mapOf(stringType))))), + expectedType: aliasOf("pppms", mapOf(stringType)), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // *alias-of-string + in: ptrTo(aliasOf("s", stringType)), + expectedType: aliasOf("s", stringType), + expectedTypePfx: "*", + expectedExprPfx: "", + }, { + // **alias-of-string + in: ptrTo(ptrTo(aliasOf("s", stringType))), + expectedType: aliasOf("s", stringType), + expectedTypePfx: "*", + expectedExprPfx: "*", + }, { + // ***alias-of-string + in: ptrTo(ptrTo(ptrTo(aliasOf("s", stringType)))), + expectedType: aliasOf("s", stringType), + expectedTypePfx: "*", + expectedExprPfx: "**", + }, { + // []alias-of-string + in: sliceOf(aliasOf("s", stringType)), + expectedType: sliceOf(aliasOf("s", stringType)), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // *[]alias-of-string + in: ptrTo(sliceOf(aliasOf("s", stringType))), + expectedType: sliceOf(aliasOf("s", stringType)), + expectedTypePfx: "", + expectedExprPfx: "*", + }, { + // **[]alias-of-string + in: ptrTo(ptrTo(sliceOf(aliasOf("s", stringType)))), + expectedType: sliceOf(aliasOf("s", stringType)), + expectedTypePfx: "", + expectedExprPfx: "**", + }, { + // ***[]alias-of-string + in: ptrTo(ptrTo(ptrTo(sliceOf(aliasOf("s", stringType))))), + expectedType: sliceOf(aliasOf("s", stringType)), + expectedTypePfx: "", + expectedExprPfx: "***", + }, { + // map[string]alias-of-string + in: mapOf(aliasOf("s", stringType)), + expectedType: mapOf(aliasOf("s", stringType)), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // *map[string]alias-of-string + in: ptrTo(mapOf(aliasOf("s", stringType))), + expectedType: mapOf(aliasOf("s", stringType)), + expectedTypePfx: "", + expectedExprPfx: "*", + }, { + // **map[string]alias-of-string + in: ptrTo(ptrTo(mapOf(aliasOf("s", stringType)))), + expectedType: mapOf(aliasOf("s", stringType)), + expectedTypePfx: "", + expectedExprPfx: "**", + }, { + // ***map[string]alias-of-string + in: ptrTo(ptrTo(ptrTo(mapOf(aliasOf("s", stringType))))), + expectedType: mapOf(aliasOf("s", stringType)), + expectedTypePfx: "", + expectedExprPfx: "***", + }, { + // *alias-of-*string + in: ptrTo(aliasOf("ps", ptrTo(stringType))), + expectedType: aliasOf("ps", ptrTo(stringType)), + expectedTypePfx: "", + expectedExprPfx: "*", + }, { + // **alias-of-*string + in: ptrTo(ptrTo(aliasOf("ps", ptrTo(stringType)))), + expectedType: aliasOf("ps", ptrTo(stringType)), + expectedTypePfx: "", + expectedExprPfx: "**", + }, { + // ***alias-of-*string + in: ptrTo(ptrTo(ptrTo(aliasOf("ps", ptrTo(stringType))))), + expectedType: aliasOf("ps", ptrTo(stringType)), + expectedTypePfx: "", + expectedExprPfx: "***", + }, { + // []alias-of-*string + in: sliceOf(aliasOf("ps", ptrTo(stringType))), + expectedType: sliceOf(aliasOf("ps", ptrTo(stringType))), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // *[]alias-of-*string + in: ptrTo(sliceOf(aliasOf("ps", ptrTo(stringType)))), + expectedType: sliceOf(aliasOf("ps", ptrTo(stringType))), + expectedTypePfx: "", + expectedExprPfx: "*", + }, { + // **[]alias-of-*string + in: ptrTo(ptrTo(sliceOf(aliasOf("ps", ptrTo(stringType))))), + expectedType: sliceOf(aliasOf("ps", ptrTo(stringType))), + expectedTypePfx: "", + expectedExprPfx: "**", + }, { + // ***[]alias-of-*string + in: ptrTo(ptrTo(ptrTo(sliceOf(aliasOf("ps", ptrTo(stringType)))))), + expectedType: sliceOf(aliasOf("ps", ptrTo(stringType))), + expectedTypePfx: "", + expectedExprPfx: "***", + }, { + // map[string]alias-of-*string + in: mapOf(aliasOf("ps", ptrTo(stringType))), + expectedType: mapOf(aliasOf("ps", ptrTo(stringType))), + expectedTypePfx: "", + expectedExprPfx: "", + }, { + // *map[string]alias-of-*string + in: ptrTo(mapOf(aliasOf("ps", ptrTo(stringType)))), + expectedType: mapOf(aliasOf("ps", ptrTo(stringType))), + expectedTypePfx: "", + expectedExprPfx: "*", + }, { + // **map[string]alias-of-*string + in: ptrTo(ptrTo(mapOf(aliasOf("ps", ptrTo(stringType))))), + expectedType: mapOf(aliasOf("ps", ptrTo(stringType))), + expectedTypePfx: "", + expectedExprPfx: "**", + }, { + // ***map[string]alias-of-*string + in: ptrTo(ptrTo(ptrTo(mapOf(aliasOf("ps", ptrTo(stringType)))))), + expectedType: mapOf(aliasOf("ps", ptrTo(stringType))), + expectedTypePfx: "", + expectedExprPfx: "***", + }} + + for _, tc := range cases { + leafType, typePfx, exprPfx := getLeafTypeAndPrefixes(tc.in) + if got, want := leafType.Name.String(), tc.expectedType.Name.String(); got != want { + t.Errorf("%q: wrong leaf type: expected %q, got %q", tc.in, want, got) + } + if got, want := typePfx, tc.expectedTypePfx; got != want { + t.Errorf("%q: wrong type prefix: expected %q, got %q", tc.in, want, got) + } + if got, want := exprPfx, tc.expectedExprPfx; got != want { + t.Errorf("%q: wrong expr prefix: expected %q, got %q", tc.in, want, got) + } + } +} diff --git a/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/common.go b/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/common.go new file mode 100644 index 00000000000..d72950a7846 --- /dev/null +++ b/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/common.go @@ -0,0 +1,51 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package validators + +import ( + "k8s.io/gengo/v2/parser/tags" + "k8s.io/gengo/v2/types" +) + +const ( + // libValidationPkg is the pkgpath to our "standard library" of validation + // functions. + libValidationPkg = "k8s.io/apimachinery/pkg/api/validate" +) + +func getMemberByJSON(t *types.Type, jsonName string) *types.Member { + for i := range t.Members { + if jsonTag, ok := tags.LookupJSON(t.Members[i]); ok { + if jsonTag.Name == jsonName { + return &t.Members[i] + } + } + } + return nil +} + +// isNilableType returns true if the argument type can be compared to nil. +func isNilableType(t *types.Type) bool { + for t.Kind == types.Alias { + t = t.Underlying + } + switch t.Kind { + case types.Pointer, types.Map, types.Slice, types.Interface: // Note: Arrays are not nilable + return true + } + return false +} diff --git a/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/registry.go b/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/registry.go new file mode 100644 index 00000000000..1fabcf3833a --- /dev/null +++ b/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/registry.go @@ -0,0 +1,240 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package validators + +import ( + "cmp" + "fmt" + "slices" + "sort" + "sync" + "sync/atomic" + + "k8s.io/gengo/v2" + "k8s.io/gengo/v2/generator" +) + +// This is the global registry of tag validators. For simplicity this is in +// the same package as the implementations, but it should not be used directly. +var globalRegistry = ®istry{ + tagValidators: map[string]TagValidator{}, +} + +// registry holds a list of registered tags. +type registry struct { + lock sync.Mutex + initialized atomic.Bool // init() was called + + tagValidators map[string]TagValidator // keyed by tagname + tagIndex []string // all tag names + + typeValidators []TypeValidator +} + +func (reg *registry) addTagValidator(tv TagValidator) { + if reg.initialized.Load() { + panic("registry was modified after init") + } + + reg.lock.Lock() + defer reg.lock.Unlock() + + name := tv.TagName() + if _, exists := globalRegistry.tagValidators[name]; exists { + panic(fmt.Sprintf("tag %q was registered twice", name)) + } + globalRegistry.tagValidators[name] = tv +} + +func (reg *registry) addTypeValidator(tv TypeValidator) { + if reg.initialized.Load() { + panic("registry was modified after init") + } + + reg.lock.Lock() + defer reg.lock.Unlock() + + globalRegistry.typeValidators = append(globalRegistry.typeValidators, tv) +} + +func (reg *registry) init(c *generator.Context) { + if reg.initialized.Load() { + panic("registry.init() was called twice") + } + + reg.lock.Lock() + defer reg.lock.Unlock() + + cfg := Config{ + GengoContext: c, + Validator: reg, + } + + for _, tv := range globalRegistry.tagValidators { + reg.tagIndex = append(reg.tagIndex, tv.TagName()) + tv.Init(cfg) + } + sort.Strings(reg.tagIndex) + + for _, tv := range reg.typeValidators { + tv.Init(cfg) + } + slices.SortFunc(reg.typeValidators, func(a, b TypeValidator) int { + return cmp.Compare(a.Name(), b.Name()) + }) + + reg.initialized.Store(true) +} + +// ExtractValidations considers the given context (e.g. a type definition) and +// evaluates registered validators. This includes type validators (which run +// against all types) and tag validators which run only if a specific tag is +// found in the associated comment block. Any matching validators produce zero +// or more validations, which will later be rendered by the code-generation +// logic. +func (reg *registry) ExtractValidations(context Context, comments []string) (Validations, error) { + if !reg.initialized.Load() { + panic("registry.init() was not called") + } + + validations := Validations{} + + // Extract tags and run matching tag-validators first. + tags, err := gengo.ExtractFunctionStyleCommentTags("+", reg.tagIndex, comments) + if err != nil { + return Validations{}, fmt.Errorf("failed to parse tags: %w", err) + } + phases := reg.sortTagsIntoPhases(tags) + for _, idx := range phases { + for _, tag := range idx { + vals := tags[tag] + tv := reg.tagValidators[tag] + if scopes := tv.ValidScopes(); !scopes.Has(context.Scope) && !scopes.Has(ScopeAny) { + return Validations{}, fmt.Errorf("tag %q cannot be specified on %s", tv.TagName(), context.Scope) + } + for _, val := range vals { // tags may have multiple values + if theseValidations, err := tv.GetValidations(context, val.Args, val.Value); err != nil { + return Validations{}, fmt.Errorf("tag %q: %w", tv.TagName(), err) + } else { + validations.Add(theseValidations) + } + } + } + } + + // Run type-validators after tag validators are done. + if context.Scope == ScopeType { + // Run all type-validators. + for _, tv := range reg.typeValidators { + if theseValidations, err := tv.GetValidations(context); err != nil { + return Validations{}, fmt.Errorf("type validator %q: %w", tv.Name(), err) + } else { + validations.Add(theseValidations) + } + } + } + + return validations, nil +} + +func (reg *registry) sortTagsIntoPhases(tags map[string][]gengo.Tag) [][]string { + // First sort all tags by their name, so the final output is deterministic. + // + // It makes more sense to sort here, rather than when emitting because: + // + // Consider a type or field with the following comments: + // + // // +k8s:validateFalse="111" + // // +k8s:validateFalse="222" + // // +k8s:ifOptionEnabled(Foo)=+k8s:validateFalse="333" + // + // Tag extraction will retain the relative order between 111 and 222, but + // 333 is extracted as tag "k8s:ifOptionEnabled". Those are all in a map, + // which we iterate (in a random order). When it reaches the emit stage, + // the "ifOptionEnabled" part is gone, and we will have 3 functionGen + // objects, all with tag "k8s:validateFalse", in a non-deterministic order + // because of the map iteration. If we sort them at that point, we won't + // have enough information to do something smart, unless we look at the + // args, which are opaque to us. + // + // Sorting it earlier means we can sort "k8s:ifOptionEnabled" against + // "k8s:validateFalse". All of the records within each of those is + // relatively ordered, so the result here would be to put "ifOptionEnabled" + // before "validateFalse" (lexicographical is better than random). + sortedTags := []string{} + for tag := range tags { + sortedTags = append(sortedTags, tag) + } + sort.Strings(sortedTags) + + // Now split them into phases. + phase0 := []string{} // regular tags + phase1 := []string{} // "late" tags + for _, tn := range sortedTags { + tv := reg.tagValidators[tn] + if _, ok := tv.(LateTagValidator); ok { + phase1 = append(phase1, tn) + } else { + phase0 = append(phase0, tn) + } + } + return [][]string{phase0, phase1} +} + +// Docs returns documentation for each tag in this registry. +func (reg *registry) Docs() []TagDoc { + var result []TagDoc + for _, k := range reg.tagIndex { + v := reg.tagValidators[k] + result = append(result, v.Docs()) + } + return result +} + +// RegisterTagValidator must be called by any validator which wants to run when +// a specific tag is found. +func RegisterTagValidator(tv TagValidator) { + globalRegistry.addTagValidator(tv) +} + +// RegisterTypeValidator must be called by any validator which wants to run +// against every type definition. +func RegisterTypeValidator(tv TypeValidator) { + globalRegistry.addTypeValidator(tv) +} + +// Validator represents an aggregation of validator plugins. +type Validator interface { + // ExtractValidations considers the given context (e.g. a type definition) and + // evaluates registered validators. This includes type validators (which run + // against all types) and tag validators which run only if a specific tag is + // found in the associated comment block. Any matching validators produce zero + // or more validations, which will later be rendered by the code-generation + // logic. + ExtractValidations(context Context, comments []string) (Validations, error) + + // Docs returns documentation for each known tag. + Docs() []TagDoc +} + +// InitGlobalValidator must be called exactly once by the main application to +// initialize and safely access the global tag registry. Once this is called, +// no more validators may be registered. +func InitGlobalValidator(c *generator.Context) Validator { + globalRegistry.init(c) + return globalRegistry +} diff --git a/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/validators.go b/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/validators.go new file mode 100644 index 00000000000..4ad16b903f5 --- /dev/null +++ b/staging/src/k8s.io/code-generator/cmd/validation-gen/validators/validators.go @@ -0,0 +1,443 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package validators + +import ( + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/gengo/v2/generator" + "k8s.io/gengo/v2/types" +) + +// TagValidator describes a single validation tag and how to use it. +type TagValidator interface { + // Init initializes the implementation. This will be called exactly once. + Init(cfg Config) + + // TagName returns the full tag name (without the "marker" prefix) for this + // tag. + TagName() string + + // ValidScopes returns the set of scopes where this tag may be used. + ValidScopes() sets.Set[Scope] + + // GetValidations returns any validations described by this tag. + GetValidations(context Context, args []string, payload string) (Validations, error) + + // Docs returns user-facing documentation for this tag. + Docs() TagDoc +} + +// LateTagValidator is an optional extension to TagValidator. Any TagValidator +// which implements this interface will be evaluated after all TagValidators +// which do not. +type LateTagValidator interface { + LateTagValidator() +} + +// TypeValidator describes a validator which runs on every type definition. +type TypeValidator interface { + // Init initializes the implementation. This will be called exactly once. + Init(cfg Config) + + // Name returns a unique name for this validator. This is used for sorting + // and logging. + Name() string + + // GetValidations returns any validations imposed by this validator for the + // given context. + // + // The way gengo handles type definitions varies between structs and other + // types. For struct definitions (e.g. `type Foo struct {}`), the realType + // is the struct itself (the Kind field will be `types.Struct`) and the + // parentType will be nil. For other types (e.g. `type Bar string`), the + // realType will be the underlying type and the parentType will be the + // newly defined type (the Kind field will be `types.Alias`). + GetValidations(context Context) (Validations, error) +} + +// Config carries optional configuration information for use by validators. +type Config struct { + // GengoContext provides gengo's generator Context. This allows validators + // to look up all sorts of other information. + GengoContext *generator.Context + + // Validator provides a way to compose validations. + // + // For example, it is possible to define a validation such as + // "+myValidator=+format=IP" by using the registry to extract the + // validation for the embedded "+format=IP" and use those to + // create the final Validations returned by the "+myValidator" tag. + // + // This field MUST NOT be used during init, since other validators may not + // be initialized yet. + Validator Validator +} + +// Scope describes where a validation (or potential validation) is located. +type Scope string + +// Note: All of these values should be strings which can be used in an error +// message such as "may not be used in %s". +const ( + // ScopeAny indicates that a validator may be use in any context. This value + // should never appear in a Context struct, since that indicates a + // specific use. + ScopeAny Scope = "anywhere" + + // ScopeType indicates a validation on a type definition, which applies to + // all instances of that type. + ScopeType Scope = "type definitions" + + // ScopeField indicates a validation on a particular struct field, which + // applies only to that field of that struct. + ScopeField Scope = "struct fields" + + // ScopeListVal indicates a validation which applies to all elements of a + // list field or type. + ScopeListVal Scope = "list values" + + // ScopeMapKey indicates a validation which applies to all keys of a map + // field or type. + ScopeMapKey Scope = "map keys" + + // ScopeMapVal indicates a validation which applies to all values of a map + // field or type. + ScopeMapVal Scope = "map values" + + // TODO: It's not clear if we need to distinguish (e.g.) list values of + // fields from list values of typedefs. We could make {type,field} be + // orthogonal to {scalar, list, list-value, map, map-key, map-value} (and + // maybe even pointers?), but that seems like extra work that is not needed + // for now. +) + +// Context describes where a tag was used, so that the scope can be checked +// and so validators can handle different cases if they need. +type Context struct { + // Scope is where the validation is being considered. + Scope Scope + + // Type provides details about the type being validated. When Scope is + // ScopeType, this is the underlying type. When Scope is ScopeField, this + // is the field's type (including any pointerness). When Scope indicates a + // list-value, map-key, or map-value, this is the type of that key or + // value. + Type *types.Type + + // Parent provides details about the logical parent type of the type being + // validated, when applicable. When Scope is ScopeType, this is the + // newly-defined type (when it exists - gengo handles struct-type + // definitions differently that other "alias" type definitions). When + // Scope is ScopeField, this is the field's parent struct's type. When + // Scope indicates a list-value, map-key, or map-value, this is the type of + // the whole list or map. + // + // Because of how gengo handles struct-type definitions, this field may be + // nil in those cases. + Parent *types.Type + + // Member provides details about a field within a struct, when Scope is + // ScopeField. For all other values of Scope, this will be nil. + Member *types.Member + + // Path provides the field path to the type or field being validated. This + // is useful for identifying an exact context, e.g. to track information + // between related tags. + Path *field.Path +} + +// TagDoc describes a comment-tag and its usage. +type TagDoc struct { + // Tag is the tag name, without the leading '+'. + Tag string + // Args lists any arguments this tag might take. + Args []TagArgDoc + // Usage is how the tag is used, including arguments. + Usage string + // Description is a short description of this tag's purpose. + Description string + // Docs is a human-oriented string explaining this tag. + Docs string + // Scopes lists the place or places this tag may be used. + Scopes []Scope + // Payloads lists zero or more varieties of value for this tag. If this tag + // never has a payload, this list should be empty, but if the payload is + // optional, this list should include an entry for "". + Payloads []TagPayloadDoc +} + +// TagArgDoc describes an argument for a tag (e.g. `+tagName(tagArg)`. +type TagArgDoc struct { + // Description is a short description of this arg (e.g. ``). + Description string +} + +// TagPayloadDoc describes a value for a tag (e.g. `+tagName=tagValue`). Some +// tags upport multiple payloads, including (e.g. `+tagName`). +type TagPayloadDoc struct { + // Description is a short description of this payload (e.g. ``). + Description string + // Docs is a human-orientd string explaining this payload. + Docs string + // Schema details a JSON payload's contents. + Schema []TagPayloadSchema +} + +// TagPayloadSchema describes a JSON tag payload. +type TagPayloadSchema struct { + Key string + Value string + Docs string + Default string +} + +// Validations defines the function calls and variables to generate to perform validation. +type Validations struct { + Functions []FunctionGen + Variables []VariableGen + Comments []string + OpaqueType bool + OpaqueKeyType bool + OpaqueValType bool +} + +func (v *Validations) Empty() bool { + return v.Len() == 0 +} + +func (v *Validations) Len() int { + return len(v.Functions) + len(v.Variables) + len(v.Comments) +} + +func (v *Validations) AddFunction(f FunctionGen) { + v.Functions = append(v.Functions, f) +} + +func (v *Validations) AddVariable(variable VariableGen) { + v.Variables = append(v.Variables, variable) +} + +func (v *Validations) AddComment(comment string) { + v.Comments = append(v.Comments, comment) +} + +func (v *Validations) Add(o Validations) { + v.Functions = append(v.Functions, o.Functions...) + v.Variables = append(v.Variables, o.Variables...) + v.Comments = append(v.Comments, o.Comments...) + v.OpaqueType = v.OpaqueType || o.OpaqueType + v.OpaqueKeyType = v.OpaqueKeyType || o.OpaqueKeyType + v.OpaqueValType = v.OpaqueValType || o.OpaqueValType +} + +// FunctionFlags define optional properties of a validator. Most validators +// can just use DefaultFlags. +type FunctionFlags uint32 + +// IsSet returns true if all of the wanted flags are set. +func (ff FunctionFlags) IsSet(wanted FunctionFlags) bool { + return (ff & wanted) == wanted +} + +const ( + // DefaultFlags is defined for clarity. + DefaultFlags FunctionFlags = 0 + + // ShortCircuit indicates that further validations should be skipped if + // this validator fails. Most validators are not fatal. + ShortCircuit FunctionFlags = 1 << iota + + // NonError indicates that a failure of this validator should not be + // accumulated as an error, but should trigger other aspects of the failure + // path (e.g. early return when combined with ShortCircuit). + NonError +) + +// FunctionGen provides validation-gen with the information needed to generate a +// validation function invocation. +type FunctionGen interface { + // TagName returns the tag which triggers this validator. + TagName() string + + // SignatureAndArgs returns the function name and all extraArg value literals that are passed when the function + // invocation is generated. + // + // The function signature must be of the form: + // func(op operation.Operation, + // fldPath field.Path, + // value, oldValue , // always nilable + // extraArgs[0] , // optional + // ..., + // extraArgs[N] ) + // + // extraArgs may contain: + // - data literals comprised of maps, slices, strings, ints, floats and bools + // - references, represented by types.Type (to reference any type in the universe), and types.Member (to reference members of the current value) + // + // If validation function to be called does not have a signature of this form, please introduce + // a function that does and use that function to call the validation function. + SignatureAndArgs() (function types.Name, extraArgs []any) + + // TypeArgs assigns types to the type parameters of the function, for invocation. + TypeArgs() []types.Name + + // Flags returns the options for this validator function. + Flags() FunctionFlags + + // Conditions returns the conditions that must true for a resource to be + // validated by this function. + Conditions() Conditions +} + +// Conditions defines what conditions must be true for a resource to be validated. +// If any of the conditions are not true, the resource is not validated. +type Conditions struct { + // OptionEnabled specifies an option name that must be set to true for the condition to be true. + OptionEnabled string + + // OptionDisabled specifies an option name that must be set to false for the condition to be true. + OptionDisabled string +} + +func (c Conditions) Empty() bool { + return len(c.OptionEnabled) == 0 && len(c.OptionDisabled) == 0 +} + +// Identifier is a name that the generator will output as an identifier. +// Identifiers are generated using the RawNamer strategy. +type Identifier types.Name + +// PrivateVar is a variable name that the generator will output as a private identifier. +// PrivateVars are generated using the PrivateNamer strategy. +type PrivateVar types.Name + +// VariableGen provides validation-gen with the information needed to generate variable. +// Variables typically support generated functions by providing static information such +// as the list of supported symbols for an enum. +type VariableGen interface { + // TagName returns the tag which triggers this validator. + TagName() string + + // Var returns the variable identifier. + Var() PrivateVar + + // Init generates the function call that the variable is assigned to. + Init() FunctionGen +} + +// Function creates a FunctionGen for a given function name and extraArgs. +func Function(tagName string, flags FunctionFlags, function types.Name, extraArgs ...any) FunctionGen { + return GenericFunction(tagName, flags, function, nil, extraArgs...) +} + +func GenericFunction(tagName string, flags FunctionFlags, function types.Name, typeArgs []types.Name, extraArgs ...any) FunctionGen { + // Callers of Signature don't care if the args are all of a known type, it just + // makes it easier to declare validators. + var anyArgs []any + if len(extraArgs) > 0 { + anyArgs = make([]any, len(extraArgs)) + copy(anyArgs, extraArgs) + } + return &functionGen{tagName: tagName, flags: flags, function: function, extraArgs: anyArgs, typeArgs: typeArgs} +} + +func WithCondition(fn FunctionGen, conditions Conditions) FunctionGen { + name, args := fn.SignatureAndArgs() + return &functionGen{ + tagName: fn.TagName(), flags: fn.Flags(), function: name, extraArgs: args, typeArgs: fn.TypeArgs(), + conditions: conditions, + } +} + +type functionGen struct { + tagName string + function types.Name + extraArgs []any + typeArgs []types.Name + flags FunctionFlags + conditions Conditions +} + +func (v *functionGen) TagName() string { + return v.tagName +} + +func (v *functionGen) SignatureAndArgs() (function types.Name, args []any) { + return v.function, v.extraArgs +} + +func (v *functionGen) TypeArgs() []types.Name { return v.typeArgs } + +func (v *functionGen) Flags() FunctionFlags { + return v.flags +} + +func (v *functionGen) Conditions() Conditions { return v.conditions } + +// Variable creates a VariableGen for a given function name and extraArgs. +func Variable(variable PrivateVar, init FunctionGen) VariableGen { + return &variableGen{ + variable: variable, + init: init, + } +} + +type variableGen struct { + variable PrivateVar + init FunctionGen +} + +func (v variableGen) TagName() string { + return v.init.TagName() +} + +func (v variableGen) Var() PrivateVar { + return v.variable +} + +func (v variableGen) Init() FunctionGen { + return v.init +} + +// WrapperFunction describes a function literal which has the fingerprint of a +// regular validation function (op, fldPath, obj, oldObj) and calls another +// validation function with the same signature, plus extra args if needed. +type WrapperFunction struct { + Function FunctionGen + ObjType *types.Type +} + +// Literal is a literal value that, when used as an argument to a validator, +// will be emitted without any further interpretation. Use this with caution, +// it will not be subject to Namers. +type Literal string + +// FunctionLiteral describes a function-literal expression that can be used as +// an argument to a validator. Unlike WrapperFunction, this does not +// necessarily have the same signature as a regular validation function. +type FunctionLiteral struct { + Parameters []ParamResult + Results []ParamResult + Body string +} + +// ParamResult represents a parameter or a result of a function. +type ParamResult struct { + Name string + Type *types.Type +}