diff --git a/pkg/controlplane/apiserver/config.go b/pkg/controlplane/apiserver/config.go index fa2ac67f5e2..c204e5058ef 100644 --- a/pkg/controlplane/apiserver/config.go +++ b/pkg/controlplane/apiserver/config.go @@ -94,9 +94,9 @@ type Extra struct { ExtendExpiration bool // ServiceAccountIssuerDiscovery - ServiceAccountIssuerURL string - ServiceAccountJWKSURI string - ServiceAccountPublicKeys []interface{} + ServiceAccountIssuerURL string + ServiceAccountJWKSURI string + ServiceAccountPublicKeysGetter serviceaccount.PublicKeysGetter SystemNamespaces []string @@ -368,18 +368,24 @@ func CreateConfig( return nil, nil, fmt.Errorf("failed to apply admission: %w", err) } - // Load and set the public keys. - var pubKeys []interface{} - for _, f := range opts.Authentication.ServiceAccounts.KeyFiles { - keys, err := keyutil.PublicKeysFromFile(f) - if err != nil { - return nil, nil, fmt.Errorf("failed to parse key file %q: %w", f, err) + if len(opts.Authentication.ServiceAccounts.KeyFiles) > 0 { + // Load and set the public keys. + var pubKeys []interface{} + for _, f := range opts.Authentication.ServiceAccounts.KeyFiles { + keys, err := keyutil.PublicKeysFromFile(f) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse key file %q: %w", f, err) + } + pubKeys = append(pubKeys, keys...) } - pubKeys = append(pubKeys, keys...) + keysGetter, err := serviceaccount.StaticPublicKeysGetter(pubKeys) + if err != nil { + return nil, nil, fmt.Errorf("failed to set up public service account keys: %w", err) + } + config.ServiceAccountPublicKeysGetter = keysGetter } config.ServiceAccountIssuerURL = opts.Authentication.ServiceAccounts.Issuers[0] config.ServiceAccountJWKSURI = opts.Authentication.ServiceAccounts.JWKSURI - config.ServiceAccountPublicKeys = pubKeys return config, genericInitializers, nil } diff --git a/pkg/controlplane/apiserver/server.go b/pkg/controlplane/apiserver/server.go index 9e426062240..b22eb8ba562 100644 --- a/pkg/controlplane/apiserver/server.go +++ b/pkg/controlplane/apiserver/server.go @@ -93,13 +93,11 @@ func (c completedConfig) New(name string, delegationTarget genericapiserver.Dele routes.Logs{}.Install(generic.Handler.GoRestfulContainer) } - // Metadata and keys are expected to only change across restarts at present, - // so we just marshal immediately and serve the cached JSON bytes. - md, err := serviceaccount.NewOpenIDMetadata( + md, err := serviceaccount.NewOpenIDMetadataProvider( c.ServiceAccountIssuerURL, c.ServiceAccountJWKSURI, c.Generic.ExternalAddress, - c.ServiceAccountPublicKeys, + c.ServiceAccountPublicKeysGetter, ) if err != nil { // If there was an error, skip installing the endpoints and log the @@ -120,8 +118,7 @@ func (c completedConfig) New(name string, delegationTarget genericapiserver.Dele klog.Info(msg) } } else { - routes.NewOpenIDMetadataServer(md.ConfigJSON, md.PublicKeysetJSON). - Install(generic.Handler.GoRestfulContainer) + routes.NewOpenIDMetadataServer(md).Install(generic.Handler.GoRestfulContainer) } s := &Server{ diff --git a/pkg/kubeapiserver/authenticator/config.go b/pkg/kubeapiserver/authenticator/config.go index bd7ed0e1496..d0b1414c84b 100644 --- a/pkg/kubeapiserver/authenticator/config.go +++ b/pkg/kubeapiserver/authenticator/config.go @@ -62,7 +62,6 @@ type Config struct { AuthenticationConfig *apiserver.AuthenticationConfiguration AuthenticationConfigData string OIDCSigningAlgs []string - ServiceAccountKeyFiles []string ServiceAccountLookup bool ServiceAccountIssuers []string APIAudiences authenticator.Audiences @@ -79,7 +78,9 @@ type Config struct { RequestHeaderConfig *authenticatorfactory.RequestHeaderConfig - // TODO, this is the only non-serializable part of the entire config. Factor it out into a clientconfig + // ServiceAccountPublicKeysGetter returns public keys for verifying service account tokens. + ServiceAccountPublicKeysGetter serviceaccount.PublicKeysGetter + // ServiceAccountTokenGetter fetches API objects used to verify bound objects in service account token claims. ServiceAccountTokenGetter serviceaccount.ServiceAccountTokenGetter SecretsWriter typedv1core.SecretsGetter BootstrapTokenAuthenticator authenticator.Token @@ -127,15 +128,15 @@ func (config Config) New(serverLifecycle context.Context) (authenticator.Request } tokenAuthenticators = append(tokenAuthenticators, authenticator.WrapAudienceAgnosticToken(config.APIAudiences, tokenAuth)) } - if len(config.ServiceAccountKeyFiles) > 0 { - serviceAccountAuth, err := newLegacyServiceAccountAuthenticator(config.ServiceAccountKeyFiles, config.ServiceAccountLookup, config.APIAudiences, config.ServiceAccountTokenGetter, config.SecretsWriter) + if config.ServiceAccountPublicKeysGetter != nil { + serviceAccountAuth, err := newLegacyServiceAccountAuthenticator(config.ServiceAccountPublicKeysGetter, config.ServiceAccountLookup, config.APIAudiences, config.ServiceAccountTokenGetter, config.SecretsWriter) if err != nil { return nil, nil, nil, nil, err } tokenAuthenticators = append(tokenAuthenticators, serviceAccountAuth) } if len(config.ServiceAccountIssuers) > 0 { - serviceAccountAuth, err := newServiceAccountAuthenticator(config.ServiceAccountIssuers, config.ServiceAccountKeyFiles, config.APIAudiences, config.ServiceAccountTokenGetter) + serviceAccountAuth, err := newServiceAccountAuthenticator(config.ServiceAccountIssuers, config.ServiceAccountPublicKeysGetter, config.APIAudiences, config.ServiceAccountTokenGetter) if err != nil { return nil, nil, nil, nil, err } @@ -338,36 +339,25 @@ func newAuthenticatorFromTokenFile(tokenAuthFile string) (authenticator.Token, e } // newLegacyServiceAccountAuthenticator returns an authenticator.Token or an error -func newLegacyServiceAccountAuthenticator(keyfiles []string, lookup bool, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter, secretsWriter typedv1core.SecretsGetter) (authenticator.Token, error) { - allPublicKeys := []interface{}{} - for _, keyfile := range keyfiles { - publicKeys, err := keyutil.PublicKeysFromFile(keyfile) - if err != nil { - return nil, err - } - allPublicKeys = append(allPublicKeys, publicKeys...) +func newLegacyServiceAccountAuthenticator(publicKeysGetter serviceaccount.PublicKeysGetter, lookup bool, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter, secretsWriter typedv1core.SecretsGetter) (authenticator.Token, error) { + if publicKeysGetter == nil { + return nil, fmt.Errorf("no public key getter provided") } validator, err := serviceaccount.NewLegacyValidator(lookup, serviceAccountGetter, secretsWriter) if err != nil { return nil, fmt.Errorf("while creating legacy validator, err: %w", err) } - tokenAuthenticator := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer}, allPublicKeys, apiAudiences, validator) + tokenAuthenticator := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer}, publicKeysGetter, apiAudiences, validator) return tokenAuthenticator, nil } // newServiceAccountAuthenticator returns an authenticator.Token or an error -func newServiceAccountAuthenticator(issuers []string, keyfiles []string, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter) (authenticator.Token, error) { - allPublicKeys := []interface{}{} - for _, keyfile := range keyfiles { - publicKeys, err := keyutil.PublicKeysFromFile(keyfile) - if err != nil { - return nil, err - } - allPublicKeys = append(allPublicKeys, publicKeys...) +func newServiceAccountAuthenticator(issuers []string, publicKeysGetter serviceaccount.PublicKeysGetter, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter) (authenticator.Token, error) { + if publicKeysGetter == nil { + return nil, fmt.Errorf("no public key getter provided") } - - tokenAuthenticator := serviceaccount.JWTTokenAuthenticator(issuers, allPublicKeys, apiAudiences, serviceaccount.NewValidator(serviceAccountGetter)) + tokenAuthenticator := serviceaccount.JWTTokenAuthenticator(issuers, publicKeysGetter, apiAudiences, serviceaccount.NewValidator(serviceAccountGetter)) return tokenAuthenticator, nil } diff --git a/pkg/kubeapiserver/options/authentication.go b/pkg/kubeapiserver/options/authentication.go index 4b24f28e01f..6ad4af1e100 100644 --- a/pkg/kubeapiserver/options/authentication.go +++ b/pkg/kubeapiserver/options/authentication.go @@ -47,6 +47,7 @@ import ( "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes" v1listers "k8s.io/client-go/listers/core/v1" + "k8s.io/client-go/util/keyutil" cliflag "k8s.io/component-base/cli/flag" "k8s.io/klog/v2" openapicommon "k8s.io/kube-openapi/pkg/common" @@ -54,6 +55,7 @@ import ( "k8s.io/kubernetes/pkg/features" kubeauthenticator "k8s.io/kubernetes/pkg/kubeapiserver/authenticator" authzmodes "k8s.io/kubernetes/pkg/kubeapiserver/authorizer/modes" + "k8s.io/kubernetes/pkg/serviceaccount" "k8s.io/kubernetes/pkg/util/filesystem" "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/bootstrap" "k8s.io/utils/pointer" @@ -559,7 +561,21 @@ func (o *BuiltInAuthenticationOptions) ToAuthenticationConfig() (kubeauthenticat if len(o.ServiceAccounts.Issuers) != 0 && len(o.APIAudiences) == 0 { ret.APIAudiences = authenticator.Audiences(o.ServiceAccounts.Issuers) } - ret.ServiceAccountKeyFiles = o.ServiceAccounts.KeyFiles + if len(o.ServiceAccounts.KeyFiles) > 0 { + allPublicKeys := []interface{}{} + for _, keyfile := range o.ServiceAccounts.KeyFiles { + publicKeys, err := keyutil.PublicKeysFromFile(keyfile) + if err != nil { + return kubeauthenticator.Config{}, err + } + allPublicKeys = append(allPublicKeys, publicKeys...) + } + keysGetter, err := serviceaccount.StaticPublicKeysGetter(allPublicKeys) + if err != nil { + return kubeauthenticator.Config{}, fmt.Errorf("failed to set up public service account keys: %w", err) + } + ret.ServiceAccountPublicKeysGetter = keysGetter + } ret.ServiceAccountIssuers = o.ServiceAccounts.Issuers ret.ServiceAccountLookup = o.ServiceAccounts.Lookup } diff --git a/pkg/routes/openidmetadata.go b/pkg/routes/openidmetadata.go index aafd76e1069..374cf4af9e9 100644 --- a/pkg/routes/openidmetadata.go +++ b/pkg/routes/openidmetadata.go @@ -17,6 +17,7 @@ limitations under the License. package routes import ( + "fmt" "net/http" restful "github.com/emicklei/go-restful/v3" @@ -34,7 +35,8 @@ const ( // cacheControl is the value of the Cache-Control header. Overrides the // global `private, no-cache` setting. headerCacheControl = "Cache-Control" - cacheControl = "public, max-age=3600" // 1 hour + + cacheControlTemplate = "public, max-age=%d" // mimeJWKS is the content type of the keyset response mimeJWKS = "application/jwk-set+json" @@ -42,18 +44,14 @@ const ( // OpenIDMetadataServer is an HTTP server for metadata of the KSA token issuer. type OpenIDMetadataServer struct { - configJSON []byte - keysetJSON []byte + provider serviceaccount.OpenIDMetadataProvider } // NewOpenIDMetadataServer creates a new OpenIDMetadataServer. // The issuer is the OIDC issuer; keys are the keys that may be used to sign // KSA tokens. -func NewOpenIDMetadataServer(configJSON, keysetJSON []byte) *OpenIDMetadataServer { - return &OpenIDMetadataServer{ - configJSON: configJSON, - keysetJSON: keysetJSON, - } +func NewOpenIDMetadataServer(provider serviceaccount.OpenIDMetadataProvider) *OpenIDMetadataServer { + return &OpenIDMetadataServer{provider: provider} } // Install adds this server to the request router c. @@ -95,19 +93,21 @@ func fromStandard(h http.HandlerFunc) restful.RouteFunction { } func (s *OpenIDMetadataServer) serveConfiguration(w http.ResponseWriter, req *http.Request) { + configJSON, maxAge := s.provider.GetConfigJSON() w.Header().Set(restful.HEADER_ContentType, restful.MIME_JSON) - w.Header().Set(headerCacheControl, cacheControl) - if _, err := w.Write(s.configJSON); err != nil { + w.Header().Set(headerCacheControl, fmt.Sprintf(cacheControlTemplate, maxAge)) + if _, err := w.Write(configJSON); err != nil { klog.Errorf("failed to write service account issuer metadata response: %v", err) return } } func (s *OpenIDMetadataServer) serveKeys(w http.ResponseWriter, req *http.Request) { + keysetJSON, maxAge := s.provider.GetKeysetJSON() // Per RFC7517 : https://tools.ietf.org/html/rfc7517#section-8.5.1 w.Header().Set(restful.HEADER_ContentType, mimeJWKS) - w.Header().Set(headerCacheControl, cacheControl) - if _, err := w.Write(s.keysetJSON); err != nil { + w.Header().Set(headerCacheControl, fmt.Sprintf(cacheControlTemplate, maxAge)) + if _, err := w.Write(keysetJSON); err != nil { klog.Errorf("failed to write service account issuer JWKS response: %v", err) return } diff --git a/pkg/serviceaccount/jwt.go b/pkg/serviceaccount/jwt.go index d6168e3411b..ac075bc6c90 100644 --- a/pkg/serviceaccount/jwt.go +++ b/pkg/serviceaccount/jwt.go @@ -225,22 +225,97 @@ func (j *jwtTokenGenerator) GenerateToken(claims *jwt.Claims, privateClaims inte // JWTTokenAuthenticator authenticates tokens as JWT tokens produced by JWTTokenGenerator // Token signatures are verified using each of the given public keys until one works (allowing key rotation) // If lookup is true, the service account and secret referenced as claims inside the token are retrieved and verified with the provided ServiceAccountTokenGetter -func JWTTokenAuthenticator[PrivateClaims any](issuers []string, keys []interface{}, implicitAuds authenticator.Audiences, validator Validator[PrivateClaims]) authenticator.Token { +func JWTTokenAuthenticator[PrivateClaims any](issuers []string, publicKeysGetter PublicKeysGetter, implicitAuds authenticator.Audiences, validator Validator[PrivateClaims]) authenticator.Token { issuersMap := make(map[string]bool) for _, issuer := range issuers { issuersMap[issuer] = true } return &jwtTokenAuthenticator[PrivateClaims]{ issuers: issuersMap, - keys: keys, + keysGetter: publicKeysGetter, implicitAuds: implicitAuds, validator: validator, } } +// Listener is an interface to use to notify interested parties of a change. +type Listener interface { + // Enqueue should be called when an input may have changed + Enqueue() +} + +// PublicKeysGetter returns public keys for a given key id. +type PublicKeysGetter interface { + // AddListener is adds a listener to be notified of potential input changes. + // This is a noop on static providers. + AddListener(listener Listener) + + // GetCacheAgeMaxSeconds returns the seconds a call to GetPublicKeys() can be cached for. + // If the results of GetPublicKeys() can be dynamic, this means a new key must be included in the results + // for at least this long before it is used to sign new tokens. + GetCacheAgeMaxSeconds() int + + // GetPublicKeys returns public keys to use for verifying a token with the given key id. + // keyIDHint may be empty if the token did not have a kid header, or if all public keys are desired. + GetPublicKeys(keyIDHint string) []PublicKey +} + +type PublicKey struct { + KeyID string + PublicKey interface{} +} + +type staticPublicKeysGetter struct { + allPublicKeys []PublicKey + publicKeysByID map[string][]PublicKey +} + +// StaticPublicKeysGetter constructs an implementation of PublicKeysGetter +// which returns all public keys when key id is unspecified, and returns +// the public keys matching the keyIDFromPublicKey-derived key id when +// a key id is specified. +func StaticPublicKeysGetter(keys []interface{}) (PublicKeysGetter, error) { + allPublicKeys := []PublicKey{} + publicKeysByID := map[string][]PublicKey{} + for _, key := range keys { + if privateKey, isPrivateKey := key.(publicKeyGetter); isPrivateKey { + // This is a private key. Extract its public key. + key = privateKey.Public() + } + + keyID, err := keyIDFromPublicKey(key) + if err != nil { + return nil, err + } + pk := PublicKey{PublicKey: key, KeyID: keyID} + publicKeysByID[keyID] = append(publicKeysByID[keyID], pk) + allPublicKeys = append(allPublicKeys, pk) + } + return &staticPublicKeysGetter{ + allPublicKeys: allPublicKeys, + publicKeysByID: publicKeysByID, + }, nil +} + +func (s staticPublicKeysGetter) AddListener(listener Listener) { + // no-op, static key content never changes +} + +func (s staticPublicKeysGetter) GetCacheAgeMaxSeconds() int { + // hard-coded to match cache max-age set in OIDC discovery + return 3600 +} + +func (s staticPublicKeysGetter) GetPublicKeys(keyID string) []PublicKey { + if len(keyID) == 0 { + return s.allPublicKeys + } + return s.publicKeysByID[keyID] +} + type jwtTokenAuthenticator[PrivateClaims any] struct { issuers map[string]bool - keys []interface{} + keysGetter PublicKeysGetter validator Validator[PrivateClaims] implicitAuds authenticator.Audiences } @@ -269,13 +344,25 @@ func (j *jwtTokenAuthenticator[PrivateClaims]) AuthenticateToken(ctx context.Con public := &jwt.Claims{} private := new(PrivateClaims) - // TODO: Pick the key that has the same key ID as `tok`, if one exists. + // Pick the key that has the same key ID as `tok`, if one exists. + var kid string + for _, header := range tok.Headers { + if header.KeyID != "" { + kid = header.KeyID + break + } + } + var ( found bool errlist []error ) - for _, key := range j.keys { - if err := tok.Claims(key, public, private); err != nil { + keys := j.keysGetter.GetPublicKeys(kid) + if len(keys) == 0 { + return nil, false, fmt.Errorf("invalid signature, no keys found") + } + for _, key := range keys { + if err := tok.Claims(key.PublicKey, public, private); err != nil { errlist = append(errlist, err) continue } diff --git a/pkg/serviceaccount/jwt_test.go b/pkg/serviceaccount/jwt_test.go index f8e41079eb4..446ed567fa9 100644 --- a/pkg/serviceaccount/jwt_test.go +++ b/pkg/serviceaccount/jwt_test.go @@ -247,7 +247,7 @@ func TestTokenGenerateAndValidate(t *testing.T) { Token: rsaToken, Client: nil, Keys: []interface{}{}, - ExpectedErr: false, + ExpectedErr: true, ExpectedOK: false, }, "invalid keys (rsa)": { @@ -385,7 +385,13 @@ func TestTokenGenerateAndValidate(t *testing.T) { if err != nil { t.Fatalf("While creating legacy validator, err: %v", err) } - authn := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer, "bar"}, tc.Keys, auds, validator) + staticKeysGetter, err := serviceaccount.StaticPublicKeysGetter(tc.Keys) + if err != nil { + t.Fatal(err) + } + keysGetter := &keyIDPrefixer{PublicKeysGetter: staticKeysGetter} + + authn := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer, "bar"}, keysGetter, auds, validator) // An invalid, non-JWT token should always fail ctx := authenticator.WithAudiences(context.Background(), auds) @@ -394,6 +400,16 @@ func TestTokenGenerateAndValidate(t *testing.T) { continue } + if tc.ExpectedOK { + // if authentication is otherwise expected to succeed, demonstrate changing key ids makes it fail + keysGetter.keyIDPrefix = "bogus" + if _, ok, err := authn.AuthenticateToken(ctx, tc.Token); err == nil || !strings.Contains(err.Error(), "no keys found") || ok { + t.Errorf("%s: Expected err containing 'no keys found', ok=false when key lookup by ID fails", k) + continue + } + keysGetter.keyIDPrefix = "" + } + resp, ok, err := authn.AuthenticateToken(ctx, tc.Token) if (err != nil) != tc.ExpectedErr { t.Errorf("%s: Expected error=%v, got %v", k, tc.ExpectedErr, err) @@ -424,6 +440,26 @@ func TestTokenGenerateAndValidate(t *testing.T) { } } +type keyIDPrefixer struct { + serviceaccount.PublicKeysGetter + keyIDPrefix string +} + +func (k *keyIDPrefixer) GetPublicKeys(keyIDHint string) []serviceaccount.PublicKey { + if k.keyIDPrefix == "" { + return k.PublicKeysGetter.GetPublicKeys(keyIDHint) + } + if keyIDHint != "" { + keyIDHint = k.keyIDPrefix + keyIDHint + } + var retval []serviceaccount.PublicKey + for _, key := range k.PublicKeysGetter.GetPublicKeys(keyIDHint) { + key.KeyID = k.keyIDPrefix + key.KeyID + retval = append(retval, key) + } + return retval +} + func checkJSONWebSignatureHasKeyID(t *testing.T, jwsString string, expectedKeyID string) { jws, err := jose.ParseSigned(jwsString) if err != nil { @@ -502,3 +538,76 @@ func generateECDSATokenWithMalformedIss(t *testing.T, serviceAccount *v1.Service return string(out) } + +func TestStaticPublicKeysGetter(t *testing.T) { + ecPrivate := getPrivateKey(ecdsaPrivateKey) + ecPublic := getPublicKey(ecdsaPublicKey) + rsaPublic := getPublicKey(rsaPublicKey) + + testcases := []struct { + Name string + Keys []interface{} + ExpectErr bool + ExpectKeys []serviceaccount.PublicKey + }{ + { + Name: "empty", + Keys: nil, + ExpectKeys: []serviceaccount.PublicKey{}, + }, + { + Name: "simple", + Keys: []interface{}{ecPublic, rsaPublic}, + ExpectKeys: []serviceaccount.PublicKey{ + {KeyID: "SoABiieYuNx4UdqYvZRVeuC6SihxgLrhLy9peHMHpTc", PublicKey: ecPublic}, + {KeyID: "JHJehTTTZlsspKHT-GaJxK7Kd1NQgZJu3fyK6K_QDYU", PublicKey: rsaPublic}, + }, + }, + { + Name: "private --> public", + Keys: []interface{}{ecPrivate}, + ExpectKeys: []serviceaccount.PublicKey{ + {KeyID: "SoABiieYuNx4UdqYvZRVeuC6SihxgLrhLy9peHMHpTc", PublicKey: ecPublic}, + }, + }, + { + Name: "invalid", + Keys: []interface{}{"bogus"}, + ExpectErr: true, + }, + } + + for _, tc := range testcases { + t.Run(tc.Name, func(t *testing.T) { + getter, err := serviceaccount.StaticPublicKeysGetter(tc.Keys) + if tc.ExpectErr { + if err == nil { + t.Fatal("expected construction error, got none") + } + return + } + if err != nil { + t.Fatalf("unexpected construction error: %v", err) + } + + bogusKeys := getter.GetPublicKeys("bogus") + if len(bogusKeys) != 0 { + t.Fatalf("unexpected bogus keys: %#v", bogusKeys) + } + + allKeys := getter.GetPublicKeys("") + if !reflect.DeepEqual(tc.ExpectKeys, allKeys) { + t.Fatalf("unexpected keys: %#v", allKeys) + } + for _, key := range allKeys { + keysByID := getter.GetPublicKeys(key.KeyID) + if len(keysByID) != 1 { + t.Fatalf("expected 1 key for id %s, got %d", key.KeyID, len(keysByID)) + } + if !reflect.DeepEqual(key, keysByID[0]) { + t.Fatalf("unexpected key for id %s", key.KeyID) + } + } + }) + } +} diff --git a/pkg/serviceaccount/keyid_test.go b/pkg/serviceaccount/keyid_test.go new file mode 100644 index 00000000000..4d19c2cd709 --- /dev/null +++ b/pkg/serviceaccount/keyid_test.go @@ -0,0 +1,49 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package serviceaccount + +import ( + "testing" + + "k8s.io/client-go/util/keyutil" +) + +const rsaPublicKey = `-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA249XwEo9k4tM8fMxV7zx +OhcrP+WvXn917koM5Qr2ZXs4vo26e4ytdlrV0bQ9SlcLpQVSYjIxNfhTZdDt+ecI +zshKuv1gKIxbbLQMOuK1eA/4HALyEkFgmS/tleLJrhc65tKPMGD+pKQ/xhmzRuCG +51RoiMgbQxaCyYxGfNLpLAZK9L0Tctv9a0mJmGIYnIOQM4kC1A1I1n3EsXMWmeJU +j7OTh/AjjCnMnkgvKT2tpKxYQ59PgDgU8Ssc7RDSmSkLxnrv+OrN80j6xrw0OjEi +B4Ycr0PqfzZcvy8efTtFQ/Jnc4Bp1zUtFXt7+QeevePtQ2EcyELXE0i63T1CujRM +WwIDAQAB +-----END PUBLIC KEY----- +` + +func TestKeyIDStability(t *testing.T) { + keys, err := keyutil.ParsePublicKeysPEM([]byte(rsaPublicKey)) + if err != nil { + t.Fatal(err) + } + keyID, err := keyIDFromPublicKey(keys[0]) + if err != nil { + t.Fatal(err) + } + // The derived key id for a given public key must not change or validation of previously issued tokens will fail to find associated keys + if expected, actual := "JHJehTTTZlsspKHT-GaJxK7Kd1NQgZJu3fyK6K_QDYU", keyID; expected != actual { + t.Fatalf("expected stable key id %q, got %q", expected, actual) + } +} diff --git a/pkg/serviceaccount/openidmetadata.go b/pkg/serviceaccount/openidmetadata.go index 56ec23d118a..9a58f967be2 100644 --- a/pkg/serviceaccount/openidmetadata.go +++ b/pkg/serviceaccount/openidmetadata.go @@ -24,11 +24,13 @@ import ( "encoding/json" "fmt" "net/url" + "sync/atomic" jose "gopkg.in/square/go-jose.v2" "k8s.io/apimachinery/pkg/util/errors" "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/klog/v2" ) const ( @@ -44,26 +46,68 @@ const ( JWKSPath = "/openid/v1/jwks" ) -// OpenIDMetadata contains the pre-rendered responses for OIDC discovery endpoints. -type OpenIDMetadata struct { - ConfigJSON []byte - PublicKeysetJSON []byte +// OpenIDMetadataProvider returns pre-rendered responses for OIDC discovery endpoints. +type OpenIDMetadataProvider interface { + GetConfigJSON() (json []byte, maxAge int) + GetKeysetJSON() (json []byte, maxAge int) } -// NewOpenIDMetadata returns the pre-rendered JSON responses for the OIDC discovery +type openidConfigProvider struct { + issuerURL, jwksURI string + pubKeyGetter PublicKeysGetter + config atomic.Pointer[openidConfig] +} +type openidConfig struct { + configJSON []byte + keysetJSON []byte +} + +func (p *openidConfigProvider) GetConfigJSON() ([]byte, int) { + return p.config.Load().configJSON, p.pubKeyGetter.GetCacheAgeMaxSeconds() +} +func (p *openidConfigProvider) GetKeysetJSON() ([]byte, int) { + return p.config.Load().keysetJSON, p.pubKeyGetter.GetCacheAgeMaxSeconds() +} +func (p *openidConfigProvider) Enqueue() { + err := p.Update() + if err != nil { + klog.ErrorS(err, "failed to update openid config metadata") + } +} +func (p *openidConfigProvider) Update() error { + pubKeys := p.pubKeyGetter.GetPublicKeys("") + if len(pubKeys) == 0 { + return fmt.Errorf("no keys provided for validating keyset") + } + configJSON, err := openIDConfigJSON(p.issuerURL, p.jwksURI, pubKeys) + if err != nil { + return fmt.Errorf("could not marshal issuer discovery JSON, error: %w", err) + } + keysetJSON, err := openIDKeysetJSON(pubKeys) + if err != nil { + return fmt.Errorf("could not marshal issuer keys JSON, error: %w", err) + } + p.config.Store(&openidConfig{ + configJSON: configJSON, + keysetJSON: keysetJSON, + }) + return nil +} + +// NewOpenIDMetadataProvider returns a provider for the OIDC discovery // endpoints, or an error if they could not be constructed. Callers should note // that this function may perform additional validation on inputs that is not // backwards-compatible with all command-line validation. The recommendation is // to log the error and skip installing the OIDC discovery endpoints. -func NewOpenIDMetadata(issuerURL, jwksURI, defaultExternalAddress string, pubKeys []interface{}) (*OpenIDMetadata, error) { +func NewOpenIDMetadataProvider(issuerURL, jwksURI, defaultExternalAddress string, pubKeyGetter PublicKeysGetter) (OpenIDMetadataProvider, error) { if issuerURL == "" { return nil, fmt.Errorf("empty issuer URL") } if jwksURI == "" && defaultExternalAddress == "" { return nil, fmt.Errorf("either the JWKS URI or the default external address, or both, must be set") } - if len(pubKeys) == 0 { - return nil, fmt.Errorf("no keys provided for validating keyset") + if pubKeyGetter == nil { + return nil, fmt.Errorf("no public key getter provided") } // Ensure the issuer URL meets the OIDC spec (this is the additional @@ -126,20 +170,18 @@ func NewOpenIDMetadata(issuerURL, jwksURI, defaultExternalAddress string, pubKey } } - configJSON, err := openIDConfigJSON(issuerURL, jwksURI, pubKeys) - if err != nil { - return nil, fmt.Errorf("could not marshal issuer discovery JSON, error: %v", err) + provider := &openidConfigProvider{ + issuerURL: issuerURL, + jwksURI: jwksURI, + pubKeyGetter: pubKeyGetter, } - - keysetJSON, err := openIDKeysetJSON(pubKeys) - if err != nil { - return nil, fmt.Errorf("could not marshal issuer keys JSON, error: %v", err) + // Register to be notified if public keys change + pubKeyGetter.AddListener(provider) + // Synchronously construct the config / keyset json once at startup to ensure a successful starting point + if err := provider.Update(); err != nil { + return nil, err } - - return &OpenIDMetadata{ - ConfigJSON: configJSON, - PublicKeysetJSON: keysetJSON, - }, nil + return provider, nil } // openIDMetadata provides a minimal subset of OIDC provider metadata: @@ -159,7 +201,7 @@ type openIDMetadata struct { // openIDConfigJSON returns the JSON OIDC Discovery Doc for the service // account issuer. -func openIDConfigJSON(iss, jwksURI string, keys []interface{}) ([]byte, error) { +func openIDConfigJSON(iss, jwksURI string, keys []PublicKey) ([]byte, error) { keyset, errs := publicJWKSFromKeys(keys) if errs != nil { return nil, errs @@ -183,7 +225,7 @@ func openIDConfigJSON(iss, jwksURI string, keys []interface{}) ([]byte, error) { // openIDKeysetJSON returns the JSON Web Key Set for the service account // issuer's keys. -func openIDKeysetJSON(keys []interface{}) ([]byte, error) { +func openIDKeysetJSON(keys []PublicKey) ([]byte, error) { keyset, errs := publicJWKSFromKeys(keys) if errs != nil { return nil, errs @@ -212,21 +254,12 @@ type publicKeyGetter interface { // publicJWKSFromKeys constructs a JSONWebKeySet from a list of keys. The key // set will only contain the public keys associated with the input keys. -func publicJWKSFromKeys(in []interface{}) (*jose.JSONWebKeySet, errors.Aggregate) { +func publicJWKSFromKeys(in []PublicKey) (*jose.JSONWebKeySet, errors.Aggregate) { // Decode keys into a JWKS. var keys jose.JSONWebKeySet var errs []error for i, key := range in { - var pubkey *jose.JSONWebKey - var err error - - switch k := key.(type) { - case publicKeyGetter: - // This is a private key. Get its public key - pubkey, err = jwkFromPublicKey(k.Public()) - default: - pubkey, err = jwkFromPublicKey(k) - } + pubkey, err := jwkFromPublicKey(key) if err != nil { errs = append(errs, fmt.Errorf("error constructing JWK for key #%d: %v", i, err)) continue @@ -244,21 +277,16 @@ func publicJWKSFromKeys(in []interface{}) (*jose.JSONWebKeySet, errors.Aggregate return &keys, nil } -func jwkFromPublicKey(publicKey crypto.PublicKey) (*jose.JSONWebKey, error) { - alg, err := algorithmFromPublicKey(publicKey) - if err != nil { - return nil, err - } - - keyID, err := keyIDFromPublicKey(publicKey) +func jwkFromPublicKey(publicKey PublicKey) (*jose.JSONWebKey, error) { + alg, err := algorithmFromPublicKey(publicKey.PublicKey) if err != nil { return nil, err } jwk := &jose.JSONWebKey{ Algorithm: string(alg), - Key: publicKey, - KeyID: keyID, + Key: publicKey.PublicKey, + KeyID: publicKey.KeyID, Use: "sig", } diff --git a/pkg/serviceaccount/openidmetadata_test.go b/pkg/serviceaccount/openidmetadata_test.go index 00331e8560b..365af2c7241 100644 --- a/pkg/serviceaccount/openidmetadata_test.go +++ b/pkg/serviceaccount/openidmetadata_test.go @@ -39,7 +39,7 @@ const ( exampleIssuer = "https://issuer.example.com" ) -func setupServer(t *testing.T, iss string, keys []interface{}) (*httptest.Server, string) { +func setupServer(t *testing.T, iss string, keys serviceaccount.PublicKeysGetter) (*httptest.Server, string) { t.Helper() c := restful.NewContainer() @@ -53,13 +53,13 @@ func setupServer(t *testing.T, iss string, keys []interface{}) (*httptest.Server jwksURI.Scheme = "https" jwksURI.Path = serviceaccount.JWKSPath - md, err := serviceaccount.NewOpenIDMetadata( + md, err := serviceaccount.NewOpenIDMetadataProvider( iss, jwksURI.String(), "", keys) if err != nil { t.Fatal(err) } - srv := routes.NewOpenIDMetadataServer(md.ConfigJSON, md.PublicKeysetJSON) + srv := routes.NewOpenIDMetadataServer(md) srv.Install(c) return s, jwksURI.String() @@ -77,20 +77,59 @@ type Configuration struct { SubjectTypes []string `json:"subject_types_supported"` } +type proxyKeyGetter struct { + serviceaccount.PublicKeysGetter + listeners []serviceaccount.Listener +} + +func (p *proxyKeyGetter) AddListener(listener serviceaccount.Listener) { + p.listeners = append(p.listeners, listener) + p.PublicKeysGetter.AddListener(listener) +} + func TestServeConfiguration(t *testing.T) { - s, jwksURI := setupServer(t, exampleIssuer, defaultKeys) + ecKeysGetter, err := serviceaccount.StaticPublicKeysGetter([]interface{}{getPublicKey(ecdsaPublicKey)}) + if err != nil { + t.Fatal(err) + } + rsaKeysGetter, err := serviceaccount.StaticPublicKeysGetter([]interface{}{getPublicKey(rsaPublicKey)}) + if err != nil { + t.Fatal(err) + } + keysGetter := &proxyKeyGetter{PublicKeysGetter: ecKeysGetter} + s, jwksURI := setupServer(t, exampleIssuer, keysGetter) defer s.Close() - want := Configuration{ + wantEC := Configuration{ Issuer: exampleIssuer, JWKSURI: jwksURI, ResponseTypes: []string{"id_token"}, SubjectTypes: []string{"public"}, - SigningAlgs: []string{"ES256", "RS256"}, + SigningAlgs: []string{"ES256"}, + } + wantRSA := Configuration{ + Issuer: exampleIssuer, + JWKSURI: jwksURI, + ResponseTypes: []string{"id_token"}, + SubjectTypes: []string{"public"}, + SigningAlgs: []string{"RS256"}, } - reqURL := s.URL + "/.well-known/openid-configuration" + expectConfiguration(t, reqURL, wantEC) + + // modify the underlying keys, expect the same response + keysGetter.PublicKeysGetter = rsaKeysGetter + expectConfiguration(t, reqURL, wantEC) + + // notify the metadata the keys changed, expected a modified response + for _, listener := range keysGetter.listeners { + listener.Enqueue() + } + expectConfiguration(t, reqURL, wantRSA) +} + +func expectConfiguration(t *testing.T, reqURL string, want Configuration) { resp, err := http.Get(reqURL) if err != nil { t.Fatalf("Get(%s) = %v, %v want: , ", reqURL, resp, err) @@ -185,47 +224,82 @@ func TestServeKeys(t *testing.T) { for _, tt := range serveKeysTests { t.Run(tt.Name, func(t *testing.T) { - s, _ := setupServer(t, exampleIssuer, tt.Keys) + initialKeysGetter, err := serviceaccount.StaticPublicKeysGetter(tt.Keys) + if err != nil { + t.Fatal(err) + } + updatedKeysGetter, err := serviceaccount.StaticPublicKeysGetter([]interface{}{wantPubRSA}) + if err != nil { + t.Fatal(err) + } + keysGetter := &proxyKeyGetter{PublicKeysGetter: initialKeysGetter} + s, _ := setupServer(t, exampleIssuer, keysGetter) defer s.Close() reqURL := s.URL + "/openid/v1/jwks" + expectKeys(t, reqURL, tt.WantKeys) - resp, err := http.Get(reqURL) - if err != nil { - t.Fatalf("Get(%s) = %v, %v want: , ", reqURL, resp, err) - } - defer resp.Body.Close() + // modify the underlying keys, expect the same response + keysGetter.PublicKeysGetter = updatedKeysGetter + expectKeys(t, reqURL, tt.WantKeys) - if resp.StatusCode != http.StatusOK { - t.Errorf("Get(%s) = %v, _ want: %v, _", reqURL, resp.StatusCode, http.StatusOK) - } - - if got, want := resp.Header.Get("Content-Type"), "application/jwk-set+json"; got != want { - t.Errorf("Get(%s) Content-Type = %q, _ want: %q, _", reqURL, got, want) - } - if got, want := resp.Header.Get("Cache-Control"), "public, max-age=3600"; got != want { - t.Errorf("Get(%s) Cache-Control = %q, _ want: %q, _", reqURL, got, want) - } - - ks := &jose.JSONWebKeySet{} - if err := json.NewDecoder(resp.Body).Decode(ks); err != nil { - t.Fatalf("Decode(_) = %v, want: ", err) - } - - bigIntComparer := cmp.Comparer( - func(x, y *big.Int) bool { - return x.Cmp(y) == 0 - }) - if !cmp.Equal(tt.WantKeys, ks.Keys, bigIntComparer) { - t.Errorf("unexpected diff in JWKS keys (-want, +got): %v", - cmp.Diff(tt.WantKeys, ks.Keys, bigIntComparer)) + // notify the metadata the keys changed, expected a modified response + for _, listener := range keysGetter.listeners { + listener.Enqueue() } + expectKeys(t, reqURL, []jose.JSONWebKey{{ + Algorithm: "RS256", + Key: wantPubRSA, + KeyID: rsaKeyID, + Use: "sig", + Certificates: []*x509.Certificate{}, + CertificateThumbprintSHA1: []uint8{}, + CertificateThumbprintSHA256: []uint8{}, + }}) }) } } +func expectKeys(t *testing.T, reqURL string, wantKeys []jose.JSONWebKey) { + resp, err := http.Get(reqURL) + if err != nil { + t.Fatalf("Get(%s) = %v, %v want: , ", reqURL, resp, err) + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Get(%s) = %v, _ want: %v, _", reqURL, resp.StatusCode, http.StatusOK) + } + + if got, want := resp.Header.Get("Content-Type"), "application/jwk-set+json"; got != want { + t.Errorf("Get(%s) Content-Type = %q, _ want: %q, _", reqURL, got, want) + } + if got, want := resp.Header.Get("Cache-Control"), "public, max-age=3600"; got != want { + t.Errorf("Get(%s) Cache-Control = %q, _ want: %q, _", reqURL, got, want) + } + + ks := &jose.JSONWebKeySet{} + if err := json.NewDecoder(resp.Body).Decode(ks); err != nil { + t.Fatalf("Decode(_) = %v, want: ", err) + } + + bigIntComparer := cmp.Comparer( + func(x, y *big.Int) bool { + return x.Cmp(y) == 0 + }) + if !cmp.Equal(wantKeys, ks.Keys, bigIntComparer) { + t.Errorf("unexpected diff in JWKS keys (-want, +got): %v", + cmp.Diff(wantKeys, ks.Keys, bigIntComparer)) + } +} func TestURLBoundaries(t *testing.T) { - s, _ := setupServer(t, exampleIssuer, defaultKeys) + keysGetter, err := serviceaccount.StaticPublicKeysGetter(defaultKeys) + if err != nil { + t.Fatal(err) + } + s, _ := setupServer(t, exampleIssuer, keysGetter) defer s.Close() for _, tt := range []struct { @@ -380,7 +454,11 @@ func TestNewOpenIDMetadata(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - md, err := serviceaccount.NewOpenIDMetadata(tc.issuerURL, tc.jwksURI, tc.externalAddress, tc.keys) + keysGetter, err := serviceaccount.StaticPublicKeysGetter(tc.keys) + if err != nil { + t.Fatal(err) + } + md, err := serviceaccount.NewOpenIDMetadataProvider(tc.issuerURL, tc.jwksURI, tc.externalAddress, keysGetter) if tc.err { if err == nil { t.Fatalf("got , want error") @@ -390,13 +468,13 @@ func TestNewOpenIDMetadata(t *testing.T) { t.Fatalf("got error %v, want ", err) } - config := string(md.ConfigJSON) - keyset := string(md.PublicKeysetJSON) - if config != tc.wantConfig { - t.Errorf("got metadata %s, want %s", config, tc.wantConfig) + config, _ := md.GetConfigJSON() + keyset, _ := md.GetKeysetJSON() + if string(config) != tc.wantConfig { + t.Errorf("got metadata %s, want %s", string(config), tc.wantConfig) } - if keyset != tc.wantKeyset { - t.Errorf("got keyset %s, want %s", keyset, tc.wantKeyset) + if string(keyset) != tc.wantKeyset { + t.Errorf("got keyset %s, want %s", string(keyset), tc.wantKeyset) } }) }