Merge pull request #103526 from brianpursley/kubectl-686

Close connection and stop listening when port forwarding errors occur so that kubectl can exit

Kubernetes-commit: cd6ffff85d257ff9067d59339f2ffdbcc66dc164
This commit is contained in:
Kubernetes Publisher 2021-11-17 07:59:54 -08:00
commit cc8a98c5db
2 changed files with 185 additions and 12 deletions

View File

@ -300,6 +300,10 @@ func (pf *PortForwarder) getListener(protocol string, hostname string, port *For
// the background. // the background.
func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) { func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
for { for {
select {
case <-pf.streamConn.CloseChan():
return
default:
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
// TODO consider using something like https://github.com/hydrogen18/stoppableListener? // TODO consider using something like https://github.com/hydrogen18/stoppableListener?
@ -310,6 +314,7 @@ func (pf *PortForwarder) waitForConnection(listener net.Listener, port Forwarded
} }
go pf.handleConnection(conn, port) go pf.handleConnection(conn, port)
} }
}
} }
func (pf *PortForwarder) nextRequestID() int { func (pf *PortForwarder) nextRequestID() int {
@ -399,6 +404,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
err = <-errorChan err = <-errorChan
if err != nil { if err != nil {
runtime.HandleError(err) runtime.HandleError(err)
pf.streamConn.Close()
} }
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package portforward package portforward
import ( import (
"bytes"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -27,6 +28,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/httpstream" "k8s.io/apimachinery/pkg/util/httpstream"
) )
@ -45,16 +49,27 @@ func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, e
type fakeConnection struct { type fakeConnection struct {
closed bool closed bool
closeChan chan bool closeChan chan bool
dataStream *fakeStream
errorStream *fakeStream
} }
func newFakeConnection() httpstream.Connection { func newFakeConnection() *fakeConnection {
return &fakeConnection{ return &fakeConnection{
closeChan: make(chan bool), closeChan: make(chan bool),
dataStream: &fakeStream{},
errorStream: &fakeStream{},
} }
} }
func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) { func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
return nil, nil switch headers.Get(v1.StreamType) {
case v1.StreamTypeData:
return c.dataStream, nil
case v1.StreamTypeError:
return c.errorStream, nil
default:
return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType))
}
} }
func (c *fakeConnection) Close() error { func (c *fakeConnection) Close() error {
@ -76,6 +91,65 @@ func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) {
// no-op // no-op
} }
type fakeListener struct {
net.Listener
closeChan chan bool
}
func newFakeListener() fakeListener {
return fakeListener{
closeChan: make(chan bool),
}
}
func (l *fakeListener) Accept() (net.Conn, error) {
select {
case <-l.closeChan:
return nil, fmt.Errorf("listener closed")
}
}
func (l *fakeListener) Close() error {
close(l.closeChan)
return nil
}
func (l *fakeListener) Addr() net.Addr {
return fakeAddr{}
}
type fakeAddr struct{}
func (fakeAddr) Network() string { return "fake" }
func (fakeAddr) String() string { return "fake" }
type fakeStream struct {
headers http.Header
readFunc func(p []byte) (int, error)
writeFunc func(p []byte) (int, error)
}
func (s *fakeStream) Read(p []byte) (n int, err error) { return s.readFunc(p) }
func (s *fakeStream) Write(p []byte) (n int, err error) { return s.writeFunc(p) }
func (*fakeStream) Close() error { return nil }
func (*fakeStream) Reset() error { return nil }
func (s *fakeStream) Headers() http.Header { return s.headers }
func (*fakeStream) Identifier() uint32 { return 0 }
type fakeConn struct {
sendBuffer *bytes.Buffer
receiveBuffer *bytes.Buffer
}
func (f fakeConn) Read(p []byte) (int, error) { return f.sendBuffer.Read(p) }
func (f fakeConn) Write(p []byte) (int, error) { return f.receiveBuffer.Write(p) }
func (fakeConn) Close() error { return nil }
func (fakeConn) LocalAddr() net.Addr { return nil }
func (fakeConn) RemoteAddr() net.Addr { return nil }
func (fakeConn) SetDeadline(t time.Time) error { return nil }
func (fakeConn) SetReadDeadline(t time.Time) error { return nil }
func (fakeConn) SetWriteDeadline(t time.Time) error { return nil }
func TestParsePortsAndNew(t *testing.T) { func TestParsePortsAndNew(t *testing.T) {
tests := []struct { tests := []struct {
input []string input []string
@ -393,3 +467,96 @@ func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
t.Fatalf("local port is 0, expected != 0") t.Fatalf("local port is 0, expected != 0")
} }
} }
func TestHandleConnection(t *testing.T) {
out := bytes.NewBufferString("")
pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, nil)
if err != nil {
t.Fatalf("error while calling New: %s", err)
}
// Setup fake local connection
localConnection := &fakeConn{
sendBuffer: bytes.NewBufferString("test data from local"),
receiveBuffer: bytes.NewBufferString(""),
}
// Setup fake remote connection to send data on the data stream after it receives data from the local connection
remoteDataToSend := bytes.NewBufferString("test data from remote")
remoteDataReceived := bytes.NewBufferString("")
remoteErrorToSend := bytes.NewBufferString("")
blockRemoteSend := make(chan struct{})
remoteConnection := newFakeConnection()
remoteConnection.dataStream.readFunc = func(p []byte) (int, error) {
<-blockRemoteSend // Wait for the expected data to be received before responding
return remoteDataToSend.Read(p)
}
remoteConnection.dataStream.writeFunc = func(p []byte) (int, error) {
n, err := remoteDataReceived.Write(p)
if remoteDataReceived.String() == "test data from local" {
close(blockRemoteSend)
}
return n, err
}
remoteConnection.errorStream.readFunc = remoteErrorToSend.Read
pf.streamConn = remoteConnection
// Test handleConnection
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
assert.Equal(t, "test data from local", remoteDataReceived.String())
assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String())
assert.Equal(t, "Handling connection for 1111\n", out.String())
}
func TestHandleConnectionSendsRemoteError(t *testing.T) {
out := bytes.NewBufferString("")
errOut := bytes.NewBufferString("")
pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut)
if err != nil {
t.Fatalf("error while calling New: %s", err)
}
// Setup fake local connection
localConnection := &fakeConn{
sendBuffer: bytes.NewBufferString(""),
receiveBuffer: bytes.NewBufferString(""),
}
// Setup fake remote connection to return an error message on the error stream
remoteDataToSend := bytes.NewBufferString("")
remoteDataReceived := bytes.NewBufferString("")
remoteErrorToSend := bytes.NewBufferString("error")
remoteConnection := newFakeConnection()
remoteConnection.dataStream.readFunc = remoteDataToSend.Read
remoteConnection.dataStream.writeFunc = remoteDataReceived.Write
remoteConnection.errorStream.readFunc = remoteErrorToSend.Read
pf.streamConn = remoteConnection
// Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
assert.Equal(t, "", remoteDataReceived.String())
assert.Equal(t, "", localConnection.receiveBuffer.String())
assert.Equal(t, "Handling connection for 1111\n", out.String())
}
func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) {
out := bytes.NewBufferString("")
errOut := bytes.NewBufferString("")
pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut)
if err != nil {
t.Fatalf("error while calling New: %s", err)
}
listener := newFakeListener()
pf.streamConn = newFakeConnection()
pf.streamConn.Close()
port := ForwardedPort{}
pf.waitForConnection(&listener, port)
}