Merge pull request #15961 from ncdc/stream-protocol-negotiation

Auto commit by PR queue bot
This commit is contained in:
k8s-merge-robot 2015-10-25 07:26:41 -07:00
commit 4f17b4b39c
11 changed files with 592 additions and 151 deletions

View File

@ -29,6 +29,7 @@ import (
"github.com/golang/glog"
"k8s.io/kubernetes/pkg/api"
"k8s.io/kubernetes/pkg/kubelet/portforward"
"k8s.io/kubernetes/pkg/util"
"k8s.io/kubernetes/pkg/util/httpstream"
)
@ -128,7 +129,7 @@ func (pf *PortForwarder) ForwardPorts() error {
defer pf.Close()
var err error
pf.streamConn, err = pf.dialer.Dial()
pf.streamConn, _, err = pf.dialer.Dial(portforward.PortForwardProtocolV1Name)
if err != nil {
return fmt.Errorf("error upgrading connection: %s", err)
}

View File

@ -38,14 +38,15 @@ import (
)
type fakeDialer struct {
dialed bool
conn httpstream.Connection
err error
dialed bool
conn httpstream.Connection
err error
negotiatedProtocol string
}
func (d *fakeDialer) Dial() (httpstream.Connection, error) {
func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
d.dialed = true
return d.conn, d.err
return d.conn, d.negotiatedProtocol, d.err
}
func TestParsePortsAndNew(t *testing.T) {

View File

@ -19,14 +19,12 @@ package remotecommand
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"sync"
"k8s.io/kubernetes/pkg/api"
"github.com/golang/glog"
client "k8s.io/kubernetes/pkg/client/unversioned"
"k8s.io/kubernetes/pkg/util"
"k8s.io/kubernetes/pkg/util/httpstream"
"k8s.io/kubernetes/pkg/util/httpstream/spdy"
)
@ -97,155 +95,100 @@ func NewStreamExecutor(upgrader httpstream.UpgradeRoundTripper, fn func(http.Rou
}, nil
}
// Dial opens a connection to a remote server and attempts to negotiate a SPDY connection.
func (e *streamExecutor) Dial() (httpstream.Connection, error) {
client := &http.Client{Transport: e.transport}
// Dial opens a connection to a remote server and attempts to negotiate a SPDY
// connection. Upon success, it returns the connection and the protocol
// selected by the server.
func (e *streamExecutor) Dial(protocols ...string) (httpstream.Connection, string, error) {
transport := e.transport
// TODO consider removing this and reusing client.TransportFor above to get this for free
switch {
case bool(glog.V(9)):
transport = client.NewDebuggingRoundTripper(transport, client.CurlCommand, client.URLTiming, client.ResponseHeaders)
case bool(glog.V(8)):
transport = client.NewDebuggingRoundTripper(transport, client.JustURL, client.RequestHeaders, client.ResponseStatus, client.ResponseHeaders)
case bool(glog.V(7)):
transport = client.NewDebuggingRoundTripper(transport, client.JustURL, client.RequestHeaders, client.ResponseStatus)
case bool(glog.V(6)):
transport = client.NewDebuggingRoundTripper(transport, client.URLTiming)
}
// TODO the client probably shouldn't be created here, as it doesn't allow
// flexibility to allow callers to configure it.
client := &http.Client{Transport: transport}
req, err := http.NewRequest(e.method, e.url.String(), nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %s", err)
return nil, "", fmt.Errorf("error creating request: %v", err)
}
for i := range protocols {
req.Header.Add(httpstream.HeaderProtocolVersion, protocols[i])
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("error sending request: %s", err)
return nil, "", fmt.Errorf("error sending request: %v", err)
}
defer resp.Body.Close()
// TODO: handle protocol selection in the future
return e.upgrader.NewConnection(resp)
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, "", fmt.Errorf("unexpected response status code %d (%s)", resp.StatusCode, http.StatusText(resp.StatusCode))
}
conn, err := e.upgrader.NewConnection(resp)
if err != nil {
return nil, "", err
}
return conn, resp.Header.Get(httpstream.HeaderProtocolVersion), nil
}
const (
// The SPDY subprotocol "channel.k8s.io" is used for remote command
// attachment/execution. This represents the initial unversioned subprotocol,
// which has the known bugs http://issues.k8s.io/13394 and
// http://issues.k8s.io/13395.
StreamProtocolV1Name = "channel.k8s.io"
// The SPDY subprotocol "v2.channel.k8s.io" is used for remote command
// attachment/execution. It is the second version of the subprotocol and
// resolves the issues present in the first version.
StreamProtocolV2Name = "v2.channel.k8s.io"
)
type streamProtocolHandler interface {
stream(httpstream.Connection) error
}
// Stream opens a protocol streamer to the server and streams until a client closes
// the connection or the server disconnects.
func (e *streamExecutor) Stream(stdin io.Reader, stdout, stderr io.Writer, tty bool) error {
conn, err := e.Dial()
supportedProtocols := []string{StreamProtocolV2Name, StreamProtocolV1Name}
conn, protocol, err := e.Dial(supportedProtocols...)
if err != nil {
return err
}
defer conn.Close()
// TODO: negotiate protocols
streamer := &streamProtocol{
stdin: stdin,
stdout: stdout,
stderr: stderr,
tty: tty,
var streamer streamProtocolHandler
switch protocol {
case StreamProtocolV2Name:
streamer = &streamProtocolV2{
stdin: stdin,
stdout: stdout,
stderr: stderr,
tty: tty,
}
case "":
glog.V(4).Infof("The server did not negotiate a streaming protocol version. Falling back to %s", StreamProtocolV1Name)
fallthrough
case StreamProtocolV1Name:
streamer = &streamProtocolV1{
stdin: stdin,
stdout: stdout,
stderr: stderr,
tty: tty,
}
}
return streamer.stream(conn)
}
type streamProtocol struct {
stdin io.Reader
stdout io.Writer
stderr io.Writer
tty bool
}
func (e *streamProtocol) stream(conn httpstream.Connection) error {
headers := http.Header{}
// set up error stream
errorChan := make(chan error)
headers.Set(api.StreamType, api.StreamTypeError)
errorStream, err := conn.CreateStream(headers)
if err != nil {
return err
}
go func() {
message, err := ioutil.ReadAll(errorStream)
switch {
case err != nil && err != io.EOF:
errorChan <- fmt.Errorf("error reading from error stream: %s", err)
case len(message) > 0:
errorChan <- fmt.Errorf("error executing remote command: %s", message)
default:
errorChan <- nil
}
close(errorChan)
}()
var wg sync.WaitGroup
var once sync.Once
// set up stdin stream
if e.stdin != nil {
headers.Set(api.StreamType, api.StreamTypeStdin)
remoteStdin, err := conn.CreateStream(headers)
if err != nil {
return err
}
// copy from client's stdin to container's stdin
go func() {
// if e.stdin is noninteractive, e.g. `echo abc | kubectl exec -i <pod> -- cat`, make sure
// we close remoteStdin as soon as the copy from e.stdin to remoteStdin finishes. Otherwise
// the executed command will remain running.
defer once.Do(func() { remoteStdin.Close() })
if _, err := io.Copy(remoteStdin, e.stdin); err != nil {
util.HandleError(err)
}
}()
// read from remoteStdin until the stream is closed. this is essential to
// be able to exit interactive sessions cleanly and not leak goroutines or
// hang the client's terminal.
//
// go-dockerclient's current hijack implementation
// (https://github.com/fsouza/go-dockerclient/blob/89f3d56d93788dfe85f864a44f85d9738fca0670/client.go#L564)
// waits for all three streams (stdin/stdout/stderr) to finish copying
// before returning. When hijack finishes copying stdout/stderr, it calls
// Close() on its side of remoteStdin, which allows this copy to complete.
// When that happens, we must Close() on our side of remoteStdin, to
// allow the copy in hijack to complete, and hijack to return.
go func() {
defer once.Do(func() { remoteStdin.Close() })
// this "copy" doesn't actually read anything - it's just here to wait for
// the server to close remoteStdin.
if _, err := io.Copy(ioutil.Discard, remoteStdin); err != nil {
util.HandleError(err)
}
}()
}
// set up stdout stream
if e.stdout != nil {
headers.Set(api.StreamType, api.StreamTypeStdout)
remoteStdout, err := conn.CreateStream(headers)
if err != nil {
return err
}
wg.Add(1)
go func() {
defer wg.Done()
if _, err := io.Copy(e.stdout, remoteStdout); err != nil {
util.HandleError(err)
}
}()
}
// set up stderr stream
if e.stderr != nil && !e.tty {
headers.Set(api.StreamType, api.StreamTypeStderr)
remoteStderr, err := conn.CreateStream(headers)
if err != nil {
return err
}
wg.Add(1)
go func() {
defer wg.Done()
if _, err := io.Copy(e.stderr, remoteStderr); err != nil {
util.HandleError(err)
}
}()
}
// we're waiting for stdout/stderr to finish copying
wg.Wait()
// waits for errorStream to finish reading with an error or nil
return <-errorChan
}

View File

@ -42,6 +42,13 @@ func fakeExecServer(t *testing.T, i int, stdinData, stdoutData, stderrData, erro
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
protocol, err := httpstream.Handshake(req, w, []string{StreamProtocolV2Name}, StreamProtocolV1Name)
if err != nil {
t.Fatal(err)
}
if protocol != StreamProtocolV2Name {
t.Fatalf("unexpected protocol: %s", protocol)
}
streamCh := make(chan httpstream.Stream)
upgrader := spdy.NewResponseUpgrader()
@ -184,6 +191,7 @@ func TestRequestExecuteRemoteCommand(t *testing.T) {
url, _ := url.ParseRequestURI(server.URL)
c := client.NewRESTClient(url, "x", nil, -1, -1)
req := c.Post().Resource("testing")
req.SetHeader(httpstream.HeaderProtocolVersion, StreamProtocolV2Name)
req.Param("command", "ls")
req.Param("command", "/")
conf := &client.Config{
@ -347,7 +355,7 @@ func TestDial(t *testing.T) {
checkResponse: true,
conn: &fakeConnection{},
resp: &http.Response{
StatusCode: http.StatusOK,
StatusCode: http.StatusSwitchingProtocols,
Body: ioutil.NopCloser(&bytes.Buffer{}),
},
}
@ -363,7 +371,7 @@ func TestDial(t *testing.T) {
if err != nil {
t.Fatal(err)
}
conn, err := exec.Dial()
conn, protocol, err := exec.Dial("protocol1")
if err != nil {
t.Fatal(err)
}
@ -373,4 +381,5 @@ func TestDial(t *testing.T) {
if !called {
t.Errorf("wrapper not called")
}
_ = protocol
}

View File

@ -0,0 +1,130 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package remotecommand
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"github.com/golang/glog"
"k8s.io/kubernetes/pkg/api"
"k8s.io/kubernetes/pkg/util/httpstream"
)
// streamProtocolV1 implements the first version of the streaming exec & attach
// protocol. This version has some bugs, such as not being able to detecte when
// non-interactive stdin data has ended. See http://issues.k8s.io/13394 and
// http://issues.k8s.io/13395 for more details.
type streamProtocolV1 struct {
stdin io.Reader
stdout io.Writer
stderr io.Writer
tty bool
}
var _ streamProtocolHandler = &streamProtocolV1{}
func (e *streamProtocolV1) stream(conn httpstream.Connection) error {
doneChan := make(chan struct{}, 2)
errorChan := make(chan error)
cp := func(s string, dst io.Writer, src io.Reader) {
glog.V(6).Infof("Copying %s", s)
defer glog.V(6).Infof("Done copying %s", s)
if _, err := io.Copy(dst, src); err != nil && err != io.EOF {
glog.Errorf("Error copying %s: %v", s, err)
}
if s == api.StreamTypeStdout || s == api.StreamTypeStderr {
doneChan <- struct{}{}
}
}
headers := http.Header{}
headers.Set(api.StreamType, api.StreamTypeError)
errorStream, err := conn.CreateStream(headers)
if err != nil {
return err
}
go func() {
message, err := ioutil.ReadAll(errorStream)
if err != nil && err != io.EOF {
errorChan <- fmt.Errorf("Error reading from error stream: %s", err)
return
}
if len(message) > 0 {
errorChan <- fmt.Errorf("Error executing remote command: %s", message)
return
}
}()
defer errorStream.Reset()
if e.stdin != nil {
headers.Set(api.StreamType, api.StreamTypeStdin)
remoteStdin, err := conn.CreateStream(headers)
if err != nil {
return err
}
defer remoteStdin.Reset()
// TODO this goroutine will never exit cleanly (the io.Copy never unblocks)
// because stdin is not closed until the process exits. If we try to call
// stdin.Close(), it returns no error but doesn't unblock the copy. It will
// exit when the process exits, instead.
go cp(api.StreamTypeStdin, remoteStdin, e.stdin)
}
waitCount := 0
completedStreams := 0
if e.stdout != nil {
waitCount++
headers.Set(api.StreamType, api.StreamTypeStdout)
remoteStdout, err := conn.CreateStream(headers)
if err != nil {
return err
}
defer remoteStdout.Reset()
go cp(api.StreamTypeStdout, e.stdout, remoteStdout)
}
if e.stderr != nil && !e.tty {
waitCount++
headers.Set(api.StreamType, api.StreamTypeStderr)
remoteStderr, err := conn.CreateStream(headers)
if err != nil {
return err
}
defer remoteStderr.Reset()
go cp(api.StreamTypeStderr, e.stderr, remoteStderr)
}
Loop:
for {
select {
case <-doneChan:
completedStreams++
if completedStreams == waitCount {
break Loop
}
case err := <-errorChan:
return err
}
}
return nil
}

View File

@ -0,0 +1,151 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package remotecommand
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"sync"
"k8s.io/kubernetes/pkg/api"
"k8s.io/kubernetes/pkg/util"
"k8s.io/kubernetes/pkg/util/httpstream"
)
// streamProtocolV2 implements version 2 of the streaming protocol for attach
// and exec. The original streaming protocol was unversioned. As a result, this
// version is referred to as version 2, even though it is the first actual
// numbered version.
type streamProtocolV2 struct {
stdin io.Reader
stdout io.Writer
stderr io.Writer
tty bool
}
var _ streamProtocolHandler = &streamProtocolV2{}
func (e *streamProtocolV2) stream(conn httpstream.Connection) error {
headers := http.Header{}
// set up error stream
errorChan := make(chan error)
headers.Set(api.StreamType, api.StreamTypeError)
errorStream, err := conn.CreateStream(headers)
if err != nil {
return err
}
go func() {
message, err := ioutil.ReadAll(errorStream)
switch {
case err != nil && err != io.EOF:
errorChan <- fmt.Errorf("error reading from error stream: %s", err)
case len(message) > 0:
errorChan <- fmt.Errorf("error executing remote command: %s", message)
default:
errorChan <- nil
}
close(errorChan)
}()
var wg sync.WaitGroup
var once sync.Once
// set up stdin stream
if e.stdin != nil {
headers.Set(api.StreamType, api.StreamTypeStdin)
remoteStdin, err := conn.CreateStream(headers)
if err != nil {
return err
}
// copy from client's stdin to container's stdin
go func() {
// if e.stdin is noninteractive, e.g. `echo abc | kubectl exec -i <pod> -- cat`, make sure
// we close remoteStdin as soon as the copy from e.stdin to remoteStdin finishes. Otherwise
// the executed command will remain running.
defer once.Do(func() { remoteStdin.Close() })
if _, err := io.Copy(remoteStdin, e.stdin); err != nil {
util.HandleError(err)
}
}()
// read from remoteStdin until the stream is closed. this is essential to
// be able to exit interactive sessions cleanly and not leak goroutines or
// hang the client's terminal.
//
// go-dockerclient's current hijack implementation
// (https://github.com/fsouza/go-dockerclient/blob/89f3d56d93788dfe85f864a44f85d9738fca0670/client.go#L564)
// waits for all three streams (stdin/stdout/stderr) to finish copying
// before returning. When hijack finishes copying stdout/stderr, it calls
// Close() on its side of remoteStdin, which allows this copy to complete.
// When that happens, we must Close() on our side of remoteStdin, to
// allow the copy in hijack to complete, and hijack to return.
go func() {
defer once.Do(func() { remoteStdin.Close() })
// this "copy" doesn't actually read anything - it's just here to wait for
// the server to close remoteStdin.
if _, err := io.Copy(ioutil.Discard, remoteStdin); err != nil {
util.HandleError(err)
}
}()
}
// set up stdout stream
if e.stdout != nil {
headers.Set(api.StreamType, api.StreamTypeStdout)
remoteStdout, err := conn.CreateStream(headers)
if err != nil {
return err
}
wg.Add(1)
go func() {
defer wg.Done()
if _, err := io.Copy(e.stdout, remoteStdout); err != nil {
util.HandleError(err)
}
}()
}
// set up stderr stream
if e.stderr != nil && !e.tty {
headers.Set(api.StreamType, api.StreamTypeStderr)
remoteStderr, err := conn.CreateStream(headers)
if err != nil {
return err
}
wg.Add(1)
go func() {
defer wg.Done()
if _, err := io.Copy(e.stderr, remoteStderr); err != nil {
util.HandleError(err)
}
}()
}
// we're waiting for stdout/stderr to finish copying
wg.Wait()
// waits for errorStream to finish reading with an error or nil
return <-errorChan
}

View File

@ -0,0 +1,21 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// package portforward contains server-side logic for handling port forwarding requests.
package portforward
// The subprotocol "portforward.k8s.io" is used for port forwarding.
const PortForwardProtocolV1Name = "portforward.k8s.io"

View File

@ -44,9 +44,11 @@ import (
"k8s.io/kubernetes/pkg/api/validation"
"k8s.io/kubernetes/pkg/auth/authenticator"
"k8s.io/kubernetes/pkg/auth/authorizer"
"k8s.io/kubernetes/pkg/client/unversioned/remotecommand"
"k8s.io/kubernetes/pkg/healthz"
"k8s.io/kubernetes/pkg/httplog"
kubecontainer "k8s.io/kubernetes/pkg/kubelet/container"
"k8s.io/kubernetes/pkg/kubelet/portforward"
"k8s.io/kubernetes/pkg/types"
"k8s.io/kubernetes/pkg/util"
"k8s.io/kubernetes/pkg/util/flushwriter"
@ -685,6 +687,13 @@ func (s *Server) createStreams(request *restful.Request, response *restful.Respo
return streams[0], streams[1], streams[2], streams[3], conn, tty, true
}
supportedStreamProtocols := []string{remotecommand.StreamProtocolV2Name, remotecommand.StreamProtocolV1Name}
_, err := httpstream.Handshake(request.Request, response.ResponseWriter, supportedStreamProtocols, remotecommand.StreamProtocolV1Name)
// negotiated protocol isn't used server side at the moment, but could be in the future
if err != nil {
return nil, nil, nil, nil, nil, false, false
}
streamCh := make(chan httpstream.Stream)
upgrader := spdy.NewResponseUpgrader()
@ -779,6 +788,15 @@ func (s *Server) getPortForward(request *restful.Request, response *restful.Resp
// connections; i.e., multiple `curl http://localhost:8888/` requests will be
// handled by a single invocation of ServePortForward.
func ServePortForward(w http.ResponseWriter, req *http.Request, portForwarder PortForwarder, podName string, uid types.UID, idleTimeout time.Duration, streamCreationTimeout time.Duration) {
supportedPortForwardProtocols := []string{portforward.PortForwardProtocolV1Name}
_, err := httpstream.Handshake(req, w, supportedPortForwardProtocols, portforward.PortForwardProtocolV1Name)
// negotiated protocol isn't currently used server side, but could be in the future
if err != nil {
// Handshake writes the error to the client
util.HandleError(err)
return
}
streamChan := make(chan httpstream.Stream, 1)
glog.V(5).Infof("Upgrading port forward response")

View File

@ -17,6 +17,7 @@ limitations under the License.
package httpstream
import (
"fmt"
"io"
"net/http"
"strings"
@ -24,8 +25,10 @@ import (
)
const (
HeaderConnection = "Connection"
HeaderUpgrade = "Upgrade"
HeaderConnection = "Connection"
HeaderUpgrade = "Upgrade"
HeaderProtocolVersion = "X-Stream-Protocol-Version"
HeaderAcceptedProtocolVersions = "X-Accepted-Stream-Protocol-Versions"
)
// NewStreamHandler defines a function that is called when a new Stream is
@ -39,7 +42,10 @@ func NoOpNewStreamHandler(stream Stream) error { return nil }
// Dialer knows how to open a streaming connection to a server.
type Dialer interface {
Dial() (Connection, error)
// Dial opens a streaming connection to a server using one of the protocols
// specified (in order of most preferred to least preferred).
Dial(protocols ...string) (Connection, string, error)
}
// UpgradeRoundTripper is a type of http.RoundTripper that is able to upgrade
@ -56,7 +62,7 @@ type UpgradeRoundTripper interface {
// add streaming support to them.
type ResponseUpgrader interface {
// UpgradeResponse upgrades an HTTP response to one that supports multiplexed
// streams. newStreamHandler will be called synchronously whenever the
// streams. newStreamHandler will be called asynchronously whenever the
// other end of the upgraded connection creates a new stream.
UpgradeResponse(w http.ResponseWriter, req *http.Request, newStreamHandler NewStreamHandler) Connection
}
@ -96,3 +102,44 @@ func IsUpgradeRequest(req *http.Request) bool {
}
return false
}
func negotiateProtocol(clientProtocols, serverProtocols []string) string {
for i := range clientProtocols {
for j := range serverProtocols {
if clientProtocols[i] == serverProtocols[j] {
return clientProtocols[i]
}
}
}
return ""
}
// Handshake performs a subprotocol negotiation. If the client did not request
// a specific subprotocol, defaultProtocol is used. If the client did request a
// subprotocol, Handshake will select the first common value found in
// serverProtocols. If a match is found, Handshake adds a response header
// indicating the chosen subprotocol. If no match is found, HTTP forbidden is
// returned, along with a response header containing the list of protocols the
// server can accept.
func Handshake(req *http.Request, w http.ResponseWriter, serverProtocols []string, defaultProtocol string) (string, error) {
clientProtocols := req.Header[http.CanonicalHeaderKey(HeaderProtocolVersion)]
if len(clientProtocols) == 0 {
// Kube 1.0 client that didn't support subprotocol negotiation
// TODO remove this defaulting logic once Kube 1.0 is no longer supported
w.Header().Add(HeaderProtocolVersion, defaultProtocol)
return defaultProtocol, nil
}
negotiatedProtocol := negotiateProtocol(clientProtocols, serverProtocols)
if len(negotiatedProtocol) == 0 {
w.WriteHeader(http.StatusForbidden)
for i := range serverProtocols {
w.Header().Add(HeaderAcceptedProtocolVersions, serverProtocols[i])
}
fmt.Fprintf(w, "unable to upgrade: unable to negotiate protocol: client supports %v, server accepts %v", clientProtocols, serverProtocols)
return "", fmt.Errorf("unable to upgrade: unable to negotiate protocol: client supports %v, server supports %v", clientProtocols, serverProtocols)
}
w.Header().Add(HeaderProtocolVersion, negotiatedProtocol)
return negotiatedProtocol, nil
}

View File

@ -0,0 +1,120 @@
/*
Copyright 2015 The Kubernetes Authors All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package httpstream
import (
"net/http"
"reflect"
"testing"
)
type responseWriter struct {
header http.Header
statusCode *int
}
func newResponseWriter() *responseWriter {
return &responseWriter{
header: make(http.Header),
}
}
func (r *responseWriter) Header() http.Header {
return r.header
}
func (r *responseWriter) WriteHeader(code int) {
r.statusCode = &code
}
func (r *responseWriter) Write([]byte) (int, error) {
return 0, nil
}
func TestHandshake(t *testing.T) {
defaultProtocol := "default"
tests := map[string]struct {
clientProtocols []string
serverProtocols []string
expectedProtocol string
expectError bool
}{
"no client protocols": {
clientProtocols: []string{},
serverProtocols: []string{"a", "b"},
expectedProtocol: defaultProtocol,
},
"no common protocol": {
clientProtocols: []string{"c"},
serverProtocols: []string{"a", "b"},
expectedProtocol: "",
expectError: true,
},
"common protocol": {
clientProtocols: []string{"b"},
serverProtocols: []string{"a", "b"},
expectedProtocol: "b",
},
}
for name, test := range tests {
req, err := http.NewRequest("GET", "http://www.example.com/", nil)
if err != nil {
t.Fatalf("%s: error creating request: %v", name, err)
}
for _, p := range test.clientProtocols {
req.Header.Add(HeaderProtocolVersion, p)
}
w := newResponseWriter()
negotiated, err := Handshake(req, w, test.serverProtocols, defaultProtocol)
// verify negotiated protocol
if e, a := test.expectedProtocol, negotiated; e != a {
t.Errorf("%s: protocol: expected %q, got %q", name, e, a)
}
if test.expectError {
if err == nil {
t.Errorf("%s: expected error but did not get one", name)
}
if w.statusCode == nil {
t.Errorf("%s: expected w.statusCode to be set", name)
} else if e, a := http.StatusForbidden, *w.statusCode; e != a {
t.Errorf("%s: w.statusCode: expected %d, got %d", name, e, a)
}
if e, a := test.serverProtocols, w.Header()[HeaderAcceptedProtocolVersions]; !reflect.DeepEqual(e, a) {
t.Errorf("%s: accepted server protocols: expected %v, got %v", name, e, a)
}
continue
}
if !test.expectError && err != nil {
t.Errorf("%s: unexpected error: %v", name, err)
continue
}
if w.statusCode != nil {
t.Errorf("%s: unexpected non-nil w.statusCode: %d", w.statusCode)
}
// verify response headers
if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !reflect.DeepEqual(e, a) {
t.Errorf("%s: protocol response header: expected %v, got %v", name, e, a)
}
}
}

View File

@ -21,7 +21,7 @@ import (
"net/http"
"strings"
"github.com/golang/glog"
"k8s.io/kubernetes/pkg/util"
"k8s.io/kubernetes/pkg/util/httpstream"
)
@ -46,15 +46,15 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque
connectionHeader := strings.ToLower(req.Header.Get(httpstream.HeaderConnection))
upgradeHeader := strings.ToLower(req.Header.Get(httpstream.HeaderUpgrade))
if !strings.Contains(connectionHeader, strings.ToLower(httpstream.HeaderUpgrade)) || !strings.Contains(upgradeHeader, strings.ToLower(HeaderSpdy31)) {
w.Write([]byte(fmt.Sprintf("Unable to upgrade: missing upgrade headers in request: %#v", req.Header)))
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "unable to upgrade: missing upgrade headers in request: %#v", req.Header)
return nil
}
hijacker, ok := w.(http.Hijacker)
if !ok {
w.Write([]byte("Unable to upgrade: unable to hijack response"))
w.WriteHeader(http.StatusInternalServerError)
fmt.Fprintf(w, "unable to upgrade: unable to hijack response")
return nil
}
@ -64,13 +64,13 @@ func (u responseUpgrader) UpgradeResponse(w http.ResponseWriter, req *http.Reque
conn, _, err := hijacker.Hijack()
if err != nil {
glog.Errorf("Unable to upgrade: error hijacking response: %v", err)
util.HandleError(fmt.Errorf("unable to upgrade: error hijacking response: %v", err))
return nil
}
spdyConn, err := NewServerConnection(conn, newStreamHandler)
if err != nil {
glog.Errorf("Unable to upgrade: error creating SPDY server connection: %v", err)
util.HandleError(fmt.Errorf("unable to upgrade: error creating SPDY server connection: %v", err))
return nil
}