From fdb7c93a9797482822f95326cd2c24e9100d00b3 Mon Sep 17 00:00:00 2001 From: Jordan Liggitt Date: Tue, 16 Aug 2016 23:33:13 -0400 Subject: [PATCH] Close websocket watch when client closes --- pkg/apiserver/watch.go | 10 ++++- pkg/apiserver/watch_test.go | 73 +++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/pkg/apiserver/watch.go b/pkg/apiserver/watch.go index fb57f33119d..2aa5409a193 100755 --- a/pkg/apiserver/watch.go +++ b/pkg/apiserver/watch.go @@ -216,7 +216,15 @@ func (s *WatchServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { func (s *WatchServer) HandleWS(ws *websocket.Conn) { defer ws.Close() done := make(chan struct{}) - go wsstream.IgnoreReceives(ws, 0) + + go func() { + defer utilruntime.HandleCrash() + // This blocks until the connection is closed. + // Client should not send anything. + wsstream.IgnoreReceives(ws, 0) + // Once the client closes, we should also close + close(done) + }() var unknown runtime.Unknown internalEvent := &versioned.InternalEvent{} diff --git a/pkg/apiserver/watch_test.go b/pkg/apiserver/watch_test.go index d54f6b8fbd9..7844deccf04 100644 --- a/pkg/apiserver/watch_test.go +++ b/pkg/apiserver/watch_test.go @@ -136,6 +136,79 @@ func TestWatchWebsocket(t *testing.T) { } } +func TestWatchWebsocketClientClose(t *testing.T) { + simpleStorage := &SimpleRESTStorage{} + _ = rest.Watcher(simpleStorage) // Give compile error if this doesn't work. + handler := handle(map[string]rest.Storage{"simples": simpleStorage}) + server := httptest.NewServer(handler) + defer server.Close() + + dest, _ := url.Parse(server.URL) + dest.Scheme = "ws" // Required by websocket, though the server never sees it. + dest.Path = "/" + prefix + "/" + testGroupVersion.Group + "/" + testGroupVersion.Version + "/watch/simples" + dest.RawQuery = "" + + ws, err := websocket.Dial(dest.String(), "", "http://localhost") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + try := func(action watch.EventType, object runtime.Object) { + // Send + simpleStorage.fakeWatch.Action(action, object) + // Test receive + var got watchJSON + err := websocket.JSON.Receive(ws, &got) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if got.Type != action { + t.Errorf("Unexpected type: %v", got.Type) + } + gotObj, err := runtime.Decode(codec, got.Object) + if err != nil { + t.Fatalf("Decode error: %v\n%v", err, got) + } + if _, err := api.GetReference(gotObj); err != nil { + t.Errorf("Unable to construct reference: %v", err) + } + if e, a := object, gotObj; !reflect.DeepEqual(e, a) { + t.Errorf("Expected %#v, got %#v", e, a) + } + } + + // Send/receive should work + for _, item := range watchTestTable { + try(item.t, item.obj) + } + + // Sending normal data should be ignored + websocket.JSON.Send(ws, map[string]interface{}{"test": "data"}) + + // Send/receive should still work + for _, item := range watchTestTable { + try(item.t, item.obj) + } + + // Client requests a close + ws.Close() + + select { + case data, ok := <-simpleStorage.fakeWatch.ResultChan(): + if ok { + t.Errorf("expected a closed result channel, but got watch result %#v", data) + } + case <-time.After(5 * time.Second): + t.Errorf("watcher did not close when client closed") + } + + var got watchJSON + err = websocket.JSON.Receive(ws, &got) + if err == nil { + t.Errorf("Unexpected non-error") + } +} + func TestWatchRead(t *testing.T) { simpleStorage := &SimpleRESTStorage{} _ = rest.Watcher(simpleStorage) // Give compile error if this doesn't work.