diff --git a/tools/portforward/portforward.go b/tools/portforward/portforward.go index 6b5e3076..b581043f 100644 --- a/tools/portforward/portforward.go +++ b/tools/portforward/portforward.go @@ -37,6 +37,8 @@ import ( // TODO move to API machinery and re-unify with kubelet/server/portfoward const PortForwardProtocolV1Name = "portforward.k8s.io" +var ErrLostConnectionToPod = errors.New("lost connection to pod") + // PortForwarder knows how to listen for local connections and forward them to // a remote pod via an upgraded HTTP request. type PortForwarder struct { @@ -230,7 +232,7 @@ func (pf *PortForwarder) forward() error { select { case <-pf.stopChan: case <-pf.streamConn.CloseChan(): - runtime.HandleError(errors.New("lost connection to pod")) + return ErrLostConnectionToPod } return nil diff --git a/tools/portforward/portforward_test.go b/tools/portforward/portforward_test.go index ada70339..3c90a3fd 100644 --- a/tools/portforward/portforward_test.go +++ b/tools/portforward/portforward_test.go @@ -567,3 +567,64 @@ func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) { port := ForwardedPort{} pf.waitForConnection(&listener, port) } + +func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) { + dialer := &fakeDialer{ + conn: newFakeConnection(), + } + + stopChan := make(chan struct{}) + readyChan := make(chan struct{}) + errChan := make(chan error) + + pf, err := New(dialer, []string{":5000"}, stopChan, readyChan, os.Stdout, os.Stderr) + if err != nil { + t.Fatalf("failed to create new PortForwarder: %s", err) + } + + go func() { + errChan <- pf.ForwardPorts() + }() + + <-pf.Ready + + // Simulate lost pod connection by closing streamConn, which should result in pf.ForwardPorts() returning an error. + pf.streamConn.Close() + + err = <-errChan + if err == nil { + t.Fatalf("unexpected non-error from pf.ForwardPorts()") + } else if err != ErrLostConnectionToPod { + t.Fatalf("unexpected error from pf.ForwardPorts(): %s", err) + } +} + +func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) { + dialer := &fakeDialer{ + conn: newFakeConnection(), + } + + stopChan := make(chan struct{}) + readyChan := make(chan struct{}) + errChan := make(chan error) + + pf, err := New(dialer, []string{":5000"}, stopChan, readyChan, os.Stdout, os.Stderr) + if err != nil { + t.Fatalf("failed to create new PortForwarder: %s", err) + } + + go func() { + errChan <- pf.ForwardPorts() + }() + + <-pf.Ready + + // Closing (or sending to) stopChan indicates a stop request by the caller, which should result in pf.ForwardPorts() + // returning nil. + close(stopChan) + + err = <-errChan + if err != nil { + t.Fatalf("unexpected error from pf.ForwardPorts(): %s", err) + } +}