Refactor exec code to support version skew testing

Refactor exec/attach client and server code to better support interoperability testing of different
client and server subprotocol versions.
This commit is contained in:
Andy Goldstein 2016-03-22 09:38:42 -04:00
parent d124deeb2f
commit 4551ba6b53
14 changed files with 894 additions and 804 deletions

View File

@ -26,6 +26,7 @@ import (
"k8s.io/kubernetes/pkg/client/restclient"
"k8s.io/kubernetes/pkg/client/transport"
"k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
"k8s.io/kubernetes/pkg/util/httpstream"
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
)
@ -36,7 +37,7 @@ type Executor interface {
// non-nil stream to a remote system, and return an error if a problem occurs. If tty
// is set, the stderr stream is not used (raw TTY manages stdout and stderr over the
// stdout stream).
Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error
Stream(supportedProtocols []string, stdin io.Reader, stdout, stderr io.Writer, tty bool) error
}
// StreamExecutor supports the ability to dial an httpstream connection and the ability to
@ -128,26 +129,13 @@ func (e *streamExecutor) Dial(protocols ...string) (httpstream.Connection, strin
return conn, resp.Header.Get(httpstream.HeaderProtocolVersion), nil
}
const (
// The SPDY subprotocol "channel.k8s.io" is used for remote command
// attachment/execution. This represents the initial unversioned subprotocol,
// which has the known bugs http://issues.k8s.io/13394 and
// http://issues.k8s.io/13395.
StreamProtocolV1Name = "channel.k8s.io"
// The SPDY subprotocol "v2.channel.k8s.io" is used for remote command
// attachment/execution. It is the second version of the subprotocol and
// resolves the issues present in the first version.
StreamProtocolV2Name = "v2.channel.k8s.io"
)
type streamProtocolHandler interface {
stream(httpstream.Connection) error
}
// Stream opens a protocol streamer to the server and streams until a client closes
// the connection or the server disconnects.
func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error {
supportedProtocols := []string{StreamProtocolV2Name, StreamProtocolV1Name}
func (e *streamExecutor) Stream(supportedProtocols []string, stdin io.Reader, stdout, stderr io.Writer, tty bool) error {
conn, protocol, err := e.Dial(supportedProtocols...)
if err != nil {
return err
@ -157,7 +145,7 @@ func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty b
var streamer streamProtocolHandler
switch protocol {
case StreamProtocolV2Name:
case remotecommand.StreamProtocolV2Name:
streamer = &streamProtocolV2{
stdin: stdin,
stdout: stdout,
@ -165,9 +153,9 @@ func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty b
tty: tty,
}
case "":
glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", StreamProtocolV1Name)
glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", remotecommand.StreamProtocolV1Name)
fallthrough
case StreamProtocolV1Name:
case remotecommand.StreamProtocolV1Name:
streamer = &streamProtocolV1{
stdin: stdin,
stdout: stdout,

View File

@ -18,6 +18,7 @@ package remotecommand
import (
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
@ -26,325 +27,263 @@ import (
"net/url"
"strings"
"testing"
"time"
"k8s.io/kubernetes/pkg/api"
"k8s.io/kubernetes/pkg/api/unversioned"
"k8s.io/kubernetes/pkg/client/restclient"
"k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
"k8s.io/kubernetes/pkg/types"
"k8s.io/kubernetes/pkg/util/httpstream"
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
)
type streamAndReply struct {
httpstream.Stream
replySent <-chan struct{}
type fakeExecutor struct {
t *testing.T
testName string
errorData string
stdoutData string
stderrData string
expectStdin bool
stdinReceived bytes.Buffer
tty bool
messageCount int
command []string
exec bool
}
func waitStreamReply(replySent <-chan struct{}, notify chan<- struct{}, stop <-chan struct{}) {
select {
case <-replySent:
notify <- struct{}{}
case <-stop:
}
func (ex *fakeExecutor) ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error {
return ex.run(name, uid, container, cmd, in, out, err, tty)
}
func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int) http.HandlerFunc {
// error + stdin + stdout
expectedStreams := 3
if !tty {
// stderr
expectedStreams++
func (ex *fakeExecutor) AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool) error {
return ex.run(name, uid, container, nil, in, out, err, tty)
}
func (ex *fakeExecutor) run(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error {
ex.command = cmd
ex.tty = tty
if e, a := "pod", name; e != a {
ex.t.Errorf("%s: pod: expected %q, got %q", ex.testName, e, a)
}
if e, a := "uid", uid; e != string(a) {
ex.t.Errorf("%s: uid: expected %q, got %q", ex.testName, e, a)
}
if ex.exec {
if e, a := "ls /", strings.Join(ex.command, " "); e != a {
ex.t.Errorf("%s: command: expected %q, got %q", ex.testName, e, a)
}
} else {
if len(ex.command) > 0 {
ex.t.Errorf("%s: command: expected nothing, got %v", ex.testName, ex.command)
}
}
if len(ex.errorData) > 0 {
return errors.New(ex.errorData)
}
if len(ex.stdoutData) > 0 {
for i := 0; i < ex.messageCount; i++ {
fmt.Fprint(out, ex.stdoutData)
}
}
if len(ex.stderrData) > 0 {
for i := 0; i < ex.messageCount; i++ {
fmt.Fprint(err, ex.stderrData)
}
}
if ex.expectStdin {
io.Copy(&ex.stdinReceived, in)
}
return nil
}
func fakeServer(t *testing.T, testName string, exec bool, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int, serverProtocols []string) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
protocol, err := httpstream.Handshake(req, w, []string{StreamProtocolV2Name}, StreamProtocolV1Name)
if err != nil {
t.Fatal(err)
}
if protocol != StreamProtocolV2Name {
t.Fatalf("unexpected protocol: %s", protocol)
}
streamCh := make(chan streamAndReply)
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error {
streamCh <- streamAndReply{Stream: stream, replySent: replySent}
return nil
})
// from this point on, we can no longer call methods on w
if conn == nil {
// The upgrader is responsible for notifying the client of any errors that
// occurred during upgrading. All we can do is return here at this point
// if we weren't successful in upgrading.
return
}
defer conn.Close()
var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
receivedStreams := 0
replyChan := make(chan struct{})
stop := make(chan struct{})
defer close(stop)
WaitForStreams:
for {
select {
case stream := <-streamCh:
streamType := stream.Headers().Get(api.StreamType)
switch streamType {
case api.StreamTypeError:
errorStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
case api.StreamTypeStdin:
stdinStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
case api.StreamTypeStdout:
stdoutStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
case api.StreamTypeStderr:
stderrStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
default:
t.Errorf("%d: unexpected stream type: %q", i, streamType)
}
if receivedStreams == expectedStreams {
break WaitForStreams
}
case <-replyChan:
receivedStreams++
if receivedStreams == expectedStreams {
break WaitForStreams
}
}
executor := &fakeExecutor{
t: t,
testName: testName,
errorData: errorData,
stdoutData: stdoutData,
stderrData: stderrData,
expectStdin: len(stdinData) > 0,
tty: tty,
messageCount: messageCount,
exec: exec,
}
if len(errorData) > 0 {
n, err := fmt.Fprint(errorStream, errorData)
if err != nil {
t.Errorf("%d: error writing to errorStream: %v", i, err)
}
if e, a := len(errorData), n; e != a {
t.Errorf("%d: expected to write %d bytes to errorStream, but only wrote %d", i, e, a)
}
errorStream.Close()
if exec {
remotecommand.ServeExec(w, req, executor, "pod", "uid", "container", 0, 10*time.Second, serverProtocols)
} else {
remotecommand.ServeAttach(w, req, executor, "pod", "uid", "container", 0, 10*time.Second, serverProtocols)
}
if len(stdoutData) > 0 {
for j := 0; j < messageCount; j++ {
n, err := fmt.Fprint(stdoutStream, stdoutData)
if err != nil {
t.Errorf("%d: error writing to stdoutStream: %v", i, err)
}
if e, a := len(stdoutData), n; e != a {
t.Errorf("%d: expected to write %d bytes to stdoutStream, but only wrote %d", i, e, a)
}
}
stdoutStream.Close()
}
if len(stderrData) > 0 {
for j := 0; j < messageCount; j++ {
n, err := fmt.Fprint(stderrStream, stderrData)
if err != nil {
t.Errorf("%d: error writing to stderrStream: %v", i, err)
}
if e, a := len(stderrData), n; e != a {
t.Errorf("%d: expected to write %d bytes to stderrStream, but only wrote %d", i, e, a)
}
}
stderrStream.Close()
}
if len(stdinData) > 0 {
data := make([]byte, len(stdinData))
for j := 0; j < messageCount; j++ {
n, err := io.ReadFull(stdinStream, data)
if err != nil {
t.Errorf("%d: error reading stdin stream: %v", i, err)
}
if e, a := len(stdinData), n; e != a {
t.Errorf("%d: expected to read %d bytes from stdinStream, but only read %d", i, e, a)
}
if e, a := stdinData, string(data); e != a {
t.Errorf("%d: stdin: expected %q, got %q", i, e, a)
}
}
stdinStream.Close()
if e, a := strings.Repeat(stdinData, messageCount), executor.stdinReceived.String(); e != a {
t.Errorf("%s: stdin: expected %q, got %q", testName, e, a)
}
})
}
func TestRequestExecuteRemoteCommand(t *testing.T) {
func TestStream(t *testing.T) {
testCases := []struct {
Stdin string
Stdout string
Stderr string
Error string
Tty bool
MessageCount int
TestName string
Stdin string
Stdout string
Stderr string
Error string
Tty bool
MessageCount int
ClientProtocols []string
ServerProtocols []string
}{
{
Error: "bail",
TestName: "error",
Error: "bail",
Stdout: "a",
ClientProtocols: []string{remotecommand.StreamProtocolV2Name},
ServerProtocols: []string{remotecommand.StreamProtocolV2Name},
},
{
Stdin: "a",
Stdout: "b",
Stderr: "c",
// TODO bump this to a larger number such as 100 once
// https://github.com/docker/spdystream/issues/55 is fixed and the Godep
// is bumped. Sending multiple messages over stdin/stdout/stderr results
// in more frames being spread across multiple spdystream frame workers.
// This makes it more likely that the spdystream bug will be encountered,
// where streams are closed as soon as a goaway frame is received, and
// any pending frames that haven't been processed yet may not be
// delivered (it's a race).
MessageCount: 1,
TestName: "in/out/err",
Stdin: "a",
Stdout: "b",
Stderr: "c",
MessageCount: 100,
ClientProtocols: []string{remotecommand.StreamProtocolV2Name},
ServerProtocols: []string{remotecommand.StreamProtocolV2Name},
},
{
Stdin: "a",
Stdout: "b",
Tty: true,
TestName: "in/out/tty",
Stdin: "a",
Stdout: "b",
Tty: true,
MessageCount: 100,
ClientProtocols: []string{remotecommand.StreamProtocolV2Name},
ServerProtocols: []string{remotecommand.StreamProtocolV2Name},
},
{
// 1.0 kubectl, 1.0 kubelet
TestName: "unversioned client, unversioned server",
Stdout: "b",
Stderr: "c",
MessageCount: 1,
ClientProtocols: []string{},
ServerProtocols: []string{},
},
{
// 1.0 kubectl, 1.1+ kubelet
TestName: "unversioned client, versioned server",
Stdout: "b",
Stderr: "c",
MessageCount: 1,
ClientProtocols: []string{},
ServerProtocols: []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name},
},
{
// 1.1+ kubectl, 1.0 kubelet
TestName: "versioned client, unversioned server",
Stdout: "b",
Stderr: "c",
MessageCount: 1,
ClientProtocols: []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name},
ServerProtocols: []string{},
},
}
for i, testCase := range testCases {
localOut := &bytes.Buffer{}
localErr := &bytes.Buffer{}
server := httptest.NewServer(fakeExecServer(t, i, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount))
url, _ := url.ParseRequestURI(server.URL)
c := restclient.NewRESTClient(url, "", restclient.ContentConfig{GroupVersion: &unversioned.GroupVersion{Group: "x"}}, -1, -1, nil)
req := c.Post().Resource("testing")
req.SetHeader(httpstream.HeaderProtocolVersion, StreamProtocolV2Name)
req.Param("command", "ls")
req.Param("command", "/")
conf := &restclient.Config{
Host: server.URL,
}
e, err := NewExecutor(conf, "POST", req.URL())
if err != nil {
t.Errorf("%d: unexpected error: %v", i, err)
continue
}
err = e.Stream(strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount)), localOut, localErr, testCase.Tty)
hasErr := err != nil
if len(testCase.Error) > 0 {
if !hasErr {
t.Errorf("%d: expected an error", i)
for _, testCase := range testCases {
for _, exec := range []bool{true, false} {
var name string
if exec {
name = testCase.TestName + " (exec)"
} else {
if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
t.Errorf("%d: expected error stream read '%v', got '%v'", i, e, a)
name = testCase.TestName + " (attach)"
}
var (
streamIn io.Reader
streamOut, streamErr io.Writer
)
localOut := &bytes.Buffer{}
localErr := &bytes.Buffer{}
server := httptest.NewServer(fakeServer(t, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols))
url, _ := url.ParseRequestURI(server.URL)
c := restclient.NewRESTClient(url, "", restclient.ContentConfig{GroupVersion: &unversioned.GroupVersion{Group: "x"}}, -1, -1, nil)
req := c.Post().Resource("testing")
if exec {
req.Param("command", "ls")
req.Param("command", "/")
}
if len(testCase.Stdin) > 0 {
req.Param(api.ExecStdinParam, "1")
streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount))
}
if len(testCase.Stdout) > 0 {
req.Param(api.ExecStdoutParam, "1")
streamOut = localOut
}
if testCase.Tty {
req.Param(api.ExecTTYParam, "1")
} else if len(testCase.Stderr) > 0 {
req.Param(api.ExecStderrParam, "1")
streamErr = localErr
}
conf := &restclient.Config{
Host: server.URL,
}
e, err := NewExecutor(conf, "POST", req.URL())
if err != nil {
t.Errorf("%s: unexpected error: %v", name, err)
continue
}
err = e.Stream(testCase.ClientProtocols, streamIn, streamOut, streamErr, testCase.Tty)
hasErr := err != nil
if len(testCase.Error) > 0 {
if !hasErr {
t.Errorf("%s: expected an error", name)
} else {
if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
t.Errorf("%s: expected error stream read %q, got %q", name, e, a)
}
}
// TODO: Uncomment when fix #19254
// server.Close()
continue
}
if hasErr {
t.Errorf("%s: unexpected error: %v", name, err)
// TODO: Uncomment when fix #19254
// server.Close()
continue
}
if len(testCase.Stdout) > 0 {
if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() {
t.Errorf("%s: expected stdout data '%s', got '%s'", name, e, a)
}
}
if testCase.Stderr != "" {
if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() {
t.Errorf("%s: expected stderr data '%s', got '%s'", name, e, a)
}
}
// TODO: Uncomment when fix #19254
// server.Close()
continue
}
if hasErr {
t.Errorf("%d: unexpected error: %v", i, err)
// TODO: Uncomment when fix #19254
// server.Close()
continue
}
if len(testCase.Stdout) > 0 {
if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() {
t.Errorf("%d: expected stdout data '%s', got '%s'", i, e, a)
}
}
if testCase.Stderr != "" {
if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() {
t.Errorf("%d: expected stderr data '%s', got '%s'", i, e, a)
}
}
// TODO: Uncomment when fix #19254
// server.Close()
}
}
// TODO: this test is largely cut and paste, refactor to share code
func TestRequestAttachRemoteCommand(t *testing.T) {
testCases := []struct {
Stdin string
Stdout string
Stderr string
Error string
Tty bool
}{
{
Error: "bail",
},
{
Stdin: "a",
Stdout: "b",
Stderr: "c",
},
{
Stdin: "a",
Stdout: "b",
Tty: true,
},
}
for i, testCase := range testCases {
localOut := &bytes.Buffer{}
localErr := &bytes.Buffer{}
server := httptest.NewServer(fakeExecServer(t, i, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, 1))
url, _ := url.ParseRequestURI(server.URL)
c := restclient.NewRESTClient(url, "", restclient.ContentConfig{GroupVersion: &unversioned.GroupVersion{Group: "x"}}, -1, -1, nil)
req := c.Post().Resource("testing")
conf := &restclient.Config{
Host: server.URL,
}
e, err := NewExecutor(conf, "POST", req.URL())
if err != nil {
t.Errorf("%d: unexpected error: %v", i, err)
continue
}
err = e.Stream(strings.NewReader(testCase.Stdin), localOut, localErr, testCase.Tty)
hasErr := err != nil
if len(testCase.Error) > 0 {
if !hasErr {
t.Errorf("%d: expected an error", i)
} else {
if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
t.Errorf("%d: expected error stream read '%v', got '%v'", i, e, a)
}
}
// TODO: Uncomment when fix #19254
// server.Close()
continue
}
if hasErr {
t.Errorf("%d: unexpected error: %v", i, err)
// TODO: Uncomment when fix #19254
// server.Close()
continue
}
if len(testCase.Stdout) > 0 {
if e, a := testCase.Stdout, localOut; e != a.String() {
t.Errorf("%d: expected stdout data '%s', got '%s'", i, e, a)
}
}
if testCase.Stderr != "" {
if e, a := testCase.Stderr, localErr; e != a.String() {
t.Errorf("%d: expected stderr data '%s', got '%s'", i, e, a)
}
}
// TODO: Uncomment when fix #19254
// server.Close()
}
}

View File

@ -29,6 +29,7 @@ import (
client "k8s.io/kubernetes/pkg/client/unversioned"
"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
cmdutil "k8s.io/kubernetes/pkg/kubectl/cmd/util"
remotecommandserver "k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
utilerrors "k8s.io/kubernetes/pkg/util/errors"
"k8s.io/kubernetes/pkg/util/interrupt"
"k8s.io/kubernetes/pkg/util/term"
@ -87,7 +88,7 @@ func (*DefaultRemoteAttach) Attach(method string, url *url.URL, config *restclie
if err != nil {
return err
}
return exec.Stream(stdin, stdout, stderr, tty)
return exec.Stream(remotecommandserver.SupportedStreamingProtocols, stdin, stdout, stderr, tty)
}
// AttachOptions declare the arguments accepted by the Exec command

View File

@ -32,6 +32,7 @@ import (
client "k8s.io/kubernetes/pkg/client/unversioned"
"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
cmdutil "k8s.io/kubernetes/pkg/kubectl/cmd/util"
remotecommandserver "k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
)
const (
@ -87,7 +88,7 @@ func (*DefaultRemoteExecutor) Execute(method string, url *url.URL, config *restc
if err != nil {
return err
}
return exec.Stream(stdin, stdout, stderr, tty)
return exec.Stream(remotecommandserver.SupportedStreamingProtocols, stdin, stdout, stderr, tty)
}
// ExecOptions declare the arguments accepted by the Exec command

View File

@ -0,0 +1,53 @@
/*
Copyright 2016 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package remotecommand
import (
"errors"
"fmt"
"io"
"net/http"
"time"
"k8s.io/kubernetes/pkg/types"
"k8s.io/kubernetes/pkg/util/runtime"
)
// Attacher knows how to attach to a running container in a pod.
type Attacher interface {
// AttachContainer attaches to the running container in the pod, copying data between in/out/err
// and the container's stdin/stdout/stderr.
AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool) error
}
// ServeAttach handles requests to attach to a container. After creating/receiving the required
// streams, it delegates the actual attaching to attacher.
func ServeAttach(w http.ResponseWriter, req *http.Request, attacher Attacher, podName string, uid types.UID, container string, idleTimeout, streamCreationTimeout time.Duration, supportedProtocols []string) {
ctx, ok := createStreams(req, w, supportedProtocols, idleTimeout, streamCreationTimeout)
if !ok {
// error is handled by createStreams
return
}
defer ctx.conn.Close()
err := attacher.AttachContainer(podName, uid, container, ctx.stdinStream, ctx.stdoutStream, ctx.stderrStream, ctx.tty)
if err != nil {
msg := fmt.Sprintf("error attaching to container: %v", err)
runtime.HandleError(errors.New(msg))
fmt.Fprint(ctx.errorStream, msg)
}
}

View File

@ -0,0 +1,36 @@
/*
Copyright 2016 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package remotecommand
import "time"
const (
DefaultStreamCreationTimeout = 30 * time.Second
// The SPDY subprotocol "channel.k8s.io" is used for remote command
// attachment/execution. This represents the initial unversioned subprotocol,
// which has the known bugs http://issues.k8s.io/13394 and
// http://issues.k8s.io/13395.
StreamProtocolV1Name = "channel.k8s.io"
// The SPDY subprotocol "v2.channel.k8s.io" is used for remote command
// attachment/execution. It is the second version of the subprotocol and
// resolves the issues present in the first version.
StreamProtocolV2Name = "v2.channel.k8s.io"
)
var SupportedStreamingProtocols = []string{StreamProtocolV2Name, StreamProtocolV1Name}

View File

@ -0,0 +1,18 @@
/*
Copyright 2016 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// package remotecommand contains functions related to executing commands in and attaching to pods.
package remotecommand

View File

@ -0,0 +1,57 @@
/*
Copyright 2016 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package remotecommand
import (
"errors"
"fmt"
"io"
"net/http"
"time"
"k8s.io/kubernetes/pkg/api"
"k8s.io/kubernetes/pkg/types"
"k8s.io/kubernetes/pkg/util/runtime"
)
// Executor knows how to execute a command in a container in a pod.
type Executor interface {
// ExecInContainer executes a command in a container in the pod, copying data
// between in/out/err and the container's stdin/stdout/stderr.
ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error
}
// ServeExec handles requests to execute a command in a container. After
// creating/receiving the required streams, it delegates the actual execution
// to the executor.
func ServeExec(w http.ResponseWriter, req *http.Request, executor Executor, podName string, uid types.UID, container string, idleTimeout, streamCreationTimeout time.Duration, supportedProtocols []string) {
ctx, ok := createStreams(req, w, supportedProtocols, idleTimeout, streamCreationTimeout)
if !ok {
// error is handled by createStreams
return
}
defer ctx.conn.Close()
cmd := req.URL.Query()[api.ExecCommandParamm]
err := executor.ExecInContainer(podName, uid, container, cmd, ctx.stdinStream, ctx.stdoutStream, ctx.stderrStream, ctx.tty)
if err != nil {
msg := fmt.Sprintf("error executing command in container: %v", err)
runtime.HandleError(errors.New(msg))
fmt.Fprint(ctx.errorStream, msg)
}
}

View File

@ -0,0 +1,277 @@
/*
Copyright 2016 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package remotecommand
import (
"errors"
"fmt"
"io"
"net/http"
"time"
"k8s.io/kubernetes/pkg/api"
"k8s.io/kubernetes/pkg/util/httpstream"
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
"k8s.io/kubernetes/pkg/util/runtime"
"k8s.io/kubernetes/pkg/util/wsstream"
"github.com/golang/glog"
)
// options contains details about which streams are required for
// remote command execution.
type options struct {
stdin bool
stdout bool
stderr bool
tty bool
expectedStreams int
}
// newOptions creates a new options from the Request.
func newOptions(req *http.Request) (*options, error) {
tty := req.FormValue(api.ExecTTYParam) == "1"
stdin := req.FormValue(api.ExecStdinParam) == "1"
stdout := req.FormValue(api.ExecStdoutParam) == "1"
stderr := req.FormValue(api.ExecStderrParam) == "1"
if tty && stderr {
// TODO: make this an error before we reach this method
glog.V(4).Infof("Access to exec with tty and stderr is not supported, bypassing stderr")
stderr = false
}
// count the streams client asked for, starting with 1
expectedStreams := 1
if stdin {
expectedStreams++
}
if stdout {
expectedStreams++
}
if stderr {
expectedStreams++
}
if expectedStreams == 1 {
return nil, fmt.Errorf("you must specify at least 1 of stdin, stdout, stderr")
}
return &options{
stdin: stdin,
stdout: stdout,
stderr: stderr,
tty: tty,
expectedStreams: expectedStreams,
}, nil
}
// context contains the connection and streams used when
// forwarding an attach or execute session into a container.
type context struct {
conn io.Closer
stdinStream io.ReadCloser
stdoutStream io.WriteCloser
stderrStream io.WriteCloser
errorStream io.WriteCloser
tty bool
}
// streamAndReply holds both a Stream and a channel that is closed when the stream's reply frame is
// enqueued. Consumers can wait for replySent to be closed prior to proceeding, to ensure that the
// replyFrame is enqueued before the connection's goaway frame is sent (e.g. if a stream was
// received and right after, the connection gets closed).
type streamAndReply struct {
httpstream.Stream
replySent <-chan struct{}
}
// waitStreamReply waits until either replySent or stop is closed. If replySent is closed, it sends
// an empty struct to the notify channel.
func waitStreamReply(replySent <-chan struct{}, notify chan<- struct{}, stop <-chan struct{}) {
select {
case <-replySent:
notify <- struct{}{}
case <-stop:
}
}
func createStreams(req *http.Request, w http.ResponseWriter, supportedStreamProtocols []string, idleTimeout, streamCreationTimeout time.Duration) (*context, bool) {
opts, err := newOptions(req)
if err != nil {
runtime.HandleError(err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, err.Error())
return nil, false
}
if wsstream.IsWebSocketRequest(req) {
return createWebSocketStreams(req, w, opts, idleTimeout)
}
protocol, err := httpstream.Handshake(req, w, supportedStreamProtocols)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, err.Error())
return nil, false
}
streamCh := make(chan streamAndReply)
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error {
streamCh <- streamAndReply{Stream: stream, replySent: replySent}
return nil
})
// from this point on, we can no longer call methods on response
if conn == nil {
// The upgrader is responsible for notifying the client of any errors that
// occurred during upgrading. All we can do is return here at this point
// if we weren't successful in upgrading.
return nil, false
}
conn.SetIdleTimeout(idleTimeout)
var handler protocolHandler
switch protocol {
case StreamProtocolV2Name:
handler = &v2ProtocolHandler{}
case "":
glog.V(4).Infof("Client did not request protocol negotiaion. Falling back to %q", StreamProtocolV1Name)
fallthrough
case StreamProtocolV1Name:
handler = &v1ProtocolHandler{}
}
expired := time.NewTimer(streamCreationTimeout)
ctx, err := handler.waitForStreams(streamCh, opts.expectedStreams, expired.C)
if err != nil {
runtime.HandleError(err)
return nil, false
}
ctx.conn = conn
ctx.tty = opts.tty
return ctx, true
}
type protocolHandler interface {
// waitForStreams waits for the expected streams or a timeout, returning a
// remoteCommandContext if all the streams were received, or an error if not.
waitForStreams(streams <-chan streamAndReply, expectedStreams int, expired <-chan time.Time) (*context, error)
}
// v2ProtocolHandler implements the V2 protocol version for streaming command execution.
type v2ProtocolHandler struct{}
func (*v2ProtocolHandler) waitForStreams(streams <-chan streamAndReply, expectedStreams int, expired <-chan time.Time) (*context, error) {
ctx := &context{}
receivedStreams := 0
replyChan := make(chan struct{})
stop := make(chan struct{})
defer close(stop)
WaitForStreams:
for {
select {
case stream := <-streams:
streamType := stream.Headers().Get(api.StreamType)
switch streamType {
case api.StreamTypeError:
ctx.errorStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
case api.StreamTypeStdin:
ctx.stdinStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
case api.StreamTypeStdout:
ctx.stdoutStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
case api.StreamTypeStderr:
ctx.stderrStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
default:
runtime.HandleError(fmt.Errorf("Unexpected stream type: %q", streamType))
}
case <-replyChan:
receivedStreams++
if receivedStreams == expectedStreams {
break WaitForStreams
}
case <-expired:
// TODO find a way to return the error to the user. Maybe use a separate
// stream to report errors?
return nil, errors.New("timed out waiting for client to create streams")
}
}
return ctx, nil
}
// v1ProtocolHandler implements the V1 protocol version for streaming command execution.
type v1ProtocolHandler struct{}
func (*v1ProtocolHandler) waitForStreams(streams <-chan streamAndReply, expectedStreams int, expired <-chan time.Time) (*context, error) {
ctx := &context{}
receivedStreams := 0
replyChan := make(chan struct{})
stop := make(chan struct{})
defer close(stop)
WaitForStreams:
for {
select {
case stream := <-streams:
streamType := stream.Headers().Get(api.StreamType)
switch streamType {
case api.StreamTypeError:
ctx.errorStream = stream
// This defer statement shouldn't be here, but due to previous refactoring, it ended up in
// here. This is what 1.0.x kubelets do, so we're retaining that behavior. This is fixed in
// the v2ProtocolHandler.
defer stream.Reset()
go waitStreamReply(stream.replySent, replyChan, stop)
case api.StreamTypeStdin:
ctx.stdinStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
case api.StreamTypeStdout:
ctx.stdoutStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
case api.StreamTypeStderr:
ctx.stderrStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
default:
runtime.HandleError(fmt.Errorf("Unexpected stream type: %q", streamType))
}
case <-replyChan:
receivedStreams++
if receivedStreams == expectedStreams {
break WaitForStreams
}
case <-expired:
// TODO find a way to return the error to the user. Maybe use a separate
// stream to report errors?
return nil, errors.New("timed out waiting for client to create streams")
}
}
if ctx.stdinStream != nil {
ctx.stdinStream.Close()
}
return ctx, nil
}

View File

@ -0,0 +1,77 @@
/*
Copyright 2016 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package remotecommand
import (
"net/http"
"time"
"k8s.io/kubernetes/pkg/httplog"
"k8s.io/kubernetes/pkg/util/wsstream"
"github.com/golang/glog"
)
// standardShellChannels returns the standard channel types for a shell connection (STDIN 0, STDOUT 1, STDERR 2)
// along with the approximate duplex value. Supported subprotocols are "channel.k8s.io" and
// "base64.channel.k8s.io".
func standardShellChannels(stdin, stdout, stderr bool) []wsstream.ChannelType {
// open three half-duplex channels
channels := []wsstream.ChannelType{wsstream.ReadChannel, wsstream.WriteChannel, wsstream.WriteChannel}
if !stdin {
channels[0] = wsstream.IgnoreChannel
}
if !stdout {
channels[1] = wsstream.IgnoreChannel
}
if !stderr {
channels[2] = wsstream.IgnoreChannel
}
return channels
}
// createWebSocketStreams returns a remoteCommandContext containing the websocket connection and
// streams needed to perform an exec or an attach.
func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts *options, idleTimeout time.Duration) (*context, bool) {
// open the requested channels, and always open the error channel
channels := append(standardShellChannels(opts.stdin, opts.stdout, opts.stderr), wsstream.WriteChannel)
conn := wsstream.NewConn(channels...)
conn.SetIdleTimeout(idleTimeout)
streams, err := conn.Open(httplog.Unlogged(w), req)
if err != nil {
glog.Errorf("Unable to upgrade websocket connection: %v", err)
return nil, false
}
// Send an empty message to the lowest writable channel to notify the client the connection is established
// TODO: make generic to SPDY and WebSockets and do it outside of this method?
switch {
case opts.stdout:
streams[1].Write([]byte{})
case opts.stderr:
streams[2].Write([]byte{})
default:
streams[3].Write([]byte{})
}
return &context{
conn: conn,
stdinStream: streams[0],
stdoutStream: streams[1],
stderrStream: streams[2],
errorStream: streams[3],
tty: opts.tty,
}, true
}

View File

@ -43,12 +43,12 @@ import (
"k8s.io/kubernetes/pkg/api/validation"
"k8s.io/kubernetes/pkg/auth/authenticator"
"k8s.io/kubernetes/pkg/auth/authorizer"
"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
"k8s.io/kubernetes/pkg/healthz"
"k8s.io/kubernetes/pkg/httplog"
"k8s.io/kubernetes/pkg/kubelet/cm"
kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
"k8s.io/kubernetes/pkg/kubelet/server/portforward"
"k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
"k8s.io/kubernetes/pkg/kubelet/server/stats"
"k8s.io/kubernetes/pkg/runtime"
"k8s.io/kubernetes/pkg/types"
@ -58,7 +58,6 @@ import (
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
"k8s.io/kubernetes/pkg/util/limitwriter"
utilruntime "k8s.io/kubernetes/pkg/util/runtime"
"k8s.io/kubernetes/pkg/util/wsstream"
"k8s.io/kubernetes/pkg/volume"
)
@ -540,12 +539,7 @@ func getContainerCoordinates(request *restful.Request) (namespace, pod string, u
return
}
const defaultStreamCreationTimeout = 30 * time.Second
type Closer interface {
Close() error
}
// getAttach handles requests to attach to a container.
func (s *Server) getAttach(request *restful.Request, response *restful.Response) {
podNamespace, podID, uid, container := getContainerCoordinates(request)
pod, ok := s.host.GetPodByName(podNamespace, podID)
@ -554,21 +548,35 @@ func (s *Server) getAttach(request *restful.Request, response *restful.Response)
return
}
stdinStream, stdoutStream, stderrStream, errorStream, conn, tty, ok := s.createStreams(request, response)
if conn != nil {
defer conn.Close()
}
remotecommand.ServeAttach(response.ResponseWriter,
request.Request,
s.host,
kubecontainer.GetPodFullName(pod),
uid,
container,
s.host.StreamingConnectionIdleTimeout(),
remotecommand.DefaultStreamCreationTimeout,
remotecommand.SupportedStreamingProtocols)
}
// getExec handles requests to run a command inside a container.
func (s *Server) getExec(request *restful.Request, response *restful.Response) {
podNamespace, podID, uid, container := getContainerCoordinates(request)
pod, ok := s.host.GetPodByName(podNamespace, podID)
if !ok {
// error is handled in the createStreams function
response.WriteError(http.StatusNotFound, fmt.Errorf("pod does not exist"))
return
}
err := s.host.AttachContainer(kubecontainer.GetPodFullName(pod), uid, container, stdinStream, stdoutStream, stderrStream, tty)
if err != nil {
msg := fmt.Sprintf("Error executing command in container: %v", err)
glog.Error(msg)
errorStream.Write([]byte(msg))
}
remotecommand.ServeExec(response.ResponseWriter,
request.Request,
s.host,
kubecontainer.GetPodFullName(pod),
uid,
container,
s.host.StreamingConnectionIdleTimeout(),
remotecommand.DefaultStreamCreationTimeout,
remotecommand.SupportedStreamingProtocols)
}
// getRun handles requests to run a command inside a container.
@ -588,187 +596,6 @@ func (s *Server) getRun(request *restful.Request, response *restful.Response) {
writeJsonResponse(response, data)
}
// getExec handles requests to run a command inside a container.
func (s *Server) getExec(request *restful.Request, response *restful.Response) {
podNamespace, podID, uid, container := getContainerCoordinates(request)
pod, ok := s.host.GetPodByName(podNamespace, podID)
if !ok {
response.WriteError(http.StatusNotFound, fmt.Errorf("pod does not exist"))
return
}
stdinStream, stdoutStream, stderrStream, errorStream, conn, tty, ok := s.createStreams(request, response)
if conn != nil {
defer conn.Close()
}
if !ok {
// error is handled in the createStreams function
return
}
cmd := request.Request.URL.Query()[api.ExecCommandParamm]
err := s.host.ExecInContainer(kubecontainer.GetPodFullName(pod), uid, container, cmd, stdinStream, stdoutStream, stderrStream, tty)
if err != nil {
msg := fmt.Sprintf("Error executing command in container: %v", err)
glog.Error(msg)
errorStream.Write([]byte(msg))
}
}
// standardShellChannels returns the standard channel types for a shell connection (STDIN 0, STDOUT 1, STDERR 2)
// along with the approprxate duplex value
func standardShellChannels(stdin, stdout, stderr bool) []wsstream.ChannelType {
// open three half-duplex channels
channels := []wsstream.ChannelType{wsstream.ReadChannel, wsstream.WriteChannel, wsstream.WriteChannel}
if !stdin {
channels[0] = wsstream.IgnoreChannel
}
if !stdout {
channels[1] = wsstream.IgnoreChannel
}
if !stderr {
channels[2] = wsstream.IgnoreChannel
}
return channels
}
// streamAndReply holds both a Stream and a channel that is closed when the stream's reply frame is
// enqueued. Consumers can wait for replySent to be closed prior to proceeding, to ensure that the
// replyFrame is enqueued before the connection's goaway frame is sent (e.g. if a stream was
// received and right after, the connection gets closed).
type streamAndReply struct {
httpstream.Stream
replySent <-chan struct{}
}
func (s *Server) createStreams(request *restful.Request, response *restful.Response) (io.Reader, io.WriteCloser, io.WriteCloser, io.WriteCloser, Closer, bool, bool) {
tty := request.QueryParameter(api.ExecTTYParam) == "1"
stdin := request.QueryParameter(api.ExecStdinParam) == "1"
stdout := request.QueryParameter(api.ExecStdoutParam) == "1"
stderr := request.QueryParameter(api.ExecStderrParam) == "1"
if tty && stderr {
// TODO: make this an error before we reach this method
glog.V(4).Infof("Access to exec with tty and stderr is not supported, bypassing stderr")
stderr = false
}
// count the streams client asked for, starting with 1
expectedStreams := 1
if stdin {
expectedStreams++
}
if stdout {
expectedStreams++
}
if stderr {
expectedStreams++
}
if expectedStreams == 1 {
response.WriteError(http.StatusBadRequest, fmt.Errorf("you must specify at least 1 of stdin, stdout, stderr"))
return nil, nil, nil, nil, nil, false, false
}
if wsstream.IsWebSocketRequest(request.Request) {
// open the requested channels, and always open the error channel
channels := append(standardShellChannels(stdin, stdout, stderr), wsstream.WriteChannel)
conn := wsstream.NewConn(channels...)
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
streams, err := conn.Open(httplog.Unlogged(response.ResponseWriter), request.Request)
if err != nil {
glog.Errorf("Unable to upgrade websocket connection: %v", err)
return nil, nil, nil, nil, nil, false, false
}
// Send an empty message to the lowest writable channel to notify the client the connection is established
// TODO: make generic to SDPY and WebSockets and do it outside of this method?
switch {
case stdout:
streams[1].Write([]byte{})
case stderr:
streams[2].Write([]byte{})
default:
streams[3].Write([]byte{})
}
return streams[0], streams[1], streams[2], streams[3], conn, tty, true
}
supportedStreamProtocols := []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}
_, err := httpstream.Handshake(request.Request, response.ResponseWriter, supportedStreamProtocols, remotecommand.StreamProtocolV1Name)
// negotiated protocol isn't used server side at the moment, but could be in the future
if err != nil {
return nil, nil, nil, nil, nil, false, false
}
streamCh := make(chan streamAndReply)
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream, replySent <-chan struct{}) error {
streamCh <- streamAndReply{Stream: stream, replySent: replySent}
return nil
})
// from this point on, we can no longer call methods on response
if conn == nil {
// The upgrader is responsible for notifying the client of any errors that
// occurred during upgrading. All we can do is return here at this point
// if we weren't successful in upgrading.
return nil, nil, nil, nil, nil, false, false
}
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
// TODO make it configurable?
expired := time.NewTimer(defaultStreamCreationTimeout)
var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
receivedStreams := 0
replyChan := make(chan struct{})
stop := make(chan struct{})
defer close(stop)
WaitForStreams:
for {
select {
case stream := <-streamCh:
streamType := stream.Headers().Get(api.StreamType)
switch streamType {
case api.StreamTypeError:
errorStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
case api.StreamTypeStdin:
stdinStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
case api.StreamTypeStdout:
stdoutStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
case api.StreamTypeStderr:
stderrStream = stream
go waitStreamReply(stream.replySent, replyChan, stop)
default:
glog.Errorf("Unexpected stream type: '%s'", streamType)
}
case <-replyChan:
receivedStreams++
if receivedStreams == expectedStreams {
break WaitForStreams
}
case <-expired.C:
// TODO find a way to return the error to the user. Maybe use a separate
// stream to report errors?
glog.Error("Timed out waiting for client to create streams")
return nil, nil, nil, nil, nil, false, false
}
}
return stdinStream, stdoutStream, stderrStream, errorStream, conn, tty, true
}
// waitStreamReply waits until either replySent or stop is closed. If replySent is closed, it sends
// an empty struct to the notify channel.
func waitStreamReply(replySent <-chan struct{}, notify chan<- struct{}, stop <-chan struct{}) {
select {
case <-replySent:
notify <- struct{}{}
case <-stop:
}
}
func getPodCoordinates(request *restful.Request) (namespace, pod string, uid types.UID) {
namespace = request.PathParameter("podNamespace")
pod = request.PathParameter("podID")
@ -811,7 +638,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
podName := kubecontainer.GetPodFullName(pod)
ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), defaultStreamCreationTimeout)
ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), remotecommand.DefaultStreamCreationTimeout)
}
// ServePortForward handles a port forwarding request. A single request is
@ -821,7 +648,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
// handled by a single invocation of ServePortForward.
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
supportedPortForwardProtocols := []string{portforward.PortForwardProtocolV1Name}
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols, portforward.PortForwardProtocolV1Name)
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols)
// negotiated protocol isn't currently used server side, but could be in the future
if err != nil {
// Handshake writes the error to the client

View File

@ -1019,7 +1019,7 @@ func TestServeExecInContainerIdleTimeout(t *testing.T) {
<-conn.CloseChan()
}
func TestServeExecInContainer(t *testing.T) {
func testExecAttach(t *testing.T, verb string) {
tests := []struct {
stdin bool
stdout bool
@ -1053,12 +1053,15 @@ func TestServeExecInContainer(t *testing.T) {
expectedStdin := "stdin"
expectedStdout := "stdout"
expectedStderr := "stderr"
execFuncDone := make(chan struct{})
done := make(chan struct{})
clientStdoutReadDone := make(chan struct{})
clientStderrReadDone := make(chan struct{})
execInvoked := false
attachInvoked := false
testStreamFunc := func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool, done chan struct{}) error {
defer close(done)
fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
defer close(execFuncDone)
if podFullName != expectedPodName {
t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName)
}
@ -1068,66 +1071,79 @@ func TestServeExecInContainer(t *testing.T) {
if containerName != expectedContainerName {
t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName)
}
if test.stdin {
if in == nil {
t.Fatalf("%d: stdin: expected non-nil", i)
}
b := make([]byte, 10)
n, err := in.Read(b)
if err != nil {
t.Fatalf("%d: error reading from stdin: %v", i, err)
}
if e, a := expectedStdin, string(b[0:n]); e != a {
t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a)
}
} else if in != nil {
t.Fatalf("%d: stdin: expected nil: %#v", i, in)
}
if test.stdout {
if out == nil {
t.Fatalf("%d: stdout: expected non-nil", i)
}
_, err := out.Write([]byte(expectedStdout))
if err != nil {
t.Fatalf("%d:, error writing to stdout: %v", i, err)
}
out.Close()
<-clientStdoutReadDone
} else if out != nil {
t.Fatalf("%d: stdout: expected nil: %#v", i, out)
}
if tty {
if stderr != nil {
t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr)
}
} else if test.stderr {
if stderr == nil {
t.Fatalf("%d: stderr: expected non-nil", i)
}
_, err := stderr.Write([]byte(expectedStderr))
if err != nil {
t.Fatalf("%d:, error writing to stderr: %v", i, err)
}
stderr.Close()
<-clientStderrReadDone
} else if stderr != nil {
t.Fatalf("%d: stderr: expected nil: %#v", i, stderr)
}
return nil
}
fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
execInvoked = true
if strings.Join(cmd, " ") != expectedCommand {
t.Fatalf("%d: cmd: expected: %s, got %v", i, expectedCommand, cmd)
}
return testStreamFunc(podFullName, uid, containerName, cmd, in, out, stderr, tty, done)
}
if test.stdin {
if in == nil {
t.Fatalf("%d: stdin: expected non-nil", i)
}
b := make([]byte, 10)
n, err := in.Read(b)
if err != nil {
t.Fatalf("%d: error reading from stdin: %v", i, err)
}
if e, a := expectedStdin, string(b[0:n]); e != a {
t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a)
}
} else if in != nil {
t.Fatalf("%d: stdin: expected nil: %#v", i, in)
}
if test.stdout {
if out == nil {
t.Fatalf("%d: stdout: expected non-nil", i)
}
_, err := out.Write([]byte(expectedStdout))
if err != nil {
t.Fatalf("%d:, error writing to stdout: %v", i, err)
}
out.Close()
<-clientStdoutReadDone
} else if out != nil {
t.Fatalf("%d: stdout: expected nil: %#v", i, out)
}
if tty {
if stderr != nil {
t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr)
}
} else if test.stderr {
if stderr == nil {
t.Fatalf("%d: stderr: expected non-nil", i)
}
_, err := stderr.Write([]byte(expectedStderr))
if err != nil {
t.Fatalf("%d:, error writing to stderr: %v", i, err)
}
stderr.Close()
<-clientStderrReadDone
} else if stderr != nil {
t.Fatalf("%d: stderr: expected nil: %#v", i, stderr)
}
return nil
fw.fakeKubelet.attachFunc = func(podFullName string, uid types.UID, containerName string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
attachInvoked = true
return testStreamFunc(podFullName, uid, containerName, nil, in, out, stderr, tty, done)
}
var url string
if test.uid {
url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedUid + "/" + expectedContainerName + "?command=ls&command=-a"
url = fw.testHTTPServer.URL + "/" + verb + "/" + podNamespace + "/" + podName + "/" + expectedUid + "/" + expectedContainerName + "?ignore=1"
} else {
url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?command=ls&command=-a"
url = fw.testHTTPServer.URL + "/" + verb + "/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?ignore=1"
}
if verb == "exec" {
url += "&command=ls&command=-a"
}
if test.stdin {
url += "&" + api.ExecStdinParam + "=1"
@ -1186,11 +1202,9 @@ func TestServeExecInContainer(t *testing.T) {
h := http.Header{}
h.Set(api.StreamType, api.StreamTypeError)
errorStream, err := conn.CreateStream(h)
if err != nil {
if _, err := conn.CreateStream(h); err != nil {
t.Fatalf("%d: error creating error stream: %v", i, err)
}
defer errorStream.Reset()
if test.stdin {
h.Set(api.StreamType, api.StreamTypeStdin)
@ -1198,7 +1212,6 @@ func TestServeExecInContainer(t *testing.T) {
if err != nil {
t.Fatalf("%d: error creating stdin stream: %v", i, err)
}
defer stream.Reset()
_, err = stream.Write([]byte(expectedStdin))
if err != nil {
t.Fatalf("%d: error writing to stdin stream: %v", i, err)
@ -1212,7 +1225,6 @@ func TestServeExecInContainer(t *testing.T) {
if err != nil {
t.Fatalf("%d: error creating stdout stream: %v", i, err)
}
defer stdoutStream.Reset()
}
var stderrStream httpstream.Stream
@ -1222,7 +1234,6 @@ func TestServeExecInContainer(t *testing.T) {
if err != nil {
t.Fatalf("%d: error creating stderr stream: %v", i, err)
}
defer stderrStream.Reset()
}
if test.stdout {
@ -1249,239 +1260,33 @@ func TestServeExecInContainer(t *testing.T) {
}
}
<-execFuncDone
// wait for the server to finish before checking if the attach/exec funcs were invoked
<-done
if verb == "exec" {
if !execInvoked {
t.Errorf("%d: exec was not invoked", i)
}
if attachInvoked {
t.Errorf("%d: attach should not have been invoked", i)
}
} else {
if !attachInvoked {
t.Errorf("%d: attach was not invoked", i)
}
if execInvoked {
t.Errorf("%d: exec should not have been invoked", i)
}
}
}
}
// TODO: largely cloned from TestServeExecContainer, refactor and re-use code
func TestServeExecInContainer(t *testing.T) {
testExecAttach(t, "exec")
}
func TestServeAttachContainer(t *testing.T) {
tests := []struct {
stdin bool
stdout bool
stderr bool
tty bool
responseStatusCode int
uid bool
}{
{responseStatusCode: http.StatusBadRequest},
{stdin: true, responseStatusCode: http.StatusSwitchingProtocols},
{stdout: true, responseStatusCode: http.StatusSwitchingProtocols},
{stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
{stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
{stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols},
{stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
}
for i, test := range tests {
fw := newServerTest()
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 0
}
podNamespace := "other"
podName := "foo"
expectedPodName := getPodName(podName, podNamespace)
expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647"
expectedContainerName := "baz"
expectedStdin := "stdin"
expectedStdout := "stdout"
expectedStderr := "stderr"
attachFuncDone := make(chan struct{})
clientStdoutReadDone := make(chan struct{})
clientStderrReadDone := make(chan struct{})
fw.fakeKubelet.attachFunc = func(podFullName string, uid types.UID, containerName string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
defer close(attachFuncDone)
if podFullName != expectedPodName {
t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName)
}
if test.uid && string(uid) != expectedUid {
t.Fatalf("%d: uid: expected %v, got %v", i, expectedUid, uid)
}
if containerName != expectedContainerName {
t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName)
}
if test.stdin {
if in == nil {
t.Fatalf("%d: stdin: expected non-nil", i)
}
b := make([]byte, 10)
n, err := in.Read(b)
if err != nil {
t.Fatalf("%d: error reading from stdin: %v", i, err)
}
if e, a := expectedStdin, string(b[0:n]); e != a {
t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a)
}
} else if in != nil {
t.Fatalf("%d: stdin: expected nil: %#v", i, in)
}
if test.stdout {
if out == nil {
t.Fatalf("%d: stdout: expected non-nil", i)
}
_, err := out.Write([]byte(expectedStdout))
if err != nil {
t.Fatalf("%d:, error writing to stdout: %v", i, err)
}
out.Close()
<-clientStdoutReadDone
} else if out != nil {
t.Fatalf("%d: stdout: expected nil: %#v", i, out)
}
if tty {
if stderr != nil {
t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr)
}
} else if test.stderr {
if stderr == nil {
t.Fatalf("%d: stderr: expected non-nil", i)
}
_, err := stderr.Write([]byte(expectedStderr))
if err != nil {
t.Fatalf("%d:, error writing to stderr: %v", i, err)
}
stderr.Close()
<-clientStderrReadDone
} else if stderr != nil {
t.Fatalf("%d: stderr: expected nil: %#v", i, stderr)
}
return nil
}
var url string
if test.uid {
url = fw.testHTTPServer.URL + "/attach/" + podNamespace + "/" + podName + "/" + expectedUid + "/" + expectedContainerName + "?"
} else {
url = fw.testHTTPServer.URL + "/attach/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?"
}
if test.stdin {
url += "&" + api.ExecStdinParam + "=1"
}
if test.stdout {
url += "&" + api.ExecStdoutParam + "=1"
}
if test.stderr && !test.tty {
url += "&" + api.ExecStderrParam + "=1"
}
if test.tty {
url += "&" + api.ExecTTYParam + "=1"
}
var (
resp *http.Response
err error
upgradeRoundTripper httpstream.UpgradeRoundTripper
c *http.Client
)
if test.responseStatusCode != http.StatusSwitchingProtocols {
c = &http.Client{}
} else {
upgradeRoundTripper = spdy.NewRoundTripper(nil)
c = &http.Client{Transport: upgradeRoundTripper}
}
resp, err = c.Post(url, "", nil)
if err != nil {
t.Fatalf("%d: Got error POSTing: %v", i, err)
}
defer resp.Body.Close()
_, err = ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("%d: Error reading response body: %v", i, err)
}
if e, a := test.responseStatusCode, resp.StatusCode; e != a {
t.Fatalf("%d: response status: expected %v, got %v", i, e, a)
}
if test.responseStatusCode != http.StatusSwitchingProtocols {
continue
}
conn, err := upgradeRoundTripper.NewConnection(resp)
if err != nil {
t.Fatalf("Unexpected error creating streaming connection: %s", err)
}
if conn == nil {
t.Fatalf("%d: unexpected nil conn", i)
}
defer conn.Close()
h := http.Header{}
h.Set(api.StreamType, api.StreamTypeError)
errorStream, err := conn.CreateStream(h)
if err != nil {
t.Fatalf("%d: error creating error stream: %v", i, err)
}
defer errorStream.Reset()
if test.stdin {
h.Set(api.StreamType, api.StreamTypeStdin)
stream, err := conn.CreateStream(h)
if err != nil {
t.Fatalf("%d: error creating stdin stream: %v", i, err)
}
defer stream.Reset()
_, err = stream.Write([]byte(expectedStdin))
if err != nil {
t.Fatalf("%d: error writing to stdin stream: %v", i, err)
}
}
var stdoutStream httpstream.Stream
if test.stdout {
h.Set(api.StreamType, api.StreamTypeStdout)
stdoutStream, err = conn.CreateStream(h)
if err != nil {
t.Fatalf("%d: error creating stdout stream: %v", i, err)
}
defer stdoutStream.Reset()
}
var stderrStream httpstream.Stream
if test.stderr && !test.tty {
h.Set(api.StreamType, api.StreamTypeStderr)
stderrStream, err = conn.CreateStream(h)
if err != nil {
t.Fatalf("%d: error creating stderr stream: %v", i, err)
}
defer stderrStream.Reset()
}
if test.stdout {
output := make([]byte, 10)
n, err := stdoutStream.Read(output)
close(clientStdoutReadDone)
if err != nil {
t.Fatalf("%d: error reading from stdout stream: %v", i, err)
}
if e, a := expectedStdout, string(output[0:n]); e != a {
t.Fatalf("%d: stdout: expected '%v', got '%v'", i, e, a)
}
}
if test.stderr && !test.tty {
output := make([]byte, 10)
n, err := stderrStream.Read(output)
close(clientStderrReadDone)
if err != nil {
t.Fatalf("%d: error reading from stderr stream: %v", i, err)
}
if e, a := expectedStderr, string(output[0:n]); e != a {
t.Fatalf("%d: stderr: expected '%v', got '%v'", i, e, a)
}
}
<-attachFuncDone
}
testExecAttach(t, "attach")
}
func TestServePortForwardIdleTimeout(t *testing.T) {

View File

@ -114,20 +114,24 @@ func negotiateProtocol(clientProtocols, serverProtocols []string) string {
return ""
}
// Handshake performs a subprotocol negotiation. If the client did not request
// a specific subprotocol, defaultProtocol is used. If the client did request a
// Handshake performs a subprotocol negotiation. If the client did request a
// subprotocol, Handshake will select the first common value found in
// serverProtocols. If a match is found, Handshake adds a response header
// indicating the chosen subprotocol. If no match is found, HTTP forbidden is
// returned, along with a response header containing the list of protocols the
// server can accept.
func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string, defaultProtocol string) (string, error) {
func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string) (string, error) {
clientProtocols := req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)]
if len(clientProtocols) == 0 {
// Kube 1.0 client that didn't support subprotocol negotiation
// TODO remove this defaulting logic once Kube 1.0 is no longer supported
w.Header().Add(HeaderProtocolVersion, defaultProtocol)
return defaultProtocol, nil
// Kube 1.0 clients didn't support subprotocol negotiation.
// TODO require clientProtocols once Kube 1.0 is no longer supported
return "", nil
}
if len(serverProtocols) == 0 {
// Kube 1.0 servers didn't support subprotocol negotiation. This is mainly for testing.
// TODO require serverProtocols once Kube 1.0 is no longer supported
return "", nil
}
negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols)

View File

@ -20,6 +20,8 @@ import (
"net/http"
"reflect"
"testing"
"k8s.io/kubernetes/pkg/api"
)
type responseWriter struct {
@ -46,8 +48,6 @@ func (r *responseWriter) Write([]byte) (int, error) {
}
func TestHandshake(t *testing.T) {
defaultProtocol := "default"
tests := map[string]struct {
clientProtocols []string
serverProtocols []string
@ -57,7 +57,7 @@ func TestHandshake(t *testing.T) {
"no client protocols": {
clientProtocols: []string{},
serverProtocols: []string{"a", "b"},
expectedProtocol: defaultProtocol,
expectedProtocol: "",
},
"no common protocol": {
clientProtocols: []string{"c"},
@ -83,7 +83,7 @@ func TestHandshake(t *testing.T) {
}
w := newResponseWriter()
negotiated, err := Handshake(req, w, test.serverProtocols, defaultProtocol)
negotiated, err := Handshake(req, w, test.serverProtocols)
// verify negotiated protocol
if e, a := test.expectedProtocol, negotiated; e != a {
@ -112,8 +112,15 @@ func TestHandshake(t *testing.T) {
t.Errorf("%s: unexpected non-nil w.statusCode: %d", w.statusCode)
}
if len(test.expectedProtocol) == 0 {
if len(w.Header()[HeaderProtocolVersion]) > 0 {
t.Errorf("%s: unexpected protocol version response header: %s", w.Header()[HeaderProtocolVersion])
}
continue
}
// verify response headers
if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !reflect.DeepEqual(e, a) {
if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !api.Semantic.DeepEqual(e, a) {
t.Errorf("%s: protocol response header: expected %v, got %v", name, e, a)
}
}