From f5d063cf7a3a6eed86a9d07c903f5ffa368f8156 Mon Sep 17 00:00:00 2001 From: brianpursley Date: Tue, 16 Nov 2021 21:05:43 -0500 Subject: [PATCH] close streamConn and stop listening when an error occurs while port forwarding Kubernetes-commit: 4a35e6c5641ee36291548c9c24b5b85663d5cd07 --- tools/portforward/portforward.go | 20 ++- tools/portforward/portforward_test.go | 177 +++++++++++++++++++++++++- 2 files changed, 185 insertions(+), 12 deletions(-) diff --git a/tools/portforward/portforward.go b/tools/portforward/portforward.go index 1c3985f3..6f1d12b6 100644 --- a/tools/portforward/portforward.go +++ b/tools/portforward/portforward.go @@ -300,15 +300,20 @@ func (pf *PortForwarder) getListener(protocol string, hostname string, port *For // the background. func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) { for { - conn, err := listener.Accept() - if err != nil { - // TODO consider using something like https://github.com/hydrogen18/stoppableListener? - if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") { - runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err)) - } + select { + case <-pf.streamConn.CloseChan(): return + default: + conn, err := listener.Accept() + if err != nil { + // TODO consider using something like https://github.com/hydrogen18/stoppableListener? + if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") { + runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err)) + } + return + } + go pf.handleConnection(conn, port) } - go pf.handleConnection(conn, port) } } @@ -399,6 +404,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) { err = <-errorChan if err != nil { runtime.HandleError(err) + pf.streamConn.Close() } } diff --git a/tools/portforward/portforward_test.go b/tools/portforward/portforward_test.go index 551d97e9..04427e12 100644 --- a/tools/portforward/portforward_test.go +++ b/tools/portforward/portforward_test.go @@ -17,6 +17,7 @@ limitations under the License. package portforward import ( + "bytes" "fmt" "net" "net/http" @@ -27,6 +28,9 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + + v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/httpstream" ) @@ -43,18 +47,29 @@ func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, e } type fakeConnection struct { - closed bool - closeChan chan bool + closed bool + closeChan chan bool + dataStream *fakeStream + errorStream *fakeStream } -func newFakeConnection() httpstream.Connection { +func newFakeConnection() *fakeConnection { return &fakeConnection{ - closeChan: make(chan bool), + closeChan: make(chan bool), + dataStream: &fakeStream{}, + errorStream: &fakeStream{}, } } func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) { - return nil, nil + switch headers.Get(v1.StreamType) { + case v1.StreamTypeData: + return c.dataStream, nil + case v1.StreamTypeError: + return c.errorStream, nil + default: + return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType)) + } } func (c *fakeConnection) Close() error { @@ -76,6 +91,65 @@ func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) { // no-op } +type fakeListener struct { + net.Listener + closeChan chan bool +} + +func newFakeListener() fakeListener { + return fakeListener{ + closeChan: make(chan bool), + } +} + +func (l *fakeListener) Accept() (net.Conn, error) { + select { + case <-l.closeChan: + return nil, fmt.Errorf("listener closed") + } +} + +func (l *fakeListener) Close() error { + close(l.closeChan) + return nil +} + +func (l *fakeListener) Addr() net.Addr { + return fakeAddr{} +} + +type fakeAddr struct{} + +func (fakeAddr) Network() string { return "fake" } +func (fakeAddr) String() string { return "fake" } + +type fakeStream struct { + headers http.Header + readFunc func(p []byte) (int, error) + writeFunc func(p []byte) (int, error) +} + +func (s *fakeStream) Read(p []byte) (n int, err error) { return s.readFunc(p) } +func (s *fakeStream) Write(p []byte) (n int, err error) { return s.writeFunc(p) } +func (*fakeStream) Close() error { return nil } +func (*fakeStream) Reset() error { return nil } +func (s *fakeStream) Headers() http.Header { return s.headers } +func (*fakeStream) Identifier() uint32 { return 0 } + +type fakeConn struct { + sendBuffer *bytes.Buffer + receiveBuffer *bytes.Buffer +} + +func (f fakeConn) Read(p []byte) (int, error) { return f.sendBuffer.Read(p) } +func (f fakeConn) Write(p []byte) (int, error) { return f.receiveBuffer.Write(p) } +func (fakeConn) Close() error { return nil } +func (fakeConn) LocalAddr() net.Addr { return nil } +func (fakeConn) RemoteAddr() net.Addr { return nil } +func (fakeConn) SetDeadline(t time.Time) error { return nil } +func (fakeConn) SetReadDeadline(t time.Time) error { return nil } +func (fakeConn) SetWriteDeadline(t time.Time) error { return nil } + func TestParsePortsAndNew(t *testing.T) { tests := []struct { input []string @@ -393,3 +467,96 @@ func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) { t.Fatalf("local port is 0, expected != 0") } } + +func TestHandleConnection(t *testing.T) { + out := bytes.NewBufferString("") + + pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, nil) + if err != nil { + t.Fatalf("error while calling New: %s", err) + } + + // Setup fake local connection + localConnection := &fakeConn{ + sendBuffer: bytes.NewBufferString("test data from local"), + receiveBuffer: bytes.NewBufferString(""), + } + + // Setup fake remote connection to send data on the data stream after it receives data from the local connection + remoteDataToSend := bytes.NewBufferString("test data from remote") + remoteDataReceived := bytes.NewBufferString("") + remoteErrorToSend := bytes.NewBufferString("") + blockRemoteSend := make(chan struct{}) + remoteConnection := newFakeConnection() + remoteConnection.dataStream.readFunc = func(p []byte) (int, error) { + <-blockRemoteSend // Wait for the expected data to be received before responding + return remoteDataToSend.Read(p) + } + remoteConnection.dataStream.writeFunc = func(p []byte) (int, error) { + n, err := remoteDataReceived.Write(p) + if remoteDataReceived.String() == "test data from local" { + close(blockRemoteSend) + } + return n, err + } + remoteConnection.errorStream.readFunc = remoteErrorToSend.Read + pf.streamConn = remoteConnection + + // Test handleConnection + pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222}) + + assert.Equal(t, "test data from local", remoteDataReceived.String()) + assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String()) + assert.Equal(t, "Handling connection for 1111\n", out.String()) +} + +func TestHandleConnectionSendsRemoteError(t *testing.T) { + out := bytes.NewBufferString("") + errOut := bytes.NewBufferString("") + + pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut) + if err != nil { + t.Fatalf("error while calling New: %s", err) + } + + // Setup fake local connection + localConnection := &fakeConn{ + sendBuffer: bytes.NewBufferString(""), + receiveBuffer: bytes.NewBufferString(""), + } + + // Setup fake remote connection to return an error message on the error stream + remoteDataToSend := bytes.NewBufferString("") + remoteDataReceived := bytes.NewBufferString("") + remoteErrorToSend := bytes.NewBufferString("error") + remoteConnection := newFakeConnection() + remoteConnection.dataStream.readFunc = remoteDataToSend.Read + remoteConnection.dataStream.writeFunc = remoteDataReceived.Write + remoteConnection.errorStream.readFunc = remoteErrorToSend.Read + pf.streamConn = remoteConnection + + // Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan + pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222}) + + assert.Equal(t, "", remoteDataReceived.String()) + assert.Equal(t, "", localConnection.receiveBuffer.String()) + assert.Equal(t, "Handling connection for 1111\n", out.String()) +} + +func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) { + out := bytes.NewBufferString("") + errOut := bytes.NewBufferString("") + + pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut) + if err != nil { + t.Fatalf("error while calling New: %s", err) + } + + listener := newFakeListener() + + pf.streamConn = newFakeConnection() + pf.streamConn.Close() + + port := ForwardedPort{} + pf.waitForConnection(&listener, port) +}