Merge pull request #101668 from tkashem/wait-poll-with-context

apimachinery: add context bound polling
This commit is contained in:
Kubernetes Prow Robot 2021-05-06 14:14:02 -07:00 committed by GitHub
commit d110c2dd6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 660 additions and 86 deletions

View File

@ -205,10 +205,29 @@ var ErrWaitTimeout = errors.New("timed out waiting for the condition")
// if the loop should be aborted.
type ConditionFunc func() (done bool, err error)
// ConditionWithContextFunc returns true if the condition is satisfied, or an error
// if the loop should be aborted.
//
// The caller passes along a context that can be used by the condition function.
type ConditionWithContextFunc func(context.Context) (done bool, err error)
// WithContext converts a ConditionFunc into a ConditionWithContextFunc
func (cf ConditionFunc) WithContext() ConditionWithContextFunc {
return func(context.Context) (done bool, err error) {
return cf()
}
}
// runConditionWithCrashProtection runs a ConditionFunc with crash protection
func runConditionWithCrashProtection(condition ConditionFunc) (bool, error) {
return runConditionWithCrashProtectionWithContext(context.TODO(), condition.WithContext())
}
// runConditionWithCrashProtectionWithContext runs a
// ConditionWithContextFunc with crash protection.
func runConditionWithCrashProtectionWithContext(ctx context.Context, condition ConditionWithContextFunc) (bool, error) {
defer runtime.HandleCrash()
return condition()
return condition(ctx)
}
// Backoff holds parameters applied to a Backoff function.
@ -418,13 +437,62 @@ func ExponentialBackoff(backoff Backoff, condition ConditionFunc) error {
//
// If you want to Poll something forever, see PollInfinite.
func Poll(interval, timeout time.Duration, condition ConditionFunc) error {
return pollInternal(poller(interval, timeout), condition)
return PollWithContext(context.Background(), interval, timeout, condition.WithContext())
}
func pollInternal(wait WaitFunc, condition ConditionFunc) error {
done := make(chan struct{})
defer close(done)
return WaitFor(wait, condition, done)
// PollWithContext tries a condition func until it returns true, an error,
// or when the context expires or the timeout is reached, whichever
// happens first.
//
// PollWithContext always waits the interval before the run of 'condition'.
// 'condition' will always be invoked at least once.
//
// Some intervals may be missed if the condition takes too long or the time
// window is too short.
//
// If you want to Poll something forever, see PollInfinite.
func PollWithContext(ctx context.Context, interval, timeout time.Duration, condition ConditionWithContextFunc) error {
return poll(ctx, false, poller(interval, timeout), condition)
}
// PollUntil tries a condition func until it returns true, an error or stopCh is
// closed.
//
// PollUntil always waits interval before the first run of 'condition'.
// 'condition' will always be invoked at least once.
func PollUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error {
ctx, cancel := contextForChannel(stopCh)
defer cancel()
return PollUntilWithContext(ctx, interval, condition.WithContext())
}
// PollUntilWithContext tries a condition func until it returns true,
// an error or the specified context is cancelled or expired.
//
// PollUntilWithContext always waits interval before the first run of 'condition'.
// 'condition' will always be invoked at least once.
func PollUntilWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error {
return poll(ctx, false, poller(interval, 0), condition)
}
// PollInfinite tries a condition func until it returns true or an error
//
// PollInfinite always waits the interval before the run of 'condition'.
//
// Some intervals may be missed if the condition takes too long or the time
// window is too short.
func PollInfinite(interval time.Duration, condition ConditionFunc) error {
return PollInfiniteWithContext(context.Background(), interval, condition.WithContext())
}
// PollInfiniteWithContext tries a condition func until it returns true or an error
//
// PollInfiniteWithContext always waits the interval before the run of 'condition'.
//
// Some intervals may be missed if the condition takes too long or the time
// window is too short.
func PollInfiniteWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error {
return poll(ctx, false, poller(interval, 0), condition)
}
// PollImmediate tries a condition func until it returns true, an error, or the timeout
@ -438,30 +506,40 @@ func pollInternal(wait WaitFunc, condition ConditionFunc) error {
//
// If you want to immediately Poll something forever, see PollImmediateInfinite.
func PollImmediate(interval, timeout time.Duration, condition ConditionFunc) error {
return pollImmediateInternal(poller(interval, timeout), condition)
return PollImmediateWithContext(context.Background(), interval, timeout, condition.WithContext())
}
func pollImmediateInternal(wait WaitFunc, condition ConditionFunc) error {
done, err := runConditionWithCrashProtection(condition)
if err != nil {
return err
}
if done {
return nil
}
return pollInternal(wait, condition)
}
// PollInfinite tries a condition func until it returns true or an error
// PollImmediateWithContext tries a condition func until it returns true, an error,
// or the timeout is reached or the specified context expires, whichever happens first.
//
// PollInfinite always waits the interval before the run of 'condition'.
// PollImmediateWithContext always checks 'condition' before waiting for the interval.
// 'condition' will always be invoked at least once.
//
// Some intervals may be missed if the condition takes too long or the time
// window is too short.
func PollInfinite(interval time.Duration, condition ConditionFunc) error {
done := make(chan struct{})
defer close(done)
return PollUntil(interval, condition, done)
//
// If you want to immediately Poll something forever, see PollImmediateInfinite.
func PollImmediateWithContext(ctx context.Context, interval, timeout time.Duration, condition ConditionWithContextFunc) error {
return poll(ctx, true, poller(interval, timeout), condition)
}
// PollImmediateUntil tries a condition func until it returns true, an error or stopCh is closed.
//
// PollImmediateUntil runs the 'condition' before waiting for the interval.
// 'condition' will always be invoked at least once.
func PollImmediateUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error {
ctx, cancel := contextForChannel(stopCh)
defer cancel()
return PollImmediateUntilWithContext(ctx, interval, condition.WithContext())
}
// PollImmediateUntilWithContext tries a condition func until it returns true,
// an error or the specified context is cancelled or expired.
//
// PollImmediateUntilWithContext runs the 'condition' before waiting for the interval.
// 'condition' will always be invoked at least once.
func PollImmediateUntilWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error {
return poll(ctx, true, poller(interval, 0), condition)
}
// PollImmediateInfinite tries a condition func until it returns true or an error
@ -471,44 +549,46 @@ func PollInfinite(interval time.Duration, condition ConditionFunc) error {
// Some intervals may be missed if the condition takes too long or the time
// window is too short.
func PollImmediateInfinite(interval time.Duration, condition ConditionFunc) error {
done, err := runConditionWithCrashProtection(condition)
if err != nil {
return err
}
if done {
return nil
}
return PollInfinite(interval, condition)
return PollImmediateInfiniteWithContext(context.Background(), interval, condition.WithContext())
}
// PollUntil tries a condition func until it returns true, an error or stopCh is
// closed.
// PollImmediateInfiniteWithContext tries a condition func until it returns true
// or an error or the specified context gets cancelled or expired.
//
// PollUntil always waits interval before the first run of 'condition'.
// 'condition' will always be invoked at least once.
func PollUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error {
ctx, cancel := contextForChannel(stopCh)
defer cancel()
return WaitFor(poller(interval, 0), condition, ctx.Done())
// PollImmediateInfiniteWithContext runs the 'condition' before waiting for the interval.
//
// Some intervals may be missed if the condition takes too long or the time
// window is too short.
func PollImmediateInfiniteWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error {
return poll(ctx, true, poller(interval, 0), condition)
}
// PollImmediateUntil tries a condition func until it returns true, an error or stopCh is closed.
//
// PollImmediateUntil runs the 'condition' before waiting for the interval.
// 'condition' will always be invoked at least once.
func PollImmediateUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error {
done, err := condition()
if err != nil {
return err
}
if done {
return nil
// Internally used, each of the the public 'Poll*' function defined in this
// package should invoke this internal function with appropriate parameters.
// ctx: the context specified by the caller, for infinite polling pass
// a context that never gets cancelled or expired.
// immediate: if true, the 'condition' will be invoked before waiting for the interval,
// in this case 'condition' will always be invoked at least once.
// wait: user specified WaitFunc function that controls at what interval the condition
// function should be invoked periodically and whether it is bound by a timeout.
// condition: user specified ConditionWithContextFunc function.
func poll(ctx context.Context, immediate bool, wait WaitWithContextFunc, condition ConditionWithContextFunc) error {
if immediate {
done, err := runConditionWithCrashProtectionWithContext(ctx, condition)
if err != nil {
return err
}
if done {
return nil
}
}
select {
case <-stopCh:
case <-ctx.Done():
// returning ctx.Err() will break backward compatibility
return ErrWaitTimeout
default:
return PollUntil(interval, condition, stopCh)
return WaitForWithContext(ctx, wait, condition)
}
}
@ -516,6 +596,20 @@ func PollImmediateUntil(interval time.Duration, condition ConditionFunc, stopCh
// should be executed and is closed when the last test should be invoked.
type WaitFunc func(done <-chan struct{}) <-chan struct{}
// WithContext converts the WaitFunc to an equivalent WaitWithContextFunc
func (w WaitFunc) WithContext() WaitWithContextFunc {
return func(ctx context.Context) <-chan struct{} {
return w(ctx.Done())
}
}
// WaitWithContextFunc creates a channel that receives an item every time a test
// should be executed and is closed when the last test should be invoked.
//
// When the specified context gets cancelled or expires the function
// stops sending item and returns immediately.
type WaitWithContextFunc func(ctx context.Context) <-chan struct{}
// WaitFor continually checks 'fn' as driven by 'wait'.
//
// WaitFor gets a channel from 'wait()'', and then invokes 'fn' once for every value
@ -532,13 +626,35 @@ type WaitFunc func(done <-chan struct{}) <-chan struct{}
// "uniform pseudo-random", the `fn` might still run one or multiple time,
// though eventually `WaitFor` will return.
func WaitFor(wait WaitFunc, fn ConditionFunc, done <-chan struct{}) error {
stopCh := make(chan struct{})
defer close(stopCh)
c := wait(stopCh)
ctx, cancel := contextForChannel(done)
defer cancel()
return WaitForWithContext(ctx, wait.WithContext(), fn.WithContext())
}
// WaitForWithContext continually checks 'fn' as driven by 'wait'.
//
// WaitForWithContext gets a channel from 'wait()'', and then invokes 'fn'
// once for every value placed on the channel and once more when the
// channel is closed. If the channel is closed and 'fn'
// returns false without error, WaitForWithContext returns ErrWaitTimeout.
//
// If 'fn' returns an error the loop ends and that error is returned. If
// 'fn' returns true the loop ends and nil is returned.
//
// context.Canceled will be returned if the ctx.Done() channel is closed
// without fn ever returning true.
//
// When the ctx.Done() channel is closed, because the golang `select` statement is
// "uniform pseudo-random", the `fn` might still run one or multiple times,
// though eventually `WaitForWithContext` will return.
func WaitForWithContext(ctx context.Context, wait WaitWithContextFunc, fn ConditionWithContextFunc) error {
waitCtx, cancel := context.WithCancel(context.Background())
defer cancel()
c := wait(waitCtx)
for {
select {
case _, open := <-c:
ok, err := runConditionWithCrashProtection(fn)
ok, err := runConditionWithCrashProtectionWithContext(ctx, fn)
if err != nil {
return err
}
@ -548,7 +664,8 @@ func WaitFor(wait WaitFunc, fn ConditionFunc, done <-chan struct{}) error {
if !open {
return ErrWaitTimeout
}
case <-done:
case <-ctx.Done():
// returning ctx.Err() will break backward compatibility
return ErrWaitTimeout
}
}
@ -564,8 +681,8 @@ func WaitFor(wait WaitFunc, fn ConditionFunc, done <-chan struct{}) error {
//
// Output ticks are not buffered. If the channel is not ready to receive an
// item, the tick is skipped.
func poller(interval, timeout time.Duration) WaitFunc {
return WaitFunc(func(done <-chan struct{}) <-chan struct{} {
func poller(interval, timeout time.Duration) WaitWithContextFunc {
return WaitWithContextFunc(func(ctx context.Context) <-chan struct{} {
ch := make(chan struct{})
go func() {
@ -595,7 +712,7 @@ func poller(interval, timeout time.Duration) WaitFunc {
}
case <-after:
return
case <-done:
case <-ctx.Done():
return
}
}

View File

@ -282,10 +282,10 @@ func TestExponentialBackoff(t *testing.T) {
}
func TestPoller(t *testing.T) {
done := make(chan struct{})
defer close(done)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
w := poller(time.Millisecond, 2*time.Millisecond)
ch := w(done)
ch := w(ctx)
count := 0
DRAIN:
for {
@ -343,7 +343,10 @@ func TestPoll(t *testing.T) {
return true, nil
})
fp := fakePoller{max: 1}
if err := pollInternal(fp.GetWaitFunc(), f); err != nil {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := poll(ctx, false, fp.GetWaitFunc().WithContext(), f.WithContext()); err != nil {
t.Fatalf("unexpected error %v", err)
}
fp.wg.Wait()
@ -362,7 +365,10 @@ func TestPollError(t *testing.T) {
return false, expectedError
})
fp := fakePoller{max: 1}
if err := pollInternal(fp.GetWaitFunc(), f); err == nil || err != expectedError {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := poll(ctx, false, fp.GetWaitFunc().WithContext(), f.WithContext()); err == nil || err != expectedError {
t.Fatalf("Expected error %v, got none %v", expectedError, err)
}
fp.wg.Wait()
@ -379,7 +385,10 @@ func TestPollImmediate(t *testing.T) {
return true, nil
})
fp := fakePoller{max: 0}
if err := pollImmediateInternal(fp.GetWaitFunc(), f); err != nil {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := poll(ctx, true, fp.GetWaitFunc().WithContext(), f.WithContext()); err != nil {
t.Fatalf("unexpected error %v", err)
}
// We don't need to wait for fp.wg, as pollImmediate shouldn't call WaitFunc at all.
@ -398,7 +407,10 @@ func TestPollImmediateError(t *testing.T) {
return false, expectedError
})
fp := fakePoller{max: 0}
if err := pollImmediateInternal(fp.GetWaitFunc(), f); err == nil || err != expectedError {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := poll(ctx, true, fp.GetWaitFunc().WithContext(), f.WithContext()); err == nil || err != expectedError {
t.Fatalf("Expected error %v, got none %v", expectedError, err)
}
// We don't need to wait for fp.wg, as pollImmediate shouldn't call WaitFunc at all.
@ -567,28 +579,24 @@ func TestWaitForWithClosedChannel(t *testing.T) {
}
}
// TestWaitForClosesStopCh verifies that after the condition func returns true, WaitFor() closes the stop channel it supplies to the WaitFunc.
func TestWaitForClosesStopCh(t *testing.T) {
stopCh := make(chan struct{})
defer close(stopCh)
// TestWaitForWithContextCancelsContext verifies that after the condition func returns true,
// WaitForWithContext cancels the context it supplies to the WaitWithContextFunc.
func TestWaitForWithContextCancelsContext(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
waitFunc := poller(time.Millisecond, ForeverTestTimeout)
var doneCh <-chan struct{}
WaitFor(func(done <-chan struct{}) <-chan struct{} {
doneCh = done
return waitFunc(done)
}, func() (bool, error) {
var ctxPassedToWait context.Context
WaitForWithContext(ctx, func(ctx context.Context) <-chan struct{} {
ctxPassedToWait = ctx
return waitFunc(ctx)
}, func(ctx context.Context) (bool, error) {
time.Sleep(10 * time.Millisecond)
return true, nil
}, stopCh)
// The polling goroutine should be closed after WaitFor returning.
select {
case _, ok := <-doneCh:
if ok {
t.Errorf("expected closed channel after WaitFunc returning")
}
default:
t.Errorf("expected an ack of the done signal")
})
// The polling goroutine should be closed after WaitForWithContext returning.
if ctxPassedToWait.Err() != context.Canceled {
t.Errorf("expected the context passed to WaitForWithContext to be closed with: %v, but got: %v", context.Canceled, ctxPassedToWait.Err())
}
}
@ -873,3 +881,452 @@ func TestExponentialBackoffWithContext(t *testing.T) {
})
}
}
func TestPollImmediateUntilWithContext(t *testing.T) {
fakeErr := errors.New("my error")
tests := []struct {
name string
condition func(int) ConditionWithContextFunc
context func() (context.Context, context.CancelFunc)
cancelContextAfterNthAttempt int
errExpected error
attemptsExpected int
}{
{
name: "condition throws error on immediate attempt, no retry is attempted",
condition: func(int) ConditionWithContextFunc {
return func(context.Context) (done bool, err error) {
return false, fakeErr
}
},
context: func() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
},
errExpected: fakeErr,
attemptsExpected: 1,
},
{
name: "condition returns done=true on immediate attempt, no retry is attempted",
condition: func(int) ConditionWithContextFunc {
return func(context.Context) (done bool, err error) {
return true, nil
}
},
context: func() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
},
errExpected: nil,
attemptsExpected: 1,
},
{
name: "condition returns done=false on immediate attempt, context is already cancelled, no retry is attempted",
condition: func(int) ConditionWithContextFunc {
return func(context.Context) (done bool, err error) {
return false, nil
}
},
context: func() (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
return ctx, cancel
},
errExpected: ErrWaitTimeout,
attemptsExpected: 1,
},
{
name: "condition returns done=false on immediate attempt, context is not cancelled, retry is attempted",
condition: func(attempts int) ConditionWithContextFunc {
return func(context.Context) (done bool, err error) {
// let first 3 attempts fail and the last one succeed
if attempts <= 3 {
return false, nil
}
return true, nil
}
},
context: func() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
},
errExpected: nil,
attemptsExpected: 4,
},
{
name: "condition always returns done=false, context gets cancelled after N attempts",
condition: func(attempts int) ConditionWithContextFunc {
return func(ctx context.Context) (done bool, err error) {
return false, nil
}
},
context: func() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
},
cancelContextAfterNthAttempt: 4,
errExpected: ErrWaitTimeout,
attemptsExpected: 4,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx, cancel := test.context()
defer cancel()
var attempts int
conditionWrapper := func(ctx context.Context) (done bool, err error) {
attempts++
defer func() {
if test.cancelContextAfterNthAttempt == attempts {
cancel()
}
}()
c := test.condition(attempts)
return c(ctx)
}
err := PollImmediateUntilWithContext(ctx, time.Millisecond, conditionWrapper)
if test.errExpected != err {
t.Errorf("Expected error: %v, but got: %v", test.errExpected, err)
}
if test.attemptsExpected != attempts {
t.Errorf("Expected ConditionFunc to be invoked: %d times, but got: %d", test.attemptsExpected, attempts)
}
})
}
}
func TestWaitForWithContext(t *testing.T) {
fakeErr := errors.New("fake error")
tests := []struct {
name string
context func() (context.Context, context.CancelFunc)
condition ConditionWithContextFunc
waitFunc func() WaitFunc
attemptsExpected int
errExpected error
}{
{
name: "condition returns done=true on first attempt, no retry is attempted",
context: func() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return true, nil
}),
waitFunc: func() WaitFunc { return fakeTicker(2, nil, func() {}) },
attemptsExpected: 1,
errExpected: nil,
},
{
name: "condition always returns done=false, timeout error expected",
context: func() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return false, nil
}),
waitFunc: func() WaitFunc { return fakeTicker(2, nil, func() {}) },
// the contract of WaitForWithContext() says the func is called once more at the end of the wait
attemptsExpected: 3,
errExpected: ErrWaitTimeout,
},
{
name: "condition returns an error on first attempt, the error is returned",
context: func() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return false, fakeErr
}),
waitFunc: func() WaitFunc { return fakeTicker(2, nil, func() {}) },
attemptsExpected: 1,
errExpected: fakeErr,
},
{
name: "context is cancelled, context cancelled error expected",
context: func() (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
return ctx, cancel
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return false, nil
}),
waitFunc: func() WaitFunc {
return func(done <-chan struct{}) <-chan struct{} {
ch := make(chan struct{})
// never tick on this channel
return ch
}
},
attemptsExpected: 0,
errExpected: ErrWaitTimeout,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var attempts int
conditionWrapper := func(ctx context.Context) (done bool, err error) {
attempts++
return test.condition(ctx)
}
ticker := test.waitFunc()
err := func() error {
ctx, cancel := test.context()
defer cancel()
return WaitForWithContext(ctx, ticker.WithContext(), conditionWrapper)
}()
if test.errExpected != err {
t.Errorf("Expected error: %v, but got: %v", test.errExpected, err)
}
if test.attemptsExpected != attempts {
t.Errorf("Expected %d invocations, got %d", test.attemptsExpected, attempts)
}
})
}
}
func TestPollInternal(t *testing.T) {
fakeErr := errors.New("fake error")
tests := []struct {
name string
context func() (context.Context, context.CancelFunc)
immediate bool
waitFunc func() WaitFunc
condition ConditionWithContextFunc
cancelContextAfter int
attemptsExpected int
errExpected error
}{
{
name: "immediate is true, condition returns an error",
immediate: true,
context: func() (context.Context, context.CancelFunc) {
// use a cancelled context, we want to make sure the
// condition is expected to be invoked immediately.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
return ctx, cancel
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return false, fakeErr
}),
waitFunc: nil,
attemptsExpected: 1,
errExpected: fakeErr,
},
{
name: "immediate is true, condition returns true",
immediate: true,
context: func() (context.Context, context.CancelFunc) {
// use a cancelled context, we want to make sure the
// condition is expected to be invoked immediately.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
return ctx, cancel
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return true, nil
}),
waitFunc: nil,
attemptsExpected: 1,
errExpected: nil,
},
{
name: "immediate is true, context is cancelled, condition return false",
immediate: true,
context: func() (context.Context, context.CancelFunc) {
// use a cancelled context, we want to make sure the
// condition is expected to be invoked immediately.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
return ctx, cancel
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return false, nil
}),
waitFunc: nil,
attemptsExpected: 1,
errExpected: ErrWaitTimeout,
},
{
name: "immediate is false, context is cancelled",
immediate: false,
context: func() (context.Context, context.CancelFunc) {
// use a cancelled context, we want to make sure the
// condition is expected to be invoked immediately.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
return ctx, cancel
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return false, nil
}),
waitFunc: nil,
attemptsExpected: 0,
errExpected: ErrWaitTimeout,
},
{
name: "immediate is false, condition returns an error",
immediate: false,
context: func() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return false, fakeErr
}),
waitFunc: func() WaitFunc { return fakeTicker(5, nil, func() {}) },
attemptsExpected: 1,
errExpected: fakeErr,
},
{
name: "immediate is false, condition returns true",
immediate: false,
context: func() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return true, nil
}),
waitFunc: func() WaitFunc { return fakeTicker(5, nil, func() {}) },
attemptsExpected: 1,
errExpected: nil,
},
{
name: "immediate is false, ticker channel is closed, condition returns true",
immediate: false,
context: func() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return true, nil
}),
waitFunc: func() WaitFunc {
return func(done <-chan struct{}) <-chan struct{} {
ch := make(chan struct{})
close(ch)
return ch
}
},
attemptsExpected: 1,
errExpected: nil,
},
{
name: "immediate is false, ticker channel is closed, condition returns error",
immediate: false,
context: func() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return false, fakeErr
}),
waitFunc: func() WaitFunc {
return func(done <-chan struct{}) <-chan struct{} {
ch := make(chan struct{})
close(ch)
return ch
}
},
attemptsExpected: 1,
errExpected: fakeErr,
},
{
name: "immediate is false, ticker channel is closed, condition returns false",
immediate: false,
context: func() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return false, nil
}),
waitFunc: func() WaitFunc {
return func(done <-chan struct{}) <-chan struct{} {
ch := make(chan struct{})
close(ch)
return ch
}
},
attemptsExpected: 1,
errExpected: ErrWaitTimeout,
},
{
name: "condition always returns false, timeout error expected",
immediate: false,
context: func() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return false, nil
}),
waitFunc: func() WaitFunc { return fakeTicker(2, nil, func() {}) },
// the contract of WaitForWithContext() says the func is called once more at the end of the wait
attemptsExpected: 3,
errExpected: ErrWaitTimeout,
},
{
name: "context is cancelled after N attempts, timeout error expected",
immediate: false,
context: func() (context.Context, context.CancelFunc) {
return context.WithCancel(context.Background())
},
condition: ConditionWithContextFunc(func(context.Context) (bool, error) {
return false, nil
}),
waitFunc: func() WaitFunc {
return func(done <-chan struct{}) <-chan struct{} {
ch := make(chan struct{})
// just tick twice
go func() {
ch <- struct{}{}
ch <- struct{}{}
}()
return ch
}
},
cancelContextAfter: 2,
attemptsExpected: 2,
errExpected: ErrWaitTimeout,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var attempts int
ticker := WaitFunc(func(done <-chan struct{}) <-chan struct{} {
return nil
})
if test.waitFunc != nil {
ticker = test.waitFunc()
}
err := func() error {
ctx, cancel := test.context()
defer cancel()
conditionWrapper := func(ctx context.Context) (done bool, err error) {
attempts++
defer func() {
if test.cancelContextAfter == attempts {
cancel()
}
}()
return test.condition(ctx)
}
return poll(ctx, test.immediate, ticker.WithContext(), conditionWrapper)
}()
if test.errExpected != err {
t.Errorf("Expected error: %v, but got: %v", test.errExpected, err)
}
if test.attemptsExpected != attempts {
t.Errorf("Expected %d invocations, got %d", test.attemptsExpected, attempts)
}
})
}
}