Add streaming command execution & port forwarding

Add streaming command execution & port forwarding via HTTP connection
upgrades (currently using SPDY).
This commit is contained in:
Andy Goldstein
2015-01-08 15:41:38 -05:00
parent 25d38c175b
commit 5bd0e9ab05
45 changed files with 4439 additions and 157 deletions

View File

@@ -0,0 +1,19 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package portforward adds support for SSH-like port forwarding from the client's
// local host to remote containers.
package portforward

View File

@@ -0,0 +1,300 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"strconv"
"strings"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream/spdy"
"github.com/golang/glog"
)
type upgrader interface {
upgrade(*client.Request, *client.Config) (httpstream.Connection, error)
}
type defaultUpgrader struct{}
func (u *defaultUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
return req.Upgrade(config, spdy.NewRoundTripper)
}
// PortForwarder knows how to listen for local connections and forward them to
// a remote pod via an upgraded HTTP request.
type PortForwarder struct {
req *client.Request
config *client.Config
ports []ForwardedPort
stopChan <-chan struct{}
streamConn httpstream.Connection
listeners []io.Closer
upgrader upgrader
Ready chan struct{}
}
// ForwardedPort contains a Local:Remote port pairing.
type ForwardedPort struct {
Local uint16
Remote uint16
}
/*
valid port specifications:
5000
- forwards from localhost:5000 to pod:5000
8888:5000
- forwards from localhost:8888 to pod:5000
0:5000
:5000
- selects a random available local port,
forwards from localhost:<random port> to pod:5000
*/
func parsePorts(ports []string) ([]ForwardedPort, error) {
var forwards []ForwardedPort
for _, portString := range ports {
parts := strings.Split(portString, ":")
var localString, remoteString string
if len(parts) == 1 {
localString = parts[0]
remoteString = parts[0]
} else if len(parts) == 2 {
localString = parts[0]
if localString == "" {
// support :5000
localString = "0"
}
remoteString = parts[1]
} else {
return nil, fmt.Errorf("Invalid port format '%s'", portString)
}
localPort, err := strconv.ParseUint(localString, 10, 16)
if err != nil {
return nil, fmt.Errorf("Error parsing local port '%s': %s", localString, err)
}
remotePort, err := strconv.ParseUint(remoteString, 10, 16)
if err != nil {
return nil, fmt.Errorf("Error parsing remote port '%s': %s", remoteString, err)
}
if remotePort == 0 {
return nil, fmt.Errorf("Remote port must be > 0")
}
forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)})
}
return forwards, nil
}
// New creates a new PortForwarder.
func New(req *client.Request, config *client.Config, ports []string, stopChan <-chan struct{}) (*PortForwarder, error) {
if len(ports) == 0 {
return nil, errors.New("You must specify at least 1 port")
}
parsedPorts, err := parsePorts(ports)
if err != nil {
return nil, err
}
return &PortForwarder{
req: req,
config: config,
ports: parsedPorts,
stopChan: stopChan,
Ready: make(chan struct{}),
}, nil
}
// ForwardPorts formats and executes a port forwarding request. The connection will remain
// open until stopChan is closed.
func (pf *PortForwarder) ForwardPorts() error {
defer pf.Close()
if pf.upgrader == nil {
pf.upgrader = &defaultUpgrader{}
}
var err error
pf.streamConn, err = pf.upgrader.upgrade(pf.req, pf.config)
if err != nil {
return fmt.Errorf("Error upgrading connection: %s", err)
}
defer pf.streamConn.Close()
return pf.forward()
}
// forward dials the remote host specific in req, upgrades the request, starts
// listeners for each port specified in ports, and forwards local connections
// to the remote host via streams.
func (pf *PortForwarder) forward() error {
var err error
listenSuccess := false
for _, port := range pf.ports {
err = pf.listenOnPort(&port)
if err != nil {
glog.Warningf("Unable to listen on port %d: %v", port, err)
}
listenSuccess = true
}
if !listenSuccess {
return fmt.Errorf("Unable to listen on any of the requested ports: %v", pf.ports)
}
close(pf.Ready)
// wait for interrupt or conn closure
select {
case <-pf.stopChan:
case <-pf.streamConn.CloseChan():
glog.Errorf("Lost connection to pod")
}
return nil
}
// listenOnPort creates a new listener on port and waits for new connections
// in the background.
func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error {
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port.Local))
if err != nil {
return err
}
parts := strings.Split(listener.Addr().String(), ":")
localPort, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return fmt.Errorf("Error parsing local part: %s", err)
}
port.Local = uint16(localPort)
glog.Infof("Forwarding from %d -> %d", localPort, port.Remote)
pf.listeners = append(pf.listeners, listener)
go pf.waitForConnection(listener, *port)
return nil
}
// waitForConnection waits for new connections to listener and handles them in
// the background.
func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
for {
conn, err := listener.Accept()
if err != nil {
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
glog.Errorf("Error accepting connection on port %d: %v", port.Local, err)
}
return
}
go pf.handleConnection(conn, port)
}
}
// handleConnection copies data between the local connection and the stream to
// the remote server.
func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
defer conn.Close()
glog.Infof("Handling connection for %d", port.Local)
errorChan := make(chan error)
doneChan := make(chan struct{}, 2)
// create error stream
headers := http.Header{}
headers.Set(api.StreamType, api.StreamTypeError)
headers.Set(api.PortHeader, fmt.Sprintf("%d", port.Remote))
errorStream, err := pf.streamConn.CreateStream(headers)
if err != nil {
glog.Errorf("Error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err)
return
}
defer errorStream.Reset()
go func() {
message, err := ioutil.ReadAll(errorStream)
if err != nil && err != io.EOF {
errorChan <- fmt.Errorf("Error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
}
if len(message) > 0 {
errorChan <- fmt.Errorf("An error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
}
}()
// create data stream
headers.Set(api.StreamType, api.StreamTypeData)
dataStream, err := pf.streamConn.CreateStream(headers)
if err != nil {
glog.Errorf("Error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err)
return
}
// Send a Reset when this function exits to completely tear down the stream here
// and in the remote server.
defer dataStream.Reset()
go func() {
// Copy from the remote side to the local port. We won't get an EOF from
// the server as it has no way of knowing when to close the stream. We'll
// take care of closing both ends of the stream with the call to
// stream.Reset() when this function exits.
if _, err := io.Copy(conn, dataStream); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
glog.Errorf("Error copying from remote stream to local connection: %v", err)
}
doneChan <- struct{}{}
}()
go func() {
// Copy from the local port to the remote side. Here we will be able to know
// when the Copy gets an EOF from conn, as that will happen as soon as conn is
// closed (i.e. client disconnected).
if _, err := io.Copy(dataStream, conn); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
glog.Errorf("Error copying from local connection to remote stream: %v", err)
}
doneChan <- struct{}{}
}()
select {
case err := <-errorChan:
glog.Error(err)
case <-doneChan:
}
}
func (pf *PortForwarder) Close() {
// stop all listeners
for _, l := range pf.listeners {
if err := l.Close(); err != nil {
glog.Errorf("Error closing listener: %v", err)
}
}
}

View File

@@ -0,0 +1,321 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"net/http"
"reflect"
"sync"
"testing"
"time"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
)
func TestParsePortsAndNew(t *testing.T) {
tests := []struct {
input []string
expected []ForwardedPort
expectParseError bool
expectNewError bool
}{
{input: []string{}, expectNewError: true},
{input: []string{"a"}, expectParseError: true, expectNewError: true},
{input: []string{":a"}, expectParseError: true, expectNewError: true},
{input: []string{"-1"}, expectParseError: true, expectNewError: true},
{input: []string{"65536"}, expectParseError: true, expectNewError: true},
{input: []string{"0"}, expectParseError: true, expectNewError: true},
{input: []string{"0:0"}, expectParseError: true, expectNewError: true},
{input: []string{"a:5000"}, expectParseError: true, expectNewError: true},
{input: []string{"5000:a"}, expectParseError: true, expectNewError: true},
{
input: []string{"5000", "5000:5000", "8888:5000", "5000:8888", ":5000", "0:5000"},
expected: []ForwardedPort{
{5000, 5000},
{5000, 5000},
{8888, 5000},
{5000, 8888},
{0, 5000},
{0, 5000},
},
},
}
for i, test := range tests {
parsed, err := parsePorts(test.input)
haveError := err != nil
if e, a := test.expectParseError, haveError; e != a {
t.Fatalf("%d: parsePorts: error expected=%t, got %t: %s", i, e, a, err)
}
expectedRequest := &client.Request{}
expectedConfig := &client.Config{}
expectedStopChan := make(chan struct{})
pf, err := New(expectedRequest, expectedConfig, test.input, expectedStopChan)
haveError = err != nil
if e, a := test.expectNewError, haveError; e != a {
t.Fatalf("%d: New: error expected=%t, got %t: %s", i, e, a, err)
}
if test.expectParseError || test.expectNewError {
continue
}
for pi, expectedPort := range test.expected {
if e, a := expectedPort.Local, parsed[pi].Local; e != a {
t.Fatalf("%d: local expected: %d, got: %d", i, e, a)
}
if e, a := expectedPort.Remote, parsed[pi].Remote; e != a {
t.Fatalf("%d: remote expected: %d, got: %d", i, e, a)
}
}
if e, a := expectedRequest, pf.req; e != a {
t.Fatalf("%d: req: expected %#v, got %#v", i, e, a)
}
if e, a := expectedConfig, pf.config; e != a {
t.Fatalf("%d: config: expected %#v, got %#v", i, e, a)
}
if e, a := test.expected, pf.ports; !reflect.DeepEqual(e, a) {
t.Fatalf("%d: ports: expected %#v, got %#v", i, e, a)
}
if e, a := expectedStopChan, pf.stopChan; e != a {
t.Fatalf("%d: stopChan: expected %#v, got %#v", i, e, a)
}
if pf.Ready == nil {
t.Fatalf("%d: Ready should be non-nil", i)
}
}
}
type fakeUpgrader struct {
conn *fakeUpgradeConnection
err error
}
func (u *fakeUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
return u.conn, u.err
}
type fakeUpgradeConnection struct {
closeCalled bool
lock sync.Mutex
streams map[string]*fakeUpgradeStream
portData map[string]string
}
func newFakeUpgradeConnection() *fakeUpgradeConnection {
return &fakeUpgradeConnection{
streams: make(map[string]*fakeUpgradeStream),
portData: make(map[string]string),
}
}
func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
c.lock.Lock()
defer c.lock.Unlock()
stream := &fakeUpgradeStream{}
c.streams[headers.Get(api.PortHeader)] = stream
stream.data = c.portData[headers.Get(api.PortHeader)]
return stream, nil
}
func (c *fakeUpgradeConnection) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
c.closeCalled = true
return nil
}
func (c *fakeUpgradeConnection) CloseChan() <-chan bool {
return make(chan bool)
}
func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) {
}
type fakeUpgradeStream struct {
readCalled bool
writeCalled bool
dataWritten []byte
closeCalled bool
resetCalled bool
data string
lock sync.Mutex
}
func (s *fakeUpgradeStream) Read(p []byte) (int, error) {
s.lock.Lock()
defer s.lock.Unlock()
s.readCalled = true
b := []byte(s.data)
n := copy(p, b)
return n, io.EOF
}
func (s *fakeUpgradeStream) Write(p []byte) (int, error) {
s.lock.Lock()
defer s.lock.Unlock()
s.writeCalled = true
s.dataWritten = make([]byte, len(p))
copy(s.dataWritten, p)
return len(p), io.EOF
}
func (s *fakeUpgradeStream) Close() error {
s.lock.Lock()
defer s.lock.Unlock()
s.closeCalled = true
return nil
}
func (s *fakeUpgradeStream) Reset() error {
s.lock.Lock()
defer s.lock.Unlock()
s.resetCalled = true
return nil
}
func (s *fakeUpgradeStream) Headers() http.Header {
s.lock.Lock()
defer s.lock.Unlock()
return http.Header{}
}
func TestForwardPorts(t *testing.T) {
testCases := []struct {
Upgrader *fakeUpgrader
Ports []string
Send map[uint16]string
Receive map[uint16]string
Err bool
}{
{
Upgrader: &fakeUpgrader{err: errors.New("bail")},
Err: true,
},
{
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
Ports: []string{"5000"},
},
{
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
Ports: []string{"5000", "6000"},
Send: map[uint16]string{
5000: "abcd",
6000: "ghij",
},
Receive: map[uint16]string{
5000: "1234",
6000: "5678",
},
},
}
for i, testCase := range testCases {
stopChan := make(chan struct{}, 1)
pf, err := New(&client.Request{}, &client.Config{}, testCase.Ports, stopChan)
hasErr := err != nil
if hasErr != testCase.Err {
t.Fatalf("%d: New: expected %t, got %t: %v", i, testCase.Err, hasErr, err)
}
if pf == nil {
continue
}
pf.upgrader = testCase.Upgrader
if testCase.Upgrader.err != nil {
err := pf.ForwardPorts()
hasErr := err != nil
if hasErr != testCase.Err {
t.Fatalf("%d: ForwardPorts: expected %t, got %t: %v", i, testCase.Err, hasErr, err)
}
continue
}
doneChan := make(chan error)
go func() {
doneChan <- pf.ForwardPorts()
}()
select {
case <-pf.Ready:
case <-time.After(500 * time.Millisecond):
t.Fatalf("%d: timed out waiting for listeners", i)
}
conn := testCase.Upgrader.conn
for port, data := range testCase.Send {
conn.lock.Lock()
conn.portData[fmt.Sprintf("%d", port)] = testCase.Receive[port]
conn.lock.Unlock()
clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
t.Fatalf("%d: error dialing %d: %s", i, port, err)
}
defer clientConn.Close()
n, err := clientConn.Write([]byte(data))
if err != nil && err != io.EOF {
t.Fatalf("%d: Error sending data '%s': %s", i, data, err)
}
if n == 0 {
t.Fatalf("%d: unexpected write of 0 bytes", i)
}
b := make([]byte, 4)
n, err = clientConn.Read(b)
if err != nil && err != io.EOF {
t.Fatalf("%d: Error reading data: %s", i, err)
}
if !bytes.Equal([]byte(testCase.Receive[port]), b) {
t.Fatalf("%d: expected to read '%s', got '%s'", i, testCase.Receive[port], b)
}
}
// tell r.ForwardPorts to stop
close(stopChan)
// wait for r.ForwardPorts to actually return
select {
case err := <-doneChan:
if err != nil {
t.Fatalf("%d: unexpected error: %s", err)
}
case <-time.After(200 * time.Millisecond):
t.Fatalf("%d: timeout waiting for ForwardPorts to finish")
}
if e, a := len(testCase.Send), len(conn.streams); e != a {
t.Fatalf("%d: expected %d streams to be created, got %d", e, a)
}
if !conn.closeCalled {
t.Fatalf("%d: expected conn closure", i)
}
}
}

View File

@@ -0,0 +1,20 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package remotecommand adds support for executing commands in containers,
// with support for separate stdin, stdout, and stderr streams, as well as
// TTY.
package remotecommand

View File

@@ -0,0 +1,186 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package remotecommand
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream/spdy"
"github.com/golang/glog"
)
type upgrader interface {
upgrade(*client.Request, *client.Config) (httpstream.Connection, error)
}
type defaultUpgrader struct{}
func (u *defaultUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
return req.Upgrade(config, spdy.NewRoundTripper)
}
type RemoteCommandExecutor struct {
req *client.Request
config *client.Config
command []string
stdin io.Reader
stdout io.Writer
stderr io.Writer
tty bool
upgrader upgrader
}
func New(req *client.Request, config *client.Config, command []string, stdin io.Reader, stdout, stderr io.Writer, tty bool) *RemoteCommandExecutor {
return &RemoteCommandExecutor{
req: req,
config: config,
command: command,
stdin: stdin,
stdout: stdout,
stderr: stderr,
tty: tty,
}
}
// Execute sends a remote command execution request, upgrading the
// connection and creating streams to represent stdin/stdout/stderr. Data is
// copied between these streams and the supplied stdin/stdout/stderr parameters.
func (e *RemoteCommandExecutor) Execute() error {
doStdin := (e.stdin != nil)
doStdout := (e.stdout != nil)
doStderr := (!e.tty && e.stderr != nil)
if doStdin {
e.req.Param(api.ExecStdinParam, "1")
}
if doStdout {
e.req.Param(api.ExecStdoutParam, "1")
}
if doStderr {
e.req.Param(api.ExecStderrParam, "1")
}
if e.tty {
e.req.Param(api.ExecTTYParam, "1")
}
for _, s := range e.command {
e.req.Param(api.ExecCommandParamm, s)
}
if e.upgrader == nil {
e.upgrader = &defaultUpgrader{}
}
conn, err := e.upgrader.upgrade(e.req, e.config)
if err != nil {
return err
}
defer conn.Close()
doneChan := make(chan struct{}, 2)
errorChan := make(chan error)
cp := func(s string, dst io.Writer, src io.Reader) {
glog.V(4).Infof("Copying %s", s)
defer glog.V(4).Infof("Done copying %s", s)
if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
glog.Errorf("Error copying %s: %v", s, err)
}
if s == api.StreamTypeStdout || s == api.StreamTypeStderr {
doneChan <- struct{}{}
}
}
headers := http.Header{}
headers.Set(api.StreamType, api.StreamTypeError)
errorStream, err := conn.CreateStream(headers)
if err != nil {
return err
}
go func() {
message, err := ioutil.ReadAll(errorStream)
if err != nil && err != io.EOF {
errorChan <- fmt.Errorf("Error reading from error stream: %s", err)
return
}
if len(message) > 0 {
errorChan <- fmt.Errorf("Error executing remote command: %s", message)
return
}
}()
defer errorStream.Reset()
if doStdin {
headers.Set(api.StreamType, api.StreamTypeStdin)
remoteStdin, err := conn.CreateStream(headers)
if err != nil {
return err
}
defer remoteStdin.Reset()
// TODO this goroutine will never exit cleanly (the io.Copy never unblocks)
// because stdin is not closed until the process exits. If we try to call
// stdin.Close(), it returns no error but doesn't unblock the copy. It will
// exit when the process exits, instead.
go cp(api.StreamTypeStdin, remoteStdin, e.stdin)
}
waitCount := 0
completedStreams := 0
if doStdout {
waitCount++
headers.Set(api.StreamType, api.StreamTypeStdout)
remoteStdout, err := conn.CreateStream(headers)
if err != nil {
return err
}
defer remoteStdout.Reset()
go cp(api.StreamTypeStdout, e.stdout, remoteStdout)
}
if doStderr && !e.tty {
waitCount++
headers.Set(api.StreamType, api.StreamTypeStderr)
remoteStderr, err := conn.CreateStream(headers)
if err != nil {
return err
}
defer remoteStderr.Reset()
go cp(api.StreamTypeStderr, e.stderr, remoteStderr)
}
Loop:
for {
select {
case <-doneChan:
completedStreams++
if completedStreams == waitCount {
break Loop
}
case err := <-errorChan:
return err
}
}
return nil
}

View File

@@ -0,0 +1,288 @@
/*
Copyright 2015 Google Inc. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package remotecommand
import (
"bytes"
"errors"
"io"
"net/http"
"strings"
"sync"
"testing"
"time"
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
)
type fakeUpgrader struct {
conn *fakeUpgradeConnection
err error
}
func (u *fakeUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
return u.conn, u.err
}
type fakeUpgradeConnection struct {
closeCalled bool
lock sync.Mutex
stdin *fakeUpgradeStream
stdout *fakeUpgradeStream
stdoutData string
stderr *fakeUpgradeStream
stderrData string
errorStream *fakeUpgradeStream
errorData string
unexpectedStreamCreated bool
}
func newFakeUpgradeConnection() *fakeUpgradeConnection {
return &fakeUpgradeConnection{}
}
func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
c.lock.Lock()
defer c.lock.Unlock()
stream := &fakeUpgradeStream{}
switch headers.Get(api.StreamType) {
case api.StreamTypeStdin:
c.stdin = stream
case api.StreamTypeStdout:
c.stdout = stream
stream.data = c.stdoutData
case api.StreamTypeStderr:
c.stderr = stream
stream.data = c.stderrData
case api.StreamTypeError:
c.errorStream = stream
stream.data = c.errorData
default:
c.unexpectedStreamCreated = true
}
return stream, nil
}
func (c *fakeUpgradeConnection) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
c.closeCalled = true
return nil
}
func (c *fakeUpgradeConnection) CloseChan() <-chan bool {
return make(chan bool)
}
func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) {
}
type fakeUpgradeStream struct {
readCalled bool
writeCalled bool
dataWritten []byte
closeCalled bool
resetCalled bool
data string
lock sync.Mutex
}
func (s *fakeUpgradeStream) Read(p []byte) (int, error) {
s.lock.Lock()
defer s.lock.Unlock()
s.readCalled = true
b := []byte(s.data)
n := copy(p, b)
return n, io.EOF
}
func (s *fakeUpgradeStream) Write(p []byte) (int, error) {
s.lock.Lock()
defer s.lock.Unlock()
s.writeCalled = true
s.dataWritten = make([]byte, len(p))
copy(s.dataWritten, p)
return len(p), io.EOF
}
func (s *fakeUpgradeStream) Close() error {
s.lock.Lock()
defer s.lock.Unlock()
s.closeCalled = true
return nil
}
func (s *fakeUpgradeStream) Reset() error {
s.lock.Lock()
defer s.lock.Unlock()
s.resetCalled = true
return nil
}
func (s *fakeUpgradeStream) Headers() http.Header {
s.lock.Lock()
defer s.lock.Unlock()
return http.Header{}
}
func TestRequestExecuteRemoteCommand(t *testing.T) {
testCases := []struct {
Upgrader *fakeUpgrader
Stdin string
Stdout string
Stderr string
Error string
Tty bool
ShouldError bool
}{
{
Upgrader: &fakeUpgrader{err: errors.New("bail")},
ShouldError: true,
},
{
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
Stdin: "a",
Stdout: "b",
Stderr: "c",
Error: "bail",
ShouldError: true,
},
{
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
Stdin: "a",
Stdout: "b",
Stderr: "c",
},
{
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
Stdin: "a",
Stdout: "b",
Stderr: "c",
Tty: true,
},
}
for i, testCase := range testCases {
if testCase.Error != "" {
testCase.Upgrader.conn.errorData = testCase.Error
}
if testCase.Stdout != "" {
testCase.Upgrader.conn.stdoutData = testCase.Stdout
}
if testCase.Stderr != "" {
testCase.Upgrader.conn.stderrData = testCase.Stderr
}
var localOut, localErr *bytes.Buffer
if testCase.Stdout != "" {
localOut = &bytes.Buffer{}
}
if testCase.Stderr != "" {
localErr = &bytes.Buffer{}
}
e := New(&client.Request{}, &client.Config{}, []string{"ls", "/"}, strings.NewReader(testCase.Stdin), localOut, localErr, testCase.Tty)
e.upgrader = testCase.Upgrader
err := e.Execute()
hasErr := err != nil
if hasErr != testCase.ShouldError {
t.Fatalf("%d: expected %t, got %t: %v", i, testCase.ShouldError, hasErr, err)
}
conn := testCase.Upgrader.conn
if testCase.Error != "" {
if conn.errorStream == nil {
t.Fatalf("%d: expected error stream creation", i)
}
if !conn.errorStream.readCalled {
t.Fatalf("%d: expected error stream read", i)
}
if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
t.Fatalf("%d: expected error stream read '%v', got '%v'", i, e, a)
}
if !conn.errorStream.resetCalled {
t.Fatalf("%d: expected error reset", i)
}
}
if testCase.ShouldError {
continue
}
if testCase.Stdin != "" {
if conn.stdin == nil {
t.Fatalf("%d: expected stdin stream creation", i)
}
if !conn.stdin.writeCalled {
t.Fatalf("%d: expected stdin stream write", i)
}
if e, a := testCase.Stdin, string(conn.stdin.dataWritten); e != a {
t.Fatalf("%d: expected stdin write %v, got %v", i, e, a)
}
if !conn.stdin.resetCalled {
t.Fatalf("%d: expected stdin reset", i)
}
}
if testCase.Stdout != "" {
if conn.stdout == nil {
t.Fatalf("%d: expected stdout stream creation", i)
}
if !conn.stdout.readCalled {
t.Fatalf("%d: expected stdout stream read", i)
}
if e, a := testCase.Stdout, localOut; e != a.String() {
t.Fatalf("%d: expected stdout data '%s', got '%s'", i, e, a)
}
if !conn.stdout.resetCalled {
t.Fatalf("%d: expected stdout reset", i)
}
}
if testCase.Stderr != "" {
if testCase.Tty {
if conn.stderr != nil {
t.Fatalf("%d: unexpected stderr stream creation", i)
}
if localErr.String() != "" {
t.Fatalf("%d: unexpected stderr data '%s'", i, localErr)
}
} else {
if conn.stderr == nil {
t.Fatalf("%d: expected stderr stream creation", i)
}
if !conn.stderr.readCalled {
t.Fatalf("%d: expected stderr stream read", i)
}
if e, a := testCase.Stderr, localErr; e != a.String() {
t.Fatalf("%d: expected stderr data '%s', got '%s'", i, e, a)
}
if !conn.stderr.resetCalled {
t.Fatalf("%d: expected stderr reset", i)
}
}
}
if !conn.closeCalled {
t.Fatalf("%d: expected upgraded connection to get closed")
}
}
}

View File

@@ -18,6 +18,7 @@ package client
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
@@ -33,6 +34,7 @@ import (
"github.com/GoogleCloudPlatform/kubernetes/pkg/labels"
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
"github.com/GoogleCloudPlatform/kubernetes/pkg/watch"
watchjson "github.com/GoogleCloudPlatform/kubernetes/pkg/watch/json"
"github.com/golang/glog"
@@ -277,7 +279,7 @@ func (r *Request) setParam(paramName, value string) *Request {
if r.params == nil {
r.params = make(url.Values)
}
r.params[paramName] = []string{value}
r.params[paramName] = append(r.params[paramName], value)
return r
}
@@ -347,8 +349,10 @@ func (r *Request) finalURL() string {
finalURL.Path = p
query := url.Values{}
for key, value := range r.params {
query[key] = value
for key, values := range r.params {
for _, value := range values {
query.Add(key, value)
}
}
if r.namespaceSet && r.namespaceInQuery {
@@ -434,6 +438,41 @@ func (r *Request) Stream() (io.ReadCloser, error) {
return resp.Body, nil
}
// Upgrade upgrades the request so that it supports multiplexed bidirectional
// streams. The current implementation uses SPDY, but this could be replaced
// with HTTP/2 once it's available, or something else.
func (r *Request) Upgrade(config *Config, newRoundTripperFunc func(*tls.Config) httpstream.UpgradeRoundTripper) (httpstream.Connection, error) {
if r.err != nil {
return nil, r.err
}
tlsConfig, err := TLSConfigFor(config)
if err != nil {
return nil, err
}
upgradeRoundTripper := newRoundTripperFunc(tlsConfig)
wrapper, err := HTTPWrappersForConfig(config, upgradeRoundTripper)
if err != nil {
return nil, err
}
r.client = &http.Client{Transport: wrapper}
req, err := http.NewRequest(r.verb, r.finalURL(), nil)
if err != nil {
return nil, fmt.Errorf("Error creating request: %s", err)
}
resp, err := r.client.Do(req)
if err != nil {
return nil, fmt.Errorf("Error sending request: %s", err)
}
defer resp.Body.Close()
return upgradeRoundTripper.NewConnection(resp)
}
// Do formats and executes the request. Returns a Result object for easy response
// processing.
//
@@ -513,6 +552,8 @@ func (r *Request) transformResponse(resp *http.Response, req *http.Request) ([]b
}
switch {
case resp.StatusCode == http.StatusSwitchingProtocols:
// no-op, we've been upgraded
case resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusPartialContent:
if !isStatusResponse {
var err error = &UnexpectedStatusError{

View File

@@ -18,6 +18,7 @@ package client
import (
"bytes"
"crypto/tls"
"encoding/base64"
"errors"
"io"
@@ -40,6 +41,7 @@ import (
"github.com/GoogleCloudPlatform/kubernetes/pkg/labels"
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
"github.com/GoogleCloudPlatform/kubernetes/pkg/watch"
watchjson "github.com/GoogleCloudPlatform/kubernetes/pkg/watch/json"
)
@@ -151,16 +153,22 @@ func TestRequestParam(t *testing.T) {
if !api.Semantic.DeepDerivative(r.params, url.Values{"foo": []string{"a"}}) {
t.Errorf("should have set a param: %#v", r)
}
r.Param("bar", "1")
r.Param("bar", "2")
if !api.Semantic.DeepDerivative(r.params, url.Values{"foo": []string{"a"}, "bar": []string{"1", "2"}}) {
t.Errorf("should have set a param: %#v", r)
}
}
func TestRequestURI(t *testing.T) {
r := (&Request{}).Param("foo", "a")
r.Prefix("other")
r.RequestURI("/test?foo=b&a=b")
r.RequestURI("/test?foo=b&a=b&c=1&c=2")
if r.path != "/test" {
t.Errorf("path is wrong: %#v", r)
}
if !api.Semantic.DeepDerivative(r.params, url.Values{"a": []string{"b"}, "foo": []string{"b"}}) {
if !api.Semantic.DeepDerivative(r.params, url.Values{"a": []string{"b"}, "foo": []string{"b"}, "c": []string{"1", "2"}}) {
t.Errorf("should have set a param: %#v", r)
}
}
@@ -443,6 +451,122 @@ func TestRequestStream(t *testing.T) {
}
}
type fakeUpgradeConnection struct{}
func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
return nil, nil
}
func (c *fakeUpgradeConnection) Close() error {
return nil
}
func (c *fakeUpgradeConnection) CloseChan() <-chan bool {
return make(chan bool)
}
func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) {
}
type fakeUpgradeRoundTripper struct {
req *http.Request
conn httpstream.Connection
}
func (f *fakeUpgradeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
f.req = req
b := []byte{}
body := ioutil.NopCloser(bytes.NewReader(b))
resp := &http.Response{
StatusCode: 101,
Body: body,
}
return resp, nil
}
func (f *fakeUpgradeRoundTripper) NewConnection(resp *http.Response) (httpstream.Connection, error) {
return f.conn, nil
}
func TestRequestUpgrade(t *testing.T) {
uri, _ := url.Parse("http://localhost/")
testCases := []struct {
Request *Request
Config *Config
RoundTripper *fakeUpgradeRoundTripper
Err bool
AuthBasicHeader bool
AuthBearerHeader bool
}{
{
Request: &Request{err: errors.New("bail")},
Err: true,
},
{
Request: &Request{},
Config: &Config{
TLSClientConfig: TLSClientConfig{
CAFile: "foo",
},
Insecure: true,
},
Err: true,
},
{
Request: &Request{},
Config: &Config{
Username: "u",
Password: "p",
BearerToken: "b",
},
Err: true,
},
{
Request: NewRequest(nil, "", uri, testapi.Codec(), true, true),
Config: &Config{
Username: "u",
Password: "p",
},
AuthBasicHeader: true,
Err: false,
},
{
Request: NewRequest(nil, "", uri, testapi.Codec(), true, true),
Config: &Config{
BearerToken: "b",
},
AuthBearerHeader: true,
Err: false,
},
}
for i, testCase := range testCases {
r := testCase.Request
rt := &fakeUpgradeRoundTripper{}
expectedConn := &fakeUpgradeConnection{}
conn, err := r.Upgrade(testCase.Config, func(config *tls.Config) httpstream.UpgradeRoundTripper {
rt.conn = expectedConn
return rt
})
_ = conn
hasErr := err != nil
if hasErr != testCase.Err {
t.Errorf("%d: expected %t, got %t: %v", i, testCase.Err, hasErr, r.err)
}
if testCase.Err {
continue
}
if testCase.AuthBasicHeader && !strings.Contains(rt.req.Header.Get("Authorization"), "Basic") {
t.Errorf("%d: expected basic auth header, got: %s", rt.req.Header.Get("Authorization"))
}
if testCase.AuthBearerHeader && !strings.Contains(rt.req.Header.Get("Authorization"), "Bearer") {
t.Errorf("%d: expected bearer auth header, got: %s", rt.req.Header.Get("Authorization"))
}
if e, a := expectedConn, conn; e != a {
t.Errorf("%d: conn: expected %#v, got %#v", i, e, a)
}
}
}
func TestRequestDo(t *testing.T) {
testCases := []struct {
Request *Request