diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 1d4b07c1680..54839a058d5 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -451,7 +451,7 @@ func write(statusCode int, gv unversioned.GroupVersion, s runtime.NegotiatedSeri defer out.Close() if wsstream.IsWebSocketRequest(req) { - r := wsstream.NewReader(out, true) + r := wsstream.NewReader(out, true, wsstream.NewDefaultReaderProtocols()) if err := r.Copy(w, req); err != nil { utilruntime.HandleError(fmt.Errorf("error encountered while streaming results via websocket: %v", err)) } diff --git a/pkg/util/wsstream/conn.go b/pkg/util/wsstream/conn.go index eb0b4b28bb1..94b75568eec 100644 --- a/pkg/util/wsstream/conn.go +++ b/pkg/util/wsstream/conn.go @@ -27,6 +27,7 @@ import ( "github.com/golang/glog" "golang.org/x/net/websocket" + "k8s.io/kubernetes/pkg/util/runtime" ) @@ -44,7 +45,7 @@ import ( // READ []byte{1, 10} # receive "\n" on channel 1 (STDOUT) // CLOSE // -const channelWebSocketProtocol = "channel.k8s.io" +const ChannelWebSocketProtocol = "channel.k8s.io" // The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character // indicating the channel number (zero indexed) the message was sent on. Messages in both directions @@ -60,7 +61,7 @@ const channelWebSocketProtocol = "channel.k8s.io" // READ []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT) // CLOSE // -const base64ChannelWebSocketProtocol = "base64.channel.k8s.io" +const Base64ChannelWebSocketProtocol = "base64.channel.k8s.io" type codecType int @@ -107,8 +108,9 @@ func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) { func handshake(config *websocket.Config, req *http.Request, allowed []string) error { protocols := config.Protocol if len(protocols) == 0 { - return nil + protocols = []string{""} } + for _, protocol := range protocols { for _, allow := range allowed { if allow == protocol { @@ -117,41 +119,50 @@ func handshake(config *websocket.Config, req *http.Request, allowed []string) er } } } + return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed) } +// ChannelProtocolConfig describes a websocket subprotocol with channels. +type ChannelProtocolConfig struct { + Binary bool + Channels []ChannelType +} + +// NewDefaultChannelProtocols returns a channel protocol map with the +// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io" and the given +// channels. +func NewDefaultChannelProtocols(channels []ChannelType) map[string]ChannelProtocolConfig { + return map[string]ChannelProtocolConfig{ + "": {Binary: true, Channels: channels}, + ChannelWebSocketProtocol: {Binary: true, Channels: channels}, + Base64ChannelWebSocketProtocol: {Binary: false, Channels: channels}, + } +} + // Conn supports sending multiple binary channels over a websocket connection. -// Supports only the "channel.k8s.io" subprotocol. type Conn struct { - channels []*websocketChannel - codec codecType - ready chan struct{} - ws *websocket.Conn - timeout time.Duration + protocols map[string]ChannelProtocolConfig + selectedProtocol string + channels []*websocketChannel + codec codecType + ready chan struct{} + ws *websocket.Conn + timeout time.Duration } // NewConn creates a WebSocket connection that supports a set of channels. Channels begin each // web socket message with a single byte indicating the channel number (0-N). 255 is reserved for // future use. The channel types for each channel are passed as an array, supporting the different // duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer. -func NewConn(channels ...ChannelType) *Conn { - conn := &Conn{ - ready: make(chan struct{}), - channels: make([]*websocketChannel, len(channels)), +// +// The protocols parameter maps subprotocol names to ChannelProtocols. The empty string subprotocol +// name is used if websocket.Config.Protocol is empty. +func NewConn(protocols map[string]ChannelProtocolConfig) *Conn { + return &Conn{ + ready: make(chan struct{}), + protocols: protocols, } - for i := range conn.channels { - switch channels[i] { - case ReadChannel: - conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false) - case WriteChannel: - conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true) - case ReadWriteChannel: - conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true) - case IgnoreChannel: - conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false) - } - } - return conn } // SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified, @@ -160,8 +171,9 @@ func (conn *Conn) SetIdleTimeout(duration time.Duration) { conn.timeout = duration } -// Open the connection and create channels for reading and writing. -func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) ([]io.ReadWriteCloser, error) { +// Open the connection and create channels for reading and writing. It returns +// the selected subprotocol, a slice of channels and an error. +func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) { go func() { defer runtime.HandleCrash() defer conn.Close() @@ -172,23 +184,42 @@ func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) ([]io.ReadWrite for i := range conn.channels { rwc[i] = conn.channels[i] } - return rwc, nil + return conn.selectedProtocol, rwc, nil } func (conn *Conn) initialize(ws *websocket.Conn) { - protocols := ws.Config().Protocol - switch { - case len(protocols) == 0, protocols[0] == channelWebSocketProtocol: + negotiated := ws.Config().Protocol + conn.selectedProtocol = negotiated[0] + p := conn.protocols[conn.selectedProtocol] + if p.Binary { conn.codec = rawCodec - case protocols[0] == base64ChannelWebSocketProtocol: + } else { conn.codec = base64Codec } conn.ws = ws + conn.channels = make([]*websocketChannel, len(p.Channels)) + for i, t := range p.Channels { + switch t { + case ReadChannel: + conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false) + case WriteChannel: + conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true) + case ReadWriteChannel: + conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true) + case IgnoreChannel: + conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false) + } + } + close(conn.ready) } func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error { - return handshake(config, req, []string{channelWebSocketProtocol, base64ChannelWebSocketProtocol}) + supportedProtocols := make([]string, 0, len(conn.protocols)) + for p := range conn.protocols { + supportedProtocols = append(supportedProtocols, p) + } + return handshake(config, req, supportedProtocols) } func (conn *Conn) resetTimeout() { diff --git a/pkg/util/wsstream/conn_test.go b/pkg/util/wsstream/conn_test.go index 88c84bf162f..1c049aad7a2 100644 --- a/pkg/util/wsstream/conn_test.go +++ b/pkg/util/wsstream/conn_test.go @@ -20,6 +20,7 @@ import ( "encoding/base64" "io" "io/ioutil" + "net/http" "net/http/httptest" "reflect" "sync" @@ -28,15 +29,19 @@ import ( "golang.org/x/net/websocket" ) -func newServer(handler websocket.Handler) (*httptest.Server, string) { +func newServer(handler http.Handler) (*httptest.Server, string) { server := httptest.NewServer(handler) serverAddr := server.Listener.Addr().String() return server, serverAddr } func TestRawConn(t *testing.T) { - conn := NewConn(ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel) - s, addr := newServer(conn.handle) + channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel} + conn := NewConn(NewDefaultChannelProtocols(channels)) + + s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conn.Open(w, req) + })) defer s.Close() client, err := websocket.Dial("ws://"+addr, "", "http://localhost/") @@ -112,8 +117,10 @@ func TestRawConn(t *testing.T) { } func TestBase64Conn(t *testing.T) { - conn := NewConn(ReadWriteChannel, ReadWriteChannel) - s, addr := newServer(conn.handle) + conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel})) + s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + conn.Open(w, req) + })) defer s.Close() config, err := websocket.NewConfig("ws://"+addr, "http://localhost/") @@ -167,3 +174,99 @@ func TestBase64Conn(t *testing.T) { client.Close() wg.Wait() } + +type versionTest struct { + supported map[string]bool // protocol -> binary + requested []string + error bool + expected string +} + +func versionTests() []versionTest { + const ( + binary = true + base64 = false + ) + return []versionTest{ + { + supported: nil, + requested: []string{"raw"}, + error: true, + }, + { + supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, + requested: nil, + expected: "", + }, + { + supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, + requested: []string{"v1.raw"}, + error: true, + }, + { + supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, + requested: []string{"v1.raw", "v1.base64"}, + error: true, + }, { + supported: map[string]bool{"": binary, "raw": binary, "base64": base64}, + requested: []string{"v1.raw", "raw"}, + expected: "raw", + }, + { + supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64}, + requested: []string{"v1.raw"}, + expected: "v1.raw", + }, + { + supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64}, + requested: []string{"v2.base64"}, + expected: "v2.base64", + }, + } +} + +func TestVersionedConn(t *testing.T) { + for i, test := range versionTests() { + func() { + supportedProtocols := map[string]ChannelProtocolConfig{} + for p, binary := range test.supported { + supportedProtocols[p] = ChannelProtocolConfig{ + Binary: binary, + Channels: []ChannelType{ReadWriteChannel}, + } + } + conn := NewConn(supportedProtocols) + // note that it's not enough to wait for conn.ready to avoid a race here. Hence, + // we use a channel. + selectedProtocol := make(chan string, 0) + s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + p, _, _ := conn.Open(w, req) + selectedProtocol <- p + })) + defer s.Close() + + config, err := websocket.NewConfig("ws://"+addr, "http://localhost/") + if err != nil { + t.Fatal(err) + } + config.Protocol = test.requested + client, err := websocket.DialConfig(config) + if err != nil { + if !test.error { + t.Fatalf("test %d: didn't expect error: %v", i, err) + } else { + return + } + } + defer client.Close() + if test.error && err == nil { + t.Fatalf("test %d: expected an error", i) + } + + <-conn.ready + if got, expected := <-selectedProtocol, test.expected; got != expected { + t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected) + } + }() + } +} diff --git a/pkg/util/wsstream/stream.go b/pkg/util/wsstream/stream.go index d16d99817e3..2a8326cd8ab 100644 --- a/pkg/util/wsstream/stream.go +++ b/pkg/util/wsstream/stream.go @@ -23,6 +23,7 @@ import ( "time" "golang.org/x/net/websocket" + "k8s.io/kubernetes/pkg/util/runtime" ) @@ -37,23 +38,46 @@ const binaryWebSocketProtocol = "binary.k8s.io" // possible. const base64BinaryWebSocketProtocol = "base64.binary.k8s.io" +// ReaderProtocolConfig describes a websocket subprotocol with one stream. +type ReaderProtocolConfig struct { + Binary bool +} + +// NewDefaultReaderProtocols returns a stream protocol map with the +// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io". +func NewDefaultReaderProtocols() map[string]ReaderProtocolConfig { + return map[string]ReaderProtocolConfig{ + "": {Binary: true}, + binaryWebSocketProtocol: {Binary: true}, + base64BinaryWebSocketProtocol: {Binary: false}, + } +} + // Reader supports returning an arbitrary byte stream over a websocket channel. -// Supports the "binary.k8s.io" and "base64.binary.k8s.io" subprotocols. type Reader struct { - err chan error - r io.Reader - ping bool - timeout time.Duration + err chan error + r io.Reader + ping bool + timeout time.Duration + protocols map[string]ReaderProtocolConfig + selectedProtocol string + + handleCrash func() // overridable for testing } // NewReader creates a WebSocket pipe that will copy the contents of r to a provided // WebSocket connection. If ping is true, a zero length message will be sent to the client // before the stream begins reading. -func NewReader(r io.Reader, ping bool) *Reader { +// +// The protocols parameter maps subprotocol names to StreamProtocols. The empty string +// subprotocol name is used if websocket.Config.Protocol is empty. +func NewReader(r io.Reader, ping bool, protocols map[string]ReaderProtocolConfig) *Reader { return &Reader{ - r: r, - err: make(chan error), - ping: ping, + r: r, + err: make(chan error), + ping: ping, + protocols: protocols, + handleCrash: func() { runtime.HandleCrash() }, } } @@ -64,14 +88,18 @@ func (r *Reader) SetIdleTimeout(duration time.Duration) { } func (r *Reader) handshake(config *websocket.Config, req *http.Request) error { - return handshake(config, req, []string{binaryWebSocketProtocol, base64BinaryWebSocketProtocol}) + supportedProtocols := make([]string, 0, len(r.protocols)) + for p := range r.protocols { + supportedProtocols = append(supportedProtocols, p) + } + return handshake(config, req, supportedProtocols) } // Copy the reader to the response. The created WebSocket is closed after this // method completes. func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error { go func() { - defer runtime.HandleCrash() + defer r.handleCrash() websocket.Server{Handshake: r.handshake, Handler: r.handle}.ServeHTTP(w, req) }() return <-r.err @@ -79,11 +107,12 @@ func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error { // handle implements a WebSocket handler. func (r *Reader) handle(ws *websocket.Conn) { - encode := len(ws.Config().Protocol) > 0 && ws.Config().Protocol[0] == base64BinaryWebSocketProtocol + negotiated := ws.Config().Protocol + r.selectedProtocol = negotiated[0] defer close(r.err) defer ws.Close() go IgnoreReceives(ws, r.timeout) - r.err <- messageCopy(ws, r.r, encode, r.ping, r.timeout) + r.err <- messageCopy(ws, r.r, !r.protocols[r.selectedProtocol].Binary, r.ping, r.timeout) } func resetTimeout(ws *websocket.Conn, timeout time.Duration) { diff --git a/pkg/util/wsstream/stream_test.go b/pkg/util/wsstream/stream_test.go index aa37b2ddfe5..09dda761f8c 100644 --- a/pkg/util/wsstream/stream_test.go +++ b/pkg/util/wsstream/stream_test.go @@ -22,6 +22,7 @@ import ( "fmt" "io" "io/ioutil" + "net/http" "reflect" "strings" "testing" @@ -32,7 +33,7 @@ import ( func TestStream(t *testing.T) { input := "some random text" - r := NewReader(bytes.NewBuffer([]byte(input)), true) + r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols()) r.SetIdleTimeout(time.Second) data, err := readWebSocket(r, t, nil) if !reflect.DeepEqual(data, []byte(input)) { @@ -45,7 +46,7 @@ func TestStream(t *testing.T) { func TestStreamPing(t *testing.T) { input := "some random text" - r := NewReader(bytes.NewBuffer([]byte(input)), true) + r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols()) r.SetIdleTimeout(time.Second) err := expectWebSocketFrames(r, t, nil, [][]byte{ {}, @@ -59,8 +60,8 @@ func TestStreamPing(t *testing.T) { func TestStreamBase64(t *testing.T) { input := "some random text" encoded := base64.StdEncoding.EncodeToString([]byte(input)) - r := NewReader(bytes.NewBuffer([]byte(input)), true) - data, err := readWebSocket(r, t, nil, base64BinaryWebSocketProtocol) + r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols()) + data, err := readWebSocket(r, t, nil, "base64.binary.k8s.io") if !reflect.DeepEqual(data, []byte(encoded)) { t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded)) } @@ -69,6 +70,73 @@ func TestStreamBase64(t *testing.T) { } } +func TestStreamVersionedBase64(t *testing.T) { + input := "some random text" + encoded := base64.StdEncoding.EncodeToString([]byte(input)) + r := NewReader(bytes.NewBuffer([]byte(input)), true, map[string]ReaderProtocolConfig{ + "": {Binary: true}, + "binary.k8s.io": {Binary: true}, + "base64.binary.k8s.io": {Binary: false}, + "v1.binary.k8s.io": {Binary: true}, + "v1.base64.binary.k8s.io": {Binary: false}, + "v2.binary.k8s.io": {Binary: true}, + "v2.base64.binary.k8s.io": {Binary: false}, + }) + data, err := readWebSocket(r, t, nil, "v2.base64.binary.k8s.io") + if !reflect.DeepEqual(data, []byte(encoded)) { + t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded)) + } + if err != nil { + t.Fatal(err) + } +} + +func TestStreamVersionedCopy(t *testing.T) { + for i, test := range versionTests() { + func() { + supportedProtocols := map[string]ReaderProtocolConfig{} + for p, binary := range test.supported { + supportedProtocols[p] = ReaderProtocolConfig{ + Binary: binary, + } + } + input := "some random text" + r := NewReader(bytes.NewBuffer([]byte(input)), true, supportedProtocols) + s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + err := r.Copy(w, req) + if err != nil { + w.WriteHeader(503) + } + })) + defer s.Close() + + config, err := websocket.NewConfig("ws://"+addr, "http://localhost/") + if err != nil { + t.Error(err) + return + } + config.Protocol = test.requested + client, err := websocket.DialConfig(config) + if err != nil { + if !test.error { + t.Errorf("test %d: didn't expect error: %v", i, err) + } + return + } + defer client.Close() + if test.error && err == nil { + t.Errorf("test %d: expected an error", i) + return + } + + <-r.err + if got, expected := r.selectedProtocol, test.expected; got != expected { + t.Errorf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected) + } + }() + } +} + func TestStreamError(t *testing.T) { input := "some random text" errs := &errorReader{ @@ -78,7 +146,7 @@ func TestStreamError(t *testing.T) { }, err: fmt.Errorf("bad read"), } - r := NewReader(errs, false) + r := NewReader(errs, false, NewDefaultReaderProtocols()) data, err := readWebSocket(r, t, nil) if !reflect.DeepEqual(data, []byte(input)) { @@ -98,7 +166,10 @@ func TestStreamSurvivesPanic(t *testing.T) { }, panicMessage: "bad read", } - r := NewReader(errs, false) + r := NewReader(errs, false, NewDefaultReaderProtocols()) + + // do not call runtime.HandleCrash() in handler. Otherwise, the tests are interrupted. + r.handleCrash = func() { recover() } data, err := readWebSocket(r, t, nil) if !reflect.DeepEqual(data, []byte(input)) { @@ -121,7 +192,7 @@ func TestStreamClosedDuringRead(t *testing.T) { err: fmt.Errorf("stuff"), pause: ch, } - r := NewReader(errs, false) + r := NewReader(errs, false, NewDefaultReaderProtocols()) data, err := readWebSocket(r, t, func(c *websocket.Conn) { c.Close() @@ -163,19 +234,13 @@ func (r *errorReader) Read(p []byte) (int, error) { func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) { errCh := make(chan error, 1) - s, addr := newServer(func(ws *websocket.Conn) { - cfg := ws.Config() - cfg.Protocol = protocols - go IgnoreReceives(ws, 0) - go func() { - err := <-r.err - errCh <- err - }() - r.handle(ws) - }) + s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + errCh <- r.Copy(w, req) + })) defer s.Close() config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr) + config.Protocol = protocols client, err := websocket.DialConfig(config) if err != nil { return nil, err @@ -195,19 +260,13 @@ func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error { errCh := make(chan error, 1) - s, addr := newServer(func(ws *websocket.Conn) { - cfg := ws.Config() - cfg.Protocol = protocols - go IgnoreReceives(ws, 0) - go func() { - err := <-r.err - errCh <- err - }() - r.handle(ws) - }) + s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + errCh <- r.Copy(w, req) + })) defer s.Close() config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr) + config.Protocol = protocols ws, err := websocket.DialConfig(config) if err != nil { return err