From 52ed57ec3b6a0e47a59d23197475bb6fc5a13b97 Mon Sep 17 00:00:00 2001 From: "Tim St. Clair" Date: Tue, 25 Oct 2016 10:54:45 -0700 Subject: [PATCH] Refactor PortForward server methods into the portforward package --- pkg/client/unversioned/portforward/BUILD | 2 +- .../portforward/portforward_test.go | 8 +- pkg/kubelet/server/BUILD | 3 - pkg/kubelet/server/portforward/BUILD | 24 +- pkg/kubelet/server/portforward/portforward.go | 323 ++++++++++++++++++ .../server/portforward/portforward_test.go | 246 +++++++++++++ pkg/kubelet/server/server.go | 295 +--------------- pkg/kubelet/server/server_test.go | 220 ------------ 8 files changed, 598 insertions(+), 523 deletions(-) create mode 100644 pkg/kubelet/server/portforward/portforward.go create mode 100644 pkg/kubelet/server/portforward/portforward_test.go diff --git a/pkg/client/unversioned/portforward/BUILD b/pkg/client/unversioned/portforward/BUILD index 6a8bfc0834f..e15084d4ab4 100644 --- a/pkg/client/unversioned/portforward/BUILD +++ b/pkg/client/unversioned/portforward/BUILD @@ -33,7 +33,7 @@ go_test( deps = [ "//pkg/client/restclient:go_default_library", "//pkg/client/unversioned/remotecommand:go_default_library", - "//pkg/kubelet/server:go_default_library", + "//pkg/kubelet/server/portforward:go_default_library", "//pkg/types:go_default_library", "//pkg/util/httpstream:go_default_library", ], diff --git a/pkg/client/unversioned/portforward/portforward_test.go b/pkg/client/unversioned/portforward/portforward_test.go index 4bdd222c087..8a0d57b0952 100644 --- a/pkg/client/unversioned/portforward/portforward_test.go +++ b/pkg/client/unversioned/portforward/portforward_test.go @@ -33,7 +33,7 @@ import ( "k8s.io/kubernetes/pkg/client/restclient" "k8s.io/kubernetes/pkg/client/unversioned/remotecommand" - kubeletserver "k8s.io/kubernetes/pkg/kubelet/server" + "k8s.io/kubernetes/pkg/kubelet/server/portforward" "k8s.io/kubernetes/pkg/types" "k8s.io/kubernetes/pkg/util/httpstream" ) @@ -206,7 +206,7 @@ func TestGetListener(t *testing.T) { } // fakePortForwarder simulates port forwarding for testing. It implements -// kubeletserver.PortForwarder. +// portforward.PortForwarder. type fakePortForwarder struct { lock sync.Mutex // stores data expected from the stream per port @@ -217,7 +217,7 @@ type fakePortForwarder struct { send map[uint16]string } -var _ kubeletserver.PortForwarder = &fakePortForwarder{} +var _ portforward.PortForwarder = &fakePortForwarder{} func (pf *fakePortForwarder) PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error { defer stream.Close() @@ -252,7 +252,7 @@ func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedF received: make(map[uint16]string), send: serverSends, } - kubeletserver.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second) + portforward.ServePortForward(w, req, pf, "pod", "uid", 0, 10*time.Second) for port, expected := range expectedFromClient { actual, ok := pf.received[port] diff --git a/pkg/kubelet/server/BUILD b/pkg/kubelet/server/BUILD index d6b3f9e3651..0a8f51fb2d6 100644 --- a/pkg/kubelet/server/BUILD +++ b/pkg/kubelet/server/BUILD @@ -37,10 +37,7 @@ go_library( "//pkg/types:go_default_library", "//pkg/util/configz:go_default_library", "//pkg/util/flushwriter:go_default_library", - "//pkg/util/httpstream:go_default_library", - "//pkg/util/httpstream/spdy:go_default_library", "//pkg/util/limitwriter:go_default_library", - "//pkg/util/runtime:go_default_library", "//pkg/util/term:go_default_library", "//pkg/volume:go_default_library", "//vendor:github.com/emicklei/go-restful", diff --git a/pkg/kubelet/server/portforward/BUILD b/pkg/kubelet/server/portforward/BUILD index f4180b9ab7f..ca09ebb98cc 100644 --- a/pkg/kubelet/server/portforward/BUILD +++ b/pkg/kubelet/server/portforward/BUILD @@ -12,6 +12,28 @@ load( go_library( name = "go_default_library", - srcs = ["constants.go"], + srcs = [ + "constants.go", + "portforward.go", + ], tags = ["automanaged"], + deps = [ + "//pkg/api:go_default_library", + "//pkg/types:go_default_library", + "//pkg/util/httpstream:go_default_library", + "//pkg/util/httpstream/spdy:go_default_library", + "//pkg/util/runtime:go_default_library", + "//vendor:github.com/golang/glog", + ], +) + +go_test( + name = "go_default_test", + srcs = ["portforward_test.go"], + library = "go_default_library", + tags = ["automanaged"], + deps = [ + "//pkg/api:go_default_library", + "//pkg/util/httpstream:go_default_library", + ], ) diff --git a/pkg/kubelet/server/portforward/portforward.go b/pkg/kubelet/server/portforward/portforward.go new file mode 100644 index 00000000000..c7ca790ff64 --- /dev/null +++ b/pkg/kubelet/server/portforward/portforward.go @@ -0,0 +1,323 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "errors" + "fmt" + "io" + "net/http" + "strconv" + "sync" + "time" + + "github.com/golang/glog" + + "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/types" + "k8s.io/kubernetes/pkg/util/httpstream" + "k8s.io/kubernetes/pkg/util/httpstream/spdy" + utilruntime "k8s.io/kubernetes/pkg/util/runtime" +) + +// PortForwarder knows how to forward content from a data stream to/from a port +// in a pod. +type PortForwarder interface { + // PortForwarder copies data between a data stream and a port in a pod. + PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error +} + +// ServePortForward handles a port forwarding request. A single request is +// kept alive as long as the client is still alive and the connection has not +// been timed out due to idleness. This function handles multiple forwarded +// connections; i.e., multiple `curl http://localhost:8888/` requests will be +// handled by a single invocation of ServePortForward. +func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) { + supportedPortForwardProtocols := []string{PortForwardProtocolV1Name} + _, err := httpstream.Handshake(req, w, supportedPortForwardProtocols) + // negotiated protocol isn't currently used server side, but could be in the future + if err != nil { + // Handshake writes the error to the client + utilruntime.HandleError(err) + return + } + + streamChan := make(chan httpstream.Stream, 1) + + glog.V(5).Infof("Upgrading port forward response") + upgrader := spdy.NewResponseUpgrader() + conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan)) + if conn == nil { + return + } + defer conn.Close() + + glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout) + conn.SetIdleTimeout(idleTimeout) + + h := &portForwardStreamHandler{ + conn: conn, + streamChan: streamChan, + streamPairs: make(map[string]*portForwardStreamPair), + streamCreationTimeout: streamCreationTimeout, + pod: podName, + uid: uid, + forwarder: portForwarder, + } + h.run() +} + +// portForwardStreamReceived is the httpstream.NewStreamHandler for port +// forward streams. It checks each stream's port and stream type headers, +// rejecting any streams that with missing or invalid values. Each valid +// stream is sent to the streams channel. +func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error { + return func(stream httpstream.Stream, replySent <-chan struct{}) error { + // make sure it has a valid port header + portString := stream.Headers().Get(api.PortHeader) + if len(portString) == 0 { + return fmt.Errorf("%q header is required", api.PortHeader) + } + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return fmt.Errorf("unable to parse %q as a port: %v", portString, err) + } + if port < 1 { + return fmt.Errorf("port %q must be > 0", portString) + } + + // make sure it has a valid stream type header + streamType := stream.Headers().Get(api.StreamType) + if len(streamType) == 0 { + return fmt.Errorf("%q header is required", api.StreamType) + } + if streamType != api.StreamTypeError && streamType != api.StreamTypeData { + return fmt.Errorf("invalid stream type %q", streamType) + } + + streams <- stream + return nil + } +} + +// portForwardStreamHandler is capable of processing multiple port forward +// requests over a single httpstream.Connection. +type portForwardStreamHandler struct { + conn httpstream.Connection + streamChan chan httpstream.Stream + streamPairsLock sync.RWMutex + streamPairs map[string]*portForwardStreamPair + streamCreationTimeout time.Duration + pod string + uid types.UID + forwarder PortForwarder +} + +// getStreamPair returns a portForwardStreamPair for requestID. This creates a +// new pair if one does not yet exist for the requestID. The returned bool is +// true if the pair was created. +func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) { + h.streamPairsLock.Lock() + defer h.streamPairsLock.Unlock() + + if p, ok := h.streamPairs[requestID]; ok { + glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID) + return p, false + } + + glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID) + + p := newPortForwardPair(requestID) + h.streamPairs[requestID] = p + + return p, true +} + +// monitorStreamPair waits for the pair to receive both its error and data +// streams, or for the timeout to expire (whichever happens first), and then +// removes the pair. +func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) { + select { + case <-timeout: + err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID) + utilruntime.HandleError(err) + p.printError(err.Error()) + case <-p.complete: + glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID) + } + h.removeStreamPair(p.requestID) +} + +// hasStreamPair returns a bool indicating if a stream pair for requestID +// exists. +func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool { + h.streamPairsLock.RLock() + defer h.streamPairsLock.RUnlock() + + _, ok := h.streamPairs[requestID] + return ok +} + +// removeStreamPair removes the stream pair identified by requestID from streamPairs. +func (h *portForwardStreamHandler) removeStreamPair(requestID string) { + h.streamPairsLock.Lock() + defer h.streamPairsLock.Unlock() + + delete(h.streamPairs, requestID) +} + +// requestID returns the request id for stream. +func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string { + requestID := stream.Headers().Get(api.PortForwardRequestIDHeader) + if len(requestID) == 0 { + glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader) + // If we get here, it's because the connection came from an older client + // that isn't generating the request id header + // (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287) + // + // This is a best-effort attempt at supporting older clients. + // + // When there aren't concurrent new forwarded connections, each connection + // will have a pair of streams (data, error), and the stream IDs will be + // consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert + // the stream ID into a pseudo-request id by taking the stream type and + // using id = stream.Identifier() when the stream type is error, + // and id = stream.Identifier() - 2 when it's data. + // + // NOTE: this only works when there are not concurrent new streams from + // multiple forwarded connections; it's a best-effort attempt at supporting + // old clients that don't generate request ids. If there are concurrent + // new connections, it's possible that 1 connection gets streams whose IDs + // are not consecutive (e.g. 5 and 9 instead of 5 and 7). + streamType := stream.Headers().Get(api.StreamType) + switch streamType { + case api.StreamTypeError: + requestID = strconv.Itoa(int(stream.Identifier())) + case api.StreamTypeData: + requestID = strconv.Itoa(int(stream.Identifier()) - 2) + } + + glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier()) + } + return requestID +} + +// run is the main loop for the portForwardStreamHandler. It processes new +// streams, invoking portForward for each complete stream pair. The loop exits +// when the httpstream.Connection is closed. +func (h *portForwardStreamHandler) run() { + glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn) +Loop: + for { + select { + case <-h.conn.CloseChan(): + glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn) + break Loop + case stream := <-h.streamChan: + requestID := h.requestID(stream) + streamType := stream.Headers().Get(api.StreamType) + glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType) + + p, created := h.getStreamPair(requestID) + if created { + go h.monitorStreamPair(p, time.After(h.streamCreationTimeout)) + } + if complete, err := p.add(stream); err != nil { + msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err) + utilruntime.HandleError(errors.New(msg)) + p.printError(msg) + } else if complete { + go h.portForward(p) + } + } + } +} + +// portForward invokes the portForwardStreamHandler's forwarder.PortForward +// function for the given stream pair. +func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) { + defer p.dataStream.Close() + defer p.errorStream.Close() + + portString := p.dataStream.Headers().Get(api.PortHeader) + port, _ := strconv.ParseUint(portString, 10, 16) + + glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) + err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream) + glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) + + if err != nil { + msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err) + utilruntime.HandleError(msg) + fmt.Fprint(p.errorStream, msg.Error()) + } +} + +// portForwardStreamPair represents the error and data streams for a port +// forwarding request. +type portForwardStreamPair struct { + lock sync.RWMutex + requestID string + dataStream httpstream.Stream + errorStream httpstream.Stream + complete chan struct{} +} + +// newPortForwardPair creates a new portForwardStreamPair. +func newPortForwardPair(requestID string) *portForwardStreamPair { + return &portForwardStreamPair{ + requestID: requestID, + complete: make(chan struct{}), + } +} + +// add adds the stream to the portForwardStreamPair. If the pair already +// contains a stream for the new stream's type, an error is returned. add +// returns true if both the data and error streams for this pair have been +// received. +func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) { + p.lock.Lock() + defer p.lock.Unlock() + + switch stream.Headers().Get(api.StreamType) { + case api.StreamTypeError: + if p.errorStream != nil { + return false, errors.New("error stream already assigned") + } + p.errorStream = stream + case api.StreamTypeData: + if p.dataStream != nil { + return false, errors.New("data stream already assigned") + } + p.dataStream = stream + } + + complete := p.errorStream != nil && p.dataStream != nil + if complete { + close(p.complete) + } + return complete, nil +} + +// printError writes s to p.errorStream if p.errorStream has been set. +func (p *portForwardStreamPair) printError(s string) { + p.lock.RLock() + defer p.lock.RUnlock() + if p.errorStream != nil { + fmt.Fprint(p.errorStream, s) + } +} diff --git a/pkg/kubelet/server/portforward/portforward_test.go b/pkg/kubelet/server/portforward/portforward_test.go new file mode 100644 index 00000000000..cfaed560823 --- /dev/null +++ b/pkg/kubelet/server/portforward/portforward_test.go @@ -0,0 +1,246 @@ +/* +Copyright 2016 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package portforward + +import ( + "net/http" + "testing" + "time" + + "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/util/httpstream" +) + +func TestPortForwardStreamReceived(t *testing.T) { + tests := map[string]struct { + port string + streamType string + expectedError string + }{ + "missing port": { + expectedError: `"port" header is required`, + }, + "unable to parse port": { + port: "abc", + expectedError: `unable to parse "abc" as a port: strconv.ParseUint: parsing "abc": invalid syntax`, + }, + "negative port": { + port: "-1", + expectedError: `unable to parse "-1" as a port: strconv.ParseUint: parsing "-1": invalid syntax`, + }, + "missing stream type": { + port: "80", + expectedError: `"streamType" header is required`, + }, + "valid port with error stream": { + port: "80", + streamType: "error", + }, + "valid port with data stream": { + port: "80", + streamType: "data", + }, + "invalid stream type": { + port: "80", + streamType: "foo", + expectedError: `invalid stream type "foo"`, + }, + } + for name, test := range tests { + streams := make(chan httpstream.Stream, 1) + f := portForwardStreamReceived(streams) + stream := newFakeHttpStream() + if len(test.port) > 0 { + stream.headers.Set("port", test.port) + } + if len(test.streamType) > 0 { + stream.headers.Set("streamType", test.streamType) + } + replySent := make(chan struct{}) + err := f(stream, replySent) + close(replySent) + if len(test.expectedError) > 0 { + if err == nil { + t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError) + } + if e, a := test.expectedError, err.Error(); e != a { + t.Errorf("%s: expected err=%q, got %q", name, e, a) + } + continue + } + if err != nil { + t.Errorf("%s: unexpected error %v", name, err) + continue + } + if s := <-streams; s != stream { + t.Errorf("%s: expected stream %#v, got %#v", name, stream, s) + } + } +} + +type fakeHttpStream struct { + headers http.Header + id uint32 +} + +func newFakeHttpStream() *fakeHttpStream { + return &fakeHttpStream{ + headers: make(http.Header), + } +} + +var _ httpstream.Stream = &fakeHttpStream{} + +func (s *fakeHttpStream) Read(data []byte) (int, error) { + return 0, nil +} + +func (s *fakeHttpStream) Write(data []byte) (int, error) { + return 0, nil +} + +func (s *fakeHttpStream) Close() error { + return nil +} + +func (s *fakeHttpStream) Reset() error { + return nil +} + +func (s *fakeHttpStream) Headers() http.Header { + return s.headers +} + +func (s *fakeHttpStream) Identifier() uint32 { + return s.id +} + +func TestGetStreamPair(t *testing.T) { + timeout := make(chan time.Time) + + h := &portForwardStreamHandler{ + streamPairs: make(map[string]*portForwardStreamPair), + } + + // test adding a new entry + p, created := h.getStreamPair("1") + if p == nil { + t.Fatalf("unexpected nil pair") + } + if !created { + t.Fatal("expected created=true") + } + if p.dataStream != nil { + t.Errorf("unexpected non-nil data stream") + } + if p.errorStream != nil { + t.Errorf("unexpected non-nil error stream") + } + + // start the monitor for this pair + monitorDone := make(chan struct{}) + go func() { + h.monitorStreamPair(p, timeout) + close(monitorDone) + }() + + if !h.hasStreamPair("1") { + t.Fatal("This should still be true") + } + + // make sure we can retrieve an existing entry + p2, created := h.getStreamPair("1") + if created { + t.Fatal("expected created=false") + } + if p != p2 { + t.Fatalf("retrieving an existing pair: expected %#v, got %#v", p, p2) + } + + // removed via complete + dataStream := newFakeHttpStream() + dataStream.headers.Set(api.StreamType, api.StreamTypeData) + complete, err := p.add(dataStream) + if err != nil { + t.Fatalf("unexpected error adding data stream to pair: %v", err) + } + if complete { + t.Fatalf("unexpected complete") + } + + errorStream := newFakeHttpStream() + errorStream.headers.Set(api.StreamType, api.StreamTypeError) + complete, err = p.add(errorStream) + if err != nil { + t.Fatalf("unexpected error adding error stream to pair: %v", err) + } + if !complete { + t.Fatal("unexpected incomplete") + } + + // make sure monitorStreamPair completed + <-monitorDone + + // make sure the pair was removed + if h.hasStreamPair("1") { + t.Fatal("expected removal of pair after both data and error streams received") + } + + // removed via timeout + p, created = h.getStreamPair("2") + if !created { + t.Fatal("expected created=true") + } + if p == nil { + t.Fatal("expected p not to be nil") + } + monitorDone = make(chan struct{}) + go func() { + h.monitorStreamPair(p, timeout) + close(monitorDone) + }() + // cause the timeout + close(timeout) + // make sure monitorStreamPair completed + <-monitorDone + if h.hasStreamPair("2") { + t.Fatal("expected stream pair to be removed") + } +} + +func TestRequestID(t *testing.T) { + h := &portForwardStreamHandler{} + + s := newFakeHttpStream() + s.headers.Set(api.StreamType, api.StreamTypeError) + s.id = 1 + if e, a := "1", h.requestID(s); e != a { + t.Errorf("expected %q, got %q", e, a) + } + + s.headers.Set(api.StreamType, api.StreamTypeData) + s.id = 3 + if e, a := "1", h.requestID(s); e != a { + t.Errorf("expected %q, got %q", e, a) + } + + s.id = 7 + s.headers.Set(api.PortForwardRequestIDHeader, "2") + if e, a := "2", h.requestID(s); e != a { + t.Errorf("expected %q, got %q", e, a) + } +} diff --git a/pkg/kubelet/server/server.go b/pkg/kubelet/server/server.go index 917ae890560..0aeac18e9d3 100644 --- a/pkg/kubelet/server/server.go +++ b/pkg/kubelet/server/server.go @@ -18,7 +18,6 @@ package server import ( "crypto/tls" - "errors" "fmt" "io" "net" @@ -27,7 +26,6 @@ import ( "reflect" "strconv" "strings" - "sync" "time" restful "github.com/emicklei/go-restful" @@ -54,10 +52,7 @@ import ( "k8s.io/kubernetes/pkg/types" "k8s.io/kubernetes/pkg/util/configz" "k8s.io/kubernetes/pkg/util/flushwriter" - "k8s.io/kubernetes/pkg/util/httpstream" - "k8s.io/kubernetes/pkg/util/httpstream/spdy" "k8s.io/kubernetes/pkg/util/limitwriter" - utilruntime "k8s.io/kubernetes/pkg/util/runtime" "k8s.io/kubernetes/pkg/util/term" "k8s.io/kubernetes/pkg/volume" ) @@ -650,13 +645,6 @@ func writeJsonResponse(response *restful.Response, data []byte) { } } -// PortForwarder knows how to forward content from a data stream to/from a port -// in a pod. -type PortForwarder interface { - // PortForwarder copies data between a data stream and a port in a pod. - PortForward(name string, uid types.UID, port uint16, stream io.ReadWriteCloser) error -} - // getPortForward handles a new restful port forward request. It determines the // pod name and uid and then calls ServePortForward. func (s *Server) getPortForward(request *restful.Request, response *restful.Response) { @@ -669,288 +657,7 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp podName := kubecontainer.GetPodFullName(pod) - ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), remotecommand.DefaultStreamCreationTimeout) -} - -// ServePortForward handles a port forwarding request. A single request is -// kept alive as long as the client is still alive and the connection has not -// been timed out due to idleness. This function handles multiple forwarded -// connections; i.e., multiple `curl http://localhost:8888/` requests will be -// handled by a single invocation of ServePortForward. -func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) { - supportedPortForwardProtocols := []string{portforward.PortForwardProtocolV1Name} - _, err := httpstream.Handshake(req, w, supportedPortForwardProtocols) - // negotiated protocol isn't currently used server side, but could be in the future - if err != nil { - // Handshake writes the error to the client - utilruntime.HandleError(err) - return - } - - streamChan := make(chan httpstream.Stream, 1) - - glog.V(5).Infof("Upgrading port forward response") - upgrader := spdy.NewResponseUpgrader() - conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan)) - if conn == nil { - return - } - defer conn.Close() - - glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout) - conn.SetIdleTimeout(idleTimeout) - - h := &portForwardStreamHandler{ - conn: conn, - streamChan: streamChan, - streamPairs: make(map[string]*portForwardStreamPair), - streamCreationTimeout: streamCreationTimeout, - pod: podName, - uid: uid, - forwarder: portForwarder, - } - h.run() -} - -// portForwardStreamReceived is the httpstream.NewStreamHandler for port -// forward streams. It checks each stream's port and stream type headers, -// rejecting any streams that with missing or invalid values. Each valid -// stream is sent to the streams channel. -func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error { - return func(stream httpstream.Stream, replySent <-chan struct{}) error { - // make sure it has a valid port header - portString := stream.Headers().Get(api.PortHeader) - if len(portString) == 0 { - return fmt.Errorf("%q header is required", api.PortHeader) - } - port, err := strconv.ParseUint(portString, 10, 16) - if err != nil { - return fmt.Errorf("unable to parse %q as a port: %v", portString, err) - } - if port < 1 { - return fmt.Errorf("port %q must be > 0", portString) - } - - // make sure it has a valid stream type header - streamType := stream.Headers().Get(api.StreamType) - if len(streamType) == 0 { - return fmt.Errorf("%q header is required", api.StreamType) - } - if streamType != api.StreamTypeError && streamType != api.StreamTypeData { - return fmt.Errorf("invalid stream type %q", streamType) - } - - streams <- stream - return nil - } -} - -// portForwardStreamHandler is capable of processing multiple port forward -// requests over a single httpstream.Connection. -type portForwardStreamHandler struct { - conn httpstream.Connection - streamChan chan httpstream.Stream - streamPairsLock sync.RWMutex - streamPairs map[string]*portForwardStreamPair - streamCreationTimeout time.Duration - pod string - uid types.UID - forwarder PortForwarder -} - -// getStreamPair returns a portForwardStreamPair for requestID. This creates a -// new pair if one does not yet exist for the requestID. The returned bool is -// true if the pair was created. -func (h *portForwardStreamHandler) getStreamPair(requestID string) (*portForwardStreamPair, bool) { - h.streamPairsLock.Lock() - defer h.streamPairsLock.Unlock() - - if p, ok := h.streamPairs[requestID]; ok { - glog.V(5).Infof("(conn=%p, request=%s) found existing stream pair", h.conn, requestID) - return p, false - } - - glog.V(5).Infof("(conn=%p, request=%s) creating new stream pair", h.conn, requestID) - - p := newPortForwardPair(requestID) - h.streamPairs[requestID] = p - - return p, true -} - -// monitorStreamPair waits for the pair to receive both its error and data -// streams, or for the timeout to expire (whichever happens first), and then -// removes the pair. -func (h *portForwardStreamHandler) monitorStreamPair(p *portForwardStreamPair, timeout <-chan time.Time) { - select { - case <-timeout: - err := fmt.Errorf("(conn=%v, request=%s) timed out waiting for streams", h.conn, p.requestID) - utilruntime.HandleError(err) - p.printError(err.Error()) - case <-p.complete: - glog.V(5).Infof("(conn=%v, request=%s) successfully received error and data streams", h.conn, p.requestID) - } - h.removeStreamPair(p.requestID) -} - -// hasStreamPair returns a bool indicating if a stream pair for requestID -// exists. -func (h *portForwardStreamHandler) hasStreamPair(requestID string) bool { - h.streamPairsLock.RLock() - defer h.streamPairsLock.RUnlock() - - _, ok := h.streamPairs[requestID] - return ok -} - -// removeStreamPair removes the stream pair identified by requestID from streamPairs. -func (h *portForwardStreamHandler) removeStreamPair(requestID string) { - h.streamPairsLock.Lock() - defer h.streamPairsLock.Unlock() - - delete(h.streamPairs, requestID) -} - -// requestID returns the request id for stream. -func (h *portForwardStreamHandler) requestID(stream httpstream.Stream) string { - requestID := stream.Headers().Get(api.PortForwardRequestIDHeader) - if len(requestID) == 0 { - glog.V(5).Infof("(conn=%p) stream received without %s header", h.conn, api.PortForwardRequestIDHeader) - // If we get here, it's because the connection came from an older client - // that isn't generating the request id header - // (https://github.com/kubernetes/kubernetes/blob/843134885e7e0b360eb5441e85b1410a8b1a7a0c/pkg/client/unversioned/portforward/portforward.go#L258-L287) - // - // This is a best-effort attempt at supporting older clients. - // - // When there aren't concurrent new forwarded connections, each connection - // will have a pair of streams (data, error), and the stream IDs will be - // consecutive odd numbers, e.g. 1 and 3 for the first connection. Convert - // the stream ID into a pseudo-request id by taking the stream type and - // using id = stream.Identifier() when the stream type is error, - // and id = stream.Identifier() - 2 when it's data. - // - // NOTE: this only works when there are not concurrent new streams from - // multiple forwarded connections; it's a best-effort attempt at supporting - // old clients that don't generate request ids. If there are concurrent - // new connections, it's possible that 1 connection gets streams whose IDs - // are not consecutive (e.g. 5 and 9 instead of 5 and 7). - streamType := stream.Headers().Get(api.StreamType) - switch streamType { - case api.StreamTypeError: - requestID = strconv.Itoa(int(stream.Identifier())) - case api.StreamTypeData: - requestID = strconv.Itoa(int(stream.Identifier()) - 2) - } - - glog.V(5).Infof("(conn=%p) automatically assigning request ID=%q from stream type=%s, stream ID=%d", h.conn, requestID, streamType, stream.Identifier()) - } - return requestID -} - -// run is the main loop for the portForwardStreamHandler. It processes new -// streams, invoking portForward for each complete stream pair. The loop exits -// when the httpstream.Connection is closed. -func (h *portForwardStreamHandler) run() { - glog.V(5).Infof("(conn=%p) waiting for port forward streams", h.conn) -Loop: - for { - select { - case <-h.conn.CloseChan(): - glog.V(5).Infof("(conn=%p) upgraded connection closed", h.conn) - break Loop - case stream := <-h.streamChan: - requestID := h.requestID(stream) - streamType := stream.Headers().Get(api.StreamType) - glog.V(5).Infof("(conn=%p, request=%s) received new stream of type %s", h.conn, requestID, streamType) - - p, created := h.getStreamPair(requestID) - if created { - go h.monitorStreamPair(p, time.After(h.streamCreationTimeout)) - } - if complete, err := p.add(stream); err != nil { - msg := fmt.Sprintf("error processing stream for request %s: %v", requestID, err) - utilruntime.HandleError(errors.New(msg)) - p.printError(msg) - } else if complete { - go h.portForward(p) - } - } - } -} - -// portForward invokes the portForwardStreamHandler's forwarder.PortForward -// function for the given stream pair. -func (h *portForwardStreamHandler) portForward(p *portForwardStreamPair) { - defer p.dataStream.Close() - defer p.errorStream.Close() - - portString := p.dataStream.Headers().Get(api.PortHeader) - port, _ := strconv.ParseUint(portString, 10, 16) - - glog.V(5).Infof("(conn=%p, request=%s) invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) - err := h.forwarder.PortForward(h.pod, h.uid, uint16(port), p.dataStream) - glog.V(5).Infof("(conn=%p, request=%s) done invoking forwarder.PortForward for port %s", h.conn, p.requestID, portString) - - if err != nil { - msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", port, h.pod, h.uid, err) - utilruntime.HandleError(msg) - fmt.Fprint(p.errorStream, msg.Error()) - } -} - -// portForwardStreamPair represents the error and data streams for a port -// forwarding request. -type portForwardStreamPair struct { - lock sync.RWMutex - requestID string - dataStream httpstream.Stream - errorStream httpstream.Stream - complete chan struct{} -} - -// newPortForwardPair creates a new portForwardStreamPair. -func newPortForwardPair(requestID string) *portForwardStreamPair { - return &portForwardStreamPair{ - requestID: requestID, - complete: make(chan struct{}), - } -} - -// add adds the stream to the portForwardStreamPair. If the pair already -// contains a stream for the new stream's type, an error is returned. add -// returns true if both the data and error streams for this pair have been -// received. -func (p *portForwardStreamPair) add(stream httpstream.Stream) (bool, error) { - p.lock.Lock() - defer p.lock.Unlock() - - switch stream.Headers().Get(api.StreamType) { - case api.StreamTypeError: - if p.errorStream != nil { - return false, errors.New("error stream already assigned") - } - p.errorStream = stream - case api.StreamTypeData: - if p.dataStream != nil { - return false, errors.New("data stream already assigned") - } - p.dataStream = stream - } - - complete := p.errorStream != nil && p.dataStream != nil - if complete { - close(p.complete) - } - return complete, nil -} - -// printError writes s to p.errorStream if p.errorStream has been set. -func (p *portForwardStreamPair) printError(s string) { - p.lock.RLock() - defer p.lock.RUnlock() - if p.errorStream != nil { - fmt.Fprint(p.errorStream, s) - } + portforward.ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), remotecommand.DefaultStreamCreationTimeout) } // ServeHTTP responds to HTTP requests on the Kubelet. diff --git a/pkg/kubelet/server/server_test.go b/pkg/kubelet/server/server_test.go index 4674a6ae4e8..67586204a93 100644 --- a/pkg/kubelet/server/server_test.go +++ b/pkg/kubelet/server/server_test.go @@ -1530,223 +1530,3 @@ func TestServePortForward(t *testing.T) { <-portForwardFuncDone } } - -type fakeHttpStream struct { - headers http.Header - id uint32 -} - -func newFakeHttpStream() *fakeHttpStream { - return &fakeHttpStream{ - headers: make(http.Header), - } -} - -var _ httpstream.Stream = &fakeHttpStream{} - -func (s *fakeHttpStream) Read(data []byte) (int, error) { - return 0, nil -} - -func (s *fakeHttpStream) Write(data []byte) (int, error) { - return 0, nil -} - -func (s *fakeHttpStream) Close() error { - return nil -} - -func (s *fakeHttpStream) Reset() error { - return nil -} - -func (s *fakeHttpStream) Headers() http.Header { - return s.headers -} - -func (s *fakeHttpStream) Identifier() uint32 { - return s.id -} - -func TestPortForwardStreamReceived(t *testing.T) { - tests := map[string]struct { - port string - streamType string - expectedError string - }{ - "missing port": { - expectedError: `"port" header is required`, - }, - "unable to parse port": { - port: "abc", - expectedError: `unable to parse "abc" as a port: strconv.ParseUint: parsing "abc": invalid syntax`, - }, - "negative port": { - port: "-1", - expectedError: `unable to parse "-1" as a port: strconv.ParseUint: parsing "-1": invalid syntax`, - }, - "missing stream type": { - port: "80", - expectedError: `"streamType" header is required`, - }, - "valid port with error stream": { - port: "80", - streamType: "error", - }, - "valid port with data stream": { - port: "80", - streamType: "data", - }, - "invalid stream type": { - port: "80", - streamType: "foo", - expectedError: `invalid stream type "foo"`, - }, - } - for name, test := range tests { - streams := make(chan httpstream.Stream, 1) - f := portForwardStreamReceived(streams) - stream := newFakeHttpStream() - if len(test.port) > 0 { - stream.headers.Set("port", test.port) - } - if len(test.streamType) > 0 { - stream.headers.Set("streamType", test.streamType) - } - replySent := make(chan struct{}) - err := f(stream, replySent) - close(replySent) - if len(test.expectedError) > 0 { - if err == nil { - t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError) - } - if e, a := test.expectedError, err.Error(); e != a { - t.Errorf("%s: expected err=%q, got %q", name, e, a) - } - continue - } - if err != nil { - t.Errorf("%s: unexpected error %v", name, err) - continue - } - if s := <-streams; s != stream { - t.Errorf("%s: expected stream %#v, got %#v", name, stream, s) - } - } -} - -func TestGetStreamPair(t *testing.T) { - timeout := make(chan time.Time) - - h := &portForwardStreamHandler{ - streamPairs: make(map[string]*portForwardStreamPair), - } - - // test adding a new entry - p, created := h.getStreamPair("1") - if p == nil { - t.Fatalf("unexpected nil pair") - } - if !created { - t.Fatal("expected created=true") - } - if p.dataStream != nil { - t.Errorf("unexpected non-nil data stream") - } - if p.errorStream != nil { - t.Errorf("unexpected non-nil error stream") - } - - // start the monitor for this pair - monitorDone := make(chan struct{}) - go func() { - h.monitorStreamPair(p, timeout) - close(monitorDone) - }() - - if !h.hasStreamPair("1") { - t.Fatal("This should still be true") - } - - // make sure we can retrieve an existing entry - p2, created := h.getStreamPair("1") - if created { - t.Fatal("expected created=false") - } - if p != p2 { - t.Fatalf("retrieving an existing pair: expected %#v, got %#v", p, p2) - } - - // removed via complete - dataStream := newFakeHttpStream() - dataStream.headers.Set(api.StreamType, api.StreamTypeData) - complete, err := p.add(dataStream) - if err != nil { - t.Fatalf("unexpected error adding data stream to pair: %v", err) - } - if complete { - t.Fatalf("unexpected complete") - } - - errorStream := newFakeHttpStream() - errorStream.headers.Set(api.StreamType, api.StreamTypeError) - complete, err = p.add(errorStream) - if err != nil { - t.Fatalf("unexpected error adding error stream to pair: %v", err) - } - if !complete { - t.Fatal("unexpected incomplete") - } - - // make sure monitorStreamPair completed - <-monitorDone - - // make sure the pair was removed - if h.hasStreamPair("1") { - t.Fatal("expected removal of pair after both data and error streams received") - } - - // removed via timeout - p, created = h.getStreamPair("2") - if !created { - t.Fatal("expected created=true") - } - if p == nil { - t.Fatal("expected p not to be nil") - } - monitorDone = make(chan struct{}) - go func() { - h.monitorStreamPair(p, timeout) - close(monitorDone) - }() - // cause the timeout - close(timeout) - // make sure monitorStreamPair completed - <-monitorDone - if h.hasStreamPair("2") { - t.Fatal("expected stream pair to be removed") - } -} - -func TestRequestID(t *testing.T) { - h := &portForwardStreamHandler{} - - s := newFakeHttpStream() - s.headers.Set(api.StreamType, api.StreamTypeError) - s.id = 1 - if e, a := "1", h.requestID(s); e != a { - t.Errorf("expected %q, got %q", e, a) - } - - s.headers.Set(api.StreamType, api.StreamTypeData) - s.id = 3 - if e, a := "1", h.requestID(s); e != a { - t.Errorf("expected %q, got %q", e, a) - } - - s.id = 7 - s.headers.Set(api.PortForwardRequestIDHeader, "2") - if e, a := "2", h.requestID(s); e != a { - t.Errorf("expected %q, got %q", e, a) - } -}