diff --git a/pkg/api/resource/quantity.go b/pkg/api/resource/quantity.go index f57acbb9fee..4c6d669e523 100644 --- a/pkg/api/resource/quantity.go +++ b/pkg/api/resource/quantity.go @@ -328,15 +328,28 @@ func (q *Quantity) Cmp(y Quantity) int { } func (q *Quantity) Add(y Quantity) error { - q.Amount.Add(q.Amount, y.Amount) + switch { + case y.Amount == nil: + // Adding 0: do nothing. + case q.Amount == nil: + q.Amount = &inf.Dec{} + return q.Add(y) + default: + q.Amount.Add(q.Amount, y.Amount) + } return nil } func (q *Quantity) Sub(y Quantity) error { - if q.Format != y.Format { - return fmt.Errorf("format mismatch: %v vs. %v", q.Format, y.Format) + switch { + case y.Amount == nil: + // Subtracting 0: do nothing. + case q.Amount == nil: + q.Amount = &inf.Dec{} + return q.Sub(y) + default: + q.Amount.Sub(q.Amount, y.Amount) } - q.Amount.Sub(q.Amount, y.Amount) return nil } diff --git a/pkg/api/resource/quantity_test.go b/pkg/api/resource/quantity_test.go index da8858ea97d..c600294c60d 100644 --- a/pkg/api/resource/quantity_test.go +++ b/pkg/api/resource/quantity_test.go @@ -535,3 +535,47 @@ func TestQFlagIsPFlag(t *testing.T) { t.Errorf("Unexpected result %v != %v", e, a) } } + +func TestSub(t *testing.T) { + tests := []struct { + a Quantity + b Quantity + expected Quantity + }{ + {Quantity{dec(10, 0), DecimalSI}, Quantity{dec(1, 1), DecimalSI}, Quantity{dec(0, 0), DecimalSI}}, + {Quantity{dec(10, 0), DecimalSI}, Quantity{dec(1, 0), BinarySI}, Quantity{dec(9, 0), DecimalSI}}, + {Quantity{dec(10, 0), BinarySI}, Quantity{dec(1, 0), DecimalSI}, Quantity{dec(9, 0), BinarySI}}, + {Quantity{nil, DecimalSI}, Quantity{dec(50, 0), DecimalSI}, Quantity{dec(-50, 0), DecimalSI}}, + {Quantity{dec(50, 0), DecimalSI}, Quantity{nil, DecimalSI}, Quantity{dec(50, 0), DecimalSI}}, + {Quantity{nil, DecimalSI}, Quantity{nil, DecimalSI}, Quantity{dec(0, 0), DecimalSI}}, + } + + for i, test := range tests { + test.a.Sub(test.b) + if test.a.Cmp(test.expected) != 0 { + t.Errorf("[%d] Expected %q, got %q", i, test.expected.String(), test.a.String()) + } + } +} + +func TestAdd(t *testing.T) { + tests := []struct { + a Quantity + b Quantity + expected Quantity + }{ + {Quantity{dec(10, 0), DecimalSI}, Quantity{dec(1, 1), DecimalSI}, Quantity{dec(20, 0), DecimalSI}}, + {Quantity{dec(10, 0), DecimalSI}, Quantity{dec(1, 0), BinarySI}, Quantity{dec(11, 0), DecimalSI}}, + {Quantity{dec(10, 0), BinarySI}, Quantity{dec(1, 0), DecimalSI}, Quantity{dec(11, 0), BinarySI}}, + {Quantity{nil, DecimalSI}, Quantity{dec(50, 0), DecimalSI}, Quantity{dec(50, 0), DecimalSI}}, + {Quantity{dec(50, 0), DecimalSI}, Quantity{nil, DecimalSI}, Quantity{dec(50, 0), DecimalSI}}, + {Quantity{nil, DecimalSI}, Quantity{nil, DecimalSI}, Quantity{dec(0, 0), DecimalSI}}, + } + + for i, test := range tests { + test.a.Add(test.b) + if test.a.Cmp(test.expected) != 0 { + t.Errorf("[%d] Expected %q, got %q", i, test.expected.String(), test.a.String()) + } + } +}