From 7999e72659192da204905b46cbe97c3eb3996f85 Mon Sep 17 00:00:00 2001 From: Andy Goldstein Date: Fri, 2 Oct 2015 16:48:50 -0400 Subject: [PATCH] Fix potential goroutine leaks in pollers --- pkg/util/wait/wait.go | 27 ++++++++++++++++++++------- pkg/util/wait/wait_test.go | 21 ++++++++++++++++----- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/pkg/util/wait/wait.go b/pkg/util/wait/wait.go index 1e8b786300a..f31349248d1 100644 --- a/pkg/util/wait/wait.go +++ b/pkg/util/wait/wait.go @@ -48,13 +48,17 @@ type ConditionFunc func() (done bool, err error) func Poll(interval, timeout time.Duration, condition ConditionFunc) error { return pollInternal(poller(interval, timeout), condition) } + func pollInternal(wait WaitFunc, condition ConditionFunc) error { - return WaitFor(wait, condition) + done := make(chan struct{}) + defer close(done) + return WaitFor(wait, condition, done) } func PollImmediate(interval, timeout time.Duration, condition ConditionFunc) error { return pollImmediateInternal(poller(interval, timeout), condition) } + func pollImmediateInternal(wait WaitFunc, condition ConditionFunc) error { done, err := condition() if err != nil { @@ -68,20 +72,22 @@ func pollImmediateInternal(wait WaitFunc, condition ConditionFunc) error { // PollInfinite polls forever. func PollInfinite(interval time.Duration, condition ConditionFunc) error { - return WaitFor(poller(interval, 0), condition) + done := make(chan struct{}) + defer close(done) + return WaitFor(poller(interval, 0), condition, done) } // WaitFunc creates a channel that receives an item every time a test // should be executed and is closed when the last test should be invoked. -type WaitFunc func() <-chan struct{} +type WaitFunc func(done <-chan struct{}) <-chan struct{} // WaitFor gets a channel from wait(), and then invokes fn once for every value // placed on the channel and once more when the channel is closed. If fn // returns an error the loop ends and that error is returned, and if fn returns // true the loop ends and nil is returned. ErrWaitTimeout will be returned if // the channel is closed without fn ever returning true. -func WaitFor(wait WaitFunc, fn ConditionFunc) error { - c := wait() +func WaitFor(wait WaitFunc, fn ConditionFunc, done <-chan struct{}) error { + c := wait(done) for { _, open := <-c ok, err := fn() @@ -104,11 +110,15 @@ func WaitFor(wait WaitFunc, fn ConditionFunc) error { // the channel is closed. If timeout is 0, the channel // will never be closed. func poller(interval, timeout time.Duration) WaitFunc { - return WaitFunc(func() <-chan struct{} { + return WaitFunc(func(done <-chan struct{}) <-chan struct{} { ch := make(chan struct{}) + go func() { + defer close(ch) + tick := time.NewTicker(interval) defer tick.Stop() + var after <-chan time.Time if timeout != 0 { // time.After is more convenient, but it @@ -118,16 +128,19 @@ func poller(interval, timeout time.Duration) WaitFunc { after = timer.C defer timer.Stop() } + for { select { case <-tick.C: ch <- struct{}{} case <-after: - close(ch) + return + case <-done: return } } }() + return ch }) } diff --git a/pkg/util/wait/wait_test.go b/pkg/util/wait/wait_test.go index 53eaae72005..315b9b4ca3a 100644 --- a/pkg/util/wait/wait_test.go +++ b/pkg/util/wait/wait_test.go @@ -26,8 +26,10 @@ import ( ) func TestPoller(t *testing.T) { + done := make(chan struct{}) + defer close(done) w := poller(time.Millisecond, 2*time.Millisecond) - ch := w() + ch := w(done) count := 0 DRAIN: for { @@ -47,16 +49,20 @@ DRAIN: } func fakeTicker(max int, used *int32) WaitFunc { - return func() <-chan struct{} { + return func(done <-chan struct{}) <-chan struct{} { ch := make(chan struct{}) go func() { + defer close(ch) for i := 0; i < max; i++ { - ch <- struct{}{} + select { + case ch <- struct{}{}: + case <-done: + return + } if used != nil { atomic.AddInt32(used, 1) } } - close(ch) }() return ch } @@ -155,6 +161,7 @@ func TestPollForever(t *testing.T) { } return false, nil }) + if err := PollInfinite(time.Microsecond, f); err != nil { t.Fatalf("unexpected error %v", err) } @@ -232,7 +239,11 @@ func TestWaitFor(t *testing.T) { for k, c := range testCases { invocations = 0 ticker := fakeTicker(c.Ticks, nil) - err := WaitFor(ticker, c.F) + err := func() error { + done := make(chan struct{}) + defer close(done) + return WaitFor(ticker, c.F, done) + }() switch { case c.Err && err == nil: t.Errorf("%s: Expected error, got nil", k)