diff --git a/plugin/pkg/client/auth/oidc/BUILD b/plugin/pkg/client/auth/oidc/BUILD index d53698f4ac9..fdac035a6d0 100644 --- a/plugin/pkg/client/auth/oidc/BUILD +++ b/plugin/pkg/client/auth/oidc/BUILD @@ -14,7 +14,6 @@ go_library( tags = ["automanaged"], deps = [ "//pkg/client/restclient:go_default_library", - "//pkg/util/wait:go_default_library", "//vendor:github.com/coreos/go-oidc/jose", "//vendor:github.com/coreos/go-oidc/oauth2", "//vendor:github.com/coreos/go-oidc/oidc", @@ -28,8 +27,6 @@ go_test( library = ":go_default_library", tags = ["automanaged"], deps = [ - "//pkg/util/diff:go_default_library", - "//pkg/util/wait:go_default_library", "//plugin/pkg/auth/authenticator/token/oidc/testing:go_default_library", "//vendor:github.com/coreos/go-oidc/jose", "//vendor:github.com/coreos/go-oidc/key", diff --git a/plugin/pkg/client/auth/oidc/oidc.go b/plugin/pkg/client/auth/oidc/oidc.go index 155bb486733..68168bf9d73 100644 --- a/plugin/pkg/client/auth/oidc/oidc.go +++ b/plugin/pkg/client/auth/oidc/oidc.go @@ -22,6 +22,7 @@ import ( "fmt" "net/http" "strings" + "sync" "time" "github.com/coreos/go-oidc/jose" @@ -30,7 +31,6 @@ import ( "github.com/golang/glog" "k8s.io/kubernetes/pkg/client/restclient" - "k8s.io/kubernetes/pkg/util/wait" ) const ( @@ -44,21 +44,68 @@ const ( cfgRefreshToken = "refresh-token" ) -var ( - backoff = wait.Backoff{ - Duration: 1 * time.Second, - Factor: 2, - Jitter: .1, - Steps: 5, - } -) - func init() { if err := restclient.RegisterAuthProviderPlugin("oidc", newOIDCAuthProvider); err != nil { glog.Fatalf("Failed to register oidc auth plugin: %v", err) } } +// expiryDelta determines how earlier a token should be considered +// expired than its actual expiration time. It is used to avoid late +// expirations due to client-server time mismatches. +// +// NOTE(ericchiang): this is take from golang.org/x/oauth2 +const expiryDelta = 10 * time.Second + +var cache = newClientCache() + +// Like TLS transports, keep a cache of OIDC clients indexed by issuer URL. +type clientCache struct { + mu sync.RWMutex + cache map[cacheKey]*oidcAuthProvider +} + +func newClientCache() *clientCache { + return &clientCache{cache: make(map[cacheKey]*oidcAuthProvider)} +} + +type cacheKey struct { + // Canonical issuer URL string of the provider. + issuerURL string + + clientID string + clientSecret string + + // Don't use CA as cache key because we only add a cache entry if we can connect + // to the issuer in the first place. A valid CA is a prerequisite. +} + +func (c *clientCache) getClient(issuer, clientID, clientSecret string) (*oidcAuthProvider, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + client, ok := c.cache[cacheKey{issuer, clientID, clientSecret}] + return client, ok +} + +// setClient attempts to put the client in the cache but may return any clients +// with the same keys set before. This is so there's only ever one client for a provider. +func (c *clientCache) setClient(issuer, clientID, clientSecret string, client *oidcAuthProvider) *oidcAuthProvider { + c.mu.Lock() + defer c.mu.Unlock() + key := cacheKey{issuer, clientID, clientSecret} + + // If another client has already initialized a client for the given provider we want + // to use that client instead of the one we're trying to set. This is so all transports + // share a client and can coordinate around the same mutex when refreshing and writing + // to the kubeconfig. + if oldClient, ok := c.cache[key]; ok { + return oldClient + } + + c.cache[key] = client + return client +} + func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) { issuer := cfg[cfgIssuerUrl] if issuer == "" { @@ -75,6 +122,11 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A return nil, fmt.Errorf("Must provide %s", cfgClientSecret) } + // Check cache for existing provider. + if provider, ok := cache.getClient(issuer, clientID, clientSecret); ok { + return provider, nil + } + var certAuthData []byte var err error if cfg[cfgCertificateAuthorityData] != "" { @@ -112,146 +164,134 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A ProviderConfig: providerCfg, Scope: append(scopes, oidc.DefaultScope...), } - client, err := oidc.NewClient(oidcCfg) if err != nil { return nil, fmt.Errorf("error creating OIDC Client: %v", err) } - oClient := &oidcClient{client} - - var initialIDToken jose.JWT - if cfg[cfgIDToken] != "" { - initialIDToken, err = jose.ParseJWT(cfg[cfgIDToken]) - if err != nil { - return nil, err - } + provider := &oidcAuthProvider{ + client: &oidcClient{client}, + cfg: cfg, + persister: persister, + now: time.Now, } - return &oidcAuthProvider{ - initialIDToken: initialIDToken, - refresher: &idTokenRefresher{ - client: oClient, - cfg: cfg, - persister: persister, - }, - }, nil + return cache.setClient(issuer, clientID, clientSecret, provider), nil } type oidcAuthProvider struct { - refresher *idTokenRefresher - initialIDToken jose.JWT + // Interface rather than a raw *oidc.Client for testing. + client OIDCClient + + // Stubbed out for testing. + now func() time.Time + + // Mutex guards persisting to the kubeconfig file and allows synchronized + // updates to the in-memory config. It also ensures concurrent calls to + // the RoundTripper only trigger a single refresh request. + mu sync.Mutex + cfg map[string]string + persister restclient.AuthProviderConfigPersister } -func (g *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper { - at := &oidc.AuthenticatedTransport{ - TokenRefresher: g.refresher, - RoundTripper: rt, - } - at.SetJWT(g.initialIDToken) +func (p *oidcAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper { return &roundTripper{ - wrapped: at, - refresher: g.refresher, + wrapped: rt, + provider: p, } } -func (g *oidcAuthProvider) Login() error { +func (p *oidcAuthProvider) Login() error { return errors.New("not yet implemented") } type OIDCClient interface { refreshToken(rt string) (oauth2.TokenResponse, error) - verifyJWT(jwt jose.JWT) error + verifyJWT(jwt *jose.JWT) error } type roundTripper struct { - refresher *idTokenRefresher - wrapped *oidc.AuthenticatedTransport + provider *oidcAuthProvider + wrapped http.RoundTripper } func (r *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - var res *http.Response - var err error - firstTime := true - wait.ExponentialBackoff(backoff, func() (bool, error) { - if !firstTime { - var jwt jose.JWT - jwt, err = r.refresher.Refresh() - if err != nil { - return true, nil - } - r.wrapped.SetJWT(jwt) - } else { - firstTime = false - } + token, err := r.provider.idToken() + if err != nil { + return nil, err + } - res, err = r.wrapped.RoundTrip(req) + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *req + // deep copy of the Header so we don't modify the original + // request's Header (as per RoundTripper contract). + r2.Header = make(http.Header) + for k, s := range req.Header { + r2.Header[k] = s + } + r2.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + return r.wrapped.RoundTrip(r2) +} + +func (p *oidcAuthProvider) idToken() (string, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if idToken, ok := p.cfg[cfgIDToken]; ok && len(idToken) > 0 { + valid, err := verifyJWTExpiry(p.now(), idToken) if err != nil { - return true, nil + return "", err } - if res.StatusCode == http.StatusUnauthorized { - return false, nil + if valid { + // If the cached id token is still valid use it. + return idToken, nil } - return true, nil - }) - return res, err -} + } -type idTokenRefresher struct { - cfg map[string]string - client OIDCClient - persister restclient.AuthProviderConfigPersister - intialIDToken jose.JWT -} + // Try to request a new token using the refresh token. + rt, ok := p.cfg[cfgRefreshToken] + if !ok || len(rt) == 0 { + return "", errors.New("No valid id-token, and cannot refresh without refresh-token") + } -func (r *idTokenRefresher) Verify(jwt jose.JWT) error { - claims, err := jwt.Claims() + tokens, err := p.client.refreshToken(rt) if err != nil { - return err - } - - now := time.Now() - exp, ok, err := claims.TimeClaim("exp") - switch { - case err != nil: - return fmt.Errorf("failed to parse 'exp' claim: %v", err) - case !ok: - return errors.New("missing required 'exp' claim") - case exp.Before(now): - return fmt.Errorf("token already expired at: %v", exp) - } - - return nil -} - -func (r *idTokenRefresher) Refresh() (jose.JWT, error) { - rt, ok := r.cfg[cfgRefreshToken] - if !ok { - return jose.JWT{}, errors.New("No valid id-token, and cannot refresh without refresh-token") - } - - tokens, err := r.client.refreshToken(rt) - if err != nil { - return jose.JWT{}, fmt.Errorf("could not refresh token: %v", err) + return "", fmt.Errorf("could not refresh token: %v", err) } jwt, err := jose.ParseJWT(tokens.IDToken) if err != nil { - return jose.JWT{}, err + return "", err + } + + if err := p.client.verifyJWT(&jwt); err != nil { + return "", err + } + + // Create a new config to persist. + newCfg := make(map[string]string) + for key, val := range p.cfg { + newCfg[key] = val } if tokens.RefreshToken != "" && tokens.RefreshToken != rt { - r.cfg[cfgRefreshToken] = tokens.RefreshToken - } - r.cfg[cfgIDToken] = jwt.Encode() - - err = r.persister.Persist(r.cfg) - if err != nil { - return jose.JWT{}, fmt.Errorf("could not perist new tokens: %v", err) + newCfg[cfgRefreshToken] = tokens.RefreshToken } - return jwt, r.client.verifyJWT(jwt) + newCfg[cfgIDToken] = tokens.IDToken + if err = p.persister.Persist(newCfg); err != nil { + return "", fmt.Errorf("could not perist new tokens: %v", err) + } + + // Update the in memory config to reflect the on disk one. + p.cfg = newCfg + + return tokens.IDToken, nil } +// oidcClient is the real implementation of the OIDCClient interface, which is +// used for testing. type oidcClient struct { client *oidc.Client } @@ -265,6 +305,29 @@ func (o *oidcClient) refreshToken(rt string) (oauth2.TokenResponse, error) { return oac.RequestToken(oauth2.GrantTypeRefreshToken, rt) } -func (o *oidcClient) verifyJWT(jwt jose.JWT) error { - return o.client.VerifyJWT(jwt) +func (o *oidcClient) verifyJWT(jwt *jose.JWT) error { + return o.client.VerifyJWT(*jwt) +} + +func verifyJWTExpiry(now time.Time, s string) (valid bool, err error) { + jwt, err := jose.ParseJWT(s) + if err != nil { + return false, fmt.Errorf("invalid %q", cfgIDToken) + } + claims, err := jwt.Claims() + if err != nil { + return false, err + } + + exp, ok, err := claims.TimeClaim("exp") + switch { + case err != nil: + return false, fmt.Errorf("failed to parse 'exp' claim: %v", err) + case !ok: + return false, errors.New("missing required 'exp' claim") + case exp.After(now.Add(expiryDelta)): + return true, nil + } + + return false, nil } diff --git a/plugin/pkg/client/auth/oidc/oidc_test.go b/plugin/pkg/client/auth/oidc/oidc_test.go index 48bbc5a72b6..16e11492818 100644 --- a/plugin/pkg/client/auth/oidc/oidc_test.go +++ b/plugin/pkg/client/auth/oidc/oidc_test.go @@ -19,13 +19,10 @@ package oidc import ( "encoding/base64" "errors" - "fmt" "io/ioutil" - "net/http" "os" "path" "reflect" - "strings" "testing" "time" @@ -33,11 +30,41 @@ import ( "github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/oauth2" - "k8s.io/kubernetes/pkg/util/diff" - "k8s.io/kubernetes/pkg/util/wait" oidctesting "k8s.io/kubernetes/plugin/pkg/auth/authenticator/token/oidc/testing" ) +func clearCache() { + cache = newClientCache() +} + +type persister struct{} + +// we don't need to actually persist anything because there's no way for us to +// read from a persister. +func (p *persister) Persist(map[string]string) error { return nil } + +type noRefreshOIDCClient struct{} + +func (c *noRefreshOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) { + return oauth2.TokenResponse{}, errors.New("alwaysErrOIDCClient: cannot refresh token") +} + +func (c *noRefreshOIDCClient) verifyJWT(jwt *jose.JWT) error { + return nil +} + +type mockOIDCClient struct { + tokenResponse oauth2.TokenResponse +} + +func (c *mockOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) { + return c.tokenResponse, nil +} + +func (c *mockOIDCClient) verifyJWT(jwt *jose.JWT) error { + return nil +} + func TestNewOIDCAuthProvider(t *testing.T) { tempDir, err := ioutil.TempDir(os.TempDir(), "oidc_test") if err != nil { @@ -60,127 +87,211 @@ func TestNewOIDCAuthProvider(t *testing.T) { t.Fatalf("Could not read cert bytes %v", err) } - jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{ - "test": "jwt", - }), op.PrivKey.Signer()) - if err != nil { - t.Fatalf("Could not create signed JWT %v", err) + makeToken := func(exp time.Time) *jose.JWT { + jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{ + "exp": exp.UTC().Unix(), + }), op.PrivKey.Signer()) + if err != nil { + t.Fatalf("Could not create signed JWT %v", err) + } + return jwt } - tests := []struct { - cfg map[string]string + t0 := time.Now() - wantErr bool - wantInitialIDToken jose.JWT + goodToken := makeToken(t0.Add(time.Hour)).Encode() + expiredToken := makeToken(t0.Add(-time.Hour)).Encode() + + tests := []struct { + name string + + cfg map[string]string + wantInitErr bool + + client OIDCClient + wantCfg map[string]string + wantTokenErr bool }{ { // A Valid configuration + name: "no id token and no refresh token", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "client-secret", }, + wantTokenErr: true, }, { - // A Valid configuration with an Initial JWT + name: "valid config with an initial token", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "client-secret", - cfgIDToken: jwt.Encode(), + cfgIDToken: goodToken, + }, + client: new(noRefreshOIDCClient), + wantCfg: map[string]string{ + cfgIssuerUrl: srv.URL, + cfgCertificateAuthority: cert, + cfgClientID: "client-id", + cfgClientSecret: "client-secret", + cfgIDToken: goodToken, }, - wantInitialIDToken: *jwt, }, { - // Valid config, but using cfgCertificateAuthorityData + name: "invalid ID token with a refresh token", + cfg: map[string]string{ + cfgIssuerUrl: srv.URL, + cfgCertificateAuthority: cert, + cfgClientID: "client-id", + cfgClientSecret: "client-secret", + cfgRefreshToken: "foo", + cfgIDToken: expiredToken, + }, + client: &mockOIDCClient{ + tokenResponse: oauth2.TokenResponse{ + IDToken: goodToken, + }, + }, + wantCfg: map[string]string{ + cfgIssuerUrl: srv.URL, + cfgCertificateAuthority: cert, + cfgClientID: "client-id", + cfgClientSecret: "client-secret", + cfgRefreshToken: "foo", + cfgIDToken: goodToken, + }, + }, + { + name: "invalid ID token with a refresh token, server returns new refresh token", + cfg: map[string]string{ + cfgIssuerUrl: srv.URL, + cfgCertificateAuthority: cert, + cfgClientID: "client-id", + cfgClientSecret: "client-secret", + cfgRefreshToken: "foo", + cfgIDToken: expiredToken, + }, + client: &mockOIDCClient{ + tokenResponse: oauth2.TokenResponse{ + IDToken: goodToken, + RefreshToken: "bar", + }, + }, + wantCfg: map[string]string{ + cfgIssuerUrl: srv.URL, + cfgCertificateAuthority: cert, + cfgClientID: "client-id", + cfgClientSecret: "client-secret", + cfgRefreshToken: "bar", + cfgIDToken: goodToken, + }, + }, + { + name: "expired token and no refresh otken", + cfg: map[string]string{ + cfgIssuerUrl: srv.URL, + cfgCertificateAuthority: cert, + cfgClientID: "client-id", + cfgClientSecret: "client-secret", + cfgIDToken: expiredToken, + }, + wantTokenErr: true, + }, + { + name: "valid base64d ca", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthorityData: base64.StdEncoding.EncodeToString(certData), cfgClientID: "client-id", cfgClientSecret: "client-secret", }, + client: new(noRefreshOIDCClient), + wantTokenErr: true, }, { - // Missing client id + name: "missing client ID", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientSecret: "client-secret", }, - wantErr: true, + wantInitErr: true, }, { - // Missing client secret + name: "missing client secret", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgCertificateAuthority: cert, cfgClientID: "client-id", }, - wantErr: true, + wantInitErr: true, }, { - // Missing issuer url. + name: "missing issuer URL", cfg: map[string]string{ cfgCertificateAuthority: cert, cfgClientID: "client-id", cfgClientSecret: "secret", }, - wantErr: true, + wantInitErr: true, }, { - // No TLS config + name: "missing TLS config", cfg: map[string]string{ cfgIssuerUrl: srv.URL, cfgClientID: "client-id", cfgClientSecret: "secret", }, - wantErr: true, + wantInitErr: true, }, } - for i, tt := range tests { - ap, err := newOIDCAuthProvider("cluster.example.com", tt.cfg, nil) - if tt.wantErr { + for _, tt := range tests { + clearCache() + + p, err := newOIDCAuthProvider("cluster.example.com", tt.cfg, new(persister)) + if tt.wantInitErr { if err == nil { - t.Errorf("case %d: want non-nil err", i) + t.Errorf("%s: want non-nil err", tt.name) } continue } if err != nil { - t.Errorf("case %d: unexpected error on newOIDCAuthProvider: %v", i, err) + t.Errorf("%s: unexpected error on newOIDCAuthProvider: %v", tt.name, err) continue } - oidcAP, ok := ap.(*oidcAuthProvider) - if !ok { - t.Errorf("case %d: expected ap to be an oidcAuthProvider", i) + provider := p.(*oidcAuthProvider) + provider.client = tt.client + provider.now = func() time.Time { return t0 } + + if _, err := provider.idToken(); err != nil { + if !tt.wantTokenErr { + t.Errorf("%s: failed to get id token: %v", tt.name, err) + } + continue + } + if tt.wantTokenErr { + t.Errorf("%s: expected to not get id token: %v", tt.name, err) continue } - if diff := compareJWTs(tt.wantInitialIDToken, oidcAP.initialIDToken); diff != "" { - t.Errorf("case %d: compareJWTs(tt.wantInitialIDToken, oidcAP.initialIDToken)=%v", i, diff) + if !reflect.DeepEqual(tt.wantCfg, provider.cfg) { + t.Errorf("%s: expected config %#v got %#v", tt.name, tt.wantCfg, provider.cfg) } } } -func TestWrapTranport(t *testing.T) { - oldBackoff := backoff - defer func() { - backoff = oldBackoff - }() - backoff = wait.Backoff{ - Duration: 1 * time.Nanosecond, - Steps: 3, - } - +func TestVerifyJWTExpiry(t *testing.T) { privKey, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("can't generate private key: %v", err) } - makeToken := func(s string, exp time.Time, count int) *jose.JWT { jwt, err := jose.NewSignedJWT(jose.Claims(map[string]interface{}{ "test": s, @@ -193,451 +304,81 @@ func TestWrapTranport(t *testing.T) { return jwt } - goodToken := makeToken("good", time.Now().Add(time.Hour), 0) - goodToken2 := makeToken("good", time.Now().Add(time.Hour), 1) - expiredToken := makeToken("good", time.Now().Add(-time.Hour), 0) + t0 := time.Now() - str := func(s string) *string { - return &s - } tests := []struct { - cfgIDToken *jose.JWT - cfgRefreshToken *string - - expectRequests []testRoundTrip - - expectRefreshes []testRefresh - - expectPersists []testPersist - - wantStatus int - wantErr bool + name string + jwt *jose.JWT + now time.Time + wantErr bool + wantExpired bool }{ { - // Initial JWT is set, it is good, it is set as bearer. - cfgIDToken: goodToken, - - expectRequests: []testRoundTrip{ - { - expectBearerToken: goodToken.Encode(), - returnHTTPStatus: 200, - }, - }, - - wantStatus: 200, + name: "valid jwt", + jwt: makeToken("foo", t0.Add(time.Hour), 1), + now: t0, }, { - // Initial JWT is set, but it's expired, so it gets refreshed. - cfgIDToken: expiredToken, - cfgRefreshToken: str("rt1"), - - expectRefreshes: []testRefresh{ - { - expectRefreshToken: "rt1", - returnTokens: oauth2.TokenResponse{ - IDToken: goodToken.Encode(), - }, - }, - }, - - expectRequests: []testRoundTrip{ - { - expectBearerToken: goodToken.Encode(), - returnHTTPStatus: 200, - }, - }, - - expectPersists: []testPersist{ - { - cfg: map[string]string{ - cfgIDToken: goodToken.Encode(), - cfgRefreshToken: "rt1", - }, - }, - }, - - wantStatus: 200, - }, - { - // Initial JWT is set, but it's expired, so it gets refreshed - this - // time the refresh token itself is also refreshed - cfgIDToken: expiredToken, - cfgRefreshToken: str("rt1"), - - expectRefreshes: []testRefresh{ - { - expectRefreshToken: "rt1", - returnTokens: oauth2.TokenResponse{ - IDToken: goodToken.Encode(), - RefreshToken: "rt2", - }, - }, - }, - - expectRequests: []testRoundTrip{ - { - expectBearerToken: goodToken.Encode(), - returnHTTPStatus: 200, - }, - }, - - expectPersists: []testPersist{ - { - cfg: map[string]string{ - cfgIDToken: goodToken.Encode(), - cfgRefreshToken: "rt2", - }, - }, - }, - - wantStatus: 200, - }, - { - // Initial JWT is not set, so it gets refreshed. - cfgRefreshToken: str("rt1"), - - expectRefreshes: []testRefresh{ - { - expectRefreshToken: "rt1", - returnTokens: oauth2.TokenResponse{ - IDToken: goodToken.Encode(), - }, - }, - }, - - expectRequests: []testRoundTrip{ - { - expectBearerToken: goodToken.Encode(), - returnHTTPStatus: 200, - }, - }, - - expectPersists: []testPersist{ - { - cfg: map[string]string{ - cfgIDToken: goodToken.Encode(), - cfgRefreshToken: "rt1", - }, - }, - }, - - wantStatus: 200, - }, - { - // Expired token, but no refresh token. - cfgIDToken: expiredToken, - + name: "invalid jwt", + jwt: &jose.JWT{}, + now: t0, wantErr: true, }, { - // Initial JWT is not set, so it gets refreshed, but the server - // rejects it when it is used, so it refreshes again, which - // succeeds. - cfgRefreshToken: str("rt1"), - - expectRefreshes: []testRefresh{ - { - expectRefreshToken: "rt1", - returnTokens: oauth2.TokenResponse{ - IDToken: goodToken.Encode(), - }, - }, - { - expectRefreshToken: "rt1", - returnTokens: oauth2.TokenResponse{ - IDToken: goodToken2.Encode(), - }, - }, - }, - - expectRequests: []testRoundTrip{ - { - expectBearerToken: goodToken.Encode(), - returnHTTPStatus: http.StatusUnauthorized, - }, - { - expectBearerToken: goodToken2.Encode(), - returnHTTPStatus: http.StatusOK, - }, - }, - - expectPersists: []testPersist{ - { - cfg: map[string]string{ - cfgIDToken: goodToken.Encode(), - cfgRefreshToken: "rt1", - }, - }, - { - cfg: map[string]string{ - cfgIDToken: goodToken2.Encode(), - cfgRefreshToken: "rt1", - }, - }, - }, - - wantStatus: 200, + name: "expired jwt", + jwt: makeToken("foo", t0.Add(-time.Hour), 1), + now: t0, + wantExpired: true, }, { - // Initial JWT is but the server rejects it when it is used, so it - // refreshes again, which succeeds. - cfgRefreshToken: str("rt1"), - cfgIDToken: goodToken, - - expectRefreshes: []testRefresh{ - { - expectRefreshToken: "rt1", - returnTokens: oauth2.TokenResponse{ - IDToken: goodToken2.Encode(), - }, - }, - }, - - expectRequests: []testRoundTrip{ - { - expectBearerToken: goodToken.Encode(), - returnHTTPStatus: http.StatusUnauthorized, - }, - { - expectBearerToken: goodToken2.Encode(), - returnHTTPStatus: http.StatusOK, - }, - }, - - expectPersists: []testPersist{ - { - cfg: map[string]string{ - cfgIDToken: goodToken2.Encode(), - cfgRefreshToken: "rt1", - }, - }, - }, - wantStatus: 200, + name: "jwt expires soon enough to be marked expired", + jwt: makeToken("foo", t0, 1), + now: t0, + wantExpired: true, }, } - for i, tt := range tests { - client := &testOIDCClient{ - refreshes: tt.expectRefreshes, - } - - persister := &testPersister{ - tt.expectPersists, - } - - cfg := map[string]string{} - if tt.cfgIDToken != nil { - cfg[cfgIDToken] = tt.cfgIDToken.Encode() - } - - if tt.cfgRefreshToken != nil { - cfg[cfgRefreshToken] = *tt.cfgRefreshToken - } - - ap := &oidcAuthProvider{ - refresher: &idTokenRefresher{ - client: client, - cfg: cfg, - persister: persister, - }, - } - - if tt.cfgIDToken != nil { - ap.initialIDToken = *tt.cfgIDToken - } - - tstRT := &testRoundTripper{ - tt.expectRequests, - } - - rt := ap.WrapTransport(tstRT) - - req, err := http.NewRequest("GET", "http://cluster.example.com", nil) - if err != nil { - t.Errorf("case %d: unexpected error making request: %v", i, err) - } - - res, err := rt.RoundTrip(req) - if tt.wantErr { - if err == nil { - t.Errorf("case %d: Expected non-nil error", i) + for _, tc := range tests { + func() { + valid, err := verifyJWTExpiry(tc.now, tc.jwt.Encode()) + if err != nil { + if !tc.wantErr { + t.Errorf("%s: %v", tc.name, err) + } + return } - } else if err != nil { - t.Errorf("case %d: unexpected error making round trip: %v", i, err) - - } else { - if res.StatusCode != tt.wantStatus { - t.Errorf("case %d: want=%d, got=%d", i, tt.wantStatus, res.StatusCode) + if tc.wantErr { + t.Errorf("%s: expected error", tc.name) + return } - } - - if err = client.verify(); err != nil { - t.Errorf("case %d: %v", i, err) - } - - if err = persister.verify(); err != nil { - t.Errorf("case %d: %v", i, err) - } - - if err = tstRT.verify(); err != nil { - t.Errorf("case %d: %v", i, err) - continue - } + if valid && tc.wantExpired { + t.Errorf("%s: expected token to be expired", tc.name) + } + if !valid && !tc.wantExpired { + t.Errorf("%s: expected token to be valid", tc.name) + } + }() } } -type testRoundTrip struct { - expectBearerToken string - returnHTTPStatus int -} - -type testRoundTripper struct { - trips []testRoundTrip -} - -func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if len(t.trips) == 0 { - return nil, errors.New("unexpected RoundTrip call") - } - - var trip testRoundTrip - trip, t.trips = t.trips[0], t.trips[1:] - - var bt string - var parts []string - auth := strings.TrimSpace(req.Header.Get("Authorization")) - if auth == "" { - goto Compare - } - - parts = strings.Split(auth, " ") - if len(parts) < 2 || strings.ToLower(parts[0]) != "bearer" { - goto Compare - } - - bt = parts[1] - -Compare: - if trip.expectBearerToken != bt { - return nil, fmt.Errorf("want bearerToken=%v, got=%v", trip.expectBearerToken, bt) - } - return &http.Response{ - StatusCode: trip.returnHTTPStatus, - }, nil -} - -func (t *testRoundTripper) verify() error { - if l := len(t.trips); l > 0 { - return fmt.Errorf("%d uncalled round trips", l) - } - return nil -} - -type testPersist struct { - cfg map[string]string - returnErr error -} - -type testPersister struct { - persists []testPersist -} - -func (t *testPersister) Persist(cfg map[string]string) error { - if len(t.persists) == 0 { - return errors.New("unexpected persist call") - } - - var persist testPersist - persist, t.persists = t.persists[0], t.persists[1:] - - if !reflect.DeepEqual(persist.cfg, cfg) { - return fmt.Errorf("Unexpected cfg: %v", diff.ObjectDiff(persist.cfg, cfg)) - } - - return persist.returnErr -} - -func (t *testPersister) verify() error { - if l := len(t.persists); l > 0 { - return fmt.Errorf("%d uncalled persists", l) - } - return nil -} - -type testRefresh struct { - expectRefreshToken string - - returnErr error - returnTokens oauth2.TokenResponse -} - -type testOIDCClient struct { - refreshes []testRefresh -} - -func (o *testOIDCClient) refreshToken(rt string) (oauth2.TokenResponse, error) { - if len(o.refreshes) == 0 { - return oauth2.TokenResponse{}, errors.New("unexpected refresh request") - } - - var refresh testRefresh - refresh, o.refreshes = o.refreshes[0], o.refreshes[1:] - - if rt != refresh.expectRefreshToken { - return oauth2.TokenResponse{}, fmt.Errorf("want rt=%v, got=%v", - refresh.expectRefreshToken, - rt) - } - - if refresh.returnErr != nil { - return oauth2.TokenResponse{}, refresh.returnErr - } - - return refresh.returnTokens, nil -} - -func (o *testOIDCClient) verifyJWT(jwt jose.JWT) error { - claims, err := jwt.Claims() - if err != nil { - return err - } - claim, _, _ := claims.StringClaim("test") - if claim != "good" { - return errors.New("bad token") - } - return nil -} - -func (t *testOIDCClient) verify() error { - if l := len(t.refreshes); l > 0 { - return fmt.Errorf("%d uncalled refreshes", l) - } - return nil -} - -func compareJWTs(a, b jose.JWT) string { - if a.Encode() == b.Encode() { - return "" - } - - var aClaims, bClaims jose.Claims - for _, j := range []struct { - claims *jose.Claims - jwt jose.JWT - }{ - {&aClaims, a}, - {&bClaims, b}, - } { - var err error - *j.claims, err = j.jwt.Claims() - if err != nil { - *j.claims = jose.Claims(map[string]interface{}{ - "msg": "bad claims", - "err": err, - }) - } - } - - return diff.ObjectDiff(aClaims, bClaims) +func TestClientCache(t *testing.T) { + cache := newClientCache() + + if _, ok := cache.getClient("issuer1", "id1", "secret1"); ok { + t.Fatalf("got client before putting one in the cache") + } + + cli1 := new(oidcAuthProvider) + cli2 := new(oidcAuthProvider) + + gotcli := cache.setClient("issuer1", "id1", "secret1", cli1) + if cli1 != gotcli { + t.Fatalf("set first client and got a different one") + } + + gotcli = cache.setClient("issuer1", "id1", "secret1", cli2) + if cli1 != gotcli { + t.Fatalf("set a second client and didn't get the first") + } }