diff --git a/cmd/libs/go2idl/go-to-protobuf/protobuf/cmd.go b/cmd/libs/go2idl/go-to-protobuf/protobuf/cmd.go index 5f21816a4ab..7f8d04e1c43 100644 --- a/cmd/libs/go2idl/go-to-protobuf/protobuf/cmd.go +++ b/cmd/libs/go2idl/go-to-protobuf/protobuf/cmd.go @@ -124,6 +124,9 @@ func Run(g *Generator) { protobufNames := NewProtobufNamer() outputPackages := generator.Packages{} for _, d := range strings.Split(g.Packages, ",") { + if strings.Contains(d, "-") { + log.Fatalf("Package names must be valid protobuf package identifiers, which allow only [a-z0-9_]: %s", d) + } generateAllTypes, outputPackage := true, true switch { case strings.HasPrefix(d, "+"): @@ -235,7 +238,7 @@ func Run(g *Generator) { // alter the generated protobuf file to remove the generated types (but leave the serializers) and rewrite the // package statement to match the desired package name - if err := RewriteGeneratedGogoProtobufFile(outputPath, p.ExtractGeneratedType, buf.Bytes()); err != nil { + if err := RewriteGeneratedGogoProtobufFile(outputPath, p.ExtractGeneratedType, p.OptionalTypeName, buf.Bytes()); err != nil { log.Fatalf("Unable to rewrite generated %s: %v", outputPath, err) } diff --git a/cmd/libs/go2idl/go-to-protobuf/protobuf/generator.go b/cmd/libs/go2idl/go-to-protobuf/protobuf/generator.go index 92e9dc85e8f..739a6ef70f4 100644 --- a/cmd/libs/go2idl/go-to-protobuf/protobuf/generator.go +++ b/cmd/libs/go2idl/go-to-protobuf/protobuf/generator.go @@ -118,6 +118,19 @@ func isProtoable(seen map[*types.Type]bool, t *types.Type) bool { } } +// isOptionalAlias should return true if the specified type has an underlying type +// (is an alias) of a map or slice and has the comment tag protobuf.nullable=true, +// indicating that the type should be nullable in protobuf. +func isOptionalAlias(t *types.Type) bool { + if t.Underlying == nil || (t.Underlying.Kind != types.Map && t.Underlying.Kind != types.Slice) { + return false + } + if types.ExtractCommentTags("+", t.CommentLines)["protobuf.nullable"] != "true" { + return false + } + return true +} + func (g *genProtoIDL) Imports(c *generator.Context) (imports []string) { lines := []string{} // TODO: this could be expressed more cleanly @@ -149,6 +162,8 @@ func (g *genProtoIDL) GenerateType(c *generator.Context, t *types.Type, w io.Wri t: t, } switch t.Kind { + case types.Alias: + return b.doAlias(sw) case types.Struct: return b.doStruct(sw) default: @@ -206,7 +221,7 @@ func (p protobufLocator) ProtoTypeFor(t *types.Type) (*types.Type, error) { return t, nil } // it's a message - if t.Kind == types.Struct { + if t.Kind == types.Struct || isOptionalAlias(t) { t := &types.Type{ Name: p.namer.GoNameToProtoName(t.Name), Kind: types.Protobuf, @@ -232,6 +247,37 @@ func (b bodyGen) unknown(sw *generator.SnippetWriter) error { return fmt.Errorf("not sure how to generate: %#v", b.t) } +func (b bodyGen) doAlias(sw *generator.SnippetWriter) error { + if !isOptionalAlias(b.t) { + return nil + } + + var kind string + switch b.t.Underlying.Kind { + case types.Map: + kind = "map" + default: + kind = "slice" + } + optional := &types.Type{ + Name: b.t.Name, + Kind: types.Struct, + + CommentLines: b.t.CommentLines, + SecondClosestCommentLines: b.t.SecondClosestCommentLines, + Members: []types.Member{ + { + Name: "Items", + CommentLines: fmt.Sprintf("items, if empty, will result in an empty %s\n", kind), + Type: b.t.Underlying, + }, + }, + } + nested := b + nested.t = optional + return nested.doStruct(sw) +} + func (b bodyGen) doStruct(sw *generator.SnippetWriter) error { if len(b.t.Name.Name) == 0 { return nil @@ -421,7 +467,7 @@ func memberTypeToProtobufField(locator ProtobufLocator, field *protoField, t *ty if err := memberTypeToProtobufField(locator, keyField, t.Key); err != nil { return err } - // All other protobuf types has kind types.Protobuf, so setting types.Map + // All other protobuf types have kind types.Protobuf, so setting types.Map // here would be very misleading. field.Type = &types.Type{ Kind: types.Protobuf, @@ -444,14 +490,19 @@ func memberTypeToProtobufField(locator ProtobufLocator, field *protoField, t *ty } field.Nullable = true case types.Alias: - if err := memberTypeToProtobufField(locator, field, t.Underlying); err != nil { - log.Printf("failed to alias: %s %s: err %v", t.Name, t.Underlying.Name, err) - return err + if isOptionalAlias(t) { + field.Type, err = locator.ProtoTypeFor(t) + field.Nullable = true + } else { + if err := memberTypeToProtobufField(locator, field, t.Underlying); err != nil { + log.Printf("failed to alias: %s %s: err %v", t.Name, t.Underlying.Name, err) + return err + } + if field.Extras == nil { + field.Extras = make(map[string]string) + } + field.Extras["(gogoproto.casttype)"] = strconv.Quote(locator.CastTypeName(t.Name)) } - if field.Extras == nil { - field.Extras = make(map[string]string) - } - field.Extras["(gogoproto.casttype)"] = strconv.Quote(locator.CastTypeName(t.Name)) case types.Slice: if t.Elem.Name.Name == "byte" && len(t.Elem.Name.Package) == 0 { field.Type = &types.Type{Name: types.Name{Name: "bytes"}, Kind: types.Protobuf} @@ -661,7 +712,7 @@ func assembleProtoFile(w io.Writer, f *generator.File) { fmt.Fprint(w, "syntax = 'proto2';\n\n") if len(f.PackageName) > 0 { - fmt.Fprintf(w, "package %v;\n\n", f.PackageName) + fmt.Fprintf(w, "package %s;\n\n", f.PackageName) } if len(f.Imports) > 0 { diff --git a/cmd/libs/go2idl/go-to-protobuf/protobuf/namer.go b/cmd/libs/go2idl/go-to-protobuf/protobuf/namer.go index 9e2b9853462..8da08af4cd6 100644 --- a/cmd/libs/go2idl/go-to-protobuf/protobuf/namer.go +++ b/cmd/libs/go2idl/go-to-protobuf/protobuf/namer.go @@ -101,7 +101,7 @@ type typeNameSet map[types.Name]*protobufPackage // assignGoTypeToProtoPackage looks for Go and Protobuf types that are referenced by a type in // a package. It will not recurse into protobuf types. -func assignGoTypeToProtoPackage(p *protobufPackage, t *types.Type, local, global typeNameSet) { +func assignGoTypeToProtoPackage(p *protobufPackage, t *types.Type, local, global typeNameSet, optional map[types.Name]struct{}) { newT, isProto := isFundamentalProtoType(t) if isProto { t = newT @@ -136,20 +136,23 @@ func assignGoTypeToProtoPackage(p *protobufPackage, t *types.Type, local, global continue } if err := protobufTagToField(tag, field, m, t, p.ProtoTypeName()); err == nil && field.Type != nil { - assignGoTypeToProtoPackage(p, field.Type, local, global) + assignGoTypeToProtoPackage(p, field.Type, local, global, optional) continue } - assignGoTypeToProtoPackage(p, m.Type, local, global) + assignGoTypeToProtoPackage(p, m.Type, local, global, optional) } // TODO: should methods be walked? if t.Elem != nil { - assignGoTypeToProtoPackage(p, t.Elem, local, global) + assignGoTypeToProtoPackage(p, t.Elem, local, global, optional) } if t.Key != nil { - assignGoTypeToProtoPackage(p, t.Key, local, global) + assignGoTypeToProtoPackage(p, t.Key, local, global, optional) } if t.Underlying != nil { - assignGoTypeToProtoPackage(p, t.Underlying, local, global) + if t.Kind == types.Alias && isOptionalAlias(t) { + optional[t.Name] = struct{}{} + } + assignGoTypeToProtoPackage(p, t.Underlying, local, global, optional) } } @@ -157,19 +160,24 @@ func (n *protobufNamer) AssignTypesToPackages(c *generator.Context) error { global := make(typeNameSet) for _, p := range n.packages { local := make(typeNameSet) + optional := make(map[types.Name]struct{}) p.Imports = NewImportTracker(p.ProtoTypeName()) for _, t := range c.Order { if t.Name.Package != p.PackagePath { continue } - assignGoTypeToProtoPackage(p, t, local, global) + assignGoTypeToProtoPackage(p, t, local, global, optional) } p.FilterTypes = make(map[types.Name]struct{}) p.LocalNames = make(map[string]struct{}) + p.OptionalTypeNames = make(map[string]struct{}) for k, v := range local { if v == p { p.FilterTypes[k] = struct{}{} p.LocalNames[k.Name] = struct{}{} + if _, ok := optional[k]; ok { + p.OptionalTypeNames[k.Name] = struct{}{} + } } } } diff --git a/cmd/libs/go2idl/go-to-protobuf/protobuf/package.go b/cmd/libs/go2idl/go-to-protobuf/protobuf/package.go index cab22352211..742487e191d 100644 --- a/cmd/libs/go2idl/go-to-protobuf/protobuf/package.go +++ b/cmd/libs/go2idl/go-to-protobuf/protobuf/package.go @@ -75,6 +75,10 @@ type protobufPackage struct { // A list of names that this package exports LocalNames map[string]struct{} + // A list of type names in this package that will need marshaller rewriting + // to remove synthetic protobuf fields. + OptionalTypeNames map[string]struct{} + // A list of struct tags to generate onto named struct fields StructTags map[string]map[string]string @@ -110,7 +114,9 @@ func (p *protobufPackage) filterFunc(c *generator.Context, t *types.Type) bool { case types.Builtin: return false case types.Alias: - return false + if !isOptionalAlias(t) { + return false + } case types.Slice, types.Array, types.Map: return false case types.Pointer: @@ -128,6 +134,11 @@ func (p *protobufPackage) HasGoType(name string) bool { return ok } +func (p *protobufPackage) OptionalTypeName(name string) bool { + _, ok := p.OptionalTypeNames[name] + return ok +} + func (p *protobufPackage) ExtractGeneratedType(t *ast.TypeSpec) bool { if !p.HasGoType(t.Name.Name) { return false diff --git a/cmd/libs/go2idl/go-to-protobuf/protobuf/parser.go b/cmd/libs/go2idl/go-to-protobuf/protobuf/parser.go index 4ac4b86a37d..4fef53ff93a 100644 --- a/cmd/libs/go2idl/go-to-protobuf/protobuf/parser.go +++ b/cmd/libs/go2idl/go-to-protobuf/protobuf/parser.go @@ -74,10 +74,19 @@ func rewriteFile(name string, header []byte, rewriteFn func(*token.FileSet, *ast // removed from the destination file. type ExtractFunc func(*ast.TypeSpec) bool -func RewriteGeneratedGogoProtobufFile(name string, extractFn ExtractFunc, header []byte) error { +// OptionalFunc returns true if the provided local name is a type that has protobuf.nullable=true +// and should have its marshal functions adjusted to remove the 'Items' accessor. +type OptionalFunc func(name string) bool + +func RewriteGeneratedGogoProtobufFile(name string, extractFn ExtractFunc, optionalFn OptionalFunc, header []byte) error { return rewriteFile(name, header, func(fset *token.FileSet, file *ast.File) error { cmap := ast.NewCommentMap(fset, file, file.Comments) + // transform methods that point to optional maps or slices + for _, d := range file.Decls { + rewriteOptionalMethods(d, optionalFn) + } + // remove types that are already declared decls := []ast.Decl{} for _, d := range file.Decls { @@ -97,6 +106,168 @@ func RewriteGeneratedGogoProtobufFile(name string, extractFn ExtractFunc, header }) } +// rewriteOptionalMethods makes specific mutations to marshaller methods that belong to types identified +// as being "optional" (they may be nil on the wire). This allows protobuf to serialize a map or slice and +// properly discriminate between empty and nil (which is not possible in protobuf). +// TODO: move into upstream gogo-protobuf once https://github.com/gogo/protobuf/issues/181 +// has agreement +func rewriteOptionalMethods(decl ast.Decl, isOptional OptionalFunc) { + switch t := decl.(type) { + case *ast.FuncDecl: + ident, ptr, ok := receiver(t) + if !ok { + return + } + + // correct initialization of the form `m.Field = &OptionalType{}` to + // `m.Field = OptionalType{}` + if t.Name.Name == "Unmarshal" { + ast.Walk(optionalAssignmentVisitor{fn: isOptional}, t.Body) + } + + if !isOptional(ident.Name) { + return + } + + switch t.Name.Name { + case "Unmarshal": + ast.Walk(&optionalItemsVisitor{}, t.Body) + case "MarshalTo", "Size": + ast.Walk(&optionalItemsVisitor{}, t.Body) + fallthrough + case "Marshal": + // if the method has a pointer receiver, set it back to a normal receiver + if ptr { + t.Recv.List[0].Type = ident + } + } + } +} + +type optionalAssignmentVisitor struct { + fn OptionalFunc +} + +// Visit walks the provided node, transforming field initializations of the form +// m.Field = &OptionalType{} -> m.Field = OptionalType{} +func (v optionalAssignmentVisitor) Visit(n ast.Node) ast.Visitor { + switch t := n.(type) { + case *ast.AssignStmt: + if len(t.Lhs) == 1 && len(t.Rhs) == 1 { + if !isFieldSelector(t.Lhs[0], "m", "") { + return nil + } + unary, ok := t.Rhs[0].(*ast.UnaryExpr) + if !ok || unary.Op != token.AND { + return nil + } + composite, ok := unary.X.(*ast.CompositeLit) + if !ok || composite.Type == nil || len(composite.Elts) != 0 { + return nil + } + if ident, ok := composite.Type.(*ast.Ident); ok && v.fn(ident.Name) { + t.Rhs[0] = composite + } + } + return nil + } + return v +} + +type optionalItemsVisitor struct{} + +// Visit walks the provided node, looking for specific patterns to transform that match +// the effective outcome of turning struct{ map[x]y || []x } into map[x]y or []x. +func (v *optionalItemsVisitor) Visit(n ast.Node) ast.Visitor { + switch t := n.(type) { + case *ast.RangeStmt: + if isFieldSelector(t.X, "m", "Items") { + t.X = &ast.Ident{Name: "m"} + } + case *ast.AssignStmt: + if len(t.Lhs) == 1 && len(t.Rhs) == 1 { + switch lhs := t.Lhs[0].(type) { + case *ast.IndexExpr: + if isFieldSelector(lhs.X, "m", "Items") { + lhs.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}} + } + default: + if isFieldSelector(t.Lhs[0], "m", "Items") { + t.Lhs[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}} + } + } + switch rhs := t.Rhs[0].(type) { + case *ast.CallExpr: + if ident, ok := rhs.Fun.(*ast.Ident); ok && ident.Name == "append" { + ast.Walk(v, rhs) + if len(rhs.Args) > 0 { + switch arg := rhs.Args[0].(type) { + case *ast.Ident: + if arg.Name == "m" { + rhs.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}} + } + } + } + return nil + } + } + } + case *ast.IfStmt: + if b, ok := t.Cond.(*ast.BinaryExpr); ok && b.Op == token.EQL { + if isFieldSelector(b.X, "m", "Items") && isIdent(b.Y, "nil") { + b.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}} + } + } + case *ast.IndexExpr: + if isFieldSelector(t.X, "m", "Items") { + t.X = &ast.Ident{Name: "m"} + return nil + } + case *ast.CallExpr: + changed := false + for i := range t.Args { + if isFieldSelector(t.Args[i], "m", "Items") { + t.Args[i] = &ast.Ident{Name: "m"} + changed = true + } + } + if changed { + return nil + } + } + return v +} + +func isFieldSelector(n ast.Expr, name, field string) bool { + s, ok := n.(*ast.SelectorExpr) + if !ok || s.Sel == nil || (field != "" && s.Sel.Name != field) { + return false + } + return isIdent(s.X, name) +} + +func isIdent(n ast.Expr, value string) bool { + ident, ok := n.(*ast.Ident) + return ok && ident.Name == value +} + +func receiver(f *ast.FuncDecl) (ident *ast.Ident, pointer bool, ok bool) { + if f.Recv == nil || len(f.Recv.List) != 1 { + return nil, false, false + } + switch t := f.Recv.List[0].Type.(type) { + case *ast.StarExpr: + identity, ok := t.X.(*ast.Ident) + if !ok { + return nil, false, false + } + return identity, true, true + case *ast.Ident: + return t, false, true + } + return nil, false, false +} + // dropExistingTypeDeclarations removes any type declaration for which extractFn returns true. The function // returns true if the entire declaration should be dropped. func dropExistingTypeDeclarations(decl ast.Decl, extractFn ExtractFunc) bool { diff --git a/cmd/libs/go2idl/parser/parse.go b/cmd/libs/go2idl/parser/parse.go index b82b862d922..8957d92d938 100644 --- a/cmd/libs/go2idl/parser/parse.go +++ b/cmd/libs/go2idl/parser/parse.go @@ -576,7 +576,7 @@ func (b *Builder) walkType(u types.Universe, useName *types.Name, in tc.Type) *t return out case *tc.Named: switch t.Underlying().(type) { - case *tc.Named, *tc.Basic: + case *tc.Named, *tc.Basic, *tc.Map, *tc.Slice: name := tcNameToName(t.String()) out := u.Type(name) if out.Kind != types.Unknown { @@ -591,6 +591,9 @@ func (b *Builder) walkType(u types.Universe, useName *types.Name, in tc.Type) *t // "feature" for users. This flattens those types // together. name := tcNameToName(t.String()) + if name.Name == "OptionalMap" { + fmt.Printf("DEBUG: flattening %T -> %T\n", t, t.Underlying()) + } if out := u.Type(name); out.Kind != types.Unknown { return out // short circuit if we've already made this. } diff --git a/cmd/libs/go2idl/parser/parse_test.go b/cmd/libs/go2idl/parser/parse_test.go index 8d232c604fd..83194b9cc14 100644 --- a/cmd/libs/go2idl/parser/parse_test.go +++ b/cmd/libs/go2idl/parser/parse_test.go @@ -391,8 +391,16 @@ type Interface interface{Method(a, b string) (c, d string)} t.Errorf("type %s not found", n) continue } - if e, a := item.k, thisType.Kind; e != a { - t.Errorf("%v-%s: type kind wrong, wanted %v, got %v (%#v)", nameIndex, n, e, a, thisType) + underlyingType := thisType + if item.k != types.Alias && thisType.Kind == types.Alias { + underlyingType = thisType.Underlying + if underlyingType == nil { + t.Errorf("underlying type %s not found", n) + continue + } + } + if e, a := item.k, underlyingType.Kind; e != a { + t.Errorf("%v-%s: type kind wrong, wanted %v, got %v (%#v)", nameIndex, n, e, a, underlyingType) } if e, a := item.names[nameIndex], namer.Name(thisType); e != a { t.Errorf("%v-%s: Expected %q, got %q", nameIndex, n, e, a)