Port forward over websockets

- split out port forwarding into its own package

Allow multiple port forwarding ports
- Make it easy to determine which port is tied to which channel
- odd channels are for data
- even channels are for errors

- allow comma separated ports to specify multiple ports

Add  portfowardtester 1.2 to whitelist
This commit is contained in:
Michael Fraenkel 2016-10-17 16:50:20 +08:00
parent 96cfe7b938
commit beb53fb71a
17 changed files with 1167 additions and 364 deletions

View File

@ -27,7 +27,7 @@ spec:
command:
- /bin/sh
- -c
- "for i in gcr.io/google_containers/busybox gcr.io/google_containers/busybox:1.24 gcr.io/google_containers/dnsutils:e2e gcr.io/google_containers/eptest:0.1 gcr.io/google_containers/fakegitserver:0.1 gcr.io/google_containers/hostexec:1.2 gcr.io/google_containers/iperf:e2e gcr.io/google_containers/jessie-dnsutils:e2e gcr.io/google_containers/liveness:e2e gcr.io/google_containers/mounttest:0.7 gcr.io/google_containers/mounttest-user:0.3 gcr.io/google_containers/netexec:1.4 gcr.io/google_containers/netexec:1.7 gcr.io/google_containers/nettest:1.7 gcr.io/google_containers/nettest:1.8 gcr.io/google_containers/nginx-slim:0.7 gcr.io/google_containers/nginx-slim:0.8 gcr.io/google_containers/n-way-http:1.0 gcr.io/google_containers/pause:2.0 gcr.io/google_containers/pause-amd64:3.0 gcr.io/google_containers/porter:cd5cb5791ebaa8641955f0e8c2a9bed669b1eaab gcr.io/google_containers/portforwardtester:1.0 gcr.io/google_containers/redis:e2e gcr.io/google_containers/resource_consumer:beta4 gcr.io/google_containers/resource_consumer/controller:beta4 gcr.io/google_containers/serve_hostname:v1.4 gcr.io/google_containers/test-webserver:e2e gcr.io/google_containers/ubuntu:14.04 gcr.io/google_containers/update-demo:kitten gcr.io/google_containers/update-demo:nautilus gcr.io/google_containers/volume-ceph:0.1 gcr.io/google_containers/volume-gluster:0.2 gcr.io/google_containers/volume-iscsi:0.1 gcr.io/google_containers/volume-nfs:0.6 gcr.io/google_containers/volume-rbd:0.1 gcr.io/google_samples/gb-redisslave:v1 gcr.io/google_containers/redis:v1; do echo $(date '+%X') pulling $i; docker pull $i 1>/dev/null; done; exit 0;"
- "for i in gcr.io/google_containers/busybox gcr.io/google_containers/busybox:1.24 gcr.io/google_containers/dnsutils:e2e gcr.io/google_containers/eptest:0.1 gcr.io/google_containers/fakegitserver:0.1 gcr.io/google_containers/hostexec:1.2 gcr.io/google_containers/iperf:e2e gcr.io/google_containers/jessie-dnsutils:e2e gcr.io/google_containers/liveness:e2e gcr.io/google_containers/mounttest:0.7 gcr.io/google_containers/mounttest-user:0.3 gcr.io/google_containers/netexec:1.4 gcr.io/google_containers/netexec:1.7 gcr.io/google_containers/nettest:1.7 gcr.io/google_containers/nettest:1.8 gcr.io/google_containers/nginx-slim:0.7 gcr.io/google_containers/nginx-slim:0.8 gcr.io/google_containers/n-way-http:1.0 gcr.io/google_containers/pause:2.0 gcr.io/google_containers/pause-amd64:3.0 gcr.io/google_containers/porter:cd5cb5791ebaa8641955f0e8c2a9bed669b1eaab gcr.io/google_containers/portforwardtester:1.2 gcr.io/google_containers/redis:e2e gcr.io/google_containers/resource_consumer:beta4 gcr.io/google_containers/resource_consumer/controller:beta4 gcr.io/google_containers/serve_hostname:v1.4 gcr.io/google_containers/test-webserver:e2e gcr.io/google_containers/ubuntu:14.04 gcr.io/google_containers/update-demo:kitten gcr.io/google_containers/update-demo:nautilus gcr.io/google_containers/volume-ceph:0.1 gcr.io/google_containers/volume-gluster:0.2 gcr.io/google_containers/volume-iscsi:0.1 gcr.io/google_containers/volume-nfs:0.6 gcr.io/google_containers/volume-rbd:0.1 gcr.io/google_samples/gb-redisslave:v1 gcr.io/google_containers/redis:v1; do echo $(date '+%X') pulling $i; docker pull $i 1>/dev/null; done; exit 0;"
securityContext:
privileged: true
volumeMounts:

View File

@ -84,7 +84,7 @@ func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedF
received: make(map[uint16]string),
send: serverSends,
}
portforward.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second)
portforward.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second, portforward.SupportedProtocols)
for port, expected := range expectedFromClient {
actual, ok := pf.received[port]

View File

@ -1047,10 +1047,14 @@ func typeToJSON(typeName string) string {
return "string"
case "byte", "*byte":
return "string"
// TODO: Fix these when go-restful supports a way to specify an array query param:
// https://github.com/emicklei/go-restful/issues/225
case "[]string", "[]*string":
// TODO: Fix this when go-restful supports a way to specify an array query param:
// https://github.com/emicklei/go-restful/issues/225
return "string"
case "[]int32", "[]*int32":
return "integer"
default:
return typeName
}

View File

@ -2171,9 +2171,10 @@ func getStreamingConfig(kubeCfg *componentconfig.KubeletConfiguration, kubeDeps
BaseURL: &url.URL{
Path: "/cri/",
},
StreamIdleTimeout: kubeCfg.StreamingConnectionIdleTimeout.Duration,
StreamCreationTimeout: streaming.DefaultConfig.StreamCreationTimeout,
SupportedProtocols: streaming.DefaultConfig.SupportedProtocols,
StreamIdleTimeout: kubeCfg.StreamingConnectionIdleTimeout.Duration,
StreamCreationTimeout: streaming.DefaultConfig.StreamCreationTimeout,
SupportedRemoteCommandProtocols: streaming.DefaultConfig.SupportedRemoteCommandProtocols,
SupportedPortForwardProtocols: streaming.DefaultConfig.SupportedPortForwardProtocols,
}
if kubeDeps.TLSOptions != nil {
config.TLSConfig = kubeDeps.TLSOptions.Config

View File

@ -18,4 +18,6 @@ limitations under the License.
package portforward
// The subprotocol "portforward.k8s.io" is used for port forwarding.
const PortForwardProtocolV1Name = "portforward.k8s.io"
const ProtocolV1Name = "portforward.k8s.io"
var SupportedProtocols = []string{ProtocolV1Name}

View File

@ -0,0 +1,309 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"errors"
"fmt"
"net/http"
"strconv"
"sync"
"time"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/kubernetes/pkg/api"
"github.com/golang/glog"
)
func handleHttpStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error {
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols)
// negotiated protocol isn't currently used server side, but could be in the future
if err != nil {
// Handshake writes the error to the client
return err
}
streamChan := make(chan httpstream.Stream, 1)
glog.V(5).Infof("Upgrading port forward response")
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, httpStreamReceived(streamChan))
if conn == nil {
return errors.New("Unable to upgrade websocket connection")
}
defer conn.Close()
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
conn.SetIdleTimeout(idleTimeout)
h := &httpStreamHandler{
conn: conn,
streamChan: streamChan,
streamPairs: make(map[string]*httpStreamPair),
streamCreationTimeout: streamCreationTimeout,
pod: podName,
uid: uid,
forwarder: portForwarder,
}
h.run()
return nil
}
// httpStreamReceived is the httpstream.NewStreamHandler for port
// forward streams. It checks each stream's port and stream type headers,
// rejecting any streams that with missing or invalid values. Each valid
// stream is sent to the streams channel.
func httpStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
// make sure it has a valid port header
portString := stream.Headers().Get(api.PortHeader)
if len(portString) == 0 {
return fmt.Errorf("%q header is required", api.PortHeader)
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
}
if port < 1 {
return fmt.Errorf("port %q must be > 0", portString)
}
// make sure it has a valid stream type header
streamType := stream.Headers().Get(api.StreamType)
if len(streamType) == 0 {
return fmt.Errorf("%q header is required", api.StreamType)
}
if streamType != api.StreamTypeError && streamType != api.StreamTypeData {
return fmt.Errorf("invalid stream type %q", streamType)
}
streams <- stream
return nil
}
}
// httpStreamHandler is capable of processing multiple port forward
// requests over a single httpstream.Connection.
type httpStreamHandler struct {
conn httpstream.Connection
streamChan chan httpstream.Stream
streamPairsLock sync.RWMutex
streamPairs map[string]*httpStreamPair
streamCreationTimeout time.Duration
pod string
uid types.UID
forwarder PortForwarder
}
// getStreamPair returns a httpStreamPair for requestID. This creates a
// new pair if one does not yet exist for the requestID. The returned bool is
// true if the pair was created.
func (h *httpStreamHandler) getStreamPair(requestID string) (*httpStreamPair, bool) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()
if p, ok := h.streamPairs[requestID]; ok {
glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
return p, false
}
glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)
p := newPortForwardPair(requestID)
h.streamPairs[requestID] = p
return p, true
}
// monitorStreamPair waits for the pair to receive both its error and data
// streams, or for the timeout to expire (whichever happens first), and then
// removes the pair.
func (h *httpStreamHandler) monitorStreamPair(p *httpStreamPair, timeout <-chan time.Time) {
select {
case <-timeout:
err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID)
utilruntime.HandleError(err)
p.printError(err.Error())
case <-p.complete:
glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID)
}
h.removeStreamPair(p.requestID)
}
// hasStreamPair returns a bool indicating if a stream pair for requestID
// exists.
func (h *httpStreamHandler) hasStreamPair(requestID string) bool {
h.streamPairsLock.RLock()
defer h.streamPairsLock.RUnlock()
_, ok := h.streamPairs[requestID]
return ok
}
// removeStreamPair removes the stream pair identified by requestID from streamPairs.
func (h *httpStreamHandler) removeStreamPair(requestID string) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()
delete(h.streamPairs, requestID)
}
// requestID returns the request id for stream.
func (h *httpStreamHandler) requestID(stream httpstream.Stream) string {
requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
if len(requestID) == 0 {
glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
// If we get here, it's because the connection came from an older client
// that isn't generating the request id header
// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
//
// This is a best-effort attempt at supporting older clients.
//
// When there aren't concurrent new forwarded connections, each connection
// will have a pair of streams (data, error), and the stream IDs will be
// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
// the stream ID into a pseudo-request id by taking the stream type and
// using id = stream.Identifier() when the stream type is error,
// and id = stream.Identifier() - 2 when it's data.
//
// NOTE: this only works when there are not concurrent new streams from
// multiple forwarded connections; it's a best-effort attempt at supporting
// old clients that don't generate request ids. If there are concurrent
// new connections, it's possible that 1 connection gets streams whose IDs
// are not consecutive (e.g. 5 and 9 instead of 5 and 7).
streamType := stream.Headers().Get(api.StreamType)
switch streamType {
case api.StreamTypeError:
requestID = strconv.Itoa(int(stream.Identifier()))
case api.StreamTypeData:
requestID = strconv.Itoa(int(stream.Identifier()) - 2)
}
glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
}
return requestID
}
// run is the main loop for the httpStreamHandler. It processes new
// streams, invoking portForward for each complete stream pair. The loop exits
// when the httpstream.Connection is closed.
func (h *httpStreamHandler) run() {
glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn)
Loop:
for {
select {
case <-h.conn.CloseChan():
glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn)
break Loop
case stream := <-h.streamChan:
requestID := h.requestID(stream)
streamType := stream.Headers().Get(api.StreamType)
glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)
p, created := h.getStreamPair(requestID)
if created {
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
}
if complete, err := p.add(stream); err != nil {
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
utilruntime.HandleError(errors.New(msg))
p.printError(msg)
} else if complete {
go h.portForward(p)
}
}
}
}
// portForward invokes the httpStreamHandler's forwarder.PortForward
// function for the given stream pair.
func (h *httpStreamHandler) portForward(p *httpStreamPair) {
defer p.dataStream.Close()
defer p.errorStream.Close()
portString := p.dataStream.Headers().Get(api.PortHeader)
port, _ := strconv.ParseUint(portString, 10, 16)
glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream)
glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
if err != nil {
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err)
utilruntime.HandleError(msg)
fmt.Fprint(p.errorStream, msg.Error())
}
}
// httpStreamPair represents the error and data streams for a port
// forwarding request.
type httpStreamPair struct {
lock sync.RWMutex
requestID string
dataStream httpstream.Stream
errorStream httpstream.Stream
complete chan struct{}
}
// newPortForwardPair creates a new httpStreamPair.
func newPortForwardPair(requestID string) *httpStreamPair {
return &httpStreamPair{
requestID: requestID,
complete: make(chan struct{}),
}
}
// add adds the stream to the httpStreamPair. If the pair already
// contains a stream for the new stream's type, an error is returned. add
// returns true if both the data and error streams for this pair have been
// received.
func (p *httpStreamPair) add(stream httpstream.Stream) (bool, error) {
p.lock.Lock()
defer p.lock.Unlock()
switch stream.Headers().Get(api.StreamType) {
case api.StreamTypeError:
if p.errorStream != nil {
return false, errors.New("error stream already assigned")
}
p.errorStream = stream
case api.StreamTypeData:
if p.dataStream != nil {
return false, errors.New("data stream already assigned")
}
p.dataStream = stream
}
complete := p.errorStream != nil && p.dataStream != nil
if complete {
close(p.complete)
}
return complete, nil
}
// printError writes s to p.errorStream if p.errorStream has been set.
func (p *httpStreamPair) printError(s string) {
p.lock.RLock()
defer p.lock.RUnlock()
if p.errorStream != nil {
fmt.Fprint(p.errorStream, s)
}
}

View File

@ -25,7 +25,7 @@ import (
"k8s.io/kubernetes/pkg/api"
)
func TestPortForwardStreamReceived(t *testing.T) {
func TestHTTPStreamReceived(t *testing.T) {
tests := map[string]struct {
port string
streamType string
@ -62,7 +62,7 @@ func TestPortForwardStreamReceived(t *testing.T) {
}
for name, test := range tests {
streams := make(chan httpstream.Stream, 1)
f := portForwardStreamReceived(streams)
f := httpStreamReceived(streams)
stream := newFakeHttpStream()
if len(test.port) > 0 {
stream.headers.Set("port", test.port)
@ -92,48 +92,11 @@ func TestPortForwardStreamReceived(t *testing.T) {
}
}
type fakeHttpStream struct {
headers http.Header
id uint32
}
func newFakeHttpStream() *fakeHttpStream {
return &fakeHttpStream{
headers: make(http.Header),
}
}
var _ httpstream.Stream = &fakeHttpStream{}
func (s *fakeHttpStream) Read(data []byte) (int, error) {
return 0, nil
}
func (s *fakeHttpStream) Write(data []byte) (int, error) {
return 0, nil
}
func (s *fakeHttpStream) Close() error {
return nil
}
func (s *fakeHttpStream) Reset() error {
return nil
}
func (s *fakeHttpStream) Headers() http.Header {
return s.headers
}
func (s *fakeHttpStream) Identifier() uint32 {
return s.id
}
func TestGetStreamPair(t *testing.T) {
timeout := make(chan time.Time)
h := &portForwardStreamHandler{
streamPairs: make(map[string]*portForwardStreamPair),
h := &httpStreamHandler{
streamPairs: make(map[string]*httpStreamPair),
}
// test adding a new entry
@ -223,7 +186,7 @@ func TestGetStreamPair(t *testing.T) {
}
func TestRequestID(t *testing.T) {
h := &portForwardStreamHandler{}
h := &httpStreamHandler{}
s := newFakeHttpStream()
s.headers.Set(api.StreamType, api.StreamTypeError)
@ -244,3 +207,40 @@ func TestRequestID(t *testing.T) {
t.Errorf("expected %q, got %q", e, a)
}
}
type fakeHttpStream struct {
headers http.Header
id uint32
}
func newFakeHttpStream() *fakeHttpStream {
return &fakeHttpStream{
headers: make(http.Header),
}
}
var _ httpstream.Stream = &fakeHttpStream{}
func (s *fakeHttpStream) Read(data []byte) (int, error) {
return 0, nil
}
func (s *fakeHttpStream) Write(data []byte) (int, error) {
return 0, nil
}
func (s *fakeHttpStream) Close() error {
return nil
}
func (s *fakeHttpStream) Reset() error {
return nil
}
func (s *fakeHttpStream) Headers() http.Header {
return s.headers
}
func (s *fakeHttpStream) Identifier() uint32 {
return s.id
}

View File

@ -17,21 +17,13 @@ limitations under the License.
package portforward
import (
"errors"
"fmt"
"io"
"net/http"
"strconv"
"sync"
"time"
"github.com/golang/glog"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/kubernetes/pkg/api"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apiserver/pkg/util/wsstream"
)
// PortForwarder knows how to forward content from a data stream to/from a port
@ -46,278 +38,16 @@ type PortForwarder interface {
// been timed out due to idleness. This function handles multiple forwarded
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
// handled by a single invocation of ServePortForward.
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
supportedPortForwardProtocols := []string{PortForwardProtocolV1Name}
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols)
// negotiated protocol isn't currently used server side, but could be in the future
if err != nil {
// Handshake writes the error to the client
utilruntime.HandleError(err)
return
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration, supportedProtocols []string) {
var err error
if wsstream.IsWebSocketRequest(req) {
err = handleWebSocketStreams(req, w, portForwarder, podName, uid, supportedProtocols, idleTimeout, streamCreationTimeout)
} else {
err = handleHttpStreams(req, w, portForwarder, podName, uid, supportedProtocols, idleTimeout, streamCreationTimeout)
}
streamChan := make(chan httpstream.Stream, 1)
glog.V(5).Infof("Upgrading port forward response")
upgrader := spdy.NewResponseUpgrader()
conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan))
if conn == nil {
return
}
defer conn.Close()
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
conn.SetIdleTimeout(idleTimeout)
h := &portForwardStreamHandler{
conn: conn,
streamChan: streamChan,
streamPairs: make(map[string]*portForwardStreamPair),
streamCreationTimeout: streamCreationTimeout,
pod: podName,
uid: uid,
forwarder: portForwarder,
}
h.run()
}
// portForwardStreamReceived is the httpstream.NewStreamHandler for port
// forward streams. It checks each stream's port and stream type headers,
// rejecting any streams that with missing or invalid values. Each valid
// stream is sent to the streams channel.
func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
// make sure it has a valid port header
portString := stream.Headers().Get(api.PortHeader)
if len(portString) == 0 {
return fmt.Errorf("%q header is required", api.PortHeader)
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
}
if port < 1 {
return fmt.Errorf("port %q must be > 0", portString)
}
// make sure it has a valid stream type header
streamType := stream.Headers().Get(api.StreamType)
if len(streamType) == 0 {
return fmt.Errorf("%q header is required", api.StreamType)
}
if streamType != api.StreamTypeError && streamType != api.StreamTypeData {
return fmt.Errorf("invalid stream type %q", streamType)
}
streams <- stream
return nil
}
}
// portForwardStreamHandler is capable of processing multiple port forward
// requests over a single httpstream.Connection.
type portForwardStreamHandler struct {
conn httpstream.Connection
streamChan chan httpstream.Stream
streamPairsLock sync.RWMutex
streamPairs map[string]*portForwardStreamPair
streamCreationTimeout time.Duration
pod string
uid types.UID
forwarder PortForwarder
}
// getStreamPair returns a portForwardStreamPair for requestID. This creates a
// new pair if one does not yet exist for the requestID. The returned bool is
// true if the pair was created.
func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()
if p, ok := h.streamPairs[requestID]; ok {
glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
return p, false
}
glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)
p := newPortForwardPair(requestID)
h.streamPairs[requestID] = p
return p, true
}
// monitorStreamPair waits for the pair to receive both its error and data
// streams, or for the timeout to expire (whichever happens first), and then
// removes the pair.
func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) {
select {
case <-timeout:
err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID)
utilruntime.HandleError(err)
p.printError(err.Error())
case <-p.complete:
glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID)
}
h.removeStreamPair(p.requestID)
}
// hasStreamPair returns a bool indicating if a stream pair for requestID
// exists.
func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool {
h.streamPairsLock.RLock()
defer h.streamPairsLock.RUnlock()
_, ok := h.streamPairs[requestID]
return ok
}
// removeStreamPair removes the stream pair identified by requestID from streamPairs.
func (h *portForwardStreamHandler) removeStreamPair(requestID string) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()
delete(h.streamPairs, requestID)
}
// requestID returns the request id for stream.
func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string {
requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
if len(requestID) == 0 {
glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
// If we get here, it's because the connection came from an older client
// that isn't generating the request id header
// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
//
// This is a best-effort attempt at supporting older clients.
//
// When there aren't concurrent new forwarded connections, each connection
// will have a pair of streams (data, error), and the stream IDs will be
// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
// the stream ID into a pseudo-request id by taking the stream type and
// using id = stream.Identifier() when the stream type is error,
// and id = stream.Identifier() - 2 when it's data.
//
// NOTE: this only works when there are not concurrent new streams from
// multiple forwarded connections; it's a best-effort attempt at supporting
// old clients that don't generate request ids. If there are concurrent
// new connections, it's possible that 1 connection gets streams whose IDs
// are not consecutive (e.g. 5 and 9 instead of 5 and 7).
streamType := stream.Headers().Get(api.StreamType)
switch streamType {
case api.StreamTypeError:
requestID = strconv.Itoa(int(stream.Identifier()))
case api.StreamTypeData:
requestID = strconv.Itoa(int(stream.Identifier()) - 2)
}
glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
}
return requestID
}
// run is the main loop for the portForwardStreamHandler. It processes new
// streams, invoking portForward for each complete stream pair. The loop exits
// when the httpstream.Connection is closed.
func (h *portForwardStreamHandler) run() {
glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn)
Loop:
for {
select {
case <-h.conn.CloseChan():
glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn)
break Loop
case stream := <-h.streamChan:
requestID := h.requestID(stream)
streamType := stream.Headers().Get(api.StreamType)
glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)
p, created := h.getStreamPair(requestID)
if created {
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
}
if complete, err := p.add(stream); err != nil {
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
utilruntime.HandleError(errors.New(msg))
p.printError(msg)
} else if complete {
go h.portForward(p)
}
}
}
}
// portForward invokes the portForwardStreamHandler's forwarder.PortForward
// function for the given stream pair.
func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) {
defer p.dataStream.Close()
defer p.errorStream.Close()
portString := p.dataStream.Headers().Get(api.PortHeader)
port, _ := strconv.ParseUint(portString, 10, 16)
glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream)
glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
if err != nil {
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err)
utilruntime.HandleError(msg)
fmt.Fprint(p.errorStream, msg.Error())
}
}
// portForwardStreamPair represents the error and data streams for a port
// forwarding request.
type portForwardStreamPair struct {
lock sync.RWMutex
requestID string
dataStream httpstream.Stream
errorStream httpstream.Stream
complete chan struct{}
}
// newPortForwardPair creates a new portForwardStreamPair.
func newPortForwardPair(requestID string) *portForwardStreamPair {
return &portForwardStreamPair{
requestID: requestID,
complete: make(chan struct{}),
}
}
// add adds the stream to the portForwardStreamPair. If the pair already
// contains a stream for the new stream's type, an error is returned. add
// returns true if both the data and error streams for this pair have been
// received.
func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) {
p.lock.Lock()
defer p.lock.Unlock()
switch stream.Headers().Get(api.StreamType) {
case api.StreamTypeError:
if p.errorStream != nil {
return false, errors.New("error stream already assigned")
}
p.errorStream = stream
case api.StreamTypeData:
if p.dataStream != nil {
return false, errors.New("data stream already assigned")
}
p.dataStream = stream
}
complete := p.errorStream != nil && p.dataStream != nil
if complete {
close(p.complete)
}
return complete, nil
}
// printError writes s to p.errorStream if p.errorStream has been set.
func (p *portForwardStreamPair) printError(s string) {
p.lock.RLock()
defer p.lock.RUnlock()
if p.errorStream != nil {
fmt.Fprint(p.errorStream, s)
runtime.HandleError(err)
return
}
}

View File

@ -0,0 +1,189 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package portforward
import (
"encoding/binary"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/golang/glog"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/apiserver/pkg/server/httplog"
"k8s.io/apiserver/pkg/util/wsstream"
"k8s.io/kubernetes/pkg/api"
)
const (
dataChannel = iota
errorChannel
v4BinaryWebsocketProtocol = "v4." + wsstream.ChannelWebSocketProtocol
v4Base64WebsocketProtocol = "v4." + wsstream.Base64ChannelWebSocketProtocol
)
// options contains details about which streams are required for
// port forwarding.
type v4Options struct {
ports []uint16
}
// newOptions creates a new options from the Request.
func newV4Options(req *http.Request) (*v4Options, error) {
portStrings := req.URL.Query()[api.PortHeader]
if len(portStrings) == 0 {
return nil, fmt.Errorf("%q is required", api.PortHeader)
}
ports := make([]uint16, 0, len(portStrings))
for _, portString := range portStrings {
if len(portString) == 0 {
return nil, fmt.Errorf("%q is cannot be empty", api.PortHeader)
}
for _, p := range strings.Split(portString, ",") {
port, err := strconv.ParseUint(p, 10, 16)
if err != nil {
return nil, fmt.Errorf("unable to parse %q as a port: %v", portString, err)
}
if port < 1 {
return nil, fmt.Errorf("port %q must be > 0", portString)
}
ports = append(ports, uint16(port))
}
}
return &v4Options{
ports: ports,
}, nil
}
func handleWebSocketStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error {
opts, err := newV4Options(req)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, err.Error())
return err
}
channels := make([]wsstream.ChannelType, 0, len(opts.ports)*2)
for i := 0; i < len(opts.ports); i++ {
channels = append(channels, wsstream.ReadWriteChannel, wsstream.WriteChannel)
}
conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{
"": {
Binary: true,
Channels: channels,
},
v4BinaryWebsocketProtocol: {
Binary: true,
Channels: channels,
},
v4Base64WebsocketProtocol: {
Binary: false,
Channels: channels,
},
})
conn.SetIdleTimeout(idleTimeout)
_, streams, err := conn.Open(httplog.Unlogged(w), req)
if err != nil {
err = fmt.Errorf("Unable to upgrade websocket connection: %v", err)
return err
}
defer conn.Close()
streamPairs := make([]*websocketStreamPair, len(opts.ports))
for i := range streamPairs {
streamPair := websocketStreamPair{
port: opts.ports[i],
dataStream: streams[i*2+dataChannel],
errorStream: streams[i*2+errorChannel],
}
streamPairs[i] = &streamPair
portBytes := make([]byte, 2)
binary.LittleEndian.PutUint16(portBytes, streamPair.port)
streamPair.dataStream.Write(portBytes)
streamPair.errorStream.Write(portBytes)
}
h := &websocketStreamHandler{
conn: conn,
streamPairs: streamPairs,
pod: podName,
uid: uid,
forwarder: portForwarder,
}
h.run()
return nil
}
// websocketStreamPair represents the error and data streams for a port
// forwarding request.
type websocketStreamPair struct {
port uint16
dataStream io.ReadWriteCloser
errorStream io.WriteCloser
}
// websocketStreamHandler is capable of processing a single port forward
// request over a websocket connection
type websocketStreamHandler struct {
conn *wsstream.Conn
ports []uint16
streamPairs []*websocketStreamPair
pod string
uid types.UID
forwarder PortForwarder
}
// run invokes the websocketStreamHandler's forwarder.PortForward
// function for the given stream pair.
func (h *websocketStreamHandler) run() {
wg := sync.WaitGroup{}
wg.Add(len(h.streamPairs))
for _, pair := range h.streamPairs {
p := pair
go func() {
defer wg.Done()
h.portForward(p)
}()
}
wg.Wait()
}
func (h *websocketStreamHandler) portForward(p *websocketStreamPair) {
defer p.dataStream.Close()
defer p.errorStream.Close()
glog.V(5).Infof("(conn=%p) invoking forwarder.PortForward for port %d", h.conn, p.port)
err := h.forwarder.PortForward(h.pod, h.uid, p.port, p.dataStream)
glog.V(5).Infof("(conn=%p) done invoking forwarder.PortForward for port %d", h.conn, p.port)
if err != nil {
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", p.port, h.pod, h.uid, err)
runtime.HandleError(msg)
fmt.Fprint(p.errorStream, msg.Error())
}
}

View File

@ -335,9 +335,15 @@ func (s *Server) InstallDebuggingHandlers(criHandler http.Handler) {
ws = new(restful.WebService)
ws.
Path("/portForward")
ws.Route(ws.GET("/{podNamespace}/{podID}").
To(s.getPortForward).
Operation("getPortForward"))
ws.Route(ws.POST("/{podNamespace}/{podID}").
To(s.getPortForward).
Operation("getPortForward"))
ws.Route(ws.GET("/{podNamespace}/{podID}/{uid}").
To(s.getPortForward).
Operation("getPortForward"))
ws.Route(ws.POST("/{podNamespace}/{podID}/{uid}").
To(s.getPortForward).
Operation("getPortForward"))
@ -720,7 +726,8 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
kubecontainer.GetPodFullName(pod),
params.podUID,
s.host.StreamingConnectionIdleTimeout(),
remotecommand.DefaultStreamCreationTimeout)
remotecommand.DefaultStreamCreationTimeout,
portforward.SupportedProtocols)
}
// ServeHTTP responds to HTTP requests on the Kubelet.

View File

@ -0,0 +1,331 @@
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package server
import (
"encoding/binary"
"fmt"
"io"
"strconv"
"sync"
"testing"
"time"
"golang.org/x/net/websocket"
"k8s.io/apimachinery/pkg/types"
)
const (
dataChannel = iota
errorChannel
)
func TestServeWSPortForward(t *testing.T) {
tests := []struct {
port string
uid bool
clientData string
containerData string
shouldError bool
}{
{port: "", shouldError: true},
{port: "abc", shouldError: true},
{port: "-1", shouldError: true},
{port: "65536", shouldError: true},
{port: "0", shouldError: true},
{port: "1", shouldError: false},
{port: "8000", shouldError: false},
{port: "8000", clientData: "client data", containerData: "container data", shouldError: false},
{port: "65535", shouldError: false},
{port: "65535", uid: true, shouldError: false},
}
podNamespace := "other"
podName := "foo"
expectedPodName := getPodName(podName, podNamespace)
expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647"
for i, test := range tests {
fw := newServerTest()
defer fw.testHTTPServer.Close()
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 0
}
portForwardFuncDone := make(chan struct{})
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
defer close(portForwardFuncDone)
if e, a := expectedPodName, name; e != a {
t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a)
}
if e, a := expectedUid, uid; test.uid && e != string(a) {
t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a)
}
p, err := strconv.ParseUint(test.port, 10, 16)
if err != nil {
t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err)
}
if e, a := uint16(p), port; e != a {
t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a)
}
if test.clientData != "" {
fromClient := make([]byte, 32)
n, err := stream.Read(fromClient)
if err != nil {
t.Fatalf("%d: error reading client data: %v", i, err)
}
if e, a := test.clientData, string(fromClient[0:n]); e != a {
t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a)
}
}
if test.containerData != "" {
_, err := stream.Write([]byte(test.containerData))
if err != nil {
t.Fatalf("%d: error writing container data: %v", i, err)
}
}
return nil
}
var url string
if test.uid {
url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, expectedUid, test.port)
} else {
url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port)
}
ws, err := websocket.Dial(url, "", "http://127.0.0.1/")
if test.shouldError {
if err == nil {
t.Fatalf("%d: websocket dial expected err", i)
}
continue
} else if err != nil {
t.Fatalf("%d: websocket dial unexpected err: %v", i, err)
}
defer ws.Close()
p, err := strconv.ParseUint(test.port, 10, 16)
if err != nil {
t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err)
}
p16 := uint16(p)
channel, data, err := wsRead(ws)
if err != nil {
t.Fatalf("%d: read failed: expected no error: got %v", i, err)
}
if channel != dataChannel {
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, dataChannel)
}
if len(data) != binary.Size(p16) {
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16))
}
if e, a := p16, binary.LittleEndian.Uint16(data); e != a {
t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port)
}
channel, data, err = wsRead(ws)
if err != nil {
t.Fatalf("%d: read succeeded: expected no error: got %v", i, err)
}
if channel != errorChannel {
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, errorChannel)
}
if len(data) != binary.Size(p16) {
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16))
}
if e, a := p16, binary.LittleEndian.Uint16(data); e != a {
t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port)
}
if test.clientData != "" {
println("writing the client data")
err := wsWrite(ws, dataChannel, []byte(test.clientData))
if err != nil {
t.Fatalf("%d: unexpected error writing client data: %v", i, err)
}
}
if test.containerData != "" {
channel, data, err = wsRead(ws)
if err != nil {
t.Fatalf("%d: unexpected error reading container data: %v", i, err)
}
if e, a := test.containerData, string(data); e != a {
t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a)
}
}
<-portForwardFuncDone
}
}
func TestServeWSMultiplePortForward(t *testing.T) {
portsText := []string{"7000,8000", "9000"}
ports := []uint16{7000, 8000, 9000}
podNamespace := "other"
podName := "foo"
expectedPodName := getPodName(podName, podNamespace)
fw := newServerTest()
defer fw.testHTTPServer.Close()
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
return 0
}
portForwardWG := sync.WaitGroup{}
portForwardWG.Add(len(ports))
portsMutex := sync.Mutex{}
portsForwarded := map[uint16]struct{}{}
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
defer portForwardWG.Done()
if e, a := expectedPodName, name; e != a {
t.Fatalf("%d: pod name: expected '%v', got '%v'", port, e, a)
}
portsMutex.Lock()
portsForwarded[port] = struct{}{}
portsMutex.Unlock()
fromClient := make([]byte, 32)
n, err := stream.Read(fromClient)
if err != nil {
t.Fatalf("%d: error reading client data: %v", port, err)
}
if e, a := fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]); e != a {
t.Fatalf("%d: client data: expected to receive '%v', got '%v'", port, e, a)
}
_, err = stream.Write([]byte(fmt.Sprintf("container data on port %d", port)))
if err != nil {
t.Fatalf("%d: error writing container data: %v", port, err)
}
return nil
}
url := fmt.Sprintf("ws://%s/portForward/%s/%s?", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName)
for _, port := range portsText {
url = url + fmt.Sprintf("port=%s&", port)
}
ws, err := websocket.Dial(url, "", "http://127.0.0.1/")
if err != nil {
t.Fatalf("websocket dial unexpected err: %v", err)
}
defer ws.Close()
for i, port := range ports {
channel, data, err := wsRead(ws)
if err != nil {
t.Fatalf("%d: read failed: expected no error: got %v", i, err)
}
if int(channel) != i*2+dataChannel {
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+dataChannel)
}
if len(data) != binary.Size(port) {
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port))
}
if e, a := port, binary.LittleEndian.Uint16(data); e != a {
t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port)
}
channel, data, err = wsRead(ws)
if err != nil {
t.Fatalf("%d: read succeeded: expected no error: got %v", i, err)
}
if int(channel) != i*2+errorChannel {
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+errorChannel)
}
if len(data) != binary.Size(port) {
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port))
}
if e, a := port, binary.LittleEndian.Uint16(data); e != a {
t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port)
}
}
for i, port := range ports {
println("writing the client data", port)
err := wsWrite(ws, byte(i*2+dataChannel), []byte(fmt.Sprintf("client data on port %d", port)))
if err != nil {
t.Fatalf("%d: unexpected error writing client data: %v", i, err)
}
channel, data, err := wsRead(ws)
if err != nil {
t.Fatalf("%d: unexpected error reading container data: %v", i, err)
}
if int(channel) != i*2+dataChannel {
t.Fatalf("%d: wrong channel: got %q: expected %q", port, channel, i*2+dataChannel)
}
if e, a := fmt.Sprintf("container data on port %d", port), string(data); e != a {
t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a)
}
}
portForwardWG.Wait()
portsMutex.Lock()
defer portsMutex.Unlock()
if len(ports) != len(portsForwarded) {
t.Fatalf("expected to forward %d ports; got %v", len(ports), portsForwarded)
}
}
func wsWrite(conn *websocket.Conn, channel byte, data []byte) error {
frame := make([]byte, len(data)+1)
frame[0] = channel
copy(frame[1:], data)
err := websocket.Message.Send(conn, frame)
return err
}
func wsRead(conn *websocket.Conn) (byte, []byte, error) {
for {
var data []byte
err := websocket.Message.Receive(conn, &data)
if err != nil {
return 0, nil, err
}
if len(data) == 0 {
continue
}
channel := data[0]
data = data[1:]
return channel, data, err
}
}

View File

@ -80,7 +80,12 @@ type Config struct {
// The streaming protocols the server supports (understands and permits). See
// k8s.io/kubernetes/pkg/kubelet/server/remotecommand/constants.go for available protocols.
// Only used for SPDY streaming.
SupportedProtocols []string
SupportedRemoteCommandProtocols []string
// The streaming protocols the server supports (understands and permits). See
// k8s.io/kubernetes/pkg/kubelet/server/portforward/constants.go for available protocols.
// Only used for SPDY streaming.
SupportedPortForwardProtocols []string
// The config for serving over TLS. If nil, TLS will not be used.
TLSConfig *tls.Config
@ -89,9 +94,10 @@ type Config struct {
// DefaultConfig provides default values for server Config. The DefaultConfig is partial, so
// some fields like Addr must still be provided.
var DefaultConfig = Config{
StreamIdleTimeout: 4 * time.Hour,
StreamCreationTimeout: remotecommand.DefaultStreamCreationTimeout,
SupportedProtocols: remotecommand.SupportedStreamingProtocols,
StreamIdleTimeout: 4 * time.Hour,
StreamCreationTimeout: remotecommand.DefaultStreamCreationTimeout,
SupportedRemoteCommandProtocols: remotecommand.SupportedStreamingProtocols,
SupportedPortForwardProtocols: portforward.SupportedProtocols,
}
// TODO(timstclair): Add auth(n/z) interface & handling.
@ -248,7 +254,7 @@ func (s *server) serveExec(req *restful.Request, resp *restful.Response) {
streamOpts,
s.config.StreamIdleTimeout,
s.config.StreamCreationTimeout,
s.config.SupportedProtocols)
s.config.SupportedRemoteCommandProtocols)
}
func (s *server) serveAttach(req *restful.Request, resp *restful.Response) {
@ -280,7 +286,7 @@ func (s *server) serveAttach(req *restful.Request, resp *restful.Response) {
streamOpts,
s.config.StreamIdleTimeout,
s.config.StreamCreationTimeout,
s.config.SupportedProtocols)
s.config.SupportedRemoteCommandProtocols)
}
func (s *server) servePortForward(req *restful.Request, resp *restful.Response) {
@ -303,7 +309,8 @@ func (s *server) servePortForward(req *restful.Request, resp *restful.Response)
pf.PodSandboxId,
"", // unused: podUID
s.config.StreamIdleTimeout,
s.config.StreamCreationTimeout)
s.config.StreamCreationTimeout,
s.config.SupportedPortForwardProtocols)
}
// criAdapter wraps the Runtime functions to conform to the remotecommand interfaces.

View File

@ -240,7 +240,7 @@ func TestServePortForward(t *testing.T) {
exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL)
require.NoError(t, err)
streamConn, _, err := exec.Dial(kubeletportforward.PortForwardProtocolV1Name)
streamConn, _, err := exec.Dial(kubeletportforward.ProtocolV1Name)
require.NoError(t, err)
defer streamConn.Close()

View File

@ -165,9 +165,10 @@ func (r *PortForwardREST) New() runtime.Object {
return &api.Pod{}
}
// NewConnectOptions returns nil since portforward doesn't take additional parameters
// NewConnectOptions returns the versioned object that represents the
// portforward parameters
func (r *PortForwardREST) NewConnectOptions() (runtime.Object, bool, string) {
return nil, false, ""
return &api.PodPortForwardOptions{}, false, ""
}
// ConnectMethods returns the methods supported by portforward
@ -177,7 +178,11 @@ func (r *PortForwardREST) ConnectMethods() []string {
// Connect returns a handler for the pod portforward proxy
func (r *PortForwardREST) Connect(ctx genericapirequest.Context, name string, opts runtime.Object, responder rest.Responder) (http.Handler, error) {
location, transport, err := pod.PortForwardLocation(r.Store, r.KubeletConn, ctx, name)
portForwardOpts, ok := opts.(*api.PodPortForwardOptions)
if !ok {
return nil, fmt.Errorf("invalid options object: %#v", opts)
}
location, transport, err := pod.PortForwardLocation(r.Store, r.KubeletConn, ctx, name, portForwardOpts)
if err != nil {
return nil, err
}

View File

@ -383,6 +383,15 @@ func streamParams(params url.Values, opts runtime.Object) error {
if opts.TTY {
params.Add(api.ExecTTYParam, "1")
}
case *api.PodPortForwardOptions:
if len(opts.Ports) == 0 {
return errors.NewBadRequest("at least one port must be specified")
}
ports := make([]string, len(opts.Ports))
for i, p := range opts.Ports {
ports[i] = strconv.FormatInt(int64(p), 10)
}
params.Add(api.PortHeader, strings.Join(ports, ","))
default:
return fmt.Errorf("Unknown object for streaming: %v", opts)
}
@ -477,6 +486,7 @@ func PortForwardLocation(
connInfo client.ConnectionInfoGetter,
ctx genericapirequest.Context,
name string,
opts *api.PodPortForwardOptions,
) (*url.URL, http.RoundTripper, error) {
pod, err := getPod(getter, ctx, name)
if err != nil {
@ -492,10 +502,15 @@ func PortForwardLocation(
if err != nil {
return nil, nil, err
}
params := url.Values{}
if err := streamParams(params, opts); err != nil {
return nil, nil, err
}
loc := &url.URL{
Scheme: nodeInfo.Scheme,
Host: net.JoinHostPort(nodeInfo.Hostname, nodeInfo.Port),
Path: fmt.Sprintf("/portForward/%s/%s", pod.Namespace, pod.Name),
Scheme: nodeInfo.Scheme,
Host: net.JoinHostPort(nodeInfo.Hostname, nodeInfo.Port),
Path: fmt.Sprintf("/portForward/%s/%s", pod.Namespace, pod.Name),
RawQuery: params.Encode(),
}
return loc, nodeInfo.Transport, nil
}

View File

@ -17,6 +17,7 @@ limitations under the License.
package pod
import (
"net/url"
"reflect"
"testing"
@ -26,9 +27,11 @@ import (
"k8s.io/apimachinery/pkg/fields"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
genericapirequest "k8s.io/apiserver/pkg/endpoints/request"
"k8s.io/kubernetes/pkg/api"
apitesting "k8s.io/kubernetes/pkg/api/testing"
"k8s.io/kubernetes/pkg/kubelet/client"
)
func TestMatchPod(t *testing.T) {
@ -333,3 +336,64 @@ func TestSelectableFieldLabelConversions(t *testing.T) {
nil,
)
}
type mockConnectionInfoGetter struct {
info *client.ConnectionInfo
}
func (g mockConnectionInfoGetter) GetConnectionInfo(nodeName types.NodeName) (*client.ConnectionInfo, error) {
return g.info, nil
}
func TestPortForwardLocation(t *testing.T) {
ctx := genericapirequest.NewDefaultContext()
tcs := []struct {
in *api.Pod
info *client.ConnectionInfo
opts *api.PodPortForwardOptions
expectedErr error
expectedURL *url.URL
}{
{
in: &api.Pod{
Spec: api.PodSpec{},
},
opts: &api.PodPortForwardOptions{},
expectedErr: errors.NewBadRequest("pod test does not have a host assigned"),
},
{
in: &api.Pod{
Spec: api.PodSpec{
NodeName: "node1",
},
},
opts: &api.PodPortForwardOptions{},
expectedErr: errors.NewBadRequest("at least one port must be specified"),
},
{
in: &api.Pod{
ObjectMeta: api.ObjectMeta{
Namespace: "ns",
Name: "pod1",
},
Spec: api.PodSpec{
NodeName: "node1",
},
},
info: &client.ConnectionInfo{},
opts: &api.PodPortForwardOptions{Ports: []int32{80}},
expectedURL: &url.URL{Host: ":", Path: "/portForward/ns/pod1", RawQuery: "port=80"},
},
}
for _, tc := range tcs {
getter := &mockPodGetter{tc.in}
connectionGetter := &mockConnectionInfoGetter{tc.info}
loc, _, err := PortForwardLocation(getter, connectionGetter, ctx, "test", tc.opts)
if !reflect.DeepEqual(err, tc.expectedErr) {
t.Errorf("expected %v, got %v", tc.expectedErr, err)
}
if !reflect.DeepEqual(loc, tc.expectedURL) {
t.Errorf("expected %v, got %v", tc.expectedURL, loc)
}
}
}

View File

@ -17,6 +17,8 @@ limitations under the License.
package e2e
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
@ -28,6 +30,7 @@ import (
"syscall"
"time"
"golang.org/x/net/websocket"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/kubernetes/pkg/api/v1"
@ -36,6 +39,7 @@ import (
testutils "k8s.io/kubernetes/test/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
const (
@ -368,36 +372,144 @@ func doTestMustConnectSendDisconnect(bindAddress string, f *framework.Framework)
verifyLogMessage(logOutput, "^Done$")
}
func doTestOverWebSockets(bindAddress string, f *framework.Framework) {
config, err := framework.LoadConfig()
Expect(err).NotTo(HaveOccurred(), "unable to get base config")
By("creating the pod")
pod := pfPod("def", "10", "10", "100", fmt.Sprintf("%s", bindAddress))
if _, err := f.ClientSet.Core().Pods(f.Namespace.Name).Create(pod); err != nil {
framework.Failf("Couldn't create pod: %v", err)
}
if err := f.WaitForPodReady(pod.Name); err != nil {
framework.Failf("Pod did not start running: %v", err)
}
defer func() {
logs, err := framework.GetPodLogs(f.ClientSet, f.Namespace.Name, pod.Name, "portforwardtester")
if err != nil {
framework.Logf("Error getting pod log: %v", err)
} else {
framework.Logf("Pod log:\n%s", logs)
}
}()
req := f.ClientSet.Core().RESTClient().Get().
Namespace(f.Namespace.Name).
Resource("pods").
Name(pod.Name).
Suffix("portforward").
Param("ports", "80")
url := req.URL()
ws, err := framework.OpenWebSocketForURL(url, config, []string{"v4.channel.k8s.io"})
if err != nil {
framework.Failf("Failed to open websocket to %s: %v", url.String(), err)
}
defer ws.Close()
Eventually(func() error {
channel, msg, err := wsRead(ws)
if err != nil {
return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err)
}
if channel != 0 {
return fmt.Errorf("Got message from server that didn't start with channel 0 (data): %v", msg)
}
if p := binary.LittleEndian.Uint16(msg); p != 80 {
return fmt.Errorf("Received the wrong port: %d", p)
}
return nil
}, time.Minute, 10*time.Second).Should(BeNil())
Eventually(func() error {
channel, msg, err := wsRead(ws)
if err != nil {
return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err)
}
if channel != 1 {
return fmt.Errorf("Got message from server that didn't start with channel 1 (error): %v", msg)
}
if p := binary.LittleEndian.Uint16(msg); p != 80 {
return fmt.Errorf("Received the wrong port: %d", p)
}
return nil
}, time.Minute, 10*time.Second).Should(BeNil())
By("sending the expected data to the local port")
err = wsWrite(ws, 0, []byte("def"))
if err != nil {
framework.Failf("Failed to write to websocket %s: %v", url.String(), err)
}
By("reading data from the local port")
buf := bytes.Buffer{}
expectedData := bytes.Repeat([]byte("x"), 100)
Eventually(func() error {
channel, msg, err := wsRead(ws)
if err != nil {
return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err)
}
if channel != 0 {
return fmt.Errorf("Got message from server that didn't start with channel 0 (data): %v", msg)
}
buf.Write(msg)
if bytes.Equal(expectedData, buf.Bytes()) {
return fmt.Errorf("Expected %q from server, got %q", expectedData, buf.Bytes())
}
return nil
}, time.Minute, 10*time.Second).Should(BeNil())
By("verifying logs")
logOutput, err := framework.GetPodLogs(f.ClientSet, f.Namespace.Name, pod.Name, "portforwardtester")
if err != nil {
framework.Failf("Error retrieving pod logs: %v", err)
}
verifyLogMessage(logOutput, "^Accepted client connection$")
verifyLogMessage(logOutput, "^Received expected client data$")
}
var _ = framework.KubeDescribe("Port forwarding", func() {
f := framework.NewDefaultFramework("port-forwarding")
framework.KubeDescribe("With a server listening on 0.0.0.0 that expects a client request", func() {
It("should support a client that connects, sends no data, and disconnects", func() {
doTestMustConnectSendNothing("0.0.0.0", f)
framework.KubeDescribe("With a server listening on 0.0.0.0", func() {
framework.KubeDescribe("that expects a client request", func() {
It("should support a client that connects, sends no data, and disconnects", func() {
doTestMustConnectSendNothing("0.0.0.0", f)
})
It("should support a client that connects, sends data, and disconnects", func() {
doTestMustConnectSendDisconnect("0.0.0.0", f)
})
})
It("should support a client that connects, sends data, and disconnects", func() {
doTestMustConnectSendDisconnect("0.0.0.0", f)
framework.KubeDescribe("that expects no client request", func() {
It("should support a client that connects, sends data, and disconnects", func() {
doTestConnectSendDisconnect("0.0.0.0", f)
})
})
It("should support forwarding over websockets", func() {
doTestOverWebSockets("0.0.0.0", f)
})
})
framework.KubeDescribe("With a server listening on 0.0.0.0 that expects no client request", func() {
It("should support a client that connects, sends data, and disconnects", func() {
doTestConnectSendDisconnect("0.0.0.0", f)
framework.KubeDescribe("With a server listening on localhost", func() {
framework.KubeDescribe("that expects a client request", func() {
It("should support a client that connects, sends no data, and disconnects [Conformance]", func() {
doTestMustConnectSendNothing("localhost", f)
})
It("should support a client that connects, sends data, and disconnects [Conformance]", func() {
doTestMustConnectSendDisconnect("localhost", f)
})
})
})
framework.KubeDescribe("With a server listening on localhost that expects a client request", func() {
It("should support a client that connects, sends no data, and disconnects [Conformance]", func() {
doTestMustConnectSendNothing("localhost", f)
framework.KubeDescribe("that expects no client request", func() {
It("should support a client that connects, sends data, and disconnects [Conformance]", func() {
doTestConnectSendDisconnect("localhost", f)
})
})
It("should support a client that connects, sends data, and disconnects [Conformance]", func() {
doTestMustConnectSendDisconnect("localhost", f)
})
})
framework.KubeDescribe("With a server listening on localhost that expects no client request", func() {
It("should support a client that connects, sends data, and disconnects [Conformance]", func() {
doTestConnectSendDisconnect("localhost", f)
It("should support forwarding over websockets", func() {
doTestOverWebSockets("localhost", f)
})
})
})
@ -412,3 +524,30 @@ func verifyLogMessage(log, expected string) {
}
framework.Failf("Missing %q from log: %s", expected, log)
}
func wsRead(conn *websocket.Conn) (byte, []byte, error) {
for {
var data []byte
err := websocket.Message.Receive(conn, &data)
if err != nil {
return 0, nil, err
}
if len(data) == 0 {
continue
}
channel := data[0]
data = data[1:]
return channel, data, err
}
}
func wsWrite(conn *websocket.Conn, channel byte, data []byte) error {
frame := make([]byte, len(data)+1)
frame[0] = channel
copy(frame[1:], data)
err := websocket.Message.Send(conn, frame)
return err
}