diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch.go b/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch.go index 2433ea2facb..59a265b53c4 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch.go @@ -64,6 +64,8 @@ func (w *realTimeoutFactory) TimeoutCh() (<-chan time.Time, func() bool) { // serveWatch will serve a watch response. // TODO: the functionality in this method and in WatchServer.Serve is not cleanly decoupled. func serveWatch(watcher watch.Interface, scope *RequestScope, mediaTypeOptions negotiation.MediaTypeOptions, req *http.Request, w http.ResponseWriter, timeout time.Duration) { + defer watcher.Stop() + options, err := optionsForTransform(mediaTypeOptions, req) if err != nil { scope.err(err, w, req) @@ -201,7 +203,6 @@ func (s *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { // ensure the connection times out timeoutCh, cleanup := s.TimeoutFactory.TimeoutCh() defer cleanup() - defer s.Watching.Stop() // begin the stream w.Header().Set("Content-Type", s.MediaType) @@ -286,8 +287,6 @@ func (s *WatchServer) HandleWS(ws *websocket.Conn) { streamBuf := &bytes.Buffer{} ch := s.Watching.ResultChan() - defer s.Watching.Stop() - for { select { case <-done: diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/watch_test.go b/staging/src/k8s.io/apiserver/pkg/endpoints/watch_test.go index 915d793adf1..10ad79f123d 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/watch_test.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/watch_test.go @@ -43,6 +43,7 @@ import ( "k8s.io/apimachinery/pkg/watch" example "k8s.io/apiserver/pkg/apis/example" "k8s.io/apiserver/pkg/endpoints/handlers" + "k8s.io/apiserver/pkg/endpoints/handlers/responsewriters" apitesting "k8s.io/apiserver/pkg/endpoints/testing" "k8s.io/apiserver/pkg/registry/rest" "k8s.io/client-go/dynamic" @@ -565,6 +566,21 @@ func (t *fakeTimeoutFactory) TimeoutCh() (<-chan time.Time, func() bool) { } } +// serveWatch will serve a watch response according to the watcher and watchServer. +// Before watchServer.ServeHTTP, an error may occur like k8s.io/apiserver/pkg/endpoints/handlers/watch.go#serveWatch does. +func serveWatch(watcher watch.Interface, watchServer *handlers.WatchServer, preServeErr error) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + defer watcher.Stop() + + if preServeErr != nil { + responsewriters.ErrorNegotiated(preServeErr, watchServer.Scope.Serializer, watchServer.Scope.Kind.GroupVersion(), w, req) + return + } + + watchServer.ServeHTTP(w, req) + } +} + func TestWatchHTTPErrors(t *testing.T) { watcher := watch.NewFake() timeoutCh := make(chan time.Time) @@ -590,9 +606,7 @@ func TestWatchHTTPErrors(t *testing.T) { TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done}, } - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - watchServer.ServeHTTP(w, req) - })) + s := httptest.NewServer(serveWatch(watcher, watchServer, nil)) defer s.Close() // Setup a client @@ -629,6 +643,68 @@ func TestWatchHTTPErrors(t *testing.T) { } } +func TestWatchHTTPErrorsBeforeServe(t *testing.T) { + watcher := watch.NewFake() + timeoutCh := make(chan time.Time) + done := make(chan struct{}) + + info, ok := runtime.SerializerInfoForMediaType(codecs.SupportedMediaTypes(), runtime.ContentTypeJSON) + if !ok || info.StreamSerializer == nil { + t.Fatal(info) + } + serializer := info.StreamSerializer + + // Setup a new watchserver + watchServer := &handlers.WatchServer{ + Scope: &handlers.RequestScope{ + Serializer: runtime.NewSimpleNegotiatedSerializer(info), + Kind: testGroupVersion.WithKind("test"), + }, + Watching: watcher, + + MediaType: "testcase/json", + Framer: serializer.Framer, + Encoder: newCodec, + EmbeddedEncoder: newCodec, + + Fixup: func(obj runtime.Object) runtime.Object { return obj }, + TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done}, + } + + errStatus := errors.NewInternalError(fmt.Errorf("we got an error")) + + s := httptest.NewServer(serveWatch(watcher, watchServer, errStatus)) + defer s.Close() + + // Setup a client + dest, _ := url.Parse(s.URL) + dest.Path = "/" + prefix + "/" + newGroupVersion.Group + "/" + newGroupVersion.Version + "/simple" + dest.RawQuery = "watch=true" + + req, _ := http.NewRequest("GET", dest.String(), nil) + client := http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // We had already got an error before watch serve started + decoder := json.NewDecoder(resp.Body) + var status *metav1.Status + err = decoder.Decode(&status) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if status.Kind != "Status" || status.APIVersion != "v1" || status.Code != 500 || status.Status != "Failure" || !strings.Contains(status.Message, "we got an error") { + t.Fatalf("error: %#v", status) + } + + // check for leaks + if !watcher.IsStopped() { + t.Errorf("Leaked watcher goruntine after request done") + } +} + func TestWatchHTTPDynamicClientErrors(t *testing.T) { watcher := watch.NewFake() timeoutCh := make(chan time.Time) @@ -654,9 +730,7 @@ func TestWatchHTTPDynamicClientErrors(t *testing.T) { TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done}, } - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - watchServer.ServeHTTP(w, req) - })) + s := httptest.NewServer(serveWatch(watcher, watchServer, nil)) defer s.Close() defer s.CloseClientConnections() @@ -699,9 +773,7 @@ func TestWatchHTTPTimeout(t *testing.T) { TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done}, } - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - watchServer.ServeHTTP(w, req) - })) + s := httptest.NewServer(serveWatch(watcher, watchServer, nil)) defer s.Close() // Setup a client @@ -729,7 +801,7 @@ func TestWatchHTTPTimeout(t *testing.T) { close(timeoutCh) select { case <-done: - if !watcher.Stopped { + if !watcher.IsStopped() { t.Errorf("Leaked watch on timeout") } case <-time.After(wait.ForeverTestTimeout):