diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/backoff.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/backoff.go index ed419d10527..4187619256e 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/backoff.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/backoff.go @@ -19,6 +19,7 @@ package wait import ( "context" "math" + "sync" "time" "k8s.io/apimachinery/pkg/util/runtime" @@ -51,33 +52,104 @@ type Backoff struct { Cap time.Duration } -// Step (1) returns an amount of time to sleep determined by the -// original Duration and Jitter and (2) mutates the provided Backoff -// to update its Steps and Duration. +// Step returns an amount of time to sleep determined by the original +// Duration and Jitter. The backoff is mutated to update its Steps and +// Duration. A nil Backoff always has a zero-duration step. func (b *Backoff) Step() time.Duration { - if b.Steps < 1 { - if b.Jitter > 0 { - return Jitter(b.Duration, b.Jitter) - } - return b.Duration + if b == nil { + return 0 } - b.Steps-- + var nextDuration time.Duration + nextDuration, b.Duration, b.Steps = delay(b.Steps, b.Duration, b.Cap, b.Factor, b.Jitter) + return nextDuration +} +// DelayFunc returns a function that will compute the next interval to +// wait given the arguments in b. It does not mutate the original backoff +// but the function is safe to use only from a single goroutine. +func (b Backoff) DelayFunc() DelayFunc { + steps := b.Steps duration := b.Duration + cap := b.Cap + factor := b.Factor + jitter := b.Jitter - // calculate the next step - if b.Factor != 0 { - b.Duration = time.Duration(float64(b.Duration) * b.Factor) - if b.Cap > 0 && b.Duration > b.Cap { - b.Duration = b.Cap - b.Steps = 0 + return func() time.Duration { + var nextDuration time.Duration + // jitter is applied per step and is not cumulative over multiple steps + nextDuration, duration, steps = delay(steps, duration, cap, factor, jitter) + return nextDuration + } +} + +// Timer returns a timer implementation appropriate to this backoff's parameters +// for use with wait functions. +func (b Backoff) Timer() Timer { + if b.Steps > 1 || b.Jitter != 0 { + return &variableTimer{new: internalClock.NewTimer, fn: b.DelayFunc()} + } + if b.Duration > 0 { + return &fixedTimer{new: internalClock.NewTicker, interval: b.Duration} + } + return newNoopTimer() +} + +// delay implements the core delay algorithm used in this package. +func delay(steps int, duration, cap time.Duration, factor, jitter float64) (_ time.Duration, next time.Duration, nextSteps int) { + // when steps is non-positive, do not alter the base duration + if steps < 1 { + if jitter > 0 { + return Jitter(duration, jitter), duration, 0 } + return duration, duration, 0 + } + steps-- + + // calculate the next step's interval + if factor != 0 { + next = time.Duration(float64(duration) * factor) + if cap > 0 && next > cap { + next = cap + steps = 0 + } + } else { + next = duration } - if b.Jitter > 0 { - duration = Jitter(duration, b.Jitter) + // add jitter for this step + if jitter > 0 { + duration = Jitter(duration, jitter) } - return duration + + return duration, next, steps + +} + +// DelayWithReset returns a DelayFunc that will return the appropriate next interval to +// wait. Every resetInterval the backoff parameters are reset to their initial state. +// This method is safe to invoke from multiple goroutines, but all calls will advance +// the backoff state when Factor is set. If Factor is zero, this method is the same as +// invoking b.DelayFunc() since Steps has no impact without Factor. If resetInterval is +// zero no backoff will be performed as the same calling DelayFunc with a zero factor +// and steps. +func (b Backoff) DelayWithReset(c clock.Clock, resetInterval time.Duration) DelayFunc { + if b.Factor <= 0 { + return b.DelayFunc() + } + if resetInterval <= 0 { + b.Steps = 0 + b.Factor = 0 + return b.DelayFunc() + } + return (&backoffManager{ + backoff: b, + initialBackoff: b, + resetInterval: resetInterval, + + clock: c, + lastStart: c.Now(), + timer: nil, + }).Step } // Until loops until stop channel is closed, running f every period. @@ -187,15 +259,65 @@ func JitterUntilWithContext(ctx context.Context, f func(context.Context), period JitterUntil(func() { f(ctx) }, period, jitterFactor, sliding, ctx.Done()) } -// BackoffManager manages backoff with a particular scheme based on its underlying implementation. It provides -// an interface to return a timer for backoff, and caller shall backoff until Timer.C() drains. If the second Backoff() -// is called before the timer from the first Backoff() call finishes, the first timer will NOT be drained and result in -// undetermined behavior. -// The BackoffManager is supposed to be called in a single-threaded environment. +// backoffManager provides simple backoff behavior in a threadsafe manner to a caller. +type backoffManager struct { + backoff Backoff + initialBackoff Backoff + resetInterval time.Duration + + clock clock.Clock + + lock sync.Mutex + lastStart time.Time + timer clock.Timer +} + +// Step returns the expected next duration to wait. +func (b *backoffManager) Step() time.Duration { + b.lock.Lock() + defer b.lock.Unlock() + + switch { + case b.resetInterval == 0: + b.backoff = b.initialBackoff + case b.clock.Now().Sub(b.lastStart) > b.resetInterval: + b.backoff = b.initialBackoff + b.lastStart = b.clock.Now() + } + return b.backoff.Step() +} + +// Backoff implements BackoffManager.Backoff, it returns a timer so caller can block on the timer +// for exponential backoff. The returned timer must be drained before calling Backoff() the second +// time. +func (b *backoffManager) Backoff() clock.Timer { + b.lock.Lock() + defer b.lock.Unlock() + if b.timer == nil { + b.timer = b.clock.NewTimer(b.Step()) + } else { + b.timer.Reset(b.Step()) + } + return b.timer +} + +// Timer returns a new Timer instance that shares the clock and the reset behavior with all other +// timers. +func (b *backoffManager) Timer() Timer { + return DelayFunc(b.Step).Timer(b.clock) +} + +// BackoffManager manages backoff with a particular scheme based on its underlying implementation. type BackoffManager interface { + // Backoff returns a shared clock.Timer that is Reset on every invocation. This method is not + // safe for use from multiple threads. It returns a timer for backoff, and caller shall backoff + // until Timer.C() drains. If the second Backoff() is called before the timer from the first + // Backoff() call finishes, the first timer will NOT be drained and result in undetermined + // behavior. Backoff() clock.Timer } +// Deprecated: Will be removed when the legacy polling functions are removed. type exponentialBackoffManagerImpl struct { backoff *Backoff backoffTimer clock.Timer @@ -208,6 +330,27 @@ type exponentialBackoffManagerImpl struct { // NewExponentialBackoffManager returns a manager for managing exponential backoff. Each backoff is jittered and // backoff will not exceed the given max. If the backoff is not called within resetDuration, the backoff is reset. // This backoff manager is used to reduce load during upstream unhealthiness. +// +// Deprecated: Will be removed when the legacy Poll methods are removed. Callers should construct a +// Backoff struct, use DelayWithReset() to get a DelayFunc that periodically resets itself, and then +// invoke Timer() when calling wait.BackoffUntil. +// +// Instead of: +// +// bm := wait.NewExponentialBackoffManager(init, max, reset, factor, jitter, clock) +// ... +// wait.BackoffUntil(..., bm.Backoff, ...) +// +// Use: +// +// delayFn := wait.Backoff{ +// Duration: init, +// Cap: max, +// Steps: int(math.Ceil(float64(max) / float64(init))), // now a required argument +// Factor: factor, +// Jitter: jitter, +// }.DelayWithReset(reset, clock) +// wait.BackoffUntil(..., delayFn.Timer(), ...) func NewExponentialBackoffManager(initBackoff, maxBackoff, resetDuration time.Duration, backoffFactor, jitter float64, c clock.Clock) BackoffManager { return &exponentialBackoffManagerImpl{ backoff: &Backoff{ @@ -248,6 +391,7 @@ func (b *exponentialBackoffManagerImpl) Backoff() clock.Timer { return b.backoffTimer } +// Deprecated: Will be removed when the legacy polling functions are removed. type jitteredBackoffManagerImpl struct { clock clock.Clock duration time.Duration @@ -257,6 +401,19 @@ type jitteredBackoffManagerImpl struct { // NewJitteredBackoffManager returns a BackoffManager that backoffs with given duration plus given jitter. If the jitter // is negative, backoff will not be jittered. +// +// Deprecated: Will be removed when the legacy Poll methods are removed. Callers should construct a +// Backoff struct and invoke Timer() when calling wait.BackoffUntil. +// +// Instead of: +// +// bm := wait.NewJitteredBackoffManager(duration, jitter, clock) +// ... +// wait.BackoffUntil(..., bm.Backoff, ...) +// +// Use: +// +// wait.BackoffUntil(..., wait.Backoff{Duration: duration, Jitter: jitter}.Timer(), ...) func NewJitteredBackoffManager(duration time.Duration, jitter float64, c clock.Clock) BackoffManager { return &jitteredBackoffManagerImpl{ clock: c, @@ -296,6 +453,9 @@ func (j *jitteredBackoffManagerImpl) Backoff() clock.Timer { // 3. a sleep truncated by the cap on duration has been completed. // In case (1) the returned error is what the condition function returned. // In all other cases, ErrWaitTimeout is returned. +// +// Since backoffs are often subject to cancellation, we recommend using +// ExponentialBackoffWithContext and passing a context to the method. func ExponentialBackoff(backoff Backoff, condition ConditionFunc) error { for backoff.Steps > 0 { if ok, err := runConditionWithCrashProtection(condition); err != nil || ok { @@ -309,8 +469,11 @@ func ExponentialBackoff(backoff Backoff, condition ConditionFunc) error { return ErrWaitTimeout } -// ExponentialBackoffWithContext works with a request context and a Backoff. It ensures that the retry wait never -// exceeds the deadline specified by the request context. +// ExponentialBackoffWithContext repeats a condition check with exponential backoff. +// It immediately returns an error if the condition returns an error, the context is cancelled +// or hits the deadline, or if the maximum attempts defined in backoff is exceeded (ErrWaitTimeout). +// If an error is returned by the condition the backoff stops immediately. The condition will +// never be invoked more than backoff.Steps times. func ExponentialBackoffWithContext(ctx context.Context, backoff Backoff, condition ConditionWithContextFunc) error { for backoff.Steps > 0 { select { diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/delay.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/delay.go new file mode 100644 index 00000000000..1d3dcaa74ec --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/delay.go @@ -0,0 +1,51 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "context" + "sync" + "time" + + "k8s.io/utils/clock" +) + +// DelayFunc returns the next time interval to wait. +type DelayFunc func() time.Duration + +// Timer takes an arbitrary delay function and returns a timer that can handle arbitrary interval changes. +// Use Backoff{...}.Timer() for simple delays and more efficient timers. +func (fn DelayFunc) Timer(c clock.Clock) Timer { + return &variableTimer{fn: fn, new: c.NewTimer} +} + +// Until takes an arbitrary delay function and runs until cancelled or the condition indicates exit. This +// offers all of the functionality of the methods in this package. +func (fn DelayFunc) Until(ctx context.Context, immediate, sliding bool, condition ConditionWithContextFunc) error { + return loopConditionUntilContext(ctx, &variableTimer{fn: fn, new: internalClock.NewTimer}, immediate, sliding, condition) +} + +// Concurrent returns a version of this DelayFunc that is safe for use by multiple goroutines that +// wish to share a single delay timer. +func (fn DelayFunc) Concurrent() DelayFunc { + var lock sync.Mutex + return func() time.Duration { + lock.Lock() + defer lock.Unlock() + return fn() + } +} diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/error.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/error.go index 5172f08dff7..dd75801d829 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/error.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/error.go @@ -16,7 +16,81 @@ limitations under the License. package wait -import "errors" +import ( + "context" + "errors" +) -// ErrWaitTimeout is returned when the condition exited without success. -var ErrWaitTimeout = errors.New("timed out waiting for the condition") +// ErrWaitTimeout is returned when the condition was not satisfied in time. +// +// Deprecated: This type will be made private in favor of Interrupted() +// for checking errors or ErrorInterrupted(err) for returning a wrapped error. +var ErrWaitTimeout = ErrorInterrupted(errors.New("timed out waiting for the condition")) + +// Interrupted returns true if the error indicates a Poll, ExponentialBackoff, or +// Until loop exited for any reason besides the condition returning true or an +// error. A loop is considered interrupted if the calling context is cancelled, +// the context reaches its deadline, or a backoff reaches its maximum allowed +// steps. +// +// Callers should use this method instead of comparing the error value directly to +// ErrWaitTimeout, as methods that cancel a context may not return that error. +// +// Instead of: +// +// err := wait.Poll(...) +// if err == wait.ErrWaitTimeout { +// log.Infof("Wait for operation exceeded") +// } else ... +// +// Use: +// +// err := wait.Poll(...) +// if wait.Interrupted(err) { +// log.Infof("Wait for operation exceeded") +// } else ... +func Interrupted(err error) bool { + switch { + case errors.Is(err, errWaitTimeout), + errors.Is(err, context.Canceled), + errors.Is(err, context.DeadlineExceeded): + return true + default: + return false + } +} + +// errInterrupted +type errInterrupted struct { + cause error +} + +// ErrorInterrupted returns an error that indicates the wait was ended +// early for a given reason. If no cause is provided a generic error +// will be used but callers are encouraged to provide a real cause for +// clarity in debugging. +func ErrorInterrupted(cause error) error { + switch cause.(type) { + case errInterrupted: + // no need to wrap twice since errInterrupted is only needed + // once in a chain + return cause + default: + return errInterrupted{cause} + } +} + +// errWaitTimeout is the private version of the previous ErrWaitTimeout +// and is private to prevent direct comparison. Use ErrorInterrupted(err) +// to get an error that will return true for Interrupted(err). +var errWaitTimeout = errInterrupted{} + +func (e errInterrupted) Unwrap() error { return e.cause } +func (e errInterrupted) Is(target error) bool { return target == errWaitTimeout } +func (e errInterrupted) Error() string { + if e.cause == nil { + // returns the same error message as historical behavior + return "timed out waiting for the condition" + } + return e.cause.Error() +} diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/error_test.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/error_test.go new file mode 100644 index 00000000000..0c96f06198b --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/error_test.go @@ -0,0 +1,144 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "context" + "errors" + "fmt" + "testing" +) + +type errWrapper struct { + wrapped error +} + +func (w errWrapper) Unwrap() error { + return w.wrapped +} +func (w errWrapper) Error() string { + return fmt.Sprintf("wrapped: %v", w.wrapped) +} + +type errNotWrapper struct { + wrapped error +} + +func (w errNotWrapper) Error() string { + return fmt.Sprintf("wrapped: %v", w.wrapped) +} + +func TestInterrupted(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + err: ErrWaitTimeout, + want: true, + }, + { + err: context.Canceled, + want: true, + }, { + err: context.DeadlineExceeded, + want: true, + }, + { + err: errWrapper{ErrWaitTimeout}, + want: true, + }, + { + err: errWrapper{context.Canceled}, + want: true, + }, + { + err: errWrapper{context.DeadlineExceeded}, + want: true, + }, + { + err: ErrorInterrupted(nil), + want: true, + }, + { + err: ErrorInterrupted(errors.New("unknown")), + want: true, + }, + { + err: ErrorInterrupted(context.Canceled), + want: true, + }, + { + err: ErrorInterrupted(ErrWaitTimeout), + want: true, + }, + + { + err: nil, + }, + { + err: errors.New("not a cancellation"), + }, + { + err: errNotWrapper{ErrWaitTimeout}, + }, + { + err: errNotWrapper{context.Canceled}, + }, + { + err: errNotWrapper{context.DeadlineExceeded}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Interrupted(tt.err); got != tt.want { + t.Errorf("Interrupted() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestErrorInterrupted(t *testing.T) { + internalErr := errInterrupted{} + if ErrorInterrupted(internalErr) != internalErr { + t.Fatalf("error should not be wrapped twice") + } + + internalErr = errInterrupted{errInterrupted{}} + if ErrorInterrupted(internalErr) != internalErr { + t.Fatalf("object should be identical") + } + + in := errors.New("test") + actual, expected := ErrorInterrupted(in), (errInterrupted{in}) + if actual != expected { + t.Fatalf("did not wrap error") + } + if !errors.Is(actual, errWaitTimeout) { + t.Fatalf("does not obey errors.Is contract") + } + if actual.Error() != in.Error() { + t.Fatalf("unexpected error output") + } + if !Interrupted(actual) { + t.Fatalf("is not Interrupted") + } + if Interrupted(in) { + t.Fatalf("should not be Interrupted") + } +} diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/loop.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/loop.go new file mode 100644 index 00000000000..51864d70f95 --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/loop.go @@ -0,0 +1,86 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "context" + "time" + + "k8s.io/apimachinery/pkg/util/runtime" +) + +// loopConditionUntilContext executes the provided condition at intervals defined by +// the provided timer until the provided context is cancelled, the condition returns +// true, or the condition returns an error. If sliding is true, the period is computed +// after condition runs. If it is false then period includes the runtime for condition. +// If immediate is false the first delay happens before any call to condition. The +// returned error is the error returned by the last condition or the context error if +// the context was terminated. +// +// This is the common loop construct for all polling in the wait package. +func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding bool, condition ConditionWithContextFunc) error { + defer t.Stop() + + var timeCh <-chan time.Time + doneCh := ctx.Done() + + // if we haven't requested immediate execution, delay once + if !immediate { + timeCh = t.C() + select { + case <-doneCh: + return ctx.Err() + case <-timeCh: + } + } + + for { + // checking ctx.Err() is slightly faster than checking a select + if err := ctx.Err(); err != nil { + return err + } + + if !sliding { + t.Next() + } + if ok, err := func() (bool, error) { + defer runtime.HandleCrash() + return condition(ctx) + }(); err != nil || ok { + return err + } + if sliding { + t.Next() + } + + if timeCh == nil { + timeCh = t.C() + } + + // NOTE: b/c there is no priority selection in golang + // it is possible for this to race, meaning we could + // trigger t.C and doneCh, and t.C select falls through. + // In order to mitigate we re-check doneCh at the beginning + // of every loop to guarantee at-most one extra execution + // of condition. + select { + case <-doneCh: + return ctx.Err() + case <-timeCh: + } + } +} diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/loop_test.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/loop_test.go new file mode 100644 index 00000000000..c5849250aa2 --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/loop_test.go @@ -0,0 +1,447 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "context" + "errors" + "fmt" + "reflect" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "k8s.io/utils/clock" + testingclock "k8s.io/utils/clock/testing" +) + +func timerWithClock(t Timer, c clock.WithTicker) Timer { + switch t := t.(type) { + case *fixedTimer: + t.new = c.NewTicker + case *variableTimer: + t.new = c.NewTimer + default: + panic("unrecognized timer type, cannot inject clock") + } + return t +} + +func Test_loopConditionWithContextImmediateDelay(t *testing.T) { + fakeClock := testingclock.NewFakeClock(time.Time{}) + backoff := Backoff{Duration: time.Second} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + expectedError := errors.New("Expected error") + var attempt int + f := ConditionFunc(func() (bool, error) { + attempt++ + return false, expectedError + }) + + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + if err := loopConditionUntilContext(ctx, timerWithClock(backoff.Timer(), fakeClock), false, true, f.WithContext()); err == nil || err != expectedError { + t.Errorf("unexpected error: %v", err) + } + }() + + for !fakeClock.HasWaiters() { + time.Sleep(time.Microsecond) + } + + fakeClock.Step(time.Second - time.Millisecond) + if attempt != 0 { + t.Fatalf("should still be waiting for condition") + } + fakeClock.Step(2 * time.Millisecond) + + select { + case <-doneCh: + case <-time.After(time.Second): + t.Fatalf("should have exited after a single loop") + } + if attempt != 1 { + t.Fatalf("expected attempt") + } +} + +func Test_loopConditionUntilContext_semantic(t *testing.T) { + defaultCallback := func(_ int) (bool, error) { + return false, nil + } + + conditionErr := errors.New("condition failed") + + tests := []struct { + name string + immediate bool + sliding bool + context func() (context.Context, context.CancelFunc) + callback func(calls int) (bool, error) + cancelContextAfter int + attemptsExpected int + errExpected error + }{ + { + name: "condition successful is only one attempt", + callback: func(attempts int) (bool, error) { + return true, nil + }, + attemptsExpected: 1, + }, + { + name: "delayed condition successful causes return and attempts", + callback: func(attempts int) (bool, error) { + return attempts > 1, nil + }, + attemptsExpected: 2, + }, + { + name: "delayed condition successful causes return and attempts many times", + callback: func(attempts int) (bool, error) { + return attempts >= 100, nil + }, + attemptsExpected: 100, + }, + { + name: "condition returns error even if ok is true", + callback: func(_ int) (bool, error) { + return true, conditionErr + }, + attemptsExpected: 1, + errExpected: conditionErr, + }, + { + name: "condition exits after an error", + callback: func(_ int) (bool, error) { + return false, conditionErr + }, + attemptsExpected: 1, + errExpected: conditionErr, + }, + { + name: "context already canceled no attempts expected", + context: cancelledContext, + callback: defaultCallback, + attemptsExpected: 0, + errExpected: context.Canceled, + }, + { + name: "context cancelled after 5 attempts", + context: defaultContext, + callback: defaultCallback, + cancelContextAfter: 5, + attemptsExpected: 5, + errExpected: context.Canceled, + }, + { + name: "context at deadline no attempts expected", + context: deadlinedContext, + callback: defaultCallback, + attemptsExpected: 0, + errExpected: context.DeadlineExceeded, + }, + } + + for _, test := range tests { + for _, immediate := range []bool{true, false} { + t.Run(fmt.Sprintf("immediate=%t", immediate), func(t *testing.T) { + for _, sliding := range []bool{true, false} { + t.Run(fmt.Sprintf("sliding=%t", sliding), func(t *testing.T) { + t.Run(test.name, func(t *testing.T) { + contextFn := test.context + if contextFn == nil { + contextFn = defaultContext + } + ctx, cancel := contextFn() + defer cancel() + + timer := Backoff{Duration: time.Microsecond}.Timer() + attempts := 0 + err := loopConditionUntilContext(ctx, timer, test.immediate, test.sliding, func(_ context.Context) (bool, error) { + attempts++ + defer func() { + if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts { + cancel() + } + }() + return test.callback(attempts) + }) + + if test.errExpected != err { + t.Errorf("expected error: %v but got: %v", test.errExpected, err) + } + + if test.attemptsExpected != attempts { + t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts) + } + }) + }) + } + }) + } + } +} + +type timerWrapper struct { + timer clock.Timer + resets []time.Duration + onReset func(d time.Duration) +} + +func (w *timerWrapper) C() <-chan time.Time { return w.timer.C() } +func (w *timerWrapper) Stop() bool { return w.timer.Stop() } +func (w *timerWrapper) Reset(d time.Duration) bool { + w.resets = append(w.resets, d) + b := w.timer.Reset(d) + if w.onReset != nil { + w.onReset(d) + } + return b +} + +func Test_loopConditionUntilContext_timings(t *testing.T) { + // Verify that timings returned by the delay func are passed to the timer, and that + // the timer advancing is enough to drive the state machine. Not a deep verification + // of the behavior of the loop, but tests that we drive the scenario to completion. + tests := []struct { + name string + delayFn DelayFunc + immediate bool + sliding bool + context func() (context.Context, context.CancelFunc) + callback func(calls int, lastInterval time.Duration) (bool, error) + cancelContextAfter int + attemptsExpected int + errExpected error + expectedIntervals func(t *testing.T, delays []time.Duration, delaysRequested []time.Duration) + }{ + { + name: "condition success", + delayFn: Backoff{Duration: time.Second, Steps: 2, Factor: 2.0, Jitter: 0}.DelayFunc(), + callback: func(attempts int, _ time.Duration) (bool, error) { + return true, nil + }, + attemptsExpected: 1, + expectedIntervals: func(t *testing.T, delays []time.Duration, delaysRequested []time.Duration) { + if reflect.DeepEqual(delays, []time.Duration{time.Second, 2 * time.Second}) { + return + } + if reflect.DeepEqual(delaysRequested, []time.Duration{time.Second}) { + return + } + }, + }, + { + name: "condition success and immediate", + immediate: true, + delayFn: Backoff{Duration: time.Second, Steps: 2, Factor: 2.0, Jitter: 0}.DelayFunc(), + callback: func(attempts int, _ time.Duration) (bool, error) { + return true, nil + }, + attemptsExpected: 1, + expectedIntervals: func(t *testing.T, delays []time.Duration, delaysRequested []time.Duration) { + if reflect.DeepEqual(delays, []time.Duration{time.Second}) { + return + } + if reflect.DeepEqual(delaysRequested, []time.Duration{}) { + return + } + }, + }, + { + name: "condition success and sliding", + sliding: true, + delayFn: Backoff{Duration: time.Second, Steps: 2, Factor: 2.0, Jitter: 0}.DelayFunc(), + callback: func(attempts int, _ time.Duration) (bool, error) { + return true, nil + }, + attemptsExpected: 1, + expectedIntervals: func(t *testing.T, delays []time.Duration, delaysRequested []time.Duration) { + if reflect.DeepEqual(delays, []time.Duration{time.Second}) { + return + } + if !reflect.DeepEqual(delays, delaysRequested) { + t.Fatalf("sliding non-immediate should have equal delays: %v", cmp.Diff(delays, delaysRequested)) + } + }, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s/sliding=%t/immediate=%t", test.name, test.sliding, test.immediate), func(t *testing.T) { + contextFn := test.context + if contextFn == nil { + contextFn = defaultContext + } + ctx, cancel := contextFn() + defer cancel() + + fakeClock := &testingclock.FakeClock{} + var fakeTimers []*timerWrapper + timerFn := func(d time.Duration) clock.Timer { + t := fakeClock.NewTimer(d) + fakeClock.Step(d + 1) + w := &timerWrapper{timer: t, resets: []time.Duration{d}, onReset: func(d time.Duration) { + fakeClock.Step(d + 1) + }} + fakeTimers = append(fakeTimers, w) + return w + } + + delayFn := test.delayFn + if delayFn == nil { + delayFn = Backoff{Duration: time.Microsecond}.DelayFunc() + } + var delays []time.Duration + wrappedDelayFn := func() time.Duration { + d := delayFn() + delays = append(delays, d) + return d + } + timer := &variableTimer{fn: wrappedDelayFn, new: timerFn} + + attempts := 0 + err := loopConditionUntilContext(ctx, timer, test.immediate, test.sliding, func(_ context.Context) (bool, error) { + attempts++ + defer func() { + if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts { + cancel() + } + }() + lastInterval := time.Duration(-1) + if len(delays) > 0 { + lastInterval = delays[len(delays)-1] + } + return test.callback(attempts, lastInterval) + }) + + if test.errExpected != err { + t.Errorf("expected error: %v but got: %v", test.errExpected, err) + } + + if test.attemptsExpected != attempts { + t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts) + } + switch len(fakeTimers) { + case 0: + test.expectedIntervals(t, delays, nil) + case 1: + test.expectedIntervals(t, delays, fakeTimers[0].resets) + default: + t.Fatalf("expected zero or one timers: %#v", fakeTimers) + } + }) + } +} + +// Test_loopConditionUntilContext_timings runs actual timing loops and calculates the delta. This +// test depends on high precision wakeups which depends on low CPU contention so it is not a +// candidate to run during normal unit test execution (nor is it a benchmark or example). Instead, +// it can be run manually if there is a scenario where we suspect the timings are off and other +// tests haven't caught it. A final sanity test that would have to be run serially in isolation. +func Test_loopConditionUntilContext_Elapsed(t *testing.T) { + const maxAttempts = 10 + // TODO: this may be too aggressive, but the overhead should be minor + const estimatedLoopOverhead = time.Millisecond + // estimate how long this delay can be + intervalMax := func(backoff Backoff) time.Duration { + d := backoff.Duration + if backoff.Jitter > 0 { + d += time.Duration(backoff.Jitter * float64(d)) + } + return d + } + // estimate how short this delay can be + intervalMin := func(backoff Backoff) time.Duration { + d := backoff.Duration + return d + } + + // Because timing is dependent other factors in test environments, such as + // whether the OS or go runtime scheduler wake the timers, excess duration + // is logged by default and can be converted to a fatal error for testing. + // fail := t.Fatalf + fail := t.Logf + + for _, test := range []struct { + name string + backoff Backoff + t reflect.Type + }{ + {name: "variable timer with jitter", backoff: Backoff{Duration: time.Millisecond, Jitter: 1.0}, t: reflect.TypeOf(&variableTimer{})}, + {name: "fixed timer", backoff: Backoff{Duration: time.Millisecond}, t: reflect.TypeOf(&fixedTimer{})}, + {name: "no-op timer", backoff: Backoff{}, t: reflect.TypeOf(noopTimer{})}, + } { + t.Run(test.name, func(t *testing.T) { + var attempts int + start := time.Now() + timer := test.backoff.Timer() + if test.t != reflect.ValueOf(timer).Type() { + t.Fatalf("unexpected timer type %T: expected %v", timer, test.t) + } + if err := loopConditionUntilContext(context.Background(), timer, false, false, func(_ context.Context) (bool, error) { + attempts++ + if attempts > maxAttempts { + t.Fatalf("should not reach %d attempts", maxAttempts+1) + } + return attempts >= maxAttempts, nil + }); err != nil { + t.Fatal(err) + } + duration := time.Since(start) + if min := maxAttempts * intervalMin(test.backoff); duration < min { + fail("elapsed duration %v < expected min duration %v", duration, min) + } + if max := maxAttempts * (intervalMax(test.backoff) + estimatedLoopOverhead); duration > max { + fail("elapsed duration %v > expected max duration %v", duration, max) + } + }) + } +} + +func Benchmark_loopConditionUntilContext_ZeroDuration(b *testing.B) { + ctx := context.Background() + b.ResetTimer() + for i := 0; i < b.N; i++ { + attempts := 0 + if err := loopConditionUntilContext(ctx, Backoff{Duration: 0}.Timer(), true, false, func(_ context.Context) (bool, error) { + attempts++ + return attempts >= 100, nil + }); err != nil { + b.Fatalf("unexpected err: %v", err) + } + } +} + +func Benchmark_loopConditionUntilContext_ShortDuration(b *testing.B) { + ctx := context.Background() + b.ResetTimer() + for i := 0; i < b.N; i++ { + attempts := 0 + if err := loopConditionUntilContext(ctx, Backoff{Duration: time.Microsecond}.Timer(), true, false, func(_ context.Context) (bool, error) { + attempts++ + return attempts >= 100, nil + }); err != nil { + b.Fatalf("unexpected err: %v", err) + } + } +} diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/poll.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/poll.go index 564e9b9d290..32e8688ca0f 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/poll.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/poll.go @@ -21,6 +21,33 @@ import ( "time" ) +// PollUntilContextCancel tries a condition func until it returns true, an error, or the context +// is cancelled or hits a deadline. condition will be invoked after the first interval if the +// context is not cancelled first. The returned error will be from ctx.Err(), the condition's +// err return value, or nil. If invoking condition takes longer than interval the next condition +// will be invoked immediately. When using very short intervals, condition may be invoked multiple +// times before a context cancellation is detected. If immediate is true, condition will be +// invoked before waiting and guarantees that condition is invoked at least once, regardless of +// whether the context has been cancelled. +func PollUntilContextCancel(ctx context.Context, interval time.Duration, immediate bool, condition ConditionWithContextFunc) error { + return loopConditionUntilContext(ctx, Backoff{Duration: interval}.Timer(), immediate, false, condition) +} + +// PollUntilContextTimeout will terminate polling after timeout duration by setting a context +// timeout. This is provided as a convenience function for callers not currently executing under +// a deadline and is equivalent to: +// +// deadlineCtx, deadlineCancel := context.WithTimeout(ctx, timeout) +// err := PollUntilContextCancel(ctx, interval, immediate, condition) +// +// The deadline context will be cancelled if the Poll succeeds before the timeout, simplifying +// inline usage. All other behavior is identical to PollWithContextTimeout. +func PollUntilContextTimeout(ctx context.Context, interval, timeout time.Duration, immediate bool, condition ConditionWithContextFunc) error { + deadlineCtx, deadlineCancel := context.WithTimeout(ctx, timeout) + defer deadlineCancel() + return loopConditionUntilContext(deadlineCtx, Backoff{Duration: interval}.Timer(), immediate, false, condition) +} + // Poll tries a condition func until it returns true, an error, or the timeout // is reached. // @@ -31,6 +58,10 @@ import ( // window is too short. // // If you want to Poll something forever, see PollInfinite. +// +// Deprecated: This method does not return errors from context, use PollWithContextTimeout. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func Poll(interval, timeout time.Duration, condition ConditionFunc) error { return PollWithContext(context.Background(), interval, timeout, condition.WithContext()) } @@ -46,6 +77,10 @@ func Poll(interval, timeout time.Duration, condition ConditionFunc) error { // window is too short. // // If you want to Poll something forever, see PollInfinite. +// +// Deprecated: This method does not return errors from context, use PollWithContextTimeout. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollWithContext(ctx context.Context, interval, timeout time.Duration, condition ConditionWithContextFunc) error { return poll(ctx, false, poller(interval, timeout), condition) } @@ -55,6 +90,10 @@ func PollWithContext(ctx context.Context, interval, timeout time.Duration, condi // // PollUntil always waits interval before the first run of 'condition'. // 'condition' will always be invoked at least once. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error { return PollUntilWithContext(ContextForChannel(stopCh), interval, condition.WithContext()) } @@ -64,6 +103,10 @@ func PollUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan st // // PollUntilWithContext always waits interval before the first run of 'condition'. // 'condition' will always be invoked at least once. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollUntilWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { return poll(ctx, false, poller(interval, 0), condition) } @@ -74,6 +117,10 @@ func PollUntilWithContext(ctx context.Context, interval time.Duration, condition // // Some intervals may be missed if the condition takes too long or the time // window is too short. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollInfinite(interval time.Duration, condition ConditionFunc) error { return PollInfiniteWithContext(context.Background(), interval, condition.WithContext()) } @@ -84,6 +131,10 @@ func PollInfinite(interval time.Duration, condition ConditionFunc) error { // // Some intervals may be missed if the condition takes too long or the time // window is too short. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollInfiniteWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { return poll(ctx, false, poller(interval, 0), condition) } @@ -98,6 +149,10 @@ func PollInfiniteWithContext(ctx context.Context, interval time.Duration, condit // window is too short. // // If you want to immediately Poll something forever, see PollImmediateInfinite. +// +// Deprecated: This method does not return errors from context, use PollWithContextTimeout. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollImmediate(interval, timeout time.Duration, condition ConditionFunc) error { return PollImmediateWithContext(context.Background(), interval, timeout, condition.WithContext()) } @@ -112,6 +167,10 @@ func PollImmediate(interval, timeout time.Duration, condition ConditionFunc) err // window is too short. // // If you want to immediately Poll something forever, see PollImmediateInfinite. +// +// Deprecated: This method does not return errors from context, use PollWithContextTimeout. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollImmediateWithContext(ctx context.Context, interval, timeout time.Duration, condition ConditionWithContextFunc) error { return poll(ctx, true, poller(interval, timeout), condition) } @@ -120,6 +179,10 @@ func PollImmediateWithContext(ctx context.Context, interval, timeout time.Durati // // PollImmediateUntil runs the 'condition' before waiting for the interval. // 'condition' will always be invoked at least once. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollImmediateUntil(interval time.Duration, condition ConditionFunc, stopCh <-chan struct{}) error { return PollImmediateUntilWithContext(ContextForChannel(stopCh), interval, condition.WithContext()) } @@ -129,6 +192,10 @@ func PollImmediateUntil(interval time.Duration, condition ConditionFunc, stopCh // // PollImmediateUntilWithContext runs the 'condition' before waiting for the interval. // 'condition' will always be invoked at least once. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollImmediateUntilWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { return poll(ctx, true, poller(interval, 0), condition) } @@ -139,6 +206,10 @@ func PollImmediateUntilWithContext(ctx context.Context, interval time.Duration, // // Some intervals may be missed if the condition takes too long or the time // window is too short. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollImmediateInfinite(interval time.Duration, condition ConditionFunc) error { return PollImmediateInfiniteWithContext(context.Background(), interval, condition.WithContext()) } @@ -150,6 +221,10 @@ func PollImmediateInfinite(interval time.Duration, condition ConditionFunc) erro // // Some intervals may be missed if the condition takes too long or the time // window is too short. +// +// Deprecated: This method does not return errors from context, use PollWithContextCancel. +// Note that the new method will no longer return ErrWaitTimeout and instead return errors +// defined by the context package. Will be removed in a future release. func PollImmediateInfiniteWithContext(ctx context.Context, interval time.Duration, condition ConditionWithContextFunc) error { return poll(ctx, true, poller(interval, 0), condition) } @@ -163,6 +238,8 @@ func PollImmediateInfiniteWithContext(ctx context.Context, interval time.Duratio // wait: user specified WaitFunc function that controls at what interval the condition // function should be invoked periodically and whether it is bound by a timeout. // condition: user specified ConditionWithContextFunc function. +// +// Deprecated: will be removed in favor of loopConditionUntilContext. func poll(ctx context.Context, immediate bool, wait waitWithContextFunc, condition ConditionWithContextFunc) error { if immediate { done, err := runConditionWithCrashProtectionWithContext(ctx, condition) @@ -176,7 +253,8 @@ func poll(ctx context.Context, immediate bool, wait waitWithContextFunc, conditi select { case <-ctx.Done(): - // returning ctx.Err() will break backward compatibility + // returning ctx.Err() will break backward compatibility, use new PollUntilContext* + // methods instead return ErrWaitTimeout default: return waitForWithContext(ctx, wait, condition) @@ -193,6 +271,8 @@ func poll(ctx context.Context, immediate bool, wait waitWithContextFunc, conditi // // Output ticks are not buffered. If the channel is not ready to receive an // item, the tick is skipped. +// +// Deprecated: Will be removed in a future release. func poller(interval, timeout time.Duration) waitWithContextFunc { return waitWithContextFunc(func(ctx context.Context) <-chan struct{} { ch := make(chan struct{}) diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/timer.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/timer.go new file mode 100644 index 00000000000..3efba321325 --- /dev/null +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/timer.go @@ -0,0 +1,121 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package wait + +import ( + "time" + + "k8s.io/utils/clock" +) + +// Timer abstracts how wait functions interact with time runtime efficiently. Test +// code may implement this interface directly but package consumers are encouraged +// to use the Backoff type as the primary mechanism for acquiring a Timer. The +// interface is a simplification of clock.Timer to prevent misuse. Timers are not +// expected to be safe for calls from multiple goroutines. +type Timer interface { + // C returns a channel that will receive a struct{} each time the timer fires. + // The channel should not be waited on after Stop() is invoked. It is allowed + // to cache the returned value of C() for the lifetime of the Timer. + C() <-chan time.Time + // Next is invoked by wait functions to signal timers that the next interval + // should begin. You may only use Next() if you have drained the channel C(). + // You should not call Next() after Stop() is invoked. + Next() + // Stop releases the timer. It is safe to invoke if no other methods have been + // called. + Stop() +} + +type noopTimer struct { + closedCh <-chan time.Time +} + +// newNoopTimer creates a timer with a unique channel to avoid contention +// for the channel's lock across multiple unrelated timers. +func newNoopTimer() noopTimer { + ch := make(chan time.Time) + close(ch) + return noopTimer{closedCh: ch} +} + +func (t noopTimer) C() <-chan time.Time { + return t.closedCh +} +func (noopTimer) Next() {} +func (noopTimer) Stop() {} + +type variableTimer struct { + fn DelayFunc + t clock.Timer + new func(time.Duration) clock.Timer +} + +func (t *variableTimer) C() <-chan time.Time { + if t.t == nil { + d := t.fn() + t.t = t.new(d) + } + return t.t.C() +} +func (t *variableTimer) Next() { + if t.t == nil { + return + } + d := t.fn() + t.t.Reset(d) +} +func (t *variableTimer) Stop() { + if t.t == nil { + return + } + t.t.Stop() + t.t = nil +} + +type fixedTimer struct { + interval time.Duration + t clock.Ticker + new func(time.Duration) clock.Ticker +} + +func (t *fixedTimer) C() <-chan time.Time { + if t.t == nil { + t.t = t.new(t.interval) + } + return t.t.C() +} +func (t *fixedTimer) Next() { + // no-op for fixed timers +} +func (t *fixedTimer) Stop() { + if t.t == nil { + return + } + t.t.Stop() + t.t = nil +} + +var ( + // RealTimer can be passed to methods that need a clock.Timer. + RealTimer = clock.RealClock{}.NewTimer +) + +var ( + // internalClock is used for test injection of clocks + internalClock = clock.RealClock{} +) diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go index c6e516dfc82..6805e8cf948 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go @@ -137,13 +137,18 @@ func (c channelContext) Err() error { func (c channelContext) Deadline() (time.Time, bool) { return time.Time{}, false } func (c channelContext) Value(key any) any { return nil } -// runConditionWithCrashProtection runs a ConditionFunc with crash protection +// runConditionWithCrashProtection runs a ConditionFunc with crash protection. +// +// Deprecated: Will be removed when the legacy polling methods are removed. func runConditionWithCrashProtection(condition ConditionFunc) (bool, error) { - return runConditionWithCrashProtectionWithContext(context.TODO(), condition.WithContext()) + defer runtime.HandleCrash() + return condition() } -// runConditionWithCrashProtectionWithContext runs a -// ConditionWithContextFunc with crash protection. +// runConditionWithCrashProtectionWithContext runs a ConditionWithContextFunc +// with crash protection. +// +// Deprecated: Will be removed when the legacy polling methods are removed. func runConditionWithCrashProtectionWithContext(ctx context.Context, condition ConditionWithContextFunc) (bool, error) { defer runtime.HandleCrash() return condition(ctx) @@ -151,6 +156,9 @@ func runConditionWithCrashProtectionWithContext(ctx context.Context, condition C // waitFunc creates a channel that receives an item every time a test // should be executed and is closed when the last test should be invoked. +// +// Deprecated: Will be removed in a future release in favor of +// loopConditionUntilContext. type waitFunc func(done <-chan struct{}) <-chan struct{} // WithContext converts the WaitFunc to an equivalent WaitWithContextFunc @@ -166,7 +174,8 @@ func (w waitFunc) WithContext() waitWithContextFunc { // When the specified context gets cancelled or expires the function // stops sending item and returns immediately. // -// Deprecated: Will be removed when the legacy Poll methods are removed. +// Deprecated: Will be removed in a future release in favor of +// loopConditionUntilContext. type waitWithContextFunc func(ctx context.Context) <-chan struct{} // waitForWithContext continually checks 'fn' as driven by 'wait'. @@ -186,7 +195,8 @@ type waitWithContextFunc func(ctx context.Context) <-chan struct{} // "uniform pseudo-random", the `fn` might still run one or multiple times, // though eventually `waitForWithContext` will return. // -// Deprecated: Will be removed when the legacy Poll methods are removed. +// Deprecated: Will be removed in a future release in favor of +// loopConditionUntilContext. func waitForWithContext(ctx context.Context, wait waitWithContextFunc, fn ConditionWithContextFunc) error { waitCtx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -205,7 +215,8 @@ func waitForWithContext(ctx context.Context, wait waitWithContextFunc, fn Condit return ErrWaitTimeout } case <-ctx.Done(): - // returning ctx.Err() will break backward compatibility + // returning ctx.Err() will break backward compatibility, use new PollUntilContext* + // methods instead return ErrWaitTimeout } } diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go index 82ff8866ff8..c8dd0bf58b3 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go @@ -114,7 +114,12 @@ func TestNonSlidingUntilWithContext(t *testing.T) { func TestUntilReturnsImmediately(t *testing.T) { now := time.Now() ch := make(chan struct{}) + var attempts int Until(func() { + attempts++ + if attempts > 1 { + t.Fatalf("invoked after close of channel") + } close(ch) }, 30*time.Second, ch) if now.Add(25 * time.Second).Before(time.Now()) { @@ -233,15 +238,24 @@ func TestJitterUntilNegativeFactor(t *testing.T) { if now.Add(3 * time.Second).Before(time.Now()) { t.Errorf("JitterUntil did not returned after predefined period with negative jitter factor when the stop chan was closed inside the func") } - } func TestExponentialBackoff(t *testing.T) { + // exits immediately + i := 0 + err := ExponentialBackoff(Backoff{Factor: 1.0}, func() (bool, error) { + i++ + return false, nil + }) + if err != ErrWaitTimeout || i != 0 { + t.Errorf("unexpected error: %v", err) + } + opts := Backoff{Factor: 1.0, Steps: 3} // waits up to steps - i := 0 - err := ExponentialBackoff(opts, func() (bool, error) { + i = 0 + err = ExponentialBackoff(opts, func() (bool, error) { i++ return false, nil }) @@ -339,7 +353,7 @@ func (fp *fakePoller) GetwaitFunc() waitFunc { func TestPoll(t *testing.T) { invocations := 0 - f := ConditionFunc(func() (bool, error) { + f := ConditionWithContextFunc(func(ctx context.Context) (bool, error) { invocations++ return true, nil }) @@ -347,7 +361,7 @@ func TestPoll(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - if err := poll(ctx, false, fp.GetwaitFunc().WithContext(), f.WithContext()); err != nil { + if err := poll(ctx, false, fp.GetwaitFunc().WithContext(), f); err != nil { t.Fatalf("unexpected error %v", err) } fp.wg.Wait() @@ -540,7 +554,7 @@ func Test_waitFor(t *testing.T) { } } -// Test_waitForWithEarlyClosing_waitFunc tests waitFor when the waitFunc closes its channel. The waitFor should +// Test_waitForWithEarlyClosing_waitFunc tests WaitFor when the waitFunc closes its channel. The WaitFor should // always return ErrWaitTimeout. func Test_waitForWithEarlyClosing_waitFunc(t *testing.T) { stopCh := make(chan struct{}) @@ -597,12 +611,12 @@ func Test_waitForWithClosedChannel(t *testing.T) { func Test_waitForWithContextCancelsContext(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - waitFunc := poller(time.Millisecond, ForeverTestTimeout) + waitFn := poller(time.Millisecond, ForeverTestTimeout) var ctxPassedToWait context.Context waitForWithContext(ctx, func(ctx context.Context) <-chan struct{} { ctxPassedToWait = ctx - return waitFunc(ctx) + return waitFn(ctx) }, func(ctx context.Context) (bool, error) { time.Sleep(10 * time.Millisecond) return true, nil @@ -633,14 +647,14 @@ func TestPollUntil(t *testing.T) { close(stopCh) go func() { - // release the condition func if needed - for { - <-called + // release the condition func if needed + for range called { } }() // make sure we finished the poll <-pollDone + close(called) } func TestBackoff_Step(t *testing.T) { @@ -648,6 +662,8 @@ func TestBackoff_Step(t *testing.T) { initial *Backoff want []time.Duration }{ + {initial: nil, want: []time.Duration{0, 0, 0, 0}}, + {initial: &Backoff{Duration: time.Second, Steps: -1}, want: []time.Duration{time.Second, time.Second, time.Second}}, {initial: &Backoff{Duration: time.Second, Steps: 0}, want: []time.Duration{time.Second, time.Second, time.Second}}, {initial: &Backoff{Duration: time.Second, Steps: 1}, want: []time.Duration{time.Second, time.Second, time.Second}}, {initial: &Backoff{Duration: time.Second, Factor: 1.0, Steps: 1}, want: []time.Duration{time.Second, time.Second, time.Second}}, @@ -658,13 +674,19 @@ func TestBackoff_Step(t *testing.T) { } for seed := int64(0); seed < 5; seed++ { for _, tt := range tests { - initial := *tt.initial + var initial *Backoff + if tt.initial != nil { + copied := *tt.initial + initial = &copied + } else { + initial = nil + } t.Run(fmt.Sprintf("%#v seed=%d", initial, seed), func(t *testing.T) { rand.Seed(seed) for i := 0; i < len(tt.want); i++ { got := initial.Step() t.Logf("[%d]=%s", i, got) - if initial.Jitter > 0 { + if initial != nil && initial.Jitter > 0 { if got == tt.want[i] { // this is statistically unlikely to happen by chance t.Errorf("Backoff.Step(%d) = %v, no jitter", i, got) @@ -779,11 +801,105 @@ func TestExponentialBackoffManagerWithRealClock(t *testing.T) { } } -func TestExponentialBackoffWithContext(t *testing.T) { - defaultCtx := func() context.Context { - return context.Background() +func TestBackoffDelayWithResetExponential(t *testing.T) { + fc := testingclock.NewFakeClock(time.Now()) + backoff := Backoff{Duration: 1, Cap: 10, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(fc, 10) + durations := []time.Duration{1, 2, 4, 8, 10, 10, 10} + for i := 0; i < len(durations); i++ { + generatedBackoff := backoff() + if generatedBackoff != durations[i] { + t.Errorf("unexpected %d-th backoff: %d, expecting %d", i, generatedBackoff, durations[i]) + } } + fc.Step(11) + resetDuration := backoff() + if resetDuration != 1 { + t.Errorf("after reset, backoff should be 1, but got %d", resetDuration) + } +} + +func TestBackoffDelayWithResetEmpty(t *testing.T) { + fc := testingclock.NewFakeClock(time.Now()) + backoff := Backoff{Duration: 1, Cap: 10, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(fc, 0) + // we reset to initial duration because the resetInterval is 0, immediate + durations := []time.Duration{1, 1, 1, 1, 1, 1, 1} + for i := 0; i < len(durations); i++ { + generatedBackoff := backoff() + if generatedBackoff != durations[i] { + t.Errorf("unexpected %d-th backoff: %d, expecting %d", i, generatedBackoff, durations[i]) + } + } + + fc.Step(11) + resetDuration := backoff() + if resetDuration != 1 { + t.Errorf("after reset, backoff should be 1, but got %d", resetDuration) + } +} + +func TestBackoffDelayWithResetJitter(t *testing.T) { + // positive jitter + backoff := Backoff{Duration: 1, Jitter: 1}.DelayWithReset(testingclock.NewFakeClock(time.Now()), 0) + for i := 0; i < 5; i++ { + value := backoff() + if value < 1 || value > 2 { + t.Errorf("backoff out of range: %d", value) + } + } + + // negative jitter, shall be a fixed backoff + backoff = Backoff{Duration: 1, Jitter: -1}.DelayWithReset(testingclock.NewFakeClock(time.Now()), 0) + value := backoff() + if value != 1 { + t.Errorf("backoff should be 1, but got %d", value) + } +} + +func TestBackoffDelayWithResetWithRealClockJitter(t *testing.T) { + backoff := Backoff{Duration: 1 * time.Millisecond, Jitter: 0}.DelayWithReset(&clock.RealClock{}, 0) + for i := 0; i < 5; i++ { + start := time.Now() + <-RealTimer(backoff()).C() + passed := time.Since(start) + if passed < 1*time.Millisecond { + t.Errorf("backoff should be at least 1ms, but got %s", passed.String()) + } + } +} + +func TestBackoffDelayWithResetWithRealClockExponential(t *testing.T) { + // backoff at least 1ms, 2ms, 4ms, 8ms, 10ms, 10ms, 10ms + durationFactors := []time.Duration{1, 2, 4, 8, 10, 10, 10} + backoff := Backoff{Duration: 1 * time.Millisecond, Cap: 10 * time.Millisecond, Factor: 2.0, Jitter: 0.0, Steps: 10}.DelayWithReset(&clock.RealClock{}, 1*time.Hour) + + for i := range durationFactors { + start := time.Now() + <-RealTimer(backoff()).C() + passed := time.Since(start) + if passed < durationFactors[i]*time.Millisecond { + t.Errorf("backoff should be at least %d ms, but got %s", durationFactors[i], passed.String()) + } + } +} + +func defaultContext() (context.Context, context.CancelFunc) { + return context.WithCancel(context.Background()) +} +func cancelledContext() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx, cancel +} +func deadlinedContext() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + for ctx.Err() != context.DeadlineExceeded { + time.Sleep(501 * time.Microsecond) + } + return ctx, cancel +} + +func TestExponentialBackoffWithContext(t *testing.T) { defaultCallback := func(_ int) (bool, error) { return false, nil } @@ -791,17 +907,18 @@ func TestExponentialBackoffWithContext(t *testing.T) { conditionErr := errors.New("condition failed") tests := []struct { - name string - steps int - ctxGetter func() context.Context - callback func(calls int) (bool, error) - attemptsExpected int - errExpected error + name string + steps int + zeroDuration bool + context func() (context.Context, context.CancelFunc) + callback func(calls int) (bool, error) + cancelContextAfter int + attemptsExpected int + errExpected error }{ { name: "no attempts expected with zero backoff steps", steps: 0, - ctxGetter: defaultCtx, callback: defaultCallback, attemptsExpected: 0, errExpected: ErrWaitTimeout, @@ -809,15 +926,13 @@ func TestExponentialBackoffWithContext(t *testing.T) { { name: "condition returns false with single backoff step", steps: 1, - ctxGetter: defaultCtx, callback: defaultCallback, attemptsExpected: 1, errExpected: ErrWaitTimeout, }, { - name: "condition returns true with single backoff step", - steps: 1, - ctxGetter: defaultCtx, + name: "condition returns true with single backoff step", + steps: 1, callback: func(_ int) (bool, error) { return true, nil }, @@ -827,15 +942,13 @@ func TestExponentialBackoffWithContext(t *testing.T) { { name: "condition always returns false with multiple backoff steps", steps: 5, - ctxGetter: defaultCtx, callback: defaultCallback, attemptsExpected: 5, errExpected: ErrWaitTimeout, }, { - name: "condition returns true after certain attempts with multiple backoff steps", - steps: 5, - ctxGetter: defaultCtx, + name: "condition returns true after certain attempts with multiple backoff steps", + steps: 5, callback: func(attempts int) (bool, error) { if attempts == 3 { return true, nil @@ -846,9 +959,8 @@ func TestExponentialBackoffWithContext(t *testing.T) { errExpected: nil, }, { - name: "condition returns error no further attempts expected", - steps: 5, - ctxGetter: defaultCtx, + name: "condition returns error no further attempts expected", + steps: 5, callback: func(_ int) (bool, error) { return true, conditionErr }, @@ -856,30 +968,118 @@ func TestExponentialBackoffWithContext(t *testing.T) { errExpected: conditionErr, }, { - name: "context already canceled no attempts expected", - steps: 5, - ctxGetter: func() context.Context { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ctx - }, + name: "context already canceled no attempts expected", + steps: 5, + context: cancelledContext, callback: defaultCallback, attemptsExpected: 0, errExpected: context.Canceled, }, + { + name: "context at deadline no attempts expected", + steps: 5, + context: deadlinedContext, + callback: defaultCallback, + attemptsExpected: 0, + errExpected: context.DeadlineExceeded, + }, + { + name: "no attempts expected with zero backoff steps", + steps: 0, + callback: defaultCallback, + attemptsExpected: 0, + errExpected: ErrWaitTimeout, + }, + { + name: "condition returns false with single backoff step", + steps: 1, + callback: defaultCallback, + attemptsExpected: 1, + errExpected: ErrWaitTimeout, + }, + { + name: "condition returns true with single backoff step", + steps: 1, + callback: func(_ int) (bool, error) { + return true, nil + }, + attemptsExpected: 1, + errExpected: nil, + }, + { + name: "condition always returns false with multiple backoff steps but is cancelled at step 4", + steps: 5, + callback: defaultCallback, + attemptsExpected: 4, + cancelContextAfter: 4, + errExpected: context.Canceled, + }, + { + name: "condition returns true after certain attempts with multiple backoff steps and zero duration", + steps: 5, + zeroDuration: true, + callback: func(attempts int) (bool, error) { + if attempts == 3 { + return true, nil + } + return false, nil + }, + attemptsExpected: 3, + errExpected: nil, + }, + { + name: "condition returns error no further attempts expected", + steps: 5, + callback: func(_ int) (bool, error) { + return true, conditionErr + }, + attemptsExpected: 1, + errExpected: conditionErr, + }, + { + name: "context already canceled no attempts expected", + steps: 5, + context: cancelledContext, + callback: defaultCallback, + attemptsExpected: 0, + errExpected: context.Canceled, + }, + { + name: "context at deadline no attempts expected", + steps: 5, + context: deadlinedContext, + callback: defaultCallback, + attemptsExpected: 0, + errExpected: context.DeadlineExceeded, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { backoff := Backoff{ - Duration: 1 * time.Millisecond, + Duration: 1 * time.Microsecond, Factor: 1.0, Steps: test.steps, } + if test.zeroDuration { + backoff.Duration = 0 + } + + contextFn := test.context + if contextFn == nil { + contextFn = defaultContext + } + ctx, cancel := contextFn() + defer cancel() attempts := 0 - err := ExponentialBackoffWithContext(test.ctxGetter(), backoff, func(_ context.Context) (bool, error) { + err := ExponentialBackoffWithContext(ctx, backoff, func(_ context.Context) (bool, error) { attempts++ + defer func() { + if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts { + cancel() + } + }() return test.callback(attempts) }) @@ -894,6 +1094,26 @@ func TestExponentialBackoffWithContext(t *testing.T) { } } +func BenchmarkExponentialBackoffWithContext(b *testing.B) { + backoff := Backoff{ + Duration: 0, + Factor: 0, + Steps: 101, + } + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + attempts := 0 + if err := ExponentialBackoffWithContext(ctx, backoff, func(_ context.Context) (bool, error) { + attempts++ + return attempts >= 100, nil + }); err != nil { + b.Fatalf("unexpected err: %v", err) + } + } +} + func TestPollImmediateUntilWithContext(t *testing.T) { fakeErr := errors.New("my error") tests := []struct { @@ -911,9 +1131,6 @@ func TestPollImmediateUntilWithContext(t *testing.T) { return false, fakeErr } }, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, errExpected: fakeErr, attemptsExpected: 1, }, @@ -924,9 +1141,6 @@ func TestPollImmediateUntilWithContext(t *testing.T) { return true, nil } }, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, errExpected: nil, attemptsExpected: 1, }, @@ -937,12 +1151,8 @@ func TestPollImmediateUntilWithContext(t *testing.T) { return false, nil } }, - context: func() (context.Context, context.CancelFunc) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ctx, cancel - }, - errExpected: ErrWaitTimeout, + context: cancelledContext, + errExpected: ErrWaitTimeout, // this should be context.Canceled but that would break callers that assume all errors are ErrWaitTimeout attemptsExpected: 1, }, { @@ -956,9 +1166,6 @@ func TestPollImmediateUntilWithContext(t *testing.T) { return true, nil } }, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, errExpected: nil, attemptsExpected: 4, }, @@ -969,18 +1176,19 @@ func TestPollImmediateUntilWithContext(t *testing.T) { return false, nil } }, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, cancelContextAfterNthAttempt: 4, - errExpected: ErrWaitTimeout, + errExpected: ErrWaitTimeout, // this should be context.Canceled, but this method cannot change attemptsExpected: 4, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx, cancel := test.context() + contextFn := test.context + if contextFn == nil { + contextFn = defaultContext + } + ctx, cancel := contextFn() defer cancel() var attempts int @@ -1018,10 +1226,8 @@ func Test_waitForWithContext(t *testing.T) { errExpected error }{ { - name: "condition returns done=true on first attempt, no retry is attempted", - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, + name: "condition returns done=true on first attempt, no retry is attempted", + context: defaultContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return true, nil }), @@ -1030,10 +1236,8 @@ func Test_waitForWithContext(t *testing.T) { errExpected: nil, }, { - name: "condition always returns done=false, timeout error expected", - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, + name: "condition always returns done=false, timeout error expected", + context: defaultContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), @@ -1043,10 +1247,8 @@ func Test_waitForWithContext(t *testing.T) { errExpected: ErrWaitTimeout, }, { - name: "condition returns an error on first attempt, the error is returned", - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, + name: "condition returns an error on first attempt, the error is returned", + context: defaultContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, fakeErr }), @@ -1055,12 +1257,8 @@ func Test_waitForWithContext(t *testing.T) { errExpected: fakeErr, }, { - name: "context is cancelled, context cancelled error expected", - context: func() (context.Context, context.CancelFunc) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - return ctx, cancel - }, + name: "context is cancelled, context cancelled error expected", + context: cancelledContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), @@ -1086,7 +1284,11 @@ func Test_waitForWithContext(t *testing.T) { ticker := test.waitFunc() err := func() error { - ctx, cancel := test.context() + contextFn := test.context + if contextFn == nil { + contextFn = defaultContext + } + ctx, cancel := contextFn() defer cancel() return waitForWithContext(ctx, ticker.WithContext(), conditionWrapper) @@ -1102,7 +1304,7 @@ func Test_waitForWithContext(t *testing.T) { } } -func TestPollInternal(t *testing.T) { +func Test_poll(t *testing.T) { fakeErr := errors.New("fake error") tests := []struct { name string @@ -1117,13 +1319,6 @@ func TestPollInternal(t *testing.T) { { name: "immediate is true, condition returns an error", immediate: true, - context: func() (context.Context, context.CancelFunc) { - // use a cancelled context, we want to make sure the - // condition is expected to be invoked immediately. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ctx, cancel - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, fakeErr }), @@ -1134,13 +1329,6 @@ func TestPollInternal(t *testing.T) { { name: "immediate is true, condition returns true", immediate: true, - context: func() (context.Context, context.CancelFunc) { - // use a cancelled context, we want to make sure the - // condition is expected to be invoked immediately. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ctx, cancel - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return true, nil }), @@ -1151,13 +1339,7 @@ func TestPollInternal(t *testing.T) { { name: "immediate is true, context is cancelled, condition return false", immediate: true, - context: func() (context.Context, context.CancelFunc) { - // use a cancelled context, we want to make sure the - // condition is expected to be invoked immediately. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ctx, cancel - }, + context: cancelledContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), @@ -1168,13 +1350,7 @@ func TestPollInternal(t *testing.T) { { name: "immediate is false, context is cancelled", immediate: false, - context: func() (context.Context, context.CancelFunc) { - // use a cancelled context, we want to make sure the - // condition is expected to be invoked immediately. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - return ctx, cancel - }, + context: cancelledContext, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), @@ -1185,9 +1361,6 @@ func TestPollInternal(t *testing.T) { { name: "immediate is false, condition returns an error", immediate: false, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, fakeErr }), @@ -1198,9 +1371,6 @@ func TestPollInternal(t *testing.T) { { name: "immediate is false, condition returns true", immediate: false, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return true, nil }), @@ -1211,9 +1381,6 @@ func TestPollInternal(t *testing.T) { { name: "immediate is false, ticker channel is closed, condition returns true", immediate: false, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return true, nil }), @@ -1230,9 +1397,6 @@ func TestPollInternal(t *testing.T) { { name: "immediate is false, ticker channel is closed, condition returns error", immediate: false, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, fakeErr }), @@ -1249,9 +1413,6 @@ func TestPollInternal(t *testing.T) { { name: "immediate is false, ticker channel is closed, condition returns false", immediate: false, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), @@ -1268,9 +1429,6 @@ func TestPollInternal(t *testing.T) { { name: "condition always returns false, timeout error expected", immediate: false, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) - }, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), @@ -1282,9 +1440,27 @@ func TestPollInternal(t *testing.T) { { name: "context is cancelled after N attempts, timeout error expected", immediate: false, - context: func() (context.Context, context.CancelFunc) { - return context.WithCancel(context.Background()) + condition: ConditionWithContextFunc(func(context.Context) (bool, error) { + return false, nil + }), + waitFunc: func() waitFunc { + return func(done <-chan struct{}) <-chan struct{} { + ch := make(chan struct{}) + // just tick twice + go func() { + ch <- struct{}{} + ch <- struct{}{} + }() + return ch + } }, + cancelContextAfter: 2, + attemptsExpected: 2, + errExpected: ErrWaitTimeout, + }, + { + name: "context is cancelled after N attempts, context error not expected (legacy behavior)", + immediate: false, condition: ConditionWithContextFunc(func(context.Context) (bool, error) { return false, nil }), @@ -1315,7 +1491,11 @@ func TestPollInternal(t *testing.T) { ticker = test.waitFunc() } err := func() error { - ctx, cancel := test.context() + contextFn := test.context + if contextFn == nil { + contextFn = defaultContext + } + ctx, cancel := contextFn() defer cancel() conditionWrapper := func(ctx context.Context) (done bool, err error) { @@ -1342,3 +1522,17 @@ func TestPollInternal(t *testing.T) { }) } } + +func Benchmark_poll(b *testing.B) { + ctx := context.Background() + b.ResetTimer() + for i := 0; i < b.N; i++ { + attempts := 0 + if err := poll(ctx, true, poller(time.Microsecond, 0), func(_ context.Context) (bool, error) { + attempts++ + return attempts >= 100, nil + }); err != nil { + b.Fatalf("unexpected err: %v", err) + } + } +} diff --git a/staging/src/k8s.io/client-go/tools/watch/until_test.go b/staging/src/k8s.io/client-go/tools/watch/until_test.go index 2d64c423ac2..a9367a81817 100644 --- a/staging/src/k8s.io/client-go/tools/watch/until_test.go +++ b/staging/src/k8s.io/client-go/tools/watch/until_test.go @@ -209,7 +209,7 @@ func TestUntilWithSync(t *testing.T) { conditionFunc: func(e watch.Event) (bool, error) { return true, nil }, - expectedErr: errors.New("timed out waiting for the condition"), + expectedErr: wait.ErrWaitTimeout, expectedEvent: nil, }, { diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_instances_test.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_instances_test.go index 6c47a37b763..c699d5c9d64 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_instances_test.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_instances_test.go @@ -34,6 +34,7 @@ import ( v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/wait" cloudprovider "k8s.io/cloud-provider" azcache "k8s.io/legacy-cloud-providers/azure/cache" "k8s.io/legacy-cloud-providers/azure/clients/interfaceclient/mockinterfaceclient" @@ -487,7 +488,7 @@ func TestNodeAddresses(t *testing.T) { metadataName: "vm1", vmType: vmTypeStandard, useInstanceMetadata: true, - expectedErrMsg: fmt.Errorf("timed out waiting for the condition"), + expectedErrMsg: wait.ErrWaitTimeout, }, { name: "NodeAddresses should get IP addresses from Azure API if node's name isn't equal to metadataName", diff --git a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_routes_test.go b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_routes_test.go index 4f2412feeae..2a0e1ccf381 100644 --- a/staging/src/k8s.io/legacy-cloud-providers/azure/azure_routes_test.go +++ b/staging/src/k8s.io/legacy-cloud-providers/azure/azure_routes_test.go @@ -33,6 +33,7 @@ import ( "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/util/wait" cloudprovider "k8s.io/cloud-provider" "k8s.io/legacy-cloud-providers/azure/clients/routetableclient/mockroutetableclient" "k8s.io/legacy-cloud-providers/azure/mockvmsets" @@ -226,7 +227,7 @@ func TestCreateRoute(t *testing.T) { name: "CreateRoute should report error if error occurs when invoke GetIPByNodeName", routeTableName: "rt7", getIPError: fmt.Errorf("getIP error"), - expectedErrMsg: fmt.Errorf("timed out waiting for the condition"), + expectedErrMsg: wait.ErrWaitTimeout, }, { name: "CreateRoute should add route to cloud.RouteCIDRs if node is unmanaged",