wsstream: use a single approach to detect connection upgrade

Signed-off-by: Monis Khan <mok@microsoft.com>
This commit is contained in:
Monis Khan 2023-08-01 18:37:34 -04:00
parent f0dcf06140
commit 62b063b74b
No known key found for this signature in database
2 changed files with 22 additions and 8 deletions

View File

@ -21,14 +21,14 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"strings"
"time"
"golang.org/x/net/websocket"
"k8s.io/klog/v2"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/klog/v2"
)
// The Websocket subprotocol "channel.k8s.io" prepends each binary message with a byte indicating
@ -77,18 +77,13 @@ const (
ReadWriteChannel
)
var (
// connectionUpgradeRegex matches any Connection header value that includes upgrade
connectionUpgradeRegex = regexp.MustCompile("(^|.*,\\s*)upgrade($|\\s*,)")
)
// IsWebSocketRequest returns true if the incoming request contains connection upgrade headers
// for WebSockets.
func IsWebSocketRequest(req *http.Request) bool {
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
return false
}
return connectionUpgradeRegex.MatchString(strings.ToLower(req.Header.Get("Connection")))
return httpstream.IsUpgradeRequest(req)
}
// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the

View File

@ -46,6 +46,25 @@ func TestAuthenticateRequest(t *testing.T) {
}
}
func TestAuthenticateRequestMultipleConnectionHeaders(t *testing.T) {
auth := NewProtocolAuthenticator(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
if token != "token" {
t.Errorf("unexpected token: %s", token)
}
return &authenticator.Response{User: &user.DefaultInfo{Name: "user"}}, true, nil
}))
resp, ok, err := auth.AuthenticateRequest(&http.Request{
Header: http.Header{
"Connection": []string{"not", "upgrade"},
"Upgrade": []string{"websocket"},
"Sec-Websocket-Protocol": []string{"base64url.bearer.authorization.k8s.io.dG9rZW4,dummy"},
},
})
if !ok || resp == nil || err != nil {
t.Errorf("expected valid user")
}
}
func TestAuthenticateRequestTokenInvalid(t *testing.T) {
auth := NewProtocolAuthenticator(authenticator.TokenFunc(func(ctx context.Context, token string) (*authenticator.Response, bool, error) {
return nil, false, nil