diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index 736ac9c9992..3a265b7e87d 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -403,7 +403,7 @@ }, { "ImportPath": "github.com/docker/spdystream", - "Rev": "c33989bcb56748d2473194d11f8ac3fc563688eb" + "Rev": "106e140db2cb50923efe088bf2906b2ee5a45fec" }, { "ImportPath": "github.com/elazarl/go-bindata-assetfs", diff --git a/Godeps/_workspace/src/github.com/docker/spdystream/connection.go b/Godeps/_workspace/src/github.com/docker/spdystream/connection.go index aa6c75202e3..6031a0db1ab 100644 --- a/Godeps/_workspace/src/github.com/docker/spdystream/connection.go +++ b/Godeps/_workspace/src/github.com/docker/spdystream/connection.go @@ -320,6 +320,7 @@ func (s *Connection) Serve(newHandler StreamHandler) { partitionRoundRobin int goAwayFrame *spdy.GoAwayFrame ) +Loop: for { readFrame, err := s.framer.ReadFrame() if err != nil { @@ -362,7 +363,7 @@ func (s *Connection) Serve(newHandler StreamHandler) { case *spdy.GoAwayFrame: // hold on to the go away frame and exit the loop goAwayFrame = frame - break + break Loop default: priority = 7 partition = partitionRoundRobin diff --git a/pkg/client/unversioned/remotecommand/remotecommand_test.go b/pkg/client/unversioned/remotecommand/remotecommand_test.go index 68fb9eb129b..71acd239d33 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand_test.go +++ b/pkg/client/unversioned/remotecommand/remotecommand_test.go @@ -34,6 +34,19 @@ import ( "k8s.io/kubernetes/pkg/util/httpstream/spdy" ) +type streamAndReply struct { + httpstream.Stream + replySent <-chan struct{} +} + +func waitStreamReply(replySent <-chan struct{}, notify chan<- struct{}, stop <-chan struct{}) { + select { + case <-replySent: + notify <- struct{}{} + case <-stop: + } +} + func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, errorData string, tty bool, messageCount int) http.HandlerFunc { // error + stdin + stdout expectedStreams := 3 @@ -50,11 +63,11 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro if protocol != StreamProtocolV2Name { t.Fatalf("unexpected protocol: %s", protocol) } - streamCh := make(chan httpstream.Stream) + streamCh := make(chan streamAndReply) upgrader := spdy.NewResponseUpgrader() - conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error { - streamCh <- stream + conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error { + streamCh <- streamAndReply{Stream: stream, replySent: replySent} return nil }) // from this point on, we can no longer call methods on w @@ -68,6 +81,9 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream receivedStreams := 0 + replyChan := make(chan struct{}) + stop := make(chan struct{}) + defer close(stop) WaitForStreams: for { select { @@ -76,20 +92,25 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro switch streamType { case api.StreamTypeError: errorStream = stream - receivedStreams++ + go waitStreamReply(stream.replySent, replyChan, stop) case api.StreamTypeStdin: stdinStream = stream - receivedStreams++ + go waitStreamReply(stream.replySent, replyChan, stop) case api.StreamTypeStdout: stdoutStream = stream - receivedStreams++ + go waitStreamReply(stream.replySent, replyChan, stop) case api.StreamTypeStderr: stderrStream = stream - receivedStreams++ + go waitStreamReply(stream.replySent, replyChan, stop) default: t.Errorf("%d: unexpected stream type: %q", i, streamType) } + if receivedStreams == expectedStreams { + break WaitForStreams + } + case <-replyChan: + receivedStreams++ if receivedStreams == expectedStreams { break WaitForStreams } diff --git a/pkg/kubelet/server/server.go b/pkg/kubelet/server/server.go index a5f6cfda0d8..88b7ee9f5fc 100644 --- a/pkg/kubelet/server/server.go +++ b/pkg/kubelet/server/server.go @@ -617,6 +617,15 @@ func standardShellChannels(stdin, stdout, stderr bool) []wsstream.ChannelType { return channels } +// streamAndReply holds both a Stream and a channel that is closed when the stream's reply frame is +// enqueued. Consumers can wait for replySent to be closed prior to proceeding, to ensure that the +// replyFrame is enqueued before the connection's goaway frame is sent (e.g. if a stream was +// received and right after, the connection gets closed). +type streamAndReply struct { + httpstream.Stream + replySent <-chan struct{} +} + func (s *Server) createStreams(request *restful.Request, response *restful.Response) (io.Reader, io.WriteCloser, io.WriteCloser, io.WriteCloser, Closer, bool, bool) { tty := request.QueryParameter(api.ExecTTYParam) == "1" stdin := request.QueryParameter(api.ExecStdinParam) == "1" @@ -675,11 +684,11 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo return nil, nil, nil, nil, nil, false, false } - streamCh := make(chan httpstream.Stream) + streamCh := make(chan streamAndReply) upgrader := spdy.NewResponseUpgrader() - conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error { - streamCh <- stream + conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream, replySent <-chan struct{}) error { + streamCh <- streamAndReply{Stream: stream, replySent: replySent} return nil }) // from this point on, we can no longer call methods on response @@ -697,6 +706,9 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream receivedStreams := 0 + replyChan := make(chan struct{}) + stop := make(chan struct{}) + defer close(stop) WaitForStreams: for { select { @@ -705,19 +717,21 @@ WaitForStreams: switch streamType { case api.StreamTypeError: errorStream = stream - receivedStreams++ + go waitStreamReply(stream.replySent, replyChan, stop) case api.StreamTypeStdin: stdinStream = stream - receivedStreams++ + go waitStreamReply(stream.replySent, replyChan, stop) case api.StreamTypeStdout: stdoutStream = stream - receivedStreams++ + go waitStreamReply(stream.replySent, replyChan, stop) case api.StreamTypeStderr: stderrStream = stream - receivedStreams++ + go waitStreamReply(stream.replySent, replyChan, stop) default: glog.Errorf("Unexpected stream type: '%s'", streamType) } + case <-replyChan: + receivedStreams++ if receivedStreams == expectedStreams { break WaitForStreams } @@ -732,6 +746,16 @@ WaitForStreams: return stdinStream, stdoutStream, stderrStream, errorStream, conn, tty, true } +// waitStreamReply waits until either replySent or stop is closed. If replySent is closed, it sends +// an empty struct to the notify channel. +func waitStreamReply(replySent <-chan struct{}, notify chan<- struct{}, stop <-chan struct{}) { + select { + case <-replySent: + notify <- struct{}{} + case <-stop: + } +} + func getPodCoordinates(request *restful.Request) (namespace, pod string, uid types.UID) { namespace = request.PathParameter("podNamespace") pod = request.PathParameter("podID") @@ -807,8 +831,8 @@ func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder Po // forward streams. It checks each stream's port and stream type headers, // rejecting any streams that with missing or invalid values. Each valid // stream is sent to the streams channel. -func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream) error { - return func(stream httpstream.Stream) error { +func portForwardStreamReceived(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error { + return func(stream httpstream.Stream, replySent <-chan struct{}) error { // make sure it has a valid port header portString := stream.Headers().Get(api.PortHeader) if len(portString) == 0 { diff --git a/pkg/kubelet/server/server_test.go b/pkg/kubelet/server/server_test.go index 5854c4ed522..cf5dd1ea7ee 100644 --- a/pkg/kubelet/server/server_test.go +++ b/pkg/kubelet/server/server_test.go @@ -1727,7 +1727,9 @@ func TestPortForwardStreamReceived(t *testing.T) { if len(test.streamType) > 0 { stream.headers.Set("streamType", test.streamType) } - err := f(stream) + replySent := make(chan struct{}) + err := f(stream, replySent) + close(replySent) if len(test.expectedError) > 0 { if err == nil { t.Errorf("%s: expected err=%q, but it was nil", name, test.expectedError) diff --git a/pkg/util/httpstream/httpstream.go b/pkg/util/httpstream/httpstream.go index 80c3cd78fc0..4f6b608ce7a 100644 --- a/pkg/util/httpstream/httpstream.go +++ b/pkg/util/httpstream/httpstream.go @@ -33,12 +33,12 @@ const ( // NewStreamHandler defines a function that is called when a new Stream is // received. If no error is returned, the Stream is accepted; otherwise, -// the stream is rejected. -type NewStreamHandler func(Stream) error +// the stream is rejected. After the reply frame has been sent, replySent is closed. +type NewStreamHandler func(stream Stream, replySent <-chan struct{}) error // NoOpNewStreamHandler is a stream handler that accepts a new stream and // performs no other logic. -func NoOpNewStreamHandler(stream Stream) error { return nil } +func NoOpNewStreamHandler(stream Stream, replySent <-chan struct{}) error { return nil } // Dialer knows how to open a streaming connection to a server. type Dialer interface { diff --git a/pkg/util/httpstream/spdy/connection.go b/pkg/util/httpstream/spdy/connection.go index 6d4855d195f..884c6e20360 100644 --- a/pkg/util/httpstream/spdy/connection.go +++ b/pkg/util/httpstream/spdy/connection.go @@ -120,7 +120,8 @@ func (c *connection) CloseChan() <-chan bool { // the stream. If newStreamHandler returns an error, the stream is rejected. If not, the // stream is accepted and registered with the connection. func (c *connection) newSpdyStream(stream *spdystream.Stream) { - err := c.newStreamHandler(stream) + replySent := make(chan struct{}) + err := c.newStreamHandler(stream, replySent) rejectStream := (err != nil) if rejectStream { glog.Warningf("Stream rejected: %v", err) @@ -130,6 +131,7 @@ func (c *connection) newSpdyStream(stream *spdystream.Stream) { c.registerStream(stream) stream.SendReply(http.Header{}, rejectStream) + close(replySent) } // SetIdleTimeout sets the amount of time the connection may remain idle before diff --git a/pkg/util/httpstream/spdy/roundtripper_test.go b/pkg/util/httpstream/spdy/roundtripper_test.go index 651dd89f347..7ca1d42ad34 100644 --- a/pkg/util/httpstream/spdy/roundtripper_test.go +++ b/pkg/util/httpstream/spdy/roundtripper_test.go @@ -134,7 +134,7 @@ func TestRoundTripAndNewConnection(t *testing.T) { streamCh := make(chan httpstream.Stream) responseUpgrader := NewResponseUpgrader() - spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream) error { + spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream, replySent <-chan struct{}) error { streamCh <- s return nil })