Refactor Rewrite functions

This commit is contained in:
Wojciech Tyczynski 2016-03-09 12:14:00 +01:00
parent aca37830b1
commit dcf292174e
2 changed files with 42 additions and 59 deletions

View File

@ -229,7 +229,7 @@ func Run(g *Generator) {
// alter the generated protobuf file to remove the generated types (but leave the serializers) and rewrite the // alter the generated protobuf file to remove the generated types (but leave the serializers) and rewrite the
// package statement to match the desired package name // package statement to match the desired package name
if err := RewriteGeneratedGogoProtobufFile(outputPath, p.GoPackageName(), p.ExtractGeneratedType, buf.Bytes()); err != nil { if err := RewriteGeneratedGogoProtobufFile(outputPath, p.ExtractGeneratedType, buf.Bytes()); err != nil {
log.Fatalf("Unable to rewrite generated %s: %v", outputPath, err) log.Fatalf("Unable to rewrite generated %s: %v", outputPath, err)
} }

View File

@ -33,11 +33,7 @@ import (
customreflect "k8s.io/kubernetes/third_party/golang/reflect" customreflect "k8s.io/kubernetes/third_party/golang/reflect"
) )
// ExtractFunc extracts information from the provided TypeSpec and returns true if the type should be func rewriteFile(name string, header []byte, rewriteFn func(*token.FileSet, *ast.File) error) error {
// removed from the destination file.
type ExtractFunc func(*ast.TypeSpec) bool
func RewriteGeneratedGogoProtobufFile(name string, packageName string, extractFn ExtractFunc, header []byte) error {
fset := token.NewFileSet() fset := token.NewFileSet()
src, err := ioutil.ReadFile(name) src, err := ioutil.ReadFile(name)
if err != nil { if err != nil {
@ -47,19 +43,10 @@ func RewriteGeneratedGogoProtobufFile(name string, packageName string, extractFn
if err != nil { if err != nil {
return err return err
} }
cmap := ast.NewCommentMap(fset, file, file.Comments)
// remove types that are already declared if err := rewriteFn(fset, file); err != nil {
decls := []ast.Decl{} return err
for _, d := range file.Decls {
if !dropExistingTypeDeclarations(d, extractFn) {
decls = append(decls, d)
} }
}
file.Decls = decls
// remove unmapped comments
file.Comments = cmap.Filter(file).Comments()
b := &bytes.Buffer{} b := &bytes.Buffer{}
b.Write(header) b.Write(header)
@ -83,6 +70,29 @@ func RewriteGeneratedGogoProtobufFile(name string, packageName string, extractFn
return f.Close() return f.Close()
} }
// ExtractFunc extracts information from the provided TypeSpec and returns true if the type should be
// removed from the destination file.
type ExtractFunc func(*ast.TypeSpec) bool
func RewriteGeneratedGogoProtobufFile(name string, extractFn ExtractFunc, header []byte) error {
return rewriteFile(name, header, func(fset *token.FileSet, file *ast.File) error {
cmap := ast.NewCommentMap(fset, file, file.Comments)
// remove types that are already declared
decls := []ast.Decl{}
for _, d := range file.Decls {
if !dropExistingTypeDeclarations(d, extractFn) {
decls = append(decls, d)
}
}
file.Decls = decls
// remove unmapped comments
file.Comments = cmap.Filter(file).Comments()
return nil
})
}
func dropExistingTypeDeclarations(decl ast.Decl, extractFn ExtractFunc) bool { func dropExistingTypeDeclarations(decl ast.Decl, extractFn ExtractFunc) bool {
switch t := decl.(type) { switch t := decl.(type) {
case *ast.GenDecl: case *ast.GenDecl:
@ -108,16 +118,7 @@ func dropExistingTypeDeclarations(decl ast.Decl, extractFn ExtractFunc) bool {
} }
func RewriteTypesWithProtobufStructTags(name string, structTags map[string]map[string]string) error { func RewriteTypesWithProtobufStructTags(name string, structTags map[string]map[string]string) error {
fset := token.NewFileSet() return rewriteFile(name, []byte{}, func(fset *token.FileSet, file *ast.File) error {
src, err := ioutil.ReadFile(name)
if err != nil {
return err
}
file, err := parser.ParseFile(fset, name, src, parser.DeclarationErrors|parser.ParseComments)
if err != nil {
return err
}
allErrs := []error{} allErrs := []error{}
// set any new struct tags // set any new struct tags
@ -134,26 +135,8 @@ func RewriteTypesWithProtobufStructTags(name string, structTags map[string]map[s
} }
return errors.New(s) return errors.New(s)
} }
return nil
b := &bytes.Buffer{} })
if err := printer.Fprint(b, fset, file); err != nil {
return err
}
body, err := format.Source(b.Bytes())
if err != nil {
return fmt.Errorf("%s\n---\nunable to format %q: %v", b, name, err)
}
f, err := os.OpenFile(name, os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return err
}
defer f.Close()
if _, err := f.Write(body); err != nil {
return err
}
return f.Close()
} }
func updateStructTags(decl ast.Decl, structTags map[string]map[string]string, toCopy []string) []error { func updateStructTags(decl ast.Decl, structTags map[string]map[string]string, toCopy []string) []error {