mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-08-05 10:19:50 +00:00
Add protocol versions to pkg/util/wsstream
This commit is contained in:
parent
7b3c08d7d3
commit
ce7f003f57
@ -451,7 +451,7 @@ func write(statusCode int, gv unversioned.GroupVersion, s runtime.NegotiatedSeri
|
|||||||
defer out.Close()
|
defer out.Close()
|
||||||
|
|
||||||
if wsstream.IsWebSocketRequest(req) {
|
if wsstream.IsWebSocketRequest(req) {
|
||||||
r := wsstream.NewReader(out, true)
|
r := wsstream.NewReader(out, true, wsstream.NewDefaultReaderProtocols())
|
||||||
if err := r.Copy(w, req); err != nil {
|
if err := r.Copy(w, req); err != nil {
|
||||||
utilruntime.HandleError(fmt.Errorf("error encountered while streaming results via websocket: %v", err))
|
utilruntime.HandleError(fmt.Errorf("error encountered while streaming results via websocket: %v", err))
|
||||||
}
|
}
|
||||||
|
@ -27,6 +27,7 @@ import (
|
|||||||
|
|
||||||
"github.com/golang/glog"
|
"github.com/golang/glog"
|
||||||
"golang.org/x/net/websocket"
|
"golang.org/x/net/websocket"
|
||||||
|
|
||||||
"k8s.io/kubernetes/pkg/util/runtime"
|
"k8s.io/kubernetes/pkg/util/runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -44,7 +45,7 @@ import (
|
|||||||
// READ []byte{1, 10} # receive "\n" on channel 1 (STDOUT)
|
// READ []byte{1, 10} # receive "\n" on channel 1 (STDOUT)
|
||||||
// CLOSE
|
// CLOSE
|
||||||
//
|
//
|
||||||
const channelWebSocketProtocol = "channel.k8s.io"
|
const ChannelWebSocketProtocol = "channel.k8s.io"
|
||||||
|
|
||||||
// The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character
|
// The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character
|
||||||
// indicating the channel number (zero indexed) the message was sent on. Messages in both directions
|
// indicating the channel number (zero indexed) the message was sent on. Messages in both directions
|
||||||
@ -60,7 +61,7 @@ const channelWebSocketProtocol = "channel.k8s.io"
|
|||||||
// READ []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT)
|
// READ []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT)
|
||||||
// CLOSE
|
// CLOSE
|
||||||
//
|
//
|
||||||
const base64ChannelWebSocketProtocol = "base64.channel.k8s.io"
|
const Base64ChannelWebSocketProtocol = "base64.channel.k8s.io"
|
||||||
|
|
||||||
type codecType int
|
type codecType int
|
||||||
|
|
||||||
@ -107,8 +108,9 @@ func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
|
|||||||
func handshake(config *websocket.Config, req *http.Request, allowed []string) error {
|
func handshake(config *websocket.Config, req *http.Request, allowed []string) error {
|
||||||
protocols := config.Protocol
|
protocols := config.Protocol
|
||||||
if len(protocols) == 0 {
|
if len(protocols) == 0 {
|
||||||
return nil
|
protocols = []string{""}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, protocol := range protocols {
|
for _, protocol := range protocols {
|
||||||
for _, allow := range allowed {
|
for _, allow := range allowed {
|
||||||
if allow == protocol {
|
if allow == protocol {
|
||||||
@ -117,41 +119,50 @@ func handshake(config *websocket.Config, req *http.Request, allowed []string) er
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed)
|
return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ChannelProtocolConfig describes a websocket subprotocol with channels.
|
||||||
|
type ChannelProtocolConfig struct {
|
||||||
|
Binary bool
|
||||||
|
Channels []ChannelType
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultChannelProtocols returns a channel protocol map with the
|
||||||
|
// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io" and the given
|
||||||
|
// channels.
|
||||||
|
func NewDefaultChannelProtocols(channels []ChannelType) map[string]ChannelProtocolConfig {
|
||||||
|
return map[string]ChannelProtocolConfig{
|
||||||
|
"": {Binary: true, Channels: channels},
|
||||||
|
ChannelWebSocketProtocol: {Binary: true, Channels: channels},
|
||||||
|
Base64ChannelWebSocketProtocol: {Binary: false, Channels: channels},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Conn supports sending multiple binary channels over a websocket connection.
|
// Conn supports sending multiple binary channels over a websocket connection.
|
||||||
// Supports only the "channel.k8s.io" subprotocol.
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
channels []*websocketChannel
|
protocols map[string]ChannelProtocolConfig
|
||||||
codec codecType
|
selectedProtocol string
|
||||||
ready chan struct{}
|
channels []*websocketChannel
|
||||||
ws *websocket.Conn
|
codec codecType
|
||||||
timeout time.Duration
|
ready chan struct{}
|
||||||
|
ws *websocket.Conn
|
||||||
|
timeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConn creates a WebSocket connection that supports a set of channels. Channels begin each
|
// NewConn creates a WebSocket connection that supports a set of channels. Channels begin each
|
||||||
// web socket message with a single byte indicating the channel number (0-N). 255 is reserved for
|
// web socket message with a single byte indicating the channel number (0-N). 255 is reserved for
|
||||||
// future use. The channel types for each channel are passed as an array, supporting the different
|
// future use. The channel types for each channel are passed as an array, supporting the different
|
||||||
// duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer.
|
// duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer.
|
||||||
func NewConn(channels ...ChannelType) *Conn {
|
//
|
||||||
conn := &Conn{
|
// The protocols parameter maps subprotocol names to ChannelProtocols. The empty string subprotocol
|
||||||
ready: make(chan struct{}),
|
// name is used if websocket.Config.Protocol is empty.
|
||||||
channels: make([]*websocketChannel, len(channels)),
|
func NewConn(protocols map[string]ChannelProtocolConfig) *Conn {
|
||||||
|
return &Conn{
|
||||||
|
ready: make(chan struct{}),
|
||||||
|
protocols: protocols,
|
||||||
}
|
}
|
||||||
for i := range conn.channels {
|
|
||||||
switch channels[i] {
|
|
||||||
case ReadChannel:
|
|
||||||
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
|
|
||||||
case WriteChannel:
|
|
||||||
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
|
|
||||||
case ReadWriteChannel:
|
|
||||||
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
|
|
||||||
case IgnoreChannel:
|
|
||||||
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return conn
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified,
|
// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified,
|
||||||
@ -160,8 +171,9 @@ func (conn *Conn) SetIdleTimeout(duration time.Duration) {
|
|||||||
conn.timeout = duration
|
conn.timeout = duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open the connection and create channels for reading and writing.
|
// Open the connection and create channels for reading and writing. It returns
|
||||||
func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) ([]io.ReadWriteCloser, error) {
|
// the selected subprotocol, a slice of channels and an error.
|
||||||
|
func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) {
|
||||||
go func() {
|
go func() {
|
||||||
defer runtime.HandleCrash()
|
defer runtime.HandleCrash()
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
@ -172,23 +184,42 @@ func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) ([]io.ReadWrite
|
|||||||
for i := range conn.channels {
|
for i := range conn.channels {
|
||||||
rwc[i] = conn.channels[i]
|
rwc[i] = conn.channels[i]
|
||||||
}
|
}
|
||||||
return rwc, nil
|
return conn.selectedProtocol, rwc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) initialize(ws *websocket.Conn) {
|
func (conn *Conn) initialize(ws *websocket.Conn) {
|
||||||
protocols := ws.Config().Protocol
|
negotiated := ws.Config().Protocol
|
||||||
switch {
|
conn.selectedProtocol = negotiated[0]
|
||||||
case len(protocols) == 0, protocols[0] == channelWebSocketProtocol:
|
p := conn.protocols[conn.selectedProtocol]
|
||||||
|
if p.Binary {
|
||||||
conn.codec = rawCodec
|
conn.codec = rawCodec
|
||||||
case protocols[0] == base64ChannelWebSocketProtocol:
|
} else {
|
||||||
conn.codec = base64Codec
|
conn.codec = base64Codec
|
||||||
}
|
}
|
||||||
conn.ws = ws
|
conn.ws = ws
|
||||||
|
conn.channels = make([]*websocketChannel, len(p.Channels))
|
||||||
|
for i, t := range p.Channels {
|
||||||
|
switch t {
|
||||||
|
case ReadChannel:
|
||||||
|
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
|
||||||
|
case WriteChannel:
|
||||||
|
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
|
||||||
|
case ReadWriteChannel:
|
||||||
|
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
|
||||||
|
case IgnoreChannel:
|
||||||
|
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
close(conn.ready)
|
close(conn.ready)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error {
|
func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error {
|
||||||
return handshake(config, req, []string{channelWebSocketProtocol, base64ChannelWebSocketProtocol})
|
supportedProtocols := make([]string, 0, len(conn.protocols))
|
||||||
|
for p := range conn.protocols {
|
||||||
|
supportedProtocols = append(supportedProtocols, p)
|
||||||
|
}
|
||||||
|
return handshake(config, req, supportedProtocols)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) resetTimeout() {
|
func (conn *Conn) resetTimeout() {
|
||||||
|
@ -20,6 +20,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
@ -28,15 +29,19 @@ import (
|
|||||||
"golang.org/x/net/websocket"
|
"golang.org/x/net/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newServer(handler websocket.Handler) (*httptest.Server, string) {
|
func newServer(handler http.Handler) (*httptest.Server, string) {
|
||||||
server := httptest.NewServer(handler)
|
server := httptest.NewServer(handler)
|
||||||
serverAddr := server.Listener.Addr().String()
|
serverAddr := server.Listener.Addr().String()
|
||||||
return server, serverAddr
|
return server, serverAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRawConn(t *testing.T) {
|
func TestRawConn(t *testing.T) {
|
||||||
conn := NewConn(ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel)
|
channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel}
|
||||||
s, addr := newServer(conn.handle)
|
conn := NewConn(NewDefaultChannelProtocols(channels))
|
||||||
|
|
||||||
|
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
conn.Open(w, req)
|
||||||
|
}))
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
client, err := websocket.Dial("ws://"+addr, "", "http://localhost/")
|
client, err := websocket.Dial("ws://"+addr, "", "http://localhost/")
|
||||||
@ -112,8 +117,10 @@ func TestRawConn(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBase64Conn(t *testing.T) {
|
func TestBase64Conn(t *testing.T) {
|
||||||
conn := NewConn(ReadWriteChannel, ReadWriteChannel)
|
conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel}))
|
||||||
s, addr := newServer(conn.handle)
|
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
conn.Open(w, req)
|
||||||
|
}))
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
|
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
|
||||||
@ -167,3 +174,99 @@ func TestBase64Conn(t *testing.T) {
|
|||||||
client.Close()
|
client.Close()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type versionTest struct {
|
||||||
|
supported map[string]bool // protocol -> binary
|
||||||
|
requested []string
|
||||||
|
error bool
|
||||||
|
expected string
|
||||||
|
}
|
||||||
|
|
||||||
|
func versionTests() []versionTest {
|
||||||
|
const (
|
||||||
|
binary = true
|
||||||
|
base64 = false
|
||||||
|
)
|
||||||
|
return []versionTest{
|
||||||
|
{
|
||||||
|
supported: nil,
|
||||||
|
requested: []string{"raw"},
|
||||||
|
error: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
|
||||||
|
requested: nil,
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
|
||||||
|
requested: []string{"v1.raw"},
|
||||||
|
error: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
|
||||||
|
requested: []string{"v1.raw", "v1.base64"},
|
||||||
|
error: true,
|
||||||
|
}, {
|
||||||
|
supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
|
||||||
|
requested: []string{"v1.raw", "raw"},
|
||||||
|
expected: "raw",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
|
||||||
|
requested: []string{"v1.raw"},
|
||||||
|
expected: "v1.raw",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
|
||||||
|
requested: []string{"v2.base64"},
|
||||||
|
expected: "v2.base64",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVersionedConn(t *testing.T) {
|
||||||
|
for i, test := range versionTests() {
|
||||||
|
func() {
|
||||||
|
supportedProtocols := map[string]ChannelProtocolConfig{}
|
||||||
|
for p, binary := range test.supported {
|
||||||
|
supportedProtocols[p] = ChannelProtocolConfig{
|
||||||
|
Binary: binary,
|
||||||
|
Channels: []ChannelType{ReadWriteChannel},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conn := NewConn(supportedProtocols)
|
||||||
|
// note that it's not enough to wait for conn.ready to avoid a race here. Hence,
|
||||||
|
// we use a channel.
|
||||||
|
selectedProtocol := make(chan string, 0)
|
||||||
|
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
p, _, _ := conn.Open(w, req)
|
||||||
|
selectedProtocol <- p
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
config.Protocol = test.requested
|
||||||
|
client, err := websocket.DialConfig(config)
|
||||||
|
if err != nil {
|
||||||
|
if !test.error {
|
||||||
|
t.Fatalf("test %d: didn't expect error: %v", i, err)
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
if test.error && err == nil {
|
||||||
|
t.Fatalf("test %d: expected an error", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
<-conn.ready
|
||||||
|
if got, expected := <-selectedProtocol, test.expected; got != expected {
|
||||||
|
t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -23,6 +23,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/net/websocket"
|
"golang.org/x/net/websocket"
|
||||||
|
|
||||||
"k8s.io/kubernetes/pkg/util/runtime"
|
"k8s.io/kubernetes/pkg/util/runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -37,23 +38,46 @@ const binaryWebSocketProtocol = "binary.k8s.io"
|
|||||||
// possible.
|
// possible.
|
||||||
const base64BinaryWebSocketProtocol = "base64.binary.k8s.io"
|
const base64BinaryWebSocketProtocol = "base64.binary.k8s.io"
|
||||||
|
|
||||||
|
// ReaderProtocolConfig describes a websocket subprotocol with one stream.
|
||||||
|
type ReaderProtocolConfig struct {
|
||||||
|
Binary bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultReaderProtocols returns a stream protocol map with the
|
||||||
|
// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io".
|
||||||
|
func NewDefaultReaderProtocols() map[string]ReaderProtocolConfig {
|
||||||
|
return map[string]ReaderProtocolConfig{
|
||||||
|
"": {Binary: true},
|
||||||
|
binaryWebSocketProtocol: {Binary: true},
|
||||||
|
base64BinaryWebSocketProtocol: {Binary: false},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Reader supports returning an arbitrary byte stream over a websocket channel.
|
// Reader supports returning an arbitrary byte stream over a websocket channel.
|
||||||
// Supports the "binary.k8s.io" and "base64.binary.k8s.io" subprotocols.
|
|
||||||
type Reader struct {
|
type Reader struct {
|
||||||
err chan error
|
err chan error
|
||||||
r io.Reader
|
r io.Reader
|
||||||
ping bool
|
ping bool
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
|
protocols map[string]ReaderProtocolConfig
|
||||||
|
selectedProtocol string
|
||||||
|
|
||||||
|
handleCrash func() // overridable for testing
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewReader creates a WebSocket pipe that will copy the contents of r to a provided
|
// NewReader creates a WebSocket pipe that will copy the contents of r to a provided
|
||||||
// WebSocket connection. If ping is true, a zero length message will be sent to the client
|
// WebSocket connection. If ping is true, a zero length message will be sent to the client
|
||||||
// before the stream begins reading.
|
// before the stream begins reading.
|
||||||
func NewReader(r io.Reader, ping bool) *Reader {
|
//
|
||||||
|
// The protocols parameter maps subprotocol names to StreamProtocols. The empty string
|
||||||
|
// subprotocol name is used if websocket.Config.Protocol is empty.
|
||||||
|
func NewReader(r io.Reader, ping bool, protocols map[string]ReaderProtocolConfig) *Reader {
|
||||||
return &Reader{
|
return &Reader{
|
||||||
r: r,
|
r: r,
|
||||||
err: make(chan error),
|
err: make(chan error),
|
||||||
ping: ping,
|
ping: ping,
|
||||||
|
protocols: protocols,
|
||||||
|
handleCrash: func() { runtime.HandleCrash() },
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,14 +88,18 @@ func (r *Reader) SetIdleTimeout(duration time.Duration) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Reader) handshake(config *websocket.Config, req *http.Request) error {
|
func (r *Reader) handshake(config *websocket.Config, req *http.Request) error {
|
||||||
return handshake(config, req, []string{binaryWebSocketProtocol, base64BinaryWebSocketProtocol})
|
supportedProtocols := make([]string, 0, len(r.protocols))
|
||||||
|
for p := range r.protocols {
|
||||||
|
supportedProtocols = append(supportedProtocols, p)
|
||||||
|
}
|
||||||
|
return handshake(config, req, supportedProtocols)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy the reader to the response. The created WebSocket is closed after this
|
// Copy the reader to the response. The created WebSocket is closed after this
|
||||||
// method completes.
|
// method completes.
|
||||||
func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error {
|
func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error {
|
||||||
go func() {
|
go func() {
|
||||||
defer runtime.HandleCrash()
|
defer r.handleCrash()
|
||||||
websocket.Server{Handshake: r.handshake, Handler: r.handle}.ServeHTTP(w, req)
|
websocket.Server{Handshake: r.handshake, Handler: r.handle}.ServeHTTP(w, req)
|
||||||
}()
|
}()
|
||||||
return <-r.err
|
return <-r.err
|
||||||
@ -79,11 +107,12 @@ func (r *Reader) Copy(w http.ResponseWriter, req *http.Request) error {
|
|||||||
|
|
||||||
// handle implements a WebSocket handler.
|
// handle implements a WebSocket handler.
|
||||||
func (r *Reader) handle(ws *websocket.Conn) {
|
func (r *Reader) handle(ws *websocket.Conn) {
|
||||||
encode := len(ws.Config().Protocol) > 0 && ws.Config().Protocol[0] == base64BinaryWebSocketProtocol
|
negotiated := ws.Config().Protocol
|
||||||
|
r.selectedProtocol = negotiated[0]
|
||||||
defer close(r.err)
|
defer close(r.err)
|
||||||
defer ws.Close()
|
defer ws.Close()
|
||||||
go IgnoreReceives(ws, r.timeout)
|
go IgnoreReceives(ws, r.timeout)
|
||||||
r.err <- messageCopy(ws, r.r, encode, r.ping, r.timeout)
|
r.err <- messageCopy(ws, r.r, !r.protocols[r.selectedProtocol].Binary, r.ping, r.timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func resetTimeout(ws *websocket.Conn, timeout time.Duration) {
|
func resetTimeout(ws *websocket.Conn, timeout time.Duration) {
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@ -32,7 +33,7 @@ import (
|
|||||||
|
|
||||||
func TestStream(t *testing.T) {
|
func TestStream(t *testing.T) {
|
||||||
input := "some random text"
|
input := "some random text"
|
||||||
r := NewReader(bytes.NewBuffer([]byte(input)), true)
|
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
|
||||||
r.SetIdleTimeout(time.Second)
|
r.SetIdleTimeout(time.Second)
|
||||||
data, err := readWebSocket(r, t, nil)
|
data, err := readWebSocket(r, t, nil)
|
||||||
if !reflect.DeepEqual(data, []byte(input)) {
|
if !reflect.DeepEqual(data, []byte(input)) {
|
||||||
@ -45,7 +46,7 @@ func TestStream(t *testing.T) {
|
|||||||
|
|
||||||
func TestStreamPing(t *testing.T) {
|
func TestStreamPing(t *testing.T) {
|
||||||
input := "some random text"
|
input := "some random text"
|
||||||
r := NewReader(bytes.NewBuffer([]byte(input)), true)
|
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
|
||||||
r.SetIdleTimeout(time.Second)
|
r.SetIdleTimeout(time.Second)
|
||||||
err := expectWebSocketFrames(r, t, nil, [][]byte{
|
err := expectWebSocketFrames(r, t, nil, [][]byte{
|
||||||
{},
|
{},
|
||||||
@ -59,8 +60,8 @@ func TestStreamPing(t *testing.T) {
|
|||||||
func TestStreamBase64(t *testing.T) {
|
func TestStreamBase64(t *testing.T) {
|
||||||
input := "some random text"
|
input := "some random text"
|
||||||
encoded := base64.StdEncoding.EncodeToString([]byte(input))
|
encoded := base64.StdEncoding.EncodeToString([]byte(input))
|
||||||
r := NewReader(bytes.NewBuffer([]byte(input)), true)
|
r := NewReader(bytes.NewBuffer([]byte(input)), true, NewDefaultReaderProtocols())
|
||||||
data, err := readWebSocket(r, t, nil, base64BinaryWebSocketProtocol)
|
data, err := readWebSocket(r, t, nil, "base64.binary.k8s.io")
|
||||||
if !reflect.DeepEqual(data, []byte(encoded)) {
|
if !reflect.DeepEqual(data, []byte(encoded)) {
|
||||||
t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
|
t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
|
||||||
}
|
}
|
||||||
@ -69,6 +70,73 @@ func TestStreamBase64(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStreamVersionedBase64(t *testing.T) {
|
||||||
|
input := "some random text"
|
||||||
|
encoded := base64.StdEncoding.EncodeToString([]byte(input))
|
||||||
|
r := NewReader(bytes.NewBuffer([]byte(input)), true, map[string]ReaderProtocolConfig{
|
||||||
|
"": {Binary: true},
|
||||||
|
"binary.k8s.io": {Binary: true},
|
||||||
|
"base64.binary.k8s.io": {Binary: false},
|
||||||
|
"v1.binary.k8s.io": {Binary: true},
|
||||||
|
"v1.base64.binary.k8s.io": {Binary: false},
|
||||||
|
"v2.binary.k8s.io": {Binary: true},
|
||||||
|
"v2.base64.binary.k8s.io": {Binary: false},
|
||||||
|
})
|
||||||
|
data, err := readWebSocket(r, t, nil, "v2.base64.binary.k8s.io")
|
||||||
|
if !reflect.DeepEqual(data, []byte(encoded)) {
|
||||||
|
t.Errorf("unexpected server read: %v\n%v", data, []byte(encoded))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStreamVersionedCopy(t *testing.T) {
|
||||||
|
for i, test := range versionTests() {
|
||||||
|
func() {
|
||||||
|
supportedProtocols := map[string]ReaderProtocolConfig{}
|
||||||
|
for p, binary := range test.supported {
|
||||||
|
supportedProtocols[p] = ReaderProtocolConfig{
|
||||||
|
Binary: binary,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input := "some random text"
|
||||||
|
r := NewReader(bytes.NewBuffer([]byte(input)), true, supportedProtocols)
|
||||||
|
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
err := r.Copy(w, req)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(503)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
config.Protocol = test.requested
|
||||||
|
client, err := websocket.DialConfig(config)
|
||||||
|
if err != nil {
|
||||||
|
if !test.error {
|
||||||
|
t.Errorf("test %d: didn't expect error: %v", i, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
if test.error && err == nil {
|
||||||
|
t.Errorf("test %d: expected an error", i)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
<-r.err
|
||||||
|
if got, expected := r.selectedProtocol, test.expected; got != expected {
|
||||||
|
t.Errorf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStreamError(t *testing.T) {
|
func TestStreamError(t *testing.T) {
|
||||||
input := "some random text"
|
input := "some random text"
|
||||||
errs := &errorReader{
|
errs := &errorReader{
|
||||||
@ -78,7 +146,7 @@ func TestStreamError(t *testing.T) {
|
|||||||
},
|
},
|
||||||
err: fmt.Errorf("bad read"),
|
err: fmt.Errorf("bad read"),
|
||||||
}
|
}
|
||||||
r := NewReader(errs, false)
|
r := NewReader(errs, false, NewDefaultReaderProtocols())
|
||||||
|
|
||||||
data, err := readWebSocket(r, t, nil)
|
data, err := readWebSocket(r, t, nil)
|
||||||
if !reflect.DeepEqual(data, []byte(input)) {
|
if !reflect.DeepEqual(data, []byte(input)) {
|
||||||
@ -98,7 +166,10 @@ func TestStreamSurvivesPanic(t *testing.T) {
|
|||||||
},
|
},
|
||||||
panicMessage: "bad read",
|
panicMessage: "bad read",
|
||||||
}
|
}
|
||||||
r := NewReader(errs, false)
|
r := NewReader(errs, false, NewDefaultReaderProtocols())
|
||||||
|
|
||||||
|
// do not call runtime.HandleCrash() in handler. Otherwise, the tests are interrupted.
|
||||||
|
r.handleCrash = func() { recover() }
|
||||||
|
|
||||||
data, err := readWebSocket(r, t, nil)
|
data, err := readWebSocket(r, t, nil)
|
||||||
if !reflect.DeepEqual(data, []byte(input)) {
|
if !reflect.DeepEqual(data, []byte(input)) {
|
||||||
@ -121,7 +192,7 @@ func TestStreamClosedDuringRead(t *testing.T) {
|
|||||||
err: fmt.Errorf("stuff"),
|
err: fmt.Errorf("stuff"),
|
||||||
pause: ch,
|
pause: ch,
|
||||||
}
|
}
|
||||||
r := NewReader(errs, false)
|
r := NewReader(errs, false, NewDefaultReaderProtocols())
|
||||||
|
|
||||||
data, err := readWebSocket(r, t, func(c *websocket.Conn) {
|
data, err := readWebSocket(r, t, func(c *websocket.Conn) {
|
||||||
c.Close()
|
c.Close()
|
||||||
@ -163,19 +234,13 @@ func (r *errorReader) Read(p []byte) (int, error) {
|
|||||||
|
|
||||||
func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) {
|
func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols ...string) ([]byte, error) {
|
||||||
errCh := make(chan error, 1)
|
errCh := make(chan error, 1)
|
||||||
s, addr := newServer(func(ws *websocket.Conn) {
|
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
cfg := ws.Config()
|
errCh <- r.Copy(w, req)
|
||||||
cfg.Protocol = protocols
|
}))
|
||||||
go IgnoreReceives(ws, 0)
|
|
||||||
go func() {
|
|
||||||
err := <-r.err
|
|
||||||
errCh <- err
|
|
||||||
}()
|
|
||||||
r.handle(ws)
|
|
||||||
})
|
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
|
config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
|
||||||
|
config.Protocol = protocols
|
||||||
client, err := websocket.DialConfig(config)
|
client, err := websocket.DialConfig(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -195,19 +260,13 @@ func readWebSocket(r *Reader, t *testing.T, fn func(*websocket.Conn), protocols
|
|||||||
|
|
||||||
func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error {
|
func expectWebSocketFrames(r *Reader, t *testing.T, fn func(*websocket.Conn), frames [][]byte, protocols ...string) error {
|
||||||
errCh := make(chan error, 1)
|
errCh := make(chan error, 1)
|
||||||
s, addr := newServer(func(ws *websocket.Conn) {
|
s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
cfg := ws.Config()
|
errCh <- r.Copy(w, req)
|
||||||
cfg.Protocol = protocols
|
}))
|
||||||
go IgnoreReceives(ws, 0)
|
|
||||||
go func() {
|
|
||||||
err := <-r.err
|
|
||||||
errCh <- err
|
|
||||||
}()
|
|
||||||
r.handle(ws)
|
|
||||||
})
|
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
|
|
||||||
config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
|
config, _ := websocket.NewConfig("ws://"+addr, "http://"+addr)
|
||||||
|
config.Protocol = protocols
|
||||||
ws, err := websocket.DialConfig(config)
|
ws, err := websocket.DialConfig(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
Loading…
Reference in New Issue
Block a user