From 1ad524ac248a36c92e37f779b2d8a1b83bd23d20 Mon Sep 17 00:00:00 2001 From: Wojciech Tyczynski Date: Sun, 3 Jan 2016 09:56:57 +0100 Subject: [PATCH] Fix wait_test flakes --- pkg/util/wait/wait_test.go | 49 +++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/pkg/util/wait/wait_test.go b/pkg/util/wait/wait_test.go index 54d35637f3c..b28947fd086 100644 --- a/pkg/util/wait/wait_test.go +++ b/pkg/util/wait/wait_test.go @@ -18,6 +18,7 @@ package wait import ( "errors" + "sync" "sync/atomic" "testing" "time" @@ -48,10 +49,17 @@ DRAIN: } } -func fakeTicker(max int, used *int32) WaitFunc { +type fakePoller struct { + max int + used int32 // accessed with atomics + wg sync.WaitGroup +} + +func fakeTicker(max int, used *int32, doneFunc func()) WaitFunc { return func(done <-chan struct{}) <-chan struct{} { ch := make(chan struct{}) go func() { + defer doneFunc() defer close(ch) for i := 0; i < max; i++ { select { @@ -68,13 +76,9 @@ func fakeTicker(max int, used *int32) WaitFunc { } } -type fakePoller struct { - max int - used int32 // accessed with atomics -} - -func (fp *fakePoller) GetWaitFunc(interval, timeout time.Duration) WaitFunc { - return fakeTicker(fp.max, &fp.used) +func (fp *fakePoller) GetWaitFunc() WaitFunc { + fp.wg.Add(1) + return fakeTicker(fp.max, &fp.used, fp.wg.Done) } func TestPoll(t *testing.T) { @@ -83,10 +87,11 @@ func TestPoll(t *testing.T) { invocations++ return true, nil }) - fp := fakePoller{max: 1} - if err := pollInternal(fp.GetWaitFunc(time.Microsecond, time.Second), f); err != nil { + fp := fakePoller{max: 1, wg: sync.WaitGroup{}} + if err := pollInternal(fp.GetWaitFunc(), f); err != nil { t.Fatalf("unexpected error %v", err) } + fp.wg.Wait() if invocations != 1 { t.Errorf("Expected exactly one invocation, got %d", invocations) } @@ -101,10 +106,11 @@ func TestPollError(t *testing.T) { f := ConditionFunc(func() (bool, error) { return false, expectedError }) - fp := fakePoller{max: 1} - if err := pollInternal(fp.GetWaitFunc(time.Microsecond, time.Second), f); err == nil || err != expectedError { + fp := fakePoller{max: 1, wg: sync.WaitGroup{}} + if err := pollInternal(fp.GetWaitFunc(), f); err == nil || err != expectedError { t.Fatalf("Expected error %v, got none %v", expectedError, err) } + fp.wg.Wait() used := atomic.LoadInt32(&fp.used) if used != 1 { t.Errorf("Expected exactly one tick, got %d", used) @@ -117,8 +123,8 @@ func TestPollImmediate(t *testing.T) { invocations++ return true, nil }) - fp := fakePoller{max: 0} - if err := pollImmediateInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err != nil { + fp := fakePoller{max: 0, wg: sync.WaitGroup{}} + if err := pollImmediateInternal(fp.GetWaitFunc(), f); err != nil { t.Fatalf("unexpected error %v", err) } if invocations != 1 { @@ -128,19 +134,18 @@ func TestPollImmediate(t *testing.T) { if used != 0 { t.Errorf("Expected exactly zero ticks, got %d", used) } +} +func TestPollImmediateError(t *testing.T) { expectedError := errors.New("Expected error") - f = ConditionFunc(func() (bool, error) { + f := ConditionFunc(func() (bool, error) { return false, expectedError }) - fp = fakePoller{max: 0} - if err := pollImmediateInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err == nil || err != expectedError { + fp := fakePoller{max: 0, wg: sync.WaitGroup{}} + if err := pollImmediateInternal(fp.GetWaitFunc(), f); err == nil || err != expectedError { t.Fatalf("Expected error %v, got none %v", expectedError, err) } - if invocations != 1 { - t.Errorf("Expected exactly one invocation, got %d", invocations) - } - used = atomic.LoadInt32(&fp.used) + used := atomic.LoadInt32(&fp.used) if used != 0 { t.Errorf("Expected exactly zero ticks, got %d", used) } @@ -236,7 +241,7 @@ func TestWaitFor(t *testing.T) { } for k, c := range testCases { invocations = 0 - ticker := fakeTicker(c.Ticks, nil) + ticker := fakeTicker(c.Ticks, nil, func() {}) err := func() error { done := make(chan struct{}) defer close(done)