Update unit test.

This commit is contained in:
Lantao Liu 2018-05-18 16:08:44 -07:00
parent 174b6d0e2f
commit 1eb721248b
3 changed files with 571 additions and 646 deletions

View File

@ -17,7 +17,6 @@ limitations under the License.
package kubelet package kubelet
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -2095,7 +2094,7 @@ func (f *fakeReadWriteCloser) Close() error {
return nil return nil
} }
func TestExec(t *testing.T) { func TestGetExec(t *testing.T) {
const ( const (
podName = "podFoo" podName = "podFoo"
podNamespace = "nsFoo" podNamespace = "nsFoo"
@ -2106,9 +2105,6 @@ func TestExec(t *testing.T) {
var ( var (
podFullName = kubecontainer.GetPodFullName(podWithUIDNameNs(podUID, podName, podNamespace)) podFullName = kubecontainer.GetPodFullName(podWithUIDNameNs(podUID, podName, podNamespace))
command = []string{"ls"} command = []string{"ls"}
stdin = &bytes.Buffer{}
stdout = &fakeReadWriteCloser{}
stderr = &fakeReadWriteCloser{}
) )
testcases := []struct { testcases := []struct {
@ -2161,22 +2157,16 @@ func TestExec(t *testing.T) {
assert.NoError(t, err, description) assert.NoError(t, err, description)
assert.Equal(t, containertest.FakeHost, redirect.Host, description+": redirect") assert.Equal(t, containertest.FakeHost, redirect.Host, description+": redirect")
} }
err = kubelet.ExecInContainer(tc.podFullName, podUID, tc.container, command, stdin, stdout, stderr, tty, nil, 0)
assert.Error(t, err, description)
} }
} }
func TestPortForward(t *testing.T) { func TestGetPortForward(t *testing.T) {
const ( const (
podName = "podFoo" podName = "podFoo"
podNamespace = "nsFoo" podNamespace = "nsFoo"
podUID types.UID = "12345678" podUID types.UID = "12345678"
port int32 = 5000 port int32 = 5000
) )
var (
stream = &fakeReadWriteCloser{}
)
testcases := []struct { testcases := []struct {
description string description string
@ -2208,7 +2198,6 @@ func TestPortForward(t *testing.T) {
}}, }},
} }
podFullName := kubecontainer.GetPodFullName(podWithUIDNameNs(podUID, tc.podName, podNamespace))
description := "streaming - " + tc.description description := "streaming - " + tc.description
fakeRuntime := &containertest.FakeStreamingRuntime{FakeRuntime: testKubelet.fakeRuntime} fakeRuntime := &containertest.FakeStreamingRuntime{FakeRuntime: testKubelet.fakeRuntime}
kubelet.containerRuntime = fakeRuntime kubelet.containerRuntime = fakeRuntime
@ -2221,9 +2210,6 @@ func TestPortForward(t *testing.T) {
assert.NoError(t, err, description) assert.NoError(t, err, description)
assert.Equal(t, containertest.FakeHost, redirect.Host, description+": redirect") assert.Equal(t, containertest.FakeHost, redirect.Host, description+": redirect")
} }
err = kubelet.PortForward(podFullName, podUID, port, stream)
assert.Error(t, err, description)
} }
} }

View File

@ -46,21 +46,25 @@ import (
"k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/authentication/user"
"k8s.io/apiserver/pkg/authorization/authorizer" "k8s.io/apiserver/pkg/authorization/authorizer"
"k8s.io/client-go/tools/remotecommand" "k8s.io/client-go/tools/remotecommand"
utiltesting "k8s.io/client-go/util/testing"
api "k8s.io/kubernetes/pkg/apis/core" api "k8s.io/kubernetes/pkg/apis/core"
runtimeapi "k8s.io/kubernetes/pkg/kubelet/apis/cri/runtime/v1alpha2"
statsapi "k8s.io/kubernetes/pkg/kubelet/apis/stats/v1alpha1" statsapi "k8s.io/kubernetes/pkg/kubelet/apis/stats/v1alpha1"
// Do some initialization to decode the query parameters correctly. // Do some initialization to decode the query parameters correctly.
_ "k8s.io/kubernetes/pkg/apis/core/install" _ "k8s.io/kubernetes/pkg/apis/core/install"
"k8s.io/kubernetes/pkg/kubelet/cm" "k8s.io/kubernetes/pkg/kubelet/cm"
kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
kubecontainertesting "k8s.io/kubernetes/pkg/kubelet/container/testing"
"k8s.io/kubernetes/pkg/kubelet/server/portforward" "k8s.io/kubernetes/pkg/kubelet/server/portforward"
remotecommandserver "k8s.io/kubernetes/pkg/kubelet/server/remotecommand" remotecommandserver "k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
"k8s.io/kubernetes/pkg/kubelet/server/stats" "k8s.io/kubernetes/pkg/kubelet/server/stats"
"k8s.io/kubernetes/pkg/kubelet/server/streaming"
"k8s.io/kubernetes/pkg/volume" "k8s.io/kubernetes/pkg/volume"
) )
const ( const (
testUID = "9b01b80f-8fb4-11e4-95ab-4200af06647" testUID = "9b01b80f-8fb4-11e4-95ab-4200af06647"
testContainerID = "container789"
testPodSandboxID = "pod0987"
) )
type fakeKubelet struct { type fakeKubelet struct {
@ -72,16 +76,16 @@ type fakeKubelet struct {
runningPodsFunc func() ([]*v1.Pod, error) runningPodsFunc func() ([]*v1.Pod, error)
logFunc func(w http.ResponseWriter, req *http.Request) logFunc func(w http.ResponseWriter, req *http.Request)
runFunc func(podFullName string, uid types.UID, containerName string, cmd []string) ([]byte, error) runFunc func(podFullName string, uid types.UID, containerName string, cmd []string) ([]byte, error)
execFunc func(pod string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error getExecCheck func(string, types.UID, string, []string, remotecommandserver.Options)
attachFunc func(pod string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool) error getAttachCheck func(string, types.UID, string, remotecommandserver.Options)
portForwardFunc func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error getPortForwardCheck func(string, string, types.UID, portforward.V4Options)
containerLogsFunc func(podFullName, containerName string, logOptions *v1.PodLogOptions, stdout, stderr io.Writer) error containerLogsFunc func(podFullName, containerName string, logOptions *v1.PodLogOptions, stdout, stderr io.Writer) error
streamingConnectionIdleTimeoutFunc func() time.Duration
hostnameFunc func() string hostnameFunc func() string
resyncInterval time.Duration resyncInterval time.Duration
loopEntryTime time.Time loopEntryTime time.Time
plegHealth bool plegHealth bool
redirectURL *url.URL streamingRuntime streaming.Server
} }
func (fk *fakeKubelet) ResyncInterval() time.Duration { func (fk *fakeKubelet) ResyncInterval() time.Duration {
@ -136,32 +140,109 @@ func (fk *fakeKubelet) RunInContainer(podFullName string, uid types.UID, contain
return fk.runFunc(podFullName, uid, containerName, cmd) return fk.runFunc(podFullName, uid, containerName, cmd)
} }
func (fk *fakeKubelet) ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize, timeout time.Duration) error { type fakeRuntime struct {
return fk.execFunc(name, uid, container, cmd, in, out, err, tty) execFunc func(string, []string, io.Reader, io.WriteCloser, io.WriteCloser, bool, <-chan remotecommand.TerminalSize) error
attachFunc func(string, io.Reader, io.WriteCloser, io.WriteCloser, bool, <-chan remotecommand.TerminalSize) error
portForwardFunc func(string, int32, io.ReadWriteCloser) error
} }
func (fk *fakeKubelet) AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { func (f *fakeRuntime) Exec(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
return fk.attachFunc(name, uid, container, in, out, err, tty) return f.execFunc(containerID, cmd, stdin, stdout, stderr, tty, resize)
} }
func (fk *fakeKubelet) PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { func (f *fakeRuntime) Attach(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
return fk.portForwardFunc(name, uid, port, stream) return f.attachFunc(containerID, stdin, stdout, stderr, tty, resize)
}
func (f *fakeRuntime) PortForward(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
return f.portForwardFunc(podSandboxID, port, stream)
}
type testStreamingServer struct {
streaming.Server
fakeRuntime *fakeRuntime
testHTTPServer *httptest.Server
}
func newTestStreamingServer(streamIdleTimeout time.Duration) (s *testStreamingServer, err error) {
s = &testStreamingServer{}
s.testHTTPServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.ServeHTTP(w, r)
}))
defer func() {
if err != nil {
s.testHTTPServer.Close()
}
}()
testURL, err := url.Parse(s.testHTTPServer.URL)
if err != nil {
return nil, err
}
s.fakeRuntime = &fakeRuntime{}
config := streaming.DefaultConfig
config.BaseURL = testURL
if streamIdleTimeout != 0 {
config.StreamIdleTimeout = streamIdleTimeout
}
s.Server, err = streaming.NewServer(config, s.fakeRuntime)
if err != nil {
return nil, err
}
return s, nil
} }
func (fk *fakeKubelet) GetExec(podFullName string, podUID types.UID, containerName string, cmd []string, streamOpts remotecommandserver.Options) (*url.URL, error) { func (fk *fakeKubelet) GetExec(podFullName string, podUID types.UID, containerName string, cmd []string, streamOpts remotecommandserver.Options) (*url.URL, error) {
return fk.redirectURL, nil if fk.getExecCheck != nil {
fk.getExecCheck(podFullName, podUID, containerName, cmd, streamOpts)
}
// Always use testContainerID
resp, err := fk.streamingRuntime.GetExec(&runtimeapi.ExecRequest{
ContainerId: testContainerID,
Cmd: cmd,
Tty: streamOpts.TTY,
Stdin: streamOpts.Stdin,
Stdout: streamOpts.Stdout,
Stderr: streamOpts.Stderr,
})
if err != nil {
return nil, err
}
return url.Parse(resp.GetUrl())
} }
func (fk *fakeKubelet) GetAttach(podFullName string, podUID types.UID, containerName string, streamOpts remotecommandserver.Options) (*url.URL, error) { func (fk *fakeKubelet) GetAttach(podFullName string, podUID types.UID, containerName string, streamOpts remotecommandserver.Options) (*url.URL, error) {
return fk.redirectURL, nil if fk.getAttachCheck != nil {
fk.getAttachCheck(podFullName, podUID, containerName, streamOpts)
}
// Always use testContainerID
resp, err := fk.streamingRuntime.GetAttach(&runtimeapi.AttachRequest{
ContainerId: testContainerID,
Tty: streamOpts.TTY,
Stdin: streamOpts.Stdin,
Stdout: streamOpts.Stdout,
Stderr: streamOpts.Stderr,
})
if err != nil {
return nil, err
}
return url.Parse(resp.GetUrl())
} }
func (fk *fakeKubelet) GetPortForward(podName, podNamespace string, podUID types.UID, portForwardOpts portforward.V4Options) (*url.URL, error) { func (fk *fakeKubelet) GetPortForward(podName, podNamespace string, podUID types.UID, portForwardOpts portforward.V4Options) (*url.URL, error) {
return fk.redirectURL, nil if fk.getPortForwardCheck != nil {
fk.getPortForwardCheck(podName, podNamespace, podUID, portForwardOpts)
} }
// Always use testPodSandboxID
func (fk *fakeKubelet) StreamingConnectionIdleTimeout() time.Duration { resp, err := fk.streamingRuntime.GetPortForward(&runtimeapi.PortForwardRequest{
return fk.streamingConnectionIdleTimeoutFunc() PodSandboxId: testPodSandboxID,
Port: portForwardOpts.Ports,
})
if err != nil {
return nil, err
}
return url.Parse(resp.GetUrl())
} }
// Unused functions // Unused functions
@ -202,13 +283,16 @@ type serverTestFramework struct {
fakeKubelet *fakeKubelet fakeKubelet *fakeKubelet
fakeAuth *fakeAuth fakeAuth *fakeAuth
testHTTPServer *httptest.Server testHTTPServer *httptest.Server
fakeRuntime *fakeRuntime
testStreamingHTTPServer *httptest.Server
criHandler *utiltesting.FakeHandler
} }
func newServerTest() *serverTestFramework { func newServerTest() *serverTestFramework {
return newServerTestWithDebug(true) return newServerTestWithDebug(true, false, nil)
} }
func newServerTestWithDebug(enableDebugging bool) *serverTestFramework { func newServerTestWithDebug(enableDebugging, redirectContainerStreaming bool, streamingServer streaming.Server) *serverTestFramework {
fw := &serverTestFramework{} fw := &serverTestFramework{}
fw.fakeKubelet = &fakeKubelet{ fw.fakeKubelet = &fakeKubelet{
hostnameFunc: func() string { hostnameFunc: func() string {
@ -224,6 +308,7 @@ func newServerTestWithDebug(enableDebugging bool) *serverTestFramework {
}, true }, true
}, },
plegHealth: true, plegHealth: true,
streamingRuntime: streamingServer,
} }
fw.fakeAuth = &fakeAuth{ fw.fakeAuth = &fakeAuth{
authenticateFunc: func(req *http.Request) (user.Info, bool, error) { authenticateFunc: func(req *http.Request) (user.Info, bool, error) {
@ -236,13 +321,17 @@ func newServerTestWithDebug(enableDebugging bool) *serverTestFramework {
return authorizer.DecisionAllow, "", nil return authorizer.DecisionAllow, "", nil
}, },
} }
fw.criHandler = &utiltesting.FakeHandler{
StatusCode: http.StatusOK,
}
server := NewServer( server := NewServer(
fw.fakeKubelet, fw.fakeKubelet,
stats.NewResourceAnalyzer(fw.fakeKubelet, time.Minute), stats.NewResourceAnalyzer(fw.fakeKubelet, time.Minute),
fw.fakeAuth, fw.fakeAuth,
enableDebugging, enableDebugging,
false, false,
&kubecontainertesting.Mock{}) redirectContainerStreaming,
fw.criHandler)
fw.serverUnderTest = &server fw.serverUnderTest = &server
fw.testHTTPServer = httptest.NewServer(fw.serverUnderTest) fw.testHTTPServer = httptest.NewServer(fw.serverUnderTest)
return fw return fw
@ -1064,13 +1153,12 @@ func TestContainerLogsWithFollow(t *testing.T) {
} }
func TestServeExecInContainerIdleTimeout(t *testing.T) { func TestServeExecInContainerIdleTimeout(t *testing.T) {
fw := newServerTest() ss, err := newTestStreamingServer(100 * time.Millisecond)
require.NoError(t, err)
defer ss.testHTTPServer.Close()
fw := newServerTestWithDebug(true, false, ss)
defer fw.testHTTPServer.Close() defer fw.testHTTPServer.Close()
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 100 * time.Millisecond
}
podNamespace := "other" podNamespace := "other"
podName := "foo" podName := "foo"
expectedContainerName := "baz" expectedContainerName := "baz"
@ -1102,38 +1190,35 @@ func TestServeExecInContainerIdleTimeout(t *testing.T) {
} }
func testExecAttach(t *testing.T, verb string) { func testExecAttach(t *testing.T, verb string) {
tests := []struct { tests := map[string]struct {
stdin bool stdin bool
stdout bool stdout bool
stderr bool stderr bool
tty bool tty bool
responseStatusCode int responseStatusCode int
uid bool uid bool
responseLocation string redirect bool
}{ }{
{responseStatusCode: http.StatusBadRequest}, "no input or output": {responseStatusCode: http.StatusBadRequest},
{stdin: true, responseStatusCode: http.StatusSwitchingProtocols}, "stdin": {stdin: true, responseStatusCode: http.StatusSwitchingProtocols},
{stdout: true, responseStatusCode: http.StatusSwitchingProtocols}, "stdout": {stdout: true, responseStatusCode: http.StatusSwitchingProtocols},
{stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, "stderr": {stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
{stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, "stdout and stderr": {stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
{stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols}, "stdout stderr and tty": {stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols},
{stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, "stdin stdout and stderr": {stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
{stdout: true, responseStatusCode: http.StatusFound, responseLocation: "http://localhost:12345/" + verb}, "stdin stdout stderr with uid": {stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols, uid: true},
"stdout with redirect": {stdout: true, responseStatusCode: http.StatusFound, redirect: true},
} }
for i, test := range tests { for desc, test := range tests {
fw := newServerTest() test := test
defer fw.testHTTPServer.Close() t.Run(desc, func(t *testing.T) {
ss, err := newTestStreamingServer(0)
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 0
}
if test.responseLocation != "" {
var err error
fw.fakeKubelet.redirectURL, err = url.Parse(test.responseLocation)
require.NoError(t, err) require.NoError(t, err)
} defer ss.testHTTPServer.Close()
fw := newServerTestWithDebug(true, test.redirect, ss)
defer fw.testHTTPServer.Close()
fmt.Println(desc)
podNamespace := "other" podNamespace := "other"
podName := "foo" podName := "foo"
@ -1149,81 +1234,67 @@ func testExecAttach(t *testing.T, verb string) {
execInvoked := false execInvoked := false
attachInvoked := false attachInvoked := false
testStreamFunc := func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool, done chan struct{}) error { checkStream := func(podFullName string, uid types.UID, containerName string, streamOpts remotecommandserver.Options) {
defer close(done) assert.Equal(t, expectedPodName, podFullName, "podFullName")
if test.uid {
assert.Equal(t, testUID, string(uid), "uid")
}
assert.Equal(t, expectedContainerName, containerName, "containerName")
assert.Equal(t, test.stdin, streamOpts.Stdin, "stdin")
assert.Equal(t, test.stdout, streamOpts.Stdout, "stdout")
assert.Equal(t, test.tty, streamOpts.TTY, "tty")
assert.Equal(t, !test.tty && test.stderr, streamOpts.Stderr, "stderr")
}
if podFullName != expectedPodName { fw.fakeKubelet.getExecCheck = func(podFullName string, uid types.UID, containerName string, cmd []string, streamOpts remotecommandserver.Options) {
t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName) execInvoked = true
assert.Equal(t, expectedCommand, strings.Join(cmd, " "), "cmd")
checkStream(podFullName, uid, containerName, streamOpts)
} }
if test.uid && string(uid) != testUID {
t.Fatalf("%d: uid: expected %v, got %v", i, testUID, uid) fw.fakeKubelet.getAttachCheck = func(podFullName string, uid types.UID, containerName string, streamOpts remotecommandserver.Options) {
} attachInvoked = true
if containerName != expectedContainerName { checkStream(podFullName, uid, containerName, streamOpts)
t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName)
} }
testStream := func(containerID string, in io.Reader, out, stderr io.WriteCloser, tty bool, done chan struct{}) error {
close(done)
assert.Equal(t, testContainerID, containerID, "containerID")
assert.Equal(t, test.tty, tty, "tty")
require.Equal(t, test.stdin, in != nil, "in")
require.Equal(t, test.stdout, out != nil, "out")
require.Equal(t, !test.tty && test.stderr, stderr != nil, "err")
if test.stdin { if test.stdin {
if in == nil {
t.Fatalf("%d: stdin: expected non-nil", i)
}
b := make([]byte, 10) b := make([]byte, 10)
n, err := in.Read(b) n, err := in.Read(b)
if err != nil { assert.NoError(t, err, "reading from stdin")
t.Fatalf("%d: error reading from stdin: %v", i, err) assert.Equal(t, expectedStdin, string(b[0:n]), "content from stdin")
}
if e, a := expectedStdin, string(b[0:n]); e != a {
t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a)
}
} else if in != nil {
t.Fatalf("%d: stdin: expected nil: %#v", i, in)
} }
if test.stdout { if test.stdout {
if out == nil {
t.Fatalf("%d: stdout: expected non-nil", i)
}
_, err := out.Write([]byte(expectedStdout)) _, err := out.Write([]byte(expectedStdout))
if err != nil { assert.NoError(t, err, "writing to stdout")
t.Fatalf("%d:, error writing to stdout: %v", i, err)
}
out.Close() out.Close()
<-clientStdoutReadDone <-clientStdoutReadDone
} else if out != nil {
t.Fatalf("%d: stdout: expected nil: %#v", i, out)
} }
if tty { if !test.tty && test.stderr {
if stderr != nil {
t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr)
}
} else if test.stderr {
if stderr == nil {
t.Fatalf("%d: stderr: expected non-nil", i)
}
_, err := stderr.Write([]byte(expectedStderr)) _, err := stderr.Write([]byte(expectedStderr))
if err != nil { assert.NoError(t, err, "writing to stderr")
t.Fatalf("%d:, error writing to stderr: %v", i, err)
}
stderr.Close() stderr.Close()
<-clientStderrReadDone <-clientStderrReadDone
} else if stderr != nil {
t.Fatalf("%d: stderr: expected nil: %#v", i, stderr)
} }
return nil return nil
} }
fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error { ss.fakeRuntime.execFunc = func(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
execInvoked = true assert.Equal(t, expectedCommand, strings.Join(cmd, " "), "cmd")
if strings.Join(cmd, " ") != expectedCommand { return testStream(containerID, stdin, stdout, stderr, tty, done)
t.Fatalf("%d: cmd: expected: %s, got %v", i, expectedCommand, cmd)
}
return testStreamFunc(podFullName, uid, containerName, cmd, in, out, stderr, tty, done)
} }
fw.fakeKubelet.attachFunc = func(podFullName string, uid types.UID, containerName string, in io.Reader, out, stderr io.WriteCloser, tty bool) error { ss.fakeRuntime.attachFunc = func(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
attachInvoked = true return testStream(containerID, stdin, stdout, stderr, tty, done)
return testStreamFunc(podFullName, uid, containerName, nil, in, out, stderr, tty, done)
} }
var url string var url string
@ -1250,12 +1321,10 @@ func testExecAttach(t *testing.T, verb string) {
var ( var (
resp *http.Response resp *http.Response
err error
upgradeRoundTripper httpstream.UpgradeRoundTripper upgradeRoundTripper httpstream.UpgradeRoundTripper
c *http.Client c *http.Client
) )
if test.redirect {
if test.responseStatusCode != http.StatusSwitchingProtocols {
c = &http.Client{} c = &http.Client{}
// Don't follow redirects, since we want to inspect the redirect response. // Don't follow redirects, since we want to inspect the redirect response.
c.CheckRedirect = func(*http.Request, []*http.Request) error { c.CheckRedirect = func(*http.Request, []*http.Request) error {
@ -1267,115 +1336,75 @@ func testExecAttach(t *testing.T, verb string) {
} }
resp, err = c.Post(url, "", nil) resp, err = c.Post(url, "", nil)
if err != nil { require.NoError(t, err, "POSTing")
t.Fatalf("%d: Got error POSTing: %v", i, err)
}
defer resp.Body.Close() defer resp.Body.Close()
_, err = ioutil.ReadAll(resp.Body) _, err = ioutil.ReadAll(resp.Body)
if err != nil { assert.NoError(t, err, "reading response body")
t.Errorf("%d: Error reading response body: %v", i, err)
}
if e, a := test.responseStatusCode, resp.StatusCode; e != a {
t.Fatalf("%d: response status: expected %v, got %v", i, e, a)
}
if e, a := test.responseLocation, resp.Header.Get("Location"); e != a {
t.Errorf("%d: response location: expected %v, got %v", i, e, a)
}
require.Equal(t, test.responseStatusCode, resp.StatusCode, "response status")
if test.responseStatusCode != http.StatusSwitchingProtocols { if test.responseStatusCode != http.StatusSwitchingProtocols {
continue return
} }
conn, err := upgradeRoundTripper.NewConnection(resp) conn, err := upgradeRoundTripper.NewConnection(resp)
if err != nil { require.NoError(t, err, "creating streaming connection")
t.Fatalf("Unexpected error creating streaming connection: %s", err)
}
if conn == nil {
t.Fatalf("%d: unexpected nil conn", i)
}
defer conn.Close() defer conn.Close()
h := http.Header{} h := http.Header{}
h.Set(api.StreamType, api.StreamTypeError) h.Set(api.StreamType, api.StreamTypeError)
if _, err := conn.CreateStream(h); err != nil { _, err = conn.CreateStream(h)
t.Fatalf("%d: error creating error stream: %v", i, err) require.NoError(t, err, "creating error stream")
}
if test.stdin { if test.stdin {
h.Set(api.StreamType, api.StreamTypeStdin) h.Set(api.StreamType, api.StreamTypeStdin)
stream, err := conn.CreateStream(h) stream, err := conn.CreateStream(h)
if err != nil { require.NoError(t, err, "creating stdin stream")
t.Fatalf("%d: error creating stdin stream: %v", i, err)
}
_, err = stream.Write([]byte(expectedStdin)) _, err = stream.Write([]byte(expectedStdin))
if err != nil { require.NoError(t, err, "writing to stdin stream")
t.Fatalf("%d: error writing to stdin stream: %v", i, err)
}
} }
var stdoutStream httpstream.Stream var stdoutStream httpstream.Stream
if test.stdout { if test.stdout {
h.Set(api.StreamType, api.StreamTypeStdout) h.Set(api.StreamType, api.StreamTypeStdout)
stdoutStream, err = conn.CreateStream(h) stdoutStream, err = conn.CreateStream(h)
if err != nil { require.NoError(t, err, "creating stdout stream")
t.Fatalf("%d: error creating stdout stream: %v", i, err)
}
} }
var stderrStream httpstream.Stream var stderrStream httpstream.Stream
if test.stderr && !test.tty { if test.stderr && !test.tty {
h.Set(api.StreamType, api.StreamTypeStderr) h.Set(api.StreamType, api.StreamTypeStderr)
stderrStream, err = conn.CreateStream(h) stderrStream, err = conn.CreateStream(h)
if err != nil { require.NoError(t, err, "creating stderr stream")
t.Fatalf("%d: error creating stderr stream: %v", i, err)
}
} }
if test.stdout { if test.stdout {
output := make([]byte, 10) output := make([]byte, 10)
n, err := stdoutStream.Read(output) n, err := stdoutStream.Read(output)
close(clientStdoutReadDone) close(clientStdoutReadDone)
if err != nil { assert.NoError(t, err, "reading from stdout stream")
t.Fatalf("%d: error reading from stdout stream: %v", i, err) assert.Equal(t, expectedStdout, string(output[0:n]), "stdout")
}
if e, a := expectedStdout, string(output[0:n]); e != a {
t.Fatalf("%d: stdout: expected '%v', got '%v'", i, e, a)
}
} }
if test.stderr && !test.tty { if test.stderr && !test.tty {
output := make([]byte, 10) output := make([]byte, 10)
n, err := stderrStream.Read(output) n, err := stderrStream.Read(output)
close(clientStderrReadDone) close(clientStderrReadDone)
if err != nil { assert.NoError(t, err, "reading from stderr stream")
t.Fatalf("%d: error reading from stderr stream: %v", i, err) assert.Equal(t, expectedStderr, string(output[0:n]), "stderr")
}
if e, a := expectedStderr, string(output[0:n]); e != a {
t.Fatalf("%d: stderr: expected '%v', got '%v'", i, e, a)
}
} }
// wait for the server to finish before checking if the attach/exec funcs were invoked // wait for the server to finish before checking if the attach/exec funcs were invoked
<-done <-done
if verb == "exec" { if verb == "exec" {
if !execInvoked { assert.True(t, execInvoked, "exec should be invoked")
t.Errorf("%d: exec was not invoked", i) assert.False(t, attachInvoked, "attach should not be invoked")
}
if attachInvoked {
t.Errorf("%d: attach should not have been invoked", i)
}
} else { } else {
if !attachInvoked { assert.True(t, attachInvoked, "attach should be invoked")
t.Errorf("%d: attach was not invoked", i) assert.False(t, execInvoked, "exec should not be invoked")
}
if execInvoked {
t.Errorf("%d: exec should not have been invoked", i)
}
} }
})
} }
} }
@ -1388,13 +1417,12 @@ func TestServeAttachContainer(t *testing.T) {
} }
func TestServePortForwardIdleTimeout(t *testing.T) { func TestServePortForwardIdleTimeout(t *testing.T) {
fw := newServerTest() ss, err := newTestStreamingServer(100 * time.Millisecond)
require.NoError(t, err)
defer ss.testHTTPServer.Close()
fw := newServerTestWithDebug(true, false, ss)
defer fw.testHTTPServer.Close() defer fw.testHTTPServer.Close()
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 100 * time.Millisecond
}
podNamespace := "other" podNamespace := "other"
podName := "foo" podName := "foo"
@ -1422,82 +1450,67 @@ func TestServePortForwardIdleTimeout(t *testing.T) {
} }
func TestServePortForward(t *testing.T) { func TestServePortForward(t *testing.T) {
tests := []struct { tests := map[string]struct {
port string port string
uid bool uid bool
clientData string clientData string
containerData string containerData string
redirect bool
shouldError bool shouldError bool
responseLocation string
}{ }{
{port: "", shouldError: true}, "no port": {port: "", shouldError: true},
{port: "abc", shouldError: true}, "none number port": {port: "abc", shouldError: true},
{port: "-1", shouldError: true}, "negative port": {port: "-1", shouldError: true},
{port: "65536", shouldError: true}, "too large port": {port: "65536", shouldError: true},
{port: "0", shouldError: true}, "0 port": {port: "0", shouldError: true},
{port: "1", shouldError: false}, "min port": {port: "1", shouldError: false},
{port: "8000", shouldError: false}, "normal port": {port: "8000", shouldError: false},
{port: "8000", clientData: "client data", containerData: "container data", shouldError: false}, "normal port with data forward": {port: "8000", clientData: "client data", containerData: "container data", shouldError: false},
{port: "65535", shouldError: false}, "max port": {port: "65535", shouldError: false},
{port: "65535", uid: true, shouldError: false}, "normal port with uid": {port: "8000", uid: true, shouldError: false},
{port: "65535", responseLocation: "http://localhost:12345/portforward", shouldError: false}, "normal port with redirect": {port: "8000", redirect: true, shouldError: false},
} }
podNamespace := "other" podNamespace := "other"
podName := "foo" podName := "foo"
expectedPodName := getPodName(podName, podNamespace)
for i, test := range tests { for desc, test := range tests {
fw := newServerTest() test := test
defer fw.testHTTPServer.Close() t.Run(desc, func(t *testing.T) {
ss, err := newTestStreamingServer(0)
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 0
}
if test.responseLocation != "" {
var err error
fw.fakeKubelet.redirectURL, err = url.Parse(test.responseLocation)
require.NoError(t, err) require.NoError(t, err)
} defer ss.testHTTPServer.Close()
fw := newServerTestWithDebug(true, test.redirect, ss)
defer fw.testHTTPServer.Close()
portForwardFuncDone := make(chan struct{}) portForwardFuncDone := make(chan struct{})
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) {
assert.Equal(t, podName, name, "pod name")
assert.Equal(t, podNamespace, namespace, "pod namespace")
if test.uid {
assert.Equal(t, testUID, string(uid), "uid")
}
}
ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
defer close(portForwardFuncDone) defer close(portForwardFuncDone)
assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id")
if e, a := expectedPodName, name; e != a { // The port should be valid if it reaches here.
t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a) testPort, err := strconv.ParseInt(test.port, 10, 32)
} require.NoError(t, err, "parse port")
assert.Equal(t, int32(testPort), port, "port")
if e, a := testUID, uid; test.uid && e != string(a) {
t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a)
}
p, err := strconv.ParseInt(test.port, 10, 32)
if err != nil {
t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err)
}
if e, a := int32(p), port; e != a {
t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a)
}
if test.clientData != "" { if test.clientData != "" {
fromClient := make([]byte, 32) fromClient := make([]byte, 32)
n, err := stream.Read(fromClient) n, err := stream.Read(fromClient)
if err != nil { assert.NoError(t, err, "reading client data")
t.Fatalf("%d: error reading client data: %v", i, err) assert.Equal(t, test.clientData, string(fromClient[0:n]), "client data")
}
if e, a := test.clientData, string(fromClient[0:n]); e != a {
t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a)
}
} }
if test.containerData != "" { if test.containerData != "" {
_, err := stream.Write([]byte(test.containerData)) _, err := stream.Write([]byte(test.containerData))
if err != nil { assert.NoError(t, err, "writing container data")
t.Fatalf("%d: error writing container data: %v", i, err)
}
} }
return nil return nil
@ -1515,7 +1528,7 @@ func TestServePortForward(t *testing.T) {
c *http.Client c *http.Client
) )
if len(test.responseLocation) > 0 { if test.redirect {
c = &http.Client{} c = &http.Client{}
// Don't follow redirects, since we want to inspect the redirect response. // Don't follow redirects, since we want to inspect the redirect response.
c.CheckRedirect = func(*http.Request, []*http.Request) error { c.CheckRedirect = func(*http.Request, []*http.Request) error {
@ -1527,74 +1540,70 @@ func TestServePortForward(t *testing.T) {
} }
resp, err := c.Post(url, "", nil) resp, err := c.Post(url, "", nil)
if err != nil { require.NoError(t, err, "POSTing")
t.Fatalf("%d: Got error POSTing: %v", i, err)
}
defer resp.Body.Close() defer resp.Body.Close()
if test.responseLocation != "" { if test.redirect {
assert.Equal(t, http.StatusFound, resp.StatusCode, "%d: status code", i) assert.Equal(t, http.StatusFound, resp.StatusCode, "status code")
assert.Equal(t, test.responseLocation, resp.Header.Get("Location"), "%d: location", i) return
continue
} else { } else {
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode, "%d: status code", i) assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode, "status code")
} }
conn, err := upgradeRoundTripper.NewConnection(resp) conn, err := upgradeRoundTripper.NewConnection(resp)
if err != nil { require.NoError(t, err, "creating streaming connection")
t.Fatalf("Unexpected error creating streaming connection: %s", err)
}
if conn == nil {
t.Fatalf("%d: Unexpected nil connection", i)
}
defer conn.Close() defer conn.Close()
headers := http.Header{} headers := http.Header{}
headers.Set("streamType", "error") headers.Set("streamType", "error")
headers.Set("port", test.port) headers.Set("port", test.port)
errorStream, err := conn.CreateStream(headers) _, err = conn.CreateStream(headers)
_ = errorStream assert.Equal(t, test.shouldError, err != nil, "expect error")
haveErr := err != nil
if e, a := test.shouldError, haveErr; e != a {
t.Fatalf("%d: create stream: expected err=%t, got %t: %v", i, e, a, err)
}
if test.shouldError { if test.shouldError {
continue return
} }
headers.Set("streamType", "data") headers.Set("streamType", "data")
headers.Set("port", test.port) headers.Set("port", test.port)
dataStream, err := conn.CreateStream(headers) dataStream, err := conn.CreateStream(headers)
haveErr = err != nil require.NoError(t, err, "create stream")
if e, a := test.shouldError, haveErr; e != a {
t.Fatalf("%d: create stream: expected err=%t, got %t: %v", i, e, a, err)
}
if test.clientData != "" { if test.clientData != "" {
_, err := dataStream.Write([]byte(test.clientData)) _, err := dataStream.Write([]byte(test.clientData))
if err != nil { assert.NoError(t, err, "writing client data")
t.Fatalf("%d: unexpected error writing client data: %v", i, err)
}
} }
if test.containerData != "" { if test.containerData != "" {
fromContainer := make([]byte, 32) fromContainer := make([]byte, 32)
n, err := dataStream.Read(fromContainer) n, err := dataStream.Read(fromContainer)
if err != nil { assert.NoError(t, err, "reading container data")
t.Fatalf("%d: unexpected error reading container data: %v", i, err) assert.Equal(t, test.containerData, string(fromContainer[0:n]), "container data")
}
if e, a := test.containerData, string(fromContainer[0:n]); e != a {
t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a)
}
} }
<-portForwardFuncDone <-portForwardFuncDone
})
} }
} }
func TestCRIHandler(t *testing.T) {
fw := newServerTest()
defer fw.testHTTPServer.Close()
const (
path = "/cri/exec/123456abcdef"
query = "cmd=echo+foo"
)
resp, err := http.Get(fw.testHTTPServer.URL + path + "?" + query)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "GET", fw.criHandler.RequestReceived.Method)
assert.Equal(t, path, fw.criHandler.RequestReceived.URL.Path)
assert.Equal(t, query, fw.criHandler.RequestReceived.URL.RawQuery)
}
func TestDebuggingDisabledHandlers(t *testing.T) { func TestDebuggingDisabledHandlers(t *testing.T) {
fw := newServerTestWithDebug(false) fw := newServerTestWithDebug(false, false, nil)
defer fw.testHTTPServer.Close() defer fw.testHTTPServer.Close()
paths := []string{ paths := []string{

View File

@ -23,11 +23,13 @@ import (
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
"k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/types"
"k8s.io/kubernetes/pkg/kubelet/server/portforward"
) )
const ( const (
@ -36,75 +38,65 @@ const (
) )
func TestServeWSPortForward(t *testing.T) { func TestServeWSPortForward(t *testing.T) {
tests := []struct { tests := map[string]struct {
port string port string
uid bool uid bool
clientData string clientData string
containerData string containerData string
shouldError bool shouldError bool
}{ }{
{port: "", shouldError: true}, "no port": {port: "", shouldError: true},
{port: "abc", shouldError: true}, "none number port": {port: "abc", shouldError: true},
{port: "-1", shouldError: true}, "negative port": {port: "-1", shouldError: true},
{port: "65536", shouldError: true}, "too large port": {port: "65536", shouldError: true},
{port: "0", shouldError: true}, "0 port": {port: "0", shouldError: true},
{port: "1", shouldError: false}, "min port": {port: "1", shouldError: false},
{port: "8000", shouldError: false}, "normal port": {port: "8000", shouldError: false},
{port: "8000", clientData: "client data", containerData: "container data", shouldError: false}, "normal port with data forward": {port: "8000", clientData: "client data", containerData: "container data", shouldError: false},
{port: "65535", shouldError: false}, "max port": {port: "65535", shouldError: false},
{port: "65535", uid: true, shouldError: false}, "normal port with uid": {port: "8000", uid: true, shouldError: false},
} }
podNamespace := "other" podNamespace := "other"
podName := "foo" podName := "foo"
expectedPodName := getPodName(podName, podNamespace)
expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647"
for i, test := range tests { for desc, test := range tests {
fw := newServerTest() test := test
t.Run(desc, func(t *testing.T) {
ss, err := newTestStreamingServer(0)
require.NoError(t, err)
defer ss.testHTTPServer.Close()
fw := newServerTestWithDebug(true, false, ss)
defer fw.testHTTPServer.Close() defer fw.testHTTPServer.Close()
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 0
}
portForwardFuncDone := make(chan struct{}) portForwardFuncDone := make(chan struct{})
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) {
assert.Equal(t, podName, name, "pod name")
assert.Equal(t, podNamespace, namespace, "pod namespace")
if test.uid {
assert.Equal(t, testUID, string(uid), "uid")
}
}
ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
defer close(portForwardFuncDone) defer close(portForwardFuncDone)
assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id")
if e, a := expectedPodName, name; e != a { // The port should be valid if it reaches here.
t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a) testPort, err := strconv.ParseInt(test.port, 10, 32)
} require.NoError(t, err, "parse port")
assert.Equal(t, int32(testPort), port, "port")
if e, a := expectedUid, uid; test.uid && e != string(a) {
t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a)
}
p, err := strconv.ParseInt(test.port, 10, 32)
if err != nil {
t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err)
}
if e, a := int32(p), port; e != a {
t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a)
}
if test.clientData != "" { if test.clientData != "" {
fromClient := make([]byte, 32) fromClient := make([]byte, 32)
n, err := stream.Read(fromClient) n, err := stream.Read(fromClient)
if err != nil { assert.NoError(t, err, "reading client data")
t.Fatalf("%d: error reading client data: %v", i, err) assert.Equal(t, test.clientData, string(fromClient[0:n]), "client data")
}
if e, a := test.clientData, string(fromClient[0:n]); e != a {
t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a)
}
} }
if test.containerData != "" { if test.containerData != "" {
_, err := stream.Write([]byte(test.containerData)) _, err := stream.Write([]byte(test.containerData))
if err != nil { assert.NoError(t, err, "writing container data")
t.Fatalf("%d: error writing container data: %v", i, err)
}
} }
return nil return nil
@ -112,76 +104,48 @@ func TestServeWSPortForward(t *testing.T) {
var url string var url string
if test.uid { if test.uid {
url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, expectedUid, test.port) url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, testUID, test.port)
} else { } else {
url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port) url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port)
} }
ws, err := websocket.Dial(url, "", "http://127.0.0.1/") ws, err := websocket.Dial(url, "", "http://127.0.0.1/")
assert.Equal(t, test.shouldError, err != nil, "websocket dial")
if test.shouldError { if test.shouldError {
if err == nil { return
t.Fatalf("%d: websocket dial expected err", i)
} }
continue
} else if err != nil {
t.Fatalf("%d: websocket dial unexpected err: %v", i, err)
}
defer ws.Close() defer ws.Close()
p, err := strconv.ParseUint(test.port, 10, 16) p, err := strconv.ParseUint(test.port, 10, 16)
if err != nil { require.NoError(t, err, "parse port")
t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err)
}
p16 := uint16(p) p16 := uint16(p)
channel, data, err := wsRead(ws) channel, data, err := wsRead(ws)
if err != nil { require.NoError(t, err, "read")
t.Fatalf("%d: read failed: expected no error: got %v", i, err) assert.Equal(t, dataChannel, int(channel), "channel")
} assert.Len(t, data, binary.Size(p16), "data size")
if channel != dataChannel { assert.Equal(t, p16, binary.LittleEndian.Uint16(data), "data")
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, dataChannel)
}
if len(data) != binary.Size(p16) {
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16))
}
if e, a := p16, binary.LittleEndian.Uint16(data); e != a {
t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port)
}
channel, data, err = wsRead(ws) channel, data, err = wsRead(ws)
if err != nil { assert.NoError(t, err, "read")
t.Fatalf("%d: read succeeded: expected no error: got %v", i, err) assert.Equal(t, errorChannel, int(channel), "channel")
} assert.Len(t, data, binary.Size(p16), "data size")
if channel != errorChannel { assert.Equal(t, p16, binary.LittleEndian.Uint16(data), "data")
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, errorChannel)
}
if len(data) != binary.Size(p16) {
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16))
}
if e, a := p16, binary.LittleEndian.Uint16(data); e != a {
t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port)
}
if test.clientData != "" { if test.clientData != "" {
println("writing the client data") println("writing the client data")
err := wsWrite(ws, dataChannel, []byte(test.clientData)) err := wsWrite(ws, dataChannel, []byte(test.clientData))
if err != nil { assert.NoError(t, err, "writing client data")
t.Fatalf("%d: unexpected error writing client data: %v", i, err)
}
} }
if test.containerData != "" { if test.containerData != "" {
_, data, err = wsRead(ws) _, data, err = wsRead(ws)
if err != nil { assert.NoError(t, err, "reading container data")
t.Fatalf("%d: unexpected error reading container data: %v", i, err) assert.Equal(t, test.containerData, string(data), "container data")
}
if e, a := test.containerData, string(data); e != a {
t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a)
}
} }
<-portForwardFuncDone <-portForwardFuncDone
})
} }
} }
@ -190,45 +154,39 @@ func TestServeWSMultiplePortForward(t *testing.T) {
ports := []uint16{7000, 8000, 9000} ports := []uint16{7000, 8000, 9000}
podNamespace := "other" podNamespace := "other"
podName := "foo" podName := "foo"
expectedPodName := getPodName(podName, podNamespace)
fw := newServerTest() ss, err := newTestStreamingServer(0)
require.NoError(t, err)
defer ss.testHTTPServer.Close()
fw := newServerTestWithDebug(true, false, ss)
defer fw.testHTTPServer.Close() defer fw.testHTTPServer.Close()
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 0
}
portForwardWG := sync.WaitGroup{} portForwardWG := sync.WaitGroup{}
portForwardWG.Add(len(ports)) portForwardWG.Add(len(ports))
portsMutex := sync.Mutex{} portsMutex := sync.Mutex{}
portsForwarded := map[int32]struct{}{} portsForwarded := map[int32]struct{}{}
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) {
defer portForwardWG.Done() assert.Equal(t, podName, name, "pod name")
assert.Equal(t, podNamespace, namespace, "pod namespace")
if e, a := expectedPodName, name; e != a {
t.Fatalf("%d: pod name: expected '%v', got '%v'", port, e, a)
} }
ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
defer portForwardWG.Done()
assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id")
portsMutex.Lock() portsMutex.Lock()
portsForwarded[port] = struct{}{} portsForwarded[port] = struct{}{}
portsMutex.Unlock() portsMutex.Unlock()
fromClient := make([]byte, 32) fromClient := make([]byte, 32)
n, err := stream.Read(fromClient) n, err := stream.Read(fromClient)
if err != nil { assert.NoError(t, err, "reading client data")
t.Fatalf("%d: error reading client data: %v", port, err) assert.Equal(t, fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]), "client data")
}
if e, a := fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]); e != a {
t.Fatalf("%d: client data: expected to receive '%v', got '%v'", port, e, a)
}
_, err = stream.Write([]byte(fmt.Sprintf("container data on port %d", port))) _, err = stream.Write([]byte(fmt.Sprintf("container data on port %d", port)))
if err != nil { assert.NoError(t, err, "writing container data")
t.Fatalf("%d: error writing container data: %v", port, err)
}
return nil return nil
} }
@ -239,70 +197,42 @@ func TestServeWSMultiplePortForward(t *testing.T) {
} }
ws, err := websocket.Dial(url, "", "http://127.0.0.1/") ws, err := websocket.Dial(url, "", "http://127.0.0.1/")
if err != nil { require.NoError(t, err, "websocket dial")
t.Fatalf("websocket dial unexpected err: %v", err)
}
defer ws.Close() defer ws.Close()
for i, port := range ports { for i, port := range ports {
channel, data, err := wsRead(ws) channel, data, err := wsRead(ws)
if err != nil { assert.NoError(t, err, "port %d read", port)
t.Fatalf("%d: read failed: expected no error: got %v", i, err) assert.Equal(t, i*2+dataChannel, int(channel), "port %d channel", port)
} assert.Len(t, data, binary.Size(port), "port %d data size", port)
if int(channel) != i*2+dataChannel { assert.Equal(t, binary.LittleEndian.Uint16(data), port, "port %d data", port)
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+dataChannel)
}
if len(data) != binary.Size(port) {
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port))
}
if e, a := port, binary.LittleEndian.Uint16(data); e != a {
t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port)
}
channel, data, err = wsRead(ws) channel, data, err = wsRead(ws)
if err != nil { assert.NoError(t, err, "port %d read", port)
t.Fatalf("%d: read succeeded: expected no error: got %v", i, err) assert.Equal(t, i*2+errorChannel, int(channel), "port %d channel", port)
} assert.Len(t, data, binary.Size(port), "port %d data size", port)
if int(channel) != i*2+errorChannel { assert.Equal(t, binary.LittleEndian.Uint16(data), port, "port %d data", port)
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+errorChannel)
}
if len(data) != binary.Size(port) {
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port))
}
if e, a := port, binary.LittleEndian.Uint16(data); e != a {
t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port)
}
} }
for i, port := range ports { for i, port := range ports {
println("writing the client data", port) t.Logf("port %d writing the client data", port)
err := wsWrite(ws, byte(i*2+dataChannel), []byte(fmt.Sprintf("client data on port %d", port))) err := wsWrite(ws, byte(i*2+dataChannel), []byte(fmt.Sprintf("client data on port %d", port)))
if err != nil { assert.NoError(t, err, "port %d write client data", port)
t.Fatalf("%d: unexpected error writing client data: %v", i, err)
}
channel, data, err := wsRead(ws) channel, data, err := wsRead(ws)
if err != nil { assert.NoError(t, err, "port %d read container data", port)
t.Fatalf("%d: unexpected error reading container data: %v", i, err) assert.Equal(t, i*2+dataChannel, int(channel), "port %d channel", port)
} assert.Equal(t, fmt.Sprintf("container data on port %d", port), string(data), "port %d container data", port)
if int(channel) != i*2+dataChannel {
t.Fatalf("%d: wrong channel: got %q: expected %q", port, channel, i*2+dataChannel)
}
if e, a := fmt.Sprintf("container data on port %d", port), string(data); e != a {
t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a)
}
} }
portForwardWG.Wait() portForwardWG.Wait()
portsMutex.Lock() portsMutex.Lock()
defer portsMutex.Unlock() defer portsMutex.Unlock()
if len(ports) != len(portsForwarded) { assert.Len(t, portsForwarded, len(ports), "all ports forwarded")
t.Fatalf("expected to forward %d ports; got %v", len(ports), portsForwarded)
}
} }
func wsWrite(conn *websocket.Conn, channel byte, data []byte) error { func wsWrite(conn *websocket.Conn, channel byte, data []byte) error {
frame := make([]byte, len(data)+1) frame := make([]byte, len(data)+1)
frame[0] = channel frame[0] = channel