Add verification to code gen

This commit is contained in:
Daniel Smith
2015-11-04 13:56:38 -08:00
parent d524bd8f52
commit ad925dd2e8
16 changed files with 193 additions and 22 deletions

View File

@@ -21,6 +21,7 @@ import (
"fmt"
"go/format"
"io"
"io/ioutil"
"log"
"os"
"path/filepath"
@@ -30,17 +31,29 @@ import (
"k8s.io/kubernetes/cmd/libs/go2idl/types"
)
func errs2strings(errors []error) []string {
strs := make([]string, len(errors))
for i := range errors {
strs[i] = errors[i].Error()
}
return strs
}
// ExecutePackages runs the generators for every package in 'packages'. 'outDir'
// is the base directory in which to place all the generated packages; it
// should be a physical path on disk, not an import path. e.g.:
// /path/to/home/path/to/gopath/src/
// Each package has its import path already, this will be appended to 'outDir'.
func (c *Context) ExecutePackages(outDir string, packages Packages) error {
var errors []error
for _, p := range packages {
if err := c.ExecutePackage(outDir, p); err != nil {
return err
errors = append(errors, err)
}
}
if len(errors) > 0 {
return fmt.Errorf("some packages had errors:\n%v\n", strings.Join(errs2strings(errors), "\n"))
}
return nil
}
@@ -61,8 +74,11 @@ func (ft golangFileType) AssembleFile(f *File, pathname string) error {
return et.Error()
}
if formatted, err := format.Source(b.Bytes()); err != nil {
log.Printf("Warning: unable to run gofmt on %q (%v).", pathname, err)
_, err = destFile.Write(b.Bytes())
err = fmt.Errorf("unable to run gofmt on %q (%v).", pathname, err)
// Write the file anyway, so they can see what's going wrong and fix the generator.
if _, err2 := destFile.Write(b.Bytes()); err2 != nil {
return err2
}
return err
} else {
_, err = destFile.Write(formatted)
@@ -70,6 +86,41 @@ func (ft golangFileType) AssembleFile(f *File, pathname string) error {
}
}
func (ft golangFileType) VerifyFile(f *File, pathname string) error {
log.Printf("Verifying file %q", pathname)
friendlyName := filepath.Join(f.PackageName, f.Name)
b := &bytes.Buffer{}
et := NewErrorTracker(b)
ft.assemble(et, f)
if et.Error() != nil {
return et.Error()
}
formatted, err := format.Source(b.Bytes())
if err != nil {
return fmt.Errorf("unable to gofmt the output for %q: %v", friendlyName, err)
}
existing, err := ioutil.ReadFile(pathname)
if err != nil {
return fmt.Errorf("unable to read file %q for comparison: %v", friendlyName, err)
}
if bytes.Compare(formatted, existing) == 0 {
return nil
}
// Be nice and find the first place where they differ
i := 0
for i < len(formatted) && i < len(existing) && formatted[i] == existing[i] {
i++
}
eDiff, fDiff := existing[i:], formatted[i:]
if len(eDiff) > 100 {
eDiff = eDiff[:100]
}
if len(fDiff) > 100 {
fDiff = fDiff[:100]
}
return fmt.Errorf("output for %q differs; first existing/expected diff: \n %q\n %q", friendlyName, string(eDiff), string(fDiff))
}
func (ft golangFileType) assemble(w io.Writer, f *File) {
w.Write(f.Header)
fmt.Fprintf(w, "package %v\n\n", f.PackageName)
@@ -149,7 +200,7 @@ func (c *Context) addNameSystems(namers namer.NameSystems) *Context {
// import path already, this will be appended to 'outDir'.
func (c *Context) ExecutePackage(outDir string, p Package) error {
path := filepath.Join(outDir, p.Path())
log.Printf("Executing package %v into %v", p.Name(), path)
log.Printf("Processing package %q, disk location %q", p.Name(), path)
// Filter out any types the *package* doesn't care about.
packageContext := c.filteredBy(p.Filter)
os.MkdirAll(path, 0755)
@@ -207,14 +258,25 @@ func (c *Context) ExecutePackage(outDir string, p Package) error {
}
}
var errors []error
for _, f := range files {
finalPath := filepath.Join(path, f.Name)
assembler, ok := c.FileTypes[f.FileType]
if !ok {
return fmt.Errorf("the file type %q registered for file %q does not exist in the context", f.FileType, f.Name)
}
if err := assembler.AssembleFile(f, filepath.Join(path, f.Name)); err != nil {
return err
var err error
if c.Verify {
err = assembler.VerifyFile(f, finalPath)
} else {
err = assembler.AssembleFile(f, finalPath)
}
if err != nil {
errors = append(errors, err)
}
}
if len(errors) > 0 {
return fmt.Errorf("errors in package %q:\n%v\n", p.Name(), strings.Join(errs2strings(errors), "\n"))
}
return nil
}