From 3da48466d61a9a5a2b7aefe1bafaa1ebd0e22b49 Mon Sep 17 00:00:00 2001 From: Monis Khan Date: Thu, 23 May 2024 15:14:19 -0400 Subject: [PATCH] serviceaccount: use generics to remove runtime type checks during validation Signed-off-by: Monis Khan --- pkg/serviceaccount/claims.go | 17 ++++------------- pkg/serviceaccount/jwt.go | 26 +++++++++++--------------- pkg/serviceaccount/legacy.go | 17 ++++------------- 3 files changed, 19 insertions(+), 41 deletions(-) diff --git a/pkg/serviceaccount/claims.go b/pkg/serviceaccount/claims.go index 61128579bce..64feaca7ed3 100644 --- a/pkg/serviceaccount/claims.go +++ b/pkg/serviceaccount/claims.go @@ -24,11 +24,11 @@ import ( "github.com/google/uuid" "gopkg.in/square/go-jose.v2/jwt" - "k8s.io/klog/v2" "k8s.io/apiserver/pkg/audit" apiserverserviceaccount "k8s.io/apiserver/pkg/authentication/serviceaccount" utilfeature "k8s.io/apiserver/pkg/util/feature" + "k8s.io/klog/v2" "k8s.io/kubernetes/pkg/apis/core" "k8s.io/kubernetes/pkg/features" ) @@ -128,7 +128,7 @@ func Claims(sa core.ServiceAccount, pod *core.Pod, secret *core.Secret, node *co return sc, pc, nil } -func NewValidator(getter ServiceAccountTokenGetter) Validator { +func NewValidator(getter ServiceAccountTokenGetter) Validator[privateClaims] { return &validator{ getter: getter, } @@ -138,14 +138,9 @@ type validator struct { getter ServiceAccountTokenGetter } -var _ = Validator(&validator{}) +var _ = Validator[privateClaims](&validator{}) -func (v *validator) Validate(ctx context.Context, _ string, public *jwt.Claims, privateObj interface{}) (*apiserverserviceaccount.ServiceAccountInfo, error) { - private, ok := privateObj.(*privateClaims) - if !ok { - klog.Errorf("service account jwt validator expected private claim of type *privateClaims but got: %T", privateObj) - return nil, errors.New("service account token claims could not be validated due to unexpected private claim") - } +func (v *validator) Validate(ctx context.Context, _ string, public *jwt.Claims, private *privateClaims) (*apiserverserviceaccount.ServiceAccountInfo, error) { nowTime := now() err := public.Validate(jwt.Expected{ Time: nowTime, @@ -294,7 +289,3 @@ func (v *validator) Validate(ctx context.Context, _ string, public *jwt.Claims, CredentialID: apiserverserviceaccount.CredentialIDForJTI(jti), }, nil } - -func (v *validator) NewPrivateClaims() interface{} { - return &privateClaims{} -} diff --git a/pkg/serviceaccount/jwt.go b/pkg/serviceaccount/jwt.go index 45bf3828074..d6168e3411b 100644 --- a/pkg/serviceaccount/jwt.go +++ b/pkg/serviceaccount/jwt.go @@ -225,12 +225,12 @@ func (j *jwtTokenGenerator) GenerateToken(claims *jwt.Claims, privateClaims inte // JWTTokenAuthenticator authenticates tokens as JWT tokens produced by JWTTokenGenerator // Token signatures are verified using each of the given public keys until one works (allowing key rotation) // If lookup is true, the service account and secret referenced as claims inside the token are retrieved and verified with the provided ServiceAccountTokenGetter -func JWTTokenAuthenticator(issuers []string, keys []interface{}, implicitAuds authenticator.Audiences, validator Validator) authenticator.Token { +func JWTTokenAuthenticator[PrivateClaims any](issuers []string, keys []interface{}, implicitAuds authenticator.Audiences, validator Validator[PrivateClaims]) authenticator.Token { issuersMap := make(map[string]bool) for _, issuer := range issuers { issuersMap[issuer] = true } - return &jwtTokenAuthenticator{ + return &jwtTokenAuthenticator[PrivateClaims]{ issuers: issuersMap, keys: keys, implicitAuds: implicitAuds, @@ -238,29 +238,25 @@ func JWTTokenAuthenticator(issuers []string, keys []interface{}, implicitAuds au } } -type jwtTokenAuthenticator struct { +type jwtTokenAuthenticator[PrivateClaims any] struct { issuers map[string]bool keys []interface{} - validator Validator + validator Validator[PrivateClaims] implicitAuds authenticator.Audiences } // Validator is called by the JWT token authenticator to apply domain specific // validation to a token and extract user information. -type Validator interface { +// PrivateClaims is the struct that the authenticator should deserialize the JWT payload into, thus +// it should contain fields for any private claims that the Validator requires to validate the JWT. +type Validator[PrivateClaims any] interface { // Validate validates a token and returns user information or an error. // Validator can assume that the issuer and signature of a token are already // verified when this function is called. - Validate(ctx context.Context, tokenData string, public *jwt.Claims, private interface{}) (*apiserverserviceaccount.ServiceAccountInfo, error) - // NewPrivateClaims returns a struct that the authenticator should - // deserialize the JWT payload into. The authenticator may then pass this - // struct back to the Validator as the 'private' argument to a Validate() - // call. This struct should contain fields for any private claims that the - // Validator requires to validate the JWT. - NewPrivateClaims() interface{} + Validate(ctx context.Context, tokenData string, public *jwt.Claims, private *PrivateClaims) (*apiserverserviceaccount.ServiceAccountInfo, error) } -func (j *jwtTokenAuthenticator) AuthenticateToken(ctx context.Context, tokenData string) (*authenticator.Response, bool, error) { +func (j *jwtTokenAuthenticator[PrivateClaims]) AuthenticateToken(ctx context.Context, tokenData string) (*authenticator.Response, bool, error) { if !j.hasCorrectIssuer(tokenData) { return nil, false, nil } @@ -271,7 +267,7 @@ func (j *jwtTokenAuthenticator) AuthenticateToken(ctx context.Context, tokenData } public := &jwt.Claims{} - private := j.validator.NewPrivateClaims() + private := new(PrivateClaims) // TODO: Pick the key that has the same key ID as `tok`, if one exists. var ( @@ -334,7 +330,7 @@ func (j *jwtTokenAuthenticator) AuthenticateToken(ctx context.Context, tokenData // // Note: go-jose currently does not allow access to unverified JWS payloads. // See https://github.com/square/go-jose/issues/169 -func (j *jwtTokenAuthenticator) hasCorrectIssuer(tokenData string) bool { +func (j *jwtTokenAuthenticator[PrivateClaims]) hasCorrectIssuer(tokenData string) bool { if strings.HasPrefix(strings.TrimSpace(tokenData), "{") { return false } diff --git a/pkg/serviceaccount/legacy.go b/pkg/serviceaccount/legacy.go index f05c4d1b171..52951d58ec2 100644 --- a/pkg/serviceaccount/legacy.go +++ b/pkg/serviceaccount/legacy.go @@ -25,6 +25,7 @@ import ( "time" "gopkg.in/square/go-jose.v2/jwt" + v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" @@ -61,7 +62,7 @@ type legacyPrivateClaims struct { Namespace string `json:"kubernetes.io/serviceaccount/namespace"` } -func NewLegacyValidator(lookup bool, getter ServiceAccountTokenGetter, secretsWriter typedv1core.SecretsGetter) (Validator, error) { +func NewLegacyValidator(lookup bool, getter ServiceAccountTokenGetter, secretsWriter typedv1core.SecretsGetter) (Validator[legacyPrivateClaims], error) { if lookup && getter == nil { return nil, errors.New("ServiceAccountTokenGetter must be provided") } @@ -81,15 +82,9 @@ type legacyValidator struct { secretsWriter typedv1core.SecretsGetter } -var _ = Validator(&legacyValidator{}) - -func (v *legacyValidator) Validate(ctx context.Context, tokenData string, public *jwt.Claims, privateObj interface{}) (*apiserverserviceaccount.ServiceAccountInfo, error) { - private, ok := privateObj.(*legacyPrivateClaims) - if !ok { - klog.Errorf("jwt validator expected private claim of type *legacyPrivateClaims but got: %T", privateObj) - return nil, errors.New("Token could not be validated.") - } +var _ = Validator[legacyPrivateClaims](&legacyValidator{}) +func (v *legacyValidator) Validate(ctx context.Context, tokenData string, public *jwt.Claims, private *legacyPrivateClaims) (*apiserverserviceaccount.ServiceAccountInfo, error) { // Make sure the claims we need exist if len(public.Subject) == 0 { return nil, errors.New("sub claim is missing") @@ -201,7 +196,3 @@ func (v *legacyValidator) patchSecretWithLastUsedDate(ctx context.Context, secre } } } - -func (v *legacyValidator) NewPrivateClaims() interface{} { - return &legacyPrivateClaims{} -}