mirror of
https://github.com/kubernetes/client-go.git
synced 2025-06-23 21:57:30 +00:00
Merge pull request #123542 from liggitt/websocket-round-tripper-protocol
Use the websocket protocol header, verify selected protocol Kubernetes-commit: e21a2f5d4f010e49cea1b954bd9b31d94e712c5b
This commit is contained in:
commit
d99a76c51e
@ -18,6 +18,7 @@ package websocket
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -25,6 +26,7 @@ import (
|
|||||||
gwebsocket "github.com/gorilla/websocket"
|
gwebsocket "github.com/gorilla/websocket"
|
||||||
|
|
||||||
"k8s.io/apimachinery/pkg/util/httpstream"
|
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||||
|
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
|
||||||
utilnet "k8s.io/apimachinery/pkg/util/net"
|
utilnet "k8s.io/apimachinery/pkg/util/net"
|
||||||
restclient "k8s.io/client-go/rest"
|
restclient "k8s.io/client-go/rest"
|
||||||
"k8s.io/client-go/transport"
|
"k8s.io/client-go/transport"
|
||||||
@ -88,8 +90,8 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// set the protocol version directly on the dialer from the header
|
// set the protocol version directly on the dialer from the header
|
||||||
protocolVersions := request.Header[httpstream.HeaderProtocolVersion]
|
protocolVersions := request.Header[wsstream.WebSocketProtocolHeader]
|
||||||
delete(request.Header, httpstream.HeaderProtocolVersion)
|
delete(request.Header, wsstream.WebSocketProtocolHeader)
|
||||||
|
|
||||||
dialer := gwebsocket.Dialer{
|
dialer := gwebsocket.Dialer{
|
||||||
Proxy: rt.Proxier,
|
Proxy: rt.Proxier,
|
||||||
@ -108,8 +110,24 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response
|
|||||||
}
|
}
|
||||||
wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header)
|
wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if errors.Is(err, gwebsocket.ErrBadHandshake) {
|
||||||
return nil, &httpstream.UpgradeFailureError{Cause: err}
|
return nil, &httpstream.UpgradeFailureError{Cause: err}
|
||||||
}
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure we got back a protocol we understand
|
||||||
|
foundProtocol := false
|
||||||
|
for _, protocolVersion := range protocolVersions {
|
||||||
|
if protocolVersion == wsConn.Subprotocol() {
|
||||||
|
foundProtocol = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundProtocol {
|
||||||
|
wsConn.Close() // nolint:errcheck
|
||||||
|
return nil, &httpstream.UpgradeFailureError{Cause: fmt.Errorf("invalid protocol, expected one of %q, got %q", protocolVersions, wsConn.Subprotocol())}
|
||||||
|
}
|
||||||
|
|
||||||
rt.Conn = wsConn
|
rt.Conn = wsConn
|
||||||
|
|
||||||
@ -149,7 +167,8 @@ func RoundTripperFor(config *restclient.Config) (http.RoundTripper, ConnectionHo
|
|||||||
// a WebSocket connection. Upon success, it returns the negotiated connection.
|
// a WebSocket connection. Upon success, it returns the negotiated connection.
|
||||||
// The round tripper rt must use the WebSocket round tripper wsRt - see RoundTripperFor.
|
// The round tripper rt must use the WebSocket round tripper wsRt - see RoundTripperFor.
|
||||||
func Negotiate(rt http.RoundTripper, connectionInfo ConnectionHolder, req *http.Request, protocols ...string) (*gwebsocket.Conn, error) {
|
func Negotiate(rt http.RoundTripper, connectionInfo ConnectionHolder, req *http.Request, protocols ...string) (*gwebsocket.Conn, error) {
|
||||||
req.Header[httpstream.HeaderProtocolVersion] = protocols
|
// Plumb protocols to RoundTripper#RoundTrip
|
||||||
|
req.Header[wsstream.WebSocketProtocolHeader] = protocols
|
||||||
resp, err := rt.RoundTrip(req)
|
resp, err := rt.RoundTrip(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -54,7 +54,7 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) {
|
|||||||
rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
|
rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
requestedProtocol := remotecommand.StreamProtocolV5Name
|
requestedProtocol := remotecommand.StreamProtocolV5Name
|
||||||
req.Header[httpstream.HeaderProtocolVersion] = []string{requestedProtocol}
|
req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
|
||||||
_, err = rt.RoundTrip(req)
|
_, err = rt.RoundTrip(req)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// WebSocket Connection is stored in websocket RoundTripper.
|
// WebSocket Connection is stored in websocket RoundTripper.
|
||||||
@ -83,11 +83,12 @@ func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// Requested subprotocol version 1 is not supported by test websocket server.
|
// Requested subprotocol version 1 is not supported by test websocket server.
|
||||||
requestedProtocol := remotecommand.StreamProtocolV1Name
|
requestedProtocol := remotecommand.StreamProtocolV1Name
|
||||||
req.Header[httpstream.HeaderProtocolVersion] = []string{requestedProtocol}
|
req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
|
||||||
_, err = rt.RoundTrip(req)
|
_, err = rt.RoundTrip(req)
|
||||||
// Ensure a "bad handshake" error is returned, since requested protocol is not supported.
|
// Ensure a "bad handshake" error is returned, since requested protocol is not supported.
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.True(t, strings.Contains(err.Error(), "bad handshake"))
|
assert.True(t, strings.Contains(err.Error(), "bad handshake"))
|
||||||
|
assert.True(t, httpstream.IsUpgradeFailure(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) {
|
func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) {
|
||||||
|
Loading…
Reference in New Issue
Block a user