From fc9d3c9de2b521c9866dd8da2864a0cc209337ca Mon Sep 17 00:00:00 2001 From: Eric Paris Date: Thu, 17 Sep 2015 16:17:11 -0400 Subject: [PATCH] pkg/util/sets: add Intersection function I actually want to use this over in kubernetes/contrib --- pkg/util/sets/set.go | 23 +++++++++++ pkg/util/sets/set_test.go | 85 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+) diff --git a/pkg/util/sets/set.go b/pkg/util/sets/set.go index 910b254de19..5ac72017ef9 100644 --- a/pkg/util/sets/set.go +++ b/pkg/util/sets/set.go @@ -121,6 +121,29 @@ func (s1 String) Union(s2 String) String { return result } +// Intersection returns a new set which includes the item in BOTH s1 and s2 +// For example: +// s1 = {1, 2} +// s2 = {2, 3} +// s1.Intersection(s2) = {2} +func (s1 String) Intersection(s2 String) String { + var walk, other String + result := NewString() + if s1.Len() < s2.Len() { + walk = s1 + other = s2 + } else { + walk = s2 + other = s1 + } + for key := range walk { + if other.Has(key) { + result.Insert(key) + } + } + return result +} + // IsSuperset returns true if and only if s1 is a superset of s2. func (s1 String) IsSuperset(s2 String) bool { for item := range s2 { diff --git a/pkg/util/sets/set_test.go b/pkg/util/sets/set_test.go index dd401cfd197..69415000a2d 100644 --- a/pkg/util/sets/set_test.go +++ b/pkg/util/sets/set_test.go @@ -183,3 +183,88 @@ func TestStringSetEquals(t *testing.T) { t.Errorf("Expected to be not-equal: %v vs %v", a, b) } } + +func TestStringUnion(t *testing.T) { + tests := []struct { + s1 String + s2 String + expected String + }{ + { + NewString("1", "2", "3", "4"), + NewString("3", "4", "5", "6"), + NewString("1", "2", "3", "4", "5", "6"), + }, + { + NewString("1", "2", "3", "4"), + NewString(), + NewString("1", "2", "3", "4"), + }, + { + NewString(), + NewString("1", "2", "3", "4"), + NewString("1", "2", "3", "4"), + }, + { + NewString(), + NewString(), + NewString(), + }, + } + + for _, test := range tests { + union := test.s1.Union(test.s2) + if union.Len() != test.expected.Len() { + t.Errorf("Expected union.Len()=%d but got %d", test.expected.Len(), union.Len()) + } + + if !union.Equal(test.expected) { + t.Errorf("Expected union.Equal(expected) but not true. union:%v expected:%v", union.List(), test.expected.List()) + } + } +} + +func TestStringIntersection(t *testing.T) { + tests := []struct { + s1 String + s2 String + expected String + }{ + { + NewString("1", "2", "3", "4"), + NewString("3", "4", "5", "6"), + NewString("3", "4"), + }, + { + NewString("1", "2", "3", "4"), + NewString("1", "2", "3", "4"), + NewString("1", "2", "3", "4"), + }, + { + NewString("1", "2", "3", "4"), + NewString(), + NewString(), + }, + { + NewString(), + NewString("1", "2", "3", "4"), + NewString(), + }, + { + NewString(), + NewString(), + NewString(), + }, + } + + for _, test := range tests { + intersection := test.s1.Intersection(test.s2) + if intersection.Len() != test.expected.Len() { + t.Errorf("Expected intersection.Len()=%d but got %d", test.expected.Len(), intersection.Len()) + } + + if !intersection.Equal(test.expected) { + t.Errorf("Expected intersection.Equal(expected) but not true. intersection:%v expected:%v", intersection.List(), test.expected.List()) + } + } +}