Merge pull request #121651 from jiahuif-forks/fix/cel/type-resolver-safe-guard

CEL type resolvers: avoid infinite recursion for type resolvers.
This commit is contained in:
Kubernetes Prow Robot 2023-10-31 21:50:37 +01:00 committed by GitHub
commit 3631efd85c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 58 additions and 27 deletions

View File

@ -38,7 +38,7 @@ type combinedSchemaResolver struct {
// If the DefinitionsSchemaResolver knows the gvk, the DefinitionsSchemaResolver handles the resolution, // If the DefinitionsSchemaResolver knows the gvk, the DefinitionsSchemaResolver handles the resolution,
// otherwise, the secondary does. // otherwise, the secondary does.
func (r *combinedSchemaResolver) ResolveSchema(gvk schema.GroupVersionKind) (*spec.Schema, error) { func (r *combinedSchemaResolver) ResolveSchema(gvk schema.GroupVersionKind) (*spec.Schema, error) {
if _, ok := r.definitions.gvkToSchema[gvk]; ok { if _, ok := r.definitions.gvkToRef[gvk]; ok {
return r.definitions.ResolveSchema(gvk) return r.definitions.ResolveSchema(gvk)
} }
return r.secondary.ResolveSchema(gvk) return r.secondary.ResolveSchema(gvk)

View File

@ -30,7 +30,7 @@ import (
// by looking up the OpenAPI definitions. // by looking up the OpenAPI definitions.
type DefinitionsSchemaResolver struct { type DefinitionsSchemaResolver struct {
defs map[string]common.OpenAPIDefinition defs map[string]common.OpenAPIDefinition
gvkToSchema map[schema.GroupVersionKind]*spec.Schema gvkToRef map[schema.GroupVersionKind]string
} }
// NewDefinitionsSchemaResolver creates a new DefinitionsSchemaResolver. // NewDefinitionsSchemaResolver creates a new DefinitionsSchemaResolver.
@ -38,31 +38,30 @@ type DefinitionsSchemaResolver struct {
// getDefinitions = "k8s.io/kubernetes/pkg/generated/openapi".GetOpenAPIDefinitions // getDefinitions = "k8s.io/kubernetes/pkg/generated/openapi".GetOpenAPIDefinitions
// scheme = "k8s.io/client-go/kubernetes/scheme".Scheme // scheme = "k8s.io/client-go/kubernetes/scheme".Scheme
func NewDefinitionsSchemaResolver(getDefinitions common.GetOpenAPIDefinitions, schemes ...*runtime.Scheme) *DefinitionsSchemaResolver { func NewDefinitionsSchemaResolver(getDefinitions common.GetOpenAPIDefinitions, schemes ...*runtime.Scheme) *DefinitionsSchemaResolver {
gvkToSchema := make(map[schema.GroupVersionKind]*spec.Schema) gvkToRef := make(map[schema.GroupVersionKind]string)
namer := openapi.NewDefinitionNamer(schemes...) namer := openapi.NewDefinitionNamer(schemes...)
defs := getDefinitions(func(path string) spec.Ref { defs := getDefinitions(func(path string) spec.Ref {
return spec.MustCreateRef(path) return spec.MustCreateRef(path)
}) })
for name, def := range defs { for name := range defs {
_, e := namer.GetDefinitionName(name) _, e := namer.GetDefinitionName(name)
gvks := extensionsToGVKs(e) gvks := extensionsToGVKs(e)
s := def.Schema // map value not addressable, make copy
for _, gvk := range gvks { for _, gvk := range gvks {
gvkToSchema[gvk] = &s gvkToRef[gvk] = name
} }
} }
return &DefinitionsSchemaResolver{ return &DefinitionsSchemaResolver{
gvkToSchema: gvkToSchema, gvkToRef: gvkToRef,
defs: defs, defs: defs,
} }
} }
func (d *DefinitionsSchemaResolver) ResolveSchema(gvk schema.GroupVersionKind) (*spec.Schema, error) { func (d *DefinitionsSchemaResolver) ResolveSchema(gvk schema.GroupVersionKind) (*spec.Schema, error) {
s, ok := d.gvkToSchema[gvk] ref, ok := d.gvkToRef[gvk]
if !ok { if !ok {
return nil, fmt.Errorf("cannot resolve %v: %w", gvk, ErrSchemaNotFound) return nil, fmt.Errorf("cannot resolve %v: %w", gvk, ErrSchemaNotFound)
} }
s, err := populateRefs(func(ref string) (*spec.Schema, bool) { s, err := PopulateRefs(func(ref string) (*spec.Schema, bool) {
// find the schema by the ref string, and return a deep copy // find the schema by the ref string, and return a deep copy
def, ok := d.defs[ref] def, ok := d.defs[ref]
if !ok { if !ok {
@ -70,7 +69,7 @@ func (d *DefinitionsSchemaResolver) ResolveSchema(gvk schema.GroupVersionKind) (
} }
s := def.Schema s := def.Schema
return &s, true return &s, true
}, s) }, ref)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -53,34 +53,34 @@ func (r *ClientDiscoveryResolver) ResolveSchema(gvk schema.GroupVersionKind) (*s
if err != nil { if err != nil {
return nil, err return nil, err
} }
s, err := resolveType(resp, gvk) ref, err := resolveRef(resp, gvk)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s, err = populateRefs(func(ref string) (*spec.Schema, bool) { s, err := PopulateRefs(func(ref string) (*spec.Schema, bool) {
s, ok := resp.Components.Schemas[strings.TrimPrefix(ref, refPrefix)] s, ok := resp.Components.Schemas[strings.TrimPrefix(ref, refPrefix)]
return s, ok return s, ok
}, s) }, ref)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return s, nil return s, nil
} }
func resolveType(resp *schemaResponse, gvk schema.GroupVersionKind) (*spec.Schema, error) { func resolveRef(resp *schemaResponse, gvk schema.GroupVersionKind) (string, error) {
for _, s := range resp.Components.Schemas { for ref, s := range resp.Components.Schemas {
var gvks []schema.GroupVersionKind var gvks []schema.GroupVersionKind
err := s.Extensions.GetObject(extGVK, &gvks) err := s.Extensions.GetObject(extGVK, &gvks)
if err != nil { if err != nil {
return nil, err return "", err
} }
for _, g := range gvks { for _, g := range gvks {
if g == gvk { if g == gvk {
return s, nil return ref, nil
} }
} }
} }
return nil, fmt.Errorf("cannot resolve group version kind %q: %w", gvk, ErrSchemaNotFound) return "", fmt.Errorf("cannot resolve group version kind %q: %w", gvk, ErrSchemaNotFound)
} }
func resourcePathFromGV(gv schema.GroupVersion) string { func resourcePathFromGV(gv schema.GroupVersion) string {

View File

@ -19,19 +19,41 @@ package resolver
import ( import (
"fmt" "fmt"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/kube-openapi/pkg/validation/spec" "k8s.io/kube-openapi/pkg/validation/spec"
) )
// populateRefs recursively replaces Refs in the schema with the referred one. // PopulateRefs recursively replaces Refs in the schema with the referred one.
// schemaOf is the callback to find the corresponding schema by the ref. // schemaOf is the callback to find the corresponding schema by the ref.
// This function will not mutate the original schema. If the schema needs to be // This function will not mutate the original schema. If the schema needs to be
// mutated, a copy will be returned, otherwise it returns the original schema. // mutated, a copy will be returned, otherwise it returns the original schema.
func populateRefs(schemaOf func(ref string) (*spec.Schema, bool), schema *spec.Schema) (*spec.Schema, error) { func PopulateRefs(schemaOf func(ref string) (*spec.Schema, bool), rootRef string) (*spec.Schema, error) {
visitedRefs := sets.New[string]()
rootSchema, ok := schemaOf(rootRef)
visitedRefs.Insert(rootRef)
if !ok {
return nil, fmt.Errorf("internal error: cannot resolve Ref for root schema %q: %w", rootRef, ErrSchemaNotFound)
}
return populateRefs(schemaOf, visitedRefs, rootSchema)
}
func populateRefs(schemaOf func(ref string) (*spec.Schema, bool), visited sets.Set[string], schema *spec.Schema) (*spec.Schema, error) {
result := *schema result := *schema
changed := false changed := false
ref, isRef := refOf(schema) ref, isRef := refOf(schema)
if isRef { if isRef {
if visited.Has(ref) {
return &spec.Schema{
// for circular ref, return an empty object as placeholder
SchemaProps: spec.SchemaProps{Type: []string{"object"}},
}, nil
}
visited.Insert(ref)
// restore visited state at the end of the recursion.
defer func() {
visited.Delete(ref)
}()
// replace the whole schema with the referred one. // replace the whole schema with the referred one.
resolved, ok := schemaOf(ref) resolved, ok := schemaOf(ref)
if !ok { if !ok {
@ -44,7 +66,7 @@ func populateRefs(schemaOf func(ref string) (*spec.Schema, bool), schema *spec.S
props := make(map[string]spec.Schema, len(schema.Properties)) props := make(map[string]spec.Schema, len(schema.Properties))
propsChanged := false propsChanged := false
for name, prop := range result.Properties { for name, prop := range result.Properties {
populated, err := populateRefs(schemaOf, &prop) populated, err := populateRefs(schemaOf, visited, &prop)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -58,7 +80,7 @@ func populateRefs(schemaOf func(ref string) (*spec.Schema, bool), schema *spec.S
result.Properties = props result.Properties = props
} }
if result.AdditionalProperties != nil && result.AdditionalProperties.Schema != nil { if result.AdditionalProperties != nil && result.AdditionalProperties.Schema != nil {
populated, err := populateRefs(schemaOf, result.AdditionalProperties.Schema) populated, err := populateRefs(schemaOf, visited, result.AdditionalProperties.Schema)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -69,7 +91,7 @@ func populateRefs(schemaOf func(ref string) (*spec.Schema, bool), schema *spec.S
} }
// schema is a list, populate its items // schema is a list, populate its items
if result.Items != nil && result.Items.Schema != nil { if result.Items != nil && result.Items.Schema != nil {
populated, err := populateRefs(schemaOf, result.Items.Schema) populated, err := populateRefs(schemaOf, visited, result.Items.Schema)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -36,6 +36,7 @@ import (
storagev1 "k8s.io/api/storage/v1" storagev1 "k8s.io/api/storage/v1"
apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
extclientset "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset" extclientset "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset"
apiextensionsscheme "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset/scheme"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
@ -79,7 +80,7 @@ func TestTypeResolver(t *testing.T) {
} }
}(crd) }(crd)
discoveryResolver := &resolver.ClientDiscoveryResolver{Discovery: client.Discovery()} discoveryResolver := &resolver.ClientDiscoveryResolver{Discovery: client.Discovery()}
definitionsResolver := resolver.NewDefinitionsSchemaResolver(openapi.GetOpenAPIDefinitions, k8sscheme.Scheme) definitionsResolver := resolver.NewDefinitionsSchemaResolver(openapi.GetOpenAPIDefinitions, k8sscheme.Scheme, apiextensionsscheme.Scheme)
// wait until the CRD schema is published at the OpenAPI v3 endpoint // wait until the CRD schema is published at the OpenAPI v3 endpoint
err = wait.PollImmediate(time.Second, time.Minute, func() (done bool, err error) { err = wait.PollImmediate(time.Second, time.Minute, func() (done bool, err error) {
p, err := client.OpenAPIV3().Paths() p, err := client.OpenAPIV3().Paths()
@ -330,7 +331,7 @@ func TestBuiltinResolution(t *testing.T) {
}{ }{
{ {
name: "definitions", name: "definitions",
resolver: resolver.NewDefinitionsSchemaResolver(openapi.GetOpenAPIDefinitions, k8sscheme.Scheme), resolver: resolver.NewDefinitionsSchemaResolver(openapi.GetOpenAPIDefinitions, k8sscheme.Scheme, apiextensionsscheme.Scheme),
scheme: buildTestScheme(), scheme: buildTestScheme(),
}, },
{ {
@ -354,6 +355,14 @@ func TestBuiltinResolution(t *testing.T) {
if gvk.Version == "__internal" { if gvk.Version == "__internal" {
continue continue
} }
// apiextensions.k8s.io/v1beta1 not published
if tc.name == "discovery" && gvk.Group == "apiextensions.k8s.io" && gvk.Version == "v1beta1" {
continue
}
// apiextensions.k8s.io ConversionReview not published
if tc.name == "discovery" && gvk.Group == "apiextensions.k8s.io" && gvk.Kind == "ConversionReview" {
continue
}
_, err = tc.resolver.ResolveSchema(gvk) _, err = tc.resolver.ResolveSchema(gvk)
if err != nil { if err != nil {
t.Errorf("resolver %q cannot resolve %v", tc.name, gvk) t.Errorf("resolver %q cannot resolve %v", tc.name, gvk)
@ -506,5 +515,6 @@ func buildTestScheme() *runtime.Scheme {
_ = networkingv1.AddToScheme(scheme) _ = networkingv1.AddToScheme(scheme)
_ = nodev1.AddToScheme(scheme) _ = nodev1.AddToScheme(scheme)
_ = storagev1.AddToScheme(scheme) _ = storagev1.AddToScheme(scheme)
_ = apiextensionsscheme.AddToScheme(scheme)
return scheme return scheme
} }