Add protocol versions to pkg/util/wsstream

This commit is contained in:
bindata-mockuser 2016-08-04 18:39:12 +02:00 committed by Dr. Stefan Schimanski
parent 7b3c08d7d3
commit ce7f003f57
5 changed files with 302 additions and 80 deletions

View File

@ -451,7 +451,7 @@ func write(statusCode int, gv unversioned.GroupVersion, s runtime.NegotiatedSeri
defer out.Close()
if wsstream.IsWebSocketRequest(req) {
r := wsstream.NewReader(out, true)
r := wsstream.NewReader(out, true, wsstream.NewDefaultReaderProtocols())
if err := r.Copy(w, req); err != nil {
utilruntime.HandleError(fmt.Errorf("error encountered while streaming results via websocket: %v", err))
}

View File

@ -27,6 +27,7 @@ import (
"github.com/golang/glog"
"golang.org/x/net/websocket"
"k8s.io/kubernetes/pkg/util/runtime"
)
@ -44,7 +45,7 @@ import (
// READ []byte{1, 10} # receive "\n" on channel 1 (STDOUT)
// CLOSE
//
const channelWebSocketProtocol = "channel.k8s.io"
const ChannelWebSocketProtocol = "channel.k8s.io"
// The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character
// indicating the channel number (zero indexed) the message was sent on. Messages in both directions
@ -60,7 +61,7 @@ const channelWebSocketProtocol = "channel.k8s.io"
// READ []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT)
// CLOSE
//
const base64ChannelWebSocketProtocol = "base64.channel.k8s.io"
const Base64ChannelWebSocketProtocol = "base64.channel.k8s.io"
type codecType int
@ -107,8 +108,9 @@ func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
func handshake(config *websocket.Config, req *http.Request, allowed []string) error {
protocols := config.Protocol
if len(protocols) == 0 {
return nil
protocols = []string{""}
}
for _, protocol := range protocols {
for _, allow := range allowed {
if allow == protocol {
@ -117,41 +119,50 @@ func handshake(config *websocket.Config, req *http.Request, allowed []string) er
}
}
}
return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed)
}
// ChannelProtocolConfig describes a websocket subprotocol with channels.
type ChannelProtocolConfig struct {
Binary bool
Channels []ChannelType
}
// NewDefaultChannelProtocols returns a channel protocol map with the
// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io" and the given
// channels.
func NewDefaultChannelProtocols(channels []ChannelType) map[string]ChannelProtocolConfig {
return map[string]ChannelProtocolConfig{
"": {Binary: true, Channels: channels},
ChannelWebSocketProtocol: {Binary: true, Channels: channels},
Base64ChannelWebSocketProtocol: {Binary: false, Channels: channels},
}
}
// Conn supports sending multiple binary channels over a websocket connection.
// Supports only the "channel.k8s.io" subprotocol.
type Conn struct {
channels []*websocketChannel
codec codecType
ready chan struct{}
ws *websocket.Conn
timeout time.Duration
protocols map[string]ChannelProtocolConfig
selectedProtocol string
channels []*websocketChannel
codec codecType
ready chan struct{}
ws *websocket.Conn
timeout time.Duration
}
// NewConn creates a WebSocket connection that supports a set of channels. Channels begin each
// web socket message with a single byte indicating the channel number (0-N). 255 is reserved for
// future use. The channel types for each channel are passed as an array, supporting the different
// duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer.
func NewConn(channels ...ChannelType) *Conn {
conn := &Conn{
ready: make(chan struct{}),
channels: make([]*websocketChannel, len(channels)),
//
// The protocols parameter maps subprotocol names to ChannelProtocols. The empty string subprotocol
// name is used if websocket.Config.Protocol is empty.
func NewConn(protocols map[string]ChannelProtocolConfig) *Conn {
return &Conn{
ready: make(chan struct{}),
protocols: protocols,
}
for i := range conn.channels {
switch channels[i] {
case ReadChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
case WriteChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
case ReadWriteChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
case IgnoreChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
}
}
return conn
}
// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified,
@ -160,8 +171,9 @@ func (conn *Conn) SetIdleTimeout(duration time.Duration) {
conn.timeout = duration
}
// Open the connection and create channels for reading and writing.
func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) ([]io.ReadWriteCloser, error) {
// Open the connection and create channels for reading and writing. It returns
// the selected subprotocol, a slice of channels and an error.
func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) {
go func() {
defer runtime.HandleCrash()
defer conn.Close()
@ -172,23 +184,42 @@ func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) ([]io.ReadWrite
for i := range conn.channels {
rwc[i] = conn.channels[i]
}
return rwc, nil
return conn.selectedProtocol, rwc, nil
}
func (conn *Conn) initialize(ws *websocket.Conn) {
protocols := ws.Config().Protocol
switch {
case len(protocols) == 0, protocols[0] == channelWebSocketProtocol:
negotiated := ws.Config().Protocol
conn.selectedProtocol = negotiated[0]
p := conn.protocols[conn.selectedProtocol]
if p.Binary {
conn.codec = rawCodec
case protocols[0] == base64ChannelWebSocketProtocol:
} else {
conn.codec = base64Codec
}
conn.ws = ws
conn.channels = make([]*websocketChannel, len(p.Channels))
for i, t := range p.Channels {
switch t {
case ReadChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
case WriteChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
case ReadWriteChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
case IgnoreChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
}
}
close(conn.ready)
}
func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error {
return handshake(config, req, []string{channelWebSocketProtocol, base64ChannelWebSocketProtocol})
supportedProtocols := make([]string, 0, len(conn.protocols))
for p := range conn.protocols {
supportedProtocols = append(supportedProtocols, p)
}
return handshake(config, req, supportedProtocols)
}
func (conn *Conn) resetTimeout() {

View File

@ -20,6 +20,7 @@ import (
"encoding/base64"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"reflect"
"sync"
@ -28,15 +29,19 @@ import (
"golang.org/x/net/websocket"
)
func newServer(handler websocket.Handler) (*httptest.Server, string) {
func newServer(handler http.Handler) (*httptest.Server, string) {
server := httptest.NewServer(handler)
serverAddr := server.Listener.Addr().String()
return server, serverAddr
}
func TestRawConn(t *testing.T) {
conn := NewConn(ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel)
s, addr := newServer(conn.handle)
channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel}
conn := NewConn(NewDefaultChannelProtocols(channels))
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
conn.Open(w, req)
}))
defer s.Close()
client, err := websocket.Dial("ws://"+addr, "", "http://localhost/")
@ -112,8 +117,10 @@ func TestRawConn(t *testing.T) {
}
func TestBase64Conn(t *testing.T) {
conn := NewConn(ReadWriteChannel, ReadWriteChannel)
s, addr := newServer(conn.handle)
conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel}))
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
conn.Open(w, req)
}))
defer s.Close()
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
@ -167,3 +174,99 @@ func TestBase64Conn(t *testing.T) {
client.Close()
wg.Wait()
}
type versionTest struct {
supported map[string]bool // protocol -> binary
requested []string
error bool
expected string
}
func versionTests() []versionTest {
const (
binary = true
base64 = false
)
return []versionTest{
{
supported: nil,
requested: []string{"raw"},
error: true,
},
{
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
requested: nil,
expected: "",
},
{
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
requested: []string{"v1.raw"},
error: true,
},
{
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
requested: []string{"v1.raw", "v1.base64"},
error: true,
}, {
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
requested: []string{"v1.raw", "raw"},
expected: "raw",
},
{
supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
requested: []string{"v1.raw"},
expected: "v1.raw",
},
{
supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
requested: []string{"v2.base64"},
expected: "v2.base64",
},
}
}
func TestVersionedConn(t *testing.T) {
for i, test := range versionTests() {
func() {
supportedProtocols := map[string]ChannelProtocolConfig{}
for p, binary := range test.supported {
supportedProtocols[p] = ChannelProtocolConfig{
Binary: binary,
Channels: []ChannelType{ReadWriteChannel},
}
}
conn := NewConn(supportedProtocols)
// note that it's not enough to wait for conn.ready to avoid a race here. Hence,
// we use a channel.
selectedProtocol := make(chan string, 0)
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
p, _, _ := conn.Open(w, req)
selectedProtocol <- p
}))
defer s.Close()
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
if err != nil {
t.Fatal(err)
}
config.Protocol = test.requested
client, err := websocket.DialConfig(config)
if err != nil {
if !test.error {
t.Fatalf("test %d: didn't expect error: %v", i, err)
} else {
return
}
}
defer client.Close()
if test.error && err == nil {
t.Fatalf("test %d: expected an error", i)
}
<-conn.ready
if got, expected := <-selectedProtocol, test.expected; got != expected {
t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
}
}()
}
}

View File

@ -23,6 +23,7 @@ import (
"time"
"golang.org/x/net/websocket"
"k8s.io/kubernetes/pkg/util/runtime"
)
@ -37,23 +38,46 @@ const binaryWebSocketProtocol = "binary.k8s.io"
// possible.
const base64BinaryWebSocketProtocol = "base64.binary.k8s.io"
// ReaderProtocolConfig describes a websocket subprotocol with one stream.
type ReaderProtocolConfig struct {
Binary bool
}
// NewDefaultReaderProtocols returns a stream protocol map with the
// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io".
func NewDefaultReaderProtocols() map[string]ReaderProtocolConfig {
return map[string]ReaderProtocolConfig{
"": {Binary: true},
binaryWebSocketProtocol: {Binary: true},
base64BinaryWebSocketProtocol: {Binary: false},
}
}
// Reader supports returning an arbitrary byte stream over a websocket channel.
// Supports the "binary.k8s.io" and "base64.binary.k8s.io" subprotocols.
type Reader struct {
err chan error
r io.Reader
ping bool
timeout time.Duration
err chan error
r io.Reader
ping bool
timeout time.Duration
protocols map[string]ReaderProtocolConfig
selectedProtocol string
handleCrash func() // overridable for testing
}
// NewReader creates a WebSocket pipe that will copy the contents of r to a provided
// WebSocket connection. If ping is true, a zero length message will be sent to the client
// before the stream begins reading.
func NewReader(r io.Reader, ping bool) *Reader {
//
// The protocols parameter maps subprotocol names to StreamProtocols. The empty string
// subprotocol name is used if websocket.Config.Protocol is empty.
func NewReader(r io.Reader, ping bool, protocols map[string]ReaderProtocolConfig) *Reader {
return &Reader{
r: r,
err: make(chan error),
ping: ping,
r: r,
err: make(chan error),
ping: ping,
protocols: protocols,
handleCrash: func() { runtime.HandleCrash() },
}
}
@ -64,14 +88,18 @@ func (r *Reader) SetIdleTimeout(duration time.Duration) {
}
func (r *Reader) handshake(config *websocket.Config, req *http.Request) error {
return handshake(config, req, []string{binaryWebSocketProtocol, base64BinaryWebSocketProtocol})
supportedProtocols := make([]string, 0, len(r.protocols))
for p := range r.protocols {
supportedProtocols = append(supportedProtocols, p)
}
return handshake(config, req, supportedProtocols)
}
// Copy the reader to the response. The created WebSocket is closed after this
// method completes.
func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error {
go func() {
defer runtime.HandleCrash()
defer r.handleCrash()
websocket.Server{Handshake: r.handshake, Handler: r.handle}.ServeHTTP(w, req)
}()
return <-r.err
@ -79,11 +107,12 @@ func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error {
// handle implements a WebSocket handler.
func (r *Reader) handle(ws *websocket.Conn) {
encode := len(ws.Config().Protocol) > 0 && ws.Config().Protocol[0] == base64BinaryWebSocketProtocol
negotiated := ws.Config().Protocol
r.selectedProtocol = negotiated[0]
defer close(r.err)
defer ws.Close()
go IgnoreReceives(ws, r.timeout)
r.err <- messageCopy(ws, r.r, encode, r.ping, r.timeout)
r.err <- messageCopy(ws, r.r, !r.protocols[r.selectedProtocol].Binary, r.ping, r.timeout)
}
func resetTimeout(ws *websocket.Conn, timeout time.Duration) {

View File

@ -22,6 +22,7 @@ import (
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"strings"
"testing"
@ -32,7 +33,7 @@ import (
func TestStream(t *testing.T) {
input := "some random text"
r := NewReader(bytes.NewBuffer([]byte(input)), true)
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
r.SetIdleTimeout(time.Second)
data, err := readWebSocket(r, t, nil)
if !reflect.DeepEqual(data, []byte(input)) {
@ -45,7 +46,7 @@ func TestStream(t *testing.T) {
func TestStreamPing(t *testing.T) {
input := "some random text"
r := NewReader(bytes.NewBuffer([]byte(input)), true)
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
r.SetIdleTimeout(time.Second)
err := expectWebSocketFrames(r, t, nil, [][]byte{
{},
@ -59,8 +60,8 @@ func TestStreamPing(t *testing.T) {
func TestStreamBase64(t *testing.T) {
input := "some random text"
encoded := base64.StdEncoding.EncodeToString([]byte(input))
r := NewReader(bytes.NewBuffer([]byte(input)), true)
data, err := readWebSocket(r, t, nil, base64BinaryWebSocketProtocol)
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
data, err := readWebSocket(r, t, nil, "base64.binary.k8s.io")
if !reflect.DeepEqual(data, []byte(encoded)) {
t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
}
@ -69,6 +70,73 @@ func TestStreamBase64(t *testing.T) {
}
}
func TestStreamVersionedBase64(t *testing.T) {
input := "some random text"
encoded := base64.StdEncoding.EncodeToString([]byte(input))
r := NewReader(bytes.NewBuffer([]byte(input)), true, map[string]ReaderProtocolConfig{
"": {Binary: true},
"binary.k8s.io": {Binary: true},
"base64.binary.k8s.io": {Binary: false},
"v1.binary.k8s.io": {Binary: true},
"v1.base64.binary.k8s.io": {Binary: false},
"v2.binary.k8s.io": {Binary: true},
"v2.base64.binary.k8s.io": {Binary: false},
})
data, err := readWebSocket(r, t, nil, "v2.base64.binary.k8s.io")
if !reflect.DeepEqual(data, []byte(encoded)) {
t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
}
if err != nil {
t.Fatal(err)
}
}
func TestStreamVersionedCopy(t *testing.T) {
for i, test := range versionTests() {
func() {
supportedProtocols := map[string]ReaderProtocolConfig{}
for p, binary := range test.supported {
supportedProtocols[p] = ReaderProtocolConfig{
Binary: binary,
}
}
input := "some random text"
r := NewReader(bytes.NewBuffer([]byte(input)), true, supportedProtocols)
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
err := r.Copy(w, req)
if err != nil {
w.WriteHeader(503)
}
}))
defer s.Close()
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
if err != nil {
t.Error(err)
return
}
config.Protocol = test.requested
client, err := websocket.DialConfig(config)
if err != nil {
if !test.error {
t.Errorf("test %d: didn't expect error: %v", i, err)
}
return
}
defer client.Close()
if test.error && err == nil {
t.Errorf("test %d: expected an error", i)
return
}
<-r.err
if got, expected := r.selectedProtocol, test.expected; got != expected {
t.Errorf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
}
}()
}
}
func TestStreamError(t *testing.T) {
input := "some random text"
errs := &errorReader{
@ -78,7 +146,7 @@ func TestStreamError(t *testing.T) {
},
err: fmt.Errorf("bad read"),
}
r := NewReader(errs, false)
r := NewReader(errs, false, NewDefaultReaderProtocols())
data, err := readWebSocket(r, t, nil)
if !reflect.DeepEqual(data, []byte(input)) {
@ -98,7 +166,10 @@ func TestStreamSurvivesPanic(t *testing.T) {
},
panicMessage: "bad read",
}
r := NewReader(errs, false)
r := NewReader(errs, false, NewDefaultReaderProtocols())
// do not call runtime.HandleCrash() in handler. Otherwise, the tests are interrupted.
r.handleCrash = func() { recover() }
data, err := readWebSocket(r, t, nil)
if !reflect.DeepEqual(data, []byte(input)) {
@ -121,7 +192,7 @@ func TestStreamClosedDuringRead(t *testing.T) {
err: fmt.Errorf("stuff"),
pause: ch,
}
r := NewReader(errs, false)
r := NewReader(errs, false, NewDefaultReaderProtocols())
data, err := readWebSocket(r, t, func(c *websocket.Conn) {
c.Close()
@ -163,19 +234,13 @@ func (r *errorReader) Read(p []byte) (int, error) {
func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) {
errCh := make(chan error, 1)
s, addr := newServer(func(ws *websocket.Conn) {
cfg := ws.Config()
cfg.Protocol = protocols
go IgnoreReceives(ws, 0)
go func() {
err := <-r.err
errCh <- err
}()
r.handle(ws)
})
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
errCh <- r.Copy(w, req)
}))
defer s.Close()
config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
config.Protocol = protocols
client, err := websocket.DialConfig(config)
if err != nil {
return nil, err
@ -195,19 +260,13 @@ func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols
func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error {
errCh := make(chan error, 1)
s, addr := newServer(func(ws *websocket.Conn) {
cfg := ws.Config()
cfg.Protocol = protocols
go IgnoreReceives(ws, 0)
go func() {
err := <-r.err
errCh <- err
}()
r.handle(ws)
})
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
errCh <- r.Copy(w, req)
}))
defer s.Close()
config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
config.Protocol = protocols
ws, err := websocket.DialConfig(config)
if err != nil {
return err