mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-31 23:37:01 +00:00
Port forward over websockets
- split out port forwarding into its own package Allow multiple port forwarding ports - Make it easy to determine which port is tied to which channel - odd channels are for data - even channels are for errors - allow comma separated ports to specify multiple ports Add portfowardtester 1.2 to whitelist
This commit is contained in:
parent
96cfe7b938
commit
beb53fb71a
@ -27,7 +27,7 @@ spec:
|
||||
command:
|
||||
- /bin/sh
|
||||
- -c
|
||||
- "for i in gcr.io/google_containers/busybox gcr.io/google_containers/busybox:1.24 gcr.io/google_containers/dnsutils:e2e gcr.io/google_containers/eptest:0.1 gcr.io/google_containers/fakegitserver:0.1 gcr.io/google_containers/hostexec:1.2 gcr.io/google_containers/iperf:e2e gcr.io/google_containers/jessie-dnsutils:e2e gcr.io/google_containers/liveness:e2e gcr.io/google_containers/mounttest:0.7 gcr.io/google_containers/mounttest-user:0.3 gcr.io/google_containers/netexec:1.4 gcr.io/google_containers/netexec:1.7 gcr.io/google_containers/nettest:1.7 gcr.io/google_containers/nettest:1.8 gcr.io/google_containers/nginx-slim:0.7 gcr.io/google_containers/nginx-slim:0.8 gcr.io/google_containers/n-way-http:1.0 gcr.io/google_containers/pause:2.0 gcr.io/google_containers/pause-amd64:3.0 gcr.io/google_containers/porter:cd5cb5791ebaa8641955f0e8c2a9bed669b1eaab gcr.io/google_containers/portforwardtester:1.0 gcr.io/google_containers/redis:e2e gcr.io/google_containers/resource_consumer:beta4 gcr.io/google_containers/resource_consumer/controller:beta4 gcr.io/google_containers/serve_hostname:v1.4 gcr.io/google_containers/test-webserver:e2e gcr.io/google_containers/ubuntu:14.04 gcr.io/google_containers/update-demo:kitten gcr.io/google_containers/update-demo:nautilus gcr.io/google_containers/volume-ceph:0.1 gcr.io/google_containers/volume-gluster:0.2 gcr.io/google_containers/volume-iscsi:0.1 gcr.io/google_containers/volume-nfs:0.6 gcr.io/google_containers/volume-rbd:0.1 gcr.io/google_samples/gb-redisslave:v1 gcr.io/google_containers/redis:v1; do echo $(date '+%X') pulling $i; docker pull $i 1>/dev/null; done; exit 0;"
|
||||
- "for i in gcr.io/google_containers/busybox gcr.io/google_containers/busybox:1.24 gcr.io/google_containers/dnsutils:e2e gcr.io/google_containers/eptest:0.1 gcr.io/google_containers/fakegitserver:0.1 gcr.io/google_containers/hostexec:1.2 gcr.io/google_containers/iperf:e2e gcr.io/google_containers/jessie-dnsutils:e2e gcr.io/google_containers/liveness:e2e gcr.io/google_containers/mounttest:0.7 gcr.io/google_containers/mounttest-user:0.3 gcr.io/google_containers/netexec:1.4 gcr.io/google_containers/netexec:1.7 gcr.io/google_containers/nettest:1.7 gcr.io/google_containers/nettest:1.8 gcr.io/google_containers/nginx-slim:0.7 gcr.io/google_containers/nginx-slim:0.8 gcr.io/google_containers/n-way-http:1.0 gcr.io/google_containers/pause:2.0 gcr.io/google_containers/pause-amd64:3.0 gcr.io/google_containers/porter:cd5cb5791ebaa8641955f0e8c2a9bed669b1eaab gcr.io/google_containers/portforwardtester:1.2 gcr.io/google_containers/redis:e2e gcr.io/google_containers/resource_consumer:beta4 gcr.io/google_containers/resource_consumer/controller:beta4 gcr.io/google_containers/serve_hostname:v1.4 gcr.io/google_containers/test-webserver:e2e gcr.io/google_containers/ubuntu:14.04 gcr.io/google_containers/update-demo:kitten gcr.io/google_containers/update-demo:nautilus gcr.io/google_containers/volume-ceph:0.1 gcr.io/google_containers/volume-gluster:0.2 gcr.io/google_containers/volume-iscsi:0.1 gcr.io/google_containers/volume-nfs:0.6 gcr.io/google_containers/volume-rbd:0.1 gcr.io/google_samples/gb-redisslave:v1 gcr.io/google_containers/redis:v1; do echo $(date '+%X') pulling $i; docker pull $i 1>/dev/null; done; exit 0;"
|
||||
securityContext:
|
||||
privileged: true
|
||||
volumeMounts:
|
||||
|
@ -84,7 +84,7 @@ func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedF
|
||||
received: make(map[uint16]string),
|
||||
send: serverSends,
|
||||
}
|
||||
portforward.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second)
|
||||
portforward.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second, portforward.SupportedProtocols)
|
||||
|
||||
for port, expected := range expectedFromClient {
|
||||
actual, ok := pf.received[port]
|
||||
|
@ -1047,10 +1047,14 @@ func typeToJSON(typeName string) string {
|
||||
return "string"
|
||||
case "byte", "*byte":
|
||||
return "string"
|
||||
|
||||
// TODO: Fix these when go-restful supports a way to specify an array query param:
|
||||
// https://github.com/emicklei/go-restful/issues/225
|
||||
case "[]string", "[]*string":
|
||||
// TODO: Fix this when go-restful supports a way to specify an array query param:
|
||||
// https://github.com/emicklei/go-restful/issues/225
|
||||
return "string"
|
||||
case "[]int32", "[]*int32":
|
||||
return "integer"
|
||||
|
||||
default:
|
||||
return typeName
|
||||
}
|
||||
|
@ -2171,9 +2171,10 @@ func getStreamingConfig(kubeCfg *componentconfig.KubeletConfiguration, kubeDeps
|
||||
BaseURL: &url.URL{
|
||||
Path: "/cri/",
|
||||
},
|
||||
StreamIdleTimeout: kubeCfg.StreamingConnectionIdleTimeout.Duration,
|
||||
StreamCreationTimeout: streaming.DefaultConfig.StreamCreationTimeout,
|
||||
SupportedProtocols: streaming.DefaultConfig.SupportedProtocols,
|
||||
StreamIdleTimeout: kubeCfg.StreamingConnectionIdleTimeout.Duration,
|
||||
StreamCreationTimeout: streaming.DefaultConfig.StreamCreationTimeout,
|
||||
SupportedRemoteCommandProtocols: streaming.DefaultConfig.SupportedRemoteCommandProtocols,
|
||||
SupportedPortForwardProtocols: streaming.DefaultConfig.SupportedPortForwardProtocols,
|
||||
}
|
||||
if kubeDeps.TLSOptions != nil {
|
||||
config.TLSConfig = kubeDeps.TLSOptions.Config
|
||||
|
@ -18,4 +18,6 @@ limitations under the License.
|
||||
package portforward
|
||||
|
||||
// The subprotocol "portforward.k8s.io" is used for port forwarding.
|
||||
const PortForwardProtocolV1Name = "portforward.k8s.io"
|
||||
const ProtocolV1Name = "portforward.k8s.io"
|
||||
|
||||
var SupportedProtocols = []string{ProtocolV1Name}
|
||||
|
309
pkg/kubelet/server/portforward/httpstream.go
Normal file
309
pkg/kubelet/server/portforward/httpstream.go
Normal file
@ -0,0 +1,309 @@
|
||||
/*
|
||||
Copyright 2016 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
|
||||
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
|
||||
"k8s.io/kubernetes/pkg/api"
|
||||
|
||||
"github.com/golang/glog"
|
||||
)
|
||||
|
||||
func handleHttpStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error {
|
||||
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols)
|
||||
// negotiated protocol isn't currently used server side, but could be in the future
|
||||
if err != nil {
|
||||
// Handshake writes the error to the client
|
||||
return err
|
||||
}
|
||||
streamChan := make(chan httpstream.Stream, 1)
|
||||
|
||||
glog.V(5).Infof("Upgrading port forward response")
|
||||
upgrader := spdy.NewResponseUpgrader()
|
||||
conn := upgrader.UpgradeResponse(w, req, httpStreamReceived(streamChan))
|
||||
if conn == nil {
|
||||
return errors.New("Unable to upgrade websocket connection")
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
|
||||
conn.SetIdleTimeout(idleTimeout)
|
||||
|
||||
h := &httpStreamHandler{
|
||||
conn: conn,
|
||||
streamChan: streamChan,
|
||||
streamPairs: make(map[string]*httpStreamPair),
|
||||
streamCreationTimeout: streamCreationTimeout,
|
||||
pod: podName,
|
||||
uid: uid,
|
||||
forwarder: portForwarder,
|
||||
}
|
||||
h.run()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// httpStreamReceived is the httpstream.NewStreamHandler for port
|
||||
// forward streams. It checks each stream's port and stream type headers,
|
||||
// rejecting any streams that with missing or invalid values. Each valid
|
||||
// stream is sent to the streams channel.
|
||||
func httpStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
|
||||
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
||||
// make sure it has a valid port header
|
||||
portString := stream.Headers().Get(api.PortHeader)
|
||||
if len(portString) == 0 {
|
||||
return fmt.Errorf("%q header is required", api.PortHeader)
|
||||
}
|
||||
port, err := strconv.ParseUint(portString, 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
|
||||
}
|
||||
if port < 1 {
|
||||
return fmt.Errorf("port %q must be > 0", portString)
|
||||
}
|
||||
|
||||
// make sure it has a valid stream type header
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
if len(streamType) == 0 {
|
||||
return fmt.Errorf("%q header is required", api.StreamType)
|
||||
}
|
||||
if streamType != api.StreamTypeError && streamType != api.StreamTypeData {
|
||||
return fmt.Errorf("invalid stream type %q", streamType)
|
||||
}
|
||||
|
||||
streams <- stream
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// httpStreamHandler is capable of processing multiple port forward
|
||||
// requests over a single httpstream.Connection.
|
||||
type httpStreamHandler struct {
|
||||
conn httpstream.Connection
|
||||
streamChan chan httpstream.Stream
|
||||
streamPairsLock sync.RWMutex
|
||||
streamPairs map[string]*httpStreamPair
|
||||
streamCreationTimeout time.Duration
|
||||
pod string
|
||||
uid types.UID
|
||||
forwarder PortForwarder
|
||||
}
|
||||
|
||||
// getStreamPair returns a httpStreamPair for requestID. This creates a
|
||||
// new pair if one does not yet exist for the requestID. The returned bool is
|
||||
// true if the pair was created.
|
||||
func (h *httpStreamHandler) getStreamPair(requestID string) (*httpStreamPair, bool) {
|
||||
h.streamPairsLock.Lock()
|
||||
defer h.streamPairsLock.Unlock()
|
||||
|
||||
if p, ok := h.streamPairs[requestID]; ok {
|
||||
glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
|
||||
return p, false
|
||||
}
|
||||
|
||||
glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)
|
||||
|
||||
p := newPortForwardPair(requestID)
|
||||
h.streamPairs[requestID] = p
|
||||
|
||||
return p, true
|
||||
}
|
||||
|
||||
// monitorStreamPair waits for the pair to receive both its error and data
|
||||
// streams, or for the timeout to expire (whichever happens first), and then
|
||||
// removes the pair.
|
||||
func (h *httpStreamHandler) monitorStreamPair(p *httpStreamPair, timeout <-chan time.Time) {
|
||||
select {
|
||||
case <-timeout:
|
||||
err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID)
|
||||
utilruntime.HandleError(err)
|
||||
p.printError(err.Error())
|
||||
case <-p.complete:
|
||||
glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID)
|
||||
}
|
||||
h.removeStreamPair(p.requestID)
|
||||
}
|
||||
|
||||
// hasStreamPair returns a bool indicating if a stream pair for requestID
|
||||
// exists.
|
||||
func (h *httpStreamHandler) hasStreamPair(requestID string) bool {
|
||||
h.streamPairsLock.RLock()
|
||||
defer h.streamPairsLock.RUnlock()
|
||||
|
||||
_, ok := h.streamPairs[requestID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// removeStreamPair removes the stream pair identified by requestID from streamPairs.
|
||||
func (h *httpStreamHandler) removeStreamPair(requestID string) {
|
||||
h.streamPairsLock.Lock()
|
||||
defer h.streamPairsLock.Unlock()
|
||||
|
||||
delete(h.streamPairs, requestID)
|
||||
}
|
||||
|
||||
// requestID returns the request id for stream.
|
||||
func (h *httpStreamHandler) requestID(stream httpstream.Stream) string {
|
||||
requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
|
||||
if len(requestID) == 0 {
|
||||
glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
|
||||
// If we get here, it's because the connection came from an older client
|
||||
// that isn't generating the request id header
|
||||
// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
|
||||
//
|
||||
// This is a best-effort attempt at supporting older clients.
|
||||
//
|
||||
// When there aren't concurrent new forwarded connections, each connection
|
||||
// will have a pair of streams (data, error), and the stream IDs will be
|
||||
// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
|
||||
// the stream ID into a pseudo-request id by taking the stream type and
|
||||
// using id = stream.Identifier() when the stream type is error,
|
||||
// and id = stream.Identifier() - 2 when it's data.
|
||||
//
|
||||
// NOTE: this only works when there are not concurrent new streams from
|
||||
// multiple forwarded connections; it's a best-effort attempt at supporting
|
||||
// old clients that don't generate request ids. If there are concurrent
|
||||
// new connections, it's possible that 1 connection gets streams whose IDs
|
||||
// are not consecutive (e.g. 5 and 9 instead of 5 and 7).
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
switch streamType {
|
||||
case api.StreamTypeError:
|
||||
requestID = strconv.Itoa(int(stream.Identifier()))
|
||||
case api.StreamTypeData:
|
||||
requestID = strconv.Itoa(int(stream.Identifier()) - 2)
|
||||
}
|
||||
|
||||
glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
|
||||
}
|
||||
return requestID
|
||||
}
|
||||
|
||||
// run is the main loop for the httpStreamHandler. It processes new
|
||||
// streams, invoking portForward for each complete stream pair. The loop exits
|
||||
// when the httpstream.Connection is closed.
|
||||
func (h *httpStreamHandler) run() {
|
||||
glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn)
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case <-h.conn.CloseChan():
|
||||
glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn)
|
||||
break Loop
|
||||
case stream := <-h.streamChan:
|
||||
requestID := h.requestID(stream)
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)
|
||||
|
||||
p, created := h.getStreamPair(requestID)
|
||||
if created {
|
||||
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
|
||||
}
|
||||
if complete, err := p.add(stream); err != nil {
|
||||
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
|
||||
utilruntime.HandleError(errors.New(msg))
|
||||
p.printError(msg)
|
||||
} else if complete {
|
||||
go h.portForward(p)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// portForward invokes the httpStreamHandler's forwarder.PortForward
|
||||
// function for the given stream pair.
|
||||
func (h *httpStreamHandler) portForward(p *httpStreamPair) {
|
||||
defer p.dataStream.Close()
|
||||
defer p.errorStream.Close()
|
||||
|
||||
portString := p.dataStream.Headers().Get(api.PortHeader)
|
||||
port, _ := strconv.ParseUint(portString, 10, 16)
|
||||
|
||||
glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
|
||||
err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream)
|
||||
glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
|
||||
|
||||
if err != nil {
|
||||
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err)
|
||||
utilruntime.HandleError(msg)
|
||||
fmt.Fprint(p.errorStream, msg.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// httpStreamPair represents the error and data streams for a port
|
||||
// forwarding request.
|
||||
type httpStreamPair struct {
|
||||
lock sync.RWMutex
|
||||
requestID string
|
||||
dataStream httpstream.Stream
|
||||
errorStream httpstream.Stream
|
||||
complete chan struct{}
|
||||
}
|
||||
|
||||
// newPortForwardPair creates a new httpStreamPair.
|
||||
func newPortForwardPair(requestID string) *httpStreamPair {
|
||||
return &httpStreamPair{
|
||||
requestID: requestID,
|
||||
complete: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// add adds the stream to the httpStreamPair. If the pair already
|
||||
// contains a stream for the new stream's type, an error is returned. add
|
||||
// returns true if both the data and error streams for this pair have been
|
||||
// received.
|
||||
func (p *httpStreamPair) add(stream httpstream.Stream) (bool, error) {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
switch stream.Headers().Get(api.StreamType) {
|
||||
case api.StreamTypeError:
|
||||
if p.errorStream != nil {
|
||||
return false, errors.New("error stream already assigned")
|
||||
}
|
||||
p.errorStream = stream
|
||||
case api.StreamTypeData:
|
||||
if p.dataStream != nil {
|
||||
return false, errors.New("data stream already assigned")
|
||||
}
|
||||
p.dataStream = stream
|
||||
}
|
||||
|
||||
complete := p.errorStream != nil && p.dataStream != nil
|
||||
if complete {
|
||||
close(p.complete)
|
||||
}
|
||||
return complete, nil
|
||||
}
|
||||
|
||||
// printError writes s to p.errorStream if p.errorStream has been set.
|
||||
func (p *httpStreamPair) printError(s string) {
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
if p.errorStream != nil {
|
||||
fmt.Fprint(p.errorStream, s)
|
||||
}
|
||||
}
|
@ -25,7 +25,7 @@ import (
|
||||
"k8s.io/kubernetes/pkg/api"
|
||||
)
|
||||
|
||||
func TestPortForwardStreamReceived(t *testing.T) {
|
||||
func TestHTTPStreamReceived(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
port string
|
||||
streamType string
|
||||
@ -62,7 +62,7 @@ func TestPortForwardStreamReceived(t *testing.T) {
|
||||
}
|
||||
for name, test := range tests {
|
||||
streams := make(chan httpstream.Stream, 1)
|
||||
f := portForwardStreamReceived(streams)
|
||||
f := httpStreamReceived(streams)
|
||||
stream := newFakeHttpStream()
|
||||
if len(test.port) > 0 {
|
||||
stream.headers.Set("port", test.port)
|
||||
@ -92,48 +92,11 @@ func TestPortForwardStreamReceived(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type fakeHttpStream struct {
|
||||
headers http.Header
|
||||
id uint32
|
||||
}
|
||||
|
||||
func newFakeHttpStream() *fakeHttpStream {
|
||||
return &fakeHttpStream{
|
||||
headers: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
var _ httpstream.Stream = &fakeHttpStream{}
|
||||
|
||||
func (s *fakeHttpStream) Read(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Write(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Headers() http.Header {
|
||||
return s.headers
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Identifier() uint32 {
|
||||
return s.id
|
||||
}
|
||||
|
||||
func TestGetStreamPair(t *testing.T) {
|
||||
timeout := make(chan time.Time)
|
||||
|
||||
h := &portForwardStreamHandler{
|
||||
streamPairs: make(map[string]*portForwardStreamPair),
|
||||
h := &httpStreamHandler{
|
||||
streamPairs: make(map[string]*httpStreamPair),
|
||||
}
|
||||
|
||||
// test adding a new entry
|
||||
@ -223,7 +186,7 @@ func TestGetStreamPair(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRequestID(t *testing.T) {
|
||||
h := &portForwardStreamHandler{}
|
||||
h := &httpStreamHandler{}
|
||||
|
||||
s := newFakeHttpStream()
|
||||
s.headers.Set(api.StreamType, api.StreamTypeError)
|
||||
@ -244,3 +207,40 @@ func TestRequestID(t *testing.T) {
|
||||
t.Errorf("expected %q, got %q", e, a)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeHttpStream struct {
|
||||
headers http.Header
|
||||
id uint32
|
||||
}
|
||||
|
||||
func newFakeHttpStream() *fakeHttpStream {
|
||||
return &fakeHttpStream{
|
||||
headers: make(http.Header),
|
||||
}
|
||||
}
|
||||
|
||||
var _ httpstream.Stream = &fakeHttpStream{}
|
||||
|
||||
func (s *fakeHttpStream) Read(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Write(data []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Headers() http.Header {
|
||||
return s.headers
|
||||
}
|
||||
|
||||
func (s *fakeHttpStream) Identifier() uint32 {
|
||||
return s.id
|
||||
}
|
@ -17,21 +17,13 @@ limitations under the License.
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/glog"
|
||||
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||
"k8s.io/apimachinery/pkg/util/httpstream/spdy"
|
||||
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
|
||||
"k8s.io/kubernetes/pkg/api"
|
||||
"k8s.io/apimachinery/pkg/util/runtime"
|
||||
"k8s.io/apiserver/pkg/util/wsstream"
|
||||
)
|
||||
|
||||
// PortForwarder knows how to forward content from a data stream to/from a port
|
||||
@ -46,278 +38,16 @@ type PortForwarder interface {
|
||||
// been timed out due to idleness. This function handles multiple forwarded
|
||||
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
|
||||
// handled by a single invocation of ServePortForward.
|
||||
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
|
||||
supportedPortForwardProtocols := []string{PortForwardProtocolV1Name}
|
||||
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols)
|
||||
// negotiated protocol isn't currently used server side, but could be in the future
|
||||
if err != nil {
|
||||
// Handshake writes the error to the client
|
||||
utilruntime.HandleError(err)
|
||||
return
|
||||
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration, supportedProtocols []string) {
|
||||
var err error
|
||||
if wsstream.IsWebSocketRequest(req) {
|
||||
err = handleWebSocketStreams(req, w, portForwarder, podName, uid, supportedProtocols, idleTimeout, streamCreationTimeout)
|
||||
} else {
|
||||
err = handleHttpStreams(req, w, portForwarder, podName, uid, supportedProtocols, idleTimeout, streamCreationTimeout)
|
||||
}
|
||||
|
||||
streamChan := make(chan httpstream.Stream, 1)
|
||||
|
||||
glog.V(5).Infof("Upgrading port forward response")
|
||||
upgrader := spdy.NewResponseUpgrader()
|
||||
conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan))
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout)
|
||||
conn.SetIdleTimeout(idleTimeout)
|
||||
|
||||
h := &portForwardStreamHandler{
|
||||
conn: conn,
|
||||
streamChan: streamChan,
|
||||
streamPairs: make(map[string]*portForwardStreamPair),
|
||||
streamCreationTimeout: streamCreationTimeout,
|
||||
pod: podName,
|
||||
uid: uid,
|
||||
forwarder: portForwarder,
|
||||
}
|
||||
h.run()
|
||||
}
|
||||
|
||||
// portForwardStreamReceived is the httpstream.NewStreamHandler for port
|
||||
// forward streams. It checks each stream's port and stream type headers,
|
||||
// rejecting any streams that with missing or invalid values. Each valid
|
||||
// stream is sent to the streams channel.
|
||||
func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
|
||||
return func(stream httpstream.Stream, replySent <-chan struct{}) error {
|
||||
// make sure it has a valid port header
|
||||
portString := stream.Headers().Get(api.PortHeader)
|
||||
if len(portString) == 0 {
|
||||
return fmt.Errorf("%q header is required", api.PortHeader)
|
||||
}
|
||||
port, err := strconv.ParseUint(portString, 10, 16)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse %q as a port: %v", portString, err)
|
||||
}
|
||||
if port < 1 {
|
||||
return fmt.Errorf("port %q must be > 0", portString)
|
||||
}
|
||||
|
||||
// make sure it has a valid stream type header
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
if len(streamType) == 0 {
|
||||
return fmt.Errorf("%q header is required", api.StreamType)
|
||||
}
|
||||
if streamType != api.StreamTypeError && streamType != api.StreamTypeData {
|
||||
return fmt.Errorf("invalid stream type %q", streamType)
|
||||
}
|
||||
|
||||
streams <- stream
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// portForwardStreamHandler is capable of processing multiple port forward
|
||||
// requests over a single httpstream.Connection.
|
||||
type portForwardStreamHandler struct {
|
||||
conn httpstream.Connection
|
||||
streamChan chan httpstream.Stream
|
||||
streamPairsLock sync.RWMutex
|
||||
streamPairs map[string]*portForwardStreamPair
|
||||
streamCreationTimeout time.Duration
|
||||
pod string
|
||||
uid types.UID
|
||||
forwarder PortForwarder
|
||||
}
|
||||
|
||||
// getStreamPair returns a portForwardStreamPair for requestID. This creates a
|
||||
// new pair if one does not yet exist for the requestID. The returned bool is
|
||||
// true if the pair was created.
|
||||
func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) {
|
||||
h.streamPairsLock.Lock()
|
||||
defer h.streamPairsLock.Unlock()
|
||||
|
||||
if p, ok := h.streamPairs[requestID]; ok {
|
||||
glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID)
|
||||
return p, false
|
||||
}
|
||||
|
||||
glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID)
|
||||
|
||||
p := newPortForwardPair(requestID)
|
||||
h.streamPairs[requestID] = p
|
||||
|
||||
return p, true
|
||||
}
|
||||
|
||||
// monitorStreamPair waits for the pair to receive both its error and data
|
||||
// streams, or for the timeout to expire (whichever happens first), and then
|
||||
// removes the pair.
|
||||
func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) {
|
||||
select {
|
||||
case <-timeout:
|
||||
err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID)
|
||||
utilruntime.HandleError(err)
|
||||
p.printError(err.Error())
|
||||
case <-p.complete:
|
||||
glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID)
|
||||
}
|
||||
h.removeStreamPair(p.requestID)
|
||||
}
|
||||
|
||||
// hasStreamPair returns a bool indicating if a stream pair for requestID
|
||||
// exists.
|
||||
func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool {
|
||||
h.streamPairsLock.RLock()
|
||||
defer h.streamPairsLock.RUnlock()
|
||||
|
||||
_, ok := h.streamPairs[requestID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// removeStreamPair removes the stream pair identified by requestID from streamPairs.
|
||||
func (h *portForwardStreamHandler) removeStreamPair(requestID string) {
|
||||
h.streamPairsLock.Lock()
|
||||
defer h.streamPairsLock.Unlock()
|
||||
|
||||
delete(h.streamPairs, requestID)
|
||||
}
|
||||
|
||||
// requestID returns the request id for stream.
|
||||
func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string {
|
||||
requestID := stream.Headers().Get(api.PortForwardRequestIDHeader)
|
||||
if len(requestID) == 0 {
|
||||
glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader)
|
||||
// If we get here, it's because the connection came from an older client
|
||||
// that isn't generating the request id header
|
||||
// (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287)
|
||||
//
|
||||
// This is a best-effort attempt at supporting older clients.
|
||||
//
|
||||
// When there aren't concurrent new forwarded connections, each connection
|
||||
// will have a pair of streams (data, error), and the stream IDs will be
|
||||
// consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert
|
||||
// the stream ID into a pseudo-request id by taking the stream type and
|
||||
// using id = stream.Identifier() when the stream type is error,
|
||||
// and id = stream.Identifier() - 2 when it's data.
|
||||
//
|
||||
// NOTE: this only works when there are not concurrent new streams from
|
||||
// multiple forwarded connections; it's a best-effort attempt at supporting
|
||||
// old clients that don't generate request ids. If there are concurrent
|
||||
// new connections, it's possible that 1 connection gets streams whose IDs
|
||||
// are not consecutive (e.g. 5 and 9 instead of 5 and 7).
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
switch streamType {
|
||||
case api.StreamTypeError:
|
||||
requestID = strconv.Itoa(int(stream.Identifier()))
|
||||
case api.StreamTypeData:
|
||||
requestID = strconv.Itoa(int(stream.Identifier()) - 2)
|
||||
}
|
||||
|
||||
glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier())
|
||||
}
|
||||
return requestID
|
||||
}
|
||||
|
||||
// run is the main loop for the portForwardStreamHandler. It processes new
|
||||
// streams, invoking portForward for each complete stream pair. The loop exits
|
||||
// when the httpstream.Connection is closed.
|
||||
func (h *portForwardStreamHandler) run() {
|
||||
glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn)
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case <-h.conn.CloseChan():
|
||||
glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn)
|
||||
break Loop
|
||||
case stream := <-h.streamChan:
|
||||
requestID := h.requestID(stream)
|
||||
streamType := stream.Headers().Get(api.StreamType)
|
||||
glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType)
|
||||
|
||||
p, created := h.getStreamPair(requestID)
|
||||
if created {
|
||||
go h.monitorStreamPair(p, time.After(h.streamCreationTimeout))
|
||||
}
|
||||
if complete, err := p.add(stream); err != nil {
|
||||
msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err)
|
||||
utilruntime.HandleError(errors.New(msg))
|
||||
p.printError(msg)
|
||||
} else if complete {
|
||||
go h.portForward(p)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// portForward invokes the portForwardStreamHandler's forwarder.PortForward
|
||||
// function for the given stream pair.
|
||||
func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) {
|
||||
defer p.dataStream.Close()
|
||||
defer p.errorStream.Close()
|
||||
|
||||
portString := p.dataStream.Headers().Get(api.PortHeader)
|
||||
port, _ := strconv.ParseUint(portString, 10, 16)
|
||||
|
||||
glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
|
||||
err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream)
|
||||
glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString)
|
||||
|
||||
if err != nil {
|
||||
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err)
|
||||
utilruntime.HandleError(msg)
|
||||
fmt.Fprint(p.errorStream, msg.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// portForwardStreamPair represents the error and data streams for a port
|
||||
// forwarding request.
|
||||
type portForwardStreamPair struct {
|
||||
lock sync.RWMutex
|
||||
requestID string
|
||||
dataStream httpstream.Stream
|
||||
errorStream httpstream.Stream
|
||||
complete chan struct{}
|
||||
}
|
||||
|
||||
// newPortForwardPair creates a new portForwardStreamPair.
|
||||
func newPortForwardPair(requestID string) *portForwardStreamPair {
|
||||
return &portForwardStreamPair{
|
||||
requestID: requestID,
|
||||
complete: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// add adds the stream to the portForwardStreamPair. If the pair already
|
||||
// contains a stream for the new stream's type, an error is returned. add
|
||||
// returns true if both the data and error streams for this pair have been
|
||||
// received.
|
||||
func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
switch stream.Headers().Get(api.StreamType) {
|
||||
case api.StreamTypeError:
|
||||
if p.errorStream != nil {
|
||||
return false, errors.New("error stream already assigned")
|
||||
}
|
||||
p.errorStream = stream
|
||||
case api.StreamTypeData:
|
||||
if p.dataStream != nil {
|
||||
return false, errors.New("data stream already assigned")
|
||||
}
|
||||
p.dataStream = stream
|
||||
}
|
||||
|
||||
complete := p.errorStream != nil && p.dataStream != nil
|
||||
if complete {
|
||||
close(p.complete)
|
||||
}
|
||||
return complete, nil
|
||||
}
|
||||
|
||||
// printError writes s to p.errorStream if p.errorStream has been set.
|
||||
func (p *portForwardStreamPair) printError(s string) {
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
if p.errorStream != nil {
|
||||
fmt.Fprint(p.errorStream, s)
|
||||
runtime.HandleError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
189
pkg/kubelet/server/portforward/websocket.go
Normal file
189
pkg/kubelet/server/portforward/websocket.go
Normal file
@ -0,0 +1,189 @@
|
||||
/*
|
||||
Copyright 2016 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package portforward
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang/glog"
|
||||
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
"k8s.io/apimachinery/pkg/util/runtime"
|
||||
"k8s.io/apiserver/pkg/server/httplog"
|
||||
"k8s.io/apiserver/pkg/util/wsstream"
|
||||
"k8s.io/kubernetes/pkg/api"
|
||||
)
|
||||
|
||||
const (
|
||||
dataChannel = iota
|
||||
errorChannel
|
||||
|
||||
v4BinaryWebsocketProtocol = "v4." + wsstream.ChannelWebSocketProtocol
|
||||
v4Base64WebsocketProtocol = "v4." + wsstream.Base64ChannelWebSocketProtocol
|
||||
)
|
||||
|
||||
// options contains details about which streams are required for
|
||||
// port forwarding.
|
||||
type v4Options struct {
|
||||
ports []uint16
|
||||
}
|
||||
|
||||
// newOptions creates a new options from the Request.
|
||||
func newV4Options(req *http.Request) (*v4Options, error) {
|
||||
portStrings := req.URL.Query()[api.PortHeader]
|
||||
if len(portStrings) == 0 {
|
||||
return nil, fmt.Errorf("%q is required", api.PortHeader)
|
||||
}
|
||||
|
||||
ports := make([]uint16, 0, len(portStrings))
|
||||
for _, portString := range portStrings {
|
||||
if len(portString) == 0 {
|
||||
return nil, fmt.Errorf("%q is cannot be empty", api.PortHeader)
|
||||
}
|
||||
for _, p := range strings.Split(portString, ",") {
|
||||
port, err := strconv.ParseUint(p, 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse %q as a port: %v", portString, err)
|
||||
}
|
||||
if port < 1 {
|
||||
return nil, fmt.Errorf("port %q must be > 0", portString)
|
||||
}
|
||||
ports = append(ports, uint16(port))
|
||||
}
|
||||
}
|
||||
|
||||
return &v4Options{
|
||||
ports: ports,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func handleWebSocketStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error {
|
||||
opts, err := newV4Options(req)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
fmt.Fprint(w, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
channels := make([]wsstream.ChannelType, 0, len(opts.ports)*2)
|
||||
for i := 0; i < len(opts.ports); i++ {
|
||||
channels = append(channels, wsstream.ReadWriteChannel, wsstream.WriteChannel)
|
||||
}
|
||||
conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{
|
||||
"": {
|
||||
Binary: true,
|
||||
Channels: channels,
|
||||
},
|
||||
v4BinaryWebsocketProtocol: {
|
||||
Binary: true,
|
||||
Channels: channels,
|
||||
},
|
||||
v4Base64WebsocketProtocol: {
|
||||
Binary: false,
|
||||
Channels: channels,
|
||||
},
|
||||
})
|
||||
conn.SetIdleTimeout(idleTimeout)
|
||||
_, streams, err := conn.Open(httplog.Unlogged(w), req)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Unable to upgrade websocket connection: %v", err)
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
streamPairs := make([]*websocketStreamPair, len(opts.ports))
|
||||
for i := range streamPairs {
|
||||
streamPair := websocketStreamPair{
|
||||
port: opts.ports[i],
|
||||
dataStream: streams[i*2+dataChannel],
|
||||
errorStream: streams[i*2+errorChannel],
|
||||
}
|
||||
streamPairs[i] = &streamPair
|
||||
|
||||
portBytes := make([]byte, 2)
|
||||
binary.LittleEndian.PutUint16(portBytes, streamPair.port)
|
||||
streamPair.dataStream.Write(portBytes)
|
||||
streamPair.errorStream.Write(portBytes)
|
||||
}
|
||||
h := &websocketStreamHandler{
|
||||
conn: conn,
|
||||
streamPairs: streamPairs,
|
||||
pod: podName,
|
||||
uid: uid,
|
||||
forwarder: portForwarder,
|
||||
}
|
||||
h.run()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// websocketStreamPair represents the error and data streams for a port
|
||||
// forwarding request.
|
||||
type websocketStreamPair struct {
|
||||
port uint16
|
||||
dataStream io.ReadWriteCloser
|
||||
errorStream io.WriteCloser
|
||||
}
|
||||
|
||||
// websocketStreamHandler is capable of processing a single port forward
|
||||
// request over a websocket connection
|
||||
type websocketStreamHandler struct {
|
||||
conn *wsstream.Conn
|
||||
ports []uint16
|
||||
streamPairs []*websocketStreamPair
|
||||
pod string
|
||||
uid types.UID
|
||||
forwarder PortForwarder
|
||||
}
|
||||
|
||||
// run invokes the websocketStreamHandler's forwarder.PortForward
|
||||
// function for the given stream pair.
|
||||
func (h *websocketStreamHandler) run() {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(h.streamPairs))
|
||||
|
||||
for _, pair := range h.streamPairs {
|
||||
p := pair
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
h.portForward(p)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (h *websocketStreamHandler) portForward(p *websocketStreamPair) {
|
||||
defer p.dataStream.Close()
|
||||
defer p.errorStream.Close()
|
||||
|
||||
glog.V(5).Infof("(conn=%p) invoking forwarder.PortForward for port %d", h.conn, p.port)
|
||||
err := h.forwarder.PortForward(h.pod, h.uid, p.port, p.dataStream)
|
||||
glog.V(5).Infof("(conn=%p) done invoking forwarder.PortForward for port %d", h.conn, p.port)
|
||||
|
||||
if err != nil {
|
||||
msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", p.port, h.pod, h.uid, err)
|
||||
runtime.HandleError(msg)
|
||||
fmt.Fprint(p.errorStream, msg.Error())
|
||||
}
|
||||
}
|
@ -335,9 +335,15 @@ func (s *Server) InstallDebuggingHandlers(criHandler http.Handler) {
|
||||
ws = new(restful.WebService)
|
||||
ws.
|
||||
Path("/portForward")
|
||||
ws.Route(ws.GET("/{podNamespace}/{podID}").
|
||||
To(s.getPortForward).
|
||||
Operation("getPortForward"))
|
||||
ws.Route(ws.POST("/{podNamespace}/{podID}").
|
||||
To(s.getPortForward).
|
||||
Operation("getPortForward"))
|
||||
ws.Route(ws.GET("/{podNamespace}/{podID}/{uid}").
|
||||
To(s.getPortForward).
|
||||
Operation("getPortForward"))
|
||||
ws.Route(ws.POST("/{podNamespace}/{podID}/{uid}").
|
||||
To(s.getPortForward).
|
||||
Operation("getPortForward"))
|
||||
@ -720,7 +726,8 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
|
||||
kubecontainer.GetPodFullName(pod),
|
||||
params.podUID,
|
||||
s.host.StreamingConnectionIdleTimeout(),
|
||||
remotecommand.DefaultStreamCreationTimeout)
|
||||
remotecommand.DefaultStreamCreationTimeout,
|
||||
portforward.SupportedProtocols)
|
||||
}
|
||||
|
||||
// ServeHTTP responds to HTTP requests on the Kubelet.
|
||||
|
331
pkg/kubelet/server/server_websocket_test.go
Normal file
331
pkg/kubelet/server/server_websocket_test.go
Normal file
@ -0,0 +1,331 @@
|
||||
/*
|
||||
Copyright 2016 The Kubernetes Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
)
|
||||
|
||||
const (
|
||||
dataChannel = iota
|
||||
errorChannel
|
||||
)
|
||||
|
||||
func TestServeWSPortForward(t *testing.T) {
|
||||
tests := []struct {
|
||||
port string
|
||||
uid bool
|
||||
clientData string
|
||||
containerData string
|
||||
shouldError bool
|
||||
}{
|
||||
{port: "", shouldError: true},
|
||||
{port: "abc", shouldError: true},
|
||||
{port: "-1", shouldError: true},
|
||||
{port: "65536", shouldError: true},
|
||||
{port: "0", shouldError: true},
|
||||
{port: "1", shouldError: false},
|
||||
{port: "8000", shouldError: false},
|
||||
{port: "8000", clientData: "client data", containerData: "container data", shouldError: false},
|
||||
{port: "65535", shouldError: false},
|
||||
{port: "65535", uid: true, shouldError: false},
|
||||
}
|
||||
|
||||
podNamespace := "other"
|
||||
podName := "foo"
|
||||
expectedPodName := getPodName(podName, podNamespace)
|
||||
expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647"
|
||||
|
||||
for i, test := range tests {
|
||||
fw := newServerTest()
|
||||
defer fw.testHTTPServer.Close()
|
||||
|
||||
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
portForwardFuncDone := make(chan struct{})
|
||||
|
||||
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
|
||||
defer close(portForwardFuncDone)
|
||||
|
||||
if e, a := expectedPodName, name; e != a {
|
||||
t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
|
||||
if e, a := expectedUid, uid; test.uid && e != string(a) {
|
||||
t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
|
||||
p, err := strconv.ParseUint(test.port, 10, 16)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err)
|
||||
}
|
||||
if e, a := uint16(p), port; e != a {
|
||||
t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a)
|
||||
}
|
||||
|
||||
if test.clientData != "" {
|
||||
fromClient := make([]byte, 32)
|
||||
n, err := stream.Read(fromClient)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading client data: %v", i, err)
|
||||
}
|
||||
if e, a := test.clientData, string(fromClient[0:n]); e != a {
|
||||
t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
if test.containerData != "" {
|
||||
_, err := stream.Write([]byte(test.containerData))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error writing container data: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var url string
|
||||
if test.uid {
|
||||
url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, expectedUid, test.port)
|
||||
} else {
|
||||
url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port)
|
||||
}
|
||||
|
||||
ws, err := websocket.Dial(url, "", "http://127.0.0.1/")
|
||||
if test.shouldError {
|
||||
if err == nil {
|
||||
t.Fatalf("%d: websocket dial expected err", i)
|
||||
}
|
||||
continue
|
||||
} else if err != nil {
|
||||
t.Fatalf("%d: websocket dial unexpected err: %v", i, err)
|
||||
}
|
||||
|
||||
defer ws.Close()
|
||||
|
||||
p, err := strconv.ParseUint(test.port, 10, 16)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err)
|
||||
}
|
||||
p16 := uint16(p)
|
||||
|
||||
channel, data, err := wsRead(ws)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: read failed: expected no error: got %v", i, err)
|
||||
}
|
||||
if channel != dataChannel {
|
||||
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, dataChannel)
|
||||
}
|
||||
if len(data) != binary.Size(p16) {
|
||||
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16))
|
||||
}
|
||||
if e, a := p16, binary.LittleEndian.Uint16(data); e != a {
|
||||
t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port)
|
||||
}
|
||||
|
||||
channel, data, err = wsRead(ws)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: read succeeded: expected no error: got %v", i, err)
|
||||
}
|
||||
if channel != errorChannel {
|
||||
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, errorChannel)
|
||||
}
|
||||
if len(data) != binary.Size(p16) {
|
||||
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16))
|
||||
}
|
||||
if e, a := p16, binary.LittleEndian.Uint16(data); e != a {
|
||||
t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port)
|
||||
}
|
||||
|
||||
if test.clientData != "" {
|
||||
println("writing the client data")
|
||||
err := wsWrite(ws, dataChannel, []byte(test.clientData))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error writing client data: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
if test.containerData != "" {
|
||||
channel, data, err = wsRead(ws)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error reading container data: %v", i, err)
|
||||
}
|
||||
if e, a := test.containerData, string(data); e != a {
|
||||
t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
<-portForwardFuncDone
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeWSMultiplePortForward(t *testing.T) {
|
||||
portsText := []string{"7000,8000", "9000"}
|
||||
ports := []uint16{7000, 8000, 9000}
|
||||
podNamespace := "other"
|
||||
podName := "foo"
|
||||
expectedPodName := getPodName(podName, podNamespace)
|
||||
|
||||
fw := newServerTest()
|
||||
defer fw.testHTTPServer.Close()
|
||||
|
||||
fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
portForwardWG := sync.WaitGroup{}
|
||||
portForwardWG.Add(len(ports))
|
||||
|
||||
portsMutex := sync.Mutex{}
|
||||
portsForwarded := map[uint16]struct{}{}
|
||||
|
||||
fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error {
|
||||
defer portForwardWG.Done()
|
||||
|
||||
if e, a := expectedPodName, name; e != a {
|
||||
t.Fatalf("%d: pod name: expected '%v', got '%v'", port, e, a)
|
||||
}
|
||||
|
||||
portsMutex.Lock()
|
||||
portsForwarded[port] = struct{}{}
|
||||
portsMutex.Unlock()
|
||||
|
||||
fromClient := make([]byte, 32)
|
||||
n, err := stream.Read(fromClient)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error reading client data: %v", port, err)
|
||||
}
|
||||
if e, a := fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]); e != a {
|
||||
t.Fatalf("%d: client data: expected to receive '%v', got '%v'", port, e, a)
|
||||
}
|
||||
|
||||
_, err = stream.Write([]byte(fmt.Sprintf("container data on port %d", port)))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: error writing container data: %v", port, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("ws://%s/portForward/%s/%s?", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName)
|
||||
for _, port := range portsText {
|
||||
url = url + fmt.Sprintf("port=%s&", port)
|
||||
}
|
||||
|
||||
ws, err := websocket.Dial(url, "", "http://127.0.0.1/")
|
||||
if err != nil {
|
||||
t.Fatalf("websocket dial unexpected err: %v", err)
|
||||
}
|
||||
|
||||
defer ws.Close()
|
||||
|
||||
for i, port := range ports {
|
||||
channel, data, err := wsRead(ws)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: read failed: expected no error: got %v", i, err)
|
||||
}
|
||||
if int(channel) != i*2+dataChannel {
|
||||
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+dataChannel)
|
||||
}
|
||||
if len(data) != binary.Size(port) {
|
||||
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port))
|
||||
}
|
||||
if e, a := port, binary.LittleEndian.Uint16(data); e != a {
|
||||
t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port)
|
||||
}
|
||||
|
||||
channel, data, err = wsRead(ws)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: read succeeded: expected no error: got %v", i, err)
|
||||
}
|
||||
if int(channel) != i*2+errorChannel {
|
||||
t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+errorChannel)
|
||||
}
|
||||
if len(data) != binary.Size(port) {
|
||||
t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port))
|
||||
}
|
||||
if e, a := port, binary.LittleEndian.Uint16(data); e != a {
|
||||
t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port)
|
||||
}
|
||||
}
|
||||
|
||||
for i, port := range ports {
|
||||
println("writing the client data", port)
|
||||
err := wsWrite(ws, byte(i*2+dataChannel), []byte(fmt.Sprintf("client data on port %d", port)))
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error writing client data: %v", i, err)
|
||||
}
|
||||
|
||||
channel, data, err := wsRead(ws)
|
||||
if err != nil {
|
||||
t.Fatalf("%d: unexpected error reading container data: %v", i, err)
|
||||
}
|
||||
|
||||
if int(channel) != i*2+dataChannel {
|
||||
t.Fatalf("%d: wrong channel: got %q: expected %q", port, channel, i*2+dataChannel)
|
||||
}
|
||||
if e, a := fmt.Sprintf("container data on port %d", port), string(data); e != a {
|
||||
t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a)
|
||||
}
|
||||
}
|
||||
|
||||
portForwardWG.Wait()
|
||||
|
||||
portsMutex.Lock()
|
||||
defer portsMutex.Unlock()
|
||||
if len(ports) != len(portsForwarded) {
|
||||
t.Fatalf("expected to forward %d ports; got %v", len(ports), portsForwarded)
|
||||
}
|
||||
}
|
||||
func wsWrite(conn *websocket.Conn, channel byte, data []byte) error {
|
||||
frame := make([]byte, len(data)+1)
|
||||
frame[0] = channel
|
||||
copy(frame[1:], data)
|
||||
err := websocket.Message.Send(conn, frame)
|
||||
return err
|
||||
}
|
||||
|
||||
func wsRead(conn *websocket.Conn) (byte, []byte, error) {
|
||||
for {
|
||||
var data []byte
|
||||
err := websocket.Message.Receive(conn, &data)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
channel := data[0]
|
||||
data = data[1:]
|
||||
|
||||
return channel, data, err
|
||||
}
|
||||
}
|
@ -80,7 +80,12 @@ type Config struct {
|
||||
// The streaming protocols the server supports (understands and permits). See
|
||||
// k8s.io/kubernetes/pkg/kubelet/server/remotecommand/constants.go for available protocols.
|
||||
// Only used for SPDY streaming.
|
||||
SupportedProtocols []string
|
||||
SupportedRemoteCommandProtocols []string
|
||||
|
||||
// The streaming protocols the server supports (understands and permits). See
|
||||
// k8s.io/kubernetes/pkg/kubelet/server/portforward/constants.go for available protocols.
|
||||
// Only used for SPDY streaming.
|
||||
SupportedPortForwardProtocols []string
|
||||
|
||||
// The config for serving over TLS. If nil, TLS will not be used.
|
||||
TLSConfig *tls.Config
|
||||
@ -89,9 +94,10 @@ type Config struct {
|
||||
// DefaultConfig provides default values for server Config. The DefaultConfig is partial, so
|
||||
// some fields like Addr must still be provided.
|
||||
var DefaultConfig = Config{
|
||||
StreamIdleTimeout: 4 * time.Hour,
|
||||
StreamCreationTimeout: remotecommand.DefaultStreamCreationTimeout,
|
||||
SupportedProtocols: remotecommand.SupportedStreamingProtocols,
|
||||
StreamIdleTimeout: 4 * time.Hour,
|
||||
StreamCreationTimeout: remotecommand.DefaultStreamCreationTimeout,
|
||||
SupportedRemoteCommandProtocols: remotecommand.SupportedStreamingProtocols,
|
||||
SupportedPortForwardProtocols: portforward.SupportedProtocols,
|
||||
}
|
||||
|
||||
// TODO(timstclair): Add auth(n/z) interface & handling.
|
||||
@ -248,7 +254,7 @@ func (s *server) serveExec(req *restful.Request, resp *restful.Response) {
|
||||
streamOpts,
|
||||
s.config.StreamIdleTimeout,
|
||||
s.config.StreamCreationTimeout,
|
||||
s.config.SupportedProtocols)
|
||||
s.config.SupportedRemoteCommandProtocols)
|
||||
}
|
||||
|
||||
func (s *server) serveAttach(req *restful.Request, resp *restful.Response) {
|
||||
@ -280,7 +286,7 @@ func (s *server) serveAttach(req *restful.Request, resp *restful.Response) {
|
||||
streamOpts,
|
||||
s.config.StreamIdleTimeout,
|
||||
s.config.StreamCreationTimeout,
|
||||
s.config.SupportedProtocols)
|
||||
s.config.SupportedRemoteCommandProtocols)
|
||||
}
|
||||
|
||||
func (s *server) servePortForward(req *restful.Request, resp *restful.Response) {
|
||||
@ -303,7 +309,8 @@ func (s *server) servePortForward(req *restful.Request, resp *restful.Response)
|
||||
pf.PodSandboxId,
|
||||
"", // unused: podUID
|
||||
s.config.StreamIdleTimeout,
|
||||
s.config.StreamCreationTimeout)
|
||||
s.config.StreamCreationTimeout,
|
||||
s.config.SupportedPortForwardProtocols)
|
||||
}
|
||||
|
||||
// criAdapter wraps the Runtime functions to conform to the remotecommand interfaces.
|
||||
|
@ -240,7 +240,7 @@ func TestServePortForward(t *testing.T) {
|
||||
|
||||
exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL)
|
||||
require.NoError(t, err)
|
||||
streamConn, _, err := exec.Dial(kubeletportforward.PortForwardProtocolV1Name)
|
||||
streamConn, _, err := exec.Dial(kubeletportforward.ProtocolV1Name)
|
||||
require.NoError(t, err)
|
||||
defer streamConn.Close()
|
||||
|
||||
|
@ -165,9 +165,10 @@ func (r *PortForwardREST) New() runtime.Object {
|
||||
return &api.Pod{}
|
||||
}
|
||||
|
||||
// NewConnectOptions returns nil since portforward doesn't take additional parameters
|
||||
// NewConnectOptions returns the versioned object that represents the
|
||||
// portforward parameters
|
||||
func (r *PortForwardREST) NewConnectOptions() (runtime.Object, bool, string) {
|
||||
return nil, false, ""
|
||||
return &api.PodPortForwardOptions{}, false, ""
|
||||
}
|
||||
|
||||
// ConnectMethods returns the methods supported by portforward
|
||||
@ -177,7 +178,11 @@ func (r *PortForwardREST) ConnectMethods() []string {
|
||||
|
||||
// Connect returns a handler for the pod portforward proxy
|
||||
func (r *PortForwardREST) Connect(ctx genericapirequest.Context, name string, opts runtime.Object, responder rest.Responder) (http.Handler, error) {
|
||||
location, transport, err := pod.PortForwardLocation(r.Store, r.KubeletConn, ctx, name)
|
||||
portForwardOpts, ok := opts.(*api.PodPortForwardOptions)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid options object: %#v", opts)
|
||||
}
|
||||
location, transport, err := pod.PortForwardLocation(r.Store, r.KubeletConn, ctx, name, portForwardOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -383,6 +383,15 @@ func streamParams(params url.Values, opts runtime.Object) error {
|
||||
if opts.TTY {
|
||||
params.Add(api.ExecTTYParam, "1")
|
||||
}
|
||||
case *api.PodPortForwardOptions:
|
||||
if len(opts.Ports) == 0 {
|
||||
return errors.NewBadRequest("at least one port must be specified")
|
||||
}
|
||||
ports := make([]string, len(opts.Ports))
|
||||
for i, p := range opts.Ports {
|
||||
ports[i] = strconv.FormatInt(int64(p), 10)
|
||||
}
|
||||
params.Add(api.PortHeader, strings.Join(ports, ","))
|
||||
default:
|
||||
return fmt.Errorf("Unknown object for streaming: %v", opts)
|
||||
}
|
||||
@ -477,6 +486,7 @@ func PortForwardLocation(
|
||||
connInfo client.ConnectionInfoGetter,
|
||||
ctx genericapirequest.Context,
|
||||
name string,
|
||||
opts *api.PodPortForwardOptions,
|
||||
) (*url.URL, http.RoundTripper, error) {
|
||||
pod, err := getPod(getter, ctx, name)
|
||||
if err != nil {
|
||||
@ -492,10 +502,15 @@ func PortForwardLocation(
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
params := url.Values{}
|
||||
if err := streamParams(params, opts); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
loc := &url.URL{
|
||||
Scheme: nodeInfo.Scheme,
|
||||
Host: net.JoinHostPort(nodeInfo.Hostname, nodeInfo.Port),
|
||||
Path: fmt.Sprintf("/portForward/%s/%s", pod.Namespace, pod.Name),
|
||||
Scheme: nodeInfo.Scheme,
|
||||
Host: net.JoinHostPort(nodeInfo.Hostname, nodeInfo.Port),
|
||||
Path: fmt.Sprintf("/portForward/%s/%s", pod.Namespace, pod.Name),
|
||||
RawQuery: params.Encode(),
|
||||
}
|
||||
return loc, nodeInfo.Transport, nil
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
package pod
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
@ -26,9 +27,11 @@ import (
|
||||
"k8s.io/apimachinery/pkg/fields"
|
||||
"k8s.io/apimachinery/pkg/labels"
|
||||
"k8s.io/apimachinery/pkg/runtime"
|
||||
"k8s.io/apimachinery/pkg/types"
|
||||
genericapirequest "k8s.io/apiserver/pkg/endpoints/request"
|
||||
"k8s.io/kubernetes/pkg/api"
|
||||
apitesting "k8s.io/kubernetes/pkg/api/testing"
|
||||
"k8s.io/kubernetes/pkg/kubelet/client"
|
||||
)
|
||||
|
||||
func TestMatchPod(t *testing.T) {
|
||||
@ -333,3 +336,64 @@ func TestSelectableFieldLabelConversions(t *testing.T) {
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
type mockConnectionInfoGetter struct {
|
||||
info *client.ConnectionInfo
|
||||
}
|
||||
|
||||
func (g mockConnectionInfoGetter) GetConnectionInfo(nodeName types.NodeName) (*client.ConnectionInfo, error) {
|
||||
return g.info, nil
|
||||
}
|
||||
|
||||
func TestPortForwardLocation(t *testing.T) {
|
||||
ctx := genericapirequest.NewDefaultContext()
|
||||
tcs := []struct {
|
||||
in *api.Pod
|
||||
info *client.ConnectionInfo
|
||||
opts *api.PodPortForwardOptions
|
||||
expectedErr error
|
||||
expectedURL *url.URL
|
||||
}{
|
||||
{
|
||||
in: &api.Pod{
|
||||
Spec: api.PodSpec{},
|
||||
},
|
||||
opts: &api.PodPortForwardOptions{},
|
||||
expectedErr: errors.NewBadRequest("pod test does not have a host assigned"),
|
||||
},
|
||||
{
|
||||
in: &api.Pod{
|
||||
Spec: api.PodSpec{
|
||||
NodeName: "node1",
|
||||
},
|
||||
},
|
||||
opts: &api.PodPortForwardOptions{},
|
||||
expectedErr: errors.NewBadRequest("at least one port must be specified"),
|
||||
},
|
||||
{
|
||||
in: &api.Pod{
|
||||
ObjectMeta: api.ObjectMeta{
|
||||
Namespace: "ns",
|
||||
Name: "pod1",
|
||||
},
|
||||
Spec: api.PodSpec{
|
||||
NodeName: "node1",
|
||||
},
|
||||
},
|
||||
info: &client.ConnectionInfo{},
|
||||
opts: &api.PodPortForwardOptions{Ports: []int32{80}},
|
||||
expectedURL: &url.URL{Host: ":", Path: "/portForward/ns/pod1", RawQuery: "port=80"},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
getter := &mockPodGetter{tc.in}
|
||||
connectionGetter := &mockConnectionInfoGetter{tc.info}
|
||||
loc, _, err := PortForwardLocation(getter, connectionGetter, ctx, "test", tc.opts)
|
||||
if !reflect.DeepEqual(err, tc.expectedErr) {
|
||||
t.Errorf("expected %v, got %v", tc.expectedErr, err)
|
||||
}
|
||||
if !reflect.DeepEqual(loc, tc.expectedURL) {
|
||||
t.Errorf("expected %v, got %v", tc.expectedURL, loc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
package e2e
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@ -28,6 +30,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/util/wait"
|
||||
"k8s.io/kubernetes/pkg/api/v1"
|
||||
@ -36,6 +39,7 @@ import (
|
||||
testutils "k8s.io/kubernetes/test/utils"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -368,36 +372,144 @@ func doTestMustConnectSendDisconnect(bindAddress string, f *framework.Framework)
|
||||
verifyLogMessage(logOutput, "^Done$")
|
||||
}
|
||||
|
||||
func doTestOverWebSockets(bindAddress string, f *framework.Framework) {
|
||||
config, err := framework.LoadConfig()
|
||||
Expect(err).NotTo(HaveOccurred(), "unable to get base config")
|
||||
|
||||
By("creating the pod")
|
||||
pod := pfPod("def", "10", "10", "100", fmt.Sprintf("%s", bindAddress))
|
||||
if _, err := f.ClientSet.Core().Pods(f.Namespace.Name).Create(pod); err != nil {
|
||||
framework.Failf("Couldn't create pod: %v", err)
|
||||
}
|
||||
if err := f.WaitForPodReady(pod.Name); err != nil {
|
||||
framework.Failf("Pod did not start running: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
logs, err := framework.GetPodLogs(f.ClientSet, f.Namespace.Name, pod.Name, "portforwardtester")
|
||||
if err != nil {
|
||||
framework.Logf("Error getting pod log: %v", err)
|
||||
} else {
|
||||
framework.Logf("Pod log:\n%s", logs)
|
||||
}
|
||||
}()
|
||||
|
||||
req := f.ClientSet.Core().RESTClient().Get().
|
||||
Namespace(f.Namespace.Name).
|
||||
Resource("pods").
|
||||
Name(pod.Name).
|
||||
Suffix("portforward").
|
||||
Param("ports", "80")
|
||||
|
||||
url := req.URL()
|
||||
ws, err := framework.OpenWebSocketForURL(url, config, []string{"v4.channel.k8s.io"})
|
||||
if err != nil {
|
||||
framework.Failf("Failed to open websocket to %s: %v", url.String(), err)
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
Eventually(func() error {
|
||||
channel, msg, err := wsRead(ws)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err)
|
||||
}
|
||||
if channel != 0 {
|
||||
return fmt.Errorf("Got message from server that didn't start with channel 0 (data): %v", msg)
|
||||
}
|
||||
if p := binary.LittleEndian.Uint16(msg); p != 80 {
|
||||
return fmt.Errorf("Received the wrong port: %d", p)
|
||||
}
|
||||
return nil
|
||||
}, time.Minute, 10*time.Second).Should(BeNil())
|
||||
|
||||
Eventually(func() error {
|
||||
channel, msg, err := wsRead(ws)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err)
|
||||
}
|
||||
if channel != 1 {
|
||||
return fmt.Errorf("Got message from server that didn't start with channel 1 (error): %v", msg)
|
||||
}
|
||||
if p := binary.LittleEndian.Uint16(msg); p != 80 {
|
||||
return fmt.Errorf("Received the wrong port: %d", p)
|
||||
}
|
||||
return nil
|
||||
}, time.Minute, 10*time.Second).Should(BeNil())
|
||||
|
||||
By("sending the expected data to the local port")
|
||||
err = wsWrite(ws, 0, []byte("def"))
|
||||
if err != nil {
|
||||
framework.Failf("Failed to write to websocket %s: %v", url.String(), err)
|
||||
}
|
||||
|
||||
By("reading data from the local port")
|
||||
buf := bytes.Buffer{}
|
||||
expectedData := bytes.Repeat([]byte("x"), 100)
|
||||
Eventually(func() error {
|
||||
channel, msg, err := wsRead(ws)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err)
|
||||
}
|
||||
if channel != 0 {
|
||||
return fmt.Errorf("Got message from server that didn't start with channel 0 (data): %v", msg)
|
||||
}
|
||||
buf.Write(msg)
|
||||
if bytes.Equal(expectedData, buf.Bytes()) {
|
||||
return fmt.Errorf("Expected %q from server, got %q", expectedData, buf.Bytes())
|
||||
}
|
||||
return nil
|
||||
}, time.Minute, 10*time.Second).Should(BeNil())
|
||||
|
||||
By("verifying logs")
|
||||
logOutput, err := framework.GetPodLogs(f.ClientSet, f.Namespace.Name, pod.Name, "portforwardtester")
|
||||
if err != nil {
|
||||
framework.Failf("Error retrieving pod logs: %v", err)
|
||||
}
|
||||
verifyLogMessage(logOutput, "^Accepted client connection$")
|
||||
verifyLogMessage(logOutput, "^Received expected client data$")
|
||||
}
|
||||
|
||||
var _ = framework.KubeDescribe("Port forwarding", func() {
|
||||
f := framework.NewDefaultFramework("port-forwarding")
|
||||
|
||||
framework.KubeDescribe("With a server listening on 0.0.0.0 that expects a client request", func() {
|
||||
It("should support a client that connects, sends no data, and disconnects", func() {
|
||||
doTestMustConnectSendNothing("0.0.0.0", f)
|
||||
framework.KubeDescribe("With a server listening on 0.0.0.0", func() {
|
||||
framework.KubeDescribe("that expects a client request", func() {
|
||||
It("should support a client that connects, sends no data, and disconnects", func() {
|
||||
doTestMustConnectSendNothing("0.0.0.0", f)
|
||||
})
|
||||
It("should support a client that connects, sends data, and disconnects", func() {
|
||||
doTestMustConnectSendDisconnect("0.0.0.0", f)
|
||||
})
|
||||
})
|
||||
It("should support a client that connects, sends data, and disconnects", func() {
|
||||
doTestMustConnectSendDisconnect("0.0.0.0", f)
|
||||
|
||||
framework.KubeDescribe("that expects no client request", func() {
|
||||
It("should support a client that connects, sends data, and disconnects", func() {
|
||||
doTestConnectSendDisconnect("0.0.0.0", f)
|
||||
})
|
||||
})
|
||||
|
||||
It("should support forwarding over websockets", func() {
|
||||
doTestOverWebSockets("0.0.0.0", f)
|
||||
})
|
||||
})
|
||||
|
||||
framework.KubeDescribe("With a server listening on 0.0.0.0 that expects no client request", func() {
|
||||
It("should support a client that connects, sends data, and disconnects", func() {
|
||||
doTestConnectSendDisconnect("0.0.0.0", f)
|
||||
framework.KubeDescribe("With a server listening on localhost", func() {
|
||||
framework.KubeDescribe("that expects a client request", func() {
|
||||
It("should support a client that connects, sends no data, and disconnects [Conformance]", func() {
|
||||
doTestMustConnectSendNothing("localhost", f)
|
||||
})
|
||||
It("should support a client that connects, sends data, and disconnects [Conformance]", func() {
|
||||
doTestMustConnectSendDisconnect("localhost", f)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
framework.KubeDescribe("With a server listening on localhost that expects a client request", func() {
|
||||
It("should support a client that connects, sends no data, and disconnects [Conformance]", func() {
|
||||
doTestMustConnectSendNothing("localhost", f)
|
||||
framework.KubeDescribe("that expects no client request", func() {
|
||||
It("should support a client that connects, sends data, and disconnects [Conformance]", func() {
|
||||
doTestConnectSendDisconnect("localhost", f)
|
||||
})
|
||||
})
|
||||
It("should support a client that connects, sends data, and disconnects [Conformance]", func() {
|
||||
doTestMustConnectSendDisconnect("localhost", f)
|
||||
})
|
||||
})
|
||||
|
||||
framework.KubeDescribe("With a server listening on localhost that expects no client request", func() {
|
||||
It("should support a client that connects, sends data, and disconnects [Conformance]", func() {
|
||||
doTestConnectSendDisconnect("localhost", f)
|
||||
It("should support forwarding over websockets", func() {
|
||||
doTestOverWebSockets("localhost", f)
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -412,3 +524,30 @@ func verifyLogMessage(log, expected string) {
|
||||
}
|
||||
framework.Failf("Missing %q from log: %s", expected, log)
|
||||
}
|
||||
|
||||
func wsRead(conn *websocket.Conn) (byte, []byte, error) {
|
||||
for {
|
||||
var data []byte
|
||||
err := websocket.Message.Receive(conn, &data)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
channel := data[0]
|
||||
data = data[1:]
|
||||
|
||||
return channel, data, err
|
||||
}
|
||||
}
|
||||
|
||||
func wsWrite(conn *websocket.Conn, channel byte, data []byte) error {
|
||||
frame := make([]byte, len(data)+1)
|
||||
frame[0] = channel
|
||||
copy(frame[1:], data)
|
||||
err := websocket.Message.Send(conn, frame)
|
||||
return err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user