Merge pull request #125177 from liggitt/dynamic-public-key

Move public key serviceaccount getter to interface, filter by key id
This commit is contained in:
Kubernetes Prow Robot 2024-06-27 11:57:06 -07:00 committed by GitHub
commit ef1d28aa52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 508 additions and 148 deletions

View File

@ -94,9 +94,9 @@ type Extra struct {
ExtendExpiration bool
// ServiceAccountIssuerDiscovery
ServiceAccountIssuerURL string
ServiceAccountJWKSURI string
ServiceAccountPublicKeys []interface{}
ServiceAccountIssuerURL string
ServiceAccountJWKSURI string
ServiceAccountPublicKeysGetter serviceaccount.PublicKeysGetter
SystemNamespaces []string
@ -368,18 +368,24 @@ func CreateConfig(
return nil, nil, fmt.Errorf("failed to apply admission: %w", err)
}
// Load and set the public keys.
var pubKeys []interface{}
for _, f := range opts.Authentication.ServiceAccounts.KeyFiles {
keys, err := keyutil.PublicKeysFromFile(f)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse key file %q: %w", f, err)
if len(opts.Authentication.ServiceAccounts.KeyFiles) > 0 {
// Load and set the public keys.
var pubKeys []interface{}
for _, f := range opts.Authentication.ServiceAccounts.KeyFiles {
keys, err := keyutil.PublicKeysFromFile(f)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse key file %q: %w", f, err)
}
pubKeys = append(pubKeys, keys...)
}
pubKeys = append(pubKeys, keys...)
keysGetter, err := serviceaccount.StaticPublicKeysGetter(pubKeys)
if err != nil {
return nil, nil, fmt.Errorf("failed to set up public service account keys: %w", err)
}
config.ServiceAccountPublicKeysGetter = keysGetter
}
config.ServiceAccountIssuerURL = opts.Authentication.ServiceAccounts.Issuers[0]
config.ServiceAccountJWKSURI = opts.Authentication.ServiceAccounts.JWKSURI
config.ServiceAccountPublicKeys = pubKeys
return config, genericInitializers, nil
}

View File

@ -93,13 +93,11 @@ func (c completedConfig) New(name string, delegationTarget genericapiserver.Dele
routes.Logs{}.Install(generic.Handler.GoRestfulContainer)
}
// Metadata and keys are expected to only change across restarts at present,
// so we just marshal immediately and serve the cached JSON bytes.
md, err := serviceaccount.NewOpenIDMetadata(
md, err := serviceaccount.NewOpenIDMetadataProvider(
c.ServiceAccountIssuerURL,
c.ServiceAccountJWKSURI,
c.Generic.ExternalAddress,
c.ServiceAccountPublicKeys,
c.ServiceAccountPublicKeysGetter,
)
if err != nil {
// If there was an error, skip installing the endpoints and log the
@ -120,8 +118,7 @@ func (c completedConfig) New(name string, delegationTarget genericapiserver.Dele
klog.Info(msg)
}
} else {
routes.NewOpenIDMetadataServer(md.ConfigJSON, md.PublicKeysetJSON).
Install(generic.Handler.GoRestfulContainer)
routes.NewOpenIDMetadataServer(md).Install(generic.Handler.GoRestfulContainer)
}
s := &Server{

View File

@ -62,7 +62,6 @@ type Config struct {
AuthenticationConfig *apiserver.AuthenticationConfiguration
AuthenticationConfigData string
OIDCSigningAlgs []string
ServiceAccountKeyFiles []string
ServiceAccountLookup bool
ServiceAccountIssuers []string
APIAudiences authenticator.Audiences
@ -79,7 +78,9 @@ type Config struct {
RequestHeaderConfig *authenticatorfactory.RequestHeaderConfig
// TODO, this is the only non-serializable part of the entire config. Factor it out into a clientconfig
// ServiceAccountPublicKeysGetter returns public keys for verifying service account tokens.
ServiceAccountPublicKeysGetter serviceaccount.PublicKeysGetter
// ServiceAccountTokenGetter fetches API objects used to verify bound objects in service account token claims.
ServiceAccountTokenGetter serviceaccount.ServiceAccountTokenGetter
SecretsWriter typedv1core.SecretsGetter
BootstrapTokenAuthenticator authenticator.Token
@ -127,15 +128,15 @@ func (config Config) New(serverLifecycle context.Context) (authenticator.Request
}
tokenAuthenticators = append(tokenAuthenticators, authenticator.WrapAudienceAgnosticToken(config.APIAudiences, tokenAuth))
}
if len(config.ServiceAccountKeyFiles) > 0 {
serviceAccountAuth, err := newLegacyServiceAccountAuthenticator(config.ServiceAccountKeyFiles, config.ServiceAccountLookup, config.APIAudiences, config.ServiceAccountTokenGetter, config.SecretsWriter)
if config.ServiceAccountPublicKeysGetter != nil {
serviceAccountAuth, err := newLegacyServiceAccountAuthenticator(config.ServiceAccountPublicKeysGetter, config.ServiceAccountLookup, config.APIAudiences, config.ServiceAccountTokenGetter, config.SecretsWriter)
if err != nil {
return nil, nil, nil, nil, err
}
tokenAuthenticators = append(tokenAuthenticators, serviceAccountAuth)
}
if len(config.ServiceAccountIssuers) > 0 {
serviceAccountAuth, err := newServiceAccountAuthenticator(config.ServiceAccountIssuers, config.ServiceAccountKeyFiles, config.APIAudiences, config.ServiceAccountTokenGetter)
serviceAccountAuth, err := newServiceAccountAuthenticator(config.ServiceAccountIssuers, config.ServiceAccountPublicKeysGetter, config.APIAudiences, config.ServiceAccountTokenGetter)
if err != nil {
return nil, nil, nil, nil, err
}
@ -338,36 +339,25 @@ func newAuthenticatorFromTokenFile(tokenAuthFile string) (authenticator.Token, e
}
// newLegacyServiceAccountAuthenticator returns an authenticator.Token or an error
func newLegacyServiceAccountAuthenticator(keyfiles []string, lookup bool, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter, secretsWriter typedv1core.SecretsGetter) (authenticator.Token, error) {
allPublicKeys := []interface{}{}
for _, keyfile := range keyfiles {
publicKeys, err := keyutil.PublicKeysFromFile(keyfile)
if err != nil {
return nil, err
}
allPublicKeys = append(allPublicKeys, publicKeys...)
func newLegacyServiceAccountAuthenticator(publicKeysGetter serviceaccount.PublicKeysGetter, lookup bool, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter, secretsWriter typedv1core.SecretsGetter) (authenticator.Token, error) {
if publicKeysGetter == nil {
return nil, fmt.Errorf("no public key getter provided")
}
validator, err := serviceaccount.NewLegacyValidator(lookup, serviceAccountGetter, secretsWriter)
if err != nil {
return nil, fmt.Errorf("while creating legacy validator, err: %w", err)
}
tokenAuthenticator := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer}, allPublicKeys, apiAudiences, validator)
tokenAuthenticator := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer}, publicKeysGetter, apiAudiences, validator)
return tokenAuthenticator, nil
}
// newServiceAccountAuthenticator returns an authenticator.Token or an error
func newServiceAccountAuthenticator(issuers []string, keyfiles []string, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter) (authenticator.Token, error) {
allPublicKeys := []interface{}{}
for _, keyfile := range keyfiles {
publicKeys, err := keyutil.PublicKeysFromFile(keyfile)
if err != nil {
return nil, err
}
allPublicKeys = append(allPublicKeys, publicKeys...)
func newServiceAccountAuthenticator(issuers []string, publicKeysGetter serviceaccount.PublicKeysGetter, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter) (authenticator.Token, error) {
if publicKeysGetter == nil {
return nil, fmt.Errorf("no public key getter provided")
}
tokenAuthenticator := serviceaccount.JWTTokenAuthenticator(issuers, allPublicKeys, apiAudiences, serviceaccount.NewValidator(serviceAccountGetter))
tokenAuthenticator := serviceaccount.JWTTokenAuthenticator(issuers, publicKeysGetter, apiAudiences, serviceaccount.NewValidator(serviceAccountGetter))
return tokenAuthenticator, nil
}

View File

@ -47,6 +47,7 @@ import (
"k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes"
v1listers "k8s.io/client-go/listers/core/v1"
"k8s.io/client-go/util/keyutil"
cliflag "k8s.io/component-base/cli/flag"
"k8s.io/klog/v2"
openapicommon "k8s.io/kube-openapi/pkg/common"
@ -54,6 +55,7 @@ import (
"k8s.io/kubernetes/pkg/features"
kubeauthenticator "k8s.io/kubernetes/pkg/kubeapiserver/authenticator"
authzmodes "k8s.io/kubernetes/pkg/kubeapiserver/authorizer/modes"
"k8s.io/kubernetes/pkg/serviceaccount"
"k8s.io/kubernetes/pkg/util/filesystem"
"k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/bootstrap"
"k8s.io/utils/pointer"
@ -559,7 +561,21 @@ func (o *BuiltInAuthenticationOptions) ToAuthenticationConfig() (kubeauthenticat
if len(o.ServiceAccounts.Issuers) != 0 && len(o.APIAudiences) == 0 {
ret.APIAudiences = authenticator.Audiences(o.ServiceAccounts.Issuers)
}
ret.ServiceAccountKeyFiles = o.ServiceAccounts.KeyFiles
if len(o.ServiceAccounts.KeyFiles) > 0 {
allPublicKeys := []interface{}{}
for _, keyfile := range o.ServiceAccounts.KeyFiles {
publicKeys, err := keyutil.PublicKeysFromFile(keyfile)
if err != nil {
return kubeauthenticator.Config{}, err
}
allPublicKeys = append(allPublicKeys, publicKeys...)
}
keysGetter, err := serviceaccount.StaticPublicKeysGetter(allPublicKeys)
if err != nil {
return kubeauthenticator.Config{}, fmt.Errorf("failed to set up public service account keys: %w", err)
}
ret.ServiceAccountPublicKeysGetter = keysGetter
}
ret.ServiceAccountIssuers = o.ServiceAccounts.Issuers
ret.ServiceAccountLookup = o.ServiceAccounts.Lookup
}

View File

@ -17,6 +17,7 @@ limitations under the License.
package routes
import (
"fmt"
"net/http"
restful "github.com/emicklei/go-restful/v3"
@ -34,7 +35,8 @@ const (
// cacheControl is the value of the Cache-Control header. Overrides the
// global `private, no-cache` setting.
headerCacheControl = "Cache-Control"
cacheControl = "public, max-age=3600" // 1 hour
cacheControlTemplate = "public, max-age=%d"
// mimeJWKS is the content type of the keyset response
mimeJWKS = "application/jwk-set+json"
@ -42,18 +44,14 @@ const (
// OpenIDMetadataServer is an HTTP server for metadata of the KSA token issuer.
type OpenIDMetadataServer struct {
configJSON []byte
keysetJSON []byte
provider serviceaccount.OpenIDMetadataProvider
}
// NewOpenIDMetadataServer creates a new OpenIDMetadataServer.
// The issuer is the OIDC issuer; keys are the keys that may be used to sign
// KSA tokens.
func NewOpenIDMetadataServer(configJSON, keysetJSON []byte) *OpenIDMetadataServer {
return &OpenIDMetadataServer{
configJSON: configJSON,
keysetJSON: keysetJSON,
}
func NewOpenIDMetadataServer(provider serviceaccount.OpenIDMetadataProvider) *OpenIDMetadataServer {
return &OpenIDMetadataServer{provider: provider}
}
// Install adds this server to the request router c.
@ -95,19 +93,21 @@ func fromStandard(h http.HandlerFunc) restful.RouteFunction {
}
func (s *OpenIDMetadataServer) serveConfiguration(w http.ResponseWriter, req *http.Request) {
configJSON, maxAge := s.provider.GetConfigJSON()
w.Header().Set(restful.HEADER_ContentType, restful.MIME_JSON)
w.Header().Set(headerCacheControl, cacheControl)
if _, err := w.Write(s.configJSON); err != nil {
w.Header().Set(headerCacheControl, fmt.Sprintf(cacheControlTemplate, maxAge))
if _, err := w.Write(configJSON); err != nil {
klog.Errorf("failed to write service account issuer metadata response: %v", err)
return
}
}
func (s *OpenIDMetadataServer) serveKeys(w http.ResponseWriter, req *http.Request) {
keysetJSON, maxAge := s.provider.GetKeysetJSON()
// Per RFC7517 : https://tools.ietf.org/html/rfc7517#section-8.5.1
w.Header().Set(restful.HEADER_ContentType, mimeJWKS)
w.Header().Set(headerCacheControl, cacheControl)
if _, err := w.Write(s.keysetJSON); err != nil {
w.Header().Set(headerCacheControl, fmt.Sprintf(cacheControlTemplate, maxAge))
if _, err := w.Write(keysetJSON); err != nil {
klog.Errorf("failed to write service account issuer JWKS response: %v", err)
return
}

View File

@ -225,22 +225,97 @@ func (j *jwtTokenGenerator) GenerateToken(claims *jwt.Claims, privateClaims inte
// JWTTokenAuthenticator authenticates tokens as JWT tokens produced by JWTTokenGenerator
// Token signatures are verified using each of the given public keys until one works (allowing key rotation)
// If lookup is true, the service account and secret referenced as claims inside the token are retrieved and verified with the provided ServiceAccountTokenGetter
func JWTTokenAuthenticator[PrivateClaims any](issuers []string, keys []interface{}, implicitAuds authenticator.Audiences, validator Validator[PrivateClaims]) authenticator.Token {
func JWTTokenAuthenticator[PrivateClaims any](issuers []string, publicKeysGetter PublicKeysGetter, implicitAuds authenticator.Audiences, validator Validator[PrivateClaims]) authenticator.Token {
issuersMap := make(map[string]bool)
for _, issuer := range issuers {
issuersMap[issuer] = true
}
return &jwtTokenAuthenticator[PrivateClaims]{
issuers: issuersMap,
keys: keys,
keysGetter: publicKeysGetter,
implicitAuds: implicitAuds,
validator: validator,
}
}
// Listener is an interface to use to notify interested parties of a change.
type Listener interface {
// Enqueue should be called when an input may have changed
Enqueue()
}
// PublicKeysGetter returns public keys for a given key id.
type PublicKeysGetter interface {
// AddListener is adds a listener to be notified of potential input changes.
// This is a noop on static providers.
AddListener(listener Listener)
// GetCacheAgeMaxSeconds returns the seconds a call to GetPublicKeys() can be cached for.
// If the results of GetPublicKeys() can be dynamic, this means a new key must be included in the results
// for at least this long before it is used to sign new tokens.
GetCacheAgeMaxSeconds() int
// GetPublicKeys returns public keys to use for verifying a token with the given key id.
// keyIDHint may be empty if the token did not have a kid header, or if all public keys are desired.
GetPublicKeys(keyIDHint string) []PublicKey
}
type PublicKey struct {
KeyID string
PublicKey interface{}
}
type staticPublicKeysGetter struct {
allPublicKeys []PublicKey
publicKeysByID map[string][]PublicKey
}
// StaticPublicKeysGetter constructs an implementation of PublicKeysGetter
// which returns all public keys when key id is unspecified, and returns
// the public keys matching the keyIDFromPublicKey-derived key id when
// a key id is specified.
func StaticPublicKeysGetter(keys []interface{}) (PublicKeysGetter, error) {
allPublicKeys := []PublicKey{}
publicKeysByID := map[string][]PublicKey{}
for _, key := range keys {
if privateKey, isPrivateKey := key.(publicKeyGetter); isPrivateKey {
// This is a private key. Extract its public key.
key = privateKey.Public()
}
keyID, err := keyIDFromPublicKey(key)
if err != nil {
return nil, err
}
pk := PublicKey{PublicKey: key, KeyID: keyID}
publicKeysByID[keyID] = append(publicKeysByID[keyID], pk)
allPublicKeys = append(allPublicKeys, pk)
}
return &staticPublicKeysGetter{
allPublicKeys: allPublicKeys,
publicKeysByID: publicKeysByID,
}, nil
}
func (s staticPublicKeysGetter) AddListener(listener Listener) {
// no-op, static key content never changes
}
func (s staticPublicKeysGetter) GetCacheAgeMaxSeconds() int {
// hard-coded to match cache max-age set in OIDC discovery
return 3600
}
func (s staticPublicKeysGetter) GetPublicKeys(keyID string) []PublicKey {
if len(keyID) == 0 {
return s.allPublicKeys
}
return s.publicKeysByID[keyID]
}
type jwtTokenAuthenticator[PrivateClaims any] struct {
issuers map[string]bool
keys []interface{}
keysGetter PublicKeysGetter
validator Validator[PrivateClaims]
implicitAuds authenticator.Audiences
}
@ -269,13 +344,25 @@ func (j *jwtTokenAuthenticator[PrivateClaims]) AuthenticateToken(ctx context.Con
public := &jwt.Claims{}
private := new(PrivateClaims)
// TODO: Pick the key that has the same key ID as `tok`, if one exists.
// Pick the key that has the same key ID as `tok`, if one exists.
var kid string
for _, header := range tok.Headers {
if header.KeyID != "" {
kid = header.KeyID
break
}
}
var (
found bool
errlist []error
)
for _, key := range j.keys {
if err := tok.Claims(key, public, private); err != nil {
keys := j.keysGetter.GetPublicKeys(kid)
if len(keys) == 0 {
return nil, false, fmt.Errorf("invalid signature, no keys found")
}
for _, key := range keys {
if err := tok.Claims(key.PublicKey, public, private); err != nil {
errlist = append(errlist, err)
continue
}

View File

@ -247,7 +247,7 @@ func TestTokenGenerateAndValidate(t *testing.T) {
Token: rsaToken,
Client: nil,
Keys: []interface{}{},
ExpectedErr: false,
ExpectedErr: true,
ExpectedOK: false,
},
"invalid keys (rsa)": {
@ -385,7 +385,13 @@ func TestTokenGenerateAndValidate(t *testing.T) {
if err != nil {
t.Fatalf("While creating legacy validator, err: %v", err)
}
authn := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer, "bar"}, tc.Keys, auds, validator)
staticKeysGetter, err := serviceaccount.StaticPublicKeysGetter(tc.Keys)
if err != nil {
t.Fatal(err)
}
keysGetter := &keyIDPrefixer{PublicKeysGetter: staticKeysGetter}
authn := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer, "bar"}, keysGetter, auds, validator)
// An invalid, non-JWT token should always fail
ctx := authenticator.WithAudiences(context.Background(), auds)
@ -394,6 +400,16 @@ func TestTokenGenerateAndValidate(t *testing.T) {
continue
}
if tc.ExpectedOK {
// if authentication is otherwise expected to succeed, demonstrate changing key ids makes it fail
keysGetter.keyIDPrefix = "bogus"
if _, ok, err := authn.AuthenticateToken(ctx, tc.Token); err == nil || !strings.Contains(err.Error(), "no keys found") || ok {
t.Errorf("%s: Expected err containing 'no keys found', ok=false when key lookup by ID fails", k)
continue
}
keysGetter.keyIDPrefix = ""
}
resp, ok, err := authn.AuthenticateToken(ctx, tc.Token)
if (err != nil) != tc.ExpectedErr {
t.Errorf("%s: Expected error=%v, got %v", k, tc.ExpectedErr, err)
@ -424,6 +440,26 @@ func TestTokenGenerateAndValidate(t *testing.T) {
}
}
type keyIDPrefixer struct {
serviceaccount.PublicKeysGetter
keyIDPrefix string
}
func (k *keyIDPrefixer) GetPublicKeys(keyIDHint string) []serviceaccount.PublicKey {
if k.keyIDPrefix == "" {
return k.PublicKeysGetter.GetPublicKeys(keyIDHint)
}
if keyIDHint != "" {
keyIDHint = k.keyIDPrefix + keyIDHint
}
var retval []serviceaccount.PublicKey
for _, key := range k.PublicKeysGetter.GetPublicKeys(keyIDHint) {
key.KeyID = k.keyIDPrefix + key.KeyID
retval = append(retval, key)
}
return retval
}
func checkJSONWebSignatureHasKeyID(t *testing.T, jwsString string, expectedKeyID string) {
jws, err := jose.ParseSigned(jwsString)
if err != nil {
@ -502,3 +538,76 @@ func generateECDSATokenWithMalformedIss(t *testing.T, serviceAccount *v1.Service
return string(out)
}
func TestStaticPublicKeysGetter(t *testing.T) {
ecPrivate := getPrivateKey(ecdsaPrivateKey)
ecPublic := getPublicKey(ecdsaPublicKey)
rsaPublic := getPublicKey(rsaPublicKey)
testcases := []struct {
Name string
Keys []interface{}
ExpectErr bool
ExpectKeys []serviceaccount.PublicKey
}{
{
Name: "empty",
Keys: nil,
ExpectKeys: []serviceaccount.PublicKey{},
},
{
Name: "simple",
Keys: []interface{}{ecPublic, rsaPublic},
ExpectKeys: []serviceaccount.PublicKey{
{KeyID: "SoABiieYuNx4UdqYvZRVeuC6SihxgLrhLy9peHMHpTc", PublicKey: ecPublic},
{KeyID: "JHJehTTTZlsspKHT-GaJxK7Kd1NQgZJu3fyK6K_QDYU", PublicKey: rsaPublic},
},
},
{
Name: "private --> public",
Keys: []interface{}{ecPrivate},
ExpectKeys: []serviceaccount.PublicKey{
{KeyID: "SoABiieYuNx4UdqYvZRVeuC6SihxgLrhLy9peHMHpTc", PublicKey: ecPublic},
},
},
{
Name: "invalid",
Keys: []interface{}{"bogus"},
ExpectErr: true,
},
}
for _, tc := range testcases {
t.Run(tc.Name, func(t *testing.T) {
getter, err := serviceaccount.StaticPublicKeysGetter(tc.Keys)
if tc.ExpectErr {
if err == nil {
t.Fatal("expected construction error, got none")
}
return
}
if err != nil {
t.Fatalf("unexpected construction error: %v", err)
}
bogusKeys := getter.GetPublicKeys("bogus")
if len(bogusKeys) != 0 {
t.Fatalf("unexpected bogus keys: %#v", bogusKeys)
}
allKeys := getter.GetPublicKeys("")
if !reflect.DeepEqual(tc.ExpectKeys, allKeys) {
t.Fatalf("unexpected keys: %#v", allKeys)
}
for _, key := range allKeys {
keysByID := getter.GetPublicKeys(key.KeyID)
if len(keysByID) != 1 {
t.Fatalf("expected 1 key for id %s, got %d", key.KeyID, len(keysByID))
}
if !reflect.DeepEqual(key, keysByID[0]) {
t.Fatalf("unexpected key for id %s", key.KeyID)
}
}
})
}
}

View File

@ -0,0 +1,49 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package serviceaccount
import (
"testing"
"k8s.io/client-go/util/keyutil"
)
const rsaPublicKey = `-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA249XwEo9k4tM8fMxV7zx
OhcrP+WvXn917koM5Qr2ZXs4vo26e4ytdlrV0bQ9SlcLpQVSYjIxNfhTZdDt+ecI
zshKuv1gKIxbbLQMOuK1eA/4HALyEkFgmS/tleLJrhc65tKPMGD+pKQ/xhmzRuCG
51RoiMgbQxaCyYxGfNLpLAZK9L0Tctv9a0mJmGIYnIOQM4kC1A1I1n3EsXMWmeJU
j7OTh/AjjCnMnkgvKT2tpKxYQ59PgDgU8Ssc7RDSmSkLxnrv+OrN80j6xrw0OjEi
B4Ycr0PqfzZcvy8efTtFQ/Jnc4Bp1zUtFXt7+QeevePtQ2EcyELXE0i63T1CujRM
WwIDAQAB
-----END PUBLIC KEY-----
`
func TestKeyIDStability(t *testing.T) {
keys, err := keyutil.ParsePublicKeysPEM([]byte(rsaPublicKey))
if err != nil {
t.Fatal(err)
}
keyID, err := keyIDFromPublicKey(keys[0])
if err != nil {
t.Fatal(err)
}
// The derived key id for a given public key must not change or validation of previously issued tokens will fail to find associated keys
if expected, actual := "JHJehTTTZlsspKHT-GaJxK7Kd1NQgZJu3fyK6K_QDYU", keyID; expected != actual {
t.Fatalf("expected stable key id %q, got %q", expected, actual)
}
}

View File

@ -24,11 +24,13 @@ import (
"encoding/json"
"fmt"
"net/url"
"sync/atomic"
jose "gopkg.in/square/go-jose.v2"
"k8s.io/apimachinery/pkg/util/errors"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/klog/v2"
)
const (
@ -44,26 +46,68 @@ const (
JWKSPath = "/openid/v1/jwks"
)
// OpenIDMetadata contains the pre-rendered responses for OIDC discovery endpoints.
type OpenIDMetadata struct {
ConfigJSON []byte
PublicKeysetJSON []byte
// OpenIDMetadataProvider returns pre-rendered responses for OIDC discovery endpoints.
type OpenIDMetadataProvider interface {
GetConfigJSON() (json []byte, maxAge int)
GetKeysetJSON() (json []byte, maxAge int)
}
// NewOpenIDMetadata returns the pre-rendered JSON responses for the OIDC discovery
type openidConfigProvider struct {
issuerURL, jwksURI string
pubKeyGetter PublicKeysGetter
config atomic.Pointer[openidConfig]
}
type openidConfig struct {
configJSON []byte
keysetJSON []byte
}
func (p *openidConfigProvider) GetConfigJSON() ([]byte, int) {
return p.config.Load().configJSON, p.pubKeyGetter.GetCacheAgeMaxSeconds()
}
func (p *openidConfigProvider) GetKeysetJSON() ([]byte, int) {
return p.config.Load().keysetJSON, p.pubKeyGetter.GetCacheAgeMaxSeconds()
}
func (p *openidConfigProvider) Enqueue() {
err := p.Update()
if err != nil {
klog.ErrorS(err, "failed to update openid config metadata")
}
}
func (p *openidConfigProvider) Update() error {
pubKeys := p.pubKeyGetter.GetPublicKeys("")
if len(pubKeys) == 0 {
return fmt.Errorf("no keys provided for validating keyset")
}
configJSON, err := openIDConfigJSON(p.issuerURL, p.jwksURI, pubKeys)
if err != nil {
return fmt.Errorf("could not marshal issuer discovery JSON, error: %w", err)
}
keysetJSON, err := openIDKeysetJSON(pubKeys)
if err != nil {
return fmt.Errorf("could not marshal issuer keys JSON, error: %w", err)
}
p.config.Store(&openidConfig{
configJSON: configJSON,
keysetJSON: keysetJSON,
})
return nil
}
// NewOpenIDMetadataProvider returns a provider for the OIDC discovery
// endpoints, or an error if they could not be constructed. Callers should note
// that this function may perform additional validation on inputs that is not
// backwards-compatible with all command-line validation. The recommendation is
// to log the error and skip installing the OIDC discovery endpoints.
func NewOpenIDMetadata(issuerURL, jwksURI, defaultExternalAddress string, pubKeys []interface{}) (*OpenIDMetadata, error) {
func NewOpenIDMetadataProvider(issuerURL, jwksURI, defaultExternalAddress string, pubKeyGetter PublicKeysGetter) (OpenIDMetadataProvider, error) {
if issuerURL == "" {
return nil, fmt.Errorf("empty issuer URL")
}
if jwksURI == "" && defaultExternalAddress == "" {
return nil, fmt.Errorf("either the JWKS URI or the default external address, or both, must be set")
}
if len(pubKeys) == 0 {
return nil, fmt.Errorf("no keys provided for validating keyset")
if pubKeyGetter == nil {
return nil, fmt.Errorf("no public key getter provided")
}
// Ensure the issuer URL meets the OIDC spec (this is the additional
@ -126,20 +170,18 @@ func NewOpenIDMetadata(issuerURL, jwksURI, defaultExternalAddress string, pubKey
}
}
configJSON, err := openIDConfigJSON(issuerURL, jwksURI, pubKeys)
if err != nil {
return nil, fmt.Errorf("could not marshal issuer discovery JSON, error: %v", err)
provider := &openidConfigProvider{
issuerURL: issuerURL,
jwksURI: jwksURI,
pubKeyGetter: pubKeyGetter,
}
keysetJSON, err := openIDKeysetJSON(pubKeys)
if err != nil {
return nil, fmt.Errorf("could not marshal issuer keys JSON, error: %v", err)
// Register to be notified if public keys change
pubKeyGetter.AddListener(provider)
// Synchronously construct the config / keyset json once at startup to ensure a successful starting point
if err := provider.Update(); err != nil {
return nil, err
}
return &OpenIDMetadata{
ConfigJSON: configJSON,
PublicKeysetJSON: keysetJSON,
}, nil
return provider, nil
}
// openIDMetadata provides a minimal subset of OIDC provider metadata:
@ -159,7 +201,7 @@ type openIDMetadata struct {
// openIDConfigJSON returns the JSON OIDC Discovery Doc for the service
// account issuer.
func openIDConfigJSON(iss, jwksURI string, keys []interface{}) ([]byte, error) {
func openIDConfigJSON(iss, jwksURI string, keys []PublicKey) ([]byte, error) {
keyset, errs := publicJWKSFromKeys(keys)
if errs != nil {
return nil, errs
@ -183,7 +225,7 @@ func openIDConfigJSON(iss, jwksURI string, keys []interface{}) ([]byte, error) {
// openIDKeysetJSON returns the JSON Web Key Set for the service account
// issuer's keys.
func openIDKeysetJSON(keys []interface{}) ([]byte, error) {
func openIDKeysetJSON(keys []PublicKey) ([]byte, error) {
keyset, errs := publicJWKSFromKeys(keys)
if errs != nil {
return nil, errs
@ -212,21 +254,12 @@ type publicKeyGetter interface {
// publicJWKSFromKeys constructs a JSONWebKeySet from a list of keys. The key
// set will only contain the public keys associated with the input keys.
func publicJWKSFromKeys(in []interface{}) (*jose.JSONWebKeySet, errors.Aggregate) {
func publicJWKSFromKeys(in []PublicKey) (*jose.JSONWebKeySet, errors.Aggregate) {
// Decode keys into a JWKS.
var keys jose.JSONWebKeySet
var errs []error
for i, key := range in {
var pubkey *jose.JSONWebKey
var err error
switch k := key.(type) {
case publicKeyGetter:
// This is a private key. Get its public key
pubkey, err = jwkFromPublicKey(k.Public())
default:
pubkey, err = jwkFromPublicKey(k)
}
pubkey, err := jwkFromPublicKey(key)
if err != nil {
errs = append(errs, fmt.Errorf("error constructing JWK for key #%d: %v", i, err))
continue
@ -244,21 +277,16 @@ func publicJWKSFromKeys(in []interface{}) (*jose.JSONWebKeySet, errors.Aggregate
return &keys, nil
}
func jwkFromPublicKey(publicKey crypto.PublicKey) (*jose.JSONWebKey, error) {
alg, err := algorithmFromPublicKey(publicKey)
if err != nil {
return nil, err
}
keyID, err := keyIDFromPublicKey(publicKey)
func jwkFromPublicKey(publicKey PublicKey) (*jose.JSONWebKey, error) {
alg, err := algorithmFromPublicKey(publicKey.PublicKey)
if err != nil {
return nil, err
}
jwk := &jose.JSONWebKey{
Algorithm: string(alg),
Key: publicKey,
KeyID: keyID,
Key: publicKey.PublicKey,
KeyID: publicKey.KeyID,
Use: "sig",
}

View File

@ -39,7 +39,7 @@ const (
exampleIssuer = "https://issuer.example.com"
)
func setupServer(t *testing.T, iss string, keys []interface{}) (*httptest.Server, string) {
func setupServer(t *testing.T, iss string, keys serviceaccount.PublicKeysGetter) (*httptest.Server, string) {
t.Helper()
c := restful.NewContainer()
@ -53,13 +53,13 @@ func setupServer(t *testing.T, iss string, keys []interface{}) (*httptest.Server
jwksURI.Scheme = "https"
jwksURI.Path = serviceaccount.JWKSPath
md, err := serviceaccount.NewOpenIDMetadata(
md, err := serviceaccount.NewOpenIDMetadataProvider(
iss, jwksURI.String(), "", keys)
if err != nil {
t.Fatal(err)
}
srv := routes.NewOpenIDMetadataServer(md.ConfigJSON, md.PublicKeysetJSON)
srv := routes.NewOpenIDMetadataServer(md)
srv.Install(c)
return s, jwksURI.String()
@ -77,20 +77,59 @@ type Configuration struct {
SubjectTypes []string `json:"subject_types_supported"`
}
type proxyKeyGetter struct {
serviceaccount.PublicKeysGetter
listeners []serviceaccount.Listener
}
func (p *proxyKeyGetter) AddListener(listener serviceaccount.Listener) {
p.listeners = append(p.listeners, listener)
p.PublicKeysGetter.AddListener(listener)
}
func TestServeConfiguration(t *testing.T) {
s, jwksURI := setupServer(t, exampleIssuer, defaultKeys)
ecKeysGetter, err := serviceaccount.StaticPublicKeysGetter([]interface{}{getPublicKey(ecdsaPublicKey)})
if err != nil {
t.Fatal(err)
}
rsaKeysGetter, err := serviceaccount.StaticPublicKeysGetter([]interface{}{getPublicKey(rsaPublicKey)})
if err != nil {
t.Fatal(err)
}
keysGetter := &proxyKeyGetter{PublicKeysGetter: ecKeysGetter}
s, jwksURI := setupServer(t, exampleIssuer, keysGetter)
defer s.Close()
want := Configuration{
wantEC := Configuration{
Issuer: exampleIssuer,
JWKSURI: jwksURI,
ResponseTypes: []string{"id_token"},
SubjectTypes: []string{"public"},
SigningAlgs: []string{"ES256", "RS256"},
SigningAlgs: []string{"ES256"},
}
wantRSA := Configuration{
Issuer: exampleIssuer,
JWKSURI: jwksURI,
ResponseTypes: []string{"id_token"},
SubjectTypes: []string{"public"},
SigningAlgs: []string{"RS256"},
}
reqURL := s.URL + "/.well-known/openid-configuration"
expectConfiguration(t, reqURL, wantEC)
// modify the underlying keys, expect the same response
keysGetter.PublicKeysGetter = rsaKeysGetter
expectConfiguration(t, reqURL, wantEC)
// notify the metadata the keys changed, expected a modified response
for _, listener := range keysGetter.listeners {
listener.Enqueue()
}
expectConfiguration(t, reqURL, wantRSA)
}
func expectConfiguration(t *testing.T, reqURL string, want Configuration) {
resp, err := http.Get(reqURL)
if err != nil {
t.Fatalf("Get(%s) = %v, %v want: <response>, <nil>", reqURL, resp, err)
@ -185,47 +224,82 @@ func TestServeKeys(t *testing.T) {
for _, tt := range serveKeysTests {
t.Run(tt.Name, func(t *testing.T) {
s, _ := setupServer(t, exampleIssuer, tt.Keys)
initialKeysGetter, err := serviceaccount.StaticPublicKeysGetter(tt.Keys)
if err != nil {
t.Fatal(err)
}
updatedKeysGetter, err := serviceaccount.StaticPublicKeysGetter([]interface{}{wantPubRSA})
if err != nil {
t.Fatal(err)
}
keysGetter := &proxyKeyGetter{PublicKeysGetter: initialKeysGetter}
s, _ := setupServer(t, exampleIssuer, keysGetter)
defer s.Close()
reqURL := s.URL + "/openid/v1/jwks"
expectKeys(t, reqURL, tt.WantKeys)
resp, err := http.Get(reqURL)
if err != nil {
t.Fatalf("Get(%s) = %v, %v want: <response>, <nil>", reqURL, resp, err)
}
defer resp.Body.Close()
// modify the underlying keys, expect the same response
keysGetter.PublicKeysGetter = updatedKeysGetter
expectKeys(t, reqURL, tt.WantKeys)
if resp.StatusCode != http.StatusOK {
t.Errorf("Get(%s) = %v, _ want: %v, _", reqURL, resp.StatusCode, http.StatusOK)
}
if got, want := resp.Header.Get("Content-Type"), "application/jwk-set+json"; got != want {
t.Errorf("Get(%s) Content-Type = %q, _ want: %q, _", reqURL, got, want)
}
if got, want := resp.Header.Get("Cache-Control"), "public, max-age=3600"; got != want {
t.Errorf("Get(%s) Cache-Control = %q, _ want: %q, _", reqURL, got, want)
}
ks := &jose.JSONWebKeySet{}
if err := json.NewDecoder(resp.Body).Decode(ks); err != nil {
t.Fatalf("Decode(_) = %v, want: <nil>", err)
}
bigIntComparer := cmp.Comparer(
func(x, y *big.Int) bool {
return x.Cmp(y) == 0
})
if !cmp.Equal(tt.WantKeys, ks.Keys, bigIntComparer) {
t.Errorf("unexpected diff in JWKS keys (-want, +got): %v",
cmp.Diff(tt.WantKeys, ks.Keys, bigIntComparer))
// notify the metadata the keys changed, expected a modified response
for _, listener := range keysGetter.listeners {
listener.Enqueue()
}
expectKeys(t, reqURL, []jose.JSONWebKey{{
Algorithm: "RS256",
Key: wantPubRSA,
KeyID: rsaKeyID,
Use: "sig",
Certificates: []*x509.Certificate{},
CertificateThumbprintSHA1: []uint8{},
CertificateThumbprintSHA256: []uint8{},
}})
})
}
}
func expectKeys(t *testing.T, reqURL string, wantKeys []jose.JSONWebKey) {
resp, err := http.Get(reqURL)
if err != nil {
t.Fatalf("Get(%s) = %v, %v want: <response>, <nil>", reqURL, resp, err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
t.Errorf("Get(%s) = %v, _ want: %v, _", reqURL, resp.StatusCode, http.StatusOK)
}
if got, want := resp.Header.Get("Content-Type"), "application/jwk-set+json"; got != want {
t.Errorf("Get(%s) Content-Type = %q, _ want: %q, _", reqURL, got, want)
}
if got, want := resp.Header.Get("Cache-Control"), "public, max-age=3600"; got != want {
t.Errorf("Get(%s) Cache-Control = %q, _ want: %q, _", reqURL, got, want)
}
ks := &jose.JSONWebKeySet{}
if err := json.NewDecoder(resp.Body).Decode(ks); err != nil {
t.Fatalf("Decode(_) = %v, want: <nil>", err)
}
bigIntComparer := cmp.Comparer(
func(x, y *big.Int) bool {
return x.Cmp(y) == 0
})
if !cmp.Equal(wantKeys, ks.Keys, bigIntComparer) {
t.Errorf("unexpected diff in JWKS keys (-want, +got): %v",
cmp.Diff(wantKeys, ks.Keys, bigIntComparer))
}
}
func TestURLBoundaries(t *testing.T) {
s, _ := setupServer(t, exampleIssuer, defaultKeys)
keysGetter, err := serviceaccount.StaticPublicKeysGetter(defaultKeys)
if err != nil {
t.Fatal(err)
}
s, _ := setupServer(t, exampleIssuer, keysGetter)
defer s.Close()
for _, tt := range []struct {
@ -380,7 +454,11 @@ func TestNewOpenIDMetadata(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
md, err := serviceaccount.NewOpenIDMetadata(tc.issuerURL, tc.jwksURI, tc.externalAddress, tc.keys)
keysGetter, err := serviceaccount.StaticPublicKeysGetter(tc.keys)
if err != nil {
t.Fatal(err)
}
md, err := serviceaccount.NewOpenIDMetadataProvider(tc.issuerURL, tc.jwksURI, tc.externalAddress, keysGetter)
if tc.err {
if err == nil {
t.Fatalf("got <nil>, want error")
@ -390,13 +468,13 @@ func TestNewOpenIDMetadata(t *testing.T) {
t.Fatalf("got error %v, want <nil>", err)
}
config := string(md.ConfigJSON)
keyset := string(md.PublicKeysetJSON)
if config != tc.wantConfig {
t.Errorf("got metadata %s, want %s", config, tc.wantConfig)
config, _ := md.GetConfigJSON()
keyset, _ := md.GetKeysetJSON()
if string(config) != tc.wantConfig {
t.Errorf("got metadata %s, want %s", string(config), tc.wantConfig)
}
if keyset != tc.wantKeyset {
t.Errorf("got keyset %s, want %s", keyset, tc.wantKeyset)
if string(keyset) != tc.wantKeyset {
t.Errorf("got keyset %s, want %s", string(keyset), tc.wantKeyset)
}
})
}