mirror of
https://github.com/kubernetes/client-go.git
synced 2025-07-05 11:16:23 +00:00
oidc client plugin: reduce round trips and switch to golang.org/x/oauth2
This PR attempts to simplify the OpenID Connect client plugin to reduce round trips. The steps taken by the client are now: * If ID Token isn't expired: * Do nothing. * If ID Token is expired: * Query /.well-known discovery URL to find token_endpoint. * Use an OAuth2 client and refresh token to request new ID token. This avoids the previous pattern of always initializing a client, which would hit the /.well-known endpoint several times. The client no longer does token validation since the server already does this. As a result, this code no longer imports github.com/coreos/go-oidc, instead just using golang.org/x/oauth2 for refreshing. Kubernetes-commit: 6915f857574505a2cd2072c32d9d6da66ce6f55a
This commit is contained in:
parent
a5e25d7218
commit
1309db5ec6
@ -13,12 +13,6 @@ go_test(
|
||||
srcs = ["oidc_test.go"],
|
||||
library = ":go_default_library",
|
||||
tags = ["automanaged"],
|
||||
deps = [
|
||||
"//vendor/github.com/coreos/go-oidc/jose:go_default_library",
|
||||
"//vendor/github.com/coreos/go-oidc/key:go_default_library",
|
||||
"//vendor/github.com/coreos/go-oidc/oauth2:go_default_library",
|
||||
"//vendor/k8s.io/client-go/plugin/pkg/auth/authenticator/token/oidc/testing:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
go_library(
|
||||
@ -26,10 +20,8 @@ go_library(
|
||||
srcs = ["oidc.go"],
|
||||
tags = ["automanaged"],
|
||||
deps = [
|
||||
"//vendor/github.com/coreos/go-oidc/jose:go_default_library",
|
||||
"//vendor/github.com/coreos/go-oidc/oauth2:go_default_library",
|
||||
"//vendor/github.com/coreos/go-oidc/oidc:go_default_library",
|
||||
"//vendor/github.com/golang/glog:go_default_library",
|
||||
"//vendor/golang.org/x/oauth2:go_default_library",
|
||||
"//vendor/k8s.io/client-go/rest:go_default_library",
|
||||
],
|
||||
)
|
||||
|
@ -17,19 +17,19 @@ limitations under the License.
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/oauth2"
|
||||
"github.com/coreos/go-oidc/oidc"
|
||||
"github.com/golang/glog"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
restclient "k8s.io/client-go/rest"
|
||||
)
|
||||
|
||||
@ -39,9 +39,11 @@ const (
|
||||
cfgClientSecret = "client-secret"
|
||||
cfgCertificateAuthority = "idp-certificate-authority"
|
||||
cfgCertificateAuthorityData = "idp-certificate-authority-data"
|
||||
cfgExtraScopes = "extra-scopes"
|
||||
cfgIDToken = "id-token"
|
||||
cfgRefreshToken = "refresh-token"
|
||||
|
||||
// Unused. Scopes aren't sent during refreshing.
|
||||
cfgExtraScopes = "extra-scopes"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -59,9 +61,12 @@ const expiryDelta = 10 * time.Second
|
||||
|
||||
var cache = newClientCache()
|
||||
|
||||
// Like TLS transports, keep a cache of OIDC clients indexed by issuer URL.
|
||||
// Like TLS transports, keep a cache of OIDC clients indexed by issuer URL. This ensures
|
||||
// current requests from different clients don't concurrently attempt to refresh the same
|
||||
// set of credentials.
|
||||
type clientCache struct {
|
||||
mu sync.RWMutex
|
||||
mu sync.RWMutex
|
||||
|
||||
cache map[cacheKey]*oidcAuthProvider
|
||||
}
|
||||
|
||||
@ -72,27 +77,22 @@ func newClientCache() *clientCache {
|
||||
type cacheKey struct {
|
||||
// Canonical issuer URL string of the provider.
|
||||
issuerURL string
|
||||
|
||||
clientID string
|
||||
clientSecret string
|
||||
|
||||
// Don't use CA as cache key because we only add a cache entry if we can connect
|
||||
// to the issuer in the first place. A valid CA is a prerequisite.
|
||||
clientID string
|
||||
}
|
||||
|
||||
func (c *clientCache) getClient(issuer, clientID, clientSecret string) (*oidcAuthProvider, bool) {
|
||||
func (c *clientCache) getClient(issuer, clientID string) (*oidcAuthProvider, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
client, ok := c.cache[cacheKey{issuer, clientID, clientSecret}]
|
||||
client, ok := c.cache[cacheKey{issuer, clientID}]
|
||||
return client, ok
|
||||
}
|
||||
|
||||
// setClient attempts to put the client in the cache but may return any clients
|
||||
// with the same keys set before. This is so there's only ever one client for a provider.
|
||||
func (c *clientCache) setClient(issuer, clientID, clientSecret string, client *oidcAuthProvider) *oidcAuthProvider {
|
||||
func (c *clientCache) setClient(issuer, clientID string, client *oidcAuthProvider) *oidcAuthProvider {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
key := cacheKey{issuer, clientID, clientSecret}
|
||||
key := cacheKey{issuer, clientID}
|
||||
|
||||
// If another client has already initialized a client for the given provider we want
|
||||
// to use that client instead of the one we're trying to set. This is so all transports
|
||||
@ -117,14 +117,14 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
|
||||
return nil, fmt.Errorf("Must provide %s", cfgClientID)
|
||||
}
|
||||
|
||||
clientSecret := cfg[cfgClientSecret]
|
||||
if clientSecret == "" {
|
||||
return nil, fmt.Errorf("Must provide %s", cfgClientSecret)
|
||||
// Check cache for existing provider.
|
||||
if provider, ok := cache.getClient(issuer, clientID); ok {
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// Check cache for existing provider.
|
||||
if provider, ok := cache.getClient(issuer, clientID, clientSecret); ok {
|
||||
return provider, nil
|
||||
if len(cfg[cfgExtraScopes]) > 0 {
|
||||
glog.V(2).Infof("%s auth provider field depricated, refresh request don't send scopes",
|
||||
cfgExtraScopes)
|
||||
}
|
||||
|
||||
var certAuthData []byte
|
||||
@ -149,41 +149,20 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
|
||||
}
|
||||
hc := &http.Client{Transport: trans}
|
||||
|
||||
providerCfg, err := oidc.FetchProviderConfig(hc, issuer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error fetching provider config: %v", err)
|
||||
}
|
||||
|
||||
scopes := strings.Split(cfg[cfgExtraScopes], ",")
|
||||
oidcCfg := oidc.ClientConfig{
|
||||
HTTPClient: hc,
|
||||
Credentials: oidc.ClientCredentials{
|
||||
ID: clientID,
|
||||
Secret: clientSecret,
|
||||
},
|
||||
ProviderConfig: providerCfg,
|
||||
Scope: append(scopes, oidc.DefaultScope...),
|
||||
}
|
||||
client, err := oidc.NewClient(oidcCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating OIDC Client: %v", err)
|
||||
}
|
||||
|
||||
provider := &oidcAuthProvider{
|
||||
client: &oidcClient{client},
|
||||
client: hc,
|
||||
now: time.Now,
|
||||
cfg: cfg,
|
||||
persister: persister,
|
||||
now: time.Now,
|
||||
}
|
||||
|
||||
return cache.setClient(issuer, clientID, clientSecret, provider), nil
|
||||
return cache.setClient(issuer, clientID, provider), nil
|
||||
}
|
||||
|
||||
type oidcAuthProvider struct {
|
||||
// Interface rather than a raw *oidc.Client for testing.
|
||||
client OIDCClient
|
||||
client *http.Client
|
||||
|
||||
// Stubbed out for testing.
|
||||
// Method for determining the current time.
|
||||
now func() time.Time
|
||||
|
||||
// Mutex guards persisting to the kubeconfig file and allows synchronized
|
||||
@ -205,11 +184,6 @@ func (p *oidcAuthProvider) Login() error {
|
||||
return errors.New("not yet implemented")
|
||||
}
|
||||
|
||||
type OIDCClient interface {
|
||||
refreshToken(rt string) (oauth2.TokenResponse, error)
|
||||
verifyJWT(jwt *jose.JWT) error
|
||||
}
|
||||
|
||||
type roundTripper struct {
|
||||
provider *oidcAuthProvider
|
||||
wrapped http.RoundTripper
|
||||
@ -243,7 +217,7 @@ func (p *oidcAuthProvider) idToken() (string, error) {
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if idToken, ok := p.cfg[cfgIDToken]; ok && len(idToken) > 0 {
|
||||
valid, err := verifyJWTExpiry(p.now(), idToken)
|
||||
valid, err := idTokenExpired(p.now, idToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -259,17 +233,27 @@ func (p *oidcAuthProvider) idToken() (string, error) {
|
||||
return "", errors.New("No valid id-token, and cannot refresh without refresh-token")
|
||||
}
|
||||
|
||||
tokens, err := p.client.refreshToken(rt)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not refresh token: %v", err)
|
||||
}
|
||||
jwt, err := jose.ParseJWT(tokens.IDToken)
|
||||
// Determine provider's OAuth2 token endpoint.
|
||||
tokenURL, err := tokenEndpoint(p.client, p.cfg[cfgIssuerUrl])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := p.client.verifyJWT(&jwt); err != nil {
|
||||
return "", err
|
||||
config := oauth2.Config{
|
||||
ClientID: p.cfg[cfgClientID],
|
||||
ClientSecret: p.cfg[cfgClientSecret],
|
||||
Endpoint: oauth2.Endpoint{TokenURL: tokenURL},
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, p.client)
|
||||
token, err := config.TokenSource(ctx, &oauth2.Token{RefreshToken: rt}).Token()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh token: %v", err)
|
||||
}
|
||||
|
||||
idToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("token response did not contain an id_token")
|
||||
}
|
||||
|
||||
// Create a new config to persist.
|
||||
@ -278,59 +262,109 @@ func (p *oidcAuthProvider) idToken() (string, error) {
|
||||
newCfg[key] = val
|
||||
}
|
||||
|
||||
if tokens.RefreshToken != "" && tokens.RefreshToken != rt {
|
||||
newCfg[cfgRefreshToken] = tokens.RefreshToken
|
||||
// Update the refresh token if the server returned another one.
|
||||
if token.RefreshToken != "" && token.RefreshToken != rt {
|
||||
newCfg[cfgRefreshToken] = token.RefreshToken
|
||||
}
|
||||
newCfg[cfgIDToken] = idToken
|
||||
|
||||
newCfg[cfgIDToken] = tokens.IDToken
|
||||
// Persist new config and if successful, update the in memory config.
|
||||
if err = p.persister.Persist(newCfg); err != nil {
|
||||
return "", fmt.Errorf("could not perist new tokens: %v", err)
|
||||
}
|
||||
|
||||
// Update the in memory config to reflect the on disk one.
|
||||
p.cfg = newCfg
|
||||
|
||||
return tokens.IDToken, nil
|
||||
return idToken, nil
|
||||
}
|
||||
|
||||
// oidcClient is the real implementation of the OIDCClient interface, which is
|
||||
// used for testing.
|
||||
type oidcClient struct {
|
||||
client *oidc.Client
|
||||
}
|
||||
|
||||
func (o *oidcClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
|
||||
oac, err := o.client.OAuthClient()
|
||||
// tokenEndpoint uses OpenID Connect discovery to determine the OAuth2 token
|
||||
// endpoint for the provider, the endpoint the client will use the refresh
|
||||
// token against.
|
||||
func tokenEndpoint(client *http.Client, issuer string) (string, error) {
|
||||
// Well known URL for getting OpenID Connect metadata.
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig
|
||||
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
|
||||
resp, err := client.Get(wellKnown)
|
||||
if err != nil {
|
||||
return oauth2.TokenResponse{}, err
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// Don't produce an error that's too huge (e.g. if we get HTML back for some reason).
|
||||
const n = 80
|
||||
if len(body) > n {
|
||||
body = append(body[:n], []byte("...")...)
|
||||
}
|
||||
return "", fmt.Errorf("oidc: failed to query metadata endpoint %s: %q", resp.Status, body)
|
||||
}
|
||||
|
||||
return oac.RequestToken(oauth2.GrantTypeRefreshToken, rt)
|
||||
}
|
||||
|
||||
func (o *oidcClient) verifyJWT(jwt *jose.JWT) error {
|
||||
return o.client.VerifyJWT(*jwt)
|
||||
}
|
||||
|
||||
func verifyJWTExpiry(now time.Time, s string) (valid bool, err error) {
|
||||
jwt, err := jose.ParseJWT(s)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid %q", cfgIDToken)
|
||||
// Metadata object. We only care about the token_endpoint, the thing endpoint
|
||||
// we'll be refreshing against.
|
||||
//
|
||||
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
|
||||
var metadata struct {
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
}
|
||||
claims, err := jwt.Claims()
|
||||
if err := json.Unmarshal(body, &metadata); err != nil {
|
||||
return "", fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
|
||||
}
|
||||
if metadata.TokenURL == "" {
|
||||
return "", fmt.Errorf("oidc: discovery object doesn't contain a token_endpoint")
|
||||
}
|
||||
return metadata.TokenURL, nil
|
||||
}
|
||||
|
||||
func idTokenExpired(now func() time.Time, idToken string) (bool, error) {
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) != 3 {
|
||||
return false, fmt.Errorf("ID Token is not a valid JWT")
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
exp, ok, err := claims.TimeClaim("exp")
|
||||
switch {
|
||||
case err != nil:
|
||||
return false, fmt.Errorf("failed to parse 'exp' claim: %v", err)
|
||||
case !ok:
|
||||
return false, errors.New("missing required 'exp' claim")
|
||||
case exp.After(now.Add(expiryDelta)):
|
||||
return true, nil
|
||||
var claims struct {
|
||||
Expiry jsonTime `json:"exp"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return false, fmt.Errorf("parsing claims: %v", err)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
return now().Add(expiryDelta).Before(time.Time(claims.Expiry)), nil
|
||||
}
|
||||
|
||||
// jsonTime is a json.Unmarshaler that parses a unix timestamp.
|
||||
// Because JSON numbers don't differentiate between ints and floats,
|
||||
// we want to ensure we can parse either.
|
||||
type jsonTime time.Time
|
||||
|
||||
func (j *jsonTime) UnmarshalJSON(b []byte) error {
|
||||
var n json.Number
|
||||
if err := json.Unmarshal(b, &n); err != nil {
|
||||
return err
|
||||
}
|
||||
var unix int64
|
||||
|
||||
if t, err := n.Int64(); err == nil {
|
||||
unix = t
|
||||
} else {
|
||||
f, err := n.Float64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
unix = int64(f)
|
||||
}
|
||||
*j = jsonTime(time.Unix(unix, 0))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (j jsonTime) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(time.Time(j).Unix())
|
||||
}
|
||||
|
@ -18,366 +18,120 @@ package oidc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
"github.com/coreos/go-oidc/oauth2"
|
||||
|
||||
oidctesting "k8s.io/client-go/plugin/pkg/auth/authenticator/token/oidc/testing"
|
||||
)
|
||||
|
||||
func clearCache() {
|
||||
cache = newClientCache()
|
||||
}
|
||||
func TestJSONTime(t *testing.T) {
|
||||
data := `{
|
||||
"t1": 1493851263,
|
||||
"t2": 1.493851263e9
|
||||
}`
|
||||
|
||||
type persister struct{}
|
||||
|
||||
// we don't need to actually persist anything because there's no way for us to
|
||||
// read from a persister.
|
||||
func (p *persister) Persist(map[string]string) error { return nil }
|
||||
|
||||
type noRefreshOIDCClient struct{}
|
||||
|
||||
func (c *noRefreshOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
|
||||
return oauth2.TokenResponse{}, errors.New("alwaysErrOIDCClient: cannot refresh token")
|
||||
}
|
||||
|
||||
func (c *noRefreshOIDCClient) verifyJWT(jwt *jose.JWT) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockOIDCClient struct {
|
||||
tokenResponse oauth2.TokenResponse
|
||||
}
|
||||
|
||||
func (c *mockOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) {
|
||||
return c.tokenResponse, nil
|
||||
}
|
||||
|
||||
func (c *mockOIDCClient) verifyJWT(jwt *jose.JWT) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewOIDCAuthProvider(t *testing.T) {
|
||||
tempDir, err := ioutil.TempDir(os.TempDir(), "oidc_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot make temp dir %v", err)
|
||||
var v struct {
|
||||
T1 jsonTime `json:"t1"`
|
||||
T2 jsonTime `json:"t2"`
|
||||
}
|
||||
cert := path.Join(tempDir, "oidc-cert")
|
||||
key := path.Join(tempDir, "oidc-key")
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
oidctesting.GenerateSelfSignedCert(t, "127.0.0.1", cert, key)
|
||||
op := oidctesting.NewOIDCProvider(t, "")
|
||||
srv, err := op.ServeTLSWithKeyPair(cert, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Cannot start server %v", err)
|
||||
if err := json.Unmarshal([]byte(data), &v); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer srv.Close()
|
||||
wantT1 := time.Unix(1493851263, 0)
|
||||
wantT2 := time.Unix(1493851263, 0)
|
||||
gotT1 := time.Time(v.T1)
|
||||
gotT2 := time.Time(v.T2)
|
||||
|
||||
certData, err := ioutil.ReadFile(cert)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not read cert bytes %v", err)
|
||||
if !wantT1.Equal(gotT1) {
|
||||
t.Errorf("t1 value: wanted %s got %s", wantT1, gotT1)
|
||||
}
|
||||
|
||||
makeToken := func(exp time.Time) *jose.JWT {
|
||||
jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
|
||||
"exp": exp.UTC().Unix(),
|
||||
}), op.PrivKey.Signer())
|
||||
if err != nil {
|
||||
t.Fatalf("Could not create signed JWT %v", err)
|
||||
}
|
||||
return jwt
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
|
||||
goodToken := makeToken(t0.Add(time.Hour)).Encode()
|
||||
expiredToken := makeToken(t0.Add(-time.Hour)).Encode()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
cfg map[string]string
|
||||
wantInitErr bool
|
||||
|
||||
client OIDCClient
|
||||
wantCfg map[string]string
|
||||
wantTokenErr bool
|
||||
}{
|
||||
{
|
||||
// A Valid configuration
|
||||
name: "no id token and no refresh token",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
},
|
||||
wantTokenErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid config with an initial token",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgIDToken: goodToken,
|
||||
},
|
||||
client: new(noRefreshOIDCClient),
|
||||
wantCfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgIDToken: goodToken,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid ID token with a refresh token",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgRefreshToken: "foo",
|
||||
cfgIDToken: expiredToken,
|
||||
},
|
||||
client: &mockOIDCClient{
|
||||
tokenResponse: oauth2.TokenResponse{
|
||||
IDToken: goodToken,
|
||||
},
|
||||
},
|
||||
wantCfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgRefreshToken: "foo",
|
||||
cfgIDToken: goodToken,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid ID token with a refresh token, server returns new refresh token",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgRefreshToken: "foo",
|
||||
cfgIDToken: expiredToken,
|
||||
},
|
||||
client: &mockOIDCClient{
|
||||
tokenResponse: oauth2.TokenResponse{
|
||||
IDToken: goodToken,
|
||||
RefreshToken: "bar",
|
||||
},
|
||||
},
|
||||
wantCfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgRefreshToken: "bar",
|
||||
cfgIDToken: goodToken,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "expired token and no refresh otken",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
cfgIDToken: expiredToken,
|
||||
},
|
||||
wantTokenErr: true,
|
||||
},
|
||||
{
|
||||
name: "valid base64d ca",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthorityData: base64.StdEncoding.EncodeToString(certData),
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "client-secret",
|
||||
},
|
||||
client: new(noRefreshOIDCClient),
|
||||
wantTokenErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing client ID",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientSecret: "client-secret",
|
||||
},
|
||||
wantInitErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing client secret",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
},
|
||||
wantInitErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing issuer URL",
|
||||
cfg: map[string]string{
|
||||
cfgCertificateAuthority: cert,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "secret",
|
||||
},
|
||||
wantInitErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing TLS config",
|
||||
cfg: map[string]string{
|
||||
cfgIssuerUrl: srv.URL,
|
||||
cfgClientID: "client-id",
|
||||
cfgClientSecret: "secret",
|
||||
},
|
||||
wantInitErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
clearCache()
|
||||
|
||||
p, err := newOIDCAuthProvider("cluster.example.com", tt.cfg, new(persister))
|
||||
if tt.wantInitErr {
|
||||
if err == nil {
|
||||
t.Errorf("%s: want non-nil err", tt.name)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("%s: unexpected error on newOIDCAuthProvider: %v", tt.name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
provider := p.(*oidcAuthProvider)
|
||||
provider.client = tt.client
|
||||
provider.now = func() time.Time { return t0 }
|
||||
|
||||
if _, err := provider.idToken(); err != nil {
|
||||
if !tt.wantTokenErr {
|
||||
t.Errorf("%s: failed to get id token: %v", tt.name, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if tt.wantTokenErr {
|
||||
t.Errorf("%s: expected to not get id token: %v", tt.name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tt.wantCfg, provider.cfg) {
|
||||
t.Errorf("%s: expected config %#v got %#v", tt.name, tt.wantCfg, provider.cfg)
|
||||
}
|
||||
if !wantT2.Equal(gotT2) {
|
||||
t.Errorf("t2 value: wanted %s got %s", wantT2, gotT2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyJWTExpiry(t *testing.T) {
|
||||
privKey, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("can't generate private key: %v", err)
|
||||
}
|
||||
makeToken := func(s string, exp time.Time, count int) *jose.JWT {
|
||||
jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{
|
||||
"test": s,
|
||||
"exp": exp.UTC().Unix(),
|
||||
"count": count,
|
||||
}), privKey.Signer())
|
||||
if err != nil {
|
||||
t.Fatalf("Could not create signed JWT %v", err)
|
||||
}
|
||||
return jwt
|
||||
func encodeJWT(header, payload, sig string) string {
|
||||
e := func(s string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(s))
|
||||
}
|
||||
return e(header) + "." + e(payload) + "." + e(sig)
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
func TestExpired(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
nowFunc := func() time.Time { return now }
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
jwt *jose.JWT
|
||||
now time.Time
|
||||
idToken string
|
||||
wantErr bool
|
||||
wantExpired bool
|
||||
}{
|
||||
{
|
||||
name: "valid jwt",
|
||||
jwt: makeToken("foo", t0.Add(time.Hour), 1),
|
||||
now: t0,
|
||||
name: "valid",
|
||||
idToken: encodeJWT(
|
||||
"{}",
|
||||
fmt.Sprintf(`{"exp":%d}`, now.Add(time.Hour).Unix()),
|
||||
"blah", // signature isn't veified.
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "invalid jwt",
|
||||
jwt: &jose.JWT{},
|
||||
now: t0,
|
||||
name: "expired",
|
||||
idToken: encodeJWT(
|
||||
"{}",
|
||||
fmt.Sprintf(`{"exp":%d}`, now.Add(-time.Hour).Unix()),
|
||||
"blah", // signature isn't veified.
|
||||
),
|
||||
wantExpired: true,
|
||||
},
|
||||
{
|
||||
name: "bad exp claim",
|
||||
idToken: encodeJWT(
|
||||
"{}",
|
||||
`{"exp":"foobar"}`,
|
||||
"blah", // signature isn't veified.
|
||||
),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "expired jwt",
|
||||
jwt: makeToken("foo", t0.Add(-time.Hour), 1),
|
||||
now: t0,
|
||||
wantExpired: true,
|
||||
},
|
||||
{
|
||||
name: "jwt expires soon enough to be marked expired",
|
||||
jwt: makeToken("foo", t0, 1),
|
||||
now: t0,
|
||||
wantExpired: true,
|
||||
name: "not an id token",
|
||||
idToken: "notanidtoken",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
func() {
|
||||
valid, err := verifyJWTExpiry(tc.now, tc.jwt.Encode())
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
valid, err := idTokenExpired(nowFunc, test.idToken)
|
||||
if err != nil {
|
||||
if !tc.wantErr {
|
||||
t.Errorf("%s: %v", tc.name, err)
|
||||
if !test.wantErr {
|
||||
t.Errorf("parse error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if tc.wantErr {
|
||||
t.Errorf("%s: expected error", tc.name)
|
||||
return
|
||||
if test.wantExpired == valid {
|
||||
t.Errorf("wanted expired %t, got %", test.wantExpired, !valid)
|
||||
}
|
||||
|
||||
if valid && tc.wantExpired {
|
||||
t.Errorf("%s: expected token to be expired", tc.name)
|
||||
}
|
||||
if !valid && !tc.wantExpired {
|
||||
t.Errorf("%s: expected token to be valid", tc.name)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientCache(t *testing.T) {
|
||||
cache := newClientCache()
|
||||
|
||||
if _, ok := cache.getClient("issuer1", "id1", "secret1"); ok {
|
||||
if _, ok := cache.getClient("issuer1", "id1"); ok {
|
||||
t.Fatalf("got client before putting one in the cache")
|
||||
}
|
||||
|
||||
cli1 := new(oidcAuthProvider)
|
||||
cli2 := new(oidcAuthProvider)
|
||||
|
||||
gotcli := cache.setClient("issuer1", "id1", "secret1", cli1)
|
||||
gotcli := cache.setClient("issuer1", "id1", cli1)
|
||||
if cli1 != gotcli {
|
||||
t.Fatalf("set first client and got a different one")
|
||||
}
|
||||
|
||||
gotcli = cache.setClient("issuer1", "id1", "secret1", cli2)
|
||||
gotcli = cache.setClient("issuer1", "id1", cli2)
|
||||
if cli1 != gotcli {
|
||||
t.Fatalf("set a second client and didn't get the first")
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user