From 3d1cafc2c38396f7f55e5927b694f402834cb5af Mon Sep 17 00:00:00 2001 From: Andy Goldstein Date: Tue, 20 Oct 2015 08:21:07 -0400 Subject: [PATCH 1/5] Add streaming subprotocol negotiation Add streaming subprotocol negotiation for exec, attach, and port forwarding. Restore previous (buggy) exec functionality as an unspecified/unversioned subprotocol so newer kubectl clients can work against 1.0.x kubelets. --- .../unversioned/portforward/portforward.go | 5 +- .../portforward/portforward_test.go | 11 +- .../remotecommand/remotecommand.go | 190 ++++++++++++++++-- .../remotecommand/remotecommand_test.go | 8 +- pkg/kubelet/server.go | 11 +- pkg/util/httpstream/httpstream.go | 12 +- pkg/util/httpstream/spdy/roundtripper_test.go | 2 +- pkg/util/httpstream/spdy/upgrade.go | 44 +++- pkg/util/httpstream/spdy/upgrade_test.go | 3 +- 9 files changed, 243 insertions(+), 43 deletions(-) diff --git a/pkg/client/unversioned/portforward/portforward.go b/pkg/client/unversioned/portforward/portforward.go index 36916bdcf49..a711c9f5883 100644 --- a/pkg/client/unversioned/portforward/portforward.go +++ b/pkg/client/unversioned/portforward/portforward.go @@ -122,13 +122,16 @@ func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}) (*P }, nil } +// The SPDY subprotocol "portforward.k8s.io" is used for port forwarding. +const PortForwardProtocolV1Name = "portforward.k8s.io" + // ForwardPorts formats and executes a port forwarding request. The connection will remain // open until stopChan is closed. func (pf *PortForwarder) ForwardPorts() error { defer pf.Close() var err error - pf.streamConn, err = pf.dialer.Dial() + pf.streamConn, _, err = pf.dialer.Dial([]string{PortForwardProtocolV1Name}) if err != nil { return fmt.Errorf("error upgrading connection: %s", err) } diff --git a/pkg/client/unversioned/portforward/portforward_test.go b/pkg/client/unversioned/portforward/portforward_test.go index f208a9e8814..7a21a9cdcd1 100644 --- a/pkg/client/unversioned/portforward/portforward_test.go +++ b/pkg/client/unversioned/portforward/portforward_test.go @@ -38,14 +38,15 @@ import ( ) type fakeDialer struct { - dialed bool - conn httpstream.Connection - err error + dialed bool + conn httpstream.Connection + err error + negotiatedProtocol string } -func (d *fakeDialer) Dial() (httpstream.Connection, error) { +func (d *fakeDialer) Dial(protocols []string) (httpstream.Connection, string, error) { d.dialed = true - return d.conn, d.err + return d.conn, d.negotiatedProtocol, d.err } func TestParsePortsAndNew(t *testing.T) { diff --git a/pkg/client/unversioned/remotecommand/remotecommand.go b/pkg/client/unversioned/remotecommand/remotecommand.go index 505e3fc9c1b..21091c85250 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand.go +++ b/pkg/client/unversioned/remotecommand/remotecommand.go @@ -24,6 +24,8 @@ import ( "net/url" "sync" + "github.com/golang/glog" + "k8s.io/kubernetes/pkg/api" client "k8s.io/kubernetes/pkg/client/unversioned" "k8s.io/kubernetes/pkg/util" @@ -97,51 +99,207 @@ func NewStreamExecutor(upgrader httpstream.UpgradeRoundTripper, fn func(http.Rou }, nil } -// Dial opens a connection to a remote server and attempts to negotiate a SPDY connection. -func (e *streamExecutor) Dial() (httpstream.Connection, error) { - client := &http.Client{Transport: e.transport} +// Dial opens a connection to a remote server and attempts to negotiate a SPDY +// connection. Upon success, it returns the connection and the protocol +// selected by the server. +func (e *streamExecutor) Dial(protocols []string) (httpstream.Connection, string, error) { + transport := e.transport + // TODO consider removing this and reusing client.TransportFor above to get this for free + switch { + case bool(glog.V(9)): + transport = client.NewDebuggingRoundTripper(transport, client.CurlCommand, client.URLTiming, client.ResponseHeaders) + case bool(glog.V(8)): + transport = client.NewDebuggingRoundTripper(transport, client.JustURL, client.RequestHeaders, client.ResponseStatus, client.ResponseHeaders) + case bool(glog.V(7)): + transport = client.NewDebuggingRoundTripper(transport, client.JustURL, client.RequestHeaders, client.ResponseStatus) + case bool(glog.V(6)): + transport = client.NewDebuggingRoundTripper(transport, client.URLTiming) + } + client := &http.Client{Transport: transport} req, err := http.NewRequest(e.method, e.url.String(), nil) if err != nil { - return nil, fmt.Errorf("error creating request: %s", err) + return nil, "", fmt.Errorf("error creating request: %v", err) + } + for i := range protocols { + req.Header.Add(httpstream.HeaderProtocolVersion, protocols[i]) } resp, err := client.Do(req) if err != nil { - return nil, fmt.Errorf("error sending request: %s", err) + return nil, "", fmt.Errorf("error sending request: %v", err) } defer resp.Body.Close() - // TODO: handle protocol selection in the future - return e.upgrader.NewConnection(resp) + if resp.StatusCode != http.StatusSwitchingProtocols { + return nil, "", fmt.Errorf("unexpected response status code %d (%s)", resp.StatusCode, http.StatusText(resp.StatusCode)) + } + + conn, err := e.upgrader.NewConnection(resp) + if err != nil { + return nil, "", err + } + + return conn, resp.Header.Get(httpstream.HeaderProtocolVersion), nil +} + +const ( + // The SPDY subprotocol "channel.k8s.io" is used for remote command + // attachment/execution. This represents the initial unversioned subprotocol, + // which has the known bugs http://issues.k8s.io/13394 and + // http://issues.k8s.io/13395. + StreamProtocolV1Name = "channel.k8s.io" + // The SPDY subprotocol "v2.channel.k8s.io" is used for remote command + // attachment/execution. It is the second version of the subprotocol and + // resolves the issues present in the first version. + StreamProtocolV2Name = "v2.channel.k8s.io" +) + +type streamProtocolHandler interface { + stream(httpstream.Connection) error } // Stream opens a protocol streamer to the server and streams until a client closes // the connection or the server disconnects. func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error { - conn, err := e.Dial() + conn, protocol, err := e.Dial([]string{StreamProtocolV2Name, StreamProtocolV1Name}) if err != nil { return err } defer conn.Close() - // TODO: negotiate protocols - streamer := &streamProtocol{ - stdin: stdin, - stdout: stdout, - stderr: stderr, - tty: tty, + + var streamer streamProtocolHandler + + switch protocol { + case StreamProtocolV2Name: + streamer = &streamProtocolV2{ + stdin: stdin, + stdout: stdout, + stderr: stderr, + tty: tty, + } + case "": + glog.Warning("The server did not negotiate a streaming protocol version. Falling back to unversioned") + // TODO restore v1 + streamer = &streamProtocolV1{ + stdin: stdin, + stdout: stdout, + stderr: stderr, + tty: tty, + } } + return streamer.stream(conn) } -type streamProtocol struct { +type streamProtocolV1 struct { stdin io.Reader stdout io.Writer stderr io.Writer tty bool } -func (e *streamProtocol) stream(conn httpstream.Connection) error { +func (e *streamProtocolV1) stream(conn httpstream.Connection) error { + doneChan := make(chan struct{}, 2) + errorChan := make(chan error) + + cp := func(s string, dst io.Writer, src io.Reader) { + glog.V(4).Infof("Copying %s", s) + defer glog.V(4).Infof("Done copying %s", s) + if _, err := io.Copy(dst, src); err != nil && err != io.EOF { + glog.Errorf("Error copying %s: %v", s, err) + } + if s == api.StreamTypeStdout || s == api.StreamTypeStderr { + doneChan <- struct{}{} + } + } + + headers := http.Header{} + headers.Set(api.StreamType, api.StreamTypeError) + errorStream, err := conn.CreateStream(headers) + if err != nil { + return err + } + go func() { + message, err := ioutil.ReadAll(errorStream) + if err != nil && err != io.EOF { + errorChan <- fmt.Errorf("Error reading from error stream: %s", err) + return + } + if len(message) > 0 { + errorChan <- fmt.Errorf("Error executing remote command: %s", message) + return + } + }() + defer errorStream.Reset() + + if e.stdin != nil { + headers.Set(api.StreamType, api.StreamTypeStdin) + remoteStdin, err := conn.CreateStream(headers) + if err != nil { + return err + } + defer remoteStdin.Reset() + // TODO this goroutine will never exit cleanly (the io.Copy never unblocks) + // because stdin is not closed until the process exits. If we try to call + // stdin.Close(), it returns no error but doesn't unblock the copy. It will + // exit when the process exits, instead. + go cp(api.StreamTypeStdin, remoteStdin, e.stdin) + } + + waitCount := 0 + completedStreams := 0 + + if e.stdout != nil { + waitCount++ + headers.Set(api.StreamType, api.StreamTypeStdout) + remoteStdout, err := conn.CreateStream(headers) + if err != nil { + return err + } + defer remoteStdout.Reset() + go cp(api.StreamTypeStdout, e.stdout, remoteStdout) + } + + if e.stderr != nil && !e.tty { + waitCount++ + headers.Set(api.StreamType, api.StreamTypeStderr) + remoteStderr, err := conn.CreateStream(headers) + if err != nil { + return err + } + defer remoteStderr.Reset() + go cp(api.StreamTypeStderr, e.stderr, remoteStderr) + } + +Loop: + for { + select { + case <-doneChan: + completedStreams++ + if completedStreams == waitCount { + break Loop + } + case err := <-errorChan: + return err + } + } + + return nil +} + +// streamProtocolV2 implements version 2 of the streaming protocol for attach +// and exec. The original streaming protocol was unversioned. As a result, this +// version is referred to as version 2, even though it is the first actual +// numbered version. +type streamProtocolV2 struct { + stdin io.Reader + stdout io.Writer + stderr io.Writer + tty bool +} + +func (e *streamProtocolV2) stream(conn httpstream.Connection) error { headers := http.Header{} // set up error stream diff --git a/pkg/client/unversioned/remotecommand/remotecommand_test.go b/pkg/client/unversioned/remotecommand/remotecommand_test.go index 56cb3372bfb..5403dacbacf 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand_test.go +++ b/pkg/client/unversioned/remotecommand/remotecommand_test.go @@ -45,7 +45,7 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro streamCh := make(chan httpstream.Stream) upgrader := spdy.NewResponseUpgrader() - conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error { + conn, protocol := upgrader.UpgradeResponse(w, req, []string{StreamProtocolV2Name, StreamProtocolV1Name}, func(stream httpstream.Stream) error { streamCh <- stream return nil }) @@ -57,6 +57,7 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro return } defer conn.Close() + _ = protocol var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream receivedStreams := 0 @@ -347,7 +348,7 @@ func TestDial(t *testing.T) { checkResponse: true, conn: &fakeConnection{}, resp: &http.Response{ - StatusCode: http.StatusOK, + StatusCode: http.StatusSwitchingProtocols, Body: ioutil.NopCloser(&bytes.Buffer{}), }, } @@ -363,7 +364,7 @@ func TestDial(t *testing.T) { if err != nil { t.Fatal(err) } - conn, err := exec.Dial() + conn, protocol, err := exec.Dial([]string{"a", "b"}) if err != nil { t.Fatal(err) } @@ -373,4 +374,5 @@ func TestDial(t *testing.T) { if !called { t.Errorf("wrapper not called") } + _ = protocol } diff --git a/pkg/kubelet/server.go b/pkg/kubelet/server.go index c8b2df9d613..04969592f65 100644 --- a/pkg/kubelet/server.go +++ b/pkg/kubelet/server.go @@ -44,6 +44,8 @@ import ( "k8s.io/kubernetes/pkg/api/validation" "k8s.io/kubernetes/pkg/auth/authenticator" "k8s.io/kubernetes/pkg/auth/authorizer" + "k8s.io/kubernetes/pkg/client/unversioned/portforward" + "k8s.io/kubernetes/pkg/client/unversioned/remotecommand" "k8s.io/kubernetes/pkg/healthz" "k8s.io/kubernetes/pkg/httplog" kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" @@ -688,7 +690,7 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo streamCh := make(chan httpstream.Stream) upgrader := spdy.NewResponseUpgrader() - conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error { + conn, protocol := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}, func(stream httpstream.Stream) error { streamCh <- stream return nil }) @@ -699,6 +701,9 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo // if we weren't successful in upgrading. return nil, nil, nil, nil, nil, false, false } + if len(protocol) == 0 { + protocol = remotecommand.StreamProtocolV1Name + } conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout()) @@ -783,12 +788,14 @@ func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder Po glog.V(5).Infof("Upgrading port forward response") upgrader := spdy.NewResponseUpgrader() - conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan)) + conn, protocol := upgrader.UpgradeResponse(w, req, []string{portforward.PortForwardProtocolV1Name}, portForwardStreamReceived(streamChan)) if conn == nil { return } defer conn.Close() + _ = protocol + glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout) conn.SetIdleTimeout(idleTimeout) diff --git a/pkg/util/httpstream/httpstream.go b/pkg/util/httpstream/httpstream.go index fff6840f279..9f119b7b7ce 100644 --- a/pkg/util/httpstream/httpstream.go +++ b/pkg/util/httpstream/httpstream.go @@ -24,8 +24,9 @@ import ( ) const ( - HeaderConnection = "Connection" - HeaderUpgrade = "Upgrade" + HeaderConnection = "Connection" + HeaderUpgrade = "Upgrade" + HeaderProtocolVersion = "X-Stream-Protocol-Version" ) // NewStreamHandler defines a function that is called when a new Stream is @@ -39,7 +40,10 @@ func NoOpNewStreamHandler(stream Stream) error { return nil } // Dialer knows how to open a streaming connection to a server. type Dialer interface { - Dial() (Connection, error) + + // Dial opens a streaming connection to a server using one of the protocols + // specified (in order of most preferred to least preferred). + Dial(protocols []string) (Connection, string, error) } // UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade @@ -58,7 +62,7 @@ type ResponseUpgrader interface { // UpgradeResponse upgrades an HTTP response to one that supports multiplexed // streams. newStreamHandler will be called synchronously whenever the // other end of the upgraded connection creates a new stream. - UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection + UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler NewStreamHandler) (Connection, string) } // Connection represents an upgraded HTTP connection. diff --git a/pkg/util/httpstream/spdy/roundtripper_test.go b/pkg/util/httpstream/spdy/roundtripper_test.go index babd23c9011..f16a6e697f3 100644 --- a/pkg/util/httpstream/spdy/roundtripper_test.go +++ b/pkg/util/httpstream/spdy/roundtripper_test.go @@ -120,7 +120,7 @@ func TestRoundTripAndNewConnection(t *testing.T) { streamCh := make(chan httpstream.Stream) responseUpgrader := NewResponseUpgrader() - spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream) error { + spdyConn, _ := responseUpgrader.UpgradeResponse(w, req, []string{"protocol1"}, func(s httpstream.Stream) error { streamCh <- s return nil }) diff --git a/pkg/util/httpstream/spdy/upgrade.go b/pkg/util/httpstream/spdy/upgrade.go index 70c90c8f499..c3e79aa3d17 100644 --- a/pkg/util/httpstream/spdy/upgrade.go +++ b/pkg/util/httpstream/spdy/upgrade.go @@ -39,23 +39,47 @@ func NewResponseUpgrader() httpstream.ResponseUpgrader { return responseUpgrader{} } +func negotiateProtocol(clientProtocols, serverProtocols []string) string { + for i := range clientProtocols { + for j := range serverProtocols { + if clientProtocols[i] == serverProtocols[j] { + return clientProtocols[i] + } + } + } + return "" +} + // UpgradeResponse upgrades an HTTP response to one that supports multiplexed // streams. newStreamHandler will be called synchronously whenever the // other end of the upgraded connection creates a new stream. -func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection { +func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler httpstream.NewStreamHandler) (httpstream.Connection, string) { connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection)) upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade)) if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) { - w.Write([]byte(fmt.Sprintf("Unable to upgrade: missing upgrade headers in request: %#v", req.Header))) w.WriteHeader(http.StatusBadRequest) - return nil + fmt.Fprintf(w, "unable to upgrade: missing upgrade headers in request: %#v", req.Header) + return nil, "" } hijacker, ok := w.(http.Hijacker) if !ok { - w.Write([]byte("Unable to upgrade: unable to hijack response")) w.WriteHeader(http.StatusInternalServerError) - return nil + fmt.Fprintf(w, "unable to upgrade: unable to hijack response") + return nil, "" + } + + var negotiatedProtocol string + clientProtocols := req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)] + if len(clientProtocols) > 0 { + negotiatedProtocol = negotiateProtocol(req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)], protocols) + if len(negotiatedProtocol) > 0 { + w.Header().Add(httpstream.HeaderProtocolVersion, negotiatedProtocol) + } else { + w.WriteHeader(http.StatusForbidden) + fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: server accepts %v", protocols) + return nil, "" + } } w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade) @@ -64,15 +88,15 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque conn, _, err := hijacker.Hijack() if err != nil { - glog.Errorf("Unable to upgrade: error hijacking response: %v", err) - return nil + glog.Errorf("unable to upgrade: error hijacking response: %v", err) + return nil, "" } spdyConn, err := NewServerConnection(conn, newStreamHandler) if err != nil { - glog.Errorf("Unable to upgrade: error creating SPDY server connection: %v", err) - return nil + glog.Errorf("unable to upgrade: error creating SPDY server connection: %v", err) + return nil, "" } - return spdyConn + return spdyConn, negotiatedProtocol } diff --git a/pkg/util/httpstream/spdy/upgrade_test.go b/pkg/util/httpstream/spdy/upgrade_test.go index 4e111407e87..e82f3515ee1 100644 --- a/pkg/util/httpstream/spdy/upgrade_test.go +++ b/pkg/util/httpstream/spdy/upgrade_test.go @@ -53,7 +53,8 @@ func TestUpgradeResponse(t *testing.T) { for i, testCase := range testCases { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upgrader := NewResponseUpgrader() - conn := upgrader.UpgradeResponse(w, req, nil) + conn, protocol := upgrader.UpgradeResponse(w, req, []string{"protocol1"}, nil) + _ = protocol haveErr := conn == nil if e, a := testCase.shouldError, haveErr; e != a { t.Fatalf("%d: expected shouldErr=%t, got %t", i, testCase.shouldError, haveErr) From 6c7b5196197b5c834be900255c4a46646c79344b Mon Sep 17 00:00:00 2001 From: Andy Goldstein Date: Tue, 20 Oct 2015 14:38:06 -0400 Subject: [PATCH 2/5] Move protocol handlers to separate files --- .../remotecommand/remotecommand.go | 220 ------------------ pkg/client/unversioned/remotecommand/v1.go | 126 ++++++++++ pkg/client/unversioned/remotecommand/v2.go | 151 ++++++++++++ 3 files changed, 277 insertions(+), 220 deletions(-) create mode 100644 pkg/client/unversioned/remotecommand/v1.go create mode 100644 pkg/client/unversioned/remotecommand/v2.go diff --git a/pkg/client/unversioned/remotecommand/remotecommand.go b/pkg/client/unversioned/remotecommand/remotecommand.go index 21091c85250..69f150b8630 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand.go +++ b/pkg/client/unversioned/remotecommand/remotecommand.go @@ -19,16 +19,12 @@ package remotecommand import ( "fmt" "io" - "io/ioutil" "net/http" "net/url" - "sync" "github.com/golang/glog" - "k8s.io/kubernetes/pkg/api" client "k8s.io/kubernetes/pkg/client/unversioned" - "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/httpstream" "k8s.io/kubernetes/pkg/util/httpstream/spdy" ) @@ -191,219 +187,3 @@ func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty b return streamer.stream(conn) } - -type streamProtocolV1 struct { - stdin io.Reader - stdout io.Writer - stderr io.Writer - tty bool -} - -func (e *streamProtocolV1) stream(conn httpstream.Connection) error { - doneChan := make(chan struct{}, 2) - errorChan := make(chan error) - - cp := func(s string, dst io.Writer, src io.Reader) { - glog.V(4).Infof("Copying %s", s) - defer glog.V(4).Infof("Done copying %s", s) - if _, err := io.Copy(dst, src); err != nil && err != io.EOF { - glog.Errorf("Error copying %s: %v", s, err) - } - if s == api.StreamTypeStdout || s == api.StreamTypeStderr { - doneChan <- struct{}{} - } - } - - headers := http.Header{} - headers.Set(api.StreamType, api.StreamTypeError) - errorStream, err := conn.CreateStream(headers) - if err != nil { - return err - } - go func() { - message, err := ioutil.ReadAll(errorStream) - if err != nil && err != io.EOF { - errorChan <- fmt.Errorf("Error reading from error stream: %s", err) - return - } - if len(message) > 0 { - errorChan <- fmt.Errorf("Error executing remote command: %s", message) - return - } - }() - defer errorStream.Reset() - - if e.stdin != nil { - headers.Set(api.StreamType, api.StreamTypeStdin) - remoteStdin, err := conn.CreateStream(headers) - if err != nil { - return err - } - defer remoteStdin.Reset() - // TODO this goroutine will never exit cleanly (the io.Copy never unblocks) - // because stdin is not closed until the process exits. If we try to call - // stdin.Close(), it returns no error but doesn't unblock the copy. It will - // exit when the process exits, instead. - go cp(api.StreamTypeStdin, remoteStdin, e.stdin) - } - - waitCount := 0 - completedStreams := 0 - - if e.stdout != nil { - waitCount++ - headers.Set(api.StreamType, api.StreamTypeStdout) - remoteStdout, err := conn.CreateStream(headers) - if err != nil { - return err - } - defer remoteStdout.Reset() - go cp(api.StreamTypeStdout, e.stdout, remoteStdout) - } - - if e.stderr != nil && !e.tty { - waitCount++ - headers.Set(api.StreamType, api.StreamTypeStderr) - remoteStderr, err := conn.CreateStream(headers) - if err != nil { - return err - } - defer remoteStderr.Reset() - go cp(api.StreamTypeStderr, e.stderr, remoteStderr) - } - -Loop: - for { - select { - case <-doneChan: - completedStreams++ - if completedStreams == waitCount { - break Loop - } - case err := <-errorChan: - return err - } - } - - return nil -} - -// streamProtocolV2 implements version 2 of the streaming protocol for attach -// and exec. The original streaming protocol was unversioned. As a result, this -// version is referred to as version 2, even though it is the first actual -// numbered version. -type streamProtocolV2 struct { - stdin io.Reader - stdout io.Writer - stderr io.Writer - tty bool -} - -func (e *streamProtocolV2) stream(conn httpstream.Connection) error { - headers := http.Header{} - - // set up error stream - errorChan := make(chan error) - headers.Set(api.StreamType, api.StreamTypeError) - errorStream, err := conn.CreateStream(headers) - if err != nil { - return err - } - - go func() { - message, err := ioutil.ReadAll(errorStream) - switch { - case err != nil && err != io.EOF: - errorChan <- fmt.Errorf("error reading from error stream: %s", err) - case len(message) > 0: - errorChan <- fmt.Errorf("error executing remote command: %s", message) - default: - errorChan <- nil - } - close(errorChan) - }() - - var wg sync.WaitGroup - var once sync.Once - - // set up stdin stream - if e.stdin != nil { - headers.Set(api.StreamType, api.StreamTypeStdin) - remoteStdin, err := conn.CreateStream(headers) - if err != nil { - return err - } - - // copy from client's stdin to container's stdin - go func() { - // if e.stdin is noninteractive, e.g. `echo abc | kubectl exec -i -- cat`, make sure - // we close remoteStdin as soon as the copy from e.stdin to remoteStdin finishes. Otherwise - // the executed command will remain running. - defer once.Do(func() { remoteStdin.Close() }) - - if _, err := io.Copy(remoteStdin, e.stdin); err != nil { - util.HandleError(err) - } - }() - - // read from remoteStdin until the stream is closed. this is essential to - // be able to exit interactive sessions cleanly and not leak goroutines or - // hang the client's terminal. - // - // go-dockerclient's current hijack implementation - // (https://github.com/fsouza/go-dockerclient/blob/89f3d56d93788dfe85f864a44f85d9738fca0670/client.go#L564) - // waits for all three streams (stdin/stdout/stderr) to finish copying - // before returning. When hijack finishes copying stdout/stderr, it calls - // Close() on its side of remoteStdin, which allows this copy to complete. - // When that happens, we must Close() on our side of remoteStdin, to - // allow the copy in hijack to complete, and hijack to return. - go func() { - defer once.Do(func() { remoteStdin.Close() }) - // this "copy" doesn't actually read anything - it's just here to wait for - // the server to close remoteStdin. - if _, err := io.Copy(ioutil.Discard, remoteStdin); err != nil { - util.HandleError(err) - } - }() - } - - // set up stdout stream - if e.stdout != nil { - headers.Set(api.StreamType, api.StreamTypeStdout) - remoteStdout, err := conn.CreateStream(headers) - if err != nil { - return err - } - - wg.Add(1) - go func() { - defer wg.Done() - if _, err := io.Copy(e.stdout, remoteStdout); err != nil { - util.HandleError(err) - } - }() - } - - // set up stderr stream - if e.stderr != nil && !e.tty { - headers.Set(api.StreamType, api.StreamTypeStderr) - remoteStderr, err := conn.CreateStream(headers) - if err != nil { - return err - } - - wg.Add(1) - go func() { - defer wg.Done() - if _, err := io.Copy(e.stderr, remoteStderr); err != nil { - util.HandleError(err) - } - }() - } - - // we're waiting for stdout/stderr to finish copying - wg.Wait() - - // waits for errorStream to finish reading with an error or nil - return <-errorChan -} diff --git a/pkg/client/unversioned/remotecommand/v1.go b/pkg/client/unversioned/remotecommand/v1.go new file mode 100644 index 00000000000..1a64ed048cc --- /dev/null +++ b/pkg/client/unversioned/remotecommand/v1.go @@ -0,0 +1,126 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "fmt" + "io" + "io/ioutil" + "net/http" + + "github.com/golang/glog" + "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/util/httpstream" +) + +type streamProtocolV1 struct { + stdin io.Reader + stdout io.Writer + stderr io.Writer + tty bool +} + +var _ streamProtocolHandler = &streamProtocolV1{} + +func (e *streamProtocolV1) stream(conn httpstream.Connection) error { + doneChan := make(chan struct{}, 2) + errorChan := make(chan error) + + cp := func(s string, dst io.Writer, src io.Reader) { + glog.V(4).Infof("Copying %s", s) + defer glog.V(4).Infof("Done copying %s", s) + if _, err := io.Copy(dst, src); err != nil && err != io.EOF { + glog.Errorf("Error copying %s: %v", s, err) + } + if s == api.StreamTypeStdout || s == api.StreamTypeStderr { + doneChan <- struct{}{} + } + } + + headers := http.Header{} + headers.Set(api.StreamType, api.StreamTypeError) + errorStream, err := conn.CreateStream(headers) + if err != nil { + return err + } + go func() { + message, err := ioutil.ReadAll(errorStream) + if err != nil && err != io.EOF { + errorChan <- fmt.Errorf("Error reading from error stream: %s", err) + return + } + if len(message) > 0 { + errorChan <- fmt.Errorf("Error executing remote command: %s", message) + return + } + }() + defer errorStream.Reset() + + if e.stdin != nil { + headers.Set(api.StreamType, api.StreamTypeStdin) + remoteStdin, err := conn.CreateStream(headers) + if err != nil { + return err + } + defer remoteStdin.Reset() + // TODO this goroutine will never exit cleanly (the io.Copy never unblocks) + // because stdin is not closed until the process exits. If we try to call + // stdin.Close(), it returns no error but doesn't unblock the copy. It will + // exit when the process exits, instead. + go cp(api.StreamTypeStdin, remoteStdin, e.stdin) + } + + waitCount := 0 + completedStreams := 0 + + if e.stdout != nil { + waitCount++ + headers.Set(api.StreamType, api.StreamTypeStdout) + remoteStdout, err := conn.CreateStream(headers) + if err != nil { + return err + } + defer remoteStdout.Reset() + go cp(api.StreamTypeStdout, e.stdout, remoteStdout) + } + + if e.stderr != nil && !e.tty { + waitCount++ + headers.Set(api.StreamType, api.StreamTypeStderr) + remoteStderr, err := conn.CreateStream(headers) + if err != nil { + return err + } + defer remoteStderr.Reset() + go cp(api.StreamTypeStderr, e.stderr, remoteStderr) + } + +Loop: + for { + select { + case <-doneChan: + completedStreams++ + if completedStreams == waitCount { + break Loop + } + case err := <-errorChan: + return err + } + } + + return nil +} diff --git a/pkg/client/unversioned/remotecommand/v2.go b/pkg/client/unversioned/remotecommand/v2.go new file mode 100644 index 00000000000..ca10dda4956 --- /dev/null +++ b/pkg/client/unversioned/remotecommand/v2.go @@ -0,0 +1,151 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "fmt" + "io" + "io/ioutil" + "net/http" + "sync" + + "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/util" + "k8s.io/kubernetes/pkg/util/httpstream" +) + +// streamProtocolV2 implements version 2 of the streaming protocol for attach +// and exec. The original streaming protocol was unversioned. As a result, this +// version is referred to as version 2, even though it is the first actual +// numbered version. +type streamProtocolV2 struct { + stdin io.Reader + stdout io.Writer + stderr io.Writer + tty bool +} + +var _ streamProtocolHandler = &streamProtocolV2{} + +func (e *streamProtocolV2) stream(conn httpstream.Connection) error { + headers := http.Header{} + + // set up error stream + errorChan := make(chan error) + headers.Set(api.StreamType, api.StreamTypeError) + errorStream, err := conn.CreateStream(headers) + if err != nil { + return err + } + + go func() { + message, err := ioutil.ReadAll(errorStream) + switch { + case err != nil && err != io.EOF: + errorChan <- fmt.Errorf("error reading from error stream: %s", err) + case len(message) > 0: + errorChan <- fmt.Errorf("error executing remote command: %s", message) + default: + errorChan <- nil + } + close(errorChan) + }() + + var wg sync.WaitGroup + var once sync.Once + + // set up stdin stream + if e.stdin != nil { + headers.Set(api.StreamType, api.StreamTypeStdin) + remoteStdin, err := conn.CreateStream(headers) + if err != nil { + return err + } + + // copy from client's stdin to container's stdin + go func() { + // if e.stdin is noninteractive, e.g. `echo abc | kubectl exec -i -- cat`, make sure + // we close remoteStdin as soon as the copy from e.stdin to remoteStdin finishes. Otherwise + // the executed command will remain running. + defer once.Do(func() { remoteStdin.Close() }) + + if _, err := io.Copy(remoteStdin, e.stdin); err != nil { + util.HandleError(err) + } + }() + + // read from remoteStdin until the stream is closed. this is essential to + // be able to exit interactive sessions cleanly and not leak goroutines or + // hang the client's terminal. + // + // go-dockerclient's current hijack implementation + // (https://github.com/fsouza/go-dockerclient/blob/89f3d56d93788dfe85f864a44f85d9738fca0670/client.go#L564) + // waits for all three streams (stdin/stdout/stderr) to finish copying + // before returning. When hijack finishes copying stdout/stderr, it calls + // Close() on its side of remoteStdin, which allows this copy to complete. + // When that happens, we must Close() on our side of remoteStdin, to + // allow the copy in hijack to complete, and hijack to return. + go func() { + defer once.Do(func() { remoteStdin.Close() }) + // this "copy" doesn't actually read anything - it's just here to wait for + // the server to close remoteStdin. + if _, err := io.Copy(ioutil.Discard, remoteStdin); err != nil { + util.HandleError(err) + } + }() + } + + // set up stdout stream + if e.stdout != nil { + headers.Set(api.StreamType, api.StreamTypeStdout) + remoteStdout, err := conn.CreateStream(headers) + if err != nil { + return err + } + + wg.Add(1) + go func() { + defer wg.Done() + if _, err := io.Copy(e.stdout, remoteStdout); err != nil { + util.HandleError(err) + } + }() + } + + // set up stderr stream + if e.stderr != nil && !e.tty { + headers.Set(api.StreamType, api.StreamTypeStderr) + remoteStderr, err := conn.CreateStream(headers) + if err != nil { + return err + } + + wg.Add(1) + go func() { + defer wg.Done() + if _, err := io.Copy(e.stderr, remoteStderr); err != nil { + util.HandleError(err) + } + }() + } + + // we're waiting for stdout/stderr to finish copying + wg.Wait() + + // waits for errorStream to finish reading with an error or nil + return <-errorChan +} From ff9883d9ecd32f9db1077a4b6aeee3c112b0595d Mon Sep 17 00:00:00 2001 From: Andy Goldstein Date: Wed, 21 Oct 2015 20:42:40 -0400 Subject: [PATCH 3/5] Address code review comments --- .../unversioned/portforward/portforward.go | 6 +- .../portforward/portforward_test.go | 2 +- .../remotecommand/remotecommand.go | 11 ++-- .../remotecommand/remotecommand_test.go | 13 ++++- pkg/client/unversioned/remotecommand/v1.go | 8 ++- pkg/kubelet/server.go | 29 +++++++--- pkg/util/httpstream/httpstream.go | 55 +++++++++++++++++-- pkg/util/httpstream/spdy/roundtripper_test.go | 2 +- pkg/util/httpstream/spdy/upgrade.go | 42 +++----------- pkg/util/httpstream/spdy/upgrade_test.go | 3 +- 10 files changed, 107 insertions(+), 64 deletions(-) diff --git a/pkg/client/unversioned/portforward/portforward.go b/pkg/client/unversioned/portforward/portforward.go index a711c9f5883..3c279e7aa9b 100644 --- a/pkg/client/unversioned/portforward/portforward.go +++ b/pkg/client/unversioned/portforward/portforward.go @@ -29,6 +29,7 @@ import ( "github.com/golang/glog" "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/kubelet" "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/httpstream" ) @@ -122,16 +123,13 @@ func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}) (*P }, nil } -// The SPDY subprotocol "portforward.k8s.io" is used for port forwarding. -const PortForwardProtocolV1Name = "portforward.k8s.io" - // ForwardPorts formats and executes a port forwarding request. The connection will remain // open until stopChan is closed. func (pf *PortForwarder) ForwardPorts() error { defer pf.Close() var err error - pf.streamConn, _, err = pf.dialer.Dial([]string{PortForwardProtocolV1Name}) + pf.streamConn, _, err = pf.dialer.Dial(kubelet.PortForwardProtocolV1Name) if err != nil { return fmt.Errorf("error upgrading connection: %s", err) } diff --git a/pkg/client/unversioned/portforward/portforward_test.go b/pkg/client/unversioned/portforward/portforward_test.go index 7a21a9cdcd1..e82b74f6497 100644 --- a/pkg/client/unversioned/portforward/portforward_test.go +++ b/pkg/client/unversioned/portforward/portforward_test.go @@ -44,7 +44,7 @@ type fakeDialer struct { negotiatedProtocol string } -func (d *fakeDialer) Dial(protocols []string) (httpstream.Connection, string, error) { +func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, error) { d.dialed = true return d.conn, d.negotiatedProtocol, d.err } diff --git a/pkg/client/unversioned/remotecommand/remotecommand.go b/pkg/client/unversioned/remotecommand/remotecommand.go index 69f150b8630..99e914abebf 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand.go +++ b/pkg/client/unversioned/remotecommand/remotecommand.go @@ -98,7 +98,7 @@ func NewStreamExecutor(upgrader httpstream.UpgradeRoundTripper, fn func(http.Rou // Dial opens a connection to a remote server and attempts to negotiate a SPDY // connection. Upon success, it returns the connection and the protocol // selected by the server. -func (e *streamExecutor) Dial(protocols []string) (httpstream.Connection, string, error) { +func (e *streamExecutor) Dial(protocols ...string) (httpstream.Connection, string, error) { transport := e.transport // TODO consider removing this and reusing client.TransportFor above to get this for free switch { @@ -111,6 +111,9 @@ func (e *streamExecutor) Dial(protocols []string) (httpstream.Connection, string case bool(glog.V(6)): transport = client.NewDebuggingRoundTripper(transport, client.URLTiming) } + + // TODO the client probably shouldn't be created here, as it doesn't allow + // flexibility to allow callers to configure it. client := &http.Client{Transport: transport} req, err := http.NewRequest(e.method, e.url.String(), nil) @@ -158,7 +161,8 @@ type streamProtocolHandler interface { // Stream opens a protocol streamer to the server and streams until a client closes // the connection or the server disconnects. func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error { - conn, protocol, err := e.Dial([]string{StreamProtocolV2Name, StreamProtocolV1Name}) + supportedProtocols := []string{StreamProtocolV2Name, StreamProtocolV1Name} + conn, protocol, err := e.Dial(supportedProtocols...) if err != nil { return err } @@ -175,8 +179,7 @@ func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty b tty: tty, } case "": - glog.Warning("The server did not negotiate a streaming protocol version. Falling back to unversioned") - // TODO restore v1 + glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to unversioned") streamer = &streamProtocolV1{ stdin: stdin, stdout: stdout, diff --git a/pkg/client/unversioned/remotecommand/remotecommand_test.go b/pkg/client/unversioned/remotecommand/remotecommand_test.go index 5403dacbacf..536c63ee080 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand_test.go +++ b/pkg/client/unversioned/remotecommand/remotecommand_test.go @@ -42,10 +42,17 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + protocol, err := httpstream.Handshake(req, w, []string{StreamProtocolV2Name}, StreamProtocolV1Name) + if err != nil { + t.Fatal(err) + } + if protocol != StreamProtocolV2Name { + t.Fatalf("unexpected protocol: %s", protocol) + } streamCh := make(chan httpstream.Stream) upgrader := spdy.NewResponseUpgrader() - conn, protocol := upgrader.UpgradeResponse(w, req, []string{StreamProtocolV2Name, StreamProtocolV1Name}, func(stream httpstream.Stream) error { + conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream) error { streamCh <- stream return nil }) @@ -57,7 +64,6 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro return } defer conn.Close() - _ = protocol var errorStream, stdinStream, stdoutStream, stderrStream httpstream.Stream receivedStreams := 0 @@ -185,6 +191,7 @@ func TestRequestExecuteRemoteCommand(t *testing.T) { url, _ := url.ParseRequestURI(server.URL) c := client.NewRESTClient(url, "x", nil, -1, -1) req := c.Post().Resource("testing") + req.SetHeader(httpstream.HeaderProtocolVersion, StreamProtocolV2Name) req.Param("command", "ls") req.Param("command", "/") conf := &client.Config{ @@ -364,7 +371,7 @@ func TestDial(t *testing.T) { if err != nil { t.Fatal(err) } - conn, protocol, err := exec.Dial([]string{"a", "b"}) + conn, protocol, err := exec.Dial("protocol1") if err != nil { t.Fatal(err) } diff --git a/pkg/client/unversioned/remotecommand/v1.go b/pkg/client/unversioned/remotecommand/v1.go index 1a64ed048cc..b10e5e1f1e7 100644 --- a/pkg/client/unversioned/remotecommand/v1.go +++ b/pkg/client/unversioned/remotecommand/v1.go @@ -27,6 +27,10 @@ import ( "k8s.io/kubernetes/pkg/util/httpstream" ) +// streamProtocolV1 implements the first version of the streaming exec & attach +// protocol. This version has some bugs, such as not being able to detecte when +// non-interactive stdin data has ended. See http://issues.k8s.io/13394 and +// http://issues.k8s.io/13395 for more details. type streamProtocolV1 struct { stdin io.Reader stdout io.Writer @@ -41,8 +45,8 @@ func (e *streamProtocolV1) stream(conn httpstream.Connection) error { errorChan := make(chan error) cp := func(s string, dst io.Writer, src io.Reader) { - glog.V(4).Infof("Copying %s", s) - defer glog.V(4).Infof("Done copying %s", s) + glog.V(6).Infof("Copying %s", s) + defer glog.V(6).Infof("Done copying %s", s) if _, err := io.Copy(dst, src); err != nil && err != io.EOF { glog.Errorf("Error copying %s: %v", s, err) } diff --git a/pkg/kubelet/server.go b/pkg/kubelet/server.go index 04969592f65..b0dc60fdc82 100644 --- a/pkg/kubelet/server.go +++ b/pkg/kubelet/server.go @@ -44,7 +44,6 @@ import ( "k8s.io/kubernetes/pkg/api/validation" "k8s.io/kubernetes/pkg/auth/authenticator" "k8s.io/kubernetes/pkg/auth/authorizer" - "k8s.io/kubernetes/pkg/client/unversioned/portforward" "k8s.io/kubernetes/pkg/client/unversioned/remotecommand" "k8s.io/kubernetes/pkg/healthz" "k8s.io/kubernetes/pkg/httplog" @@ -687,10 +686,17 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo return streams[0], streams[1], streams[2], streams[3], conn, tty, true } + supportedStreamProtocols := []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name} + _, err := httpstream.Handshake(request.Request, response.ResponseWriter, supportedStreamProtocols, remotecommand.StreamProtocolV1Name) + // negotiated protocol isn't used server side at the moment, but could be in the future + if err != nil { + return nil, nil, nil, nil, nil, false, false + } + streamCh := make(chan httpstream.Stream) upgrader := spdy.NewResponseUpgrader() - conn, protocol := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}, func(stream httpstream.Stream) error { + conn := upgrader.UpgradeResponse(response.ResponseWriter, request.Request, func(stream httpstream.Stream) error { streamCh <- stream return nil }) @@ -701,9 +707,6 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo // if we weren't successful in upgrading. return nil, nil, nil, nil, nil, false, false } - if len(protocol) == 0 { - protocol = remotecommand.StreamProtocolV1Name - } conn.SetIdleTimeout(s.host.StreamingConnectionIdleTimeout()) @@ -778,24 +781,34 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), defaultStreamCreationTimeout) } +// The subprotocol "portforward.k8s.io" is used for port forwarding. +const PortForwardProtocolV1Name = "portforward.k8s.io" + // ServePortForward handles a port forwarding request. A single request is // kept alive as long as the client is still alive and the connection has not // been timed out due to idleness. This function handles multiple forwarded // connections; i.e., multiple `curl http://localhost:8888/` requests will be // handled by a single invocation of ServePortForward. func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) { + supportedPortForwardProtocols := []string{PortForwardProtocolV1Name} + _, err := httpstream.Handshake(req, w, supportedPortForwardProtocols, PortForwardProtocolV1Name) + // negotiated protocol isn't currently used server side, but could be in the future + if err != nil { + // Handshake writes the error to the client + util.HandleError(err) + return + } + streamChan := make(chan httpstream.Stream, 1) glog.V(5).Infof("Upgrading port forward response") upgrader := spdy.NewResponseUpgrader() - conn, protocol := upgrader.UpgradeResponse(w, req, []string{portforward.PortForwardProtocolV1Name}, portForwardStreamReceived(streamChan)) + conn := upgrader.UpgradeResponse(w, req, portForwardStreamReceived(streamChan)) if conn == nil { return } defer conn.Close() - _ = protocol - glog.V(5).Infof("(conn=%p) setting port forwarding streaming connection idle timeout to %v", conn, idleTimeout) conn.SetIdleTimeout(idleTimeout) diff --git a/pkg/util/httpstream/httpstream.go b/pkg/util/httpstream/httpstream.go index 9f119b7b7ce..80c3cd78fc0 100644 --- a/pkg/util/httpstream/httpstream.go +++ b/pkg/util/httpstream/httpstream.go @@ -17,6 +17,7 @@ limitations under the License. package httpstream import ( + "fmt" "io" "net/http" "strings" @@ -24,9 +25,10 @@ import ( ) const ( - HeaderConnection = "Connection" - HeaderUpgrade = "Upgrade" - HeaderProtocolVersion = "X-Stream-Protocol-Version" + HeaderConnection = "Connection" + HeaderUpgrade = "Upgrade" + HeaderProtocolVersion = "X-Stream-Protocol-Version" + HeaderAcceptedProtocolVersions = "X-Accepted-Stream-Protocol-Versions" ) // NewStreamHandler defines a function that is called when a new Stream is @@ -43,7 +45,7 @@ type Dialer interface { // Dial opens a streaming connection to a server using one of the protocols // specified (in order of most preferred to least preferred). - Dial(protocols []string) (Connection, string, error) + Dial(protocols ...string) (Connection, string, error) } // UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade @@ -60,9 +62,9 @@ type UpgradeRoundTripper interface { // add streaming support to them. type ResponseUpgrader interface { // UpgradeResponse upgrades an HTTP response to one that supports multiplexed - // streams. newStreamHandler will be called synchronously whenever the + // streams. newStreamHandler will be called asynchronously whenever the // other end of the upgraded connection creates a new stream. - UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler NewStreamHandler) (Connection, string) + UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection } // Connection represents an upgraded HTTP connection. @@ -100,3 +102,44 @@ func IsUpgradeRequest(req *http.Request) bool { } return false } + +func negotiateProtocol(clientProtocols, serverProtocols []string) string { + for i := range clientProtocols { + for j := range serverProtocols { + if clientProtocols[i] == serverProtocols[j] { + return clientProtocols[i] + } + } + } + return "" +} + +// Handshake performs a subprotocol negotiation. If the client did not request +// a specific subprotocol, defaultProtocol is used. If the client did request a +// subprotocol, Handshake will select the first common value found in +// serverProtocols. If a match is found, Handshake adds a response header +// indicating the chosen subprotocol. If no match is found, HTTP forbidden is +// returned, along with a response header containing the list of protocols the +// server can accept. +func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string, defaultProtocol string) (string, error) { + clientProtocols := req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)] + if len(clientProtocols) == 0 { + // Kube 1.0 client that didn't support subprotocol negotiation + // TODO remove this defaulting logic once Kube 1.0 is no longer supported + w.Header().Add(HeaderProtocolVersion, defaultProtocol) + return defaultProtocol, nil + } + + negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols) + if len(negotiatedProtocol) == 0 { + w.WriteHeader(http.StatusForbidden) + for i := range serverProtocols { + w.Header().Add(HeaderAcceptedProtocolVersions, serverProtocols[i]) + } + fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: client supports %v, server accepts %v", clientProtocols, serverProtocols) + return "", fmt.Errorf("unable to upgrade: unable to negotiate protocol: client supports %v, server supports %v", clientProtocols, serverProtocols) + } + + w.Header().Add(HeaderProtocolVersion, negotiatedProtocol) + return negotiatedProtocol, nil +} diff --git a/pkg/util/httpstream/spdy/roundtripper_test.go b/pkg/util/httpstream/spdy/roundtripper_test.go index f16a6e697f3..babd23c9011 100644 --- a/pkg/util/httpstream/spdy/roundtripper_test.go +++ b/pkg/util/httpstream/spdy/roundtripper_test.go @@ -120,7 +120,7 @@ func TestRoundTripAndNewConnection(t *testing.T) { streamCh := make(chan httpstream.Stream) responseUpgrader := NewResponseUpgrader() - spdyConn, _ := responseUpgrader.UpgradeResponse(w, req, []string{"protocol1"}, func(s httpstream.Stream) error { + spdyConn := responseUpgrader.UpgradeResponse(w, req, func(s httpstream.Stream) error { streamCh <- s return nil }) diff --git a/pkg/util/httpstream/spdy/upgrade.go b/pkg/util/httpstream/spdy/upgrade.go index c3e79aa3d17..4fd2a40521a 100644 --- a/pkg/util/httpstream/spdy/upgrade.go +++ b/pkg/util/httpstream/spdy/upgrade.go @@ -21,7 +21,7 @@ import ( "net/http" "strings" - "github.com/golang/glog" + "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/httpstream" ) @@ -39,47 +39,23 @@ func NewResponseUpgrader() httpstream.ResponseUpgrader { return responseUpgrader{} } -func negotiateProtocol(clientProtocols, serverProtocols []string) string { - for i := range clientProtocols { - for j := range serverProtocols { - if clientProtocols[i] == serverProtocols[j] { - return clientProtocols[i] - } - } - } - return "" -} - // UpgradeResponse upgrades an HTTP response to one that supports multiplexed // streams. newStreamHandler will be called synchronously whenever the // other end of the upgraded connection creates a new stream. -func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, protocols []string, newStreamHandler httpstream.NewStreamHandler) (httpstream.Connection, string) { +func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler httpstream.NewStreamHandler) httpstream.Connection { connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection)) upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade)) if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) { w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, "unable to upgrade: missing upgrade headers in request: %#v", req.Header) - return nil, "" + return nil } hijacker, ok := w.(http.Hijacker) if !ok { w.WriteHeader(http.StatusInternalServerError) fmt.Fprintf(w, "unable to upgrade: unable to hijack response") - return nil, "" - } - - var negotiatedProtocol string - clientProtocols := req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)] - if len(clientProtocols) > 0 { - negotiatedProtocol = negotiateProtocol(req.Header[http.CanonicalHeaderKey(httpstream.HeaderProtocolVersion)], protocols) - if len(negotiatedProtocol) > 0 { - w.Header().Add(httpstream.HeaderProtocolVersion, negotiatedProtocol) - } else { - w.WriteHeader(http.StatusForbidden) - fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: server accepts %v", protocols) - return nil, "" - } + return nil } w.Header().Add(httpstream.HeaderConnection, httpstream.HeaderUpgrade) @@ -88,15 +64,15 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque conn, _, err := hijacker.Hijack() if err != nil { - glog.Errorf("unable to upgrade: error hijacking response: %v", err) - return nil, "" + util.HandleError(fmt.Errorf("unable to upgrade: error hijacking response: %v", err)) + return nil } spdyConn, err := NewServerConnection(conn, newStreamHandler) if err != nil { - glog.Errorf("unable to upgrade: error creating SPDY server connection: %v", err) - return nil, "" + util.HandleError(fmt.Errorf("unable to upgrade: error creating SPDY server connection: %v", err)) + return nil } - return spdyConn, negotiatedProtocol + return spdyConn } diff --git a/pkg/util/httpstream/spdy/upgrade_test.go b/pkg/util/httpstream/spdy/upgrade_test.go index e82f3515ee1..4e111407e87 100644 --- a/pkg/util/httpstream/spdy/upgrade_test.go +++ b/pkg/util/httpstream/spdy/upgrade_test.go @@ -53,8 +53,7 @@ func TestUpgradeResponse(t *testing.T) { for i, testCase := range testCases { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { upgrader := NewResponseUpgrader() - conn, protocol := upgrader.UpgradeResponse(w, req, []string{"protocol1"}, nil) - _ = protocol + conn := upgrader.UpgradeResponse(w, req, nil) haveErr := conn == nil if e, a := testCase.shouldError, haveErr; e != a { t.Fatalf("%d: expected shouldErr=%t, got %t", i, testCase.shouldError, haveErr) From ad4f108bfa81c9c33047fe8111063e833c1afdf7 Mon Sep 17 00:00:00 2001 From: Andy Goldstein Date: Wed, 21 Oct 2015 22:37:26 -0400 Subject: [PATCH 4/5] Move port forward protocol constant to subpackage Move port forward protocol name constant to a subpackage underneath pkg/kubelet to avoid flags applicable to the kubelet leaking into kubectl. Eventually, handlers for specific protocol versions will move into the new subpackage as well. --- .../unversioned/portforward/portforward.go | 4 ++-- pkg/kubelet/portforward/constants.go | 21 +++++++++++++++++++ pkg/kubelet/server.go | 8 +++---- 3 files changed, 26 insertions(+), 7 deletions(-) create mode 100644 pkg/kubelet/portforward/constants.go diff --git a/pkg/client/unversioned/portforward/portforward.go b/pkg/client/unversioned/portforward/portforward.go index 3c279e7aa9b..693ad930865 100644 --- a/pkg/client/unversioned/portforward/portforward.go +++ b/pkg/client/unversioned/portforward/portforward.go @@ -29,7 +29,7 @@ import ( "github.com/golang/glog" "k8s.io/kubernetes/pkg/api" - "k8s.io/kubernetes/pkg/kubelet" + "k8s.io/kubernetes/pkg/kubelet/portforward" "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/httpstream" ) @@ -129,7 +129,7 @@ func (pf *PortForwarder) ForwardPorts() error { defer pf.Close() var err error - pf.streamConn, _, err = pf.dialer.Dial(kubelet.PortForwardProtocolV1Name) + pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name) if err != nil { return fmt.Errorf("error upgrading connection: %s", err) } diff --git a/pkg/kubelet/portforward/constants.go b/pkg/kubelet/portforward/constants.go new file mode 100644 index 00000000000..f438670675f --- /dev/null +++ b/pkg/kubelet/portforward/constants.go @@ -0,0 +1,21 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// package portforward contains server-side logic for handling port forwarding requests. +package portforward + +// The subprotocol "portforward.k8s.io" is used for port forwarding. +const PortForwardProtocolV1Name = "portforward.k8s.io" diff --git a/pkg/kubelet/server.go b/pkg/kubelet/server.go index b0dc60fdc82..83b95eb0914 100644 --- a/pkg/kubelet/server.go +++ b/pkg/kubelet/server.go @@ -48,6 +48,7 @@ import ( "k8s.io/kubernetes/pkg/healthz" "k8s.io/kubernetes/pkg/httplog" kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" + "k8s.io/kubernetes/pkg/kubelet/portforward" "k8s.io/kubernetes/pkg/types" "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/flushwriter" @@ -781,17 +782,14 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp ServePortForward(response.ResponseWriter, request.Request, s.host, podName, uid, s.host.StreamingConnectionIdleTimeout(), defaultStreamCreationTimeout) } -// The subprotocol "portforward.k8s.io" is used for port forwarding. -const PortForwardProtocolV1Name = "portforward.k8s.io" - // ServePortForward handles a port forwarding request. A single request is // kept alive as long as the client is still alive and the connection has not // been timed out due to idleness. This function handles multiple forwarded // connections; i.e., multiple `curl http://localhost:8888/` requests will be // handled by a single invocation of ServePortForward. func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) { - supportedPortForwardProtocols := []string{PortForwardProtocolV1Name} - _, err := httpstream.Handshake(req, w, supportedPortForwardProtocols, PortForwardProtocolV1Name) + supportedPortForwardProtocols := []string{portforward.PortForwardProtocolV1Name} + _, err := httpstream.Handshake(req, w, supportedPortForwardProtocols, portforward.PortForwardProtocolV1Name) // negotiated protocol isn't currently used server side, but could be in the future if err != nil { // Handshake writes the error to the client From 6fddb0e83ac582f7a33c43ea7ac2a073f9ff98c5 Mon Sep 17 00:00:00 2001 From: Andy Goldstein Date: Fri, 23 Oct 2015 14:09:41 -0400 Subject: [PATCH 5/5] Add httpstream.Handshake unit test --- .../remotecommand/remotecommand.go | 4 +- pkg/util/httpstream/httpstream_test.go | 120 ++++++++++++++++++ 2 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 pkg/util/httpstream/httpstream_test.go diff --git a/pkg/client/unversioned/remotecommand/remotecommand.go b/pkg/client/unversioned/remotecommand/remotecommand.go index 99e914abebf..4feb953d95f 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand.go +++ b/pkg/client/unversioned/remotecommand/remotecommand.go @@ -179,7 +179,9 @@ func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty b tty: tty, } case "": - glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to unversioned") + glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", StreamProtocolV1Name) + fallthrough + case StreamProtocolV1Name: streamer = &streamProtocolV1{ stdin: stdin, stdout: stdout, diff --git a/pkg/util/httpstream/httpstream_test.go b/pkg/util/httpstream/httpstream_test.go new file mode 100644 index 00000000000..7a1bbaefb89 --- /dev/null +++ b/pkg/util/httpstream/httpstream_test.go @@ -0,0 +1,120 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package httpstream + +import ( + "net/http" + "reflect" + "testing" +) + +type responseWriter struct { + header http.Header + statusCode *int +} + +func newResponseWriter() *responseWriter { + return &responseWriter{ + header: make(http.Header), + } +} + +func (r *responseWriter) Header() http.Header { + return r.header +} + +func (r *responseWriter) WriteHeader(code int) { + r.statusCode = &code +} + +func (r *responseWriter) Write([]byte) (int, error) { + return 0, nil +} + +func TestHandshake(t *testing.T) { + defaultProtocol := "default" + + tests := map[string]struct { + clientProtocols []string + serverProtocols []string + expectedProtocol string + expectError bool + }{ + "no client protocols": { + clientProtocols: []string{}, + serverProtocols: []string{"a", "b"}, + expectedProtocol: defaultProtocol, + }, + "no common protocol": { + clientProtocols: []string{"c"}, + serverProtocols: []string{"a", "b"}, + expectedProtocol: "", + expectError: true, + }, + "common protocol": { + clientProtocols: []string{"b"}, + serverProtocols: []string{"a", "b"}, + expectedProtocol: "b", + }, + } + + for name, test := range tests { + req, err := http.NewRequest("GET", "http://www.example.com/", nil) + if err != nil { + t.Fatalf("%s: error creating request: %v", name, err) + } + + for _, p := range test.clientProtocols { + req.Header.Add(HeaderProtocolVersion, p) + } + + w := newResponseWriter() + negotiated, err := Handshake(req, w, test.serverProtocols, defaultProtocol) + + // verify negotiated protocol + if e, a := test.expectedProtocol, negotiated; e != a { + t.Errorf("%s: protocol: expected %q, got %q", name, e, a) + } + + if test.expectError { + if err == nil { + t.Errorf("%s: expected error but did not get one", name) + } + if w.statusCode == nil { + t.Errorf("%s: expected w.statusCode to be set", name) + } else if e, a := http.StatusForbidden, *w.statusCode; e != a { + t.Errorf("%s: w.statusCode: expected %d, got %d", name, e, a) + } + if e, a := test.serverProtocols, w.Header()[HeaderAcceptedProtocolVersions]; !reflect.DeepEqual(e, a) { + t.Errorf("%s: accepted server protocols: expected %v, got %v", name, e, a) + } + continue + } + if !test.expectError && err != nil { + t.Errorf("%s: unexpected error: %v", name, err) + continue + } + if w.statusCode != nil { + t.Errorf("%s: unexpected non-nil w.statusCode: %d", w.statusCode) + } + + // verify response headers + if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !reflect.DeepEqual(e, a) { + t.Errorf("%s: protocol response header: expected %v, got %v", name, e, a) + } + } +}