diff --git a/rest/request.go b/rest/request.go index 6f4c0b05..9864dbdb 100644 --- a/rest/request.go +++ b/rest/request.go @@ -93,7 +93,6 @@ type Request struct { rateLimiter flowcontrol.RateLimiter backoff BackoffManager timeout time.Duration - maxRetries int // generic components accessible via method setters verb string @@ -110,8 +109,9 @@ type Request struct { subresource string // output - err error - body io.Reader + err error + body io.Reader + retry WithRetry } // NewRequest creates a new request helper object for accessing runtime.Objects on a server. @@ -142,7 +142,7 @@ func NewRequest(c *RESTClient) *Request { backoff: backoff, timeout: timeout, pathPrefix: pathPrefix, - maxRetries: 10, + retry: &withRetry{maxRetries: 10}, warningHandler: c.warningHandler, } @@ -408,10 +408,7 @@ func (r *Request) Timeout(d time.Duration) *Request { // function is specifically called with a different value. // A zero maxRetries prevent it from doing retires and return an error immediately. func (r *Request) MaxRetries(maxRetries int) *Request { - if maxRetries < 0 { - maxRetries = 0 - } - r.maxRetries = maxRetries + r.retry.SetMaxRetries(maxRetries) return r } @@ -842,6 +839,17 @@ func (r *Request) requestPreflightCheck() error { return nil } +func (r *Request) newHTTPRequest(ctx context.Context) (*http.Request, error) { + url := r.URL().String() + req, err := http.NewRequest(r.verb, url, r.body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + req.Header = r.headers + return req, nil +} + // request connects to the server and invokes the provided function when a server response is // received. It handles retry behavior and up front validation of requests. It will invoke // fn at most once. It will return an error if a problem occurred prior to connecting to the @@ -881,27 +889,22 @@ func (r *Request) request(ctx context.Context, fn func(*http.Request, *http.Resp } // Right now we make about ten retry attempts if we get a Retry-After response. - retries := 0 - var retryInfo string + var retryAfter *RetryAfter for { - - url := r.URL().String() - req, err := http.NewRequest(r.verb, url, r.body) + req, err := r.newHTTPRequest(ctx) if err != nil { return err } - req = req.WithContext(ctx) - req.Header = r.headers r.backoff.Sleep(r.backoff.CalculateBackoff(r.URL())) - if retries > 0 { + if retryAfter != nil { // We are retrying the request that we already send to apiserver // at least once before. // This request should also be throttled with the client-internal rate limiter. - if err := r.tryThrottleWithInfo(ctx, retryInfo); err != nil { + if err := r.tryThrottleWithInfo(ctx, retryAfter.Reason); err != nil { return err } - retryInfo = "" + retryAfter = nil } resp, err := client.Do(req) updateURLMetrics(ctx, r, resp, err) @@ -910,61 +913,46 @@ func (r *Request) request(ctx context.Context, fn func(*http.Request, *http.Resp } else { r.backoff.UpdateBackoff(r.URL(), err, resp.StatusCode) } - if err != nil { - // "Connection reset by peer" or "apiserver is shutting down" are usually a transient errors. - // Thus in case of "GET" operations, we simply retry it. - // We are not automatically retrying "write" operations, as - // they are not idempotent. - if r.verb != "GET" { - return err - } - // For connection errors and apiserver shutdown errors retry. - if net.IsConnectionReset(err) || net.IsProbableEOF(err) { - // For the purpose of retry, we set the artificial "retry-after" response. - // TODO: Should we clean the original response if it exists? - resp = &http.Response{ - StatusCode: http.StatusInternalServerError, - Header: http.Header{"Retry-After": []string{"1"}}, - Body: ioutil.NopCloser(bytes.NewReader([]byte{})), - } - } else { - return err - } - } done := func() bool { - // Ensure the response body is fully read and closed - // before we reconnect, so that we reuse the same TCP - // connection. - defer func() { - const maxBodySlurpSize = 2 << 10 - if resp.ContentLength <= maxBodySlurpSize { - io.Copy(ioutil.Discard, &io.LimitedReader{R: resp.Body, N: maxBodySlurpSize}) - } - resp.Body.Close() - }() + defer readAndCloseResponseBody(resp) - retries++ - if seconds, wait := checkWait(resp); wait && retries <= r.maxRetries { - retryInfo = getRetryReason(retries, seconds, resp, err) - if seeker, ok := r.body.(io.Seeker); ok && r.body != nil { - _, err := seeker.Seek(0, 0) - if err != nil { - klog.V(4).Infof("Could not retry request, can't Seek() back to beginning of body for %T", r.body) - fn(req, resp) - return true - } + // if the the server returns an error in err, the response will be nil. + f := func(req *http.Request, resp *http.Response) { + if resp == nil { + return } + fn(req, resp) + } - klog.V(4).Infof("Got a Retry-After %ds response for attempt %d to %v", seconds, retries, url) - r.backoff.Sleep(time.Duration(seconds) * time.Second) + var retry bool + retryAfter, retry = r.retry.NextRetry(req, resp, err, func(req *http.Request, err error) bool { + // "Connection reset by peer" or "apiserver is shutting down" are usually a transient errors. + // Thus in case of "GET" operations, we simply retry it. + // We are not automatically retrying "write" operations, as they are not idempotent. + if r.verb != "GET" { + return false + } + // For connection errors and apiserver shutdown errors retry. + if net.IsConnectionReset(err) || net.IsProbableEOF(err) { + return true + } + return false + }) + if retry { + if err := r.retry.BeforeNextRetry(ctx, r.backoff, retryAfter, req.URL.String(), r.body); err != nil { + klog.V(4).Infof("Could not retry request - %v", err) + f(req, resp) + return true + } return false } - fn(req, resp) + + f(req, resp) return true }() if done { - return nil + return err } } } @@ -1196,19 +1184,6 @@ func isTextResponse(resp *http.Response) bool { return strings.HasPrefix(media, "text/") } -// checkWait returns true along with a number of seconds if the server instructed us to wait -// before retrying. -func checkWait(resp *http.Response) (int, bool) { - switch r := resp.StatusCode; { - // any 500 error code and 429 can trigger a wait - case r == http.StatusTooManyRequests, r >= 500: - default: - return 0, false - } - i, ok := retryAfterSeconds(resp) - return i, ok -} - // retryAfterSeconds returns the value of the Retry-After header and true, or 0 and false if // the header was missing or not a valid number. func retryAfterSeconds(resp *http.Response) (int, bool) { @@ -1220,26 +1195,6 @@ func retryAfterSeconds(resp *http.Response) (int, bool) { return 0, false } -func getRetryReason(retries, seconds int, resp *http.Response, err error) string { - // priority and fairness sets the UID of the FlowSchema associated with a request - // in the following response Header. - const responseHeaderMatchedFlowSchemaUID = "X-Kubernetes-PF-FlowSchema-UID" - - message := fmt.Sprintf("retries: %d, retry-after: %ds", retries, seconds) - - switch { - case resp.StatusCode == http.StatusTooManyRequests: - // it is server-side throttling from priority and fairness - flowSchemaUID := resp.Header.Get(responseHeaderMatchedFlowSchemaUID) - return fmt.Sprintf("%s - retry-reason: due to server-side throttling, FlowSchema UID: %q", message, flowSchemaUID) - case err != nil: - // it's a retriable error - return fmt.Sprintf("%s - retry-reason: due to retriable error, error: %v", message, err) - default: - return fmt.Sprintf("%s - retry-reason: %d", message, resp.StatusCode) - } -} - // Result contains the result of calling Request.Do(). type Result struct { body []byte diff --git a/rest/request_test.go b/rest/request_test.go index 4527a8af..421f7f3a 100644 --- a/rest/request_test.go +++ b/rest/request_test.go @@ -1093,6 +1093,7 @@ func TestRequestWatch(t *testing.T) { for _, testCase := range testCases { t.Run("", func(t *testing.T) { testCase.Request.backoff = &NoBackoff{} + testCase.Request.retry = &withRetry{} watch, err := testCase.Request.Watch(context.Background()) hasErr := err != nil if hasErr != testCase.Err { @@ -1194,8 +1195,10 @@ func TestRequestStream(t *testing.T) { }, }, } + for i, testCase := range testCases { testCase.Request.backoff = &NoBackoff{} + testCase.Request.retry = &withRetry{maxRetries: 0} body, err := testCase.Request.Stream(context.Background()) hasErr := err != nil if hasErr != testCase.Err { @@ -1274,6 +1277,7 @@ func TestRequestDo(t *testing.T) { } for i, testCase := range testCases { testCase.Request.backoff = &NoBackoff{} + testCase.Request.retry = &withRetry{} body, err := testCase.Request.Do(context.Background()).Raw() hasErr := err != nil if hasErr != testCase.Err { @@ -1436,8 +1440,8 @@ func TestConnectionResetByPeerIsRetried(t *testing.T) { return nil, &net.OpError{Err: syscall.ECONNRESET} }), }, - backoff: backoff, - maxRetries: 10, + backoff: backoff, + retry: &withRetry{maxRetries: 10}, } // We expect two retries of "connection reset by peer" and the success. _, err := req.Do(context.Background()).Raw() @@ -2315,3 +2319,285 @@ func TestRequestMaxRetries(t *testing.T) { }) } } + +type responseErr struct { + response *http.Response + err error +} + +type seek struct { + offset int64 + whence int +} + +type count struct { + // keeps track of the number of Seek(offset, whence) calls. + seeks []seek + // how many times {Request|Response}.Body.Close() has been invoked + closes int +} + +// used to track {Request|Response}.Body +type readTracker struct { + count *count + delegated io.Reader +} + +func (r *readTracker) Seek(offset int64, whence int) (int64, error) { + if seeker, ok := r.delegated.(io.Seeker); ok { + r.count.seeks = append(r.count.seeks, seek{offset: offset, whence: whence}) + return seeker.Seek(offset, whence) + } + return 0, io.EOF +} + +func (r *readTracker) Read(p []byte) (n int, err error) { + return r.delegated.Read(p) +} + +func (r *readTracker) Close() error { + if closer, ok := r.delegated.(io.Closer); ok { + r.count.closes++ + return closer.Close() + } + return nil +} + +func newReadTracker(count *count) *readTracker { + return &readTracker{ + count: count, + } +} + +func newCount() *count { + return &count{ + closes: 0, + seeks: make([]seek, 0), + } +} + +type readSeeker struct{ err error } + +func (rs readSeeker) Read([]byte) (int, error) { return 0, rs.err } +func (rs readSeeker) Seek(int64, int) (int64, error) { return 0, rs.err } + +func unWrap(err error) error { + if uerr, ok := err.(*url.Error); ok { + return uerr.Err + } + return err +} + +// noSleepBackOff is a NoBackoff except it does not sleep, +// used for faster execution of the unit tests. +type noSleepBackOff struct { + *NoBackoff +} + +func (n *noSleepBackOff) Sleep(d time.Duration) {} + +func TestRequestWithRetry(t *testing.T) { + tests := []struct { + name string + body io.Reader + serverReturns responseErr + errExpected error + transformFuncInvokedExpected int + roundTripInvokedExpected int + }{ + { + name: "server returns retry-after response, request body is not io.Seeker, retry goes ahead", + body: ioutil.NopCloser(bytes.NewReader([]byte{})), + serverReturns: responseErr{response: retryAfterResponse(), err: nil}, + errExpected: nil, + transformFuncInvokedExpected: 1, + roundTripInvokedExpected: 2, + }, + { + name: "server returns retry-after response, request body Seek returns error, retry aborted", + body: &readSeeker{err: io.EOF}, + serverReturns: responseErr{response: retryAfterResponse(), err: nil}, + errExpected: nil, + transformFuncInvokedExpected: 1, + roundTripInvokedExpected: 1, + }, + { + name: "server returns retry-after response, request body Seek returns no error, retry goes ahead", + body: &readSeeker{err: nil}, + serverReturns: responseErr{response: retryAfterResponse(), err: nil}, + errExpected: nil, + transformFuncInvokedExpected: 1, + roundTripInvokedExpected: 2, + }, + { + name: "server returns retryable err, request body is not io.Seek, retry goes ahead", + body: ioutil.NopCloser(bytes.NewReader([]byte{})), + serverReturns: responseErr{response: nil, err: io.ErrUnexpectedEOF}, + errExpected: io.ErrUnexpectedEOF, + transformFuncInvokedExpected: 0, + roundTripInvokedExpected: 2, + }, + { + name: "server returns retryable err, request body Seek returns error, retry aborted", + body: &readSeeker{err: io.EOF}, + serverReturns: responseErr{response: nil, err: io.ErrUnexpectedEOF}, + errExpected: io.ErrUnexpectedEOF, + transformFuncInvokedExpected: 0, + roundTripInvokedExpected: 1, + }, + { + name: "server returns retryable err, request body Seek returns no err, retry goes ahead", + body: &readSeeker{err: nil}, + serverReturns: responseErr{response: nil, err: io.ErrUnexpectedEOF}, + errExpected: io.ErrUnexpectedEOF, + transformFuncInvokedExpected: 0, + roundTripInvokedExpected: 2, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var roundTripInvoked int + client := clientForFunc(func(req *http.Request) (*http.Response, error) { + roundTripInvoked++ + return test.serverReturns.response, test.serverReturns.err + }) + + req := &Request{ + verb: "GET", + body: test.body, + c: &RESTClient{ + Client: client, + }, + backoff: &noSleepBackOff{}, + retry: &withRetry{maxRetries: 1}, + } + + var transformFuncInvoked int + err := req.request(context.Background(), func(request *http.Request, response *http.Response) { + transformFuncInvoked++ + }) + + if test.roundTripInvokedExpected != roundTripInvoked { + t.Errorf("Expected RoundTrip to be invoked %d times, but got: %d", test.roundTripInvokedExpected, roundTripInvoked) + } + if test.transformFuncInvokedExpected != transformFuncInvoked { + t.Errorf("Expected transform func to be invoked %d times, but got: %d", test.transformFuncInvokedExpected, transformFuncInvoked) + } + if test.errExpected != unWrap(err) { + t.Errorf("Expected error: %v, but got: %v", test.errExpected, unWrap(err)) + } + }) + } +} + +func TestRequestDoWithRetry(t *testing.T) { + testRequestWithRetry(t, func(ctx context.Context, r *Request) { + r.Do(ctx) + }) +} + +func TestRequestDORawWithRetry(t *testing.T) { + testRequestWithRetry(t, func(ctx context.Context, r *Request) { + r.DoRaw(ctx) + }) +} + +func testRequestWithRetry(t *testing.T, doFunc func(ctx context.Context, r *Request)) { + tests := []struct { + name string + verb string + body func() io.Reader + maxRetries int + serverReturns []responseErr + reqCountExpected *count + respCountExpected *count + }{ + { + name: "server always returns retry-after response", + verb: "GET", + body: func() io.Reader { return bytes.NewReader([]byte{}) }, + maxRetries: 2, + serverReturns: []responseErr{ + {response: retryAfterResponse(), err: nil}, + {response: retryAfterResponse(), err: nil}, + {response: retryAfterResponse(), err: nil}, + }, + reqCountExpected: &count{closes: 0, seeks: make([]seek, 2)}, + respCountExpected: &count{closes: 3, seeks: []seek{}}, + }, + { + name: "server always returns retryable error", + verb: "GET", + body: func() io.Reader { return bytes.NewReader([]byte{}) }, + maxRetries: 2, + serverReturns: []responseErr{ + {response: nil, err: io.EOF}, + {response: nil, err: io.EOF}, + {response: nil, err: io.EOF}, + }, + reqCountExpected: &count{closes: 0, seeks: make([]seek, 2)}, + respCountExpected: &count{closes: 0, seeks: []seek{}}, + }, + { + name: "server returns success on the final retry", + verb: "GET", + body: func() io.Reader { return bytes.NewReader([]byte{}) }, + maxRetries: 2, + serverReturns: []responseErr{ + {response: retryAfterResponse(), err: nil}, + {response: nil, err: io.EOF}, + {response: &http.Response{StatusCode: http.StatusOK}, err: nil}, + }, + reqCountExpected: &count{closes: 0, seeks: make([]seek, 2)}, + respCountExpected: &count{closes: 2, seeks: []seek{}}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + respCountGot := newCount() + responseRecorder := newReadTracker(respCountGot) + var attempts int + client := clientForFunc(func(req *http.Request) (*http.Response, error) { + defer func() { + attempts++ + }() + + resp := test.serverReturns[attempts].response + if resp != nil { + responseRecorder.delegated = ioutil.NopCloser(bytes.NewReader([]byte{})) + resp.Body = responseRecorder + } + return resp, test.serverReturns[attempts].err + }) + + reqCountGot := newCount() + reqRecorder := newReadTracker(reqCountGot) + reqRecorder.delegated = test.body() + + req := &Request{ + verb: test.verb, + body: reqRecorder, + c: &RESTClient{ + Client: client, + }, + backoff: &noSleepBackOff{}, + retry: &withRetry{maxRetries: test.maxRetries}, + } + + doFunc(context.Background(), req) + + attemptsExpected := test.maxRetries + 1 + if attemptsExpected != attempts { + t.Errorf("Expected retries: %d, but got: %d", attemptsExpected, attempts) + } + if !reflect.DeepEqual(test.reqCountExpected.seeks, reqCountGot.seeks) { + t.Errorf("Expected request body to have seek invocation: %v, but got: %v", test.reqCountExpected.seeks, reqCountGot.seeks) + } + if test.respCountExpected.closes != respCountGot.closes { + t.Errorf("Expected response body Close to be invoked %d times, but got: %d", test.respCountExpected.closes, respCountGot.closes) + } + }) + } +} diff --git a/rest/with_retry.go b/rest/with_retry.go new file mode 100644 index 00000000..aadbeb28 --- /dev/null +++ b/rest/with_retry.go @@ -0,0 +1,228 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package rest + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net/http" + "time" + + "k8s.io/klog/v2" +) + +// IsRetryableErrorFunc allows the client to provide its own function +// that determines whether the specified err from the server is retryable. +// +// request: the original request sent to the server +// err: the server sent this error to us +// +// The function returns true if the error is retryable and the request +// can be retried, otherwise it returns false. +// We have four mode of communications - 'Stream', 'Watch', 'Do' and 'DoRaw', this +// function allows us to customize the retryability aspect of each. +type IsRetryableErrorFunc func(request *http.Request, err error) bool + +func (r IsRetryableErrorFunc) IsErrorRetryable(request *http.Request, err error) bool { + return r(request, err) +} + +// WithRetry allows the client to retry a request up to a certain number of times +// Note that WithRetry is not safe for concurrent use by multiple +// goroutines without additional locking or coordination. +type WithRetry interface { + // SetMaxRetries makes the request use the specified integer as a ceiling + // for retries upon receiving a 429 status code and the "Retry-After" header + // in the response. + // A zero maxRetries should prevent from doing any retry and return immediately. + SetMaxRetries(maxRetries int) + + // NextRetry advances the retry counter appropriately and returns true if the + // request should be retried, otherwise it returns false if: + // - we have already reached the maximum retry threshold. + // - the error does not fall into the retryable category. + // - the server has not sent us a 429, or 5xx status code and the + // 'Retry-After' response header is not set with a value. + // + // if retry is set to true, retryAfter will contain the information + // regarding the next retry. + // + // request: the original request sent to the server + // resp: the response sent from the server, it is set if err is nil + // err: the server sent this error to us, if err is set then resp is nil. + // f: a IsRetryableErrorFunc function provided by the client that determines + // if the err sent by the server is retryable. + NextRetry(req *http.Request, resp *http.Response, err error, f IsRetryableErrorFunc) (*RetryAfter, bool) + + // BeforeNextRetry is responsible for carrying out operations that need + // to be completed before the next retry is initiated: + // - if the request context is already canceled there is no need to + // retry, the function will return ctx.Err(). + // - we need to seek to the beginning of the request body before we + // initiate the next retry, the function should return an error if + // it fails to do so. + // - we should wait the number of seconds the server has asked us to + // in the 'Retry-After' response header. + // + // If BeforeNextRetry returns an error the client should abort the retry, + // otherwise it is safe to initiate the next retry. + BeforeNextRetry(ctx context.Context, backoff BackoffManager, retryAfter *RetryAfter, url string, body io.Reader) error +} + +// RetryAfter holds information associated with the next retry. +type RetryAfter struct { + // Wait is the duration the server has asked us to wait before + // the next retry is initiated. + // This is the value of the 'Retry-After' response header in seconds. + Wait time.Duration + + // Attempt is the Nth attempt after which we have received a retryable + // error or a 'Retry-After' response header from the server. + Attempt int + + // Reason describes why we are retrying the request + Reason string +} + +type withRetry struct { + maxRetries int + attempts int +} + +func (r *withRetry) SetMaxRetries(maxRetries int) { + if maxRetries < 0 { + maxRetries = 0 + } + r.maxRetries = maxRetries +} + +func (r *withRetry) NextRetry(req *http.Request, resp *http.Response, err error, f IsRetryableErrorFunc) (*RetryAfter, bool) { + if req == nil || (resp == nil && err == nil) { + // bad input, we do nothing. + return nil, false + } + + r.attempts++ + retryAfter := &RetryAfter{Attempt: r.attempts} + if r.attempts > r.maxRetries { + return retryAfter, false + } + + // if the server returned an error, it takes precedence over the http response. + var errIsRetryable bool + if f != nil && err != nil && f.IsErrorRetryable(req, err) { + errIsRetryable = true + // we have a retryable error, for which we will create an + // artificial "Retry-After" response. + resp = retryAfterResponse() + } + if err != nil && !errIsRetryable { + return retryAfter, false + } + + // if we are here, we have either a or b: + // a: we have a retryable error, for which we already + // have an artificial "Retry-After" response. + // b: we have a response from the server for which we + // need to check if it is retryable + seconds, wait := checkWait(resp) + if !wait { + return retryAfter, false + } + + retryAfter.Wait = time.Duration(seconds) * time.Second + retryAfter.Reason = getRetryReason(r.attempts, seconds, resp, err) + return retryAfter, true +} + +func (r *withRetry) BeforeNextRetry(ctx context.Context, backoff BackoffManager, retryAfter *RetryAfter, url string, body io.Reader) error { + // Ensure the response body is fully read and closed before + // we reconnect, so that we reuse the same TCP connection. + if ctx.Err() != nil { + return ctx.Err() + } + + if seeker, ok := body.(io.Seeker); ok && body != nil { + if _, err := seeker.Seek(0, 0); err != nil { + return fmt.Errorf("can't Seek() back to beginning of body for %T", r) + } + } + + klog.V(4).Infof("Got a Retry-After %s response for attempt %d to %v", retryAfter.Wait, retryAfter.Attempt, url) + if backoff != nil { + backoff.Sleep(retryAfter.Wait) + } + return nil +} + +// checkWait returns true along with a number of seconds if +// the server instructed us to wait before retrying. +func checkWait(resp *http.Response) (int, bool) { + switch r := resp.StatusCode; { + // any 500 error code and 429 can trigger a wait + case r == http.StatusTooManyRequests, r >= 500: + default: + return 0, false + } + i, ok := retryAfterSeconds(resp) + return i, ok +} + +func getRetryReason(retries, seconds int, resp *http.Response, err error) string { + // priority and fairness sets the UID of the FlowSchema + // associated with a request in the following response Header. + const responseHeaderMatchedFlowSchemaUID = "X-Kubernetes-PF-FlowSchema-UID" + + message := fmt.Sprintf("retries: %d, retry-after: %ds", retries, seconds) + + switch { + case resp.StatusCode == http.StatusTooManyRequests: + // it is server-side throttling from priority and fairness + flowSchemaUID := resp.Header.Get(responseHeaderMatchedFlowSchemaUID) + return fmt.Sprintf("%s - retry-reason: due to server-side throttling, FlowSchema UID: %q", message, flowSchemaUID) + case err != nil: + // it's a retryable error + return fmt.Sprintf("%s - retry-reason: due to retryable error, error: %v", message, err) + default: + return fmt.Sprintf("%s - retry-reason: %d", message, resp.StatusCode) + } +} + +func readAndCloseResponseBody(resp *http.Response) { + if resp == nil { + return + } + + // Ensure the response body is fully read and closed + // before we reconnect, so that we reuse the same TCP + // connection. + const maxBodySlurpSize = 2 << 10 + defer resp.Body.Close() + + if resp.ContentLength <= maxBodySlurpSize { + io.Copy(ioutil.Discard, &io.LimitedReader{R: resp.Body, N: maxBodySlurpSize}) + } +} + +func retryAfterResponse() *http.Response { + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Header: http.Header{"Retry-After": []string{"1"}}, + } +} diff --git a/rest/with_retry_test.go b/rest/with_retry_test.go new file mode 100644 index 00000000..127746c9 --- /dev/null +++ b/rest/with_retry_test.go @@ -0,0 +1,230 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package rest + +import ( + "errors" + "net/http" + "reflect" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +var alwaysRetryError = IsRetryableErrorFunc(func(_ *http.Request, _ error) bool { + return true +}) + +var neverRetryError = IsRetryableErrorFunc(func(_ *http.Request, _ error) bool { + return false +}) + +func TestNextRetry(t *testing.T) { + fakeError := errors.New("fake error") + tests := []struct { + name string + attempts int + maxRetries int + request *http.Request + response *http.Response + err error + retryableErrFunc IsRetryableErrorFunc + retryExpected []bool + retryAfterExpected []*RetryAfter + }{ + { + name: "bad input, response and err are nil", + maxRetries: 2, + attempts: 1, + request: &http.Request{}, + response: nil, + err: nil, + retryExpected: []bool{false}, + retryAfterExpected: []*RetryAfter{nil}, + }, + { + name: "zero maximum retry", + maxRetries: 0, + attempts: 1, + request: &http.Request{}, + response: retryAfterResponse(), + err: nil, + retryExpected: []bool{false}, + retryAfterExpected: []*RetryAfter{ + { + Attempt: 1, + }, + }, + }, + { + name: "server returned a retryable error", + maxRetries: 3, + attempts: 1, + request: &http.Request{}, + response: nil, + err: fakeError, + retryableErrFunc: func(_ *http.Request, err error) bool { + if err == fakeError { + return true + } + return false + }, + retryExpected: []bool{true}, + retryAfterExpected: []*RetryAfter{ + { + Attempt: 1, + Wait: time.Second, + Reason: "retries: 1, retry-after: 1s - retry-reason: due to retryable error, error: fake error", + }, + }, + }, + { + name: "server returned a retryable HTTP 429 response", + maxRetries: 3, + attempts: 1, + request: &http.Request{}, + response: &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{ + "Retry-After": []string{"2"}, + "X-Kubernetes-Pf-Flowschema-Uid": []string{"fs-1"}, + }, + }, + err: nil, + retryExpected: []bool{true}, + retryAfterExpected: []*RetryAfter{ + { + Attempt: 1, + Wait: 2 * time.Second, + Reason: `retries: 1, retry-after: 2s - retry-reason: due to server-side throttling, FlowSchema UID: "fs-1"`, + }, + }, + }, + { + name: "server returned a retryable HTTP 5xx response", + maxRetries: 3, + attempts: 1, + request: &http.Request{}, + response: &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{ + "Retry-After": []string{"3"}, + }, + }, + err: nil, + retryExpected: []bool{true}, + retryAfterExpected: []*RetryAfter{ + { + Attempt: 1, + Wait: 3 * time.Second, + Reason: "retries: 1, retry-after: 3s - retry-reason: 503", + }, + }, + }, + { + name: "server returned a non response without without a Retry-After header", + maxRetries: 1, + attempts: 1, + request: &http.Request{}, + response: &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + }, + err: nil, + retryExpected: []bool{false}, + retryAfterExpected: []*RetryAfter{ + { + Attempt: 1, + }, + }, + }, + { + name: "both response and err are set, err takes precedence", + maxRetries: 1, + attempts: 1, + request: &http.Request{}, + response: retryAfterResponse(), + err: fakeError, + retryableErrFunc: func(_ *http.Request, err error) bool { + if err == fakeError { + return true + } + return false + }, + retryExpected: []bool{true}, + retryAfterExpected: []*RetryAfter{ + { + Attempt: 1, + Wait: time.Second, + Reason: "retries: 1, retry-after: 1s - retry-reason: due to retryable error, error: fake error", + }, + }, + }, + { + name: "all retries are exhausted", + maxRetries: 3, + attempts: 4, + request: &http.Request{}, + response: nil, + err: fakeError, + retryableErrFunc: alwaysRetryError, + retryExpected: []bool{true, true, true, false}, + retryAfterExpected: []*RetryAfter{ + { + Attempt: 1, + Wait: time.Second, + Reason: "retries: 1, retry-after: 1s - retry-reason: due to retryable error, error: fake error", + }, + { + Attempt: 2, + Wait: time.Second, + Reason: "retries: 2, retry-after: 1s - retry-reason: due to retryable error, error: fake error", + }, + { + Attempt: 3, + Wait: time.Second, + Reason: "retries: 3, retry-after: 1s - retry-reason: due to retryable error, error: fake error", + }, + { + Attempt: 4, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + r := &withRetry{maxRetries: test.maxRetries} + + retryGot := make([]bool, 0) + retryAfterGot := make([]*RetryAfter, 0) + for i := 0; i < test.attempts; i++ { + retryAfter, retry := r.NextRetry(test.request, test.response, test.err, test.retryableErrFunc) + retryGot = append(retryGot, retry) + retryAfterGot = append(retryAfterGot, retryAfter) + } + + if !reflect.DeepEqual(test.retryExpected, retryGot) { + t.Errorf("Expected retry: %t, but got: %t", test.retryExpected, retryGot) + } + if !reflect.DeepEqual(test.retryAfterExpected, retryAfterGot) { + t.Errorf("Expected retry-after parameters to match, but got: %s", cmp.Diff(test.retryAfterExpected, retryAfterGot)) + } + }) + } +}