Merge pull request #123598 from liggitt/remotecommand-cleanup

Remotecommand test flake cleanup
This commit is contained in:
Kubernetes Prow Robot 2024-02-29 13:40:48 -08:00 committed by GitHub
commit 0d50a398df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 35 additions and 10 deletions

View File

@ -344,7 +344,7 @@ func (conn *Conn) handle(ws *websocket.Conn) {
continue continue
} }
if _, err := conn.channels[channel].DataFromSocket(data); err != nil { if _, err := conn.channels[channel].DataFromSocket(data); err != nil {
klog.Errorf("Unable to write frame to %d: %v\n%s", channel, err, string(data)) klog.Errorf("Unable to write frame (%d bytes) to %d: %v", len(data), channel, err)
continue continue
} }
} }

View File

@ -187,6 +187,9 @@ type wsStreamCreator struct {
// map of stream id to stream; multiple streams read/write the connection // map of stream id to stream; multiple streams read/write the connection
streams map[byte]*stream streams map[byte]*stream
streamsMu sync.Mutex streamsMu sync.Mutex
// setStreamErr holds the error to return to anyone calling setStreams.
// this is populated in closeAllStreamReaders
setStreamErr error
} }
func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator { func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator {
@ -202,10 +205,14 @@ func (c *wsStreamCreator) getStream(id byte) *stream {
return c.streams[id] return c.streams[id]
} }
func (c *wsStreamCreator) setStream(id byte, s *stream) { func (c *wsStreamCreator) setStream(id byte, s *stream) error {
c.streamsMu.Lock() c.streamsMu.Lock()
defer c.streamsMu.Unlock() defer c.streamsMu.Unlock()
if c.setStreamErr != nil {
return c.setStreamErr
}
c.streams[id] = s c.streams[id] = s
return nil
} }
// CreateStream uses id from passed headers to create a stream over "c.conn" connection. // CreateStream uses id from passed headers to create a stream over "c.conn" connection.
@ -228,7 +235,11 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream,
connWriteLock: &c.connWriteLock, connWriteLock: &c.connWriteLock,
id: id, id: id,
} }
c.setStream(id, s) if err := c.setStream(id, s); err != nil {
_ = s.writePipe.Close()
_ = s.readPipe.Close()
return nil, err
}
return s, nil return s, nil
} }
@ -312,7 +323,7 @@ func (c *wsStreamCreator) readDemuxLoop(bufferSize int, period time.Duration, de
} }
// closeAllStreamReaders closes readers in all streams. // closeAllStreamReaders closes readers in all streams.
// This unblocks all stream.Read() calls. // This unblocks all stream.Read() calls, and keeps any future streams from being created.
func (c *wsStreamCreator) closeAllStreamReaders(err error) { func (c *wsStreamCreator) closeAllStreamReaders(err error) {
c.streamsMu.Lock() c.streamsMu.Lock()
defer c.streamsMu.Unlock() defer c.streamsMu.Unlock()
@ -320,6 +331,12 @@ func (c *wsStreamCreator) closeAllStreamReaders(err error) {
// Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes. // Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes.
_ = s.writePipe.CloseWithError(err) _ = s.writePipe.CloseWithError(err)
} }
// ensure callers to setStreams receive an error after this point
if err != nil {
c.setStreamErr = err
} else {
c.setStreamErr = fmt.Errorf("closed all streams")
}
} }
type stream struct { type stream struct {

View File

@ -817,6 +817,8 @@ func TestWebSocketClient_BadHandshake(t *testing.T) {
// TestWebSocketClient_HeartbeatTimeout tests the heartbeat by forcing a // TestWebSocketClient_HeartbeatTimeout tests the heartbeat by forcing a
// timeout by setting the ping period greater than the deadline. // timeout by setting the ping period greater than the deadline.
func TestWebSocketClient_HeartbeatTimeout(t *testing.T) { func TestWebSocketClient_HeartbeatTimeout(t *testing.T) {
blockRequestCtx, unblockRequest := context.WithCancel(context.Background())
defer unblockRequest()
// Create fake WebSocket server which blocks. // Create fake WebSocket server which blocks.
websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req)) conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
@ -824,8 +826,7 @@ func TestWebSocketClient_HeartbeatTimeout(t *testing.T) {
t.Fatalf("error on webSocketServerStreams: %v", err) t.Fatalf("error on webSocketServerStreams: %v", err)
} }
defer conns.conn.Close() defer conns.conn.Close()
// Block server; heartbeat timeout (or test timeout) will fire before this returns. <-blockRequestCtx.Done()
time.Sleep(1 * time.Second)
})) }))
defer websocketServer.Close() defer websocketServer.Close()
// Create websocket client connecting to fake server. // Create websocket client connecting to fake server.
@ -840,8 +841,8 @@ func TestWebSocketClient_HeartbeatTimeout(t *testing.T) {
} }
streamExec := exec.(*wsStreamExecutor) streamExec := exec.(*wsStreamExecutor)
// Ping period is greater than the ping deadline, forcing the timeout to fire. // Ping period is greater than the ping deadline, forcing the timeout to fire.
pingPeriod := 20 * time.Millisecond pingPeriod := wait.ForeverTestTimeout // this lets the heartbeat deadline expire without renewing it
pingDeadline := 5 * time.Millisecond pingDeadline := time.Second // this gives setup 1 second to establish streams
streamExec.heartbeatPeriod = pingPeriod streamExec.heartbeatPeriod = pingPeriod
streamExec.heartbeatDeadline = pingDeadline streamExec.heartbeatDeadline = pingDeadline
// Send some random data to the websocket server through STDIN. // Send some random data to the websocket server through STDIN.
@ -859,8 +860,7 @@ func TestWebSocketClient_HeartbeatTimeout(t *testing.T) {
}() }()
select { select {
case <-time.After(pingPeriod * 5): case <-time.After(wait.ForeverTestTimeout):
// Give up after about five ping attempts
t.Fatalf("expected heartbeat timeout, got none.") t.Fatalf("expected heartbeat timeout, got none.")
case err := <-errorChan: case err := <-errorChan:
// Expecting heartbeat timeout error. // Expecting heartbeat timeout error.
@ -1116,6 +1116,14 @@ func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestLateStreamCreation(t *testing.T) {
c := newWSStreamCreator(nil)
c.closeAllStreamReaders(nil)
if err := c.setStream(0, nil); err == nil {
t.Fatal("expected error adding stream after closeAllStreamReaders")
}
}
func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) { func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) {
// Validate Stream functions. // Validate Stream functions.
c := newWSStreamCreator(nil) c := newWSStreamCreator(nil)