diff --git a/pkg/conversion/converter.go b/pkg/conversion/converter.go index 9e3a6385d5d..9f3cdc3a94f 100644 --- a/pkg/conversion/converter.go +++ b/pkg/conversion/converter.go @@ -40,7 +40,8 @@ type DebugLogger interface { type Converter struct { // Map from the conversion pair to a function which can // do the conversion. - conversionFuncs map[typePair]reflect.Value + conversionFuncs map[typePair]reflect.Value + generatedConversionFuncs map[typePair]reflect.Value // This is a map from a source field type and name, to a list of destination // field type and name. @@ -72,11 +73,12 @@ type Converter struct { // NewConverter creates a new Converter object. func NewConverter() *Converter { c := &Converter{ - conversionFuncs: map[typePair]reflect.Value{}, - defaultingFuncs: map[reflect.Type]reflect.Value{}, - nameFunc: func(t reflect.Type) string { return t.Name() }, - structFieldDests: map[typeNamePair][]typeNamePair{}, - structFieldSources: map[typeNamePair][]typeNamePair{}, + conversionFuncs: map[typePair]reflect.Value{}, + generatedConversionFuncs: map[typePair]reflect.Value{}, + defaultingFuncs: map[reflect.Type]reflect.Value{}, + nameFunc: func(t reflect.Type) string { return t.Name() }, + structFieldDests: map[typeNamePair][]typeNamePair{}, + structFieldSources: map[typeNamePair][]typeNamePair{}, inputFieldMappingFuncs: map[reflect.Type]FieldMappingFunc{}, inputDefaultFlags: map[reflect.Type]FieldMatchingFlags{}, @@ -238,20 +240,8 @@ func (s *scope) error(message string, args ...interface{}) error { return fmt.Errorf(where+message, args...) } -// RegisterConversionFunc registers a conversion func with the -// Converter. conversionFunc must take three parameters: a pointer to the input -// type, a pointer to the output type, and a conversion.Scope (which should be -// used if recursive conversion calls are desired). It must return an error. -// -// Example: -// c.RegisteConversionFunc( -// func(in *Pod, out *v1beta1.Pod, s Scope) error { -// // conversion logic... -// return nil -// }) -func (c *Converter) RegisterConversionFunc(conversionFunc interface{}) error { - fv := reflect.ValueOf(conversionFunc) - ft := fv.Type() +// Verifies whether a conversion function has a correct signature. +func verifyConversionFunctionSignature(ft reflect.Type) error { if ft.Kind() != reflect.Func { return fmt.Errorf("expected func, got: %v", ft) } @@ -278,10 +268,47 @@ func (c *Converter) RegisterConversionFunc(conversionFunc interface{}) error { if ft.Out(0) != errorType { return fmt.Errorf("expected error return, got: %v", ft) } + return nil +} + +// RegisterConversionFunc registers a conversion func with the +// Converter. conversionFunc must take three parameters: a pointer to the input +// type, a pointer to the output type, and a conversion.Scope (which should be +// used if recursive conversion calls are desired). It must return an error. +// +// Example: +// c.RegisteConversionFunc( +// func(in *Pod, out *v1beta1.Pod, s Scope) error { +// // conversion logic... +// return nil +// }) +func (c *Converter) RegisterConversionFunc(conversionFunc interface{}) error { + fv := reflect.ValueOf(conversionFunc) + ft := fv.Type() + if err := verifyConversionFunctionSignature(ft); err != nil { + return err + } c.conversionFuncs[typePair{ft.In(0).Elem(), ft.In(1).Elem()}] = fv return nil } +// Similar to RegisterConversionFunc, but registers conversion function that were +// automatically generated. +func (c *Converter) RegisterGeneratedConversionFunc(conversionFunc interface{}) error { + fv := reflect.ValueOf(conversionFunc) + ft := fv.Type() + if err := verifyConversionFunctionSignature(ft); err != nil { + return err + } + c.generatedConversionFuncs[typePair{ft.In(0).Elem(), ft.In(1).Elem()}] = fv + return nil +} + +func (c *Converter) HasConversionFunc(inType, outType reflect.Type) bool { + _, found := c.conversionFuncs[typePair{inType, outType}] + return found +} + // SetStructFieldCopy registers a correspondence. Whenever a struct field is encountered // which has a type and name matching srcFieldType and srcFieldName, it wil be copied // into the field in the destination struct matching destFieldType & Name, if such a @@ -469,6 +496,12 @@ func (c *Converter) convert(sv, dv reflect.Value, scope *scope) error { } return c.callCustom(sv, dv, fv, scope) } + if fv, ok := c.generatedConversionFuncs[typePair{st, dt}]; ok { + if c.Debug != nil { + c.Debug.Logf("Calling custom conversion of '%v' to '%v'", st, dt) + } + return c.callCustom(sv, dv, fv, scope) + } return c.defaultConvert(sv, dv, scope) } diff --git a/pkg/conversion/converter_test.go b/pkg/conversion/converter_test.go index 0907c98ad20..6ba1bac9cc8 100644 --- a/pkg/conversion/converter_test.go +++ b/pkg/conversion/converter_test.go @@ -187,6 +187,28 @@ func TestConverter_CallsRegisteredFunctions(t *testing.T) { } } +func TestConverter_GeneratedConversionOverriden(t *testing.T) { + type A struct{} + type B struct{} + c := NewConverter() + if err := c.RegisterConversionFunc(func(in *A, out *B, s Scope) error { + return nil + }); err != nil { + t.Fatalf("unexpected error %v", err) + } + if err := c.RegisterGeneratedConversionFunc(func(in *A, out *B, s Scope) error { + return fmt.Errorf("generated function should be overriden") + }); err != nil { + t.Fatalf("unexpected error %v", err) + } + + a := A{} + b := B{} + if err := c.Convert(&a, &b, 0, nil); err != nil { + t.Errorf("%v", err) + } +} + func TestConverter_MapsStringArrays(t *testing.T) { type A struct { Foo string diff --git a/pkg/conversion/generator.go b/pkg/conversion/generator.go index cc810e2b0e1..02c869b51d0 100644 --- a/pkg/conversion/generator.go +++ b/pkg/conversion/generator.go @@ -64,6 +64,8 @@ func (g *generator) GenerateConversionsForType(version string, reflection reflec } func (g *generator) generateConversionsBetween(inType, outType reflect.Type) error { + existingConversion := g.scheme.Converter().HasConversionFunc(inType, outType) && g.scheme.Converter().HasConversionFunc(outType, inType) + // Avoid processing the same type multiple times. if value, found := g.convertibles[inType]; found { if value != outType { @@ -79,19 +81,50 @@ func (g *generator) generateConversionsBetween(inType, outType reflect.Type) err if inType.Kind() != outType.Kind() { return fmt.Errorf("cannot convert types of different kinds: %v %v", inType, outType) } + // We should be able to generate conversions both sides. switch inType.Kind() { case reflect.Map: - return g.generateConversionsForMap(inType, outType) - case reflect.Ptr: - return g.generateConversionsBetween(inType.Elem(), outType.Elem()) - case reflect.Slice: - return g.generateConversionsForSlice(inType, outType) - case reflect.Interface: - // TODO(wojtek-t): Currently we rely on default conversion functions for interfaces. - // Add support for reflect.Interface. + inErr := g.generateConversionsForMap(inType, outType) + outErr := g.generateConversionsForMap(outType, inType) + if !existingConversion && (inErr != nil || outErr != nil) { + return inErr + } + // We don't add it to g.convertibles - maps should be handled correctly + // inside appropriate conversion functions. return nil + case reflect.Ptr: + inErr := g.generateConversionsBetween(inType.Elem(), outType.Elem()) + outErr := g.generateConversionsBetween(outType.Elem(), inType.Elem()) + if !existingConversion && (inErr != nil || outErr != nil) { + return inErr + } + // We don't add it to g.convertibles - maps should be handled correctly + // inside appropriate conversion functions. + return nil + case reflect.Slice: + inErr := g.generateConversionsForSlice(inType, outType) + outErr := g.generateConversionsForSlice(outType, inType) + if !existingConversion && (inErr != nil || outErr != nil) { + return inErr + } + // We don't add it to g.convertibles - slices should be handled correctly + // inside appropriate conversion functions. + return nil + case reflect.Interface: + // TODO(wojtek-t): Currently we don't support converting interfaces. + return fmt.Errorf("interfaces are not supported") case reflect.Struct: - return g.generateConversionsForStruct(inType, outType) + inErr := g.generateConversionsForStruct(inType, outType) + outErr := g.generateConversionsForStruct(outType, inType) + if !existingConversion && (inErr != nil || outErr != nil) { + return inErr + } + if !existingConversion { + if _, found := g.convertibles[outType]; !found { + g.convertibles[inType] = outType + } + } + return nil default: // All simple types should be handled correctly with default conversion. return nil @@ -119,8 +152,6 @@ func (g *generator) generateConversionsForMap(inType, outType reflect.Type) erro if err := g.generateConversionsBetween(inValue, outValue); err != nil { return err } - // We don't add it to g.convertibles - maps should be handled correctly - // inside appropriate conversion functions. return nil } @@ -130,8 +161,6 @@ func (g *generator) generateConversionsForSlice(inType, outType reflect.Type) er if err := g.generateConversionsBetween(inElem, outElem); err != nil { return err } - // We don't add it to g.convertibles - slices should be handled correctly - // inside appropriate conversion functions. return nil } @@ -151,7 +180,6 @@ func (g *generator) generateConversionsForStruct(inType, outType reflect.Type) e } } } - g.convertibles[inType] = outType return nil } @@ -477,6 +505,22 @@ func (g *generator) writeConversionForStruct(w io.Writer, inType, outType reflec inField := inType.Field(i) outField, _ := outType.FieldByName(inField.Name) + if g.scheme.Converter().HasConversionFunc(inField.Type, outField.Type) { + // Use the conversion method that is already defined. + assignFormat := "if err := s.Convert(&in.%s, &out.%s, 0); err != nil {\n" + assignStmt := fmt.Sprintf(assignFormat, inField.Name, outField.Name) + if err := writeLine(w, indent, assignStmt); err != nil { + return err + } + if err := writeLine(w, indent+1, "return err\n"); err != nil { + return err + } + if err := writeLine(w, indent, "}\n"); err != nil { + return err + } + continue + } + switch inField.Type.Kind() { case reflect.Map, reflect.Ptr, reflect.Slice, reflect.Interface, reflect.Struct: // Don't copy these via assignment/conversion! diff --git a/pkg/conversion/scheme.go b/pkg/conversion/scheme.go index f842f103104..39ecae81733 100644 --- a/pkg/conversion/scheme.go +++ b/pkg/conversion/scheme.go @@ -259,7 +259,12 @@ func (s *Scheme) AddConversionFuncs(conversionFuncs ...interface{}) error { // Similar to AddConversionFuncs, but registers conversion functions that were // automatically generated. func (s *Scheme) AddGeneratedConversionFuncs(conversionFuncs ...interface{}) error { - return s.AddConversionFuncs(conversionFuncs...) + for _, f := range conversionFuncs { + if err := s.converter.RegisterGeneratedConversionFunc(f); err != nil { + return err + } + } + return nil } // AddStructFieldConversion allows you to specify a mechanical copy for a moved