Azure auth fallback to real auth if refresh token fails, refactor and add more tests.

Signed-off-by: Ping He <tdihp@hotmail.com>

Kubernetes-commit: 26c97fa1b40a7939ca26084c819af4794df34406
This commit is contained in:
Ping He 2020-03-22 17:04:20 +08:00 committed by Kubernetes Publisher
parent 70eb484951
commit a6c937f0cd
2 changed files with 252 additions and 35 deletions

View File

@ -180,6 +180,7 @@ type azureToken struct {
type tokenSource interface {
Token() (*azureToken, error)
Refresh(*azureToken) (*azureToken, error)
}
type azureTokenSource struct {
@ -210,33 +211,66 @@ func (ts *azureTokenSource) Token() (*azureToken, error) {
var err error
token := ts.cache.getToken(azureTokenKey)
if token != nil && !token.token.IsExpired() {
return token, nil
}
// retrieve from config if no cache
if token == nil {
token, err = ts.retrieveTokenFromCfg()
if err != nil {
token, err = ts.source.Token()
if err != nil {
return nil, fmt.Errorf("acquiring a new fresh token: %v", err)
}
tokenFromCfg, err := ts.retrieveTokenFromCfg()
if err == nil {
token = tokenFromCfg
}
}
if token != nil {
// cache and return if the token is as good
// avoids frequent persistor calls
if !token.token.IsExpired() {
ts.cache.setToken(azureTokenKey, token)
err = ts.storeTokenInCfg(token)
if err != nil {
return nil, fmt.Errorf("storing the token in configuration: %v", err)
}
return token, nil
}
klog.V(4).Info("Refreshing token.")
tokenFromRefresh, err := ts.Refresh(token)
switch {
case err == nil:
token = tokenFromRefresh
case autorest.IsTokenRefreshError(err):
klog.V(4).Infof("Failed to refresh expired token, proceed to auth: %v", err)
// reset token to nil so that the token source will be used to acquire new
token = nil
default:
return nil, fmt.Errorf("unexpected error when refreshing token: %v", err)
}
}
if token == nil {
tokenFromSource, err := ts.source.Token()
if err != nil {
return nil, fmt.Errorf("failed acquiring new token: %v", err)
}
token = tokenFromSource
}
// sanity check
if token == nil {
return nil, fmt.Errorf("unable to acquire token")
}
// corner condition, newly got token is valid but expired
if token.token.IsExpired() {
token, err = ts.refreshToken(token)
if err != nil {
return nil, fmt.Errorf("refreshing the expired token: %v", err)
}
ts.cache.setToken(azureTokenKey, token)
err = ts.storeTokenInCfg(token)
if err != nil {
return nil, fmt.Errorf("storing the refreshed token in configuration: %v", err)
}
return nil, fmt.Errorf("newly acquired token is expired")
}
err = ts.storeTokenInCfg(token)
if err != nil {
return nil, fmt.Errorf("storing the refreshed token in configuration: %v", err)
}
ts.cache.setToken(azureTokenKey, token)
return token, nil
}
@ -314,7 +348,13 @@ func (ts *azureTokenSource) storeTokenInCfg(token *azureToken) error {
return nil
}
func (ts *azureTokenSource) refreshToken(token *azureToken) (*azureToken, error) {
func (ts *azureTokenSource) Refresh(token *azureToken) (*azureToken, error) {
return ts.source.Refresh(token)
}
// refresh outdated token with adal.
// adal.RefreshTokenError will be returned if error occur during refreshing.
func (ts *azureTokenSourceDeviceCode) Refresh(token *azureToken) (*azureToken, error) {
env, err := azure.EnvironmentFromName(token.environment)
if err != nil {
return nil, err

View File

@ -18,6 +18,8 @@ package azure
import (
"encoding/json"
"errors"
"net/http"
"strconv"
"strings"
"sync"
@ -172,10 +174,7 @@ func TestAzureTokenSource(t *testing.T) {
for i, configMode := range configModes {
t.Run("validate token against cache", func(t *testing.T) {
fakeAccessToken := "fake token 1"
fakeSource := fakeTokenSource{
accessToken: fakeAccessToken,
expiresOn: strconv.FormatInt(time.Now().Add(3600*time.Second).Unix(), 10),
}
fakeSource := fakeTokenSource{token: newFakeAzureToken(fakeAccessToken, time.Now().Add(3600*time.Second))}
cfg := make(map[string]string)
persiter := &fakePersister{cache: make(map[string]string)}
tokenCache := newAzureTokenCache()
@ -210,7 +209,7 @@ func TestAzureTokenSource(t *testing.T) {
}
}
fakeSource.accessToken = "fake token 2"
fakeSource.token = newFakeAzureToken("fake token 2", time.Now().Add(3600*time.Second))
token, err = tokenSource.Token()
if err != nil {
t.Errorf("failed to retrieve the cached token: %v", err)
@ -223,14 +222,161 @@ func TestAzureTokenSource(t *testing.T) {
}
}
func TestAzureTokenSourceScenarios(t *testing.T) {
configMode := configModeDefault
expiredToken := newFakeAzureToken("expired token", time.Now().Add(-time.Second))
extendedToken := newFakeAzureToken("extend token", time.Now().Add(1000*time.Second))
fakeToken := newFakeAzureToken("fake token", time.Now().Add(1000*time.Second))
wrongToken := newFakeAzureToken("wrong token", time.Now().Add(1000*time.Second))
tests := []struct {
name string
sourceToken *azureToken
refreshToken *azureToken
cachedToken *azureToken
configToken *azureToken
expectToken *azureToken
tokenErr error
refreshErr error
expectErr string
tokenCalls uint
refreshCalls uint
persistCalls uint
}{
{
name: "new config",
sourceToken: fakeToken,
expectToken: fakeToken,
tokenCalls: 1,
persistCalls: 1,
},
{
name: "load token from cache",
sourceToken: wrongToken,
cachedToken: fakeToken,
configToken: wrongToken,
expectToken: fakeToken,
},
{
name: "load token from config",
sourceToken: wrongToken,
configToken: fakeToken,
expectToken: fakeToken,
},
{
name: "cached token timeout, extend success, config token should never load",
cachedToken: expiredToken,
refreshToken: extendedToken,
configToken: wrongToken,
expectToken: extendedToken,
refreshCalls: 1,
persistCalls: 1,
},
{
name: "config token timeout, extend failure, acquire new token",
configToken: expiredToken,
refreshErr: fakeTokenRefreshError{message: "FakeError happened when refreshing"},
sourceToken: fakeToken,
expectToken: fakeToken,
refreshCalls: 1,
tokenCalls: 1,
persistCalls: 1,
},
{
name: "unexpected error when extend",
configToken: expiredToken,
refreshErr: errors.New("unexpected refresh error"),
sourceToken: fakeToken,
expectErr: "unexpected refresh error",
refreshCalls: 1,
},
{
name: "token error",
tokenErr: errors.New("tokenerr"),
expectErr: "tokenerr",
tokenCalls: 1,
},
{
name: "Token() got expired token",
sourceToken: expiredToken,
expectErr: "newly acquired token is expired",
tokenCalls: 1,
},
{
name: "Token() got nil but no error",
sourceToken: nil,
expectErr: "unable to acquire token",
tokenCalls: 1,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
persister := newFakePersister()
cfg := map[string]string{}
if tc.configToken != nil {
cfg = token2Cfg(tc.configToken)
}
tokenCache := newAzureTokenCache()
if tc.cachedToken != nil {
tokenCache.setToken(azureTokenKey, tc.cachedToken)
}
fakeSource := fakeTokenSource{
token: tc.sourceToken,
tokenErr: tc.tokenErr,
refreshToken: tc.refreshToken,
refreshErr: tc.refreshErr,
}
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister)
token, err := tokenSource.Token()
if fakeSource.tokenCalls != tc.tokenCalls {
t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls)
}
if fakeSource.refreshCalls != tc.refreshCalls {
t.Errorf("expecting refreshCalls: %v, got: %v", tc.refreshCalls, fakeSource.refreshCalls)
}
if persister.calls != tc.persistCalls {
t.Errorf("expecting persister calls: %v, got: %v", tc.persistCalls, persister.calls)
}
if tc.expectErr != "" {
if !strings.Contains(err.Error(), tc.expectErr) {
t.Errorf("expecting error %v, got %v", tc.expectErr, err)
}
if token != nil {
t.Errorf("token should be nil in err situation, got %v", token)
}
} else {
if err != nil {
t.Fatalf("error should be nil, got %v", err)
}
if token.token.AccessToken != tc.expectToken.token.AccessToken {
t.Errorf("token should have accessToken %v, got %v", token.token.AccessToken, tc.expectToken.token.AccessToken)
}
}
})
}
}
type fakePersister struct {
lock sync.Mutex
cache map[string]string
calls uint
}
func newFakePersister() fakePersister {
return fakePersister{cache: make(map[string]string), calls: 0}
}
func (p *fakePersister) Persist(cache map[string]string) error {
p.lock.Lock()
defer p.lock.Unlock()
p.calls++
p.cache = map[string]string{}
for k, v := range cache {
p.cache[k] = v
@ -248,19 +394,24 @@ func (p *fakePersister) Cache() map[string]string {
return ret
}
// a simple token source simply always returns the token property
type fakeTokenSource struct {
expiresOn string
accessToken string
token *azureToken
tokenCalls uint
tokenErr error
refreshToken *azureToken
refreshCalls uint
refreshErr error
}
func (ts *fakeTokenSource) Token() (*azureToken, error) {
return &azureToken{
token: newFackeAzureToken(ts.accessToken, ts.expiresOn),
environment: "testenv",
clientID: "fake",
tenantID: "fake",
apiserverID: "fake",
}, nil
ts.tokenCalls++
return ts.token, ts.tokenErr
}
func (ts *fakeTokenSource) Refresh(*azureToken) (*azureToken, error) {
ts.refreshCalls++
return ts.refreshToken, ts.refreshErr
}
func token2Cfg(token *azureToken) map[string]string {
@ -276,7 +427,17 @@ func token2Cfg(token *azureToken) map[string]string {
return cfg
}
func newFackeAzureToken(accessToken string, expiresOn string) adal.Token {
func newFakeAzureToken(accessToken string, expiresOnTime time.Time) *azureToken {
return &azureToken{
token: newFakeADALToken(accessToken, strconv.FormatInt(expiresOnTime.Unix(), 10)),
environment: "testenv",
clientID: "fake",
tenantID: "fake",
apiserverID: "fake",
}
}
func newFakeADALToken(accessToken string, expiresOn string) adal.Token {
return adal.Token{
AccessToken: accessToken,
RefreshToken: "fake",
@ -287,3 +448,19 @@ func newFackeAzureToken(accessToken string, expiresOn string) adal.Token {
Type: "fake",
}
}
// copied from go-autorest/adal
type fakeTokenRefreshError struct {
message string
resp *http.Response
}
// Error implements the error interface which is part of the TokenRefreshError interface.
func (tre fakeTokenRefreshError) Error() string {
return tre.message
}
// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
func (tre fakeTokenRefreshError) Response() *http.Response {
return tre.resp
}