Merge pull request #119762 from AxeZhan/PollUntilContextCancel

wait.PollUntilContextCancel immediately executes condition once
This commit is contained in:
Kubernetes Prow Robot 2023-11-02 05:40:03 +01:00 committed by GitHub
commit 227d1b2357
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 96 additions and 54 deletions

View File

@ -40,6 +40,10 @@ func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding
var timeCh <-chan time.Time
doneCh := ctx.Done()
if !sliding {
timeCh = t.C()
}
// if immediate is true the condition is
// guaranteed to be executed at least once,
// if we haven't requested immediate execution, delay once
@ -50,17 +54,27 @@ func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding
}(); err != nil || ok {
return err
}
} else {
}
if sliding {
timeCh = t.C()
}
for {
// Wait for either the context to be cancelled or the next invocation be called
select {
case <-doneCh:
return ctx.Err()
case <-timeCh:
}
}
for {
// checking ctx.Err() is slightly faster than checking a select
// IMPORTANT: Because there is no channel priority selection in golang
// it is possible for very short timers to "win" the race in the previous select
// repeatedly even when the context has been canceled. We therefore must
// explicitly check for context cancellation on every loop and exit if true to
// guarantee that we don't invoke condition more than once after context has
// been cancelled.
if err := ctx.Err(); err != nil {
return err
}
@ -77,21 +91,5 @@ func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding
if sliding {
t.Next()
}
if timeCh == nil {
timeCh = t.C()
}
// NOTE: b/c there is no priority selection in golang
// it is possible for this to race, meaning we could
// trigger t.C and doneCh, and t.C select falls through.
// In order to mitigate we re-check doneCh at the beginning
// of every loop to guarantee at-most one extra execution
// of condition.
select {
case <-doneCh:
return ctx.Err()
case <-timeCh:
}
}
}

View File

@ -99,6 +99,7 @@ func Test_loopConditionUntilContext_semantic(t *testing.T) {
cancelContextAfter int
attemptsExpected int
errExpected error
timer Timer
}{
{
name: "condition successful is only one attempt",
@ -203,45 +204,88 @@ func Test_loopConditionUntilContext_semantic(t *testing.T) {
attemptsExpected: 0,
errExpected: context.DeadlineExceeded,
},
{
name: "context canceled before the second execution and immediate",
immediate: true,
context: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Second)
},
callback: func(attempts int) (bool, error) {
return false, nil
},
attemptsExpected: 1,
errExpected: context.DeadlineExceeded,
timer: Backoff{Duration: 2 * time.Second}.Timer(),
},
{
name: "immediate and long duration of condition and sliding false",
immediate: true,
sliding: false,
context: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Second)
},
callback: func(attempts int) (bool, error) {
if attempts >= 4 {
return true, nil
}
time.Sleep(time.Second / 5)
return false, nil
},
attemptsExpected: 4,
timer: Backoff{Duration: time.Second / 5, Jitter: 0.001}.Timer(),
},
{
name: "immediate and long duration of condition and sliding true",
immediate: true,
sliding: true,
context: func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), time.Second)
},
callback: func(attempts int) (bool, error) {
if attempts >= 4 {
return true, nil
}
time.Sleep(time.Second / 5)
return false, nil
},
errExpected: context.DeadlineExceeded,
attemptsExpected: 3,
timer: Backoff{Duration: time.Second / 5, Jitter: 0.001}.Timer(),
},
}
for _, test := range tests {
for _, immediate := range []bool{true, false} {
t.Run(fmt.Sprintf("immediate=%t", immediate), func(t *testing.T) {
for _, sliding := range []bool{true, false} {
t.Run(fmt.Sprintf("sliding=%t", sliding), func(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
contextFn := test.context
if contextFn == nil {
contextFn = defaultContext
}
ctx, cancel := contextFn()
defer cancel()
t.Run(test.name, func(t *testing.T) {
contextFn := test.context
if contextFn == nil {
contextFn = defaultContext
}
ctx, cancel := contextFn()
defer cancel()
timer := Backoff{Duration: time.Microsecond}.Timer()
attempts := 0
err := loopConditionUntilContext(ctx, timer, test.immediate, test.sliding, func(_ context.Context) (bool, error) {
attempts++
defer func() {
if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts {
cancel()
}
}()
return test.callback(attempts)
})
if test.errExpected != err {
t.Errorf("expected error: %v but got: %v", test.errExpected, err)
}
if test.attemptsExpected != attempts {
t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts)
}
})
})
}
timer := test.timer
if timer == nil {
timer = Backoff{Duration: time.Microsecond}.Timer()
}
attempts := 0
err := loopConditionUntilContext(ctx, timer, test.immediate, test.sliding, func(_ context.Context) (bool, error) {
attempts++
defer func() {
if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts {
cancel()
}
}()
return test.callback(attempts)
})
}
if test.errExpected != err {
t.Errorf("expected error: %v but got: %v", test.errExpected, err)
}
if test.attemptsExpected != attempts {
t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts)
}
})
}
}