Signed-off-by: Miloslav Trmač <mitr@redhat.com>
This commit is contained in:
Miloslav Trmač
2024-05-09 21:49:41 +02:00
parent dcf937e170
commit 7649059a0d
474 changed files with 30496 additions and 30923 deletions

View File

@@ -27,10 +27,8 @@ package retryablehttp
import (
"bytes"
"context"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"log"
"math"
"math/rand"
@@ -63,6 +61,10 @@ var (
// limit the size we consume to respReadLimit.
respReadLimit = int64(4096)
// timeNow sets the function that returns the current time.
// This defaults to time.Now. Changes to this should only be done in tests.
timeNow = time.Now
// A regular expression to match the error returned by net/http when the
// configured number of redirects is exhausted. This error isn't typed
// specifically so we resort to matching on the error string.
@@ -73,6 +75,11 @@ var (
// specifically so we resort to matching on the error string.
schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`)
// A regular expression to match the error returned by net/http when a
// request header or value is invalid. This error isn't typed
// specifically so we resort to matching on the error string.
invalidHeaderErrorRe = regexp.MustCompile(`invalid header`)
// A regular expression to match the error returned by net/http when the
// TLS certificate is not trusted. This error isn't typed
// specifically so we resort to matching on the error string.
@@ -248,21 +255,19 @@ func getBodyReaderAndContentLength(rawBody interface{}) (ReaderFunc, int64, erro
// deal with it seeking so want it to match here instead of the
// io.ReadSeeker case.
case *bytes.Reader:
buf, err := ioutil.ReadAll(body)
if err != nil {
return nil, 0, err
}
snapshot := *body
bodyReader = func() (io.Reader, error) {
return bytes.NewReader(buf), nil
r := snapshot
return &r, nil
}
contentLength = int64(len(buf))
contentLength = int64(body.Len())
// Compat case
case io.ReadSeeker:
raw := body
bodyReader = func() (io.Reader, error) {
_, err := raw.Seek(0, 0)
return ioutil.NopCloser(raw), err
return io.NopCloser(raw), err
}
if lr, ok := raw.(LenReader); ok {
contentLength = int64(lr.Len())
@@ -270,7 +275,7 @@ func getBodyReaderAndContentLength(rawBody interface{}) (ReaderFunc, int64, erro
// Read all in so we can reset
case io.Reader:
buf, err := ioutil.ReadAll(body)
buf, err := io.ReadAll(body)
if err != nil {
return nil, 0, err
}
@@ -393,6 +398,9 @@ type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) t
// attempted. If overriding this, be sure to close the body if needed.
type ErrorHandler func(resp *http.Response, err error, numTries int) (*http.Response, error)
// PrepareRetry is called before retry operation. It can be used for example to re-sign the request
type PrepareRetry func(req *http.Request) error
// Client is used to make HTTP requests. It adds additional functionality
// like automatic retries to tolerate minor outages.
type Client struct {
@@ -421,6 +429,9 @@ type Client struct {
// ErrorHandler specifies the custom error handler to use, if any
ErrorHandler ErrorHandler
// PrepareRetry can prepare the request for retry operation, for example re-sign it
PrepareRetry PrepareRetry
loggerInit sync.Once
clientInit sync.Once
}
@@ -494,11 +505,16 @@ func baseRetryPolicy(resp *http.Response, err error) (bool, error) {
return false, v
}
// Don't retry if the error was due to an invalid header.
if invalidHeaderErrorRe.MatchString(v.Error()) {
return false, v
}
// Don't retry if the error was due to TLS cert verification failure.
if notTrustedErrorRe.MatchString(v.Error()) {
return false, v
}
if _, ok := v.Err.(x509.UnknownAuthorityError); ok {
if isCertError(v.Err) {
return false, v
}
}
@@ -535,10 +551,8 @@ func baseRetryPolicy(resp *http.Response, err error) (bool, error) {
func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration {
if resp != nil {
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable {
if s, ok := resp.Header["Retry-After"]; ok {
if sleep, err := strconv.ParseInt(s[0], 10, 64); err == nil {
return time.Second * time.Duration(sleep)
}
if sleep, ok := parseRetryAfterHeader(resp.Header["Retry-After"]); ok {
return sleep
}
}
}
@@ -551,6 +565,41 @@ func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response)
return sleep
}
// parseRetryAfterHeader parses the Retry-After header and returns the
// delay duration according to the spec: https://httpwg.org/specs/rfc7231.html#header.retry-after
// The bool returned will be true if the header was successfully parsed.
// Otherwise, the header was either not present, or was not parseable according to the spec.
//
// Retry-After headers come in two flavors: Seconds or HTTP-Date
//
// Examples:
// * Retry-After: Fri, 31 Dec 1999 23:59:59 GMT
// * Retry-After: 120
func parseRetryAfterHeader(headers []string) (time.Duration, bool) {
if len(headers) == 0 || headers[0] == "" {
return 0, false
}
header := headers[0]
// Retry-After: 120
if sleep, err := strconv.ParseInt(header, 10, 64); err == nil {
if sleep < 0 { // a negative sleep doesn't make sense
return 0, false
}
return time.Second * time.Duration(sleep), true
}
// Retry-After: Fri, 31 Dec 1999 23:59:59 GMT
retryTime, err := time.Parse(time.RFC1123, header)
if err != nil {
return 0, false
}
if until := retryTime.Sub(timeNow()); until > 0 {
return until, true
}
// date is in the past
return 0, true
}
// LinearJitterBackoff provides a callback for Client.Backoff which will
// perform linear backoff based on the attempt number and with jitter to
// prevent a thundering herd.
@@ -578,13 +627,13 @@ func LinearJitterBackoff(min, max time.Duration, attemptNum int, resp *http.Resp
}
// Seed rand; doing this every time is fine
rand := rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
source := rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
// Pick a random number that lies somewhere between the min and max and
// multiply by the attemptNum. attemptNum starts at zero so we always
// increment here. We first get a random percentage, then apply that to the
// difference between min and max, and add to min.
jitter := rand.Float64() * float64(max-min)
jitter := source.Float64() * float64(max-min)
jitterMin := int64(jitter) + int64(min)
return time.Duration(jitterMin * int64(attemptNum))
}
@@ -618,10 +667,10 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
var resp *http.Response
var attempt int
var shouldRetry bool
var doErr, respErr, checkErr error
var doErr, respErr, checkErr, prepareErr error
for i := 0; ; i++ {
doErr, respErr = nil, nil
doErr, respErr, prepareErr = nil, nil, nil
attempt++
// Always rewind the request body when non-nil.
@@ -634,7 +683,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
if c, ok := body.(io.ReadCloser); ok {
req.Body = c
} else {
req.Body = ioutil.NopCloser(body)
req.Body = io.NopCloser(body)
}
}
@@ -728,17 +777,26 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
// without racing against the closeBody call in persistConn.writeLoop.
httpreq := *req.Request
req.Request = &httpreq
if c.PrepareRetry != nil {
if err := c.PrepareRetry(req.Request); err != nil {
prepareErr = err
break
}
}
}
// this is the closest we have to success criteria
if doErr == nil && respErr == nil && checkErr == nil && !shouldRetry {
if doErr == nil && respErr == nil && checkErr == nil && prepareErr == nil && !shouldRetry {
return resp, nil
}
defer c.HTTPClient.CloseIdleConnections()
var err error
if checkErr != nil {
if prepareErr != nil {
err = prepareErr
} else if checkErr != nil {
err = checkErr
} else if respErr != nil {
err = respErr
@@ -770,7 +828,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
// Try to read the response body so we can reuse this connection.
func (c *Client) drainBody(body io.ReadCloser) {
defer body.Close()
_, err := io.Copy(ioutil.Discard, io.LimitReader(body, respReadLimit))
_, err := io.Copy(io.Discard, io.LimitReader(body, respReadLimit))
if err != nil {
if c.logger() != nil {
switch v := c.logger().(type) {