diff --git a/go.mod b/go.mod index a2bc4f91..5522226a 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( golang.org/x/time v0.3.0 google.golang.org/protobuf v1.33.0 k8s.io/api v0.0.0-20240404161350-448db12cecfb - k8s.io/apimachinery v0.0.0-20240404161013-3e7c65a7bc4d + k8s.io/apimachinery v0.0.0-20240405121012-2bbf53022625 k8s.io/klog/v2 v2.110.1 k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 k8s.io/utils v0.0.0-20230726121419-3b25d923346b diff --git a/go.sum b/go.sum index f617f1d1..dc319d87 100644 --- a/go.sum +++ b/go.sum @@ -155,8 +155,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= k8s.io/api v0.0.0-20240404161350-448db12cecfb h1:/uYBtjPF21kf1qLyVVXGQOcFcjjqGowAssqopfbjQB0= k8s.io/api v0.0.0-20240404161350-448db12cecfb/go.mod h1:dPhF9d7Kl/mI8T1SzZBp2/saFzES5J24O1PGAwNajB0= -k8s.io/apimachinery v0.0.0-20240404161013-3e7c65a7bc4d h1:bth65kuQyrmaqeI1unRAE9O4KtHWeblZ8mIZEjpOuc4= -k8s.io/apimachinery v0.0.0-20240404161013-3e7c65a7bc4d/go.mod h1:i3FJVwhvSp/6n8Fl4K97PJEP8C+MM+aoDq4+ZJBf70Y= +k8s.io/apimachinery v0.0.0-20240405121012-2bbf53022625 h1:itYkKr52W3ugNsYoBZwpwcezu/vPxg3GK7XVkX24GAo= +k8s.io/apimachinery v0.0.0-20240405121012-2bbf53022625/go.mod h1:i3FJVwhvSp/6n8Fl4K97PJEP8C+MM+aoDq4+ZJBf70Y= k8s.io/klog/v2 v2.110.1 h1:U/Af64HJf7FcwMcXyKm2RPM22WZzyR7OSpYj5tg3cL0= k8s.io/klog/v2 v2.110.1/go.mod h1:YGtd1984u+GgbuZ7e08/yBuAfKLSO0+uR1Fhi6ExXjo= k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 h1:aVUu9fTY98ivBPKR9Y5w/AuzbMm96cd3YHRTU83I780= diff --git a/tools/remotecommand/websocket.go b/tools/remotecommand/websocket.go index a60986de..49ef4717 100644 --- a/tools/remotecommand/websocket.go +++ b/tools/remotecommand/websocket.go @@ -187,6 +187,9 @@ type wsStreamCreator struct { // map of stream id to stream; multiple streams read/write the connection streams map[byte]*stream streamsMu sync.Mutex + // setStreamErr holds the error to return to anyone calling setStreams. + // this is populated in closeAllStreamReaders + setStreamErr error } func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator { @@ -202,10 +205,14 @@ func (c *wsStreamCreator) getStream(id byte) *stream { return c.streams[id] } -func (c *wsStreamCreator) setStream(id byte, s *stream) { +func (c *wsStreamCreator) setStream(id byte, s *stream) error { c.streamsMu.Lock() defer c.streamsMu.Unlock() + if c.setStreamErr != nil { + return c.setStreamErr + } c.streams[id] = s + return nil } // CreateStream uses id from passed headers to create a stream over "c.conn" connection. @@ -228,7 +235,11 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream, connWriteLock: &c.connWriteLock, id: id, } - c.setStream(id, s) + if err := c.setStream(id, s); err != nil { + _ = s.writePipe.Close() + _ = s.readPipe.Close() + return nil, err + } return s, nil } @@ -312,7 +323,7 @@ func (c *wsStreamCreator) readDemuxLoop(bufferSize int, period time.Duration, de } // closeAllStreamReaders closes readers in all streams. -// This unblocks all stream.Read() calls. +// This unblocks all stream.Read() calls, and keeps any future streams from being created. func (c *wsStreamCreator) closeAllStreamReaders(err error) { c.streamsMu.Lock() defer c.streamsMu.Unlock() @@ -320,6 +331,12 @@ func (c *wsStreamCreator) closeAllStreamReaders(err error) { // Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes. _ = s.writePipe.CloseWithError(err) } + // ensure callers to setStreams receive an error after this point + if err != nil { + c.setStreamErr = err + } else { + c.setStreamErr = fmt.Errorf("closed all streams") + } } type stream struct { diff --git a/tools/remotecommand/websocket_test.go b/tools/remotecommand/websocket_test.go index 61df2b77..4a333b0b 100644 --- a/tools/remotecommand/websocket_test.go +++ b/tools/remotecommand/websocket_test.go @@ -817,6 +817,8 @@ func TestWebSocketClient_BadHandshake(t *testing.T) { // TestWebSocketClient_HeartbeatTimeout tests the heartbeat by forcing a // timeout by setting the ping period greater than the deadline. func TestWebSocketClient_HeartbeatTimeout(t *testing.T) { + blockRequestCtx, unblockRequest := context.WithCancel(context.Background()) + defer unblockRequest() // Create fake WebSocket server which blocks. websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) @@ -824,8 +826,7 @@ func TestWebSocketClient_HeartbeatTimeout(t *testing.T) { t.Fatalf("error on webSocketServerStreams: %v", err) } defer conns.conn.Close() - // Block server; heartbeat timeout (or test timeout) will fire before this returns. - time.Sleep(1 * time.Second) + <-blockRequestCtx.Done() })) defer websocketServer.Close() // Create websocket client connecting to fake server. @@ -840,8 +841,8 @@ func TestWebSocketClient_HeartbeatTimeout(t *testing.T) { } streamExec := exec.(*wsStreamExecutor) // Ping period is greater than the ping deadline, forcing the timeout to fire. - pingPeriod := 20 * time.Millisecond - pingDeadline := 5 * time.Millisecond + pingPeriod := wait.ForeverTestTimeout // this lets the heartbeat deadline expire without renewing it + pingDeadline := time.Second // this gives setup 1 second to establish streams streamExec.heartbeatPeriod = pingPeriod streamExec.heartbeatDeadline = pingDeadline // Send some random data to the websocket server through STDIN. @@ -859,8 +860,7 @@ func TestWebSocketClient_HeartbeatTimeout(t *testing.T) { }() select { - case <-time.After(pingPeriod * 5): - // Give up after about five ping attempts + case <-time.After(wait.ForeverTestTimeout): t.Fatalf("expected heartbeat timeout, got none.") case err := <-errorChan: // Expecting heartbeat timeout error. @@ -1116,6 +1116,14 @@ func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) { wg.Wait() } +func TestLateStreamCreation(t *testing.T) { + c := newWSStreamCreator(nil) + c.closeAllStreamReaders(nil) + if err := c.setStream(0, nil); err == nil { + t.Fatal("expected error adding stream after closeAllStreamReaders") + } +} + func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) { // Validate Stream functions. c := newWSStreamCreator(nil)