diff --git a/pkg/client/unversioned/portforward/portforward.go b/pkg/client/unversioned/portforward/portforward.go index 36916bdcf49..693ad930865 100644 --- a/pkg/client/unversioned/portforward/portforward.go +++ b/pkg/client/unversioned/portforward/portforward.go @@ -29,6 +29,7 @@ import ( "github.com/golang/glog" "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/kubelet/portforward" "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/httpstream" ) @@ -128,7 +129,7 @@ func (pf *PortForwarder) ForwardPorts() error { defer pf.Close() var err error - pf.streamConn, err = pf.dialer.Dial() + pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name) if err != nil { return fmt.Errorf("error upgrading connection: %s", err) } diff --git a/pkg/client/unversioned/portforward/portforward_test.go b/pkg/client/unversioned/portforward/portforward_test.go index f208a9e8814..e82b74f6497 100644 --- a/pkg/client/unversioned/portforward/portforward_test.go +++ b/pkg/client/unversioned/portforward/portforward_test.go @@ -38,14 +38,15 @@ import ( ) type fakeDialer struct { - dialed bool - conn httpstream.Connection - err error + dialed bool + conn httpstream.Connection + err error + negotiatedProtocol string } -func (d *fakeDialer) Dial() (httpstream.Connection, error) { +func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, error) { d.dialed = true - return d.conn, d.err + return d.conn, d.negotiatedProtocol, d.err } func TestParsePortsAndNew(t *testing.T) { diff --git a/pkg/client/unversioned/remotecommand/remotecommand.go b/pkg/client/unversioned/remotecommand/remotecommand.go index 505e3fc9c1b..4feb953d95f 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand.go +++ b/pkg/client/unversioned/remotecommand/remotecommand.go @@ -19,14 +19,12 @@ package remotecommand import ( "fmt" "io" - "io/ioutil" "net/http" "net/url" - "sync" - "k8s.io/kubernetes/pkg/api" + "github.com/golang/glog" + client "k8s.io/kubernetes/pkg/client/unversioned" - "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/httpstream" "k8s.io/kubernetes/pkg/util/httpstream/spdy" ) @@ -97,155 +95,100 @@ func NewStreamExecutor(upgrader httpstream.UpgradeRoundTripper, fn func(http.Rou }, nil } -// Dial opens a connection to a remote server and attempts to negotiate a SPDY connection. -func (e *streamExecutor) Dial() (httpstream.Connection, error) { - client := &http.Client{Transport: e.transport} +// Dial opens a connection to a remote server and attempts to negotiate a SPDY +// connection. Upon success, it returns the connection and the protocol +// selected by the server. +func (e *streamExecutor) Dial(protocols ...string) (httpstream.Connection, string, error) { + transport := e.transport + // TODO consider removing this and reusing client.TransportFor above to get this for free + switch { + case bool(glog.V(9)): + transport = client.NewDebuggingRoundTripper(transport, client.CurlCommand, client.URLTiming, client.ResponseHeaders) + case bool(glog.V(8)): + transport = client.NewDebuggingRoundTripper(transport, client.JustURL, client.RequestHeaders, client.ResponseStatus, client.ResponseHeaders) + case bool(glog.V(7)): + transport = client.NewDebuggingRoundTripper(transport, client.JustURL, client.RequestHeaders, client.ResponseStatus) + case bool(glog.V(6)): + transport = client.NewDebuggingRoundTripper(transport, client.URLTiming) + } + + // TODO the client probably shouldn't be created here, as it doesn't allow + // flexibility to allow callers to configure it. + client := &http.Client{Transport: transport} req, err := http.NewRequest(e.method, e.url.String(), nil) if err != nil { - return nil, fmt.Errorf("error creating request: %s", err) + return nil, "", fmt.Errorf("error creating request: %v", err) + } + for i := range protocols { + req.Header.Add(httpstream.HeaderProtocolVersion, protocols[i]) } resp, err := client.Do(req) if err != nil { - return nil, fmt.Errorf("error sending request: %s", err) + return nil, "", fmt.Errorf("error sending request: %v", err) } defer resp.Body.Close() - // TODO: handle protocol selection in the future - return e.upgrader.NewConnection(resp) + if resp.StatusCode != http.StatusSwitchingProtocols { + return nil, "", fmt.Errorf("unexpected response status code %d (%s)", resp.StatusCode, http.StatusText(resp.StatusCode)) + } + + conn, err := e.upgrader.NewConnection(resp) + if err != nil { + return nil, "", err + } + + return conn, resp.Header.Get(httpstream.HeaderProtocolVersion), nil +} + +const ( + // The SPDY subprotocol "channel.k8s.io" is used for remote command + // attachment/execution. This represents the initial unversioned subprotocol, + // which has the known bugs http://issues.k8s.io/13394 and + // http://issues.k8s.io/13395. + StreamProtocolV1Name = "channel.k8s.io" + // The SPDY subprotocol "v2.channel.k8s.io" is used for remote command + // attachment/execution. It is the second version of the subprotocol and + // resolves the issues present in the first version. + StreamProtocolV2Name = "v2.channel.k8s.io" +) + +type streamProtocolHandler interface { + stream(httpstream.Connection) error } // Stream opens a protocol streamer to the server and streams until a client closes // the connection or the server disconnects. func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error { - conn, err := e.Dial() + supportedProtocols := []string{StreamProtocolV2Name, StreamProtocolV1Name} + conn, protocol, err := e.Dial(supportedProtocols...) if err != nil { return err } defer conn.Close() - // TODO: negotiate protocols - streamer := &streamProtocol{ - stdin: stdin, - stdout: stdout, - stderr: stderr, - tty: tty, + + var streamer streamProtocolHandler + + switch protocol { + case StreamProtocolV2Name: + streamer = &streamProtocolV2{ + stdin: stdin, + stdout: stdout, + stderr: stderr, + tty: tty, + } + case "": + glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", StreamProtocolV1Name) + fallthrough + case StreamProtocolV1Name: + streamer = &streamProtocolV1{ + stdin: stdin, + stdout: stdout, + stderr: stderr, + tty: tty, + } } + return streamer.stream(conn) } - -type streamProtocol struct { - stdin io.Reader - stdout io.Writer - stderr io.Writer - tty bool -} - -func (e *streamProtocol) stream(conn httpstream.Connection) error { - headers := http.Header{} - - // set up error stream - errorChan := make(chan error) - headers.Set(api.StreamType, api.StreamTypeError) - errorStream, err := conn.CreateStream(headers) - if err != nil { - return err - } - - go func() { - message, err := ioutil.ReadAll(errorStream) - switch { - case err != nil && err != io.EOF: - errorChan <- fmt.Errorf("error reading from error stream: %s", err) - case len(message) > 0: - errorChan <- fmt.Errorf("error executing remote command: %s", message) - default: - errorChan <- nil - } - close(errorChan) - }() - - var wg sync.WaitGroup - var once sync.Once - - // set up stdin stream - if e.stdin != nil { - headers.Set(api.StreamType, api.StreamTypeStdin) - remoteStdin, err := conn.CreateStream(headers) - if err != nil { - return err - } - - // copy from client's stdin to container's stdin - go func() { - // if e.stdin is noninteractive, e.g. `echo abc | kubectl exec -i -- cat`, make sure - // we close remoteStdin as soon as the copy from e.stdin to remoteStdin finishes. Otherwise - // the executed command will remain running. - defer once.Do(func() { remoteStdin.Close() }) - - if _, err := io.Copy(remoteStdin, e.stdin); err != nil { - util.HandleError(err) - } - }() - - // read from remoteStdin until the stream is closed. this is essential to - // be able to exit interactive sessions cleanly and not leak goroutines or - // hang the client's terminal. - // - // go-dockerclient's current hijack implementation - // (https://github.com/fsouza/go-dockerclient/blob/89f3d56d93788dfe85f864a44f85d9738fca0670/client.go#L564) - // waits for all three streams (stdin/stdout/stderr) to finish copying - // before returning. When hijack finishes copying stdout/stderr, it calls - // Close() on its side of remoteStdin, which allows this copy to complete. - // When that happens, we must Close() on our side of remoteStdin, to - // allow the copy in hijack to complete, and hijack to return. - go func() { - defer once.Do(func() { remoteStdin.Close() }) - // this "copy" doesn't actually read anything - it's just here to wait for - // the server to close remoteStdin. - if _, err := io.Copy(ioutil.Discard, remoteStdin); err != nil { - util.HandleError(err) - } - }() - } - - // set up stdout stream - if e.stdout != nil { - headers.Set(api.StreamType, api.StreamTypeStdout) - remoteStdout, err := conn.CreateStream(headers) - if err != nil { - return err - } - - wg.Add(1) - go func() { - defer wg.Done() - if _, err := io.Copy(e.stdout, remoteStdout); err != nil { - util.HandleError(err) - } - }() - } - - // set up stderr stream - if e.stderr != nil && !e.tty { - headers.Set(api.StreamType, api.StreamTypeStderr) - remoteStderr, err := conn.CreateStream(headers) - if err != nil { - return err - } - - wg.Add(1) - go func() { - defer wg.Done() - if _, err := io.Copy(e.stderr, remoteStderr); err != nil { - util.HandleError(err) - } - }() - } - - // we're waiting for stdout/stderr to finish copying - wg.Wait() - - // waits for errorStream to finish reading with an error or nil - return <-errorChan -} diff --git a/pkg/client/unversioned/remotecommand/remotecommand_test.go b/pkg/client/unversioned/remotecommand/remotecommand_test.go index 56cb3372bfb..536c63ee080 100644 --- a/pkg/client/unversioned/remotecommand/remotecommand_test.go +++ b/pkg/client/unversioned/remotecommand/remotecommand_test.go @@ -42,6 +42,13 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro } return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + protocol, err := httpstream.Handshake(req, w, []string{StreamProtocolV2Name}, StreamProtocolV1Name) + if err != nil { + t.Fatal(err) + } + if protocol != StreamProtocolV2Name { + t.Fatalf("unexpected protocol: %s", protocol) + } streamCh := make(chan httpstream.Stream) upgrader := spdy.NewResponseUpgrader() @@ -184,6 +191,7 @@ func TestRequestExecuteRemoteCommand(t *testing.T) { url, _ := url.ParseRequestURI(server.URL) c := client.NewRESTClient(url, "x", nil, -1, -1) req := c.Post().Resource("testing") + req.SetHeader(httpstream.HeaderProtocolVersion, StreamProtocolV2Name) req.Param("command", "ls") req.Param("command", "/") conf := &client.Config{ @@ -347,7 +355,7 @@ func TestDial(t *testing.T) { checkResponse: true, conn: &fakeConnection{}, resp: &http.Response{ - StatusCode: http.StatusOK, + StatusCode: http.StatusSwitchingProtocols, Body: ioutil.NopCloser(&bytes.Buffer{}), }, } @@ -363,7 +371,7 @@ func TestDial(t *testing.T) { if err != nil { t.Fatal(err) } - conn, err := exec.Dial() + conn, protocol, err := exec.Dial("protocol1") if err != nil { t.Fatal(err) } @@ -373,4 +381,5 @@ func TestDial(t *testing.T) { if !called { t.Errorf("wrapper not called") } + _ = protocol } diff --git a/pkg/client/unversioned/remotecommand/v1.go b/pkg/client/unversioned/remotecommand/v1.go new file mode 100644 index 00000000000..b10e5e1f1e7 --- /dev/null +++ b/pkg/client/unversioned/remotecommand/v1.go @@ -0,0 +1,130 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "fmt" + "io" + "io/ioutil" + "net/http" + + "github.com/golang/glog" + "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/util/httpstream" +) + +// streamProtocolV1 implements the first version of the streaming exec & attach +// protocol. This version has some bugs, such as not being able to detecte when +// non-interactive stdin data has ended. See http://issues.k8s.io/13394 and +// http://issues.k8s.io/13395 for more details. +type streamProtocolV1 struct { + stdin io.Reader + stdout io.Writer + stderr io.Writer + tty bool +} + +var _ streamProtocolHandler = &streamProtocolV1{} + +func (e *streamProtocolV1) stream(conn httpstream.Connection) error { + doneChan := make(chan struct{}, 2) + errorChan := make(chan error) + + cp := func(s string, dst io.Writer, src io.Reader) { + glog.V(6).Infof("Copying %s", s) + defer glog.V(6).Infof("Done copying %s", s) + if _, err := io.Copy(dst, src); err != nil && err != io.EOF { + glog.Errorf("Error copying %s: %v", s, err) + } + if s == api.StreamTypeStdout || s == api.StreamTypeStderr { + doneChan <- struct{}{} + } + } + + headers := http.Header{} + headers.Set(api.StreamType, api.StreamTypeError) + errorStream, err := conn.CreateStream(headers) + if err != nil { + return err + } + go func() { + message, err := ioutil.ReadAll(errorStream) + if err != nil && err != io.EOF { + errorChan <- fmt.Errorf("Error reading from error stream: %s", err) + return + } + if len(message) > 0 { + errorChan <- fmt.Errorf("Error executing remote command: %s", message) + return + } + }() + defer errorStream.Reset() + + if e.stdin != nil { + headers.Set(api.StreamType, api.StreamTypeStdin) + remoteStdin, err := conn.CreateStream(headers) + if err != nil { + return err + } + defer remoteStdin.Reset() + // TODO this goroutine will never exit cleanly (the io.Copy never unblocks) + // because stdin is not closed until the process exits. If we try to call + // stdin.Close(), it returns no error but doesn't unblock the copy. It will + // exit when the process exits, instead. + go cp(api.StreamTypeStdin, remoteStdin, e.stdin) + } + + waitCount := 0 + completedStreams := 0 + + if e.stdout != nil { + waitCount++ + headers.Set(api.StreamType, api.StreamTypeStdout) + remoteStdout, err := conn.CreateStream(headers) + if err != nil { + return err + } + defer remoteStdout.Reset() + go cp(api.StreamTypeStdout, e.stdout, remoteStdout) + } + + if e.stderr != nil && !e.tty { + waitCount++ + headers.Set(api.StreamType, api.StreamTypeStderr) + remoteStderr, err := conn.CreateStream(headers) + if err != nil { + return err + } + defer remoteStderr.Reset() + go cp(api.StreamTypeStderr, e.stderr, remoteStderr) + } + +Loop: + for { + select { + case <-doneChan: + completedStreams++ + if completedStreams == waitCount { + break Loop + } + case err := <-errorChan: + return err + } + } + + return nil +} diff --git a/pkg/client/unversioned/remotecommand/v2.go b/pkg/client/unversioned/remotecommand/v2.go new file mode 100644 index 00000000000..ca10dda4956 --- /dev/null +++ b/pkg/client/unversioned/remotecommand/v2.go @@ -0,0 +1,151 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package remotecommand + +import ( + "fmt" + "io" + "io/ioutil" + "net/http" + "sync" + + "k8s.io/kubernetes/pkg/api" + "k8s.io/kubernetes/pkg/util" + "k8s.io/kubernetes/pkg/util/httpstream" +) + +// streamProtocolV2 implements version 2 of the streaming protocol for attach +// and exec. The original streaming protocol was unversioned. As a result, this +// version is referred to as version 2, even though it is the first actual +// numbered version. +type streamProtocolV2 struct { + stdin io.Reader + stdout io.Writer + stderr io.Writer + tty bool +} + +var _ streamProtocolHandler = &streamProtocolV2{} + +func (e *streamProtocolV2) stream(conn httpstream.Connection) error { + headers := http.Header{} + + // set up error stream + errorChan := make(chan error) + headers.Set(api.StreamType, api.StreamTypeError) + errorStream, err := conn.CreateStream(headers) + if err != nil { + return err + } + + go func() { + message, err := ioutil.ReadAll(errorStream) + switch { + case err != nil && err != io.EOF: + errorChan <- fmt.Errorf("error reading from error stream: %s", err) + case len(message) > 0: + errorChan <- fmt.Errorf("error executing remote command: %s", message) + default: + errorChan <- nil + } + close(errorChan) + }() + + var wg sync.WaitGroup + var once sync.Once + + // set up stdin stream + if e.stdin != nil { + headers.Set(api.StreamType, api.StreamTypeStdin) + remoteStdin, err := conn.CreateStream(headers) + if err != nil { + return err + } + + // copy from client's stdin to container's stdin + go func() { + // if e.stdin is noninteractive, e.g. `echo abc | kubectl exec -i -- cat`, make sure + // we close remoteStdin as soon as the copy from e.stdin to remoteStdin finishes. Otherwise + // the executed command will remain running. + defer once.Do(func() { remoteStdin.Close() }) + + if _, err := io.Copy(remoteStdin, e.stdin); err != nil { + util.HandleError(err) + } + }() + + // read from remoteStdin until the stream is closed. this is essential to + // be able to exit interactive sessions cleanly and not leak goroutines or + // hang the client's terminal. + // + // go-dockerclient's current hijack implementation + // (https://github.com/fsouza/go-dockerclient/blob/89f3d56d93788dfe85f864a44f85d9738fca0670/client.go#L564) + // waits for all three streams (stdin/stdout/stderr) to finish copying + // before returning. When hijack finishes copying stdout/stderr, it calls + // Close() on its side of remoteStdin, which allows this copy to complete. + // When that happens, we must Close() on our side of remoteStdin, to + // allow the copy in hijack to complete, and hijack to return. + go func() { + defer once.Do(func() { remoteStdin.Close() }) + // this "copy" doesn't actually read anything - it's just here to wait for + // the server to close remoteStdin. + if _, err := io.Copy(ioutil.Discard, remoteStdin); err != nil { + util.HandleError(err) + } + }() + } + + // set up stdout stream + if e.stdout != nil { + headers.Set(api.StreamType, api.StreamTypeStdout) + remoteStdout, err := conn.CreateStream(headers) + if err != nil { + return err + } + + wg.Add(1) + go func() { + defer wg.Done() + if _, err := io.Copy(e.stdout, remoteStdout); err != nil { + util.HandleError(err) + } + }() + } + + // set up stderr stream + if e.stderr != nil && !e.tty { + headers.Set(api.StreamType, api.StreamTypeStderr) + remoteStderr, err := conn.CreateStream(headers) + if err != nil { + return err + } + + wg.Add(1) + go func() { + defer wg.Done() + if _, err := io.Copy(e.stderr, remoteStderr); err != nil { + util.HandleError(err) + } + }() + } + + // we're waiting for stdout/stderr to finish copying + wg.Wait() + + // waits for errorStream to finish reading with an error or nil + return <-errorChan +} diff --git a/pkg/kubelet/portforward/constants.go b/pkg/kubelet/portforward/constants.go new file mode 100644 index 00000000000..f438670675f --- /dev/null +++ b/pkg/kubelet/portforward/constants.go @@ -0,0 +1,21 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// package portforward contains server-side logic for handling port forwarding requests. +package portforward + +// The subprotocol "portforward.k8s.io" is used for port forwarding. +const PortForwardProtocolV1Name = "portforward.k8s.io" diff --git a/pkg/kubelet/server.go b/pkg/kubelet/server.go index c8b2df9d613..83b95eb0914 100644 --- a/pkg/kubelet/server.go +++ b/pkg/kubelet/server.go @@ -44,9 +44,11 @@ import ( "k8s.io/kubernetes/pkg/api/validation" "k8s.io/kubernetes/pkg/auth/authenticator" "k8s.io/kubernetes/pkg/auth/authorizer" + "k8s.io/kubernetes/pkg/client/unversioned/remotecommand" "k8s.io/kubernetes/pkg/healthz" "k8s.io/kubernetes/pkg/httplog" kubecontainer "k8s.io/kubernetes/pkg/kubelet/container" + "k8s.io/kubernetes/pkg/kubelet/portforward" "k8s.io/kubernetes/pkg/types" "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/flushwriter" @@ -685,6 +687,13 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo return streams[0], streams[1], streams[2], streams[3], conn, tty, true } + supportedStreamProtocols := []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name} + _, err := httpstream.Handshake(request.Request, response.ResponseWriter, supportedStreamProtocols, remotecommand.StreamProtocolV1Name) + // negotiated protocol isn't used server side at the moment, but could be in the future + if err != nil { + return nil, nil, nil, nil, nil, false, false + } + streamCh := make(chan httpstream.Stream) upgrader := spdy.NewResponseUpgrader() @@ -779,6 +788,15 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp // connections; i.e., multiple `curl http://localhost:8888/` requests will be // handled by a single invocation of ServePortForward. func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) { + supportedPortForwardProtocols := []string{portforward.PortForwardProtocolV1Name} + _, err := httpstream.Handshake(req, w, supportedPortForwardProtocols, portforward.PortForwardProtocolV1Name) + // negotiated protocol isn't currently used server side, but could be in the future + if err != nil { + // Handshake writes the error to the client + util.HandleError(err) + return + } + streamChan := make(chan httpstream.Stream, 1) glog.V(5).Infof("Upgrading port forward response") diff --git a/pkg/util/httpstream/httpstream.go b/pkg/util/httpstream/httpstream.go index fff6840f279..80c3cd78fc0 100644 --- a/pkg/util/httpstream/httpstream.go +++ b/pkg/util/httpstream/httpstream.go @@ -17,6 +17,7 @@ limitations under the License. package httpstream import ( + "fmt" "io" "net/http" "strings" @@ -24,8 +25,10 @@ import ( ) const ( - HeaderConnection = "Connection" - HeaderUpgrade = "Upgrade" + HeaderConnection = "Connection" + HeaderUpgrade = "Upgrade" + HeaderProtocolVersion = "X-Stream-Protocol-Version" + HeaderAcceptedProtocolVersions = "X-Accepted-Stream-Protocol-Versions" ) // NewStreamHandler defines a function that is called when a new Stream is @@ -39,7 +42,10 @@ func NoOpNewStreamHandler(stream Stream) error { return nil } // Dialer knows how to open a streaming connection to a server. type Dialer interface { - Dial() (Connection, error) + + // Dial opens a streaming connection to a server using one of the protocols + // specified (in order of most preferred to least preferred). + Dial(protocols ...string) (Connection, string, error) } // UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade @@ -56,7 +62,7 @@ type UpgradeRoundTripper interface { // add streaming support to them. type ResponseUpgrader interface { // UpgradeResponse upgrades an HTTP response to one that supports multiplexed - // streams. newStreamHandler will be called synchronously whenever the + // streams. newStreamHandler will be called asynchronously whenever the // other end of the upgraded connection creates a new stream. UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection } @@ -96,3 +102,44 @@ func IsUpgradeRequest(req *http.Request) bool { } return false } + +func negotiateProtocol(clientProtocols, serverProtocols []string) string { + for i := range clientProtocols { + for j := range serverProtocols { + if clientProtocols[i] == serverProtocols[j] { + return clientProtocols[i] + } + } + } + return "" +} + +// Handshake performs a subprotocol negotiation. If the client did not request +// a specific subprotocol, defaultProtocol is used. If the client did request a +// subprotocol, Handshake will select the first common value found in +// serverProtocols. If a match is found, Handshake adds a response header +// indicating the chosen subprotocol. If no match is found, HTTP forbidden is +// returned, along with a response header containing the list of protocols the +// server can accept. +func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string, defaultProtocol string) (string, error) { + clientProtocols := req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)] + if len(clientProtocols) == 0 { + // Kube 1.0 client that didn't support subprotocol negotiation + // TODO remove this defaulting logic once Kube 1.0 is no longer supported + w.Header().Add(HeaderProtocolVersion, defaultProtocol) + return defaultProtocol, nil + } + + negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols) + if len(negotiatedProtocol) == 0 { + w.WriteHeader(http.StatusForbidden) + for i := range serverProtocols { + w.Header().Add(HeaderAcceptedProtocolVersions, serverProtocols[i]) + } + fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: client supports %v, server accepts %v", clientProtocols, serverProtocols) + return "", fmt.Errorf("unable to upgrade: unable to negotiate protocol: client supports %v, server supports %v", clientProtocols, serverProtocols) + } + + w.Header().Add(HeaderProtocolVersion, negotiatedProtocol) + return negotiatedProtocol, nil +} diff --git a/pkg/util/httpstream/httpstream_test.go b/pkg/util/httpstream/httpstream_test.go new file mode 100644 index 00000000000..7a1bbaefb89 --- /dev/null +++ b/pkg/util/httpstream/httpstream_test.go @@ -0,0 +1,120 @@ +/* +Copyright 2015 The Kubernetes Authors All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package httpstream + +import ( + "net/http" + "reflect" + "testing" +) + +type responseWriter struct { + header http.Header + statusCode *int +} + +func newResponseWriter() *responseWriter { + return &responseWriter{ + header: make(http.Header), + } +} + +func (r *responseWriter) Header() http.Header { + return r.header +} + +func (r *responseWriter) WriteHeader(code int) { + r.statusCode = &code +} + +func (r *responseWriter) Write([]byte) (int, error) { + return 0, nil +} + +func TestHandshake(t *testing.T) { + defaultProtocol := "default" + + tests := map[string]struct { + clientProtocols []string + serverProtocols []string + expectedProtocol string + expectError bool + }{ + "no client protocols": { + clientProtocols: []string{}, + serverProtocols: []string{"a", "b"}, + expectedProtocol: defaultProtocol, + }, + "no common protocol": { + clientProtocols: []string{"c"}, + serverProtocols: []string{"a", "b"}, + expectedProtocol: "", + expectError: true, + }, + "common protocol": { + clientProtocols: []string{"b"}, + serverProtocols: []string{"a", "b"}, + expectedProtocol: "b", + }, + } + + for name, test := range tests { + req, err := http.NewRequest("GET", "http://www.example.com/", nil) + if err != nil { + t.Fatalf("%s: error creating request: %v", name, err) + } + + for _, p := range test.clientProtocols { + req.Header.Add(HeaderProtocolVersion, p) + } + + w := newResponseWriter() + negotiated, err := Handshake(req, w, test.serverProtocols, defaultProtocol) + + // verify negotiated protocol + if e, a := test.expectedProtocol, negotiated; e != a { + t.Errorf("%s: protocol: expected %q, got %q", name, e, a) + } + + if test.expectError { + if err == nil { + t.Errorf("%s: expected error but did not get one", name) + } + if w.statusCode == nil { + t.Errorf("%s: expected w.statusCode to be set", name) + } else if e, a := http.StatusForbidden, *w.statusCode; e != a { + t.Errorf("%s: w.statusCode: expected %d, got %d", name, e, a) + } + if e, a := test.serverProtocols, w.Header()[HeaderAcceptedProtocolVersions]; !reflect.DeepEqual(e, a) { + t.Errorf("%s: accepted server protocols: expected %v, got %v", name, e, a) + } + continue + } + if !test.expectError && err != nil { + t.Errorf("%s: unexpected error: %v", name, err) + continue + } + if w.statusCode != nil { + t.Errorf("%s: unexpected non-nil w.statusCode: %d", w.statusCode) + } + + // verify response headers + if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !reflect.DeepEqual(e, a) { + t.Errorf("%s: protocol response header: expected %v, got %v", name, e, a) + } + } +} diff --git a/pkg/util/httpstream/spdy/upgrade.go b/pkg/util/httpstream/spdy/upgrade.go index 70c90c8f499..4fd2a40521a 100644 --- a/pkg/util/httpstream/spdy/upgrade.go +++ b/pkg/util/httpstream/spdy/upgrade.go @@ -21,7 +21,7 @@ import ( "net/http" "strings" - "github.com/golang/glog" + "k8s.io/kubernetes/pkg/util" "k8s.io/kubernetes/pkg/util/httpstream" ) @@ -46,15 +46,15 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection)) upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade)) if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) { - w.Write([]byte(fmt.Sprintf("Unable to upgrade: missing upgrade headers in request: %#v", req.Header))) w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "unable to upgrade: missing upgrade headers in request: %#v", req.Header) return nil } hijacker, ok := w.(http.Hijacker) if !ok { - w.Write([]byte("Unable to upgrade: unable to hijack response")) w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "unable to upgrade: unable to hijack response") return nil } @@ -64,13 +64,13 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque conn, _, err := hijacker.Hijack() if err != nil { - glog.Errorf("Unable to upgrade: error hijacking response: %v", err) + util.HandleError(fmt.Errorf("unable to upgrade: error hijacking response: %v", err)) return nil } spdyConn, err := NewServerConnection(conn, newStreamHandler) if err != nil { - glog.Errorf("Unable to upgrade: error creating SPDY server connection: %v", err) + util.HandleError(fmt.Errorf("unable to upgrade: error creating SPDY server connection: %v", err)) return nil }