mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-25 04:33:26 +00:00
fix a bug where spn: prefix is unexpectedly added to kubeconfig apiserver-id setting
This commit is contained in:
parent
1168b4b812
commit
77bd7c8a8b
@ -307,8 +307,9 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) {
|
|||||||
if expiresOn == "" {
|
if expiresOn == "" {
|
||||||
return nil, fmt.Errorf("no expiresOn in cfg: %s", cfgExpiresOn)
|
return nil, fmt.Errorf("no expiresOn in cfg: %s", cfgExpiresOn)
|
||||||
}
|
}
|
||||||
|
tokenAudience := resourceID
|
||||||
if ts.configMode == configModeDefault {
|
if ts.configMode == configModeDefault {
|
||||||
resourceID = fmt.Sprintf("spn:%s", resourceID)
|
tokenAudience = fmt.Sprintf("spn:%s", resourceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &azureToken{
|
return &azureToken{
|
||||||
@ -318,7 +319,7 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) {
|
|||||||
ExpiresIn: json.Number(expiresIn),
|
ExpiresIn: json.Number(expiresIn),
|
||||||
ExpiresOn: json.Number(expiresOn),
|
ExpiresOn: json.Number(expiresOn),
|
||||||
NotBefore: json.Number(expiresOn),
|
NotBefore: json.Number(expiresOn),
|
||||||
Resource: resourceID,
|
Resource: tokenAudience,
|
||||||
Type: tokenType,
|
Type: tokenType,
|
||||||
},
|
},
|
||||||
environment: environment,
|
environment: environment,
|
||||||
|
@ -19,6 +19,7 @@ package azure
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@ -172,6 +173,55 @@ func TestAzureTokenSource(t *testing.T) {
|
|||||||
expectedConfigModes := []string{"1", "0"}
|
expectedConfigModes := []string{"1", "0"}
|
||||||
|
|
||||||
for i, configMode := range configModes {
|
for i, configMode := range configModes {
|
||||||
|
t.Run(fmt.Sprintf("validate token from cfg with configMode %v", configMode), func(t *testing.T) {
|
||||||
|
const (
|
||||||
|
serverID = "fakeServerID"
|
||||||
|
clientID = "fakeClientID"
|
||||||
|
tenantID = "fakeTenantID"
|
||||||
|
accessToken = "fakeToken"
|
||||||
|
environment = "fakeEnvironment"
|
||||||
|
refreshToken = "fakeToken"
|
||||||
|
expiresIn = "foo"
|
||||||
|
expiresOn = "foo"
|
||||||
|
)
|
||||||
|
cfg := map[string]string{
|
||||||
|
cfgConfigMode: string(configMode),
|
||||||
|
cfgApiserverID: serverID,
|
||||||
|
cfgClientID: clientID,
|
||||||
|
cfgTenantID: tenantID,
|
||||||
|
cfgEnvironment: environment,
|
||||||
|
cfgAccessToken: accessToken,
|
||||||
|
cfgRefreshToken: refreshToken,
|
||||||
|
cfgExpiresIn: expiresIn,
|
||||||
|
cfgExpiresOn: expiresOn,
|
||||||
|
}
|
||||||
|
fakeSource := fakeTokenSource{token: newFakeAzureToken("fakeToken", time.Now().Add(3600*time.Second))}
|
||||||
|
persiter := &fakePersister{cache: make(map[string]string)}
|
||||||
|
tokenCache := newAzureTokenCache()
|
||||||
|
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter)
|
||||||
|
azTokenSource := tokenSource.(*azureTokenSource)
|
||||||
|
token, err := azTokenSource.retrieveTokenFromCfg()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to retrieve the token form cfg: %s", err)
|
||||||
|
}
|
||||||
|
if token.apiserverID != serverID {
|
||||||
|
t.Errorf("expecting token.apiserverID: %s, actual: %s", serverID, token.apiserverID)
|
||||||
|
}
|
||||||
|
if token.clientID != clientID {
|
||||||
|
t.Errorf("expecting token.clientID: %s, actual: %s", clientID, token.clientID)
|
||||||
|
}
|
||||||
|
if token.tenantID != tenantID {
|
||||||
|
t.Errorf("expecting token.tenantID: %s, actual: %s", tenantID, token.tenantID)
|
||||||
|
}
|
||||||
|
expectedAudience := serverID
|
||||||
|
if configMode == configModeDefault {
|
||||||
|
expectedAudience = fmt.Sprintf("spn:%s", serverID)
|
||||||
|
}
|
||||||
|
if token.token.Resource != expectedAudience {
|
||||||
|
t.Errorf("expecting adal token.Resource: %s, actual: %s", expectedAudience, token.token.Resource)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("validate token against cache", func(t *testing.T) {
|
t.Run("validate token against cache", func(t *testing.T) {
|
||||||
fakeAccessToken := "fake token 1"
|
fakeAccessToken := "fake token 1"
|
||||||
fakeSource := fakeTokenSource{token: newFakeAzureToken(fakeAccessToken, time.Now().Add(3600*time.Second))}
|
fakeSource := fakeTokenSource{token: newFakeAzureToken(fakeAccessToken, time.Now().Add(3600*time.Second))}
|
||||||
@ -223,7 +273,6 @@ func TestAzureTokenSource(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAzureTokenSourceScenarios(t *testing.T) {
|
func TestAzureTokenSourceScenarios(t *testing.T) {
|
||||||
configMode := configModeDefault
|
|
||||||
expiredToken := newFakeAzureToken("expired token", time.Now().Add(-time.Second))
|
expiredToken := newFakeAzureToken("expired token", time.Now().Add(-time.Second))
|
||||||
extendedToken := newFakeAzureToken("extend token", time.Now().Add(1000*time.Second))
|
extendedToken := newFakeAzureToken("extend token", time.Now().Add(1000*time.Second))
|
||||||
fakeToken := newFakeAzureToken("fake token", time.Now().Add(1000*time.Second))
|
fakeToken := newFakeAzureToken("fake token", time.Now().Add(1000*time.Second))
|
||||||
@ -309,57 +358,66 @@ func TestAzureTokenSourceScenarios(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
configModes := []configMode{configModeOmitSPNPrefix, configModeDefault}
|
||||||
persister := newFakePersister()
|
|
||||||
|
|
||||||
cfg := map[string]string{}
|
for _, configMode := range configModes {
|
||||||
if tc.configToken != nil {
|
t.Run(fmt.Sprintf("%s with configMode: %v", tc.name, configMode), func(t *testing.T) {
|
||||||
cfg = token2Cfg(tc.configToken)
|
persister := newFakePersister()
|
||||||
}
|
|
||||||
|
|
||||||
tokenCache := newAzureTokenCache()
|
cfg := map[string]string{
|
||||||
if tc.cachedToken != nil {
|
cfgConfigMode: string(configMode),
|
||||||
tokenCache.setToken(azureTokenKey, tc.cachedToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
fakeSource := fakeTokenSource{
|
|
||||||
token: tc.sourceToken,
|
|
||||||
tokenErr: tc.tokenErr,
|
|
||||||
refreshToken: tc.refreshToken,
|
|
||||||
refreshErr: tc.refreshErr,
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister)
|
|
||||||
token, err := tokenSource.Token()
|
|
||||||
|
|
||||||
if fakeSource.tokenCalls != tc.tokenCalls {
|
|
||||||
t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
if fakeSource.refreshCalls != tc.refreshCalls {
|
|
||||||
t.Errorf("expecting refreshCalls: %v, got: %v", tc.refreshCalls, fakeSource.refreshCalls)
|
|
||||||
}
|
|
||||||
|
|
||||||
if persister.calls != tc.persistCalls {
|
|
||||||
t.Errorf("expecting persister calls: %v, got: %v", tc.persistCalls, persister.calls)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tc.expectErr != "" {
|
|
||||||
if !strings.Contains(err.Error(), tc.expectErr) {
|
|
||||||
t.Errorf("expecting error %v, got %v", tc.expectErr, err)
|
|
||||||
}
|
}
|
||||||
if token != nil {
|
if tc.configToken != nil {
|
||||||
t.Errorf("token should be nil in err situation, got %v", token)
|
cfg = token2Cfg(tc.configToken)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
if err != nil {
|
tokenCache := newAzureTokenCache()
|
||||||
t.Fatalf("error should be nil, got %v", err)
|
if tc.cachedToken != nil {
|
||||||
|
tokenCache.setToken(azureTokenKey, tc.cachedToken)
|
||||||
}
|
}
|
||||||
if token.token.AccessToken != tc.expectToken.token.AccessToken {
|
|
||||||
t.Errorf("token should have accessToken %v, got %v", token.token.AccessToken, tc.expectToken.token.AccessToken)
|
fakeSource := fakeTokenSource{
|
||||||
|
token: tc.sourceToken,
|
||||||
|
tokenErr: tc.tokenErr,
|
||||||
|
refreshToken: tc.refreshToken,
|
||||||
|
refreshErr: tc.refreshErr,
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister)
|
||||||
|
token, err := tokenSource.Token()
|
||||||
|
|
||||||
|
if token != nil && fakeSource.token != nil && token.apiserverID != fakeSource.token.apiserverID {
|
||||||
|
t.Errorf("expecting apiservierID: %s, got: %s", fakeSource.token.apiserverID, token.apiserverID)
|
||||||
|
}
|
||||||
|
if fakeSource.tokenCalls != tc.tokenCalls {
|
||||||
|
t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fakeSource.refreshCalls != tc.refreshCalls {
|
||||||
|
t.Errorf("expecting refreshCalls: %v, got: %v", tc.refreshCalls, fakeSource.refreshCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
if persister.calls != tc.persistCalls {
|
||||||
|
t.Errorf("expecting persister calls: %v, got: %v", tc.persistCalls, persister.calls)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.expectErr != "" {
|
||||||
|
if !strings.Contains(err.Error(), tc.expectErr) {
|
||||||
|
t.Errorf("expecting error %v, got %v", tc.expectErr, err)
|
||||||
|
}
|
||||||
|
if token != nil {
|
||||||
|
t.Errorf("token should be nil in err situation, got %v", token)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error should be nil, got %v", err)
|
||||||
|
}
|
||||||
|
if token.token.AccessToken != tc.expectToken.token.AccessToken {
|
||||||
|
t.Errorf("token should have accessToken %v, got %v", token.token.AccessToken, tc.expectToken.token.AccessToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user