diff --git a/cluster/saltbase/salt/e2e-image-puller/e2e-image-puller.manifest b/cluster/saltbase/salt/e2e-image-puller/e2e-image-puller.manifest index e30103cb8d0..cfeed5a8ff3 100644 --- a/cluster/saltbase/salt/e2e-image-puller/e2e-image-puller.manifest +++ b/cluster/saltbase/salt/e2e-image-puller/e2e-image-puller.manifest @@ -27,7 +27,7 @@ spec: command: - /bin/sh - -c - - "for i in gcr.io/google_containers/busybox gcr.io/google_containers/busybox:1.24 gcr.io/google_containers/dnsutils:e2e gcr.io/google_containers/eptest:0.1 gcr.io/google_containers/fakegitserver:0.1 gcr.io/google_containers/hostexec:1.2 gcr.io/google_containers/iperf:e2e gcr.io/google_containers/jessie-dnsutils:e2e gcr.io/google_containers/liveness:e2e gcr.io/google_containers/mounttest:0.7 gcr.io/google_containers/mounttest-user:0.3 gcr.io/google_containers/netexec:1.4 gcr.io/google_containers/netexec:1.7 gcr.io/google_containers/nettest:1.7 gcr.io/google_containers/nettest:1.8 gcr.io/google_containers/nginx-slim:0.7 gcr.io/google_containers/nginx-slim:0.8 gcr.io/google_containers/n-way-http:1.0 gcr.io/google_containers/pause:2.0 gcr.io/google_containers/pause-amd64:3.0 gcr.io/google_containers/porter:cd5cb5791ebaa8641955f0e8c2a9bed669b1eaab gcr.io/google_containers/portforwardtester:1.0 gcr.io/google_containers/redis:e2e gcr.io/google_containers/resource_consumer:beta4 gcr.io/google_containers/resource_consumer/controller:beta4 gcr.io/google_containers/serve_hostname:v1.4 gcr.io/google_containers/test-webserver:e2e gcr.io/google_containers/ubuntu:14.04 gcr.io/google_containers/update-demo:kitten gcr.io/google_containers/update-demo:nautilus gcr.io/google_containers/volume-ceph:0.1 gcr.io/google_containers/volume-gluster:0.2 gcr.io/google_containers/volume-iscsi:0.1 gcr.io/google_containers/volume-nfs:0.6 gcr.io/google_containers/volume-rbd:0.1 gcr.io/google_samples/gb-redisslave:v1 gcr.io/google_containers/redis:v1; do echo $(date '+%X') pulling $i; docker pull $i 1>/dev/null; done; exit 0;" + - "for i in gcr.io/google_containers/busybox gcr.io/google_containers/busybox:1.24 gcr.io/google_containers/dnsutils:e2e gcr.io/google_containers/eptest:0.1 gcr.io/google_containers/fakegitserver:0.1 gcr.io/google_containers/hostexec:1.2 gcr.io/google_containers/iperf:e2e gcr.io/google_containers/jessie-dnsutils:e2e gcr.io/google_containers/liveness:e2e gcr.io/google_containers/mounttest:0.7 gcr.io/google_containers/mounttest-user:0.3 gcr.io/google_containers/netexec:1.4 gcr.io/google_containers/netexec:1.7 gcr.io/google_containers/nettest:1.7 gcr.io/google_containers/nettest:1.8 gcr.io/google_containers/nginx-slim:0.7 gcr.io/google_containers/nginx-slim:0.8 gcr.io/google_containers/n-way-http:1.0 gcr.io/google_containers/pause:2.0 gcr.io/google_containers/pause-amd64:3.0 gcr.io/google_containers/porter:cd5cb5791ebaa8641955f0e8c2a9bed669b1eaab gcr.io/google_containers/portforwardtester:1.2 gcr.io/google_containers/redis:e2e gcr.io/google_containers/resource_consumer:beta4 gcr.io/google_containers/resource_consumer/controller:beta4 gcr.io/google_containers/serve_hostname:v1.4 gcr.io/google_containers/test-webserver:e2e gcr.io/google_containers/ubuntu:14.04 gcr.io/google_containers/update-demo:kitten gcr.io/google_containers/update-demo:nautilus gcr.io/google_containers/volume-ceph:0.1 gcr.io/google_containers/volume-gluster:0.2 gcr.io/google_containers/volume-iscsi:0.1 gcr.io/google_containers/volume-nfs:0.6 gcr.io/google_containers/volume-rbd:0.1 gcr.io/google_samples/gb-redisslave:v1 gcr.io/google_containers/redis:v1; do echo $(date '+%X') pulling $i; docker pull $i 1>/dev/null; done; exit 0;" securityContext: privileged: true volumeMounts: diff --git a/pkg/client/tests/portfoward_test.go b/pkg/client/tests/portfoward_test.go index c2c1cb083aa..aaf7dba8bcf 100644 --- a/pkg/client/tests/portfoward_test.go +++ b/pkg/client/tests/portfoward_test.go @@ -84,7 +84,7 @@ func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedF received: make(map[uint16]string), send: serverSends, } - portforward.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second) + portforward.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second, portforward.SupportedProtocols) for port, expected := range expectedFromClient { actual, ok := pf.received[port] diff --git a/pkg/genericapiserver/endpoints/installer.go b/pkg/genericapiserver/endpoints/installer.go index 43416ae67e4..4176106bff8 100644 --- a/pkg/genericapiserver/endpoints/installer.go +++ b/pkg/genericapiserver/endpoints/installer.go @@ -1047,10 +1047,14 @@ func typeToJSON(typeName string) string { return "string" case "byte", "*byte": return "string" + + // TODO: Fix these when go-restful supports a way to specify an array query param: + // https://github.com/emicklei/go-restful/issues/225 case "[]string", "[]*string": - // TODO: Fix this when go-restful supports a way to specify an array query param: - // https://github.com/emicklei/go-restful/issues/225 return "string" + case "[]int32", "[]*int32": + return "integer" + default: return typeName } diff --git a/pkg/kubelet/kubelet.go b/pkg/kubelet/kubelet.go index 9f951eb7e32..d229e7413f4 100644 --- a/pkg/kubelet/kubelet.go +++ b/pkg/kubelet/kubelet.go @@ -2171,9 +2171,10 @@ func getStreamingConfig(kubeCfg *componentconfig.KubeletConfiguration, kubeDeps BaseURL: &url.URL{ Path: "/cri/", }, - StreamIdleTimeout: kubeCfg.StreamingConnectionIdleTimeout.Duration, - StreamCreationTimeout: streaming.DefaultConfig.StreamCreationTimeout, - SupportedProtocols: streaming.DefaultConfig.SupportedProtocols, + StreamIdleTimeout: kubeCfg.StreamingConnectionIdleTimeout.Duration, + StreamCreationTimeout: streaming.DefaultConfig.StreamCreationTimeout, + SupportedRemoteCommandProtocols: streaming.DefaultConfig.SupportedRemoteCommandProtocols, + SupportedPortForwardProtocols: streaming.DefaultConfig.SupportedPortForwardProtocols, } if kubeDeps.TLSOptions != nil { config.TLSConfig = kubeDeps.TLSOptions.Config diff --git a/pkg/kubelet/server/portforward/constants.go b/pkg/kubelet/server/portforward/constants.go index 1b73be299b4..e7ccd58ae29 100644 --- a/pkg/kubelet/server/portforward/constants.go +++ b/pkg/kubelet/server/portforward/constants.go @@ -18,4 +18,6 @@ limitations under the License. package portforward // The subprotocol "portforward.k8s.io" is used for port forwarding. -const PortForwardProtocolV1Name = "portforward.k8s.io" +const ProtocolV1Name = "portforward.k8s.io" + +var SupportedProtocols = []string{ProtocolV1Name} diff --git a/pkg/kubelet/server/portforward/httpstream.go b/pkg/kubelet/server/portforward/httpstream.go new file mode 100644 index 00000000000..8af23b966ef --- /dev/null +++ b/pkg/kubelet/server/portforward/httpstream.go @@ -0,0 +1,309 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "errors" + "fmt" + "net/http" + "strconv" + "sync" + "time" + + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/httpstream/spdy" + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/kubernetes/pkg/api" + + "github.com/golang/glog" +) + +func handleHttpStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error { + _, err := httpstream.Handshake(req, w, supportedPortForwardProtocols) + // negotiated protocol isn't currently used server side, but could be in the future + if err != nil { + // Handshake writes the error to the client + return err + } + streamChan := make(chan httpstream.Stream, 1) + + glog.V(5).Infof("Upgrading port forward response") + upgrader := spdy.NewResponseUpgrader() + conn := upgrader.UpgradeResponse(w, req, httpStreamReceived(streamChan)) + if conn == nil { + return errors.New("Unable to upgrade websocket connection") + } + defer conn.Close() + + glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout) + conn.SetIdleTimeout(idleTimeout) + + h := &httpStreamHandler{ + conn: conn, + streamChan: streamChan, + streamPairs: make(map[string]*httpStreamPair), + streamCreationTimeout: streamCreationTimeout, + pod: podName, + uid: uid, + forwarder: portForwarder, + } + h.run() + + return nil +} + +// httpStreamReceived is the httpstream.NewStreamHandler for port +// forward streams. It checks each stream's port and stream type headers, +// rejecting any streams that with missing or invalid values. Each valid +// stream is sent to the streams channel. +func httpStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error { + return func(stream httpstream.Stream, replySent <-chan struct{}) error { + // make sure it has a valid port header + portString := stream.Headers().Get(api.PortHeader) + if len(portString) == 0 { + return fmt.Errorf("%q header is required", api.PortHeader) + } + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return fmt.Errorf("unable to parse %q as a port: %v", portString, err) + } + if port < 1 { + return fmt.Errorf("port %q must be > 0", portString) + } + + // make sure it has a valid stream type header + streamType := stream.Headers().Get(api.StreamType) + if len(streamType) == 0 { + return fmt.Errorf("%q header is required", api.StreamType) + } + if streamType != api.StreamTypeError && streamType != api.StreamTypeData { + return fmt.Errorf("invalid stream type %q", streamType) + } + + streams <- stream + return nil + } +} + +// httpStreamHandler is capable of processing multiple port forward +// requests over a single httpstream.Connection. +type httpStreamHandler struct { + conn httpstream.Connection + streamChan chan httpstream.Stream + streamPairsLock sync.RWMutex + streamPairs map[string]*httpStreamPair + streamCreationTimeout time.Duration + pod string + uid types.UID + forwarder PortForwarder +} + +// getStreamPair returns a httpStreamPair for requestID. This creates a +// new pair if one does not yet exist for the requestID. The returned bool is +// true if the pair was created. +func (h *httpStreamHandler) getStreamPair(requestID string) (*httpStreamPair, bool) { + h.streamPairsLock.Lock() + defer h.streamPairsLock.Unlock() + + if p, ok := h.streamPairs[requestID]; ok { + glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID) + return p, false + } + + glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID) + + p := newPortForwardPair(requestID) + h.streamPairs[requestID] = p + + return p, true +} + +// monitorStreamPair waits for the pair to receive both its error and data +// streams, or for the timeout to expire (whichever happens first), and then +// removes the pair. +func (h *httpStreamHandler) monitorStreamPair(p *httpStreamPair, timeout <-chan time.Time) { + select { + case <-timeout: + err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID) + utilruntime.HandleError(err) + p.printError(err.Error()) + case <-p.complete: + glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID) + } + h.removeStreamPair(p.requestID) +} + +// hasStreamPair returns a bool indicating if a stream pair for requestID +// exists. +func (h *httpStreamHandler) hasStreamPair(requestID string) bool { + h.streamPairsLock.RLock() + defer h.streamPairsLock.RUnlock() + + _, ok := h.streamPairs[requestID] + return ok +} + +// removeStreamPair removes the stream pair identified by requestID from streamPairs. +func (h *httpStreamHandler) removeStreamPair(requestID string) { + h.streamPairsLock.Lock() + defer h.streamPairsLock.Unlock() + + delete(h.streamPairs, requestID) +} + +// requestID returns the request id for stream. +func (h *httpStreamHandler) requestID(stream httpstream.Stream) string { + requestID := stream.Headers().Get(api.PortForwardRequestIDHeader) + if len(requestID) == 0 { + glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader) + // If we get here, it's because the connection came from an older client + // that isn't generating the request id header + // (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287) + // + // This is a best-effort attempt at supporting older clients. + // + // When there aren't concurrent new forwarded connections, each connection + // will have a pair of streams (data, error), and the stream IDs will be + // consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert + // the stream ID into a pseudo-request id by taking the stream type and + // using id = stream.Identifier() when the stream type is error, + // and id = stream.Identifier() - 2 when it's data. + // + // NOTE: this only works when there are not concurrent new streams from + // multiple forwarded connections; it's a best-effort attempt at supporting + // old clients that don't generate request ids. If there are concurrent + // new connections, it's possible that 1 connection gets streams whose IDs + // are not consecutive (e.g. 5 and 9 instead of 5 and 7). + streamType := stream.Headers().Get(api.StreamType) + switch streamType { + case api.StreamTypeError: + requestID = strconv.Itoa(int(stream.Identifier())) + case api.StreamTypeData: + requestID = strconv.Itoa(int(stream.Identifier()) - 2) + } + + glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier()) + } + return requestID +} + +// run is the main loop for the httpStreamHandler. It processes new +// streams, invoking portForward for each complete stream pair. The loop exits +// when the httpstream.Connection is closed. +func (h *httpStreamHandler) run() { + glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn) +Loop: + for { + select { + case <-h.conn.CloseChan(): + glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn) + break Loop + case stream := <-h.streamChan: + requestID := h.requestID(stream) + streamType := stream.Headers().Get(api.StreamType) + glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType) + + p, created := h.getStreamPair(requestID) + if created { + go h.monitorStreamPair(p, time.After(h.streamCreationTimeout)) + } + if complete, err := p.add(stream); err != nil { + msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err) + utilruntime.HandleError(errors.New(msg)) + p.printError(msg) + } else if complete { + go h.portForward(p) + } + } + } +} + +// portForward invokes the httpStreamHandler's forwarder.PortForward +// function for the given stream pair. +func (h *httpStreamHandler) portForward(p *httpStreamPair) { + defer p.dataStream.Close() + defer p.errorStream.Close() + + portString := p.dataStream.Headers().Get(api.PortHeader) + port, _ := strconv.ParseUint(portString, 10, 16) + + glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) + err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream) + glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) + + if err != nil { + msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err) + utilruntime.HandleError(msg) + fmt.Fprint(p.errorStream, msg.Error()) + } +} + +// httpStreamPair represents the error and data streams for a port +// forwarding request. +type httpStreamPair struct { + lock sync.RWMutex + requestID string + dataStream httpstream.Stream + errorStream httpstream.Stream + complete chan struct{} +} + +// newPortForwardPair creates a new httpStreamPair. +func newPortForwardPair(requestID string) *httpStreamPair { + return &httpStreamPair{ + requestID: requestID, + complete: make(chan struct{}), + } +} + +// add adds the stream to the httpStreamPair. If the pair already +// contains a stream for the new stream's type, an error is returned. add +// returns true if both the data and error streams for this pair have been +// received. +func (p *httpStreamPair) add(stream httpstream.Stream) (bool, error) { + p.lock.Lock() + defer p.lock.Unlock() + + switch stream.Headers().Get(api.StreamType) { + case api.StreamTypeError: + if p.errorStream != nil { + return false, errors.New("error stream already assigned") + } + p.errorStream = stream + case api.StreamTypeData: + if p.dataStream != nil { + return false, errors.New("data stream already assigned") + } + p.dataStream = stream + } + + complete := p.errorStream != nil && p.dataStream != nil + if complete { + close(p.complete) + } + return complete, nil +} + +// printError writes s to p.errorStream if p.errorStream has been set. +func (p *httpStreamPair) printError(s string) { + p.lock.RLock() + defer p.lock.RUnlock() + if p.errorStream != nil { + fmt.Fprint(p.errorStream, s) + } +} diff --git a/pkg/kubelet/server/portforward/portforward_test.go b/pkg/kubelet/server/portforward/httpstream_test.go similarity index 96% rename from pkg/kubelet/server/portforward/portforward_test.go rename to pkg/kubelet/server/portforward/httpstream_test.go index 542ad71c1d3..ee5696d5a17 100644 --- a/pkg/kubelet/server/portforward/portforward_test.go +++ b/pkg/kubelet/server/portforward/httpstream_test.go @@ -25,7 +25,7 @@ import ( "k8s.io/kubernetes/pkg/api" ) -func TestPortForwardStreamReceived(t *testing.T) { +func TestHTTPStreamReceived(t *testing.T) { tests := map[string]struct { port string streamType string @@ -62,7 +62,7 @@ func TestPortForwardStreamReceived(t *testing.T) { } for name, test := range tests { streams := make(chan httpstream.Stream, 1) - f := portForwardStreamReceived(streams) + f := httpStreamReceived(streams) stream := newFakeHttpStream() if len(test.port) > 0 { stream.headers.Set("port", test.port) @@ -92,48 +92,11 @@ func TestPortForwardStreamReceived(t *testing.T) { } } -type fakeHttpStream struct { - headers http.Header - id uint32 -} - -func newFakeHttpStream() *fakeHttpStream { - return &fakeHttpStream{ - headers: make(http.Header), - } -} - -var _ httpstream.Stream = &fakeHttpStream{} - -func (s *fakeHttpStream) Read(data []byte) (int, error) { - return 0, nil -} - -func (s *fakeHttpStream) Write(data []byte) (int, error) { - return 0, nil -} - -func (s *fakeHttpStream) Close() error { - return nil -} - -func (s *fakeHttpStream) Reset() error { - return nil -} - -func (s *fakeHttpStream) Headers() http.Header { - return s.headers -} - -func (s *fakeHttpStream) Identifier() uint32 { - return s.id -} - func TestGetStreamPair(t *testing.T) { timeout := make(chan time.Time) - h := &portForwardStreamHandler{ - streamPairs: make(map[string]*portForwardStreamPair), + h := &httpStreamHandler{ + streamPairs: make(map[string]*httpStreamPair), } // test adding a new entry @@ -223,7 +186,7 @@ func TestGetStreamPair(t *testing.T) { } func TestRequestID(t *testing.T) { - h := &portForwardStreamHandler{} + h := &httpStreamHandler{} s := newFakeHttpStream() s.headers.Set(api.StreamType, api.StreamTypeError) @@ -244,3 +207,40 @@ func TestRequestID(t *testing.T) { t.Errorf("expected %q, got %q", e, a) } } + +type fakeHttpStream struct { + headers http.Header + id uint32 +} + +func newFakeHttpStream() *fakeHttpStream { + return &fakeHttpStream{ + headers: make(http.Header), + } +} + +var _ httpstream.Stream = &fakeHttpStream{} + +func (s *fakeHttpStream) Read(data []byte) (int, error) { + return 0, nil +} + +func (s *fakeHttpStream) Write(data []byte) (int, error) { + return 0, nil +} + +func (s *fakeHttpStream) Close() error { + return nil +} + +func (s *fakeHttpStream) Reset() error { + return nil +} + +func (s *fakeHttpStream) Headers() http.Header { + return s.headers +} + +func (s *fakeHttpStream) Identifier() uint32 { + return s.id +} diff --git a/pkg/kubelet/server/portforward/portforward.go b/pkg/kubelet/server/portforward/portforward.go index 2737752c36b..a812d61420d 100644 --- a/pkg/kubelet/server/portforward/portforward.go +++ b/pkg/kubelet/server/portforward/portforward.go @@ -17,21 +17,13 @@ limitations under the License. package portforward import ( - "errors" - "fmt" "io" "net/http" - "strconv" - "sync" "time" - "github.com/golang/glog" - "k8s.io/apimachinery/pkg/types" - "k8s.io/apimachinery/pkg/util/httpstream" - "k8s.io/apimachinery/pkg/util/httpstream/spdy" - utilruntime "k8s.io/apimachinery/pkg/util/runtime" - "k8s.io/kubernetes/pkg/api" + "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/apiserver/pkg/util/wsstream" ) // PortForwarder knows how to forward content from a data stream to/from a port @@ -46,278 +38,16 @@ type PortForwarder interface { // been timed out due to idleness. This function handles multiple forwarded // connections; i.e., multiple `curl http://localhost:8888/` requests will be // handled by a single invocation of ServePortForward. -func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) { - supportedPortForwardProtocols := []string{PortForwardProtocolV1Name} - _, err := httpstream.Handshake(req, w, supportedPortForwardProtocols) - // negotiated protocol isn't currently used server side, but could be in the future - if err != nil { - // Handshake writes the error to the client - utilruntime.HandleError(err) - return +func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration, supportedProtocols []string) { + var err error + if wsstream.IsWebSocketRequest(req) { + err = handleWebSocketStreams(req, w, portForwarder, podName, uid, supportedProtocols, idleTimeout, streamCreationTimeout) + } else { + err = handleHttpStreams(req, w, portForwarder, podName, uid, supportedProtocols, idleTimeout, streamCreationTimeout) } - streamChan := make(chan httpstream.Stream, 1) - - glog.V(5).Infof("Upgrading port forward response") - upgrader := spdy.NewResponseUpgrader() - conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan)) - if conn == nil { - return - } - defer conn.Close() - - glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout) - conn.SetIdleTimeout(idleTimeout) - - h := &portForwardStreamHandler{ - conn: conn, - streamChan: streamChan, - streamPairs: make(map[string]*portForwardStreamPair), - streamCreationTimeout: streamCreationTimeout, - pod: podName, - uid: uid, - forwarder: portForwarder, - } - h.run() -} - -// portForwardStreamReceived is the httpstream.NewStreamHandler for port -// forward streams. It checks each stream's port and stream type headers, -// rejecting any streams that with missing or invalid values. Each valid -// stream is sent to the streams channel. -func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error { - return func(stream httpstream.Stream, replySent <-chan struct{}) error { - // make sure it has a valid port header - portString := stream.Headers().Get(api.PortHeader) - if len(portString) == 0 { - return fmt.Errorf("%q header is required", api.PortHeader) - } - port, err := strconv.ParseUint(portString, 10, 16) - if err != nil { - return fmt.Errorf("unable to parse %q as a port: %v", portString, err) - } - if port < 1 { - return fmt.Errorf("port %q must be > 0", portString) - } - - // make sure it has a valid stream type header - streamType := stream.Headers().Get(api.StreamType) - if len(streamType) == 0 { - return fmt.Errorf("%q header is required", api.StreamType) - } - if streamType != api.StreamTypeError && streamType != api.StreamTypeData { - return fmt.Errorf("invalid stream type %q", streamType) - } - - streams <- stream - return nil - } -} - -// portForwardStreamHandler is capable of processing multiple port forward -// requests over a single httpstream.Connection. -type portForwardStreamHandler struct { - conn httpstream.Connection - streamChan chan httpstream.Stream - streamPairsLock sync.RWMutex - streamPairs map[string]*portForwardStreamPair - streamCreationTimeout time.Duration - pod string - uid types.UID - forwarder PortForwarder -} - -// getStreamPair returns a portForwardStreamPair for requestID. This creates a -// new pair if one does not yet exist for the requestID. The returned bool is -// true if the pair was created. -func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) { - h.streamPairsLock.Lock() - defer h.streamPairsLock.Unlock() - - if p, ok := h.streamPairs[requestID]; ok { - glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID) - return p, false - } - - glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID) - - p := newPortForwardPair(requestID) - h.streamPairs[requestID] = p - - return p, true -} - -// monitorStreamPair waits for the pair to receive both its error and data -// streams, or for the timeout to expire (whichever happens first), and then -// removes the pair. -func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) { - select { - case <-timeout: - err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID) - utilruntime.HandleError(err) - p.printError(err.Error()) - case <-p.complete: - glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID) - } - h.removeStreamPair(p.requestID) -} - -// hasStreamPair returns a bool indicating if a stream pair for requestID -// exists. -func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool { - h.streamPairsLock.RLock() - defer h.streamPairsLock.RUnlock() - - _, ok := h.streamPairs[requestID] - return ok -} - -// removeStreamPair removes the stream pair identified by requestID from streamPairs. -func (h *portForwardStreamHandler) removeStreamPair(requestID string) { - h.streamPairsLock.Lock() - defer h.streamPairsLock.Unlock() - - delete(h.streamPairs, requestID) -} - -// requestID returns the request id for stream. -func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string { - requestID := stream.Headers().Get(api.PortForwardRequestIDHeader) - if len(requestID) == 0 { - glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader) - // If we get here, it's because the connection came from an older client - // that isn't generating the request id header - // (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287) - // - // This is a best-effort attempt at supporting older clients. - // - // When there aren't concurrent new forwarded connections, each connection - // will have a pair of streams (data, error), and the stream IDs will be - // consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert - // the stream ID into a pseudo-request id by taking the stream type and - // using id = stream.Identifier() when the stream type is error, - // and id = stream.Identifier() - 2 when it's data. - // - // NOTE: this only works when there are not concurrent new streams from - // multiple forwarded connections; it's a best-effort attempt at supporting - // old clients that don't generate request ids. If there are concurrent - // new connections, it's possible that 1 connection gets streams whose IDs - // are not consecutive (e.g. 5 and 9 instead of 5 and 7). - streamType := stream.Headers().Get(api.StreamType) - switch streamType { - case api.StreamTypeError: - requestID = strconv.Itoa(int(stream.Identifier())) - case api.StreamTypeData: - requestID = strconv.Itoa(int(stream.Identifier()) - 2) - } - - glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier()) - } - return requestID -} - -// run is the main loop for the portForwardStreamHandler. It processes new -// streams, invoking portForward for each complete stream pair. The loop exits -// when the httpstream.Connection is closed. -func (h *portForwardStreamHandler) run() { - glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn) -Loop: - for { - select { - case <-h.conn.CloseChan(): - glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn) - break Loop - case stream := <-h.streamChan: - requestID := h.requestID(stream) - streamType := stream.Headers().Get(api.StreamType) - glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType) - - p, created := h.getStreamPair(requestID) - if created { - go h.monitorStreamPair(p, time.After(h.streamCreationTimeout)) - } - if complete, err := p.add(stream); err != nil { - msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err) - utilruntime.HandleError(errors.New(msg)) - p.printError(msg) - } else if complete { - go h.portForward(p) - } - } - } -} - -// portForward invokes the portForwardStreamHandler's forwarder.PortForward -// function for the given stream pair. -func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) { - defer p.dataStream.Close() - defer p.errorStream.Close() - - portString := p.dataStream.Headers().Get(api.PortHeader) - port, _ := strconv.ParseUint(portString, 10, 16) - - glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) - err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream) - glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) - if err != nil { - msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err) - utilruntime.HandleError(msg) - fmt.Fprint(p.errorStream, msg.Error()) - } -} - -// portForwardStreamPair represents the error and data streams for a port -// forwarding request. -type portForwardStreamPair struct { - lock sync.RWMutex - requestID string - dataStream httpstream.Stream - errorStream httpstream.Stream - complete chan struct{} -} - -// newPortForwardPair creates a new portForwardStreamPair. -func newPortForwardPair(requestID string) *portForwardStreamPair { - return &portForwardStreamPair{ - requestID: requestID, - complete: make(chan struct{}), - } -} - -// add adds the stream to the portForwardStreamPair. If the pair already -// contains a stream for the new stream's type, an error is returned. add -// returns true if both the data and error streams for this pair have been -// received. -func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) { - p.lock.Lock() - defer p.lock.Unlock() - - switch stream.Headers().Get(api.StreamType) { - case api.StreamTypeError: - if p.errorStream != nil { - return false, errors.New("error stream already assigned") - } - p.errorStream = stream - case api.StreamTypeData: - if p.dataStream != nil { - return false, errors.New("data stream already assigned") - } - p.dataStream = stream - } - - complete := p.errorStream != nil && p.dataStream != nil - if complete { - close(p.complete) - } - return complete, nil -} - -// printError writes s to p.errorStream if p.errorStream has been set. -func (p *portForwardStreamPair) printError(s string) { - p.lock.RLock() - defer p.lock.RUnlock() - if p.errorStream != nil { - fmt.Fprint(p.errorStream, s) + runtime.HandleError(err) + return } } diff --git a/pkg/kubelet/server/portforward/websocket.go b/pkg/kubelet/server/portforward/websocket.go new file mode 100644 index 00000000000..f201062735c --- /dev/null +++ b/pkg/kubelet/server/portforward/websocket.go @@ -0,0 +1,189 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "encoding/binary" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/golang/glog" + + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/runtime" + "k8s.io/apiserver/pkg/server/httplog" + "k8s.io/apiserver/pkg/util/wsstream" + "k8s.io/kubernetes/pkg/api" +) + +const ( + dataChannel = iota + errorChannel + + v4BinaryWebsocketProtocol = "v4." + wsstream.ChannelWebSocketProtocol + v4Base64WebsocketProtocol = "v4." + wsstream.Base64ChannelWebSocketProtocol +) + +// options contains details about which streams are required for +// port forwarding. +type v4Options struct { + ports []uint16 +} + +// newOptions creates a new options from the Request. +func newV4Options(req *http.Request) (*v4Options, error) { + portStrings := req.URL.Query()[api.PortHeader] + if len(portStrings) == 0 { + return nil, fmt.Errorf("%q is required", api.PortHeader) + } + + ports := make([]uint16, 0, len(portStrings)) + for _, portString := range portStrings { + if len(portString) == 0 { + return nil, fmt.Errorf("%q is cannot be empty", api.PortHeader) + } + for _, p := range strings.Split(portString, ",") { + port, err := strconv.ParseUint(p, 10, 16) + if err != nil { + return nil, fmt.Errorf("unable to parse %q as a port: %v", portString, err) + } + if port < 1 { + return nil, fmt.Errorf("port %q must be > 0", portString) + } + ports = append(ports, uint16(port)) + } + } + + return &v4Options{ + ports: ports, + }, nil +} + +func handleWebSocketStreams(req *http.Request, w http.ResponseWriter, portForwarder PortForwarder, podName string, uid types.UID, supportedPortForwardProtocols []string, idleTimeout, streamCreationTimeout time.Duration) error { + opts, err := newV4Options(req) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, err.Error()) + return err + } + + channels := make([]wsstream.ChannelType, 0, len(opts.ports)*2) + for i := 0; i < len(opts.ports); i++ { + channels = append(channels, wsstream.ReadWriteChannel, wsstream.WriteChannel) + } + conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{ + "": { + Binary: true, + Channels: channels, + }, + v4BinaryWebsocketProtocol: { + Binary: true, + Channels: channels, + }, + v4Base64WebsocketProtocol: { + Binary: false, + Channels: channels, + }, + }) + conn.SetIdleTimeout(idleTimeout) + _, streams, err := conn.Open(httplog.Unlogged(w), req) + if err != nil { + err = fmt.Errorf("Unable to upgrade websocket connection: %v", err) + return err + } + defer conn.Close() + streamPairs := make([]*websocketStreamPair, len(opts.ports)) + for i := range streamPairs { + streamPair := websocketStreamPair{ + port: opts.ports[i], + dataStream: streams[i*2+dataChannel], + errorStream: streams[i*2+errorChannel], + } + streamPairs[i] = &streamPair + + portBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(portBytes, streamPair.port) + streamPair.dataStream.Write(portBytes) + streamPair.errorStream.Write(portBytes) + } + h := &websocketStreamHandler{ + conn: conn, + streamPairs: streamPairs, + pod: podName, + uid: uid, + forwarder: portForwarder, + } + h.run() + + return nil +} + +// websocketStreamPair represents the error and data streams for a port +// forwarding request. +type websocketStreamPair struct { + port uint16 + dataStream io.ReadWriteCloser + errorStream io.WriteCloser +} + +// websocketStreamHandler is capable of processing a single port forward +// request over a websocket connection +type websocketStreamHandler struct { + conn *wsstream.Conn + ports []uint16 + streamPairs []*websocketStreamPair + pod string + uid types.UID + forwarder PortForwarder +} + +// run invokes the websocketStreamHandler's forwarder.PortForward +// function for the given stream pair. +func (h *websocketStreamHandler) run() { + wg := sync.WaitGroup{} + wg.Add(len(h.streamPairs)) + + for _, pair := range h.streamPairs { + p := pair + go func() { + defer wg.Done() + h.portForward(p) + }() + } + + wg.Wait() +} + +func (h *websocketStreamHandler) portForward(p *websocketStreamPair) { + defer p.dataStream.Close() + defer p.errorStream.Close() + + glog.V(5).Infof("(conn=%p) invoking forwarder.PortForward for port %d", h.conn, p.port) + err := h.forwarder.PortForward(h.pod, h.uid, p.port, p.dataStream) + glog.V(5).Infof("(conn=%p) done invoking forwarder.PortForward for port %d", h.conn, p.port) + + if err != nil { + msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", p.port, h.pod, h.uid, err) + runtime.HandleError(msg) + fmt.Fprint(p.errorStream, msg.Error()) + } +} diff --git a/pkg/kubelet/server/server.go b/pkg/kubelet/server/server.go index 76eb3ef13d5..6918b8e68fb 100644 --- a/pkg/kubelet/server/server.go +++ b/pkg/kubelet/server/server.go @@ -335,9 +335,15 @@ func (s *Server) InstallDebuggingHandlers(criHandler http.Handler) { ws = new(restful.WebService) ws. Path("/portForward") + ws.Route(ws.GET("/{podNamespace}/{podID}"). + To(s.getPortForward). + Operation("getPortForward")) ws.Route(ws.POST("/{podNamespace}/{podID}"). To(s.getPortForward). Operation("getPortForward")) + ws.Route(ws.GET("/{podNamespace}/{podID}/{uid}"). + To(s.getPortForward). + Operation("getPortForward")) ws.Route(ws.POST("/{podNamespace}/{podID}/{uid}"). To(s.getPortForward). Operation("getPortForward")) @@ -720,7 +726,8 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp kubecontainer.GetPodFullName(pod), params.podUID, s.host.StreamingConnectionIdleTimeout(), - remotecommand.DefaultStreamCreationTimeout) + remotecommand.DefaultStreamCreationTimeout, + portforward.SupportedProtocols) } // ServeHTTP responds to HTTP requests on the Kubelet. diff --git a/pkg/kubelet/server/server_websocket_test.go b/pkg/kubelet/server/server_websocket_test.go new file mode 100644 index 00000000000..2d55100bfec --- /dev/null +++ b/pkg/kubelet/server/server_websocket_test.go @@ -0,0 +1,331 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package server + +import ( + "encoding/binary" + "fmt" + "io" + "strconv" + "sync" + "testing" + "time" + + "golang.org/x/net/websocket" + + "k8s.io/apimachinery/pkg/types" +) + +const ( + dataChannel = iota + errorChannel +) + +func TestServeWSPortForward(t *testing.T) { + tests := []struct { + port string + uid bool + clientData string + containerData string + shouldError bool + }{ + {port: "", shouldError: true}, + {port: "abc", shouldError: true}, + {port: "-1", shouldError: true}, + {port: "65536", shouldError: true}, + {port: "0", shouldError: true}, + {port: "1", shouldError: false}, + {port: "8000", shouldError: false}, + {port: "8000", clientData: "client data", containerData: "container data", shouldError: false}, + {port: "65535", shouldError: false}, + {port: "65535", uid: true, shouldError: false}, + } + + podNamespace := "other" + podName := "foo" + expectedPodName := getPodName(podName, podNamespace) + expectedUid := "9b01b80f-8fb4-11e4-95ab-4200af06647" + + for i, test := range tests { + fw := newServerTest() + defer fw.testHTTPServer.Close() + + fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { + return 0 + } + + portForwardFuncDone := make(chan struct{}) + + fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error { + defer close(portForwardFuncDone) + + if e, a := expectedPodName, name; e != a { + t.Fatalf("%d: pod name: expected '%v', got '%v'", i, e, a) + } + + if e, a := expectedUid, uid; test.uid && e != string(a) { + t.Fatalf("%d: uid: expected '%v', got '%v'", i, e, a) + } + + p, err := strconv.ParseUint(test.port, 10, 16) + if err != nil { + t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err) + } + if e, a := uint16(p), port; e != a { + t.Fatalf("%d: port: expected '%v', got '%v'", i, e, a) + } + + if test.clientData != "" { + fromClient := make([]byte, 32) + n, err := stream.Read(fromClient) + if err != nil { + t.Fatalf("%d: error reading client data: %v", i, err) + } + if e, a := test.clientData, string(fromClient[0:n]); e != a { + t.Fatalf("%d: client data: expected to receive '%v', got '%v'", i, e, a) + } + } + + if test.containerData != "" { + _, err := stream.Write([]byte(test.containerData)) + if err != nil { + t.Fatalf("%d: error writing container data: %v", i, err) + } + } + + return nil + } + + var url string + if test.uid { + url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, expectedUid, test.port) + } else { + url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port) + } + + ws, err := websocket.Dial(url, "", "http://127.0.0.1/") + if test.shouldError { + if err == nil { + t.Fatalf("%d: websocket dial expected err", i) + } + continue + } else if err != nil { + t.Fatalf("%d: websocket dial unexpected err: %v", i, err) + } + + defer ws.Close() + + p, err := strconv.ParseUint(test.port, 10, 16) + if err != nil { + t.Fatalf("%d: error parsing port string '%s': %v", i, test.port, err) + } + p16 := uint16(p) + + channel, data, err := wsRead(ws) + if err != nil { + t.Fatalf("%d: read failed: expected no error: got %v", i, err) + } + if channel != dataChannel { + t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, dataChannel) + } + if len(data) != binary.Size(p16) { + t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16)) + } + if e, a := p16, binary.LittleEndian.Uint16(data); e != a { + t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port) + } + + channel, data, err = wsRead(ws) + if err != nil { + t.Fatalf("%d: read succeeded: expected no error: got %v", i, err) + } + if channel != errorChannel { + t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, errorChannel) + } + if len(data) != binary.Size(p16) { + t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(p16)) + } + if e, a := p16, binary.LittleEndian.Uint16(data); e != a { + t.Fatalf("%d: wrong data: got %q: expected %s", i, data, test.port) + } + + if test.clientData != "" { + println("writing the client data") + err := wsWrite(ws, dataChannel, []byte(test.clientData)) + if err != nil { + t.Fatalf("%d: unexpected error writing client data: %v", i, err) + } + } + + if test.containerData != "" { + channel, data, err = wsRead(ws) + if err != nil { + t.Fatalf("%d: unexpected error reading container data: %v", i, err) + } + if e, a := test.containerData, string(data); e != a { + t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a) + } + } + + <-portForwardFuncDone + } +} + +func TestServeWSMultiplePortForward(t *testing.T) { + portsText := []string{"7000,8000", "9000"} + ports := []uint16{7000, 8000, 9000} + podNamespace := "other" + podName := "foo" + expectedPodName := getPodName(podName, podNamespace) + + fw := newServerTest() + defer fw.testHTTPServer.Close() + + fw.fakeKubelet.streamingConnectionIdleTimeoutFunc = func() time.Duration { + return 0 + } + + portForwardWG := sync.WaitGroup{} + portForwardWG.Add(len(ports)) + + portsMutex := sync.Mutex{} + portsForwarded := map[uint16]struct{}{} + + fw.fakeKubelet.portForwardFunc = func(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error { + defer portForwardWG.Done() + + if e, a := expectedPodName, name; e != a { + t.Fatalf("%d: pod name: expected '%v', got '%v'", port, e, a) + } + + portsMutex.Lock() + portsForwarded[port] = struct{}{} + portsMutex.Unlock() + + fromClient := make([]byte, 32) + n, err := stream.Read(fromClient) + if err != nil { + t.Fatalf("%d: error reading client data: %v", port, err) + } + if e, a := fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]); e != a { + t.Fatalf("%d: client data: expected to receive '%v', got '%v'", port, e, a) + } + + _, err = stream.Write([]byte(fmt.Sprintf("container data on port %d", port))) + if err != nil { + t.Fatalf("%d: error writing container data: %v", port, err) + } + + return nil + } + + url := fmt.Sprintf("ws://%s/portForward/%s/%s?", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName) + for _, port := range portsText { + url = url + fmt.Sprintf("port=%s&", port) + } + + ws, err := websocket.Dial(url, "", "http://127.0.0.1/") + if err != nil { + t.Fatalf("websocket dial unexpected err: %v", err) + } + + defer ws.Close() + + for i, port := range ports { + channel, data, err := wsRead(ws) + if err != nil { + t.Fatalf("%d: read failed: expected no error: got %v", i, err) + } + if int(channel) != i*2+dataChannel { + t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+dataChannel) + } + if len(data) != binary.Size(port) { + t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port)) + } + if e, a := port, binary.LittleEndian.Uint16(data); e != a { + t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port) + } + + channel, data, err = wsRead(ws) + if err != nil { + t.Fatalf("%d: read succeeded: expected no error: got %v", i, err) + } + if int(channel) != i*2+errorChannel { + t.Fatalf("%d: wrong channel: got %q: expected %q", i, channel, i*2+errorChannel) + } + if len(data) != binary.Size(port) { + t.Fatalf("%d: wrong data size: got %q: expected %d", i, data, binary.Size(port)) + } + if e, a := port, binary.LittleEndian.Uint16(data); e != a { + t.Fatalf("%d: wrong data: got %q: expected %d", i, data, port) + } + } + + for i, port := range ports { + println("writing the client data", port) + err := wsWrite(ws, byte(i*2+dataChannel), []byte(fmt.Sprintf("client data on port %d", port))) + if err != nil { + t.Fatalf("%d: unexpected error writing client data: %v", i, err) + } + + channel, data, err := wsRead(ws) + if err != nil { + t.Fatalf("%d: unexpected error reading container data: %v", i, err) + } + + if int(channel) != i*2+dataChannel { + t.Fatalf("%d: wrong channel: got %q: expected %q", port, channel, i*2+dataChannel) + } + if e, a := fmt.Sprintf("container data on port %d", port), string(data); e != a { + t.Fatalf("%d: expected to receive '%v' from container, got '%v'", i, e, a) + } + } + + portForwardWG.Wait() + + portsMutex.Lock() + defer portsMutex.Unlock() + if len(ports) != len(portsForwarded) { + t.Fatalf("expected to forward %d ports; got %v", len(ports), portsForwarded) + } +} +func wsWrite(conn *websocket.Conn, channel byte, data []byte) error { + frame := make([]byte, len(data)+1) + frame[0] = channel + copy(frame[1:], data) + err := websocket.Message.Send(conn, frame) + return err +} + +func wsRead(conn *websocket.Conn) (byte, []byte, error) { + for { + var data []byte + err := websocket.Message.Receive(conn, &data) + if err != nil { + return 0, nil, err + } + + if len(data) == 0 { + continue + } + + channel := data[0] + data = data[1:] + + return channel, data, err + } +} diff --git a/pkg/kubelet/server/streaming/server.go b/pkg/kubelet/server/streaming/server.go index b17ee1f035f..9d70d4cb21b 100644 --- a/pkg/kubelet/server/streaming/server.go +++ b/pkg/kubelet/server/streaming/server.go @@ -80,7 +80,12 @@ type Config struct { // The streaming protocols the server supports (understands and permits). See // k8s.io/kubernetes/pkg/kubelet/server/remotecommand/constants.go for available protocols. // Only used for SPDY streaming. - SupportedProtocols []string + SupportedRemoteCommandProtocols []string + + // The streaming protocols the server supports (understands and permits). See + // k8s.io/kubernetes/pkg/kubelet/server/portforward/constants.go for available protocols. + // Only used for SPDY streaming. + SupportedPortForwardProtocols []string // The config for serving over TLS. If nil, TLS will not be used. TLSConfig *tls.Config @@ -89,9 +94,10 @@ type Config struct { // DefaultConfig provides default values for server Config. The DefaultConfig is partial, so // some fields like Addr must still be provided. var DefaultConfig = Config{ - StreamIdleTimeout: 4 * time.Hour, - StreamCreationTimeout: remotecommand.DefaultStreamCreationTimeout, - SupportedProtocols: remotecommand.SupportedStreamingProtocols, + StreamIdleTimeout: 4 * time.Hour, + StreamCreationTimeout: remotecommand.DefaultStreamCreationTimeout, + SupportedRemoteCommandProtocols: remotecommand.SupportedStreamingProtocols, + SupportedPortForwardProtocols: portforward.SupportedProtocols, } // TODO(timstclair): Add auth(n/z) interface & handling. @@ -248,7 +254,7 @@ func (s *server) serveExec(req *restful.Request, resp *restful.Response) { streamOpts, s.config.StreamIdleTimeout, s.config.StreamCreationTimeout, - s.config.SupportedProtocols) + s.config.SupportedRemoteCommandProtocols) } func (s *server) serveAttach(req *restful.Request, resp *restful.Response) { @@ -280,7 +286,7 @@ func (s *server) serveAttach(req *restful.Request, resp *restful.Response) { streamOpts, s.config.StreamIdleTimeout, s.config.StreamCreationTimeout, - s.config.SupportedProtocols) + s.config.SupportedRemoteCommandProtocols) } func (s *server) servePortForward(req *restful.Request, resp *restful.Response) { @@ -303,7 +309,8 @@ func (s *server) servePortForward(req *restful.Request, resp *restful.Response) pf.PodSandboxId, "", // unused: podUID s.config.StreamIdleTimeout, - s.config.StreamCreationTimeout) + s.config.StreamCreationTimeout, + s.config.SupportedPortForwardProtocols) } // criAdapter wraps the Runtime functions to conform to the remotecommand interfaces. diff --git a/pkg/kubelet/server/streaming/server_test.go b/pkg/kubelet/server/streaming/server_test.go index c9581fffd7a..e8b26090ff8 100644 --- a/pkg/kubelet/server/streaming/server_test.go +++ b/pkg/kubelet/server/streaming/server_test.go @@ -240,7 +240,7 @@ func TestServePortForward(t *testing.T) { exec, err := remotecommand.NewExecutor(&restclient.Config{}, "POST", reqURL) require.NoError(t, err) - streamConn, _, err := exec.Dial(kubeletportforward.PortForwardProtocolV1Name) + streamConn, _, err := exec.Dial(kubeletportforward.ProtocolV1Name) require.NoError(t, err) defer streamConn.Close() diff --git a/pkg/registry/core/pod/rest/subresources.go b/pkg/registry/core/pod/rest/subresources.go index 9534d995385..0b3e27443b2 100644 --- a/pkg/registry/core/pod/rest/subresources.go +++ b/pkg/registry/core/pod/rest/subresources.go @@ -165,9 +165,10 @@ func (r *PortForwardREST) New() runtime.Object { return &api.Pod{} } -// NewConnectOptions returns nil since portforward doesn't take additional parameters +// NewConnectOptions returns the versioned object that represents the +// portforward parameters func (r *PortForwardREST) NewConnectOptions() (runtime.Object, bool, string) { - return nil, false, "" + return &api.PodPortForwardOptions{}, false, "" } // ConnectMethods returns the methods supported by portforward @@ -177,7 +178,11 @@ func (r *PortForwardREST) ConnectMethods() []string { // Connect returns a handler for the pod portforward proxy func (r *PortForwardREST) Connect(ctx genericapirequest.Context, name string, opts runtime.Object, responder rest.Responder) (http.Handler, error) { - location, transport, err := pod.PortForwardLocation(r.Store, r.KubeletConn, ctx, name) + portForwardOpts, ok := opts.(*api.PodPortForwardOptions) + if !ok { + return nil, fmt.Errorf("invalid options object: %#v", opts) + } + location, transport, err := pod.PortForwardLocation(r.Store, r.KubeletConn, ctx, name, portForwardOpts) if err != nil { return nil, err } diff --git a/pkg/registry/core/pod/strategy.go b/pkg/registry/core/pod/strategy.go index cc2ab68d57a..5ef5de96d24 100644 --- a/pkg/registry/core/pod/strategy.go +++ b/pkg/registry/core/pod/strategy.go @@ -383,6 +383,15 @@ func streamParams(params url.Values, opts runtime.Object) error { if opts.TTY { params.Add(api.ExecTTYParam, "1") } + case *api.PodPortForwardOptions: + if len(opts.Ports) == 0 { + return errors.NewBadRequest("at least one port must be specified") + } + ports := make([]string, len(opts.Ports)) + for i, p := range opts.Ports { + ports[i] = strconv.FormatInt(int64(p), 10) + } + params.Add(api.PortHeader, strings.Join(ports, ",")) default: return fmt.Errorf("Unknown object for streaming: %v", opts) } @@ -477,6 +486,7 @@ func PortForwardLocation( connInfo client.ConnectionInfoGetter, ctx genericapirequest.Context, name string, + opts *api.PodPortForwardOptions, ) (*url.URL, http.RoundTripper, error) { pod, err := getPod(getter, ctx, name) if err != nil { @@ -492,10 +502,15 @@ func PortForwardLocation( if err != nil { return nil, nil, err } + params := url.Values{} + if err := streamParams(params, opts); err != nil { + return nil, nil, err + } loc := &url.URL{ - Scheme: nodeInfo.Scheme, - Host: net.JoinHostPort(nodeInfo.Hostname, nodeInfo.Port), - Path: fmt.Sprintf("/portForward/%s/%s", pod.Namespace, pod.Name), + Scheme: nodeInfo.Scheme, + Host: net.JoinHostPort(nodeInfo.Hostname, nodeInfo.Port), + Path: fmt.Sprintf("/portForward/%s/%s", pod.Namespace, pod.Name), + RawQuery: params.Encode(), } return loc, nodeInfo.Transport, nil } diff --git a/pkg/registry/core/pod/strategy_test.go b/pkg/registry/core/pod/strategy_test.go index 7c6a369e8d1..6c5d140fc68 100644 --- a/pkg/registry/core/pod/strategy_test.go +++ b/pkg/registry/core/pod/strategy_test.go @@ -17,6 +17,7 @@ limitations under the License. package pod import ( + "net/url" "reflect" "testing" @@ -26,9 +27,11 @@ import ( "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" genericapirequest "k8s.io/apiserver/pkg/endpoints/request" "k8s.io/kubernetes/pkg/api" apitesting "k8s.io/kubernetes/pkg/api/testing" + "k8s.io/kubernetes/pkg/kubelet/client" ) func TestMatchPod(t *testing.T) { @@ -333,3 +336,64 @@ func TestSelectableFieldLabelConversions(t *testing.T) { nil, ) } + +type mockConnectionInfoGetter struct { + info *client.ConnectionInfo +} + +func (g mockConnectionInfoGetter) GetConnectionInfo(nodeName types.NodeName) (*client.ConnectionInfo, error) { + return g.info, nil +} + +func TestPortForwardLocation(t *testing.T) { + ctx := genericapirequest.NewDefaultContext() + tcs := []struct { + in *api.Pod + info *client.ConnectionInfo + opts *api.PodPortForwardOptions + expectedErr error + expectedURL *url.URL + }{ + { + in: &api.Pod{ + Spec: api.PodSpec{}, + }, + opts: &api.PodPortForwardOptions{}, + expectedErr: errors.NewBadRequest("pod test does not have a host assigned"), + }, + { + in: &api.Pod{ + Spec: api.PodSpec{ + NodeName: "node1", + }, + }, + opts: &api.PodPortForwardOptions{}, + expectedErr: errors.NewBadRequest("at least one port must be specified"), + }, + { + in: &api.Pod{ + ObjectMeta: api.ObjectMeta{ + Namespace: "ns", + Name: "pod1", + }, + Spec: api.PodSpec{ + NodeName: "node1", + }, + }, + info: &client.ConnectionInfo{}, + opts: &api.PodPortForwardOptions{Ports: []int32{80}}, + expectedURL: &url.URL{Host: ":", Path: "/portForward/ns/pod1", RawQuery: "port=80"}, + }, + } + for _, tc := range tcs { + getter := &mockPodGetter{tc.in} + connectionGetter := &mockConnectionInfoGetter{tc.info} + loc, _, err := PortForwardLocation(getter, connectionGetter, ctx, "test", tc.opts) + if !reflect.DeepEqual(err, tc.expectedErr) { + t.Errorf("expected %v, got %v", tc.expectedErr, err) + } + if !reflect.DeepEqual(loc, tc.expectedURL) { + t.Errorf("expected %v, got %v", tc.expectedURL, loc) + } + } +} diff --git a/test/e2e/portforward.go b/test/e2e/portforward.go index 55af281ef50..1a5e3fa615e 100644 --- a/test/e2e/portforward.go +++ b/test/e2e/portforward.go @@ -17,6 +17,8 @@ limitations under the License. package e2e import ( + "bytes" + "encoding/binary" "fmt" "io" "io/ioutil" @@ -28,6 +30,7 @@ import ( "syscall" "time" + "golang.org/x/net/websocket" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/kubernetes/pkg/api/v1" @@ -36,6 +39,7 @@ import ( testutils "k8s.io/kubernetes/test/utils" . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" ) const ( @@ -368,36 +372,144 @@ func doTestMustConnectSendDisconnect(bindAddress string, f *framework.Framework) verifyLogMessage(logOutput, "^Done$") } +func doTestOverWebSockets(bindAddress string, f *framework.Framework) { + config, err := framework.LoadConfig() + Expect(err).NotTo(HaveOccurred(), "unable to get base config") + + By("creating the pod") + pod := pfPod("def", "10", "10", "100", fmt.Sprintf("%s", bindAddress)) + if _, err := f.ClientSet.Core().Pods(f.Namespace.Name).Create(pod); err != nil { + framework.Failf("Couldn't create pod: %v", err) + } + if err := f.WaitForPodReady(pod.Name); err != nil { + framework.Failf("Pod did not start running: %v", err) + } + defer func() { + logs, err := framework.GetPodLogs(f.ClientSet, f.Namespace.Name, pod.Name, "portforwardtester") + if err != nil { + framework.Logf("Error getting pod log: %v", err) + } else { + framework.Logf("Pod log:\n%s", logs) + } + }() + + req := f.ClientSet.Core().RESTClient().Get(). + Namespace(f.Namespace.Name). + Resource("pods"). + Name(pod.Name). + Suffix("portforward"). + Param("ports", "80") + + url := req.URL() + ws, err := framework.OpenWebSocketForURL(url, config, []string{"v4.channel.k8s.io"}) + if err != nil { + framework.Failf("Failed to open websocket to %s: %v", url.String(), err) + } + defer ws.Close() + + Eventually(func() error { + channel, msg, err := wsRead(ws) + if err != nil { + return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err) + } + if channel != 0 { + return fmt.Errorf("Got message from server that didn't start with channel 0 (data): %v", msg) + } + if p := binary.LittleEndian.Uint16(msg); p != 80 { + return fmt.Errorf("Received the wrong port: %d", p) + } + return nil + }, time.Minute, 10*time.Second).Should(BeNil()) + + Eventually(func() error { + channel, msg, err := wsRead(ws) + if err != nil { + return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err) + } + if channel != 1 { + return fmt.Errorf("Got message from server that didn't start with channel 1 (error): %v", msg) + } + if p := binary.LittleEndian.Uint16(msg); p != 80 { + return fmt.Errorf("Received the wrong port: %d", p) + } + return nil + }, time.Minute, 10*time.Second).Should(BeNil()) + + By("sending the expected data to the local port") + err = wsWrite(ws, 0, []byte("def")) + if err != nil { + framework.Failf("Failed to write to websocket %s: %v", url.String(), err) + } + + By("reading data from the local port") + buf := bytes.Buffer{} + expectedData := bytes.Repeat([]byte("x"), 100) + Eventually(func() error { + channel, msg, err := wsRead(ws) + if err != nil { + return fmt.Errorf("Failed to read completely from websocket %s: %v", url.String(), err) + } + if channel != 0 { + return fmt.Errorf("Got message from server that didn't start with channel 0 (data): %v", msg) + } + buf.Write(msg) + if bytes.Equal(expectedData, buf.Bytes()) { + return fmt.Errorf("Expected %q from server, got %q", expectedData, buf.Bytes()) + } + return nil + }, time.Minute, 10*time.Second).Should(BeNil()) + + By("verifying logs") + logOutput, err := framework.GetPodLogs(f.ClientSet, f.Namespace.Name, pod.Name, "portforwardtester") + if err != nil { + framework.Failf("Error retrieving pod logs: %v", err) + } + verifyLogMessage(logOutput, "^Accepted client connection$") + verifyLogMessage(logOutput, "^Received expected client data$") +} + var _ = framework.KubeDescribe("Port forwarding", func() { f := framework.NewDefaultFramework("port-forwarding") - framework.KubeDescribe("With a server listening on 0.0.0.0 that expects a client request", func() { - It("should support a client that connects, sends no data, and disconnects", func() { - doTestMustConnectSendNothing("0.0.0.0", f) + framework.KubeDescribe("With a server listening on 0.0.0.0", func() { + framework.KubeDescribe("that expects a client request", func() { + It("should support a client that connects, sends no data, and disconnects", func() { + doTestMustConnectSendNothing("0.0.0.0", f) + }) + It("should support a client that connects, sends data, and disconnects", func() { + doTestMustConnectSendDisconnect("0.0.0.0", f) + }) }) - It("should support a client that connects, sends data, and disconnects", func() { - doTestMustConnectSendDisconnect("0.0.0.0", f) + + framework.KubeDescribe("that expects no client request", func() { + It("should support a client that connects, sends data, and disconnects", func() { + doTestConnectSendDisconnect("0.0.0.0", f) + }) + }) + + It("should support forwarding over websockets", func() { + doTestOverWebSockets("0.0.0.0", f) }) }) - framework.KubeDescribe("With a server listening on 0.0.0.0 that expects no client request", func() { - It("should support a client that connects, sends data, and disconnects", func() { - doTestConnectSendDisconnect("0.0.0.0", f) + framework.KubeDescribe("With a server listening on localhost", func() { + framework.KubeDescribe("that expects a client request", func() { + It("should support a client that connects, sends no data, and disconnects [Conformance]", func() { + doTestMustConnectSendNothing("localhost", f) + }) + It("should support a client that connects, sends data, and disconnects [Conformance]", func() { + doTestMustConnectSendDisconnect("localhost", f) + }) }) - }) - framework.KubeDescribe("With a server listening on localhost that expects a client request", func() { - It("should support a client that connects, sends no data, and disconnects [Conformance]", func() { - doTestMustConnectSendNothing("localhost", f) + framework.KubeDescribe("that expects no client request", func() { + It("should support a client that connects, sends data, and disconnects [Conformance]", func() { + doTestConnectSendDisconnect("localhost", f) + }) }) - It("should support a client that connects, sends data, and disconnects [Conformance]", func() { - doTestMustConnectSendDisconnect("localhost", f) - }) - }) - framework.KubeDescribe("With a server listening on localhost that expects no client request", func() { - It("should support a client that connects, sends data, and disconnects [Conformance]", func() { - doTestConnectSendDisconnect("localhost", f) + It("should support forwarding over websockets", func() { + doTestOverWebSockets("localhost", f) }) }) }) @@ -412,3 +524,30 @@ func verifyLogMessage(log, expected string) { } framework.Failf("Missing %q from log: %s", expected, log) } + +func wsRead(conn *websocket.Conn) (byte, []byte, error) { + for { + var data []byte + err := websocket.Message.Receive(conn, &data) + if err != nil { + return 0, nil, err + } + + if len(data) == 0 { + continue + } + + channel := data[0] + data = data[1:] + + return channel, data, err + } +} + +func wsWrite(conn *websocket.Conn, channel byte, data []byte) error { + frame := make([]byte, len(data)+1) + frame[0] = channel + copy(frame[1:], data) + err := websocket.Message.Send(conn, frame) + return err +}