Merge pull request #112091 from xyz-li/master

kubectl: fix memory leaks in port forwarding client
This commit is contained in:
Kubernetes Prow Robot 2022-09-15 06:47:25 -07:00 committed by GitHub
commit 7a68c8a21a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 2 deletions

View File

@ -347,6 +347,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
} }
// we're not writing to this stream // we're not writing to this stream
errorStream.Close() errorStream.Close()
defer pf.streamConn.RemoveStreams(errorStream)
errorChan := make(chan error) errorChan := make(chan error)
go func() { go func() {
@ -367,6 +368,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err)) runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
return return
} }
defer pf.streamConn.RemoveStreams(dataStream)
localError := make(chan struct{}) localError := make(chan struct{})
remoteDone := make(chan struct{}) remoteDone := make(chan struct{})

View File

@ -51,6 +51,7 @@ type fakeConnection struct {
closeChan chan bool closeChan chan bool
dataStream *fakeStream dataStream *fakeStream
errorStream *fakeStream errorStream *fakeStream
streamCount int
} }
func newFakeConnection() *fakeConnection { func newFakeConnection() *fakeConnection {
@ -64,8 +65,10 @@ func newFakeConnection() *fakeConnection {
func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) { func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
switch headers.Get(v1.StreamType) { switch headers.Get(v1.StreamType) {
case v1.StreamTypeData: case v1.StreamTypeData:
c.streamCount++
return c.dataStream, nil return c.dataStream, nil
case v1.StreamTypeError: case v1.StreamTypeError:
c.streamCount++
return c.errorStream, nil return c.errorStream, nil
default: default:
return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType)) return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType))
@ -84,7 +87,10 @@ func (c *fakeConnection) CloseChan() <-chan bool {
return c.closeChan return c.closeChan
} }
func (c *fakeConnection) RemoveStreams(_ ...httpstream.Stream) { func (c *fakeConnection) RemoveStreams(streams ...httpstream.Stream) {
for range streams {
c.streamCount--
}
} }
func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) { func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) {
@ -504,7 +510,7 @@ func TestHandleConnection(t *testing.T) {
// Test handleConnection // Test handleConnection
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222}) pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
assert.Equal(t, 0, remoteConnection.streamCount, "stream count should be zero")
assert.Equal(t, "test data from local", remoteDataReceived.String()) assert.Equal(t, "test data from local", remoteDataReceived.String())
assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String()) assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String())
assert.Equal(t, "Handling connection for 1111\n", out.String()) assert.Equal(t, "Handling connection for 1111\n", out.String())
@ -538,6 +544,7 @@ func TestHandleConnectionSendsRemoteError(t *testing.T) {
// Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan // Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222}) pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
assert.Equal(t, 0, remoteConnection.streamCount, "stream count should be zero")
assert.Equal(t, "", remoteDataReceived.String()) assert.Equal(t, "", remoteDataReceived.String())
assert.Equal(t, "", localConnection.receiveBuffer.String()) assert.Equal(t, "", localConnection.receiveBuffer.String())
assert.Equal(t, "Handling connection for 1111\n", out.String()) assert.Equal(t, "Handling connection for 1111\n", out.String())