From ff9883d9ecd32f9db1077a4b6aeee3c112b0595d Mon Sep 17 00:00:00 2001 From: Andy Goldstein Date: Wed, 21 Oct 2015 20:42:40 -0400 Subject: [PATCH] Address code review comments --- .../unversioned/portforward/portforward.go | 6 +- .../portforward/portforward_test.go | 2 +- .../remotecommand/remotecommand.go | 11 ++-- .../remotecommand/remotecommand_test.go | 13 ++++- pkg/client/unversioned/remotecommand/v1.go | 8 ++- pkg/kubelet/server.go | 29 +++++++--- pkg/util/httpstream/httpstream.go | 55 +++++++++++++++++-- pkg/util/httpstream/spdy/roundtripper_test.go | 2 +- pkg/util/httpstream/spdy/upgrade.go | 42 +++----------- pkg/util/httpstream/spdy/upgrade_test.go | 3 +- 10 files changed, 107 insertions(+), 64 deletions(-) diff --git a/pkg/client/unversioned/portforward/portforward.go b/pkg/client/unversioned/portforward/portforward.go index a711c9f5883..3c279e7aa9b 100644 --- a/pkg/client/unversioned/portforward/portforward.go +++ b/pkg/client/unversioned/portforward/portforward.go @@ -29,6 +29,7 @@ import ( "github.com/golang/glog" "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/kubelet" "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/httpstream" ) @@ -122,16 +123,13 @@ func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}) (*P }, nil } -// The SPDY subprotocol "portforward.k8s.io" is used for port forwarding. -const PortForwardProtocolV1Name = "portforward.k8s.io" - // ForwardPorts formats and executes a port forwarding request. The connection will remain // open until stopChan is closed. func (pf *PortForwarder) ForwardPorts() error { defer pf.Close() var err error - pf.streamConn, _, err = pf.dialer.Dial([]string{PortForwardProtocolV1Name}) + pf.streamConn, _, err = pf.dialer.Dial(kubelet.PortForwardProtocolV1Name) if err != nil { return fmt.Errorf("error upgrading connection: %s", err) } diff --git a/pkg/client/unversioned/portforward/portforward_test.go b/pkg/client/unversioned/portforward/portforward_test.go index 7a21a9cdcd1..e82b74f6497 100644 --- a/pkg/client/unversioned/portforward/portforward_test.go +++ b/pkg/client/unversioned/portforward/portforward_test.go @@ -44,7 +44,7 @@ type fakeDialer struct { negotiatedProtocol string } -func (d *fakeDialer) Dial(protocols []string) (httpstream.Connection, string, error) { +func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, error) { d.dialed = true return d.conn, d.negotiatedProtocol, d.err } diff --git a/pkg/client/unversioned/remotecommand/remotecommand.go b/pkg/client/unversioned/remotecommand/remotecommand.go index 69f150b8630..99e914abebf 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand.go +++ b/pkg/client/unversioned/remotecommand/remotecommand.go @@ -98,7 +98,7 @@ func NewStreamExecutor(upgrader httpstream.UpgradeRoundTripper, fn func(http.Rou // Dial opens a connection to a remote server and attempts to negotiate a SPDY // connection. Upon success, it returns the connection and the protocol // selected by the server. -func (e *streamExecutor) Dial(protocols []string) (httpstream.Connection, string, error) { +func (e *streamExecutor) Dial(protocols ...string) (httpstream.Connection, string, error) { transport := e.transport // TODO consider removing this and reusing client.TransportFor above to get this for free switch { @@ -111,6 +111,9 @@ func (e *streamExecutor) Dial(protocols []string) (httpstream.Connection, string case bool(glog.V(6)): transport = client.NewDebuggingRoundTripper(transport, client.URLTiming) } + + // TODO the client probably shouldn't be created here, as it doesn't allow + // flexibility to allow callers to configure it. client := &http.Client{Transport: transport} req, err := http.NewRequest(e.method, e.url.String(), nil) @@ -158,7 +161,8 @@ type streamProtocolHandler interface { // Stream opens a protocol streamer to the server and streams until a client closes // the connection or the server disconnects. func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error { - conn, protocol, err := e.Dial([]string{StreamProtocolV2Name, StreamProtocolV1Name}) + supportedProtocols := []string{StreamProtocolV2Name, StreamProtocolV1Name} + conn, protocol, err := e.Dial(supportedProtocols...) if err != nil { return err } @@ -175,8 +179,7 @@ func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty b tty: tty, } case "": - glog.Warning("The server did not negotiate a streaming protocol version. Falling back to unversioned") - // TODO restore v1 + glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to unversioned") streamer = &streamProtocolV1{ stdin: stdin, stdout: stdout, diff --git a/pkg/client/unversioned/remotecommand/remotecommand_test.go b/pkg/client/unversioned/remotecommand/remotecommand_test.go index 5403dacbacf..536c63ee080 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand_test.go +++ b/pkg/client/unversioned/remotecommand/remotecommand_test.go @@ -42,10 +42,17 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + protocol, err := httpstream.Handshake(req, w, []string{StreamProtocolV2Name}, StreamProtocolV1Name) + if err != nil { + t.Fatal(err) + } + if protocol != StreamProtocolV2Name { + t.Fatalf("unexpected protocol: %s", protocol) + } streamCh := make(chan httpstream.Stream) upgrader := spdy.NewResponseUpgrader() - conn, protocol := upgrader.UpgradeResponse(w, req, []string{StreamProtocolV2Name, StreamProtocolV1Name}, func(stream httpstream.Stream) error { + conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error { streamCh <- stream return nil }) @@ -57,7 +64,6 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro return } defer conn.Close() - _ = protocol var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream receivedStreams := 0 @@ -185,6 +191,7 @@ func TestRequestExecuteRemoteCommand(t *testing.T) { url, _ := url.ParseRequestURI(server.URL) c := client.NewRESTClient(url, "x", nil, -1, -1) req := c.Post().Resource("testing") + req.SetHeader(httpstream.HeaderProtocolVersion, StreamProtocolV2Name) req.Param("command", "ls") req.Param("command", "/") conf := &client.Config{ @@ -364,7 +371,7 @@ func TestDial(t *testing.T) { if err != nil { t.Fatal(err) } - conn, protocol, err := exec.Dial([]string{"a", "b"}) + conn, protocol, err := exec.Dial("protocol1") if err != nil { t.Fatal(err) } diff --git a/pkg/client/unversioned/remotecommand/v1.go b/pkg/client/unversioned/remotecommand/v1.go index 1a64ed048cc..b10e5e1f1e7 100644 --- a/pkg/client/unversioned/remotecommand/v1.go +++ b/pkg/client/unversioned/remotecommand/v1.go @@ -27,6 +27,10 @@ import ( "k8s.io/kubernetes/pkg/util/httpstream" ) +// streamProtocolV1 implements the first version of the streaming exec & attach +// protocol. This version has some bugs, such as not being able to detecte when +// non-interactive stdin data has ended. See http://issues.k8s.io/13394 and +// http://issues.k8s.io/13395 for more details. type streamProtocolV1 struct { stdin io.Reader stdout io.Writer @@ -41,8 +45,8 @@ func (e *streamProtocolV1) stream(conn httpstream.Connection) error { errorChan := make(chan error) cp := func(s string, dst io.Writer, src io.Reader) { - glog.V(4).Infof("Copying %s", s) - defer glog.V(4).Infof("Done copying %s", s) + glog.V(6).Infof("Copying %s", s) + defer glog.V(6).Infof("Done copying %s", s) if _, err := io.Copy(dst, src); err != nil && err != io.EOF { glog.Errorf("Error copying %s: %v", s, err) } diff --git a/pkg/kubelet/server.go b/pkg/kubelet/server.go index 04969592f65..b0dc60fdc82 100644 --- a/pkg/kubelet/server.go +++ b/pkg/kubelet/server.go @@ -44,7 +44,6 @@ import ( "k8s.io/kubernetes/pkg/api/validation" "k8s.io/kubernetes/pkg/auth/authenticator" "k8s.io/kubernetes/pkg/auth/authorizer" - "k8s.io/kubernetes/pkg/client/unversioned/portforward" "k8s.io/kubernetes/pkg/client/unversioned/remotecommand" "k8s.io/kubernetes/pkg/healthz" "k8s.io/kubernetes/pkg/httplog" @@ -687,10 +686,17 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo return streams[0], streams[1], streams[2], streams[3], conn, tty, true } + supportedStreamProtocols := []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name} + _, err := httpstream.Handshake(request.Request, response.ResponseWriter, supportedStreamProtocols, remotecommand.StreamProtocolV1Name) + // negotiated protocol isn't used server side at the moment, but could be in the future + if err != nil { + return nil, nil, nil, nil, nil, false, false + } + streamCh := make(chan httpstream.Stream) upgrader := spdy.NewResponseUpgrader() - conn, protocol := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}, func(stream httpstream.Stream) error { + conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error { streamCh <- stream return nil }) @@ -701,9 +707,6 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo // if we weren't successful in upgrading. return nil, nil, nil, nil, nil, false, false } - if len(protocol) == 0 { - protocol = remotecommand.StreamProtocolV1Name - } conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout()) @@ -778,24 +781,34 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), defaultStreamCreationTimeout) } +// The subprotocol "portforward.k8s.io" is used for port forwarding. +const PortForwardProtocolV1Name = "portforward.k8s.io" + // ServePortForward handles a port forwarding request. A single request is // kept alive as long as the client is still alive and the connection has not // been timed out due to idleness. This function handles multiple forwarded // connections; i.e., multiple `curl http://localhost:8888/` requests will be // handled by a single invocation of ServePortForward. func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) { + supportedPortForwardProtocols := []string{PortForwardProtocolV1Name} + _, err := httpstream.Handshake(req, w, supportedPortForwardProtocols, PortForwardProtocolV1Name) + // negotiated protocol isn't currently used server side, but could be in the future + if err != nil { + // Handshake writes the error to the client + util.HandleError(err) + return + } + streamChan := make(chan httpstream.Stream, 1) glog.V(5).Infof("Upgrading port forward response") upgrader := spdy.NewResponseUpgrader() - conn, protocol := upgrader.UpgradeResponse(w, req, []string{portforward.PortForwardProtocolV1Name}, portForwardStreamReceived(streamChan)) + conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan)) if conn == nil { return } defer conn.Close() - _ = protocol - glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout) conn.SetIdleTimeout(idleTimeout) diff --git a/pkg/util/httpstream/httpstream.go b/pkg/util/httpstream/httpstream.go index 9f119b7b7ce..80c3cd78fc0 100644 --- a/pkg/util/httpstream/httpstream.go +++ b/pkg/util/httpstream/httpstream.go @@ -17,6 +17,7 @@ limitations under the License. package httpstream import ( + "fmt" "io" "net/http" "strings" @@ -24,9 +25,10 @@ import ( ) const ( - HeaderConnection = "Connection" - HeaderUpgrade = "Upgrade" - HeaderProtocolVersion = "X-Stream-Protocol-Version" + HeaderConnection = "Connection" + HeaderUpgrade = "Upgrade" + HeaderProtocolVersion = "X-Stream-Protocol-Version" + HeaderAcceptedProtocolVersions = "X-Accepted-Stream-Protocol-Versions" ) // NewStreamHandler defines a function that is called when a new Stream is @@ -43,7 +45,7 @@ type Dialer interface { // Dial opens a streaming connection to a server using one of the protocols // specified (in order of most preferred to least preferred). - Dial(protocols []string) (Connection, string, error) + Dial(protocols ...string) (Connection, string, error) } // UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade @@ -60,9 +62,9 @@ type UpgradeRoundTripper interface { // add streaming support to them. type ResponseUpgrader interface { // UpgradeResponse upgrades an HTTP response to one that supports multiplexed - // streams. newStreamHandler will be called synchronously whenever the + // streams. newStreamHandler will be called asynchronously whenever the // other end of the upgraded connection creates a new stream. - UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler NewStreamHandler) (Connection, string) + UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection } // Connection represents an upgraded HTTP connection. @@ -100,3 +102,44 @@ func IsUpgradeRequest(req *http.Request) bool { } return false } + +func negotiateProtocol(clientProtocols, serverProtocols []string) string { + for i := range clientProtocols { + for j := range serverProtocols { + if clientProtocols[i] == serverProtocols[j] { + return clientProtocols[i] + } + } + } + return "" +} + +// Handshake performs a subprotocol negotiation. If the client did not request +// a specific subprotocol, defaultProtocol is used. If the client did request a +// subprotocol, Handshake will select the first common value found in +// serverProtocols. If a match is found, Handshake adds a response header +// indicating the chosen subprotocol. If no match is found, HTTP forbidden is +// returned, along with a response header containing the list of protocols the +// server can accept. +func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string, defaultProtocol string) (string, error) { + clientProtocols := req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)] + if len(clientProtocols) == 0 { + // Kube 1.0 client that didn't support subprotocol negotiation + // TODO remove this defaulting logic once Kube 1.0 is no longer supported + w.Header().Add(HeaderProtocolVersion, defaultProtocol) + return defaultProtocol, nil + } + + negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols) + if len(negotiatedProtocol) == 0 { + w.WriteHeader(http.StatusForbidden) + for i := range serverProtocols { + w.Header().Add(HeaderAcceptedProtocolVersions, serverProtocols[i]) + } + fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: client supports %v, server accepts %v", clientProtocols, serverProtocols) + return "", fmt.Errorf("unable to upgrade: unable to negotiate protocol: client supports %v, server supports %v", clientProtocols, serverProtocols) + } + + w.Header().Add(HeaderProtocolVersion, negotiatedProtocol) + return negotiatedProtocol, nil +} diff --git a/pkg/util/httpstream/spdy/roundtripper_test.go b/pkg/util/httpstream/spdy/roundtripper_test.go index f16a6e697f3..babd23c9011 100644 --- a/pkg/util/httpstream/spdy/roundtripper_test.go +++ b/pkg/util/httpstream/spdy/roundtripper_test.go @@ -120,7 +120,7 @@ func TestRoundTripAndNewConnection(t *testing.T) { streamCh := make(chan httpstream.Stream) responseUpgrader := NewResponseUpgrader() - spdyConn, _ := responseUpgrader.UpgradeResponse(w, req, []string{"protocol1"}, func(s httpstream.Stream) error { + spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream) error { streamCh <- s return nil }) diff --git a/pkg/util/httpstream/spdy/upgrade.go b/pkg/util/httpstream/spdy/upgrade.go index c3e79aa3d17..4fd2a40521a 100644 --- a/pkg/util/httpstream/spdy/upgrade.go +++ b/pkg/util/httpstream/spdy/upgrade.go @@ -21,7 +21,7 @@ import ( "net/http" "strings" - "github.com/golang/glog" + "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/httpstream" ) @@ -39,47 +39,23 @@ func NewResponseUpgrader() httpstream.ResponseUpgrader { return responseUpgrader{} } -func negotiateProtocol(clientProtocols, serverProtocols []string) string { - for i := range clientProtocols { - for j := range serverProtocols { - if clientProtocols[i] == serverProtocols[j] { - return clientProtocols[i] - } - } - } - return "" -} - // UpgradeResponse upgrades an HTTP response to one that supports multiplexed // streams. newStreamHandler will be called synchronously whenever the // other end of the upgraded connection creates a new stream. -func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler httpstream.NewStreamHandler) (httpstream.Connection, string) { +func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection { connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection)) upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade)) if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) { w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, "unable to upgrade: missing upgrade headers in request: %#v", req.Header) - return nil, "" + return nil } hijacker, ok := w.(http.Hijacker) if !ok { w.WriteHeader(http.StatusInternalServerError) fmt.Fprintf(w, "unable to upgrade: unable to hijack response") - return nil, "" - } - - var negotiatedProtocol string - clientProtocols := req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)] - if len(clientProtocols) > 0 { - negotiatedProtocol = negotiateProtocol(req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)], protocols) - if len(negotiatedProtocol) > 0 { - w.Header().Add(httpstream.HeaderProtocolVersion, negotiatedProtocol) - } else { - w.WriteHeader(http.StatusForbidden) - fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: server accepts %v", protocols) - return nil, "" - } + return nil } w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade) @@ -88,15 +64,15 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque conn, _, err := hijacker.Hijack() if err != nil { - glog.Errorf("unable to upgrade: error hijacking response: %v", err) - return nil, "" + util.HandleError(fmt.Errorf("unable to upgrade: error hijacking response: %v", err)) + return nil } spdyConn, err := NewServerConnection(conn, newStreamHandler) if err != nil { - glog.Errorf("unable to upgrade: error creating SPDY server connection: %v", err) - return nil, "" + util.HandleError(fmt.Errorf("unable to upgrade: error creating SPDY server connection: %v", err)) + return nil } - return spdyConn, negotiatedProtocol + return spdyConn } diff --git a/pkg/util/httpstream/spdy/upgrade_test.go b/pkg/util/httpstream/spdy/upgrade_test.go index e82f3515ee1..4e111407e87 100644 --- a/pkg/util/httpstream/spdy/upgrade_test.go +++ b/pkg/util/httpstream/spdy/upgrade_test.go @@ -53,8 +53,7 @@ func TestUpgradeResponse(t *testing.T) { for i, testCase := range testCases { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upgrader := NewResponseUpgrader() - conn, protocol := upgrader.UpgradeResponse(w, req, []string{"protocol1"}, nil) - _ = protocol + conn := upgrader.UpgradeResponse(w, req, nil) haveErr := conn == nil if e, a := testCase.shouldError, haveErr; e != a { t.Fatalf("%d: expected shouldErr=%t, got %t", i, testCase.shouldError, haveErr)