mirror of
https://github.com/k3s-io/kubernetes.git
synced 2026-01-13 11:25:19 +00:00
Add streaming command execution & port forwarding
Add streaming command execution & port forwarding via HTTP connection upgrades (currently using SPDY).
This commit is contained in:
19
pkg/client/portforward/doc.go
Normal file
19
pkg/client/portforward/doc.go
Normal file
@@ -0,0 +1,19 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package portforward adds support for SSH-like port forwarding from the client's
|
||||
// local host to remote containers.
|
||||
package portforward
|
||||
300
pkg/client/portforward/portforward.go
Normal file
300
pkg/client/portforward/portforward.go
Normal file
@@ -0,0 +1,300 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream/spdy"
|
||||
"github.com/golang/glog"
|
||||
)
|
||||
|
||||
type upgrader interface {
|
||||
upgrade(*client.Request, *client.Config) (httpstream.Connection, error)
|
||||
}
|
||||
|
||||
type defaultUpgrader struct{}
|
||||
|
||||
func (u *defaultUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
|
||||
return req.Upgrade(config, spdy.NewRoundTripper)
|
||||
}
|
||||
|
||||
// PortForwarder knows how to listen for local connections and forward them to
|
||||
// a remote pod via an upgraded HTTP request.
|
||||
type PortForwarder struct {
|
||||
req *client.Request
|
||||
config *client.Config
|
||||
ports []ForwardedPort
|
||||
stopChan <-chan struct{}
|
||||
|
||||
streamConn httpstream.Connection
|
||||
listeners []io.Closer
|
||||
upgrader upgrader
|
||||
Ready chan struct{}
|
||||
}
|
||||
|
||||
// ForwardedPort contains a Local:Remote port pairing.
|
||||
type ForwardedPort struct {
|
||||
Local uint16
|
||||
Remote uint16
|
||||
}
|
||||
|
||||
/*
|
||||
valid port specifications:
|
||||
|
||||
5000
|
||||
- forwards from localhost:5000 to pod:5000
|
||||
|
||||
8888:5000
|
||||
- forwards from localhost:8888 to pod:5000
|
||||
|
||||
0:5000
|
||||
:5000
|
||||
- selects a random available local port,
|
||||
forwards from localhost:<random port> to pod:5000
|
||||
*/
|
||||
func parsePorts(ports []string) ([]ForwardedPort, error) {
|
||||
var forwards []ForwardedPort
|
||||
for _, portString := range ports {
|
||||
parts := strings.Split(portString, ":")
|
||||
var localString, remoteString string
|
||||
if len(parts) == 1 {
|
||||
localString = parts[0]
|
||||
remoteString = parts[0]
|
||||
} else if len(parts) == 2 {
|
||||
localString = parts[0]
|
||||
if localString == "" {
|
||||
// support :5000
|
||||
localString = "0"
|
||||
}
|
||||
remoteString = parts[1]
|
||||
} else {
|
||||
return nil, fmt.Errorf("Invalid port format '%s'", portString)
|
||||
}
|
||||
|
||||
localPort, err := strconv.ParseUint(localString, 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error parsing local port '%s': %s", localString, err)
|
||||
}
|
||||
|
||||
remotePort, err := strconv.ParseUint(remoteString, 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error parsing remote port '%s': %s", remoteString, err)
|
||||
}
|
||||
if remotePort == 0 {
|
||||
return nil, fmt.Errorf("Remote port must be > 0")
|
||||
}
|
||||
|
||||
forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)})
|
||||
}
|
||||
|
||||
return forwards, nil
|
||||
}
|
||||
|
||||
// New creates a new PortForwarder.
|
||||
func New(req *client.Request, config *client.Config, ports []string, stopChan <-chan struct{}) (*PortForwarder, error) {
|
||||
if len(ports) == 0 {
|
||||
return nil, errors.New("You must specify at least 1 port")
|
||||
}
|
||||
parsedPorts, err := parsePorts(ports)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PortForwarder{
|
||||
req: req,
|
||||
config: config,
|
||||
ports: parsedPorts,
|
||||
stopChan: stopChan,
|
||||
Ready: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ForwardPorts formats and executes a port forwarding request. The connection will remain
|
||||
// open until stopChan is closed.
|
||||
func (pf *PortForwarder) ForwardPorts() error {
|
||||
defer pf.Close()
|
||||
|
||||
if pf.upgrader == nil {
|
||||
pf.upgrader = &defaultUpgrader{}
|
||||
}
|
||||
var err error
|
||||
pf.streamConn, err = pf.upgrader.upgrade(pf.req, pf.config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error upgrading connection: %s", err)
|
||||
}
|
||||
defer pf.streamConn.Close()
|
||||
|
||||
return pf.forward()
|
||||
}
|
||||
|
||||
// forward dials the remote host specific in req, upgrades the request, starts
|
||||
// listeners for each port specified in ports, and forwards local connections
|
||||
// to the remote host via streams.
|
||||
func (pf *PortForwarder) forward() error {
|
||||
var err error
|
||||
|
||||
listenSuccess := false
|
||||
for _, port := range pf.ports {
|
||||
err = pf.listenOnPort(&port)
|
||||
if err != nil {
|
||||
glog.Warningf("Unable to listen on port %d: %v", port, err)
|
||||
}
|
||||
listenSuccess = true
|
||||
}
|
||||
|
||||
if !listenSuccess {
|
||||
return fmt.Errorf("Unable to listen on any of the requested ports: %v", pf.ports)
|
||||
}
|
||||
|
||||
close(pf.Ready)
|
||||
|
||||
// wait for interrupt or conn closure
|
||||
select {
|
||||
case <-pf.stopChan:
|
||||
case <-pf.streamConn.CloseChan():
|
||||
glog.Errorf("Lost connection to pod")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// listenOnPort creates a new listener on port and waits for new connections
|
||||
// in the background.
|
||||
func (pf *PortForwarder) listenOnPort(port *ForwardedPort) error {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port.Local))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
parts := strings.Split(listener.Addr().String(), ":")
|
||||
localPort, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error parsing local part: %s", err)
|
||||
}
|
||||
port.Local = uint16(localPort)
|
||||
glog.Infof("Forwarding from %d -> %d", localPort, port.Remote)
|
||||
|
||||
pf.listeners = append(pf.listeners, listener)
|
||||
|
||||
go pf.waitForConnection(listener, *port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForConnection waits for new connections to listener and handles them in
|
||||
// the background.
|
||||
func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
|
||||
glog.Errorf("Error accepting connection on port %d: %v", port.Local, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
go pf.handleConnection(conn, port)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection copies data between the local connection and the stream to
|
||||
// the remote server.
|
||||
func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
|
||||
defer conn.Close()
|
||||
|
||||
glog.Infof("Handling connection for %d", port.Local)
|
||||
|
||||
errorChan := make(chan error)
|
||||
doneChan := make(chan struct{}, 2)
|
||||
|
||||
// create error stream
|
||||
headers := http.Header{}
|
||||
headers.Set(api.StreamType, api.StreamTypeError)
|
||||
headers.Set(api.PortHeader, fmt.Sprintf("%d", port.Remote))
|
||||
errorStream, err := pf.streamConn.CreateStream(headers)
|
||||
if err != nil {
|
||||
glog.Errorf("Error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err)
|
||||
return
|
||||
}
|
||||
defer errorStream.Reset()
|
||||
go func() {
|
||||
message, err := ioutil.ReadAll(errorStream)
|
||||
if err != nil && err != io.EOF {
|
||||
errorChan <- fmt.Errorf("Error reading from error stream for port %d -> %d: %v", port.Local, port.Remote, err)
|
||||
}
|
||||
if len(message) > 0 {
|
||||
errorChan <- fmt.Errorf("An error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
|
||||
}
|
||||
}()
|
||||
|
||||
// create data stream
|
||||
headers.Set(api.StreamType, api.StreamTypeData)
|
||||
dataStream, err := pf.streamConn.CreateStream(headers)
|
||||
if err != nil {
|
||||
glog.Errorf("Error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err)
|
||||
return
|
||||
}
|
||||
// Send a Reset when this function exits to completely tear down the stream here
|
||||
// and in the remote server.
|
||||
defer dataStream.Reset()
|
||||
|
||||
go func() {
|
||||
// Copy from the remote side to the local port. We won't get an EOF from
|
||||
// the server as it has no way of knowing when to close the stream. We'll
|
||||
// take care of closing both ends of the stream with the call to
|
||||
// stream.Reset() when this function exits.
|
||||
if _, err := io.Copy(conn, dataStream); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
glog.Errorf("Error copying from remote stream to local connection: %v", err)
|
||||
}
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
// Copy from the local port to the remote side. Here we will be able to know
|
||||
// when the Copy gets an EOF from conn, as that will happen as soon as conn is
|
||||
// closed (i.e. client disconnected).
|
||||
if _, err := io.Copy(dataStream, conn); err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
|
||||
glog.Errorf("Error copying from local connection to remote stream: %v", err)
|
||||
}
|
||||
doneChan <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errorChan:
|
||||
glog.Error(err)
|
||||
case <-doneChan:
|
||||
}
|
||||
}
|
||||
|
||||
func (pf *PortForwarder) Close() {
|
||||
// stop all listeners
|
||||
for _, l := range pf.listeners {
|
||||
if err := l.Close(); err != nil {
|
||||
glog.Errorf("Error closing listener: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
321
pkg/client/portforward/portforward_test.go
Normal file
321
pkg/client/portforward/portforward_test.go
Normal file
@@ -0,0 +1,321 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
)
|
||||
|
||||
func TestParsePortsAndNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
input []string
|
||||
expected []ForwardedPort
|
||||
expectParseError bool
|
||||
expectNewError bool
|
||||
}{
|
||||
{input: []string{}, expectNewError: true},
|
||||
{input: []string{"a"}, expectParseError: true, expectNewError: true},
|
||||
{input: []string{":a"}, expectParseError: true, expectNewError: true},
|
||||
{input: []string{"-1"}, expectParseError: true, expectNewError: true},
|
||||
{input: []string{"65536"}, expectParseError: true, expectNewError: true},
|
||||
{input: []string{"0"}, expectParseError: true, expectNewError: true},
|
||||
{input: []string{"0:0"}, expectParseError: true, expectNewError: true},
|
||||
{input: []string{"a:5000"}, expectParseError: true, expectNewError: true},
|
||||
{input: []string{"5000:a"}, expectParseError: true, expectNewError: true},
|
||||
{
|
||||
input: []string{"5000", "5000:5000", "8888:5000", "5000:8888", ":5000", "0:5000"},
|
||||
expected: []ForwardedPort{
|
||||
{5000, 5000},
|
||||
{5000, 5000},
|
||||
{8888, 5000},
|
||||
{5000, 8888},
|
||||
{0, 5000},
|
||||
{0, 5000},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
parsed, err := parsePorts(test.input)
|
||||
haveError := err != nil
|
||||
if e, a := test.expectParseError, haveError; e != a {
|
||||
t.Fatalf("%d: parsePorts: error expected=%t, got %t: %s", i, e, a, err)
|
||||
}
|
||||
|
||||
expectedRequest := &client.Request{}
|
||||
expectedConfig := &client.Config{}
|
||||
expectedStopChan := make(chan struct{})
|
||||
pf, err := New(expectedRequest, expectedConfig, test.input, expectedStopChan)
|
||||
haveError = err != nil
|
||||
if e, a := test.expectNewError, haveError; e != a {
|
||||
t.Fatalf("%d: New: error expected=%t, got %t: %s", i, e, a, err)
|
||||
}
|
||||
|
||||
if test.expectParseError || test.expectNewError {
|
||||
continue
|
||||
}
|
||||
|
||||
for pi, expectedPort := range test.expected {
|
||||
if e, a := expectedPort.Local, parsed[pi].Local; e != a {
|
||||
t.Fatalf("%d: local expected: %d, got: %d", i, e, a)
|
||||
}
|
||||
if e, a := expectedPort.Remote, parsed[pi].Remote; e != a {
|
||||
t.Fatalf("%d: remote expected: %d, got: %d", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
if e, a := expectedRequest, pf.req; e != a {
|
||||
t.Fatalf("%d: req: expected %#v, got %#v", i, e, a)
|
||||
}
|
||||
if e, a := expectedConfig, pf.config; e != a {
|
||||
t.Fatalf("%d: config: expected %#v, got %#v", i, e, a)
|
||||
}
|
||||
if e, a := test.expected, pf.ports; !reflect.DeepEqual(e, a) {
|
||||
t.Fatalf("%d: ports: expected %#v, got %#v", i, e, a)
|
||||
}
|
||||
if e, a := expectedStopChan, pf.stopChan; e != a {
|
||||
t.Fatalf("%d: stopChan: expected %#v, got %#v", i, e, a)
|
||||
}
|
||||
if pf.Ready == nil {
|
||||
t.Fatalf("%d: Ready should be non-nil", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type fakeUpgrader struct {
|
||||
conn *fakeUpgradeConnection
|
||||
err error
|
||||
}
|
||||
|
||||
func (u *fakeUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
|
||||
return u.conn, u.err
|
||||
}
|
||||
|
||||
type fakeUpgradeConnection struct {
|
||||
closeCalled bool
|
||||
lock sync.Mutex
|
||||
streams map[string]*fakeUpgradeStream
|
||||
portData map[string]string
|
||||
}
|
||||
|
||||
func newFakeUpgradeConnection() *fakeUpgradeConnection {
|
||||
return &fakeUpgradeConnection{
|
||||
streams: make(map[string]*fakeUpgradeStream),
|
||||
portData: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
stream := &fakeUpgradeStream{}
|
||||
c.streams[headers.Get(api.PortHeader)] = stream
|
||||
stream.data = c.portData[headers.Get(api.PortHeader)]
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) Close() error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.closeCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) CloseChan() <-chan bool {
|
||||
return make(chan bool)
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) {
|
||||
}
|
||||
|
||||
type fakeUpgradeStream struct {
|
||||
readCalled bool
|
||||
writeCalled bool
|
||||
dataWritten []byte
|
||||
closeCalled bool
|
||||
resetCalled bool
|
||||
data string
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Read(p []byte) (int, error) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.readCalled = true
|
||||
b := []byte(s.data)
|
||||
n := copy(p, b)
|
||||
return n, io.EOF
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Write(p []byte) (int, error) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.writeCalled = true
|
||||
s.dataWritten = make([]byte, len(p))
|
||||
copy(s.dataWritten, p)
|
||||
return len(p), io.EOF
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Close() error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.closeCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Reset() error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.resetCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Headers() http.Header {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
return http.Header{}
|
||||
}
|
||||
|
||||
func TestForwardPorts(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Upgrader *fakeUpgrader
|
||||
Ports []string
|
||||
Send map[uint16]string
|
||||
Receive map[uint16]string
|
||||
Err bool
|
||||
}{
|
||||
{
|
||||
Upgrader: &fakeUpgrader{err: errors.New("bail")},
|
||||
Err: true,
|
||||
},
|
||||
{
|
||||
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
|
||||
Ports: []string{"5000"},
|
||||
},
|
||||
{
|
||||
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
|
||||
Ports: []string{"5000", "6000"},
|
||||
Send: map[uint16]string{
|
||||
5000: "abcd",
|
||||
6000: "ghij",
|
||||
},
|
||||
Receive: map[uint16]string{
|
||||
5000: "1234",
|
||||
6000: "5678",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range testCases {
|
||||
stopChan := make(chan struct{}, 1)
|
||||
|
||||
pf, err := New(&client.Request{}, &client.Config{}, testCase.Ports, stopChan)
|
||||
hasErr := err != nil
|
||||
if hasErr != testCase.Err {
|
||||
t.Fatalf("%d: New: expected %t, got %t: %v", i, testCase.Err, hasErr, err)
|
||||
}
|
||||
if pf == nil {
|
||||
continue
|
||||
}
|
||||
pf.upgrader = testCase.Upgrader
|
||||
if testCase.Upgrader.err != nil {
|
||||
err := pf.ForwardPorts()
|
||||
hasErr := err != nil
|
||||
if hasErr != testCase.Err {
|
||||
t.Fatalf("%d: ForwardPorts: expected %t, got %t: %v", i, testCase.Err, hasErr, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
doneChan := make(chan error)
|
||||
go func() {
|
||||
doneChan <- pf.ForwardPorts()
|
||||
}()
|
||||
select {
|
||||
case <-pf.Ready:
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatalf("%d: timed out waiting for listeners", i)
|
||||
}
|
||||
|
||||
conn := testCase.Upgrader.conn
|
||||
|
||||
for port, data := range testCase.Send {
|
||||
conn.lock.Lock()
|
||||
conn.portData[fmt.Sprintf("%d", port)] = testCase.Receive[port]
|
||||
conn.lock.Unlock()
|
||||
|
||||
clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error dialing %d: %s", i, port, err)
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
||||
n, err := clientConn.Write([]byte(data))
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("%d: Error sending data '%s': %s", i, data, err)
|
||||
}
|
||||
if n == 0 {
|
||||
t.Fatalf("%d: unexpected write of 0 bytes", i)
|
||||
}
|
||||
b := make([]byte, 4)
|
||||
n, err = clientConn.Read(b)
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("%d: Error reading data: %s", i, err)
|
||||
}
|
||||
if !bytes.Equal([]byte(testCase.Receive[port]), b) {
|
||||
t.Fatalf("%d: expected to read '%s', got '%s'", i, testCase.Receive[port], b)
|
||||
}
|
||||
}
|
||||
|
||||
// tell r.ForwardPorts to stop
|
||||
close(stopChan)
|
||||
|
||||
// wait for r.ForwardPorts to actually return
|
||||
select {
|
||||
case err := <-doneChan:
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error: %s", err)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatalf("%d: timeout waiting for ForwardPorts to finish")
|
||||
}
|
||||
|
||||
if e, a := len(testCase.Send), len(conn.streams); e != a {
|
||||
t.Fatalf("%d: expected %d streams to be created, got %d", e, a)
|
||||
}
|
||||
|
||||
if !conn.closeCalled {
|
||||
t.Fatalf("%d: expected conn closure", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
20
pkg/client/remotecommand/doc.go
Normal file
20
pkg/client/remotecommand/doc.go
Normal file
@@ -0,0 +1,20 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package remotecommand adds support for executing commands in containers,
|
||||
// with support for separate stdin, stdout, and stderr streams, as well as
|
||||
// TTY.
|
||||
package remotecommand
|
||||
186
pkg/client/remotecommand/remotecommand.go
Normal file
186
pkg/client/remotecommand/remotecommand.go
Normal file
@@ -0,0 +1,186 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package remotecommand
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream/spdy"
|
||||
"github.com/golang/glog"
|
||||
)
|
||||
|
||||
type upgrader interface {
|
||||
upgrade(*client.Request, *client.Config) (httpstream.Connection, error)
|
||||
}
|
||||
|
||||
type defaultUpgrader struct{}
|
||||
|
||||
func (u *defaultUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
|
||||
return req.Upgrade(config, spdy.NewRoundTripper)
|
||||
}
|
||||
|
||||
type RemoteCommandExecutor struct {
|
||||
req *client.Request
|
||||
config *client.Config
|
||||
command []string
|
||||
stdin io.Reader
|
||||
stdout io.Writer
|
||||
stderr io.Writer
|
||||
tty bool
|
||||
|
||||
upgrader upgrader
|
||||
}
|
||||
|
||||
func New(req *client.Request, config *client.Config, command []string, stdin io.Reader, stdout, stderr io.Writer, tty bool) *RemoteCommandExecutor {
|
||||
return &RemoteCommandExecutor{
|
||||
req: req,
|
||||
config: config,
|
||||
command: command,
|
||||
stdin: stdin,
|
||||
stdout: stdout,
|
||||
stderr: stderr,
|
||||
tty: tty,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute sends a remote command execution request, upgrading the
|
||||
// connection and creating streams to represent stdin/stdout/stderr. Data is
|
||||
// copied between these streams and the supplied stdin/stdout/stderr parameters.
|
||||
func (e *RemoteCommandExecutor) Execute() error {
|
||||
doStdin := (e.stdin != nil)
|
||||
doStdout := (e.stdout != nil)
|
||||
doStderr := (!e.tty && e.stderr != nil)
|
||||
|
||||
if doStdin {
|
||||
e.req.Param(api.ExecStdinParam, "1")
|
||||
}
|
||||
if doStdout {
|
||||
e.req.Param(api.ExecStdoutParam, "1")
|
||||
}
|
||||
if doStderr {
|
||||
e.req.Param(api.ExecStderrParam, "1")
|
||||
}
|
||||
if e.tty {
|
||||
e.req.Param(api.ExecTTYParam, "1")
|
||||
}
|
||||
|
||||
for _, s := range e.command {
|
||||
e.req.Param(api.ExecCommandParamm, s)
|
||||
}
|
||||
|
||||
if e.upgrader == nil {
|
||||
e.upgrader = &defaultUpgrader{}
|
||||
}
|
||||
conn, err := e.upgrader.upgrade(e.req, e.config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
doneChan := make(chan struct{}, 2)
|
||||
errorChan := make(chan error)
|
||||
|
||||
cp := func(s string, dst io.Writer, src io.Reader) {
|
||||
glog.V(4).Infof("Copying %s", s)
|
||||
defer glog.V(4).Infof("Done copying %s", s)
|
||||
if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
|
||||
glog.Errorf("Error copying %s: %v", s, err)
|
||||
}
|
||||
if s == api.StreamTypeStdout || s == api.StreamTypeStderr {
|
||||
doneChan <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set(api.StreamType, api.StreamTypeError)
|
||||
errorStream, err := conn.CreateStream(headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
message, err := ioutil.ReadAll(errorStream)
|
||||
if err != nil && err != io.EOF {
|
||||
errorChan <- fmt.Errorf("Error reading from error stream: %s", err)
|
||||
return
|
||||
}
|
||||
if len(message) > 0 {
|
||||
errorChan <- fmt.Errorf("Error executing remote command: %s", message)
|
||||
return
|
||||
}
|
||||
}()
|
||||
defer errorStream.Reset()
|
||||
|
||||
if doStdin {
|
||||
headers.Set(api.StreamType, api.StreamTypeStdin)
|
||||
remoteStdin, err := conn.CreateStream(headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer remoteStdin.Reset()
|
||||
// TODO this goroutine will never exit cleanly (the io.Copy never unblocks)
|
||||
// because stdin is not closed until the process exits. If we try to call
|
||||
// stdin.Close(), it returns no error but doesn't unblock the copy. It will
|
||||
// exit when the process exits, instead.
|
||||
go cp(api.StreamTypeStdin, remoteStdin, e.stdin)
|
||||
}
|
||||
|
||||
waitCount := 0
|
||||
completedStreams := 0
|
||||
|
||||
if doStdout {
|
||||
waitCount++
|
||||
headers.Set(api.StreamType, api.StreamTypeStdout)
|
||||
remoteStdout, err := conn.CreateStream(headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer remoteStdout.Reset()
|
||||
go cp(api.StreamTypeStdout, e.stdout, remoteStdout)
|
||||
}
|
||||
|
||||
if doStderr && !e.tty {
|
||||
waitCount++
|
||||
headers.Set(api.StreamType, api.StreamTypeStderr)
|
||||
remoteStderr, err := conn.CreateStream(headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer remoteStderr.Reset()
|
||||
go cp(api.StreamTypeStderr, e.stderr, remoteStderr)
|
||||
}
|
||||
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case <-doneChan:
|
||||
completedStreams++
|
||||
if completedStreams == waitCount {
|
||||
break Loop
|
||||
}
|
||||
case err := <-errorChan:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
288
pkg/client/remotecommand/remotecommand_test.go
Normal file
288
pkg/client/remotecommand/remotecommand_test.go
Normal file
@@ -0,0 +1,288 @@
|
||||
/*
|
||||
Copyright 2015 Google Inc. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package remotecommand
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/api"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/client"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
)
|
||||
|
||||
type fakeUpgrader struct {
|
||||
conn *fakeUpgradeConnection
|
||||
err error
|
||||
}
|
||||
|
||||
func (u *fakeUpgrader) upgrade(req *client.Request, config *client.Config) (httpstream.Connection, error) {
|
||||
return u.conn, u.err
|
||||
}
|
||||
|
||||
type fakeUpgradeConnection struct {
|
||||
closeCalled bool
|
||||
lock sync.Mutex
|
||||
|
||||
stdin *fakeUpgradeStream
|
||||
stdout *fakeUpgradeStream
|
||||
stdoutData string
|
||||
stderr *fakeUpgradeStream
|
||||
stderrData string
|
||||
errorStream *fakeUpgradeStream
|
||||
errorData string
|
||||
unexpectedStreamCreated bool
|
||||
}
|
||||
|
||||
func newFakeUpgradeConnection() *fakeUpgradeConnection {
|
||||
return &fakeUpgradeConnection{}
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
stream := &fakeUpgradeStream{}
|
||||
switch headers.Get(api.StreamType) {
|
||||
case api.StreamTypeStdin:
|
||||
c.stdin = stream
|
||||
case api.StreamTypeStdout:
|
||||
c.stdout = stream
|
||||
stream.data = c.stdoutData
|
||||
case api.StreamTypeStderr:
|
||||
c.stderr = stream
|
||||
stream.data = c.stderrData
|
||||
case api.StreamTypeError:
|
||||
c.errorStream = stream
|
||||
stream.data = c.errorData
|
||||
default:
|
||||
c.unexpectedStreamCreated = true
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) Close() error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.closeCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) CloseChan() <-chan bool {
|
||||
return make(chan bool)
|
||||
}
|
||||
|
||||
func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) {
|
||||
}
|
||||
|
||||
type fakeUpgradeStream struct {
|
||||
readCalled bool
|
||||
writeCalled bool
|
||||
dataWritten []byte
|
||||
closeCalled bool
|
||||
resetCalled bool
|
||||
data string
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Read(p []byte) (int, error) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.readCalled = true
|
||||
b := []byte(s.data)
|
||||
n := copy(p, b)
|
||||
return n, io.EOF
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Write(p []byte) (int, error) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.writeCalled = true
|
||||
s.dataWritten = make([]byte, len(p))
|
||||
copy(s.dataWritten, p)
|
||||
return len(p), io.EOF
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Close() error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.closeCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Reset() error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.resetCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeUpgradeStream) Headers() http.Header {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
return http.Header{}
|
||||
}
|
||||
|
||||
func TestRequestExecuteRemoteCommand(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Upgrader *fakeUpgrader
|
||||
Stdin string
|
||||
Stdout string
|
||||
Stderr string
|
||||
Error string
|
||||
Tty bool
|
||||
ShouldError bool
|
||||
}{
|
||||
{
|
||||
Upgrader: &fakeUpgrader{err: errors.New("bail")},
|
||||
ShouldError: true,
|
||||
},
|
||||
{
|
||||
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
|
||||
Stdin: "a",
|
||||
Stdout: "b",
|
||||
Stderr: "c",
|
||||
Error: "bail",
|
||||
ShouldError: true,
|
||||
},
|
||||
{
|
||||
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
|
||||
Stdin: "a",
|
||||
Stdout: "b",
|
||||
Stderr: "c",
|
||||
},
|
||||
{
|
||||
Upgrader: &fakeUpgrader{conn: newFakeUpgradeConnection()},
|
||||
Stdin: "a",
|
||||
Stdout: "b",
|
||||
Stderr: "c",
|
||||
Tty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range testCases {
|
||||
if testCase.Error != "" {
|
||||
testCase.Upgrader.conn.errorData = testCase.Error
|
||||
}
|
||||
if testCase.Stdout != "" {
|
||||
testCase.Upgrader.conn.stdoutData = testCase.Stdout
|
||||
}
|
||||
if testCase.Stderr != "" {
|
||||
testCase.Upgrader.conn.stderrData = testCase.Stderr
|
||||
}
|
||||
var localOut, localErr *bytes.Buffer
|
||||
if testCase.Stdout != "" {
|
||||
localOut = &bytes.Buffer{}
|
||||
}
|
||||
if testCase.Stderr != "" {
|
||||
localErr = &bytes.Buffer{}
|
||||
}
|
||||
e := New(&client.Request{}, &client.Config{}, []string{"ls", "/"}, strings.NewReader(testCase.Stdin), localOut, localErr, testCase.Tty)
|
||||
e.upgrader = testCase.Upgrader
|
||||
err := e.Execute()
|
||||
hasErr := err != nil
|
||||
if hasErr != testCase.ShouldError {
|
||||
t.Fatalf("%d: expected %t, got %t: %v", i, testCase.ShouldError, hasErr, err)
|
||||
}
|
||||
|
||||
conn := testCase.Upgrader.conn
|
||||
if testCase.Error != "" {
|
||||
if conn.errorStream == nil {
|
||||
t.Fatalf("%d: expected error stream creation", i)
|
||||
}
|
||||
if !conn.errorStream.readCalled {
|
||||
t.Fatalf("%d: expected error stream read", i)
|
||||
}
|
||||
if e, a := testCase.Error, err.Error(); !strings.Contains(a, e) {
|
||||
t.Fatalf("%d: expected error stream read '%v', got '%v'", i, e, a)
|
||||
}
|
||||
if !conn.errorStream.resetCalled {
|
||||
t.Fatalf("%d: expected error reset", i)
|
||||
}
|
||||
}
|
||||
|
||||
if testCase.ShouldError {
|
||||
continue
|
||||
}
|
||||
|
||||
if testCase.Stdin != "" {
|
||||
if conn.stdin == nil {
|
||||
t.Fatalf("%d: expected stdin stream creation", i)
|
||||
}
|
||||
if !conn.stdin.writeCalled {
|
||||
t.Fatalf("%d: expected stdin stream write", i)
|
||||
}
|
||||
if e, a := testCase.Stdin, string(conn.stdin.dataWritten); e != a {
|
||||
t.Fatalf("%d: expected stdin write %v, got %v", i, e, a)
|
||||
}
|
||||
if !conn.stdin.resetCalled {
|
||||
t.Fatalf("%d: expected stdin reset", i)
|
||||
}
|
||||
}
|
||||
|
||||
if testCase.Stdout != "" {
|
||||
if conn.stdout == nil {
|
||||
t.Fatalf("%d: expected stdout stream creation", i)
|
||||
}
|
||||
if !conn.stdout.readCalled {
|
||||
t.Fatalf("%d: expected stdout stream read", i)
|
||||
}
|
||||
if e, a := testCase.Stdout, localOut; e != a.String() {
|
||||
t.Fatalf("%d: expected stdout data '%s', got '%s'", i, e, a)
|
||||
}
|
||||
if !conn.stdout.resetCalled {
|
||||
t.Fatalf("%d: expected stdout reset", i)
|
||||
}
|
||||
}
|
||||
|
||||
if testCase.Stderr != "" {
|
||||
if testCase.Tty {
|
||||
if conn.stderr != nil {
|
||||
t.Fatalf("%d: unexpected stderr stream creation", i)
|
||||
}
|
||||
if localErr.String() != "" {
|
||||
t.Fatalf("%d: unexpected stderr data '%s'", i, localErr)
|
||||
}
|
||||
} else {
|
||||
if conn.stderr == nil {
|
||||
t.Fatalf("%d: expected stderr stream creation", i)
|
||||
}
|
||||
if !conn.stderr.readCalled {
|
||||
t.Fatalf("%d: expected stderr stream read", i)
|
||||
}
|
||||
if e, a := testCase.Stderr, localErr; e != a.String() {
|
||||
t.Fatalf("%d: expected stderr data '%s', got '%s'", i, e, a)
|
||||
}
|
||||
if !conn.stderr.resetCalled {
|
||||
t.Fatalf("%d: expected stderr reset", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !conn.closeCalled {
|
||||
t.Fatalf("%d: expected upgraded connection to get closed")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@@ -33,6 +34,7 @@ import (
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/labels"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/watch"
|
||||
watchjson "github.com/GoogleCloudPlatform/kubernetes/pkg/watch/json"
|
||||
"github.com/golang/glog"
|
||||
@@ -277,7 +279,7 @@ func (r *Request) setParam(paramName, value string) *Request {
|
||||
if r.params == nil {
|
||||
r.params = make(url.Values)
|
||||
}
|
||||
r.params[paramName] = []string{value}
|
||||
r.params[paramName] = append(r.params[paramName], value)
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -347,8 +349,10 @@ func (r *Request) finalURL() string {
|
||||
finalURL.Path = p
|
||||
|
||||
query := url.Values{}
|
||||
for key, value := range r.params {
|
||||
query[key] = value
|
||||
for key, values := range r.params {
|
||||
for _, value := range values {
|
||||
query.Add(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
if r.namespaceSet && r.namespaceInQuery {
|
||||
@@ -434,6 +438,41 @@ func (r *Request) Stream() (io.ReadCloser, error) {
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
// Upgrade upgrades the request so that it supports multiplexed bidirectional
|
||||
// streams. The current implementation uses SPDY, but this could be replaced
|
||||
// with HTTP/2 once it's available, or something else.
|
||||
func (r *Request) Upgrade(config *Config, newRoundTripperFunc func(*tls.Config) httpstream.UpgradeRoundTripper) (httpstream.Connection, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
|
||||
tlsConfig, err := TLSConfigFor(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
upgradeRoundTripper := newRoundTripperFunc(tlsConfig)
|
||||
wrapper, err := HTTPWrappersForConfig(config, upgradeRoundTripper)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.client = &http.Client{Transport: wrapper}
|
||||
|
||||
req, err := http.NewRequest(r.verb, r.finalURL(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error creating request: %s", err)
|
||||
}
|
||||
|
||||
resp, err := r.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error sending request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return upgradeRoundTripper.NewConnection(resp)
|
||||
}
|
||||
|
||||
// Do formats and executes the request. Returns a Result object for easy response
|
||||
// processing.
|
||||
//
|
||||
@@ -513,6 +552,8 @@ func (r *Request) transformResponse(resp *http.Response, req *http.Request) ([]b
|
||||
}
|
||||
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusSwitchingProtocols:
|
||||
// no-op, we've been upgraded
|
||||
case resp.StatusCode < http.StatusOK || resp.StatusCode > http.StatusPartialContent:
|
||||
if !isStatusResponse {
|
||||
var err error = &UnexpectedStatusError{
|
||||
|
||||
@@ -18,6 +18,7 @@ package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -40,6 +41,7 @@ import (
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/labels"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/runtime"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/util/httpstream"
|
||||
"github.com/GoogleCloudPlatform/kubernetes/pkg/watch"
|
||||
watchjson "github.com/GoogleCloudPlatform/kubernetes/pkg/watch/json"
|
||||
)
|
||||
@@ -151,16 +153,22 @@ func TestRequestParam(t *testing.T) {
|
||||
if !api.Semantic.DeepDerivative(r.params, url.Values{"foo": []string{"a"}}) {
|
||||
t.Errorf("should have set a param: %#v", r)
|
||||
}
|
||||
|
||||
r.Param("bar", "1")
|
||||
r.Param("bar", "2")
|
||||
if !api.Semantic.DeepDerivative(r.params, url.Values{"foo": []string{"a"}, "bar": []string{"1", "2"}}) {
|
||||
t.Errorf("should have set a param: %#v", r)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestURI(t *testing.T) {
|
||||
r := (&Request{}).Param("foo", "a")
|
||||
r.Prefix("other")
|
||||
r.RequestURI("/test?foo=b&a=b")
|
||||
r.RequestURI("/test?foo=b&a=b&c=1&c=2")
|
||||
if r.path != "/test" {
|
||||
t.Errorf("path is wrong: %#v", r)
|
||||
}
|
||||
if !api.Semantic.DeepDerivative(r.params, url.Values{"a": []string{"b"}, "foo": []string{"b"}}) {
|
||||
if !api.Semantic.DeepDerivative(r.params, url.Values{"a": []string{"b"}, "foo": []string{"b"}, "c": []string{"1", "2"}}) {
|
||||
t.Errorf("should have set a param: %#v", r)
|
||||
}
|
||||
}
|
||||
@@ -443,6 +451,122 @@ func TestRequestStream(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type fakeUpgradeConnection struct{}
|
||||
|
||||
func (c *fakeUpgradeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *fakeUpgradeConnection) Close() error {
|
||||
return nil
|
||||
}
|
||||
func (c *fakeUpgradeConnection) CloseChan() <-chan bool {
|
||||
return make(chan bool)
|
||||
}
|
||||
func (c *fakeUpgradeConnection) SetIdleTimeout(timeout time.Duration) {
|
||||
}
|
||||
|
||||
type fakeUpgradeRoundTripper struct {
|
||||
req *http.Request
|
||||
conn httpstream.Connection
|
||||
}
|
||||
|
||||
func (f *fakeUpgradeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
f.req = req
|
||||
b := []byte{}
|
||||
body := ioutil.NopCloser(bytes.NewReader(b))
|
||||
resp := &http.Response{
|
||||
StatusCode: 101,
|
||||
Body: body,
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (f *fakeUpgradeRoundTripper) NewConnection(resp *http.Response) (httpstream.Connection, error) {
|
||||
return f.conn, nil
|
||||
}
|
||||
|
||||
func TestRequestUpgrade(t *testing.T) {
|
||||
uri, _ := url.Parse("http://localhost/")
|
||||
testCases := []struct {
|
||||
Request *Request
|
||||
Config *Config
|
||||
RoundTripper *fakeUpgradeRoundTripper
|
||||
Err bool
|
||||
AuthBasicHeader bool
|
||||
AuthBearerHeader bool
|
||||
}{
|
||||
{
|
||||
Request: &Request{err: errors.New("bail")},
|
||||
Err: true,
|
||||
},
|
||||
{
|
||||
Request: &Request{},
|
||||
Config: &Config{
|
||||
TLSClientConfig: TLSClientConfig{
|
||||
CAFile: "foo",
|
||||
},
|
||||
Insecure: true,
|
||||
},
|
||||
Err: true,
|
||||
},
|
||||
{
|
||||
Request: &Request{},
|
||||
Config: &Config{
|
||||
Username: "u",
|
||||
Password: "p",
|
||||
BearerToken: "b",
|
||||
},
|
||||
Err: true,
|
||||
},
|
||||
{
|
||||
Request: NewRequest(nil, "", uri, testapi.Codec(), true, true),
|
||||
Config: &Config{
|
||||
Username: "u",
|
||||
Password: "p",
|
||||
},
|
||||
AuthBasicHeader: true,
|
||||
Err: false,
|
||||
},
|
||||
{
|
||||
Request: NewRequest(nil, "", uri, testapi.Codec(), true, true),
|
||||
Config: &Config{
|
||||
BearerToken: "b",
|
||||
},
|
||||
AuthBearerHeader: true,
|
||||
Err: false,
|
||||
},
|
||||
}
|
||||
for i, testCase := range testCases {
|
||||
r := testCase.Request
|
||||
rt := &fakeUpgradeRoundTripper{}
|
||||
expectedConn := &fakeUpgradeConnection{}
|
||||
conn, err := r.Upgrade(testCase.Config, func(config *tls.Config) httpstream.UpgradeRoundTripper {
|
||||
rt.conn = expectedConn
|
||||
return rt
|
||||
})
|
||||
_ = conn
|
||||
hasErr := err != nil
|
||||
if hasErr != testCase.Err {
|
||||
t.Errorf("%d: expected %t, got %t: %v", i, testCase.Err, hasErr, r.err)
|
||||
}
|
||||
if testCase.Err {
|
||||
continue
|
||||
}
|
||||
|
||||
if testCase.AuthBasicHeader && !strings.Contains(rt.req.Header.Get("Authorization"), "Basic") {
|
||||
t.Errorf("%d: expected basic auth header, got: %s", rt.req.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
if testCase.AuthBearerHeader && !strings.Contains(rt.req.Header.Get("Authorization"), "Bearer") {
|
||||
t.Errorf("%d: expected bearer auth header, got: %s", rt.req.Header.Get("Authorization"))
|
||||
}
|
||||
|
||||
if e, a := expectedConn, conn; e != a {
|
||||
t.Errorf("%d: conn: expected %#v, got %#v", i, e, a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestDo(t *testing.T) {
|
||||
testCases := []struct {
|
||||
Request *Request
|
||||
|
||||
Reference in New Issue
Block a user