fix a bug where spn: prefix is unexpectedly added to kubeconfig apiserver-id setting

Kubernetes-commit: 77bd7c8a8b29dced5a06c232485ab6de1306c087
This commit is contained in:
Weinong Wang
2020-03-31 15:59:37 -07:00
committed by Kubernetes Publisher
parent ed67da3a23
commit de57c8c011
2 changed files with 107 additions and 48 deletions

View File

@@ -307,8 +307,9 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) {
if expiresOn == "" {
return nil, fmt.Errorf("no expiresOn in cfg: %s", cfgExpiresOn)
}
tokenAudience := resourceID
if ts.configMode == configModeDefault {
resourceID = fmt.Sprintf("spn:%s", resourceID)
tokenAudience = fmt.Sprintf("spn:%s", resourceID)
}
return &azureToken{
@@ -318,7 +319,7 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) {
ExpiresIn: json.Number(expiresIn),
ExpiresOn: json.Number(expiresOn),
NotBefore: json.Number(expiresOn),
Resource: resourceID,
Resource: tokenAudience,
Type: tokenType,
},
environment: environment,

View File

@@ -19,6 +19,7 @@ package azure
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
@@ -172,6 +173,55 @@ func TestAzureTokenSource(t *testing.T) {
expectedConfigModes := []string{"1", "0"}
for i, configMode := range configModes {
t.Run(fmt.Sprintf("validate token from cfg with configMode %v", configMode), func(t *testing.T) {
const (
serverID = "fakeServerID"
clientID = "fakeClientID"
tenantID = "fakeTenantID"
accessToken = "fakeToken"
environment = "fakeEnvironment"
refreshToken = "fakeToken"
expiresIn = "foo"
expiresOn = "foo"
)
cfg := map[string]string{
cfgConfigMode: string(configMode),
cfgApiserverID: serverID,
cfgClientID: clientID,
cfgTenantID: tenantID,
cfgEnvironment: environment,
cfgAccessToken: accessToken,
cfgRefreshToken: refreshToken,
cfgExpiresIn: expiresIn,
cfgExpiresOn: expiresOn,
}
fakeSource := fakeTokenSource{token: newFakeAzureToken("fakeToken", time.Now().Add(3600*time.Second))}
persiter := &fakePersister{cache: make(map[string]string)}
tokenCache := newAzureTokenCache()
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter)
azTokenSource := tokenSource.(*azureTokenSource)
token, err := azTokenSource.retrieveTokenFromCfg()
if err != nil {
t.Errorf("failed to retrieve the token form cfg: %s", err)
}
if token.apiserverID != serverID {
t.Errorf("expecting token.apiserverID: %s, actual: %s", serverID, token.apiserverID)
}
if token.clientID != clientID {
t.Errorf("expecting token.clientID: %s, actual: %s", clientID, token.clientID)
}
if token.tenantID != tenantID {
t.Errorf("expecting token.tenantID: %s, actual: %s", tenantID, token.tenantID)
}
expectedAudience := serverID
if configMode == configModeDefault {
expectedAudience = fmt.Sprintf("spn:%s", serverID)
}
if token.token.Resource != expectedAudience {
t.Errorf("expecting adal token.Resource: %s, actual: %s", expectedAudience, token.token.Resource)
}
})
t.Run("validate token against cache", func(t *testing.T) {
fakeAccessToken := "fake token 1"
fakeSource := fakeTokenSource{token: newFakeAzureToken(fakeAccessToken, time.Now().Add(3600*time.Second))}
@@ -223,7 +273,6 @@ func TestAzureTokenSource(t *testing.T) {
}
func TestAzureTokenSourceScenarios(t *testing.T) {
configMode := configModeDefault
expiredToken := newFakeAzureToken("expired token", time.Now().Add(-time.Second))
extendedToken := newFakeAzureToken("extend token", time.Now().Add(1000*time.Second))
fakeToken := newFakeAzureToken("fake token", time.Now().Add(1000*time.Second))
@@ -309,10 +358,15 @@ func TestAzureTokenSourceScenarios(t *testing.T) {
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
configModes := []configMode{configModeOmitSPNPrefix, configModeDefault}
for _, configMode := range configModes {
t.Run(fmt.Sprintf("%s with configMode: %v", tc.name, configMode), func(t *testing.T) {
persister := newFakePersister()
cfg := map[string]string{}
cfg := map[string]string{
cfgConfigMode: string(configMode),
}
if tc.configToken != nil {
cfg = token2Cfg(tc.configToken)
}
@@ -332,6 +386,9 @@ func TestAzureTokenSourceScenarios(t *testing.T) {
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister)
token, err := tokenSource.Token()
if token != nil && fakeSource.token != nil && token.apiserverID != fakeSource.token.apiserverID {
t.Errorf("expecting apiservierID: %s, got: %s", fakeSource.token.apiserverID, token.apiserverID)
}
if fakeSource.tokenCalls != tc.tokenCalls {
t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls)
}
@@ -362,6 +419,7 @@ func TestAzureTokenSourceScenarios(t *testing.T) {
})
}
}
}
type fakePersister struct {
lock sync.Mutex