diff --git a/staging/src/k8s.io/apimachinery/pkg/api/resource/quantity.go b/staging/src/k8s.io/apimachinery/pkg/api/resource/quantity.go index 69f1bc336d3..50af8334f01 100644 --- a/staging/src/k8s.io/apimachinery/pkg/api/resource/quantity.go +++ b/staging/src/k8s.io/apimachinery/pkg/api/resource/quantity.go @@ -25,6 +25,8 @@ import ( "strconv" "strings" + cbor "k8s.io/apimachinery/pkg/runtime/serializer/cbor/direct" + inf "gopkg.in/inf.v0" ) @@ -683,6 +685,12 @@ func (q Quantity) MarshalJSON() ([]byte, error) { return result, nil } +func (q Quantity) MarshalCBOR() ([]byte, error) { + // The call to String() should never return the string "" because the receiver's + // address will never be nil. + return cbor.Marshal(q.String()) +} + // ToUnstructured implements the value.UnstructuredConverter interface. func (q Quantity) ToUnstructured() interface{} { return q.String() @@ -711,6 +719,27 @@ func (q *Quantity) UnmarshalJSON(value []byte) error { return nil } +func (q *Quantity) UnmarshalCBOR(value []byte) error { + var s *string + if err := cbor.Unmarshal(value, &s); err != nil { + return err + } + + if s == nil { + q.d.Dec = nil + q.i = int64Amount{} + return nil + } + + parsed, err := ParseQuantity(strings.TrimSpace(*s)) + if err != nil { + return err + } + + *q = parsed + return nil +} + // NewDecimalQuantity returns a new Quantity representing the given // value in the given format. func NewDecimalQuantity(b inf.Dec, format Format) *Quantity { diff --git a/staging/src/k8s.io/apimachinery/pkg/api/resource/quantity_test.go b/staging/src/k8s.io/apimachinery/pkg/api/resource/quantity_test.go index 646caee7b78..c78ff8d6bd0 100644 --- a/staging/src/k8s.io/apimachinery/pkg/api/resource/quantity_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/api/resource/quantity_test.go @@ -27,10 +27,12 @@ import ( "testing" "unicode" + "github.com/google/go-cmp/cmp" fuzz "github.com/google/gofuzz" "github.com/spf13/pflag" - inf "gopkg.in/inf.v0" + + cbor "k8s.io/apimachinery/pkg/runtime/serializer/cbor/direct" ) var ( @@ -1615,3 +1617,88 @@ func ExampleQuantityValue() { // Output: // --mem quantity sets amount of memory (default 1Mi) } + +func TestQuantityUnmarshalCBOR(t *testing.T) { + for _, tc := range []struct { + name string + in []byte + want Quantity + errMessage string + }{ + { + name: "null", + in: []byte{0xf6}, // null + want: Quantity{}, + }, + { + name: "text string input", + in: []byte("\x621M"), // "1M" + want: Quantity{i: int64Amount{value: 1, scale: 6}}, + }, + { + name: "byte string input", + in: []byte("\x421M"), // '1M' + want: Quantity{i: int64Amount{value: 1, scale: 6}}, + }, + { + name: "whitespace", + in: []byte("\x4a \t\n\r1M \t\n\r"), // h'20090a0d314d20090a0d' + want: Quantity{i: int64Amount{value: 1, scale: 6}}, + }, + { + name: "empty byte string", + in: []byte{0x40}, + errMessage: ErrFormatWrong.Error(), + }, + { + name: "empty text string", + in: []byte{0x60}, + errMessage: ErrFormatWrong.Error(), + }, + { + name: "unsupported input type", + in: []byte{0x07}, // 7 + errMessage: "cbor: cannot unmarshal positive integer into Go value of type string", + }, + } { + t.Run(tc.name, func(t *testing.T) { + var got Quantity + if err := got.UnmarshalCBOR(tc.in); err != nil { + if tc.errMessage == "" { + t.Fatalf("want nil error, got: %v", err) + } else if gotMessage := err.Error(); tc.errMessage != gotMessage { + t.Fatalf("want error: %q, got: %q", tc.errMessage, gotMessage) + } + } else if tc.errMessage != "" { + t.Fatalf("got nil error, want: %s", tc.errMessage) + } + + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("unexpected diff:\n%s", diff) + } + }) + } +} + +func TestQuantityRoundtripCBOR(t *testing.T) { + for i := 0; i < 500; i++ { + var initial, final Quantity + fuzzer.Fuzz(&initial) + b, err := cbor.Marshal(initial) + if err != nil { + t.Errorf("error encoding %v: %v", initial, err) + continue + } + err = cbor.Unmarshal(b, &final) + if err != nil { + t.Errorf("%v: error decoding %v: %v", initial, string(b), err) + } + if final.Cmp(initial) != 0 { + diag, err := cbor.Diagnose(b) + if err != nil { + t.Logf("failed to produce diagnostic encoding of 0x%x: %v", b, err) + } + t.Errorf("Expected equal: %v, %v (cbor was '%s')", initial, final, diag) + } + } +}