diff --git a/plugin/pkg/client/auth/oidc/oidc.go b/plugin/pkg/client/auth/oidc/oidc.go index 1383a97c..94021201 100644 --- a/plugin/pkg/client/auth/oidc/oidc.go +++ b/plugin/pkg/client/auth/oidc/oidc.go @@ -76,24 +76,25 @@ func newClientCache() *clientCache { } type cacheKey struct { + clusterAddress string // Canonical issuer URL string of the provider. issuerURL string clientID string } -func (c *clientCache) getClient(issuer, clientID string) (*oidcAuthProvider, bool) { +func (c *clientCache) getClient(clusterAddress, issuer, clientID string) (*oidcAuthProvider, bool) { c.mu.RLock() defer c.mu.RUnlock() - client, ok := c.cache[cacheKey{issuer, clientID}] + client, ok := c.cache[cacheKey{clusterAddress: clusterAddress, issuerURL: issuer, clientID: clientID}] return client, ok } // setClient attempts to put the client in the cache but may return any clients // with the same keys set before. This is so there's only ever one client for a provider. -func (c *clientCache) setClient(issuer, clientID string, client *oidcAuthProvider) *oidcAuthProvider { +func (c *clientCache) setClient(clusterAddress, issuer, clientID string, client *oidcAuthProvider) *oidcAuthProvider { c.mu.Lock() defer c.mu.Unlock() - key := cacheKey{issuer, clientID} + key := cacheKey{clusterAddress: clusterAddress, issuerURL: issuer, clientID: clientID} // If another client has already initialized a client for the given provider we want // to use that client instead of the one we're trying to set. This is so all transports @@ -107,7 +108,7 @@ func (c *clientCache) setClient(issuer, clientID string, client *oidcAuthProvide return client } -func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) { +func newOIDCAuthProvider(clusterAddress string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) { issuer := cfg[cfgIssuerUrl] if issuer == "" { return nil, fmt.Errorf("Must provide %s", cfgIssuerUrl) @@ -119,7 +120,7 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A } // Check cache for existing provider. - if provider, ok := cache.getClient(issuer, clientID); ok { + if provider, ok := cache.getClient(clusterAddress, issuer, clientID); ok { return provider, nil } @@ -157,7 +158,7 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A persister: persister, } - return cache.setClient(issuer, clientID, provider), nil + return cache.setClient(clusterAddress, issuer, clientID, provider), nil } type oidcAuthProvider struct { diff --git a/plugin/pkg/client/auth/oidc/oidc_test.go b/plugin/pkg/client/auth/oidc/oidc_test.go index c14a3a84..a6407d06 100644 --- a/plugin/pkg/client/auth/oidc/oidc_test.go +++ b/plugin/pkg/client/auth/oidc/oidc_test.go @@ -119,20 +119,40 @@ func TestExpired(t *testing.T) { func TestClientCache(t *testing.T) { cache := newClientCache() - if _, ok := cache.getClient("issuer1", "id1"); ok { + if _, ok := cache.getClient("cluster1", "issuer1", "id1"); ok { t.Fatalf("got client before putting one in the cache") } + assertCacheLen(t, cache, 0) cli1 := new(oidcAuthProvider) cli2 := new(oidcAuthProvider) + cli3 := new(oidcAuthProvider) - gotcli := cache.setClient("issuer1", "id1", cli1) + gotcli := cache.setClient("cluster1", "issuer1", "id1", cli1) if cli1 != gotcli { t.Fatalf("set first client and got a different one") } + assertCacheLen(t, cache, 1) - gotcli = cache.setClient("issuer1", "id1", cli2) + gotcli = cache.setClient("cluster1", "issuer1", "id1", cli2) if cli1 != gotcli { t.Fatalf("set a second client and didn't get the first") } + assertCacheLen(t, cache, 1) + + gotcli = cache.setClient("cluster2", "issuer1", "id1", cli3) + if cli1 == gotcli { + t.Fatalf("set a third client and got the first") + } + if cli3 != gotcli { + t.Fatalf("set third client and got a different one") + } + assertCacheLen(t, cache, 2) +} + +func assertCacheLen(t *testing.T, cache *clientCache, length int) { + t.Helper() + if len(cache.cache) != length { + t.Errorf("expected cache length %d got %d", length, len(cache.cache)) + } }