diff --git a/pkg/client/unversioned/portforward/portforward.go b/pkg/client/unversioned/portforward/portforward.go index 36916bdcf49..a711c9f5883 100644 --- a/pkg/client/unversioned/portforward/portforward.go +++ b/pkg/client/unversioned/portforward/portforward.go @@ -122,13 +122,16 @@ func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}) (*P }, nil } +// The SPDY subprotocol "portforward.k8s.io" is used for port forwarding. +const PortForwardProtocolV1Name = "portforward.k8s.io" + // ForwardPorts formats and executes a port forwarding request. The connection will remain // open until stopChan is closed. func (pf *PortForwarder) ForwardPorts() error { defer pf.Close() var err error - pf.streamConn, err = pf.dialer.Dial() + pf.streamConn, _, err = pf.dialer.Dial([]string{PortForwardProtocolV1Name}) if err != nil { return fmt.Errorf("error upgrading connection: %s", err) } diff --git a/pkg/client/unversioned/portforward/portforward_test.go b/pkg/client/unversioned/portforward/portforward_test.go index f208a9e8814..7a21a9cdcd1 100644 --- a/pkg/client/unversioned/portforward/portforward_test.go +++ b/pkg/client/unversioned/portforward/portforward_test.go @@ -38,14 +38,15 @@ import ( ) type fakeDialer struct { - dialed bool - conn httpstream.Connection - err error + dialed bool + conn httpstream.Connection + err error + negotiatedProtocol string } -func (d *fakeDialer) Dial() (httpstream.Connection, error) { +func (d *fakeDialer) Dial(protocols []string) (httpstream.Connection, string, error) { d.dialed = true - return d.conn, d.err + return d.conn, d.negotiatedProtocol, d.err } func TestParsePortsAndNew(t *testing.T) { diff --git a/pkg/client/unversioned/remotecommand/remotecommand.go b/pkg/client/unversioned/remotecommand/remotecommand.go index 505e3fc9c1b..21091c85250 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand.go +++ b/pkg/client/unversioned/remotecommand/remotecommand.go @@ -24,6 +24,8 @@ import ( "net/url" "sync" + "github.com/golang/glog" + "k8s.io/kubernetes/pkg/api" client "k8s.io/kubernetes/pkg/client/unversioned" "k8s.io/kubernetes/pkg/util" @@ -97,51 +99,207 @@ func NewStreamExecutor(upgrader httpstream.UpgradeRoundTripper, fn func(http.Rou }, nil } -// Dial opens a connection to a remote server and attempts to negotiate a SPDY connection. -func (e *streamExecutor) Dial() (httpstream.Connection, error) { - client := &http.Client{Transport: e.transport} +// Dial opens a connection to a remote server and attempts to negotiate a SPDY +// connection. Upon success, it returns the connection and the protocol +// selected by the server. +func (e *streamExecutor) Dial(protocols []string) (httpstream.Connection, string, error) { + transport := e.transport + // TODO consider removing this and reusing client.TransportFor above to get this for free + switch { + case bool(glog.V(9)): + transport = client.NewDebuggingRoundTripper(transport, client.CurlCommand, client.URLTiming, client.ResponseHeaders) + case bool(glog.V(8)): + transport = client.NewDebuggingRoundTripper(transport, client.JustURL, client.RequestHeaders, client.ResponseStatus, client.ResponseHeaders) + case bool(glog.V(7)): + transport = client.NewDebuggingRoundTripper(transport, client.JustURL, client.RequestHeaders, client.ResponseStatus) + case bool(glog.V(6)): + transport = client.NewDebuggingRoundTripper(transport, client.URLTiming) + } + client := &http.Client{Transport: transport} req, err := http.NewRequest(e.method, e.url.String(), nil) if err != nil { - return nil, fmt.Errorf("error creating request: %s", err) + return nil, "", fmt.Errorf("error creating request: %v", err) + } + for i := range protocols { + req.Header.Add(httpstream.HeaderProtocolVersion, protocols[i]) } resp, err := client.Do(req) if err != nil { - return nil, fmt.Errorf("error sending request: %s", err) + return nil, "", fmt.Errorf("error sending request: %v", err) } defer resp.Body.Close() - // TODO: handle protocol selection in the future - return e.upgrader.NewConnection(resp) + if resp.StatusCode != http.StatusSwitchingProtocols { + return nil, "", fmt.Errorf("unexpected response status code %d (%s)", resp.StatusCode, http.StatusText(resp.StatusCode)) + } + + conn, err := e.upgrader.NewConnection(resp) + if err != nil { + return nil, "", err + } + + return conn, resp.Header.Get(httpstream.HeaderProtocolVersion), nil +} + +const ( + // The SPDY subprotocol "channel.k8s.io" is used for remote command + // attachment/execution. This represents the initial unversioned subprotocol, + // which has the known bugs http://issues.k8s.io/13394 and + // http://issues.k8s.io/13395. + StreamProtocolV1Name = "channel.k8s.io" + // The SPDY subprotocol "v2.channel.k8s.io" is used for remote command + // attachment/execution. It is the second version of the subprotocol and + // resolves the issues present in the first version. + StreamProtocolV2Name = "v2.channel.k8s.io" +) + +type streamProtocolHandler interface { + stream(httpstream.Connection) error } // Stream opens a protocol streamer to the server and streams until a client closes // the connection or the server disconnects. func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error { - conn, err := e.Dial() + conn, protocol, err := e.Dial([]string{StreamProtocolV2Name, StreamProtocolV1Name}) if err != nil { return err } defer conn.Close() - // TODO: negotiate protocols - streamer := &streamProtocol{ - stdin: stdin, - stdout: stdout, - stderr: stderr, - tty: tty, + + var streamer streamProtocolHandler + + switch protocol { + case StreamProtocolV2Name: + streamer = &streamProtocolV2{ + stdin: stdin, + stdout: stdout, + stderr: stderr, + tty: tty, + } + case "": + glog.Warning("The server did not negotiate a streaming protocol version. Falling back to unversioned") + // TODO restore v1 + streamer = &streamProtocolV1{ + stdin: stdin, + stdout: stdout, + stderr: stderr, + tty: tty, + } } + return streamer.stream(conn) } -type streamProtocol struct { +type streamProtocolV1 struct { stdin io.Reader stdout io.Writer stderr io.Writer tty bool } -func (e *streamProtocol) stream(conn httpstream.Connection) error { +func (e *streamProtocolV1) stream(conn httpstream.Connection) error { + doneChan := make(chan struct{}, 2) + errorChan := make(chan error) + + cp := func(s string, dst io.Writer, src io.Reader) { + glog.V(4).Infof("Copying %s", s) + defer glog.V(4).Infof("Done copying %s", s) + if _, err := io.Copy(dst, src); err != nil && err != io.EOF { + glog.Errorf("Error copying %s: %v", s, err) + } + if s == api.StreamTypeStdout || s == api.StreamTypeStderr { + doneChan <- struct{}{} + } + } + + headers := http.Header{} + headers.Set(api.StreamType, api.StreamTypeError) + errorStream, err := conn.CreateStream(headers) + if err != nil { + return err + } + go func() { + message, err := ioutil.ReadAll(errorStream) + if err != nil && err != io.EOF { + errorChan <- fmt.Errorf("Error reading from error stream: %s", err) + return + } + if len(message) > 0 { + errorChan <- fmt.Errorf("Error executing remote command: %s", message) + return + } + }() + defer errorStream.Reset() + + if e.stdin != nil { + headers.Set(api.StreamType, api.StreamTypeStdin) + remoteStdin, err := conn.CreateStream(headers) + if err != nil { + return err + } + defer remoteStdin.Reset() + // TODO this goroutine will never exit cleanly (the io.Copy never unblocks) + // because stdin is not closed until the process exits. If we try to call + // stdin.Close(), it returns no error but doesn't unblock the copy. It will + // exit when the process exits, instead. + go cp(api.StreamTypeStdin, remoteStdin, e.stdin) + } + + waitCount := 0 + completedStreams := 0 + + if e.stdout != nil { + waitCount++ + headers.Set(api.StreamType, api.StreamTypeStdout) + remoteStdout, err := conn.CreateStream(headers) + if err != nil { + return err + } + defer remoteStdout.Reset() + go cp(api.StreamTypeStdout, e.stdout, remoteStdout) + } + + if e.stderr != nil && !e.tty { + waitCount++ + headers.Set(api.StreamType, api.StreamTypeStderr) + remoteStderr, err := conn.CreateStream(headers) + if err != nil { + return err + } + defer remoteStderr.Reset() + go cp(api.StreamTypeStderr, e.stderr, remoteStderr) + } + +Loop: + for { + select { + case <-doneChan: + completedStreams++ + if completedStreams == waitCount { + break Loop + } + case err := <-errorChan: + return err + } + } + + return nil +} + +// streamProtocolV2 implements version 2 of the streaming protocol for attach +// and exec. The original streaming protocol was unversioned. As a result, this +// version is referred to as version 2, even though it is the first actual +// numbered version. +type streamProtocolV2 struct { + stdin io.Reader + stdout io.Writer + stderr io.Writer + tty bool +} + +func (e *streamProtocolV2) stream(conn httpstream.Connection) error { headers := http.Header{} // set up error stream diff --git a/pkg/client/unversioned/remotecommand/remotecommand_test.go b/pkg/client/unversioned/remotecommand/remotecommand_test.go index 56cb3372bfb..5403dacbacf 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand_test.go +++ b/pkg/client/unversioned/remotecommand/remotecommand_test.go @@ -45,7 +45,7 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro streamCh := make(chan httpstream.Stream) upgrader := spdy.NewResponseUpgrader() - conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error { + conn, protocol := upgrader.UpgradeResponse(w, req, []string{StreamProtocolV2Name, StreamProtocolV1Name}, func(stream httpstream.Stream) error { streamCh <- stream return nil }) @@ -57,6 +57,7 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro return } defer conn.Close() + _ = protocol var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream receivedStreams := 0 @@ -347,7 +348,7 @@ func TestDial(t *testing.T) { checkResponse: true, conn: &fakeConnection{}, resp: &http.Response{ - StatusCode: http.StatusOK, + StatusCode: http.StatusSwitchingProtocols, Body: ioutil.NopCloser(&bytes.Buffer{}), }, } @@ -363,7 +364,7 @@ func TestDial(t *testing.T) { if err != nil { t.Fatal(err) } - conn, err := exec.Dial() + conn, protocol, err := exec.Dial([]string{"a", "b"}) if err != nil { t.Fatal(err) } @@ -373,4 +374,5 @@ func TestDial(t *testing.T) { if !called { t.Errorf("wrapper not called") } + _ = protocol } diff --git a/pkg/kubelet/server.go b/pkg/kubelet/server.go index c8b2df9d613..04969592f65 100644 --- a/pkg/kubelet/server.go +++ b/pkg/kubelet/server.go @@ -44,6 +44,8 @@ import ( "k8s.io/kubernetes/pkg/api/validation" "k8s.io/kubernetes/pkg/auth/authenticator" "k8s.io/kubernetes/pkg/auth/authorizer" + "k8s.io/kubernetes/pkg/client/unversioned/portforward" + "k8s.io/kubernetes/pkg/client/unversioned/remotecommand" "k8s.io/kubernetes/pkg/healthz" "k8s.io/kubernetes/pkg/httplog" kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" @@ -688,7 +690,7 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo streamCh := make(chan httpstream.Stream) upgrader := spdy.NewResponseUpgrader() - conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error { + conn, protocol := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}, func(stream httpstream.Stream) error { streamCh <- stream return nil }) @@ -699,6 +701,9 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo // if we weren't successful in upgrading. return nil, nil, nil, nil, nil, false, false } + if len(protocol) == 0 { + protocol = remotecommand.StreamProtocolV1Name + } conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout()) @@ -783,12 +788,14 @@ func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder Po glog.V(5).Infof("Upgrading port forward response") upgrader := spdy.NewResponseUpgrader() - conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan)) + conn, protocol := upgrader.UpgradeResponse(w, req, []string{portforward.PortForwardProtocolV1Name}, portForwardStreamReceived(streamChan)) if conn == nil { return } defer conn.Close() + _ = protocol + glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout) conn.SetIdleTimeout(idleTimeout) diff --git a/pkg/util/httpstream/httpstream.go b/pkg/util/httpstream/httpstream.go index fff6840f279..9f119b7b7ce 100644 --- a/pkg/util/httpstream/httpstream.go +++ b/pkg/util/httpstream/httpstream.go @@ -24,8 +24,9 @@ import ( ) const ( - HeaderConnection = "Connection" - HeaderUpgrade = "Upgrade" + HeaderConnection = "Connection" + HeaderUpgrade = "Upgrade" + HeaderProtocolVersion = "X-Stream-Protocol-Version" ) // NewStreamHandler defines a function that is called when a new Stream is @@ -39,7 +40,10 @@ func NoOpNewStreamHandler(stream Stream) error { return nil } // Dialer knows how to open a streaming connection to a server. type Dialer interface { - Dial() (Connection, error) + + // Dial opens a streaming connection to a server using one of the protocols + // specified (in order of most preferred to least preferred). + Dial(protocols []string) (Connection, string, error) } // UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade @@ -58,7 +62,7 @@ type ResponseUpgrader interface { // UpgradeResponse upgrades an HTTP response to one that supports multiplexed // streams. newStreamHandler will be called synchronously whenever the // other end of the upgraded connection creates a new stream. - UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection + UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler NewStreamHandler) (Connection, string) } // Connection represents an upgraded HTTP connection. diff --git a/pkg/util/httpstream/spdy/roundtripper_test.go b/pkg/util/httpstream/spdy/roundtripper_test.go index babd23c9011..f16a6e697f3 100644 --- a/pkg/util/httpstream/spdy/roundtripper_test.go +++ b/pkg/util/httpstream/spdy/roundtripper_test.go @@ -120,7 +120,7 @@ func TestRoundTripAndNewConnection(t *testing.T) { streamCh := make(chan httpstream.Stream) responseUpgrader := NewResponseUpgrader() - spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream) error { + spdyConn, _ := responseUpgrader.UpgradeResponse(w, req, []string{"protocol1"}, func(s httpstream.Stream) error { streamCh <- s return nil }) diff --git a/pkg/util/httpstream/spdy/upgrade.go b/pkg/util/httpstream/spdy/upgrade.go index 70c90c8f499..c3e79aa3d17 100644 --- a/pkg/util/httpstream/spdy/upgrade.go +++ b/pkg/util/httpstream/spdy/upgrade.go @@ -39,23 +39,47 @@ func NewResponseUpgrader() httpstream.ResponseUpgrader { return responseUpgrader{} } +func negotiateProtocol(clientProtocols, serverProtocols []string) string { + for i := range clientProtocols { + for j := range serverProtocols { + if clientProtocols[i] == serverProtocols[j] { + return clientProtocols[i] + } + } + } + return "" +} + // UpgradeResponse upgrades an HTTP response to one that supports multiplexed // streams. newStreamHandler will be called synchronously whenever the // other end of the upgraded connection creates a new stream. -func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection { +func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler httpstream.NewStreamHandler) (httpstream.Connection, string) { connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection)) upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade)) if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) { - w.Write([]byte(fmt.Sprintf("Unable to upgrade: missing upgrade headers in request: %#v", req.Header))) w.WriteHeader(http.StatusBadRequest) - return nil + fmt.Fprintf(w, "unable to upgrade: missing upgrade headers in request: %#v", req.Header) + return nil, "" } hijacker, ok := w.(http.Hijacker) if !ok { - w.Write([]byte("Unable to upgrade: unable to hijack response")) w.WriteHeader(http.StatusInternalServerError) - return nil + fmt.Fprintf(w, "unable to upgrade: unable to hijack response") + return nil, "" + } + + var negotiatedProtocol string + clientProtocols := req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)] + if len(clientProtocols) > 0 { + negotiatedProtocol = negotiateProtocol(req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)], protocols) + if len(negotiatedProtocol) > 0 { + w.Header().Add(httpstream.HeaderProtocolVersion, negotiatedProtocol) + } else { + w.WriteHeader(http.StatusForbidden) + fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: server accepts %v", protocols) + return nil, "" + } } w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade) @@ -64,15 +88,15 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque conn, _, err := hijacker.Hijack() if err != nil { - glog.Errorf("Unable to upgrade: error hijacking response: %v", err) - return nil + glog.Errorf("unable to upgrade: error hijacking response: %v", err) + return nil, "" } spdyConn, err := NewServerConnection(conn, newStreamHandler) if err != nil { - glog.Errorf("Unable to upgrade: error creating SPDY server connection: %v", err) - return nil + glog.Errorf("unable to upgrade: error creating SPDY server connection: %v", err) + return nil, "" } - return spdyConn + return spdyConn, negotiatedProtocol } diff --git a/pkg/util/httpstream/spdy/upgrade_test.go b/pkg/util/httpstream/spdy/upgrade_test.go index 4e111407e87..e82f3515ee1 100644 --- a/pkg/util/httpstream/spdy/upgrade_test.go +++ b/pkg/util/httpstream/spdy/upgrade_test.go @@ -53,7 +53,8 @@ func TestUpgradeResponse(t *testing.T) { for i, testCase := range testCases { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upgrader := NewResponseUpgrader() - conn := upgrader.UpgradeResponse(w, req, nil) + conn, protocol := upgrader.UpgradeResponse(w, req, []string{"protocol1"}, nil) + _ = protocol haveErr := conn == nil if e, a := testCase.shouldError, haveErr; e != a { t.Fatalf("%d: expected shouldErr=%t, got %t", i, testCase.shouldError, haveErr)