diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch.go b/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch.go index 2433ea2facb..fb201ba161e 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch.go @@ -173,13 +173,6 @@ func (s *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - cn, ok := w.(http.CloseNotifier) - if !ok { - err := fmt.Errorf("unable to start watch - can't get http.CloseNotifier: %#v", w) - utilruntime.HandleError(err) - s.Scope.err(errors.NewInternalError(err), w, req) - return - } flusher, ok := w.(http.Flusher) if !ok { err := fmt.Errorf("unable to start watch - can't get http.Flusher: %#v", w) @@ -214,9 +207,11 @@ func (s *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { outEvent := &metav1.WatchEvent{} buf := &bytes.Buffer{} ch := s.Watching.ResultChan() + done := req.Context().Done() + for { select { - case <-cn.CloseNotify(): + case <-done: return case <-timeoutCh: return diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/watch_test.go b/staging/src/k8s.io/apiserver/pkg/endpoints/watch_test.go index 915d793adf1..5629089b55c 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/watch_test.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/watch_test.go @@ -212,6 +212,48 @@ func TestWatchWebsocketClientClose(t *testing.T) { } } +func TestWatchClientClose(t *testing.T) { + simpleStorage := &SimpleRESTStorage{} + _ = rest.Watcher(simpleStorage) // Give compile error if this doesn't work. + handler := handle(map[string]rest.Storage{"simples": simpleStorage}) + server := httptest.NewServer(handler) + defer server.Close() + + dest, _ := url.Parse(server.URL) + dest.Path = "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/simples" + dest.RawQuery = "watch=1" + + request, err := http.NewRequest("GET", dest.String(), nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + request.Header.Add("Accept", "application/json") + + response, err := http.DefaultClient.Do(request) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if response.StatusCode != http.StatusOK { + b, _ := ioutil.ReadAll(response.Body) + t.Fatalf("Unexpected response: %#v\n%s", response, string(b)) + } + + // Close response to cause a cancel on the server + if err := response.Body.Close(); err != nil { + t.Fatalf("Unexpected close client err: %v", err) + } + + select { + case data, ok := <-simpleStorage.fakeWatch.ResultChan(): + if ok { + t.Errorf("expected a closed result channel, but got watch result %#v", data) + } + case <-time.After(5 * time.Second): + t.Errorf("watcher did not close when client closed") + } +} + func TestWatchRead(t *testing.T) { simpleStorage := &SimpleRESTStorage{} _ = rest.Watcher(simpleStorage) // Give compile error if this doesn't work.