mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-30 15:05:27 +00:00
Refactor exec code to support version skew testing
Refactor exec/attach client and server code to better support interoperability testing of different client and server subprotocol versions.
This commit is contained in:
parent
d124deeb2f
commit
4551ba6b53
@ -26,6 +26,7 @@ import (
|
||||
|
||||
"k8s.io/kubernetes/pkg/client/restclient"
|
||||
"k8s.io/kubernetes/pkg/client/transport"
|
||||
"k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
|
||||
)
|
||||
@ -36,7 +37,7 @@ type Executor interface {
|
||||
// non-nil stream to a remote system, and return an error if a problem occurs. If tty
|
||||
// is set, the stderr stream is not used (raw TTY manages stdout and stderr over the
|
||||
// stdout stream).
|
||||
Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error
|
||||
Stream(supportedProtocols []string, stdin io.Reader, stdout, stderr io.Writer, tty bool) error
|
||||
}
|
||||
|
||||
// StreamExecutor supports the ability to dial an httpstream connection and the ability to
|
||||
@ -128,26 +129,13 @@ func (e *streamExecutor) Dial(protocols ...string) (httpstream.Connection, strin
|
||||
return conn, resp.Header.Get(httpstream.HeaderProtocolVersion), nil
|
||||
}
|
||||
|
||||
const (
|
||||
// The SPDY subprotocol "channel.k8s.io" is used for remote command
|
||||
// attachment/execution. This represents the initial unversioned subprotocol,
|
||||
// which has the known bugs http://issues.k8s.io/13394 and
|
||||
// http://issues.k8s.io/13395.
|
||||
StreamProtocolV1Name = "channel.k8s.io"
|
||||
// The SPDY subprotocol "v2.channel.k8s.io" is used for remote command
|
||||
// attachment/execution. It is the second version of the subprotocol and
|
||||
// resolves the issues present in the first version.
|
||||
StreamProtocolV2Name = "v2.channel.k8s.io"
|
||||
)
|
||||
|
||||
type streamProtocolHandler interface {
|
||||
stream(httpstream.Connection) error
|
||||
}
|
||||
|
||||
// Stream opens a protocol streamer to the server and streams until a client closes
|
||||
// the connection or the server disconnects.
|
||||
func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error {
|
||||
supportedProtocols := []string{StreamProtocolV2Name, StreamProtocolV1Name}
|
||||
func (e *streamExecutor) Stream(supportedProtocols []string, stdin io.Reader, stdout, stderr io.Writer, tty bool) error {
|
||||
conn, protocol, err := e.Dial(supportedProtocols...)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -157,7 +145,7 @@ func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty b
|
||||
var streamer streamProtocolHandler
|
||||
|
||||
switch protocol {
|
||||
case StreamProtocolV2Name:
|
||||
case remotecommand.StreamProtocolV2Name:
|
||||
streamer = &streamProtocolV2{
|
||||
stdin: stdin,
|
||||
stdout: stdout,
|
||||
@ -165,9 +153,9 @@ func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty b
|
||||
tty: tty,
|
||||
}
|
||||
case "":
|
||||
glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", StreamProtocolV1Name)
|
||||
glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", remotecommand.StreamProtocolV1Name)
|
||||
fallthrough
|
||||
case StreamProtocolV1Name:
|
||||
case remotecommand.StreamProtocolV1Name:
|
||||
streamer = &streamProtocolV1{
|
||||
stdin: stdin,
|
||||
stdout: stdout,
|
||||
|
@ -18,6 +18,7 @@ package remotecommand
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@ -26,325 +27,263 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"k8s.io/kubernetes/pkg/api"
|
||||
"k8s.io/kubernetes/pkg/api/unversioned"
|
||||
"k8s.io/kubernetes/pkg/client/restclient"
|
||||
"k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
|
||||
"k8s.io/kubernetes/pkg/types"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
|
||||
)
|
||||
|
||||
type streamAndReply struct {
|
||||
httpstream.Stream
|
||||
replySent <-chan struct{}
|
||||
type fakeExecutor struct {
|
||||
t *testing.T
|
||||
testName string
|
||||
errorData string
|
||||
stdoutData string
|
||||
stderrData string
|
||||
expectStdin bool
|
||||
stdinReceived bytes.Buffer
|
||||
tty bool
|
||||
messageCount int
|
||||
command []string
|
||||
exec bool
|
||||
}
|
||||
|
||||
func waitStreamReply(replySent <-chan struct{}, notify chan<- struct{}, stop <-chan struct{}) {
|
||||
select {
|
||||
case <-replySent:
|
||||
notify <- struct{}{}
|
||||
case <-stop:
|
||||
}
|
||||
func (ex *fakeExecutor) ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error {
|
||||
return ex.run(name, uid, container, cmd, in, out, err, tty)
|
||||
}
|
||||
|
||||
func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int) http.HandlerFunc {
|
||||
// error + stdin + stdout
|
||||
expectedStreams := 3
|
||||
if !tty {
|
||||
// stderr
|
||||
expectedStreams++
|
||||
func (ex *fakeExecutor) AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool) error {
|
||||
return ex.run(name, uid, container, nil, in, out, err, tty)
|
||||
}
|
||||
|
||||
func (ex *fakeExecutor) run(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error {
|
||||
ex.command = cmd
|
||||
ex.tty = tty
|
||||
|
||||
if e, a := "pod", name; e != a {
|
||||
ex.t.Errorf("%s: pod: expected %q, got %q", ex.testName, e, a)
|
||||
}
|
||||
if e, a := "uid", uid; e != string(a) {
|
||||
ex.t.Errorf("%s: uid: expected %q, got %q", ex.testName, e, a)
|
||||
}
|
||||
if ex.exec {
|
||||
if e, a := "ls /", strings.Join(ex.command, " "); e != a {
|
||||
ex.t.Errorf("%s: command: expected %q, got %q", ex.testName, e, a)
|
||||
}
|
||||
} else {
|
||||
if len(ex.command) > 0 {
|
||||
ex.t.Errorf("%s: command: expected nothing, got %v", ex.testName, ex.command)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ex.errorData) > 0 {
|
||||
return errors.New(ex.errorData)
|
||||
}
|
||||
|
||||
if len(ex.stdoutData) > 0 {
|
||||
for i := 0; i < ex.messageCount; i++ {
|
||||
fmt.Fprint(out, ex.stdoutData)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ex.stderrData) > 0 {
|
||||
for i := 0; i < ex.messageCount; i++ {
|
||||
fmt.Fprint(err, ex.stderrData)
|
||||
}
|
||||
}
|
||||
|
||||
if ex.expectStdin {
|
||||
io.Copy(&ex.stdinReceived, in)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func fakeServer(t *testing.T, testName string, exec bool, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int, serverProtocols []string) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
protocol, err := httpstream.Handshake(req, w, []string{StreamProtocolV2Name}, StreamProtocolV1Name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if protocol != StreamProtocolV2Name {
|
||||
t.Fatalf("unexpected protocol: %s", protocol)
|
||||
}
|
||||
streamCh := make(chan streamAndReply)
|
||||
|
||||
upgrader := spdy.NewResponseUpgrader()
|
||||
conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
||||
streamCh <- streamAndReply{Stream: stream, replySent: replySent}
|
||||
return nil
|
||||
})
|
||||
// from this point on, we can no longer call methods on w
|
||||
if conn == nil {
|
||||
// The upgrader is responsible for notifying the client of any errors that
|
||||
// occurred during upgrading. All we can do is return here at this point
|
||||
// if we weren't successful in upgrading.
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
|
||||
receivedStreams := 0
|
||||
replyChan := make(chan struct{})
|
||||
stop := make(chan struct{})
|
||||
defer close(stop)
|
||||
WaitForStreams:
|
||||
for {
|
||||
select {
|
||||
case stream := <-streamCh:
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
switch streamType {
|
||||
case api.StreamTypeError:
|
||||
errorStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
case api.StreamTypeStdin:
|
||||
stdinStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
case api.StreamTypeStdout:
|
||||
stdoutStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
case api.StreamTypeStderr:
|
||||
stderrStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
default:
|
||||
t.Errorf("%d: unexpected stream type: %q", i, streamType)
|
||||
}
|
||||
|
||||
if receivedStreams == expectedStreams {
|
||||
break WaitForStreams
|
||||
}
|
||||
case <-replyChan:
|
||||
receivedStreams++
|
||||
if receivedStreams == expectedStreams {
|
||||
break WaitForStreams
|
||||
}
|
||||
}
|
||||
executor := &fakeExecutor{
|
||||
t: t,
|
||||
testName: testName,
|
||||
errorData: errorData,
|
||||
stdoutData: stdoutData,
|
||||
stderrData: stderrData,
|
||||
expectStdin: len(stdinData) > 0,
|
||||
tty: tty,
|
||||
messageCount: messageCount,
|
||||
exec: exec,
|
||||
}
|
||||
|
||||
if len(errorData) > 0 {
|
||||
n, err := fmt.Fprint(errorStream, errorData)
|
||||
if err != nil {
|
||||
t.Errorf("%d: error writing to errorStream: %v", i, err)
|
||||
}
|
||||
if e, a := len(errorData), n; e != a {
|
||||
t.Errorf("%d: expected to write %d bytes to errorStream, but only wrote %d", i, e, a)
|
||||
}
|
||||
errorStream.Close()
|
||||
if exec {
|
||||
remotecommand.ServeExec(w, req, executor, "pod", "uid", "container", 0, 10*time.Second, serverProtocols)
|
||||
} else {
|
||||
remotecommand.ServeAttach(w, req, executor, "pod", "uid", "container", 0, 10*time.Second, serverProtocols)
|
||||
}
|
||||
|
||||
if len(stdoutData) > 0 {
|
||||
for j := 0; j < messageCount; j++ {
|
||||
n, err := fmt.Fprint(stdoutStream, stdoutData)
|
||||
if err != nil {
|
||||
t.Errorf("%d: error writing to stdoutStream: %v", i, err)
|
||||
}
|
||||
if e, a := len(stdoutData), n; e != a {
|
||||
t.Errorf("%d: expected to write %d bytes to stdoutStream, but only wrote %d", i, e, a)
|
||||
}
|
||||
}
|
||||
stdoutStream.Close()
|
||||
}
|
||||
if len(stderrData) > 0 {
|
||||
for j := 0; j < messageCount; j++ {
|
||||
n, err := fmt.Fprint(stderrStream, stderrData)
|
||||
if err != nil {
|
||||
t.Errorf("%d: error writing to stderrStream: %v", i, err)
|
||||
}
|
||||
if e, a := len(stderrData), n; e != a {
|
||||
t.Errorf("%d: expected to write %d bytes to stderrStream, but only wrote %d", i, e, a)
|
||||
}
|
||||
}
|
||||
stderrStream.Close()
|
||||
}
|
||||
if len(stdinData) > 0 {
|
||||
data := make([]byte, len(stdinData))
|
||||
for j := 0; j < messageCount; j++ {
|
||||
n, err := io.ReadFull(stdinStream, data)
|
||||
if err != nil {
|
||||
t.Errorf("%d: error reading stdin stream: %v", i, err)
|
||||
}
|
||||
if e, a := len(stdinData), n; e != a {
|
||||
t.Errorf("%d: expected to read %d bytes from stdinStream, but only read %d", i, e, a)
|
||||
}
|
||||
if e, a := stdinData, string(data); e != a {
|
||||
t.Errorf("%d: stdin: expected %q, got %q", i, e, a)
|
||||
}
|
||||
}
|
||||
stdinStream.Close()
|
||||
if e, a := strings.Repeat(stdinData, messageCount), executor.stdinReceived.String(); e != a {
|
||||
t.Errorf("%s: stdin: expected %q, got %q", testName, e, a)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestExecuteRemoteCommand(t *testing.T) {
|
||||
func TestStream(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Stdin string
|
||||
Stdout string
|
||||
Stderr string
|
||||
Error string
|
||||
Tty bool
|
||||
MessageCount int
|
||||
TestName string
|
||||
Stdin string
|
||||
Stdout string
|
||||
Stderr string
|
||||
Error string
|
||||
Tty bool
|
||||
MessageCount int
|
||||
ClientProtocols []string
|
||||
ServerProtocols []string
|
||||
}{
|
||||
{
|
||||
Error: "bail",
|
||||
TestName: "error",
|
||||
Error: "bail",
|
||||
Stdout: "a",
|
||||
ClientProtocols: []string{remotecommand.StreamProtocolV2Name},
|
||||
ServerProtocols: []string{remotecommand.StreamProtocolV2Name},
|
||||
},
|
||||
{
|
||||
Stdin: "a",
|
||||
Stdout: "b",
|
||||
Stderr: "c",
|
||||
// TODO bump this to a larger number such as 100 once
|
||||
// https://github.com/docker/spdystream/issues/55 is fixed and the Godep
|
||||
// is bumped. Sending multiple messages over stdin/stdout/stderr results
|
||||
// in more frames being spread across multiple spdystream frame workers.
|
||||
// This makes it more likely that the spdystream bug will be encountered,
|
||||
// where streams are closed as soon as a goaway frame is received, and
|
||||
// any pending frames that haven't been processed yet may not be
|
||||
// delivered (it's a race).
|
||||
MessageCount: 1,
|
||||
TestName: "in/out/err",
|
||||
Stdin: "a",
|
||||
Stdout: "b",
|
||||
Stderr: "c",
|
||||
MessageCount: 100,
|
||||
ClientProtocols: []string{remotecommand.StreamProtocolV2Name},
|
||||
ServerProtocols: []string{remotecommand.StreamProtocolV2Name},
|
||||
},
|
||||
{
|
||||
Stdin: "a",
|
||||
Stdout: "b",
|
||||
Tty: true,
|
||||
TestName: "in/out/tty",
|
||||
Stdin: "a",
|
||||
Stdout: "b",
|
||||
Tty: true,
|
||||
MessageCount: 100,
|
||||
ClientProtocols: []string{remotecommand.StreamProtocolV2Name},
|
||||
ServerProtocols: []string{remotecommand.StreamProtocolV2Name},
|
||||
},
|
||||
{
|
||||
// 1.0 kubectl, 1.0 kubelet
|
||||
TestName: "unversioned client, unversioned server",
|
||||
Stdout: "b",
|
||||
Stderr: "c",
|
||||
MessageCount: 1,
|
||||
ClientProtocols: []string{},
|
||||
ServerProtocols: []string{},
|
||||
},
|
||||
{
|
||||
// 1.0 kubectl, 1.1+ kubelet
|
||||
TestName: "unversioned client, versioned server",
|
||||
Stdout: "b",
|
||||
Stderr: "c",
|
||||
MessageCount: 1,
|
||||
ClientProtocols: []string{},
|
||||
ServerProtocols: []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name},
|
||||
},
|
||||
{
|
||||
// 1.1+ kubectl, 1.0 kubelet
|
||||
TestName: "versioned client, unversioned server",
|
||||
Stdout: "b",
|
||||
Stderr: "c",
|
||||
MessageCount: 1,
|
||||
ClientProtocols: []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name},
|
||||
ServerProtocols: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range testCases {
|
||||
localOut := &bytes.Buffer{}
|
||||
localErr := &bytes.Buffer{}
|
||||
|
||||
server := httptest.NewServer(fakeExecServer(t, i, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount))
|
||||
|
||||
url, _ := url.ParseRequestURI(server.URL)
|
||||
c := restclient.NewRESTClient(url, "", restclient.ContentConfig{GroupVersion: &unversioned.GroupVersion{Group: "x"}}, -1, -1, nil)
|
||||
req := c.Post().Resource("testing")
|
||||
req.SetHeader(httpstream.HeaderProtocolVersion, StreamProtocolV2Name)
|
||||
req.Param("command", "ls")
|
||||
req.Param("command", "/")
|
||||
conf := &restclient.Config{
|
||||
Host: server.URL,
|
||||
}
|
||||
e, err := NewExecutor(conf, "POST", req.URL())
|
||||
if err != nil {
|
||||
t.Errorf("%d: unexpected error: %v", i, err)
|
||||
continue
|
||||
}
|
||||
err = e.Stream(strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount)), localOut, localErr, testCase.Tty)
|
||||
hasErr := err != nil
|
||||
|
||||
if len(testCase.Error) > 0 {
|
||||
if !hasErr {
|
||||
t.Errorf("%d: expected an error", i)
|
||||
for _, testCase := range testCases {
|
||||
for _, exec := range []bool{true, false} {
|
||||
var name string
|
||||
if exec {
|
||||
name = testCase.TestName + " (exec)"
|
||||
} else {
|
||||
if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
|
||||
t.Errorf("%d: expected error stream read '%v', got '%v'", i, e, a)
|
||||
name = testCase.TestName + " (attach)"
|
||||
}
|
||||
var (
|
||||
streamIn io.Reader
|
||||
streamOut, streamErr io.Writer
|
||||
)
|
||||
localOut := &bytes.Buffer{}
|
||||
localErr := &bytes.Buffer{}
|
||||
|
||||
server := httptest.NewServer(fakeServer(t, name, exec, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, testCase.MessageCount, testCase.ServerProtocols))
|
||||
|
||||
url, _ := url.ParseRequestURI(server.URL)
|
||||
c := restclient.NewRESTClient(url, "", restclient.ContentConfig{GroupVersion: &unversioned.GroupVersion{Group: "x"}}, -1, -1, nil)
|
||||
req := c.Post().Resource("testing")
|
||||
|
||||
if exec {
|
||||
req.Param("command", "ls")
|
||||
req.Param("command", "/")
|
||||
}
|
||||
|
||||
if len(testCase.Stdin) > 0 {
|
||||
req.Param(api.ExecStdinParam, "1")
|
||||
streamIn = strings.NewReader(strings.Repeat(testCase.Stdin, testCase.MessageCount))
|
||||
}
|
||||
|
||||
if len(testCase.Stdout) > 0 {
|
||||
req.Param(api.ExecStdoutParam, "1")
|
||||
streamOut = localOut
|
||||
}
|
||||
|
||||
if testCase.Tty {
|
||||
req.Param(api.ExecTTYParam, "1")
|
||||
} else if len(testCase.Stderr) > 0 {
|
||||
req.Param(api.ExecStderrParam, "1")
|
||||
streamErr = localErr
|
||||
}
|
||||
|
||||
conf := &restclient.Config{
|
||||
Host: server.URL,
|
||||
}
|
||||
e, err := NewExecutor(conf, "POST", req.URL())
|
||||
if err != nil {
|
||||
t.Errorf("%s: unexpected error: %v", name, err)
|
||||
continue
|
||||
}
|
||||
err = e.Stream(testCase.ClientProtocols, streamIn, streamOut, streamErr, testCase.Tty)
|
||||
hasErr := err != nil
|
||||
|
||||
if len(testCase.Error) > 0 {
|
||||
if !hasErr {
|
||||
t.Errorf("%s: expected an error", name)
|
||||
} else {
|
||||
if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
|
||||
t.Errorf("%s: expected error stream read %q, got %q", name, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Uncomment when fix #19254
|
||||
// server.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
t.Errorf("%s: unexpected error: %v", name, err)
|
||||
// TODO: Uncomment when fix #19254
|
||||
// server.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
if len(testCase.Stdout) > 0 {
|
||||
if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() {
|
||||
t.Errorf("%s: expected stdout data '%s', got '%s'", name, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
if testCase.Stderr != "" {
|
||||
if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() {
|
||||
t.Errorf("%s: expected stderr data '%s', got '%s'", name, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Uncomment when fix #19254
|
||||
// server.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
t.Errorf("%d: unexpected error: %v", i, err)
|
||||
// TODO: Uncomment when fix #19254
|
||||
// server.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
if len(testCase.Stdout) > 0 {
|
||||
if e, a := strings.Repeat(testCase.Stdout, testCase.MessageCount), localOut; e != a.String() {
|
||||
t.Errorf("%d: expected stdout data '%s', got '%s'", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
if testCase.Stderr != "" {
|
||||
if e, a := strings.Repeat(testCase.Stderr, testCase.MessageCount), localErr; e != a.String() {
|
||||
t.Errorf("%d: expected stderr data '%s', got '%s'", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Uncomment when fix #19254
|
||||
// server.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: this test is largely cut and paste, refactor to share code
|
||||
func TestRequestAttachRemoteCommand(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Stdin string
|
||||
Stdout string
|
||||
Stderr string
|
||||
Error string
|
||||
Tty bool
|
||||
}{
|
||||
{
|
||||
Error: "bail",
|
||||
},
|
||||
{
|
||||
Stdin: "a",
|
||||
Stdout: "b",
|
||||
Stderr: "c",
|
||||
},
|
||||
{
|
||||
Stdin: "a",
|
||||
Stdout: "b",
|
||||
Tty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range testCases {
|
||||
localOut := &bytes.Buffer{}
|
||||
localErr := &bytes.Buffer{}
|
||||
|
||||
server := httptest.NewServer(fakeExecServer(t, i, testCase.Stdin, testCase.Stdout, testCase.Stderr, testCase.Error, testCase.Tty, 1))
|
||||
|
||||
url, _ := url.ParseRequestURI(server.URL)
|
||||
c := restclient.NewRESTClient(url, "", restclient.ContentConfig{GroupVersion: &unversioned.GroupVersion{Group: "x"}}, -1, -1, nil)
|
||||
req := c.Post().Resource("testing")
|
||||
|
||||
conf := &restclient.Config{
|
||||
Host: server.URL,
|
||||
}
|
||||
e, err := NewExecutor(conf, "POST", req.URL())
|
||||
if err != nil {
|
||||
t.Errorf("%d: unexpected error: %v", i, err)
|
||||
continue
|
||||
}
|
||||
err = e.Stream(strings.NewReader(testCase.Stdin), localOut, localErr, testCase.Tty)
|
||||
hasErr := err != nil
|
||||
|
||||
if len(testCase.Error) > 0 {
|
||||
if !hasErr {
|
||||
t.Errorf("%d: expected an error", i)
|
||||
} else {
|
||||
if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
|
||||
t.Errorf("%d: expected error stream read '%v', got '%v'", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Uncomment when fix #19254
|
||||
// server.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
if hasErr {
|
||||
t.Errorf("%d: unexpected error: %v", i, err)
|
||||
// TODO: Uncomment when fix #19254
|
||||
// server.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
if len(testCase.Stdout) > 0 {
|
||||
if e, a := testCase.Stdout, localOut; e != a.String() {
|
||||
t.Errorf("%d: expected stdout data '%s', got '%s'", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
if testCase.Stderr != "" {
|
||||
if e, a := testCase.Stderr, localErr; e != a.String() {
|
||||
t.Errorf("%d: expected stderr data '%s', got '%s'", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Uncomment when fix #19254
|
||||
// server.Close()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -29,6 +29,7 @@ import (
|
||||
client "k8s.io/kubernetes/pkg/client/unversioned"
|
||||
"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
|
||||
cmdutil "k8s.io/kubernetes/pkg/kubectl/cmd/util"
|
||||
remotecommandserver "k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
|
||||
utilerrors "k8s.io/kubernetes/pkg/util/errors"
|
||||
"k8s.io/kubernetes/pkg/util/interrupt"
|
||||
"k8s.io/kubernetes/pkg/util/term"
|
||||
@ -87,7 +88,7 @@ func (*DefaultRemoteAttach) Attach(method string, url *url.URL, config *restclie
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return exec.Stream(stdin, stdout, stderr, tty)
|
||||
return exec.Stream(remotecommandserver.SupportedStreamingProtocols, stdin, stdout, stderr, tty)
|
||||
}
|
||||
|
||||
// AttachOptions declare the arguments accepted by the Exec command
|
||||
|
@ -32,6 +32,7 @@ import (
|
||||
client "k8s.io/kubernetes/pkg/client/unversioned"
|
||||
"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
|
||||
cmdutil "k8s.io/kubernetes/pkg/kubectl/cmd/util"
|
||||
remotecommandserver "k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -87,7 +88,7 @@ func (*DefaultRemoteExecutor) Execute(method string, url *url.URL, config *restc
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return exec.Stream(stdin, stdout, stderr, tty)
|
||||
return exec.Stream(remotecommandserver.SupportedStreamingProtocols, stdin, stdout, stderr, tty)
|
||||
}
|
||||
|
||||
// ExecOptions declare the arguments accepted by the Exec command
|
||||
|
53
pkg/kubelet/server/remotecommand/attach.go
Normal file
53
pkg/kubelet/server/remotecommand/attach.go
Normal file
@ -0,0 +1,53 @@
|
||||
/*
|
||||
Copyright 2016 The Kubernetes Authors All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package remotecommand
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"k8s.io/kubernetes/pkg/types"
|
||||
"k8s.io/kubernetes/pkg/util/runtime"
|
||||
)
|
||||
|
||||
// Attacher knows how to attach to a running container in a pod.
|
||||
type Attacher interface {
|
||||
// AttachContainer attaches to the running container in the pod, copying data between in/out/err
|
||||
// and the container's stdin/stdout/stderr.
|
||||
AttachContainer(name string, uid types.UID, container string, in io.Reader, out, err io.WriteCloser, tty bool) error
|
||||
}
|
||||
|
||||
// ServeAttach handles requests to attach to a container. After creating/receiving the required
|
||||
// streams, it delegates the actual attaching to attacher.
|
||||
func ServeAttach(w http.ResponseWriter, req *http.Request, attacher Attacher, podName string, uid types.UID, container string, idleTimeout, streamCreationTimeout time.Duration, supportedProtocols []string) {
|
||||
ctx, ok := createStreams(req, w, supportedProtocols, idleTimeout, streamCreationTimeout)
|
||||
if !ok {
|
||||
// error is handled by createStreams
|
||||
return
|
||||
}
|
||||
defer ctx.conn.Close()
|
||||
|
||||
err := attacher.AttachContainer(podName, uid, container, ctx.stdinStream, ctx.stdoutStream, ctx.stderrStream, ctx.tty)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("error attaching to container: %v", err)
|
||||
runtime.HandleError(errors.New(msg))
|
||||
fmt.Fprint(ctx.errorStream, msg)
|
||||
}
|
||||
}
|
36
pkg/kubelet/server/remotecommand/contants.go
Normal file
36
pkg/kubelet/server/remotecommand/contants.go
Normal file
@ -0,0 +1,36 @@
|
||||
/*
|
||||
Copyright 2016 The Kubernetes Authors All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package remotecommand
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
DefaultStreamCreationTimeout = 30 * time.Second
|
||||
|
||||
// The SPDY subprotocol "channel.k8s.io" is used for remote command
|
||||
// attachment/execution. This represents the initial unversioned subprotocol,
|
||||
// which has the known bugs http://issues.k8s.io/13394 and
|
||||
// http://issues.k8s.io/13395.
|
||||
StreamProtocolV1Name = "channel.k8s.io"
|
||||
|
||||
// The SPDY subprotocol "v2.channel.k8s.io" is used for remote command
|
||||
// attachment/execution. It is the second version of the subprotocol and
|
||||
// resolves the issues present in the first version.
|
||||
StreamProtocolV2Name = "v2.channel.k8s.io"
|
||||
)
|
||||
|
||||
var SupportedStreamingProtocols = []string{StreamProtocolV2Name, StreamProtocolV1Name}
|
18
pkg/kubelet/server/remotecommand/doc.go
Normal file
18
pkg/kubelet/server/remotecommand/doc.go
Normal file
@ -0,0 +1,18 @@
|
||||
/*
|
||||
Copyright 2016 The Kubernetes Authors All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// package remotecommand contains functions related to executing commands in and attaching to pods.
|
||||
package remotecommand
|
57
pkg/kubelet/server/remotecommand/exec.go
Normal file
57
pkg/kubelet/server/remotecommand/exec.go
Normal file
@ -0,0 +1,57 @@
|
||||
/*
|
||||
Copyright 2016 The Kubernetes Authors All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package remotecommand
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"k8s.io/kubernetes/pkg/api"
|
||||
"k8s.io/kubernetes/pkg/types"
|
||||
"k8s.io/kubernetes/pkg/util/runtime"
|
||||
)
|
||||
|
||||
// Executor knows how to execute a command in a container in a pod.
|
||||
type Executor interface {
|
||||
// ExecInContainer executes a command in a container in the pod, copying data
|
||||
// between in/out/err and the container's stdin/stdout/stderr.
|
||||
ExecInContainer(name string, uid types.UID, container string, cmd []string, in io.Reader, out, err io.WriteCloser, tty bool) error
|
||||
}
|
||||
|
||||
// ServeExec handles requests to execute a command in a container. After
|
||||
// creating/receiving the required streams, it delegates the actual execution
|
||||
// to the executor.
|
||||
func ServeExec(w http.ResponseWriter, req *http.Request, executor Executor, podName string, uid types.UID, container string, idleTimeout, streamCreationTimeout time.Duration, supportedProtocols []string) {
|
||||
ctx, ok := createStreams(req, w, supportedProtocols, idleTimeout, streamCreationTimeout)
|
||||
if !ok {
|
||||
// error is handled by createStreams
|
||||
return
|
||||
}
|
||||
defer ctx.conn.Close()
|
||||
|
||||
cmd := req.URL.Query()[api.ExecCommandParamm]
|
||||
|
||||
err := executor.ExecInContainer(podName, uid, container, cmd, ctx.stdinStream, ctx.stdoutStream, ctx.stderrStream, ctx.tty)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("error executing command in container: %v", err)
|
||||
runtime.HandleError(errors.New(msg))
|
||||
fmt.Fprint(ctx.errorStream, msg)
|
||||
}
|
||||
}
|
277
pkg/kubelet/server/remotecommand/httpstream.go
Normal file
277
pkg/kubelet/server/remotecommand/httpstream.go
Normal file
@ -0,0 +1,277 @@
|
||||
/*
|
||||
Copyright 2016 The Kubernetes Authors All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package remotecommand
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"k8s.io/kubernetes/pkg/api"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream"
|
||||
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
|
||||
"k8s.io/kubernetes/pkg/util/runtime"
|
||||
"k8s.io/kubernetes/pkg/util/wsstream"
|
||||
|
||||
"github.com/golang/glog"
|
||||
)
|
||||
|
||||
// options contains details about which streams are required for
|
||||
// remote command execution.
|
||||
type options struct {
|
||||
stdin bool
|
||||
stdout bool
|
||||
stderr bool
|
||||
tty bool
|
||||
expectedStreams int
|
||||
}
|
||||
|
||||
// newOptions creates a new options from the Request.
|
||||
func newOptions(req *http.Request) (*options, error) {
|
||||
tty := req.FormValue(api.ExecTTYParam) == "1"
|
||||
stdin := req.FormValue(api.ExecStdinParam) == "1"
|
||||
stdout := req.FormValue(api.ExecStdoutParam) == "1"
|
||||
stderr := req.FormValue(api.ExecStderrParam) == "1"
|
||||
if tty && stderr {
|
||||
// TODO: make this an error before we reach this method
|
||||
glog.V(4).Infof("Access to exec with tty and stderr is not supported, bypassing stderr")
|
||||
stderr = false
|
||||
}
|
||||
|
||||
// count the streams client asked for, starting with 1
|
||||
expectedStreams := 1
|
||||
if stdin {
|
||||
expectedStreams++
|
||||
}
|
||||
if stdout {
|
||||
expectedStreams++
|
||||
}
|
||||
if stderr {
|
||||
expectedStreams++
|
||||
}
|
||||
|
||||
if expectedStreams == 1 {
|
||||
return nil, fmt.Errorf("you must specify at least 1 of stdin, stdout, stderr")
|
||||
}
|
||||
|
||||
return &options{
|
||||
stdin: stdin,
|
||||
stdout: stdout,
|
||||
stderr: stderr,
|
||||
tty: tty,
|
||||
expectedStreams: expectedStreams,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// context contains the connection and streams used when
|
||||
// forwarding an attach or execute session into a container.
|
||||
type context struct {
|
||||
conn io.Closer
|
||||
stdinStream io.ReadCloser
|
||||
stdoutStream io.WriteCloser
|
||||
stderrStream io.WriteCloser
|
||||
errorStream io.WriteCloser
|
||||
tty bool
|
||||
}
|
||||
|
||||
// streamAndReply holds both a Stream and a channel that is closed when the stream's reply frame is
|
||||
// enqueued. Consumers can wait for replySent to be closed prior to proceeding, to ensure that the
|
||||
// replyFrame is enqueued before the connection's goaway frame is sent (e.g. if a stream was
|
||||
// received and right after, the connection gets closed).
|
||||
type streamAndReply struct {
|
||||
httpstream.Stream
|
||||
replySent <-chan struct{}
|
||||
}
|
||||
|
||||
// waitStreamReply waits until either replySent or stop is closed. If replySent is closed, it sends
|
||||
// an empty struct to the notify channel.
|
||||
func waitStreamReply(replySent <-chan struct{}, notify chan<- struct{}, stop <-chan struct{}) {
|
||||
select {
|
||||
case <-replySent:
|
||||
notify <- struct{}{}
|
||||
case <-stop:
|
||||
}
|
||||
}
|
||||
|
||||
func createStreams(req *http.Request, w http.ResponseWriter, supportedStreamProtocols []string, idleTimeout, streamCreationTimeout time.Duration) (*context, bool) {
|
||||
opts, err := newOptions(req)
|
||||
if err != nil {
|
||||
runtime.HandleError(err)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(w, err.Error())
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if wsstream.IsWebSocketRequest(req) {
|
||||
return createWebSocketStreams(req, w, opts, idleTimeout)
|
||||
}
|
||||
|
||||
protocol, err := httpstream.Handshake(req, w, supportedStreamProtocols)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(w, err.Error())
|
||||
return nil, false
|
||||
}
|
||||
|
||||
streamCh := make(chan streamAndReply)
|
||||
|
||||
upgrader := spdy.NewResponseUpgrader()
|
||||
conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
||||
streamCh <- streamAndReply{Stream: stream, replySent: replySent}
|
||||
return nil
|
||||
})
|
||||
// from this point on, we can no longer call methods on response
|
||||
if conn == nil {
|
||||
// The upgrader is responsible for notifying the client of any errors that
|
||||
// occurred during upgrading. All we can do is return here at this point
|
||||
// if we weren't successful in upgrading.
|
||||
return nil, false
|
||||
}
|
||||
|
||||
conn.SetIdleTimeout(idleTimeout)
|
||||
|
||||
var handler protocolHandler
|
||||
switch protocol {
|
||||
case StreamProtocolV2Name:
|
||||
handler = &v2ProtocolHandler{}
|
||||
case "":
|
||||
glog.V(4).Infof("Client did not request protocol negotiaion. Falling back to %q", StreamProtocolV1Name)
|
||||
fallthrough
|
||||
case StreamProtocolV1Name:
|
||||
handler = &v1ProtocolHandler{}
|
||||
}
|
||||
|
||||
expired := time.NewTimer(streamCreationTimeout)
|
||||
|
||||
ctx, err := handler.waitForStreams(streamCh, opts.expectedStreams, expired.C)
|
||||
if err != nil {
|
||||
runtime.HandleError(err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
ctx.conn = conn
|
||||
ctx.tty = opts.tty
|
||||
return ctx, true
|
||||
}
|
||||
|
||||
type protocolHandler interface {
|
||||
// waitForStreams waits for the expected streams or a timeout, returning a
|
||||
// remoteCommandContext if all the streams were received, or an error if not.
|
||||
waitForStreams(streams <-chan streamAndReply, expectedStreams int, expired <-chan time.Time) (*context, error)
|
||||
}
|
||||
|
||||
// v2ProtocolHandler implements the V2 protocol version for streaming command execution.
|
||||
type v2ProtocolHandler struct{}
|
||||
|
||||
func (*v2ProtocolHandler) waitForStreams(streams <-chan streamAndReply, expectedStreams int, expired <-chan time.Time) (*context, error) {
|
||||
ctx := &context{}
|
||||
receivedStreams := 0
|
||||
replyChan := make(chan struct{})
|
||||
stop := make(chan struct{})
|
||||
defer close(stop)
|
||||
WaitForStreams:
|
||||
for {
|
||||
select {
|
||||
case stream := <-streams:
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
switch streamType {
|
||||
case api.StreamTypeError:
|
||||
ctx.errorStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
case api.StreamTypeStdin:
|
||||
ctx.stdinStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
case api.StreamTypeStdout:
|
||||
ctx.stdoutStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
case api.StreamTypeStderr:
|
||||
ctx.stderrStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
default:
|
||||
runtime.HandleError(fmt.Errorf("Unexpected stream type: %q", streamType))
|
||||
}
|
||||
case <-replyChan:
|
||||
receivedStreams++
|
||||
if receivedStreams == expectedStreams {
|
||||
break WaitForStreams
|
||||
}
|
||||
case <-expired:
|
||||
// TODO find a way to return the error to the user. Maybe use a separate
|
||||
// stream to report errors?
|
||||
return nil, errors.New("timed out waiting for client to create streams")
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// v1ProtocolHandler implements the V1 protocol version for streaming command execution.
|
||||
type v1ProtocolHandler struct{}
|
||||
|
||||
func (*v1ProtocolHandler) waitForStreams(streams <-chan streamAndReply, expectedStreams int, expired <-chan time.Time) (*context, error) {
|
||||
ctx := &context{}
|
||||
receivedStreams := 0
|
||||
replyChan := make(chan struct{})
|
||||
stop := make(chan struct{})
|
||||
defer close(stop)
|
||||
WaitForStreams:
|
||||
for {
|
||||
select {
|
||||
case stream := <-streams:
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
switch streamType {
|
||||
case api.StreamTypeError:
|
||||
ctx.errorStream = stream
|
||||
|
||||
// This defer statement shouldn't be here, but due to previous refactoring, it ended up in
|
||||
// here. This is what 1.0.x kubelets do, so we're retaining that behavior. This is fixed in
|
||||
// the v2ProtocolHandler.
|
||||
defer stream.Reset()
|
||||
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
case api.StreamTypeStdin:
|
||||
ctx.stdinStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
case api.StreamTypeStdout:
|
||||
ctx.stdoutStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
case api.StreamTypeStderr:
|
||||
ctx.stderrStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
default:
|
||||
runtime.HandleError(fmt.Errorf("Unexpected stream type: %q", streamType))
|
||||
}
|
||||
case <-replyChan:
|
||||
receivedStreams++
|
||||
if receivedStreams == expectedStreams {
|
||||
break WaitForStreams
|
||||
}
|
||||
case <-expired:
|
||||
// TODO find a way to return the error to the user. Maybe use a separate
|
||||
// stream to report errors?
|
||||
return nil, errors.New("timed out waiting for client to create streams")
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.stdinStream != nil {
|
||||
ctx.stdinStream.Close()
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
}
|
77
pkg/kubelet/server/remotecommand/websocket.go
Normal file
77
pkg/kubelet/server/remotecommand/websocket.go
Normal file
@ -0,0 +1,77 @@
|
||||
/*
|
||||
Copyright 2016 The Kubernetes Authors All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package remotecommand
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"k8s.io/kubernetes/pkg/httplog"
|
||||
"k8s.io/kubernetes/pkg/util/wsstream"
|
||||
|
||||
"github.com/golang/glog"
|
||||
)
|
||||
|
||||
// standardShellChannels returns the standard channel types for a shell connection (STDIN 0, STDOUT 1, STDERR 2)
|
||||
// along with the approximate duplex value. Supported subprotocols are "channel.k8s.io" and
|
||||
// "base64.channel.k8s.io".
|
||||
func standardShellChannels(stdin, stdout, stderr bool) []wsstream.ChannelType {
|
||||
// open three half-duplex channels
|
||||
channels := []wsstream.ChannelType{wsstream.ReadChannel, wsstream.WriteChannel, wsstream.WriteChannel}
|
||||
if !stdin {
|
||||
channels[0] = wsstream.IgnoreChannel
|
||||
}
|
||||
if !stdout {
|
||||
channels[1] = wsstream.IgnoreChannel
|
||||
}
|
||||
if !stderr {
|
||||
channels[2] = wsstream.IgnoreChannel
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
// createWebSocketStreams returns a remoteCommandContext containing the websocket connection and
|
||||
// streams needed to perform an exec or an attach.
|
||||
func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts *options, idleTimeout time.Duration) (*context, bool) {
|
||||
// open the requested channels, and always open the error channel
|
||||
channels := append(standardShellChannels(opts.stdin, opts.stdout, opts.stderr), wsstream.WriteChannel)
|
||||
conn := wsstream.NewConn(channels...)
|
||||
conn.SetIdleTimeout(idleTimeout)
|
||||
streams, err := conn.Open(httplog.Unlogged(w), req)
|
||||
if err != nil {
|
||||
glog.Errorf("Unable to upgrade websocket connection: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
// Send an empty message to the lowest writable channel to notify the client the connection is established
|
||||
// TODO: make generic to SPDY and WebSockets and do it outside of this method?
|
||||
switch {
|
||||
case opts.stdout:
|
||||
streams[1].Write([]byte{})
|
||||
case opts.stderr:
|
||||
streams[2].Write([]byte{})
|
||||
default:
|
||||
streams[3].Write([]byte{})
|
||||
}
|
||||
return &context{
|
||||
conn: conn,
|
||||
stdinStream: streams[0],
|
||||
stdoutStream: streams[1],
|
||||
stderrStream: streams[2],
|
||||
errorStream: streams[3],
|
||||
tty: opts.tty,
|
||||
}, true
|
||||
}
|
@ -43,12 +43,12 @@ import (
|
||||
"k8s.io/kubernetes/pkg/api/validation"
|
||||
"k8s.io/kubernetes/pkg/auth/authenticator"
|
||||
"k8s.io/kubernetes/pkg/auth/authorizer"
|
||||
"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
|
||||
"k8s.io/kubernetes/pkg/healthz"
|
||||
"k8s.io/kubernetes/pkg/httplog"
|
||||
"k8s.io/kubernetes/pkg/kubelet/cm"
|
||||
kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
|
||||
"k8s.io/kubernetes/pkg/kubelet/server/portforward"
|
||||
"k8s.io/kubernetes/pkg/kubelet/server/remotecommand"
|
||||
"k8s.io/kubernetes/pkg/kubelet/server/stats"
|
||||
"k8s.io/kubernetes/pkg/runtime"
|
||||
"k8s.io/kubernetes/pkg/types"
|
||||
@ -58,7 +58,6 @@ import (
|
||||
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
|
||||
"k8s.io/kubernetes/pkg/util/limitwriter"
|
||||
utilruntime "k8s.io/kubernetes/pkg/util/runtime"
|
||||
"k8s.io/kubernetes/pkg/util/wsstream"
|
||||
"k8s.io/kubernetes/pkg/volume"
|
||||
)
|
||||
|
||||
@ -540,12 +539,7 @@ func getContainerCoordinates(request *restful.Request) (namespace, pod string, u
|
||||
return
|
||||
}
|
||||
|
||||
const defaultStreamCreationTimeout = 30 * time.Second
|
||||
|
||||
type Closer interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
// getAttach handles requests to attach to a container.
|
||||
func (s *Server) getAttach(request *restful.Request, response *restful.Response) {
|
||||
podNamespace, podID, uid, container := getContainerCoordinates(request)
|
||||
pod, ok := s.host.GetPodByName(podNamespace, podID)
|
||||
@ -554,21 +548,35 @@ func (s *Server) getAttach(request *restful.Request, response *restful.Response)
|
||||
return
|
||||
}
|
||||
|
||||
stdinStream, stdoutStream, stderrStream, errorStream, conn, tty, ok := s.createStreams(request, response)
|
||||
if conn != nil {
|
||||
defer conn.Close()
|
||||
}
|
||||
remotecommand.ServeAttach(response.ResponseWriter,
|
||||
request.Request,
|
||||
s.host,
|
||||
kubecontainer.GetPodFullName(pod),
|
||||
uid,
|
||||
container,
|
||||
s.host.StreamingConnectionIdleTimeout(),
|
||||
remotecommand.DefaultStreamCreationTimeout,
|
||||
remotecommand.SupportedStreamingProtocols)
|
||||
}
|
||||
|
||||
// getExec handles requests to run a command inside a container.
|
||||
func (s *Server) getExec(request *restful.Request, response *restful.Response) {
|
||||
podNamespace, podID, uid, container := getContainerCoordinates(request)
|
||||
pod, ok := s.host.GetPodByName(podNamespace, podID)
|
||||
if !ok {
|
||||
// error is handled in the createStreams function
|
||||
response.WriteError(http.StatusNotFound, fmt.Errorf("pod does not exist"))
|
||||
return
|
||||
}
|
||||
|
||||
err := s.host.AttachContainer(kubecontainer.GetPodFullName(pod), uid, container, stdinStream, stdoutStream, stderrStream, tty)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("Error executing command in container: %v", err)
|
||||
glog.Error(msg)
|
||||
errorStream.Write([]byte(msg))
|
||||
}
|
||||
remotecommand.ServeExec(response.ResponseWriter,
|
||||
request.Request,
|
||||
s.host,
|
||||
kubecontainer.GetPodFullName(pod),
|
||||
uid,
|
||||
container,
|
||||
s.host.StreamingConnectionIdleTimeout(),
|
||||
remotecommand.DefaultStreamCreationTimeout,
|
||||
remotecommand.SupportedStreamingProtocols)
|
||||
}
|
||||
|
||||
// getRun handles requests to run a command inside a container.
|
||||
@ -588,187 +596,6 @@ func (s *Server) getRun(request *restful.Request, response *restful.Response) {
|
||||
writeJsonResponse(response, data)
|
||||
}
|
||||
|
||||
// getExec handles requests to run a command inside a container.
|
||||
func (s *Server) getExec(request *restful.Request, response *restful.Response) {
|
||||
podNamespace, podID, uid, container := getContainerCoordinates(request)
|
||||
pod, ok := s.host.GetPodByName(podNamespace, podID)
|
||||
if !ok {
|
||||
response.WriteError(http.StatusNotFound, fmt.Errorf("pod does not exist"))
|
||||
return
|
||||
}
|
||||
stdinStream, stdoutStream, stderrStream, errorStream, conn, tty, ok := s.createStreams(request, response)
|
||||
if conn != nil {
|
||||
defer conn.Close()
|
||||
}
|
||||
if !ok {
|
||||
// error is handled in the createStreams function
|
||||
return
|
||||
}
|
||||
cmd := request.Request.URL.Query()[api.ExecCommandParamm]
|
||||
err := s.host.ExecInContainer(kubecontainer.GetPodFullName(pod), uid, container, cmd, stdinStream, stdoutStream, stderrStream, tty)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("Error executing command in container: %v", err)
|
||||
glog.Error(msg)
|
||||
errorStream.Write([]byte(msg))
|
||||
}
|
||||
}
|
||||
|
||||
// standardShellChannels returns the standard channel types for a shell connection (STDIN 0, STDOUT 1, STDERR 2)
|
||||
// along with the approprxate duplex value
|
||||
func standardShellChannels(stdin, stdout, stderr bool) []wsstream.ChannelType {
|
||||
// open three half-duplex channels
|
||||
channels := []wsstream.ChannelType{wsstream.ReadChannel, wsstream.WriteChannel, wsstream.WriteChannel}
|
||||
if !stdin {
|
||||
channels[0] = wsstream.IgnoreChannel
|
||||
}
|
||||
if !stdout {
|
||||
channels[1] = wsstream.IgnoreChannel
|
||||
}
|
||||
if !stderr {
|
||||
channels[2] = wsstream.IgnoreChannel
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
// streamAndReply holds both a Stream and a channel that is closed when the stream's reply frame is
|
||||
// enqueued. Consumers can wait for replySent to be closed prior to proceeding, to ensure that the
|
||||
// replyFrame is enqueued before the connection's goaway frame is sent (e.g. if a stream was
|
||||
// received and right after, the connection gets closed).
|
||||
type streamAndReply struct {
|
||||
httpstream.Stream
|
||||
replySent <-chan struct{}
|
||||
}
|
||||
|
||||
func (s *Server) createStreams(request *restful.Request, response *restful.Response) (io.Reader, io.WriteCloser, io.WriteCloser, io.WriteCloser, Closer, bool, bool) {
|
||||
tty := request.QueryParameter(api.ExecTTYParam) == "1"
|
||||
stdin := request.QueryParameter(api.ExecStdinParam) == "1"
|
||||
stdout := request.QueryParameter(api.ExecStdoutParam) == "1"
|
||||
stderr := request.QueryParameter(api.ExecStderrParam) == "1"
|
||||
if tty && stderr {
|
||||
// TODO: make this an error before we reach this method
|
||||
glog.V(4).Infof("Access to exec with tty and stderr is not supported, bypassing stderr")
|
||||
stderr = false
|
||||
}
|
||||
|
||||
// count the streams client asked for, starting with 1
|
||||
expectedStreams := 1
|
||||
if stdin {
|
||||
expectedStreams++
|
||||
}
|
||||
if stdout {
|
||||
expectedStreams++
|
||||
}
|
||||
if stderr {
|
||||
expectedStreams++
|
||||
}
|
||||
|
||||
if expectedStreams == 1 {
|
||||
response.WriteError(http.StatusBadRequest, fmt.Errorf("you must specify at least 1 of stdin, stdout, stderr"))
|
||||
return nil, nil, nil, nil, nil, false, false
|
||||
}
|
||||
|
||||
if wsstream.IsWebSocketRequest(request.Request) {
|
||||
// open the requested channels, and always open the error channel
|
||||
channels := append(standardShellChannels(stdin, stdout, stderr), wsstream.WriteChannel)
|
||||
conn := wsstream.NewConn(channels...)
|
||||
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
|
||||
streams, err := conn.Open(httplog.Unlogged(response.ResponseWriter), request.Request)
|
||||
if err != nil {
|
||||
glog.Errorf("Unable to upgrade websocket connection: %v", err)
|
||||
return nil, nil, nil, nil, nil, false, false
|
||||
}
|
||||
// Send an empty message to the lowest writable channel to notify the client the connection is established
|
||||
// TODO: make generic to SDPY and WebSockets and do it outside of this method?
|
||||
switch {
|
||||
case stdout:
|
||||
streams[1].Write([]byte{})
|
||||
case stderr:
|
||||
streams[2].Write([]byte{})
|
||||
default:
|
||||
streams[3].Write([]byte{})
|
||||
}
|
||||
return streams[0], streams[1], streams[2], streams[3], conn, tty, true
|
||||
}
|
||||
|
||||
supportedStreamProtocols := []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}
|
||||
_, err := httpstream.Handshake(request.Request, response.ResponseWriter, supportedStreamProtocols, remotecommand.StreamProtocolV1Name)
|
||||
// negotiated protocol isn't used server side at the moment, but could be in the future
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, false, false
|
||||
}
|
||||
|
||||
streamCh := make(chan streamAndReply)
|
||||
|
||||
upgrader := spdy.NewResponseUpgrader()
|
||||
conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
||||
streamCh <- streamAndReply{Stream: stream, replySent: replySent}
|
||||
return nil
|
||||
})
|
||||
// from this point on, we can no longer call methods on response
|
||||
if conn == nil {
|
||||
// The upgrader is responsible for notifying the client of any errors that
|
||||
// occurred during upgrading. All we can do is return here at this point
|
||||
// if we weren't successful in upgrading.
|
||||
return nil, nil, nil, nil, nil, false, false
|
||||
}
|
||||
|
||||
conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout())
|
||||
|
||||
// TODO make it configurable?
|
||||
expired := time.NewTimer(defaultStreamCreationTimeout)
|
||||
|
||||
var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream
|
||||
receivedStreams := 0
|
||||
replyChan := make(chan struct{})
|
||||
stop := make(chan struct{})
|
||||
defer close(stop)
|
||||
WaitForStreams:
|
||||
for {
|
||||
select {
|
||||
case stream := <-streamCh:
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
switch streamType {
|
||||
case api.StreamTypeError:
|
||||
errorStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
case api.StreamTypeStdin:
|
||||
stdinStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
case api.StreamTypeStdout:
|
||||
stdoutStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
case api.StreamTypeStderr:
|
||||
stderrStream = stream
|
||||
go waitStreamReply(stream.replySent, replyChan, stop)
|
||||
default:
|
||||
glog.Errorf("Unexpected stream type: '%s'", streamType)
|
||||
}
|
||||
case <-replyChan:
|
||||
receivedStreams++
|
||||
if receivedStreams == expectedStreams {
|
||||
break WaitForStreams
|
||||
}
|
||||
case <-expired.C:
|
||||
// TODO find a way to return the error to the user. Maybe use a separate
|
||||
// stream to report errors?
|
||||
glog.Error("Timed out waiting for client to create streams")
|
||||
return nil, nil, nil, nil, nil, false, false
|
||||
}
|
||||
}
|
||||
|
||||
return stdinStream, stdoutStream, stderrStream, errorStream, conn, tty, true
|
||||
}
|
||||
|
||||
// waitStreamReply waits until either replySent or stop is closed. If replySent is closed, it sends
|
||||
// an empty struct to the notify channel.
|
||||
func waitStreamReply(replySent <-chan struct{}, notify chan<- struct{}, stop <-chan struct{}) {
|
||||
select {
|
||||
case <-replySent:
|
||||
notify <- struct{}{}
|
||||
case <-stop:
|
||||
}
|
||||
}
|
||||
|
||||
func getPodCoordinates(request *restful.Request) (namespace, pod string, uid types.UID) {
|
||||
namespace = request.PathParameter("podNamespace")
|
||||
pod = request.PathParameter("podID")
|
||||
@ -811,7 +638,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
|
||||
|
||||
podName := kubecontainer.GetPodFullName(pod)
|
||||
|
||||
ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), defaultStreamCreationTimeout)
|
||||
ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), remotecommand.DefaultStreamCreationTimeout)
|
||||
}
|
||||
|
||||
// ServePortForward handles a port forwarding request. A single request is
|
||||
@ -821,7 +648,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
|
||||
// handled by a single invocation of ServePortForward.
|
||||
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
|
||||
supportedPortForwardProtocols := []string{portforward.PortForwardProtocolV1Name}
|
||||
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols, portforward.PortForwardProtocolV1Name)
|
||||
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols)
|
||||
// negotiated protocol isn't currently used server side, but could be in the future
|
||||
if err != nil {
|
||||
// Handshake writes the error to the client
|
||||
|
@ -1019,7 +1019,7 @@ func TestServeExecInContainerIdleTimeout(t *testing.T) {
|
||||
<-conn.CloseChan()
|
||||
}
|
||||
|
||||
func TestServeExecInContainer(t *testing.T) {
|
||||
func testExecAttach(t *testing.T, verb string) {
|
||||
tests := []struct {
|
||||
stdin bool
|
||||
stdout bool
|
||||
@ -1053,12 +1053,15 @@ func TestServeExecInContainer(t *testing.T) {
|
||||
expectedStdin := "stdin"
|
||||
expectedStdout := "stdout"
|
||||
expectedStderr := "stderr"
|
||||
execFuncDone := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
clientStdoutReadDone := make(chan struct{})
|
||||
clientStderrReadDone := make(chan struct{})
|
||||
execInvoked := false
|
||||
attachInvoked := false
|
||||
|
||||
testStreamFunc := func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool, done chan struct{}) error {
|
||||
defer close(done)
|
||||
|
||||
fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
|
||||
defer close(execFuncDone)
|
||||
if podFullName != expectedPodName {
|
||||
t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName)
|
||||
}
|
||||
@ -1068,66 +1071,79 @@ func TestServeExecInContainer(t *testing.T) {
|
||||
if containerName != expectedContainerName {
|
||||
t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName)
|
||||
}
|
||||
|
||||
if test.stdin {
|
||||
if in == nil {
|
||||
t.Fatalf("%d: stdin: expected non-nil", i)
|
||||
}
|
||||
b := make([]byte, 10)
|
||||
n, err := in.Read(b)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading from stdin: %v", i, err)
|
||||
}
|
||||
if e, a := expectedStdin, string(b[0:n]); e != a {
|
||||
t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a)
|
||||
}
|
||||
} else if in != nil {
|
||||
t.Fatalf("%d: stdin: expected nil: %#v", i, in)
|
||||
}
|
||||
|
||||
if test.stdout {
|
||||
if out == nil {
|
||||
t.Fatalf("%d: stdout: expected non-nil", i)
|
||||
}
|
||||
_, err := out.Write([]byte(expectedStdout))
|
||||
if err != nil {
|
||||
t.Fatalf("%d:, error writing to stdout: %v", i, err)
|
||||
}
|
||||
out.Close()
|
||||
<-clientStdoutReadDone
|
||||
} else if out != nil {
|
||||
t.Fatalf("%d: stdout: expected nil: %#v", i, out)
|
||||
}
|
||||
|
||||
if tty {
|
||||
if stderr != nil {
|
||||
t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr)
|
||||
}
|
||||
} else if test.stderr {
|
||||
if stderr == nil {
|
||||
t.Fatalf("%d: stderr: expected non-nil", i)
|
||||
}
|
||||
_, err := stderr.Write([]byte(expectedStderr))
|
||||
if err != nil {
|
||||
t.Fatalf("%d:, error writing to stderr: %v", i, err)
|
||||
}
|
||||
stderr.Close()
|
||||
<-clientStderrReadDone
|
||||
} else if stderr != nil {
|
||||
t.Fatalf("%d: stderr: expected nil: %#v", i, stderr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
fw.fakeKubelet.execFunc = func(podFullName string, uid types.UID, containerName string, cmd []string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
|
||||
execInvoked = true
|
||||
if strings.Join(cmd, " ") != expectedCommand {
|
||||
t.Fatalf("%d: cmd: expected: %s, got %v", i, expectedCommand, cmd)
|
||||
}
|
||||
return testStreamFunc(podFullName, uid, containerName, cmd, in, out, stderr, tty, done)
|
||||
}
|
||||
|
||||
if test.stdin {
|
||||
if in == nil {
|
||||
t.Fatalf("%d: stdin: expected non-nil", i)
|
||||
}
|
||||
b := make([]byte, 10)
|
||||
n, err := in.Read(b)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading from stdin: %v", i, err)
|
||||
}
|
||||
if e, a := expectedStdin, string(b[0:n]); e != a {
|
||||
t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a)
|
||||
}
|
||||
} else if in != nil {
|
||||
t.Fatalf("%d: stdin: expected nil: %#v", i, in)
|
||||
}
|
||||
|
||||
if test.stdout {
|
||||
if out == nil {
|
||||
t.Fatalf("%d: stdout: expected non-nil", i)
|
||||
}
|
||||
_, err := out.Write([]byte(expectedStdout))
|
||||
if err != nil {
|
||||
t.Fatalf("%d:, error writing to stdout: %v", i, err)
|
||||
}
|
||||
out.Close()
|
||||
<-clientStdoutReadDone
|
||||
} else if out != nil {
|
||||
t.Fatalf("%d: stdout: expected nil: %#v", i, out)
|
||||
}
|
||||
|
||||
if tty {
|
||||
if stderr != nil {
|
||||
t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr)
|
||||
}
|
||||
} else if test.stderr {
|
||||
if stderr == nil {
|
||||
t.Fatalf("%d: stderr: expected non-nil", i)
|
||||
}
|
||||
_, err := stderr.Write([]byte(expectedStderr))
|
||||
if err != nil {
|
||||
t.Fatalf("%d:, error writing to stderr: %v", i, err)
|
||||
}
|
||||
stderr.Close()
|
||||
<-clientStderrReadDone
|
||||
} else if stderr != nil {
|
||||
t.Fatalf("%d: stderr: expected nil: %#v", i, stderr)
|
||||
}
|
||||
|
||||
return nil
|
||||
fw.fakeKubelet.attachFunc = func(podFullName string, uid types.UID, containerName string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
|
||||
attachInvoked = true
|
||||
return testStreamFunc(podFullName, uid, containerName, nil, in, out, stderr, tty, done)
|
||||
}
|
||||
|
||||
var url string
|
||||
if test.uid {
|
||||
url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedUid + "/" + expectedContainerName + "?command=ls&command=-a"
|
||||
url = fw.testHTTPServer.URL + "/" + verb + "/" + podNamespace + "/" + podName + "/" + expectedUid + "/" + expectedContainerName + "?ignore=1"
|
||||
} else {
|
||||
url = fw.testHTTPServer.URL + "/exec/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?command=ls&command=-a"
|
||||
url = fw.testHTTPServer.URL + "/" + verb + "/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?ignore=1"
|
||||
}
|
||||
if verb == "exec" {
|
||||
url += "&command=ls&command=-a"
|
||||
}
|
||||
if test.stdin {
|
||||
url += "&" + api.ExecStdinParam + "=1"
|
||||
@ -1186,11 +1202,9 @@ func TestServeExecInContainer(t *testing.T) {
|
||||
|
||||
h := http.Header{}
|
||||
h.Set(api.StreamType, api.StreamTypeError)
|
||||
errorStream, err := conn.CreateStream(h)
|
||||
if err != nil {
|
||||
if _, err := conn.CreateStream(h); err != nil {
|
||||
t.Fatalf("%d: error creating error stream: %v", i, err)
|
||||
}
|
||||
defer errorStream.Reset()
|
||||
|
||||
if test.stdin {
|
||||
h.Set(api.StreamType, api.StreamTypeStdin)
|
||||
@ -1198,7 +1212,6 @@ func TestServeExecInContainer(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating stdin stream: %v", i, err)
|
||||
}
|
||||
defer stream.Reset()
|
||||
_, err = stream.Write([]byte(expectedStdin))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error writing to stdin stream: %v", i, err)
|
||||
@ -1212,7 +1225,6 @@ func TestServeExecInContainer(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating stdout stream: %v", i, err)
|
||||
}
|
||||
defer stdoutStream.Reset()
|
||||
}
|
||||
|
||||
var stderrStream httpstream.Stream
|
||||
@ -1222,7 +1234,6 @@ func TestServeExecInContainer(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating stderr stream: %v", i, err)
|
||||
}
|
||||
defer stderrStream.Reset()
|
||||
}
|
||||
|
||||
if test.stdout {
|
||||
@ -1249,239 +1260,33 @@ func TestServeExecInContainer(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
<-execFuncDone
|
||||
// wait for the server to finish before checking if the attach/exec funcs were invoked
|
||||
<-done
|
||||
|
||||
if verb == "exec" {
|
||||
if !execInvoked {
|
||||
t.Errorf("%d: exec was not invoked", i)
|
||||
}
|
||||
if attachInvoked {
|
||||
t.Errorf("%d: attach should not have been invoked", i)
|
||||
}
|
||||
} else {
|
||||
if !attachInvoked {
|
||||
t.Errorf("%d: attach was not invoked", i)
|
||||
}
|
||||
if execInvoked {
|
||||
t.Errorf("%d: exec should not have been invoked", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: largely cloned from TestServeExecContainer, refactor and re-use code
|
||||
func TestServeExecInContainer(t *testing.T) {
|
||||
testExecAttach(t, "exec")
|
||||
}
|
||||
|
||||
func TestServeAttachContainer(t *testing.T) {
|
||||
tests := []struct {
|
||||
stdin bool
|
||||
stdout bool
|
||||
stderr bool
|
||||
tty bool
|
||||
responseStatusCode int
|
||||
uid bool
|
||||
}{
|
||||
{responseStatusCode: http.StatusBadRequest},
|
||||
{stdin: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stdout: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stdout: true, stderr: true, tty: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
{stdin: true, stdout: true, stderr: true, responseStatusCode: http.StatusSwitchingProtocols},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
fw := newServerTest()
|
||||
|
||||
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
podNamespace := "other"
|
||||
podName := "foo"
|
||||
expectedPodName := getPodName(podName, podNamespace)
|
||||
expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647"
|
||||
expectedContainerName := "baz"
|
||||
expectedStdin := "stdin"
|
||||
expectedStdout := "stdout"
|
||||
expectedStderr := "stderr"
|
||||
attachFuncDone := make(chan struct{})
|
||||
clientStdoutReadDone := make(chan struct{})
|
||||
clientStderrReadDone := make(chan struct{})
|
||||
|
||||
fw.fakeKubelet.attachFunc = func(podFullName string, uid types.UID, containerName string, in io.Reader, out, stderr io.WriteCloser, tty bool) error {
|
||||
defer close(attachFuncDone)
|
||||
if podFullName != expectedPodName {
|
||||
t.Fatalf("%d: podFullName: expected %s, got %s", i, expectedPodName, podFullName)
|
||||
}
|
||||
if test.uid && string(uid) != expectedUid {
|
||||
t.Fatalf("%d: uid: expected %v, got %v", i, expectedUid, uid)
|
||||
}
|
||||
if containerName != expectedContainerName {
|
||||
t.Fatalf("%d: containerName: expected %s, got %s", i, expectedContainerName, containerName)
|
||||
}
|
||||
|
||||
if test.stdin {
|
||||
if in == nil {
|
||||
t.Fatalf("%d: stdin: expected non-nil", i)
|
||||
}
|
||||
b := make([]byte, 10)
|
||||
n, err := in.Read(b)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading from stdin: %v", i, err)
|
||||
}
|
||||
if e, a := expectedStdin, string(b[0:n]); e != a {
|
||||
t.Fatalf("%d: stdin: expected to read %v, got %v", i, e, a)
|
||||
}
|
||||
} else if in != nil {
|
||||
t.Fatalf("%d: stdin: expected nil: %#v", i, in)
|
||||
}
|
||||
|
||||
if test.stdout {
|
||||
if out == nil {
|
||||
t.Fatalf("%d: stdout: expected non-nil", i)
|
||||
}
|
||||
_, err := out.Write([]byte(expectedStdout))
|
||||
if err != nil {
|
||||
t.Fatalf("%d:, error writing to stdout: %v", i, err)
|
||||
}
|
||||
out.Close()
|
||||
<-clientStdoutReadDone
|
||||
} else if out != nil {
|
||||
t.Fatalf("%d: stdout: expected nil: %#v", i, out)
|
||||
}
|
||||
|
||||
if tty {
|
||||
if stderr != nil {
|
||||
t.Fatalf("%d: tty set but received non-nil stderr: %v", i, stderr)
|
||||
}
|
||||
} else if test.stderr {
|
||||
if stderr == nil {
|
||||
t.Fatalf("%d: stderr: expected non-nil", i)
|
||||
}
|
||||
_, err := stderr.Write([]byte(expectedStderr))
|
||||
if err != nil {
|
||||
t.Fatalf("%d:, error writing to stderr: %v", i, err)
|
||||
}
|
||||
stderr.Close()
|
||||
<-clientStderrReadDone
|
||||
} else if stderr != nil {
|
||||
t.Fatalf("%d: stderr: expected nil: %#v", i, stderr)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var url string
|
||||
if test.uid {
|
||||
url = fw.testHTTPServer.URL + "/attach/" + podNamespace + "/" + podName + "/" + expectedUid + "/" + expectedContainerName + "?"
|
||||
} else {
|
||||
url = fw.testHTTPServer.URL + "/attach/" + podNamespace + "/" + podName + "/" + expectedContainerName + "?"
|
||||
}
|
||||
if test.stdin {
|
||||
url += "&" + api.ExecStdinParam + "=1"
|
||||
}
|
||||
if test.stdout {
|
||||
url += "&" + api.ExecStdoutParam + "=1"
|
||||
}
|
||||
if test.stderr && !test.tty {
|
||||
url += "&" + api.ExecStderrParam + "=1"
|
||||
}
|
||||
if test.tty {
|
||||
url += "&" + api.ExecTTYParam + "=1"
|
||||
}
|
||||
|
||||
var (
|
||||
resp *http.Response
|
||||
err error
|
||||
upgradeRoundTripper httpstream.UpgradeRoundTripper
|
||||
c *http.Client
|
||||
)
|
||||
|
||||
if test.responseStatusCode != http.StatusSwitchingProtocols {
|
||||
c = &http.Client{}
|
||||
} else {
|
||||
upgradeRoundTripper = spdy.NewRoundTripper(nil)
|
||||
c = &http.Client{Transport: upgradeRoundTripper}
|
||||
}
|
||||
|
||||
resp, err = c.Post(url, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: Got error POSTing: %v", i, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
_, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("%d: Error reading response body: %v", i, err)
|
||||
}
|
||||
|
||||
if e, a := test.responseStatusCode, resp.StatusCode; e != a {
|
||||
t.Fatalf("%d: response status: expected %v, got %v", i, e, a)
|
||||
}
|
||||
|
||||
if test.responseStatusCode != http.StatusSwitchingProtocols {
|
||||
continue
|
||||
}
|
||||
|
||||
conn, err := upgradeRoundTripper.NewConnection(resp)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error creating streaming connection: %s", err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Fatalf("%d: unexpected nil conn", i)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
h := http.Header{}
|
||||
h.Set(api.StreamType, api.StreamTypeError)
|
||||
errorStream, err := conn.CreateStream(h)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating error stream: %v", i, err)
|
||||
}
|
||||
defer errorStream.Reset()
|
||||
|
||||
if test.stdin {
|
||||
h.Set(api.StreamType, api.StreamTypeStdin)
|
||||
stream, err := conn.CreateStream(h)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating stdin stream: %v", i, err)
|
||||
}
|
||||
defer stream.Reset()
|
||||
_, err = stream.Write([]byte(expectedStdin))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error writing to stdin stream: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
var stdoutStream httpstream.Stream
|
||||
if test.stdout {
|
||||
h.Set(api.StreamType, api.StreamTypeStdout)
|
||||
stdoutStream, err = conn.CreateStream(h)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating stdout stream: %v", i, err)
|
||||
}
|
||||
defer stdoutStream.Reset()
|
||||
}
|
||||
|
||||
var stderrStream httpstream.Stream
|
||||
if test.stderr && !test.tty {
|
||||
h.Set(api.StreamType, api.StreamTypeStderr)
|
||||
stderrStream, err = conn.CreateStream(h)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error creating stderr stream: %v", i, err)
|
||||
}
|
||||
defer stderrStream.Reset()
|
||||
}
|
||||
|
||||
if test.stdout {
|
||||
output := make([]byte, 10)
|
||||
n, err := stdoutStream.Read(output)
|
||||
close(clientStdoutReadDone)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading from stdout stream: %v", i, err)
|
||||
}
|
||||
if e, a := expectedStdout, string(output[0:n]); e != a {
|
||||
t.Fatalf("%d: stdout: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
if test.stderr && !test.tty {
|
||||
output := make([]byte, 10)
|
||||
n, err := stderrStream.Read(output)
|
||||
close(clientStderrReadDone)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading from stderr stream: %v", i, err)
|
||||
}
|
||||
if e, a := expectedStderr, string(output[0:n]); e != a {
|
||||
t.Fatalf("%d: stderr: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
<-attachFuncDone
|
||||
}
|
||||
testExecAttach(t, "attach")
|
||||
}
|
||||
|
||||
func TestServePortForwardIdleTimeout(t *testing.T) {
|
||||
|
@ -114,20 +114,24 @@ func negotiateProtocol(clientProtocols, serverProtocols []string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handshake performs a subprotocol negotiation. If the client did not request
|
||||
// a specific subprotocol, defaultProtocol is used. If the client did request a
|
||||
// Handshake performs a subprotocol negotiation. If the client did request a
|
||||
// subprotocol, Handshake will select the first common value found in
|
||||
// serverProtocols. If a match is found, Handshake adds a response header
|
||||
// indicating the chosen subprotocol. If no match is found, HTTP forbidden is
|
||||
// returned, along with a response header containing the list of protocols the
|
||||
// server can accept.
|
||||
func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string, defaultProtocol string) (string, error) {
|
||||
func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string) (string, error) {
|
||||
clientProtocols := req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)]
|
||||
if len(clientProtocols) == 0 {
|
||||
// Kube 1.0 client that didn't support subprotocol negotiation
|
||||
// TODO remove this defaulting logic once Kube 1.0 is no longer supported
|
||||
w.Header().Add(HeaderProtocolVersion, defaultProtocol)
|
||||
return defaultProtocol, nil
|
||||
// Kube 1.0 clients didn't support subprotocol negotiation.
|
||||
// TODO require clientProtocols once Kube 1.0 is no longer supported
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if len(serverProtocols) == 0 {
|
||||
// Kube 1.0 servers didn't support subprotocol negotiation. This is mainly for testing.
|
||||
// TODO require serverProtocols once Kube 1.0 is no longer supported
|
||||
return "", nil
|
||||
}
|
||||
|
||||
negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols)
|
||||
|
@ -20,6 +20,8 @@ import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"k8s.io/kubernetes/pkg/api"
|
||||
)
|
||||
|
||||
type responseWriter struct {
|
||||
@ -46,8 +48,6 @@ func (r *responseWriter) Write([]byte) (int, error) {
|
||||
}
|
||||
|
||||
func TestHandshake(t *testing.T) {
|
||||
defaultProtocol := "default"
|
||||
|
||||
tests := map[string]struct {
|
||||
clientProtocols []string
|
||||
serverProtocols []string
|
||||
@ -57,7 +57,7 @@ func TestHandshake(t *testing.T) {
|
||||
"no client protocols": {
|
||||
clientProtocols: []string{},
|
||||
serverProtocols: []string{"a", "b"},
|
||||
expectedProtocol: defaultProtocol,
|
||||
expectedProtocol: "",
|
||||
},
|
||||
"no common protocol": {
|
||||
clientProtocols: []string{"c"},
|
||||
@ -83,7 +83,7 @@ func TestHandshake(t *testing.T) {
|
||||
}
|
||||
|
||||
w := newResponseWriter()
|
||||
negotiated, err := Handshake(req, w, test.serverProtocols, defaultProtocol)
|
||||
negotiated, err := Handshake(req, w, test.serverProtocols)
|
||||
|
||||
// verify negotiated protocol
|
||||
if e, a := test.expectedProtocol, negotiated; e != a {
|
||||
@ -112,8 +112,15 @@ func TestHandshake(t *testing.T) {
|
||||
t.Errorf("%s: unexpected non-nil w.statusCode: %d", w.statusCode)
|
||||
}
|
||||
|
||||
if len(test.expectedProtocol) == 0 {
|
||||
if len(w.Header()[HeaderProtocolVersion]) > 0 {
|
||||
t.Errorf("%s: unexpected protocol version response header: %s", w.Header()[HeaderProtocolVersion])
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// verify response headers
|
||||
if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !reflect.DeepEqual(e, a) {
|
||||
if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !api.Semantic.DeepEqual(e, a) {
|
||||
t.Errorf("%s: protocol response header: expected %v, got %v", name, e, a)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user