diff --git a/pkg/kubectl/rolling_updater_test.go b/pkg/kubectl/rolling_updater_test.go index fd55395d970..41c2d0d1a2e 100644 --- a/pkg/kubectl/rolling_updater_test.go +++ b/pkg/kubectl/rolling_updater_test.go @@ -817,7 +817,7 @@ func TestRollingUpdater_cleanupWithClients(t *testing.T) { t.Errorf("unexpected error: %v", err) } if len(fake.Actions()) != len(test.expected) { - t.Fatalf("%s: unexpected actions: %v, expected %v", test.name, fake.Actions, test.expected) + t.Fatalf("%s: unexpected actions: %v, expected %v", test.name, fake.Actions(), test.expected) } for j, action := range fake.Actions() { if e, a := test.expected[j], action.GetVerb(); e != a { diff --git a/pkg/util/wait/wait.go b/pkg/util/wait/wait.go index 61dab005a1b..1e8b786300a 100644 --- a/pkg/util/wait/wait.go +++ b/pkg/util/wait/wait.go @@ -45,9 +45,25 @@ type ConditionFunc func() (done bool, err error) // may be missed if the condition takes too long or the time window is too short. // If you want to Poll something forever, see PollInfinite. // Poll always waits the interval before the first check of the condition. -// TODO: create a separate PollImmediate function that does not wait. func Poll(interval, timeout time.Duration, condition ConditionFunc) error { - return WaitFor(poller(interval, timeout), condition) + return pollInternal(poller(interval, timeout), condition) +} +func pollInternal(wait WaitFunc, condition ConditionFunc) error { + return WaitFor(wait, condition) +} + +func PollImmediate(interval, timeout time.Duration, condition ConditionFunc) error { + return pollImmediateInternal(poller(interval, timeout), condition) +} +func pollImmediateInternal(wait WaitFunc, condition ConditionFunc) error { + done, err := condition() + if err != nil { + return err + } + if done { + return nil + } + return pollInternal(wait, condition) } // PollInfinite polls forever. @@ -59,16 +75,16 @@ func PollInfinite(interval time.Duration, condition ConditionFunc) error { // should be executed and is closed when the last test should be invoked. type WaitFunc func() <-chan struct{} -// WaitFor gets a channel from wait(), and then invokes c once for every value -// placed on the channel and once more when the channel is closed. If c -// returns an error the loop ends and that error is returned, and if c returns +// WaitFor gets a channel from wait(), and then invokes fn once for every value +// placed on the channel and once more when the channel is closed. If fn +// returns an error the loop ends and that error is returned, and if fn returns // true the loop ends and nil is returned. ErrWaitTimeout will be returned if -// the channel is closed without c ever returning true. -func WaitFor(wait WaitFunc, c ConditionFunc) error { - w := wait() +// the channel is closed without fn ever returning true. +func WaitFor(wait WaitFunc, fn ConditionFunc) error { + c := wait() for { - _, open := <-w - ok, err := c() + _, open := <-c + ok, err := fn() if err != nil { return err } diff --git a/pkg/util/wait/wait_test.go b/pkg/util/wait/wait_test.go index ddc8934b0d4..53eaae72005 100644 --- a/pkg/util/wait/wait_test.go +++ b/pkg/util/wait/wait_test.go @@ -18,6 +18,7 @@ package wait import ( "errors" + "sync/atomic" "testing" "time" @@ -45,12 +46,15 @@ DRAIN: } } -func fakeTicker(count int) WaitFunc { +func fakeTicker(max int, used *int32) WaitFunc { return func() <-chan struct{} { ch := make(chan struct{}) go func() { - for i := 0; i < count; i++ { + for i := 0; i < max; i++ { ch <- struct{}{} + if used != nil { + atomic.AddInt32(used, 1) + } } close(ch) }() @@ -58,25 +62,83 @@ func fakeTicker(count int) WaitFunc { } } +type fakePoller struct { + max int + used int32 // accessed with atomics +} + +func (fp *fakePoller) GetWaitFunc(interval, timeout time.Duration) WaitFunc { + return fakeTicker(fp.max, &fp.used) +} + func TestPoll(t *testing.T) { invocations := 0 f := ConditionFunc(func() (bool, error) { invocations++ return true, nil }) - if err := Poll(time.Microsecond, time.Microsecond, f); err != nil { + fp := fakePoller{max: 1} + if err := pollInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err != nil { t.Fatalf("unexpected error %v", err) } - if invocations == 0 { - t.Errorf("Expected at least one invocation, got zero") + if invocations != 1 { + t.Errorf("Expected exactly one invocation, got %d", invocations) } + used := atomic.LoadInt32(&fp.used) + if used != 1 { + t.Errorf("Expected exactly one tick, got %d", used) + } + expectedError := errors.New("Expected error") f = ConditionFunc(func() (bool, error) { return false, expectedError }) - if err := Poll(time.Microsecond, time.Microsecond, f); err == nil || err != expectedError { + fp = fakePoller{max: 1} + if err := pollInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err == nil || err != expectedError { t.Fatalf("Expected error %v, got none %v", expectedError, err) } + if invocations != 1 { + t.Errorf("Expected exactly one invocation, got %d", invocations) + } + used = atomic.LoadInt32(&fp.used) + if used != 1 { + t.Errorf("Expected exactly one tick, got %d", used) + } +} + +func TestPollImmediate(t *testing.T) { + invocations := 0 + f := ConditionFunc(func() (bool, error) { + invocations++ + return true, nil + }) + fp := fakePoller{max: 0} + if err := pollImmediateInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err != nil { + t.Fatalf("unexpected error %v", err) + } + if invocations != 1 { + t.Errorf("Expected exactly one invocation, got %d", invocations) + } + used := atomic.LoadInt32(&fp.used) + if used != 0 { + t.Errorf("Expected exactly zero ticks, got %d", used) + } + + expectedError := errors.New("Expected error") + f = ConditionFunc(func() (bool, error) { + return false, expectedError + }) + fp = fakePoller{max: 0} + if err := pollImmediateInternal(fp.GetWaitFunc(time.Microsecond, time.Microsecond), f); err == nil || err != expectedError { + t.Fatalf("Expected error %v, got none %v", expectedError, err) + } + if invocations != 1 { + t.Errorf("Expected exactly one invocation, got %d", invocations) + } + used = atomic.LoadInt32(&fp.used) + if used != 0 { + t.Errorf("Expected exactly zero ticks, got %d", used) + } } func TestPollForever(t *testing.T) { @@ -154,7 +216,7 @@ func TestWaitFor(t *testing.T) { return false, nil }), 2, - 3, + 3, // the contract of WaitFor() says the func is called once more at the end of the wait true, }, "returns immediately on error": { @@ -169,7 +231,7 @@ func TestWaitFor(t *testing.T) { } for k, c := range testCases { invocations = 0 - ticker := fakeTicker(c.Ticks) + ticker := fakeTicker(c.Ticks, nil) err := WaitFor(ticker, c.F) switch { case c.Err && err == nil: @@ -180,7 +242,7 @@ func TestWaitFor(t *testing.T) { continue } if invocations != c.Invoked { - t.Errorf("%s: Expected %d invocations, called %d", k, c.Invoked, invocations) + t.Errorf("%s: Expected %d invocations, got %d", k, c.Invoked, invocations) } } }