From c08db5feacc78aeaf104f2c0eb0c4b038fba1df5 Mon Sep 17 00:00:00 2001 From: Weinong Wang Date: Tue, 28 Jan 2020 14:29:11 -0800 Subject: [PATCH] add a flag in azure auth module to omit spn: prefix in audience claim --- .../plugin/pkg/client/auth/azure/BUILD | 5 +- .../plugin/pkg/client/auth/azure/README.md | 8 +- .../plugin/pkg/client/auth/azure/azure.go | 106 ++++++-- .../pkg/client/auth/azure/azure_test.go | 234 ++++++++++++++---- 4 files changed, 281 insertions(+), 72 deletions(-) diff --git a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/BUILD b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/BUILD index 6a28475a76d..13f5894939d 100644 --- a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/BUILD +++ b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/BUILD @@ -10,7 +10,10 @@ go_test( name = "go_default_test", srcs = ["azure_test.go"], embed = [":go_default_library"], - deps = ["//vendor/github.com/Azure/go-autorest/autorest/adal:go_default_library"], + deps = [ + "//vendor/github.com/Azure/go-autorest/autorest/adal:go_default_library", + "//vendor/github.com/Azure/go-autorest/autorest/azure:go_default_library", + ], ) go_library( diff --git a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/README.md b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/README.md index 21a39ae86ea..096f99e1d2d 100644 --- a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/README.md +++ b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/README.md @@ -38,7 +38,13 @@ This plugin provides an integration with Azure Active Directory device flow. If * Replace `APISERVER_APPLICATION_ID` with the application ID of your `apiserver` application ID * Be sure to also (create and) select a context that uses above user - 6. The access token is acquired when first `kubectl` command is executed +6. (Optionally) the AAD token has `aud` claim with `spn:` prefix. To omit that, add following auth configuration: + + ``` + --auth-provider-arg=config-mode="1" + ``` + + 7. The access token is acquired when first `kubectl` command is executed ``` kubectl get pods diff --git a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure.go b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure.go index e583100cc9c..11ef7afb213 100644 --- a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure.go +++ b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure.go @@ -22,6 +22,7 @@ import ( "fmt" "net/http" "os" + "strconv" "sync" "github.com/Azure/go-autorest/autorest" @@ -33,6 +34,8 @@ import ( restclient "k8s.io/client-go/rest" ) +type configMode int + const ( azureTokenKey = "azureTokenKey" tokenType = "Bearer" @@ -46,6 +49,10 @@ const ( cfgExpiresOn = "expires-on" cfgEnvironment = "environment" cfgApiserverID = "apiserver-id" + cfgConfigMode = "config-mode" + + configModeDefault configMode = 0 + configModeOmitSPNPrefix configMode = 1 ) func init() { @@ -78,17 +85,37 @@ func (c *azureTokenCache) setToken(tokenKey string, token *azureToken) { } func newAzureAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) { - var ts tokenSource + var ( + ts tokenSource + environment azure.Environment + err error + mode configMode + ) - environment, err := azure.EnvironmentFromName(cfg[cfgEnvironment]) + environment, err = azure.EnvironmentFromName(cfg[cfgEnvironment]) if err != nil { environment = azure.PublicCloud } - ts, err = newAzureTokenSourceDeviceCode(environment, cfg[cfgClientID], cfg[cfgTenantID], cfg[cfgApiserverID]) + + mode = configModeDefault + if cfg[cfgConfigMode] != "" { + configModeInt, err := strconv.Atoi(cfg[cfgConfigMode]) + if err != nil { + return nil, fmt.Errorf("failed to parse %s, error: %s", cfgConfigMode, err) + } + mode = configMode(configModeInt) + switch mode { + case configModeOmitSPNPrefix: + case configModeDefault: + default: + return nil, fmt.Errorf("%s:%s is not a valid mode", cfgConfigMode, cfg[cfgConfigMode]) + } + } + ts, err = newAzureTokenSourceDeviceCode(environment, cfg[cfgClientID], cfg[cfgTenantID], cfg[cfgApiserverID], mode) if err != nil { return nil, fmt.Errorf("creating a new azure token source for device code authentication: %v", err) } - cacheSource := newAzureTokenSource(ts, cache, cfg, persister) + cacheSource := newAzureTokenSource(ts, cache, cfg, mode, persister) return &azureAuthProvider{ tokenSource: cacheSource, @@ -156,19 +183,21 @@ type tokenSource interface { } type azureTokenSource struct { - source tokenSource - cache *azureTokenCache - lock sync.Mutex - cfg map[string]string - persister restclient.AuthProviderConfigPersister + source tokenSource + cache *azureTokenCache + lock sync.Mutex + configMode configMode + cfg map[string]string + persister restclient.AuthProviderConfigPersister } -func newAzureTokenSource(source tokenSource, cache *azureTokenCache, cfg map[string]string, persister restclient.AuthProviderConfigPersister) tokenSource { +func newAzureTokenSource(source tokenSource, cache *azureTokenCache, cfg map[string]string, configMode configMode, persister restclient.AuthProviderConfigPersister) tokenSource { return &azureTokenSource{ - source: source, - cache: cache, - cfg: cfg, - persister: persister, + source: source, + cache: cache, + cfg: cfg, + persister: persister, + configMode: configMode, } } @@ -232,9 +261,9 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) { if tenantID == "" { return nil, fmt.Errorf("no tenant ID in cfg: %s", cfgTenantID) } - apiserverID := ts.cfg[cfgApiserverID] - if apiserverID == "" { - return nil, fmt.Errorf("no apiserver ID in cfg: %s", apiserverID) + resourceID := ts.cfg[cfgApiserverID] + if resourceID == "" { + return nil, fmt.Errorf("no apiserver ID in cfg: %s", cfgApiserverID) } expiresIn := ts.cfg[cfgExpiresIn] if expiresIn == "" { @@ -244,6 +273,9 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) { if expiresOn == "" { return nil, fmt.Errorf("no expiresOn in cfg: %s", cfgExpiresOn) } + if ts.configMode == configModeDefault { + resourceID = fmt.Sprintf("spn:%s", resourceID) + } return &azureToken{ token: adal.Token{ @@ -252,13 +284,13 @@ func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) { ExpiresIn: json.Number(expiresIn), ExpiresOn: json.Number(expiresOn), NotBefore: json.Number(expiresOn), - Resource: fmt.Sprintf("spn:%s", apiserverID), + Resource: resourceID, Type: tokenType, }, environment: environment, clientID: clientID, tenantID: tenantID, - apiserverID: apiserverID, + apiserverID: resourceID, }, nil } @@ -272,6 +304,7 @@ func (ts *azureTokenSource) storeTokenInCfg(token *azureToken) error { newCfg[cfgApiserverID] = token.apiserverID newCfg[cfgExpiresIn] = string(token.token.ExpiresIn) newCfg[cfgExpiresOn] = string(token.token.ExpiresOn) + newCfg[cfgConfigMode] = strconv.Itoa(int(ts.configMode)) err := ts.persister.Persist(newCfg) if err != nil { @@ -287,9 +320,17 @@ func (ts *azureTokenSource) refreshToken(token *azureToken) (*azureToken, error) return nil, err } - oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, token.tenantID) - if err != nil { - return nil, fmt.Errorf("building the OAuth configuration for token refresh: %v", err) + var oauthConfig *adal.OAuthConfig + if ts.configMode == configModeOmitSPNPrefix { + oauthConfig, err = adal.NewOAuthConfigWithAPIVersion(env.ActiveDirectoryEndpoint, token.tenantID, nil) + if err != nil { + return nil, fmt.Errorf("building the OAuth configuration without api-version for token refresh: %v", err) + } + } else { + oauthConfig, err = adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, token.tenantID) + if err != nil { + return nil, fmt.Errorf("building the OAuth configuration for token refresh: %v", err) + } } callback := func(t adal.Token) error { @@ -323,9 +364,10 @@ type azureTokenSourceDeviceCode struct { clientID string tenantID string apiserverID string + configMode configMode } -func newAzureTokenSourceDeviceCode(environment azure.Environment, clientID string, tenantID string, apiserverID string) (tokenSource, error) { +func newAzureTokenSourceDeviceCode(environment azure.Environment, clientID string, tenantID string, apiserverID string, configMode configMode) (tokenSource, error) { if clientID == "" { return nil, errors.New("client-id is empty") } @@ -340,13 +382,25 @@ func newAzureTokenSourceDeviceCode(environment azure.Environment, clientID strin clientID: clientID, tenantID: tenantID, apiserverID: apiserverID, + configMode: configMode, }, nil } func (ts *azureTokenSourceDeviceCode) Token() (*azureToken, error) { - oauthConfig, err := adal.NewOAuthConfig(ts.environment.ActiveDirectoryEndpoint, ts.tenantID) - if err != nil { - return nil, fmt.Errorf("building the OAuth configuration for device code authentication: %v", err) + var ( + oauthConfig *adal.OAuthConfig + err error + ) + if ts.configMode == configModeOmitSPNPrefix { + oauthConfig, err = adal.NewOAuthConfigWithAPIVersion(ts.environment.ActiveDirectoryEndpoint, ts.tenantID, nil) + if err != nil { + return nil, fmt.Errorf("building the OAuth configuration without api-version for device code authentication: %v", err) + } + } else { + oauthConfig, err = adal.NewOAuthConfig(ts.environment.ActiveDirectoryEndpoint, ts.tenantID) + if err != nil { + return nil, fmt.Errorf("building the OAuth configuration for device code authentication: %v", err) + } } client := &autorest.Client{} deviceCode, err := adal.InitiateDeviceAuth(client, *oauthConfig, ts.clientID, ts.apiserverID) diff --git a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure_test.go b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure_test.go index c2cac5633e4..75fb6c3ad7c 100644 --- a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure_test.go +++ b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/azure/azure_test.go @@ -25,55 +25,201 @@ import ( "time" "github.com/Azure/go-autorest/autorest/adal" + "github.com/Azure/go-autorest/autorest/azure" ) -func TestAzureTokenSource(t *testing.T) { - fakeAccessToken := "fake token 1" - fakeSource := fakeTokenSource{ - accessToken: fakeAccessToken, - expiresOn: strconv.FormatInt(time.Now().Add(3600*time.Second).Unix(), 10), - } - cfg := make(map[string]string) - persiter := &fakePersister{cache: make(map[string]string)} - tokenCache := newAzureTokenCache() - tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, persiter) - token, err := tokenSource.Token() - if err != nil { - t.Errorf("failed to retrieve the token form cache: %v", err) - } - - wantCacheLen := 1 - if len(tokenCache.cache) != wantCacheLen { - t.Errorf("Token() cache length error: got %v, want %v", len(tokenCache.cache), wantCacheLen) - } - - if token != tokenCache.cache[azureTokenKey] { - t.Error("Token() returned token != cached token") - } - - wantCfg := token2Cfg(token) - persistedCfg := persiter.Cache() - - wantCfgLen := len(wantCfg) - persistedCfgLen := len(persistedCfg) - if wantCfgLen != persistedCfgLen { - t.Errorf("wantCfgLen and persistedCfgLen do not match, wantCfgLen=%v, persistedCfgLen=%v", wantCfgLen, persistedCfgLen) - } - - for k, v := range persistedCfg { - if strings.Compare(v, wantCfg[k]) != 0 { - t.Errorf("Token() persisted cfg %s: got %v, want %v", k, v, wantCfg[k]) +func TestAzureAuthProvider(t *testing.T) { + t.Run("validate against invalid configurations", func(t *testing.T) { + vectors := []struct { + cfg map[string]string + expectedError string + }{ + { + cfg: map[string]string{ + cfgClientID: "foo", + cfgApiserverID: "foo", + cfgTenantID: "foo", + cfgConfigMode: "-1", + }, + expectedError: "config-mode:-1 is not a valid mode", + }, + { + cfg: map[string]string{ + cfgClientID: "foo", + cfgApiserverID: "foo", + cfgTenantID: "foo", + cfgConfigMode: "2", + }, + expectedError: "config-mode:2 is not a valid mode", + }, + { + cfg: map[string]string{ + cfgClientID: "foo", + cfgApiserverID: "foo", + cfgTenantID: "foo", + cfgConfigMode: "foo", + }, + expectedError: "failed to parse config-mode, error: strconv.Atoi: parsing \"foo\": invalid syntax", + }, } - } - fakeSource.accessToken = "fake token 2" - token, err = tokenSource.Token() - if err != nil { - t.Errorf("failed to retrieve the cached token: %v", err) - } + for _, v := range vectors { + persister := &fakePersister{} + _, err := newAzureAuthProvider("", v.cfg, persister) + if !strings.Contains(err.Error(), v.expectedError) { + t.Errorf("cfg %v should fail with message containing '%s'. actual: '%s'", v.cfg, v.expectedError, err) + } + } + }) - if token.token.AccessToken != fakeAccessToken { - t.Errorf("Token() didn't return the cached token") + t.Run("it should return non-nil provider in happy cases", func(t *testing.T) { + vectors := []struct { + cfg map[string]string + expectedConfigMode configMode + }{ + { + cfg: map[string]string{ + cfgClientID: "foo", + cfgApiserverID: "foo", + cfgTenantID: "foo", + }, + expectedConfigMode: configModeDefault, + }, + { + cfg: map[string]string{ + cfgClientID: "foo", + cfgApiserverID: "foo", + cfgTenantID: "foo", + cfgConfigMode: "0", + }, + expectedConfigMode: configModeDefault, + }, + { + cfg: map[string]string{ + cfgClientID: "foo", + cfgApiserverID: "foo", + cfgTenantID: "foo", + cfgConfigMode: "1", + }, + expectedConfigMode: configModeOmitSPNPrefix, + }, + } + + for _, v := range vectors { + persister := &fakePersister{} + provider, err := newAzureAuthProvider("", v.cfg, persister) + if err != nil { + t.Errorf("newAzureAuthProvider should not fail with '%s'", err) + } + if provider == nil { + t.Fatalf("newAzureAuthProvider should return non-nil provider") + } + azureProvider := provider.(*azureAuthProvider) + if azureProvider == nil { + t.Errorf("newAzureAuthProvider should return an instance of type azureAuthProvider") + } + ts := azureProvider.tokenSource.(*azureTokenSource) + if ts == nil { + t.Errorf("azureAuthProvider should be an instance of azureTokenSource") + } + if ts.configMode != v.expectedConfigMode { + t.Errorf("expected configMode: %d, actual: %d", v.expectedConfigMode, ts.configMode) + } + } + }) +} + +func TestTokenSourceDeviceCode(t *testing.T) { + var ( + clientID = "clientID" + tenantID = "tenantID" + apiserverID = "apiserverID" + configMode = configModeDefault + azureEnv = azure.Environment{} + ) + t.Run("validate to create azureTokenSourceDeviceCode", func(t *testing.T) { + if _, err := newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, apiserverID, configModeDefault); err != nil { + t.Errorf("newAzureTokenSourceDeviceCode should not have failed. err: %s", err) + } + + if _, err := newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, apiserverID, configModeOmitSPNPrefix); err != nil { + t.Errorf("newAzureTokenSourceDeviceCode should not have failed. err: %s", err) + } + + _, err := newAzureTokenSourceDeviceCode(azureEnv, "", tenantID, apiserverID, configMode) + actual := "client-id is empty" + if err.Error() != actual { + t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err) + } + + _, err = newAzureTokenSourceDeviceCode(azureEnv, clientID, "", apiserverID, configMode) + actual = "tenant-id is empty" + if err.Error() != actual { + t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err) + } + + _, err = newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, "", configMode) + actual = "apiserver-id is empty" + if err.Error() != actual { + t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err) + } + }) +} +func TestAzureTokenSource(t *testing.T) { + configModes := []configMode{configModeOmitSPNPrefix, configModeDefault} + expectedConfigModes := []string{"1", "0"} + + for i, configMode := range configModes { + t.Run("validate token against cache", func(t *testing.T) { + fakeAccessToken := "fake token 1" + fakeSource := fakeTokenSource{ + accessToken: fakeAccessToken, + expiresOn: strconv.FormatInt(time.Now().Add(3600*time.Second).Unix(), 10), + } + cfg := make(map[string]string) + persiter := &fakePersister{cache: make(map[string]string)} + tokenCache := newAzureTokenCache() + tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter) + token, err := tokenSource.Token() + if err != nil { + t.Errorf("failed to retrieve the token form cache: %v", err) + } + + wantCacheLen := 1 + if len(tokenCache.cache) != wantCacheLen { + t.Errorf("Token() cache length error: got %v, want %v", len(tokenCache.cache), wantCacheLen) + } + + if token != tokenCache.cache[azureTokenKey] { + t.Error("Token() returned token != cached token") + } + + wantCfg := token2Cfg(token) + wantCfg[cfgConfigMode] = expectedConfigModes[i] + persistedCfg := persiter.Cache() + + wantCfgLen := len(wantCfg) + persistedCfgLen := len(persistedCfg) + if wantCfgLen != persistedCfgLen { + t.Errorf("wantCfgLen and persistedCfgLen do not match, wantCfgLen=%v, persistedCfgLen=%v", wantCfgLen, persistedCfgLen) + } + + for k, v := range persistedCfg { + if strings.Compare(v, wantCfg[k]) != 0 { + t.Errorf("Token() persisted cfg %s: got %v, want %v", k, v, wantCfg[k]) + } + } + + fakeSource.accessToken = "fake token 2" + token, err = tokenSource.Token() + if err != nil { + t.Errorf("failed to retrieve the cached token: %v", err) + } + + if token.token.AccessToken != fakeAccessToken { + t.Errorf("Token() didn't return the cached token") + } + }) } }