diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/request/websocket/BUILD b/staging/src/k8s.io/apiserver/pkg/authentication/request/websocket/BUILD new file mode 100644 index 00000000000..e2275409403 --- /dev/null +++ b/staging/src/k8s.io/apiserver/pkg/authentication/request/websocket/BUILD @@ -0,0 +1,31 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +load( + "@io_bazel_rules_go//go:def.bzl", + "go_library", + "go_test", +) + +go_library( + name = "go_default_library", + srcs = ["protocol.go"], + tags = ["automanaged"], + deps = [ + "//vendor/k8s.io/apiserver/pkg/authentication/authenticator:go_default_library", + "//vendor/k8s.io/apiserver/pkg/authentication/user:go_default_library", + "//vendor/k8s.io/apiserver/pkg/util/wsstream:go_default_library", + ], +) + +go_test( + name = "go_default_test", + srcs = ["protocol_test.go"], + library = ":go_default_library", + tags = ["automanaged"], + deps = [ + "//vendor/k8s.io/apiserver/pkg/authentication/authenticator:go_default_library", + "//vendor/k8s.io/apiserver/pkg/authentication/user:go_default_library", + ], +) diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/request/websocket/protocol.go b/staging/src/k8s.io/apiserver/pkg/authentication/request/websocket/protocol.go new file mode 100644 index 00000000000..e007bf2d57f --- /dev/null +++ b/staging/src/k8s.io/apiserver/pkg/authentication/request/websocket/protocol.go @@ -0,0 +1,109 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package websocket + +import ( + "encoding/base64" + "errors" + "net/http" + "net/textproto" + "strings" + "unicode/utf8" + + "k8s.io/apiserver/pkg/authentication/authenticator" + "k8s.io/apiserver/pkg/authentication/user" + "k8s.io/apiserver/pkg/util/wsstream" +) + +const bearerProtocolPrefix = "base64url.bearer.authorization.k8s.io." + +var protocolHeader = textproto.CanonicalMIMEHeaderKey("Sec-WebSocket-Protocol") + +var invalidToken = errors.New("invalid bearer token") + +// ProtocolAuthenticator allows a websocket connection to provide a bearer token as a subprotocol +// in the format "base64url.bearer.authorization." +type ProtocolAuthenticator struct { + // auth is the token authenticator to use to validate the token + auth authenticator.Token +} + +func NewProtocolAuthenticator(auth authenticator.Token) *ProtocolAuthenticator { + return &ProtocolAuthenticator{auth} +} + +func (a *ProtocolAuthenticator) AuthenticateRequest(req *http.Request) (user.Info, bool, error) { + // Only accept websocket connections + if !wsstream.IsWebSocketRequest(req) { + return nil, false, nil + } + + token := "" + sawTokenProtocol := false + filteredProtocols := []string{} + for _, protocolHeader := range req.Header[protocolHeader] { + for _, protocol := range strings.Split(protocolHeader, ",") { + protocol = strings.TrimSpace(protocol) + + if !strings.HasPrefix(protocol, bearerProtocolPrefix) { + filteredProtocols = append(filteredProtocols, protocol) + continue + } + + if sawTokenProtocol { + return nil, false, errors.New("multiple base64.bearer.authorization tokens specified") + } + sawTokenProtocol = true + + encodedToken := strings.TrimPrefix(protocol, bearerProtocolPrefix) + decodedToken, err := base64.RawURLEncoding.DecodeString(encodedToken) + if err != nil { + return nil, false, errors.New("invalid base64.bearer.authorization token encoding") + } + if !utf8.Valid(decodedToken) { + return nil, false, errors.New("invalid base64.bearer.authorization token") + } + token = string(decodedToken) + } + } + + // Must pass at least one other subprotocol so that we can remove the one containing the bearer token, + // and there is at least one to echo back to the client + if len(token) > 0 && len(filteredProtocols) == 0 { + return nil, false, errors.New("missing additional subprotocol") + } + + if len(token) == 0 { + return nil, false, nil + } + + user, ok, err := a.auth.AuthenticateToken(token) + + // on success, remove the protocol with the token + if ok { + // https://tools.ietf.org/html/rfc6455#section-11.3.4 indicates the Sec-WebSocket-Protocol header may appear multiple times + // in a request, and is logically the same as a single Sec-WebSocket-Protocol header field that contains all values + req.Header.Set(protocolHeader, strings.Join(filteredProtocols, ",")) + } + + // If the token authenticator didn't error, provide a default error + if !ok && err == nil { + err = invalidToken + } + + return user, ok, err +} diff --git a/staging/src/k8s.io/apiserver/pkg/authentication/request/websocket/protocol_test.go b/staging/src/k8s.io/apiserver/pkg/authentication/request/websocket/protocol_test.go new file mode 100644 index 00000000000..2a21aa65d92 --- /dev/null +++ b/staging/src/k8s.io/apiserver/pkg/authentication/request/websocket/protocol_test.go @@ -0,0 +1,222 @@ +/* +Copyright 2017 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package websocket + +import ( + "errors" + "net/http" + "reflect" + "testing" + + "k8s.io/apiserver/pkg/authentication/authenticator" + "k8s.io/apiserver/pkg/authentication/user" +) + +func TestAuthenticateRequest(t *testing.T) { + auth := NewProtocolAuthenticator(authenticator.TokenFunc(func(token string) (user.Info, bool, error) { + if token != "token" { + t.Errorf("unexpected token: %s", token) + } + return &user.DefaultInfo{Name: "user"}, true, nil + })) + user, ok, err := auth.AuthenticateRequest(&http.Request{ + Header: http.Header{ + "Connection": []string{"upgrade"}, + "Upgrade": []string{"websocket"}, + "Sec-Websocket-Protocol": []string{"base64url.bearer.authorization.k8s.io.dG9rZW4,dummy"}, + }, + }) + if !ok || user == nil || err != nil { + t.Errorf("expected valid user") + } +} + +func TestAuthenticateRequestTokenInvalid(t *testing.T) { + auth := NewProtocolAuthenticator(authenticator.TokenFunc(func(token string) (user.Info, bool, error) { + return nil, false, nil + })) + user, ok, err := auth.AuthenticateRequest(&http.Request{ + Header: http.Header{ + "Connection": []string{"upgrade"}, + "Upgrade": []string{"websocket"}, + "Sec-Websocket-Protocol": []string{"base64url.bearer.authorization.k8s.io.dG9rZW4,dummy"}, + }, + }) + if ok || user != nil { + t.Errorf("expected not authenticated user") + } + if err != invalidToken { + t.Errorf("expected invalidToken error, got %v", err) + } +} + +func TestAuthenticateRequestTokenInvalidCustomError(t *testing.T) { + customError := errors.New("custom") + auth := NewProtocolAuthenticator(authenticator.TokenFunc(func(token string) (user.Info, bool, error) { + return nil, false, customError + })) + user, ok, err := auth.AuthenticateRequest(&http.Request{ + Header: http.Header{ + "Connection": []string{"upgrade"}, + "Upgrade": []string{"websocket"}, + "Sec-Websocket-Protocol": []string{"base64url.bearer.authorization.k8s.io.dG9rZW4,dummy"}, + }, + }) + if ok || user != nil { + t.Errorf("expected not authenticated user") + } + if err != customError { + t.Errorf("expected custom error, got %v", err) + } +} + +func TestAuthenticateRequestTokenError(t *testing.T) { + auth := NewProtocolAuthenticator(authenticator.TokenFunc(func(token string) (user.Info, bool, error) { + return nil, false, errors.New("error") + })) + user, ok, err := auth.AuthenticateRequest(&http.Request{ + Header: http.Header{ + "Connection": []string{"upgrade"}, + "Upgrade": []string{"websocket"}, + "Sec-Websocket-Protocol": []string{"base64url.bearer.authorization.k8s.io.dG9rZW4,dummy"}, + }, + }) + if ok || user != nil || err == nil { + t.Errorf("expected error") + } +} + +func TestAuthenticateRequestBadValue(t *testing.T) { + testCases := []struct { + Req *http.Request + }{ + {Req: &http.Request{}}, + {Req: &http.Request{Header: http.Header{ + "Connection": []string{"upgrade"}, + "Upgrade": []string{"websocket"}, + "Sec-Websocket-Protocol": []string{"other-protocol"}}}, + }, + {Req: &http.Request{Header: http.Header{ + "Connection": []string{"upgrade"}, + "Upgrade": []string{"websocket"}, + "Sec-Websocket-Protocol": []string{"base64url.bearer.authorization.k8s.io."}}}, + }, + } + for i, testCase := range testCases { + auth := NewProtocolAuthenticator(authenticator.TokenFunc(func(token string) (user.Info, bool, error) { + t.Errorf("authentication should not have been called") + return nil, false, nil + })) + user, ok, err := auth.AuthenticateRequest(testCase.Req) + if ok || user != nil || err != nil { + t.Errorf("%d: expected not authenticated (no token)", i) + } + } +} + +func TestBearerToken(t *testing.T) { + tests := map[string]struct { + ProtocolHeaders []string + TokenAuth authenticator.Token + + ExpectedUserName string + ExpectedOK bool + ExpectedErr bool + ExpectedProtocolHeaders []string + }{ + "no header": { + ProtocolHeaders: nil, + ExpectedUserName: "", + ExpectedOK: false, + ExpectedErr: false, + ExpectedProtocolHeaders: nil, + }, + "empty header": { + ProtocolHeaders: []string{""}, + ExpectedUserName: "", + ExpectedOK: false, + ExpectedErr: false, + ExpectedProtocolHeaders: []string{""}, + }, + "non-bearer header": { + ProtocolHeaders: []string{"undefined"}, + ExpectedUserName: "", + ExpectedOK: false, + ExpectedErr: false, + ExpectedProtocolHeaders: []string{"undefined"}, + }, + "empty bearer token": { + ProtocolHeaders: []string{"base64url.bearer.authorization.k8s.io."}, + ExpectedUserName: "", + ExpectedOK: false, + ExpectedErr: false, + ExpectedProtocolHeaders: []string{"base64url.bearer.authorization.k8s.io."}, + }, + "valid bearer token removing header": { + ProtocolHeaders: []string{"base64url.bearer.authorization.k8s.io.dG9rZW4", "dummy, dummy2"}, + TokenAuth: authenticator.TokenFunc(func(t string) (user.Info, bool, error) { return &user.DefaultInfo{Name: "myuser"}, true, nil }), + ExpectedUserName: "myuser", + ExpectedOK: true, + ExpectedErr: false, + ExpectedProtocolHeaders: []string{"dummy,dummy2"}, + }, + "invalid bearer token": { + ProtocolHeaders: []string{"base64url.bearer.authorization.k8s.io.dG9rZW4,dummy"}, + TokenAuth: authenticator.TokenFunc(func(t string) (user.Info, bool, error) { return nil, false, nil }), + ExpectedUserName: "", + ExpectedOK: false, + ExpectedErr: true, + ExpectedProtocolHeaders: []string{"base64url.bearer.authorization.k8s.io.dG9rZW4,dummy"}, + }, + "error bearer token": { + ProtocolHeaders: []string{"base64url.bearer.authorization.k8s.io.dG9rZW4,dummy"}, + TokenAuth: authenticator.TokenFunc(func(t string) (user.Info, bool, error) { return nil, false, errors.New("error") }), + ExpectedUserName: "", + ExpectedOK: false, + ExpectedErr: true, + ExpectedProtocolHeaders: []string{"base64url.bearer.authorization.k8s.io.dG9rZW4,dummy"}, + }, + } + + for k, tc := range tests { + req, _ := http.NewRequest("GET", "/", nil) + req.Header.Set("Connection", "upgrade") + req.Header.Set("Upgrade", "websocket") + for _, h := range tc.ProtocolHeaders { + req.Header.Add("Sec-Websocket-Protocol", h) + } + + bearerAuth := NewProtocolAuthenticator(tc.TokenAuth) + u, ok, err := bearerAuth.AuthenticateRequest(req) + if tc.ExpectedErr != (err != nil) { + t.Errorf("%s: Expected err=%v, got %v", k, tc.ExpectedErr, err) + continue + } + if ok != tc.ExpectedOK { + t.Errorf("%s: Expected ok=%v, got %v", k, tc.ExpectedOK, ok) + continue + } + if ok && u.GetName() != tc.ExpectedUserName { + t.Errorf("%s: Expected username=%v, got %v", k, tc.ExpectedUserName, u.GetName()) + continue + } + if !reflect.DeepEqual(req.Header["Sec-Websocket-Protocol"], tc.ExpectedProtocolHeaders) { + t.Errorf("%s: Expected headers=%#v, got %#v", k, tc.ExpectedProtocolHeaders, req.Header["Sec-Websocket-Protocol"]) + continue + } + } +} diff --git a/staging/src/k8s.io/apiserver/pkg/util/wsstream/conn.go b/staging/src/k8s.io/apiserver/pkg/util/wsstream/conn.go index f01638ad6d3..6f26b227579 100644 --- a/staging/src/k8s.io/apiserver/pkg/util/wsstream/conn.go +++ b/staging/src/k8s.io/apiserver/pkg/util/wsstream/conn.go @@ -87,7 +87,10 @@ var ( // IsWebSocketRequest returns true if the incoming request contains connection upgrade headers // for WebSockets. func IsWebSocketRequest(req *http.Request) bool { - return connectionUpgradeRegex.MatchString(strings.ToLower(req.Header.Get("Connection"))) && strings.ToLower(req.Header.Get("Upgrade")) == "websocket" + if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") { + return false + } + return connectionUpgradeRegex.MatchString(strings.ToLower(req.Header.Get("Connection"))) } // IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the