diff --git a/rest/config_test.go b/rest/config_test.go index e832b323..fa58f087 100644 --- a/rest/config_test.go +++ b/rest/config_test.go @@ -172,6 +172,10 @@ func (t *fakeLimiter) QPS() float32 { return t.FakeQPS } +func (t *fakeLimiter) Wait(ctx context.Context) error { + return nil +} + func (t *fakeLimiter) Stop() {} func (t *fakeLimiter) Accept() {} diff --git a/rest/request.go b/rest/request.go index 0570615f..0bdb0b55 100644 --- a/rest/request.go +++ b/rest/request.go @@ -521,14 +521,24 @@ func (r Request) finalURLTemplate() url.URL { return *url } -func (r *Request) tryThrottle() { +func (r *Request) tryThrottle() error { + if r.throttle == nil { + return nil + } + now := time.Now() - if r.throttle != nil { + var err error + if r.ctx != nil { + err = r.throttle.Wait(r.ctx) + } else { r.throttle.Accept() } + if latency := time.Since(now); latency > longThrottleLatency { klog.V(4).Infof("Throttling request took %v, request: %s:%s", latency, r.verb, r.URL().String()) } + + return err } // Watch attempts to begin watching the requested location. @@ -630,7 +640,9 @@ func (r *Request) Stream() (io.ReadCloser, error) { return nil, r.err } - r.tryThrottle() + if err := r.tryThrottle(); err != nil { + return nil, err + } url := r.URL().String() req, err := http.NewRequest(r.verb, url, nil) @@ -732,7 +744,9 @@ func (r *Request) request(fn func(*http.Request, *http.Response)) error { // We are retrying the request that we already send to apiserver // at least once before. // This request should also be throttled with the client-internal throttler. - r.tryThrottle() + if err := r.tryThrottle(); err != nil { + return err + } } resp, err := client.Do(req) updateURLMetrics(r, resp, err) @@ -803,7 +817,9 @@ func (r *Request) request(fn func(*http.Request, *http.Response)) error { // * If the server responds with a status: *errors.StatusError or *errors.UnexpectedObjectError // * http.Client.Do errors are returned directly. func (r *Request) Do() Result { - r.tryThrottle() + if err := r.tryThrottle(); err != nil { + return Result{err: err} + } var result Result err := r.request(func(req *http.Request, resp *http.Response) { @@ -817,7 +833,9 @@ func (r *Request) Do() Result { // DoRaw executes the request but does not process the response body. func (r *Request) DoRaw() ([]byte, error) { - r.tryThrottle() + if err := r.tryThrottle(); err != nil { + return nil, err + } var result Result err := r.request(func(req *http.Request, resp *http.Response) { diff --git a/util/flowcontrol/throttle.go b/util/flowcontrol/throttle.go index e671c044..ffd912c5 100644 --- a/util/flowcontrol/throttle.go +++ b/util/flowcontrol/throttle.go @@ -17,6 +17,8 @@ limitations under the License. package flowcontrol import ( + "context" + "errors" "sync" "time" @@ -33,6 +35,8 @@ type RateLimiter interface { Stop() // QPS returns QPS of this rate limiter QPS() float32 + // Wait returns nil if a token is taken before the Context is done. + Wait(ctx context.Context) error } type tokenBucketRateLimiter struct { @@ -98,6 +102,10 @@ func (t *tokenBucketRateLimiter) QPS() float32 { return t.qps } +func (t *tokenBucketRateLimiter) Wait(ctx context.Context) error { + return t.limiter.Wait(ctx) +} + type fakeAlwaysRateLimiter struct{} func NewFakeAlwaysRateLimiter() RateLimiter { @@ -116,6 +124,10 @@ func (t *fakeAlwaysRateLimiter) QPS() float32 { return 1 } +func (t *fakeAlwaysRateLimiter) Wait(ctx context.Context) error { + return nil +} + type fakeNeverRateLimiter struct { wg sync.WaitGroup } @@ -141,3 +153,7 @@ func (t *fakeNeverRateLimiter) Accept() { func (t *fakeNeverRateLimiter) QPS() float32 { return 1 } + +func (t *fakeNeverRateLimiter) Wait(ctx context.Context) error { + return errors.New("can not be accept") +} diff --git a/util/flowcontrol/throttle_test.go b/util/flowcontrol/throttle_test.go index 99cf64d6..e48ad51b 100644 --- a/util/flowcontrol/throttle_test.go +++ b/util/flowcontrol/throttle_test.go @@ -17,6 +17,8 @@ limitations under the License. package flowcontrol import ( + "context" + "fmt" "sync" "testing" "time" @@ -151,3 +153,21 @@ func TestNeverFake(t *testing.T) { t.Error("Stop should make Accept unblock in NeverFake.") } } + +func TestWait(t *testing.T) { + r := NewTokenBucketRateLimiter(0.0001, 1) + + ctx, cancelFn := context.WithTimeout(context.Background(), time.Second) + defer cancelFn() + if err := r.Wait(ctx); err != nil { + t.Errorf("unexpected wait failed, err: %v", err) + } + + ctx2, cancelFn2 := context.WithTimeout(context.Background(), time.Second) + defer cancelFn2() + if err := r.Wait(ctx2); err == nil { + t.Errorf("unexpected wait success") + } else { + t.Log(fmt.Sprintf("wait err: %v", err)) + } +}