diff --git a/pkg/client/restclient/request.go b/pkg/client/restclient/request.go index d5c4384fc97..93c18ce9ead 100644 --- a/pkg/client/restclient/request.go +++ b/pkg/client/restclient/request.go @@ -18,6 +18,7 @@ package restclient import ( "bytes" + "context" "encoding/hex" "fmt" "io" @@ -105,16 +106,14 @@ type Request struct { resource string resourceName string subresource string - selector labels.Selector timeout time.Duration // output err error body io.Reader - // The constructed request and the response - req *http.Request - resp *http.Response + // This is only used for per-request timeouts, deadlines, and cancellations. + ctx context.Context backoffMgr BackoffManager throttle flowcontrol.RateLimiter @@ -566,6 +565,13 @@ func (r *Request) Body(obj interface{}) *Request { return r } +// Context adds a context to the request. Contexts are only used for +// timeouts, deadlines, and cancellations. +func (r *Request) Context(ctx context.Context) *Request { + r.ctx = ctx + return r +} + // URL returns the current working URL. func (r *Request) URL() *url.URL { p := r.pathPrefix @@ -651,6 +657,9 @@ func (r *Request) Watch() (watch.Interface, error) { if err != nil { return nil, err } + if r.ctx != nil { + req = req.WithContext(r.ctx) + } req.Header = r.headers client := r.client if client == nil { @@ -720,6 +729,9 @@ func (r *Request) Stream() (io.ReadCloser, error) { if err != nil { return nil, err } + if r.ctx != nil { + req = req.WithContext(r.ctx) + } req.Header = r.headers client := r.client if client == nil { @@ -794,6 +806,9 @@ func (r *Request) request(fn func(*http.Request, *http.Response)) error { if err != nil { return err } + if r.ctx != nil { + req = req.WithContext(r.ctx) + } req.Header = r.headers r.backoffMgr.Sleep(r.backoffMgr.CalculateBackoff(r.URL())) diff --git a/pkg/client/restclient/request_test.go b/pkg/client/restclient/request_test.go index 4493537a88c..8f8ddcd154f 100755 --- a/pkg/client/restclient/request_test.go +++ b/pkg/client/restclient/request_test.go @@ -18,6 +18,7 @@ package restclient import ( "bytes" + "context" "errors" "fmt" "io" @@ -1621,3 +1622,32 @@ func testRESTClient(t testing.TB, srv *httptest.Server) *RESTClient { } return client } + +func TestDoContext(t *testing.T) { + receivedCh := make(chan struct{}) + block := make(chan struct{}) + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + close(receivedCh) + <-block + w.WriteHeader(http.StatusOK) + })) + defer testServer.Close() + defer close(block) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + <-receivedCh + cancel() + }() + + c := testRESTClient(t, testServer) + _, err := c.Verb("GET"). + Context(ctx). + Prefix("foo"). + DoRaw() + if err == nil { + t.Fatal("Expected context cancellation error") + } +}