diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go index 760e17066c5..e60f45bffd0 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go @@ -351,22 +351,21 @@ type WaitFunc func(done <-chan struct{}) <-chan struct{} // WaitFor continually checks 'fn' as driven by 'wait'. // // WaitFor gets a channel from 'wait()'', and then invokes 'fn' once for every value -// placed on the channel and once more when the channel is closed. +// placed on the channel and once more when the channel is closed. If the channel is closed +// and 'fn' returns false without error, WaitFor returns ErrWaitTimeout. // -// If 'fn' returns an error the loop ends and that error is returned, and if +// If 'fn' returns an error the loop ends and that error is returned. If // 'fn' returns true the loop ends and nil is returned. // -// ErrWaitTimeout will be returned if the channel is closed without fn ever +// ErrWaitTimeout will be returned if the 'done' channel is closed without fn ever // returning true. +// +// When the done channel is closed, because the golang `select` statement is +// "uniform pseudo-random", the `fn` might still run one or multiple time, +// though eventually `WaitFor` will return. func WaitFor(wait WaitFunc, fn ConditionFunc, done <-chan struct{}) error { stopCh := make(chan struct{}) - once := sync.Once{} - closeCh := func() { - once.Do(func() { - close(stopCh) - }) - } - defer closeCh() + defer close(stopCh) c := wait(stopCh) for { select { @@ -382,10 +381,9 @@ func WaitFor(wait WaitFunc, fn ConditionFunc, done <-chan struct{}) error { return ErrWaitTimeout } case <-done: - closeCh() + return ErrWaitTimeout } } - return ErrWaitTimeout } // poller returns a WaitFunc that will send to the channel every interval until diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go index 611fa5a0a0e..89987983d7c 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go @@ -456,11 +456,42 @@ func TestWaitFor(t *testing.T) { } } +// TestWaitForWithEarlyClosingWaitFunc tests WaitFor when the WaitFunc closes its channel. The WaitFor should +// always return ErrWaitTimeout. +func TestWaitForWithEarlyClosingWaitFunc(t *testing.T) { + stopCh := make(chan struct{}) + defer close(stopCh) + + start := time.Now() + err := WaitFor(func(done <-chan struct{}) <-chan struct{} { + c := make(chan struct{}) + close(c) + return c + }, func() (bool, error) { + return false, nil + }, stopCh) + duration := time.Now().Sub(start) + + // The WaitFor should return immediately, so the duration is close to 0s. + if duration >= ForeverTestTimeout/2 { + t.Errorf("expected short timeout duration") + } + if err != ErrWaitTimeout { + t.Errorf("expected ErrWaitTimeout from WaitFunc") + } +} + +// TestWaitForWithClosedChannel tests WaitFor when it receives a closed channel. The WaitFor should +// always return ErrWaitTimeout. func TestWaitForWithClosedChannel(t *testing.T) { stopCh := make(chan struct{}) close(stopCh) + c := make(chan struct{}) + defer close(c) start := time.Now() - err := WaitFor(poller(ForeverTestTimeout, ForeverTestTimeout), func() (bool, error) { + err := WaitFor(func(done <-chan struct{}) <-chan struct{} { + return c + }, func() (bool, error) { return false, nil }, stopCh) duration := time.Now().Sub(start)