mirror of
https://github.com/kubernetes/client-go.git
synced 2025-06-22 21:27:07 +00:00
Merge pull request #123542 from liggitt/websocket-round-tripper-protocol
Use the websocket protocol header, verify selected protocol Kubernetes-commit: e21a2f5d4f010e49cea1b954bd9b31d94e712c5b
This commit is contained in:
commit
d99a76c51e
@ -18,6 +18,7 @@ package websocket
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@ -25,6 +26,7 @@ import (
|
||||
gwebsocket "github.com/gorilla/websocket"
|
||||
|
||||
"k8s.io/apimachinery/pkg/util/httpstream"
|
||||
"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
|
||||
utilnet "k8s.io/apimachinery/pkg/util/net"
|
||||
restclient "k8s.io/client-go/rest"
|
||||
"k8s.io/client-go/transport"
|
||||
@ -88,8 +90,8 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response
|
||||
}()
|
||||
|
||||
// set the protocol version directly on the dialer from the header
|
||||
protocolVersions := request.Header[httpstream.HeaderProtocolVersion]
|
||||
delete(request.Header, httpstream.HeaderProtocolVersion)
|
||||
protocolVersions := request.Header[wsstream.WebSocketProtocolHeader]
|
||||
delete(request.Header, wsstream.WebSocketProtocolHeader)
|
||||
|
||||
dialer := gwebsocket.Dialer{
|
||||
Proxy: rt.Proxier,
|
||||
@ -108,7 +110,23 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response
|
||||
}
|
||||
wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header)
|
||||
if err != nil {
|
||||
return nil, &httpstream.UpgradeFailureError{Cause: err}
|
||||
if errors.Is(err, gwebsocket.ErrBadHandshake) {
|
||||
return nil, &httpstream.UpgradeFailureError{Cause: err}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure we got back a protocol we understand
|
||||
foundProtocol := false
|
||||
for _, protocolVersion := range protocolVersions {
|
||||
if protocolVersion == wsConn.Subprotocol() {
|
||||
foundProtocol = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundProtocol {
|
||||
wsConn.Close() // nolint:errcheck
|
||||
return nil, &httpstream.UpgradeFailureError{Cause: fmt.Errorf("invalid protocol, expected one of %q, got %q", protocolVersions, wsConn.Subprotocol())}
|
||||
}
|
||||
|
||||
rt.Conn = wsConn
|
||||
@ -149,7 +167,8 @@ func RoundTripperFor(config *restclient.Config) (http.RoundTripper, ConnectionHo
|
||||
// a WebSocket connection. Upon success, it returns the negotiated connection.
|
||||
// The round tripper rt must use the WebSocket round tripper wsRt - see RoundTripperFor.
|
||||
func Negotiate(rt http.RoundTripper, connectionInfo ConnectionHolder, req *http.Request, protocols ...string) (*gwebsocket.Conn, error) {
|
||||
req.Header[httpstream.HeaderProtocolVersion] = protocols
|
||||
// Plumb protocols to RoundTripper#RoundTrip
|
||||
req.Header[wsstream.WebSocketProtocolHeader] = protocols
|
||||
resp, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -54,7 +54,7 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) {
|
||||
rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
|
||||
require.NoError(t, err)
|
||||
requestedProtocol := remotecommand.StreamProtocolV5Name
|
||||
req.Header[httpstream.HeaderProtocolVersion] = []string{requestedProtocol}
|
||||
req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
|
||||
_, err = rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
// WebSocket Connection is stored in websocket RoundTripper.
|
||||
@ -83,11 +83,12 @@ func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
// Requested subprotocol version 1 is not supported by test websocket server.
|
||||
requestedProtocol := remotecommand.StreamProtocolV1Name
|
||||
req.Header[httpstream.HeaderProtocolVersion] = []string{requestedProtocol}
|
||||
req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
|
||||
_, err = rt.RoundTrip(req)
|
||||
// Ensure a "bad handshake" error is returned, since requested protocol is not supported.
|
||||
require.Error(t, err)
|
||||
assert.True(t, strings.Contains(err.Error(), "bad handshake"))
|
||||
assert.True(t, httpstream.IsUpgradeFailure(err))
|
||||
}
|
||||
|
||||
func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) {
|
||||
|
Loading…
Reference in New Issue
Block a user