From eec2be81681c5ad147df13aad60a003af0c4c48f Mon Sep 17 00:00:00 2001 From: Sean Sullivan Date: Thu, 19 Sep 2024 15:47:19 -0700 Subject: [PATCH] Upgrade websocket failure add extra error info --- .../transport/websocket/roundtripper.go | 44 ++++++- .../transport/websocket/roundtripper_test.go | 107 ++++++++++++++---- 2 files changed, 124 insertions(+), 27 deletions(-) diff --git a/staging/src/k8s.io/client-go/transport/websocket/roundtripper.go b/staging/src/k8s.io/client-go/transport/websocket/roundtripper.go index 8286a8eb529..924518e8bbd 100644 --- a/staging/src/k8s.io/client-go/transport/websocket/roundtripper.go +++ b/staging/src/k8s.io/client-go/transport/websocket/roundtripper.go @@ -20,11 +20,17 @@ import ( "crypto/tls" "errors" "fmt" + "io" "net/http" "net/url" + "strings" gwebsocket "github.com/gorilla/websocket" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/serializer" "k8s.io/apimachinery/pkg/util/httpstream" "k8s.io/apimachinery/pkg/util/httpstream/wsstream" utilnet "k8s.io/apimachinery/pkg/util/net" @@ -37,6 +43,17 @@ var ( _ http.RoundTripper = &RoundTripper{} ) +var ( + statusScheme = runtime.NewScheme() + statusCodecs = serializer.NewCodecFactory(statusScheme) +) + +func init() { + statusScheme.AddUnversionedTypes(metav1.SchemeGroupVersion, + &metav1.Status{}, + ) +} + // ConnectionHolder defines functions for structure providing // access to the websocket connection. type ConnectionHolder interface { @@ -110,12 +127,33 @@ func (rt *RoundTripper) RoundTrip(request *http.Request) (retResp *http.Response } wsConn, resp, err := dialer.DialContext(request.Context(), request.URL.String(), request.Header) if err != nil { + // BadHandshake error becomes an "UpgradeFailureError" (used for streaming fallback). if errors.Is(err, gwebsocket.ErrBadHandshake) { - // Enhance the error message with the response status if possible. + cause := err + // Enhance the error message with the error response if possible. if resp != nil && len(resp.Status) > 0 { - err = fmt.Errorf("%w (%s)", err, resp.Status) + defer resp.Body.Close() //nolint:errcheck + cause = fmt.Errorf("%w (%s)", err, resp.Status) // Always add the response status + responseError := "" + responseErrorBytes, readErr := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + if readErr != nil { + cause = fmt.Errorf("%w: unable to read error from server response", cause) + } else { + // If returned error can be decoded as "metav1.Status", return a "StatusError". + responseError = strings.TrimSpace(string(responseErrorBytes)) + if len(responseError) > 0 { + if obj, _, decodeErr := statusCodecs.UniversalDecoder().Decode(responseErrorBytes, nil, &metav1.Status{}); decodeErr == nil { + if status, ok := obj.(*metav1.Status); ok { + cause = &apierrors.StatusError{ErrStatus: *status} + } + } else { + // Otherwise, append the responseError string. + cause = fmt.Errorf("%w: %s", cause, responseError) + } + } + } } - return nil, &httpstream.UpgradeFailureError{Cause: err} + return nil, &httpstream.UpgradeFailureError{Cause: cause} } return nil, err } diff --git a/staging/src/k8s.io/client-go/transport/websocket/roundtripper_test.go b/staging/src/k8s.io/client-go/transport/websocket/roundtripper_test.go index 2010c73c34d..6c71b97976b 100644 --- a/staging/src/k8s.io/client-go/transport/websocket/roundtripper_test.go +++ b/staging/src/k8s.io/client-go/transport/websocket/roundtripper_test.go @@ -18,6 +18,7 @@ package websocket import ( "context" + "errors" "io" "net/http" "net/http/httptest" @@ -28,6 +29,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/httpstream" "k8s.io/apimachinery/pkg/util/httpstream/wsstream" "k8s.io/apimachinery/pkg/util/remotecommand" @@ -64,31 +68,86 @@ func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) { } func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) { - // Create fake WebSocket server. - websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - // Bad handshake means websocket server will not completely initialize. - _, err := webSocketServerStreams(req, w) - require.Error(t, err) - assert.ErrorContains(t, err, "websocket server finished before becoming ready") - })) - defer websocketServer.Close() + testCases := map[string]struct { + statusCode int + body string + status *metav1.Status + expectedError string + }{ + "Empty response status still returns basic websocket error": { + statusCode: -1, + body: "", + expectedError: "websocket: bad handshake", + }, + "Empty response body still returns status": { + statusCode: http.StatusForbidden, + body: "", + expectedError: "(403 Forbidden)", + }, + "Error response body returned as string when can not be cast as metav1.Status": { + statusCode: http.StatusForbidden, + body: "RBAC violated", + expectedError: "(403 Forbidden): RBAC violated", + }, + "Error returned as metav1.Status within response body": { + statusCode: http.StatusBadRequest, + body: "", + status: &metav1.Status{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "meta.k8s.io/v1", + Kind: "Status", + }, + Status: "Failure", + Reason: "Unable to negotiate sub-protocol", + Code: http.StatusBadRequest, + }, + }, + } + encoder := statusCodecs.LegacyCodec(metav1.SchemeGroupVersion) + for testName, testCase := range testCases { + t.Run(testName, func(t *testing.T) { + // Create fake WebSocket server. + websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if testCase.statusCode > 0 { + w.WriteHeader(testCase.statusCode) + } + if testCase.status != nil { + statusBytes, err := runtime.Encode(encoder, testCase.status) + require.NoError(t, err) + _, err = w.Write(statusBytes) + require.NoError(t, err) + } else if len(testCase.body) > 0 { + _, err := w.Write([]byte(testCase.body)) + require.NoError(t, err) + } + })) + defer websocketServer.Close() - // Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()". - websocketLocation, err := url.Parse(websocketServer.URL) - require.NoError(t, err) - req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil) - require.NoError(t, err) - rt, _, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) - require.NoError(t, err) - // Requested subprotocol version 1 is not supported by test websocket server. - requestedProtocol := remotecommand.StreamProtocolV1Name - req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol} - _, err = rt.RoundTrip(req) - // Ensure a "bad handshake" error is returned, since requested protocol is not supported. - require.Error(t, err) - assert.ErrorContains(t, err, "websocket: bad handshake") - assert.ErrorContains(t, err, "403 Forbidden") - assert.True(t, httpstream.IsUpgradeFailure(err)) + // Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()". + websocketLocation, err := url.Parse(websocketServer.URL) + require.NoError(t, err) + req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil) + require.NoError(t, err) + rt, _, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host}) + require.NoError(t, err) + _, err = rt.RoundTrip(req) + require.Error(t, err) + assert.True(t, httpstream.IsUpgradeFailure(err)) + if testCase.status != nil { + upgradeErr := &httpstream.UpgradeFailureError{} + validErr := errors.As(err, &upgradeErr) + assert.True(t, validErr, "could not cast error as httpstream.UpgradeFailureError") + statusErr := upgradeErr.Cause + apiErr := &apierrors.StatusError{} + validErr = errors.As(statusErr, &apiErr) + assert.True(t, validErr, "could not cast error as apierrors.StatusError") + assert.Equal(t, *testCase.status, apiErr.ErrStatus) + } else { + assert.Contains(t, err.Error(), testCase.expectedError, + "expected (%s), got (%s)", testCase.expectedError, err.Error()) + } + }) + } } func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) {