kubectl: fix memory leaks in port forwarding client

Signed-off-by: LiHui <andrewli@kubesphere.io>

Kubernetes-commit: 1df24569a0bf62a528c49f73fdb236fd56eb05ee
This commit is contained in:
LiHui 2022-08-29 14:15:01 +08:00 committed by Kubernetes Publisher
parent f6b8521807
commit cc3cc93e6a
2 changed files with 11 additions and 2 deletions

View File

@ -347,6 +347,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
} }
// we're not writing to this stream // we're not writing to this stream
errorStream.Close() errorStream.Close()
defer pf.streamConn.RemoveStreams(errorStream)
errorChan := make(chan error) errorChan := make(chan error)
go func() { go func() {
@ -367,6 +368,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err)) runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
return return
} }
defer pf.streamConn.RemoveStreams(dataStream)
localError := make(chan struct{}) localError := make(chan struct{})
remoteDone := make(chan struct{}) remoteDone := make(chan struct{})

View File

@ -51,6 +51,7 @@ type fakeConnection struct {
closeChan chan bool closeChan chan bool
dataStream *fakeStream dataStream *fakeStream
errorStream *fakeStream errorStream *fakeStream
streamCount int
} }
func newFakeConnection() *fakeConnection { func newFakeConnection() *fakeConnection {
@ -64,8 +65,10 @@ func newFakeConnection() *fakeConnection {
func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) { func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
switch headers.Get(v1.StreamType) { switch headers.Get(v1.StreamType) {
case v1.StreamTypeData: case v1.StreamTypeData:
c.streamCount++
return c.dataStream, nil return c.dataStream, nil
case v1.StreamTypeError: case v1.StreamTypeError:
c.streamCount++
return c.errorStream, nil return c.errorStream, nil
default: default:
return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType)) return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType))
@ -84,7 +87,10 @@ func (c *fakeConnection) CloseChan() <-chan bool {
return c.closeChan return c.closeChan
} }
func (c *fakeConnection) RemoveStreams(_ ...httpstream.Stream) { func (c *fakeConnection) RemoveStreams(streams ...httpstream.Stream) {
for range streams {
c.streamCount--
}
} }
func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) { func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) {
@ -504,7 +510,7 @@ func TestHandleConnection(t *testing.T) {
// Test handleConnection // Test handleConnection
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222}) pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
assert.Equal(t, 0, remoteConnection.streamCount, "stream count should be zero")
assert.Equal(t, "test data from local", remoteDataReceived.String()) assert.Equal(t, "test data from local", remoteDataReceived.String())
assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String()) assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String())
assert.Equal(t, "Handling connection for 1111\n", out.String()) assert.Equal(t, "Handling connection for 1111\n", out.String())
@ -538,6 +544,7 @@ func TestHandleConnectionSendsRemoteError(t *testing.T) {
// Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan // Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222}) pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
assert.Equal(t, 0, remoteConnection.streamCount, "stream count should be zero")
assert.Equal(t, "", remoteDataReceived.String()) assert.Equal(t, "", remoteDataReceived.String())
assert.Equal(t, "", localConnection.receiveBuffer.String()) assert.Equal(t, "", localConnection.receiveBuffer.String())
assert.Equal(t, "Handling connection for 1111\n", out.String()) assert.Equal(t, "Handling connection for 1111\n", out.String())