diff --git a/pkg/api/helper.go b/pkg/api/helper.go index 515778d7be2..fdb86d351f2 100644 --- a/pkg/api/helper.go +++ b/pkg/api/helper.go @@ -25,16 +25,74 @@ import ( var knownTypes = map[string]reflect.Type{} func init() { - types := []interface{}{ - PodList{}, Pod{}, ReplicationControllerList{}, - ReplicationController{}, ServiceList{}, Service{}, - } + AddKnownTypes( + PodList{}, Pod{}, + ReplicationControllerList{}, ReplicationController{}, + ServiceList{}, Service{}, + ) +} + +func AddKnownTypes(types ...interface{}) { for _, obj := range types { t := reflect.TypeOf(obj) knownTypes[t.Name()] = t } } +// Encode turns the given api object into an appropriate JSON string. +// Will return an error if the object doesn't have an embedded JSONBase. +// Obj may be a pointer to a struct, or a struct. If a struct, a copy +// will be made so that the object's Kind field can be set. If a pointer, +// we change the Kind field, marshal, and then set the kind field back to +// "". Having to keep track of the kind field makes tests very annoying, +// so the rule is it's set only in wire format (json), not when in native +// format. +func Encode(obj interface{}) (data []byte, err error) { + obj = checkPtr(obj) + fieldToReset, err := prepareEncode(obj) + if err != nil { + return nil, err + } + data, err = json.Marshal(obj) + fieldToReset.SetString("") + return +} + +// Just like Encode, but produces indented output. +func EncodeIndent(obj interface{}) (data []byte, err error) { + obj = checkPtr(obj) + fieldToReset, err := prepareEncode(obj) + if err != nil { + return nil, err + } + data, err = json.MarshalIndent(obj, "", " ") + fieldToReset.SetString("") + return +} + +func checkPtr(obj interface{}) interface{} { + v := reflect.ValueOf(obj) + if v.Kind() == reflect.Ptr { + return obj + } + v2 := reflect.New(v.Type()) + v2.Elem().Set(v) + return v2.Interface() +} + +func prepareEncode(obj interface{}) (reflect.Value, error) { + name, jsonBase, err := nameAndJSONBase(obj) + if err != nil { + return reflect.Value{}, err + } + if _, contains := knownTypes[name]; !contains { + return reflect.Value{}, fmt.Errorf("struct %v won't be unmarshalable because it's not in knownTypes", name) + } + kind := jsonBase.FieldByName("Kind") + kind.SetString(name) + return kind, nil +} + // Returns the name of the type (sans pointer), and its kind field. Takes pointer-to-struct.. func nameAndJSONBase(obj interface{}) (string, reflect.Value, error) { v := reflect.ValueOf(obj) @@ -53,22 +111,6 @@ func nameAndJSONBase(obj interface{}) (string, reflect.Value, error) { return name, jsonBase, nil } -// Encode turns the given api object into an appropriate JSON string. -// Will return an error if the object doesn't have an embedded JSONBase. -// Obj must be a pointer to a struct. Note, this sets the object's Kind -// field. -func Encode(obj interface{}) (data []byte, err error) { - name, jsonBase, err := nameAndJSONBase(obj) - if err != nil { - return nil, err - } - if _, contains := knownTypes[name]; !contains { - return nil, fmt.Errorf("struct %v can't be unmarshalled because it's not in knownTypes", name) - } - jsonBase.FieldByName("Kind").SetString(name) - return json.Marshal(obj) -} - // Decode converts a JSON string back into a pointer to an api object. Deduces the type // based upon the Kind field (set by encode). func Decode(data []byte) (interface{}, error) { @@ -88,6 +130,12 @@ func Decode(data []byte) (interface{}, error) { if err != nil { return nil, err } + _, jsonBase, err := nameAndJSONBase(obj) + if err != nil { + return nil, err + } + // Don't leave these set. Track type with go's type. + jsonBase.FieldByName("Kind").SetString("") return obj, nil } @@ -104,10 +152,10 @@ func DecodeInto(data []byte, obj interface{}) error { return err } foundName := jsonBase.FieldByName("Kind").Interface().(string) - if foundName == "" { - jsonBase.FieldByName("Kind").SetString(name) - } else if foundName != name { + if foundName != "" && foundName != name { return fmt.Errorf("data had kind %v, but passed object was of type %v", foundName, name) } + // Don't leave these set. Track type with go's type. + jsonBase.FieldByName("Kind").SetString("") return nil } diff --git a/pkg/api/helper_test.go b/pkg/api/helper_test.go index 93ed887c6cc..d24083161d8 100644 --- a/pkg/api/helper_test.go +++ b/pkg/api/helper_test.go @@ -71,4 +71,34 @@ func TestTypes(t *testing.T) { } } +func TestNonPtr(t *testing.T) { + obj := interface{}(Pod{Labels: map[string]string{"name": "foo"}}) + data, err := Encode(obj) + obj2, err2 := Decode(data) + if err != nil || err2 != nil { + t.Errorf("Failure: %v %v", err2, err2) + } + if _, ok := obj2.(*Pod); !ok { + t.Errorf("Got wrong type") + } + if !reflect.DeepEqual(obj2, &Pod{Labels: map[string]string{"name": "foo"}}) { + t.Errorf("Something changed: %#v", obj2) + } +} + +func TestPtr(t *testing.T) { + obj := interface{}(&Pod{Labels: map[string]string{"name": "foo"}}) + data, err := Encode(obj) + obj2, err2 := Decode(data) + if err != nil || err2 != nil { + t.Errorf("Failure: %v %v", err2, err2) + } + if _, ok := obj2.(*Pod); !ok { + t.Errorf("Got wrong type") + } + if !reflect.DeepEqual(obj2, &Pod{Labels: map[string]string{"name": "foo"}}) { + t.Errorf("Something changed: %#v", obj2) + } +} + // TODO: test rejection of bad JSON.