From 55854fadb1cdadce145df6b3930f61573e6ce558 Mon Sep 17 00:00:00 2001 From: Abu Kashem Date: Tue, 18 May 2021 15:15:28 -0400 Subject: [PATCH] client-go: add retry logic for Watch and Stream Kubernetes-commit: 607d3819498e64d969407c3d7cbbb8f53d98f0d4 --- rest/request.go | 204 ++++++++---- rest/request_test.go | 687 +++++++++++++++++++++++++++++----------- rest/with_retry.go | 4 + rest/with_retry_test.go | 4 - 4 files changed, 651 insertions(+), 248 deletions(-) diff --git a/rest/request.go b/rest/request.go index 9864dbdb..8f66c079 100644 --- a/rest/request.go +++ b/rest/request.go @@ -675,43 +675,88 @@ func (r *Request) Watch(ctx context.Context) (watch.Interface, error) { return nil, r.err } - url := r.URL().String() - req, err := http.NewRequest(r.verb, url, r.body) - if err != nil { - return nil, err - } - req = req.WithContext(ctx) - req.Header = r.headers client := r.c.Client if client == nil { client = http.DefaultClient } - r.backoff.Sleep(r.backoff.CalculateBackoff(r.URL())) - resp, err := client.Do(req) - updateURLMetrics(ctx, r, resp, err) - if r.c.base != nil { - if err != nil { - r.backoff.UpdateBackoff(r.c.base, err, 0) - } else { - r.backoff.UpdateBackoff(r.c.base, err, resp.StatusCode) - } - } - if err != nil { + + isErrRetryableFunc := func(request *http.Request, err error) bool { // The watch stream mechanism handles many common partial data errors, so closed // connections can be retried in many cases. if net.IsProbableEOF(err) || net.IsTimeout(err) { - return watch.NewEmptyWatch(), nil + return true } - return nil, err + return false } - if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - if result := r.transformResponse(resp, req); result.err != nil { - return nil, result.err + var retryAfter *RetryAfter + url := r.URL().String() + for { + req, err := r.newHTTPRequest(ctx) + if err != nil { + return nil, err } - return nil, fmt.Errorf("for request %s, got status: %v", url, resp.StatusCode) - } + r.backoff.Sleep(r.backoff.CalculateBackoff(r.URL())) + if retryAfter != nil { + // We are retrying the request that we already send to apiserver + // at least once before. + // This request should also be throttled with the client-internal rate limiter. + if err := r.tryThrottleWithInfo(ctx, retryAfter.Reason); err != nil { + return nil, err + } + retryAfter = nil + } + + resp, err := client.Do(req) + updateURLMetrics(ctx, r, resp, err) + if r.c.base != nil { + if err != nil { + r.backoff.UpdateBackoff(r.c.base, err, 0) + } else { + r.backoff.UpdateBackoff(r.c.base, err, resp.StatusCode) + } + } + if err == nil && resp.StatusCode == http.StatusOK { + return r.newStreamWatcher(resp) + } + + done, transformErr := func() (bool, error) { + defer readAndCloseResponseBody(resp) + + var retry bool + retryAfter, retry = r.retry.NextRetry(req, resp, err, isErrRetryableFunc) + if retry { + err := r.retry.BeforeNextRetry(ctx, r.backoff, retryAfter, url, r.body) + if err == nil { + return false, nil + } + klog.V(4).Infof("Could not retry request - %v", err) + } + + if resp == nil { + // the server must have sent us an error in 'err' + return true, nil + } + if result := r.transformResponse(resp, req); result.err != nil { + return true, result.err + } + return true, fmt.Errorf("for request %s, got status: %v", url, resp.StatusCode) + }() + if done { + if isErrRetryableFunc(req, err) { + return watch.NewEmptyWatch(), nil + } + if err == nil { + // if the server sent us an HTTP Response object, + // we need to return the error object from that. + err = transformErr + } + return nil, err + } + } +} + +func (r *Request) newStreamWatcher(resp *http.Response) (watch.Interface, error) { contentType := resp.Header.Get("Content-Type") mediaType, params, err := mime.ParseMediaType(contentType) if err != nil { @@ -766,49 +811,75 @@ func (r *Request) Stream(ctx context.Context) (io.ReadCloser, error) { return nil, err } - url := r.URL().String() - req, err := http.NewRequest(r.verb, url, nil) - if err != nil { - return nil, err - } - if r.body != nil { - req.Body = ioutil.NopCloser(r.body) - } - req = req.WithContext(ctx) - req.Header = r.headers client := r.c.Client if client == nil { client = http.DefaultClient } - r.backoff.Sleep(r.backoff.CalculateBackoff(r.URL())) - resp, err := client.Do(req) - updateURLMetrics(ctx, r, resp, err) - if r.c.base != nil { + + var retryAfter *RetryAfter + url := r.URL().String() + for { + req, err := r.newHTTPRequest(ctx) if err != nil { - r.backoff.UpdateBackoff(r.URL(), err, 0) - } else { - r.backoff.UpdateBackoff(r.URL(), err, resp.StatusCode) + return nil, err } - } - if err != nil { - return nil, err - } - - switch { - case (resp.StatusCode >= 200) && (resp.StatusCode < 300): - handleWarnings(resp.Header, r.warningHandler) - return resp.Body, nil - - default: - // ensure we close the body before returning the error - defer resp.Body.Close() - - result := r.transformResponse(resp, req) - err := result.Error() - if err == nil { - err = fmt.Errorf("%d while accessing %v: %s", result.statusCode, url, string(result.body)) + if r.body != nil { + req.Body = ioutil.NopCloser(r.body) + } + + r.backoff.Sleep(r.backoff.CalculateBackoff(r.URL())) + if retryAfter != nil { + // We are retrying the request that we already send to apiserver + // at least once before. + // This request should also be throttled with the client-internal rate limiter. + if err := r.tryThrottleWithInfo(ctx, retryAfter.Reason); err != nil { + return nil, err + } + retryAfter = nil + } + + resp, err := client.Do(req) + updateURLMetrics(ctx, r, resp, err) + if r.c.base != nil { + if err != nil { + r.backoff.UpdateBackoff(r.URL(), err, 0) + } else { + r.backoff.UpdateBackoff(r.URL(), err, resp.StatusCode) + } + } + if err != nil { + // we only retry on an HTTP response with 'Retry-After' header + return nil, err + } + + switch { + case (resp.StatusCode >= 200) && (resp.StatusCode < 300): + handleWarnings(resp.Header, r.warningHandler) + return resp.Body, nil + + default: + done, transformErr := func() (bool, error) { + defer resp.Body.Close() + + var retry bool + retryAfter, retry = r.retry.NextRetry(req, resp, err, neverRetryError) + if retry { + err := r.retry.BeforeNextRetry(ctx, r.backoff, retryAfter, url, r.body) + if err == nil { + return false, nil + } + klog.V(4).Infof("Could not retry request - %v", err) + } + result := r.transformResponse(resp, req) + if err := result.Error(); err != nil { + return true, err + } + return true, fmt.Errorf("%d while accessing %v: %s", result.statusCode, url, string(result.body)) + }() + if done { + return nil, transformErr + } } - return nil, err } } @@ -940,12 +1011,11 @@ func (r *Request) request(ctx context.Context, fn func(*http.Request, *http.Resp return false }) if retry { - if err := r.retry.BeforeNextRetry(ctx, r.backoff, retryAfter, req.URL.String(), r.body); err != nil { - klog.V(4).Infof("Could not retry request - %v", err) - f(req, resp) - return true + err := r.retry.BeforeNextRetry(ctx, r.backoff, retryAfter, req.URL.String(), r.body) + if err == nil { + return false } - return false + klog.V(4).Infof("Could not retry request - %v", err) } f(req, resp) diff --git a/rest/request_test.go b/rest/request_test.go index 421f7f3a..9a458c47 100644 --- a/rest/request_test.go +++ b/rest/request_test.go @@ -924,53 +924,57 @@ func TestTransformUnstructuredError(t *testing.T) { } } -type errorReader struct { - err error -} - -func (r errorReader) Read(data []byte) (int, error) { return 0, r.err } -func (r errorReader) Close() error { return nil } - func TestRequestWatch(t *testing.T) { testCases := []struct { - Request *Request - Expect []watch.Event - Err bool - ErrFn func(error) bool - Empty bool + name string + Request *Request + maxRetries int + serverReturns []responseErr + Expect []watch.Event + attemptsExpected int + Err bool + ErrFn func(error) bool + Empty bool }{ { - Request: &Request{err: errors.New("bail")}, - Err: true, + name: "Request has error", + Request: &Request{err: errors.New("bail")}, + attemptsExpected: 0, + Err: true, }, { + name: "Client is nil, should use http.DefaultClient", Request: &Request{c: &RESTClient{base: &url.URL{}}, pathPrefix: "%"}, Err: true, }, { + name: "error is not retryable", Request: &Request{ c: &RESTClient{ - Client: clientForFunc(func(req *http.Request) (*http.Response, error) { - return nil, errors.New("err") - }), base: &url.URL{}, }, }, - Err: true, + serverReturns: []responseErr{ + {response: nil, err: errors.New("err")}, + }, + attemptsExpected: 1, + Err: true, }, { + name: "server returns forbidden", Request: &Request{ c: &RESTClient{ content: defaultContentConfig(), - Client: clientForFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusForbidden, - Body: ioutil.NopCloser(bytes.NewReader([]byte{})), - }, nil - }), - base: &url.URL{}, + base: &url.URL{}, }, }, + serverReturns: []responseErr{ + {response: &http.Response{ + StatusCode: http.StatusForbidden, + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + }, err: nil}, + }, + attemptsExpected: 1, Expect: []watch.Event{ { Type: watch.Error, @@ -1000,101 +1004,205 @@ func TestRequestWatch(t *testing.T) { }, }, { + name: "server returns forbidden", Request: &Request{ c: &RESTClient{ content: defaultContentConfig(), - Client: clientForFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusForbidden, - Body: ioutil.NopCloser(bytes.NewReader([]byte{})), - }, nil - }), - base: &url.URL{}, + base: &url.URL{}, }, }, - Err: true, + serverReturns: []responseErr{ + {response: &http.Response{ + StatusCode: http.StatusForbidden, + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + }, err: nil}, + }, + attemptsExpected: 1, + Err: true, ErrFn: func(err error) bool { return apierrors.IsForbidden(err) }, }, { + name: "server returns unauthorized", Request: &Request{ c: &RESTClient{ content: defaultContentConfig(), - Client: clientForFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusUnauthorized, - Body: ioutil.NopCloser(bytes.NewReader([]byte{})), - }, nil - }), - base: &url.URL{}, + base: &url.URL{}, }, }, - Err: true, + serverReturns: []responseErr{ + {response: &http.Response{ + StatusCode: http.StatusUnauthorized, + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + }, err: nil}, + }, + attemptsExpected: 1, + Err: true, ErrFn: func(err error) bool { return apierrors.IsUnauthorized(err) }, }, { + name: "server returns unauthorized", Request: &Request{ c: &RESTClient{ content: defaultContentConfig(), - Client: clientForFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusUnauthorized, - Body: ioutil.NopCloser(bytes.NewReader([]byte(runtime.EncodeOrDie(scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion), &metav1.Status{ - Status: metav1.StatusFailure, - Reason: metav1.StatusReasonUnauthorized, - })))), - }, nil - }), - base: &url.URL{}, + base: &url.URL{}, }, }, - Err: true, + serverReturns: []responseErr{ + {response: &http.Response{ + StatusCode: http.StatusUnauthorized, + Body: ioutil.NopCloser(bytes.NewReader([]byte(runtime.EncodeOrDie(scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion), &metav1.Status{ + Status: metav1.StatusFailure, + Reason: metav1.StatusReasonUnauthorized, + })))), + }, err: nil}, + }, + attemptsExpected: 1, + Err: true, ErrFn: func(err error) bool { return apierrors.IsUnauthorized(err) }, }, { + name: "server returns EOF error", Request: &Request{ c: &RESTClient{ - Client: clientForFunc(func(req *http.Request) (*http.Response, error) { - return nil, io.EOF - }), base: &url.URL{}, }, }, + serverReturns: []responseErr{ + {response: nil, err: io.EOF}, + }, + attemptsExpected: 1, + Empty: true, + }, + { + name: "server returns can't write HTTP request on broken connection error", + Request: &Request{ + c: &RESTClient{ + base: &url.URL{}, + }, + }, + serverReturns: []responseErr{ + {response: nil, err: errors.New("http: can't write HTTP request on broken connection")}, + }, + attemptsExpected: 1, + Empty: true, + }, + { + name: "server returns connection reset by peer", + Request: &Request{ + c: &RESTClient{ + base: &url.URL{}, + }, + }, + serverReturns: []responseErr{ + {response: nil, err: errors.New("foo: connection reset by peer")}, + }, + attemptsExpected: 1, + Empty: true, + }, + { + name: "max retries 2, server always returns EOF error", + Request: &Request{ + c: &RESTClient{ + base: &url.URL{}, + }, + }, + maxRetries: 2, + attemptsExpected: 3, + serverReturns: []responseErr{ + {response: nil, err: io.EOF}, + {response: nil, err: io.EOF}, + {response: nil, err: io.EOF}, + }, Empty: true, }, { + name: "max retries 1, server returns a retry-after response, request body seek error", Request: &Request{ + body: &readSeeker{err: io.EOF}, c: &RESTClient{ - Client: clientForFunc(func(req *http.Request) (*http.Response, error) { - return nil, errors.New("http: can't write HTTP request on broken connection") - }), base: &url.URL{}, }, }, + maxRetries: 1, + attemptsExpected: 1, + serverReturns: []responseErr{ + {response: retryAfterResponse(), err: nil}, + }, + Err: true, + ErrFn: func(err error) bool { + return apierrors.IsInternalError(err) + }, + }, + { + name: "max retries 1, server returns a retryable error, request body seek error", + Request: &Request{ + body: &readSeeker{err: io.EOF}, + c: &RESTClient{ + base: &url.URL{}, + }, + }, + maxRetries: 1, + attemptsExpected: 1, + serverReturns: []responseErr{ + {response: nil, err: io.EOF}, + }, Empty: true, }, { + name: "max retries 2, server always returns a response with Retry-After header", Request: &Request{ c: &RESTClient{ - Client: clientForFunc(func(req *http.Request) (*http.Response, error) { - return nil, errors.New("foo: connection reset by peer") - }), base: &url.URL{}, }, }, - Empty: true, + maxRetries: 2, + attemptsExpected: 3, + serverReturns: []responseErr{ + {response: retryAfterResponse(), err: nil}, + {response: retryAfterResponse(), err: nil}, + {response: retryAfterResponse(), err: nil}, + }, + Err: true, + ErrFn: func(err error) bool { + return apierrors.IsInternalError(err) + }, }, } + for _, testCase := range testCases { - t.Run("", func(t *testing.T) { - testCase.Request.backoff = &NoBackoff{} - testCase.Request.retry = &withRetry{} + t.Run(testCase.name, func(t *testing.T) { + var attemptsGot int + client := clientForFunc(func(req *http.Request) (*http.Response, error) { + defer func() { + attemptsGot++ + }() + + if attemptsGot >= len(testCase.serverReturns) { + t.Fatalf("Wrong test setup, the server does not know what to return") + } + re := testCase.serverReturns[attemptsGot] + return re.response, re.err + }) + if c := testCase.Request.c; c != nil && len(testCase.serverReturns) > 0 { + c.Client = client + } + testCase.Request.backoff = &noSleepBackOff{} + testCase.Request.retry = &withRetry{maxRetries: testCase.maxRetries} + watch, err := testCase.Request.Watch(context.Background()) + + if watch == nil && err == nil { + t.Fatal("Both watch.Interface and err returned by Watch are nil") + } + if testCase.attemptsExpected != attemptsGot { + t.Errorf("Expected RoundTrip to be invoked %d times, but got: %d", testCase.attemptsExpected, attemptsGot) + } hasErr := err != nil if hasErr != testCase.Err { t.Fatalf("expected %t, got %t: %v", testCase.Err, hasErr, err) @@ -1132,61 +1240,72 @@ func TestRequestWatch(t *testing.T) { func TestRequestStream(t *testing.T) { testCases := []struct { - Request *Request - Err bool - ErrFn func(error) bool + name string + Request *Request + maxRetries int + serverReturns []responseErr + attemptsExpected int + Err bool + ErrFn func(error) bool }{ { - Request: &Request{err: errors.New("bail")}, - Err: true, + name: "request has error", + Request: &Request{err: errors.New("bail")}, + attemptsExpected: 0, + Err: true, }, { + name: "Client is nil, should use http.DefaultClient", Request: &Request{c: &RESTClient{base: &url.URL{}}, pathPrefix: "%"}, Err: true, }, { + name: "server returns an error", Request: &Request{ c: &RESTClient{ - Client: clientForFunc(func(req *http.Request) (*http.Response, error) { - return nil, errors.New("err") - }), base: &url.URL{}, }, }, - Err: true, + serverReturns: []responseErr{ + {response: nil, err: errors.New("err")}, + }, + attemptsExpected: 1, + Err: true, }, { Request: &Request{ c: &RESTClient{ - Client: clientForFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusUnauthorized, - Body: ioutil.NopCloser(bytes.NewReader([]byte(runtime.EncodeOrDie(scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion), &metav1.Status{ - Status: metav1.StatusFailure, - Reason: metav1.StatusReasonUnauthorized, - })))), - }, nil - }), content: defaultContentConfig(), base: &url.URL{}, }, }, - Err: true, + serverReturns: []responseErr{ + {response: &http.Response{ + StatusCode: http.StatusUnauthorized, + Body: ioutil.NopCloser(bytes.NewReader([]byte(runtime.EncodeOrDie(scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion), &metav1.Status{ + Status: metav1.StatusFailure, + Reason: metav1.StatusReasonUnauthorized, + })))), + }, err: nil}, + }, + attemptsExpected: 1, + Err: true, }, { Request: &Request{ c: &RESTClient{ - Client: clientForFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: http.StatusBadRequest, - Body: ioutil.NopCloser(bytes.NewReader([]byte(`{"kind":"Status","apiVersion":"v1","metadata":{},"status":"Failure","message":"a container name must be specified for pod kube-dns-v20-mz5cv, choose one of: [kubedns dnsmasq healthz]","reason":"BadRequest","code":400}`))), - }, nil - }), content: defaultContentConfig(), base: &url.URL{}, }, }, - Err: true, + serverReturns: []responseErr{ + {response: &http.Response{ + StatusCode: http.StatusBadRequest, + Body: ioutil.NopCloser(bytes.NewReader([]byte(`{"kind":"Status","apiVersion":"v1","metadata":{},"status":"Failure","message":"a container name must be specified for pod kube-dns-v20-mz5cv, choose one of: [kubedns dnsmasq healthz]","reason":"BadRequest","code":400}`))), + }, err: nil}, + }, + attemptsExpected: 1, + Err: true, ErrFn: func(err error) bool { if err.Error() == "a container name must be specified for pod kube-dns-v20-mz5cv, choose one of: [kubedns dnsmasq healthz]" { return true @@ -1194,25 +1313,124 @@ func TestRequestStream(t *testing.T) { return false }, }, + { + name: "max retries 1, server returns a retry-after response, request body seek error", + Request: &Request{ + body: &readSeeker{err: io.EOF}, + c: &RESTClient{ + base: &url.URL{}, + }, + }, + maxRetries: 1, + attemptsExpected: 1, + serverReturns: []responseErr{ + {response: retryAfterResponse(), err: nil}, + }, + Err: true, + ErrFn: func(err error) bool { + return apierrors.IsInternalError(err) + }, + }, + { + name: "max retries 2, server always returns a response with Retry-After header", + Request: &Request{ + c: &RESTClient{ + base: &url.URL{}, + }, + }, + maxRetries: 2, + attemptsExpected: 3, + serverReturns: []responseErr{ + {response: retryAfterResponse(), err: nil}, + {response: retryAfterResponse(), err: nil}, + {response: retryAfterResponse(), err: nil}, + }, + Err: true, + ErrFn: func(err error) bool { + return apierrors.IsInternalError(err) + }, + }, + { + name: "server returns EOF after attempt 1, retry aborted", + Request: &Request{ + c: &RESTClient{ + base: &url.URL{}, + }, + }, + maxRetries: 2, + attemptsExpected: 2, + serverReturns: []responseErr{ + {response: retryAfterResponse(), err: nil}, + {response: nil, err: io.EOF}, + }, + Err: true, + ErrFn: func(err error) bool { + return unWrap(err) == io.EOF + }, + }, + { + name: "max retries 2, server returns success on the final attempt", + Request: &Request{ + c: &RESTClient{ + base: &url.URL{}, + }, + }, + maxRetries: 2, + attemptsExpected: 3, + serverReturns: []responseErr{ + {response: retryAfterResponse(), err: nil}, + {response: retryAfterResponse(), err: nil}, + {response: &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + }, err: nil}, + }, + }, } - for i, testCase := range testCases { - testCase.Request.backoff = &NoBackoff{} - testCase.Request.retry = &withRetry{maxRetries: 0} - body, err := testCase.Request.Stream(context.Background()) - hasErr := err != nil - if hasErr != testCase.Err { - t.Errorf("%d: expected %t, got %t: %v", i, testCase.Err, hasErr, err) - } - if hasErr && body != nil { - t.Errorf("%d: body should be nil when error is returned", i) - } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + var attemptsGot int + client := clientForFunc(func(req *http.Request) (*http.Response, error) { + defer func() { + attemptsGot++ + }() - if hasErr { - if testCase.ErrFn != nil && !testCase.ErrFn(err) { - t.Errorf("unexpected error: %v", err) + if attemptsGot >= len(testCase.serverReturns) { + t.Fatalf("Wrong test setup, the server does not know what to return") + } + re := testCase.serverReturns[attemptsGot] + return re.response, re.err + }) + if c := testCase.Request.c; c != nil && len(testCase.serverReturns) > 0 { + c.Client = client } - } + testCase.Request.backoff = &noSleepBackOff{} + testCase.Request.retry = &withRetry{maxRetries: testCase.maxRetries} + + body, err := testCase.Request.Stream(context.Background()) + + if body == nil && err == nil { + t.Fatal("Both body and err returned by Stream are nil") + } + if testCase.attemptsExpected != attemptsGot { + t.Errorf("Expected RoundTrip to be invoked %d times, but got: %d", testCase.attemptsExpected, attemptsGot) + } + + hasErr := err != nil + if hasErr != testCase.Err { + t.Errorf("expected %t, got %t: %v", testCase.Err, hasErr, err) + } + if hasErr && body != nil { + t.Error("body should be nil when error is returned") + } + + if hasErr { + if testCase.ErrFn != nil && !testCase.ErrFn(err) { + t.Errorf("unexpected error: %#v", err) + } + } + }) } } @@ -1840,57 +2058,87 @@ func TestBody(t *testing.T) { } func TestWatch(t *testing.T) { - var table = []struct { - t watch.EventType - obj runtime.Object + tests := []struct { + name string + maxRetries int }{ - {watch.Added, &v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "first"}}}, - {watch.Modified, &v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "second"}}}, - {watch.Deleted, &v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "last"}}}, + { + name: "no retry", + maxRetries: 0, + }, + { + name: "with retries", + maxRetries: 3, + }, } - testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - flusher, ok := w.(http.Flusher) - if !ok { - panic("need flusher!") - } - - w.Header().Set("Transfer-Encoding", "chunked") - w.WriteHeader(http.StatusOK) - flusher.Flush() - - encoder := restclientwatch.NewEncoder(streaming.NewEncoder(w, scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion)), scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion)) - for _, item := range table { - if err := encoder.Encode(&watch.Event{Type: item.t, Object: item.obj}); err != nil { - panic(err) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var table = []struct { + t watch.EventType + obj runtime.Object + }{ + {watch.Added, &v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "first"}}}, + {watch.Modified, &v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "second"}}}, + {watch.Deleted, &v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "last"}}}, } - flusher.Flush() - } - })) - defer testServer.Close() - s := testRESTClient(t, testServer) - watching, err := s.Get().Prefix("path/to/watch/thing").Watch(context.Background()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + var attempts int + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + attempts++ + }() - for _, item := range table { - got, ok := <-watching.ResultChan() - if !ok { - t.Fatalf("Unexpected early close") - } - if e, a := item.t, got.Type; e != a { - t.Errorf("Expected %v, got %v", e, a) - } - if e, a := item.obj, got.Object; !apiequality.Semantic.DeepDerivative(e, a) { - t.Errorf("Expected %v, got %v", e, a) - } - } + flusher, ok := w.(http.Flusher) + if !ok { + panic("need flusher!") + } - _, ok := <-watching.ResultChan() - if ok { - t.Fatal("Unexpected non-close") + if attempts < test.maxRetries { + w.Header().Set("Retry-After", "1") + w.WriteHeader(http.StatusTooManyRequests) + return + } + + w.Header().Set("Transfer-Encoding", "chunked") + w.WriteHeader(http.StatusOK) + flusher.Flush() + + encoder := restclientwatch.NewEncoder(streaming.NewEncoder(w, scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion)), scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion)) + for _, item := range table { + if err := encoder.Encode(&watch.Event{Type: item.t, Object: item.obj}); err != nil { + panic(err) + } + flusher.Flush() + } + })) + defer testServer.Close() + + s := testRESTClient(t, testServer) + watching, err := s.Get().Prefix("path/to/watch/thing"). + MaxRetries(test.maxRetries).Watch(context.Background()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + for _, item := range table { + got, ok := <-watching.ResultChan() + if !ok { + t.Fatalf("Unexpected early close") + } + if e, a := item.t, got.Type; e != a { + t.Errorf("Expected %v, got %v", e, a) + } + if e, a := item.obj, got.Object; !apiequality.Semantic.DeepDerivative(e, a) { + t.Errorf("Expected %v, got %v", e, a) + } + } + + _, ok := <-watching.ResultChan() + if ok { + t.Fatal("Unexpected non-close") + } + }) } } @@ -2333,14 +2581,27 @@ type seek struct { type count struct { // keeps track of the number of Seek(offset, whence) calls. seeks []seek + // how many times {Request|Response}.Body.Close() has been invoked + lock sync.Mutex closes int } +func (c *count) close() { + c.lock.Lock() + defer c.lock.Unlock() + c.closes++ +} +func (c *count) getCloseCount() int { + c.lock.Lock() + defer c.lock.Unlock() + return c.closes +} + // used to track {Request|Response}.Body type readTracker struct { - count *count delegated io.Reader + count *count } func (r *readTracker) Seek(offset int64, whence int) (int64, error) { @@ -2357,7 +2618,7 @@ func (r *readTracker) Read(p []byte) (n int, err error) { func (r *readTracker) Close() error { if closer, ok := r.delegated.(io.Closer); ok { - r.count.closes++ + r.count.close() return closer.Close() } return nil @@ -2492,26 +2753,46 @@ func TestRequestWithRetry(t *testing.T) { } func TestRequestDoWithRetry(t *testing.T) { - testRequestWithRetry(t, func(ctx context.Context, r *Request) { + testRequestWithRetry(t, "Do", func(ctx context.Context, r *Request) { r.Do(ctx) }) } -func TestRequestDORawWithRetry(t *testing.T) { - testRequestWithRetry(t, func(ctx context.Context, r *Request) { +func TestRequestDoRawWithRetry(t *testing.T) { + // both request.Do and request.DoRaw have the same behavior and expectations + testRequestWithRetry(t, "Do", func(ctx context.Context, r *Request) { r.DoRaw(ctx) }) } -func testRequestWithRetry(t *testing.T, doFunc func(ctx context.Context, r *Request)) { +func TestRequestStreamWithRetry(t *testing.T) { + testRequestWithRetry(t, "Stream", func(ctx context.Context, r *Request) { + r.Stream(ctx) + }) +} + +func TestRequestWatchWithRetry(t *testing.T) { + testRequestWithRetry(t, "Watch", func(ctx context.Context, r *Request) { + r.Watch(ctx) + }) +} + +func testRequestWithRetry(t *testing.T, key string, doFunc func(ctx context.Context, r *Request)) { + type expected struct { + attempts int + reqCount *count + respCount *count + } + tests := []struct { - name string - verb string - body func() io.Reader - maxRetries int - serverReturns []responseErr - reqCountExpected *count - respCountExpected *count + name string + verb string + body func() io.Reader + maxRetries int + serverReturns []responseErr + + // expectations differ based on whether it is 'Watch', 'Stream' or 'Do' + expectations map[string]expected }{ { name: "server always returns retry-after response", @@ -2523,8 +2804,23 @@ func testRequestWithRetry(t *testing.T, doFunc func(ctx context.Context, r *Requ {response: retryAfterResponse(), err: nil}, {response: retryAfterResponse(), err: nil}, }, - reqCountExpected: &count{closes: 0, seeks: make([]seek, 2)}, - respCountExpected: &count{closes: 3, seeks: []seek{}}, + expectations: map[string]expected{ + "Do": { + attempts: 3, + reqCount: &count{closes: 0, seeks: make([]seek, 2)}, + respCount: &count{closes: 3, seeks: []seek{}}, + }, + "Watch": { + attempts: 3, + reqCount: &count{closes: 0, seeks: make([]seek, 2)}, + respCount: &count{closes: 3, seeks: []seek{}}, + }, + "Stream": { + attempts: 3, + reqCount: &count{closes: 0, seeks: make([]seek, 2)}, + respCount: &count{closes: 3, seeks: []seek{}}, + }, + }, }, { name: "server always returns retryable error", @@ -2536,8 +2832,24 @@ func testRequestWithRetry(t *testing.T, doFunc func(ctx context.Context, r *Requ {response: nil, err: io.EOF}, {response: nil, err: io.EOF}, }, - reqCountExpected: &count{closes: 0, seeks: make([]seek, 2)}, - respCountExpected: &count{closes: 0, seeks: []seek{}}, + expectations: map[string]expected{ + "Do": { + attempts: 3, + reqCount: &count{closes: 0, seeks: make([]seek, 2)}, + respCount: &count{closes: 0, seeks: []seek{}}, + }, + "Watch": { + attempts: 3, + reqCount: &count{closes: 0, seeks: make([]seek, 2)}, + respCount: &count{closes: 0, seeks: []seek{}}, + }, + // for Stream, we never retry on any error + "Stream": { + attempts: 1, // only the first attempt is expected + reqCount: &count{closes: 0, seeks: []seek{}}, + respCount: &count{closes: 0, seeks: []seek{}}, + }, + }, }, { name: "server returns success on the final retry", @@ -2549,8 +2861,24 @@ func testRequestWithRetry(t *testing.T, doFunc func(ctx context.Context, r *Requ {response: nil, err: io.EOF}, {response: &http.Response{StatusCode: http.StatusOK}, err: nil}, }, - reqCountExpected: &count{closes: 0, seeks: make([]seek, 2)}, - respCountExpected: &count{closes: 2, seeks: []seek{}}, + expectations: map[string]expected{ + "Do": { + attempts: 3, + reqCount: &count{closes: 0, seeks: make([]seek, 2)}, + respCount: &count{closes: 2, seeks: []seek{}}, + }, + "Watch": { + attempts: 3, + reqCount: &count{closes: 0, seeks: make([]seek, 2)}, + // we don't close the the Body of the final successful response + respCount: &count{closes: 1, seeks: []seek{}}, + }, + "Stream": { + attempts: 2, + reqCount: &count{closes: 0, seeks: make([]seek, 1)}, + respCount: &count{closes: 1, seeks: []seek{}}, + }, + }, }, } @@ -2580,7 +2908,8 @@ func testRequestWithRetry(t *testing.T, doFunc func(ctx context.Context, r *Requ verb: test.verb, body: reqRecorder, c: &RESTClient{ - Client: client, + content: defaultContentConfig(), + Client: client, }, backoff: &noSleepBackOff{}, retry: &withRetry{maxRetries: test.maxRetries}, @@ -2588,15 +2917,19 @@ func testRequestWithRetry(t *testing.T, doFunc func(ctx context.Context, r *Requ doFunc(context.Background(), req) - attemptsExpected := test.maxRetries + 1 - if attemptsExpected != attempts { - t.Errorf("Expected retries: %d, but got: %d", attemptsExpected, attempts) + expected, ok := test.expectations[key] + if !ok { + t.Fatalf("Wrong test setup - did not find expected for: %s", key) } - if !reflect.DeepEqual(test.reqCountExpected.seeks, reqCountGot.seeks) { - t.Errorf("Expected request body to have seek invocation: %v, but got: %v", test.reqCountExpected.seeks, reqCountGot.seeks) + if expected.attempts != attempts { + t.Errorf("Expected retries: %d, but got: %d", expected.attempts, attempts) } - if test.respCountExpected.closes != respCountGot.closes { - t.Errorf("Expected response body Close to be invoked %d times, but got: %d", test.respCountExpected.closes, respCountGot.closes) + + if !reflect.DeepEqual(expected.reqCount.seeks, reqCountGot.seeks) { + t.Errorf("Expected request body to have seek invocation: %v, but got: %v", expected.reqCount.seeks, reqCountGot.seeks) + } + if expected.respCount.closes != respCountGot.getCloseCount() { + t.Errorf("Expected response body Close to be invoked %d times, but got: %d", expected.respCount.closes, respCountGot.getCloseCount()) } }) } diff --git a/rest/with_retry.go b/rest/with_retry.go index aadbeb28..1b7360b5 100644 --- a/rest/with_retry.go +++ b/rest/with_retry.go @@ -43,6 +43,10 @@ func (r IsRetryableErrorFunc) IsErrorRetryable(request *http.Request, err error) return r(request, err) } +var neverRetryError = IsRetryableErrorFunc(func(_ *http.Request, _ error) bool { + return false +}) + // WithRetry allows the client to retry a request up to a certain number of times // Note that WithRetry is not safe for concurrent use by multiple // goroutines without additional locking or coordination. diff --git a/rest/with_retry_test.go b/rest/with_retry_test.go index 127746c9..25b9016d 100644 --- a/rest/with_retry_test.go +++ b/rest/with_retry_test.go @@ -30,10 +30,6 @@ var alwaysRetryError = IsRetryableErrorFunc(func(_ *http.Request, _ error) bool return true }) -var neverRetryError = IsRetryableErrorFunc(func(_ *http.Request, _ error) bool { - return false -}) - func TestNextRetry(t *testing.T) { fakeError := errors.New("fake error") tests := []struct {