mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-24 20:24:09 +00:00
Merge pull request #87630 from weinong/feat-87541
add a flag in azure auth module to omit spn: prefix in audience claim
This commit is contained in:
commit
913c9ef150
@ -10,7 +10,10 @@ go_test(
|
|||||||
name = "go_default_test",
|
name = "go_default_test",
|
||||||
srcs = ["azure_test.go"],
|
srcs = ["azure_test.go"],
|
||||||
embed = [":go_default_library"],
|
embed = [":go_default_library"],
|
||||||
deps = ["//vendor/github.com/Azure/go-autorest/autorest/adal:go_default_library"],
|
deps = [
|
||||||
|
"//vendor/github.com/Azure/go-autorest/autorest/adal:go_default_library",
|
||||||
|
"//vendor/github.com/Azure/go-autorest/autorest/azure:go_default_library",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
go_library(
|
go_library(
|
||||||
|
@ -38,7 +38,13 @@ This plugin provides an integration with Azure Active Directory device flow. If
|
|||||||
* Replace `APISERVER_APPLICATION_ID` with the application ID of your `apiserver` application ID
|
* Replace `APISERVER_APPLICATION_ID` with the application ID of your `apiserver` application ID
|
||||||
* Be sure to also (create and) select a context that uses above user
|
* Be sure to also (create and) select a context that uses above user
|
||||||
|
|
||||||
6. The access token is acquired when first `kubectl` command is executed
|
6. (Optionally) the AAD token has `aud` claim with `spn:` prefix. To omit that, add following auth configuration:
|
||||||
|
|
||||||
|
```
|
||||||
|
--auth-provider-arg=config-mode="1"
|
||||||
|
```
|
||||||
|
|
||||||
|
7. The access token is acquired when first `kubectl` command is executed
|
||||||
|
|
||||||
```
|
```
|
||||||
kubectl get pods
|
kubectl get pods
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/Azure/go-autorest/autorest"
|
"github.com/Azure/go-autorest/autorest"
|
||||||
@ -33,6 +34,8 @@ import (
|
|||||||
restclient "k8s.io/client-go/rest"
|
restclient "k8s.io/client-go/rest"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type configMode int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
azureTokenKey = "azureTokenKey"
|
azureTokenKey = "azureTokenKey"
|
||||||
tokenType = "Bearer"
|
tokenType = "Bearer"
|
||||||
@ -46,6 +49,10 @@ const (
|
|||||||
cfgExpiresOn = "expires-on"
|
cfgExpiresOn = "expires-on"
|
||||||
cfgEnvironment = "environment"
|
cfgEnvironment = "environment"
|
||||||
cfgApiserverID = "apiserver-id"
|
cfgApiserverID = "apiserver-id"
|
||||||
|
cfgConfigMode = "config-mode"
|
||||||
|
|
||||||
|
configModeDefault configMode = 0
|
||||||
|
configModeOmitSPNPrefix configMode = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@ -78,17 +85,37 @@ func (c *azureTokenCache) setToken(tokenKey string, token *azureToken) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newAzureAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
|
func newAzureAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
|
||||||
var ts tokenSource
|
var (
|
||||||
|
ts tokenSource
|
||||||
|
environment azure.Environment
|
||||||
|
err error
|
||||||
|
mode configMode
|
||||||
|
)
|
||||||
|
|
||||||
environment, err := azure.EnvironmentFromName(cfg[cfgEnvironment])
|
environment, err = azure.EnvironmentFromName(cfg[cfgEnvironment])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
environment = azure.PublicCloud
|
environment = azure.PublicCloud
|
||||||
}
|
}
|
||||||
ts, err = newAzureTokenSourceDeviceCode(environment, cfg[cfgClientID], cfg[cfgTenantID], cfg[cfgApiserverID])
|
|
||||||
|
mode = configModeDefault
|
||||||
|
if cfg[cfgConfigMode] != "" {
|
||||||
|
configModeInt, err := strconv.Atoi(cfg[cfgConfigMode])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse %s, error: %s", cfgConfigMode, err)
|
||||||
|
}
|
||||||
|
mode = configMode(configModeInt)
|
||||||
|
switch mode {
|
||||||
|
case configModeOmitSPNPrefix:
|
||||||
|
case configModeDefault:
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("%s:%s is not a valid mode", cfgConfigMode, cfg[cfgConfigMode])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ts, err = newAzureTokenSourceDeviceCode(environment, cfg[cfgClientID], cfg[cfgTenantID], cfg[cfgApiserverID], mode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("creating a new azure token source for device code authentication: %v", err)
|
return nil, fmt.Errorf("creating a new azure token source for device code authentication: %v", err)
|
||||||
}
|
}
|
||||||
cacheSource := newAzureTokenSource(ts, cache, cfg, persister)
|
cacheSource := newAzureTokenSource(ts, cache, cfg, mode, persister)
|
||||||
|
|
||||||
return &azureAuthProvider{
|
return &azureAuthProvider{
|
||||||
tokenSource: cacheSource,
|
tokenSource: cacheSource,
|
||||||
@ -156,19 +183,21 @@ type tokenSource interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type azureTokenSource struct {
|
type azureTokenSource struct {
|
||||||
source tokenSource
|
source tokenSource
|
||||||
cache *azureTokenCache
|
cache *azureTokenCache
|
||||||
lock sync.Mutex
|
lock sync.Mutex
|
||||||
cfg map[string]string
|
configMode configMode
|
||||||
persister restclient.AuthProviderConfigPersister
|
cfg map[string]string
|
||||||
|
persister restclient.AuthProviderConfigPersister
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAzureTokenSource(source tokenSource, cache *azureTokenCache, cfg map[string]string, persister restclient.AuthProviderConfigPersister) tokenSource {
|
func newAzureTokenSource(source tokenSource, cache *azureTokenCache, cfg map[string]string, configMode configMode, persister restclient.AuthProviderConfigPersister) tokenSource {
|
||||||
return &azureTokenSource{
|
return &azureTokenSource{
|
||||||
source: source,
|
source: source,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
persister: persister,
|
persister: persister,
|
||||||
|
configMode: configMode,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -232,9 +261,9 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) {
|
|||||||
if tenantID == "" {
|
if tenantID == "" {
|
||||||
return nil, fmt.Errorf("no tenant ID in cfg: %s", cfgTenantID)
|
return nil, fmt.Errorf("no tenant ID in cfg: %s", cfgTenantID)
|
||||||
}
|
}
|
||||||
apiserverID := ts.cfg[cfgApiserverID]
|
resourceID := ts.cfg[cfgApiserverID]
|
||||||
if apiserverID == "" {
|
if resourceID == "" {
|
||||||
return nil, fmt.Errorf("no apiserver ID in cfg: %s", apiserverID)
|
return nil, fmt.Errorf("no apiserver ID in cfg: %s", cfgApiserverID)
|
||||||
}
|
}
|
||||||
expiresIn := ts.cfg[cfgExpiresIn]
|
expiresIn := ts.cfg[cfgExpiresIn]
|
||||||
if expiresIn == "" {
|
if expiresIn == "" {
|
||||||
@ -244,6 +273,9 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) {
|
|||||||
if expiresOn == "" {
|
if expiresOn == "" {
|
||||||
return nil, fmt.Errorf("no expiresOn in cfg: %s", cfgExpiresOn)
|
return nil, fmt.Errorf("no expiresOn in cfg: %s", cfgExpiresOn)
|
||||||
}
|
}
|
||||||
|
if ts.configMode == configModeDefault {
|
||||||
|
resourceID = fmt.Sprintf("spn:%s", resourceID)
|
||||||
|
}
|
||||||
|
|
||||||
return &azureToken{
|
return &azureToken{
|
||||||
token: adal.Token{
|
token: adal.Token{
|
||||||
@ -252,13 +284,13 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) {
|
|||||||
ExpiresIn: json.Number(expiresIn),
|
ExpiresIn: json.Number(expiresIn),
|
||||||
ExpiresOn: json.Number(expiresOn),
|
ExpiresOn: json.Number(expiresOn),
|
||||||
NotBefore: json.Number(expiresOn),
|
NotBefore: json.Number(expiresOn),
|
||||||
Resource: fmt.Sprintf("spn:%s", apiserverID),
|
Resource: resourceID,
|
||||||
Type: tokenType,
|
Type: tokenType,
|
||||||
},
|
},
|
||||||
environment: environment,
|
environment: environment,
|
||||||
clientID: clientID,
|
clientID: clientID,
|
||||||
tenantID: tenantID,
|
tenantID: tenantID,
|
||||||
apiserverID: apiserverID,
|
apiserverID: resourceID,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -272,6 +304,7 @@ func (ts *azureTokenSource) storeTokenInCfg(token *azureToken) error {
|
|||||||
newCfg[cfgApiserverID] = token.apiserverID
|
newCfg[cfgApiserverID] = token.apiserverID
|
||||||
newCfg[cfgExpiresIn] = string(token.token.ExpiresIn)
|
newCfg[cfgExpiresIn] = string(token.token.ExpiresIn)
|
||||||
newCfg[cfgExpiresOn] = string(token.token.ExpiresOn)
|
newCfg[cfgExpiresOn] = string(token.token.ExpiresOn)
|
||||||
|
newCfg[cfgConfigMode] = strconv.Itoa(int(ts.configMode))
|
||||||
|
|
||||||
err := ts.persister.Persist(newCfg)
|
err := ts.persister.Persist(newCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -287,9 +320,17 @@ func (ts *azureTokenSource) refreshToken(token *azureToken) (*azureToken, error)
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, token.tenantID)
|
var oauthConfig *adal.OAuthConfig
|
||||||
if err != nil {
|
if ts.configMode == configModeOmitSPNPrefix {
|
||||||
return nil, fmt.Errorf("building the OAuth configuration for token refresh: %v", err)
|
oauthConfig, err = adal.NewOAuthConfigWithAPIVersion(env.ActiveDirectoryEndpoint, token.tenantID, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("building the OAuth configuration without api-version for token refresh: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
oauthConfig, err = adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, token.tenantID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("building the OAuth configuration for token refresh: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
callback := func(t adal.Token) error {
|
callback := func(t adal.Token) error {
|
||||||
@ -323,9 +364,10 @@ type azureTokenSourceDeviceCode struct {
|
|||||||
clientID string
|
clientID string
|
||||||
tenantID string
|
tenantID string
|
||||||
apiserverID string
|
apiserverID string
|
||||||
|
configMode configMode
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAzureTokenSourceDeviceCode(environment azure.Environment, clientID string, tenantID string, apiserverID string) (tokenSource, error) {
|
func newAzureTokenSourceDeviceCode(environment azure.Environment, clientID string, tenantID string, apiserverID string, configMode configMode) (tokenSource, error) {
|
||||||
if clientID == "" {
|
if clientID == "" {
|
||||||
return nil, errors.New("client-id is empty")
|
return nil, errors.New("client-id is empty")
|
||||||
}
|
}
|
||||||
@ -340,13 +382,25 @@ func newAzureTokenSourceDeviceCode(environment azure.Environment, clientID strin
|
|||||||
clientID: clientID,
|
clientID: clientID,
|
||||||
tenantID: tenantID,
|
tenantID: tenantID,
|
||||||
apiserverID: apiserverID,
|
apiserverID: apiserverID,
|
||||||
|
configMode: configMode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *azureTokenSourceDeviceCode) Token() (*azureToken, error) {
|
func (ts *azureTokenSourceDeviceCode) Token() (*azureToken, error) {
|
||||||
oauthConfig, err := adal.NewOAuthConfig(ts.environment.ActiveDirectoryEndpoint, ts.tenantID)
|
var (
|
||||||
if err != nil {
|
oauthConfig *adal.OAuthConfig
|
||||||
return nil, fmt.Errorf("building the OAuth configuration for device code authentication: %v", err)
|
err error
|
||||||
|
)
|
||||||
|
if ts.configMode == configModeOmitSPNPrefix {
|
||||||
|
oauthConfig, err = adal.NewOAuthConfigWithAPIVersion(ts.environment.ActiveDirectoryEndpoint, ts.tenantID, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("building the OAuth configuration without api-version for device code authentication: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
oauthConfig, err = adal.NewOAuthConfig(ts.environment.ActiveDirectoryEndpoint, ts.tenantID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("building the OAuth configuration for device code authentication: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
client := &autorest.Client{}
|
client := &autorest.Client{}
|
||||||
deviceCode, err := adal.InitiateDeviceAuth(client, *oauthConfig, ts.clientID, ts.apiserverID)
|
deviceCode, err := adal.InitiateDeviceAuth(client, *oauthConfig, ts.clientID, ts.apiserverID)
|
||||||
|
@ -25,55 +25,201 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Azure/go-autorest/autorest/adal"
|
"github.com/Azure/go-autorest/autorest/adal"
|
||||||
|
"github.com/Azure/go-autorest/autorest/azure"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAzureTokenSource(t *testing.T) {
|
func TestAzureAuthProvider(t *testing.T) {
|
||||||
fakeAccessToken := "fake token 1"
|
t.Run("validate against invalid configurations", func(t *testing.T) {
|
||||||
fakeSource := fakeTokenSource{
|
vectors := []struct {
|
||||||
accessToken: fakeAccessToken,
|
cfg map[string]string
|
||||||
expiresOn: strconv.FormatInt(time.Now().Add(3600*time.Second).Unix(), 10),
|
expectedError string
|
||||||
}
|
}{
|
||||||
cfg := make(map[string]string)
|
{
|
||||||
persiter := &fakePersister{cache: make(map[string]string)}
|
cfg: map[string]string{
|
||||||
tokenCache := newAzureTokenCache()
|
cfgClientID: "foo",
|
||||||
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, persiter)
|
cfgApiserverID: "foo",
|
||||||
token, err := tokenSource.Token()
|
cfgTenantID: "foo",
|
||||||
if err != nil {
|
cfgConfigMode: "-1",
|
||||||
t.Errorf("failed to retrieve the token form cache: %v", err)
|
},
|
||||||
}
|
expectedError: "config-mode:-1 is not a valid mode",
|
||||||
|
},
|
||||||
wantCacheLen := 1
|
{
|
||||||
if len(tokenCache.cache) != wantCacheLen {
|
cfg: map[string]string{
|
||||||
t.Errorf("Token() cache length error: got %v, want %v", len(tokenCache.cache), wantCacheLen)
|
cfgClientID: "foo",
|
||||||
}
|
cfgApiserverID: "foo",
|
||||||
|
cfgTenantID: "foo",
|
||||||
if token != tokenCache.cache[azureTokenKey] {
|
cfgConfigMode: "2",
|
||||||
t.Error("Token() returned token != cached token")
|
},
|
||||||
}
|
expectedError: "config-mode:2 is not a valid mode",
|
||||||
|
},
|
||||||
wantCfg := token2Cfg(token)
|
{
|
||||||
persistedCfg := persiter.Cache()
|
cfg: map[string]string{
|
||||||
|
cfgClientID: "foo",
|
||||||
wantCfgLen := len(wantCfg)
|
cfgApiserverID: "foo",
|
||||||
persistedCfgLen := len(persistedCfg)
|
cfgTenantID: "foo",
|
||||||
if wantCfgLen != persistedCfgLen {
|
cfgConfigMode: "foo",
|
||||||
t.Errorf("wantCfgLen and persistedCfgLen do not match, wantCfgLen=%v, persistedCfgLen=%v", wantCfgLen, persistedCfgLen)
|
},
|
||||||
}
|
expectedError: "failed to parse config-mode, error: strconv.Atoi: parsing \"foo\": invalid syntax",
|
||||||
|
},
|
||||||
for k, v := range persistedCfg {
|
|
||||||
if strings.Compare(v, wantCfg[k]) != 0 {
|
|
||||||
t.Errorf("Token() persisted cfg %s: got %v, want %v", k, v, wantCfg[k])
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
fakeSource.accessToken = "fake token 2"
|
for _, v := range vectors {
|
||||||
token, err = tokenSource.Token()
|
persister := &fakePersister{}
|
||||||
if err != nil {
|
_, err := newAzureAuthProvider("", v.cfg, persister)
|
||||||
t.Errorf("failed to retrieve the cached token: %v", err)
|
if !strings.Contains(err.Error(), v.expectedError) {
|
||||||
}
|
t.Errorf("cfg %v should fail with message containing '%s'. actual: '%s'", v.cfg, v.expectedError, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
if token.token.AccessToken != fakeAccessToken {
|
t.Run("it should return non-nil provider in happy cases", func(t *testing.T) {
|
||||||
t.Errorf("Token() didn't return the cached token")
|
vectors := []struct {
|
||||||
|
cfg map[string]string
|
||||||
|
expectedConfigMode configMode
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
cfg: map[string]string{
|
||||||
|
cfgClientID: "foo",
|
||||||
|
cfgApiserverID: "foo",
|
||||||
|
cfgTenantID: "foo",
|
||||||
|
},
|
||||||
|
expectedConfigMode: configModeDefault,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
cfg: map[string]string{
|
||||||
|
cfgClientID: "foo",
|
||||||
|
cfgApiserverID: "foo",
|
||||||
|
cfgTenantID: "foo",
|
||||||
|
cfgConfigMode: "0",
|
||||||
|
},
|
||||||
|
expectedConfigMode: configModeDefault,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
cfg: map[string]string{
|
||||||
|
cfgClientID: "foo",
|
||||||
|
cfgApiserverID: "foo",
|
||||||
|
cfgTenantID: "foo",
|
||||||
|
cfgConfigMode: "1",
|
||||||
|
},
|
||||||
|
expectedConfigMode: configModeOmitSPNPrefix,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range vectors {
|
||||||
|
persister := &fakePersister{}
|
||||||
|
provider, err := newAzureAuthProvider("", v.cfg, persister)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("newAzureAuthProvider should not fail with '%s'", err)
|
||||||
|
}
|
||||||
|
if provider == nil {
|
||||||
|
t.Fatalf("newAzureAuthProvider should return non-nil provider")
|
||||||
|
}
|
||||||
|
azureProvider := provider.(*azureAuthProvider)
|
||||||
|
if azureProvider == nil {
|
||||||
|
t.Errorf("newAzureAuthProvider should return an instance of type azureAuthProvider")
|
||||||
|
}
|
||||||
|
ts := azureProvider.tokenSource.(*azureTokenSource)
|
||||||
|
if ts == nil {
|
||||||
|
t.Errorf("azureAuthProvider should be an instance of azureTokenSource")
|
||||||
|
}
|
||||||
|
if ts.configMode != v.expectedConfigMode {
|
||||||
|
t.Errorf("expected configMode: %d, actual: %d", v.expectedConfigMode, ts.configMode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenSourceDeviceCode(t *testing.T) {
|
||||||
|
var (
|
||||||
|
clientID = "clientID"
|
||||||
|
tenantID = "tenantID"
|
||||||
|
apiserverID = "apiserverID"
|
||||||
|
configMode = configModeDefault
|
||||||
|
azureEnv = azure.Environment{}
|
||||||
|
)
|
||||||
|
t.Run("validate to create azureTokenSourceDeviceCode", func(t *testing.T) {
|
||||||
|
if _, err := newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, apiserverID, configModeDefault); err != nil {
|
||||||
|
t.Errorf("newAzureTokenSourceDeviceCode should not have failed. err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, apiserverID, configModeOmitSPNPrefix); err != nil {
|
||||||
|
t.Errorf("newAzureTokenSourceDeviceCode should not have failed. err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := newAzureTokenSourceDeviceCode(azureEnv, "", tenantID, apiserverID, configMode)
|
||||||
|
actual := "client-id is empty"
|
||||||
|
if err.Error() != actual {
|
||||||
|
t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = newAzureTokenSourceDeviceCode(azureEnv, clientID, "", apiserverID, configMode)
|
||||||
|
actual = "tenant-id is empty"
|
||||||
|
if err.Error() != actual {
|
||||||
|
t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, "", configMode)
|
||||||
|
actual = "apiserver-id is empty"
|
||||||
|
if err.Error() != actual {
|
||||||
|
t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
func TestAzureTokenSource(t *testing.T) {
|
||||||
|
configModes := []configMode{configModeOmitSPNPrefix, configModeDefault}
|
||||||
|
expectedConfigModes := []string{"1", "0"}
|
||||||
|
|
||||||
|
for i, configMode := range configModes {
|
||||||
|
t.Run("validate token against cache", func(t *testing.T) {
|
||||||
|
fakeAccessToken := "fake token 1"
|
||||||
|
fakeSource := fakeTokenSource{
|
||||||
|
accessToken: fakeAccessToken,
|
||||||
|
expiresOn: strconv.FormatInt(time.Now().Add(3600*time.Second).Unix(), 10),
|
||||||
|
}
|
||||||
|
cfg := make(map[string]string)
|
||||||
|
persiter := &fakePersister{cache: make(map[string]string)}
|
||||||
|
tokenCache := newAzureTokenCache()
|
||||||
|
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter)
|
||||||
|
token, err := tokenSource.Token()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to retrieve the token form cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantCacheLen := 1
|
||||||
|
if len(tokenCache.cache) != wantCacheLen {
|
||||||
|
t.Errorf("Token() cache length error: got %v, want %v", len(tokenCache.cache), wantCacheLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
if token != tokenCache.cache[azureTokenKey] {
|
||||||
|
t.Error("Token() returned token != cached token")
|
||||||
|
}
|
||||||
|
|
||||||
|
wantCfg := token2Cfg(token)
|
||||||
|
wantCfg[cfgConfigMode] = expectedConfigModes[i]
|
||||||
|
persistedCfg := persiter.Cache()
|
||||||
|
|
||||||
|
wantCfgLen := len(wantCfg)
|
||||||
|
persistedCfgLen := len(persistedCfg)
|
||||||
|
if wantCfgLen != persistedCfgLen {
|
||||||
|
t.Errorf("wantCfgLen and persistedCfgLen do not match, wantCfgLen=%v, persistedCfgLen=%v", wantCfgLen, persistedCfgLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range persistedCfg {
|
||||||
|
if strings.Compare(v, wantCfg[k]) != 0 {
|
||||||
|
t.Errorf("Token() persisted cfg %s: got %v, want %v", k, v, wantCfg[k])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fakeSource.accessToken = "fake token 2"
|
||||||
|
token, err = tokenSource.Token()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to retrieve the cached token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if token.token.AccessToken != fakeAccessToken {
|
||||||
|
t.Errorf("Token() didn't return the cached token")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user