From 1eb721248b8e70112cd2b118b435570eef1f1172 Mon Sep 17 00:00:00 2001 From: Lantao Liu Date: Fri, 18 May 2018 16:08:44 -0700 Subject: [PATCH] Update unit test. --- pkg/kubelet/kubelet_pods_test.go | 18 +- pkg/kubelet/server/server_test.go | 899 ++++++++++---------- pkg/kubelet/server/server_websocket_test.go | 300 +++---- 3 files changed, 571 insertions(+), 646 deletions(-) diff --git a/pkg/kubelet/kubelet_pods_test.go b/pkg/kubelet/kubelet_pods_test.go index 146ef3056ef..7fd9d984701 100644 --- a/pkg/kubelet/kubelet_pods_test.go +++ b/pkg/kubelet/kubelet_pods_test.go @@ -17,7 +17,6 @@ limitations under the License. package kubelet import ( - "bytes" "errors" "fmt" "io/ioutil" @@ -2095,7 +2094,7 @@ func (f *fakeReadWriteCloser) Close() error { return nil } -func TestExec(t *testing.T) { +func TestGetExec(t *testing.T) { const ( podName = "podFoo" podNamespace = "nsFoo" @@ -2106,9 +2105,6 @@ func TestExec(t *testing.T) { var ( podFullName = kubecontainer.GetPodFullName(podWithUIDNameNs(podUID, podName, podNamespace)) command = []string{"ls"} - stdin = &bytes.Buffer{} - stdout = &fakeReadWriteCloser{} - stderr = &fakeReadWriteCloser{} ) testcases := []struct { @@ -2161,22 +2157,16 @@ func TestExec(t *testing.T) { assert.NoError(t, err, description) assert.Equal(t, containertest.FakeHost, redirect.Host, description+": redirect") } - - err = kubelet.ExecInContainer(tc.podFullName, podUID, tc.container, command, stdin, stdout, stderr, tty, nil, 0) - assert.Error(t, err, description) } } -func TestPortForward(t *testing.T) { +func TestGetPortForward(t *testing.T) { const ( podName = "podFoo" podNamespace = "nsFoo" podUID types.UID = "12345678" port int32 = 5000 ) - var ( - stream = &fakeReadWriteCloser{} - ) testcases := []struct { description string @@ -2208,7 +2198,6 @@ func TestPortForward(t *testing.T) { }}, } - podFullName := kubecontainer.GetPodFullName(podWithUIDNameNs(podUID, tc.podName, podNamespace)) description := "streaming - " + tc.description fakeRuntime := &containertest.FakeStreamingRuntime{FakeRuntime: testKubelet.fakeRuntime} kubelet.containerRuntime = fakeRuntime @@ -2221,9 +2210,6 @@ func TestPortForward(t *testing.T) { assert.NoError(t, err, description) assert.Equal(t, containertest.FakeHost, redirect.Host, description+": redirect") } - - err = kubelet.PortForward(podFullName, podUID, port, stream) - assert.Error(t, err, description) } } diff --git a/pkg/kubelet/server/server_test.go b/pkg/kubelet/server/server_test.go index 25776273a1b..e84bec4d649 100644 --- a/pkg/kubelet/server/server_test.go +++ b/pkg/kubelet/server/server_test.go @@ -46,42 +46,46 @@ import ( "k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/authorization/authorizer" "k8s.io/client-go/tools/remotecommand" + utiltesting "k8s.io/client-go/util/testing" api "k8s.io/kubernetes/pkg/apis/core" + runtimeapi "k8s.io/kubernetes/pkg/kubelet/apis/cri/runtime/v1alpha2" statsapi "k8s.io/kubernetes/pkg/kubelet/apis/stats/v1alpha1" // Do some initialization to decode the query parameters correctly. _ "k8s.io/kubernetes/pkg/apis/core/install" "k8s.io/kubernetes/pkg/kubelet/cm" kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" - kubecontainertesting "k8s.io/kubernetes/pkg/kubelet/container/testing" "k8s.io/kubernetes/pkg/kubelet/server/portforward" remotecommandserver "k8s.io/kubernetes/pkg/kubelet/server/remotecommand" "k8s.io/kubernetes/pkg/kubelet/server/stats" + "k8s.io/kubernetes/pkg/kubelet/server/streaming" "k8s.io/kubernetes/pkg/volume" ) const ( - testUID = "9b01b80f-8fb4-11e4-95ab-4200af06647" + testUID = "9b01b80f-8fb4-11e4-95ab-4200af06647" + testContainerID = "container789" + testPodSandboxID = "pod0987" ) type fakeKubelet struct { - podByNameFunc func(namespace, name string) (*v1.Pod, bool) - containerInfoFunc func(podFullName string, uid types.UID, containerName string, req *cadvisorapi.ContainerInfoRequest) (*cadvisorapi.ContainerInfo, error) - rawInfoFunc func(query *cadvisorapi.ContainerInfoRequest) (map[string]*cadvisorapi.ContainerInfo, error) - machineInfoFunc func() (*cadvisorapi.MachineInfo, error) - podsFunc func() []*v1.Pod - runningPodsFunc func() ([]*v1.Pod, error) - logFunc func(w http.ResponseWriter, req *http.Request) - runFunc func(podFullName string, uid types.UID, containerName string, cmd []string) ([]byte, error) - execFunc func(pod string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error - attachFunc func(pod string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool) error - portForwardFunc func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error - containerLogsFunc func(podFullName, containerName string, logOptions *v1.PodLogOptions, stdout, stderr io.Writer) error - streamingConnectionIdleTimeoutFunc func() time.Duration - hostnameFunc func() string - resyncInterval time.Duration - loopEntryTime time.Time - plegHealth bool - redirectURL *url.URL + podByNameFunc func(namespace, name string) (*v1.Pod, bool) + containerInfoFunc func(podFullName string, uid types.UID, containerName string, req *cadvisorapi.ContainerInfoRequest) (*cadvisorapi.ContainerInfo, error) + rawInfoFunc func(query *cadvisorapi.ContainerInfoRequest) (map[string]*cadvisorapi.ContainerInfo, error) + machineInfoFunc func() (*cadvisorapi.MachineInfo, error) + podsFunc func() []*v1.Pod + runningPodsFunc func() ([]*v1.Pod, error) + logFunc func(w http.ResponseWriter, req *http.Request) + runFunc func(podFullName string, uid types.UID, containerName string, cmd []string) ([]byte, error) + getExecCheck func(string, types.UID, string, []string, remotecommandserver.Options) + getAttachCheck func(string, types.UID, string, remotecommandserver.Options) + getPortForwardCheck func(string, string, types.UID, portforward.V4Options) + + containerLogsFunc func(podFullName, containerName string, logOptions *v1.PodLogOptions, stdout, stderr io.Writer) error + hostnameFunc func() string + resyncInterval time.Duration + loopEntryTime time.Time + plegHealth bool + streamingRuntime streaming.Server } func (fk *fakeKubelet) ResyncInterval() time.Duration { @@ -136,32 +140,109 @@ func (fk *fakeKubelet) RunInContainer(podFullName string, uid types.UID, contain return fk.runFunc(podFullName, uid, containerName, cmd) } -func (fk *fakeKubelet) ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize, timeout time.Duration) error { - return fk.execFunc(name, uid, container, cmd, in, out, err, tty) +type fakeRuntime struct { + execFunc func(string, []string, io.Reader, io.WriteCloser, io.WriteCloser, bool, <-chan remotecommand.TerminalSize) error + attachFunc func(string, io.Reader, io.WriteCloser, io.WriteCloser, bool, <-chan remotecommand.TerminalSize) error + portForwardFunc func(string, int32, io.ReadWriteCloser) error } -func (fk *fakeKubelet) AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { - return fk.attachFunc(name, uid, container, in, out, err, tty) +func (f *fakeRuntime) Exec(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { + return f.execFunc(containerID, cmd, stdin, stdout, stderr, tty, resize) } -func (fk *fakeKubelet) PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { - return fk.portForwardFunc(name, uid, port, stream) +func (f *fakeRuntime) Attach(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { + return f.attachFunc(containerID, stdin, stdout, stderr, tty, resize) +} + +func (f *fakeRuntime) PortForward(podSandboxID string, port int32, stream io.ReadWriteCloser) error { + return f.portForwardFunc(podSandboxID, port, stream) +} + +type testStreamingServer struct { + streaming.Server + fakeRuntime *fakeRuntime + testHTTPServer *httptest.Server +} + +func newTestStreamingServer(streamIdleTimeout time.Duration) (s *testStreamingServer, err error) { + s = &testStreamingServer{} + s.testHTTPServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.ServeHTTP(w, r) + })) + defer func() { + if err != nil { + s.testHTTPServer.Close() + } + }() + + testURL, err := url.Parse(s.testHTTPServer.URL) + if err != nil { + return nil, err + } + + s.fakeRuntime = &fakeRuntime{} + config := streaming.DefaultConfig + config.BaseURL = testURL + if streamIdleTimeout != 0 { + config.StreamIdleTimeout = streamIdleTimeout + } + s.Server, err = streaming.NewServer(config, s.fakeRuntime) + if err != nil { + return nil, err + } + return s, nil } func (fk *fakeKubelet) GetExec(podFullName string, podUID types.UID, containerName string, cmd []string, streamOpts remotecommandserver.Options) (*url.URL, error) { - return fk.redirectURL, nil + if fk.getExecCheck != nil { + fk.getExecCheck(podFullName, podUID, containerName, cmd, streamOpts) + } + // Always use testContainerID + resp, err := fk.streamingRuntime.GetExec(&runtimeapi.ExecRequest{ + ContainerId: testContainerID, + Cmd: cmd, + Tty: streamOpts.TTY, + Stdin: streamOpts.Stdin, + Stdout: streamOpts.Stdout, + Stderr: streamOpts.Stderr, + }) + if err != nil { + return nil, err + } + return url.Parse(resp.GetUrl()) } func (fk *fakeKubelet) GetAttach(podFullName string, podUID types.UID, containerName string, streamOpts remotecommandserver.Options) (*url.URL, error) { - return fk.redirectURL, nil + if fk.getAttachCheck != nil { + fk.getAttachCheck(podFullName, podUID, containerName, streamOpts) + } + // Always use testContainerID + resp, err := fk.streamingRuntime.GetAttach(&runtimeapi.AttachRequest{ + ContainerId: testContainerID, + Tty: streamOpts.TTY, + Stdin: streamOpts.Stdin, + Stdout: streamOpts.Stdout, + Stderr: streamOpts.Stderr, + }) + if err != nil { + return nil, err + } + return url.Parse(resp.GetUrl()) } func (fk *fakeKubelet) GetPortForward(podName, podNamespace string, podUID types.UID, portForwardOpts portforward.V4Options) (*url.URL, error) { - return fk.redirectURL, nil -} - -func (fk *fakeKubelet) StreamingConnectionIdleTimeout() time.Duration { - return fk.streamingConnectionIdleTimeoutFunc() + if fk.getPortForwardCheck != nil { + fk.getPortForwardCheck(podName, podNamespace, podUID, portForwardOpts) + } + // Always use testPodSandboxID + resp, err := fk.streamingRuntime.GetPortForward(&runtimeapi.PortForwardRequest{ + PodSandboxId: testPodSandboxID, + Port: portForwardOpts.Ports, + }) + if err != nil { + return nil, err + } + return url.Parse(resp.GetUrl()) } // Unused functions @@ -198,17 +279,20 @@ func (f *fakeAuth) Authorize(a authorizer.Attributes) (authorized authorizer.Dec } type serverTestFramework struct { - serverUnderTest *Server - fakeKubelet *fakeKubelet - fakeAuth *fakeAuth - testHTTPServer *httptest.Server + serverUnderTest *Server + fakeKubelet *fakeKubelet + fakeAuth *fakeAuth + testHTTPServer *httptest.Server + fakeRuntime *fakeRuntime + testStreamingHTTPServer *httptest.Server + criHandler *utiltesting.FakeHandler } func newServerTest() *serverTestFramework { - return newServerTestWithDebug(true) + return newServerTestWithDebug(true, false, nil) } -func newServerTestWithDebug(enableDebugging bool) *serverTestFramework { +func newServerTestWithDebug(enableDebugging, redirectContainerStreaming bool, streamingServer streaming.Server) *serverTestFramework { fw := &serverTestFramework{} fw.fakeKubelet = &fakeKubelet{ hostnameFunc: func() string { @@ -223,7 +307,8 @@ func newServerTestWithDebug(enableDebugging bool) *serverTestFramework { }, }, true }, - plegHealth: true, + plegHealth: true, + streamingRuntime: streamingServer, } fw.fakeAuth = &fakeAuth{ authenticateFunc: func(req *http.Request) (user.Info, bool, error) { @@ -236,13 +321,17 @@ func newServerTestWithDebug(enableDebugging bool) *serverTestFramework { return authorizer.DecisionAllow, "", nil }, } + fw.criHandler = &utiltesting.FakeHandler{ + StatusCode: http.StatusOK, + } server := NewServer( fw.fakeKubelet, stats.NewResourceAnalyzer(fw.fakeKubelet, time.Minute), fw.fakeAuth, enableDebugging, false, - &kubecontainertesting.Mock{}) + redirectContainerStreaming, + fw.criHandler) fw.serverUnderTest = &server fw.testHTTPServer = httptest.NewServer(fw.serverUnderTest) return fw @@ -1064,13 +1153,12 @@ func TestContainerLogsWithFollow(t *testing.T) { } func TestServeExecInContainerIdleTimeout(t *testing.T) { - fw := newServerTest() + ss, err := newTestStreamingServer(100 * time.Millisecond) + require.NoError(t, err) + defer ss.testHTTPServer.Close() + fw := newServerTestWithDebug(true, false, ss) defer fw.testHTTPServer.Close() - fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { - return 100 * time.Millisecond - } - podNamespace := "other" podName := "foo" expectedContainerName := "baz" @@ -1102,280 +1190,221 @@ func TestServeExecInContainerIdleTimeout(t *testing.T) { } func testExecAttach(t *testing.T, verb string) { - tests := []struct { + tests := map[string]struct { stdin bool stdout bool stderr bool tty bool responseStatusCode int uid bool - responseLocation string + redirect bool }{ - {responseStatusCode: http.StatusBadRequest}, - {stdin: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stdout: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stdout: true, responseStatusCode: http.StatusFound, responseLocation: "http://localhost:12345/" + verb}, + "no input or output": {responseStatusCode: http.StatusBadRequest}, + "stdin": {stdin: true, responseStatusCode: http.StatusSwitchingProtocols}, + "stdout": {stdout: true, responseStatusCode: http.StatusSwitchingProtocols}, + "stderr": {stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, + "stdout and stderr": {stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, + "stdout stderr and tty": {stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols}, + "stdin stdout and stderr": {stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, + "stdin stdout stderr with uid": {stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols, uid: true}, + "stdout with redirect": {stdout: true, responseStatusCode: http.StatusFound, redirect: true}, } - for i, test := range tests { - fw := newServerTest() - defer fw.testHTTPServer.Close() - - fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { - return 0 - } - - if test.responseLocation != "" { - var err error - fw.fakeKubelet.redirectURL, err = url.Parse(test.responseLocation) + for desc, test := range tests { + test := test + t.Run(desc, func(t *testing.T) { + ss, err := newTestStreamingServer(0) require.NoError(t, err) - } + defer ss.testHTTPServer.Close() + fw := newServerTestWithDebug(true, test.redirect, ss) + defer fw.testHTTPServer.Close() + fmt.Println(desc) - podNamespace := "other" - podName := "foo" - expectedPodName := getPodName(podName, podNamespace) - expectedContainerName := "baz" - expectedCommand := "ls -a" - expectedStdin := "stdin" - expectedStdout := "stdout" - expectedStderr := "stderr" - done := make(chan struct{}) - clientStdoutReadDone := make(chan struct{}) - clientStderrReadDone := make(chan struct{}) - execInvoked := false - attachInvoked := false + podNamespace := "other" + podName := "foo" + expectedPodName := getPodName(podName, podNamespace) + expectedContainerName := "baz" + expectedCommand := "ls -a" + expectedStdin := "stdin" + expectedStdout := "stdout" + expectedStderr := "stderr" + done := make(chan struct{}) + clientStdoutReadDone := make(chan struct{}) + clientStderrReadDone := make(chan struct{}) + execInvoked := false + attachInvoked := false - testStreamFunc := func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool, done chan struct{}) error { - defer close(done) + checkStream := func(podFullName string, uid types.UID, containerName string, streamOpts remotecommandserver.Options) { + assert.Equal(t, expectedPodName, podFullName, "podFullName") + if test.uid { + assert.Equal(t, testUID, string(uid), "uid") + } + assert.Equal(t, expectedContainerName, containerName, "containerName") + assert.Equal(t, test.stdin, streamOpts.Stdin, "stdin") + assert.Equal(t, test.stdout, streamOpts.Stdout, "stdout") + assert.Equal(t, test.tty, streamOpts.TTY, "tty") + assert.Equal(t, !test.tty && test.stderr, streamOpts.Stderr, "stderr") + } - if podFullName != expectedPodName { - t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName) + fw.fakeKubelet.getExecCheck = func(podFullName string, uid types.UID, containerName string, cmd []string, streamOpts remotecommandserver.Options) { + execInvoked = true + assert.Equal(t, expectedCommand, strings.Join(cmd, " "), "cmd") + checkStream(podFullName, uid, containerName, streamOpts) } - if test.uid && string(uid) != testUID { - t.Fatalf("%d: uid: expected %v, got %v", i, testUID, uid) + + fw.fakeKubelet.getAttachCheck = func(podFullName string, uid types.UID, containerName string, streamOpts remotecommandserver.Options) { + attachInvoked = true + checkStream(podFullName, uid, containerName, streamOpts) } - if containerName != expectedContainerName { - t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName) + + testStream := func(containerID string, in io.Reader, out, stderr io.WriteCloser, tty bool, done chan struct{}) error { + close(done) + assert.Equal(t, testContainerID, containerID, "containerID") + assert.Equal(t, test.tty, tty, "tty") + require.Equal(t, test.stdin, in != nil, "in") + require.Equal(t, test.stdout, out != nil, "out") + require.Equal(t, !test.tty && test.stderr, stderr != nil, "err") + + if test.stdin { + b := make([]byte, 10) + n, err := in.Read(b) + assert.NoError(t, err, "reading from stdin") + assert.Equal(t, expectedStdin, string(b[0:n]), "content from stdin") + } + + if test.stdout { + _, err := out.Write([]byte(expectedStdout)) + assert.NoError(t, err, "writing to stdout") + out.Close() + <-clientStdoutReadDone + } + + if !test.tty && test.stderr { + _, err := stderr.Write([]byte(expectedStderr)) + assert.NoError(t, err, "writing to stderr") + stderr.Close() + <-clientStderrReadDone + } + return nil } + ss.fakeRuntime.execFunc = func(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { + assert.Equal(t, expectedCommand, strings.Join(cmd, " "), "cmd") + return testStream(containerID, stdin, stdout, stderr, tty, done) + } + + ss.fakeRuntime.attachFunc = func(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { + return testStream(containerID, stdin, stdout, stderr, tty, done) + } + + var url string + if test.uid { + url = fw.testHTTPServer.URL + "/" + verb + "/" + podNamespace + "/" + podName + "/" + testUID + "/" + expectedContainerName + "?ignore=1" + } else { + url = fw.testHTTPServer.URL + "/" + verb + "/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?ignore=1" + } + if verb == "exec" { + url += "&command=ls&command=-a" + } + if test.stdin { + url += "&" + api.ExecStdinParam + "=1" + } + if test.stdout { + url += "&" + api.ExecStdoutParam + "=1" + } + if test.stderr && !test.tty { + url += "&" + api.ExecStderrParam + "=1" + } + if test.tty { + url += "&" + api.ExecTTYParam + "=1" + } + + var ( + resp *http.Response + upgradeRoundTripper httpstream.UpgradeRoundTripper + c *http.Client + ) + if test.redirect { + c = &http.Client{} + // Don't follow redirects, since we want to inspect the redirect response. + c.CheckRedirect = func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + } + } else { + upgradeRoundTripper = spdy.NewRoundTripper(nil, true) + c = &http.Client{Transport: upgradeRoundTripper} + } + + resp, err = c.Post(url, "", nil) + require.NoError(t, err, "POSTing") + defer resp.Body.Close() + + _, err = ioutil.ReadAll(resp.Body) + assert.NoError(t, err, "reading response body") + + require.Equal(t, test.responseStatusCode, resp.StatusCode, "response status") + if test.responseStatusCode != http.StatusSwitchingProtocols { + return + } + + conn, err := upgradeRoundTripper.NewConnection(resp) + require.NoError(t, err, "creating streaming connection") + defer conn.Close() + + h := http.Header{} + h.Set(api.StreamType, api.StreamTypeError) + _, err = conn.CreateStream(h) + require.NoError(t, err, "creating error stream") + if test.stdin { - if in == nil { - t.Fatalf("%d: stdin: expected non-nil", i) - } - b := make([]byte, 10) - n, err := in.Read(b) - if err != nil { - t.Fatalf("%d: error reading from stdin: %v", i, err) - } - if e, a := expectedStdin, string(b[0:n]); e != a { - t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a) - } - } else if in != nil { - t.Fatalf("%d: stdin: expected nil: %#v", i, in) + h.Set(api.StreamType, api.StreamTypeStdin) + stream, err := conn.CreateStream(h) + require.NoError(t, err, "creating stdin stream") + _, err = stream.Write([]byte(expectedStdin)) + require.NoError(t, err, "writing to stdin stream") + } + + var stdoutStream httpstream.Stream + if test.stdout { + h.Set(api.StreamType, api.StreamTypeStdout) + stdoutStream, err = conn.CreateStream(h) + require.NoError(t, err, "creating stdout stream") + } + + var stderrStream httpstream.Stream + if test.stderr && !test.tty { + h.Set(api.StreamType, api.StreamTypeStderr) + stderrStream, err = conn.CreateStream(h) + require.NoError(t, err, "creating stderr stream") } if test.stdout { - if out == nil { - t.Fatalf("%d: stdout: expected non-nil", i) - } - _, err := out.Write([]byte(expectedStdout)) - if err != nil { - t.Fatalf("%d:, error writing to stdout: %v", i, err) - } - out.Close() - <-clientStdoutReadDone - } else if out != nil { - t.Fatalf("%d: stdout: expected nil: %#v", i, out) + output := make([]byte, 10) + n, err := stdoutStream.Read(output) + close(clientStdoutReadDone) + assert.NoError(t, err, "reading from stdout stream") + assert.Equal(t, expectedStdout, string(output[0:n]), "stdout") } - if tty { - if stderr != nil { - t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr) - } - } else if test.stderr { - if stderr == nil { - t.Fatalf("%d: stderr: expected non-nil", i) - } - _, err := stderr.Write([]byte(expectedStderr)) - if err != nil { - t.Fatalf("%d:, error writing to stderr: %v", i, err) - } - stderr.Close() - <-clientStderrReadDone - } else if stderr != nil { - t.Fatalf("%d: stderr: expected nil: %#v", i, stderr) + if test.stderr && !test.tty { + output := make([]byte, 10) + n, err := stderrStream.Read(output) + close(clientStderrReadDone) + assert.NoError(t, err, "reading from stderr stream") + assert.Equal(t, expectedStderr, string(output[0:n]), "stderr") } - return nil - } + // wait for the server to finish before checking if the attach/exec funcs were invoked + <-done - fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error { - execInvoked = true - if strings.Join(cmd, " ") != expectedCommand { - t.Fatalf("%d: cmd: expected: %s, got %v", i, expectedCommand, cmd) + if verb == "exec" { + assert.True(t, execInvoked, "exec should be invoked") + assert.False(t, attachInvoked, "attach should not be invoked") + } else { + assert.True(t, attachInvoked, "attach should be invoked") + assert.False(t, execInvoked, "exec should not be invoked") } - return testStreamFunc(podFullName, uid, containerName, cmd, in, out, stderr, tty, done) - } - - fw.fakeKubelet.attachFunc = func(podFullName string, uid types.UID, containerName string, in io.Reader, out, stderr io.WriteCloser, tty bool) error { - attachInvoked = true - return testStreamFunc(podFullName, uid, containerName, nil, in, out, stderr, tty, done) - } - - var url string - if test.uid { - url = fw.testHTTPServer.URL + "/" + verb + "/" + podNamespace + "/" + podName + "/" + testUID + "/" + expectedContainerName + "?ignore=1" - } else { - url = fw.testHTTPServer.URL + "/" + verb + "/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?ignore=1" - } - if verb == "exec" { - url += "&command=ls&command=-a" - } - if test.stdin { - url += "&" + api.ExecStdinParam + "=1" - } - if test.stdout { - url += "&" + api.ExecStdoutParam + "=1" - } - if test.stderr && !test.tty { - url += "&" + api.ExecStderrParam + "=1" - } - if test.tty { - url += "&" + api.ExecTTYParam + "=1" - } - - var ( - resp *http.Response - err error - upgradeRoundTripper httpstream.UpgradeRoundTripper - c *http.Client - ) - - if test.responseStatusCode != http.StatusSwitchingProtocols { - c = &http.Client{} - // Don't follow redirects, since we want to inspect the redirect response. - c.CheckRedirect = func(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse - } - } else { - upgradeRoundTripper = spdy.NewRoundTripper(nil, true) - c = &http.Client{Transport: upgradeRoundTripper} - } - - resp, err = c.Post(url, "", nil) - if err != nil { - t.Fatalf("%d: Got error POSTing: %v", i, err) - } - defer resp.Body.Close() - - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("%d: Error reading response body: %v", i, err) - } - - if e, a := test.responseStatusCode, resp.StatusCode; e != a { - t.Fatalf("%d: response status: expected %v, got %v", i, e, a) - } - - if e, a := test.responseLocation, resp.Header.Get("Location"); e != a { - t.Errorf("%d: response location: expected %v, got %v", i, e, a) - } - - if test.responseStatusCode != http.StatusSwitchingProtocols { - continue - } - - conn, err := upgradeRoundTripper.NewConnection(resp) - if err != nil { - t.Fatalf("Unexpected error creating streaming connection: %s", err) - } - if conn == nil { - t.Fatalf("%d: unexpected nil conn", i) - } - defer conn.Close() - - h := http.Header{} - h.Set(api.StreamType, api.StreamTypeError) - if _, err := conn.CreateStream(h); err != nil { - t.Fatalf("%d: error creating error stream: %v", i, err) - } - - if test.stdin { - h.Set(api.StreamType, api.StreamTypeStdin) - stream, err := conn.CreateStream(h) - if err != nil { - t.Fatalf("%d: error creating stdin stream: %v", i, err) - } - _, err = stream.Write([]byte(expectedStdin)) - if err != nil { - t.Fatalf("%d: error writing to stdin stream: %v", i, err) - } - } - - var stdoutStream httpstream.Stream - if test.stdout { - h.Set(api.StreamType, api.StreamTypeStdout) - stdoutStream, err = conn.CreateStream(h) - if err != nil { - t.Fatalf("%d: error creating stdout stream: %v", i, err) - } - } - - var stderrStream httpstream.Stream - if test.stderr && !test.tty { - h.Set(api.StreamType, api.StreamTypeStderr) - stderrStream, err = conn.CreateStream(h) - if err != nil { - t.Fatalf("%d: error creating stderr stream: %v", i, err) - } - } - - if test.stdout { - output := make([]byte, 10) - n, err := stdoutStream.Read(output) - close(clientStdoutReadDone) - if err != nil { - t.Fatalf("%d: error reading from stdout stream: %v", i, err) - } - if e, a := expectedStdout, string(output[0:n]); e != a { - t.Fatalf("%d: stdout: expected '%v', got '%v'", i, e, a) - } - } - - if test.stderr && !test.tty { - output := make([]byte, 10) - n, err := stderrStream.Read(output) - close(clientStderrReadDone) - if err != nil { - t.Fatalf("%d: error reading from stderr stream: %v", i, err) - } - if e, a := expectedStderr, string(output[0:n]); e != a { - t.Fatalf("%d: stderr: expected '%v', got '%v'", i, e, a) - } - } - - // wait for the server to finish before checking if the attach/exec funcs were invoked - <-done - - if verb == "exec" { - if !execInvoked { - t.Errorf("%d: exec was not invoked", i) - } - if attachInvoked { - t.Errorf("%d: attach should not have been invoked", i) - } - } else { - if !attachInvoked { - t.Errorf("%d: attach was not invoked", i) - } - if execInvoked { - t.Errorf("%d: exec should not have been invoked", i) - } - } + }) } } @@ -1388,13 +1417,12 @@ func TestServeAttachContainer(t *testing.T) { } func TestServePortForwardIdleTimeout(t *testing.T) { - fw := newServerTest() + ss, err := newTestStreamingServer(100 * time.Millisecond) + require.NoError(t, err) + defer ss.testHTTPServer.Close() + fw := newServerTestWithDebug(true, false, ss) defer fw.testHTTPServer.Close() - fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { - return 100 * time.Millisecond - } - podNamespace := "other" podName := "foo" @@ -1422,179 +1450,160 @@ func TestServePortForwardIdleTimeout(t *testing.T) { } func TestServePortForward(t *testing.T) { - tests := []struct { - port string - uid bool - clientData string - containerData string - shouldError bool - responseLocation string + tests := map[string]struct { + port string + uid bool + clientData string + containerData string + redirect bool + shouldError bool }{ - {port: "", shouldError: true}, - {port: "abc", shouldError: true}, - {port: "-1", shouldError: true}, - {port: "65536", shouldError: true}, - {port: "0", shouldError: true}, - {port: "1", shouldError: false}, - {port: "8000", shouldError: false}, - {port: "8000", clientData: "client data", containerData: "container data", shouldError: false}, - {port: "65535", shouldError: false}, - {port: "65535", uid: true, shouldError: false}, - {port: "65535", responseLocation: "http://localhost:12345/portforward", shouldError: false}, + "no port": {port: "", shouldError: true}, + "none number port": {port: "abc", shouldError: true}, + "negative port": {port: "-1", shouldError: true}, + "too large port": {port: "65536", shouldError: true}, + "0 port": {port: "0", shouldError: true}, + "min port": {port: "1", shouldError: false}, + "normal port": {port: "8000", shouldError: false}, + "normal port with data forward": {port: "8000", clientData: "client data", containerData: "container data", shouldError: false}, + "max port": {port: "65535", shouldError: false}, + "normal port with uid": {port: "8000", uid: true, shouldError: false}, + "normal port with redirect": {port: "8000", redirect: true, shouldError: false}, } podNamespace := "other" podName := "foo" - expectedPodName := getPodName(podName, podNamespace) - for i, test := range tests { - fw := newServerTest() - defer fw.testHTTPServer.Close() - - fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { - return 0 - } - - if test.responseLocation != "" { - var err error - fw.fakeKubelet.redirectURL, err = url.Parse(test.responseLocation) + for desc, test := range tests { + test := test + t.Run(desc, func(t *testing.T) { + ss, err := newTestStreamingServer(0) require.NoError(t, err) - } + defer ss.testHTTPServer.Close() + fw := newServerTestWithDebug(true, test.redirect, ss) + defer fw.testHTTPServer.Close() - portForwardFuncDone := make(chan struct{}) + portForwardFuncDone := make(chan struct{}) - fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { - defer close(portForwardFuncDone) - - if e, a := expectedPodName, name; e != a { - t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a) + fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) { + assert.Equal(t, podName, name, "pod name") + assert.Equal(t, podNamespace, namespace, "pod namespace") + if test.uid { + assert.Equal(t, testUID, string(uid), "uid") + } } - if e, a := testUID, uid; test.uid && e != string(a) { - t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a) + ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error { + defer close(portForwardFuncDone) + assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id") + // The port should be valid if it reaches here. + testPort, err := strconv.ParseInt(test.port, 10, 32) + require.NoError(t, err, "parse port") + assert.Equal(t, int32(testPort), port, "port") + + if test.clientData != "" { + fromClient := make([]byte, 32) + n, err := stream.Read(fromClient) + assert.NoError(t, err, "reading client data") + assert.Equal(t, test.clientData, string(fromClient[0:n]), "client data") + } + + if test.containerData != "" { + _, err := stream.Write([]byte(test.containerData)) + assert.NoError(t, err, "writing container data") + } + + return nil } - p, err := strconv.ParseInt(test.port, 10, 32) - if err != nil { - t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err) + var url string + if test.uid { + url = fmt.Sprintf("%s/portForward/%s/%s/%s", fw.testHTTPServer.URL, podNamespace, podName, testUID) + } else { + url = fmt.Sprintf("%s/portForward/%s/%s", fw.testHTTPServer.URL, podNamespace, podName) } - if e, a := int32(p), port; e != a { - t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a) + + var ( + upgradeRoundTripper httpstream.UpgradeRoundTripper + c *http.Client + ) + + if test.redirect { + c = &http.Client{} + // Don't follow redirects, since we want to inspect the redirect response. + c.CheckRedirect = func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + } + } else { + upgradeRoundTripper = spdy.NewRoundTripper(nil, true) + c = &http.Client{Transport: upgradeRoundTripper} } + resp, err := c.Post(url, "", nil) + require.NoError(t, err, "POSTing") + defer resp.Body.Close() + + if test.redirect { + assert.Equal(t, http.StatusFound, resp.StatusCode, "status code") + return + } else { + assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode, "status code") + } + + conn, err := upgradeRoundTripper.NewConnection(resp) + require.NoError(t, err, "creating streaming connection") + defer conn.Close() + + headers := http.Header{} + headers.Set("streamType", "error") + headers.Set("port", test.port) + _, err = conn.CreateStream(headers) + assert.Equal(t, test.shouldError, err != nil, "expect error") + + if test.shouldError { + return + } + + headers.Set("streamType", "data") + headers.Set("port", test.port) + dataStream, err := conn.CreateStream(headers) + require.NoError(t, err, "create stream") + if test.clientData != "" { - fromClient := make([]byte, 32) - n, err := stream.Read(fromClient) - if err != nil { - t.Fatalf("%d: error reading client data: %v", i, err) - } - if e, a := test.clientData, string(fromClient[0:n]); e != a { - t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a) - } + _, err := dataStream.Write([]byte(test.clientData)) + assert.NoError(t, err, "writing client data") } if test.containerData != "" { - _, err := stream.Write([]byte(test.containerData)) - if err != nil { - t.Fatalf("%d: error writing container data: %v", i, err) - } + fromContainer := make([]byte, 32) + n, err := dataStream.Read(fromContainer) + assert.NoError(t, err, "reading container data") + assert.Equal(t, test.containerData, string(fromContainer[0:n]), "container data") } - return nil - } - - var url string - if test.uid { - url = fmt.Sprintf("%s/portForward/%s/%s/%s", fw.testHTTPServer.URL, podNamespace, podName, testUID) - } else { - url = fmt.Sprintf("%s/portForward/%s/%s", fw.testHTTPServer.URL, podNamespace, podName) - } - - var ( - upgradeRoundTripper httpstream.UpgradeRoundTripper - c *http.Client - ) - - if len(test.responseLocation) > 0 { - c = &http.Client{} - // Don't follow redirects, since we want to inspect the redirect response. - c.CheckRedirect = func(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse - } - } else { - upgradeRoundTripper = spdy.NewRoundTripper(nil, true) - c = &http.Client{Transport: upgradeRoundTripper} - } - - resp, err := c.Post(url, "", nil) - if err != nil { - t.Fatalf("%d: Got error POSTing: %v", i, err) - } - defer resp.Body.Close() - - if test.responseLocation != "" { - assert.Equal(t, http.StatusFound, resp.StatusCode, "%d: status code", i) - assert.Equal(t, test.responseLocation, resp.Header.Get("Location"), "%d: location", i) - continue - } else { - assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode, "%d: status code", i) - } - - conn, err := upgradeRoundTripper.NewConnection(resp) - if err != nil { - t.Fatalf("Unexpected error creating streaming connection: %s", err) - } - if conn == nil { - t.Fatalf("%d: Unexpected nil connection", i) - } - defer conn.Close() - - headers := http.Header{} - headers.Set("streamType", "error") - headers.Set("port", test.port) - errorStream, err := conn.CreateStream(headers) - _ = errorStream - haveErr := err != nil - if e, a := test.shouldError, haveErr; e != a { - t.Fatalf("%d: create stream: expected err=%t, got %t: %v", i, e, a, err) - } - - if test.shouldError { - continue - } - - headers.Set("streamType", "data") - headers.Set("port", test.port) - dataStream, err := conn.CreateStream(headers) - haveErr = err != nil - if e, a := test.shouldError, haveErr; e != a { - t.Fatalf("%d: create stream: expected err=%t, got %t: %v", i, e, a, err) - } - - if test.clientData != "" { - _, err := dataStream.Write([]byte(test.clientData)) - if err != nil { - t.Fatalf("%d: unexpected error writing client data: %v", i, err) - } - } - - if test.containerData != "" { - fromContainer := make([]byte, 32) - n, err := dataStream.Read(fromContainer) - if err != nil { - t.Fatalf("%d: unexpected error reading container data: %v", i, err) - } - if e, a := test.containerData, string(fromContainer[0:n]); e != a { - t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a) - } - } - - <-portForwardFuncDone + <-portForwardFuncDone + }) } } +func TestCRIHandler(t *testing.T) { + fw := newServerTest() + defer fw.testHTTPServer.Close() + + const ( + path = "/cri/exec/123456abcdef" + query = "cmd=echo+foo" + ) + resp, err := http.Get(fw.testHTTPServer.URL + path + "?" + query) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "GET", fw.criHandler.RequestReceived.Method) + assert.Equal(t, path, fw.criHandler.RequestReceived.URL.Path) + assert.Equal(t, query, fw.criHandler.RequestReceived.URL.RawQuery) +} + func TestDebuggingDisabledHandlers(t *testing.T) { - fw := newServerTestWithDebug(false) + fw := newServerTestWithDebug(false, false, nil) defer fw.testHTTPServer.Close() paths := []string{ diff --git a/pkg/kubelet/server/server_websocket_test.go b/pkg/kubelet/server/server_websocket_test.go index 058b67d978a..daf6d356b63 100644 --- a/pkg/kubelet/server/server_websocket_test.go +++ b/pkg/kubelet/server/server_websocket_test.go @@ -23,11 +23,13 @@ import ( "strconv" "sync" "testing" - "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/websocket" "k8s.io/apimachinery/pkg/types" + "k8s.io/kubernetes/pkg/kubelet/server/portforward" ) const ( @@ -36,152 +38,114 @@ const ( ) func TestServeWSPortForward(t *testing.T) { - tests := []struct { + tests := map[string]struct { port string uid bool clientData string containerData string shouldError bool }{ - {port: "", shouldError: true}, - {port: "abc", shouldError: true}, - {port: "-1", shouldError: true}, - {port: "65536", shouldError: true}, - {port: "0", shouldError: true}, - {port: "1", shouldError: false}, - {port: "8000", shouldError: false}, - {port: "8000", clientData: "client data", containerData: "container data", shouldError: false}, - {port: "65535", shouldError: false}, - {port: "65535", uid: true, shouldError: false}, + "no port": {port: "", shouldError: true}, + "none number port": {port: "abc", shouldError: true}, + "negative port": {port: "-1", shouldError: true}, + "too large port": {port: "65536", shouldError: true}, + "0 port": {port: "0", shouldError: true}, + "min port": {port: "1", shouldError: false}, + "normal port": {port: "8000", shouldError: false}, + "normal port with data forward": {port: "8000", clientData: "client data", containerData: "container data", shouldError: false}, + "max port": {port: "65535", shouldError: false}, + "normal port with uid": {port: "8000", uid: true, shouldError: false}, } podNamespace := "other" podName := "foo" - expectedPodName := getPodName(podName, podNamespace) - expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647" - for i, test := range tests { - fw := newServerTest() - defer fw.testHTTPServer.Close() + for desc, test := range tests { + test := test + t.Run(desc, func(t *testing.T) { + ss, err := newTestStreamingServer(0) + require.NoError(t, err) + defer ss.testHTTPServer.Close() + fw := newServerTestWithDebug(true, false, ss) + defer fw.testHTTPServer.Close() - fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { - return 0 - } + portForwardFuncDone := make(chan struct{}) - portForwardFuncDone := make(chan struct{}) - - fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { - defer close(portForwardFuncDone) - - if e, a := expectedPodName, name; e != a { - t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a) + fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) { + assert.Equal(t, podName, name, "pod name") + assert.Equal(t, podNamespace, namespace, "pod namespace") + if test.uid { + assert.Equal(t, testUID, string(uid), "uid") + } } - if e, a := expectedUid, uid; test.uid && e != string(a) { - t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a) + ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error { + defer close(portForwardFuncDone) + assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id") + // The port should be valid if it reaches here. + testPort, err := strconv.ParseInt(test.port, 10, 32) + require.NoError(t, err, "parse port") + assert.Equal(t, int32(testPort), port, "port") + + if test.clientData != "" { + fromClient := make([]byte, 32) + n, err := stream.Read(fromClient) + assert.NoError(t, err, "reading client data") + assert.Equal(t, test.clientData, string(fromClient[0:n]), "client data") + } + + if test.containerData != "" { + _, err := stream.Write([]byte(test.containerData)) + assert.NoError(t, err, "writing container data") + } + + return nil } - p, err := strconv.ParseInt(test.port, 10, 32) - if err != nil { - t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err) + var url string + if test.uid { + url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, testUID, test.port) + } else { + url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port) } - if e, a := int32(p), port; e != a { - t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a) + + ws, err := websocket.Dial(url, "", "http://127.0.0.1/") + assert.Equal(t, test.shouldError, err != nil, "websocket dial") + if test.shouldError { + return } + defer ws.Close() + + p, err := strconv.ParseUint(test.port, 10, 16) + require.NoError(t, err, "parse port") + p16 := uint16(p) + + channel, data, err := wsRead(ws) + require.NoError(t, err, "read") + assert.Equal(t, dataChannel, int(channel), "channel") + assert.Len(t, data, binary.Size(p16), "data size") + assert.Equal(t, p16, binary.LittleEndian.Uint16(data), "data") + + channel, data, err = wsRead(ws) + assert.NoError(t, err, "read") + assert.Equal(t, errorChannel, int(channel), "channel") + assert.Len(t, data, binary.Size(p16), "data size") + assert.Equal(t, p16, binary.LittleEndian.Uint16(data), "data") if test.clientData != "" { - fromClient := make([]byte, 32) - n, err := stream.Read(fromClient) - if err != nil { - t.Fatalf("%d: error reading client data: %v", i, err) - } - if e, a := test.clientData, string(fromClient[0:n]); e != a { - t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a) - } + println("writing the client data") + err := wsWrite(ws, dataChannel, []byte(test.clientData)) + assert.NoError(t, err, "writing client data") } if test.containerData != "" { - _, err := stream.Write([]byte(test.containerData)) - if err != nil { - t.Fatalf("%d: error writing container data: %v", i, err) - } + _, data, err = wsRead(ws) + assert.NoError(t, err, "reading container data") + assert.Equal(t, test.containerData, string(data), "container data") } - return nil - } - - var url string - if test.uid { - url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, expectedUid, test.port) - } else { - url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port) - } - - ws, err := websocket.Dial(url, "", "http://127.0.0.1/") - if test.shouldError { - if err == nil { - t.Fatalf("%d: websocket dial expected err", i) - } - continue - } else if err != nil { - t.Fatalf("%d: websocket dial unexpected err: %v", i, err) - } - - defer ws.Close() - - p, err := strconv.ParseUint(test.port, 10, 16) - if err != nil { - t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err) - } - p16 := uint16(p) - - channel, data, err := wsRead(ws) - if err != nil { - t.Fatalf("%d: read failed: expected no error: got %v", i, err) - } - if channel != dataChannel { - t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, dataChannel) - } - if len(data) != binary.Size(p16) { - t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16)) - } - if e, a := p16, binary.LittleEndian.Uint16(data); e != a { - t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port) - } - - channel, data, err = wsRead(ws) - if err != nil { - t.Fatalf("%d: read succeeded: expected no error: got %v", i, err) - } - if channel != errorChannel { - t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, errorChannel) - } - if len(data) != binary.Size(p16) { - t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16)) - } - if e, a := p16, binary.LittleEndian.Uint16(data); e != a { - t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port) - } - - if test.clientData != "" { - println("writing the client data") - err := wsWrite(ws, dataChannel, []byte(test.clientData)) - if err != nil { - t.Fatalf("%d: unexpected error writing client data: %v", i, err) - } - } - - if test.containerData != "" { - _, data, err = wsRead(ws) - if err != nil { - t.Fatalf("%d: unexpected error reading container data: %v", i, err) - } - if e, a := test.containerData, string(data); e != a { - t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a) - } - } - - <-portForwardFuncDone + <-portForwardFuncDone + }) } } @@ -190,27 +154,27 @@ func TestServeWSMultiplePortForward(t *testing.T) { ports := []uint16{7000, 8000, 9000} podNamespace := "other" podName := "foo" - expectedPodName := getPodName(podName, podNamespace) - fw := newServerTest() + ss, err := newTestStreamingServer(0) + require.NoError(t, err) + defer ss.testHTTPServer.Close() + fw := newServerTestWithDebug(true, false, ss) defer fw.testHTTPServer.Close() - fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { - return 0 - } - portForwardWG := sync.WaitGroup{} portForwardWG.Add(len(ports)) portsMutex := sync.Mutex{} portsForwarded := map[int32]struct{}{} - fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error { - defer portForwardWG.Done() + fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) { + assert.Equal(t, podName, name, "pod name") + assert.Equal(t, podNamespace, namespace, "pod namespace") + } - if e, a := expectedPodName, name; e != a { - t.Fatalf("%d: pod name: expected '%v', got '%v'", port, e, a) - } + ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error { + defer portForwardWG.Done() + assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id") portsMutex.Lock() portsForwarded[port] = struct{}{} @@ -218,17 +182,11 @@ func TestServeWSMultiplePortForward(t *testing.T) { fromClient := make([]byte, 32) n, err := stream.Read(fromClient) - if err != nil { - t.Fatalf("%d: error reading client data: %v", port, err) - } - if e, a := fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]); e != a { - t.Fatalf("%d: client data: expected to receive '%v', got '%v'", port, e, a) - } + assert.NoError(t, err, "reading client data") + assert.Equal(t, fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]), "client data") _, err = stream.Write([]byte(fmt.Sprintf("container data on port %d", port))) - if err != nil { - t.Fatalf("%d: error writing container data: %v", port, err) - } + assert.NoError(t, err, "writing container data") return nil } @@ -239,70 +197,42 @@ func TestServeWSMultiplePortForward(t *testing.T) { } ws, err := websocket.Dial(url, "", "http://127.0.0.1/") - if err != nil { - t.Fatalf("websocket dial unexpected err: %v", err) - } + require.NoError(t, err, "websocket dial") defer ws.Close() for i, port := range ports { channel, data, err := wsRead(ws) - if err != nil { - t.Fatalf("%d: read failed: expected no error: got %v", i, err) - } - if int(channel) != i*2+dataChannel { - t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+dataChannel) - } - if len(data) != binary.Size(port) { - t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port)) - } - if e, a := port, binary.LittleEndian.Uint16(data); e != a { - t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port) - } + assert.NoError(t, err, "port %d read", port) + assert.Equal(t, i*2+dataChannel, int(channel), "port %d channel", port) + assert.Len(t, data, binary.Size(port), "port %d data size", port) + assert.Equal(t, binary.LittleEndian.Uint16(data), port, "port %d data", port) channel, data, err = wsRead(ws) - if err != nil { - t.Fatalf("%d: read succeeded: expected no error: got %v", i, err) - } - if int(channel) != i*2+errorChannel { - t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+errorChannel) - } - if len(data) != binary.Size(port) { - t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port)) - } - if e, a := port, binary.LittleEndian.Uint16(data); e != a { - t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port) - } + assert.NoError(t, err, "port %d read", port) + assert.Equal(t, i*2+errorChannel, int(channel), "port %d channel", port) + assert.Len(t, data, binary.Size(port), "port %d data size", port) + assert.Equal(t, binary.LittleEndian.Uint16(data), port, "port %d data", port) } for i, port := range ports { - println("writing the client data", port) + t.Logf("port %d writing the client data", port) err := wsWrite(ws, byte(i*2+dataChannel), []byte(fmt.Sprintf("client data on port %d", port))) - if err != nil { - t.Fatalf("%d: unexpected error writing client data: %v", i, err) - } + assert.NoError(t, err, "port %d write client data", port) channel, data, err := wsRead(ws) - if err != nil { - t.Fatalf("%d: unexpected error reading container data: %v", i, err) - } - - if int(channel) != i*2+dataChannel { - t.Fatalf("%d: wrong channel: got %q: expected %q", port, channel, i*2+dataChannel) - } - if e, a := fmt.Sprintf("container data on port %d", port), string(data); e != a { - t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a) - } + assert.NoError(t, err, "port %d read container data", port) + assert.Equal(t, i*2+dataChannel, int(channel), "port %d channel", port) + assert.Equal(t, fmt.Sprintf("container data on port %d", port), string(data), "port %d container data", port) } portForwardWG.Wait() portsMutex.Lock() defer portsMutex.Unlock() - if len(ports) != len(portsForwarded) { - t.Fatalf("expected to forward %d ports; got %v", len(ports), portsForwarded) - } + assert.Len(t, portsForwarded, len(ports), "all ports forwarded") } + func wsWrite(conn *websocket.Conn, channel byte, data []byte) error { frame := make([]byte, len(data)+1) frame[0] = channel