kubectl oidc auth-provider: include cluster address in cache key

This change includes the cluster address in the cache key so that
using the same issuer and client ID with different tokens across
multiple clusters does not result in the wrong token being used for
authentication.

Signed-off-by: Monis Khan <mok@vmware.com>

Kubernetes-commit: 96fe76a9ed4fde16f449995cc698dca3719ed546
This commit is contained in:
Monis Khan 2019-12-06 20:26:25 -05:00 committed by Kubernetes Publisher
parent d528d16a5d
commit 98b61416aa
2 changed files with 31 additions and 10 deletions

View File

@ -76,24 +76,25 @@ func newClientCache() *clientCache {
}
type cacheKey struct {
clusterAddress string
// Canonical issuer URL string of the provider.
issuerURL string
clientID string
}
func (c *clientCache) getClient(issuer, clientID string) (*oidcAuthProvider, bool) {
func (c *clientCache) getClient(clusterAddress, issuer, clientID string) (*oidcAuthProvider, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
client, ok := c.cache[cacheKey{issuer, clientID}]
client, ok := c.cache[cacheKey{clusterAddress: clusterAddress, issuerURL: issuer, clientID: clientID}]
return client, ok
}
// setClient attempts to put the client in the cache but may return any clients
// with the same keys set before. This is so there's only ever one client for a provider.
func (c *clientCache) setClient(issuer, clientID string, client *oidcAuthProvider) *oidcAuthProvider {
func (c *clientCache) setClient(clusterAddress, issuer, clientID string, client *oidcAuthProvider) *oidcAuthProvider {
c.mu.Lock()
defer c.mu.Unlock()
key := cacheKey{issuer, clientID}
key := cacheKey{clusterAddress: clusterAddress, issuerURL: issuer, clientID: clientID}
// If another client has already initialized a client for the given provider we want
// to use that client instead of the one we're trying to set. This is so all transports
@ -107,7 +108,7 @@ func (c *clientCache) setClient(issuer, clientID string, client *oidcAuthProvide
return client
}
func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
func newOIDCAuthProvider(clusterAddress string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
issuer := cfg[cfgIssuerUrl]
if issuer == "" {
return nil, fmt.Errorf("Must provide %s", cfgIssuerUrl)
@ -119,7 +120,7 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
}
// Check cache for existing provider.
if provider, ok := cache.getClient(issuer, clientID); ok {
if provider, ok := cache.getClient(clusterAddress, issuer, clientID); ok {
return provider, nil
}
@ -157,7 +158,7 @@ func newOIDCAuthProvider(_ string, cfg map[string]string, persister restclient.A
persister: persister,
}
return cache.setClient(issuer, clientID, provider), nil
return cache.setClient(clusterAddress, issuer, clientID, provider), nil
}
type oidcAuthProvider struct {

View File

@ -119,20 +119,40 @@ func TestExpired(t *testing.T) {
func TestClientCache(t *testing.T) {
cache := newClientCache()
if _, ok := cache.getClient("issuer1", "id1"); ok {
if _, ok := cache.getClient("cluster1", "issuer1", "id1"); ok {
t.Fatalf("got client before putting one in the cache")
}
assertCacheLen(t, cache, 0)
cli1 := new(oidcAuthProvider)
cli2 := new(oidcAuthProvider)
cli3 := new(oidcAuthProvider)
gotcli := cache.setClient("issuer1", "id1", cli1)
gotcli := cache.setClient("cluster1", "issuer1", "id1", cli1)
if cli1 != gotcli {
t.Fatalf("set first client and got a different one")
}
assertCacheLen(t, cache, 1)
gotcli = cache.setClient("issuer1", "id1", cli2)
gotcli = cache.setClient("cluster1", "issuer1", "id1", cli2)
if cli1 != gotcli {
t.Fatalf("set a second client and didn't get the first")
}
assertCacheLen(t, cache, 1)
gotcli = cache.setClient("cluster2", "issuer1", "id1", cli3)
if cli1 == gotcli {
t.Fatalf("set a third client and got the first")
}
if cli3 != gotcli {
t.Fatalf("set third client and got a different one")
}
assertCacheLen(t, cache, 2)
}
func assertCacheLen(t *testing.T, cache *clientCache, length int) {
t.Helper()
if len(cache.cache) != length {
t.Errorf("expected cache length %d got %d", length, len(cache.cache))
}
}