From 77bd7c8a8b29dced5a06c232485ab6de1306c087 Mon Sep 17 00:00:00 2001 From: Weinong Wang Date: Tue, 31 Mar 2020 15:59:37 -0700 Subject: [PATCH] fix a bug where spn: prefix is unexpectedly added to kubeconfig apiserver-id setting --- .../plugin/pkg/client/auth/azure/azure.go | 5 +- .../pkg/client/auth/azure/azure_test.go | 150 ++++++++++++------ 2 files changed, 107 insertions(+), 48 deletions(-) diff --git a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure.go b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure.go index 2746536b2a6..fded604a3c7 100644 --- a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure.go +++ b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure.go @@ -307,8 +307,9 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) { if expiresOn == "" { return nil, fmt.Errorf("no expiresOn in cfg: %s", cfgExpiresOn) } + tokenAudience := resourceID if ts.configMode == configModeDefault { - resourceID = fmt.Sprintf("spn:%s", resourceID) + tokenAudience = fmt.Sprintf("spn:%s", resourceID) } return &azureToken{ @@ -318,7 +319,7 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) { ExpiresIn: json.Number(expiresIn), ExpiresOn: json.Number(expiresOn), NotBefore: json.Number(expiresOn), - Resource: resourceID, + Resource: tokenAudience, Type: tokenType, }, environment: environment, diff --git a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure_test.go b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure_test.go index 05612e296e4..e3418c712cb 100644 --- a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure_test.go +++ b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure_test.go @@ -19,6 +19,7 @@ package azure import ( "encoding/json" "errors" + "fmt" "net/http" "strconv" "strings" @@ -172,6 +173,55 @@ func TestAzureTokenSource(t *testing.T) { expectedConfigModes := []string{"1", "0"} for i, configMode := range configModes { + t.Run(fmt.Sprintf("validate token from cfg with configMode %v", configMode), func(t *testing.T) { + const ( + serverID = "fakeServerID" + clientID = "fakeClientID" + tenantID = "fakeTenantID" + accessToken = "fakeToken" + environment = "fakeEnvironment" + refreshToken = "fakeToken" + expiresIn = "foo" + expiresOn = "foo" + ) + cfg := map[string]string{ + cfgConfigMode: string(configMode), + cfgApiserverID: serverID, + cfgClientID: clientID, + cfgTenantID: tenantID, + cfgEnvironment: environment, + cfgAccessToken: accessToken, + cfgRefreshToken: refreshToken, + cfgExpiresIn: expiresIn, + cfgExpiresOn: expiresOn, + } + fakeSource := fakeTokenSource{token: newFakeAzureToken("fakeToken", time.Now().Add(3600*time.Second))} + persiter := &fakePersister{cache: make(map[string]string)} + tokenCache := newAzureTokenCache() + tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter) + azTokenSource := tokenSource.(*azureTokenSource) + token, err := azTokenSource.retrieveTokenFromCfg() + if err != nil { + t.Errorf("failed to retrieve the token form cfg: %s", err) + } + if token.apiserverID != serverID { + t.Errorf("expecting token.apiserverID: %s, actual: %s", serverID, token.apiserverID) + } + if token.clientID != clientID { + t.Errorf("expecting token.clientID: %s, actual: %s", clientID, token.clientID) + } + if token.tenantID != tenantID { + t.Errorf("expecting token.tenantID: %s, actual: %s", tenantID, token.tenantID) + } + expectedAudience := serverID + if configMode == configModeDefault { + expectedAudience = fmt.Sprintf("spn:%s", serverID) + } + if token.token.Resource != expectedAudience { + t.Errorf("expecting adal token.Resource: %s, actual: %s", expectedAudience, token.token.Resource) + } + }) + t.Run("validate token against cache", func(t *testing.T) { fakeAccessToken := "fake token 1" fakeSource := fakeTokenSource{token: newFakeAzureToken(fakeAccessToken, time.Now().Add(3600*time.Second))} @@ -223,7 +273,6 @@ func TestAzureTokenSource(t *testing.T) { } func TestAzureTokenSourceScenarios(t *testing.T) { - configMode := configModeDefault expiredToken := newFakeAzureToken("expired token", time.Now().Add(-time.Second)) extendedToken := newFakeAzureToken("extend token", time.Now().Add(1000*time.Second)) fakeToken := newFakeAzureToken("fake token", time.Now().Add(1000*time.Second)) @@ -309,57 +358,66 @@ func TestAzureTokenSourceScenarios(t *testing.T) { }, } for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - persister := newFakePersister() + configModes := []configMode{configModeOmitSPNPrefix, configModeDefault} - cfg := map[string]string{} - if tc.configToken != nil { - cfg = token2Cfg(tc.configToken) - } + for _, configMode := range configModes { + t.Run(fmt.Sprintf("%s with configMode: %v", tc.name, configMode), func(t *testing.T) { + persister := newFakePersister() - tokenCache := newAzureTokenCache() - if tc.cachedToken != nil { - tokenCache.setToken(azureTokenKey, tc.cachedToken) - } - - fakeSource := fakeTokenSource{ - token: tc.sourceToken, - tokenErr: tc.tokenErr, - refreshToken: tc.refreshToken, - refreshErr: tc.refreshErr, - } - - tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister) - token, err := tokenSource.Token() - - if fakeSource.tokenCalls != tc.tokenCalls { - t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls) - } - - if fakeSource.refreshCalls != tc.refreshCalls { - t.Errorf("expecting refreshCalls: %v, got: %v", tc.refreshCalls, fakeSource.refreshCalls) - } - - if persister.calls != tc.persistCalls { - t.Errorf("expecting persister calls: %v, got: %v", tc.persistCalls, persister.calls) - } - - if tc.expectErr != "" { - if !strings.Contains(err.Error(), tc.expectErr) { - t.Errorf("expecting error %v, got %v", tc.expectErr, err) + cfg := map[string]string{ + cfgConfigMode: string(configMode), } - if token != nil { - t.Errorf("token should be nil in err situation, got %v", token) + if tc.configToken != nil { + cfg = token2Cfg(tc.configToken) } - } else { - if err != nil { - t.Fatalf("error should be nil, got %v", err) + + tokenCache := newAzureTokenCache() + if tc.cachedToken != nil { + tokenCache.setToken(azureTokenKey, tc.cachedToken) } - if token.token.AccessToken != tc.expectToken.token.AccessToken { - t.Errorf("token should have accessToken %v, got %v", token.token.AccessToken, tc.expectToken.token.AccessToken) + + fakeSource := fakeTokenSource{ + token: tc.sourceToken, + tokenErr: tc.tokenErr, + refreshToken: tc.refreshToken, + refreshErr: tc.refreshErr, } - } - }) + + tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister) + token, err := tokenSource.Token() + + if token != nil && fakeSource.token != nil && token.apiserverID != fakeSource.token.apiserverID { + t.Errorf("expecting apiservierID: %s, got: %s", fakeSource.token.apiserverID, token.apiserverID) + } + if fakeSource.tokenCalls != tc.tokenCalls { + t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls) + } + + if fakeSource.refreshCalls != tc.refreshCalls { + t.Errorf("expecting refreshCalls: %v, got: %v", tc.refreshCalls, fakeSource.refreshCalls) + } + + if persister.calls != tc.persistCalls { + t.Errorf("expecting persister calls: %v, got: %v", tc.persistCalls, persister.calls) + } + + if tc.expectErr != "" { + if !strings.Contains(err.Error(), tc.expectErr) { + t.Errorf("expecting error %v, got %v", tc.expectErr, err) + } + if token != nil { + t.Errorf("token should be nil in err situation, got %v", token) + } + } else { + if err != nil { + t.Fatalf("error should be nil, got %v", err) + } + if token.token.AccessToken != tc.expectToken.token.AccessToken { + t.Errorf("token should have accessToken %v, got %v", token.token.AccessToken, tc.expectToken.token.AccessToken) + } + } + }) + } } }