diff --git a/pkg/util/goroutinemap/goroutinemap.go b/pkg/util/goroutinemap/goroutinemap.go index 277f7809d74..067c47bd9ee 100644 --- a/pkg/util/goroutinemap/goroutinemap.go +++ b/pkg/util/goroutinemap/goroutinemap.go @@ -37,6 +37,11 @@ type GoRoutineMap interface { // removed from the list of executing operations allowing a new operation // to be started with the same name without error. NewGoRoutine(operationName string, operation func() error) error + + // Wait blocks until all operations are completed. This is typically + // necessary during tests - the test should wait until all operations finish + // and evaluate results after that. + Wait() } // NewGoRoutineMap returns a new instance of GoRoutineMap. @@ -49,6 +54,7 @@ func NewGoRoutineMap() GoRoutineMap { type goRoutineMap struct { operations map[string]bool sync.Mutex + wg sync.WaitGroup } func (grm *goRoutineMap) NewGoRoutine(operationName string, operation func() error) error { @@ -60,6 +66,7 @@ func (grm *goRoutineMap) NewGoRoutine(operationName string, operation func() err } grm.operations[operationName] = true + grm.wg.Add(1) go func() { defer grm.operationComplete(operationName) defer runtime.HandleCrash() @@ -70,7 +77,12 @@ func (grm *goRoutineMap) NewGoRoutine(operationName string, operation func() err } func (grm *goRoutineMap) operationComplete(operationName string) { + defer grm.wg.Done() grm.Lock() defer grm.Unlock() delete(grm.operations, operationName) } + +func (grm *goRoutineMap) Wait() { + grm.wg.Wait() +} diff --git a/pkg/util/goroutinemap/goroutinemap_test.go b/pkg/util/goroutinemap/goroutinemap_test.go index 48f7cf6e8b0..fccbe717ae9 100644 --- a/pkg/util/goroutinemap/goroutinemap_test.go +++ b/pkg/util/goroutinemap/goroutinemap_test.go @@ -17,6 +17,7 @@ limitations under the License. package goroutinemap import ( + "fmt" "testing" "time" @@ -195,3 +196,71 @@ func retryWithExponentialBackOff(initialDuration time.Duration, fn wait.Conditio } return wait.ExponentialBackoff(backoff, fn) } + +func Test_NewGoRoutineMap_Positive_WaitEmpty(t *testing.T) { + // Test than Wait() on empty GoRoutineMap always succeeds without blocking + // Arrange + grm := NewGoRoutineMap() + + // Act + waitDoneCh := make(chan interface{}, 1) + go func() { + grm.Wait() + waitDoneCh <- true + }() + + // Assert + // Tolerate 50 milliseconds for goroutine context switches etc. + err := waitChannelWithTimeout(waitDoneCh, 50*time.Millisecond) + if err != nil { + t.Errorf("Error waiting for GoRoutineMap.Wait: %v", err) + } +} + +func Test_NewGoRoutineMap_Positive_Wait(t *testing.T) { + // Test that Wait() really blocks until the last operation succeeds + // Arrange + grm := NewGoRoutineMap() + operationName := "operation-name" + operation1DoneCh := make(chan interface{}, 0 /* bufferSize */) + operation1 := generateWaitFunc(operation1DoneCh) + err := grm.NewGoRoutine(operationName, operation1) + if err != nil { + t.Fatalf("NewGoRoutine failed. Expected: Actual: <%v>", err) + } + + // Act + waitDoneCh := make(chan interface{}, 1) + go func() { + grm.Wait() + waitDoneCh <- true + }() + + // Assert + // Check that Wait() really blocks + err = waitChannelWithTimeout(waitDoneCh, 100*time.Millisecond) + if err == nil { + t.Fatalf("Expected Wait() to block but it returned early") + } + + // Finish the operation + operation1DoneCh <- true + + // check that Wait() finishes in reasonable time + err = waitChannelWithTimeout(waitDoneCh, 50*time.Millisecond) + if err != nil { + t.Fatalf("Error waiting for GoRoutineMap.Wait: %v", err) + } +} + +func waitChannelWithTimeout(ch <-chan interface{}, timeout time.Duration) error { + timer := time.NewTimer(timeout) + + select { + case <-ch: + // Success! + return nil + case <-timer.C: + return fmt.Errorf("timeout after %v", timeout) + } +}