diff --git a/plugin/pkg/client/auth/azure/azure.go b/plugin/pkg/client/auth/azure/azure.go index 11ef7afb2..f59f2f59e 100644 --- a/plugin/pkg/client/auth/azure/azure.go +++ b/plugin/pkg/client/auth/azure/azure.go @@ -273,8 +273,9 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) { if expiresOn == "" { return nil, fmt.Errorf("no expiresOn in cfg: %s", cfgExpiresOn) } + tokenAudience := resourceID if ts.configMode == configModeDefault { - resourceID = fmt.Sprintf("spn:%s", resourceID) + tokenAudience = fmt.Sprintf("spn:%s", resourceID) } return &azureToken{ @@ -284,7 +285,7 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) { ExpiresIn: json.Number(expiresIn), ExpiresOn: json.Number(expiresOn), NotBefore: json.Number(expiresOn), - Resource: resourceID, + Resource: tokenAudience, Type: tokenType, }, environment: environment, diff --git a/plugin/pkg/client/auth/azure/azure_test.go b/plugin/pkg/client/auth/azure/azure_test.go index 75fb6c3ad..ba093add2 100644 --- a/plugin/pkg/client/auth/azure/azure_test.go +++ b/plugin/pkg/client/auth/azure/azure_test.go @@ -18,6 +18,7 @@ package azure import ( "encoding/json" + "fmt" "strconv" "strings" "sync" @@ -170,6 +171,55 @@ func TestAzureTokenSource(t *testing.T) { expectedConfigModes := []string{"1", "0"} for i, configMode := range configModes { + t.Run(fmt.Sprintf("validate token from cfg with configMode %v", configMode), func(t *testing.T) { + const ( + serverID = "fakeServerID" + clientID = "fakeClientID" + tenantID = "fakeTenantID" + accessToken = "fakeToken" + environment = "fakeEnvironment" + refreshToken = "fakeToken" + expiresIn = "foo" + expiresOn = "foo" + ) + cfg := map[string]string{ + cfgConfigMode: string(configMode), + cfgApiserverID: serverID, + cfgClientID: clientID, + cfgTenantID: tenantID, + cfgEnvironment: environment, + cfgAccessToken: accessToken, + cfgRefreshToken: refreshToken, + cfgExpiresIn: expiresIn, + cfgExpiresOn: expiresOn, + } + fakeSource := fakeTokenSource{} + persiter := &fakePersister{cache: make(map[string]string)} + tokenCache := newAzureTokenCache() + tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter) + azTokenSource := tokenSource.(*azureTokenSource) + token, err := azTokenSource.retrieveTokenFromCfg() + if err != nil { + t.Errorf("failed to retrieve the token form cfg: %s", err) + } + if token.apiserverID != serverID { + t.Errorf("expecting token.apiserverID: %s, actual: %s", serverID, token.apiserverID) + } + if token.clientID != clientID { + t.Errorf("expecting token.clientID: %s, actual: %s", clientID, token.clientID) + } + if token.tenantID != tenantID { + t.Errorf("expecting token.tenantID: %s, actual: %s", tenantID, token.tenantID) + } + expectedAudience := serverID + if configMode == configModeDefault { + expectedAudience = fmt.Sprintf("spn:%s", serverID) + } + if token.token.Resource != expectedAudience { + t.Errorf("expecting adal token.Resource: %s, actual: %s", expectedAudience, token.token.Resource) + } + }) + t.Run("validate token against cache", func(t *testing.T) { fakeAccessToken := "fake token 1" fakeSource := fakeTokenSource{