Add Context() to enable per-request cancellation

This commit is contained in:
Kris 2016-12-08 14:20:22 -08:00
parent 79f497bca7
commit 8070548ebe
2 changed files with 49 additions and 4 deletions

View File

@ -18,6 +18,7 @@ package restclient
import (
"bytes"
"context"
"encoding/hex"
"fmt"
"io"
@ -105,16 +106,14 @@ type Request struct {
resource string
resourceName string
subresource string
selector labels.Selector
timeout time.Duration
// output
err error
body io.Reader
// The constructed request and the response
req *http.Request
resp *http.Response
// This is only used for per-request timeouts, deadlines, and cancellations.
ctx context.Context
backoffMgr BackoffManager
throttle flowcontrol.RateLimiter
@ -566,6 +565,13 @@ func (r *Request) Body(obj interface{}) *Request {
return r
}
// Context adds a context to the request. Contexts are only used for
// timeouts, deadlines, and cancellations.
func (r *Request) Context(ctx context.Context) *Request {
r.ctx = ctx
return r
}
// URL returns the current working URL.
func (r *Request) URL() *url.URL {
p := r.pathPrefix
@ -651,6 +657,9 @@ func (r *Request) Watch() (watch.Interface, error) {
if err != nil {
return nil, err
}
if r.ctx != nil {
req = req.WithContext(r.ctx)
}
req.Header = r.headers
client := r.client
if client == nil {
@ -720,6 +729,9 @@ func (r *Request) Stream() (io.ReadCloser, error) {
if err != nil {
return nil, err
}
if r.ctx != nil {
req = req.WithContext(r.ctx)
}
req.Header = r.headers
client := r.client
if client == nil {
@ -794,6 +806,9 @@ func (r *Request) request(fn func(*http.Request, *http.Response)) error {
if err != nil {
return err
}
if r.ctx != nil {
req = req.WithContext(r.ctx)
}
req.Header = r.headers
r.backoffMgr.Sleep(r.backoffMgr.CalculateBackoff(r.URL()))

View File

@ -18,6 +18,7 @@ package restclient
import (
"bytes"
"context"
"errors"
"fmt"
"io"
@ -1621,3 +1622,32 @@ func testRESTClient(t testing.TB, srv *httptest.Server) *RESTClient {
}
return client
}
func TestDoContext(t *testing.T) {
receivedCh := make(chan struct{})
block := make(chan struct{})
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
close(receivedCh)
<-block
w.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
defer close(block)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
<-receivedCh
cancel()
}()
c := testRESTClient(t, testServer)
_, err := c.Verb("GET").
Context(ctx).
Prefix("foo").
DoRaw()
if err == nil {
t.Fatal("Expected context cancellation error")
}
}