mirror of
https://github.com/kubernetes/client-go.git
synced 2025-06-24 22:23:02 +00:00
Merge pull request #103526 from brianpursley/kubectl-686
Close connection and stop listening when port forwarding errors occur so that kubectl can exit Kubernetes-commit: cd6ffff85d257ff9067d59339f2ffdbcc66dc164
This commit is contained in:
commit
cc8a98c5db
@ -300,15 +300,20 @@ func (pf *PortForwarder) getListener(protocol string, hostname string, port *For
|
||||
// the background.
|
||||
func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
|
||||
runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err))
|
||||
}
|
||||
select {
|
||||
case <-pf.streamConn.CloseChan():
|
||||
return
|
||||
default:
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
|
||||
runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err))
|
||||
}
|
||||
return
|
||||
}
|
||||
go pf.handleConnection(conn, port)
|
||||
}
|
||||
go pf.handleConnection(conn, port)
|
||||
}
|
||||
}
|
||||
|
||||
@ -399,6 +404,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
|
||||
err = <-errorChan
|
||||
if err != nil {
|
||||
runtime.HandleError(err)
|
||||
pf.streamConn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -27,6 +28,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
v1 "k8s.io/api/core/v1"
|
||||
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||
)
|
||||
|
||||
@ -43,18 +47,29 @@ func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, e
|
||||
}
|
||||
|
||||
type fakeConnection struct {
|
||||
closed bool
|
||||
closeChan chan bool
|
||||
closed bool
|
||||
closeChan chan bool
|
||||
dataStream *fakeStream
|
||||
errorStream *fakeStream
|
||||
}
|
||||
|
||||
func newFakeConnection() httpstream.Connection {
|
||||
func newFakeConnection() *fakeConnection {
|
||||
return &fakeConnection{
|
||||
closeChan: make(chan bool),
|
||||
closeChan: make(chan bool),
|
||||
dataStream: &fakeStream{},
|
||||
errorStream: &fakeStream{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
|
||||
return nil, nil
|
||||
switch headers.Get(v1.StreamType) {
|
||||
case v1.StreamTypeData:
|
||||
return c.dataStream, nil
|
||||
case v1.StreamTypeError:
|
||||
return c.errorStream, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *fakeConnection) Close() error {
|
||||
@ -76,6 +91,65 @@ func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
type fakeListener struct {
|
||||
net.Listener
|
||||
closeChan chan bool
|
||||
}
|
||||
|
||||
func newFakeListener() fakeListener {
|
||||
return fakeListener{
|
||||
closeChan: make(chan bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *fakeListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-l.closeChan:
|
||||
return nil, fmt.Errorf("listener closed")
|
||||
}
|
||||
}
|
||||
|
||||
func (l *fakeListener) Close() error {
|
||||
close(l.closeChan)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *fakeListener) Addr() net.Addr {
|
||||
return fakeAddr{}
|
||||
}
|
||||
|
||||
type fakeAddr struct{}
|
||||
|
||||
func (fakeAddr) Network() string { return "fake" }
|
||||
func (fakeAddr) String() string { return "fake" }
|
||||
|
||||
type fakeStream struct {
|
||||
headers http.Header
|
||||
readFunc func(p []byte) (int, error)
|
||||
writeFunc func(p []byte) (int, error)
|
||||
}
|
||||
|
||||
func (s *fakeStream) Read(p []byte) (n int, err error) { return s.readFunc(p) }
|
||||
func (s *fakeStream) Write(p []byte) (n int, err error) { return s.writeFunc(p) }
|
||||
func (*fakeStream) Close() error { return nil }
|
||||
func (*fakeStream) Reset() error { return nil }
|
||||
func (s *fakeStream) Headers() http.Header { return s.headers }
|
||||
func (*fakeStream) Identifier() uint32 { return 0 }
|
||||
|
||||
type fakeConn struct {
|
||||
sendBuffer *bytes.Buffer
|
||||
receiveBuffer *bytes.Buffer
|
||||
}
|
||||
|
||||
func (f fakeConn) Read(p []byte) (int, error) { return f.sendBuffer.Read(p) }
|
||||
func (f fakeConn) Write(p []byte) (int, error) { return f.receiveBuffer.Write(p) }
|
||||
func (fakeConn) Close() error { return nil }
|
||||
func (fakeConn) LocalAddr() net.Addr { return nil }
|
||||
func (fakeConn) RemoteAddr() net.Addr { return nil }
|
||||
func (fakeConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (fakeConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (fakeConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func TestParsePortsAndNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
input []string
|
||||
@ -393,3 +467,96 @@ func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
|
||||
t.Fatalf("local port is 0, expected != 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleConnection(t *testing.T) {
|
||||
out := bytes.NewBufferString("")
|
||||
|
||||
pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("error while calling New: %s", err)
|
||||
}
|
||||
|
||||
// Setup fake local connection
|
||||
localConnection := &fakeConn{
|
||||
sendBuffer: bytes.NewBufferString("test data from local"),
|
||||
receiveBuffer: bytes.NewBufferString(""),
|
||||
}
|
||||
|
||||
// Setup fake remote connection to send data on the data stream after it receives data from the local connection
|
||||
remoteDataToSend := bytes.NewBufferString("test data from remote")
|
||||
remoteDataReceived := bytes.NewBufferString("")
|
||||
remoteErrorToSend := bytes.NewBufferString("")
|
||||
blockRemoteSend := make(chan struct{})
|
||||
remoteConnection := newFakeConnection()
|
||||
remoteConnection.dataStream.readFunc = func(p []byte) (int, error) {
|
||||
<-blockRemoteSend // Wait for the expected data to be received before responding
|
||||
return remoteDataToSend.Read(p)
|
||||
}
|
||||
remoteConnection.dataStream.writeFunc = func(p []byte) (int, error) {
|
||||
n, err := remoteDataReceived.Write(p)
|
||||
if remoteDataReceived.String() == "test data from local" {
|
||||
close(blockRemoteSend)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
remoteConnection.errorStream.readFunc = remoteErrorToSend.Read
|
||||
pf.streamConn = remoteConnection
|
||||
|
||||
// Test handleConnection
|
||||
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
|
||||
|
||||
assert.Equal(t, "test data from local", remoteDataReceived.String())
|
||||
assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String())
|
||||
assert.Equal(t, "Handling connection for 1111\n", out.String())
|
||||
}
|
||||
|
||||
func TestHandleConnectionSendsRemoteError(t *testing.T) {
|
||||
out := bytes.NewBufferString("")
|
||||
errOut := bytes.NewBufferString("")
|
||||
|
||||
pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut)
|
||||
if err != nil {
|
||||
t.Fatalf("error while calling New: %s", err)
|
||||
}
|
||||
|
||||
// Setup fake local connection
|
||||
localConnection := &fakeConn{
|
||||
sendBuffer: bytes.NewBufferString(""),
|
||||
receiveBuffer: bytes.NewBufferString(""),
|
||||
}
|
||||
|
||||
// Setup fake remote connection to return an error message on the error stream
|
||||
remoteDataToSend := bytes.NewBufferString("")
|
||||
remoteDataReceived := bytes.NewBufferString("")
|
||||
remoteErrorToSend := bytes.NewBufferString("error")
|
||||
remoteConnection := newFakeConnection()
|
||||
remoteConnection.dataStream.readFunc = remoteDataToSend.Read
|
||||
remoteConnection.dataStream.writeFunc = remoteDataReceived.Write
|
||||
remoteConnection.errorStream.readFunc = remoteErrorToSend.Read
|
||||
pf.streamConn = remoteConnection
|
||||
|
||||
// Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan
|
||||
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
|
||||
|
||||
assert.Equal(t, "", remoteDataReceived.String())
|
||||
assert.Equal(t, "", localConnection.receiveBuffer.String())
|
||||
assert.Equal(t, "Handling connection for 1111\n", out.String())
|
||||
}
|
||||
|
||||
func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) {
|
||||
out := bytes.NewBufferString("")
|
||||
errOut := bytes.NewBufferString("")
|
||||
|
||||
pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut)
|
||||
if err != nil {
|
||||
t.Fatalf("error while calling New: %s", err)
|
||||
}
|
||||
|
||||
listener := newFakeListener()
|
||||
|
||||
pf.streamConn = newFakeConnection()
|
||||
pf.streamConn.Close()
|
||||
|
||||
port := ForwardedPort{}
|
||||
pf.waitForConnection(&listener, port)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user