diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go index 3dea7fe7f9e..afb24876adf 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go @@ -205,10 +205,29 @@ var ErrWaitTimeout = errors.New("timed out waiting for the condition") // if the loop should be aborted. type ConditionFunc func() (done bool, err error) +// ConditionWithContextFunc returns true if the condition is satisfied, or an error +// if the loop should be aborted. +// +// The caller passes along a context that can be used by the condition function. +type ConditionWithContextFunc func(context.Context) (done bool, err error) + +// WithContext converts a ConditionFunc into a ConditionWithContextFunc +func (cf ConditionFunc) WithContext() ConditionWithContextFunc { + return func(context.Context) (done bool, err error) { + return cf() + } +} + // runConditionWithCrashProtection runs a ConditionFunc with crash protection func runConditionWithCrashProtection(condition ConditionFunc) (bool, error) { + return runConditionWithCrashProtectionWithContext(context.TODO(), condition.WithContext()) +} + +// runConditionWithCrashProtectionWithContext runs a +// ConditionWithContextFunc with crash protection. +func runConditionWithCrashProtectionWithContext(ctx context.Context, condition ConditionWithContextFunc) (bool, error) { defer runtime.HandleCrash() - return condition() + return condition(ctx) } // Backoff holds parameters applied to a Backoff function. @@ -418,13 +437,62 @@ func ExponentialBackoff(backoff Backoff, condition ConditionFunc) error { // // If you want to Poll something forever, see PollInfinite. func Poll(interval, timeout time.Duration, condition ConditionFunc) error { - return pollInternal(poller(interval, timeout), condition) + return PollWithContext(context.Background(), interval, timeout, condition.WithContext()) } -func pollInternal(wait WaitFunc, condition ConditionFunc) error { - done := make(chan struct{}) - defer close(done) - return WaitFor(wait, condition, done) +// PollWithContext tries a condition func until it returns true, an error, +// or when the context expires or the timeout is reached, whichever +// happens first. +// +// PollWithContext always waits the interval before the run of 'condition'. +// 'condition' will always be invoked at least once. +// +// Some intervals may be missed if the condition takes too long or the time +// window is too short. +// +// If you want to Poll something forever, see PollInfinite. +func PollWithContext(ctx context.Context, interval, timeout time.Duration, condition ConditionWithContextFunc) error { + return poll(ctx, false, poller(interval, timeout), condition) +} + +// PollUntil tries a condition func until it returns true, an error or stopCh is +// closed. +// +// PollUntil always waits interval before the first run of 'condition'. +// 'condition' will always be invoked at least once. +func PollUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error { + ctx, cancel := contextForChannel(stopCh) + defer cancel() + return PollUntilWithContext(ctx, interval, condition.WithContext()) +} + +// PollUntilWithContext tries a condition func until it returns true, +// an error or the specified context is cancelled or expired. +// +// PollUntilWithContext always waits interval before the first run of 'condition'. +// 'condition' will always be invoked at least once. +func PollUntilWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { + return poll(ctx, false, poller(interval, 0), condition) +} + +// PollInfinite tries a condition func until it returns true or an error +// +// PollInfinite always waits the interval before the run of 'condition'. +// +// Some intervals may be missed if the condition takes too long or the time +// window is too short. +func PollInfinite(interval time.Duration, condition ConditionFunc) error { + return PollInfiniteWithContext(context.Background(), interval, condition.WithContext()) +} + +// PollInfiniteWithContext tries a condition func until it returns true or an error +// +// PollInfiniteWithContext always waits the interval before the run of 'condition'. +// +// Some intervals may be missed if the condition takes too long or the time +// window is too short. +func PollInfiniteWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { + return poll(ctx, false, poller(interval, 0), condition) } // PollImmediate tries a condition func until it returns true, an error, or the timeout @@ -438,30 +506,40 @@ func pollInternal(wait WaitFunc, condition ConditionFunc) error { // // If you want to immediately Poll something forever, see PollImmediateInfinite. func PollImmediate(interval, timeout time.Duration, condition ConditionFunc) error { - return pollImmediateInternal(poller(interval, timeout), condition) + return PollImmediateWithContext(context.Background(), interval, timeout, condition.WithContext()) } -func pollImmediateInternal(wait WaitFunc, condition ConditionFunc) error { - done, err := runConditionWithCrashProtection(condition) - if err != nil { - return err - } - if done { - return nil - } - return pollInternal(wait, condition) -} - -// PollInfinite tries a condition func until it returns true or an error +// PollImmediateWithContext tries a condition func until it returns true, an error, +// or the timeout is reached or the specified context expires, whichever happens first. // -// PollInfinite always waits the interval before the run of 'condition'. +// PollImmediateWithContext always checks 'condition' before waiting for the interval. +// 'condition' will always be invoked at least once. // // Some intervals may be missed if the condition takes too long or the time // window is too short. -func PollInfinite(interval time.Duration, condition ConditionFunc) error { - done := make(chan struct{}) - defer close(done) - return PollUntil(interval, condition, done) +// +// If you want to immediately Poll something forever, see PollImmediateInfinite. +func PollImmediateWithContext(ctx context.Context, interval, timeout time.Duration, condition ConditionWithContextFunc) error { + return poll(ctx, true, poller(interval, timeout), condition) +} + +// PollImmediateUntil tries a condition func until it returns true, an error or stopCh is closed. +// +// PollImmediateUntil runs the 'condition' before waiting for the interval. +// 'condition' will always be invoked at least once. +func PollImmediateUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error { + ctx, cancel := contextForChannel(stopCh) + defer cancel() + return PollImmediateUntilWithContext(ctx, interval, condition.WithContext()) +} + +// PollImmediateUntilWithContext tries a condition func until it returns true, +// an error or the specified context is cancelled or expired. +// +// PollImmediateUntilWithContext runs the 'condition' before waiting for the interval. +// 'condition' will always be invoked at least once. +func PollImmediateUntilWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { + return poll(ctx, true, poller(interval, 0), condition) } // PollImmediateInfinite tries a condition func until it returns true or an error @@ -471,44 +549,46 @@ func PollInfinite(interval time.Duration, condition ConditionFunc) error { // Some intervals may be missed if the condition takes too long or the time // window is too short. func PollImmediateInfinite(interval time.Duration, condition ConditionFunc) error { - done, err := runConditionWithCrashProtection(condition) - if err != nil { - return err - } - if done { - return nil - } - return PollInfinite(interval, condition) + return PollImmediateInfiniteWithContext(context.Background(), interval, condition.WithContext()) } -// PollUntil tries a condition func until it returns true, an error or stopCh is -// closed. +// PollImmediateInfiniteWithContext tries a condition func until it returns true +// or an error or the specified context gets cancelled or expired. // -// PollUntil always waits interval before the first run of 'condition'. -// 'condition' will always be invoked at least once. -func PollUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error { - ctx, cancel := contextForChannel(stopCh) - defer cancel() - return WaitFor(poller(interval, 0), condition, ctx.Done()) +// PollImmediateInfiniteWithContext runs the 'condition' before waiting for the interval. +// +// Some intervals may be missed if the condition takes too long or the time +// window is too short. +func PollImmediateInfiniteWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { + return poll(ctx, true, poller(interval, 0), condition) } -// PollImmediateUntil tries a condition func until it returns true, an error or stopCh is closed. -// -// PollImmediateUntil runs the 'condition' before waiting for the interval. -// 'condition' will always be invoked at least once. -func PollImmediateUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error { - done, err := condition() - if err != nil { - return err - } - if done { - return nil +// Internally used, each of the the public 'Poll*' function defined in this +// package should invoke this internal function with appropriate parameters. +// ctx: the context specified by the caller, for infinite polling pass +// a context that never gets cancelled or expired. +// immediate: if true, the 'condition' will be invoked before waiting for the interval, +// in this case 'condition' will always be invoked at least once. +// wait: user specified WaitFunc function that controls at what interval the condition +// function should be invoked periodically and whether it is bound by a timeout. +// condition: user specified ConditionWithContextFunc function. +func poll(ctx context.Context, immediate bool, wait WaitWithContextFunc, condition ConditionWithContextFunc) error { + if immediate { + done, err := runConditionWithCrashProtectionWithContext(ctx, condition) + if err != nil { + return err + } + if done { + return nil + } } + select { - case <-stopCh: + case <-ctx.Done(): + // returning ctx.Err() will break backward compatibility return ErrWaitTimeout default: - return PollUntil(interval, condition, stopCh) + return WaitForWithContext(ctx, wait, condition) } } @@ -516,6 +596,20 @@ func PollImmediateUntil(interval time.Duration, condition ConditionFunc, stopCh // should be executed and is closed when the last test should be invoked. type WaitFunc func(done <-chan struct{}) <-chan struct{} +// WithContext converts the WaitFunc to an equivalent WaitWithContextFunc +func (w WaitFunc) WithContext() WaitWithContextFunc { + return func(ctx context.Context) <-chan struct{} { + return w(ctx.Done()) + } +} + +// WaitWithContextFunc creates a channel that receives an item every time a test +// should be executed and is closed when the last test should be invoked. +// +// When the specified context gets cancelled or expires the function +// stops sending item and returns immediately. +type WaitWithContextFunc func(ctx context.Context) <-chan struct{} + // WaitFor continually checks 'fn' as driven by 'wait'. // // WaitFor gets a channel from 'wait()'', and then invokes 'fn' once for every value @@ -532,13 +626,35 @@ type WaitFunc func(done <-chan struct{}) <-chan struct{} // "uniform pseudo-random", the `fn` might still run one or multiple time, // though eventually `WaitFor` will return. func WaitFor(wait WaitFunc, fn ConditionFunc, done <-chan struct{}) error { - stopCh := make(chan struct{}) - defer close(stopCh) - c := wait(stopCh) + ctx, cancel := contextForChannel(done) + defer cancel() + return WaitForWithContext(ctx, wait.WithContext(), fn.WithContext()) +} + +// WaitForWithContext continually checks 'fn' as driven by 'wait'. +// +// WaitForWithContext gets a channel from 'wait()'', and then invokes 'fn' +// once for every value placed on the channel and once more when the +// channel is closed. If the channel is closed and 'fn' +// returns false without error, WaitForWithContext returns ErrWaitTimeout. +// +// If 'fn' returns an error the loop ends and that error is returned. If +// 'fn' returns true the loop ends and nil is returned. +// +// context.Canceled will be returned if the ctx.Done() channel is closed +// without fn ever returning true. +// +// When the ctx.Done() channel is closed, because the golang `select` statement is +// "uniform pseudo-random", the `fn` might still run one or multiple times, +// though eventually `WaitForWithContext` will return. +func WaitForWithContext(ctx context.Context, wait WaitWithContextFunc, fn ConditionWithContextFunc) error { + waitCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + c := wait(waitCtx) for { select { case _, open := <-c: - ok, err := runConditionWithCrashProtection(fn) + ok, err := runConditionWithCrashProtectionWithContext(ctx, fn) if err != nil { return err } @@ -548,7 +664,8 @@ func WaitFor(wait WaitFunc, fn ConditionFunc, done <-chan struct{}) error { if !open { return ErrWaitTimeout } - case <-done: + case <-ctx.Done(): + // returning ctx.Err() will break backward compatibility return ErrWaitTimeout } } @@ -564,8 +681,8 @@ func WaitFor(wait WaitFunc, fn ConditionFunc, done <-chan struct{}) error { // // Output ticks are not buffered. If the channel is not ready to receive an // item, the tick is skipped. -func poller(interval, timeout time.Duration) WaitFunc { - return WaitFunc(func(done <-chan struct{}) <-chan struct{} { +func poller(interval, timeout time.Duration) WaitWithContextFunc { + return WaitWithContextFunc(func(ctx context.Context) <-chan struct{} { ch := make(chan struct{}) go func() { @@ -595,7 +712,7 @@ func poller(interval, timeout time.Duration) WaitFunc { } case <-after: return - case <-done: + case <-ctx.Done(): return } } diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go index 03a147146fe..1729736467c 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go @@ -282,10 +282,10 @@ func TestExponentialBackoff(t *testing.T) { } func TestPoller(t *testing.T) { - done := make(chan struct{}) - defer close(done) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() w := poller(time.Millisecond, 2*time.Millisecond) - ch := w(done) + ch := w(ctx) count := 0 DRAIN: for { @@ -343,7 +343,10 @@ func TestPoll(t *testing.T) { return true, nil }) fp := fakePoller{max: 1} - if err := pollInternal(fp.GetWaitFunc(), f); err != nil { + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := poll(ctx, false, fp.GetWaitFunc().WithContext(), f.WithContext()); err != nil { t.Fatalf("unexpected error %v", err) } fp.wg.Wait() @@ -362,7 +365,10 @@ func TestPollError(t *testing.T) { return false, expectedError }) fp := fakePoller{max: 1} - if err := pollInternal(fp.GetWaitFunc(), f); err == nil || err != expectedError { + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := poll(ctx, false, fp.GetWaitFunc().WithContext(), f.WithContext()); err == nil || err != expectedError { t.Fatalf("Expected error %v, got none %v", expectedError, err) } fp.wg.Wait() @@ -379,7 +385,10 @@ func TestPollImmediate(t *testing.T) { return true, nil }) fp := fakePoller{max: 0} - if err := pollImmediateInternal(fp.GetWaitFunc(), f); err != nil { + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := poll(ctx, true, fp.GetWaitFunc().WithContext(), f.WithContext()); err != nil { t.Fatalf("unexpected error %v", err) } // We don't need to wait for fp.wg, as pollImmediate shouldn't call WaitFunc at all. @@ -398,7 +407,10 @@ func TestPollImmediateError(t *testing.T) { return false, expectedError }) fp := fakePoller{max: 0} - if err := pollImmediateInternal(fp.GetWaitFunc(), f); err == nil || err != expectedError { + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := poll(ctx, true, fp.GetWaitFunc().WithContext(), f.WithContext()); err == nil || err != expectedError { t.Fatalf("Expected error %v, got none %v", expectedError, err) } // We don't need to wait for fp.wg, as pollImmediate shouldn't call WaitFunc at all. @@ -567,28 +579,24 @@ func TestWaitForWithClosedChannel(t *testing.T) { } } -// TestWaitForClosesStopCh verifies that after the condition func returns true, WaitFor() closes the stop channel it supplies to the WaitFunc. -func TestWaitForClosesStopCh(t *testing.T) { - stopCh := make(chan struct{}) - defer close(stopCh) +// TestWaitForWithContextCancelsContext verifies that after the condition func returns true, +// WaitForWithContext cancels the context it supplies to the WaitWithContextFunc. +func TestWaitForWithContextCancelsContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() waitFunc := poller(time.Millisecond, ForeverTestTimeout) - var doneCh <-chan struct{} - WaitFor(func(done <-chan struct{}) <-chan struct{} { - doneCh = done - return waitFunc(done) - }, func() (bool, error) { + var ctxPassedToWait context.Context + WaitForWithContext(ctx, func(ctx context.Context) <-chan struct{} { + ctxPassedToWait = ctx + return waitFunc(ctx) + }, func(ctx context.Context) (bool, error) { time.Sleep(10 * time.Millisecond) return true, nil - }, stopCh) - // The polling goroutine should be closed after WaitFor returning. - select { - case _, ok := <-doneCh: - if ok { - t.Errorf("expected closed channel after WaitFunc returning") - } - default: - t.Errorf("expected an ack of the done signal") + }) + // The polling goroutine should be closed after WaitForWithContext returning. + if ctxPassedToWait.Err() != context.Canceled { + t.Errorf("expected the context passed to WaitForWithContext to be closed with: %v, but got: %v", context.Canceled, ctxPassedToWait.Err()) } } @@ -873,3 +881,452 @@ func TestExponentialBackoffWithContext(t *testing.T) { }) } } + +func TestPollImmediateUntilWithContext(t *testing.T) { + fakeErr := errors.New("my error") + tests := []struct { + name string + condition func(int) ConditionWithContextFunc + context func() (context.Context, context.CancelFunc) + cancelContextAfterNthAttempt int + errExpected error + attemptsExpected int + }{ + { + name: "condition throws error on immediate attempt, no retry is attempted", + condition: func(int) ConditionWithContextFunc { + return func(context.Context) (done bool, err error) { + return false, fakeErr + } + }, + context: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + errExpected: fakeErr, + attemptsExpected: 1, + }, + { + name: "condition returns done=true on immediate attempt, no retry is attempted", + condition: func(int) ConditionWithContextFunc { + return func(context.Context) (done bool, err error) { + return true, nil + } + }, + context: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + errExpected: nil, + attemptsExpected: 1, + }, + { + name: "condition returns done=false on immediate attempt, context is already cancelled, no retry is attempted", + condition: func(int) ConditionWithContextFunc { + return func(context.Context) (done bool, err error) { + return false, nil + } + }, + context: func() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + return ctx, cancel + }, + errExpected: ErrWaitTimeout, + attemptsExpected: 1, + }, + { + name: "condition returns done=false on immediate attempt, context is not cancelled, retry is attempted", + condition: func(attempts int) ConditionWithContextFunc { + return func(context.Context) (done bool, err error) { + // let first 3 attempts fail and the last one succeed + if attempts <= 3 { + return false, nil + } + return true, nil + } + }, + context: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + errExpected: nil, + attemptsExpected: 4, + }, + { + name: "condition always returns done=false, context gets cancelled after N attempts", + condition: func(attempts int) ConditionWithContextFunc { + return func(ctx context.Context) (done bool, err error) { + return false, nil + } + }, + context: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + cancelContextAfterNthAttempt: 4, + errExpected: ErrWaitTimeout, + attemptsExpected: 4, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx, cancel := test.context() + defer cancel() + + var attempts int + conditionWrapper := func(ctx context.Context) (done bool, err error) { + attempts++ + defer func() { + if test.cancelContextAfterNthAttempt == attempts { + cancel() + } + }() + + c := test.condition(attempts) + return c(ctx) + } + + err := PollImmediateUntilWithContext(ctx, time.Millisecond, conditionWrapper) + if test.errExpected != err { + t.Errorf("Expected error: %v, but got: %v", test.errExpected, err) + } + if test.attemptsExpected != attempts { + t.Errorf("Expected ConditionFunc to be invoked: %d times, but got: %d", test.attemptsExpected, attempts) + } + }) + } +} + +func TestWaitForWithContext(t *testing.T) { + fakeErr := errors.New("fake error") + tests := []struct { + name string + context func() (context.Context, context.CancelFunc) + condition ConditionWithContextFunc + waitFunc func() WaitFunc + attemptsExpected int + errExpected error + }{ + { + name: "condition returns done=true on first attempt, no retry is attempted", + context: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return true, nil + }), + waitFunc: func() WaitFunc { return fakeTicker(2, nil, func() {}) }, + attemptsExpected: 1, + errExpected: nil, + }, + { + name: "condition always returns done=false, timeout error expected", + context: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return false, nil + }), + waitFunc: func() WaitFunc { return fakeTicker(2, nil, func() {}) }, + // the contract of WaitForWithContext() says the func is called once more at the end of the wait + attemptsExpected: 3, + errExpected: ErrWaitTimeout, + }, + { + name: "condition returns an error on first attempt, the error is returned", + context: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return false, fakeErr + }), + waitFunc: func() WaitFunc { return fakeTicker(2, nil, func() {}) }, + attemptsExpected: 1, + errExpected: fakeErr, + }, + { + name: "context is cancelled, context cancelled error expected", + context: func() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx, cancel + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return false, nil + }), + waitFunc: func() WaitFunc { + return func(done <-chan struct{}) <-chan struct{} { + ch := make(chan struct{}) + // never tick on this channel + return ch + } + }, + attemptsExpected: 0, + errExpected: ErrWaitTimeout, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var attempts int + conditionWrapper := func(ctx context.Context) (done bool, err error) { + attempts++ + return test.condition(ctx) + } + + ticker := test.waitFunc() + err := func() error { + ctx, cancel := test.context() + defer cancel() + + return WaitForWithContext(ctx, ticker.WithContext(), conditionWrapper) + }() + + if test.errExpected != err { + t.Errorf("Expected error: %v, but got: %v", test.errExpected, err) + } + if test.attemptsExpected != attempts { + t.Errorf("Expected %d invocations, got %d", test.attemptsExpected, attempts) + } + }) + } +} + +func TestPollInternal(t *testing.T) { + fakeErr := errors.New("fake error") + tests := []struct { + name string + context func() (context.Context, context.CancelFunc) + immediate bool + waitFunc func() WaitFunc + condition ConditionWithContextFunc + cancelContextAfter int + attemptsExpected int + errExpected error + }{ + { + name: "immediate is true, condition returns an error", + immediate: true, + context: func() (context.Context, context.CancelFunc) { + // use a cancelled context, we want to make sure the + // condition is expected to be invoked immediately. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + return ctx, cancel + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return false, fakeErr + }), + waitFunc: nil, + attemptsExpected: 1, + errExpected: fakeErr, + }, + { + name: "immediate is true, condition returns true", + immediate: true, + context: func() (context.Context, context.CancelFunc) { + // use a cancelled context, we want to make sure the + // condition is expected to be invoked immediately. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + return ctx, cancel + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return true, nil + }), + waitFunc: nil, + attemptsExpected: 1, + errExpected: nil, + }, + { + name: "immediate is true, context is cancelled, condition return false", + immediate: true, + context: func() (context.Context, context.CancelFunc) { + // use a cancelled context, we want to make sure the + // condition is expected to be invoked immediately. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + return ctx, cancel + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return false, nil + }), + waitFunc: nil, + attemptsExpected: 1, + errExpected: ErrWaitTimeout, + }, + { + name: "immediate is false, context is cancelled", + immediate: false, + context: func() (context.Context, context.CancelFunc) { + // use a cancelled context, we want to make sure the + // condition is expected to be invoked immediately. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + return ctx, cancel + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return false, nil + }), + waitFunc: nil, + attemptsExpected: 0, + errExpected: ErrWaitTimeout, + }, + { + name: "immediate is false, condition returns an error", + immediate: false, + context: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return false, fakeErr + }), + waitFunc: func() WaitFunc { return fakeTicker(5, nil, func() {}) }, + attemptsExpected: 1, + errExpected: fakeErr, + }, + { + name: "immediate is false, condition returns true", + immediate: false, + context: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return true, nil + }), + waitFunc: func() WaitFunc { return fakeTicker(5, nil, func() {}) }, + attemptsExpected: 1, + errExpected: nil, + }, + { + name: "immediate is false, ticker channel is closed, condition returns true", + immediate: false, + context: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return true, nil + }), + waitFunc: func() WaitFunc { + return func(done <-chan struct{}) <-chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch + } + }, + attemptsExpected: 1, + errExpected: nil, + }, + { + name: "immediate is false, ticker channel is closed, condition returns error", + immediate: false, + context: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return false, fakeErr + }), + waitFunc: func() WaitFunc { + return func(done <-chan struct{}) <-chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch + } + }, + attemptsExpected: 1, + errExpected: fakeErr, + }, + { + name: "immediate is false, ticker channel is closed, condition returns false", + immediate: false, + context: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return false, nil + }), + waitFunc: func() WaitFunc { + return func(done <-chan struct{}) <-chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch + } + }, + attemptsExpected: 1, + errExpected: ErrWaitTimeout, + }, + { + name: "condition always returns false, timeout error expected", + immediate: false, + context: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return false, nil + }), + waitFunc: func() WaitFunc { return fakeTicker(2, nil, func() {}) }, + // the contract of WaitForWithContext() says the func is called once more at the end of the wait + attemptsExpected: 3, + errExpected: ErrWaitTimeout, + }, + { + name: "context is cancelled after N attempts, timeout error expected", + immediate: false, + context: func() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) + }, + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return false, nil + }), + waitFunc: func() WaitFunc { + return func(done <-chan struct{}) <-chan struct{} { + ch := make(chan struct{}) + // just tick twice + go func() { + ch <- struct{}{} + ch <- struct{}{} + }() + return ch + } + }, + cancelContextAfter: 2, + attemptsExpected: 2, + errExpected: ErrWaitTimeout, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var attempts int + ticker := WaitFunc(func(done <-chan struct{}) <-chan struct{} { + return nil + }) + if test.waitFunc != nil { + ticker = test.waitFunc() + } + err := func() error { + ctx, cancel := test.context() + defer cancel() + + conditionWrapper := func(ctx context.Context) (done bool, err error) { + attempts++ + + defer func() { + if test.cancelContextAfter == attempts { + cancel() + } + }() + + return test.condition(ctx) + } + + return poll(ctx, test.immediate, ticker.WithContext(), conditionWrapper) + }() + + if test.errExpected != err { + t.Errorf("Expected error: %v, but got: %v", test.errExpected, err) + } + if test.attemptsExpected != attempts { + t.Errorf("Expected %d invocations, got %d", test.attemptsExpected, attempts) + } + }) + } +}