Merge pull request #7556 from wojtek-t/conversions_with_defaulting

Auto-generated conversion methods calling one another
This commit is contained in:
Clayton Coleman 2015-05-07 11:28:35 -04:00
commit b6fb8c861e
5 changed files with 4491 additions and 3131 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -146,7 +146,7 @@ func TestNoManualChangesToGenerateConversions(t *testing.T) {
versions := []string{"v1beta3"}
for _, version := range versions {
fileName := fmt.Sprintf("../../pkg/api/%s/conversion.go", version)
fileName := fmt.Sprintf("../../pkg/api/%s/conversion_generated.go", version)
existingFunctions, existingNames := bufferExistingConversions(t, fileName)
generatedFunctions, generatedNames := generateConversions(t, version)

View File

@ -55,6 +55,9 @@ type Converter struct {
// Map from a type to a function which applies defaults.
defaultingFuncs map[reflect.Type]reflect.Value
// Similar to above, but function is stored as interface{}.
defaultingInterfaces map[reflect.Type]interface{}
// Map from an input type to a function which can apply a key name mapping
inputFieldMappingFuncs map[reflect.Type]FieldMappingFunc
@ -76,6 +79,7 @@ func NewConverter() *Converter {
conversionFuncs: map[typePair]reflect.Value{},
generatedConversionFuncs: map[typePair]reflect.Value{},
defaultingFuncs: map[reflect.Type]reflect.Value{},
defaultingInterfaces: map[reflect.Type]interface{}{},
nameFunc: func(t reflect.Type) string { return t.Name() },
structFieldDests: map[typeNamePair][]typeNamePair{},
structFieldSources: map[typeNamePair][]typeNamePair{},
@ -106,6 +110,10 @@ type Scope interface {
// on the current stack frame. This makes it safe to call from a conversion func.
DefaultConvert(src, dest interface{}, flags FieldMatchingFlags) error
// If registered, returns a function applying defaults for objects of a given type.
// Used for automatically generating convertion functions.
DefaultingInterface(inType reflect.Type) (interface{}, bool)
// SrcTags and DestTags contain the struct tags that src and dest had, respectively.
// If the enclosing object was not a struct, then these will contain no tags, of course.
SrcTag() reflect.StructTag
@ -184,6 +192,11 @@ func (s scopeStack) describe() string {
return desc
}
func (s *scope) DefaultingInterface(inType reflect.Type) (interface{}, bool) {
value, found := s.converter.defaultingInterfaces[inType]
return value, found
}
// Formats src & dest as indices for printing.
func (s *scope) setIndices(src, dest int) {
s.srcStack.top().key = fmt.Sprintf("[%v]", src)
@ -277,7 +290,7 @@ func verifyConversionFunctionSignature(ft reflect.Type) error {
// used if recursive conversion calls are desired). It must return an error.
//
// Example:
// c.RegisteConversionFunc(
// c.RegisterConversionFunc(
// func(in *Pod, out *v1beta1.Pod, s Scope) error {
// // conversion logic...
// return nil
@ -348,7 +361,9 @@ func (c *Converter) RegisterDefaultingFunc(defaultingFunc interface{}) error {
if ft.In(0).Kind() != reflect.Ptr {
return fmt.Errorf("expected pointer arg for 'in' param 0, got: %v", ft)
}
c.defaultingFuncs[ft.In(0).Elem()] = fv
inType := ft.In(0).Elem()
c.defaultingFuncs[inType] = fv
c.defaultingInterfaces[inType] = defaultingFunc
return nil
}

View File

@ -270,6 +270,22 @@ func (g *generator) typeName(inType reflect.Type) string {
}
}
func (g *generator) writeDefaultingFunc(w io.Writer, inType reflect.Type, indent int) error {
getStmt := "if defaulting, found := s.DefaultingInterface(reflect.TypeOf(*in)); found {\n"
if err := writeLine(w, indent, getStmt); err != nil {
return err
}
callFormat := "defaulting.(func(*%s))(in)\n"
callStmt := fmt.Sprintf(callFormat, g.typeName(inType))
if err := writeLine(w, indent+1, callStmt); err != nil {
return err
}
if err := writeLine(w, indent, "}\n"); err != nil {
return err
}
return nil
}
func packageForName(inType reflect.Type) string {
if inType.PkgPath() == "" {
return ""
@ -378,6 +394,14 @@ func (g *generator) writeConversionForMap(w io.Writer, inField, outField reflect
if err := writeLine(w, indent+1, "}\n"); err != nil {
return err
}
if err := writeLine(w, indent, "} else {\n"); err != nil {
return err
}
nilFormat := "out.%s = nil\n"
nilStmt := fmt.Sprintf(nilFormat, outField.Name)
if err := writeLine(w, indent+1, nilStmt); err != nil {
return err
}
if err := writeLine(w, indent, "}\n"); err != nil {
return err
}
@ -424,8 +448,15 @@ func (g *generator) writeConversionForSlice(w io.Writer, inField, outField refle
}
}
if !assigned {
assignFormat := "if err := s.Convert(&in.%s[i], &out.%s[i], 0); err != nil {\n"
assignStmt := fmt.Sprintf(assignFormat, inField.Name, outField.Name)
assignStmt := ""
if g.existsConvertionFunction(inField.Type.Elem(), outField.Type.Elem()) {
assignFormat := "if err := %s(&in.%s[i], &out.%s[i], s); err != nil {\n"
funcName := conversionFunctionName(inField.Type.Elem(), outField.Type.Elem())
assignStmt = fmt.Sprintf(assignFormat, funcName, inField.Name, outField.Name)
} else {
assignFormat := "if err := s.Convert(&in.%s[i], &out.%s[i], 0); err != nil {\n"
assignStmt = fmt.Sprintf(assignFormat, inField.Name, outField.Name)
}
if err := writeLine(w, indent+2, assignStmt); err != nil {
return err
}
@ -439,6 +470,14 @@ func (g *generator) writeConversionForSlice(w io.Writer, inField, outField refle
if err := writeLine(w, indent+1, "}\n"); err != nil {
return err
}
if err := writeLine(w, indent, "} else {\n"); err != nil {
return err
}
nilFormat := "out.%s = nil\n"
nilStmt := fmt.Sprintf(nilFormat, outField.Name)
if err := writeLine(w, indent+1, nilStmt); err != nil {
return err
}
if err := writeLine(w, indent, "}\n"); err != nil {
return err
}
@ -479,6 +518,14 @@ func (g *generator) writeConversionForPtr(w io.Writer, inField, outField reflect
}
}
if assignable || convertible {
if err := writeLine(w, indent, "} else {\n"); err != nil {
return err
}
nilFormat := "out.%s = nil\n"
nilStmt := fmt.Sprintf(nilFormat, outField.Name)
if err := writeLine(w, indent+1, nilStmt); err != nil {
return err
}
if err := writeLine(w, indent, "}\n"); err != nil {
return err
}
@ -486,12 +533,40 @@ func (g *generator) writeConversionForPtr(w io.Writer, inField, outField reflect
}
}
assignFormat := "if err := s.Convert(&in.%s, &out.%s, 0); err != nil {\n"
assignStmt := fmt.Sprintf(assignFormat, inField.Name, outField.Name)
if err := writeLine(w, indent, assignStmt); err != nil {
ifFormat := "if in.%s != nil {\n"
ifStmt := fmt.Sprintf(ifFormat, inField.Name)
if err := writeLine(w, indent, ifStmt); err != nil {
return err
}
if err := writeLine(w, indent+1, "return err\n"); err != nil {
assignStmt := ""
if g.existsConvertionFunction(inField.Type.Elem(), outField.Type.Elem()) {
newFormat := "out.%s = new(%s)\n"
newStmt := fmt.Sprintf(newFormat, outField.Name, g.typeName(outField.Type.Elem()))
if err := writeLine(w, indent+1, newStmt); err != nil {
return err
}
assignFormat := "if err := %s(in.%s, out.%s, s); err != nil {\n"
funcName := conversionFunctionName(inField.Type.Elem(), outField.Type.Elem())
assignStmt = fmt.Sprintf(assignFormat, funcName, inField.Name, outField.Name)
} else {
assignFormat := "if err := s.Convert(&in.%s, &out.%s, 0); err != nil {\n"
assignStmt = fmt.Sprintf(assignFormat, inField.Name, outField.Name)
}
if err := writeLine(w, indent+1, assignStmt); err != nil {
return err
}
if err := writeLine(w, indent+2, "return err\n"); err != nil {
return err
}
if err := writeLine(w, indent+1, "}\n"); err != nil {
return err
}
if err := writeLine(w, indent, "} else {\n"); err != nil {
return err
}
nilFormat := "out.%s = nil\n"
nilStmt := fmt.Sprintf(nilFormat, outField.Name)
if err := writeLine(w, indent+1, nilStmt); err != nil {
return err
}
if err := writeLine(w, indent, "}\n"); err != nil {
@ -568,8 +643,15 @@ func (g *generator) writeConversionForStruct(w io.Writer, inType, outType reflec
continue
}
assignFormat := "if err := s.Convert(&in.%s, &out.%s, 0); err != nil {\n"
assignStmt := fmt.Sprintf(assignFormat, inField.Name, outField.Name)
assignStmt := ""
if g.existsConvertionFunction(inField.Type, outField.Type) {
assignFormat := "if err := %s(&in.%s, &out.%s, s); err != nil {\n"
funcName := conversionFunctionName(inField.Type, outField.Type)
assignStmt = fmt.Sprintf(assignFormat, funcName, inField.Name, outField.Name)
} else {
assignFormat := "if err := s.Convert(&in.%s, &out.%s, 0); err != nil {\n"
assignStmt = fmt.Sprintf(assignFormat, inField.Name, outField.Name)
}
if err := writeLine(w, indent, assignStmt); err != nil {
return err
}
@ -588,6 +670,9 @@ func (g *generator) writeConversionForType(w io.Writer, inType, outType reflect.
if err := writeHeader(w, funcName, g.typeName(inType), g.typeName(outType), indent); err != nil {
return err
}
if err := g.writeDefaultingFunc(w, inType, indent+1); err != nil {
return err
}
switch inType.Kind() {
case reflect.Struct:
if err := g.writeConversionForStruct(w, inType, outType, indent+1); err != nil {
@ -605,6 +690,16 @@ func (g *generator) writeConversionForType(w io.Writer, inType, outType reflect.
return nil
}
func (g *generator) existsConvertionFunction(inType, outType reflect.Type) bool {
if val, found := g.convertibles[inType]; found && val == outType {
return true
}
if val, found := g.convertibles[outType]; found && val == inType {
return true
}
return false
}
func (g *generator) OverwritePackage(pkg, overwrite string) {
g.pkgOverwrites[pkg] = overwrite
}