kubectl: fix memory leaks in port forwarding client

Signed-off-by: LiHui <andrewli@kubesphere.io>
This commit is contained in:
LiHui 2022-08-29 14:15:01 +08:00
parent 50097acf15
commit 1df24569a0
2 changed files with 11 additions and 2 deletions

View File

@ -347,6 +347,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
} }
// we're not writing to this stream // we're not writing to this stream
errorStream.Close() errorStream.Close()
defer pf.streamConn.RemoveStreams(errorStream)
errorChan := make(chan error) errorChan := make(chan error)
go func() { go func() {
@ -367,6 +368,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err)) runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
return return
} }
defer pf.streamConn.RemoveStreams(dataStream)
localError := make(chan struct{}) localError := make(chan struct{})
remoteDone := make(chan struct{}) remoteDone := make(chan struct{})

View File

@ -51,6 +51,7 @@ type fakeConnection struct {
closeChan chan bool closeChan chan bool
dataStream *fakeStream dataStream *fakeStream
errorStream *fakeStream errorStream *fakeStream
streamCount int
} }
func newFakeConnection() *fakeConnection { func newFakeConnection() *fakeConnection {
@ -64,8 +65,10 @@ func newFakeConnection() *fakeConnection {
func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) { func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
switch headers.Get(v1.StreamType) { switch headers.Get(v1.StreamType) {
case v1.StreamTypeData: case v1.StreamTypeData:
c.streamCount++
return c.dataStream, nil return c.dataStream, nil
case v1.StreamTypeError: case v1.StreamTypeError:
c.streamCount++
return c.errorStream, nil return c.errorStream, nil
default: default:
return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType)) return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType))
@ -84,7 +87,10 @@ func (c *fakeConnection) CloseChan() <-chan bool {
return c.closeChan return c.closeChan
} }
func (c *fakeConnection) RemoveStreams(_ ...httpstream.Stream) { func (c *fakeConnection) RemoveStreams(streams ...httpstream.Stream) {
for range streams {
c.streamCount--
}
} }
func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) { func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) {
@ -504,7 +510,7 @@ func TestHandleConnection(t *testing.T) {
// Test handleConnection // Test handleConnection
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222}) pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
assert.Equal(t, 0, remoteConnection.streamCount, "stream count should be zero")
assert.Equal(t, "test data from local", remoteDataReceived.String()) assert.Equal(t, "test data from local", remoteDataReceived.String())
assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String()) assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String())
assert.Equal(t, "Handling connection for 1111\n", out.String()) assert.Equal(t, "Handling connection for 1111\n", out.String())
@ -538,6 +544,7 @@ func TestHandleConnectionSendsRemoteError(t *testing.T) {
// Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan // Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222}) pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
assert.Equal(t, 0, remoteConnection.streamCount, "stream count should be zero")
assert.Equal(t, "", remoteDataReceived.String()) assert.Equal(t, "", remoteDataReceived.String())
assert.Equal(t, "", localConnection.receiveBuffer.String()) assert.Equal(t, "", localConnection.receiveBuffer.String())
assert.Equal(t, "Handling connection for 1111\n", out.String()) assert.Equal(t, "Handling connection for 1111\n", out.String())