mirror of
				https://github.com/k3s-io/kubernetes.git
				synced 2025-10-31 05:40:42 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			453 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			453 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| /*
 | |
| Copyright 2015 The Kubernetes Authors.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License.
 | |
| */
 | |
| 
 | |
| package protobuf
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"go/ast"
 | |
| 	"go/format"
 | |
| 	"go/parser"
 | |
| 	"go/printer"
 | |
| 	"go/token"
 | |
| 	"io/ioutil"
 | |
| 	"os"
 | |
| 	"reflect"
 | |
| 	"strings"
 | |
| 
 | |
| 	customreflect "k8s.io/apimachinery/third_party/forked/golang/reflect"
 | |
| )
 | |
| 
 | |
| func rewriteFile(name string, header []byte, rewriteFn func(*token.FileSet, *ast.File) error) error {
 | |
| 	fset := token.NewFileSet()
 | |
| 	src, err := ioutil.ReadFile(name)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	file, err := parser.ParseFile(fset, name, src, parser.DeclarationErrors|parser.ParseComments)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if err := rewriteFn(fset, file); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	b := &bytes.Buffer{}
 | |
| 	b.Write(header)
 | |
| 	if err := printer.Fprint(b, fset, file); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	body, err := format.Source(b.Bytes())
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	f, err := os.OpenFile(name, os.O_WRONLY|os.O_TRUNC, 0644)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	defer f.Close()
 | |
| 	if _, err := f.Write(body); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return f.Close()
 | |
| }
 | |
| 
 | |
| // ExtractFunc extracts information from the provided TypeSpec and returns true if the type should be
 | |
| // removed from the destination file.
 | |
| type ExtractFunc func(*ast.TypeSpec) bool
 | |
| 
 | |
| // OptionalFunc returns true if the provided local name is a type that has protobuf.nullable=true
 | |
| // and should have its marshal functions adjusted to remove the 'Items' accessor.
 | |
| type OptionalFunc func(name string) bool
 | |
| 
 | |
| func RewriteGeneratedGogoProtobufFile(name string, extractFn ExtractFunc, optionalFn OptionalFunc, header []byte) error {
 | |
| 	return rewriteFile(name, header, func(fset *token.FileSet, file *ast.File) error {
 | |
| 		cmap := ast.NewCommentMap(fset, file, file.Comments)
 | |
| 
 | |
| 		// transform methods that point to optional maps or slices
 | |
| 		for _, d := range file.Decls {
 | |
| 			rewriteOptionalMethods(d, optionalFn)
 | |
| 		}
 | |
| 
 | |
| 		// remove types that are already declared
 | |
| 		decls := []ast.Decl{}
 | |
| 		for _, d := range file.Decls {
 | |
| 			if dropExistingTypeDeclarations(d, extractFn) {
 | |
| 				continue
 | |
| 			}
 | |
| 			if dropEmptyImportDeclarations(d) {
 | |
| 				continue
 | |
| 			}
 | |
| 			decls = append(decls, d)
 | |
| 		}
 | |
| 		file.Decls = decls
 | |
| 
 | |
| 		// remove unmapped comments
 | |
| 		file.Comments = cmap.Filter(file).Comments()
 | |
| 		return nil
 | |
| 	})
 | |
| }
 | |
| 
 | |
| // rewriteOptionalMethods makes specific mutations to marshaller methods that belong to types identified
 | |
| // as being "optional" (they may be nil on the wire). This allows protobuf to serialize a map or slice and
 | |
| // properly discriminate between empty and nil (which is not possible in protobuf).
 | |
| // TODO: move into upstream gogo-protobuf once https://github.com/gogo/protobuf/issues/181
 | |
| //   has agreement
 | |
| func rewriteOptionalMethods(decl ast.Decl, isOptional OptionalFunc) {
 | |
| 	switch t := decl.(type) {
 | |
| 	case *ast.FuncDecl:
 | |
| 		ident, ptr, ok := receiver(t)
 | |
| 		if !ok {
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		// correct initialization of the form `m.Field = &OptionalType{}` to
 | |
| 		// `m.Field = OptionalType{}`
 | |
| 		if t.Name.Name == "Unmarshal" {
 | |
| 			ast.Walk(optionalAssignmentVisitor{fn: isOptional}, t.Body)
 | |
| 		}
 | |
| 
 | |
| 		if !isOptional(ident.Name) {
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		switch t.Name.Name {
 | |
| 		case "Unmarshal":
 | |
| 			ast.Walk(&optionalItemsVisitor{}, t.Body)
 | |
| 		case "MarshalTo", "Size", "String":
 | |
| 			ast.Walk(&optionalItemsVisitor{}, t.Body)
 | |
| 			fallthrough
 | |
| 		case "Marshal":
 | |
| 			// if the method has a pointer receiver, set it back to a normal receiver
 | |
| 			if ptr {
 | |
| 				t.Recv.List[0].Type = ident
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type optionalAssignmentVisitor struct {
 | |
| 	fn OptionalFunc
 | |
| }
 | |
| 
 | |
| // Visit walks the provided node, transforming field initializations of the form
 | |
| // m.Field = &OptionalType{} -> m.Field = OptionalType{}
 | |
| func (v optionalAssignmentVisitor) Visit(n ast.Node) ast.Visitor {
 | |
| 	switch t := n.(type) {
 | |
| 	case *ast.AssignStmt:
 | |
| 		if len(t.Lhs) == 1 && len(t.Rhs) == 1 {
 | |
| 			if !isFieldSelector(t.Lhs[0], "m", "") {
 | |
| 				return nil
 | |
| 			}
 | |
| 			unary, ok := t.Rhs[0].(*ast.UnaryExpr)
 | |
| 			if !ok || unary.Op != token.AND {
 | |
| 				return nil
 | |
| 			}
 | |
| 			composite, ok := unary.X.(*ast.CompositeLit)
 | |
| 			if !ok || composite.Type == nil || len(composite.Elts) != 0 {
 | |
| 				return nil
 | |
| 			}
 | |
| 			if ident, ok := composite.Type.(*ast.Ident); ok && v.fn(ident.Name) {
 | |
| 				t.Rhs[0] = composite
 | |
| 			}
 | |
| 		}
 | |
| 		return nil
 | |
| 	}
 | |
| 	return v
 | |
| }
 | |
| 
 | |
| type optionalItemsVisitor struct{}
 | |
| 
 | |
| // Visit walks the provided node, looking for specific patterns to transform that match
 | |
| // the effective outcome of turning struct{ map[x]y || []x } into map[x]y or []x.
 | |
| func (v *optionalItemsVisitor) Visit(n ast.Node) ast.Visitor {
 | |
| 	switch t := n.(type) {
 | |
| 	case *ast.RangeStmt:
 | |
| 		if isFieldSelector(t.X, "m", "Items") {
 | |
| 			t.X = &ast.Ident{Name: "m"}
 | |
| 		}
 | |
| 	case *ast.AssignStmt:
 | |
| 		if len(t.Lhs) == 1 && len(t.Rhs) == 1 {
 | |
| 			switch lhs := t.Lhs[0].(type) {
 | |
| 			case *ast.IndexExpr:
 | |
| 				if isFieldSelector(lhs.X, "m", "Items") {
 | |
| 					lhs.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
 | |
| 				}
 | |
| 			default:
 | |
| 				if isFieldSelector(t.Lhs[0], "m", "Items") {
 | |
| 					t.Lhs[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
 | |
| 				}
 | |
| 			}
 | |
| 			switch rhs := t.Rhs[0].(type) {
 | |
| 			case *ast.CallExpr:
 | |
| 				if ident, ok := rhs.Fun.(*ast.Ident); ok && ident.Name == "append" {
 | |
| 					ast.Walk(v, rhs)
 | |
| 					if len(rhs.Args) > 0 {
 | |
| 						switch arg := rhs.Args[0].(type) {
 | |
| 						case *ast.Ident:
 | |
| 							if arg.Name == "m" {
 | |
| 								rhs.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
 | |
| 							}
 | |
| 						}
 | |
| 					}
 | |
| 					return nil
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	case *ast.IfStmt:
 | |
| 		switch cond := t.Cond.(type) {
 | |
| 		case *ast.BinaryExpr:
 | |
| 			if cond.Op == token.EQL {
 | |
| 				if isFieldSelector(cond.X, "m", "Items") && isIdent(cond.Y, "nil") {
 | |
| 					cond.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 		if t.Init != nil {
 | |
| 			// Find form:
 | |
| 			// if err := m[len(m.Items)-1].Unmarshal(data[iNdEx:postIndex]); err != nil {
 | |
| 			// 	return err
 | |
| 			// }
 | |
| 			switch s := t.Init.(type) {
 | |
| 			case *ast.AssignStmt:
 | |
| 				if call, ok := s.Rhs[0].(*ast.CallExpr); ok {
 | |
| 					if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
 | |
| 						if x, ok := sel.X.(*ast.IndexExpr); ok {
 | |
| 							// m[] -> (*m)[]
 | |
| 							if sel2, ok := x.X.(*ast.SelectorExpr); ok {
 | |
| 								if ident, ok := sel2.X.(*ast.Ident); ok && ident.Name == "m" {
 | |
| 									x.X = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
 | |
| 								}
 | |
| 							}
 | |
| 							// len(m.Items) -> len(*m)
 | |
| 							if bin, ok := x.Index.(*ast.BinaryExpr); ok {
 | |
| 								if call2, ok := bin.X.(*ast.CallExpr); ok && len(call2.Args) == 1 {
 | |
| 									if isFieldSelector(call2.Args[0], "m", "Items") {
 | |
| 										call2.Args[0] = &ast.StarExpr{X: &ast.Ident{Name: "m"}}
 | |
| 									}
 | |
| 								}
 | |
| 							}
 | |
| 						}
 | |
| 					}
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	case *ast.IndexExpr:
 | |
| 		if isFieldSelector(t.X, "m", "Items") {
 | |
| 			t.X = &ast.Ident{Name: "m"}
 | |
| 			return nil
 | |
| 		}
 | |
| 	case *ast.CallExpr:
 | |
| 		changed := false
 | |
| 		for i := range t.Args {
 | |
| 			if isFieldSelector(t.Args[i], "m", "Items") {
 | |
| 				t.Args[i] = &ast.Ident{Name: "m"}
 | |
| 				changed = true
 | |
| 			}
 | |
| 		}
 | |
| 		if changed {
 | |
| 			return nil
 | |
| 		}
 | |
| 	}
 | |
| 	return v
 | |
| }
 | |
| 
 | |
| func isFieldSelector(n ast.Expr, name, field string) bool {
 | |
| 	s, ok := n.(*ast.SelectorExpr)
 | |
| 	if !ok || s.Sel == nil || (field != "" && s.Sel.Name != field) {
 | |
| 		return false
 | |
| 	}
 | |
| 	return isIdent(s.X, name)
 | |
| }
 | |
| 
 | |
| func isIdent(n ast.Expr, value string) bool {
 | |
| 	ident, ok := n.(*ast.Ident)
 | |
| 	return ok && ident.Name == value
 | |
| }
 | |
| 
 | |
| func receiver(f *ast.FuncDecl) (ident *ast.Ident, pointer bool, ok bool) {
 | |
| 	if f.Recv == nil || len(f.Recv.List) != 1 {
 | |
| 		return nil, false, false
 | |
| 	}
 | |
| 	switch t := f.Recv.List[0].Type.(type) {
 | |
| 	case *ast.StarExpr:
 | |
| 		identity, ok := t.X.(*ast.Ident)
 | |
| 		if !ok {
 | |
| 			return nil, false, false
 | |
| 		}
 | |
| 		return identity, true, true
 | |
| 	case *ast.Ident:
 | |
| 		return t, false, true
 | |
| 	}
 | |
| 	return nil, false, false
 | |
| }
 | |
| 
 | |
| // dropExistingTypeDeclarations removes any type declaration for which extractFn returns true. The function
 | |
| // returns true if the entire declaration should be dropped.
 | |
| func dropExistingTypeDeclarations(decl ast.Decl, extractFn ExtractFunc) bool {
 | |
| 	switch t := decl.(type) {
 | |
| 	case *ast.GenDecl:
 | |
| 		if t.Tok != token.TYPE {
 | |
| 			return false
 | |
| 		}
 | |
| 		specs := []ast.Spec{}
 | |
| 		for _, s := range t.Specs {
 | |
| 			switch spec := s.(type) {
 | |
| 			case *ast.TypeSpec:
 | |
| 				if extractFn(spec) {
 | |
| 					continue
 | |
| 				}
 | |
| 				specs = append(specs, spec)
 | |
| 			}
 | |
| 		}
 | |
| 		if len(specs) == 0 {
 | |
| 			return true
 | |
| 		}
 | |
| 		t.Specs = specs
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| // dropEmptyImportDeclarations strips any generated but no-op imports from the generated code
 | |
| // to prevent generation from being able to define side-effects.  The function returns true
 | |
| // if the entire declaration should be dropped.
 | |
| func dropEmptyImportDeclarations(decl ast.Decl) bool {
 | |
| 	switch t := decl.(type) {
 | |
| 	case *ast.GenDecl:
 | |
| 		if t.Tok != token.IMPORT {
 | |
| 			return false
 | |
| 		}
 | |
| 		specs := []ast.Spec{}
 | |
| 		for _, s := range t.Specs {
 | |
| 			switch spec := s.(type) {
 | |
| 			case *ast.ImportSpec:
 | |
| 				if spec.Name != nil && spec.Name.Name == "_" {
 | |
| 					continue
 | |
| 				}
 | |
| 				specs = append(specs, spec)
 | |
| 			}
 | |
| 		}
 | |
| 		if len(specs) == 0 {
 | |
| 			return true
 | |
| 		}
 | |
| 		t.Specs = specs
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func RewriteTypesWithProtobufStructTags(name string, structTags map[string]map[string]string) error {
 | |
| 	return rewriteFile(name, []byte{}, func(fset *token.FileSet, file *ast.File) error {
 | |
| 		allErrs := []error{}
 | |
| 
 | |
| 		// set any new struct tags
 | |
| 		for _, d := range file.Decls {
 | |
| 			if errs := updateStructTags(d, structTags, []string{"protobuf"}); len(errs) > 0 {
 | |
| 				allErrs = append(allErrs, errs...)
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if len(allErrs) > 0 {
 | |
| 			var s string
 | |
| 			for _, err := range allErrs {
 | |
| 				s += err.Error() + "\n"
 | |
| 			}
 | |
| 			return errors.New(s)
 | |
| 		}
 | |
| 		return nil
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func updateStructTags(decl ast.Decl, structTags map[string]map[string]string, toCopy []string) []error {
 | |
| 	var errs []error
 | |
| 	t, ok := decl.(*ast.GenDecl)
 | |
| 	if !ok {
 | |
| 		return nil
 | |
| 	}
 | |
| 	if t.Tok != token.TYPE {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	for _, s := range t.Specs {
 | |
| 		spec, ok := s.(*ast.TypeSpec)
 | |
| 		if !ok {
 | |
| 			continue
 | |
| 		}
 | |
| 		typeName := spec.Name.Name
 | |
| 		fieldTags, ok := structTags[typeName]
 | |
| 		if !ok {
 | |
| 			continue
 | |
| 		}
 | |
| 		st, ok := spec.Type.(*ast.StructType)
 | |
| 		if !ok {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		for i := range st.Fields.List {
 | |
| 			f := st.Fields.List[i]
 | |
| 			var name string
 | |
| 			if len(f.Names) == 0 {
 | |
| 				switch t := f.Type.(type) {
 | |
| 				case *ast.Ident:
 | |
| 					name = t.Name
 | |
| 				case *ast.SelectorExpr:
 | |
| 					name = t.Sel.Name
 | |
| 				default:
 | |
| 					errs = append(errs, fmt.Errorf("unable to get name for tag from struct %q, field %#v", spec.Name.Name, t))
 | |
| 					continue
 | |
| 				}
 | |
| 			} else {
 | |
| 				name = f.Names[0].Name
 | |
| 			}
 | |
| 			value, ok := fieldTags[name]
 | |
| 			if !ok {
 | |
| 				continue
 | |
| 			}
 | |
| 			var tags customreflect.StructTags
 | |
| 			if f.Tag != nil {
 | |
| 				oldTags, err := customreflect.ParseStructTags(strings.Trim(f.Tag.Value, "`"))
 | |
| 				if err != nil {
 | |
| 					errs = append(errs, fmt.Errorf("unable to read struct tag from struct %q, field %q: %v", spec.Name.Name, name, err))
 | |
| 					continue
 | |
| 				}
 | |
| 				tags = oldTags
 | |
| 			}
 | |
| 			for _, name := range toCopy {
 | |
| 				// don't overwrite existing tags
 | |
| 				if tags.Has(name) {
 | |
| 					continue
 | |
| 				}
 | |
| 				// append new tags
 | |
| 				if v := reflect.StructTag(value).Get(name); len(v) > 0 {
 | |
| 					tags = append(tags, customreflect.StructTag{Name: name, Value: v})
 | |
| 				}
 | |
| 			}
 | |
| 			if len(tags) == 0 {
 | |
| 				continue
 | |
| 			}
 | |
| 			if f.Tag == nil {
 | |
| 				f.Tag = &ast.BasicLit{}
 | |
| 			}
 | |
| 			f.Tag.Value = tags.String()
 | |
| 		}
 | |
| 	}
 | |
| 	return errs
 | |
| }
 |