Don't make people have to worry about the Kind field.

This commit is contained in:
Daniel Smith 2014-06-20 16:08:54 -07:00
parent 2bcb44b6bd
commit d7b4915111
2 changed files with 101 additions and 23 deletions

View File

@ -25,16 +25,74 @@ import (
var knownTypes = map[string]reflect.Type{}
func init() {
types := []interface{}{
PodList{}, Pod{}, ReplicationControllerList{},
ReplicationController{}, ServiceList{}, Service{},
}
AddKnownTypes(
PodList{}, Pod{},
ReplicationControllerList{}, ReplicationController{},
ServiceList{}, Service{},
)
}
func AddKnownTypes(types ...interface{}) {
for _, obj := range types {
t := reflect.TypeOf(obj)
knownTypes[t.Name()] = t
}
}
// Encode turns the given api object into an appropriate JSON string.
// Will return an error if the object doesn't have an embedded JSONBase.
// Obj may be a pointer to a struct, or a struct. If a struct, a copy
// will be made so that the object's Kind field can be set. If a pointer,
// we change the Kind field, marshal, and then set the kind field back to
// "". Having to keep track of the kind field makes tests very annoying,
// so the rule is it's set only in wire format (json), not when in native
// format.
func Encode(obj interface{}) (data []byte, err error) {
obj = checkPtr(obj)
fieldToReset, err := prepareEncode(obj)
if err != nil {
return nil, err
}
data, err = json.Marshal(obj)
fieldToReset.SetString("")
return
}
// Just like Encode, but produces indented output.
func EncodeIndent(obj interface{}) (data []byte, err error) {
obj = checkPtr(obj)
fieldToReset, err := prepareEncode(obj)
if err != nil {
return nil, err
}
data, err = json.MarshalIndent(obj, "", " ")
fieldToReset.SetString("")
return
}
func checkPtr(obj interface{}) interface{} {
v := reflect.ValueOf(obj)
if v.Kind() == reflect.Ptr {
return obj
}
v2 := reflect.New(v.Type())
v2.Elem().Set(v)
return v2.Interface()
}
func prepareEncode(obj interface{}) (reflect.Value, error) {
name, jsonBase, err := nameAndJSONBase(obj)
if err != nil {
return reflect.Value{}, err
}
if _, contains := knownTypes[name]; !contains {
return reflect.Value{}, fmt.Errorf("struct %v won't be unmarshalable because it's not in knownTypes", name)
}
kind := jsonBase.FieldByName("Kind")
kind.SetString(name)
return kind, nil
}
// Returns the name of the type (sans pointer), and its kind field. Takes pointer-to-struct..
func nameAndJSONBase(obj interface{}) (string, reflect.Value, error) {
v := reflect.ValueOf(obj)
@ -53,22 +111,6 @@ func nameAndJSONBase(obj interface{}) (string, reflect.Value, error) {
return name, jsonBase, nil
}
// Encode turns the given api object into an appropriate JSON string.
// Will return an error if the object doesn't have an embedded JSONBase.
// Obj must be a pointer to a struct. Note, this sets the object's Kind
// field.
func Encode(obj interface{}) (data []byte, err error) {
name, jsonBase, err := nameAndJSONBase(obj)
if err != nil {
return nil, err
}
if _, contains := knownTypes[name]; !contains {
return nil, fmt.Errorf("struct %v can't be unmarshalled because it's not in knownTypes", name)
}
jsonBase.FieldByName("Kind").SetString(name)
return json.Marshal(obj)
}
// Decode converts a JSON string back into a pointer to an api object. Deduces the type
// based upon the Kind field (set by encode).
func Decode(data []byte) (interface{}, error) {
@ -88,6 +130,12 @@ func Decode(data []byte) (interface{}, error) {
if err != nil {
return nil, err
}
_, jsonBase, err := nameAndJSONBase(obj)
if err != nil {
return nil, err
}
// Don't leave these set. Track type with go's type.
jsonBase.FieldByName("Kind").SetString("")
return obj, nil
}
@ -104,10 +152,10 @@ func DecodeInto(data []byte, obj interface{}) error {
return err
}
foundName := jsonBase.FieldByName("Kind").Interface().(string)
if foundName == "" {
jsonBase.FieldByName("Kind").SetString(name)
} else if foundName != name {
if foundName != "" && foundName != name {
return fmt.Errorf("data had kind %v, but passed object was of type %v", foundName, name)
}
// Don't leave these set. Track type with go's type.
jsonBase.FieldByName("Kind").SetString("")
return nil
}

View File

@ -71,4 +71,34 @@ func TestTypes(t *testing.T) {
}
}
func TestNonPtr(t *testing.T) {
obj := interface{}(Pod{Labels: map[string]string{"name": "foo"}})
data, err := Encode(obj)
obj2, err2 := Decode(data)
if err != nil || err2 != nil {
t.Errorf("Failure: %v %v", err2, err2)
}
if _, ok := obj2.(*Pod); !ok {
t.Errorf("Got wrong type")
}
if !reflect.DeepEqual(obj2, &Pod{Labels: map[string]string{"name": "foo"}}) {
t.Errorf("Something changed: %#v", obj2)
}
}
func TestPtr(t *testing.T) {
obj := interface{}(&Pod{Labels: map[string]string{"name": "foo"}})
data, err := Encode(obj)
obj2, err2 := Decode(data)
if err != nil || err2 != nil {
t.Errorf("Failure: %v %v", err2, err2)
}
if _, ok := obj2.(*Pod); !ok {
t.Errorf("Got wrong type")
}
if !reflect.DeepEqual(obj2, &Pod{Labels: map[string]string{"name": "foo"}}) {
t.Errorf("Something changed: %#v", obj2)
}
}
// TODO: test rejection of bad JSON.