diff --git a/pkg/serviceaccount/BUILD b/pkg/serviceaccount/BUILD index 8bd9e629607..bd314ae6a26 100644 --- a/pkg/serviceaccount/BUILD +++ b/pkg/serviceaccount/BUILD @@ -60,6 +60,7 @@ go_test( "//staging/src/k8s.io/client-go/listers/core/v1:go_default_library", "//staging/src/k8s.io/client-go/tools/cache:go_default_library", "//staging/src/k8s.io/client-go/util/keyutil:go_default_library", + "//vendor/gopkg.in/square/go-jose.v2:go_default_library", "//vendor/gopkg.in/square/go-jose.v2/jwt:go_default_library", ], ) diff --git a/pkg/serviceaccount/jwt.go b/pkg/serviceaccount/jwt.go index 233fdee2d68..ba03a22233d 100644 --- a/pkg/serviceaccount/jwt.go +++ b/pkg/serviceaccount/jwt.go @@ -18,9 +18,11 @@ package serviceaccount import ( "context" + "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rsa" + "crypto/x509" "encoding/base64" "encoding/json" "fmt" @@ -29,7 +31,7 @@ import ( jose "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" - "k8s.io/api/core/v1" + v1 "k8s.io/api/core/v1" utilerrors "k8s.io/apimachinery/pkg/util/errors" "k8s.io/apiserver/pkg/authentication/authenticator" ) @@ -53,43 +55,148 @@ type TokenGenerator interface { // JWTTokenGenerator returns a TokenGenerator that generates signed JWT tokens, using the given privateKey. // privateKey is a PEM-encoded byte array of a private RSA key. -// JWTTokenAuthenticator() func JWTTokenGenerator(iss string, privateKey interface{}) (TokenGenerator, error) { - var alg jose.SignatureAlgorithm + var signer jose.Signer + var err error switch pk := privateKey.(type) { case *rsa.PrivateKey: - alg = jose.RS256 + signer, err = signerFromRSAPrivateKey(pk) + if err != nil { + return nil, fmt.Errorf("could not generate signer for RSA keypair: %v", err) + } case *ecdsa.PrivateKey: - switch pk.Curve { - case elliptic.P256(): - alg = jose.ES256 - case elliptic.P384(): - alg = jose.ES384 - case elliptic.P521(): - alg = jose.ES512 - default: - return nil, fmt.Errorf("unknown private key curve, must be 256, 384, or 521") + signer, err = signerFromECDSAPrivateKey(pk) + if err != nil { + return nil, fmt.Errorf("could not generate signer for ECDSA keypair: %v", err) } case jose.OpaqueSigner: - alg = jose.SignatureAlgorithm(pk.Public().Algorithm) + signer, err = signerFromOpaqueSigner(pk) + if err != nil { + return nil, fmt.Errorf("could not generate signer for OpaqueSigner: %v", err) + } default: return nil, fmt.Errorf("unknown private key type %T, must be *rsa.PrivateKey, *ecdsa.PrivateKey, or jose.OpaqueSigner", privateKey) } + return &jwtTokenGenerator{ + iss: iss, + signer: signer, + }, nil +} + +// keyIDFromPublicKey derives a key ID non-reversibly from a public key. +// +// The Key ID is field on a given on JWTs and JWKs that help relying parties +// pick the correct key for verification when the identity party advertises +// multiple keys. +// +// Making the derivation non-reversible makes it impossible for someone to +// accidentally obtain the real key from the key ID and use it for token +// validation. +func keyIDFromPublicKey(publicKey interface{}) (string, error) { + publicKeyDERBytes, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + return "", fmt.Errorf("failed to serialize public key to DER format: %v", err) + } + + hasher := crypto.SHA256.New() + hasher.Write(publicKeyDERBytes) + publicKeyDERHash := hasher.Sum(nil) + + keyID := base64.RawURLEncoding.EncodeToString(publicKeyDERHash) + + return keyID, nil +} + +func signerFromRSAPrivateKey(keyPair *rsa.PrivateKey) (jose.Signer, error) { + keyID, err := keyIDFromPublicKey(&keyPair.PublicKey) + if err != nil { + return nil, fmt.Errorf("failed to derive keyID: %v", err) + } + + // Wrap the RSA keypair in a JOSE JWK with the designated key ID. + privateJWK := &jose.JSONWebKey{ + Algorithm: string(jose.RS256), + Key: keyPair, + KeyID: keyID, + Use: "sig", + } + + signer, err := jose.NewSigner( + jose.SigningKey{ + Algorithm: jose.RS256, + Key: privateJWK, + }, + nil, + ) + + if err != nil { + return nil, fmt.Errorf("failed to create signer: %v", err) + } + + return signer, nil +} + +func signerFromECDSAPrivateKey(keyPair *ecdsa.PrivateKey) (jose.Signer, error) { + var alg jose.SignatureAlgorithm + switch keyPair.Curve { + case elliptic.P256(): + alg = jose.ES256 + case elliptic.P384(): + alg = jose.ES384 + case elliptic.P521(): + alg = jose.ES512 + default: + return nil, fmt.Errorf("unknown private key curve, must be 256, 384, or 521") + } + + keyID, err := keyIDFromPublicKey(&keyPair.PublicKey) + if err != nil { + return nil, fmt.Errorf("failed to derive keyID: %v", err) + } + + // Wrap the ECDSA keypair in a JOSE JWK with the designated key ID. + privateJWK := &jose.JSONWebKey{ + Algorithm: string(alg), + Key: keyPair, + KeyID: keyID, + Use: "sig", + } + signer, err := jose.NewSigner( jose.SigningKey{ Algorithm: alg, - Key: privateKey, + Key: privateJWK, }, nil, ) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create signer: %v", err) } - return &jwtTokenGenerator{ - iss: iss, - signer: signer, - }, nil + + return signer, nil +} + +func signerFromOpaqueSigner(opaqueSigner jose.OpaqueSigner) (jose.Signer, error) { + alg := jose.SignatureAlgorithm(opaqueSigner.Public().Algorithm) + + signer, err := jose.NewSigner( + jose.SigningKey{ + Algorithm: alg, + Key: &jose.JSONWebKey{ + Algorithm: string(alg), + Key: opaqueSigner, + KeyID: opaqueSigner.Public().KeyID, + Use: "sig", + }, + }, + nil, + ) + if err != nil { + return nil, fmt.Errorf("failed to create signer: %v", err) + } + + return signer, nil } type jwtTokenGenerator struct { @@ -155,6 +262,7 @@ func (j *jwtTokenAuthenticator) AuthenticateToken(ctx context.Context, tokenData public := &jwt.Claims{} private := j.validator.NewPrivateClaims() + // TODO: Pick the key that has the same key ID as `tok`, if one exists. var ( found bool errlist []error diff --git a/pkg/serviceaccount/jwt_test.go b/pkg/serviceaccount/jwt_test.go index a4fe25bc823..6e5b979811d 100644 --- a/pkg/serviceaccount/jwt_test.go +++ b/pkg/serviceaccount/jwt_test.go @@ -22,6 +22,8 @@ import ( "strings" "testing" + jose "gopkg.in/square/go-jose.v2" + v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apiserver/pkg/authentication/authenticator" @@ -55,6 +57,13 @@ WwIDAQAB -----END PUBLIC KEY----- ` +// Obtained by: +// +// 1. Serializing rsaPublicKey as DER +// 2. Taking the SHA256 of the DER bytes +// 3. URLSafe Base64-encoding the sha bytes +const rsaKeyID = "JHJehTTTZlsspKHT-GaJxK7Kd1NQgZJu3fyK6K_QDYU" + const rsaPrivateKey = `-----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEA249XwEo9k4tM8fMxV7zxOhcrP+WvXn917koM5Qr2ZXs4vo26 e4ytdlrV0bQ9SlcLpQVSYjIxNfhTZdDt+ecIzshKuv1gKIxbbLQMOuK1eA/4HALy @@ -97,6 +106,13 @@ MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEH6cuzP8XuD5wal6wf9M6xDljTOPL X2i8uIp/C/ASqiIGUeeKQtX0/IR3qCXyThP/dbCiHrF3v1cuhBOHY8CLVg== -----END PUBLIC KEY-----` +// Obtained by: +// +// 1. Serializing ecdsaPublicKey as DER +// 2. Taking the SHA256 of the DER bytes +// 3. URLSafe Base64-encoding the sha bytes +const ecdsaKeyID = "SoABiieYuNx4UdqYvZRVeuC6SihxgLrhLy9peHMHpTc" + func getPrivateKey(data string) interface{} { key, _ := keyutil.ParsePrivateKeyPEM([]byte(data)) return key @@ -106,6 +122,7 @@ func getPublicKey(data string) interface{} { keys, _ := keyutil.ParsePublicKeysPEM([]byte(data)) return keys[0] } + func TestTokenGenerateAndValidate(t *testing.T) { expectedUserName := "system:serviceaccount:test:my-service-account" expectedUserUID := "12345" @@ -147,6 +164,8 @@ func TestTokenGenerateAndValidate(t *testing.T) { "token": []byte(rsaToken), } + checkJSONWebSignatureHasKeyID(t, rsaToken, rsaKeyID) + // Generate the ECDSA token ecdsaGenerator, err := serviceaccount.JWTTokenGenerator(serviceaccount.LegacyIssuer, getPrivateKey(ecdsaPrivateKey)) if err != nil { @@ -163,6 +182,8 @@ func TestTokenGenerateAndValidate(t *testing.T) { "token": []byte(ecdsaToken), } + checkJSONWebSignatureHasKeyID(t, ecdsaToken, ecdsaKeyID) + // Generate signer with same keys as RSA signer but different issuer badIssuerGenerator, err := serviceaccount.JWTTokenGenerator("foo", getPrivateKey(rsaPrivateKey)) if err != nil { @@ -331,6 +352,17 @@ func TestTokenGenerateAndValidate(t *testing.T) { } } +func checkJSONWebSignatureHasKeyID(t *testing.T, jwsString string, expectedKeyID string) { + jws, err := jose.ParseSigned(jwsString) + if err != nil { + t.Fatalf("Error checking for key ID: couldn't parse token: %v", err) + } + + if jws.Signatures[0].Header.KeyID != expectedKeyID { + t.Errorf("Token %q has the wrong KeyID (got %q, want %q)", jwsString, jws.Signatures[0].Header.KeyID, expectedKeyID) + } +} + func newIndexer(get func(namespace, name string) (interface{}, error)) cache.Indexer { return &fakeIndexer{get: get} }