From 2f016580ef9a7d4d7cf0688acd844fe7150cd95d Mon Sep 17 00:00:00 2001 From: xuzhenglun Date: Tue, 8 Jul 2025 18:31:30 +0800 Subject: [PATCH] make sure all streams are created before starting demux websocket Kubernetes-commit: 3379d5ac4b6a1afbbaead06689a8584ce546a275 --- tools/remotecommand/remotecommand.go | 2 +- tools/remotecommand/spdy.go | 6 +++++- tools/remotecommand/spdy_test.go | 2 +- tools/remotecommand/v1.go | 7 ++++++- tools/remotecommand/v2.go | 7 ++++++- tools/remotecommand/v3.go | 7 ++++++- tools/remotecommand/v4.go | 7 ++++++- tools/remotecommand/v5.go | 4 ++-- tools/remotecommand/websocket.go | 26 ++++++++++++++++++++------ 9 files changed, 53 insertions(+), 15 deletions(-) diff --git a/tools/remotecommand/remotecommand.go b/tools/remotecommand/remotecommand.go index 1ae67729..4cff05cd 100644 --- a/tools/remotecommand/remotecommand.go +++ b/tools/remotecommand/remotecommand.go @@ -54,5 +54,5 @@ type streamCreator interface { } type streamProtocolHandler interface { - stream(conn streamCreator) error + stream(conn streamCreator, ready chan<- struct{}) error } diff --git a/tools/remotecommand/spdy.go b/tools/remotecommand/spdy.go index c2bfcf8a..34825771 100644 --- a/tools/remotecommand/spdy.go +++ b/tools/remotecommand/spdy.go @@ -157,7 +157,11 @@ func (e *spdyStreamExecutor) StreamWithContext(ctx context.Context, options Stre panicChan <- p } }() - errorChan <- streamer.stream(conn) + + // The SPDY executor does not need to synchronize stream creation, so we pass a nil + // ready channel. The underlying spdystream library handles stream multiplexing + // without a race condition. + errorChan <- streamer.stream(conn, nil) }() select { diff --git a/tools/remotecommand/spdy_test.go b/tools/remotecommand/spdy_test.go index 1b1cf749..9948832a 100644 --- a/tools/remotecommand/spdy_test.go +++ b/tools/remotecommand/spdy_test.go @@ -352,7 +352,7 @@ func TestStreamExitsAfterConnectionIsClosed(t *testing.T) { errorChan := make(chan error) go func() { - errorChan <- streamer.stream(conn) + errorChan <- streamer.stream(conn, nil) }() // Wait until stream goroutine starts. diff --git a/tools/remotecommand/v1.go b/tools/remotecommand/v1.go index efa9a6c9..293d809d 100644 --- a/tools/remotecommand/v1.go +++ b/tools/remotecommand/v1.go @@ -47,7 +47,7 @@ func newStreamProtocolV1(options StreamOptions) streamProtocolHandler { } } -func (p *streamProtocolV1) stream(conn streamCreator) error { +func (p *streamProtocolV1) stream(conn streamCreator, ready chan<- struct{}) error { doneChan := make(chan struct{}, 2) errorChan := make(chan error) @@ -106,6 +106,11 @@ func (p *streamProtocolV1) stream(conn streamCreator) error { defer p.remoteStderr.Reset() } + // Signal that all streams have been created. + if ready != nil { + close(ready) + } + // now that all the streams have been created, proceed with reading & copying // always read from errorStream diff --git a/tools/remotecommand/v2.go b/tools/remotecommand/v2.go index d54612f4..a81538a0 100644 --- a/tools/remotecommand/v2.go +++ b/tools/remotecommand/v2.go @@ -169,11 +169,16 @@ func (p *streamProtocolV2) copyStderr(wg *sync.WaitGroup) { }() } -func (p *streamProtocolV2) stream(conn streamCreator) error { +func (p *streamProtocolV2) stream(conn streamCreator, ready chan<- struct{}) error { if err := p.createStreams(conn); err != nil { return err } + // Signal that all streams have been created. + if ready != nil { + close(ready) + } + // now that all the streams have been created, proceed with reading & copying errorChan := watchErrorStream(p.errorStream, &errorDecoderV2{}) diff --git a/tools/remotecommand/v3.go b/tools/remotecommand/v3.go index 846dd24a..ece4cfaf 100644 --- a/tools/remotecommand/v3.go +++ b/tools/remotecommand/v3.go @@ -82,11 +82,16 @@ func (p *streamProtocolV3) handleResizes() { }() } -func (p *streamProtocolV3) stream(conn streamCreator) error { +func (p *streamProtocolV3) stream(conn streamCreator, ready chan<- struct{}) error { if err := p.createStreams(conn); err != nil { return err } + // Signal that all streams have been created. + if ready != nil { + close(ready) + } + // now that all the streams have been created, proceed with reading & copying errorChan := watchErrorStream(p.errorStream, &errorDecoderV3{}) diff --git a/tools/remotecommand/v4.go b/tools/remotecommand/v4.go index 6146bdf1..ecedad07 100644 --- a/tools/remotecommand/v4.go +++ b/tools/remotecommand/v4.go @@ -51,11 +51,16 @@ func (p *streamProtocolV4) handleResizes() { p.streamProtocolV3.handleResizes() } -func (p *streamProtocolV4) stream(conn streamCreator) error { +func (p *streamProtocolV4) stream(conn streamCreator, ready chan<- struct{}) error { if err := p.createStreams(conn); err != nil { return err } + // Signal that all streams have been created. + if ready != nil { + close(ready) + } + // now that all the streams have been created, proceed with reading & copying errorChan := watchErrorStream(p.errorStream, &errorDecoderV4{}) diff --git a/tools/remotecommand/v5.go b/tools/remotecommand/v5.go index 4da7bfb1..edfd3ccb 100644 --- a/tools/remotecommand/v5.go +++ b/tools/remotecommand/v5.go @@ -30,6 +30,6 @@ func newStreamProtocolV5(options StreamOptions) streamProtocolHandler { } } -func (p *streamProtocolV5) stream(conn streamCreator) error { - return p.streamProtocolV4.stream(conn) +func (p *streamProtocolV5) stream(conn streamCreator, ready chan<- struct{}) error { + return p.streamProtocolV4.stream(conn, ready) } diff --git a/tools/remotecommand/websocket.go b/tools/remotecommand/websocket.go index cea26a0b..e0433198 100644 --- a/tools/remotecommand/websocket.go +++ b/tools/remotecommand/websocket.go @@ -157,13 +157,27 @@ func (e *wsStreamExecutor) StreamWithContext(ctx context.Context, options Stream panicChan <- p } }() + + readyChan := make(chan struct{}) creator := newWSStreamCreator(conn) - go creator.readDemuxLoop( - e.upgrader.DataBufferSize(), - e.heartbeatPeriod, - e.heartbeatDeadline, - ) - errorChan <- streamer.stream(creator) + go func() { + select { + // Wait until all streams have been created before starting the readDemuxLoop. + // This is to avoid a race condition where the readDemuxLoop receives a message + // for a stream that has not yet been created. + case <-readyChan: + case <-ctx.Done(): + creator.closeAllStreamReaders(ctx.Err()) + return + } + + creator.readDemuxLoop( + e.upgrader.DataBufferSize(), + e.heartbeatPeriod, + e.heartbeatDeadline, + ) + }() + errorChan <- streamer.stream(creator, readyChan) }() select {