mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-10 12:32:03 +00:00
Keep streams from being set up after closeAllStreamReaders is called
This commit is contained in:
parent
26484df210
commit
6c1a935da2
@ -187,6 +187,9 @@ type wsStreamCreator struct {
|
||||
// map of stream id to stream; multiple streams read/write the connection
|
||||
streams map[byte]*stream
|
||||
streamsMu sync.Mutex
|
||||
// setStreamErr holds the error to return to anyone calling setStreams.
|
||||
// this is populated in closeAllStreamReaders
|
||||
setStreamErr error
|
||||
}
|
||||
|
||||
func newWSStreamCreator(conn *gwebsocket.Conn) *wsStreamCreator {
|
||||
@ -202,10 +205,14 @@ func (c *wsStreamCreator) getStream(id byte) *stream {
|
||||
return c.streams[id]
|
||||
}
|
||||
|
||||
func (c *wsStreamCreator) setStream(id byte, s *stream) {
|
||||
func (c *wsStreamCreator) setStream(id byte, s *stream) error {
|
||||
c.streamsMu.Lock()
|
||||
defer c.streamsMu.Unlock()
|
||||
if c.setStreamErr != nil {
|
||||
return c.setStreamErr
|
||||
}
|
||||
c.streams[id] = s
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateStream uses id from passed headers to create a stream over "c.conn" connection.
|
||||
@ -228,7 +235,11 @@ func (c *wsStreamCreator) CreateStream(headers http.Header) (httpstream.Stream,
|
||||
connWriteLock: &c.connWriteLock,
|
||||
id: id,
|
||||
}
|
||||
c.setStream(id, s)
|
||||
if err := c.setStream(id, s); err != nil {
|
||||
_ = s.writePipe.Close()
|
||||
_ = s.readPipe.Close()
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@ -312,7 +323,7 @@ func (c *wsStreamCreator) readDemuxLoop(bufferSize int, period time.Duration, de
|
||||
}
|
||||
|
||||
// closeAllStreamReaders closes readers in all streams.
|
||||
// This unblocks all stream.Read() calls.
|
||||
// This unblocks all stream.Read() calls, and keeps any future streams from being created.
|
||||
func (c *wsStreamCreator) closeAllStreamReaders(err error) {
|
||||
c.streamsMu.Lock()
|
||||
defer c.streamsMu.Unlock()
|
||||
@ -320,6 +331,12 @@ func (c *wsStreamCreator) closeAllStreamReaders(err error) {
|
||||
// Closing writePipe unblocks all readPipe.Read() callers and prevents any future writes.
|
||||
_ = s.writePipe.CloseWithError(err)
|
||||
}
|
||||
// ensure callers to setStreams receive an error after this point
|
||||
if err != nil {
|
||||
c.setStreamErr = err
|
||||
} else {
|
||||
c.setStreamErr = fmt.Errorf("closed all streams")
|
||||
}
|
||||
}
|
||||
|
||||
type stream struct {
|
||||
|
@ -1116,6 +1116,14 @@ func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestLateStreamCreation(t *testing.T) {
|
||||
c := newWSStreamCreator(nil)
|
||||
c.closeAllStreamReaders(nil)
|
||||
if err := c.setStream(0, nil); err == nil {
|
||||
t.Fatal("expected error adding stream after closeAllStreamReaders")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) {
|
||||
// Validate Stream functions.
|
||||
c := newWSStreamCreator(nil)
|
||||
|
Loading…
Reference in New Issue
Block a user