mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-25 12:43:23 +00:00
Merge pull request #12283 from ncdc/gh8766-port-forward-not-closing-correctly
Auto commit by PR queue bot
This commit is contained in:
commit
1bcdd56cf3
@ -1951,14 +1951,24 @@ const (
|
||||
// Command to run for remote command execution
|
||||
ExecCommandParamm = "command"
|
||||
|
||||
StreamType = "streamType"
|
||||
StreamTypeStdin = "stdin"
|
||||
// Name of header that specifies stream type
|
||||
StreamType = "streamType"
|
||||
// Value for streamType header for stdin stream
|
||||
StreamTypeStdin = "stdin"
|
||||
// Value for streamType header for stdout stream
|
||||
StreamTypeStdout = "stdout"
|
||||
// Value for streamType header for stderr stream
|
||||
StreamTypeStderr = "stderr"
|
||||
StreamTypeData = "data"
|
||||
StreamTypeError = "error"
|
||||
// Value for streamType header for data stream
|
||||
StreamTypeData = "data"
|
||||
// Value for streamType header for error stream
|
||||
StreamTypeError = "error"
|
||||
|
||||
// Name of header that specifies the port being forwarded
|
||||
PortHeader = "port"
|
||||
// Name of header that specifies a request ID used to associate the error
|
||||
// and data streams for a single forwarded connection
|
||||
PortForwardRequestIDHeader = "requestID"
|
||||
)
|
||||
|
||||
// Similarly to above, these are constants to support HTTP PATCH utilized by
|
||||
|
@ -25,10 +25,12 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/golang/glog"
|
||||
"k8s.io/kubernetes/pkg/api"
|
||||
client "k8s.io/kubernetes/pkg/client/unversioned"
|
||||
"k8s.io/kubernetes/pkg/util"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
|
||||
)
|
||||
@ -51,10 +53,12 @@ type PortForwarder struct {
|
||||
ports []ForwardedPort
|
||||
stopChan <-chan struct{}
|
||||
|
||||
streamConn httpstream.Connection
|
||||
listeners []io.Closer
|
||||
upgrader upgrader
|
||||
Ready chan struct{}
|
||||
streamConn httpstream.Connection
|
||||
listeners []io.Closer
|
||||
upgrader upgrader
|
||||
Ready chan struct{}
|
||||
requestIDLock sync.Mutex
|
||||
requestID int
|
||||
}
|
||||
|
||||
// ForwardedPort contains a Local:Remote port pairing.
|
||||
@ -145,7 +149,7 @@ func (pf *PortForwarder) ForwardPorts() error {
|
||||
var err error
|
||||
pf.streamConn, err = pf.upgrader.upgrade(pf.req, pf.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error upgrading connection: %s", err)
|
||||
return fmt.Errorf("error upgrading connection: %s", err)
|
||||
}
|
||||
defer pf.streamConn.Close()
|
||||
|
||||
@ -179,7 +183,7 @@ func (pf *PortForwarder) forward() error {
|
||||
select {
|
||||
case <-pf.stopChan:
|
||||
case <-pf.streamConn.CloseChan():
|
||||
glog.Errorf("Lost connection to pod")
|
||||
util.HandleError(errors.New("lost connection to pod"))
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -213,7 +217,7 @@ func (pf *PortForwarder) listenOnPortAndAddress(port *ForwardedPort, protocol st
|
||||
func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) {
|
||||
listener, err := net.Listen(protocol, fmt.Sprintf("%s:%d", hostname, port.Local))
|
||||
if err != nil {
|
||||
glog.Errorf("Unable to create listener: Error %s", err)
|
||||
util.HandleError(fmt.Errorf("Unable to create listener: Error %s", err))
|
||||
return nil, err
|
||||
}
|
||||
listenerAddress := listener.Addr().String()
|
||||
@ -237,7 +241,7 @@ func (pf *PortForwarder) waitForConnection(listener net.Listener, port Forwarded
|
||||
if err != nil {
|
||||
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
|
||||
glog.Errorf("Error accepting connection on port %d: %v", port.Local, err)
|
||||
util.HandleError(fmt.Errorf("Error accepting connection on port %d: %v", port.Local, err))
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -245,6 +249,14 @@ func (pf *PortForwarder) waitForConnection(listener net.Listener, port Forwarded
|
||||
}
|
||||
}
|
||||
|
||||
func (pf *PortForwarder) nextRequestID() int {
|
||||
pf.requestIDLock.Lock()
|
||||
defer pf.requestIDLock.Unlock()
|
||||
id := pf.requestID
|
||||
pf.requestID++
|
||||
return id
|
||||
}
|
||||
|
||||
// handleConnection copies data between the local connection and the stream to
|
||||
// the remote server.
|
||||
func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
|
||||
@ -252,65 +264,76 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
|
||||
|
||||
glog.Infof("Handling connection for %d", port.Local)
|
||||
|
||||
errorChan := make(chan error)
|
||||
doneChan := make(chan struct{}, 2)
|
||||
requestID := pf.nextRequestID()
|
||||
|
||||
// create error stream
|
||||
headers := http.Header{}
|
||||
headers.Set(api.StreamType, api.StreamTypeError)
|
||||
headers.Set(api.PortHeader, fmt.Sprintf("%d", port.Remote))
|
||||
headers.Set(api.PortForwardRequestIDHeader, strconv.Itoa(requestID))
|
||||
errorStream, err := pf.streamConn.CreateStream(headers)
|
||||
if err != nil {
|
||||
glog.Errorf("Error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err)
|
||||
util.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err))
|
||||
return
|
||||
}
|
||||
defer errorStream.Reset()
|
||||
// we're not writing to this stream
|
||||
errorStream.Close()
|
||||
|
||||
errorChan := make(chan error)
|
||||
go func() {
|
||||
message, err := ioutil.ReadAll(errorStream)
|
||||
if err != nil && err != io.EOF {
|
||||
errorChan <- fmt.Errorf("Error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
|
||||
}
|
||||
if len(message) > 0 {
|
||||
errorChan <- fmt.Errorf("An error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
|
||||
switch {
|
||||
case err != nil:
|
||||
errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
|
||||
case len(message) > 0:
|
||||
errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
|
||||
}
|
||||
close(errorChan)
|
||||
}()
|
||||
|
||||
// create data stream
|
||||
headers.Set(api.StreamType, api.StreamTypeData)
|
||||
dataStream, err := pf.streamConn.CreateStream(headers)
|
||||
if err != nil {
|
||||
glog.Errorf("Error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err)
|
||||
util.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
|
||||
return
|
||||
}
|
||||
// Send a Reset when this function exits to completely tear down the stream here
|
||||
// and in the remote server.
|
||||
defer dataStream.Reset()
|
||||
|
||||
localError := make(chan struct{})
|
||||
remoteDone := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
// Copy from the remote side to the local port. We won't get an EOF from
|
||||
// the server as it has no way of knowing when to close the stream. We'll
|
||||
// take care of closing both ends of the stream with the call to
|
||||
// stream.Reset() when this function exits.
|
||||
if _, err := io.Copy(conn, dataStream); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
glog.Errorf("Error copying from remote stream to local connection: %v", err)
|
||||
// Copy from the remote side to the local port.
|
||||
if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
util.HandleError(fmt.Errorf("error copying from remote stream to local connection: %v", err))
|
||||
}
|
||||
doneChan <- struct{}{}
|
||||
|
||||
// inform the select below that the remote copy is done
|
||||
close(remoteDone)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
// Copy from the local port to the remote side. Here we will be able to know
|
||||
// when the Copy gets an EOF from conn, as that will happen as soon as conn is
|
||||
// closed (i.e. client disconnected).
|
||||
if _, err := io.Copy(dataStream, conn); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
glog.Errorf("Error copying from local connection to remote stream: %v", err)
|
||||
// inform server we're not sending any more data after copy unblocks
|
||||
defer dataStream.Close()
|
||||
|
||||
// Copy from the local port to the remote side.
|
||||
if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
util.HandleError(fmt.Errorf("error copying from local connection to remote stream: %v", err))
|
||||
// break out of the select below without waiting for the other copy to finish
|
||||
close(localError)
|
||||
}
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
|
||||
// wait for either a local->remote error or for copying from remote->local to finish
|
||||
select {
|
||||
case err := <-errorChan:
|
||||
glog.Error(err)
|
||||
case <-doneChan:
|
||||
case <-remoteDone:
|
||||
case <-localError:
|
||||
}
|
||||
|
||||
// always expect something on errorChan (it may be nil)
|
||||
err = <-errorChan
|
||||
if err != nil {
|
||||
util.HandleError(err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -318,7 +341,7 @@ func (pf *PortForwarder) Close() {
|
||||
// stop all listeners
|
||||
for _, l := range pf.listeners {
|
||||
if err := l.Close(); err != nil {
|
||||
glog.Errorf("Error closing listener: %v", err)
|
||||
util.HandleError(fmt.Errorf("error closing listener: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -18,20 +18,21 @@ package portforward
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"k8s.io/kubernetes/pkg/api"
|
||||
client "k8s.io/kubernetes/pkg/client/unversioned"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||
"k8s.io/kubernetes/pkg/kubelet"
|
||||
"k8s.io/kubernetes/pkg/types"
|
||||
)
|
||||
|
||||
func TestParsePortsAndNew(t *testing.T) {
|
||||
@ -110,109 +111,6 @@ func TestParsePortsAndNew(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type fakeUpgrader struct {
|
||||
conn *fakeUpgradeConnection
|
||||
err error
|
||||
}
|
||||
|
||||
func (u *fakeUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
|
||||
return u.conn, u.err
|
||||
}
|
||||
|
||||
type fakeUpgradeConnection struct {
|
||||
closeCalled bool
|
||||
lock sync.Mutex
|
||||
streams map[string]*fakeUpgradeStream
|
||||
portData map[string]string
|
||||
}
|
||||
|
||||
func newFakeUpgradeConnection() *fakeUpgradeConnection {
|
||||
return &fakeUpgradeConnection{
|
||||
streams: make(map[string]*fakeUpgradeStream),
|
||||
portData: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
stream := &fakeUpgradeStream{}
|
||||
c.streams[headers.Get(api.PortHeader)] = stream
|
||||
// only simulate data on the data stream for now, not the error stream
|
||||
if headers.Get(api.StreamType) == api.StreamTypeData {
|
||||
stream.data = c.portData[headers.Get(api.PortHeader)]
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) Close() error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.closeCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) CloseChan() <-chan bool {
|
||||
return make(chan bool)
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) {
|
||||
}
|
||||
|
||||
type fakeUpgradeStream struct {
|
||||
readCalled bool
|
||||
writeCalled bool
|
||||
dataWritten []byte
|
||||
closeCalled bool
|
||||
resetCalled bool
|
||||
data string
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Read(p []byte) (int, error) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.readCalled = true
|
||||
b := []byte(s.data)
|
||||
n := copy(p, b)
|
||||
// Indicate we returned all the data, and have no more data (EOF)
|
||||
// Returning an EOF here will cause the port forwarder to immediately terminate, which is correct when we have no more data to send
|
||||
return n, io.EOF
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Write(p []byte) (int, error) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.writeCalled = true
|
||||
s.dataWritten = append(s.dataWritten, p...)
|
||||
// Indicate the stream accepted all the data, and can accept more (no err)
|
||||
// Returning an EOF here will cause the port forwarder to immediately terminate, which is incorrect, in case someone writes more data
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Close() error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.closeCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Reset() error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.resetCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Headers() http.Header {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
return http.Header{}
|
||||
}
|
||||
|
||||
type GetListenerTestCase struct {
|
||||
Hostname string
|
||||
Protocol string
|
||||
@ -295,55 +193,119 @@ func TestGetListener(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// fakePortForwarder simulates port forwarding for testing. It implements
|
||||
// kubelet.PortForwarder.
|
||||
type fakePortForwarder struct {
|
||||
lock sync.Mutex
|
||||
// stores data received from the stream per port
|
||||
received map[uint16]string
|
||||
// data to be sent to the stream per port
|
||||
send map[uint16]string
|
||||
}
|
||||
|
||||
var _ kubelet.PortForwarder = &fakePortForwarder{}
|
||||
|
||||
func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
|
||||
defer stream.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// client -> server
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// copy from stream into a buffer
|
||||
received := new(bytes.Buffer)
|
||||
io.Copy(received, stream)
|
||||
|
||||
// store the received content
|
||||
pf.lock.Lock()
|
||||
pf.received[port] = received.String()
|
||||
pf.lock.Unlock()
|
||||
}()
|
||||
|
||||
// server -> client
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// send the hardcoded data to the stream
|
||||
io.Copy(stream, strings.NewReader(pf.send[port]))
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// fakePortForwardServer creates an HTTP server that can handle port forwarding
|
||||
// requests.
|
||||
func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedFromClient map[uint16]string) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
pf := &fakePortForwarder{
|
||||
received: make(map[uint16]string),
|
||||
send: serverSends,
|
||||
}
|
||||
kubelet.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second)
|
||||
|
||||
for port, expected := range expectedFromClient {
|
||||
actual, ok := pf.received[port]
|
||||
if !ok {
|
||||
t.Errorf("%s: server didn't receive any data for port %d", testName, port)
|
||||
continue
|
||||
}
|
||||
|
||||
if expected != actual {
|
||||
t.Errorf("%s: server expected to receive %q, got %q for port %d", testName, expected, actual, port)
|
||||
}
|
||||
}
|
||||
|
||||
for port, actual := range pf.received {
|
||||
if _, ok := expectedFromClient[port]; !ok {
|
||||
t.Errorf("%s: server unexpectedly received %q for port %d", testName, actual, port)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestForwardPorts(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Upgrader *fakeUpgrader
|
||||
Ports []string
|
||||
Send map[uint16]string
|
||||
Receive map[uint16]string
|
||||
Err bool
|
||||
tests := map[string]struct {
|
||||
ports []string
|
||||
clientSends map[uint16]string
|
||||
serverSends map[uint16]string
|
||||
}{
|
||||
{
|
||||
Upgrader: &fakeUpgrader{err: errors.New("bail")},
|
||||
Err: true,
|
||||
"forward 1 port with no data either direction": {
|
||||
ports: []string{"5000"},
|
||||
},
|
||||
{
|
||||
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
|
||||
Ports: []string{"5000"},
|
||||
},
|
||||
{
|
||||
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
|
||||
Ports: []string{"5001", "6000"},
|
||||
Send: map[uint16]string{
|
||||
"forward 2 ports with bidirectional data": {
|
||||
ports: []string{"5001", "6000"},
|
||||
clientSends: map[uint16]string{
|
||||
5001: "abcd",
|
||||
6000: "ghij",
|
||||
},
|
||||
Receive: map[uint16]string{
|
||||
serverSends: map[uint16]string{
|
||||
5001: "1234",
|
||||
6000: "5678",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range testCases {
|
||||
for testName, test := range tests {
|
||||
server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends))
|
||||
url, _ := url.ParseRequestURI(server.URL)
|
||||
c := client.NewRESTClient(url, "x", nil, -1, -1)
|
||||
req := c.Post().Resource("testing")
|
||||
|
||||
conf := &client.Config{
|
||||
Host: server.URL,
|
||||
}
|
||||
|
||||
stopChan := make(chan struct{}, 1)
|
||||
|
||||
pf, err := New(&client.Request{}, &client.Config{}, testCase.Ports, stopChan)
|
||||
hasErr := err != nil
|
||||
if hasErr != testCase.Err {
|
||||
t.Fatalf("%d: New: expected %t, got %t: %v", i, testCase.Err, hasErr, err)
|
||||
}
|
||||
if pf == nil {
|
||||
continue
|
||||
}
|
||||
pf.upgrader = testCase.Upgrader
|
||||
if testCase.Upgrader.err != nil {
|
||||
err := pf.ForwardPorts()
|
||||
hasErr := err != nil
|
||||
if hasErr != testCase.Err {
|
||||
t.Fatalf("%d: ForwardPorts: expected %t, got %t: %v", i, testCase.Err, hasErr, err)
|
||||
}
|
||||
continue
|
||||
pf, err := New(req, conf, test.ports, stopChan)
|
||||
if err != nil {
|
||||
t.Fatalf("%s: unexpected error calling New: %v", testName, err)
|
||||
}
|
||||
|
||||
doneChan := make(chan error)
|
||||
@ -352,65 +314,70 @@ func TestForwardPorts(t *testing.T) {
|
||||
}()
|
||||
<-pf.Ready
|
||||
|
||||
conn := testCase.Upgrader.conn
|
||||
|
||||
for port, data := range testCase.Send {
|
||||
conn.lock.Lock()
|
||||
conn.portData[fmt.Sprintf("%d", port)] = testCase.Receive[port]
|
||||
conn.lock.Unlock()
|
||||
|
||||
for port, data := range test.clientSends {
|
||||
clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error dialing %d: %s", i, port, err)
|
||||
t.Errorf("%s: error dialing %d: %s", testName, port, err)
|
||||
server.Close()
|
||||
continue
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
||||
n, err := clientConn.Write([]byte(data))
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("%d: Error sending data '%s': %s", i, data, err)
|
||||
t.Errorf("%s: Error sending data '%s': %s", testName, data, err)
|
||||
server.Close()
|
||||
continue
|
||||
}
|
||||
if n == 0 {
|
||||
t.Fatalf("%d: unexpected write of 0 bytes", i)
|
||||
t.Errorf("%s: unexpected write of 0 bytes", testName)
|
||||
server.Close()
|
||||
continue
|
||||
}
|
||||
b := make([]byte, 4)
|
||||
n, err = clientConn.Read(b)
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("%d: Error reading data: %s", i, err)
|
||||
t.Errorf("%s: Error reading data: %s", testName, err)
|
||||
server.Close()
|
||||
continue
|
||||
}
|
||||
if !bytes.Equal([]byte(testCase.Receive[port]), b) {
|
||||
t.Fatalf("%d: expected to read '%s', got '%s'", i, testCase.Receive[port], b)
|
||||
if !bytes.Equal([]byte(test.serverSends[port]), b) {
|
||||
t.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b)
|
||||
server.Close()
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// tell r.ForwardPorts to stop
|
||||
close(stopChan)
|
||||
|
||||
// wait for r.ForwardPorts to actually return
|
||||
err = <-doneChan
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error: %s", i, err)
|
||||
}
|
||||
|
||||
if e, a := len(testCase.Send), len(conn.streams); e != a {
|
||||
t.Fatalf("%d: expected %d streams to be created, got %d", i, e, a)
|
||||
}
|
||||
|
||||
if !conn.closeCalled {
|
||||
t.Fatalf("%d: expected conn closure", i)
|
||||
t.Errorf("%s: unexpected error: %s", testName, err)
|
||||
}
|
||||
server.Close()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) {
|
||||
server := httptest.NewServer(fakePortForwardServer(t, "allBindsFailed", nil, nil))
|
||||
defer server.Close()
|
||||
url, _ := url.ParseRequestURI(server.URL)
|
||||
c := client.NewRESTClient(url, "x", nil, -1, -1)
|
||||
req := c.Post().Resource("testing")
|
||||
|
||||
conf := &client.Config{
|
||||
Host: server.URL,
|
||||
}
|
||||
|
||||
stopChan1 := make(chan struct{}, 1)
|
||||
defer close(stopChan1)
|
||||
|
||||
pf1, err := New(&client.Request{}, &client.Config{}, []string{"5555"}, stopChan1)
|
||||
pf1, err := New(req, conf, []string{"5555"}, stopChan1)
|
||||
if err != nil {
|
||||
t.Fatalf("error creating pf1: %v", err)
|
||||
}
|
||||
pf1.upgrader = &fakeUpgrader{conn: newFakeUpgradeConnection()}
|
||||
go pf1.ForwardPorts()
|
||||
<-pf1.Ready
|
||||
|
||||
@ -419,7 +386,6 @@ func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("error creating pf2: %v", err)
|
||||
}
|
||||
pf2.upgrader = &fakeUpgrader{conn: newFakeUpgradeConnection()}
|
||||
if err := pf2.ForwardPorts(); err == nil {
|
||||
t.Fatal("expected non-nil error for pf2.ForwardPorts")
|
||||
}
|
||||
|
@ -1179,16 +1179,39 @@ func (dm *DockerManager) PortForward(pod *kubecontainer.Pod, port uint16, stream
|
||||
}
|
||||
|
||||
containerPid := container.State.Pid
|
||||
// TODO what if the host doesn't have it???
|
||||
_, lookupErr := exec.LookPath("socat")
|
||||
socatPath, lookupErr := exec.LookPath("socat")
|
||||
if lookupErr != nil {
|
||||
return fmt.Errorf("Unable to do port forwarding: socat not found.")
|
||||
return fmt.Errorf("unable to do port forwarding: socat not found.")
|
||||
}
|
||||
args := []string{"-t", fmt.Sprintf("%d", containerPid), "-n", "socat", "-", fmt.Sprintf("TCP4:localhost:%d", port)}
|
||||
// TODO use exec.LookPath
|
||||
command := exec.Command("nsenter", args...)
|
||||
command.Stdin = stream
|
||||
|
||||
args := []string{"-t", fmt.Sprintf("%d", containerPid), "-n", socatPath, "-", fmt.Sprintf("TCP4:localhost:%d", port)}
|
||||
|
||||
nsenterPath, lookupErr := exec.LookPath("nsenter")
|
||||
if lookupErr != nil {
|
||||
return fmt.Errorf("unable to do port forwarding: nsenter not found.")
|
||||
}
|
||||
|
||||
command := exec.Command(nsenterPath, args...)
|
||||
command.Stdout = stream
|
||||
|
||||
// If we use Stdin, command.Run() won't return until the goroutine that's copying
|
||||
// from stream finishes. Unfortunately, if you have a client like telnet connected
|
||||
// via port forwarding, as long as the user's telnet client is connected to the user's
|
||||
// local listener that port forwarding sets up, the telnet session never exits. This
|
||||
// means that even if socat has finished running, command.Run() won't ever return
|
||||
// (because the client still has the connection and stream open).
|
||||
//
|
||||
// The work around is to use StdinPipe(), as Wait() (called by Run()) closes the pipe
|
||||
// when the command (socat) exits.
|
||||
inPipe, err := command.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to do port forwarding: error creating stdin pipe: %v", err)
|
||||
}
|
||||
go func() {
|
||||
io.Copy(inPipe, stream)
|
||||
inPipe.Close()
|
||||
}()
|
||||
|
||||
return command.Run()
|
||||
}
|
||||
|
||||
|
@ -1762,6 +1762,12 @@ func TestSyncPodEventHandlerFails(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type fakeReadWriteCloser struct{}
|
||||
|
||||
func (*fakeReadWriteCloser) Read([]byte) (int, error) { return 0, nil }
|
||||
func (*fakeReadWriteCloser) Write([]byte) (int, error) { return 0, nil }
|
||||
func (*fakeReadWriteCloser) Close() error { return nil }
|
||||
|
||||
func TestPortForwardNoSuchContainer(t *testing.T) {
|
||||
dm, _ := newTestDockerManager()
|
||||
|
||||
@ -1774,7 +1780,8 @@ func TestPortForwardNoSuchContainer(t *testing.T) {
|
||||
Containers: nil,
|
||||
},
|
||||
5000,
|
||||
nil,
|
||||
// need a valid io.ReadWriteCloser here
|
||||
&fakeReadWriteCloser{},
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("unexpected non-error")
|
||||
|
@ -1220,19 +1220,39 @@ func (r *runtime) PortForward(pod *kubecontainer.Pod, port uint16, stream io.Rea
|
||||
return err
|
||||
}
|
||||
|
||||
_, lookupErr := exec.LookPath("socat")
|
||||
socatPath, lookupErr := exec.LookPath("socat")
|
||||
if lookupErr != nil {
|
||||
return fmt.Errorf("unable to do port forwarding: socat not found.")
|
||||
}
|
||||
args := []string{"-t", fmt.Sprintf("%d", info.pid), "-n", "socat", "-", fmt.Sprintf("TCP4:localhost:%d", port)}
|
||||
|
||||
_, lookupErr = exec.LookPath("nsenter")
|
||||
args := []string{"-t", fmt.Sprintf("%d", info.pid), "-n", socatPath, "-", fmt.Sprintf("TCP4:localhost:%d", port)}
|
||||
|
||||
nsenterPath, lookupErr := exec.LookPath("nsenter")
|
||||
if lookupErr != nil {
|
||||
return fmt.Errorf("unable to do port forwarding: nsenter not found.")
|
||||
}
|
||||
command := exec.Command("nsenter", args...)
|
||||
command.Stdin = stream
|
||||
|
||||
command := exec.Command(nsenterPath, args...)
|
||||
command.Stdout = stream
|
||||
|
||||
// If we use Stdin, command.Run() won't return until the goroutine that's copying
|
||||
// from stream finishes. Unfortunately, if you have a client like telnet connected
|
||||
// via port forwarding, as long as the user's telnet client is connected to the user's
|
||||
// local listener that port forwarding sets up, the telnet session never exits. This
|
||||
// means that even if socat has finished running, command.Run() won't ever return
|
||||
// (because the client still has the connection and stream open).
|
||||
//
|
||||
// The work around is to use StdinPipe(), as Wait() (called by Run()) closes the pipe
|
||||
// when the command (socat) exits.
|
||||
inPipe, err := command.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to do port forwarding: error creating stdin pipe: %v", err)
|
||||
}
|
||||
go func() {
|
||||
io.Copy(inPipe, stream)
|
||||
inPipe.Close()
|
||||
}()
|
||||
|
||||
return command.Run()
|
||||
}
|
||||
|
||||
|
@ -45,6 +45,7 @@ import (
|
||||
"k8s.io/kubernetes/pkg/httplog"
|
||||
kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
|
||||
"k8s.io/kubernetes/pkg/types"
|
||||
"k8s.io/kubernetes/pkg/util"
|
||||
"k8s.io/kubernetes/pkg/util/flushwriter"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
|
||||
@ -458,7 +459,7 @@ func getContainerCoordinates(request *restful.Request) (namespace, pod string, u
|
||||
return
|
||||
}
|
||||
|
||||
const streamCreationTimeout = 30 * time.Second
|
||||
const defaultStreamCreationTimeout = 30 * time.Second
|
||||
|
||||
func (s *Server) getAttach(request *restful.Request, response *restful.Response) {
|
||||
podNamespace, podID, uid, container := getContainerCoordinates(request)
|
||||
@ -564,7 +565,7 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo
|
||||
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
|
||||
|
||||
// TODO make it configurable?
|
||||
expired := time.NewTimer(streamCreationTimeout)
|
||||
expired := time.NewTimer(defaultStreamCreationTimeout)
|
||||
|
||||
var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
|
||||
receivedStreams := 0
|
||||
@ -612,6 +613,15 @@ func getPodCoordinates(request *restful.Request) (namespace, pod string, uid typ
|
||||
return
|
||||
}
|
||||
|
||||
// PortForwarder knows how to forward content from a data stream to/from a port
|
||||
// in a pod.
|
||||
type PortForwarder interface {
|
||||
// PortForwarder copies data between a data stream and a port in a pod.
|
||||
PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error
|
||||
}
|
||||
|
||||
// getPortForward handles a new restful port forward request. It determines the
|
||||
// pod name and uid and then calls ServePortForward.
|
||||
func (s *Server) getPortForward(request *restful.Request, response *restful.Response) {
|
||||
podNamespace, podID, uid := getPodCoordinates(request)
|
||||
pod, ok := s.host.GetPodByName(podNamespace, podID)
|
||||
@ -620,80 +630,280 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
|
||||
return
|
||||
}
|
||||
|
||||
podName := kubecontainer.GetPodFullName(pod)
|
||||
|
||||
ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), defaultStreamCreationTimeout)
|
||||
}
|
||||
|
||||
// ServePortForward handles a port forwarding request. A single request is
|
||||
// kept alive as long as the client is still alive and the connection has not
|
||||
// been timed out due to idleness. This function handles multiple forwarded
|
||||
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
|
||||
// handled by a single invocation of ServePortForward.
|
||||
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
|
||||
streamChan := make(chan httpstream.Stream, 1)
|
||||
|
||||
glog.V(5).Infof("Upgrading port forward response")
|
||||
upgrader := spdy.NewResponseUpgrader()
|
||||
conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error {
|
||||
portString := stream.Headers().Get(api.PortHeader)
|
||||
port, err := strconv.ParseUint(portString, 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to parse '%s' as a port: %v", portString, err)
|
||||
}
|
||||
if port < 1 {
|
||||
return fmt.Errorf("Port '%d' must be greater than 0", port)
|
||||
}
|
||||
streamChan <- stream
|
||||
return nil
|
||||
})
|
||||
conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan))
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
|
||||
|
||||
var dataStreamLock sync.Mutex
|
||||
dataStreamChans := make(map[string]chan httpstream.Stream)
|
||||
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
|
||||
conn.SetIdleTimeout(idleTimeout)
|
||||
|
||||
h := &portForwardStreamHandler{
|
||||
conn: conn,
|
||||
streamChan: streamChan,
|
||||
streamPairs: make(map[string]*portForwardStreamPair),
|
||||
streamCreationTimeout: streamCreationTimeout,
|
||||
pod: podName,
|
||||
uid: uid,
|
||||
forwarder: portForwarder,
|
||||
}
|
||||
h.run()
|
||||
}
|
||||
|
||||
// portForwardStreamReceived is the httpstream.NewStreamHandler for port
|
||||
// forward streams. It checks each stream's port and stream type headers,
|
||||
// rejecting any streams that with missing or invalid values. Each valid
|
||||
// stream is sent to the streams channel.
|
||||
func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream) error {
|
||||
return func(stream httpstream.Stream) error {
|
||||
// make sure it has a valid port header
|
||||
portString := stream.Headers().Get(api.PortHeader)
|
||||
if len(portString) == 0 {
|
||||
return fmt.Errorf("%q header is required", api.PortHeader)
|
||||
}
|
||||
port, err := strconv.ParseUint(portString, 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
|
||||
}
|
||||
if port < 1 {
|
||||
return fmt.Errorf("port %q must be > 0", portString)
|
||||
}
|
||||
|
||||
// make sure it has a valid stream type header
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
if len(streamType) == 0 {
|
||||
return fmt.Errorf("%q header is required", api.StreamType)
|
||||
}
|
||||
if streamType != api.StreamTypeError && streamType != api.StreamTypeData {
|
||||
return fmt.Errorf("invalid stream type %q", streamType)
|
||||
}
|
||||
|
||||
streams <- stream
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// portForwardStreamHandler is capable of processing multiple port forward
|
||||
// requests over a single httpstream.Connection.
|
||||
type portForwardStreamHandler struct {
|
||||
conn httpstream.Connection
|
||||
streamChan chan httpstream.Stream
|
||||
streamPairsLock sync.RWMutex
|
||||
streamPairs map[string]*portForwardStreamPair
|
||||
streamCreationTimeout time.Duration
|
||||
pod string
|
||||
uid types.UID
|
||||
forwarder PortForwarder
|
||||
}
|
||||
|
||||
// getStreamPair returns a portForwardStreamPair for requestID. This creates a
|
||||
// new pair if one does not yet exist for the requestID. The returned bool is
|
||||
// true if the pair was created.
|
||||
func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) {
|
||||
h.streamPairsLock.Lock()
|
||||
defer h.streamPairsLock.Unlock()
|
||||
|
||||
if p, ok := h.streamPairs[requestID]; ok {
|
||||
glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
|
||||
return p, false
|
||||
}
|
||||
|
||||
glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)
|
||||
|
||||
p := newPortForwardPair(requestID)
|
||||
h.streamPairs[requestID] = p
|
||||
|
||||
return p, true
|
||||
}
|
||||
|
||||
// monitorStreamPair waits for the pair to receive both its error and data
|
||||
// streams, or for the timeout to expire (whichever happens first), and then
|
||||
// removes the pair.
|
||||
func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) {
|
||||
select {
|
||||
case <-timeout:
|
||||
err := fmt.Errorf("(conn=%p, request=%s) timed out waiting for streams", h.conn, p.requestID)
|
||||
util.HandleError(err)
|
||||
p.printError(err.Error())
|
||||
case <-p.complete:
|
||||
glog.V(5).Infof("(conn=%p, request=%s) successfully received error and data streams", h.conn, p.requestID)
|
||||
}
|
||||
h.removeStreamPair(p.requestID)
|
||||
}
|
||||
|
||||
// hasStreamPair returns a bool indicating if a stream pair for requestID
|
||||
// exists.
|
||||
func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool {
|
||||
h.streamPairsLock.RLock()
|
||||
defer h.streamPairsLock.RUnlock()
|
||||
|
||||
_, ok := h.streamPairs[requestID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// removeStreamPair removes the stream pair identified by requestID from streamPairs.
|
||||
func (h *portForwardStreamHandler) removeStreamPair(requestID string) {
|
||||
h.streamPairsLock.Lock()
|
||||
defer h.streamPairsLock.Unlock()
|
||||
|
||||
delete(h.streamPairs, requestID)
|
||||
}
|
||||
|
||||
// requestID returns the request id for stream.
|
||||
func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string {
|
||||
requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
|
||||
if len(requestID) == 0 {
|
||||
glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
|
||||
// If we get here, it's because the connection came from an older client
|
||||
// that isn't generating the request id header
|
||||
// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
|
||||
//
|
||||
// This is a best-effort attempt at supporting older clients.
|
||||
//
|
||||
// When there aren't concurrent new forwarded connections, each connection
|
||||
// will have a pair of streams (data, error), and the stream IDs will be
|
||||
// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
|
||||
// the stream ID into a pseudo-request id by taking the stream type and
|
||||
// using id = stream.Identifier() when the stream type is error,
|
||||
// and id = stream.Identifier() - 2 when it's data.
|
||||
//
|
||||
// NOTE: this only works when there are not concurrent new streams from
|
||||
// multiple forwarded connections; it's a best-effort attempt at supporting
|
||||
// old clients that don't generate request ids. If there are concurrent
|
||||
// new connections, it's possible that 1 connection gets streams whose IDs
|
||||
// are not consecutive (e.g. 5 and 9 instead of 5 and 7).
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
switch streamType {
|
||||
case api.StreamTypeError:
|
||||
requestID = strconv.Itoa(int(stream.Identifier()))
|
||||
case api.StreamTypeData:
|
||||
requestID = strconv.Itoa(int(stream.Identifier()) - 2)
|
||||
}
|
||||
|
||||
glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
|
||||
}
|
||||
return requestID
|
||||
}
|
||||
|
||||
// run is the main loop for the portForwardStreamHandler. It processes new
|
||||
// streams, invoking portForward for each complete stream pair. The loop exits
|
||||
// when the httpstream.Connection is closed.
|
||||
func (h *portForwardStreamHandler) run() {
|
||||
glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn)
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case <-conn.CloseChan():
|
||||
case <-h.conn.CloseChan():
|
||||
glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn)
|
||||
break Loop
|
||||
case stream := <-streamChan:
|
||||
case stream := <-h.streamChan:
|
||||
requestID := h.requestID(stream)
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
port := stream.Headers().Get(api.PortHeader)
|
||||
dataStreamLock.Lock()
|
||||
switch streamType {
|
||||
case "error":
|
||||
ch := make(chan httpstream.Stream)
|
||||
dataStreamChans[port] = ch
|
||||
go waitForPortForwardDataStreamAndRun(kubecontainer.GetPodFullName(pod), uid, stream, ch, s.host)
|
||||
case "data":
|
||||
ch, ok := dataStreamChans[port]
|
||||
if ok {
|
||||
ch <- stream
|
||||
delete(dataStreamChans, port)
|
||||
} else {
|
||||
glog.Errorf("Unable to locate data stream channel for port %s", port)
|
||||
}
|
||||
default:
|
||||
glog.Errorf("streamType header must be 'error' or 'data', got: '%s'", streamType)
|
||||
stream.Reset()
|
||||
glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)
|
||||
|
||||
p, created := h.getStreamPair(requestID)
|
||||
if created {
|
||||
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
|
||||
}
|
||||
if complete, err := p.add(stream); err != nil {
|
||||
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
|
||||
util.HandleError(errors.New(msg))
|
||||
p.printError(msg)
|
||||
} else if complete {
|
||||
go h.portForward(p)
|
||||
}
|
||||
dataStreamLock.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func waitForPortForwardDataStreamAndRun(pod string, uid types.UID, errorStream httpstream.Stream, dataStreamChan chan httpstream.Stream, host HostInterface) {
|
||||
defer errorStream.Reset()
|
||||
// portForward invokes the portForwardStreamHandler's forwarder.PortForward
|
||||
// function for the given stream pair.
|
||||
func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) {
|
||||
defer p.dataStream.Close()
|
||||
defer p.errorStream.Close()
|
||||
|
||||
var dataStream httpstream.Stream
|
||||
portString := p.dataStream.Headers().Get(api.PortHeader)
|
||||
port, _ := strconv.ParseUint(portString, 10, 16)
|
||||
|
||||
select {
|
||||
case dataStream = <-dataStreamChan:
|
||||
case <-time.After(streamCreationTimeout):
|
||||
errorStream.Write([]byte("Timed out waiting for data stream"))
|
||||
//TODO delete from dataStreamChans[port]
|
||||
return
|
||||
glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
|
||||
err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream)
|
||||
glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
|
||||
|
||||
if err != nil {
|
||||
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err)
|
||||
util.HandleError(msg)
|
||||
fmt.Fprint(p.errorStream, msg.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// portForwardStreamPair represents the error and data streams for a port
|
||||
// forwarding request.
|
||||
type portForwardStreamPair struct {
|
||||
lock sync.RWMutex
|
||||
requestID string
|
||||
dataStream httpstream.Stream
|
||||
errorStream httpstream.Stream
|
||||
complete chan struct{}
|
||||
}
|
||||
|
||||
// newPortForwardPair creates a new portForwardStreamPair.
|
||||
func newPortForwardPair(requestID string) *portForwardStreamPair {
|
||||
return &portForwardStreamPair{
|
||||
requestID: requestID,
|
||||
complete: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// add adds the stream to the portForwardStreamPair. If the pair already
|
||||
// contains a stream for the new stream's type, an error is returned. add
|
||||
// returns true if both the data and error streams for this pair have been
|
||||
// received.
|
||||
func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
switch stream.Headers().Get(api.StreamType) {
|
||||
case api.StreamTypeError:
|
||||
if p.errorStream != nil {
|
||||
return false, errors.New("error stream already assigned")
|
||||
}
|
||||
p.errorStream = stream
|
||||
case api.StreamTypeData:
|
||||
if p.dataStream != nil {
|
||||
return false, errors.New("data stream already assigned")
|
||||
}
|
||||
p.dataStream = stream
|
||||
}
|
||||
|
||||
portString := dataStream.Headers().Get(api.PortHeader)
|
||||
port, _ := strconv.ParseUint(portString, 10, 16)
|
||||
err := host.PortForward(pod, uid, uint16(port), dataStream)
|
||||
if err != nil {
|
||||
msg := fmt.Errorf("Error forwarding port %d to pod %s, uid %v: %v", port, pod, uid, err)
|
||||
glog.Error(msg)
|
||||
errorStream.Write([]byte(msg.Error()))
|
||||
complete := p.errorStream != nil && p.dataStream != nil
|
||||
if complete {
|
||||
close(p.complete)
|
||||
}
|
||||
return complete, nil
|
||||
}
|
||||
|
||||
// printError writes s to p.errorStream if p.errorStream has been set.
|
||||
func (p *portForwardStreamPair) printError(s string) {
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
if p.errorStream != nil {
|
||||
fmt.Fprint(p.errorStream, s)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1426,3 +1426,221 @@ func TestServePortForward(t *testing.T) {
|
||||
<-portForwardFuncDone
|
||||
}
|
||||
}
|
||||
|
||||
type fakeHttpStream struct {
|
||||
headers http.Header
|
||||
id uint32
|
||||
}
|
||||
|
||||
func newFakeHttpStream() *fakeHttpStream {
|
||||
return &fakeHttpStream{
|
||||
headers: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
var _ httpstream.Stream = &fakeHttpStream{}
|
||||
|
||||
func (s *fakeHttpStream) Read(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Write(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Headers() http.Header {
|
||||
return s.headers
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Identifier() uint32 {
|
||||
return s.id
|
||||
}
|
||||
|
||||
func TestPortForwardStreamReceived(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
port string
|
||||
streamType string
|
||||
expectedError string
|
||||
}{
|
||||
"missing port": {
|
||||
expectedError: `"port" header is required`,
|
||||
},
|
||||
"unable to parse port": {
|
||||
port: "abc",
|
||||
expectedError: `unable to parse "abc" as a port: strconv.ParseUint: parsing "abc": invalid syntax`,
|
||||
},
|
||||
"negative port": {
|
||||
port: "-1",
|
||||
expectedError: `unable to parse "-1" as a port: strconv.ParseUint: parsing "-1": invalid syntax`,
|
||||
},
|
||||
"missing stream type": {
|
||||
port: "80",
|
||||
expectedError: `"streamType" header is required`,
|
||||
},
|
||||
"valid port with error stream": {
|
||||
port: "80",
|
||||
streamType: "error",
|
||||
},
|
||||
"valid port with data stream": {
|
||||
port: "80",
|
||||
streamType: "data",
|
||||
},
|
||||
"invalid stream type": {
|
||||
port: "80",
|
||||
streamType: "foo",
|
||||
expectedError: `invalid stream type "foo"`,
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
streams := make(chan httpstream.Stream, 1)
|
||||
f := portForwardStreamReceived(streams)
|
||||
stream := newFakeHttpStream()
|
||||
if len(test.port) > 0 {
|
||||
stream.headers.Set("port", test.port)
|
||||
}
|
||||
if len(test.streamType) > 0 {
|
||||
stream.headers.Set("streamType", test.streamType)
|
||||
}
|
||||
err := f(stream)
|
||||
if len(test.expectedError) > 0 {
|
||||
if err == nil {
|
||||
t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError)
|
||||
}
|
||||
if e, a := test.expectedError, err.Error(); e != a {
|
||||
t.Errorf("%s: expected err=%q, got %q", name, e, a)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("%s: unexpected error %v", name, err)
|
||||
continue
|
||||
}
|
||||
if s := <-streams; s != stream {
|
||||
t.Errorf("%s: expected stream %#v, got %#v", name, stream, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStreamPair(t *testing.T) {
|
||||
timeout := make(chan time.Time)
|
||||
|
||||
h := &portForwardStreamHandler{
|
||||
streamPairs: make(map[string]*portForwardStreamPair),
|
||||
}
|
||||
|
||||
// test adding a new entry
|
||||
p, created := h.getStreamPair("1")
|
||||
if p == nil {
|
||||
t.Fatalf("unexpected nil pair")
|
||||
}
|
||||
if !created {
|
||||
t.Fatal("expected created=true")
|
||||
}
|
||||
if p.dataStream != nil {
|
||||
t.Errorf("unexpected non-nil data stream")
|
||||
}
|
||||
if p.errorStream != nil {
|
||||
t.Errorf("unexpected non-nil error stream")
|
||||
}
|
||||
|
||||
// start the monitor for this pair
|
||||
monitorDone := make(chan struct{})
|
||||
go func() {
|
||||
h.monitorStreamPair(p, timeout)
|
||||
close(monitorDone)
|
||||
}()
|
||||
|
||||
if !h.hasStreamPair("1") {
|
||||
t.Fatal("This should still be true")
|
||||
}
|
||||
|
||||
// make sure we can retrieve an existing entry
|
||||
p2, created := h.getStreamPair("1")
|
||||
if created {
|
||||
t.Fatal("expected created=false")
|
||||
}
|
||||
if p != p2 {
|
||||
t.Fatalf("retrieving an existing pair: expected %#v, got %#v", p, p2)
|
||||
}
|
||||
|
||||
// removed via complete
|
||||
dataStream := newFakeHttpStream()
|
||||
dataStream.headers.Set(api.StreamType, api.StreamTypeData)
|
||||
complete, err := p.add(dataStream)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error adding data stream to pair: %v", err)
|
||||
}
|
||||
if complete {
|
||||
t.Fatalf("unexpected complete")
|
||||
}
|
||||
|
||||
errorStream := newFakeHttpStream()
|
||||
errorStream.headers.Set(api.StreamType, api.StreamTypeError)
|
||||
complete, err = p.add(errorStream)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error adding error stream to pair: %v", err)
|
||||
}
|
||||
if !complete {
|
||||
t.Fatal("unexpected incomplete")
|
||||
}
|
||||
|
||||
// make sure monitorStreamPair completed
|
||||
<-monitorDone
|
||||
|
||||
// make sure the pair was removed
|
||||
if h.hasStreamPair("1") {
|
||||
t.Fatal("expected removal of pair after both data and error streams received")
|
||||
}
|
||||
|
||||
// removed via timeout
|
||||
p, created = h.getStreamPair("2")
|
||||
if !created {
|
||||
t.Fatal("expected created=true")
|
||||
}
|
||||
if p == nil {
|
||||
t.Fatal("expected p not to be nil")
|
||||
}
|
||||
monitorDone = make(chan struct{})
|
||||
go func() {
|
||||
h.monitorStreamPair(p, timeout)
|
||||
close(monitorDone)
|
||||
}()
|
||||
// cause the timeout
|
||||
close(timeout)
|
||||
// make sure monitorStreamPair completed
|
||||
<-monitorDone
|
||||
if h.hasStreamPair("2") {
|
||||
t.Fatal("expected stream pair to be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestID(t *testing.T) {
|
||||
h := &portForwardStreamHandler{}
|
||||
|
||||
s := newFakeHttpStream()
|
||||
s.headers.Set(api.StreamType, api.StreamTypeError)
|
||||
s.id = 1
|
||||
if e, a := "1", h.requestID(s); e != a {
|
||||
t.Errorf("expected %q, got %q", e, a)
|
||||
}
|
||||
|
||||
s.headers.Set(api.StreamType, api.StreamTypeData)
|
||||
s.id = 3
|
||||
if e, a := "1", h.requestID(s); e != a {
|
||||
t.Errorf("expected %q, got %q", e, a)
|
||||
}
|
||||
|
||||
s.id = 7
|
||||
s.headers.Set(api.PortForwardRequestIDHeader, "2")
|
||||
if e, a := "2", h.requestID(s); e != a {
|
||||
t.Errorf("expected %q, got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
@ -78,6 +78,8 @@ type Stream interface {
|
||||
Reset() error
|
||||
// Headers returns the headers used to create the stream.
|
||||
Headers() http.Header
|
||||
// Identifier returns the stream's ID.
|
||||
Identifier() uint32
|
||||
}
|
||||
|
||||
// IsUpgradeRequest returns true if the given request is a connection upgrade request
|
||||
|
@ -60,10 +60,7 @@ const (
|
||||
simplePodPort = 80
|
||||
)
|
||||
|
||||
var (
|
||||
portForwardRegexp = regexp.MustCompile("Forwarding from 127.0.0.1:([0-9]+) -> 80")
|
||||
proxyRegexp = regexp.MustCompile("Starting to serve on 127.0.0.1:([0-9]+)")
|
||||
)
|
||||
var proxyRegexp = regexp.MustCompile("Starting to serve on 127.0.0.1:([0-9]+)")
|
||||
|
||||
var _ = Describe("Kubectl client", func() {
|
||||
defer GinkgoRecover()
|
||||
@ -200,32 +197,11 @@ var _ = Describe("Kubectl client", func() {
|
||||
|
||||
It("should support port-forward", func() {
|
||||
By("forwarding the container port to a local port")
|
||||
cmd := kubectlCmd("port-forward", fmt.Sprintf("--namespace=%v", ns), simplePodName, fmt.Sprintf(":%d", simplePodPort))
|
||||
cmd, listenPort := runPortForward(ns, simplePodName, simplePodPort)
|
||||
defer tryKill(cmd)
|
||||
// This is somewhat ugly but is the only way to retrieve the port that was picked
|
||||
// by the port-forward command. We don't want to hard code the port as we have no
|
||||
// way of guaranteeing we can pick one that isn't in use, particularly on Jenkins.
|
||||
Logf("starting port-forward command and streaming output")
|
||||
stdout, stderr, err := startCmdAndStreamOutput(cmd)
|
||||
if err != nil {
|
||||
Failf("Failed to start port-forward command: %v", err)
|
||||
}
|
||||
defer stdout.Close()
|
||||
defer stderr.Close()
|
||||
|
||||
buf := make([]byte, 128)
|
||||
var n int
|
||||
Logf("reading from `kubectl port-forward` command's stderr")
|
||||
if n, err = stderr.Read(buf); err != nil {
|
||||
Failf("Failed to read from kubectl port-forward stderr: %v", err)
|
||||
}
|
||||
portForwardOutput := string(buf[:n])
|
||||
match := portForwardRegexp.FindStringSubmatch(portForwardOutput)
|
||||
if len(match) != 2 {
|
||||
Failf("Failed to parse kubectl port-forward output: %s", portForwardOutput)
|
||||
}
|
||||
By("curling local port output")
|
||||
localAddr := fmt.Sprintf("http://localhost:%s", match[1])
|
||||
localAddr := fmt.Sprintf("http://localhost:%d", listenPort)
|
||||
body, err := curl(localAddr)
|
||||
Logf("got: %s", body)
|
||||
if err != nil {
|
||||
|
234
test/e2e/portforward.go
Normal file
234
test/e2e/portforward.go
Normal file
@ -0,0 +1,234 @@
|
||||
/*
|
||||
Copyright 2015 The Kubernetes Authors All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"k8s.io/kubernetes/pkg/api"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
)
|
||||
|
||||
const (
|
||||
podName = "pfpod"
|
||||
)
|
||||
|
||||
// TODO support other ports besides 80
|
||||
var portForwardRegexp = regexp.MustCompile("Forwarding from 127.0.0.1:([0-9]+) -> 80")
|
||||
|
||||
func pfPod(expectedClientData, chunks, chunkSize, chunkIntervalMillis string) *api.Pod {
|
||||
return &api.Pod{
|
||||
ObjectMeta: api.ObjectMeta{
|
||||
Name: podName,
|
||||
Labels: map[string]string{"name": podName},
|
||||
},
|
||||
Spec: api.PodSpec{
|
||||
Containers: []api.Container{
|
||||
{
|
||||
Name: "portforwardtester",
|
||||
Image: "gcr.io/google_containers/portforwardtester:1.0",
|
||||
Env: []api.EnvVar{
|
||||
{
|
||||
Name: "BIND_PORT",
|
||||
Value: "80",
|
||||
},
|
||||
{
|
||||
Name: "EXPECTED_CLIENT_DATA",
|
||||
Value: expectedClientData,
|
||||
},
|
||||
{
|
||||
Name: "CHUNKS",
|
||||
Value: chunks,
|
||||
},
|
||||
{
|
||||
Name: "CHUNK_SIZE",
|
||||
Value: chunkSize,
|
||||
},
|
||||
{
|
||||
Name: "CHUNK_INTERVAL",
|
||||
Value: chunkIntervalMillis,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
RestartPolicy: api.RestartPolicyNever,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func runPortForward(ns, podName string, port int) (*exec.Cmd, int) {
|
||||
cmd := kubectlCmd("port-forward", fmt.Sprintf("--namespace=%v", ns), podName, fmt.Sprintf(":%d", port))
|
||||
// This is somewhat ugly but is the only way to retrieve the port that was picked
|
||||
// by the port-forward command. We don't want to hard code the port as we have no
|
||||
// way of guaranteeing we can pick one that isn't in use, particularly on Jenkins.
|
||||
Logf("starting port-forward command and streaming output")
|
||||
stdout, stderr, err := startCmdAndStreamOutput(cmd)
|
||||
if err != nil {
|
||||
Failf("Failed to start port-forward command: %v", err)
|
||||
}
|
||||
defer stdout.Close()
|
||||
defer stderr.Close()
|
||||
|
||||
buf := make([]byte, 128)
|
||||
var n int
|
||||
Logf("reading from `kubectl port-forward` command's stderr")
|
||||
if n, err = stderr.Read(buf); err != nil {
|
||||
Failf("Failed to read from kubectl port-forward stderr: %v", err)
|
||||
}
|
||||
portForwardOutput := string(buf[:n])
|
||||
match := portForwardRegexp.FindStringSubmatch(portForwardOutput)
|
||||
if len(match) != 2 {
|
||||
Failf("Failed to parse kubectl port-forward output: %s", portForwardOutput)
|
||||
}
|
||||
|
||||
listenPort, err := strconv.Atoi(match[1])
|
||||
if err != nil {
|
||||
Failf("Error converting %s to an int: %v", match[1], err)
|
||||
}
|
||||
|
||||
return cmd, listenPort
|
||||
}
|
||||
|
||||
var _ = Describe("Port forwarding", func() {
|
||||
framework := NewFramework("port-forwarding")
|
||||
|
||||
Describe("With a server that expects a client request", func() {
|
||||
It("should support a client that connects, sends no data, and disconnects", func() {
|
||||
By("creating the target pod")
|
||||
pod := pfPod("abc", "1", "1", "1")
|
||||
framework.Client.Pods(framework.Namespace.Name).Create(pod)
|
||||
framework.WaitForPodRunning(pod.Name)
|
||||
|
||||
By("Running 'kubectl port-forward'")
|
||||
cmd, listenPort := runPortForward(framework.Namespace.Name, pod.Name, 80)
|
||||
defer tryKill(cmd)
|
||||
|
||||
By("Dialing the local port")
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort))
|
||||
if err != nil {
|
||||
Failf("Couldn't connect to port %d: %v", listenPort, err)
|
||||
}
|
||||
|
||||
By("Closing the connection to the local port")
|
||||
conn.Close()
|
||||
|
||||
logOutput := runKubectl("logs", fmt.Sprintf("--namespace=%v", framework.Namespace.Name), "-f", podName)
|
||||
verifyLogMessage(logOutput, "Accepted client connection")
|
||||
verifyLogMessage(logOutput, "Expected to read 3 bytes from client, but got 0 instead")
|
||||
})
|
||||
|
||||
It("should support a client that connects, sends data, and disconnects", func() {
|
||||
By("creating the target pod")
|
||||
pod := pfPod("abc", "10", "10", "100")
|
||||
framework.Client.Pods(framework.Namespace.Name).Create(pod)
|
||||
framework.WaitForPodRunning(pod.Name)
|
||||
|
||||
By("Running 'kubectl port-forward'")
|
||||
cmd, listenPort := runPortForward(framework.Namespace.Name, pod.Name, 80)
|
||||
defer tryKill(cmd)
|
||||
|
||||
By("Dialing the local port")
|
||||
addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort))
|
||||
if err != nil {
|
||||
Failf("Error resolving tcp addr: %v", err)
|
||||
}
|
||||
conn, err := net.DialTCP("tcp", nil, addr)
|
||||
if err != nil {
|
||||
Failf("Couldn't connect to port %d: %v", listenPort, err)
|
||||
}
|
||||
defer func() {
|
||||
By("Closing the connection to the local port")
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
By("Sending the expected data to the local port")
|
||||
fmt.Fprint(conn, "abc")
|
||||
|
||||
By("Closing the write half of the client's connection")
|
||||
conn.CloseWrite()
|
||||
|
||||
By("Reading data from the local port")
|
||||
fromServer, err := ioutil.ReadAll(conn)
|
||||
if err != nil {
|
||||
Failf("Unexpected error reading data from the server: %v", err)
|
||||
}
|
||||
|
||||
if e, a := strings.Repeat("x", 100), string(fromServer); e != a {
|
||||
Failf("Expected %q from server, got %q", e, a)
|
||||
}
|
||||
|
||||
logOutput := runKubectl("logs", fmt.Sprintf("--namespace=%v", framework.Namespace.Name), "-f", podName)
|
||||
verifyLogMessage(logOutput, "^Accepted client connection$")
|
||||
verifyLogMessage(logOutput, "^Received expected client data$")
|
||||
verifyLogMessage(logOutput, "^Done$")
|
||||
})
|
||||
})
|
||||
Describe("With a server that expects no client request", func() {
|
||||
It("should support a client that connects, sends no data, and disconnects", func() {
|
||||
By("creating the target pod")
|
||||
pod := pfPod("", "10", "10", "100")
|
||||
framework.Client.Pods(framework.Namespace.Name).Create(pod)
|
||||
framework.WaitForPodRunning(pod.Name)
|
||||
|
||||
By("Running 'kubectl port-forward'")
|
||||
cmd, listenPort := runPortForward(framework.Namespace.Name, pod.Name, 80)
|
||||
defer tryKill(cmd)
|
||||
|
||||
By("Dialing the local port")
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenPort))
|
||||
if err != nil {
|
||||
Failf("Couldn't connect to port %d: %v", listenPort, err)
|
||||
}
|
||||
defer func() {
|
||||
By("Closing the connection to the local port")
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
By("Reading data from the local port")
|
||||
fromServer, err := ioutil.ReadAll(conn)
|
||||
if err != nil {
|
||||
Failf("Unexpected error reading data from the server: %v", err)
|
||||
}
|
||||
|
||||
if e, a := strings.Repeat("x", 100), string(fromServer); e != a {
|
||||
Failf("Expected %q from server, got %q", e, a)
|
||||
}
|
||||
|
||||
logOutput := runKubectl("logs", fmt.Sprintf("--namespace=%v", framework.Namespace.Name), "-f", podName)
|
||||
verifyLogMessage(logOutput, "Accepted client connection")
|
||||
verifyLogMessage(logOutput, "Done")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
func verifyLogMessage(log, expected string) {
|
||||
re := regexp.MustCompile(expected)
|
||||
lines := strings.Split(log, "\n")
|
||||
for i := range lines {
|
||||
if re.MatchString(lines[i]) {
|
||||
return
|
||||
}
|
||||
}
|
||||
Failf("Missing %q from log: %s", expected, log)
|
||||
}
|
1
test/images/port-forward-tester/.gitignore
vendored
Normal file
1
test/images/port-forward-tester/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
portforwardtester
|
18
test/images/port-forward-tester/Dockerfile
Normal file
18
test/images/port-forward-tester/Dockerfile
Normal file
@ -0,0 +1,18 @@
|
||||
# Copyright 2015 The Kubernetes Authors All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM scratch
|
||||
ADD portforwardtester portforwardtester
|
||||
ADD portforwardtester.go portforwardtester.go
|
||||
ENTRYPOINT ["/portforwardtester"]
|
15
test/images/port-forward-tester/Makefile
Normal file
15
test/images/port-forward-tester/Makefile
Normal file
@ -0,0 +1,15 @@
|
||||
all: push
|
||||
|
||||
TAG = 1.0
|
||||
|
||||
portforwardtester: portforwardtester.go
|
||||
CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -ldflags '-w' ./portforwardtester.go
|
||||
|
||||
image: portforwardtester
|
||||
docker build -t gcr.io/google_containers/portforwardtester:$(TAG) .
|
||||
|
||||
push: image
|
||||
gcloud docker push gcr.io/google_containers/portforwardtester:$(TAG)
|
||||
|
||||
clean:
|
||||
rm -f portforwardtester
|
106
test/images/port-forward-tester/portforwardtester.go
Normal file
106
test/images/port-forward-tester/portforwardtester.go
Normal file
@ -0,0 +1,106 @@
|
||||
/*
|
||||
Copyright 2015 The Kubernetes Authors All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// A tiny binary for testing port forwarding. The following environment variables
|
||||
// control the binary's logic:
|
||||
//
|
||||
// BIND_PORT - the TCP port to use for the listener
|
||||
// EXPECTED_CLIENT_DATA - data that we expect to receive from the client; may be "".
|
||||
// CHUNKS - how many chunks of data we should send to the client
|
||||
// CHUNK_SIZE - how large each chunk should be
|
||||
// CHUNK_INTERVAL - the delay in between sending each chunk
|
||||
//
|
||||
// Log messages are written to stdout at various stages of the binary's execution.
|
||||
// Test code can retrieve this container's log and validate that the expected
|
||||
// behavior is taking place.
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func getEnvInt(name string) int {
|
||||
s := os.Getenv(name)
|
||||
value, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
fmt.Printf("Error parsing %s %q: %v\n", name, s, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func main() {
|
||||
bindPort := os.Getenv("BIND_PORT")
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%s", bindPort))
|
||||
if err != nil {
|
||||
fmt.Printf("Error listening: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
fmt.Printf("Error accepting connection: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer conn.Close()
|
||||
fmt.Println("Accepted client connection")
|
||||
|
||||
expectedClientData := os.Getenv("EXPECTED_CLIENT_DATA")
|
||||
if len(expectedClientData) > 0 {
|
||||
buf := make([]byte, len(expectedClientData))
|
||||
read, err := conn.Read(buf)
|
||||
if read != len(expectedClientData) {
|
||||
fmt.Printf("Expected to read %d bytes from client, but got %d instead. err=%v\n", len(expectedClientData), read, err)
|
||||
os.Exit(2)
|
||||
}
|
||||
if expectedClientData != string(buf) {
|
||||
fmt.Printf("Expect to read %q, but got %q. err=%v\n", expectedClientData, string(buf), err)
|
||||
os.Exit(3)
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Printf("Read err: %v\n", err)
|
||||
}
|
||||
fmt.Println("Received expected client data")
|
||||
}
|
||||
|
||||
chunks := getEnvInt("CHUNKS")
|
||||
chunkSize := getEnvInt("CHUNK_SIZE")
|
||||
chunkInterval := getEnvInt("CHUNK_INTERVAL")
|
||||
|
||||
stringData := strings.Repeat("x", chunkSize)
|
||||
data := []byte(stringData)
|
||||
|
||||
for i := 0; i < chunks; i++ {
|
||||
written, err := conn.Write(data)
|
||||
if written != chunkSize {
|
||||
fmt.Printf("Expected to write %d bytes from client, but wrote %d instead. err=%v\n", chunkSize, written, err)
|
||||
os.Exit(4)
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Printf("Write err: %v\n", err)
|
||||
}
|
||||
if i+1 < chunks {
|
||||
time.Sleep(time.Duration(chunkInterval) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("Done")
|
||||
}
|
Loading…
Reference in New Issue
Block a user