diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go index e60f45bffd0..204177563b0 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go @@ -88,6 +88,15 @@ func Until(f func(), period time.Duration, stopCh <-chan struct{}) { JitterUntil(f, period, 0.0, true, stopCh) } +// UntilWithContext loops until context is done, running f every period. +// +// UntilWithContext is syntactic sugar on top of JitterUntilWithContext +// with zero jitter factor and with sliding = true (which means the timer +// for period starts after the f completes). +func UntilWithContext(ctx context.Context, f func(context.Context), period time.Duration) { + JitterUntilWithContext(ctx, f, period, 0.0, true) +} + // NonSlidingUntil loops until stop channel is closed, running f every // period. // @@ -98,6 +107,16 @@ func NonSlidingUntil(f func(), period time.Duration, stopCh <-chan struct{}) { JitterUntil(f, period, 0.0, false, stopCh) } +// NonSlidingUntilWithContext loops until context is done, running f every +// period. +// +// NonSlidingUntilWithContext is syntactic sugar on top of JitterUntilWithContext +// with zero jitter factor, with sliding = false (meaning the timer for period +// starts at the same time as the function starts). +func NonSlidingUntilWithContext(ctx context.Context, f func(context.Context), period time.Duration) { + JitterUntilWithContext(ctx, f, period, 0.0, false) +} + // JitterUntil loops until stop channel is closed, running f every period. // // If jitterFactor is positive, the period is jittered before every run of f. @@ -151,6 +170,19 @@ func JitterUntil(f func(), period time.Duration, jitterFactor float64, sliding b } } +// JitterUntilWithContext loops until context is done, running f every period. +// +// If jitterFactor is positive, the period is jittered before every run of f. +// If jitterFactor is not positive, the period is unchanged and not jittered. +// +// If sliding is true, the period is computed after f runs. If it is false then +// period includes the runtime for f. +// +// Cancel context to stop. f may not be invoked if context is already expired. +func JitterUntilWithContext(ctx context.Context, f func(context.Context), period time.Duration, jitterFactor float64, sliding bool) { + JitterUntil(func() { f(ctx) }, period, jitterFactor, sliding, ctx.Done()) +} + // Jitter returns a time.Duration between duration and duration + maxFactor * // duration. // diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go index 89987983d7c..50c06d47549 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go @@ -17,6 +17,7 @@ limitations under the License. package wait import ( + "context" "errors" "fmt" "math/rand" @@ -48,6 +49,26 @@ func TestUntil(t *testing.T) { <-called } +func TestUntilWithContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + UntilWithContext(ctx, func(context.Context) { + t.Fatal("should not have been invoked") + }, 0) + + ctx, cancel = context.WithCancel(context.TODO()) + called := make(chan struct{}) + go func() { + UntilWithContext(ctx, func(context.Context) { + called <- struct{}{} + }, 0) + close(called) + }() + <-called + cancel() + <-called +} + func TestNonSlidingUntil(t *testing.T) { ch := make(chan struct{}) close(ch) @@ -68,6 +89,26 @@ func TestNonSlidingUntil(t *testing.T) { <-called } +func TestNonSlidingUntilWithContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + NonSlidingUntilWithContext(ctx, func(context.Context) { + t.Fatal("should not have been invoked") + }, 0) + + ctx, cancel = context.WithCancel(context.TODO()) + called := make(chan struct{}) + go func() { + NonSlidingUntilWithContext(ctx, func(context.Context) { + called <- struct{}{} + }, 0) + close(called) + }() + <-called + cancel() + <-called +} + func TestUntilReturnsImmediately(t *testing.T) { now := time.Now() ch := make(chan struct{}) @@ -101,6 +142,26 @@ func TestJitterUntil(t *testing.T) { <-called } +func TestJitterUntilWithContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + JitterUntilWithContext(ctx, func(context.Context) { + t.Fatal("should not have been invoked") + }, 0, 1.0, true) + + ctx, cancel = context.WithCancel(context.TODO()) + called := make(chan struct{}) + go func() { + JitterUntilWithContext(ctx, func(context.Context) { + called <- struct{}{} + }, 0, 1.0, true) + close(called) + }() + <-called + cancel() + <-called +} + func TestJitterUntilReturnsImmediately(t *testing.T) { now := time.Now() ch := make(chan struct{})