Merge pull request #123542 from liggitt/websocket-round-tripper-protocol

Use the websocket protocol header, verify selected protocol

Kubernetes-commit: e21a2f5d4f010e49cea1b954bd9b31d94e712c5b
This commit is contained in:
Kubernetes Publisher 2024-02-28 11:01:44 -08:00
commit d99a76c51e
2 changed files with 26 additions and 6 deletions

View File

@ -18,6 +18,7 @@ package websocket
import (
"crypto/tls"
"errors"
"fmt"
"net/http"
"net/url"
@ -25,6 +26,7 @@ import (
gwebsocket "github.com/gorilla/websocket"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
utilnet "k8s.io/apimachinery/pkg/util/net"
restclient "k8s.io/client-go/rest"
"k8s.io/client-go/transport"
@ -88,8 +90,8 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response
}()
// set the protocol version directly on the dialer from the header
protocolVersions := request.Header[httpstream.HeaderProtocolVersion]
delete(request.Header, httpstream.HeaderProtocolVersion)
protocolVersions := request.Header[wsstream.WebSocketProtocolHeader]
delete(request.Header, wsstream.WebSocketProtocolHeader)
dialer := gwebsocket.Dialer{
Proxy: rt.Proxier,
@ -108,7 +110,23 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response
}
wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header)
if err != nil {
return nil, &httpstream.UpgradeFailureError{Cause: err}
if errors.Is(err, gwebsocket.ErrBadHandshake) {
return nil, &httpstream.UpgradeFailureError{Cause: err}
}
return nil, err
}
// Ensure we got back a protocol we understand
foundProtocol := false
for _, protocolVersion := range protocolVersions {
if protocolVersion == wsConn.Subprotocol() {
foundProtocol = true
break
}
}
if !foundProtocol {
wsConn.Close() // nolint:errcheck
return nil, &httpstream.UpgradeFailureError{Cause: fmt.Errorf("invalid protocol, expected one of %q, got %q", protocolVersions, wsConn.Subprotocol())}
}
rt.Conn = wsConn
@ -149,7 +167,8 @@ func RoundTripperFor(config *restclient.Config) (http.RoundTripper, ConnectionHo
// a WebSocket connection. Upon success, it returns the negotiated connection.
// The round tripper rt must use the WebSocket round tripper wsRt - see RoundTripperFor.
func Negotiate(rt http.RoundTripper, connectionInfo ConnectionHolder, req *http.Request, protocols ...string) (*gwebsocket.Conn, error) {
req.Header[httpstream.HeaderProtocolVersion] = protocols
// Plumb protocols to RoundTripper#RoundTrip
req.Header[wsstream.WebSocketProtocolHeader] = protocols
resp, err := rt.RoundTrip(req)
if err != nil {
return nil, err

View File

@ -54,7 +54,7 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) {
rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
require.NoError(t, err)
requestedProtocol := remotecommand.StreamProtocolV5Name
req.Header[httpstream.HeaderProtocolVersion] = []string{requestedProtocol}
req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
_, err = rt.RoundTrip(req)
require.NoError(t, err)
// WebSocket Connection is stored in websocket RoundTripper.
@ -83,11 +83,12 @@ func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) {
require.NoError(t, err)
// Requested subprotocol version 1 is not supported by test websocket server.
requestedProtocol := remotecommand.StreamProtocolV1Name
req.Header[httpstream.HeaderProtocolVersion] = []string{requestedProtocol}
req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
_, err = rt.RoundTrip(req)
// Ensure a "bad handshake" error is returned, since requested protocol is not supported.
require.Error(t, err)
assert.True(t, strings.Contains(err.Error(), "bad handshake"))
assert.True(t, httpstream.IsUpgradeFailure(err))
}
func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) {