From e8b5ff9ea3088fae0f88a5a7e28f9b5be3e7bde3 Mon Sep 17 00:00:00 2001 From: Jordan Liggitt Date: Tue, 27 Feb 2024 18:00:45 -0500 Subject: [PATCH] Use the websocket protocol header, verify selected protocol Kubernetes-commit: b394aac4ce36457bd37459a58b4c3536d2f43d86 --- transport/websocket/roundtripper.go | 27 ++++++++++++++++++++---- transport/websocket/roundtripper_test.go | 5 +++-- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/transport/websocket/roundtripper.go b/transport/websocket/roundtripper.go index 010f916b..624dd547 100644 --- a/transport/websocket/roundtripper.go +++ b/transport/websocket/roundtripper.go @@ -18,6 +18,7 @@ package websocket import ( "crypto/tls" + "errors" "fmt" "net/http" "net/url" @@ -25,6 +26,7 @@ import ( gwebsocket "github.com/gorilla/websocket" "k8s.io/apimachinery/pkg/util/httpstream" + "k8s.io/apimachinery/pkg/util/httpstream/wsstream" utilnet "k8s.io/apimachinery/pkg/util/net" restclient "k8s.io/client-go/rest" "k8s.io/client-go/transport" @@ -88,8 +90,8 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response }() // set the protocol version directly on the dialer from the header - protocolVersions := request.Header[httpstream.HeaderProtocolVersion] - delete(request.Header, httpstream.HeaderProtocolVersion) + protocolVersions := request.Header[wsstream.WebSocketProtocolHeader] + delete(request.Header, wsstream.WebSocketProtocolHeader) dialer := gwebsocket.Dialer{ Proxy: rt.Proxier, @@ -108,7 +110,23 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response } wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header) if err != nil { - return nil, &httpstream.UpgradeFailureError{Cause: err} + if errors.Is(err, gwebsocket.ErrBadHandshake) { + return nil, &httpstream.UpgradeFailureError{Cause: err} + } + return nil, err + } + + // Ensure we got back a protocol we understand + foundProtocol := false + for _, protocolVersion := range protocolVersions { + if protocolVersion == wsConn.Subprotocol() { + foundProtocol = true + break + } + } + if !foundProtocol { + wsConn.Close() // nolint:errcheck + return nil, &httpstream.UpgradeFailureError{Cause: fmt.Errorf("invalid protocol, expected one of %q, got %q", protocolVersions, wsConn.Subprotocol())} } rt.Conn = wsConn @@ -149,7 +167,8 @@ func RoundTripperFor(config *restclient.Config) (http.RoundTripper, ConnectionHo // a WebSocket connection. Upon success, it returns the negotiated connection. // The round tripper rt must use the WebSocket round tripper wsRt - see RoundTripperFor. func Negotiate(rt http.RoundTripper, connectionInfo ConnectionHolder, req *http.Request, protocols ...string) (*gwebsocket.Conn, error) { - req.Header[httpstream.HeaderProtocolVersion] = protocols + // Plumb protocols to RoundTripper#RoundTrip + req.Header[wsstream.WebSocketProtocolHeader] = protocols resp, err := rt.RoundTrip(req) if err != nil { return nil, err diff --git a/transport/websocket/roundtripper_test.go b/transport/websocket/roundtripper_test.go index 16bfbf57..39baba4b 100644 --- a/transport/websocket/roundtripper_test.go +++ b/transport/websocket/roundtripper_test.go @@ -54,7 +54,7 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) { rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) require.NoError(t, err) requestedProtocol := remotecommand.StreamProtocolV5Name - req.Header[httpstream.HeaderProtocolVersion] = []string{requestedProtocol} + req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol} _, err = rt.RoundTrip(req) require.NoError(t, err) // WebSocket Connection is stored in websocket RoundTripper. @@ -83,11 +83,12 @@ func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) { require.NoError(t, err) // Requested subprotocol version 1 is not supported by test websocket server. requestedProtocol := remotecommand.StreamProtocolV1Name - req.Header[httpstream.HeaderProtocolVersion] = []string{requestedProtocol} + req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol} _, err = rt.RoundTrip(req) // Ensure a "bad handshake" error is returned, since requested protocol is not supported. require.Error(t, err) assert.True(t, strings.Contains(err.Error(), "bad handshake")) + assert.True(t, httpstream.IsUpgradeFailure(err)) } func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) {