Add protocol versions to pkg/util/wsstream

This commit is contained in:
bindata-mockuser 2016-08-04 18:39:12 +02:00 committed by Dr. Stefan Schimanski
parent 7b3c08d7d3
commit ce7f003f57
5 changed files with 302 additions and 80 deletions

View File

@ -451,7 +451,7 @@ func write(statusCode int, gv unversioned.GroupVersion, s runtime.NegotiatedSeri
defer out.Close() defer out.Close()
if wsstream.IsWebSocketRequest(req) { if wsstream.IsWebSocketRequest(req) {
r := wsstream.NewReader(out, true) r := wsstream.NewReader(out, true, wsstream.NewDefaultReaderProtocols())
if err := r.Copy(w, req); err != nil { if err := r.Copy(w, req); err != nil {
utilruntime.HandleError(fmt.Errorf("error encountered while streaming results via websocket: %v", err)) utilruntime.HandleError(fmt.Errorf("error encountered while streaming results via websocket: %v", err))
} }

View File

@ -27,6 +27,7 @@ import (
"github.com/golang/glog" "github.com/golang/glog"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
"k8s.io/kubernetes/pkg/util/runtime" "k8s.io/kubernetes/pkg/util/runtime"
) )
@ -44,7 +45,7 @@ import (
// READ []byte{1, 10} # receive "\n" on channel 1 (STDOUT) // READ []byte{1, 10} # receive "\n" on channel 1 (STDOUT)
// CLOSE // CLOSE
// //
const channelWebSocketProtocol = "channel.k8s.io" const ChannelWebSocketProtocol = "channel.k8s.io"
// The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character // The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character
// indicating the channel number (zero indexed) the message was sent on. Messages in both directions // indicating the channel number (zero indexed) the message was sent on. Messages in both directions
@ -60,7 +61,7 @@ const channelWebSocketProtocol = "channel.k8s.io"
// READ []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT) // READ []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT)
// CLOSE // CLOSE
// //
const base64ChannelWebSocketProtocol = "base64.channel.k8s.io" const Base64ChannelWebSocketProtocol = "base64.channel.k8s.io"
type codecType int type codecType int
@ -107,8 +108,9 @@ func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
func handshake(config *websocket.Config, req *http.Request, allowed []string) error { func handshake(config *websocket.Config, req *http.Request, allowed []string) error {
protocols := config.Protocol protocols := config.Protocol
if len(protocols) == 0 { if len(protocols) == 0 {
return nil protocols = []string{""}
} }
for _, protocol := range protocols { for _, protocol := range protocols {
for _, allow := range allowed { for _, allow := range allowed {
if allow == protocol { if allow == protocol {
@ -117,41 +119,50 @@ func handshake(config *websocket.Config, req *http.Request, allowed []string) er
} }
} }
} }
return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed) return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed)
} }
// ChannelProtocolConfig describes a websocket subprotocol with channels.
type ChannelProtocolConfig struct {
Binary bool
Channels []ChannelType
}
// NewDefaultChannelProtocols returns a channel protocol map with the
// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io" and the given
// channels.
func NewDefaultChannelProtocols(channels []ChannelType) map[string]ChannelProtocolConfig {
return map[string]ChannelProtocolConfig{
"": {Binary: true, Channels: channels},
ChannelWebSocketProtocol: {Binary: true, Channels: channels},
Base64ChannelWebSocketProtocol: {Binary: false, Channels: channels},
}
}
// Conn supports sending multiple binary channels over a websocket connection. // Conn supports sending multiple binary channels over a websocket connection.
// Supports only the "channel.k8s.io" subprotocol.
type Conn struct { type Conn struct {
channels []*websocketChannel protocols map[string]ChannelProtocolConfig
codec codecType selectedProtocol string
ready chan struct{} channels []*websocketChannel
ws *websocket.Conn codec codecType
timeout time.Duration ready chan struct{}
ws *websocket.Conn
timeout time.Duration
} }
// NewConn creates a WebSocket connection that supports a set of channels. Channels begin each // NewConn creates a WebSocket connection that supports a set of channels. Channels begin each
// web socket message with a single byte indicating the channel number (0-N). 255 is reserved for // web socket message with a single byte indicating the channel number (0-N). 255 is reserved for
// future use. The channel types for each channel are passed as an array, supporting the different // future use. The channel types for each channel are passed as an array, supporting the different
// duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer. // duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer.
func NewConn(channels ...ChannelType) *Conn { //
conn := &Conn{ // The protocols parameter maps subprotocol names to ChannelProtocols. The empty string subprotocol
ready: make(chan struct{}), // name is used if websocket.Config.Protocol is empty.
channels: make([]*websocketChannel, len(channels)), func NewConn(protocols map[string]ChannelProtocolConfig) *Conn {
return &Conn{
ready: make(chan struct{}),
protocols: protocols,
} }
for i := range conn.channels {
switch channels[i] {
case ReadChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
case WriteChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
case ReadWriteChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
case IgnoreChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
}
}
return conn
} }
// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified, // SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified,
@ -160,8 +171,9 @@ func (conn *Conn) SetIdleTimeout(duration time.Duration) {
conn.timeout = duration conn.timeout = duration
} }
// Open the connection and create channels for reading and writing. // Open the connection and create channels for reading and writing. It returns
func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) ([]io.ReadWriteCloser, error) { // the selected subprotocol, a slice of channels and an error.
func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) {
go func() { go func() {
defer runtime.HandleCrash() defer runtime.HandleCrash()
defer conn.Close() defer conn.Close()
@ -172,23 +184,42 @@ func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) ([]io.ReadWrite
for i := range conn.channels { for i := range conn.channels {
rwc[i] = conn.channels[i] rwc[i] = conn.channels[i]
} }
return rwc, nil return conn.selectedProtocol, rwc, nil
} }
func (conn *Conn) initialize(ws *websocket.Conn) { func (conn *Conn) initialize(ws *websocket.Conn) {
protocols := ws.Config().Protocol negotiated := ws.Config().Protocol
switch { conn.selectedProtocol = negotiated[0]
case len(protocols) == 0, protocols[0] == channelWebSocketProtocol: p := conn.protocols[conn.selectedProtocol]
if p.Binary {
conn.codec = rawCodec conn.codec = rawCodec
case protocols[0] == base64ChannelWebSocketProtocol: } else {
conn.codec = base64Codec conn.codec = base64Codec
} }
conn.ws = ws conn.ws = ws
conn.channels = make([]*websocketChannel, len(p.Channels))
for i, t := range p.Channels {
switch t {
case ReadChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
case WriteChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
case ReadWriteChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
case IgnoreChannel:
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
}
}
close(conn.ready) close(conn.ready)
} }
func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error { func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error {
return handshake(config, req, []string{channelWebSocketProtocol, base64ChannelWebSocketProtocol}) supportedProtocols := make([]string, 0, len(conn.protocols))
for p := range conn.protocols {
supportedProtocols = append(supportedProtocols, p)
}
return handshake(config, req, supportedProtocols)
} }
func (conn *Conn) resetTimeout() { func (conn *Conn) resetTimeout() {

View File

@ -20,6 +20,7 @@ import (
"encoding/base64" "encoding/base64"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
"sync" "sync"
@ -28,15 +29,19 @@ import (
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
func newServer(handler websocket.Handler) (*httptest.Server, string) { func newServer(handler http.Handler) (*httptest.Server, string) {
server := httptest.NewServer(handler) server := httptest.NewServer(handler)
serverAddr := server.Listener.Addr().String() serverAddr := server.Listener.Addr().String()
return server, serverAddr return server, serverAddr
} }
func TestRawConn(t *testing.T) { func TestRawConn(t *testing.T) {
conn := NewConn(ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel) channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel}
s, addr := newServer(conn.handle) conn := NewConn(NewDefaultChannelProtocols(channels))
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
conn.Open(w, req)
}))
defer s.Close() defer s.Close()
client, err := websocket.Dial("ws://"+addr, "", "http://localhost/") client, err := websocket.Dial("ws://"+addr, "", "http://localhost/")
@ -112,8 +117,10 @@ func TestRawConn(t *testing.T) {
} }
func TestBase64Conn(t *testing.T) { func TestBase64Conn(t *testing.T) {
conn := NewConn(ReadWriteChannel, ReadWriteChannel) conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel}))
s, addr := newServer(conn.handle) s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
conn.Open(w, req)
}))
defer s.Close() defer s.Close()
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/") config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
@ -167,3 +174,99 @@ func TestBase64Conn(t *testing.T) {
client.Close() client.Close()
wg.Wait() wg.Wait()
} }
type versionTest struct {
supported map[string]bool // protocol -> binary
requested []string
error bool
expected string
}
func versionTests() []versionTest {
const (
binary = true
base64 = false
)
return []versionTest{
{
supported: nil,
requested: []string{"raw"},
error: true,
},
{
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
requested: nil,
expected: "",
},
{
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
requested: []string{"v1.raw"},
error: true,
},
{
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
requested: []string{"v1.raw", "v1.base64"},
error: true,
}, {
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
requested: []string{"v1.raw", "raw"},
expected: "raw",
},
{
supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
requested: []string{"v1.raw"},
expected: "v1.raw",
},
{
supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
requested: []string{"v2.base64"},
expected: "v2.base64",
},
}
}
func TestVersionedConn(t *testing.T) {
for i, test := range versionTests() {
func() {
supportedProtocols := map[string]ChannelProtocolConfig{}
for p, binary := range test.supported {
supportedProtocols[p] = ChannelProtocolConfig{
Binary: binary,
Channels: []ChannelType{ReadWriteChannel},
}
}
conn := NewConn(supportedProtocols)
// note that it's not enough to wait for conn.ready to avoid a race here. Hence,
// we use a channel.
selectedProtocol := make(chan string, 0)
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
p, _, _ := conn.Open(w, req)
selectedProtocol <- p
}))
defer s.Close()
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
if err != nil {
t.Fatal(err)
}
config.Protocol = test.requested
client, err := websocket.DialConfig(config)
if err != nil {
if !test.error {
t.Fatalf("test %d: didn't expect error: %v", i, err)
} else {
return
}
}
defer client.Close()
if test.error && err == nil {
t.Fatalf("test %d: expected an error", i)
}
<-conn.ready
if got, expected := <-selectedProtocol, test.expected; got != expected {
t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
}
}()
}
}

View File

@ -23,6 +23,7 @@ import (
"time" "time"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
"k8s.io/kubernetes/pkg/util/runtime" "k8s.io/kubernetes/pkg/util/runtime"
) )
@ -37,23 +38,46 @@ const binaryWebSocketProtocol = "binary.k8s.io"
// possible. // possible.
const base64BinaryWebSocketProtocol = "base64.binary.k8s.io" const base64BinaryWebSocketProtocol = "base64.binary.k8s.io"
// ReaderProtocolConfig describes a websocket subprotocol with one stream.
type ReaderProtocolConfig struct {
Binary bool
}
// NewDefaultReaderProtocols returns a stream protocol map with the
// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io".
func NewDefaultReaderProtocols() map[string]ReaderProtocolConfig {
return map[string]ReaderProtocolConfig{
"": {Binary: true},
binaryWebSocketProtocol: {Binary: true},
base64BinaryWebSocketProtocol: {Binary: false},
}
}
// Reader supports returning an arbitrary byte stream over a websocket channel. // Reader supports returning an arbitrary byte stream over a websocket channel.
// Supports the "binary.k8s.io" and "base64.binary.k8s.io" subprotocols.
type Reader struct { type Reader struct {
err chan error err chan error
r io.Reader r io.Reader
ping bool ping bool
timeout time.Duration timeout time.Duration
protocols map[string]ReaderProtocolConfig
selectedProtocol string
handleCrash func() // overridable for testing
} }
// NewReader creates a WebSocket pipe that will copy the contents of r to a provided // NewReader creates a WebSocket pipe that will copy the contents of r to a provided
// WebSocket connection. If ping is true, a zero length message will be sent to the client // WebSocket connection. If ping is true, a zero length message will be sent to the client
// before the stream begins reading. // before the stream begins reading.
func NewReader(r io.Reader, ping bool) *Reader { //
// The protocols parameter maps subprotocol names to StreamProtocols. The empty string
// subprotocol name is used if websocket.Config.Protocol is empty.
func NewReader(r io.Reader, ping bool, protocols map[string]ReaderProtocolConfig) *Reader {
return &Reader{ return &Reader{
r: r, r: r,
err: make(chan error), err: make(chan error),
ping: ping, ping: ping,
protocols: protocols,
handleCrash: func() { runtime.HandleCrash() },
} }
} }
@ -64,14 +88,18 @@ func (r *Reader) SetIdleTimeout(duration time.Duration) {
} }
func (r *Reader) handshake(config *websocket.Config, req *http.Request) error { func (r *Reader) handshake(config *websocket.Config, req *http.Request) error {
return handshake(config, req, []string{binaryWebSocketProtocol, base64BinaryWebSocketProtocol}) supportedProtocols := make([]string, 0, len(r.protocols))
for p := range r.protocols {
supportedProtocols = append(supportedProtocols, p)
}
return handshake(config, req, supportedProtocols)
} }
// Copy the reader to the response. The created WebSocket is closed after this // Copy the reader to the response. The created WebSocket is closed after this
// method completes. // method completes.
func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error { func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error {
go func() { go func() {
defer runtime.HandleCrash() defer r.handleCrash()
websocket.Server{Handshake: r.handshake, Handler: r.handle}.ServeHTTP(w, req) websocket.Server{Handshake: r.handshake, Handler: r.handle}.ServeHTTP(w, req)
}() }()
return <-r.err return <-r.err
@ -79,11 +107,12 @@ func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error {
// handle implements a WebSocket handler. // handle implements a WebSocket handler.
func (r *Reader) handle(ws *websocket.Conn) { func (r *Reader) handle(ws *websocket.Conn) {
encode := len(ws.Config().Protocol) > 0 && ws.Config().Protocol[0] == base64BinaryWebSocketProtocol negotiated := ws.Config().Protocol
r.selectedProtocol = negotiated[0]
defer close(r.err) defer close(r.err)
defer ws.Close() defer ws.Close()
go IgnoreReceives(ws, r.timeout) go IgnoreReceives(ws, r.timeout)
r.err <- messageCopy(ws, r.r, encode, r.ping, r.timeout) r.err <- messageCopy(ws, r.r, !r.protocols[r.selectedProtocol].Binary, r.ping, r.timeout)
} }
func resetTimeout(ws *websocket.Conn, timeout time.Duration) { func resetTimeout(ws *websocket.Conn, timeout time.Duration) {

View File

@ -22,6 +22,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
@ -32,7 +33,7 @@ import (
func TestStream(t *testing.T) { func TestStream(t *testing.T) {
input := "some random text" input := "some random text"
r := NewReader(bytes.NewBuffer([]byte(input)), true) r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
r.SetIdleTimeout(time.Second) r.SetIdleTimeout(time.Second)
data, err := readWebSocket(r, t, nil) data, err := readWebSocket(r, t, nil)
if !reflect.DeepEqual(data, []byte(input)) { if !reflect.DeepEqual(data, []byte(input)) {
@ -45,7 +46,7 @@ func TestStream(t *testing.T) {
func TestStreamPing(t *testing.T) { func TestStreamPing(t *testing.T) {
input := "some random text" input := "some random text"
r := NewReader(bytes.NewBuffer([]byte(input)), true) r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
r.SetIdleTimeout(time.Second) r.SetIdleTimeout(time.Second)
err := expectWebSocketFrames(r, t, nil, [][]byte{ err := expectWebSocketFrames(r, t, nil, [][]byte{
{}, {},
@ -59,8 +60,8 @@ func TestStreamPing(t *testing.T) {
func TestStreamBase64(t *testing.T) { func TestStreamBase64(t *testing.T) {
input := "some random text" input := "some random text"
encoded := base64.StdEncoding.EncodeToString([]byte(input)) encoded := base64.StdEncoding.EncodeToString([]byte(input))
r := NewReader(bytes.NewBuffer([]byte(input)), true) r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
data, err := readWebSocket(r, t, nil, base64BinaryWebSocketProtocol) data, err := readWebSocket(r, t, nil, "base64.binary.k8s.io")
if !reflect.DeepEqual(data, []byte(encoded)) { if !reflect.DeepEqual(data, []byte(encoded)) {
t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded)) t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
} }
@ -69,6 +70,73 @@ func TestStreamBase64(t *testing.T) {
} }
} }
func TestStreamVersionedBase64(t *testing.T) {
input := "some random text"
encoded := base64.StdEncoding.EncodeToString([]byte(input))
r := NewReader(bytes.NewBuffer([]byte(input)), true, map[string]ReaderProtocolConfig{
"": {Binary: true},
"binary.k8s.io": {Binary: true},
"base64.binary.k8s.io": {Binary: false},
"v1.binary.k8s.io": {Binary: true},
"v1.base64.binary.k8s.io": {Binary: false},
"v2.binary.k8s.io": {Binary: true},
"v2.base64.binary.k8s.io": {Binary: false},
})
data, err := readWebSocket(r, t, nil, "v2.base64.binary.k8s.io")
if !reflect.DeepEqual(data, []byte(encoded)) {
t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
}
if err != nil {
t.Fatal(err)
}
}
func TestStreamVersionedCopy(t *testing.T) {
for i, test := range versionTests() {
func() {
supportedProtocols := map[string]ReaderProtocolConfig{}
for p, binary := range test.supported {
supportedProtocols[p] = ReaderProtocolConfig{
Binary: binary,
}
}
input := "some random text"
r := NewReader(bytes.NewBuffer([]byte(input)), true, supportedProtocols)
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
err := r.Copy(w, req)
if err != nil {
w.WriteHeader(503)
}
}))
defer s.Close()
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
if err != nil {
t.Error(err)
return
}
config.Protocol = test.requested
client, err := websocket.DialConfig(config)
if err != nil {
if !test.error {
t.Errorf("test %d: didn't expect error: %v", i, err)
}
return
}
defer client.Close()
if test.error && err == nil {
t.Errorf("test %d: expected an error", i)
return
}
<-r.err
if got, expected := r.selectedProtocol, test.expected; got != expected {
t.Errorf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
}
}()
}
}
func TestStreamError(t *testing.T) { func TestStreamError(t *testing.T) {
input := "some random text" input := "some random text"
errs := &errorReader{ errs := &errorReader{
@ -78,7 +146,7 @@ func TestStreamError(t *testing.T) {
}, },
err: fmt.Errorf("bad read"), err: fmt.Errorf("bad read"),
} }
r := NewReader(errs, false) r := NewReader(errs, false, NewDefaultReaderProtocols())
data, err := readWebSocket(r, t, nil) data, err := readWebSocket(r, t, nil)
if !reflect.DeepEqual(data, []byte(input)) { if !reflect.DeepEqual(data, []byte(input)) {
@ -98,7 +166,10 @@ func TestStreamSurvivesPanic(t *testing.T) {
}, },
panicMessage: "bad read", panicMessage: "bad read",
} }
r := NewReader(errs, false) r := NewReader(errs, false, NewDefaultReaderProtocols())
// do not call runtime.HandleCrash() in handler. Otherwise, the tests are interrupted.
r.handleCrash = func() { recover() }
data, err := readWebSocket(r, t, nil) data, err := readWebSocket(r, t, nil)
if !reflect.DeepEqual(data, []byte(input)) { if !reflect.DeepEqual(data, []byte(input)) {
@ -121,7 +192,7 @@ func TestStreamClosedDuringRead(t *testing.T) {
err: fmt.Errorf("stuff"), err: fmt.Errorf("stuff"),
pause: ch, pause: ch,
} }
r := NewReader(errs, false) r := NewReader(errs, false, NewDefaultReaderProtocols())
data, err := readWebSocket(r, t, func(c *websocket.Conn) { data, err := readWebSocket(r, t, func(c *websocket.Conn) {
c.Close() c.Close()
@ -163,19 +234,13 @@ func (r *errorReader) Read(p []byte) (int, error) {
func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) { func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) {
errCh := make(chan error, 1) errCh := make(chan error, 1)
s, addr := newServer(func(ws *websocket.Conn) { s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
cfg := ws.Config() errCh <- r.Copy(w, req)
cfg.Protocol = protocols }))
go IgnoreReceives(ws, 0)
go func() {
err := <-r.err
errCh <- err
}()
r.handle(ws)
})
defer s.Close() defer s.Close()
config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr) config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
config.Protocol = protocols
client, err := websocket.DialConfig(config) client, err := websocket.DialConfig(config)
if err != nil { if err != nil {
return nil, err return nil, err
@ -195,19 +260,13 @@ func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols
func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error { func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error {
errCh := make(chan error, 1) errCh := make(chan error, 1)
s, addr := newServer(func(ws *websocket.Conn) { s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
cfg := ws.Config() errCh <- r.Copy(w, req)
cfg.Protocol = protocols }))
go IgnoreReceives(ws, 0)
go func() {
err := <-r.err
errCh <- err
}()
r.handle(ws)
})
defer s.Close() defer s.Close()
config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr) config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
config.Protocol = protocols
ws, err := websocket.DialConfig(config) ws, err := websocket.DialConfig(config)
if err != nil { if err != nil {
return err return err