diff --git a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn.go b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn.go index 7cfdd063217..8a741936a3d 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/httpstream/wsstream/conn.go @@ -344,7 +344,7 @@ func (conn *Conn) handle(ws *websocket.Conn) { continue } if _, err := conn.channels[channel].DataFromSocket(data); err != nil { - klog.Errorf("Unable to write frame to %d: %v\n%s", channel, err, string(data)) + klog.Errorf("Unable to write frame (%d bytes) to %d: %v", len(data), channel, err) continue } } diff --git a/staging/src/k8s.io/client-go/tools/remotecommand/websocket.go b/staging/src/k8s.io/client-go/tools/remotecommand/websocket.go index a60986decca..49ef4717cd9 100644 --- a/staging/src/k8s.io/client-go/tools/remotecommand/websocket.go +++ b/staging/src/k8s.io/client-go/tools/remotecommand/websocket.go @@ -187,6 +187,9 @@ type wsStreamCreator struct { // map of stream id to stream; multiple streams read/write the connection streams map[byte]*stream streamsMu sync.Mutex + // setStreamErr holds the error to return to anyone calling setStreams. + // this is populated in closeAllStreamReaders + setStreamErr error } func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator { @@ -202,10 +205,14 @@ func (c *wsStreamCreator) getStream(id byte) *stream { return c.streams[id] } -func (c *wsStreamCreator) setStream(id byte, s *stream) { +func (c *wsStreamCreator) setStream(id byte, s *stream) error { c.streamsMu.Lock() defer c.streamsMu.Unlock() + if c.setStreamErr != nil { + return c.setStreamErr + } c.streams[id] = s + return nil } // CreateStream uses id from passed headers to create a stream over "c.conn" connection. @@ -228,7 +235,11 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, connWriteLock: &c.connWriteLock, id: id, } - c.setStream(id, s) + if err := c.setStream(id, s); err != nil { + _ = s.writePipe.Close() + _ = s.readPipe.Close() + return nil, err + } return s, nil } @@ -312,7 +323,7 @@ func (c *wsStreamCreator) readDemuxLoop(bufferSize int, period time.Duration, de } // closeAllStreamReaders closes readers in all streams. -// This unblocks all stream.Read() calls. +// This unblocks all stream.Read() calls, and keeps any future streams from being created. func (c *wsStreamCreator) closeAllStreamReaders(err error) { c.streamsMu.Lock() defer c.streamsMu.Unlock() @@ -320,6 +331,12 @@ func (c *wsStreamCreator) closeAllStreamReaders(err error) { // Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes. _ = s.writePipe.CloseWithError(err) } + // ensure callers to setStreams receive an error after this point + if err != nil { + c.setStreamErr = err + } else { + c.setStreamErr = fmt.Errorf("closed all streams") + } } type stream struct { diff --git a/staging/src/k8s.io/client-go/tools/remotecommand/websocket_test.go b/staging/src/k8s.io/client-go/tools/remotecommand/websocket_test.go index 61df2b77a4c..4a333b0b24e 100644 --- a/staging/src/k8s.io/client-go/tools/remotecommand/websocket_test.go +++ b/staging/src/k8s.io/client-go/tools/remotecommand/websocket_test.go @@ -817,6 +817,8 @@ func TestWebSocketClient_BadHandshake(t *testing.T) { // TestWebSocketClient_HeartbeatTimeout tests the heartbeat by forcing a // timeout by setting the ping period greater than the deadline. func TestWebSocketClient_HeartbeatTimeout(t *testing.T) { + blockRequestCtx, unblockRequest := context.WithCancel(context.Background()) + defer unblockRequest() // Create fake WebSocket server which blocks. websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) @@ -824,8 +826,7 @@ func TestWebSocketClient_HeartbeatTimeout(t *testing.T) { t.Fatalf("error on webSocketServerStreams: %v", err) } defer conns.conn.Close() - // Block server; heartbeat timeout (or test timeout) will fire before this returns. - time.Sleep(1 * time.Second) + <-blockRequestCtx.Done() })) defer websocketServer.Close() // Create websocket client connecting to fake server. @@ -840,8 +841,8 @@ func TestWebSocketClient_HeartbeatTimeout(t *testing.T) { } streamExec := exec.(*wsStreamExecutor) // Ping period is greater than the ping deadline, forcing the timeout to fire. - pingPeriod := 20 * time.Millisecond - pingDeadline := 5 * time.Millisecond + pingPeriod := wait.ForeverTestTimeout // this lets the heartbeat deadline expire without renewing it + pingDeadline := time.Second // this gives setup 1 second to establish streams streamExec.heartbeatPeriod = pingPeriod streamExec.heartbeatDeadline = pingDeadline // Send some random data to the websocket server through STDIN. @@ -859,8 +860,7 @@ func TestWebSocketClient_HeartbeatTimeout(t *testing.T) { }() select { - case <-time.After(pingPeriod * 5): - // Give up after about five ping attempts + case <-time.After(wait.ForeverTestTimeout): t.Fatalf("expected heartbeat timeout, got none.") case err := <-errorChan: // Expecting heartbeat timeout error. @@ -1116,6 +1116,14 @@ func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) { wg.Wait() } +func TestLateStreamCreation(t *testing.T) { + c := newWSStreamCreator(nil) + c.closeAllStreamReaders(nil) + if err := c.setStream(0, nil); err == nil { + t.Fatal("expected error adding stream after closeAllStreamReaders") + } +} + func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) { // Validate Stream functions. c := newWSStreamCreator(nil)