From 4551ba6b538aacb4e6de5869c042579113ed4663 Mon Sep 17 00:00:00 2001 From: Andy Goldstein Date: Tue, 22 Mar 2016 09:38:42 -0400 Subject: [PATCH] Refactor exec code to support version skew testing Refactor exec/attach client and server code to better support interoperability testing of different client and server subprotocol versions. --- .../remotecommand/remotecommand.go | 24 +- .../remotecommand/remotecommand_test.go | 499 ++++++++---------- pkg/kubectl/cmd/attach.go | 3 +- pkg/kubectl/cmd/exec.go | 3 +- pkg/kubelet/server/remotecommand/attach.go | 53 ++ pkg/kubelet/server/remotecommand/contants.go | 36 ++ pkg/kubelet/server/remotecommand/doc.go | 18 + pkg/kubelet/server/remotecommand/exec.go | 57 ++ .../server/remotecommand/httpstream.go | 277 ++++++++++ pkg/kubelet/server/remotecommand/websocket.go | 77 +++ pkg/kubelet/server/server.go | 231 +------- pkg/kubelet/server/server_test.go | 385 ++++---------- pkg/util/httpstream/httpstream.go | 18 +- pkg/util/httpstream/httpstream_test.go | 17 +- 14 files changed, 894 insertions(+), 804 deletions(-) create mode 100644 pkg/kubelet/server/remotecommand/attach.go create mode 100644 pkg/kubelet/server/remotecommand/contants.go create mode 100644 pkg/kubelet/server/remotecommand/doc.go create mode 100644 pkg/kubelet/server/remotecommand/exec.go create mode 100644 pkg/kubelet/server/remotecommand/httpstream.go create mode 100644 pkg/kubelet/server/remotecommand/websocket.go diff --git a/pkg/client/unversioned/remotecommand/remotecommand.go b/pkg/client/unversioned/remotecommand/remotecommand.go index d05e6ba5b3e..7144f3093c1 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand.go +++ b/pkg/client/unversioned/remotecommand/remotecommand.go @@ -26,6 +26,7 @@ import ( "k8s.io/kubernetes/pkg/client/restclient" "k8s.io/kubernetes/pkg/client/transport" + "k8s.io/kubernetes/pkg/kubelet/server/remotecommand" "k8s.io/kubernetes/pkg/util/httpstream" "k8s.io/kubernetes/pkg/util/httpstream/spdy" ) @@ -36,7 +37,7 @@ type Executor interface { // non-nil stream to a remote system, and return an error if a problem occurs. If tty // is set, the stderr stream is not used (raw TTY manages stdout and stderr over the // stdout stream). - Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error + Stream(supportedProtocols []string, stdin io.Reader, stdout, stderr io.Writer, tty bool) error } // StreamExecutor supports the ability to dial an httpstream connection and the ability to @@ -128,26 +129,13 @@ func (e *streamExecutor) Dial(protocols ...string) (httpstream.Connection, strin return conn, resp.Header.Get(httpstream.HeaderProtocolVersion), nil } -const ( - // The SPDY subprotocol "channel.k8s.io" is used for remote command - // attachment/execution. This represents the initial unversioned subprotocol, - // which has the known bugs http://issues.k8s.io/13394 and - // http://issues.k8s.io/13395. - StreamProtocolV1Name = "channel.k8s.io" - // The SPDY subprotocol "v2.channel.k8s.io" is used for remote command - // attachment/execution. It is the second version of the subprotocol and - // resolves the issues present in the first version. - StreamProtocolV2Name = "v2.channel.k8s.io" -) - type streamProtocolHandler interface { stream(httpstream.Connection) error } // Stream opens a protocol streamer to the server and streams until a client closes // the connection or the server disconnects. -func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error { - supportedProtocols := []string{StreamProtocolV2Name, StreamProtocolV1Name} +func (e *streamExecutor) Stream(supportedProtocols []string, stdin io.Reader, stdout, stderr io.Writer, tty bool) error { conn, protocol, err := e.Dial(supportedProtocols...) if err != nil { return err @@ -157,7 +145,7 @@ func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty b var streamer streamProtocolHandler switch protocol { - case StreamProtocolV2Name: + case remotecommand.StreamProtocolV2Name: streamer = &streamProtocolV2{ stdin: stdin, stdout: stdout, @@ -165,9 +153,9 @@ func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty b tty: tty, } case "": - glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", StreamProtocolV1Name) + glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", remotecommand.StreamProtocolV1Name) fallthrough - case StreamProtocolV1Name: + case remotecommand.StreamProtocolV1Name: streamer = &streamProtocolV1{ stdin: stdin, stdout: stdout, diff --git a/pkg/client/unversioned/remotecommand/remotecommand_test.go b/pkg/client/unversioned/remotecommand/remotecommand_test.go index 2812d3fb315..f8010035c26 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand_test.go +++ b/pkg/client/unversioned/remotecommand/remotecommand_test.go @@ -18,6 +18,7 @@ package remotecommand import ( "bytes" + "errors" "fmt" "io" "io/ioutil" @@ -26,325 +27,263 @@ import ( "net/url" "strings" "testing" + "time" "k8s.io/kubernetes/pkg/api" "k8s.io/kubernetes/pkg/api/unversioned" "k8s.io/kubernetes/pkg/client/restclient" + "k8s.io/kubernetes/pkg/kubelet/server/remotecommand" + "k8s.io/kubernetes/pkg/types" "k8s.io/kubernetes/pkg/util/httpstream" - "k8s.io/kubernetes/pkg/util/httpstream/spdy" ) -type streamAndReply struct { - httpstream.Stream - replySent <-chan struct{} +type fakeExecutor struct { + t *testing.T + testName string + errorData string + stdoutData string + stderrData string + expectStdin bool + stdinReceived bytes.Buffer + tty bool + messageCount int + command []string + exec bool } -func waitStreamReply(replySent <-chan struct{}, notify chan<- struct{}, stop <-chan struct{}) { - select { - case <-replySent: - notify <- struct{}{} - case <-stop: - } +func (ex *fakeExecutor) ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error { + return ex.run(name, uid, container, cmd, in, out, err, tty) } -func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int) http.HandlerFunc { - // error + stdin + stdout - expectedStreams := 3 - if !tty { - // stderr - expectedStreams++ +func (ex *fakeExecutor) AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool) error { + return ex.run(name, uid, container, nil, in, out, err, tty) +} + +func (ex *fakeExecutor) run(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error { + ex.command = cmd + ex.tty = tty + + if e, a := "pod", name; e != a { + ex.t.Errorf("%s: pod: expected %q, got %q", ex.testName, e, a) + } + if e, a := "uid", uid; e != string(a) { + ex.t.Errorf("%s: uid: expected %q, got %q", ex.testName, e, a) + } + if ex.exec { + if e, a := "ls /", strings.Join(ex.command, " "); e != a { + ex.t.Errorf("%s: command: expected %q, got %q", ex.testName, e, a) + } + } else { + if len(ex.command) > 0 { + ex.t.Errorf("%s: command: expected nothing, got %v", ex.testName, ex.command) + } } + if len(ex.errorData) > 0 { + return errors.New(ex.errorData) + } + + if len(ex.stdoutData) > 0 { + for i := 0; i < ex.messageCount; i++ { + fmt.Fprint(out, ex.stdoutData) + } + } + + if len(ex.stderrData) > 0 { + for i := 0; i < ex.messageCount; i++ { + fmt.Fprint(err, ex.stderrData) + } + } + + if ex.expectStdin { + io.Copy(&ex.stdinReceived, in) + } + + return nil +} + +func fakeServer(t *testing.T, testName string, exec bool, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int, serverProtocols []string) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - protocol, err := httpstream.Handshake(req, w, []string{StreamProtocolV2Name}, StreamProtocolV1Name) - if err != nil { - t.Fatal(err) - } - if protocol != StreamProtocolV2Name { - t.Fatalf("unexpected protocol: %s", protocol) - } - streamCh := make(chan streamAndReply) - - upgrader := spdy.NewResponseUpgrader() - conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error { - streamCh <- streamAndReply{Stream: stream, replySent: replySent} - return nil - }) - // from this point on, we can no longer call methods on w - if conn == nil { - // The upgrader is responsible for notifying the client of any errors that - // occurred during upgrading. All we can do is return here at this point - // if we weren't successful in upgrading. - return - } - defer conn.Close() - - var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream - receivedStreams := 0 - replyChan := make(chan struct{}) - stop := make(chan struct{}) - defer close(stop) - WaitForStreams: - for { - select { - case stream := <-streamCh: - streamType := stream.Headers().Get(api.StreamType) - switch streamType { - case api.StreamTypeError: - errorStream = stream - go waitStreamReply(stream.replySent, replyChan, stop) - case api.StreamTypeStdin: - stdinStream = stream - go waitStreamReply(stream.replySent, replyChan, stop) - case api.StreamTypeStdout: - stdoutStream = stream - go waitStreamReply(stream.replySent, replyChan, stop) - case api.StreamTypeStderr: - stderrStream = stream - go waitStreamReply(stream.replySent, replyChan, stop) - default: - t.Errorf("%d: unexpected stream type: %q", i, streamType) - } - - if receivedStreams == expectedStreams { - break WaitForStreams - } - case <-replyChan: - receivedStreams++ - if receivedStreams == expectedStreams { - break WaitForStreams - } - } + executor := &fakeExecutor{ + t: t, + testName: testName, + errorData: errorData, + stdoutData: stdoutData, + stderrData: stderrData, + expectStdin: len(stdinData) > 0, + tty: tty, + messageCount: messageCount, + exec: exec, } - if len(errorData) > 0 { - n, err := fmt.Fprint(errorStream, errorData) - if err != nil { - t.Errorf("%d: error writing to errorStream: %v", i, err) - } - if e, a := len(errorData), n; e != a { - t.Errorf("%d: expected to write %d bytes to errorStream, but only wrote %d", i, e, a) - } - errorStream.Close() + if exec { + remotecommand.ServeExec(w, req, executor, "pod", "uid", "container", 0, 10*time.Second, serverProtocols) + } else { + remotecommand.ServeAttach(w, req, executor, "pod", "uid", "container", 0, 10*time.Second, serverProtocols) } - if len(stdoutData) > 0 { - for j := 0; j < messageCount; j++ { - n, err := fmt.Fprint(stdoutStream, stdoutData) - if err != nil { - t.Errorf("%d: error writing to stdoutStream: %v", i, err) - } - if e, a := len(stdoutData), n; e != a { - t.Errorf("%d: expected to write %d bytes to stdoutStream, but only wrote %d", i, e, a) - } - } - stdoutStream.Close() - } - if len(stderrData) > 0 { - for j := 0; j < messageCount; j++ { - n, err := fmt.Fprint(stderrStream, stderrData) - if err != nil { - t.Errorf("%d: error writing to stderrStream: %v", i, err) - } - if e, a := len(stderrData), n; e != a { - t.Errorf("%d: expected to write %d bytes to stderrStream, but only wrote %d", i, e, a) - } - } - stderrStream.Close() - } - if len(stdinData) > 0 { - data := make([]byte, len(stdinData)) - for j := 0; j < messageCount; j++ { - n, err := io.ReadFull(stdinStream, data) - if err != nil { - t.Errorf("%d: error reading stdin stream: %v", i, err) - } - if e, a := len(stdinData), n; e != a { - t.Errorf("%d: expected to read %d bytes from stdinStream, but only read %d", i, e, a) - } - if e, a := stdinData, string(data); e != a { - t.Errorf("%d: stdin: expected %q, got %q", i, e, a) - } - } - stdinStream.Close() + if e, a := strings.Repeat(stdinData, messageCount), executor.stdinReceived.String(); e != a { + t.Errorf("%s: stdin: expected %q, got %q", testName, e, a) } }) } -func TestRequestExecuteRemoteCommand(t *testing.T) { +func TestStream(t *testing.T) { testCases := []struct { - Stdin string - Stdout string - Stderr string - Error string - Tty bool - MessageCount int + TestName string + Stdin string + Stdout string + Stderr string + Error string + Tty bool + MessageCount int + ClientProtocols []string + ServerProtocols []string }{ { - Error: "bail", + TestName: "error", + Error: "bail", + Stdout: "a", + ClientProtocols: []string{remotecommand.StreamProtocolV2Name}, + ServerProtocols: []string{remotecommand.StreamProtocolV2Name}, }, { - Stdin: "a", - Stdout: "b", - Stderr: "c", - // TODO bump this to a larger number such as 100 once - // https://github.com/docker/spdystream/issues/55 is fixed and the Godep - // is bumped. Sending multiple messages over stdin/stdout/stderr results - // in more frames being spread across multiple spdystream frame workers. - // This makes it more likely that the spdystream bug will be encountered, - // where streams are closed as soon as a goaway frame is received, and - // any pending frames that haven't been processed yet may not be - // delivered (it's a race). - MessageCount: 1, + TestName: "in/out/err", + Stdin: "a", + Stdout: "b", + Stderr: "c", + MessageCount: 100, + ClientProtocols: []string{remotecommand.StreamProtocolV2Name}, + ServerProtocols: []string{remotecommand.StreamProtocolV2Name}, }, { - Stdin: "a", - Stdout: "b", - Tty: true, + TestName: "in/out/tty", + Stdin: "a", + Stdout: "b", + Tty: true, + MessageCount: 100, + ClientProtocols: []string{remotecommand.StreamProtocolV2Name}, + ServerProtocols: []string{remotecommand.StreamProtocolV2Name}, + }, + { + // 1.0 kubectl, 1.0 kubelet + TestName: "unversioned client, unversioned server", + Stdout: "b", + Stderr: "c", + MessageCount: 1, + ClientProtocols: []string{}, + ServerProtocols: []string{}, + }, + { + // 1.0 kubectl, 1.1+ kubelet + TestName: "unversioned client, versioned server", + Stdout: "b", + Stderr: "c", + MessageCount: 1, + ClientProtocols: []string{}, + ServerProtocols: []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}, + }, + { + // 1.1+ kubectl, 1.0 kubelet + TestName: "versioned client, unversioned server", + Stdout: "b", + Stderr: "c", + MessageCount: 1, + ClientProtocols: []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}, + ServerProtocols: []string{}, }, } - for i, testCase := range testCases { - localOut := &bytes.Buffer{} - localErr := &bytes.Buffer{} - - server := httptest.NewServer(fakeExecServer(t, i, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount)) - - url, _ := url.ParseRequestURI(server.URL) - c := restclient.NewRESTClient(url, "", restclient.ContentConfig{GroupVersion: &unversioned.GroupVersion{Group: "x"}}, -1, -1, nil) - req := c.Post().Resource("testing") - req.SetHeader(httpstream.HeaderProtocolVersion, StreamProtocolV2Name) - req.Param("command", "ls") - req.Param("command", "/") - conf := &restclient.Config{ - Host: server.URL, - } - e, err := NewExecutor(conf, "POST", req.URL()) - if err != nil { - t.Errorf("%d: unexpected error: %v", i, err) - continue - } - err = e.Stream(strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount)), localOut, localErr, testCase.Tty) - hasErr := err != nil - - if len(testCase.Error) > 0 { - if !hasErr { - t.Errorf("%d: expected an error", i) + for _, testCase := range testCases { + for _, exec := range []bool{true, false} { + var name string + if exec { + name = testCase.TestName + " (exec)" } else { - if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) { - t.Errorf("%d: expected error stream read '%v', got '%v'", i, e, a) + name = testCase.TestName + " (attach)" + } + var ( + streamIn io.Reader + streamOut, streamErr io.Writer + ) + localOut := &bytes.Buffer{} + localErr := &bytes.Buffer{} + + server := httptest.NewServer(fakeServer(t, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols)) + + url, _ := url.ParseRequestURI(server.URL) + c := restclient.NewRESTClient(url, "", restclient.ContentConfig{GroupVersion: &unversioned.GroupVersion{Group: "x"}}, -1, -1, nil) + req := c.Post().Resource("testing") + + if exec { + req.Param("command", "ls") + req.Param("command", "/") + } + + if len(testCase.Stdin) > 0 { + req.Param(api.ExecStdinParam, "1") + streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount)) + } + + if len(testCase.Stdout) > 0 { + req.Param(api.ExecStdoutParam, "1") + streamOut = localOut + } + + if testCase.Tty { + req.Param(api.ExecTTYParam, "1") + } else if len(testCase.Stderr) > 0 { + req.Param(api.ExecStderrParam, "1") + streamErr = localErr + } + + conf := &restclient.Config{ + Host: server.URL, + } + e, err := NewExecutor(conf, "POST", req.URL()) + if err != nil { + t.Errorf("%s: unexpected error: %v", name, err) + continue + } + err = e.Stream(testCase.ClientProtocols, streamIn, streamOut, streamErr, testCase.Tty) + hasErr := err != nil + + if len(testCase.Error) > 0 { + if !hasErr { + t.Errorf("%s: expected an error", name) + } else { + if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) { + t.Errorf("%s: expected error stream read %q, got %q", name, e, a) + } + } + + // TODO: Uncomment when fix #19254 + // server.Close() + continue + } + + if hasErr { + t.Errorf("%s: unexpected error: %v", name, err) + // TODO: Uncomment when fix #19254 + // server.Close() + continue + } + + if len(testCase.Stdout) > 0 { + if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() { + t.Errorf("%s: expected stdout data '%s', got '%s'", name, e, a) + } + } + + if testCase.Stderr != "" { + if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() { + t.Errorf("%s: expected stderr data '%s', got '%s'", name, e, a) } } // TODO: Uncomment when fix #19254 // server.Close() - continue } - - if hasErr { - t.Errorf("%d: unexpected error: %v", i, err) - // TODO: Uncomment when fix #19254 - // server.Close() - continue - } - - if len(testCase.Stdout) > 0 { - if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() { - t.Errorf("%d: expected stdout data '%s', got '%s'", i, e, a) - } - } - - if testCase.Stderr != "" { - if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() { - t.Errorf("%d: expected stderr data '%s', got '%s'", i, e, a) - } - } - - // TODO: Uncomment when fix #19254 - // server.Close() - } -} - -// TODO: this test is largely cut and paste, refactor to share code -func TestRequestAttachRemoteCommand(t *testing.T) { - testCases := []struct { - Stdin string - Stdout string - Stderr string - Error string - Tty bool - }{ - { - Error: "bail", - }, - { - Stdin: "a", - Stdout: "b", - Stderr: "c", - }, - { - Stdin: "a", - Stdout: "b", - Tty: true, - }, - } - - for i, testCase := range testCases { - localOut := &bytes.Buffer{} - localErr := &bytes.Buffer{} - - server := httptest.NewServer(fakeExecServer(t, i, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, 1)) - - url, _ := url.ParseRequestURI(server.URL) - c := restclient.NewRESTClient(url, "", restclient.ContentConfig{GroupVersion: &unversioned.GroupVersion{Group: "x"}}, -1, -1, nil) - req := c.Post().Resource("testing") - - conf := &restclient.Config{ - Host: server.URL, - } - e, err := NewExecutor(conf, "POST", req.URL()) - if err != nil { - t.Errorf("%d: unexpected error: %v", i, err) - continue - } - err = e.Stream(strings.NewReader(testCase.Stdin), localOut, localErr, testCase.Tty) - hasErr := err != nil - - if len(testCase.Error) > 0 { - if !hasErr { - t.Errorf("%d: expected an error", i) - } else { - if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) { - t.Errorf("%d: expected error stream read '%v', got '%v'", i, e, a) - } - } - - // TODO: Uncomment when fix #19254 - // server.Close() - continue - } - - if hasErr { - t.Errorf("%d: unexpected error: %v", i, err) - // TODO: Uncomment when fix #19254 - // server.Close() - continue - } - - if len(testCase.Stdout) > 0 { - if e, a := testCase.Stdout, localOut; e != a.String() { - t.Errorf("%d: expected stdout data '%s', got '%s'", i, e, a) - } - } - - if testCase.Stderr != "" { - if e, a := testCase.Stderr, localErr; e != a.String() { - t.Errorf("%d: expected stderr data '%s', got '%s'", i, e, a) - } - } - - // TODO: Uncomment when fix #19254 - // server.Close() } } diff --git a/pkg/kubectl/cmd/attach.go b/pkg/kubectl/cmd/attach.go index 75254f9d9c5..3150bccaee5 100644 --- a/pkg/kubectl/cmd/attach.go +++ b/pkg/kubectl/cmd/attach.go @@ -29,6 +29,7 @@ import ( client "k8s.io/kubernetes/pkg/client/unversioned" "k8s.io/kubernetes/pkg/client/unversioned/remotecommand" cmdutil "k8s.io/kubernetes/pkg/kubectl/cmd/util" + remotecommandserver "k8s.io/kubernetes/pkg/kubelet/server/remotecommand" utilerrors "k8s.io/kubernetes/pkg/util/errors" "k8s.io/kubernetes/pkg/util/interrupt" "k8s.io/kubernetes/pkg/util/term" @@ -87,7 +88,7 @@ func (*DefaultRemoteAttach) Attach(method string, url *url.URL, config *restclie if err != nil { return err } - return exec.Stream(stdin, stdout, stderr, tty) + return exec.Stream(remotecommandserver.SupportedStreamingProtocols, stdin, stdout, stderr, tty) } // AttachOptions declare the arguments accepted by the Exec command diff --git a/pkg/kubectl/cmd/exec.go b/pkg/kubectl/cmd/exec.go index ca981b44e5d..e5992587353 100644 --- a/pkg/kubectl/cmd/exec.go +++ b/pkg/kubectl/cmd/exec.go @@ -32,6 +32,7 @@ import ( client "k8s.io/kubernetes/pkg/client/unversioned" "k8s.io/kubernetes/pkg/client/unversioned/remotecommand" cmdutil "k8s.io/kubernetes/pkg/kubectl/cmd/util" + remotecommandserver "k8s.io/kubernetes/pkg/kubelet/server/remotecommand" ) const ( @@ -87,7 +88,7 @@ func (*DefaultRemoteExecutor) Execute(method string, url *url.URL, config *restc if err != nil { return err } - return exec.Stream(stdin, stdout, stderr, tty) + return exec.Stream(remotecommandserver.SupportedStreamingProtocols, stdin, stdout, stderr, tty) } // ExecOptions declare the arguments accepted by the Exec command diff --git a/pkg/kubelet/server/remotecommand/attach.go b/pkg/kubelet/server/remotecommand/attach.go new file mode 100644 index 00000000000..0f9ba7ff5e3 --- /dev/null +++ b/pkg/kubelet/server/remotecommand/attach.go @@ -0,0 +1,53 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "errors" + "fmt" + "io" + "net/http" + "time" + + "k8s.io/kubernetes/pkg/types" + "k8s.io/kubernetes/pkg/util/runtime" +) + +// Attacher knows how to attach to a running container in a pod. +type Attacher interface { + // AttachContainer attaches to the running container in the pod, copying data between in/out/err + // and the container's stdin/stdout/stderr. + AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool) error +} + +// ServeAttach handles requests to attach to a container. After creating/receiving the required +// streams, it delegates the actual attaching to attacher. +func ServeAttach(w http.ResponseWriter, req *http.Request, attacher Attacher, podName string, uid types.UID, container string, idleTimeout, streamCreationTimeout time.Duration, supportedProtocols []string) { + ctx, ok := createStreams(req, w, supportedProtocols, idleTimeout, streamCreationTimeout) + if !ok { + // error is handled by createStreams + return + } + defer ctx.conn.Close() + + err := attacher.AttachContainer(podName, uid, container, ctx.stdinStream, ctx.stdoutStream, ctx.stderrStream, ctx.tty) + if err != nil { + msg := fmt.Sprintf("error attaching to container: %v", err) + runtime.HandleError(errors.New(msg)) + fmt.Fprint(ctx.errorStream, msg) + } +} diff --git a/pkg/kubelet/server/remotecommand/contants.go b/pkg/kubelet/server/remotecommand/contants.go new file mode 100644 index 00000000000..f45cc644032 --- /dev/null +++ b/pkg/kubelet/server/remotecommand/contants.go @@ -0,0 +1,36 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import "time" + +const ( + DefaultStreamCreationTimeout = 30 * time.Second + + // The SPDY subprotocol "channel.k8s.io" is used for remote command + // attachment/execution. This represents the initial unversioned subprotocol, + // which has the known bugs http://issues.k8s.io/13394 and + // http://issues.k8s.io/13395. + StreamProtocolV1Name = "channel.k8s.io" + + // The SPDY subprotocol "v2.channel.k8s.io" is used for remote command + // attachment/execution. It is the second version of the subprotocol and + // resolves the issues present in the first version. + StreamProtocolV2Name = "v2.channel.k8s.io" +) + +var SupportedStreamingProtocols = []string{StreamProtocolV2Name, StreamProtocolV1Name} diff --git a/pkg/kubelet/server/remotecommand/doc.go b/pkg/kubelet/server/remotecommand/doc.go new file mode 100644 index 00000000000..482e9afc1f6 --- /dev/null +++ b/pkg/kubelet/server/remotecommand/doc.go @@ -0,0 +1,18 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// package remotecommand contains functions related to executing commands in and attaching to pods. +package remotecommand diff --git a/pkg/kubelet/server/remotecommand/exec.go b/pkg/kubelet/server/remotecommand/exec.go new file mode 100644 index 00000000000..df9a4b5854e --- /dev/null +++ b/pkg/kubelet/server/remotecommand/exec.go @@ -0,0 +1,57 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "errors" + "fmt" + "io" + "net/http" + "time" + + "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/types" + "k8s.io/kubernetes/pkg/util/runtime" +) + +// Executor knows how to execute a command in a container in a pod. +type Executor interface { + // ExecInContainer executes a command in a container in the pod, copying data + // between in/out/err and the container's stdin/stdout/stderr. + ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error +} + +// ServeExec handles requests to execute a command in a container. After +// creating/receiving the required streams, it delegates the actual execution +// to the executor. +func ServeExec(w http.ResponseWriter, req *http.Request, executor Executor, podName string, uid types.UID, container string, idleTimeout, streamCreationTimeout time.Duration, supportedProtocols []string) { + ctx, ok := createStreams(req, w, supportedProtocols, idleTimeout, streamCreationTimeout) + if !ok { + // error is handled by createStreams + return + } + defer ctx.conn.Close() + + cmd := req.URL.Query()[api.ExecCommandParamm] + + err := executor.ExecInContainer(podName, uid, container, cmd, ctx.stdinStream, ctx.stdoutStream, ctx.stderrStream, ctx.tty) + if err != nil { + msg := fmt.Sprintf("error executing command in container: %v", err) + runtime.HandleError(errors.New(msg)) + fmt.Fprint(ctx.errorStream, msg) + } +} diff --git a/pkg/kubelet/server/remotecommand/httpstream.go b/pkg/kubelet/server/remotecommand/httpstream.go new file mode 100644 index 00000000000..4b0c588e9fa --- /dev/null +++ b/pkg/kubelet/server/remotecommand/httpstream.go @@ -0,0 +1,277 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "errors" + "fmt" + "io" + "net/http" + "time" + + "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/util/httpstream" + "k8s.io/kubernetes/pkg/util/httpstream/spdy" + "k8s.io/kubernetes/pkg/util/runtime" + "k8s.io/kubernetes/pkg/util/wsstream" + + "github.com/golang/glog" +) + +// options contains details about which streams are required for +// remote command execution. +type options struct { + stdin bool + stdout bool + stderr bool + tty bool + expectedStreams int +} + +// newOptions creates a new options from the Request. +func newOptions(req *http.Request) (*options, error) { + tty := req.FormValue(api.ExecTTYParam) == "1" + stdin := req.FormValue(api.ExecStdinParam) == "1" + stdout := req.FormValue(api.ExecStdoutParam) == "1" + stderr := req.FormValue(api.ExecStderrParam) == "1" + if tty && stderr { + // TODO: make this an error before we reach this method + glog.V(4).Infof("Access to exec with tty and stderr is not supported, bypassing stderr") + stderr = false + } + + // count the streams client asked for, starting with 1 + expectedStreams := 1 + if stdin { + expectedStreams++ + } + if stdout { + expectedStreams++ + } + if stderr { + expectedStreams++ + } + + if expectedStreams == 1 { + return nil, fmt.Errorf("you must specify at least 1 of stdin, stdout, stderr") + } + + return &options{ + stdin: stdin, + stdout: stdout, + stderr: stderr, + tty: tty, + expectedStreams: expectedStreams, + }, nil +} + +// context contains the connection and streams used when +// forwarding an attach or execute session into a container. +type context struct { + conn io.Closer + stdinStream io.ReadCloser + stdoutStream io.WriteCloser + stderrStream io.WriteCloser + errorStream io.WriteCloser + tty bool +} + +// streamAndReply holds both a Stream and a channel that is closed when the stream's reply frame is +// enqueued. Consumers can wait for replySent to be closed prior to proceeding, to ensure that the +// replyFrame is enqueued before the connection's goaway frame is sent (e.g. if a stream was +// received and right after, the connection gets closed). +type streamAndReply struct { + httpstream.Stream + replySent <-chan struct{} +} + +// waitStreamReply waits until either replySent or stop is closed. If replySent is closed, it sends +// an empty struct to the notify channel. +func waitStreamReply(replySent <-chan struct{}, notify chan<- struct{}, stop <-chan struct{}) { + select { + case <-replySent: + notify <- struct{}{} + case <-stop: + } +} + +func createStreams(req *http.Request, w http.ResponseWriter, supportedStreamProtocols []string, idleTimeout, streamCreationTimeout time.Duration) (*context, bool) { + opts, err := newOptions(req) + if err != nil { + runtime.HandleError(err) + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, err.Error()) + return nil, false + } + + if wsstream.IsWebSocketRequest(req) { + return createWebSocketStreams(req, w, opts, idleTimeout) + } + + protocol, err := httpstream.Handshake(req, w, supportedStreamProtocols) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, err.Error()) + return nil, false + } + + streamCh := make(chan streamAndReply) + + upgrader := spdy.NewResponseUpgrader() + conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error { + streamCh <- streamAndReply{Stream: stream, replySent: replySent} + return nil + }) + // from this point on, we can no longer call methods on response + if conn == nil { + // The upgrader is responsible for notifying the client of any errors that + // occurred during upgrading. All we can do is return here at this point + // if we weren't successful in upgrading. + return nil, false + } + + conn.SetIdleTimeout(idleTimeout) + + var handler protocolHandler + switch protocol { + case StreamProtocolV2Name: + handler = &v2ProtocolHandler{} + case "": + glog.V(4).Infof("Client did not request protocol negotiaion. Falling back to %q", StreamProtocolV1Name) + fallthrough + case StreamProtocolV1Name: + handler = &v1ProtocolHandler{} + } + + expired := time.NewTimer(streamCreationTimeout) + + ctx, err := handler.waitForStreams(streamCh, opts.expectedStreams, expired.C) + if err != nil { + runtime.HandleError(err) + return nil, false + } + + ctx.conn = conn + ctx.tty = opts.tty + return ctx, true +} + +type protocolHandler interface { + // waitForStreams waits for the expected streams or a timeout, returning a + // remoteCommandContext if all the streams were received, or an error if not. + waitForStreams(streams <-chan streamAndReply, expectedStreams int, expired <-chan time.Time) (*context, error) +} + +// v2ProtocolHandler implements the V2 protocol version for streaming command execution. +type v2ProtocolHandler struct{} + +func (*v2ProtocolHandler) waitForStreams(streams <-chan streamAndReply, expectedStreams int, expired <-chan time.Time) (*context, error) { + ctx := &context{} + receivedStreams := 0 + replyChan := make(chan struct{}) + stop := make(chan struct{}) + defer close(stop) +WaitForStreams: + for { + select { + case stream := <-streams: + streamType := stream.Headers().Get(api.StreamType) + switch streamType { + case api.StreamTypeError: + ctx.errorStream = stream + go waitStreamReply(stream.replySent, replyChan, stop) + case api.StreamTypeStdin: + ctx.stdinStream = stream + go waitStreamReply(stream.replySent, replyChan, stop) + case api.StreamTypeStdout: + ctx.stdoutStream = stream + go waitStreamReply(stream.replySent, replyChan, stop) + case api.StreamTypeStderr: + ctx.stderrStream = stream + go waitStreamReply(stream.replySent, replyChan, stop) + default: + runtime.HandleError(fmt.Errorf("Unexpected stream type: %q", streamType)) + } + case <-replyChan: + receivedStreams++ + if receivedStreams == expectedStreams { + break WaitForStreams + } + case <-expired: + // TODO find a way to return the error to the user. Maybe use a separate + // stream to report errors? + return nil, errors.New("timed out waiting for client to create streams") + } + } + + return ctx, nil +} + +// v1ProtocolHandler implements the V1 protocol version for streaming command execution. +type v1ProtocolHandler struct{} + +func (*v1ProtocolHandler) waitForStreams(streams <-chan streamAndReply, expectedStreams int, expired <-chan time.Time) (*context, error) { + ctx := &context{} + receivedStreams := 0 + replyChan := make(chan struct{}) + stop := make(chan struct{}) + defer close(stop) +WaitForStreams: + for { + select { + case stream := <-streams: + streamType := stream.Headers().Get(api.StreamType) + switch streamType { + case api.StreamTypeError: + ctx.errorStream = stream + + // This defer statement shouldn't be here, but due to previous refactoring, it ended up in + // here. This is what 1.0.x kubelets do, so we're retaining that behavior. This is fixed in + // the v2ProtocolHandler. + defer stream.Reset() + + go waitStreamReply(stream.replySent, replyChan, stop) + case api.StreamTypeStdin: + ctx.stdinStream = stream + go waitStreamReply(stream.replySent, replyChan, stop) + case api.StreamTypeStdout: + ctx.stdoutStream = stream + go waitStreamReply(stream.replySent, replyChan, stop) + case api.StreamTypeStderr: + ctx.stderrStream = stream + go waitStreamReply(stream.replySent, replyChan, stop) + default: + runtime.HandleError(fmt.Errorf("Unexpected stream type: %q", streamType)) + } + case <-replyChan: + receivedStreams++ + if receivedStreams == expectedStreams { + break WaitForStreams + } + case <-expired: + // TODO find a way to return the error to the user. Maybe use a separate + // stream to report errors? + return nil, errors.New("timed out waiting for client to create streams") + } + } + + if ctx.stdinStream != nil { + ctx.stdinStream.Close() + } + + return ctx, nil +} diff --git a/pkg/kubelet/server/remotecommand/websocket.go b/pkg/kubelet/server/remotecommand/websocket.go new file mode 100644 index 00000000000..06a84c8e7d9 --- /dev/null +++ b/pkg/kubelet/server/remotecommand/websocket.go @@ -0,0 +1,77 @@ +/* +Copyright 2016 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "net/http" + "time" + + "k8s.io/kubernetes/pkg/httplog" + "k8s.io/kubernetes/pkg/util/wsstream" + + "github.com/golang/glog" +) + +// standardShellChannels returns the standard channel types for a shell connection (STDIN 0, STDOUT 1, STDERR 2) +// along with the approximate duplex value. Supported subprotocols are "channel.k8s.io" and +// "base64.channel.k8s.io". +func standardShellChannels(stdin, stdout, stderr bool) []wsstream.ChannelType { + // open three half-duplex channels + channels := []wsstream.ChannelType{wsstream.ReadChannel, wsstream.WriteChannel, wsstream.WriteChannel} + if !stdin { + channels[0] = wsstream.IgnoreChannel + } + if !stdout { + channels[1] = wsstream.IgnoreChannel + } + if !stderr { + channels[2] = wsstream.IgnoreChannel + } + return channels +} + +// createWebSocketStreams returns a remoteCommandContext containing the websocket connection and +// streams needed to perform an exec or an attach. +func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts *options, idleTimeout time.Duration) (*context, bool) { + // open the requested channels, and always open the error channel + channels := append(standardShellChannels(opts.stdin, opts.stdout, opts.stderr), wsstream.WriteChannel) + conn := wsstream.NewConn(channels...) + conn.SetIdleTimeout(idleTimeout) + streams, err := conn.Open(httplog.Unlogged(w), req) + if err != nil { + glog.Errorf("Unable to upgrade websocket connection: %v", err) + return nil, false + } + // Send an empty message to the lowest writable channel to notify the client the connection is established + // TODO: make generic to SPDY and WebSockets and do it outside of this method? + switch { + case opts.stdout: + streams[1].Write([]byte{}) + case opts.stderr: + streams[2].Write([]byte{}) + default: + streams[3].Write([]byte{}) + } + return &context{ + conn: conn, + stdinStream: streams[0], + stdoutStream: streams[1], + stderrStream: streams[2], + errorStream: streams[3], + tty: opts.tty, + }, true +} diff --git a/pkg/kubelet/server/server.go b/pkg/kubelet/server/server.go index 44539aee01d..aef80c33974 100644 --- a/pkg/kubelet/server/server.go +++ b/pkg/kubelet/server/server.go @@ -43,12 +43,12 @@ import ( "k8s.io/kubernetes/pkg/api/validation" "k8s.io/kubernetes/pkg/auth/authenticator" "k8s.io/kubernetes/pkg/auth/authorizer" - "k8s.io/kubernetes/pkg/client/unversioned/remotecommand" "k8s.io/kubernetes/pkg/healthz" "k8s.io/kubernetes/pkg/httplog" "k8s.io/kubernetes/pkg/kubelet/cm" kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" "k8s.io/kubernetes/pkg/kubelet/server/portforward" + "k8s.io/kubernetes/pkg/kubelet/server/remotecommand" "k8s.io/kubernetes/pkg/kubelet/server/stats" "k8s.io/kubernetes/pkg/runtime" "k8s.io/kubernetes/pkg/types" @@ -58,7 +58,6 @@ import ( "k8s.io/kubernetes/pkg/util/httpstream/spdy" "k8s.io/kubernetes/pkg/util/limitwriter" utilruntime "k8s.io/kubernetes/pkg/util/runtime" - "k8s.io/kubernetes/pkg/util/wsstream" "k8s.io/kubernetes/pkg/volume" ) @@ -540,12 +539,7 @@ func getContainerCoordinates(request *restful.Request) (namespace, pod string, u return } -const defaultStreamCreationTimeout = 30 * time.Second - -type Closer interface { - Close() error -} - +// getAttach handles requests to attach to a container. func (s *Server) getAttach(request *restful.Request, response *restful.Response) { podNamespace, podID, uid, container := getContainerCoordinates(request) pod, ok := s.host.GetPodByName(podNamespace, podID) @@ -554,21 +548,35 @@ func (s *Server) getAttach(request *restful.Request, response *restful.Response) return } - stdinStream, stdoutStream, stderrStream, errorStream, conn, tty, ok := s.createStreams(request, response) - if conn != nil { - defer conn.Close() - } + remotecommand.ServeAttach(response.ResponseWriter, + request.Request, + s.host, + kubecontainer.GetPodFullName(pod), + uid, + container, + s.host.StreamingConnectionIdleTimeout(), + remotecommand.DefaultStreamCreationTimeout, + remotecommand.SupportedStreamingProtocols) +} + +// getExec handles requests to run a command inside a container. +func (s *Server) getExec(request *restful.Request, response *restful.Response) { + podNamespace, podID, uid, container := getContainerCoordinates(request) + pod, ok := s.host.GetPodByName(podNamespace, podID) if !ok { - // error is handled in the createStreams function + response.WriteError(http.StatusNotFound, fmt.Errorf("pod does not exist")) return } - err := s.host.AttachContainer(kubecontainer.GetPodFullName(pod), uid, container, stdinStream, stdoutStream, stderrStream, tty) - if err != nil { - msg := fmt.Sprintf("Error executing command in container: %v", err) - glog.Error(msg) - errorStream.Write([]byte(msg)) - } + remotecommand.ServeExec(response.ResponseWriter, + request.Request, + s.host, + kubecontainer.GetPodFullName(pod), + uid, + container, + s.host.StreamingConnectionIdleTimeout(), + remotecommand.DefaultStreamCreationTimeout, + remotecommand.SupportedStreamingProtocols) } // getRun handles requests to run a command inside a container. @@ -588,187 +596,6 @@ func (s *Server) getRun(request *restful.Request, response *restful.Response) { writeJsonResponse(response, data) } -// getExec handles requests to run a command inside a container. -func (s *Server) getExec(request *restful.Request, response *restful.Response) { - podNamespace, podID, uid, container := getContainerCoordinates(request) - pod, ok := s.host.GetPodByName(podNamespace, podID) - if !ok { - response.WriteError(http.StatusNotFound, fmt.Errorf("pod does not exist")) - return - } - stdinStream, stdoutStream, stderrStream, errorStream, conn, tty, ok := s.createStreams(request, response) - if conn != nil { - defer conn.Close() - } - if !ok { - // error is handled in the createStreams function - return - } - cmd := request.Request.URL.Query()[api.ExecCommandParamm] - err := s.host.ExecInContainer(kubecontainer.GetPodFullName(pod), uid, container, cmd, stdinStream, stdoutStream, stderrStream, tty) - if err != nil { - msg := fmt.Sprintf("Error executing command in container: %v", err) - glog.Error(msg) - errorStream.Write([]byte(msg)) - } -} - -// standardShellChannels returns the standard channel types for a shell connection (STDIN 0, STDOUT 1, STDERR 2) -// along with the approprxate duplex value -func standardShellChannels(stdin, stdout, stderr bool) []wsstream.ChannelType { - // open three half-duplex channels - channels := []wsstream.ChannelType{wsstream.ReadChannel, wsstream.WriteChannel, wsstream.WriteChannel} - if !stdin { - channels[0] = wsstream.IgnoreChannel - } - if !stdout { - channels[1] = wsstream.IgnoreChannel - } - if !stderr { - channels[2] = wsstream.IgnoreChannel - } - return channels -} - -// streamAndReply holds both a Stream and a channel that is closed when the stream's reply frame is -// enqueued. Consumers can wait for replySent to be closed prior to proceeding, to ensure that the -// replyFrame is enqueued before the connection's goaway frame is sent (e.g. if a stream was -// received and right after, the connection gets closed). -type streamAndReply struct { - httpstream.Stream - replySent <-chan struct{} -} - -func (s *Server) createStreams(request *restful.Request, response *restful.Response) (io.Reader, io.WriteCloser, io.WriteCloser, io.WriteCloser, Closer, bool, bool) { - tty := request.QueryParameter(api.ExecTTYParam) == "1" - stdin := request.QueryParameter(api.ExecStdinParam) == "1" - stdout := request.QueryParameter(api.ExecStdoutParam) == "1" - stderr := request.QueryParameter(api.ExecStderrParam) == "1" - if tty && stderr { - // TODO: make this an error before we reach this method - glog.V(4).Infof("Access to exec with tty and stderr is not supported, bypassing stderr") - stderr = false - } - - // count the streams client asked for, starting with 1 - expectedStreams := 1 - if stdin { - expectedStreams++ - } - if stdout { - expectedStreams++ - } - if stderr { - expectedStreams++ - } - - if expectedStreams == 1 { - response.WriteError(http.StatusBadRequest, fmt.Errorf("you must specify at least 1 of stdin, stdout, stderr")) - return nil, nil, nil, nil, nil, false, false - } - - if wsstream.IsWebSocketRequest(request.Request) { - // open the requested channels, and always open the error channel - channels := append(standardShellChannels(stdin, stdout, stderr), wsstream.WriteChannel) - conn := wsstream.NewConn(channels...) - conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout()) - streams, err := conn.Open(httplog.Unlogged(response.ResponseWriter), request.Request) - if err != nil { - glog.Errorf("Unable to upgrade websocket connection: %v", err) - return nil, nil, nil, nil, nil, false, false - } - // Send an empty message to the lowest writable channel to notify the client the connection is established - // TODO: make generic to SDPY and WebSockets and do it outside of this method? - switch { - case stdout: - streams[1].Write([]byte{}) - case stderr: - streams[2].Write([]byte{}) - default: - streams[3].Write([]byte{}) - } - return streams[0], streams[1], streams[2], streams[3], conn, tty, true - } - - supportedStreamProtocols := []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name} - _, err := httpstream.Handshake(request.Request, response.ResponseWriter, supportedStreamProtocols, remotecommand.StreamProtocolV1Name) - // negotiated protocol isn't used server side at the moment, but could be in the future - if err != nil { - return nil, nil, nil, nil, nil, false, false - } - - streamCh := make(chan streamAndReply) - - upgrader := spdy.NewResponseUpgrader() - conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream, replySent <-chan struct{}) error { - streamCh <- streamAndReply{Stream: stream, replySent: replySent} - return nil - }) - // from this point on, we can no longer call methods on response - if conn == nil { - // The upgrader is responsible for notifying the client of any errors that - // occurred during upgrading. All we can do is return here at this point - // if we weren't successful in upgrading. - return nil, nil, nil, nil, nil, false, false - } - - conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout()) - - // TODO make it configurable? - expired := time.NewTimer(defaultStreamCreationTimeout) - - var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream - receivedStreams := 0 - replyChan := make(chan struct{}) - stop := make(chan struct{}) - defer close(stop) -WaitForStreams: - for { - select { - case stream := <-streamCh: - streamType := stream.Headers().Get(api.StreamType) - switch streamType { - case api.StreamTypeError: - errorStream = stream - go waitStreamReply(stream.replySent, replyChan, stop) - case api.StreamTypeStdin: - stdinStream = stream - go waitStreamReply(stream.replySent, replyChan, stop) - case api.StreamTypeStdout: - stdoutStream = stream - go waitStreamReply(stream.replySent, replyChan, stop) - case api.StreamTypeStderr: - stderrStream = stream - go waitStreamReply(stream.replySent, replyChan, stop) - default: - glog.Errorf("Unexpected stream type: '%s'", streamType) - } - case <-replyChan: - receivedStreams++ - if receivedStreams == expectedStreams { - break WaitForStreams - } - case <-expired.C: - // TODO find a way to return the error to the user. Maybe use a separate - // stream to report errors? - glog.Error("Timed out waiting for client to create streams") - return nil, nil, nil, nil, nil, false, false - } - } - - return stdinStream, stdoutStream, stderrStream, errorStream, conn, tty, true -} - -// waitStreamReply waits until either replySent or stop is closed. If replySent is closed, it sends -// an empty struct to the notify channel. -func waitStreamReply(replySent <-chan struct{}, notify chan<- struct{}, stop <-chan struct{}) { - select { - case <-replySent: - notify <- struct{}{} - case <-stop: - } -} - func getPodCoordinates(request *restful.Request) (namespace, pod string, uid types.UID) { namespace = request.PathParameter("podNamespace") pod = request.PathParameter("podID") @@ -811,7 +638,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp podName := kubecontainer.GetPodFullName(pod) - ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), defaultStreamCreationTimeout) + ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), remotecommand.DefaultStreamCreationTimeout) } // ServePortForward handles a port forwarding request. A single request is @@ -821,7 +648,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp // handled by a single invocation of ServePortForward. func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) { supportedPortForwardProtocols := []string{portforward.PortForwardProtocolV1Name} - _, err := httpstream.Handshake(req, w, supportedPortForwardProtocols, portforward.PortForwardProtocolV1Name) + _, err := httpstream.Handshake(req, w, supportedPortForwardProtocols) // negotiated protocol isn't currently used server side, but could be in the future if err != nil { // Handshake writes the error to the client diff --git a/pkg/kubelet/server/server_test.go b/pkg/kubelet/server/server_test.go index 7b5b4fc1e74..c52413b6f9f 100644 --- a/pkg/kubelet/server/server_test.go +++ b/pkg/kubelet/server/server_test.go @@ -1019,7 +1019,7 @@ func TestServeExecInContainerIdleTimeout(t *testing.T) { <-conn.CloseChan() } -func TestServeExecInContainer(t *testing.T) { +func testExecAttach(t *testing.T, verb string) { tests := []struct { stdin bool stdout bool @@ -1053,12 +1053,15 @@ func TestServeExecInContainer(t *testing.T) { expectedStdin := "stdin" expectedStdout := "stdout" expectedStderr := "stderr" - execFuncDone := make(chan struct{}) + done := make(chan struct{}) clientStdoutReadDone := make(chan struct{}) clientStderrReadDone := make(chan struct{}) + execInvoked := false + attachInvoked := false + + testStreamFunc := func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool, done chan struct{}) error { + defer close(done) - fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error { - defer close(execFuncDone) if podFullName != expectedPodName { t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName) } @@ -1068,66 +1071,79 @@ func TestServeExecInContainer(t *testing.T) { if containerName != expectedContainerName { t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName) } + + if test.stdin { + if in == nil { + t.Fatalf("%d: stdin: expected non-nil", i) + } + b := make([]byte, 10) + n, err := in.Read(b) + if err != nil { + t.Fatalf("%d: error reading from stdin: %v", i, err) + } + if e, a := expectedStdin, string(b[0:n]); e != a { + t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a) + } + } else if in != nil { + t.Fatalf("%d: stdin: expected nil: %#v", i, in) + } + + if test.stdout { + if out == nil { + t.Fatalf("%d: stdout: expected non-nil", i) + } + _, err := out.Write([]byte(expectedStdout)) + if err != nil { + t.Fatalf("%d:, error writing to stdout: %v", i, err) + } + out.Close() + <-clientStdoutReadDone + } else if out != nil { + t.Fatalf("%d: stdout: expected nil: %#v", i, out) + } + + if tty { + if stderr != nil { + t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr) + } + } else if test.stderr { + if stderr == nil { + t.Fatalf("%d: stderr: expected non-nil", i) + } + _, err := stderr.Write([]byte(expectedStderr)) + if err != nil { + t.Fatalf("%d:, error writing to stderr: %v", i, err) + } + stderr.Close() + <-clientStderrReadDone + } else if stderr != nil { + t.Fatalf("%d: stderr: expected nil: %#v", i, stderr) + } + + return nil + } + + fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error { + execInvoked = true if strings.Join(cmd, " ") != expectedCommand { t.Fatalf("%d: cmd: expected: %s, got %v", i, expectedCommand, cmd) } + return testStreamFunc(podFullName, uid, containerName, cmd, in, out, stderr, tty, done) + } - if test.stdin { - if in == nil { - t.Fatalf("%d: stdin: expected non-nil", i) - } - b := make([]byte, 10) - n, err := in.Read(b) - if err != nil { - t.Fatalf("%d: error reading from stdin: %v", i, err) - } - if e, a := expectedStdin, string(b[0:n]); e != a { - t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a) - } - } else if in != nil { - t.Fatalf("%d: stdin: expected nil: %#v", i, in) - } - - if test.stdout { - if out == nil { - t.Fatalf("%d: stdout: expected non-nil", i) - } - _, err := out.Write([]byte(expectedStdout)) - if err != nil { - t.Fatalf("%d:, error writing to stdout: %v", i, err) - } - out.Close() - <-clientStdoutReadDone - } else if out != nil { - t.Fatalf("%d: stdout: expected nil: %#v", i, out) - } - - if tty { - if stderr != nil { - t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr) - } - } else if test.stderr { - if stderr == nil { - t.Fatalf("%d: stderr: expected non-nil", i) - } - _, err := stderr.Write([]byte(expectedStderr)) - if err != nil { - t.Fatalf("%d:, error writing to stderr: %v", i, err) - } - stderr.Close() - <-clientStderrReadDone - } else if stderr != nil { - t.Fatalf("%d: stderr: expected nil: %#v", i, stderr) - } - - return nil + fw.fakeKubelet.attachFunc = func(podFullName string, uid types.UID, containerName string, in io.Reader, out, stderr io.WriteCloser, tty bool) error { + attachInvoked = true + return testStreamFunc(podFullName, uid, containerName, nil, in, out, stderr, tty, done) } var url string if test.uid { - url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedUid + "/" + expectedContainerName + "?command=ls&command=-a" + url = fw.testHTTPServer.URL + "/" + verb + "/" + podNamespace + "/" + podName + "/" + expectedUid + "/" + expectedContainerName + "?ignore=1" } else { - url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?command=ls&command=-a" + url = fw.testHTTPServer.URL + "/" + verb + "/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?ignore=1" + } + if verb == "exec" { + url += "&command=ls&command=-a" } if test.stdin { url += "&" + api.ExecStdinParam + "=1" @@ -1186,11 +1202,9 @@ func TestServeExecInContainer(t *testing.T) { h := http.Header{} h.Set(api.StreamType, api.StreamTypeError) - errorStream, err := conn.CreateStream(h) - if err != nil { + if _, err := conn.CreateStream(h); err != nil { t.Fatalf("%d: error creating error stream: %v", i, err) } - defer errorStream.Reset() if test.stdin { h.Set(api.StreamType, api.StreamTypeStdin) @@ -1198,7 +1212,6 @@ func TestServeExecInContainer(t *testing.T) { if err != nil { t.Fatalf("%d: error creating stdin stream: %v", i, err) } - defer stream.Reset() _, err = stream.Write([]byte(expectedStdin)) if err != nil { t.Fatalf("%d: error writing to stdin stream: %v", i, err) @@ -1212,7 +1225,6 @@ func TestServeExecInContainer(t *testing.T) { if err != nil { t.Fatalf("%d: error creating stdout stream: %v", i, err) } - defer stdoutStream.Reset() } var stderrStream httpstream.Stream @@ -1222,7 +1234,6 @@ func TestServeExecInContainer(t *testing.T) { if err != nil { t.Fatalf("%d: error creating stderr stream: %v", i, err) } - defer stderrStream.Reset() } if test.stdout { @@ -1249,239 +1260,33 @@ func TestServeExecInContainer(t *testing.T) { } } - <-execFuncDone + // wait for the server to finish before checking if the attach/exec funcs were invoked + <-done + + if verb == "exec" { + if !execInvoked { + t.Errorf("%d: exec was not invoked", i) + } + if attachInvoked { + t.Errorf("%d: attach should not have been invoked", i) + } + } else { + if !attachInvoked { + t.Errorf("%d: attach was not invoked", i) + } + if execInvoked { + t.Errorf("%d: exec should not have been invoked", i) + } + } } } -// TODO: largely cloned from TestServeExecContainer, refactor and re-use code +func TestServeExecInContainer(t *testing.T) { + testExecAttach(t, "exec") +} + func TestServeAttachContainer(t *testing.T) { - tests := []struct { - stdin bool - stdout bool - stderr bool - tty bool - responseStatusCode int - uid bool - }{ - {responseStatusCode: http.StatusBadRequest}, - {stdin: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stdout: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols}, - {stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols}, - } - - for i, test := range tests { - fw := newServerTest() - - fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { - return 0 - } - - podNamespace := "other" - podName := "foo" - expectedPodName := getPodName(podName, podNamespace) - expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647" - expectedContainerName := "baz" - expectedStdin := "stdin" - expectedStdout := "stdout" - expectedStderr := "stderr" - attachFuncDone := make(chan struct{}) - clientStdoutReadDone := make(chan struct{}) - clientStderrReadDone := make(chan struct{}) - - fw.fakeKubelet.attachFunc = func(podFullName string, uid types.UID, containerName string, in io.Reader, out, stderr io.WriteCloser, tty bool) error { - defer close(attachFuncDone) - if podFullName != expectedPodName { - t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName) - } - if test.uid && string(uid) != expectedUid { - t.Fatalf("%d: uid: expected %v, got %v", i, expectedUid, uid) - } - if containerName != expectedContainerName { - t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName) - } - - if test.stdin { - if in == nil { - t.Fatalf("%d: stdin: expected non-nil", i) - } - b := make([]byte, 10) - n, err := in.Read(b) - if err != nil { - t.Fatalf("%d: error reading from stdin: %v", i, err) - } - if e, a := expectedStdin, string(b[0:n]); e != a { - t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a) - } - } else if in != nil { - t.Fatalf("%d: stdin: expected nil: %#v", i, in) - } - - if test.stdout { - if out == nil { - t.Fatalf("%d: stdout: expected non-nil", i) - } - _, err := out.Write([]byte(expectedStdout)) - if err != nil { - t.Fatalf("%d:, error writing to stdout: %v", i, err) - } - out.Close() - <-clientStdoutReadDone - } else if out != nil { - t.Fatalf("%d: stdout: expected nil: %#v", i, out) - } - - if tty { - if stderr != nil { - t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr) - } - } else if test.stderr { - if stderr == nil { - t.Fatalf("%d: stderr: expected non-nil", i) - } - _, err := stderr.Write([]byte(expectedStderr)) - if err != nil { - t.Fatalf("%d:, error writing to stderr: %v", i, err) - } - stderr.Close() - <-clientStderrReadDone - } else if stderr != nil { - t.Fatalf("%d: stderr: expected nil: %#v", i, stderr) - } - - return nil - } - - var url string - if test.uid { - url = fw.testHTTPServer.URL + "/attach/" + podNamespace + "/" + podName + "/" + expectedUid + "/" + expectedContainerName + "?" - } else { - url = fw.testHTTPServer.URL + "/attach/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?" - } - if test.stdin { - url += "&" + api.ExecStdinParam + "=1" - } - if test.stdout { - url += "&" + api.ExecStdoutParam + "=1" - } - if test.stderr && !test.tty { - url += "&" + api.ExecStderrParam + "=1" - } - if test.tty { - url += "&" + api.ExecTTYParam + "=1" - } - - var ( - resp *http.Response - err error - upgradeRoundTripper httpstream.UpgradeRoundTripper - c *http.Client - ) - - if test.responseStatusCode != http.StatusSwitchingProtocols { - c = &http.Client{} - } else { - upgradeRoundTripper = spdy.NewRoundTripper(nil) - c = &http.Client{Transport: upgradeRoundTripper} - } - - resp, err = c.Post(url, "", nil) - if err != nil { - t.Fatalf("%d: Got error POSTing: %v", i, err) - } - defer resp.Body.Close() - - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Errorf("%d: Error reading response body: %v", i, err) - } - - if e, a := test.responseStatusCode, resp.StatusCode; e != a { - t.Fatalf("%d: response status: expected %v, got %v", i, e, a) - } - - if test.responseStatusCode != http.StatusSwitchingProtocols { - continue - } - - conn, err := upgradeRoundTripper.NewConnection(resp) - if err != nil { - t.Fatalf("Unexpected error creating streaming connection: %s", err) - } - if conn == nil { - t.Fatalf("%d: unexpected nil conn", i) - } - defer conn.Close() - - h := http.Header{} - h.Set(api.StreamType, api.StreamTypeError) - errorStream, err := conn.CreateStream(h) - if err != nil { - t.Fatalf("%d: error creating error stream: %v", i, err) - } - defer errorStream.Reset() - - if test.stdin { - h.Set(api.StreamType, api.StreamTypeStdin) - stream, err := conn.CreateStream(h) - if err != nil { - t.Fatalf("%d: error creating stdin stream: %v", i, err) - } - defer stream.Reset() - _, err = stream.Write([]byte(expectedStdin)) - if err != nil { - t.Fatalf("%d: error writing to stdin stream: %v", i, err) - } - } - - var stdoutStream httpstream.Stream - if test.stdout { - h.Set(api.StreamType, api.StreamTypeStdout) - stdoutStream, err = conn.CreateStream(h) - if err != nil { - t.Fatalf("%d: error creating stdout stream: %v", i, err) - } - defer stdoutStream.Reset() - } - - var stderrStream httpstream.Stream - if test.stderr && !test.tty { - h.Set(api.StreamType, api.StreamTypeStderr) - stderrStream, err = conn.CreateStream(h) - if err != nil { - t.Fatalf("%d: error creating stderr stream: %v", i, err) - } - defer stderrStream.Reset() - } - - if test.stdout { - output := make([]byte, 10) - n, err := stdoutStream.Read(output) - close(clientStdoutReadDone) - if err != nil { - t.Fatalf("%d: error reading from stdout stream: %v", i, err) - } - if e, a := expectedStdout, string(output[0:n]); e != a { - t.Fatalf("%d: stdout: expected '%v', got '%v'", i, e, a) - } - } - - if test.stderr && !test.tty { - output := make([]byte, 10) - n, err := stderrStream.Read(output) - close(clientStderrReadDone) - if err != nil { - t.Fatalf("%d: error reading from stderr stream: %v", i, err) - } - if e, a := expectedStderr, string(output[0:n]); e != a { - t.Fatalf("%d: stderr: expected '%v', got '%v'", i, e, a) - } - } - - <-attachFuncDone - } + testExecAttach(t, "attach") } func TestServePortForwardIdleTimeout(t *testing.T) { diff --git a/pkg/util/httpstream/httpstream.go b/pkg/util/httpstream/httpstream.go index 4f6b608ce7a..3ce3b02a019 100644 --- a/pkg/util/httpstream/httpstream.go +++ b/pkg/util/httpstream/httpstream.go @@ -114,20 +114,24 @@ func negotiateProtocol(clientProtocols, serverProtocols []string) string { return "" } -// Handshake performs a subprotocol negotiation. If the client did not request -// a specific subprotocol, defaultProtocol is used. If the client did request a +// Handshake performs a subprotocol negotiation. If the client did request a // subprotocol, Handshake will select the first common value found in // serverProtocols. If a match is found, Handshake adds a response header // indicating the chosen subprotocol. If no match is found, HTTP forbidden is // returned, along with a response header containing the list of protocols the // server can accept. -func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string, defaultProtocol string) (string, error) { +func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string) (string, error) { clientProtocols := req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)] if len(clientProtocols) == 0 { - // Kube 1.0 client that didn't support subprotocol negotiation - // TODO remove this defaulting logic once Kube 1.0 is no longer supported - w.Header().Add(HeaderProtocolVersion, defaultProtocol) - return defaultProtocol, nil + // Kube 1.0 clients didn't support subprotocol negotiation. + // TODO require clientProtocols once Kube 1.0 is no longer supported + return "", nil + } + + if len(serverProtocols) == 0 { + // Kube 1.0 servers didn't support subprotocol negotiation. This is mainly for testing. + // TODO require serverProtocols once Kube 1.0 is no longer supported + return "", nil } negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols) diff --git a/pkg/util/httpstream/httpstream_test.go b/pkg/util/httpstream/httpstream_test.go index 7a1bbaefb89..7b0a3bd0a5d 100644 --- a/pkg/util/httpstream/httpstream_test.go +++ b/pkg/util/httpstream/httpstream_test.go @@ -20,6 +20,8 @@ import ( "net/http" "reflect" "testing" + + "k8s.io/kubernetes/pkg/api" ) type responseWriter struct { @@ -46,8 +48,6 @@ func (r *responseWriter) Write([]byte) (int, error) { } func TestHandshake(t *testing.T) { - defaultProtocol := "default" - tests := map[string]struct { clientProtocols []string serverProtocols []string @@ -57,7 +57,7 @@ func TestHandshake(t *testing.T) { "no client protocols": { clientProtocols: []string{}, serverProtocols: []string{"a", "b"}, - expectedProtocol: defaultProtocol, + expectedProtocol: "", }, "no common protocol": { clientProtocols: []string{"c"}, @@ -83,7 +83,7 @@ func TestHandshake(t *testing.T) { } w := newResponseWriter() - negotiated, err := Handshake(req, w, test.serverProtocols, defaultProtocol) + negotiated, err := Handshake(req, w, test.serverProtocols) // verify negotiated protocol if e, a := test.expectedProtocol, negotiated; e != a { @@ -112,8 +112,15 @@ func TestHandshake(t *testing.T) { t.Errorf("%s: unexpected non-nil w.statusCode: %d", w.statusCode) } + if len(test.expectedProtocol) == 0 { + if len(w.Header()[HeaderProtocolVersion]) > 0 { + t.Errorf("%s: unexpected protocol version response header: %s", w.Header()[HeaderProtocolVersion]) + } + continue + } + // verify response headers - if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !reflect.DeepEqual(e, a) { + if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !api.Semantic.DeepEqual(e, a) { t.Errorf("%s: protocol response header: expected %v, got %v", name, e, a) } }