From dcf292174ec8767d913530d87a5e8c9648da8c90 Mon Sep 17 00:00:00 2001 From: Wojciech Tyczynski Date: Wed, 9 Mar 2016 12:14:00 +0100 Subject: [PATCH] Refactor Rewrite functions --- .../go2idl/go-to-protobuf/protobuf/cmd.go | 2 +- .../go2idl/go-to-protobuf/protobuf/parser.go | 99 ++++++++----------- 2 files changed, 42 insertions(+), 59 deletions(-) diff --git a/cmd/libs/go2idl/go-to-protobuf/protobuf/cmd.go b/cmd/libs/go2idl/go-to-protobuf/protobuf/cmd.go index 147dd5453bd..20f70fd4513 100644 --- a/cmd/libs/go2idl/go-to-protobuf/protobuf/cmd.go +++ b/cmd/libs/go2idl/go-to-protobuf/protobuf/cmd.go @@ -229,7 +229,7 @@ func Run(g *Generator) { // alter the generated protobuf file to remove the generated types (but leave the serializers) and rewrite the // package statement to match the desired package name - if err := RewriteGeneratedGogoProtobufFile(outputPath, p.GoPackageName(), p.ExtractGeneratedType, buf.Bytes()); err != nil { + if err := RewriteGeneratedGogoProtobufFile(outputPath, p.ExtractGeneratedType, buf.Bytes()); err != nil { log.Fatalf("Unable to rewrite generated %s: %v", outputPath, err) } diff --git a/cmd/libs/go2idl/go-to-protobuf/protobuf/parser.go b/cmd/libs/go2idl/go-to-protobuf/protobuf/parser.go index 29c6a7f1c9d..4da6e070730 100644 --- a/cmd/libs/go2idl/go-to-protobuf/protobuf/parser.go +++ b/cmd/libs/go2idl/go-to-protobuf/protobuf/parser.go @@ -33,11 +33,7 @@ import ( customreflect "k8s.io/kubernetes/third_party/golang/reflect" ) -// ExtractFunc extracts information from the provided TypeSpec and returns true if the type should be -// removed from the destination file. -type ExtractFunc func(*ast.TypeSpec) bool - -func RewriteGeneratedGogoProtobufFile(name string, packageName string, extractFn ExtractFunc, header []byte) error { +func rewriteFile(name string, header []byte, rewriteFn func(*token.FileSet, *ast.File) error) error { fset := token.NewFileSet() src, err := ioutil.ReadFile(name) if err != nil { @@ -47,19 +43,10 @@ func RewriteGeneratedGogoProtobufFile(name string, packageName string, extractFn if err != nil { return err } - cmap := ast.NewCommentMap(fset, file, file.Comments) - // remove types that are already declared - decls := []ast.Decl{} - for _, d := range file.Decls { - if !dropExistingTypeDeclarations(d, extractFn) { - decls = append(decls, d) - } + if err := rewriteFn(fset, file); err != nil { + return err } - file.Decls = decls - - // remove unmapped comments - file.Comments = cmap.Filter(file).Comments() b := &bytes.Buffer{} b.Write(header) @@ -83,6 +70,29 @@ func RewriteGeneratedGogoProtobufFile(name string, packageName string, extractFn return f.Close() } +// ExtractFunc extracts information from the provided TypeSpec and returns true if the type should be +// removed from the destination file. +type ExtractFunc func(*ast.TypeSpec) bool + +func RewriteGeneratedGogoProtobufFile(name string, extractFn ExtractFunc, header []byte) error { + return rewriteFile(name, header, func(fset *token.FileSet, file *ast.File) error { + cmap := ast.NewCommentMap(fset, file, file.Comments) + + // remove types that are already declared + decls := []ast.Decl{} + for _, d := range file.Decls { + if !dropExistingTypeDeclarations(d, extractFn) { + decls = append(decls, d) + } + } + file.Decls = decls + + // remove unmapped comments + file.Comments = cmap.Filter(file).Comments() + return nil + }) +} + func dropExistingTypeDeclarations(decl ast.Decl, extractFn ExtractFunc) bool { switch t := decl.(type) { case *ast.GenDecl: @@ -108,52 +118,25 @@ func dropExistingTypeDeclarations(decl ast.Decl, extractFn ExtractFunc) bool { } func RewriteTypesWithProtobufStructTags(name string, structTags map[string]map[string]string) error { - fset := token.NewFileSet() - src, err := ioutil.ReadFile(name) - if err != nil { - return err - } - file, err := parser.ParseFile(fset, name, src, parser.DeclarationErrors|parser.ParseComments) - if err != nil { - return err - } + return rewriteFile(name, []byte{}, func(fset *token.FileSet, file *ast.File) error { + allErrs := []error{} - allErrs := []error{} - - // set any new struct tags - for _, d := range file.Decls { - if errs := updateStructTags(d, structTags, []string{"protobuf"}); len(errs) > 0 { - allErrs = append(allErrs, errs...) + // set any new struct tags + for _, d := range file.Decls { + if errs := updateStructTags(d, structTags, []string{"protobuf"}); len(errs) > 0 { + allErrs = append(allErrs, errs...) + } } - } - if len(allErrs) > 0 { - var s string - for _, err := range allErrs { - s += err.Error() + "\n" + if len(allErrs) > 0 { + var s string + for _, err := range allErrs { + s += err.Error() + "\n" + } + return errors.New(s) } - return errors.New(s) - } - - b := &bytes.Buffer{} - if err := printer.Fprint(b, fset, file); err != nil { - return err - } - - body, err := format.Source(b.Bytes()) - if err != nil { - return fmt.Errorf("%s\n---\nunable to format %q: %v", b, name, err) - } - - f, err := os.OpenFile(name, os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - return err - } - defer f.Close() - if _, err := f.Write(body); err != nil { - return err - } - return f.Close() + return nil + }) } func updateStructTags(decl ast.Decl, structTags map[string]map[string]string, toCopy []string) []error {