diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go index 590c17b4c59..760e17066c5 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go @@ -359,18 +359,30 @@ type WaitFunc func(done <-chan struct{}) <-chan struct{} // ErrWaitTimeout will be returned if the channel is closed without fn ever // returning true. func WaitFor(wait WaitFunc, fn ConditionFunc, done <-chan struct{}) error { - c := wait(done) + stopCh := make(chan struct{}) + once := sync.Once{} + closeCh := func() { + once.Do(func() { + close(stopCh) + }) + } + defer closeCh() + c := wait(stopCh) for { - _, open := <-c - ok, err := fn() - if err != nil { - return err - } - if ok { - return nil - } - if !open { - break + select { + case _, open := <-c: + ok, err := fn() + if err != nil { + return err + } + if ok { + return nil + } + if !open { + return ErrWaitTimeout + } + case <-done: + closeCh() } } return ErrWaitTimeout diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go index 24073bb1922..611fa5a0a0e 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go @@ -456,18 +456,46 @@ func TestWaitFor(t *testing.T) { } } -func TestWaitForWithDelay(t *testing.T) { - done := make(chan struct{}) - defer close(done) - WaitFor(poller(time.Millisecond, ForeverTestTimeout), func() (bool, error) { +func TestWaitForWithClosedChannel(t *testing.T) { + stopCh := make(chan struct{}) + close(stopCh) + start := time.Now() + err := WaitFor(poller(ForeverTestTimeout, ForeverTestTimeout), func() (bool, error) { + return false, nil + }, stopCh) + duration := time.Now().Sub(start) + // The WaitFor should return immediately, so the duration is close to 0s. + if duration >= ForeverTestTimeout/2 { + t.Errorf("expected short timeout duration") + } + // The interval of the poller is ForeverTestTimeout, so the WaitFor should always return ErrWaitTimeout. + if err != ErrWaitTimeout { + t.Errorf("expected ErrWaitTimeout from WaitFunc") + } +} + +// TestWaitForClosesStopCh verifies that after the condition func returns true, WaitFor() closes the stop channel it supplies to the WaitFunc. +func TestWaitForClosesStopCh(t *testing.T) { + stopCh := make(chan struct{}) + defer close(stopCh) + waitFunc := poller(time.Millisecond, ForeverTestTimeout) + var doneCh <-chan struct{} + + WaitFor(func(done <-chan struct{}) <-chan struct{} { + doneCh = done + return waitFunc(done) + }, func() (bool, error) { time.Sleep(10 * time.Millisecond) return true, nil - }, done) - // If polling goroutine doesn't see the done signal it will leak timers. + }, stopCh) + // The polling goroutine should be closed after WaitFor returning. select { - case done <- struct{}{}: - case <-time.After(ForeverTestTimeout): - t.Errorf("expected an ack of the done signal.") + case _, ok := <-doneCh: + if ok { + t.Errorf("expected closed channel after WaitFunc returning") + } + default: + t.Errorf("expected an ack of the done signal") } }