Fix potential goroutine leaks in pollers

This commit is contained in:
Andy Goldstein 2015-10-02 16:48:50 -04:00
parent 7ba48583fa
commit 7999e72659
2 changed files with 36 additions and 12 deletions

View File

@ -48,13 +48,17 @@ type ConditionFunc func() (done bool, err error)
func Poll(interval, timeout time.Duration, condition ConditionFunc) error { func Poll(interval, timeout time.Duration, condition ConditionFunc) error {
return pollInternal(poller(interval, timeout), condition) return pollInternal(poller(interval, timeout), condition)
} }
func pollInternal(wait WaitFunc, condition ConditionFunc) error { func pollInternal(wait WaitFunc, condition ConditionFunc) error {
return WaitFor(wait, condition) done := make(chan struct{})
defer close(done)
return WaitFor(wait, condition, done)
} }
func PollImmediate(interval, timeout time.Duration, condition ConditionFunc) error { func PollImmediate(interval, timeout time.Duration, condition ConditionFunc) error {
return pollImmediateInternal(poller(interval, timeout), condition) return pollImmediateInternal(poller(interval, timeout), condition)
} }
func pollImmediateInternal(wait WaitFunc, condition ConditionFunc) error { func pollImmediateInternal(wait WaitFunc, condition ConditionFunc) error {
done, err := condition() done, err := condition()
if err != nil { if err != nil {
@ -68,20 +72,22 @@ func pollImmediateInternal(wait WaitFunc, condition ConditionFunc) error {
// PollInfinite polls forever. // PollInfinite polls forever.
func PollInfinite(interval time.Duration, condition ConditionFunc) error { func PollInfinite(interval time.Duration, condition ConditionFunc) error {
return WaitFor(poller(interval, 0), condition) done := make(chan struct{})
defer close(done)
return WaitFor(poller(interval, 0), condition, done)
} }
// WaitFunc creates a channel that receives an item every time a test // WaitFunc creates a channel that receives an item every time a test
// should be executed and is closed when the last test should be invoked. // should be executed and is closed when the last test should be invoked.
type WaitFunc func() <-chan struct{} type WaitFunc func(done <-chan struct{}) <-chan struct{}
// WaitFor gets a channel from wait(), and then invokes fn once for every value // WaitFor gets a channel from wait(), and then invokes fn once for every value
// placed on the channel and once more when the channel is closed. If fn // placed on the channel and once more when the channel is closed. If fn
// returns an error the loop ends and that error is returned, and if fn returns // returns an error the loop ends and that error is returned, and if fn returns
// true the loop ends and nil is returned. ErrWaitTimeout will be returned if // true the loop ends and nil is returned. ErrWaitTimeout will be returned if
// the channel is closed without fn ever returning true. // the channel is closed without fn ever returning true.
func WaitFor(wait WaitFunc, fn ConditionFunc) error { func WaitFor(wait WaitFunc, fn ConditionFunc, done <-chan struct{}) error {
c := wait() c := wait(done)
for { for {
_, open := <-c _, open := <-c
ok, err := fn() ok, err := fn()
@ -104,11 +110,15 @@ func WaitFor(wait WaitFunc, fn ConditionFunc) error {
// the channel is closed. If timeout is 0, the channel // the channel is closed. If timeout is 0, the channel
// will never be closed. // will never be closed.
func poller(interval, timeout time.Duration) WaitFunc { func poller(interval, timeout time.Duration) WaitFunc {
return WaitFunc(func() <-chan struct{} { return WaitFunc(func(done <-chan struct{}) <-chan struct{} {
ch := make(chan struct{}) ch := make(chan struct{})
go func() { go func() {
defer close(ch)
tick := time.NewTicker(interval) tick := time.NewTicker(interval)
defer tick.Stop() defer tick.Stop()
var after <-chan time.Time var after <-chan time.Time
if timeout != 0 { if timeout != 0 {
// time.After is more convenient, but it // time.After is more convenient, but it
@ -118,16 +128,19 @@ func poller(interval, timeout time.Duration) WaitFunc {
after = timer.C after = timer.C
defer timer.Stop() defer timer.Stop()
} }
for { for {
select { select {
case <-tick.C: case <-tick.C:
ch <- struct{}{} ch <- struct{}{}
case <-after: case <-after:
close(ch) return
case <-done:
return return
} }
} }
}() }()
return ch return ch
}) })
} }

View File

@ -26,8 +26,10 @@ import (
) )
func TestPoller(t *testing.T) { func TestPoller(t *testing.T) {
done := make(chan struct{})
defer close(done)
w := poller(time.Millisecond, 2*time.Millisecond) w := poller(time.Millisecond, 2*time.Millisecond)
ch := w() ch := w(done)
count := 0 count := 0
DRAIN: DRAIN:
for { for {
@ -47,16 +49,20 @@ DRAIN:
} }
func fakeTicker(max int, used *int32) WaitFunc { func fakeTicker(max int, used *int32) WaitFunc {
return func() <-chan struct{} { return func(done <-chan struct{}) <-chan struct{} {
ch := make(chan struct{}) ch := make(chan struct{})
go func() { go func() {
defer close(ch)
for i := 0; i < max; i++ { for i := 0; i < max; i++ {
ch <- struct{}{} select {
case ch <- struct{}{}:
case <-done:
return
}
if used != nil { if used != nil {
atomic.AddInt32(used, 1) atomic.AddInt32(used, 1)
} }
} }
close(ch)
}() }()
return ch return ch
} }
@ -155,6 +161,7 @@ func TestPollForever(t *testing.T) {
} }
return false, nil return false, nil
}) })
if err := PollInfinite(time.Microsecond, f); err != nil { if err := PollInfinite(time.Microsecond, f); err != nil {
t.Fatalf("unexpected error %v", err) t.Fatalf("unexpected error %v", err)
} }
@ -232,7 +239,11 @@ func TestWaitFor(t *testing.T) {
for k, c := range testCases { for k, c := range testCases {
invocations = 0 invocations = 0
ticker := fakeTicker(c.Ticks, nil) ticker := fakeTicker(c.Ticks, nil)
err := WaitFor(ticker, c.F) err := func() error {
done := make(chan struct{})
defer close(done)
return WaitFor(ticker, c.F, done)
}()
switch { switch {
case c.Err && err == nil: case c.Err && err == nil:
t.Errorf("%s: Expected error, got nil", k) t.Errorf("%s: Expected error, got nil", k)