Add optional slice and map support to protobuf

Specifying // +protobuf.nullable=true on a Go type that is an alias of a
map or slice will generate a synthetic protobuf message with the type
name that will serialize to the wire in a way that allows the difference
between empty and nil to be recorded.

For instance:

    // +protobuf.nullable=true
    types OptionalMap map[string]string

will create the following message:

    message OptionalMap {
      map<string, string> Items = 1
    }

and generate marshallers that use the presence of OptionalMap to
determine whether the map is nil (rather than Items, which protobuf
provides no way to delineate between empty and nil).
This commit is contained in:
Clayton Coleman 2016-06-12 18:08:34 -04:00
parent 9f7e16c256
commit 5f9e7a00b8
No known key found for this signature in database
GPG Key ID: 3D16906B4F1C5CB3
7 changed files with 278 additions and 23 deletions

View File

@ -124,6 +124,9 @@ func Run(g *Generator) {
protobufNames := NewProtobufNamer() protobufNames := NewProtobufNamer()
outputPackages := generator.Packages{} outputPackages := generator.Packages{}
for _, d := range strings.Split(g.Packages, ",") { for _, d := range strings.Split(g.Packages, ",") {
if strings.Contains(d, "-") {
log.Fatalf("Package names must be valid protobuf package identifiers, which allow only [a-z0-9_]: %s", d)
}
generateAllTypes, outputPackage := true, true generateAllTypes, outputPackage := true, true
switch { switch {
case strings.HasPrefix(d, "+"): case strings.HasPrefix(d, "+"):
@ -235,7 +238,7 @@ func Run(g *Generator) {
// alter the generated protobuf file to remove the generated types (but leave the serializers) and rewrite the // alter the generated protobuf file to remove the generated types (but leave the serializers) and rewrite the
// package statement to match the desired package name // package statement to match the desired package name
if err := RewriteGeneratedGogoProtobufFile(outputPath, p.ExtractGeneratedType, buf.Bytes()); err != nil { if err := RewriteGeneratedGogoProtobufFile(outputPath, p.ExtractGeneratedType, p.OptionalTypeName, buf.Bytes()); err != nil {
log.Fatalf("Unable to rewrite generated %s: %v", outputPath, err) log.Fatalf("Unable to rewrite generated %s: %v", outputPath, err)
} }

View File

@ -118,6 +118,19 @@ func isProtoable(seen map[*types.Type]bool, t *types.Type) bool {
} }
} }
// isOptionalAlias should return true if the specified type has an underlying type
// (is an alias) of a map or slice and has the comment tag protobuf.nullable=true,
// indicating that the type should be nullable in protobuf.
func isOptionalAlias(t *types.Type) bool {
if t.Underlying == nil || (t.Underlying.Kind != types.Map && t.Underlying.Kind != types.Slice) {
return false
}
if types.ExtractCommentTags("+", t.CommentLines)["protobuf.nullable"] != "true" {
return false
}
return true
}
func (g *genProtoIDL) Imports(c *generator.Context) (imports []string) { func (g *genProtoIDL) Imports(c *generator.Context) (imports []string) {
lines := []string{} lines := []string{}
// TODO: this could be expressed more cleanly // TODO: this could be expressed more cleanly
@ -149,6 +162,8 @@ func (g *genProtoIDL) GenerateType(c *generator.Context, t *types.Type, w io.Wri
t: t, t: t,
} }
switch t.Kind { switch t.Kind {
case types.Alias:
return b.doAlias(sw)
case types.Struct: case types.Struct:
return b.doStruct(sw) return b.doStruct(sw)
default: default:
@ -206,7 +221,7 @@ func (p protobufLocator) ProtoTypeFor(t *types.Type) (*types.Type, error) {
return t, nil return t, nil
} }
// it's a message // it's a message
if t.Kind == types.Struct { if t.Kind == types.Struct || isOptionalAlias(t) {
t := &types.Type{ t := &types.Type{
Name: p.namer.GoNameToProtoName(t.Name), Name: p.namer.GoNameToProtoName(t.Name),
Kind: types.Protobuf, Kind: types.Protobuf,
@ -232,6 +247,37 @@ func (b bodyGen) unknown(sw *generator.SnippetWriter) error {
return fmt.Errorf("not sure how to generate: %#v", b.t) return fmt.Errorf("not sure how to generate: %#v", b.t)
} }
func (b bodyGen) doAlias(sw *generator.SnippetWriter) error {
if !isOptionalAlias(b.t) {
return nil
}
var kind string
switch b.t.Underlying.Kind {
case types.Map:
kind = "map"
default:
kind = "slice"
}
optional := &types.Type{
Name: b.t.Name,
Kind: types.Struct,
CommentLines: b.t.CommentLines,
SecondClosestCommentLines: b.t.SecondClosestCommentLines,
Members: []types.Member{
{
Name: "Items",
CommentLines: fmt.Sprintf("items, if empty, will result in an empty %s\n", kind),
Type: b.t.Underlying,
},
},
}
nested := b
nested.t = optional
return nested.doStruct(sw)
}
func (b bodyGen) doStruct(sw *generator.SnippetWriter) error { func (b bodyGen) doStruct(sw *generator.SnippetWriter) error {
if len(b.t.Name.Name) == 0 { if len(b.t.Name.Name) == 0 {
return nil return nil
@ -421,7 +467,7 @@ func memberTypeToProtobufField(locator ProtobufLocator, field *protoField, t *ty
if err := memberTypeToProtobufField(locator, keyField, t.Key); err != nil { if err := memberTypeToProtobufField(locator, keyField, t.Key); err != nil {
return err return err
} }
// All other protobuf types has kind types.Protobuf, so setting types.Map // All other protobuf types have kind types.Protobuf, so setting types.Map
// here would be very misleading. // here would be very misleading.
field.Type = &types.Type{ field.Type = &types.Type{
Kind: types.Protobuf, Kind: types.Protobuf,
@ -444,14 +490,19 @@ func memberTypeToProtobufField(locator ProtobufLocator, field *protoField, t *ty
} }
field.Nullable = true field.Nullable = true
case types.Alias: case types.Alias:
if err := memberTypeToProtobufField(locator, field, t.Underlying); err != nil { if isOptionalAlias(t) {
log.Printf("failed to alias: %s %s: err %v", t.Name, t.Underlying.Name, err) field.Type, err = locator.ProtoTypeFor(t)
return err field.Nullable = true
} else {
if err := memberTypeToProtobufField(locator, field, t.Underlying); err != nil {
log.Printf("failed to alias: %s %s: err %v", t.Name, t.Underlying.Name, err)
return err
}
if field.Extras == nil {
field.Extras = make(map[string]string)
}
field.Extras["(gogoproto.casttype)"] = strconv.Quote(locator.CastTypeName(t.Name))
} }
if field.Extras == nil {
field.Extras = make(map[string]string)
}
field.Extras["(gogoproto.casttype)"] = strconv.Quote(locator.CastTypeName(t.Name))
case types.Slice: case types.Slice:
if t.Elem.Name.Name == "byte" && len(t.Elem.Name.Package) == 0 { if t.Elem.Name.Name == "byte" && len(t.Elem.Name.Package) == 0 {
field.Type = &types.Type{Name: types.Name{Name: "bytes"}, Kind: types.Protobuf} field.Type = &types.Type{Name: types.Name{Name: "bytes"}, Kind: types.Protobuf}
@ -661,7 +712,7 @@ func assembleProtoFile(w io.Writer, f *generator.File) {
fmt.Fprint(w, "syntax = 'proto2';\n\n") fmt.Fprint(w, "syntax = 'proto2';\n\n")
if len(f.PackageName) > 0 { if len(f.PackageName) > 0 {
fmt.Fprintf(w, "package %v;\n\n", f.PackageName) fmt.Fprintf(w, "package %s;\n\n", f.PackageName)
} }
if len(f.Imports) > 0 { if len(f.Imports) > 0 {

View File

@ -101,7 +101,7 @@ type typeNameSet map[types.Name]*protobufPackage
// assignGoTypeToProtoPackage looks for Go and Protobuf types that are referenced by a type in // assignGoTypeToProtoPackage looks for Go and Protobuf types that are referenced by a type in
// a package. It will not recurse into protobuf types. // a package. It will not recurse into protobuf types.
func assignGoTypeToProtoPackage(p *protobufPackage, t *types.Type, local, global typeNameSet) { func assignGoTypeToProtoPackage(p *protobufPackage, t *types.Type, local, global typeNameSet, optional map[types.Name]struct{}) {
newT, isProto := isFundamentalProtoType(t) newT, isProto := isFundamentalProtoType(t)
if isProto { if isProto {
t = newT t = newT
@ -136,20 +136,23 @@ func assignGoTypeToProtoPackage(p *protobufPackage, t *types.Type, local, global
continue continue
} }
if err := protobufTagToField(tag, field, m, t, p.ProtoTypeName()); err == nil && field.Type != nil { if err := protobufTagToField(tag, field, m, t, p.ProtoTypeName()); err == nil && field.Type != nil {
assignGoTypeToProtoPackage(p, field.Type, local, global) assignGoTypeToProtoPackage(p, field.Type, local, global, optional)
continue continue
} }
assignGoTypeToProtoPackage(p, m.Type, local, global) assignGoTypeToProtoPackage(p, m.Type, local, global, optional)
} }
// TODO: should methods be walked? // TODO: should methods be walked?
if t.Elem != nil { if t.Elem != nil {
assignGoTypeToProtoPackage(p, t.Elem, local, global) assignGoTypeToProtoPackage(p, t.Elem, local, global, optional)
} }
if t.Key != nil { if t.Key != nil {
assignGoTypeToProtoPackage(p, t.Key, local, global) assignGoTypeToProtoPackage(p, t.Key, local, global, optional)
} }
if t.Underlying != nil { if t.Underlying != nil {
assignGoTypeToProtoPackage(p, t.Underlying, local, global) if t.Kind == types.Alias && isOptionalAlias(t) {
optional[t.Name] = struct{}{}
}
assignGoTypeToProtoPackage(p, t.Underlying, local, global, optional)
} }
} }
@ -157,19 +160,24 @@ func (n *protobufNamer) AssignTypesToPackages(c *generator.Context) error {
global := make(typeNameSet) global := make(typeNameSet)
for _, p := range n.packages { for _, p := range n.packages {
local := make(typeNameSet) local := make(typeNameSet)
optional := make(map[types.Name]struct{})
p.Imports = NewImportTracker(p.ProtoTypeName()) p.Imports = NewImportTracker(p.ProtoTypeName())
for _, t := range c.Order { for _, t := range c.Order {
if t.Name.Package != p.PackagePath { if t.Name.Package != p.PackagePath {
continue continue
} }
assignGoTypeToProtoPackage(p, t, local, global) assignGoTypeToProtoPackage(p, t, local, global, optional)
} }
p.FilterTypes = make(map[types.Name]struct{}) p.FilterTypes = make(map[types.Name]struct{})
p.LocalNames = make(map[string]struct{}) p.LocalNames = make(map[string]struct{})
p.OptionalTypeNames = make(map[string]struct{})
for k, v := range local { for k, v := range local {
if v == p { if v == p {
p.FilterTypes[k] = struct{}{} p.FilterTypes[k] = struct{}{}
p.LocalNames[k.Name] = struct{}{} p.LocalNames[k.Name] = struct{}{}
if _, ok := optional[k]; ok {
p.OptionalTypeNames[k.Name] = struct{}{}
}
} }
} }
} }

View File

@ -75,6 +75,10 @@ type protobufPackage struct {
// A list of names that this package exports // A list of names that this package exports
LocalNames map[string]struct{} LocalNames map[string]struct{}
// A list of type names in this package that will need marshaller rewriting
// to remove synthetic protobuf fields.
OptionalTypeNames map[string]struct{}
// A list of struct tags to generate onto named struct fields // A list of struct tags to generate onto named struct fields
StructTags map[string]map[string]string StructTags map[string]map[string]string
@ -110,7 +114,9 @@ func (p *protobufPackage) filterFunc(c *generator.Context, t *types.Type) bool {
case types.Builtin: case types.Builtin:
return false return false
case types.Alias: case types.Alias:
return false if !isOptionalAlias(t) {
return false
}
case types.Slice, types.Array, types.Map: case types.Slice, types.Array, types.Map:
return false return false
case types.Pointer: case types.Pointer:
@ -128,6 +134,11 @@ func (p *protobufPackage) HasGoType(name string) bool {
return ok return ok
} }
func (p *protobufPackage) OptionalTypeName(name string) bool {
_, ok := p.OptionalTypeNames[name]
return ok
}
func (p *protobufPackage) ExtractGeneratedType(t *ast.TypeSpec) bool { func (p *protobufPackage) ExtractGeneratedType(t *ast.TypeSpec) bool {
if !p.HasGoType(t.Name.Name) { if !p.HasGoType(t.Name.Name) {
return false return false

View File

@ -74,10 +74,19 @@ func rewriteFile(name string, header []byte, rewriteFn func(*token.FileSet, *ast
// removed from the destination file. // removed from the destination file.
type ExtractFunc func(*ast.TypeSpec) bool type ExtractFunc func(*ast.TypeSpec) bool
func RewriteGeneratedGogoProtobufFile(name string, extractFn ExtractFunc, header []byte) error { // OptionalFunc returns true if the provided local name is a type that has protobuf.nullable=true
// and should have its marshal functions adjusted to remove the 'Items' accessor.
type OptionalFunc func(name string) bool
func RewriteGeneratedGogoProtobufFile(name string, extractFn ExtractFunc, optionalFn OptionalFunc, header []byte) error {
return rewriteFile(name, header, func(fset *token.FileSet, file *ast.File) error { return rewriteFile(name, header, func(fset *token.FileSet, file *ast.File) error {
cmap := ast.NewCommentMap(fset, file, file.Comments) cmap := ast.NewCommentMap(fset, file, file.Comments)
// transform methods that point to optional maps or slices
for _, d := range file.Decls {
rewriteOptionalMethods(d, optionalFn)
}
// remove types that are already declared // remove types that are already declared
decls := []ast.Decl{} decls := []ast.Decl{}
for _, d := range file.Decls { for _, d := range file.Decls {
@ -97,6 +106,168 @@ func RewriteGeneratedGogoProtobufFile(name string, extractFn ExtractFunc, header
}) })
} }
// rewriteOptionalMethods makes specific mutations to marshaller methods that belong to types identified
// as being "optional" (they may be nil on the wire). This allows protobuf to serialize a map or slice and
// properly discriminate between empty and nil (which is not possible in protobuf).
// TODO: move into upstream gogo-protobuf once https://github.com/gogo/protobuf/issues/181
// has agreement
func rewriteOptionalMethods(decl ast.Decl, isOptional OptionalFunc) {
switch t := decl.(type) {
case *ast.FuncDecl:
ident, ptr, ok := receiver(t)
if !ok {
return
}
// correct initialization of the form `m.Field = &OptionalType{}` to
// `m.Field = OptionalType{}`
if t.Name.Name == "Unmarshal" {
ast.Walk(optionalAssignmentVisitor{fn: isOptional}, t.Body)
}
if !isOptional(ident.Name) {
return
}
switch t.Name.Name {
case "Unmarshal":
ast.Walk(&optionalItemsVisitor{}, t.Body)
case "MarshalTo", "Size":
ast.Walk(&optionalItemsVisitor{}, t.Body)
fallthrough
case "Marshal":
// if the method has a pointer receiver, set it back to a normal receiver
if ptr {
t.Recv.List[0].Type = ident
}
}
}
}
type optionalAssignmentVisitor struct {
fn OptionalFunc
}
// Visit walks the provided node, transforming field initializations of the form
// m.Field = &OptionalType{} -> m.Field = OptionalType{}
func (v optionalAssignmentVisitor) Visit(n ast.Node) ast.Visitor {
switch t := n.(type) {
case *ast.AssignStmt:
if len(t.Lhs) == 1 && len(t.Rhs) == 1 {
if !isFieldSelector(t.Lhs[0], "m", "") {
return nil
}
unary, ok := t.Rhs[0].(*ast.UnaryExpr)
if !ok || unary.Op != token.AND {
return nil
}
composite, ok := unary.X.(*ast.CompositeLit)
if !ok || composite.Type == nil || len(composite.Elts) != 0 {
return nil
}
if ident, ok := composite.Type.(*ast.Ident); ok && v.fn(ident.Name) {
t.Rhs[0] = composite
}
}
return nil
}
return v
}
type optionalItemsVisitor struct{}
// Visit walks the provided node, looking for specific patterns to transform that match
// the effective outcome of turning struct{ map[x]y || []x } into map[x]y or []x.
func (v *optionalItemsVisitor) Visit(n ast.Node) ast.Visitor {
switch t := n.(type) {
case *ast.RangeStmt:
if isFieldSelector(t.X, "m", "Items") {
t.X = &ast.Ident{Name: "m"}
}
case *ast.AssignStmt:
if len(t.Lhs) == 1 && len(t.Rhs) == 1 {
switch lhs := t.Lhs[0].(type) {
case *ast.IndexExpr:
if isFieldSelector(lhs.X, "m", "Items") {
lhs.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
}
default:
if isFieldSelector(t.Lhs[0], "m", "Items") {
t.Lhs[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
}
}
switch rhs := t.Rhs[0].(type) {
case *ast.CallExpr:
if ident, ok := rhs.Fun.(*ast.Ident); ok && ident.Name == "append" {
ast.Walk(v, rhs)
if len(rhs.Args) > 0 {
switch arg := rhs.Args[0].(type) {
case *ast.Ident:
if arg.Name == "m" {
rhs.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
}
}
}
return nil
}
}
}
case *ast.IfStmt:
if b, ok := t.Cond.(*ast.BinaryExpr); ok && b.Op == token.EQL {
if isFieldSelector(b.X, "m", "Items") && isIdent(b.Y, "nil") {
b.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
}
}
case *ast.IndexExpr:
if isFieldSelector(t.X, "m", "Items") {
t.X = &ast.Ident{Name: "m"}
return nil
}
case *ast.CallExpr:
changed := false
for i := range t.Args {
if isFieldSelector(t.Args[i], "m", "Items") {
t.Args[i] = &ast.Ident{Name: "m"}
changed = true
}
}
if changed {
return nil
}
}
return v
}
func isFieldSelector(n ast.Expr, name, field string) bool {
s, ok := n.(*ast.SelectorExpr)
if !ok || s.Sel == nil || (field != "" && s.Sel.Name != field) {
return false
}
return isIdent(s.X, name)
}
func isIdent(n ast.Expr, value string) bool {
ident, ok := n.(*ast.Ident)
return ok && ident.Name == value
}
func receiver(f *ast.FuncDecl) (ident *ast.Ident, pointer bool, ok bool) {
if f.Recv == nil || len(f.Recv.List) != 1 {
return nil, false, false
}
switch t := f.Recv.List[0].Type.(type) {
case *ast.StarExpr:
identity, ok := t.X.(*ast.Ident)
if !ok {
return nil, false, false
}
return identity, true, true
case *ast.Ident:
return t, false, true
}
return nil, false, false
}
// dropExistingTypeDeclarations removes any type declaration for which extractFn returns true. The function // dropExistingTypeDeclarations removes any type declaration for which extractFn returns true. The function
// returns true if the entire declaration should be dropped. // returns true if the entire declaration should be dropped.
func dropExistingTypeDeclarations(decl ast.Decl, extractFn ExtractFunc) bool { func dropExistingTypeDeclarations(decl ast.Decl, extractFn ExtractFunc) bool {

View File

@ -576,7 +576,7 @@ func (b *Builder) walkType(u types.Universe, useName *types.Name, in tc.Type) *t
return out return out
case *tc.Named: case *tc.Named:
switch t.Underlying().(type) { switch t.Underlying().(type) {
case *tc.Named, *tc.Basic: case *tc.Named, *tc.Basic, *tc.Map, *tc.Slice:
name := tcNameToName(t.String()) name := tcNameToName(t.String())
out := u.Type(name) out := u.Type(name)
if out.Kind != types.Unknown { if out.Kind != types.Unknown {
@ -591,6 +591,9 @@ func (b *Builder) walkType(u types.Universe, useName *types.Name, in tc.Type) *t
// "feature" for users. This flattens those types // "feature" for users. This flattens those types
// together. // together.
name := tcNameToName(t.String()) name := tcNameToName(t.String())
if name.Name == "OptionalMap" {
fmt.Printf("DEBUG: flattening %T -> %T\n", t, t.Underlying())
}
if out := u.Type(name); out.Kind != types.Unknown { if out := u.Type(name); out.Kind != types.Unknown {
return out // short circuit if we've already made this. return out // short circuit if we've already made this.
} }

View File

@ -391,8 +391,16 @@ type Interface interface{Method(a, b string) (c, d string)}
t.Errorf("type %s not found", n) t.Errorf("type %s not found", n)
continue continue
} }
if e, a := item.k, thisType.Kind; e != a { underlyingType := thisType
t.Errorf("%v-%s: type kind wrong, wanted %v, got %v (%#v)", nameIndex, n, e, a, thisType) if item.k != types.Alias && thisType.Kind == types.Alias {
underlyingType = thisType.Underlying
if underlyingType == nil {
t.Errorf("underlying type %s not found", n)
continue
}
}
if e, a := item.k, underlyingType.Kind; e != a {
t.Errorf("%v-%s: type kind wrong, wanted %v, got %v (%#v)", nameIndex, n, e, a, underlyingType)
} }
if e, a := item.names[nameIndex], namer.Name(thisType); e != a { if e, a := item.names[nameIndex], namer.Name(thisType); e != a {
t.Errorf("%v-%s: Expected %q, got %q", nameIndex, n, e, a) t.Errorf("%v-%s: Expected %q, got %q", nameIndex, n, e, a)