mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-23 18:40:53 +00:00
Update unit test.
This commit is contained in:
parent
174b6d0e2f
commit
1eb721248b
@ -17,7 +17,6 @@ limitations under the License.
|
||||
package kubelet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
@ -2095,7 +2094,7 @@ func (f *fakeReadWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestExec(t *testing.T) {
|
||||
func TestGetExec(t *testing.T) {
|
||||
const (
|
||||
podName = "podFoo"
|
||||
podNamespace = "nsFoo"
|
||||
@ -2106,9 +2105,6 @@ func TestExec(t *testing.T) {
|
||||
var (
|
||||
podFullName = kubecontainer.GetPodFullName(podWithUIDNameNs(podUID, podName, podNamespace))
|
||||
command = []string{"ls"}
|
||||
stdin = &bytes.Buffer{}
|
||||
stdout = &fakeReadWriteCloser{}
|
||||
stderr = &fakeReadWriteCloser{}
|
||||
)
|
||||
|
||||
testcases := []struct {
|
||||
@ -2161,22 +2157,16 @@ func TestExec(t *testing.T) {
|
||||
assert.NoError(t, err, description)
|
||||
assert.Equal(t, containertest.FakeHost, redirect.Host, description+": redirect")
|
||||
}
|
||||
|
||||
err = kubelet.ExecInContainer(tc.podFullName, podUID, tc.container, command, stdin, stdout, stderr, tty, nil, 0)
|
||||
assert.Error(t, err, description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortForward(t *testing.T) {
|
||||
func TestGetPortForward(t *testing.T) {
|
||||
const (
|
||||
podName = "podFoo"
|
||||
podNamespace = "nsFoo"
|
||||
podUID types.UID = "12345678"
|
||||
port int32 = 5000
|
||||
)
|
||||
var (
|
||||
stream = &fakeReadWriteCloser{}
|
||||
)
|
||||
|
||||
testcases := []struct {
|
||||
description string
|
||||
@ -2208,7 +2198,6 @@ func TestPortForward(t *testing.T) {
|
||||
}},
|
||||
}
|
||||
|
||||
podFullName := kubecontainer.GetPodFullName(podWithUIDNameNs(podUID, tc.podName, podNamespace))
|
||||
description := "streaming - " + tc.description
|
||||
fakeRuntime := &containertest.FakeStreamingRuntime{FakeRuntime: testKubelet.fakeRuntime}
|
||||
kubelet.containerRuntime = fakeRuntime
|
||||
@ -2221,9 +2210,6 @@ func TestPortForward(t *testing.T) {
|
||||
assert.NoError(t, err, description)
|
||||
assert.Equal(t, containertest.FakeHost, redirect.Host, description+": redirect")
|
||||
}
|
||||
|
||||
err = kubelet.PortForward(podFullName, podUID, port, stream)
|
||||
assert.Error(t, err, description)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -46,21 +46,25 @@ import (
|
||||
"k8s.io/apiserver/pkg/authentication/user"
|
||||
"k8s.io/apiserver/pkg/authorization/authorizer"
|
||||
"k8s.io/client-go/tools/remotecommand"
|
||||
utiltesting "k8s.io/client-go/util/testing"
|
||||
api "k8s.io/kubernetes/pkg/apis/core"
|
||||
runtimeapi "k8s.io/kubernetes/pkg/kubelet/apis/cri/runtime/v1alpha2"
|
||||
statsapi "k8s.io/kubernetes/pkg/kubelet/apis/stats/v1alpha1"
|
||||
// Do some initialization to decode the query parameters correctly.
|
||||
_ "k8s.io/kubernetes/pkg/apis/core/install"
|
||||
"k8s.io/kubernetes/pkg/kubelet/cm"
|
||||
kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
|
||||
kubecontainertesting "k8s.io/kubernetes/pkg/kubelet/container/testing"
|
||||
"k8s.io/kubernetes/pkg/kubelet/server/portforward"
|
||||
remotecommandserver "k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
|
||||
"k8s.io/kubernetes/pkg/kubelet/server/stats"
|
||||
"k8s.io/kubernetes/pkg/kubelet/server/streaming"
|
||||
"k8s.io/kubernetes/pkg/volume"
|
||||
)
|
||||
|
||||
const (
|
||||
testUID = "9b01b80f-8fb4-11e4-95ab-4200af06647"
|
||||
testContainerID = "container789"
|
||||
testPodSandboxID = "pod0987"
|
||||
)
|
||||
|
||||
type fakeKubelet struct {
|
||||
@ -72,16 +76,16 @@ type fakeKubelet struct {
|
||||
runningPodsFunc func() ([]*v1.Pod, error)
|
||||
logFunc func(w http.ResponseWriter, req *http.Request)
|
||||
runFunc func(podFullName string, uid types.UID, containerName string, cmd []string) ([]byte, error)
|
||||
execFunc func(pod string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error
|
||||
attachFunc func(pod string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool) error
|
||||
portForwardFunc func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error
|
||||
getExecCheck func(string, types.UID, string, []string, remotecommandserver.Options)
|
||||
getAttachCheck func(string, types.UID, string, remotecommandserver.Options)
|
||||
getPortForwardCheck func(string, string, types.UID, portforward.V4Options)
|
||||
|
||||
containerLogsFunc func(podFullName, containerName string, logOptions *v1.PodLogOptions, stdout, stderr io.Writer) error
|
||||
streamingConnectionIdleTimeoutFunc func() time.Duration
|
||||
hostnameFunc func() string
|
||||
resyncInterval time.Duration
|
||||
loopEntryTime time.Time
|
||||
plegHealth bool
|
||||
redirectURL *url.URL
|
||||
streamingRuntime streaming.Server
|
||||
}
|
||||
|
||||
func (fk *fakeKubelet) ResyncInterval() time.Duration {
|
||||
@ -136,32 +140,109 @@ func (fk *fakeKubelet) RunInContainer(podFullName string, uid types.UID, contain
|
||||
return fk.runFunc(podFullName, uid, containerName, cmd)
|
||||
}
|
||||
|
||||
func (fk *fakeKubelet) ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize, timeout time.Duration) error {
|
||||
return fk.execFunc(name, uid, container, cmd, in, out, err, tty)
|
||||
type fakeRuntime struct {
|
||||
execFunc func(string, []string, io.Reader, io.WriteCloser, io.WriteCloser, bool, <-chan remotecommand.TerminalSize) error
|
||||
attachFunc func(string, io.Reader, io.WriteCloser, io.WriteCloser, bool, <-chan remotecommand.TerminalSize) error
|
||||
portForwardFunc func(string, int32, io.ReadWriteCloser) error
|
||||
}
|
||||
|
||||
func (fk *fakeKubelet) AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
|
||||
return fk.attachFunc(name, uid, container, in, out, err, tty)
|
||||
func (f *fakeRuntime) Exec(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
|
||||
return f.execFunc(containerID, cmd, stdin, stdout, stderr, tty, resize)
|
||||
}
|
||||
|
||||
func (fk *fakeKubelet) PortForward(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
|
||||
return fk.portForwardFunc(name, uid, port, stream)
|
||||
func (f *fakeRuntime) Attach(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
|
||||
return f.attachFunc(containerID, stdin, stdout, stderr, tty, resize)
|
||||
}
|
||||
|
||||
func (f *fakeRuntime) PortForward(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
|
||||
return f.portForwardFunc(podSandboxID, port, stream)
|
||||
}
|
||||
|
||||
type testStreamingServer struct {
|
||||
streaming.Server
|
||||
fakeRuntime *fakeRuntime
|
||||
testHTTPServer *httptest.Server
|
||||
}
|
||||
|
||||
func newTestStreamingServer(streamIdleTimeout time.Duration) (s *testStreamingServer, err error) {
|
||||
s = &testStreamingServer{}
|
||||
s.testHTTPServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
s.ServeHTTP(w, r)
|
||||
}))
|
||||
defer func() {
|
||||
if err != nil {
|
||||
s.testHTTPServer.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
testURL, err := url.Parse(s.testHTTPServer.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.fakeRuntime = &fakeRuntime{}
|
||||
config := streaming.DefaultConfig
|
||||
config.BaseURL = testURL
|
||||
if streamIdleTimeout != 0 {
|
||||
config.StreamIdleTimeout = streamIdleTimeout
|
||||
}
|
||||
s.Server, err = streaming.NewServer(config, s.fakeRuntime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (fk *fakeKubelet) GetExec(podFullName string, podUID types.UID, containerName string, cmd []string, streamOpts remotecommandserver.Options) (*url.URL, error) {
|
||||
return fk.redirectURL, nil
|
||||
if fk.getExecCheck != nil {
|
||||
fk.getExecCheck(podFullName, podUID, containerName, cmd, streamOpts)
|
||||
}
|
||||
// Always use testContainerID
|
||||
resp, err := fk.streamingRuntime.GetExec(&runtimeapi.ExecRequest{
|
||||
ContainerId: testContainerID,
|
||||
Cmd: cmd,
|
||||
Tty: streamOpts.TTY,
|
||||
Stdin: streamOpts.Stdin,
|
||||
Stdout: streamOpts.Stdout,
|
||||
Stderr: streamOpts.Stderr,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return url.Parse(resp.GetUrl())
|
||||
}
|
||||
|
||||
func (fk *fakeKubelet) GetAttach(podFullName string, podUID types.UID, containerName string, streamOpts remotecommandserver.Options) (*url.URL, error) {
|
||||
return fk.redirectURL, nil
|
||||
if fk.getAttachCheck != nil {
|
||||
fk.getAttachCheck(podFullName, podUID, containerName, streamOpts)
|
||||
}
|
||||
// Always use testContainerID
|
||||
resp, err := fk.streamingRuntime.GetAttach(&runtimeapi.AttachRequest{
|
||||
ContainerId: testContainerID,
|
||||
Tty: streamOpts.TTY,
|
||||
Stdin: streamOpts.Stdin,
|
||||
Stdout: streamOpts.Stdout,
|
||||
Stderr: streamOpts.Stderr,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return url.Parse(resp.GetUrl())
|
||||
}
|
||||
|
||||
func (fk *fakeKubelet) GetPortForward(podName, podNamespace string, podUID types.UID, portForwardOpts portforward.V4Options) (*url.URL, error) {
|
||||
return fk.redirectURL, nil
|
||||
if fk.getPortForwardCheck != nil {
|
||||
fk.getPortForwardCheck(podName, podNamespace, podUID, portForwardOpts)
|
||||
}
|
||||
|
||||
func (fk *fakeKubelet) StreamingConnectionIdleTimeout() time.Duration {
|
||||
return fk.streamingConnectionIdleTimeoutFunc()
|
||||
// Always use testPodSandboxID
|
||||
resp, err := fk.streamingRuntime.GetPortForward(&runtimeapi.PortForwardRequest{
|
||||
PodSandboxId: testPodSandboxID,
|
||||
Port: portForwardOpts.Ports,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return url.Parse(resp.GetUrl())
|
||||
}
|
||||
|
||||
// Unused functions
|
||||
@ -202,13 +283,16 @@ type serverTestFramework struct {
|
||||
fakeKubelet *fakeKubelet
|
||||
fakeAuth *fakeAuth
|
||||
testHTTPServer *httptest.Server
|
||||
fakeRuntime *fakeRuntime
|
||||
testStreamingHTTPServer *httptest.Server
|
||||
criHandler *utiltesting.FakeHandler
|
||||
}
|
||||
|
||||
func newServerTest() *serverTestFramework {
|
||||
return newServerTestWithDebug(true)
|
||||
return newServerTestWithDebug(true, false, nil)
|
||||
}
|
||||
|
||||
func newServerTestWithDebug(enableDebugging bool) *serverTestFramework {
|
||||
func newServerTestWithDebug(enableDebugging, redirectContainerStreaming bool, streamingServer streaming.Server) *serverTestFramework {
|
||||
fw := &serverTestFramework{}
|
||||
fw.fakeKubelet = &fakeKubelet{
|
||||
hostnameFunc: func() string {
|
||||
@ -224,6 +308,7 @@ func newServerTestWithDebug(enableDebugging bool) *serverTestFramework {
|
||||
}, true
|
||||
},
|
||||
plegHealth: true,
|
||||
streamingRuntime: streamingServer,
|
||||
}
|
||||
fw.fakeAuth = &fakeAuth{
|
||||
authenticateFunc: func(req *http.Request) (user.Info, bool, error) {
|
||||
@ -236,13 +321,17 @@ func newServerTestWithDebug(enableDebugging bool) *serverTestFramework {
|
||||
return authorizer.DecisionAllow, "", nil
|
||||
},
|
||||
}
|
||||
fw.criHandler = &utiltesting.FakeHandler{
|
||||
StatusCode: http.StatusOK,
|
||||
}
|
||||
server := NewServer(
|
||||
fw.fakeKubelet,
|
||||
stats.NewResourceAnalyzer(fw.fakeKubelet, time.Minute),
|
||||
fw.fakeAuth,
|
||||
enableDebugging,
|
||||
false,
|
||||
&kubecontainertesting.Mock{})
|
||||
redirectContainerStreaming,
|
||||
fw.criHandler)
|
||||
fw.serverUnderTest = &server
|
||||
fw.testHTTPServer = httptest.NewServer(fw.serverUnderTest)
|
||||
return fw
|
||||
@ -1064,13 +1153,12 @@ func TestContainerLogsWithFollow(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServeExecInContainerIdleTimeout(t *testing.T) {
|
||||
fw := newServerTest()
|
||||
ss, err := newTestStreamingServer(100 * time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
defer ss.testHTTPServer.Close()
|
||||
fw := newServerTestWithDebug(true, false, ss)
|
||||
defer fw.testHTTPServer.Close()
|
||||
|
||||
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
|
||||
return 100 * time.Millisecond
|
||||
}
|
||||
|
||||
podNamespace := "other"
|
||||
podName := "foo"
|
||||
expectedContainerName := "baz"
|
||||
@ -1102,38 +1190,35 @@ func TestServeExecInContainerIdleTimeout(t *testing.T) {
|
||||
}
|
||||
|
||||
func testExecAttach(t *testing.T, verb string) {
|
||||
tests := []struct {
|
||||
tests := map[string]struct {
|
||||
stdin bool
|
||||
stdout bool
|
||||
stderr bool
|
||||
tty bool
|
||||
responseStatusCode int
|
||||
uid bool
|
||||
responseLocation string
|
||||
redirect bool
|
||||
}{
|
||||
{responseStatusCode: http.StatusBadRequest},
|
||||
{stdin: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stdout: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stdout: true, responseStatusCode: http.StatusFound, responseLocation: "http://localhost:12345/" + verb},
|
||||
"no input or output": {responseStatusCode: http.StatusBadRequest},
|
||||
"stdin": {stdin: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
"stdout": {stdout: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
"stderr": {stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
"stdout and stderr": {stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
"stdout stderr and tty": {stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
"stdin stdout and stderr": {stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
"stdin stdout stderr with uid": {stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols, uid: true},
|
||||
"stdout with redirect": {stdout: true, responseStatusCode: http.StatusFound, redirect: true},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
fw := newServerTest()
|
||||
defer fw.testHTTPServer.Close()
|
||||
|
||||
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
if test.responseLocation != "" {
|
||||
var err error
|
||||
fw.fakeKubelet.redirectURL, err = url.Parse(test.responseLocation)
|
||||
for desc, test := range tests {
|
||||
test := test
|
||||
t.Run(desc, func(t *testing.T) {
|
||||
ss, err := newTestStreamingServer(0)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
defer ss.testHTTPServer.Close()
|
||||
fw := newServerTestWithDebug(true, test.redirect, ss)
|
||||
defer fw.testHTTPServer.Close()
|
||||
fmt.Println(desc)
|
||||
|
||||
podNamespace := "other"
|
||||
podName := "foo"
|
||||
@ -1149,81 +1234,67 @@ func testExecAttach(t *testing.T, verb string) {
|
||||
execInvoked := false
|
||||
attachInvoked := false
|
||||
|
||||
testStreamFunc := func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool, done chan struct{}) error {
|
||||
defer close(done)
|
||||
checkStream := func(podFullName string, uid types.UID, containerName string, streamOpts remotecommandserver.Options) {
|
||||
assert.Equal(t, expectedPodName, podFullName, "podFullName")
|
||||
if test.uid {
|
||||
assert.Equal(t, testUID, string(uid), "uid")
|
||||
}
|
||||
assert.Equal(t, expectedContainerName, containerName, "containerName")
|
||||
assert.Equal(t, test.stdin, streamOpts.Stdin, "stdin")
|
||||
assert.Equal(t, test.stdout, streamOpts.Stdout, "stdout")
|
||||
assert.Equal(t, test.tty, streamOpts.TTY, "tty")
|
||||
assert.Equal(t, !test.tty && test.stderr, streamOpts.Stderr, "stderr")
|
||||
}
|
||||
|
||||
if podFullName != expectedPodName {
|
||||
t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName)
|
||||
fw.fakeKubelet.getExecCheck = func(podFullName string, uid types.UID, containerName string, cmd []string, streamOpts remotecommandserver.Options) {
|
||||
execInvoked = true
|
||||
assert.Equal(t, expectedCommand, strings.Join(cmd, " "), "cmd")
|
||||
checkStream(podFullName, uid, containerName, streamOpts)
|
||||
}
|
||||
if test.uid && string(uid) != testUID {
|
||||
t.Fatalf("%d: uid: expected %v, got %v", i, testUID, uid)
|
||||
}
|
||||
if containerName != expectedContainerName {
|
||||
t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName)
|
||||
|
||||
fw.fakeKubelet.getAttachCheck = func(podFullName string, uid types.UID, containerName string, streamOpts remotecommandserver.Options) {
|
||||
attachInvoked = true
|
||||
checkStream(podFullName, uid, containerName, streamOpts)
|
||||
}
|
||||
|
||||
testStream := func(containerID string, in io.Reader, out, stderr io.WriteCloser, tty bool, done chan struct{}) error {
|
||||
close(done)
|
||||
assert.Equal(t, testContainerID, containerID, "containerID")
|
||||
assert.Equal(t, test.tty, tty, "tty")
|
||||
require.Equal(t, test.stdin, in != nil, "in")
|
||||
require.Equal(t, test.stdout, out != nil, "out")
|
||||
require.Equal(t, !test.tty && test.stderr, stderr != nil, "err")
|
||||
|
||||
if test.stdin {
|
||||
if in == nil {
|
||||
t.Fatalf("%d: stdin: expected non-nil", i)
|
||||
}
|
||||
b := make([]byte, 10)
|
||||
n, err := in.Read(b)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading from stdin: %v", i, err)
|
||||
}
|
||||
if e, a := expectedStdin, string(b[0:n]); e != a {
|
||||
t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a)
|
||||
}
|
||||
} else if in != nil {
|
||||
t.Fatalf("%d: stdin: expected nil: %#v", i, in)
|
||||
assert.NoError(t, err, "reading from stdin")
|
||||
assert.Equal(t, expectedStdin, string(b[0:n]), "content from stdin")
|
||||
}
|
||||
|
||||
if test.stdout {
|
||||
if out == nil {
|
||||
t.Fatalf("%d: stdout: expected non-nil", i)
|
||||
}
|
||||
_, err := out.Write([]byte(expectedStdout))
|
||||
if err != nil {
|
||||
t.Fatalf("%d:, error writing to stdout: %v", i, err)
|
||||
}
|
||||
assert.NoError(t, err, "writing to stdout")
|
||||
out.Close()
|
||||
<-clientStdoutReadDone
|
||||
} else if out != nil {
|
||||
t.Fatalf("%d: stdout: expected nil: %#v", i, out)
|
||||
}
|
||||
|
||||
if tty {
|
||||
if stderr != nil {
|
||||
t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr)
|
||||
}
|
||||
} else if test.stderr {
|
||||
if stderr == nil {
|
||||
t.Fatalf("%d: stderr: expected non-nil", i)
|
||||
}
|
||||
if !test.tty && test.stderr {
|
||||
_, err := stderr.Write([]byte(expectedStderr))
|
||||
if err != nil {
|
||||
t.Fatalf("%d:, error writing to stderr: %v", i, err)
|
||||
}
|
||||
assert.NoError(t, err, "writing to stderr")
|
||||
stderr.Close()
|
||||
<-clientStderrReadDone
|
||||
} else if stderr != nil {
|
||||
t.Fatalf("%d: stderr: expected nil: %#v", i, stderr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
|
||||
execInvoked = true
|
||||
if strings.Join(cmd, " ") != expectedCommand {
|
||||
t.Fatalf("%d: cmd: expected: %s, got %v", i, expectedCommand, cmd)
|
||||
}
|
||||
return testStreamFunc(podFullName, uid, containerName, cmd, in, out, stderr, tty, done)
|
||||
ss.fakeRuntime.execFunc = func(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
|
||||
assert.Equal(t, expectedCommand, strings.Join(cmd, " "), "cmd")
|
||||
return testStream(containerID, stdin, stdout, stderr, tty, done)
|
||||
}
|
||||
|
||||
fw.fakeKubelet.attachFunc = func(podFullName string, uid types.UID, containerName string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
|
||||
attachInvoked = true
|
||||
return testStreamFunc(podFullName, uid, containerName, nil, in, out, stderr, tty, done)
|
||||
ss.fakeRuntime.attachFunc = func(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error {
|
||||
return testStream(containerID, stdin, stdout, stderr, tty, done)
|
||||
}
|
||||
|
||||
var url string
|
||||
@ -1250,12 +1321,10 @@ func testExecAttach(t *testing.T, verb string) {
|
||||
|
||||
var (
|
||||
resp *http.Response
|
||||
err error
|
||||
upgradeRoundTripper httpstream.UpgradeRoundTripper
|
||||
c *http.Client
|
||||
)
|
||||
|
||||
if test.responseStatusCode != http.StatusSwitchingProtocols {
|
||||
if test.redirect {
|
||||
c = &http.Client{}
|
||||
// Don't follow redirects, since we want to inspect the redirect response.
|
||||
c.CheckRedirect = func(*http.Request, []*http.Request) error {
|
||||
@ -1267,115 +1336,75 @@ func testExecAttach(t *testing.T, verb string) {
|
||||
}
|
||||
|
||||
resp, err = c.Post(url, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: Got error POSTing: %v", i, err)
|
||||
}
|
||||
require.NoError(t, err, "POSTing")
|
||||
defer resp.Body.Close()
|
||||
|
||||
_, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("%d: Error reading response body: %v", i, err)
|
||||
}
|
||||
|
||||
if e, a := test.responseStatusCode, resp.StatusCode; e != a {
|
||||
t.Fatalf("%d: response status: expected %v, got %v", i, e, a)
|
||||
}
|
||||
|
||||
if e, a := test.responseLocation, resp.Header.Get("Location"); e != a {
|
||||
t.Errorf("%d: response location: expected %v, got %v", i, e, a)
|
||||
}
|
||||
assert.NoError(t, err, "reading response body")
|
||||
|
||||
require.Equal(t, test.responseStatusCode, resp.StatusCode, "response status")
|
||||
if test.responseStatusCode != http.StatusSwitchingProtocols {
|
||||
continue
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgradeRoundTripper.NewConnection(resp)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error creating streaming connection: %s", err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatalf("%d: unexpected nil conn", i)
|
||||
}
|
||||
require.NoError(t, err, "creating streaming connection")
|
||||
defer conn.Close()
|
||||
|
||||
h := http.Header{}
|
||||
h.Set(api.StreamType, api.StreamTypeError)
|
||||
if _, err := conn.CreateStream(h); err != nil {
|
||||
t.Fatalf("%d: error creating error stream: %v", i, err)
|
||||
}
|
||||
_, err = conn.CreateStream(h)
|
||||
require.NoError(t, err, "creating error stream")
|
||||
|
||||
if test.stdin {
|
||||
h.Set(api.StreamType, api.StreamTypeStdin)
|
||||
stream, err := conn.CreateStream(h)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating stdin stream: %v", i, err)
|
||||
}
|
||||
require.NoError(t, err, "creating stdin stream")
|
||||
_, err = stream.Write([]byte(expectedStdin))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error writing to stdin stream: %v", i, err)
|
||||
}
|
||||
require.NoError(t, err, "writing to stdin stream")
|
||||
}
|
||||
|
||||
var stdoutStream httpstream.Stream
|
||||
if test.stdout {
|
||||
h.Set(api.StreamType, api.StreamTypeStdout)
|
||||
stdoutStream, err = conn.CreateStream(h)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating stdout stream: %v", i, err)
|
||||
}
|
||||
require.NoError(t, err, "creating stdout stream")
|
||||
}
|
||||
|
||||
var stderrStream httpstream.Stream
|
||||
if test.stderr && !test.tty {
|
||||
h.Set(api.StreamType, api.StreamTypeStderr)
|
||||
stderrStream, err = conn.CreateStream(h)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating stderr stream: %v", i, err)
|
||||
}
|
||||
require.NoError(t, err, "creating stderr stream")
|
||||
}
|
||||
|
||||
if test.stdout {
|
||||
output := make([]byte, 10)
|
||||
n, err := stdoutStream.Read(output)
|
||||
close(clientStdoutReadDone)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading from stdout stream: %v", i, err)
|
||||
}
|
||||
if e, a := expectedStdout, string(output[0:n]); e != a {
|
||||
t.Fatalf("%d: stdout: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
assert.NoError(t, err, "reading from stdout stream")
|
||||
assert.Equal(t, expectedStdout, string(output[0:n]), "stdout")
|
||||
}
|
||||
|
||||
if test.stderr && !test.tty {
|
||||
output := make([]byte, 10)
|
||||
n, err := stderrStream.Read(output)
|
||||
close(clientStderrReadDone)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading from stderr stream: %v", i, err)
|
||||
}
|
||||
if e, a := expectedStderr, string(output[0:n]); e != a {
|
||||
t.Fatalf("%d: stderr: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
assert.NoError(t, err, "reading from stderr stream")
|
||||
assert.Equal(t, expectedStderr, string(output[0:n]), "stderr")
|
||||
}
|
||||
|
||||
// wait for the server to finish before checking if the attach/exec funcs were invoked
|
||||
<-done
|
||||
|
||||
if verb == "exec" {
|
||||
if !execInvoked {
|
||||
t.Errorf("%d: exec was not invoked", i)
|
||||
}
|
||||
if attachInvoked {
|
||||
t.Errorf("%d: attach should not have been invoked", i)
|
||||
}
|
||||
assert.True(t, execInvoked, "exec should be invoked")
|
||||
assert.False(t, attachInvoked, "attach should not be invoked")
|
||||
} else {
|
||||
if !attachInvoked {
|
||||
t.Errorf("%d: attach was not invoked", i)
|
||||
}
|
||||
if execInvoked {
|
||||
t.Errorf("%d: exec should not have been invoked", i)
|
||||
}
|
||||
assert.True(t, attachInvoked, "attach should be invoked")
|
||||
assert.False(t, execInvoked, "exec should not be invoked")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -1388,13 +1417,12 @@ func TestServeAttachContainer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServePortForwardIdleTimeout(t *testing.T) {
|
||||
fw := newServerTest()
|
||||
ss, err := newTestStreamingServer(100 * time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
defer ss.testHTTPServer.Close()
|
||||
fw := newServerTestWithDebug(true, false, ss)
|
||||
defer fw.testHTTPServer.Close()
|
||||
|
||||
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
|
||||
return 100 * time.Millisecond
|
||||
}
|
||||
|
||||
podNamespace := "other"
|
||||
podName := "foo"
|
||||
|
||||
@ -1422,82 +1450,67 @@ func TestServePortForwardIdleTimeout(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServePortForward(t *testing.T) {
|
||||
tests := []struct {
|
||||
tests := map[string]struct {
|
||||
port string
|
||||
uid bool
|
||||
clientData string
|
||||
containerData string
|
||||
redirect bool
|
||||
shouldError bool
|
||||
responseLocation string
|
||||
}{
|
||||
{port: "", shouldError: true},
|
||||
{port: "abc", shouldError: true},
|
||||
{port: "-1", shouldError: true},
|
||||
{port: "65536", shouldError: true},
|
||||
{port: "0", shouldError: true},
|
||||
{port: "1", shouldError: false},
|
||||
{port: "8000", shouldError: false},
|
||||
{port: "8000", clientData: "client data", containerData: "container data", shouldError: false},
|
||||
{port: "65535", shouldError: false},
|
||||
{port: "65535", uid: true, shouldError: false},
|
||||
{port: "65535", responseLocation: "http://localhost:12345/portforward", shouldError: false},
|
||||
"no port": {port: "", shouldError: true},
|
||||
"none number port": {port: "abc", shouldError: true},
|
||||
"negative port": {port: "-1", shouldError: true},
|
||||
"too large port": {port: "65536", shouldError: true},
|
||||
"0 port": {port: "0", shouldError: true},
|
||||
"min port": {port: "1", shouldError: false},
|
||||
"normal port": {port: "8000", shouldError: false},
|
||||
"normal port with data forward": {port: "8000", clientData: "client data", containerData: "container data", shouldError: false},
|
||||
"max port": {port: "65535", shouldError: false},
|
||||
"normal port with uid": {port: "8000", uid: true, shouldError: false},
|
||||
"normal port with redirect": {port: "8000", redirect: true, shouldError: false},
|
||||
}
|
||||
|
||||
podNamespace := "other"
|
||||
podName := "foo"
|
||||
expectedPodName := getPodName(podName, podNamespace)
|
||||
|
||||
for i, test := range tests {
|
||||
fw := newServerTest()
|
||||
defer fw.testHTTPServer.Close()
|
||||
|
||||
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
if test.responseLocation != "" {
|
||||
var err error
|
||||
fw.fakeKubelet.redirectURL, err = url.Parse(test.responseLocation)
|
||||
for desc, test := range tests {
|
||||
test := test
|
||||
t.Run(desc, func(t *testing.T) {
|
||||
ss, err := newTestStreamingServer(0)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
defer ss.testHTTPServer.Close()
|
||||
fw := newServerTestWithDebug(true, test.redirect, ss)
|
||||
defer fw.testHTTPServer.Close()
|
||||
|
||||
portForwardFuncDone := make(chan struct{})
|
||||
|
||||
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
|
||||
fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) {
|
||||
assert.Equal(t, podName, name, "pod name")
|
||||
assert.Equal(t, podNamespace, namespace, "pod namespace")
|
||||
if test.uid {
|
||||
assert.Equal(t, testUID, string(uid), "uid")
|
||||
}
|
||||
}
|
||||
|
||||
ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
|
||||
defer close(portForwardFuncDone)
|
||||
|
||||
if e, a := expectedPodName, name; e != a {
|
||||
t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
|
||||
if e, a := testUID, uid; test.uid && e != string(a) {
|
||||
t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
|
||||
p, err := strconv.ParseInt(test.port, 10, 32)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err)
|
||||
}
|
||||
if e, a := int32(p), port; e != a {
|
||||
t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id")
|
||||
// The port should be valid if it reaches here.
|
||||
testPort, err := strconv.ParseInt(test.port, 10, 32)
|
||||
require.NoError(t, err, "parse port")
|
||||
assert.Equal(t, int32(testPort), port, "port")
|
||||
|
||||
if test.clientData != "" {
|
||||
fromClient := make([]byte, 32)
|
||||
n, err := stream.Read(fromClient)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading client data: %v", i, err)
|
||||
}
|
||||
if e, a := test.clientData, string(fromClient[0:n]); e != a {
|
||||
t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a)
|
||||
}
|
||||
assert.NoError(t, err, "reading client data")
|
||||
assert.Equal(t, test.clientData, string(fromClient[0:n]), "client data")
|
||||
}
|
||||
|
||||
if test.containerData != "" {
|
||||
_, err := stream.Write([]byte(test.containerData))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error writing container data: %v", i, err)
|
||||
}
|
||||
assert.NoError(t, err, "writing container data")
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -1515,7 +1528,7 @@ func TestServePortForward(t *testing.T) {
|
||||
c *http.Client
|
||||
)
|
||||
|
||||
if len(test.responseLocation) > 0 {
|
||||
if test.redirect {
|
||||
c = &http.Client{}
|
||||
// Don't follow redirects, since we want to inspect the redirect response.
|
||||
c.CheckRedirect = func(*http.Request, []*http.Request) error {
|
||||
@ -1527,74 +1540,70 @@ func TestServePortForward(t *testing.T) {
|
||||
}
|
||||
|
||||
resp, err := c.Post(url, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: Got error POSTing: %v", i, err)
|
||||
}
|
||||
require.NoError(t, err, "POSTing")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if test.responseLocation != "" {
|
||||
assert.Equal(t, http.StatusFound, resp.StatusCode, "%d: status code", i)
|
||||
assert.Equal(t, test.responseLocation, resp.Header.Get("Location"), "%d: location", i)
|
||||
continue
|
||||
if test.redirect {
|
||||
assert.Equal(t, http.StatusFound, resp.StatusCode, "status code")
|
||||
return
|
||||
} else {
|
||||
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode, "%d: status code", i)
|
||||
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode, "status code")
|
||||
}
|
||||
|
||||
conn, err := upgradeRoundTripper.NewConnection(resp)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error creating streaming connection: %s", err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatalf("%d: Unexpected nil connection", i)
|
||||
}
|
||||
require.NoError(t, err, "creating streaming connection")
|
||||
defer conn.Close()
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("streamType", "error")
|
||||
headers.Set("port", test.port)
|
||||
errorStream, err := conn.CreateStream(headers)
|
||||
_ = errorStream
|
||||
haveErr := err != nil
|
||||
if e, a := test.shouldError, haveErr; e != a {
|
||||
t.Fatalf("%d: create stream: expected err=%t, got %t: %v", i, e, a, err)
|
||||
}
|
||||
_, err = conn.CreateStream(headers)
|
||||
assert.Equal(t, test.shouldError, err != nil, "expect error")
|
||||
|
||||
if test.shouldError {
|
||||
continue
|
||||
return
|
||||
}
|
||||
|
||||
headers.Set("streamType", "data")
|
||||
headers.Set("port", test.port)
|
||||
dataStream, err := conn.CreateStream(headers)
|
||||
haveErr = err != nil
|
||||
if e, a := test.shouldError, haveErr; e != a {
|
||||
t.Fatalf("%d: create stream: expected err=%t, got %t: %v", i, e, a, err)
|
||||
}
|
||||
require.NoError(t, err, "create stream")
|
||||
|
||||
if test.clientData != "" {
|
||||
_, err := dataStream.Write([]byte(test.clientData))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error writing client data: %v", i, err)
|
||||
}
|
||||
assert.NoError(t, err, "writing client data")
|
||||
}
|
||||
|
||||
if test.containerData != "" {
|
||||
fromContainer := make([]byte, 32)
|
||||
n, err := dataStream.Read(fromContainer)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error reading container data: %v", i, err)
|
||||
}
|
||||
if e, a := test.containerData, string(fromContainer[0:n]); e != a {
|
||||
t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a)
|
||||
}
|
||||
assert.NoError(t, err, "reading container data")
|
||||
assert.Equal(t, test.containerData, string(fromContainer[0:n]), "container data")
|
||||
}
|
||||
|
||||
<-portForwardFuncDone
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCRIHandler(t *testing.T) {
|
||||
fw := newServerTest()
|
||||
defer fw.testHTTPServer.Close()
|
||||
|
||||
const (
|
||||
path = "/cri/exec/123456abcdef"
|
||||
query = "cmd=echo+foo"
|
||||
)
|
||||
resp, err := http.Get(fw.testHTTPServer.URL + path + "?" + query)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "GET", fw.criHandler.RequestReceived.Method)
|
||||
assert.Equal(t, path, fw.criHandler.RequestReceived.URL.Path)
|
||||
assert.Equal(t, query, fw.criHandler.RequestReceived.URL.RawQuery)
|
||||
}
|
||||
|
||||
func TestDebuggingDisabledHandlers(t *testing.T) {
|
||||
fw := newServerTestWithDebug(false)
|
||||
fw := newServerTestWithDebug(false, false, nil)
|
||||
defer fw.testHTTPServer.Close()
|
||||
|
||||
paths := []string{
|
||||
|
@ -23,11 +23,13 @@ import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/websocket"
|
||||
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
"k8s.io/kubernetes/pkg/kubelet/server/portforward"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -36,75 +38,65 @@ const (
|
||||
)
|
||||
|
||||
func TestServeWSPortForward(t *testing.T) {
|
||||
tests := []struct {
|
||||
tests := map[string]struct {
|
||||
port string
|
||||
uid bool
|
||||
clientData string
|
||||
containerData string
|
||||
shouldError bool
|
||||
}{
|
||||
{port: "", shouldError: true},
|
||||
{port: "abc", shouldError: true},
|
||||
{port: "-1", shouldError: true},
|
||||
{port: "65536", shouldError: true},
|
||||
{port: "0", shouldError: true},
|
||||
{port: "1", shouldError: false},
|
||||
{port: "8000", shouldError: false},
|
||||
{port: "8000", clientData: "client data", containerData: "container data", shouldError: false},
|
||||
{port: "65535", shouldError: false},
|
||||
{port: "65535", uid: true, shouldError: false},
|
||||
"no port": {port: "", shouldError: true},
|
||||
"none number port": {port: "abc", shouldError: true},
|
||||
"negative port": {port: "-1", shouldError: true},
|
||||
"too large port": {port: "65536", shouldError: true},
|
||||
"0 port": {port: "0", shouldError: true},
|
||||
"min port": {port: "1", shouldError: false},
|
||||
"normal port": {port: "8000", shouldError: false},
|
||||
"normal port with data forward": {port: "8000", clientData: "client data", containerData: "container data", shouldError: false},
|
||||
"max port": {port: "65535", shouldError: false},
|
||||
"normal port with uid": {port: "8000", uid: true, shouldError: false},
|
||||
}
|
||||
|
||||
podNamespace := "other"
|
||||
podName := "foo"
|
||||
expectedPodName := getPodName(podName, podNamespace)
|
||||
expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647"
|
||||
|
||||
for i, test := range tests {
|
||||
fw := newServerTest()
|
||||
for desc, test := range tests {
|
||||
test := test
|
||||
t.Run(desc, func(t *testing.T) {
|
||||
ss, err := newTestStreamingServer(0)
|
||||
require.NoError(t, err)
|
||||
defer ss.testHTTPServer.Close()
|
||||
fw := newServerTestWithDebug(true, false, ss)
|
||||
defer fw.testHTTPServer.Close()
|
||||
|
||||
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
portForwardFuncDone := make(chan struct{})
|
||||
|
||||
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
|
||||
fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) {
|
||||
assert.Equal(t, podName, name, "pod name")
|
||||
assert.Equal(t, podNamespace, namespace, "pod namespace")
|
||||
if test.uid {
|
||||
assert.Equal(t, testUID, string(uid), "uid")
|
||||
}
|
||||
}
|
||||
|
||||
ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
|
||||
defer close(portForwardFuncDone)
|
||||
|
||||
if e, a := expectedPodName, name; e != a {
|
||||
t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
|
||||
if e, a := expectedUid, uid; test.uid && e != string(a) {
|
||||
t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
|
||||
p, err := strconv.ParseInt(test.port, 10, 32)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err)
|
||||
}
|
||||
if e, a := int32(p), port; e != a {
|
||||
t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id")
|
||||
// The port should be valid if it reaches here.
|
||||
testPort, err := strconv.ParseInt(test.port, 10, 32)
|
||||
require.NoError(t, err, "parse port")
|
||||
assert.Equal(t, int32(testPort), port, "port")
|
||||
|
||||
if test.clientData != "" {
|
||||
fromClient := make([]byte, 32)
|
||||
n, err := stream.Read(fromClient)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading client data: %v", i, err)
|
||||
}
|
||||
if e, a := test.clientData, string(fromClient[0:n]); e != a {
|
||||
t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a)
|
||||
}
|
||||
assert.NoError(t, err, "reading client data")
|
||||
assert.Equal(t, test.clientData, string(fromClient[0:n]), "client data")
|
||||
}
|
||||
|
||||
if test.containerData != "" {
|
||||
_, err := stream.Write([]byte(test.containerData))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error writing container data: %v", i, err)
|
||||
}
|
||||
assert.NoError(t, err, "writing container data")
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -112,76 +104,48 @@ func TestServeWSPortForward(t *testing.T) {
|
||||
|
||||
var url string
|
||||
if test.uid {
|
||||
url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, expectedUid, test.port)
|
||||
url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, testUID, test.port)
|
||||
} else {
|
||||
url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port)
|
||||
}
|
||||
|
||||
ws, err := websocket.Dial(url, "", "http://127.0.0.1/")
|
||||
assert.Equal(t, test.shouldError, err != nil, "websocket dial")
|
||||
if test.shouldError {
|
||||
if err == nil {
|
||||
t.Fatalf("%d: websocket dial expected err", i)
|
||||
return
|
||||
}
|
||||
continue
|
||||
} else if err != nil {
|
||||
t.Fatalf("%d: websocket dial unexpected err: %v", i, err)
|
||||
}
|
||||
|
||||
defer ws.Close()
|
||||
|
||||
p, err := strconv.ParseUint(test.port, 10, 16)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err)
|
||||
}
|
||||
require.NoError(t, err, "parse port")
|
||||
p16 := uint16(p)
|
||||
|
||||
channel, data, err := wsRead(ws)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: read failed: expected no error: got %v", i, err)
|
||||
}
|
||||
if channel != dataChannel {
|
||||
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, dataChannel)
|
||||
}
|
||||
if len(data) != binary.Size(p16) {
|
||||
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16))
|
||||
}
|
||||
if e, a := p16, binary.LittleEndian.Uint16(data); e != a {
|
||||
t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port)
|
||||
}
|
||||
require.NoError(t, err, "read")
|
||||
assert.Equal(t, dataChannel, int(channel), "channel")
|
||||
assert.Len(t, data, binary.Size(p16), "data size")
|
||||
assert.Equal(t, p16, binary.LittleEndian.Uint16(data), "data")
|
||||
|
||||
channel, data, err = wsRead(ws)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: read succeeded: expected no error: got %v", i, err)
|
||||
}
|
||||
if channel != errorChannel {
|
||||
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, errorChannel)
|
||||
}
|
||||
if len(data) != binary.Size(p16) {
|
||||
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16))
|
||||
}
|
||||
if e, a := p16, binary.LittleEndian.Uint16(data); e != a {
|
||||
t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port)
|
||||
}
|
||||
assert.NoError(t, err, "read")
|
||||
assert.Equal(t, errorChannel, int(channel), "channel")
|
||||
assert.Len(t, data, binary.Size(p16), "data size")
|
||||
assert.Equal(t, p16, binary.LittleEndian.Uint16(data), "data")
|
||||
|
||||
if test.clientData != "" {
|
||||
println("writing the client data")
|
||||
err := wsWrite(ws, dataChannel, []byte(test.clientData))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error writing client data: %v", i, err)
|
||||
}
|
||||
assert.NoError(t, err, "writing client data")
|
||||
}
|
||||
|
||||
if test.containerData != "" {
|
||||
_, data, err = wsRead(ws)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error reading container data: %v", i, err)
|
||||
}
|
||||
if e, a := test.containerData, string(data); e != a {
|
||||
t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a)
|
||||
}
|
||||
assert.NoError(t, err, "reading container data")
|
||||
assert.Equal(t, test.containerData, string(data), "container data")
|
||||
}
|
||||
|
||||
<-portForwardFuncDone
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -190,45 +154,39 @@ func TestServeWSMultiplePortForward(t *testing.T) {
|
||||
ports := []uint16{7000, 8000, 9000}
|
||||
podNamespace := "other"
|
||||
podName := "foo"
|
||||
expectedPodName := getPodName(podName, podNamespace)
|
||||
|
||||
fw := newServerTest()
|
||||
ss, err := newTestStreamingServer(0)
|
||||
require.NoError(t, err)
|
||||
defer ss.testHTTPServer.Close()
|
||||
fw := newServerTestWithDebug(true, false, ss)
|
||||
defer fw.testHTTPServer.Close()
|
||||
|
||||
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
portForwardWG := sync.WaitGroup{}
|
||||
portForwardWG.Add(len(ports))
|
||||
|
||||
portsMutex := sync.Mutex{}
|
||||
portsForwarded := map[int32]struct{}{}
|
||||
|
||||
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
|
||||
defer portForwardWG.Done()
|
||||
|
||||
if e, a := expectedPodName, name; e != a {
|
||||
t.Fatalf("%d: pod name: expected '%v', got '%v'", port, e, a)
|
||||
fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) {
|
||||
assert.Equal(t, podName, name, "pod name")
|
||||
assert.Equal(t, podNamespace, namespace, "pod namespace")
|
||||
}
|
||||
|
||||
ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
|
||||
defer portForwardWG.Done()
|
||||
assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id")
|
||||
|
||||
portsMutex.Lock()
|
||||
portsForwarded[port] = struct{}{}
|
||||
portsMutex.Unlock()
|
||||
|
||||
fromClient := make([]byte, 32)
|
||||
n, err := stream.Read(fromClient)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading client data: %v", port, err)
|
||||
}
|
||||
if e, a := fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]); e != a {
|
||||
t.Fatalf("%d: client data: expected to receive '%v', got '%v'", port, e, a)
|
||||
}
|
||||
assert.NoError(t, err, "reading client data")
|
||||
assert.Equal(t, fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]), "client data")
|
||||
|
||||
_, err = stream.Write([]byte(fmt.Sprintf("container data on port %d", port)))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error writing container data: %v", port, err)
|
||||
}
|
||||
assert.NoError(t, err, "writing container data")
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -239,70 +197,42 @@ func TestServeWSMultiplePortForward(t *testing.T) {
|
||||
}
|
||||
|
||||
ws, err := websocket.Dial(url, "", "http://127.0.0.1/")
|
||||
if err != nil {
|
||||
t.Fatalf("websocket dial unexpected err: %v", err)
|
||||
}
|
||||
require.NoError(t, err, "websocket dial")
|
||||
|
||||
defer ws.Close()
|
||||
|
||||
for i, port := range ports {
|
||||
channel, data, err := wsRead(ws)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: read failed: expected no error: got %v", i, err)
|
||||
}
|
||||
if int(channel) != i*2+dataChannel {
|
||||
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+dataChannel)
|
||||
}
|
||||
if len(data) != binary.Size(port) {
|
||||
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port))
|
||||
}
|
||||
if e, a := port, binary.LittleEndian.Uint16(data); e != a {
|
||||
t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port)
|
||||
}
|
||||
assert.NoError(t, err, "port %d read", port)
|
||||
assert.Equal(t, i*2+dataChannel, int(channel), "port %d channel", port)
|
||||
assert.Len(t, data, binary.Size(port), "port %d data size", port)
|
||||
assert.Equal(t, binary.LittleEndian.Uint16(data), port, "port %d data", port)
|
||||
|
||||
channel, data, err = wsRead(ws)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: read succeeded: expected no error: got %v", i, err)
|
||||
}
|
||||
if int(channel) != i*2+errorChannel {
|
||||
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+errorChannel)
|
||||
}
|
||||
if len(data) != binary.Size(port) {
|
||||
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port))
|
||||
}
|
||||
if e, a := port, binary.LittleEndian.Uint16(data); e != a {
|
||||
t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port)
|
||||
}
|
||||
assert.NoError(t, err, "port %d read", port)
|
||||
assert.Equal(t, i*2+errorChannel, int(channel), "port %d channel", port)
|
||||
assert.Len(t, data, binary.Size(port), "port %d data size", port)
|
||||
assert.Equal(t, binary.LittleEndian.Uint16(data), port, "port %d data", port)
|
||||
}
|
||||
|
||||
for i, port := range ports {
|
||||
println("writing the client data", port)
|
||||
t.Logf("port %d writing the client data", port)
|
||||
err := wsWrite(ws, byte(i*2+dataChannel), []byte(fmt.Sprintf("client data on port %d", port)))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error writing client data: %v", i, err)
|
||||
}
|
||||
assert.NoError(t, err, "port %d write client data", port)
|
||||
|
||||
channel, data, err := wsRead(ws)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error reading container data: %v", i, err)
|
||||
}
|
||||
|
||||
if int(channel) != i*2+dataChannel {
|
||||
t.Fatalf("%d: wrong channel: got %q: expected %q", port, channel, i*2+dataChannel)
|
||||
}
|
||||
if e, a := fmt.Sprintf("container data on port %d", port), string(data); e != a {
|
||||
t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a)
|
||||
}
|
||||
assert.NoError(t, err, "port %d read container data", port)
|
||||
assert.Equal(t, i*2+dataChannel, int(channel), "port %d channel", port)
|
||||
assert.Equal(t, fmt.Sprintf("container data on port %d", port), string(data), "port %d container data", port)
|
||||
}
|
||||
|
||||
portForwardWG.Wait()
|
||||
|
||||
portsMutex.Lock()
|
||||
defer portsMutex.Unlock()
|
||||
if len(ports) != len(portsForwarded) {
|
||||
t.Fatalf("expected to forward %d ports; got %v", len(ports), portsForwarded)
|
||||
}
|
||||
assert.Len(t, portsForwarded, len(ports), "all ports forwarded")
|
||||
}
|
||||
|
||||
func wsWrite(conn *websocket.Conn, channel byte, data []byte) error {
|
||||
frame := make([]byte, len(data)+1)
|
||||
frame[0] = channel
|
||||
|
Loading…
Reference in New Issue
Block a user