Merge pull request #14996 from thockin/wait-poll

add wait.PollImmediate() and retool wait tests
This commit is contained in:
CJ Cullen 2015-10-02 16:39:42 -07:00
commit 1b841d26e7
3 changed files with 98 additions and 20 deletions

View File

@ -817,7 +817,7 @@ func TestRollingUpdater_cleanupWithClients(t *testing.T) {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }
if len(fake.Actions()) != len(test.expected) { if len(fake.Actions()) != len(test.expected) {
t.Fatalf("%s: unexpected actions: %v, expected %v", test.name, fake.Actions, test.expected) t.Fatalf("%s: unexpected actions: %v, expected %v", test.name, fake.Actions(), test.expected)
} }
for j, action := range fake.Actions() { for j, action := range fake.Actions() {
if e, a := test.expected[j], action.GetVerb(); e != a { if e, a := test.expected[j], action.GetVerb(); e != a {

View File

@ -45,9 +45,25 @@ type ConditionFunc func() (done bool, err error)
// may be missed if the condition takes too long or the time window is too short. // may be missed if the condition takes too long or the time window is too short.
// If you want to Poll something forever, see PollInfinite. // If you want to Poll something forever, see PollInfinite.
// Poll always waits the interval before the first check of the condition. // Poll always waits the interval before the first check of the condition.
// TODO: create a separate PollImmediate function that does not wait.
func Poll(interval, timeout time.Duration, condition ConditionFunc) error { func Poll(interval, timeout time.Duration, condition ConditionFunc) error {
return WaitFor(poller(interval, timeout), condition) return pollInternal(poller(interval, timeout), condition)
}
func pollInternal(wait WaitFunc, condition ConditionFunc) error {
return WaitFor(wait, condition)
}
func PollImmediate(interval, timeout time.Duration, condition ConditionFunc) error {
return pollImmediateInternal(poller(interval, timeout), condition)
}
func pollImmediateInternal(wait WaitFunc, condition ConditionFunc) error {
done, err := condition()
if err != nil {
return err
}
if done {
return nil
}
return pollInternal(wait, condition)
} }
// PollInfinite polls forever. // PollInfinite polls forever.
@ -59,16 +75,16 @@ func PollInfinite(interval time.Duration, condition ConditionFunc) error {
// should be executed and is closed when the last test should be invoked. // should be executed and is closed when the last test should be invoked.
type WaitFunc func() <-chan struct{} type WaitFunc func() <-chan struct{}
// WaitFor gets a channel from wait(), and then invokes c once for every value // WaitFor gets a channel from wait(), and then invokes fn once for every value
// placed on the channel and once more when the channel is closed. If c // placed on the channel and once more when the channel is closed. If fn
// returns an error the loop ends and that error is returned, and if c returns // returns an error the loop ends and that error is returned, and if fn returns
// true the loop ends and nil is returned. ErrWaitTimeout will be returned if // true the loop ends and nil is returned. ErrWaitTimeout will be returned if
// the channel is closed without c ever returning true. // the channel is closed without fn ever returning true.
func WaitFor(wait WaitFunc, c ConditionFunc) error { func WaitFor(wait WaitFunc, fn ConditionFunc) error {
w := wait() c := wait()
for { for {
_, open := <-w _, open := <-c
ok, err := c() ok, err := fn()
if err != nil { if err != nil {
return err return err
} }

View File

@ -18,6 +18,7 @@ package wait
import ( import (
"errors" "errors"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -45,12 +46,15 @@ DRAIN:
} }
} }
func fakeTicker(count int) WaitFunc { func fakeTicker(max int, used *int32) WaitFunc {
return func() <-chan struct{} { return func() <-chan struct{} {
ch := make(chan struct{}) ch := make(chan struct{})
go func() { go func() {
for i := 0; i < count; i++ { for i := 0; i < max; i++ {
ch <- struct{}{} ch <- struct{}{}
if used != nil {
atomic.AddInt32(used, 1)
}
} }
close(ch) close(ch)
}() }()
@ -58,25 +62,83 @@ func fakeTicker(count int) WaitFunc {
} }
} }
type fakePoller struct {
max int
used int32 // accessed with atomics
}
func (fp *fakePoller) GetWaitFunc(interval, timeout time.Duration) WaitFunc {
return fakeTicker(fp.max, &fp.used)
}
func TestPoll(t *testing.T) { func TestPoll(t *testing.T) {
invocations := 0 invocations := 0
f := ConditionFunc(func() (bool, error) { f := ConditionFunc(func() (bool, error) {
invocations++ invocations++
return true, nil return true, nil
}) })
if err := Poll(time.Microsecond, time.Microsecond, f); err != nil { fp := fakePoller{max: 1}
if err := pollInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err != nil {
t.Fatalf("unexpected error %v", err) t.Fatalf("unexpected error %v", err)
} }
if invocations == 0 { if invocations != 1 {
t.Errorf("Expected at least one invocation, got zero") t.Errorf("Expected exactly one invocation, got %d", invocations)
} }
used := atomic.LoadInt32(&fp.used)
if used != 1 {
t.Errorf("Expected exactly one tick, got %d", used)
}
expectedError := errors.New("Expected error") expectedError := errors.New("Expected error")
f = ConditionFunc(func() (bool, error) { f = ConditionFunc(func() (bool, error) {
return false, expectedError return false, expectedError
}) })
if err := Poll(time.Microsecond, time.Microsecond, f); err == nil || err != expectedError { fp = fakePoller{max: 1}
if err := pollInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err == nil || err != expectedError {
t.Fatalf("Expected error %v, got none %v", expectedError, err) t.Fatalf("Expected error %v, got none %v", expectedError, err)
} }
if invocations != 1 {
t.Errorf("Expected exactly one invocation, got %d", invocations)
}
used = atomic.LoadInt32(&fp.used)
if used != 1 {
t.Errorf("Expected exactly one tick, got %d", used)
}
}
func TestPollImmediate(t *testing.T) {
invocations := 0
f := ConditionFunc(func() (bool, error) {
invocations++
return true, nil
})
fp := fakePoller{max: 0}
if err := pollImmediateInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err != nil {
t.Fatalf("unexpected error %v", err)
}
if invocations != 1 {
t.Errorf("Expected exactly one invocation, got %d", invocations)
}
used := atomic.LoadInt32(&fp.used)
if used != 0 {
t.Errorf("Expected exactly zero ticks, got %d", used)
}
expectedError := errors.New("Expected error")
f = ConditionFunc(func() (bool, error) {
return false, expectedError
})
fp = fakePoller{max: 0}
if err := pollImmediateInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err == nil || err != expectedError {
t.Fatalf("Expected error %v, got none %v", expectedError, err)
}
if invocations != 1 {
t.Errorf("Expected exactly one invocation, got %d", invocations)
}
used = atomic.LoadInt32(&fp.used)
if used != 0 {
t.Errorf("Expected exactly zero ticks, got %d", used)
}
} }
func TestPollForever(t *testing.T) { func TestPollForever(t *testing.T) {
@ -154,7 +216,7 @@ func TestWaitFor(t *testing.T) {
return false, nil return false, nil
}), }),
2, 2,
3, 3, // the contract of WaitFor() says the func is called once more at the end of the wait
true, true,
}, },
"returns immediately on error": { "returns immediately on error": {
@ -169,7 +231,7 @@ func TestWaitFor(t *testing.T) {
} }
for k, c := range testCases { for k, c := range testCases {
invocations = 0 invocations = 0
ticker := fakeTicker(c.Ticks) ticker := fakeTicker(c.Ticks, nil)
err := WaitFor(ticker, c.F) err := WaitFor(ticker, c.F)
switch { switch {
case c.Err && err == nil: case c.Err && err == nil:
@ -180,7 +242,7 @@ func TestWaitFor(t *testing.T) {
continue continue
} }
if invocations != c.Invoked { if invocations != c.Invoked {
t.Errorf("%s: Expected %d invocations, called %d", k, c.Invoked, invocations) t.Errorf("%s: Expected %d invocations, got %d", k, c.Invoked, invocations)
} }
} }
} }