Merge pull request #112091 from xyz-li/master

kubectl: fix memory leaks in port forwarding client

Kubernetes-commit: 7a68c8a21a29787adf9e959271b8f955a68d3d82
This commit is contained in:
Kubernetes Publisher 2022-09-15 06:47:25 -07:00
commit eecd3e52a3
2 changed files with 11 additions and 2 deletions

View File

@ -347,6 +347,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
} }
// we're not writing to this stream // we're not writing to this stream
errorStream.Close() errorStream.Close()
defer pf.streamConn.RemoveStreams(errorStream)
errorChan := make(chan error) errorChan := make(chan error)
go func() { go func() {
@ -367,6 +368,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err)) runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
return return
} }
defer pf.streamConn.RemoveStreams(dataStream)
localError := make(chan struct{}) localError := make(chan struct{})
remoteDone := make(chan struct{}) remoteDone := make(chan struct{})

View File

@ -51,6 +51,7 @@ type fakeConnection struct {
closeChan chan bool closeChan chan bool
dataStream *fakeStream dataStream *fakeStream
errorStream *fakeStream errorStream *fakeStream
streamCount int
} }
func newFakeConnection() *fakeConnection { func newFakeConnection() *fakeConnection {
@ -64,8 +65,10 @@ func newFakeConnection() *fakeConnection {
func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) { func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
switch headers.Get(v1.StreamType) { switch headers.Get(v1.StreamType) {
case v1.StreamTypeData: case v1.StreamTypeData:
c.streamCount++
return c.dataStream, nil return c.dataStream, nil
case v1.StreamTypeError: case v1.StreamTypeError:
c.streamCount++
return c.errorStream, nil return c.errorStream, nil
default: default:
return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType)) return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType))
@ -84,7 +87,10 @@ func (c *fakeConnection) CloseChan() <-chan bool {
return c.closeChan return c.closeChan
} }
func (c *fakeConnection) RemoveStreams(_ ...httpstream.Stream) { func (c *fakeConnection) RemoveStreams(streams ...httpstream.Stream) {
for range streams {
c.streamCount--
}
} }
func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) { func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) {
@ -504,7 +510,7 @@ func TestHandleConnection(t *testing.T) {
// Test handleConnection // Test handleConnection
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222}) pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
assert.Equal(t, 0, remoteConnection.streamCount, "stream count should be zero")
assert.Equal(t, "test data from local", remoteDataReceived.String()) assert.Equal(t, "test data from local", remoteDataReceived.String())
assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String()) assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String())
assert.Equal(t, "Handling connection for 1111\n", out.String()) assert.Equal(t, "Handling connection for 1111\n", out.String())
@ -538,6 +544,7 @@ func TestHandleConnectionSendsRemoteError(t *testing.T) {
// Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan // Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222}) pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
assert.Equal(t, 0, remoteConnection.streamCount, "stream count should be zero")
assert.Equal(t, "", remoteDataReceived.String()) assert.Equal(t, "", remoteDataReceived.String())
assert.Equal(t, "", localConnection.receiveBuffer.String()) assert.Equal(t, "", localConnection.receiveBuffer.String())
assert.Equal(t, "Handling connection for 1111\n", out.String()) assert.Equal(t, "Handling connection for 1111\n", out.String())