Add optional slice and map support to protobuf

Specifying // +protobuf.nullable=true on a Go type that is an alias of a
map or slice will generate a synthetic protobuf message with the type
name that will serialize to the wire in a way that allows the difference
between empty and nil to be recorded.

For instance:

    // +protobuf.nullable=true
    types OptionalMap map[string]string

will create the following message:

    message OptionalMap {
      map<string, string> Items = 1
    }

and generate marshallers that use the presence of OptionalMap to
determine whether the map is nil (rather than Items, which protobuf
provides no way to delineate between empty and nil).
This commit is contained in:
Clayton Coleman 2016-06-12 18:08:34 -04:00
parent 9f7e16c256
commit 5f9e7a00b8
No known key found for this signature in database
GPG Key ID: 3D16906B4F1C5CB3
7 changed files with 278 additions and 23 deletions

View File

@ -124,6 +124,9 @@ func Run(g *Generator) {
protobufNames := NewProtobufNamer()
outputPackages := generator.Packages{}
for _, d := range strings.Split(g.Packages, ",") {
if strings.Contains(d, "-") {
log.Fatalf("Package names must be valid protobuf package identifiers, which allow only [a-z0-9_]: %s", d)
}
generateAllTypes, outputPackage := true, true
switch {
case strings.HasPrefix(d, "+"):
@ -235,7 +238,7 @@ func Run(g *Generator) {
// alter the generated protobuf file to remove the generated types (but leave the serializers) and rewrite the
// package statement to match the desired package name
if err := RewriteGeneratedGogoProtobufFile(outputPath, p.ExtractGeneratedType, buf.Bytes()); err != nil {
if err := RewriteGeneratedGogoProtobufFile(outputPath, p.ExtractGeneratedType, p.OptionalTypeName, buf.Bytes()); err != nil {
log.Fatalf("Unable to rewrite generated %s: %v", outputPath, err)
}

View File

@ -118,6 +118,19 @@ func isProtoable(seen map[*types.Type]bool, t *types.Type) bool {
}
}
// isOptionalAlias should return true if the specified type has an underlying type
// (is an alias) of a map or slice and has the comment tag protobuf.nullable=true,
// indicating that the type should be nullable in protobuf.
func isOptionalAlias(t *types.Type) bool {
if t.Underlying == nil || (t.Underlying.Kind != types.Map && t.Underlying.Kind != types.Slice) {
return false
}
if types.ExtractCommentTags("+", t.CommentLines)["protobuf.nullable"] != "true" {
return false
}
return true
}
func (g *genProtoIDL) Imports(c *generator.Context) (imports []string) {
lines := []string{}
// TODO: this could be expressed more cleanly
@ -149,6 +162,8 @@ func (g *genProtoIDL) GenerateType(c *generator.Context, t *types.Type, w io.Wri
t: t,
}
switch t.Kind {
case types.Alias:
return b.doAlias(sw)
case types.Struct:
return b.doStruct(sw)
default:
@ -206,7 +221,7 @@ func (p protobufLocator) ProtoTypeFor(t *types.Type) (*types.Type, error) {
return t, nil
}
// it's a message
if t.Kind == types.Struct {
if t.Kind == types.Struct || isOptionalAlias(t) {
t := &types.Type{
Name: p.namer.GoNameToProtoName(t.Name),
Kind: types.Protobuf,
@ -232,6 +247,37 @@ func (b bodyGen) unknown(sw *generator.SnippetWriter) error {
return fmt.Errorf("not sure how to generate: %#v", b.t)
}
func (b bodyGen) doAlias(sw *generator.SnippetWriter) error {
if !isOptionalAlias(b.t) {
return nil
}
var kind string
switch b.t.Underlying.Kind {
case types.Map:
kind = "map"
default:
kind = "slice"
}
optional := &types.Type{
Name: b.t.Name,
Kind: types.Struct,
CommentLines: b.t.CommentLines,
SecondClosestCommentLines: b.t.SecondClosestCommentLines,
Members: []types.Member{
{
Name: "Items",
CommentLines: fmt.Sprintf("items, if empty, will result in an empty %s\n", kind),
Type: b.t.Underlying,
},
},
}
nested := b
nested.t = optional
return nested.doStruct(sw)
}
func (b bodyGen) doStruct(sw *generator.SnippetWriter) error {
if len(b.t.Name.Name) == 0 {
return nil
@ -421,7 +467,7 @@ func memberTypeToProtobufField(locator ProtobufLocator, field *protoField, t *ty
if err := memberTypeToProtobufField(locator, keyField, t.Key); err != nil {
return err
}
// All other protobuf types has kind types.Protobuf, so setting types.Map
// All other protobuf types have kind types.Protobuf, so setting types.Map
// here would be very misleading.
field.Type = &types.Type{
Kind: types.Protobuf,
@ -444,14 +490,19 @@ func memberTypeToProtobufField(locator ProtobufLocator, field *protoField, t *ty
}
field.Nullable = true
case types.Alias:
if err := memberTypeToProtobufField(locator, field, t.Underlying); err != nil {
log.Printf("failed to alias: %s %s: err %v", t.Name, t.Underlying.Name, err)
return err
if isOptionalAlias(t) {
field.Type, err = locator.ProtoTypeFor(t)
field.Nullable = true
} else {
if err := memberTypeToProtobufField(locator, field, t.Underlying); err != nil {
log.Printf("failed to alias: %s %s: err %v", t.Name, t.Underlying.Name, err)
return err
}
if field.Extras == nil {
field.Extras = make(map[string]string)
}
field.Extras["(gogoproto.casttype)"] = strconv.Quote(locator.CastTypeName(t.Name))
}
if field.Extras == nil {
field.Extras = make(map[string]string)
}
field.Extras["(gogoproto.casttype)"] = strconv.Quote(locator.CastTypeName(t.Name))
case types.Slice:
if t.Elem.Name.Name == "byte" && len(t.Elem.Name.Package) == 0 {
field.Type = &types.Type{Name: types.Name{Name: "bytes"}, Kind: types.Protobuf}
@ -661,7 +712,7 @@ func assembleProtoFile(w io.Writer, f *generator.File) {
fmt.Fprint(w, "syntax = 'proto2';\n\n")
if len(f.PackageName) > 0 {
fmt.Fprintf(w, "package %v;\n\n", f.PackageName)
fmt.Fprintf(w, "package %s;\n\n", f.PackageName)
}
if len(f.Imports) > 0 {

View File

@ -101,7 +101,7 @@ type typeNameSet map[types.Name]*protobufPackage
// assignGoTypeToProtoPackage looks for Go and Protobuf types that are referenced by a type in
// a package. It will not recurse into protobuf types.
func assignGoTypeToProtoPackage(p *protobufPackage, t *types.Type, local, global typeNameSet) {
func assignGoTypeToProtoPackage(p *protobufPackage, t *types.Type, local, global typeNameSet, optional map[types.Name]struct{}) {
newT, isProto := isFundamentalProtoType(t)
if isProto {
t = newT
@ -136,20 +136,23 @@ func assignGoTypeToProtoPackage(p *protobufPackage, t *types.Type, local, global
continue
}
if err := protobufTagToField(tag, field, m, t, p.ProtoTypeName()); err == nil && field.Type != nil {
assignGoTypeToProtoPackage(p, field.Type, local, global)
assignGoTypeToProtoPackage(p, field.Type, local, global, optional)
continue
}
assignGoTypeToProtoPackage(p, m.Type, local, global)
assignGoTypeToProtoPackage(p, m.Type, local, global, optional)
}
// TODO: should methods be walked?
if t.Elem != nil {
assignGoTypeToProtoPackage(p, t.Elem, local, global)
assignGoTypeToProtoPackage(p, t.Elem, local, global, optional)
}
if t.Key != nil {
assignGoTypeToProtoPackage(p, t.Key, local, global)
assignGoTypeToProtoPackage(p, t.Key, local, global, optional)
}
if t.Underlying != nil {
assignGoTypeToProtoPackage(p, t.Underlying, local, global)
if t.Kind == types.Alias && isOptionalAlias(t) {
optional[t.Name] = struct{}{}
}
assignGoTypeToProtoPackage(p, t.Underlying, local, global, optional)
}
}
@ -157,19 +160,24 @@ func (n *protobufNamer) AssignTypesToPackages(c *generator.Context) error {
global := make(typeNameSet)
for _, p := range n.packages {
local := make(typeNameSet)
optional := make(map[types.Name]struct{})
p.Imports = NewImportTracker(p.ProtoTypeName())
for _, t := range c.Order {
if t.Name.Package != p.PackagePath {
continue
}
assignGoTypeToProtoPackage(p, t, local, global)
assignGoTypeToProtoPackage(p, t, local, global, optional)
}
p.FilterTypes = make(map[types.Name]struct{})
p.LocalNames = make(map[string]struct{})
p.OptionalTypeNames = make(map[string]struct{})
for k, v := range local {
if v == p {
p.FilterTypes[k] = struct{}{}
p.LocalNames[k.Name] = struct{}{}
if _, ok := optional[k]; ok {
p.OptionalTypeNames[k.Name] = struct{}{}
}
}
}
}

View File

@ -75,6 +75,10 @@ type protobufPackage struct {
// A list of names that this package exports
LocalNames map[string]struct{}
// A list of type names in this package that will need marshaller rewriting
// to remove synthetic protobuf fields.
OptionalTypeNames map[string]struct{}
// A list of struct tags to generate onto named struct fields
StructTags map[string]map[string]string
@ -110,7 +114,9 @@ func (p *protobufPackage) filterFunc(c *generator.Context, t *types.Type) bool {
case types.Builtin:
return false
case types.Alias:
return false
if !isOptionalAlias(t) {
return false
}
case types.Slice, types.Array, types.Map:
return false
case types.Pointer:
@ -128,6 +134,11 @@ func (p *protobufPackage) HasGoType(name string) bool {
return ok
}
func (p *protobufPackage) OptionalTypeName(name string) bool {
_, ok := p.OptionalTypeNames[name]
return ok
}
func (p *protobufPackage) ExtractGeneratedType(t *ast.TypeSpec) bool {
if !p.HasGoType(t.Name.Name) {
return false

View File

@ -74,10 +74,19 @@ func rewriteFile(name string, header []byte, rewriteFn func(*token.FileSet, *ast
// removed from the destination file.
type ExtractFunc func(*ast.TypeSpec) bool
func RewriteGeneratedGogoProtobufFile(name string, extractFn ExtractFunc, header []byte) error {
// OptionalFunc returns true if the provided local name is a type that has protobuf.nullable=true
// and should have its marshal functions adjusted to remove the 'Items' accessor.
type OptionalFunc func(name string) bool
func RewriteGeneratedGogoProtobufFile(name string, extractFn ExtractFunc, optionalFn OptionalFunc, header []byte) error {
return rewriteFile(name, header, func(fset *token.FileSet, file *ast.File) error {
cmap := ast.NewCommentMap(fset, file, file.Comments)
// transform methods that point to optional maps or slices
for _, d := range file.Decls {
rewriteOptionalMethods(d, optionalFn)
}
// remove types that are already declared
decls := []ast.Decl{}
for _, d := range file.Decls {
@ -97,6 +106,168 @@ func RewriteGeneratedGogoProtobufFile(name string, extractFn ExtractFunc, header
})
}
// rewriteOptionalMethods makes specific mutations to marshaller methods that belong to types identified
// as being "optional" (they may be nil on the wire). This allows protobuf to serialize a map or slice and
// properly discriminate between empty and nil (which is not possible in protobuf).
// TODO: move into upstream gogo-protobuf once https://github.com/gogo/protobuf/issues/181
// has agreement
func rewriteOptionalMethods(decl ast.Decl, isOptional OptionalFunc) {
switch t := decl.(type) {
case *ast.FuncDecl:
ident, ptr, ok := receiver(t)
if !ok {
return
}
// correct initialization of the form `m.Field = &OptionalType{}` to
// `m.Field = OptionalType{}`
if t.Name.Name == "Unmarshal" {
ast.Walk(optionalAssignmentVisitor{fn: isOptional}, t.Body)
}
if !isOptional(ident.Name) {
return
}
switch t.Name.Name {
case "Unmarshal":
ast.Walk(&optionalItemsVisitor{}, t.Body)
case "MarshalTo", "Size":
ast.Walk(&optionalItemsVisitor{}, t.Body)
fallthrough
case "Marshal":
// if the method has a pointer receiver, set it back to a normal receiver
if ptr {
t.Recv.List[0].Type = ident
}
}
}
}
type optionalAssignmentVisitor struct {
fn OptionalFunc
}
// Visit walks the provided node, transforming field initializations of the form
// m.Field = &OptionalType{} -> m.Field = OptionalType{}
func (v optionalAssignmentVisitor) Visit(n ast.Node) ast.Visitor {
switch t := n.(type) {
case *ast.AssignStmt:
if len(t.Lhs) == 1 && len(t.Rhs) == 1 {
if !isFieldSelector(t.Lhs[0], "m", "") {
return nil
}
unary, ok := t.Rhs[0].(*ast.UnaryExpr)
if !ok || unary.Op != token.AND {
return nil
}
composite, ok := unary.X.(*ast.CompositeLit)
if !ok || composite.Type == nil || len(composite.Elts) != 0 {
return nil
}
if ident, ok := composite.Type.(*ast.Ident); ok && v.fn(ident.Name) {
t.Rhs[0] = composite
}
}
return nil
}
return v
}
type optionalItemsVisitor struct{}
// Visit walks the provided node, looking for specific patterns to transform that match
// the effective outcome of turning struct{ map[x]y || []x } into map[x]y or []x.
func (v *optionalItemsVisitor) Visit(n ast.Node) ast.Visitor {
switch t := n.(type) {
case *ast.RangeStmt:
if isFieldSelector(t.X, "m", "Items") {
t.X = &ast.Ident{Name: "m"}
}
case *ast.AssignStmt:
if len(t.Lhs) == 1 && len(t.Rhs) == 1 {
switch lhs := t.Lhs[0].(type) {
case *ast.IndexExpr:
if isFieldSelector(lhs.X, "m", "Items") {
lhs.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
}
default:
if isFieldSelector(t.Lhs[0], "m", "Items") {
t.Lhs[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
}
}
switch rhs := t.Rhs[0].(type) {
case *ast.CallExpr:
if ident, ok := rhs.Fun.(*ast.Ident); ok && ident.Name == "append" {
ast.Walk(v, rhs)
if len(rhs.Args) > 0 {
switch arg := rhs.Args[0].(type) {
case *ast.Ident:
if arg.Name == "m" {
rhs.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
}
}
}
return nil
}
}
}
case *ast.IfStmt:
if b, ok := t.Cond.(*ast.BinaryExpr); ok && b.Op == token.EQL {
if isFieldSelector(b.X, "m", "Items") && isIdent(b.Y, "nil") {
b.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
}
}
case *ast.IndexExpr:
if isFieldSelector(t.X, "m", "Items") {
t.X = &ast.Ident{Name: "m"}
return nil
}
case *ast.CallExpr:
changed := false
for i := range t.Args {
if isFieldSelector(t.Args[i], "m", "Items") {
t.Args[i] = &ast.Ident{Name: "m"}
changed = true
}
}
if changed {
return nil
}
}
return v
}
func isFieldSelector(n ast.Expr, name, field string) bool {
s, ok := n.(*ast.SelectorExpr)
if !ok || s.Sel == nil || (field != "" && s.Sel.Name != field) {
return false
}
return isIdent(s.X, name)
}
func isIdent(n ast.Expr, value string) bool {
ident, ok := n.(*ast.Ident)
return ok && ident.Name == value
}
func receiver(f *ast.FuncDecl) (ident *ast.Ident, pointer bool, ok bool) {
if f.Recv == nil || len(f.Recv.List) != 1 {
return nil, false, false
}
switch t := f.Recv.List[0].Type.(type) {
case *ast.StarExpr:
identity, ok := t.X.(*ast.Ident)
if !ok {
return nil, false, false
}
return identity, true, true
case *ast.Ident:
return t, false, true
}
return nil, false, false
}
// dropExistingTypeDeclarations removes any type declaration for which extractFn returns true. The function
// returns true if the entire declaration should be dropped.
func dropExistingTypeDeclarations(decl ast.Decl, extractFn ExtractFunc) bool {

View File

@ -576,7 +576,7 @@ func (b *Builder) walkType(u types.Universe, useName *types.Name, in tc.Type) *t
return out
case *tc.Named:
switch t.Underlying().(type) {
case *tc.Named, *tc.Basic:
case *tc.Named, *tc.Basic, *tc.Map, *tc.Slice:
name := tcNameToName(t.String())
out := u.Type(name)
if out.Kind != types.Unknown {
@ -591,6 +591,9 @@ func (b *Builder) walkType(u types.Universe, useName *types.Name, in tc.Type) *t
// "feature" for users. This flattens those types
// together.
name := tcNameToName(t.String())
if name.Name == "OptionalMap" {
fmt.Printf("DEBUG: flattening %T -> %T\n", t, t.Underlying())
}
if out := u.Type(name); out.Kind != types.Unknown {
return out // short circuit if we've already made this.
}

View File

@ -391,8 +391,16 @@ type Interface interface{Method(a, b string) (c, d string)}
t.Errorf("type %s not found", n)
continue
}
if e, a := item.k, thisType.Kind; e != a {
t.Errorf("%v-%s: type kind wrong, wanted %v, got %v (%#v)", nameIndex, n, e, a, thisType)
underlyingType := thisType
if item.k != types.Alias && thisType.Kind == types.Alias {
underlyingType = thisType.Underlying
if underlyingType == nil {
t.Errorf("underlying type %s not found", n)
continue
}
}
if e, a := item.k, underlyingType.Kind; e != a {
t.Errorf("%v-%s: type kind wrong, wanted %v, got %v (%#v)", nameIndex, n, e, a, underlyingType)
}
if e, a := item.names[nameIndex], namer.Name(thisType); e != a {
t.Errorf("%v-%s: Expected %q, got %q", nameIndex, n, e, a)