mirror of
https://github.com/kubernetes/client-go.git
synced 2025-06-25 22:51:40 +00:00
Merge pull request #103526 from brianpursley/kubectl-686
Close connection and stop listening when port forwarding errors occur so that kubectl can exit Kubernetes-commit: cd6ffff85d257ff9067d59339f2ffdbcc66dc164
This commit is contained in:
commit
cc8a98c5db
@ -300,6 +300,10 @@ func (pf *PortForwarder) getListener(protocol string, hostname string, port *For
|
|||||||
// the background.
|
// the background.
|
||||||
func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
|
func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
|
||||||
for {
|
for {
|
||||||
|
select {
|
||||||
|
case <-pf.streamConn.CloseChan():
|
||||||
|
return
|
||||||
|
default:
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
|
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
|
||||||
@ -310,6 +314,7 @@ func (pf *PortForwarder) waitForConnection(listener net.Listener, port Forwarded
|
|||||||
}
|
}
|
||||||
go pf.handleConnection(conn, port)
|
go pf.handleConnection(conn, port)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pf *PortForwarder) nextRequestID() int {
|
func (pf *PortForwarder) nextRequestID() int {
|
||||||
@ -399,6 +404,7 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
|
|||||||
err = <-errorChan
|
err = <-errorChan
|
||||||
if err != nil {
|
if err != nil {
|
||||||
runtime.HandleError(err)
|
runtime.HandleError(err)
|
||||||
|
pf.streamConn.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
package portforward
|
package portforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -27,6 +28,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
v1 "k8s.io/api/core/v1"
|
||||||
"k8s.io/apimachinery/pkg/util/httpstream"
|
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -45,16 +49,27 @@ func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, e
|
|||||||
type fakeConnection struct {
|
type fakeConnection struct {
|
||||||
closed bool
|
closed bool
|
||||||
closeChan chan bool
|
closeChan chan bool
|
||||||
|
dataStream *fakeStream
|
||||||
|
errorStream *fakeStream
|
||||||
}
|
}
|
||||||
|
|
||||||
func newFakeConnection() httpstream.Connection {
|
func newFakeConnection() *fakeConnection {
|
||||||
return &fakeConnection{
|
return &fakeConnection{
|
||||||
closeChan: make(chan bool),
|
closeChan: make(chan bool),
|
||||||
|
dataStream: &fakeStream{},
|
||||||
|
errorStream: &fakeStream{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
|
func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
|
||||||
return nil, nil
|
switch headers.Get(v1.StreamType) {
|
||||||
|
case v1.StreamTypeData:
|
||||||
|
return c.dataStream, nil
|
||||||
|
case v1.StreamTypeError:
|
||||||
|
return c.errorStream, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *fakeConnection) Close() error {
|
func (c *fakeConnection) Close() error {
|
||||||
@ -76,6 +91,65 @@ func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) {
|
|||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type fakeListener struct {
|
||||||
|
net.Listener
|
||||||
|
closeChan chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeListener() fakeListener {
|
||||||
|
return fakeListener{
|
||||||
|
closeChan: make(chan bool),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *fakeListener) Accept() (net.Conn, error) {
|
||||||
|
select {
|
||||||
|
case <-l.closeChan:
|
||||||
|
return nil, fmt.Errorf("listener closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *fakeListener) Close() error {
|
||||||
|
close(l.closeChan)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *fakeListener) Addr() net.Addr {
|
||||||
|
return fakeAddr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeAddr struct{}
|
||||||
|
|
||||||
|
func (fakeAddr) Network() string { return "fake" }
|
||||||
|
func (fakeAddr) String() string { return "fake" }
|
||||||
|
|
||||||
|
type fakeStream struct {
|
||||||
|
headers http.Header
|
||||||
|
readFunc func(p []byte) (int, error)
|
||||||
|
writeFunc func(p []byte) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fakeStream) Read(p []byte) (n int, err error) { return s.readFunc(p) }
|
||||||
|
func (s *fakeStream) Write(p []byte) (n int, err error) { return s.writeFunc(p) }
|
||||||
|
func (*fakeStream) Close() error { return nil }
|
||||||
|
func (*fakeStream) Reset() error { return nil }
|
||||||
|
func (s *fakeStream) Headers() http.Header { return s.headers }
|
||||||
|
func (*fakeStream) Identifier() uint32 { return 0 }
|
||||||
|
|
||||||
|
type fakeConn struct {
|
||||||
|
sendBuffer *bytes.Buffer
|
||||||
|
receiveBuffer *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f fakeConn) Read(p []byte) (int, error) { return f.sendBuffer.Read(p) }
|
||||||
|
func (f fakeConn) Write(p []byte) (int, error) { return f.receiveBuffer.Write(p) }
|
||||||
|
func (fakeConn) Close() error { return nil }
|
||||||
|
func (fakeConn) LocalAddr() net.Addr { return nil }
|
||||||
|
func (fakeConn) RemoteAddr() net.Addr { return nil }
|
||||||
|
func (fakeConn) SetDeadline(t time.Time) error { return nil }
|
||||||
|
func (fakeConn) SetReadDeadline(t time.Time) error { return nil }
|
||||||
|
func (fakeConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
|
|
||||||
func TestParsePortsAndNew(t *testing.T) {
|
func TestParsePortsAndNew(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
input []string
|
input []string
|
||||||
@ -393,3 +467,96 @@ func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
|
|||||||
t.Fatalf("local port is 0, expected != 0")
|
t.Fatalf("local port is 0, expected != 0")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandleConnection(t *testing.T) {
|
||||||
|
out := bytes.NewBufferString("")
|
||||||
|
|
||||||
|
pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error while calling New: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup fake local connection
|
||||||
|
localConnection := &fakeConn{
|
||||||
|
sendBuffer: bytes.NewBufferString("test data from local"),
|
||||||
|
receiveBuffer: bytes.NewBufferString(""),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup fake remote connection to send data on the data stream after it receives data from the local connection
|
||||||
|
remoteDataToSend := bytes.NewBufferString("test data from remote")
|
||||||
|
remoteDataReceived := bytes.NewBufferString("")
|
||||||
|
remoteErrorToSend := bytes.NewBufferString("")
|
||||||
|
blockRemoteSend := make(chan struct{})
|
||||||
|
remoteConnection := newFakeConnection()
|
||||||
|
remoteConnection.dataStream.readFunc = func(p []byte) (int, error) {
|
||||||
|
<-blockRemoteSend // Wait for the expected data to be received before responding
|
||||||
|
return remoteDataToSend.Read(p)
|
||||||
|
}
|
||||||
|
remoteConnection.dataStream.writeFunc = func(p []byte) (int, error) {
|
||||||
|
n, err := remoteDataReceived.Write(p)
|
||||||
|
if remoteDataReceived.String() == "test data from local" {
|
||||||
|
close(blockRemoteSend)
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
remoteConnection.errorStream.readFunc = remoteErrorToSend.Read
|
||||||
|
pf.streamConn = remoteConnection
|
||||||
|
|
||||||
|
// Test handleConnection
|
||||||
|
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
|
||||||
|
|
||||||
|
assert.Equal(t, "test data from local", remoteDataReceived.String())
|
||||||
|
assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String())
|
||||||
|
assert.Equal(t, "Handling connection for 1111\n", out.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleConnectionSendsRemoteError(t *testing.T) {
|
||||||
|
out := bytes.NewBufferString("")
|
||||||
|
errOut := bytes.NewBufferString("")
|
||||||
|
|
||||||
|
pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error while calling New: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup fake local connection
|
||||||
|
localConnection := &fakeConn{
|
||||||
|
sendBuffer: bytes.NewBufferString(""),
|
||||||
|
receiveBuffer: bytes.NewBufferString(""),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup fake remote connection to return an error message on the error stream
|
||||||
|
remoteDataToSend := bytes.NewBufferString("")
|
||||||
|
remoteDataReceived := bytes.NewBufferString("")
|
||||||
|
remoteErrorToSend := bytes.NewBufferString("error")
|
||||||
|
remoteConnection := newFakeConnection()
|
||||||
|
remoteConnection.dataStream.readFunc = remoteDataToSend.Read
|
||||||
|
remoteConnection.dataStream.writeFunc = remoteDataReceived.Write
|
||||||
|
remoteConnection.errorStream.readFunc = remoteErrorToSend.Read
|
||||||
|
pf.streamConn = remoteConnection
|
||||||
|
|
||||||
|
// Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan
|
||||||
|
pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
|
||||||
|
|
||||||
|
assert.Equal(t, "", remoteDataReceived.String())
|
||||||
|
assert.Equal(t, "", localConnection.receiveBuffer.String())
|
||||||
|
assert.Equal(t, "Handling connection for 1111\n", out.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) {
|
||||||
|
out := bytes.NewBufferString("")
|
||||||
|
errOut := bytes.NewBufferString("")
|
||||||
|
|
||||||
|
pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error while calling New: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
listener := newFakeListener()
|
||||||
|
|
||||||
|
pf.streamConn = newFakeConnection()
|
||||||
|
pf.streamConn.Close()
|
||||||
|
|
||||||
|
port := ForwardedPort{}
|
||||||
|
pf.waitForConnection(&listener, port)
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user