diff --git a/pkg/conversion/converter.go b/pkg/conversion/converter.go index 9f3cdc3a94f..0596bec740d 100644 --- a/pkg/conversion/converter.go +++ b/pkg/conversion/converter.go @@ -55,6 +55,9 @@ type Converter struct { // Map from a type to a function which applies defaults. defaultingFuncs map[reflect.Type]reflect.Value + // Similar to above, but function is stored as interface{}. + defaultingInterfaces map[reflect.Type]interface{} + // Map from an input type to a function which can apply a key name mapping inputFieldMappingFuncs map[reflect.Type]FieldMappingFunc @@ -76,6 +79,7 @@ func NewConverter() *Converter { conversionFuncs: map[typePair]reflect.Value{}, generatedConversionFuncs: map[typePair]reflect.Value{}, defaultingFuncs: map[reflect.Type]reflect.Value{}, + defaultingInterfaces: map[reflect.Type]interface{}{}, nameFunc: func(t reflect.Type) string { return t.Name() }, structFieldDests: map[typeNamePair][]typeNamePair{}, structFieldSources: map[typeNamePair][]typeNamePair{}, @@ -106,6 +110,10 @@ type Scope interface { // on the current stack frame. This makes it safe to call from a conversion func. DefaultConvert(src, dest interface{}, flags FieldMatchingFlags) error + // If registered, returns a function applying defaults for objects of a given type. + // Used for automatically generating convertion functions. + DefaultingInterface(inType reflect.Type) (interface{}, bool) + // SrcTags and DestTags contain the struct tags that src and dest had, respectively. // If the enclosing object was not a struct, then these will contain no tags, of course. SrcTag() reflect.StructTag @@ -184,6 +192,11 @@ func (s scopeStack) describe() string { return desc } +func (s *scope) DefaultingInterface(inType reflect.Type) (interface{}, bool) { + value, found := s.converter.defaultingInterfaces[inType] + return value, found +} + // Formats src & dest as indices for printing. func (s *scope) setIndices(src, dest int) { s.srcStack.top().key = fmt.Sprintf("[%v]", src) @@ -277,7 +290,7 @@ func verifyConversionFunctionSignature(ft reflect.Type) error { // used if recursive conversion calls are desired). It must return an error. // // Example: -// c.RegisteConversionFunc( +// c.RegisterConversionFunc( // func(in *Pod, out *v1beta1.Pod, s Scope) error { // // conversion logic... // return nil @@ -348,7 +361,9 @@ func (c *Converter) RegisterDefaultingFunc(defaultingFunc interface{}) error { if ft.In(0).Kind() != reflect.Ptr { return fmt.Errorf("expected pointer arg for 'in' param 0, got: %v", ft) } - c.defaultingFuncs[ft.In(0).Elem()] = fv + inType := ft.In(0).Elem() + c.defaultingFuncs[inType] = fv + c.defaultingInterfaces[inType] = defaultingFunc return nil } diff --git a/pkg/conversion/generator.go b/pkg/conversion/generator.go index 02c869b51d0..e3ef67faea9 100644 --- a/pkg/conversion/generator.go +++ b/pkg/conversion/generator.go @@ -270,6 +270,22 @@ func (g *generator) typeName(inType reflect.Type) string { } } +func (g *generator) writeDefaultingFunc(w io.Writer, inType reflect.Type, indent int) error { + getStmt := "if defaulting, found := s.DefaultingInterface(reflect.TypeOf(*in)); found {\n" + if err := writeLine(w, indent, getStmt); err != nil { + return err + } + callFormat := "defaulting.(func(*%s))(in)\n" + callStmt := fmt.Sprintf(callFormat, g.typeName(inType)) + if err := writeLine(w, indent+1, callStmt); err != nil { + return err + } + if err := writeLine(w, indent, "}\n"); err != nil { + return err + } + return nil +} + func packageForName(inType reflect.Type) string { if inType.PkgPath() == "" { return "" @@ -378,6 +394,14 @@ func (g *generator) writeConversionForMap(w io.Writer, inField, outField reflect if err := writeLine(w, indent+1, "}\n"); err != nil { return err } + if err := writeLine(w, indent, "} else {\n"); err != nil { + return err + } + nilFormat := "out.%s = nil\n" + nilStmt := fmt.Sprintf(nilFormat, outField.Name) + if err := writeLine(w, indent+1, nilStmt); err != nil { + return err + } if err := writeLine(w, indent, "}\n"); err != nil { return err } @@ -424,8 +448,15 @@ func (g *generator) writeConversionForSlice(w io.Writer, inField, outField refle } } if !assigned { - assignFormat := "if err := s.Convert(&in.%s[i], &out.%s[i], 0); err != nil {\n" - assignStmt := fmt.Sprintf(assignFormat, inField.Name, outField.Name) + assignStmt := "" + if g.existsConvertionFunction(inField.Type.Elem(), outField.Type.Elem()) { + assignFormat := "if err := %s(&in.%s[i], &out.%s[i], s); err != nil {\n" + funcName := conversionFunctionName(inField.Type.Elem(), outField.Type.Elem()) + assignStmt = fmt.Sprintf(assignFormat, funcName, inField.Name, outField.Name) + } else { + assignFormat := "if err := s.Convert(&in.%s[i], &out.%s[i], 0); err != nil {\n" + assignStmt = fmt.Sprintf(assignFormat, inField.Name, outField.Name) + } if err := writeLine(w, indent+2, assignStmt); err != nil { return err } @@ -439,6 +470,14 @@ func (g *generator) writeConversionForSlice(w io.Writer, inField, outField refle if err := writeLine(w, indent+1, "}\n"); err != nil { return err } + if err := writeLine(w, indent, "} else {\n"); err != nil { + return err + } + nilFormat := "out.%s = nil\n" + nilStmt := fmt.Sprintf(nilFormat, outField.Name) + if err := writeLine(w, indent+1, nilStmt); err != nil { + return err + } if err := writeLine(w, indent, "}\n"); err != nil { return err } @@ -479,6 +518,14 @@ func (g *generator) writeConversionForPtr(w io.Writer, inField, outField reflect } } if assignable || convertible { + if err := writeLine(w, indent, "} else {\n"); err != nil { + return err + } + nilFormat := "out.%s = nil\n" + nilStmt := fmt.Sprintf(nilFormat, outField.Name) + if err := writeLine(w, indent+1, nilStmt); err != nil { + return err + } if err := writeLine(w, indent, "}\n"); err != nil { return err } @@ -486,12 +533,40 @@ func (g *generator) writeConversionForPtr(w io.Writer, inField, outField reflect } } - assignFormat := "if err := s.Convert(&in.%s, &out.%s, 0); err != nil {\n" - assignStmt := fmt.Sprintf(assignFormat, inField.Name, outField.Name) - if err := writeLine(w, indent, assignStmt); err != nil { + ifFormat := "if in.%s != nil {\n" + ifStmt := fmt.Sprintf(ifFormat, inField.Name) + if err := writeLine(w, indent, ifStmt); err != nil { return err } - if err := writeLine(w, indent+1, "return err\n"); err != nil { + assignStmt := "" + if g.existsConvertionFunction(inField.Type.Elem(), outField.Type.Elem()) { + newFormat := "out.%s = new(%s)\n" + newStmt := fmt.Sprintf(newFormat, outField.Name, g.typeName(outField.Type.Elem())) + if err := writeLine(w, indent+1, newStmt); err != nil { + return err + } + assignFormat := "if err := %s(in.%s, out.%s, s); err != nil {\n" + funcName := conversionFunctionName(inField.Type.Elem(), outField.Type.Elem()) + assignStmt = fmt.Sprintf(assignFormat, funcName, inField.Name, outField.Name) + } else { + assignFormat := "if err := s.Convert(&in.%s, &out.%s, 0); err != nil {\n" + assignStmt = fmt.Sprintf(assignFormat, inField.Name, outField.Name) + } + if err := writeLine(w, indent+1, assignStmt); err != nil { + return err + } + if err := writeLine(w, indent+2, "return err\n"); err != nil { + return err + } + if err := writeLine(w, indent+1, "}\n"); err != nil { + return err + } + if err := writeLine(w, indent, "} else {\n"); err != nil { + return err + } + nilFormat := "out.%s = nil\n" + nilStmt := fmt.Sprintf(nilFormat, outField.Name) + if err := writeLine(w, indent+1, nilStmt); err != nil { return err } if err := writeLine(w, indent, "}\n"); err != nil { @@ -568,8 +643,15 @@ func (g *generator) writeConversionForStruct(w io.Writer, inType, outType reflec continue } - assignFormat := "if err := s.Convert(&in.%s, &out.%s, 0); err != nil {\n" - assignStmt := fmt.Sprintf(assignFormat, inField.Name, outField.Name) + assignStmt := "" + if g.existsConvertionFunction(inField.Type, outField.Type) { + assignFormat := "if err := %s(&in.%s, &out.%s, s); err != nil {\n" + funcName := conversionFunctionName(inField.Type, outField.Type) + assignStmt = fmt.Sprintf(assignFormat, funcName, inField.Name, outField.Name) + } else { + assignFormat := "if err := s.Convert(&in.%s, &out.%s, 0); err != nil {\n" + assignStmt = fmt.Sprintf(assignFormat, inField.Name, outField.Name) + } if err := writeLine(w, indent, assignStmt); err != nil { return err } @@ -588,6 +670,9 @@ func (g *generator) writeConversionForType(w io.Writer, inType, outType reflect. if err := writeHeader(w, funcName, g.typeName(inType), g.typeName(outType), indent); err != nil { return err } + if err := g.writeDefaultingFunc(w, inType, indent+1); err != nil { + return err + } switch inType.Kind() { case reflect.Struct: if err := g.writeConversionForStruct(w, inType, outType, indent+1); err != nil { @@ -605,6 +690,16 @@ func (g *generator) writeConversionForType(w io.Writer, inType, outType reflect. return nil } +func (g *generator) existsConvertionFunction(inType, outType reflect.Type) bool { + if val, found := g.convertibles[inType]; found && val == outType { + return true + } + if val, found := g.convertibles[outType]; found && val == inType { + return true + } + return false +} + func (g *generator) OverwritePackage(pkg, overwrite string) { g.pkgOverwrites[pkg] = overwrite }