Use conversion.EnforcePtr() where appropriate

Signed-off-by: Vojtech Vitek (V-Teq) <vvitek@redhat.com>
This commit is contained in:
Vojtech Vitek (V-Teq) 2014-10-28 09:02:29 +01:00
parent 8969f0df7e
commit 90809c270d
5 changed files with 41 additions and 33 deletions

View File

@ -224,11 +224,10 @@ func fieldPtr(v reflect.Value, fieldName string, dest interface{}) error {
if !field.IsValid() { if !field.IsValid() {
return fmt.Errorf("Couldn't find %v field in %#v", fieldName, v.Interface()) return fmt.Errorf("Couldn't find %v field in %#v", fieldName, v.Interface())
} }
v = reflect.ValueOf(dest) v, err := conversion.EnforcePtr(dest)
if v.Kind() != reflect.Ptr { if err != nil {
return fmt.Errorf("dest should be ptr") return err
} }
v = v.Elem()
field = field.Addr() field = field.Addr()
if field.Type().AssignableTo(v.Type()) { if field.Type().AssignableTo(v.Type()) {
v.Set(field) v.Set(field)

View File

@ -213,18 +213,17 @@ func (f FieldMatchingFlags) IsSet(flag FieldMatchingFlags) bool {
// it is not used by Convert() other than storing it in the scope. // it is not used by Convert() other than storing it in the scope.
// Not safe for objects with cyclic references! // Not safe for objects with cyclic references!
func (c *Converter) Convert(src, dest interface{}, flags FieldMatchingFlags, meta *Meta) error { func (c *Converter) Convert(src, dest interface{}, flags FieldMatchingFlags, meta *Meta) error {
dv, sv := reflect.ValueOf(dest), reflect.ValueOf(src) dv, err := EnforcePtr(dest)
if dv.Kind() != reflect.Ptr { if err != nil {
return fmt.Errorf("Need pointer, but got %#v", dest) return err
} }
if sv.Kind() != reflect.Ptr {
return fmt.Errorf("Need pointer, but got %#v", src)
}
dv = dv.Elem()
sv = sv.Elem()
if !dv.CanAddr() { if !dv.CanAddr() {
return fmt.Errorf("Can't write to dest") return fmt.Errorf("Can't write to dest")
} }
sv, err := EnforcePtr(src)
if err != nil {
return err
}
s := &scope{ s := &scope{
converter: c, converter: c,
flags: flags, flags: flags,

View File

@ -130,7 +130,7 @@ func TestGetFloat(t *testing.T) {
for _, test := range tests { for _, test := range tests {
val := GetFloatResource(test.res, test.name, test.def) val := GetFloatResource(test.res, test.name, test.def)
if val != test.expected { if val != test.expected {
t.Errorf("%expected: %d found %d", test.expected, val) t.Errorf("expected: %d found %d", test.expected, val)
} }
} }
} }

View File

@ -19,6 +19,8 @@ package runtime
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"github.com/GoogleCloudPlatform/kubernetes/pkg/conversion"
) )
// GetItemsPtr returns a pointer to the list object's Items member. // GetItemsPtr returns a pointer to the list object's Items member.
@ -26,11 +28,11 @@ import (
// and an error will be returned. // and an error will be returned.
// This function will either return a pointer to a slice, or an error, but not both. // This function will either return a pointer to a slice, or an error, but not both.
func GetItemsPtr(list Object) (interface{}, error) { func GetItemsPtr(list Object) (interface{}, error) {
v := reflect.ValueOf(list) v, err := conversion.EnforcePtr(list)
if !v.IsValid() { if err != nil {
return nil, fmt.Errorf("nil list object") return nil, err
} }
items := v.Elem().FieldByName("Items") items := v.FieldByName("Items")
if !items.IsValid() { if !items.IsValid() {
return nil, fmt.Errorf("no Items field in %#v", list) return nil, fmt.Errorf("no Items field in %#v", list)
} }
@ -47,13 +49,16 @@ func ExtractList(obj Object) ([]Object, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
items := reflect.ValueOf(itemsPtr).Elem() items, err := conversion.EnforcePtr(itemsPtr)
if err != nil {
return nil, err
}
list := make([]Object, items.Len()) list := make([]Object, items.Len())
for i := range list { for i := range list {
raw := items.Index(i) raw := items.Index(i)
item, ok := raw.Addr().Interface().(Object) item, ok := raw.Addr().Interface().(Object)
if !ok { if !ok {
return nil, fmt.Errorf("item in index %v isn't an object: %#v", i, raw.Interface()) return nil, fmt.Errorf("item[%v]: Expected object, got %#v", i, raw.Interface())
} }
list[i] = item list[i] = item
} }
@ -69,21 +74,23 @@ func SetList(list Object, objects []Object) error {
if err != nil { if err != nil {
return err return err
} }
items := reflect.ValueOf(itemsPtr).Elem() items, err := conversion.EnforcePtr(itemsPtr)
if err != nil {
return err
}
slice := reflect.MakeSlice(items.Type(), len(objects), len(objects)) slice := reflect.MakeSlice(items.Type(), len(objects), len(objects))
for i := range objects { for i := range objects {
dest := slice.Index(i) dest := slice.Index(i)
src := reflect.ValueOf(objects[i]) src, err := conversion.EnforcePtr(objects[i])
if !src.IsValid() || src.IsNil() { if err != nil {
return fmt.Errorf("an object was nil") return err
} }
src = src.Elem() // Object is a pointer, but the items in slice are not.
if src.Type().AssignableTo(dest.Type()) { if src.Type().AssignableTo(dest.Type()) {
dest.Set(src) dest.Set(src)
} else if src.Type().ConvertibleTo(dest.Type()) { } else if src.Type().ConvertibleTo(dest.Type()) {
dest.Set(src.Convert(dest.Type())) dest.Set(src.Convert(dest.Type()))
} else { } else {
return fmt.Errorf("wrong type: need %v, got %v", dest.Type(), src.Type()) return fmt.Errorf("item[%v]: Type mismatch: Expected %v, got %v", dest.Type(), src.Type())
} }
} }
items.Set(slice) items.Set(slice)

View File

@ -22,6 +22,7 @@ import (
"reflect" "reflect"
"strconv" "strconv"
"github.com/GoogleCloudPlatform/kubernetes/pkg/conversion"
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime" "github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
"github.com/coreos/go-etcd/etcd" "github.com/coreos/go-etcd/etcd"
) )
@ -169,12 +170,11 @@ func (h *EtcdHelper) ExtractList(key string, slicePtr interface{}, resourceVersi
// decodeNodeList walks the tree of each node in the list and decodes into the specified object // decodeNodeList walks the tree of each node in the list and decodes into the specified object
func (h *EtcdHelper) decodeNodeList(nodes []*etcd.Node, slicePtr interface{}) error { func (h *EtcdHelper) decodeNodeList(nodes []*etcd.Node, slicePtr interface{}) error {
pv := reflect.ValueOf(slicePtr) v, err := conversion.EnforcePtr(slicePtr)
if pv.Type().Kind() != reflect.Ptr || pv.Type().Elem().Kind() != reflect.Slice { if err != nil || v.Kind() != reflect.Slice {
// This should not happen at runtime. // This should not happen at runtime.
panic("need ptr to slice") panic("need ptr to slice")
} }
v := pv.Elem()
for _, node := range nodes { for _, node := range nodes {
if node.Dir { if node.Dir {
if err := h.decodeNodeList(node.Nodes, slicePtr); err != nil { if err := h.decodeNodeList(node.Nodes, slicePtr); err != nil {
@ -230,8 +230,11 @@ func (h *EtcdHelper) bodyAndExtractObj(key string, objPtr runtime.Object, ignore
} }
if err != nil || response.Node == nil || len(response.Node.Value) == 0 { if err != nil || response.Node == nil || len(response.Node.Value) == 0 {
if ignoreNotFound { if ignoreNotFound {
pv := reflect.ValueOf(objPtr) v, err := conversion.EnforcePtr(objPtr)
pv.Elem().Set(reflect.Zero(pv.Type().Elem())) if err != nil {
return "", 0, err
}
v.Set(reflect.Zero(v.Type()))
return "", 0, nil return "", 0, nil
} else if err != nil { } else if err != nil {
return "", 0, err return "", 0, err
@ -313,13 +316,13 @@ type EtcdUpdateFunc func(input runtime.Object) (output runtime.Object, err error
// }) // })
// //
func (h *EtcdHelper) AtomicUpdate(key string, ptrToType runtime.Object, tryUpdate EtcdUpdateFunc) error { func (h *EtcdHelper) AtomicUpdate(key string, ptrToType runtime.Object, tryUpdate EtcdUpdateFunc) error {
pt := reflect.TypeOf(ptrToType) v, err := conversion.EnforcePtr(ptrToType)
if pt.Kind() != reflect.Ptr { if err != nil {
// Panic is appropriate, because this is a programming error. // Panic is appropriate, because this is a programming error.
panic("need ptr to type") panic("need ptr to type")
} }
for { for {
obj := reflect.New(pt.Elem()).Interface().(runtime.Object) obj := reflect.New(v.Type()).Interface().(runtime.Object)
origBody, index, err := h.bodyAndExtractObj(key, obj, true) origBody, index, err := h.bodyAndExtractObj(key, obj, true)
if err != nil { if err != nil {
return err return err