From ba28f5cc8eb9738a322628953273802da189e165 Mon Sep 17 00:00:00 2001 From: Dong Liu Date: Tue, 18 Dec 2018 13:50:56 +0800 Subject: [PATCH] Fix aad support in kubectl for sovereign cloud Kubernetes-commit: 092f3988255801ce7a97de5448384c50c400a3a4 --- plugin/pkg/client/auth/azure/azure.go | 16 +++++++++++++++- plugin/pkg/client/auth/azure/azure_test.go | 9 +++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/plugin/pkg/client/auth/azure/azure.go b/plugin/pkg/client/auth/azure/azure.go index d42449fc..e583100c 100644 --- a/plugin/pkg/client/auth/azure/azure.go +++ b/plugin/pkg/client/auth/azure/azure.go @@ -145,6 +145,7 @@ func (r *azureRoundTripper) WrappedRoundTripper() http.RoundTripper { return r.r type azureToken struct { token adal.Token + environment string clientID string tenantID string apiserverID string @@ -219,6 +220,10 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) { if refreshToken == "" { return nil, fmt.Errorf("no refresh token in cfg: %s", cfgRefreshToken) } + environment := ts.cfg[cfgEnvironment] + if environment == "" { + return nil, fmt.Errorf("no environment in cfg: %s", cfgEnvironment) + } clientID := ts.cfg[cfgClientID] if clientID == "" { return nil, fmt.Errorf("no client ID in cfg: %s", cfgClientID) @@ -250,6 +255,7 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) { Resource: fmt.Sprintf("spn:%s", apiserverID), Type: tokenType, }, + environment: environment, clientID: clientID, tenantID: tenantID, apiserverID: apiserverID, @@ -260,6 +266,7 @@ func (ts *azureTokenSource) storeTokenInCfg(token *azureToken) error { newCfg := make(map[string]string) newCfg[cfgAccessToken] = token.token.AccessToken newCfg[cfgRefreshToken] = token.token.RefreshToken + newCfg[cfgEnvironment] = token.environment newCfg[cfgClientID] = token.clientID newCfg[cfgTenantID] = token.tenantID newCfg[cfgApiserverID] = token.apiserverID @@ -275,7 +282,12 @@ func (ts *azureTokenSource) storeTokenInCfg(token *azureToken) error { } func (ts *azureTokenSource) refreshToken(token *azureToken) (*azureToken, error) { - oauthConfig, err := adal.NewOAuthConfig(azure.PublicCloud.ActiveDirectoryEndpoint, token.tenantID) + env, err := azure.EnvironmentFromName(token.environment) + if err != nil { + return nil, err + } + + oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, token.tenantID) if err != nil { return nil, fmt.Errorf("building the OAuth configuration for token refresh: %v", err) } @@ -299,6 +311,7 @@ func (ts *azureTokenSource) refreshToken(token *azureToken) (*azureToken, error) return &azureToken{ token: spt.Token(), + environment: token.environment, clientID: token.clientID, tenantID: token.tenantID, apiserverID: token.apiserverID, @@ -353,6 +366,7 @@ func (ts *azureTokenSourceDeviceCode) Token() (*azureToken, error) { return &azureToken{ token: *token, + environment: ts.environment.Name, clientID: ts.clientID, tenantID: ts.tenantID, apiserverID: ts.apiserverID, diff --git a/plugin/pkg/client/auth/azure/azure_test.go b/plugin/pkg/client/auth/azure/azure_test.go index 810a097d..c2cac563 100644 --- a/plugin/pkg/client/auth/azure/azure_test.go +++ b/plugin/pkg/client/auth/azure/azure_test.go @@ -53,6 +53,13 @@ func TestAzureTokenSource(t *testing.T) { wantCfg := token2Cfg(token) persistedCfg := persiter.Cache() + + wantCfgLen := len(wantCfg) + persistedCfgLen := len(persistedCfg) + if wantCfgLen != persistedCfgLen { + t.Errorf("wantCfgLen and persistedCfgLen do not match, wantCfgLen=%v, persistedCfgLen=%v", wantCfgLen, persistedCfgLen) + } + for k, v := range persistedCfg { if strings.Compare(v, wantCfg[k]) != 0 { t.Errorf("Token() persisted cfg %s: got %v, want %v", k, v, wantCfg[k]) @@ -103,6 +110,7 @@ type fakeTokenSource struct { func (ts *fakeTokenSource) Token() (*azureToken, error) { return &azureToken{ token: newFackeAzureToken(ts.accessToken, ts.expiresOn), + environment: "testenv", clientID: "fake", tenantID: "fake", apiserverID: "fake", @@ -113,6 +121,7 @@ func token2Cfg(token *azureToken) map[string]string { cfg := make(map[string]string) cfg[cfgAccessToken] = token.token.AccessToken cfg[cfgRefreshToken] = token.token.RefreshToken + cfg[cfgEnvironment] = token.environment cfg[cfgClientID] = token.clientID cfg[cfgTenantID] = token.tenantID cfg[cfgApiserverID] = token.apiserverID