serviceaccount: use generics to remove runtime type checks during validation

Signed-off-by: Monis Khan <mok@microsoft.com>
This commit is contained in:
Monis Khan 2024-05-23 15:14:19 -04:00
parent 5a121aad53
commit 3da48466d6
No known key found for this signature in database
3 changed files with 19 additions and 41 deletions

View File

@ -24,11 +24,11 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
"k8s.io/klog/v2"
"k8s.io/apiserver/pkg/audit" "k8s.io/apiserver/pkg/audit"
apiserverserviceaccount "k8s.io/apiserver/pkg/authentication/serviceaccount" apiserverserviceaccount "k8s.io/apiserver/pkg/authentication/serviceaccount"
utilfeature "k8s.io/apiserver/pkg/util/feature" utilfeature "k8s.io/apiserver/pkg/util/feature"
"k8s.io/klog/v2"
"k8s.io/kubernetes/pkg/apis/core" "k8s.io/kubernetes/pkg/apis/core"
"k8s.io/kubernetes/pkg/features" "k8s.io/kubernetes/pkg/features"
) )
@ -128,7 +128,7 @@ func Claims(sa core.ServiceAccount, pod *core.Pod, secret *core.Secret, node *co
return sc, pc, nil return sc, pc, nil
} }
func NewValidator(getter ServiceAccountTokenGetter) Validator { func NewValidator(getter ServiceAccountTokenGetter) Validator[privateClaims] {
return &validator{ return &validator{
getter: getter, getter: getter,
} }
@ -138,14 +138,9 @@ type validator struct {
getter ServiceAccountTokenGetter getter ServiceAccountTokenGetter
} }
var _ = Validator(&validator{}) var _ = Validator[privateClaims](&validator{})
func (v *validator) Validate(ctx context.Context, _ string, public *jwt.Claims, privateObj interface{}) (*apiserverserviceaccount.ServiceAccountInfo, error) { func (v *validator) Validate(ctx context.Context, _ string, public *jwt.Claims, private *privateClaims) (*apiserverserviceaccount.ServiceAccountInfo, error) {
private, ok := privateObj.(*privateClaims)
if !ok {
klog.Errorf("service account jwt validator expected private claim of type *privateClaims but got: %T", privateObj)
return nil, errors.New("service account token claims could not be validated due to unexpected private claim")
}
nowTime := now() nowTime := now()
err := public.Validate(jwt.Expected{ err := public.Validate(jwt.Expected{
Time: nowTime, Time: nowTime,
@ -294,7 +289,3 @@ func (v *validator) Validate(ctx context.Context, _ string, public *jwt.Claims,
CredentialID: apiserverserviceaccount.CredentialIDForJTI(jti), CredentialID: apiserverserviceaccount.CredentialIDForJTI(jti),
}, nil }, nil
} }
func (v *validator) NewPrivateClaims() interface{} {
return &privateClaims{}
}

View File

@ -225,12 +225,12 @@ func (j *jwtTokenGenerator) GenerateToken(claims *jwt.Claims, privateClaims inte
// JWTTokenAuthenticator authenticates tokens as JWT tokens produced by JWTTokenGenerator // JWTTokenAuthenticator authenticates tokens as JWT tokens produced by JWTTokenGenerator
// Token signatures are verified using each of the given public keys until one works (allowing key rotation) // Token signatures are verified using each of the given public keys until one works (allowing key rotation)
// If lookup is true, the service account and secret referenced as claims inside the token are retrieved and verified with the provided ServiceAccountTokenGetter // If lookup is true, the service account and secret referenced as claims inside the token are retrieved and verified with the provided ServiceAccountTokenGetter
func JWTTokenAuthenticator(issuers []string, keys []interface{}, implicitAuds authenticator.Audiences, validator Validator) authenticator.Token { func JWTTokenAuthenticator[PrivateClaims any](issuers []string, keys []interface{}, implicitAuds authenticator.Audiences, validator Validator[PrivateClaims]) authenticator.Token {
issuersMap := make(map[string]bool) issuersMap := make(map[string]bool)
for _, issuer := range issuers { for _, issuer := range issuers {
issuersMap[issuer] = true issuersMap[issuer] = true
} }
return &jwtTokenAuthenticator{ return &jwtTokenAuthenticator[PrivateClaims]{
issuers: issuersMap, issuers: issuersMap,
keys: keys, keys: keys,
implicitAuds: implicitAuds, implicitAuds: implicitAuds,
@ -238,29 +238,25 @@ func JWTTokenAuthenticator(issuers []string, keys []interface{}, implicitAuds au
} }
} }
type jwtTokenAuthenticator struct { type jwtTokenAuthenticator[PrivateClaims any] struct {
issuers map[string]bool issuers map[string]bool
keys []interface{} keys []interface{}
validator Validator validator Validator[PrivateClaims]
implicitAuds authenticator.Audiences implicitAuds authenticator.Audiences
} }
// Validator is called by the JWT token authenticator to apply domain specific // Validator is called by the JWT token authenticator to apply domain specific
// validation to a token and extract user information. // validation to a token and extract user information.
type Validator interface { // PrivateClaims is the struct that the authenticator should deserialize the JWT payload into, thus
// it should contain fields for any private claims that the Validator requires to validate the JWT.
type Validator[PrivateClaims any] interface {
// Validate validates a token and returns user information or an error. // Validate validates a token and returns user information or an error.
// Validator can assume that the issuer and signature of a token are already // Validator can assume that the issuer and signature of a token are already
// verified when this function is called. // verified when this function is called.
Validate(ctx context.Context, tokenData string, public *jwt.Claims, private interface{}) (*apiserverserviceaccount.ServiceAccountInfo, error) Validate(ctx context.Context, tokenData string, public *jwt.Claims, private *PrivateClaims) (*apiserverserviceaccount.ServiceAccountInfo, error)
// NewPrivateClaims returns a struct that the authenticator should
// deserialize the JWT payload into. The authenticator may then pass this
// struct back to the Validator as the 'private' argument to a Validate()
// call. This struct should contain fields for any private claims that the
// Validator requires to validate the JWT.
NewPrivateClaims() interface{}
} }
func (j *jwtTokenAuthenticator) AuthenticateToken(ctx context.Context, tokenData string) (*authenticator.Response, bool, error) { func (j *jwtTokenAuthenticator[PrivateClaims]) AuthenticateToken(ctx context.Context, tokenData string) (*authenticator.Response, bool, error) {
if !j.hasCorrectIssuer(tokenData) { if !j.hasCorrectIssuer(tokenData) {
return nil, false, nil return nil, false, nil
} }
@ -271,7 +267,7 @@ func (j *jwtTokenAuthenticator) AuthenticateToken(ctx context.Context, tokenData
} }
public := &jwt.Claims{} public := &jwt.Claims{}
private := j.validator.NewPrivateClaims() private := new(PrivateClaims)
// TODO: Pick the key that has the same key ID as `tok`, if one exists. // TODO: Pick the key that has the same key ID as `tok`, if one exists.
var ( var (
@ -334,7 +330,7 @@ func (j *jwtTokenAuthenticator) AuthenticateToken(ctx context.Context, tokenData
// //
// Note: go-jose currently does not allow access to unverified JWS payloads. // Note: go-jose currently does not allow access to unverified JWS payloads.
// See https://github.com/square/go-jose/issues/169 // See https://github.com/square/go-jose/issues/169
func (j *jwtTokenAuthenticator) hasCorrectIssuer(tokenData string) bool { func (j *jwtTokenAuthenticator[PrivateClaims]) hasCorrectIssuer(tokenData string) bool {
if strings.HasPrefix(strings.TrimSpace(tokenData), "{") { if strings.HasPrefix(strings.TrimSpace(tokenData), "{") {
return false return false
} }

View File

@ -25,6 +25,7 @@ import (
"time" "time"
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/types"
@ -61,7 +62,7 @@ type legacyPrivateClaims struct {
Namespace string `json:"kubernetes.io/serviceaccount/namespace"` Namespace string `json:"kubernetes.io/serviceaccount/namespace"`
} }
func NewLegacyValidator(lookup bool, getter ServiceAccountTokenGetter, secretsWriter typedv1core.SecretsGetter) (Validator, error) { func NewLegacyValidator(lookup bool, getter ServiceAccountTokenGetter, secretsWriter typedv1core.SecretsGetter) (Validator[legacyPrivateClaims], error) {
if lookup && getter == nil { if lookup && getter == nil {
return nil, errors.New("ServiceAccountTokenGetter must be provided") return nil, errors.New("ServiceAccountTokenGetter must be provided")
} }
@ -81,15 +82,9 @@ type legacyValidator struct {
secretsWriter typedv1core.SecretsGetter secretsWriter typedv1core.SecretsGetter
} }
var _ = Validator(&legacyValidator{}) var _ = Validator[legacyPrivateClaims](&legacyValidator{})
func (v *legacyValidator) Validate(ctx context.Context, tokenData string, public *jwt.Claims, privateObj interface{}) (*apiserverserviceaccount.ServiceAccountInfo, error) {
private, ok := privateObj.(*legacyPrivateClaims)
if !ok {
klog.Errorf("jwt validator expected private claim of type *legacyPrivateClaims but got: %T", privateObj)
return nil, errors.New("Token could not be validated.")
}
func (v *legacyValidator) Validate(ctx context.Context, tokenData string, public *jwt.Claims, private *legacyPrivateClaims) (*apiserverserviceaccount.ServiceAccountInfo, error) {
// Make sure the claims we need exist // Make sure the claims we need exist
if len(public.Subject) == 0 { if len(public.Subject) == 0 {
return nil, errors.New("sub claim is missing") return nil, errors.New("sub claim is missing")
@ -201,7 +196,3 @@ func (v *legacyValidator) patchSecretWithLastUsedDate(ctx context.Context, secre
} }
} }
} }
func (v *legacyValidator) NewPrivateClaims() interface{} {
return &legacyPrivateClaims{}
}