Revert "Remove gcp and azure auth plugins"

This reverts commit 916cf16cf14928702f3f90b655ddddab2c85fcec.

Kubernetes-commit: 651b4f5b647a205d12fad4d0edc489d97109cccc
This commit is contained in:
Jordan Liggitt
2022-08-18 14:16:23 -04:00
committed by Kubernetes Publisher
parent a890e7bc14
commit ef26118838
10 changed files with 2194 additions and 80 deletions

View File

@@ -0,0 +1,56 @@
# Azure Active Directory plugin for client authentication
This plugin provides an integration with Azure Active Directory device flow. If no tokens are present in the kubectl configuration, it will prompt a device code which can be used to login in a browser. After login it will automatically fetch the tokens and store them in the kubectl configuration. In addition it will refresh and update the tokens in the configuration when expired.
## Usage
1. Create an Azure Active Directory *Web App / API* application for `apiserver` following these [instructions](https://docs.microsoft.com/en-us/azure/active-directory/active-directory-app-registration). The callback URL does not matter (just cannot be empty).
2. Create a second Azure Active Directory native application for `kubectl`. The callback URL does not matter (just cannot be empty).
3. On `kubectl` application's configuration page in Azure portal grant permissions to `apiserver` application by clicking on *Required Permissions*, click the *Add* button and search for the apiserver application created in step 1. Select "Access apiserver" under the *DELEGATED PERMISSIONS*. Once added click the *Grant Permissions* button to apply the changes.
4. Configure the `apiserver` to use the Azure Active Directory as an OIDC provider with following options
```
--oidc-client-id="spn:APISERVER_APPLICATION_ID" \
--oidc-issuer-url="https://sts.windows.net/TENANT_ID/"
--oidc-username-claim="sub"
```
* Replace the `APISERVER_APPLICATION_ID` with the application ID of `apiserver` application
* Replace `TENANT_ID` with your tenant ID.
  * For a list of alternative username claims that are supported by the OIDC issuer check the JSON response at `https://sts.windows.net/TENANT_ID/.well-known/openid-configuration`.
5. Configure `kubectl` to use the `azure` authentication provider
```
kubectl config set-credentials "USER_NAME" --auth-provider=azure \
--auth-provider-arg=environment=AzurePublicCloud \
--auth-provider-arg=client-id=APPLICATION_ID \
--auth-provider-arg=tenant-id=TENANT_ID \
--auth-provider-arg=apiserver-id=APISERVER_APPLICATION_ID
```
* Supported environments: `AzurePublicCloud`, `AzureUSGovernmentCloud`, `AzureChinaCloud`, `AzureGermanCloud`
* Replace `USER_NAME` and `TENANT_ID` with your user name and tenant ID
* Replace `APPLICATION_ID` with the application ID of your`kubectl` application ID
* Replace `APISERVER_APPLICATION_ID` with the application ID of your `apiserver` application ID
* Be sure to also (create and) select a context that uses above user
6. (Optionally) the AAD token has `aud` claim with `spn:` prefix. To omit that, add following auth configuration:
```
--auth-provider-arg=config-mode="1"
```
7. The access token is acquired when first `kubectl` command is executed
```
kubectl get pods
To sign in, use a web browser to open the page https://aka.ms/devicelogin and enter the code DEC7D48GA to authenticate.
```
* After signing in a web browser, the token is stored in the configuration, and it will be reused when executing further commands.
* The resulting username in Kubernetes depends on your [configuration of the `--oidc-username-claim` and `--oidc-username-prefix` flags on the API server](https://kubernetes.io/docs/admin/authentication/#configuring-the-api-server). If you are using any authorization method you need to give permissions to that user, e.g. by binding the user to a role in the case of RBAC.

View File

@@ -0,0 +1,477 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"strconv"
"sync"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
"k8s.io/klog/v2"
"k8s.io/apimachinery/pkg/util/net"
restclient "k8s.io/client-go/rest"
)
type configMode int
const (
azureTokenKey = "azureTokenKey"
tokenType = "Bearer"
authHeader = "Authorization"
cfgClientID = "client-id"
cfgTenantID = "tenant-id"
cfgAccessToken = "access-token"
cfgRefreshToken = "refresh-token"
cfgExpiresIn = "expires-in"
cfgExpiresOn = "expires-on"
cfgEnvironment = "environment"
cfgApiserverID = "apiserver-id"
cfgConfigMode = "config-mode"
configModeDefault configMode = 0
configModeOmitSPNPrefix configMode = 1
)
func init() {
if err := restclient.RegisterAuthProviderPlugin("azure", newAzureAuthProvider); err != nil {
klog.Fatalf("Failed to register azure auth plugin: %v", err)
}
}
var cache = newAzureTokenCache()
type azureTokenCache struct {
lock sync.Mutex
cache map[string]*azureToken
}
func newAzureTokenCache() *azureTokenCache {
return &azureTokenCache{cache: make(map[string]*azureToken)}
}
func (c *azureTokenCache) getToken(tokenKey string) *azureToken {
c.lock.Lock()
defer c.lock.Unlock()
return c.cache[tokenKey]
}
func (c *azureTokenCache) setToken(tokenKey string, token *azureToken) {
c.lock.Lock()
defer c.lock.Unlock()
c.cache[tokenKey] = token
}
var warnOnce sync.Once
func newAzureAuthProvider(_ string, cfg map[string]string, persister restclient.AuthProviderConfigPersister) (restclient.AuthProvider, error) {
// deprecated in v1.22, remove in v1.25
warnOnce.Do(func() {
klog.Warningf(`WARNING: the azure auth plugin is deprecated in v1.22+, unavailable in v1.25+; use https://github.com/Azure/kubelogin instead.
To learn more, consult https://kubernetes.io/docs/reference/access-authn-authz/authentication/#client-go-credential-plugins`)
})
var (
ts tokenSource
environment azure.Environment
err error
mode configMode
)
environment, err = azure.EnvironmentFromName(cfg[cfgEnvironment])
if err != nil {
environment = azure.PublicCloud
}
mode = configModeDefault
if cfg[cfgConfigMode] != "" {
configModeInt, err := strconv.Atoi(cfg[cfgConfigMode])
if err != nil {
return nil, fmt.Errorf("failed to parse %s, error: %s", cfgConfigMode, err)
}
mode = configMode(configModeInt)
switch mode {
case configModeOmitSPNPrefix:
case configModeDefault:
default:
return nil, fmt.Errorf("%s:%s is not a valid mode", cfgConfigMode, cfg[cfgConfigMode])
}
}
ts, err = newAzureTokenSourceDeviceCode(environment, cfg[cfgClientID], cfg[cfgTenantID], cfg[cfgApiserverID], mode)
if err != nil {
return nil, fmt.Errorf("creating a new azure token source for device code authentication: %v", err)
}
cacheSource := newAzureTokenSource(ts, cache, cfg, mode, persister)
return &azureAuthProvider{
tokenSource: cacheSource,
}, nil
}
type azureAuthProvider struct {
tokenSource tokenSource
}
func (p *azureAuthProvider) Login() error {
return errors.New("not yet implemented")
}
func (p *azureAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
return &azureRoundTripper{
tokenSource: p.tokenSource,
roundTripper: rt,
}
}
type azureRoundTripper struct {
tokenSource tokenSource
roundTripper http.RoundTripper
}
var _ net.RoundTripperWrapper = &azureRoundTripper{}
func (r *azureRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if len(req.Header.Get(authHeader)) != 0 {
return r.roundTripper.RoundTrip(req)
}
token, err := r.tokenSource.Token()
if err != nil {
klog.Errorf("Failed to acquire a token: %v", err)
return nil, fmt.Errorf("acquiring a token for authorization header: %v", err)
}
// clone the request in order to avoid modifying the headers of the original request
req2 := new(http.Request)
*req2 = *req
req2.Header = make(http.Header, len(req.Header))
for k, s := range req.Header {
req2.Header[k] = append([]string(nil), s...)
}
req2.Header.Set(authHeader, fmt.Sprintf("%s %s", tokenType, token.token.AccessToken))
return r.roundTripper.RoundTrip(req2)
}
func (r *azureRoundTripper) WrappedRoundTripper() http.RoundTripper { return r.roundTripper }
type azureToken struct {
token adal.Token
environment string
clientID string
tenantID string
apiserverID string
}
type tokenSource interface {
Token() (*azureToken, error)
Refresh(*azureToken) (*azureToken, error)
}
type azureTokenSource struct {
source tokenSource
cache *azureTokenCache
lock sync.Mutex
configMode configMode
cfg map[string]string
persister restclient.AuthProviderConfigPersister
}
func newAzureTokenSource(source tokenSource, cache *azureTokenCache, cfg map[string]string, configMode configMode, persister restclient.AuthProviderConfigPersister) tokenSource {
return &azureTokenSource{
source: source,
cache: cache,
cfg: cfg,
persister: persister,
configMode: configMode,
}
}
// Token fetches a token from the cache of configuration if present otherwise
// acquires a new token from the configured source. Automatically refreshes
// the token if expired.
func (ts *azureTokenSource) Token() (*azureToken, error) {
ts.lock.Lock()
defer ts.lock.Unlock()
var err error
token := ts.cache.getToken(azureTokenKey)
if token != nil && !token.token.IsExpired() {
return token, nil
}
// retrieve from config if no cache
if token == nil {
tokenFromCfg, err := ts.retrieveTokenFromCfg()
if err == nil {
token = tokenFromCfg
}
}
if token != nil {
// cache and return if the token is as good
// avoids frequent persistor calls
if !token.token.IsExpired() {
ts.cache.setToken(azureTokenKey, token)
return token, nil
}
klog.V(4).Info("Refreshing token.")
tokenFromRefresh, err := ts.Refresh(token)
switch {
case err == nil:
token = tokenFromRefresh
case autorest.IsTokenRefreshError(err):
klog.V(4).Infof("Failed to refresh expired token, proceed to auth: %v", err)
// reset token to nil so that the token source will be used to acquire new
token = nil
default:
return nil, fmt.Errorf("unexpected error when refreshing token: %v", err)
}
}
if token == nil {
tokenFromSource, err := ts.source.Token()
if err != nil {
return nil, fmt.Errorf("failed acquiring new token: %v", err)
}
token = tokenFromSource
}
// sanity check
if token == nil {
return nil, fmt.Errorf("unable to acquire token")
}
// corner condition, newly got token is valid but expired
if token.token.IsExpired() {
return nil, fmt.Errorf("newly acquired token is expired")
}
err = ts.storeTokenInCfg(token)
if err != nil {
return nil, fmt.Errorf("storing the refreshed token in configuration: %v", err)
}
ts.cache.setToken(azureTokenKey, token)
return token, nil
}
func (ts *azureTokenSource) retrieveTokenFromCfg() (*azureToken, error) {
accessToken := ts.cfg[cfgAccessToken]
if accessToken == "" {
return nil, fmt.Errorf("no access token in cfg: %s", cfgAccessToken)
}
refreshToken := ts.cfg[cfgRefreshToken]
if refreshToken == "" {
return nil, fmt.Errorf("no refresh token in cfg: %s", cfgRefreshToken)
}
environment := ts.cfg[cfgEnvironment]
if environment == "" {
return nil, fmt.Errorf("no environment in cfg: %s", cfgEnvironment)
}
clientID := ts.cfg[cfgClientID]
if clientID == "" {
return nil, fmt.Errorf("no client ID in cfg: %s", cfgClientID)
}
tenantID := ts.cfg[cfgTenantID]
if tenantID == "" {
return nil, fmt.Errorf("no tenant ID in cfg: %s", cfgTenantID)
}
resourceID := ts.cfg[cfgApiserverID]
if resourceID == "" {
return nil, fmt.Errorf("no apiserver ID in cfg: %s", cfgApiserverID)
}
expiresIn := ts.cfg[cfgExpiresIn]
if expiresIn == "" {
return nil, fmt.Errorf("no expiresIn in cfg: %s", cfgExpiresIn)
}
expiresOn := ts.cfg[cfgExpiresOn]
if expiresOn == "" {
return nil, fmt.Errorf("no expiresOn in cfg: %s", cfgExpiresOn)
}
tokenAudience := resourceID
if ts.configMode == configModeDefault {
tokenAudience = fmt.Sprintf("spn:%s", resourceID)
}
return &azureToken{
token: adal.Token{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: json.Number(expiresIn),
ExpiresOn: json.Number(expiresOn),
NotBefore: json.Number(expiresOn),
Resource: tokenAudience,
Type: tokenType,
},
environment: environment,
clientID: clientID,
tenantID: tenantID,
apiserverID: resourceID,
}, nil
}
func (ts *azureTokenSource) storeTokenInCfg(token *azureToken) error {
newCfg := make(map[string]string)
newCfg[cfgAccessToken] = token.token.AccessToken
newCfg[cfgRefreshToken] = token.token.RefreshToken
newCfg[cfgEnvironment] = token.environment
newCfg[cfgClientID] = token.clientID
newCfg[cfgTenantID] = token.tenantID
newCfg[cfgApiserverID] = token.apiserverID
newCfg[cfgExpiresIn] = string(token.token.ExpiresIn)
newCfg[cfgExpiresOn] = string(token.token.ExpiresOn)
newCfg[cfgConfigMode] = strconv.Itoa(int(ts.configMode))
err := ts.persister.Persist(newCfg)
if err != nil {
return fmt.Errorf("persisting the configuration: %v", err)
}
ts.cfg = newCfg
return nil
}
func (ts *azureTokenSource) Refresh(token *azureToken) (*azureToken, error) {
return ts.source.Refresh(token)
}
// refresh outdated token with adal.
func (ts *azureTokenSourceDeviceCode) Refresh(token *azureToken) (*azureToken, error) {
env, err := azure.EnvironmentFromName(token.environment)
if err != nil {
return nil, err
}
var oauthConfig *adal.OAuthConfig
if ts.configMode == configModeOmitSPNPrefix {
oauthConfig, err = adal.NewOAuthConfigWithAPIVersion(env.ActiveDirectoryEndpoint, token.tenantID, nil)
if err != nil {
return nil, fmt.Errorf("building the OAuth configuration without api-version for token refresh: %v", err)
}
} else {
oauthConfig, err = adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, token.tenantID)
if err != nil {
return nil, fmt.Errorf("building the OAuth configuration for token refresh: %v", err)
}
}
callback := func(t adal.Token) error {
return nil
}
spt, err := adal.NewServicePrincipalTokenFromManualToken(
*oauthConfig,
token.clientID,
token.apiserverID,
token.token,
callback)
if err != nil {
return nil, fmt.Errorf("creating new service principal for token refresh: %v", err)
}
if err := spt.Refresh(); err != nil {
// Caller expects IsTokenRefreshError(err) to trigger prompt.
return nil, fmt.Errorf("refreshing token: %w", err)
}
return &azureToken{
token: spt.Token(),
environment: token.environment,
clientID: token.clientID,
tenantID: token.tenantID,
apiserverID: token.apiserverID,
}, nil
}
type azureTokenSourceDeviceCode struct {
environment azure.Environment
clientID string
tenantID string
apiserverID string
configMode configMode
}
func newAzureTokenSourceDeviceCode(environment azure.Environment, clientID string, tenantID string, apiserverID string, configMode configMode) (tokenSource, error) {
if clientID == "" {
return nil, errors.New("client-id is empty")
}
if tenantID == "" {
return nil, errors.New("tenant-id is empty")
}
if apiserverID == "" {
return nil, errors.New("apiserver-id is empty")
}
return &azureTokenSourceDeviceCode{
environment: environment,
clientID: clientID,
tenantID: tenantID,
apiserverID: apiserverID,
configMode: configMode,
}, nil
}
func (ts *azureTokenSourceDeviceCode) Token() (*azureToken, error) {
var (
oauthConfig *adal.OAuthConfig
err error
)
if ts.configMode == configModeOmitSPNPrefix {
oauthConfig, err = adal.NewOAuthConfigWithAPIVersion(ts.environment.ActiveDirectoryEndpoint, ts.tenantID, nil)
if err != nil {
return nil, fmt.Errorf("building the OAuth configuration without api-version for device code authentication: %v", err)
}
} else {
oauthConfig, err = adal.NewOAuthConfig(ts.environment.ActiveDirectoryEndpoint, ts.tenantID)
if err != nil {
return nil, fmt.Errorf("building the OAuth configuration for device code authentication: %v", err)
}
}
client := &autorest.Client{}
deviceCode, err := adal.InitiateDeviceAuth(client, *oauthConfig, ts.clientID, ts.apiserverID)
if err != nil {
return nil, fmt.Errorf("initialing the device code authentication: %v", err)
}
_, err = fmt.Fprintln(os.Stderr, *deviceCode.Message)
if err != nil {
return nil, fmt.Errorf("prompting the device code message: %v", err)
}
token, err := adal.WaitForUserCompletion(client, deviceCode)
if err != nil {
return nil, fmt.Errorf("waiting for device code authentication to complete: %v", err)
}
return &azureToken{
token: *token,
environment: ts.environment.Name,
clientID: ts.clientID,
tenantID: ts.tenantID,
apiserverID: ts.apiserverID,
}, nil
}

View File

@@ -1,36 +0,0 @@
/*
Copyright 2022 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"errors"
"k8s.io/client-go/rest"
"k8s.io/klog/v2"
)
func init() {
if err := rest.RegisterAuthProviderPlugin("azure", newAzureAuthProvider); err != nil {
klog.Fatalf("Failed to register azure auth plugin: %v", err)
}
}
func newAzureAuthProvider(_ string, _ map[string]string, _ rest.AuthProviderConfigPersister) (rest.AuthProvider, error) {
return nil, errors.New(`The azure auth plugin has been removed.
Please use the https://github.com/Azure/kubelogin kubectl/client-go credential plugin instead.
See https://kubernetes.io/docs/reference/access-authn-authz/authentication/#client-go-credential-plugins for further details`)
}

View File

@@ -0,0 +1,534 @@
/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
)
func TestAzureAuthProvider(t *testing.T) {
t.Run("validate against invalid configurations", func(t *testing.T) {
vectors := []struct {
cfg map[string]string
expectedError string
}{
{
cfg: map[string]string{
cfgClientID: "foo",
cfgApiserverID: "foo",
cfgTenantID: "foo",
cfgConfigMode: "-1",
},
expectedError: "config-mode:-1 is not a valid mode",
},
{
cfg: map[string]string{
cfgClientID: "foo",
cfgApiserverID: "foo",
cfgTenantID: "foo",
cfgConfigMode: "2",
},
expectedError: "config-mode:2 is not a valid mode",
},
{
cfg: map[string]string{
cfgClientID: "foo",
cfgApiserverID: "foo",
cfgTenantID: "foo",
cfgConfigMode: "foo",
},
expectedError: "failed to parse config-mode, error: strconv.Atoi: parsing \"foo\": invalid syntax",
},
}
for _, v := range vectors {
persister := &fakePersister{}
_, err := newAzureAuthProvider("", v.cfg, persister)
if !strings.Contains(err.Error(), v.expectedError) {
t.Errorf("cfg %v should fail with message containing '%s'. actual: '%s'", v.cfg, v.expectedError, err)
}
}
})
t.Run("it should return non-nil provider in happy cases", func(t *testing.T) {
vectors := []struct {
cfg map[string]string
expectedConfigMode configMode
}{
{
cfg: map[string]string{
cfgClientID: "foo",
cfgApiserverID: "foo",
cfgTenantID: "foo",
},
expectedConfigMode: configModeDefault,
},
{
cfg: map[string]string{
cfgClientID: "foo",
cfgApiserverID: "foo",
cfgTenantID: "foo",
cfgConfigMode: "0",
},
expectedConfigMode: configModeDefault,
},
{
cfg: map[string]string{
cfgClientID: "foo",
cfgApiserverID: "foo",
cfgTenantID: "foo",
cfgConfigMode: "1",
},
expectedConfigMode: configModeOmitSPNPrefix,
},
}
for _, v := range vectors {
persister := &fakePersister{}
provider, err := newAzureAuthProvider("", v.cfg, persister)
if err != nil {
t.Errorf("newAzureAuthProvider should not fail with '%s'", err)
}
if provider == nil {
t.Fatalf("newAzureAuthProvider should return non-nil provider")
}
azureProvider := provider.(*azureAuthProvider)
if azureProvider == nil {
t.Fatalf("newAzureAuthProvider should return an instance of type azureAuthProvider")
}
ts := azureProvider.tokenSource.(*azureTokenSource)
if ts == nil {
t.Fatalf("azureAuthProvider should be an instance of azureTokenSource")
}
if ts.configMode != v.expectedConfigMode {
t.Errorf("expected configMode: %d, actual: %d", v.expectedConfigMode, ts.configMode)
}
}
})
}
func TestTokenSourceDeviceCode(t *testing.T) {
var (
clientID = "clientID"
tenantID = "tenantID"
apiserverID = "apiserverID"
configMode = configModeDefault
azureEnv = azure.Environment{}
)
t.Run("validate to create azureTokenSourceDeviceCode", func(t *testing.T) {
if _, err := newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, apiserverID, configModeDefault); err != nil {
t.Errorf("newAzureTokenSourceDeviceCode should not have failed. err: %s", err)
}
if _, err := newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, apiserverID, configModeOmitSPNPrefix); err != nil {
t.Errorf("newAzureTokenSourceDeviceCode should not have failed. err: %s", err)
}
_, err := newAzureTokenSourceDeviceCode(azureEnv, "", tenantID, apiserverID, configMode)
actual := "client-id is empty"
if err.Error() != actual {
t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err)
}
_, err = newAzureTokenSourceDeviceCode(azureEnv, clientID, "", apiserverID, configMode)
actual = "tenant-id is empty"
if err.Error() != actual {
t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err)
}
_, err = newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, "", configMode)
actual = "apiserver-id is empty"
if err.Error() != actual {
t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err)
}
})
}
func TestAzureTokenSource(t *testing.T) {
configModes := []configMode{configModeOmitSPNPrefix, configModeDefault}
expectedConfigModes := []string{"1", "0"}
for i, configMode := range configModes {
t.Run(fmt.Sprintf("validate token from cfg with configMode %v", configMode), func(t *testing.T) {
const (
serverID = "fakeServerID"
clientID = "fakeClientID"
tenantID = "fakeTenantID"
accessToken = "fakeToken"
environment = "fakeEnvironment"
refreshToken = "fakeToken"
expiresIn = "foo"
expiresOn = "foo"
)
cfg := map[string]string{
cfgConfigMode: strconv.Itoa(int(configMode)),
cfgApiserverID: serverID,
cfgClientID: clientID,
cfgTenantID: tenantID,
cfgEnvironment: environment,
cfgAccessToken: accessToken,
cfgRefreshToken: refreshToken,
cfgExpiresIn: expiresIn,
cfgExpiresOn: expiresOn,
}
fakeSource := fakeTokenSource{token: newFakeAzureToken("fakeToken", time.Now().Add(3600*time.Second))}
persiter := &fakePersister{cache: make(map[string]string)}
tokenCache := newAzureTokenCache()
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter)
azTokenSource := tokenSource.(*azureTokenSource)
token, err := azTokenSource.retrieveTokenFromCfg()
if err != nil {
t.Errorf("failed to retrieve the token form cfg: %s", err)
}
if token.apiserverID != serverID {
t.Errorf("expecting token.apiserverID: %s, actual: %s", serverID, token.apiserverID)
}
if token.clientID != clientID {
t.Errorf("expecting token.clientID: %s, actual: %s", clientID, token.clientID)
}
if token.tenantID != tenantID {
t.Errorf("expecting token.tenantID: %s, actual: %s", tenantID, token.tenantID)
}
expectedAudience := serverID
if configMode == configModeDefault {
expectedAudience = fmt.Sprintf("spn:%s", serverID)
}
if token.token.Resource != expectedAudience {
t.Errorf("expecting adal token.Resource: %s, actual: %s", expectedAudience, token.token.Resource)
}
})
t.Run("validate token against cache", func(t *testing.T) {
fakeAccessToken := "fake token 1"
fakeSource := fakeTokenSource{token: newFakeAzureToken(fakeAccessToken, time.Now().Add(3600*time.Second))}
cfg := make(map[string]string)
persiter := &fakePersister{cache: make(map[string]string)}
tokenCache := newAzureTokenCache()
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter)
token, err := tokenSource.Token()
if err != nil {
t.Errorf("failed to retrieve the token form cache: %v", err)
}
wantCacheLen := 1
if len(tokenCache.cache) != wantCacheLen {
t.Errorf("Token() cache length error: got %v, want %v", len(tokenCache.cache), wantCacheLen)
}
if token != tokenCache.cache[azureTokenKey] {
t.Error("Token() returned token != cached token")
}
wantCfg := token2Cfg(token)
wantCfg[cfgConfigMode] = expectedConfigModes[i]
persistedCfg := persiter.Cache()
wantCfgLen := len(wantCfg)
persistedCfgLen := len(persistedCfg)
if wantCfgLen != persistedCfgLen {
t.Errorf("wantCfgLen and persistedCfgLen do not match, wantCfgLen=%v, persistedCfgLen=%v", wantCfgLen, persistedCfgLen)
}
for k, v := range persistedCfg {
if strings.Compare(v, wantCfg[k]) != 0 {
t.Errorf("Token() persisted cfg %s: got %v, want %v", k, v, wantCfg[k])
}
}
fakeSource.token = newFakeAzureToken("fake token 2", time.Now().Add(3600*time.Second))
token, err = tokenSource.Token()
if err != nil {
t.Errorf("failed to retrieve the cached token: %v", err)
}
if token.token.AccessToken != fakeAccessToken {
t.Errorf("Token() didn't return the cached token")
}
})
}
}
func TestAzureTokenSourceScenarios(t *testing.T) {
expiredToken := newFakeAzureToken("expired token", time.Now().Add(-time.Second))
extendedToken := newFakeAzureToken("extend token", time.Now().Add(1000*time.Second))
fakeToken := newFakeAzureToken("fake token", time.Now().Add(1000*time.Second))
wrongToken := newFakeAzureToken("wrong token", time.Now().Add(1000*time.Second))
tests := []struct {
name string
sourceToken *azureToken
refreshToken *azureToken
cachedToken *azureToken
configToken *azureToken
expectToken *azureToken
tokenErr error
refreshErr error
expectErr string
tokenCalls uint
refreshCalls uint
persistCalls uint
}{
{
name: "new config",
sourceToken: fakeToken,
expectToken: fakeToken,
tokenCalls: 1,
persistCalls: 1,
},
{
name: "load token from cache",
sourceToken: wrongToken,
cachedToken: fakeToken,
configToken: wrongToken,
expectToken: fakeToken,
},
{
name: "load token from config",
sourceToken: wrongToken,
configToken: fakeToken,
expectToken: fakeToken,
},
{
name: "cached token timeout, extend success, config token should never load",
cachedToken: expiredToken,
refreshToken: extendedToken,
configToken: wrongToken,
expectToken: extendedToken,
refreshCalls: 1,
persistCalls: 1,
},
{
name: "config token timeout, extend failure, acquire new token",
configToken: expiredToken,
refreshErr: fakeTokenRefreshError{message: "FakeError happened when refreshing"},
sourceToken: fakeToken,
expectToken: fakeToken,
refreshCalls: 1,
tokenCalls: 1,
persistCalls: 1,
},
{
name: "extend failure with fmt.Errorf nested tokenRefreshError",
configToken: expiredToken,
refreshErr: fmt.Errorf("refreshing token: %w", fakeTokenRefreshError{message: "nested FakeError happened when refreshing"}),
sourceToken: fakeToken,
expectToken: fakeToken,
refreshCalls: 1,
tokenCalls: 1,
persistCalls: 1,
},
{
name: "unexpected error when extend",
configToken: expiredToken,
refreshErr: errors.New("unexpected refresh error"),
sourceToken: fakeToken,
expectErr: "unexpected refresh error",
refreshCalls: 1,
},
{
name: "token error",
tokenErr: errors.New("tokenerr"),
expectErr: "tokenerr",
tokenCalls: 1,
},
{
name: "Token() got expired token",
sourceToken: expiredToken,
expectErr: "newly acquired token is expired",
tokenCalls: 1,
},
{
name: "Token() got nil but no error",
sourceToken: nil,
expectErr: "unable to acquire token",
tokenCalls: 1,
},
}
for _, tc := range tests {
configModes := []configMode{configModeOmitSPNPrefix, configModeDefault}
for _, configMode := range configModes {
t.Run(fmt.Sprintf("%s with configMode: %v", tc.name, configMode), func(t *testing.T) {
persister := newFakePersister()
cfg := map[string]string{
cfgConfigMode: strconv.Itoa(int(configMode)),
}
if tc.configToken != nil {
cfg = token2Cfg(tc.configToken)
}
tokenCache := newAzureTokenCache()
if tc.cachedToken != nil {
tokenCache.setToken(azureTokenKey, tc.cachedToken)
}
fakeSource := fakeTokenSource{
token: tc.sourceToken,
tokenErr: tc.tokenErr,
refreshToken: tc.refreshToken,
refreshErr: tc.refreshErr,
}
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister)
token, err := tokenSource.Token()
if token != nil && fakeSource.token != nil && token.apiserverID != fakeSource.token.apiserverID {
t.Errorf("expecting apiservierID: %s, got: %s", fakeSource.token.apiserverID, token.apiserverID)
}
if fakeSource.tokenCalls != tc.tokenCalls {
t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls)
}
if fakeSource.refreshCalls != tc.refreshCalls {
t.Errorf("expecting refreshCalls: %v, got: %v", tc.refreshCalls, fakeSource.refreshCalls)
}
if persister.calls != tc.persistCalls {
t.Errorf("expecting persister calls: %v, got: %v", tc.persistCalls, persister.calls)
}
if tc.expectErr != "" {
if !strings.Contains(err.Error(), tc.expectErr) {
t.Errorf("expecting error %v, got %v", tc.expectErr, err)
}
if token != nil {
t.Errorf("token should be nil in err situation, got %v", token)
}
} else {
if err != nil {
t.Fatalf("error should be nil, got %v", err)
}
if token.token.AccessToken != tc.expectToken.token.AccessToken {
t.Errorf("token should have accessToken %v, got %v", token.token.AccessToken, tc.expectToken.token.AccessToken)
}
}
})
}
}
}
type fakePersister struct {
lock sync.Mutex
cache map[string]string
calls uint
}
func newFakePersister() fakePersister {
return fakePersister{cache: make(map[string]string), calls: 0}
}
func (p *fakePersister) Persist(cache map[string]string) error {
p.lock.Lock()
defer p.lock.Unlock()
p.calls++
p.cache = map[string]string{}
for k, v := range cache {
p.cache[k] = v
}
return nil
}
func (p *fakePersister) Cache() map[string]string {
ret := map[string]string{}
p.lock.Lock()
defer p.lock.Unlock()
for k, v := range p.cache {
ret[k] = v
}
return ret
}
// a simple token source simply always returns the token property
type fakeTokenSource struct {
token *azureToken
tokenCalls uint
tokenErr error
refreshToken *azureToken
refreshCalls uint
refreshErr error
}
func (ts *fakeTokenSource) Token() (*azureToken, error) {
ts.tokenCalls++
return ts.token, ts.tokenErr
}
func (ts *fakeTokenSource) Refresh(*azureToken) (*azureToken, error) {
ts.refreshCalls++
return ts.refreshToken, ts.refreshErr
}
func token2Cfg(token *azureToken) map[string]string {
cfg := make(map[string]string)
cfg[cfgAccessToken] = token.token.AccessToken
cfg[cfgRefreshToken] = token.token.RefreshToken
cfg[cfgEnvironment] = token.environment
cfg[cfgClientID] = token.clientID
cfg[cfgTenantID] = token.tenantID
cfg[cfgApiserverID] = token.apiserverID
cfg[cfgExpiresIn] = string(token.token.ExpiresIn)
cfg[cfgExpiresOn] = string(token.token.ExpiresOn)
return cfg
}
func newFakeAzureToken(accessToken string, expiresOnTime time.Time) *azureToken {
return &azureToken{
token: newFakeADALToken(accessToken, strconv.FormatInt(expiresOnTime.Unix(), 10)),
environment: "testenv",
clientID: "fake",
tenantID: "fake",
apiserverID: "fake",
}
}
func newFakeADALToken(accessToken string, expiresOn string) adal.Token {
return adal.Token{
AccessToken: accessToken,
RefreshToken: "fake",
ExpiresIn: "3600",
ExpiresOn: json.Number(expiresOn),
NotBefore: json.Number(expiresOn),
Resource: "fake",
Type: "fake",
}
}
// copied from go-autorest/adal
type fakeTokenRefreshError struct {
message string
resp *http.Response
}
// Error implements the error interface which is part of the TokenRefreshError interface.
func (tre fakeTokenRefreshError) Error() string {
return tre.message
}
// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
func (tre fakeTokenRefreshError) Response() *http.Response {
return tre.resp
}