Merge pull request #125177 from liggitt/dynamic-public-key

Move public key serviceaccount getter to interface, filter by key id
This commit is contained in:
Kubernetes Prow Robot 2024-06-27 11:57:06 -07:00 committed by GitHub
commit ef1d28aa52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 508 additions and 148 deletions

View File

@ -96,7 +96,7 @@ type Extra struct {
// ServiceAccountIssuerDiscovery // ServiceAccountIssuerDiscovery
ServiceAccountIssuerURL string ServiceAccountIssuerURL string
ServiceAccountJWKSURI string ServiceAccountJWKSURI string
ServiceAccountPublicKeys []interface{} ServiceAccountPublicKeysGetter serviceaccount.PublicKeysGetter
SystemNamespaces []string SystemNamespaces []string
@ -368,6 +368,7 @@ func CreateConfig(
return nil, nil, fmt.Errorf("failed to apply admission: %w", err) return nil, nil, fmt.Errorf("failed to apply admission: %w", err)
} }
if len(opts.Authentication.ServiceAccounts.KeyFiles) > 0 {
// Load and set the public keys. // Load and set the public keys.
var pubKeys []interface{} var pubKeys []interface{}
for _, f := range opts.Authentication.ServiceAccounts.KeyFiles { for _, f := range opts.Authentication.ServiceAccounts.KeyFiles {
@ -377,9 +378,14 @@ func CreateConfig(
} }
pubKeys = append(pubKeys, keys...) pubKeys = append(pubKeys, keys...)
} }
keysGetter, err := serviceaccount.StaticPublicKeysGetter(pubKeys)
if err != nil {
return nil, nil, fmt.Errorf("failed to set up public service account keys: %w", err)
}
config.ServiceAccountPublicKeysGetter = keysGetter
}
config.ServiceAccountIssuerURL = opts.Authentication.ServiceAccounts.Issuers[0] config.ServiceAccountIssuerURL = opts.Authentication.ServiceAccounts.Issuers[0]
config.ServiceAccountJWKSURI = opts.Authentication.ServiceAccounts.JWKSURI config.ServiceAccountJWKSURI = opts.Authentication.ServiceAccounts.JWKSURI
config.ServiceAccountPublicKeys = pubKeys
return config, genericInitializers, nil return config, genericInitializers, nil
} }

View File

@ -93,13 +93,11 @@ func (c completedConfig) New(name string, delegationTarget genericapiserver.Dele
routes.Logs{}.Install(generic.Handler.GoRestfulContainer) routes.Logs{}.Install(generic.Handler.GoRestfulContainer)
} }
// Metadata and keys are expected to only change across restarts at present, md, err := serviceaccount.NewOpenIDMetadataProvider(
// so we just marshal immediately and serve the cached JSON bytes.
md, err := serviceaccount.NewOpenIDMetadata(
c.ServiceAccountIssuerURL, c.ServiceAccountIssuerURL,
c.ServiceAccountJWKSURI, c.ServiceAccountJWKSURI,
c.Generic.ExternalAddress, c.Generic.ExternalAddress,
c.ServiceAccountPublicKeys, c.ServiceAccountPublicKeysGetter,
) )
if err != nil { if err != nil {
// If there was an error, skip installing the endpoints and log the // If there was an error, skip installing the endpoints and log the
@ -120,8 +118,7 @@ func (c completedConfig) New(name string, delegationTarget genericapiserver.Dele
klog.Info(msg) klog.Info(msg)
} }
} else { } else {
routes.NewOpenIDMetadataServer(md.ConfigJSON, md.PublicKeysetJSON). routes.NewOpenIDMetadataServer(md).Install(generic.Handler.GoRestfulContainer)
Install(generic.Handler.GoRestfulContainer)
} }
s := &Server{ s := &Server{

View File

@ -62,7 +62,6 @@ type Config struct {
AuthenticationConfig *apiserver.AuthenticationConfiguration AuthenticationConfig *apiserver.AuthenticationConfiguration
AuthenticationConfigData string AuthenticationConfigData string
OIDCSigningAlgs []string OIDCSigningAlgs []string
ServiceAccountKeyFiles []string
ServiceAccountLookup bool ServiceAccountLookup bool
ServiceAccountIssuers []string ServiceAccountIssuers []string
APIAudiences authenticator.Audiences APIAudiences authenticator.Audiences
@ -79,7 +78,9 @@ type Config struct {
RequestHeaderConfig *authenticatorfactory.RequestHeaderConfig RequestHeaderConfig *authenticatorfactory.RequestHeaderConfig
// TODO, this is the only non-serializable part of the entire config. Factor it out into a clientconfig // ServiceAccountPublicKeysGetter returns public keys for verifying service account tokens.
ServiceAccountPublicKeysGetter serviceaccount.PublicKeysGetter
// ServiceAccountTokenGetter fetches API objects used to verify bound objects in service account token claims.
ServiceAccountTokenGetter serviceaccount.ServiceAccountTokenGetter ServiceAccountTokenGetter serviceaccount.ServiceAccountTokenGetter
SecretsWriter typedv1core.SecretsGetter SecretsWriter typedv1core.SecretsGetter
BootstrapTokenAuthenticator authenticator.Token BootstrapTokenAuthenticator authenticator.Token
@ -127,15 +128,15 @@ func (config Config) New(serverLifecycle context.Context) (authenticator.Request
} }
tokenAuthenticators = append(tokenAuthenticators, authenticator.WrapAudienceAgnosticToken(config.APIAudiences, tokenAuth)) tokenAuthenticators = append(tokenAuthenticators, authenticator.WrapAudienceAgnosticToken(config.APIAudiences, tokenAuth))
} }
if len(config.ServiceAccountKeyFiles) > 0 { if config.ServiceAccountPublicKeysGetter != nil {
serviceAccountAuth, err := newLegacyServiceAccountAuthenticator(config.ServiceAccountKeyFiles, config.ServiceAccountLookup, config.APIAudiences, config.ServiceAccountTokenGetter, config.SecretsWriter) serviceAccountAuth, err := newLegacyServiceAccountAuthenticator(config.ServiceAccountPublicKeysGetter, config.ServiceAccountLookup, config.APIAudiences, config.ServiceAccountTokenGetter, config.SecretsWriter)
if err != nil { if err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, nil, err
} }
tokenAuthenticators = append(tokenAuthenticators, serviceAccountAuth) tokenAuthenticators = append(tokenAuthenticators, serviceAccountAuth)
} }
if len(config.ServiceAccountIssuers) > 0 { if len(config.ServiceAccountIssuers) > 0 {
serviceAccountAuth, err := newServiceAccountAuthenticator(config.ServiceAccountIssuers, config.ServiceAccountKeyFiles, config.APIAudiences, config.ServiceAccountTokenGetter) serviceAccountAuth, err := newServiceAccountAuthenticator(config.ServiceAccountIssuers, config.ServiceAccountPublicKeysGetter, config.APIAudiences, config.ServiceAccountTokenGetter)
if err != nil { if err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, nil, err
} }
@ -338,36 +339,25 @@ func newAuthenticatorFromTokenFile(tokenAuthFile string) (authenticator.Token, e
} }
// newLegacyServiceAccountAuthenticator returns an authenticator.Token or an error // newLegacyServiceAccountAuthenticator returns an authenticator.Token or an error
func newLegacyServiceAccountAuthenticator(keyfiles []string, lookup bool, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter, secretsWriter typedv1core.SecretsGetter) (authenticator.Token, error) { func newLegacyServiceAccountAuthenticator(publicKeysGetter serviceaccount.PublicKeysGetter, lookup bool, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter, secretsWriter typedv1core.SecretsGetter) (authenticator.Token, error) {
allPublicKeys := []interface{}{} if publicKeysGetter == nil {
for _, keyfile := range keyfiles { return nil, fmt.Errorf("no public key getter provided")
publicKeys, err := keyutil.PublicKeysFromFile(keyfile)
if err != nil {
return nil, err
}
allPublicKeys = append(allPublicKeys, publicKeys...)
} }
validator, err := serviceaccount.NewLegacyValidator(lookup, serviceAccountGetter, secretsWriter) validator, err := serviceaccount.NewLegacyValidator(lookup, serviceAccountGetter, secretsWriter)
if err != nil { if err != nil {
return nil, fmt.Errorf("while creating legacy validator, err: %w", err) return nil, fmt.Errorf("while creating legacy validator, err: %w", err)
} }
tokenAuthenticator := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer}, allPublicKeys, apiAudiences, validator) tokenAuthenticator := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer}, publicKeysGetter, apiAudiences, validator)
return tokenAuthenticator, nil return tokenAuthenticator, nil
} }
// newServiceAccountAuthenticator returns an authenticator.Token or an error // newServiceAccountAuthenticator returns an authenticator.Token or an error
func newServiceAccountAuthenticator(issuers []string, keyfiles []string, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter) (authenticator.Token, error) { func newServiceAccountAuthenticator(issuers []string, publicKeysGetter serviceaccount.PublicKeysGetter, apiAudiences authenticator.Audiences, serviceAccountGetter serviceaccount.ServiceAccountTokenGetter) (authenticator.Token, error) {
allPublicKeys := []interface{}{} if publicKeysGetter == nil {
for _, keyfile := range keyfiles { return nil, fmt.Errorf("no public key getter provided")
publicKeys, err := keyutil.PublicKeysFromFile(keyfile)
if err != nil {
return nil, err
} }
allPublicKeys = append(allPublicKeys, publicKeys...) tokenAuthenticator := serviceaccount.JWTTokenAuthenticator(issuers, publicKeysGetter, apiAudiences, serviceaccount.NewValidator(serviceAccountGetter))
}
tokenAuthenticator := serviceaccount.JWTTokenAuthenticator(issuers, allPublicKeys, apiAudiences, serviceaccount.NewValidator(serviceAccountGetter))
return tokenAuthenticator, nil return tokenAuthenticator, nil
} }

View File

@ -47,6 +47,7 @@ import (
"k8s.io/client-go/informers" "k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes"
v1listers "k8s.io/client-go/listers/core/v1" v1listers "k8s.io/client-go/listers/core/v1"
"k8s.io/client-go/util/keyutil"
cliflag "k8s.io/component-base/cli/flag" cliflag "k8s.io/component-base/cli/flag"
"k8s.io/klog/v2" "k8s.io/klog/v2"
openapicommon "k8s.io/kube-openapi/pkg/common" openapicommon "k8s.io/kube-openapi/pkg/common"
@ -54,6 +55,7 @@ import (
"k8s.io/kubernetes/pkg/features" "k8s.io/kubernetes/pkg/features"
kubeauthenticator "k8s.io/kubernetes/pkg/kubeapiserver/authenticator" kubeauthenticator "k8s.io/kubernetes/pkg/kubeapiserver/authenticator"
authzmodes "k8s.io/kubernetes/pkg/kubeapiserver/authorizer/modes" authzmodes "k8s.io/kubernetes/pkg/kubeapiserver/authorizer/modes"
"k8s.io/kubernetes/pkg/serviceaccount"
"k8s.io/kubernetes/pkg/util/filesystem" "k8s.io/kubernetes/pkg/util/filesystem"
"k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/bootstrap" "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/bootstrap"
"k8s.io/utils/pointer" "k8s.io/utils/pointer"
@ -559,7 +561,21 @@ func (o *BuiltInAuthenticationOptions) ToAuthenticationConfig() (kubeauthenticat
if len(o.ServiceAccounts.Issuers) != 0 && len(o.APIAudiences) == 0 { if len(o.ServiceAccounts.Issuers) != 0 && len(o.APIAudiences) == 0 {
ret.APIAudiences = authenticator.Audiences(o.ServiceAccounts.Issuers) ret.APIAudiences = authenticator.Audiences(o.ServiceAccounts.Issuers)
} }
ret.ServiceAccountKeyFiles = o.ServiceAccounts.KeyFiles if len(o.ServiceAccounts.KeyFiles) > 0 {
allPublicKeys := []interface{}{}
for _, keyfile := range o.ServiceAccounts.KeyFiles {
publicKeys, err := keyutil.PublicKeysFromFile(keyfile)
if err != nil {
return kubeauthenticator.Config{}, err
}
allPublicKeys = append(allPublicKeys, publicKeys...)
}
keysGetter, err := serviceaccount.StaticPublicKeysGetter(allPublicKeys)
if err != nil {
return kubeauthenticator.Config{}, fmt.Errorf("failed to set up public service account keys: %w", err)
}
ret.ServiceAccountPublicKeysGetter = keysGetter
}
ret.ServiceAccountIssuers = o.ServiceAccounts.Issuers ret.ServiceAccountIssuers = o.ServiceAccounts.Issuers
ret.ServiceAccountLookup = o.ServiceAccounts.Lookup ret.ServiceAccountLookup = o.ServiceAccounts.Lookup
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package routes package routes
import ( import (
"fmt"
"net/http" "net/http"
restful "github.com/emicklei/go-restful/v3" restful "github.com/emicklei/go-restful/v3"
@ -34,7 +35,8 @@ const (
// cacheControl is the value of the Cache-Control header. Overrides the // cacheControl is the value of the Cache-Control header. Overrides the
// global `private, no-cache` setting. // global `private, no-cache` setting.
headerCacheControl = "Cache-Control" headerCacheControl = "Cache-Control"
cacheControl = "public, max-age=3600" // 1 hour
cacheControlTemplate = "public, max-age=%d"
// mimeJWKS is the content type of the keyset response // mimeJWKS is the content type of the keyset response
mimeJWKS = "application/jwk-set+json" mimeJWKS = "application/jwk-set+json"
@ -42,18 +44,14 @@ const (
// OpenIDMetadataServer is an HTTP server for metadata of the KSA token issuer. // OpenIDMetadataServer is an HTTP server for metadata of the KSA token issuer.
type OpenIDMetadataServer struct { type OpenIDMetadataServer struct {
configJSON []byte provider serviceaccount.OpenIDMetadataProvider
keysetJSON []byte
} }
// NewOpenIDMetadataServer creates a new OpenIDMetadataServer. // NewOpenIDMetadataServer creates a new OpenIDMetadataServer.
// The issuer is the OIDC issuer; keys are the keys that may be used to sign // The issuer is the OIDC issuer; keys are the keys that may be used to sign
// KSA tokens. // KSA tokens.
func NewOpenIDMetadataServer(configJSON, keysetJSON []byte) *OpenIDMetadataServer { func NewOpenIDMetadataServer(provider serviceaccount.OpenIDMetadataProvider) *OpenIDMetadataServer {
return &OpenIDMetadataServer{ return &OpenIDMetadataServer{provider: provider}
configJSON: configJSON,
keysetJSON: keysetJSON,
}
} }
// Install adds this server to the request router c. // Install adds this server to the request router c.
@ -95,19 +93,21 @@ func fromStandard(h http.HandlerFunc) restful.RouteFunction {
} }
func (s *OpenIDMetadataServer) serveConfiguration(w http.ResponseWriter, req *http.Request) { func (s *OpenIDMetadataServer) serveConfiguration(w http.ResponseWriter, req *http.Request) {
configJSON, maxAge := s.provider.GetConfigJSON()
w.Header().Set(restful.HEADER_ContentType, restful.MIME_JSON) w.Header().Set(restful.HEADER_ContentType, restful.MIME_JSON)
w.Header().Set(headerCacheControl, cacheControl) w.Header().Set(headerCacheControl, fmt.Sprintf(cacheControlTemplate, maxAge))
if _, err := w.Write(s.configJSON); err != nil { if _, err := w.Write(configJSON); err != nil {
klog.Errorf("failed to write service account issuer metadata response: %v", err) klog.Errorf("failed to write service account issuer metadata response: %v", err)
return return
} }
} }
func (s *OpenIDMetadataServer) serveKeys(w http.ResponseWriter, req *http.Request) { func (s *OpenIDMetadataServer) serveKeys(w http.ResponseWriter, req *http.Request) {
keysetJSON, maxAge := s.provider.GetKeysetJSON()
// Per RFC7517 : https://tools.ietf.org/html/rfc7517#section-8.5.1 // Per RFC7517 : https://tools.ietf.org/html/rfc7517#section-8.5.1
w.Header().Set(restful.HEADER_ContentType, mimeJWKS) w.Header().Set(restful.HEADER_ContentType, mimeJWKS)
w.Header().Set(headerCacheControl, cacheControl) w.Header().Set(headerCacheControl, fmt.Sprintf(cacheControlTemplate, maxAge))
if _, err := w.Write(s.keysetJSON); err != nil { if _, err := w.Write(keysetJSON); err != nil {
klog.Errorf("failed to write service account issuer JWKS response: %v", err) klog.Errorf("failed to write service account issuer JWKS response: %v", err)
return return
} }

View File

@ -225,22 +225,97 @@ func (j *jwtTokenGenerator) GenerateToken(claims *jwt.Claims, privateClaims inte
// JWTTokenAuthenticator authenticates tokens as JWT tokens produced by JWTTokenGenerator // JWTTokenAuthenticator authenticates tokens as JWT tokens produced by JWTTokenGenerator
// Token signatures are verified using each of the given public keys until one works (allowing key rotation) // Token signatures are verified using each of the given public keys until one works (allowing key rotation)
// If lookup is true, the service account and secret referenced as claims inside the token are retrieved and verified with the provided ServiceAccountTokenGetter // If lookup is true, the service account and secret referenced as claims inside the token are retrieved and verified with the provided ServiceAccountTokenGetter
func JWTTokenAuthenticator[PrivateClaims any](issuers []string, keys []interface{}, implicitAuds authenticator.Audiences, validator Validator[PrivateClaims]) authenticator.Token { func JWTTokenAuthenticator[PrivateClaims any](issuers []string, publicKeysGetter PublicKeysGetter, implicitAuds authenticator.Audiences, validator Validator[PrivateClaims]) authenticator.Token {
issuersMap := make(map[string]bool) issuersMap := make(map[string]bool)
for _, issuer := range issuers { for _, issuer := range issuers {
issuersMap[issuer] = true issuersMap[issuer] = true
} }
return &jwtTokenAuthenticator[PrivateClaims]{ return &jwtTokenAuthenticator[PrivateClaims]{
issuers: issuersMap, issuers: issuersMap,
keys: keys, keysGetter: publicKeysGetter,
implicitAuds: implicitAuds, implicitAuds: implicitAuds,
validator: validator, validator: validator,
} }
} }
// Listener is an interface to use to notify interested parties of a change.
type Listener interface {
// Enqueue should be called when an input may have changed
Enqueue()
}
// PublicKeysGetter returns public keys for a given key id.
type PublicKeysGetter interface {
// AddListener is adds a listener to be notified of potential input changes.
// This is a noop on static providers.
AddListener(listener Listener)
// GetCacheAgeMaxSeconds returns the seconds a call to GetPublicKeys() can be cached for.
// If the results of GetPublicKeys() can be dynamic, this means a new key must be included in the results
// for at least this long before it is used to sign new tokens.
GetCacheAgeMaxSeconds() int
// GetPublicKeys returns public keys to use for verifying a token with the given key id.
// keyIDHint may be empty if the token did not have a kid header, or if all public keys are desired.
GetPublicKeys(keyIDHint string) []PublicKey
}
type PublicKey struct {
KeyID string
PublicKey interface{}
}
type staticPublicKeysGetter struct {
allPublicKeys []PublicKey
publicKeysByID map[string][]PublicKey
}
// StaticPublicKeysGetter constructs an implementation of PublicKeysGetter
// which returns all public keys when key id is unspecified, and returns
// the public keys matching the keyIDFromPublicKey-derived key id when
// a key id is specified.
func StaticPublicKeysGetter(keys []interface{}) (PublicKeysGetter, error) {
allPublicKeys := []PublicKey{}
publicKeysByID := map[string][]PublicKey{}
for _, key := range keys {
if privateKey, isPrivateKey := key.(publicKeyGetter); isPrivateKey {
// This is a private key. Extract its public key.
key = privateKey.Public()
}
keyID, err := keyIDFromPublicKey(key)
if err != nil {
return nil, err
}
pk := PublicKey{PublicKey: key, KeyID: keyID}
publicKeysByID[keyID] = append(publicKeysByID[keyID], pk)
allPublicKeys = append(allPublicKeys, pk)
}
return &staticPublicKeysGetter{
allPublicKeys: allPublicKeys,
publicKeysByID: publicKeysByID,
}, nil
}
func (s staticPublicKeysGetter) AddListener(listener Listener) {
// no-op, static key content never changes
}
func (s staticPublicKeysGetter) GetCacheAgeMaxSeconds() int {
// hard-coded to match cache max-age set in OIDC discovery
return 3600
}
func (s staticPublicKeysGetter) GetPublicKeys(keyID string) []PublicKey {
if len(keyID) == 0 {
return s.allPublicKeys
}
return s.publicKeysByID[keyID]
}
type jwtTokenAuthenticator[PrivateClaims any] struct { type jwtTokenAuthenticator[PrivateClaims any] struct {
issuers map[string]bool issuers map[string]bool
keys []interface{} keysGetter PublicKeysGetter
validator Validator[PrivateClaims] validator Validator[PrivateClaims]
implicitAuds authenticator.Audiences implicitAuds authenticator.Audiences
} }
@ -269,13 +344,25 @@ func (j *jwtTokenAuthenticator[PrivateClaims]) AuthenticateToken(ctx context.Con
public := &jwt.Claims{} public := &jwt.Claims{}
private := new(PrivateClaims) private := new(PrivateClaims)
// TODO: Pick the key that has the same key ID as `tok`, if one exists. // Pick the key that has the same key ID as `tok`, if one exists.
var kid string
for _, header := range tok.Headers {
if header.KeyID != "" {
kid = header.KeyID
break
}
}
var ( var (
found bool found bool
errlist []error errlist []error
) )
for _, key := range j.keys { keys := j.keysGetter.GetPublicKeys(kid)
if err := tok.Claims(key, public, private); err != nil { if len(keys) == 0 {
return nil, false, fmt.Errorf("invalid signature, no keys found")
}
for _, key := range keys {
if err := tok.Claims(key.PublicKey, public, private); err != nil {
errlist = append(errlist, err) errlist = append(errlist, err)
continue continue
} }

View File

@ -247,7 +247,7 @@ func TestTokenGenerateAndValidate(t *testing.T) {
Token: rsaToken, Token: rsaToken,
Client: nil, Client: nil,
Keys: []interface{}{}, Keys: []interface{}{},
ExpectedErr: false, ExpectedErr: true,
ExpectedOK: false, ExpectedOK: false,
}, },
"invalid keys (rsa)": { "invalid keys (rsa)": {
@ -385,7 +385,13 @@ func TestTokenGenerateAndValidate(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("While creating legacy validator, err: %v", err) t.Fatalf("While creating legacy validator, err: %v", err)
} }
authn := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer, "bar"}, tc.Keys, auds, validator) staticKeysGetter, err := serviceaccount.StaticPublicKeysGetter(tc.Keys)
if err != nil {
t.Fatal(err)
}
keysGetter := &keyIDPrefixer{PublicKeysGetter: staticKeysGetter}
authn := serviceaccount.JWTTokenAuthenticator([]string{serviceaccount.LegacyIssuer, "bar"}, keysGetter, auds, validator)
// An invalid, non-JWT token should always fail // An invalid, non-JWT token should always fail
ctx := authenticator.WithAudiences(context.Background(), auds) ctx := authenticator.WithAudiences(context.Background(), auds)
@ -394,6 +400,16 @@ func TestTokenGenerateAndValidate(t *testing.T) {
continue continue
} }
if tc.ExpectedOK {
// if authentication is otherwise expected to succeed, demonstrate changing key ids makes it fail
keysGetter.keyIDPrefix = "bogus"
if _, ok, err := authn.AuthenticateToken(ctx, tc.Token); err == nil || !strings.Contains(err.Error(), "no keys found") || ok {
t.Errorf("%s: Expected err containing 'no keys found', ok=false when key lookup by ID fails", k)
continue
}
keysGetter.keyIDPrefix = ""
}
resp, ok, err := authn.AuthenticateToken(ctx, tc.Token) resp, ok, err := authn.AuthenticateToken(ctx, tc.Token)
if (err != nil) != tc.ExpectedErr { if (err != nil) != tc.ExpectedErr {
t.Errorf("%s: Expected error=%v, got %v", k, tc.ExpectedErr, err) t.Errorf("%s: Expected error=%v, got %v", k, tc.ExpectedErr, err)
@ -424,6 +440,26 @@ func TestTokenGenerateAndValidate(t *testing.T) {
} }
} }
type keyIDPrefixer struct {
serviceaccount.PublicKeysGetter
keyIDPrefix string
}
func (k *keyIDPrefixer) GetPublicKeys(keyIDHint string) []serviceaccount.PublicKey {
if k.keyIDPrefix == "" {
return k.PublicKeysGetter.GetPublicKeys(keyIDHint)
}
if keyIDHint != "" {
keyIDHint = k.keyIDPrefix + keyIDHint
}
var retval []serviceaccount.PublicKey
for _, key := range k.PublicKeysGetter.GetPublicKeys(keyIDHint) {
key.KeyID = k.keyIDPrefix + key.KeyID
retval = append(retval, key)
}
return retval
}
func checkJSONWebSignatureHasKeyID(t *testing.T, jwsString string, expectedKeyID string) { func checkJSONWebSignatureHasKeyID(t *testing.T, jwsString string, expectedKeyID string) {
jws, err := jose.ParseSigned(jwsString) jws, err := jose.ParseSigned(jwsString)
if err != nil { if err != nil {
@ -502,3 +538,76 @@ func generateECDSATokenWithMalformedIss(t *testing.T, serviceAccount *v1.Service
return string(out) return string(out)
} }
func TestStaticPublicKeysGetter(t *testing.T) {
ecPrivate := getPrivateKey(ecdsaPrivateKey)
ecPublic := getPublicKey(ecdsaPublicKey)
rsaPublic := getPublicKey(rsaPublicKey)
testcases := []struct {
Name string
Keys []interface{}
ExpectErr bool
ExpectKeys []serviceaccount.PublicKey
}{
{
Name: "empty",
Keys: nil,
ExpectKeys: []serviceaccount.PublicKey{},
},
{
Name: "simple",
Keys: []interface{}{ecPublic, rsaPublic},
ExpectKeys: []serviceaccount.PublicKey{
{KeyID: "SoABiieYuNx4UdqYvZRVeuC6SihxgLrhLy9peHMHpTc", PublicKey: ecPublic},
{KeyID: "JHJehTTTZlsspKHT-GaJxK7Kd1NQgZJu3fyK6K_QDYU", PublicKey: rsaPublic},
},
},
{
Name: "private --> public",
Keys: []interface{}{ecPrivate},
ExpectKeys: []serviceaccount.PublicKey{
{KeyID: "SoABiieYuNx4UdqYvZRVeuC6SihxgLrhLy9peHMHpTc", PublicKey: ecPublic},
},
},
{
Name: "invalid",
Keys: []interface{}{"bogus"},
ExpectErr: true,
},
}
for _, tc := range testcases {
t.Run(tc.Name, func(t *testing.T) {
getter, err := serviceaccount.StaticPublicKeysGetter(tc.Keys)
if tc.ExpectErr {
if err == nil {
t.Fatal("expected construction error, got none")
}
return
}
if err != nil {
t.Fatalf("unexpected construction error: %v", err)
}
bogusKeys := getter.GetPublicKeys("bogus")
if len(bogusKeys) != 0 {
t.Fatalf("unexpected bogus keys: %#v", bogusKeys)
}
allKeys := getter.GetPublicKeys("")
if !reflect.DeepEqual(tc.ExpectKeys, allKeys) {
t.Fatalf("unexpected keys: %#v", allKeys)
}
for _, key := range allKeys {
keysByID := getter.GetPublicKeys(key.KeyID)
if len(keysByID) != 1 {
t.Fatalf("expected 1 key for id %s, got %d", key.KeyID, len(keysByID))
}
if !reflect.DeepEqual(key, keysByID[0]) {
t.Fatalf("unexpected key for id %s", key.KeyID)
}
}
})
}
}

View File

@ -0,0 +1,49 @@
/*
Copyright 2024 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package serviceaccount
import (
"testing"
"k8s.io/client-go/util/keyutil"
)
const rsaPublicKey = `-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA249XwEo9k4tM8fMxV7zx
OhcrP+WvXn917koM5Qr2ZXs4vo26e4ytdlrV0bQ9SlcLpQVSYjIxNfhTZdDt+ecI
zshKuv1gKIxbbLQMOuK1eA/4HALyEkFgmS/tleLJrhc65tKPMGD+pKQ/xhmzRuCG
51RoiMgbQxaCyYxGfNLpLAZK9L0Tctv9a0mJmGIYnIOQM4kC1A1I1n3EsXMWmeJU
j7OTh/AjjCnMnkgvKT2tpKxYQ59PgDgU8Ssc7RDSmSkLxnrv+OrN80j6xrw0OjEi
B4Ycr0PqfzZcvy8efTtFQ/Jnc4Bp1zUtFXt7+QeevePtQ2EcyELXE0i63T1CujRM
WwIDAQAB
-----END PUBLIC KEY-----
`
func TestKeyIDStability(t *testing.T) {
keys, err := keyutil.ParsePublicKeysPEM([]byte(rsaPublicKey))
if err != nil {
t.Fatal(err)
}
keyID, err := keyIDFromPublicKey(keys[0])
if err != nil {
t.Fatal(err)
}
// The derived key id for a given public key must not change or validation of previously issued tokens will fail to find associated keys
if expected, actual := "JHJehTTTZlsspKHT-GaJxK7Kd1NQgZJu3fyK6K_QDYU", keyID; expected != actual {
t.Fatalf("expected stable key id %q, got %q", expected, actual)
}
}

View File

@ -24,11 +24,13 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/url" "net/url"
"sync/atomic"
jose "gopkg.in/square/go-jose.v2" jose "gopkg.in/square/go-jose.v2"
"k8s.io/apimachinery/pkg/util/errors" "k8s.io/apimachinery/pkg/util/errors"
"k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/sets"
"k8s.io/klog/v2"
) )
const ( const (
@ -44,26 +46,68 @@ const (
JWKSPath = "/openid/v1/jwks" JWKSPath = "/openid/v1/jwks"
) )
// OpenIDMetadata contains the pre-rendered responses for OIDC discovery endpoints. // OpenIDMetadataProvider returns pre-rendered responses for OIDC discovery endpoints.
type OpenIDMetadata struct { type OpenIDMetadataProvider interface {
ConfigJSON []byte GetConfigJSON() (json []byte, maxAge int)
PublicKeysetJSON []byte GetKeysetJSON() (json []byte, maxAge int)
} }
// NewOpenIDMetadata returns the pre-rendered JSON responses for the OIDC discovery type openidConfigProvider struct {
issuerURL, jwksURI string
pubKeyGetter PublicKeysGetter
config atomic.Pointer[openidConfig]
}
type openidConfig struct {
configJSON []byte
keysetJSON []byte
}
func (p *openidConfigProvider) GetConfigJSON() ([]byte, int) {
return p.config.Load().configJSON, p.pubKeyGetter.GetCacheAgeMaxSeconds()
}
func (p *openidConfigProvider) GetKeysetJSON() ([]byte, int) {
return p.config.Load().keysetJSON, p.pubKeyGetter.GetCacheAgeMaxSeconds()
}
func (p *openidConfigProvider) Enqueue() {
err := p.Update()
if err != nil {
klog.ErrorS(err, "failed to update openid config metadata")
}
}
func (p *openidConfigProvider) Update() error {
pubKeys := p.pubKeyGetter.GetPublicKeys("")
if len(pubKeys) == 0 {
return fmt.Errorf("no keys provided for validating keyset")
}
configJSON, err := openIDConfigJSON(p.issuerURL, p.jwksURI, pubKeys)
if err != nil {
return fmt.Errorf("could not marshal issuer discovery JSON, error: %w", err)
}
keysetJSON, err := openIDKeysetJSON(pubKeys)
if err != nil {
return fmt.Errorf("could not marshal issuer keys JSON, error: %w", err)
}
p.config.Store(&openidConfig{
configJSON: configJSON,
keysetJSON: keysetJSON,
})
return nil
}
// NewOpenIDMetadataProvider returns a provider for the OIDC discovery
// endpoints, or an error if they could not be constructed. Callers should note // endpoints, or an error if they could not be constructed. Callers should note
// that this function may perform additional validation on inputs that is not // that this function may perform additional validation on inputs that is not
// backwards-compatible with all command-line validation. The recommendation is // backwards-compatible with all command-line validation. The recommendation is
// to log the error and skip installing the OIDC discovery endpoints. // to log the error and skip installing the OIDC discovery endpoints.
func NewOpenIDMetadata(issuerURL, jwksURI, defaultExternalAddress string, pubKeys []interface{}) (*OpenIDMetadata, error) { func NewOpenIDMetadataProvider(issuerURL, jwksURI, defaultExternalAddress string, pubKeyGetter PublicKeysGetter) (OpenIDMetadataProvider, error) {
if issuerURL == "" { if issuerURL == "" {
return nil, fmt.Errorf("empty issuer URL") return nil, fmt.Errorf("empty issuer URL")
} }
if jwksURI == "" && defaultExternalAddress == "" { if jwksURI == "" && defaultExternalAddress == "" {
return nil, fmt.Errorf("either the JWKS URI or the default external address, or both, must be set") return nil, fmt.Errorf("either the JWKS URI or the default external address, or both, must be set")
} }
if len(pubKeys) == 0 { if pubKeyGetter == nil {
return nil, fmt.Errorf("no keys provided for validating keyset") return nil, fmt.Errorf("no public key getter provided")
} }
// Ensure the issuer URL meets the OIDC spec (this is the additional // Ensure the issuer URL meets the OIDC spec (this is the additional
@ -126,20 +170,18 @@ func NewOpenIDMetadata(issuerURL, jwksURI, defaultExternalAddress string, pubKey
} }
} }
configJSON, err := openIDConfigJSON(issuerURL, jwksURI, pubKeys) provider := &openidConfigProvider{
if err != nil { issuerURL: issuerURL,
return nil, fmt.Errorf("could not marshal issuer discovery JSON, error: %v", err) jwksURI: jwksURI,
pubKeyGetter: pubKeyGetter,
} }
// Register to be notified if public keys change
keysetJSON, err := openIDKeysetJSON(pubKeys) pubKeyGetter.AddListener(provider)
if err != nil { // Synchronously construct the config / keyset json once at startup to ensure a successful starting point
return nil, fmt.Errorf("could not marshal issuer keys JSON, error: %v", err) if err := provider.Update(); err != nil {
return nil, err
} }
return provider, nil
return &OpenIDMetadata{
ConfigJSON: configJSON,
PublicKeysetJSON: keysetJSON,
}, nil
} }
// openIDMetadata provides a minimal subset of OIDC provider metadata: // openIDMetadata provides a minimal subset of OIDC provider metadata:
@ -159,7 +201,7 @@ type openIDMetadata struct {
// openIDConfigJSON returns the JSON OIDC Discovery Doc for the service // openIDConfigJSON returns the JSON OIDC Discovery Doc for the service
// account issuer. // account issuer.
func openIDConfigJSON(iss, jwksURI string, keys []interface{}) ([]byte, error) { func openIDConfigJSON(iss, jwksURI string, keys []PublicKey) ([]byte, error) {
keyset, errs := publicJWKSFromKeys(keys) keyset, errs := publicJWKSFromKeys(keys)
if errs != nil { if errs != nil {
return nil, errs return nil, errs
@ -183,7 +225,7 @@ func openIDConfigJSON(iss, jwksURI string, keys []interface{}) ([]byte, error) {
// openIDKeysetJSON returns the JSON Web Key Set for the service account // openIDKeysetJSON returns the JSON Web Key Set for the service account
// issuer's keys. // issuer's keys.
func openIDKeysetJSON(keys []interface{}) ([]byte, error) { func openIDKeysetJSON(keys []PublicKey) ([]byte, error) {
keyset, errs := publicJWKSFromKeys(keys) keyset, errs := publicJWKSFromKeys(keys)
if errs != nil { if errs != nil {
return nil, errs return nil, errs
@ -212,21 +254,12 @@ type publicKeyGetter interface {
// publicJWKSFromKeys constructs a JSONWebKeySet from a list of keys. The key // publicJWKSFromKeys constructs a JSONWebKeySet from a list of keys. The key
// set will only contain the public keys associated with the input keys. // set will only contain the public keys associated with the input keys.
func publicJWKSFromKeys(in []interface{}) (*jose.JSONWebKeySet, errors.Aggregate) { func publicJWKSFromKeys(in []PublicKey) (*jose.JSONWebKeySet, errors.Aggregate) {
// Decode keys into a JWKS. // Decode keys into a JWKS.
var keys jose.JSONWebKeySet var keys jose.JSONWebKeySet
var errs []error var errs []error
for i, key := range in { for i, key := range in {
var pubkey *jose.JSONWebKey pubkey, err := jwkFromPublicKey(key)
var err error
switch k := key.(type) {
case publicKeyGetter:
// This is a private key. Get its public key
pubkey, err = jwkFromPublicKey(k.Public())
default:
pubkey, err = jwkFromPublicKey(k)
}
if err != nil { if err != nil {
errs = append(errs, fmt.Errorf("error constructing JWK for key #%d: %v", i, err)) errs = append(errs, fmt.Errorf("error constructing JWK for key #%d: %v", i, err))
continue continue
@ -244,21 +277,16 @@ func publicJWKSFromKeys(in []interface{}) (*jose.JSONWebKeySet, errors.Aggregate
return &keys, nil return &keys, nil
} }
func jwkFromPublicKey(publicKey crypto.PublicKey) (*jose.JSONWebKey, error) { func jwkFromPublicKey(publicKey PublicKey) (*jose.JSONWebKey, error) {
alg, err := algorithmFromPublicKey(publicKey) alg, err := algorithmFromPublicKey(publicKey.PublicKey)
if err != nil {
return nil, err
}
keyID, err := keyIDFromPublicKey(publicKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
jwk := &jose.JSONWebKey{ jwk := &jose.JSONWebKey{
Algorithm: string(alg), Algorithm: string(alg),
Key: publicKey, Key: publicKey.PublicKey,
KeyID: keyID, KeyID: publicKey.KeyID,
Use: "sig", Use: "sig",
} }

View File

@ -39,7 +39,7 @@ const (
exampleIssuer = "https://issuer.example.com" exampleIssuer = "https://issuer.example.com"
) )
func setupServer(t *testing.T, iss string, keys []interface{}) (*httptest.Server, string) { func setupServer(t *testing.T, iss string, keys serviceaccount.PublicKeysGetter) (*httptest.Server, string) {
t.Helper() t.Helper()
c := restful.NewContainer() c := restful.NewContainer()
@ -53,13 +53,13 @@ func setupServer(t *testing.T, iss string, keys []interface{}) (*httptest.Server
jwksURI.Scheme = "https" jwksURI.Scheme = "https"
jwksURI.Path = serviceaccount.JWKSPath jwksURI.Path = serviceaccount.JWKSPath
md, err := serviceaccount.NewOpenIDMetadata( md, err := serviceaccount.NewOpenIDMetadataProvider(
iss, jwksURI.String(), "", keys) iss, jwksURI.String(), "", keys)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
srv := routes.NewOpenIDMetadataServer(md.ConfigJSON, md.PublicKeysetJSON) srv := routes.NewOpenIDMetadataServer(md)
srv.Install(c) srv.Install(c)
return s, jwksURI.String() return s, jwksURI.String()
@ -77,20 +77,59 @@ type Configuration struct {
SubjectTypes []string `json:"subject_types_supported"` SubjectTypes []string `json:"subject_types_supported"`
} }
type proxyKeyGetter struct {
serviceaccount.PublicKeysGetter
listeners []serviceaccount.Listener
}
func (p *proxyKeyGetter) AddListener(listener serviceaccount.Listener) {
p.listeners = append(p.listeners, listener)
p.PublicKeysGetter.AddListener(listener)
}
func TestServeConfiguration(t *testing.T) { func TestServeConfiguration(t *testing.T) {
s, jwksURI := setupServer(t, exampleIssuer, defaultKeys) ecKeysGetter, err := serviceaccount.StaticPublicKeysGetter([]interface{}{getPublicKey(ecdsaPublicKey)})
if err != nil {
t.Fatal(err)
}
rsaKeysGetter, err := serviceaccount.StaticPublicKeysGetter([]interface{}{getPublicKey(rsaPublicKey)})
if err != nil {
t.Fatal(err)
}
keysGetter := &proxyKeyGetter{PublicKeysGetter: ecKeysGetter}
s, jwksURI := setupServer(t, exampleIssuer, keysGetter)
defer s.Close() defer s.Close()
want := Configuration{ wantEC := Configuration{
Issuer: exampleIssuer, Issuer: exampleIssuer,
JWKSURI: jwksURI, JWKSURI: jwksURI,
ResponseTypes: []string{"id_token"}, ResponseTypes: []string{"id_token"},
SubjectTypes: []string{"public"}, SubjectTypes: []string{"public"},
SigningAlgs: []string{"ES256", "RS256"}, SigningAlgs: []string{"ES256"},
}
wantRSA := Configuration{
Issuer: exampleIssuer,
JWKSURI: jwksURI,
ResponseTypes: []string{"id_token"},
SubjectTypes: []string{"public"},
SigningAlgs: []string{"RS256"},
} }
reqURL := s.URL + "/.well-known/openid-configuration" reqURL := s.URL + "/.well-known/openid-configuration"
expectConfiguration(t, reqURL, wantEC)
// modify the underlying keys, expect the same response
keysGetter.PublicKeysGetter = rsaKeysGetter
expectConfiguration(t, reqURL, wantEC)
// notify the metadata the keys changed, expected a modified response
for _, listener := range keysGetter.listeners {
listener.Enqueue()
}
expectConfiguration(t, reqURL, wantRSA)
}
func expectConfiguration(t *testing.T, reqURL string, want Configuration) {
resp, err := http.Get(reqURL) resp, err := http.Get(reqURL)
if err != nil { if err != nil {
t.Fatalf("Get(%s) = %v, %v want: <response>, <nil>", reqURL, resp, err) t.Fatalf("Get(%s) = %v, %v want: <response>, <nil>", reqURL, resp, err)
@ -185,16 +224,49 @@ func TestServeKeys(t *testing.T) {
for _, tt := range serveKeysTests { for _, tt := range serveKeysTests {
t.Run(tt.Name, func(t *testing.T) { t.Run(tt.Name, func(t *testing.T) {
s, _ := setupServer(t, exampleIssuer, tt.Keys) initialKeysGetter, err := serviceaccount.StaticPublicKeysGetter(tt.Keys)
if err != nil {
t.Fatal(err)
}
updatedKeysGetter, err := serviceaccount.StaticPublicKeysGetter([]interface{}{wantPubRSA})
if err != nil {
t.Fatal(err)
}
keysGetter := &proxyKeyGetter{PublicKeysGetter: initialKeysGetter}
s, _ := setupServer(t, exampleIssuer, keysGetter)
defer s.Close() defer s.Close()
reqURL := s.URL + "/openid/v1/jwks" reqURL := s.URL + "/openid/v1/jwks"
expectKeys(t, reqURL, tt.WantKeys)
// modify the underlying keys, expect the same response
keysGetter.PublicKeysGetter = updatedKeysGetter
expectKeys(t, reqURL, tt.WantKeys)
// notify the metadata the keys changed, expected a modified response
for _, listener := range keysGetter.listeners {
listener.Enqueue()
}
expectKeys(t, reqURL, []jose.JSONWebKey{{
Algorithm: "RS256",
Key: wantPubRSA,
KeyID: rsaKeyID,
Use: "sig",
Certificates: []*x509.Certificate{},
CertificateThumbprintSHA1: []uint8{},
CertificateThumbprintSHA256: []uint8{},
}})
})
}
}
func expectKeys(t *testing.T, reqURL string, wantKeys []jose.JSONWebKey) {
resp, err := http.Get(reqURL) resp, err := http.Get(reqURL)
if err != nil { if err != nil {
t.Fatalf("Get(%s) = %v, %v want: <response>, <nil>", reqURL, resp, err) t.Fatalf("Get(%s) = %v, %v want: <response>, <nil>", reqURL, resp, err)
} }
defer resp.Body.Close() defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
t.Errorf("Get(%s) = %v, _ want: %v, _", reqURL, resp.StatusCode, http.StatusOK) t.Errorf("Get(%s) = %v, _ want: %v, _", reqURL, resp.StatusCode, http.StatusOK)
@ -216,16 +288,18 @@ func TestServeKeys(t *testing.T) {
func(x, y *big.Int) bool { func(x, y *big.Int) bool {
return x.Cmp(y) == 0 return x.Cmp(y) == 0
}) })
if !cmp.Equal(tt.WantKeys, ks.Keys, bigIntComparer) { if !cmp.Equal(wantKeys, ks.Keys, bigIntComparer) {
t.Errorf("unexpected diff in JWKS keys (-want, +got): %v", t.Errorf("unexpected diff in JWKS keys (-want, +got): %v",
cmp.Diff(tt.WantKeys, ks.Keys, bigIntComparer)) cmp.Diff(wantKeys, ks.Keys, bigIntComparer))
}
})
} }
} }
func TestURLBoundaries(t *testing.T) { func TestURLBoundaries(t *testing.T) {
s, _ := setupServer(t, exampleIssuer, defaultKeys) keysGetter, err := serviceaccount.StaticPublicKeysGetter(defaultKeys)
if err != nil {
t.Fatal(err)
}
s, _ := setupServer(t, exampleIssuer, keysGetter)
defer s.Close() defer s.Close()
for _, tt := range []struct { for _, tt := range []struct {
@ -380,7 +454,11 @@ func TestNewOpenIDMetadata(t *testing.T) {
} }
for _, tc := range cases { for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
md, err := serviceaccount.NewOpenIDMetadata(tc.issuerURL, tc.jwksURI, tc.externalAddress, tc.keys) keysGetter, err := serviceaccount.StaticPublicKeysGetter(tc.keys)
if err != nil {
t.Fatal(err)
}
md, err := serviceaccount.NewOpenIDMetadataProvider(tc.issuerURL, tc.jwksURI, tc.externalAddress, keysGetter)
if tc.err { if tc.err {
if err == nil { if err == nil {
t.Fatalf("got <nil>, want error") t.Fatalf("got <nil>, want error")
@ -390,13 +468,13 @@ func TestNewOpenIDMetadata(t *testing.T) {
t.Fatalf("got error %v, want <nil>", err) t.Fatalf("got error %v, want <nil>", err)
} }
config := string(md.ConfigJSON) config, _ := md.GetConfigJSON()
keyset := string(md.PublicKeysetJSON) keyset, _ := md.GetKeysetJSON()
if config != tc.wantConfig { if string(config) != tc.wantConfig {
t.Errorf("got metadata %s, want %s", config, tc.wantConfig) t.Errorf("got metadata %s, want %s", string(config), tc.wantConfig)
} }
if keyset != tc.wantKeyset { if string(keyset) != tc.wantKeyset {
t.Errorf("got keyset %s, want %s", keyset, tc.wantKeyset) t.Errorf("got keyset %s, want %s", string(keyset), tc.wantKeyset)
} }
}) })
} }