diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go index d759d912be1..3dea7fe7f9e 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait.go @@ -604,3 +604,32 @@ func poller(interval, timeout time.Duration) WaitFunc { return ch }) } + +// ExponentialBackoffWithContext works with a request context and a Backoff. It ensures that the retry wait never +// exceeds the deadline specified by the request context. +func ExponentialBackoffWithContext(ctx context.Context, backoff Backoff, condition ConditionFunc) error { + for backoff.Steps > 0 { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if ok, err := runConditionWithCrashProtection(condition); err != nil || ok { + return err + } + + if backoff.Steps == 1 { + break + } + + waitBeforeRetry := backoff.Step() + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(waitBeforeRetry): + } + } + + return ErrWaitTimeout +} diff --git a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go index d73735b7d06..0eab37f8888 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/wait/wait_test.go @@ -758,3 +758,118 @@ func TestExponentialBackoffManagerWithRealClock(t *testing.T) { } } } + +func TestExponentialBackoffWithContext(t *testing.T) { + defaultCtx := func() context.Context { + return context.Background() + } + + defaultCallback := func(_ int) (bool, error) { + return false, nil + } + + conditionErr := errors.New("condition failed") + + tests := []struct { + name string + steps int + ctxGetter func() context.Context + callback func(calls int) (bool, error) + attemptsExpected int + errExpected error + }{ + { + name: "no attempts expected with zero backoff steps", + steps: 0, + ctxGetter: defaultCtx, + callback: defaultCallback, + attemptsExpected: 0, + errExpected: ErrWaitTimeout, + }, + { + name: "condition returns false with single backoff step", + steps: 1, + ctxGetter: defaultCtx, + callback: defaultCallback, + attemptsExpected: 1, + errExpected: ErrWaitTimeout, + }, + { + name: "condition returns true with single backoff step", + steps: 1, + ctxGetter: defaultCtx, + callback: func(_ int) (bool, error) { + return true, nil + }, + attemptsExpected: 1, + errExpected: nil, + }, + { + name: "condition always returns false with multiple backoff steps", + steps: 5, + ctxGetter: defaultCtx, + callback: defaultCallback, + attemptsExpected: 5, + errExpected: ErrWaitTimeout, + }, + { + name: "condition returns true after certain attempts with multiple backoff steps", + steps: 5, + ctxGetter: defaultCtx, + callback: func(attempts int) (bool, error) { + if attempts == 3 { + return true, nil + } + return false, nil + }, + attemptsExpected: 3, + errExpected: nil, + }, + { + name: "condition returns error no further attempts expected", + steps: 5, + ctxGetter: defaultCtx, + callback: func(_ int) (bool, error) { + return true, conditionErr + }, + attemptsExpected: 1, + errExpected: conditionErr, + }, + { + name: "context already canceled no attempts expected", + steps: 5, + ctxGetter: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + return ctx + }, + callback: defaultCallback, + attemptsExpected: 0, + errExpected: context.Canceled, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + backoff := Backoff{ + Duration: 1 * time.Millisecond, + Factor: 1.0, + Steps: test.steps, + } + + attempts := 0 + err := ExponentialBackoffWithContext(test.ctxGetter(), backoff, func() (bool, error) { + attempts++ + return test.callback(attempts) + }) + + if test.errExpected != err { + t.Errorf("expected error: %v but got: %v", test.errExpected, err) + } + + if test.attemptsExpected != attempts { + t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts) + } + }) + } +} diff --git a/staging/src/k8s.io/apiserver/pkg/util/webhook/webhook.go b/staging/src/k8s.io/apiserver/pkg/util/webhook/webhook.go index 2128647e08f..799107e1350 100644 --- a/staging/src/k8s.io/apiserver/pkg/util/webhook/webhook.go +++ b/staging/src/k8s.io/apiserver/pkg/util/webhook/webhook.go @@ -131,20 +131,27 @@ func WithExponentialBackoff(ctx context.Context, initialBackoff time.Duration, w Steps: 5, } - var err error - wait.ExponentialBackoff(backoff, func() (bool, error) { - err = webhookFn() - if ctx.Err() != nil { - // we timed out or were cancelled, we should not retry - return true, err - } - if shouldRetry(err) { + // having a webhook error allows us to track the last actual webhook error for requests that + // are later cancelled or time out. + var webhookErr error + err := wait.ExponentialBackoffWithContext(ctx, backoff, func() (bool, error) { + webhookErr = webhookFn() + if shouldRetry(webhookErr) { return false, nil } - if err != nil { - return false, err + if webhookErr != nil { + return false, webhookErr } return true, nil }) - return err + + switch { + // we check for webhookErr first, if webhookErr is set it's the most important error to return. + case webhookErr != nil: + return webhookErr + case err != nil: + return fmt.Errorf("webhook call failed: %s", err.Error()) + default: + return nil + } } diff --git a/staging/src/k8s.io/apiserver/pkg/util/webhook/webhook_test.go b/staging/src/k8s.io/apiserver/pkg/util/webhook/webhook_test.go index b70d146e838..8d59c76a6c1 100644 --- a/staging/src/k8s.io/apiserver/pkg/util/webhook/webhook_test.go +++ b/staging/src/k8s.io/apiserver/pkg/util/webhook/webhook_test.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/json" + "errors" "fmt" "io/ioutil" "net/http" @@ -653,3 +654,57 @@ func newTestServer(clientCert, clientKey, caCert []byte, handler func(http.Respo return server, nil } + +func TestWithExponentialBackoffContextIsAlreadyCanceled(t *testing.T) { + alwaysRetry := func(e error) bool { + return true + } + + attemptsGot := 0 + webhookFunc := func() error { + attemptsGot++ + return nil + } + + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + + // We don't expect the webhook function to be called since the context is already canceled. + err := WithExponentialBackoff(ctx, time.Millisecond, webhookFunc, alwaysRetry) + + errExpected := fmt.Errorf("webhook call failed: %s", context.Canceled) + if errExpected.Error() != err.Error() { + t.Errorf("expected error: %v, but got: %v", errExpected, err) + } + if attemptsGot != 0 { + t.Errorf("expected %d webhook attempts, but got: %d", 0, attemptsGot) + } +} + +func TestWithExponentialBackoffWebhookErrorIsMostImportant(t *testing.T) { + alwaysRetry := func(e error) bool { + return true + } + + ctx, cancel := context.WithCancel(context.TODO()) + attemptsGot := 0 + errExpected := errors.New("webhook not available") + webhookFunc := func() error { + attemptsGot++ + + // after the first attempt, the context is canceled + cancel() + + return errExpected + } + + // webhook err has higher priority than ctx error. we expect the webhook error to be returned. + err := WithExponentialBackoff(ctx, time.Millisecond, webhookFunc, alwaysRetry) + + if attemptsGot != 1 { + t.Errorf("expected %d webhook attempts, but got: %d", 1, attemptsGot) + } + if errExpected != err { + t.Errorf("expected error: %v, but got: %v", errExpected, err) + } +}