Merge pull request #86481 from tdihp/feature/aad-fallback-real-auth

aad fallback to real auth if refresh token fails, fixes #82776

Kubernetes-commit: e7852bff43b358bcce7c77a352e171eca200c005
This commit is contained in:
Kubernetes Publisher 2020-03-24 14:09:02 -07:00
commit 615fa027f2
5 changed files with 256 additions and 39 deletions

2
Godeps/Godeps.json generated
View File

@ -348,7 +348,7 @@
},
{
"ImportPath": "k8s.io/apimachinery",
"Rev": "48159c651603"
"Rev": "1aec6bc431a9"
},
{
"ImportPath": "k8s.io/gengo",

4
go.mod
View File

@ -28,7 +28,7 @@ require (
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
google.golang.org/appengine v1.5.0 // indirect
k8s.io/api v0.0.0-20200320042356-1fc28ea2498c
k8s.io/apimachinery v0.0.0-20200320122144-48159c651603
k8s.io/apimachinery v0.0.0-20200324202305-1aec6bc431a9
k8s.io/klog v1.0.0
k8s.io/utils v0.0.0-20200322164244-327a8059b905
sigs.k8s.io/yaml v1.2.0
@ -38,5 +38,5 @@ replace (
golang.org/x/sys => golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a // pinned to release-branch.go1.13
golang.org/x/tools => golang.org/x/tools v0.0.0-20190821162956-65e3620a7ae7 // pinned to release-branch.go1.13
k8s.io/api => k8s.io/api v0.0.0-20200320042356-1fc28ea2498c
k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20200320122144-48159c651603
k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20200324202305-1aec6bc431a9
)

2
go.sum
View File

@ -188,7 +188,7 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
k8s.io/api v0.0.0-20200320042356-1fc28ea2498c/go.mod h1:5nMyHS4bWX496fulniJ+Sws3P6GLvaP43GadMObLf58=
k8s.io/apimachinery v0.0.0-20200320122144-48159c651603/go.mod h1:yKN3QjQfKl8UdUL9RQ+/1VkR7nIUs7w02zC5CXhD+G0=
k8s.io/apimachinery v0.0.0-20200324202305-1aec6bc431a9/go.mod h1:yKN3QjQfKl8UdUL9RQ+/1VkR7nIUs7w02zC5CXhD+G0=
k8s.io/gengo v0.0.0-20190128074634-0689ccc1d7d6/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8IAqLxYwwyPxAX1Pzy0ii0=
k8s.io/klog v0.0.0-20181102134211-b9b56d5dfc92/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=
k8s.io/klog v0.3.0/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=

View File

@ -180,6 +180,7 @@ type azureToken struct {
type tokenSource interface {
Token() (*azureToken, error)
Refresh(*azureToken) (*azureToken, error)
}
type azureTokenSource struct {
@ -210,33 +211,66 @@ func (ts *azureTokenSource) Token() (*azureToken, error) {
var err error
token := ts.cache.getToken(azureTokenKey)
if token != nil && !token.token.IsExpired() {
return token, nil
}
// retrieve from config if no cache
if token == nil {
token, err = ts.retrieveTokenFromCfg()
if err != nil {
token, err = ts.source.Token()
if err != nil {
return nil, fmt.Errorf("acquiring a new fresh token: %v", err)
}
tokenFromCfg, err := ts.retrieveTokenFromCfg()
if err == nil {
token = tokenFromCfg
}
}
if token != nil {
// cache and return if the token is as good
// avoids frequent persistor calls
if !token.token.IsExpired() {
ts.cache.setToken(azureTokenKey, token)
err = ts.storeTokenInCfg(token)
if err != nil {
return nil, fmt.Errorf("storing the token in configuration: %v", err)
}
return token, nil
}
klog.V(4).Info("Refreshing token.")
tokenFromRefresh, err := ts.Refresh(token)
switch {
case err == nil:
token = tokenFromRefresh
case autorest.IsTokenRefreshError(err):
klog.V(4).Infof("Failed to refresh expired token, proceed to auth: %v", err)
// reset token to nil so that the token source will be used to acquire new
token = nil
default:
return nil, fmt.Errorf("unexpected error when refreshing token: %v", err)
}
}
if token == nil {
tokenFromSource, err := ts.source.Token()
if err != nil {
return nil, fmt.Errorf("failed acquiring new token: %v", err)
}
token = tokenFromSource
}
// sanity check
if token == nil {
return nil, fmt.Errorf("unable to acquire token")
}
// corner condition, newly got token is valid but expired
if token.token.IsExpired() {
token, err = ts.refreshToken(token)
if err != nil {
return nil, fmt.Errorf("refreshing the expired token: %v", err)
}
ts.cache.setToken(azureTokenKey, token)
err = ts.storeTokenInCfg(token)
if err != nil {
return nil, fmt.Errorf("storing the refreshed token in configuration: %v", err)
}
return nil, fmt.Errorf("newly acquired token is expired")
}
err = ts.storeTokenInCfg(token)
if err != nil {
return nil, fmt.Errorf("storing the refreshed token in configuration: %v", err)
}
ts.cache.setToken(azureTokenKey, token)
return token, nil
}
@ -314,7 +348,13 @@ func (ts *azureTokenSource) storeTokenInCfg(token *azureToken) error {
return nil
}
func (ts *azureTokenSource) refreshToken(token *azureToken) (*azureToken, error) {
func (ts *azureTokenSource) Refresh(token *azureToken) (*azureToken, error) {
return ts.source.Refresh(token)
}
// refresh outdated token with adal.
// adal.RefreshTokenError will be returned if error occur during refreshing.
func (ts *azureTokenSourceDeviceCode) Refresh(token *azureToken) (*azureToken, error) {
env, err := azure.EnvironmentFromName(token.environment)
if err != nil {
return nil, err

View File

@ -18,6 +18,8 @@ package azure
import (
"encoding/json"
"errors"
"net/http"
"strconv"
"strings"
"sync"
@ -172,10 +174,7 @@ func TestAzureTokenSource(t *testing.T) {
for i, configMode := range configModes {
t.Run("validate token against cache", func(t *testing.T) {
fakeAccessToken := "fake token 1"
fakeSource := fakeTokenSource{
accessToken: fakeAccessToken,
expiresOn: strconv.FormatInt(time.Now().Add(3600*time.Second).Unix(), 10),
}
fakeSource := fakeTokenSource{token: newFakeAzureToken(fakeAccessToken, time.Now().Add(3600*time.Second))}
cfg := make(map[string]string)
persiter := &fakePersister{cache: make(map[string]string)}
tokenCache := newAzureTokenCache()
@ -210,7 +209,7 @@ func TestAzureTokenSource(t *testing.T) {
}
}
fakeSource.accessToken = "fake token 2"
fakeSource.token = newFakeAzureToken("fake token 2", time.Now().Add(3600*time.Second))
token, err = tokenSource.Token()
if err != nil {
t.Errorf("failed to retrieve the cached token: %v", err)
@ -223,14 +222,161 @@ func TestAzureTokenSource(t *testing.T) {
}
}
func TestAzureTokenSourceScenarios(t *testing.T) {
configMode := configModeDefault
expiredToken := newFakeAzureToken("expired token", time.Now().Add(-time.Second))
extendedToken := newFakeAzureToken("extend token", time.Now().Add(1000*time.Second))
fakeToken := newFakeAzureToken("fake token", time.Now().Add(1000*time.Second))
wrongToken := newFakeAzureToken("wrong token", time.Now().Add(1000*time.Second))
tests := []struct {
name string
sourceToken *azureToken
refreshToken *azureToken
cachedToken *azureToken
configToken *azureToken
expectToken *azureToken
tokenErr error
refreshErr error
expectErr string
tokenCalls uint
refreshCalls uint
persistCalls uint
}{
{
name: "new config",
sourceToken: fakeToken,
expectToken: fakeToken,
tokenCalls: 1,
persistCalls: 1,
},
{
name: "load token from cache",
sourceToken: wrongToken,
cachedToken: fakeToken,
configToken: wrongToken,
expectToken: fakeToken,
},
{
name: "load token from config",
sourceToken: wrongToken,
configToken: fakeToken,
expectToken: fakeToken,
},
{
name: "cached token timeout, extend success, config token should never load",
cachedToken: expiredToken,
refreshToken: extendedToken,
configToken: wrongToken,
expectToken: extendedToken,
refreshCalls: 1,
persistCalls: 1,
},
{
name: "config token timeout, extend failure, acquire new token",
configToken: expiredToken,
refreshErr: fakeTokenRefreshError{message: "FakeError happened when refreshing"},
sourceToken: fakeToken,
expectToken: fakeToken,
refreshCalls: 1,
tokenCalls: 1,
persistCalls: 1,
},
{
name: "unexpected error when extend",
configToken: expiredToken,
refreshErr: errors.New("unexpected refresh error"),
sourceToken: fakeToken,
expectErr: "unexpected refresh error",
refreshCalls: 1,
},
{
name: "token error",
tokenErr: errors.New("tokenerr"),
expectErr: "tokenerr",
tokenCalls: 1,
},
{
name: "Token() got expired token",
sourceToken: expiredToken,
expectErr: "newly acquired token is expired",
tokenCalls: 1,
},
{
name: "Token() got nil but no error",
sourceToken: nil,
expectErr: "unable to acquire token",
tokenCalls: 1,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
persister := newFakePersister()
cfg := map[string]string{}
if tc.configToken != nil {
cfg = token2Cfg(tc.configToken)
}
tokenCache := newAzureTokenCache()
if tc.cachedToken != nil {
tokenCache.setToken(azureTokenKey, tc.cachedToken)
}
fakeSource := fakeTokenSource{
token: tc.sourceToken,
tokenErr: tc.tokenErr,
refreshToken: tc.refreshToken,
refreshErr: tc.refreshErr,
}
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister)
token, err := tokenSource.Token()
if fakeSource.tokenCalls != tc.tokenCalls {
t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls)
}
if fakeSource.refreshCalls != tc.refreshCalls {
t.Errorf("expecting refreshCalls: %v, got: %v", tc.refreshCalls, fakeSource.refreshCalls)
}
if persister.calls != tc.persistCalls {
t.Errorf("expecting persister calls: %v, got: %v", tc.persistCalls, persister.calls)
}
if tc.expectErr != "" {
if !strings.Contains(err.Error(), tc.expectErr) {
t.Errorf("expecting error %v, got %v", tc.expectErr, err)
}
if token != nil {
t.Errorf("token should be nil in err situation, got %v", token)
}
} else {
if err != nil {
t.Fatalf("error should be nil, got %v", err)
}
if token.token.AccessToken != tc.expectToken.token.AccessToken {
t.Errorf("token should have accessToken %v, got %v", token.token.AccessToken, tc.expectToken.token.AccessToken)
}
}
})
}
}
type fakePersister struct {
lock sync.Mutex
cache map[string]string
calls uint
}
func newFakePersister() fakePersister {
return fakePersister{cache: make(map[string]string), calls: 0}
}
func (p *fakePersister) Persist(cache map[string]string) error {
p.lock.Lock()
defer p.lock.Unlock()
p.calls++
p.cache = map[string]string{}
for k, v := range cache {
p.cache[k] = v
@ -248,19 +394,24 @@ func (p *fakePersister) Cache() map[string]string {
return ret
}
// a simple token source simply always returns the token property
type fakeTokenSource struct {
expiresOn string
accessToken string
token *azureToken
tokenCalls uint
tokenErr error
refreshToken *azureToken
refreshCalls uint
refreshErr error
}
func (ts *fakeTokenSource) Token() (*azureToken, error) {
return &azureToken{
token: newFackeAzureToken(ts.accessToken, ts.expiresOn),
environment: "testenv",
clientID: "fake",
tenantID: "fake",
apiserverID: "fake",
}, nil
ts.tokenCalls++
return ts.token, ts.tokenErr
}
func (ts *fakeTokenSource) Refresh(*azureToken) (*azureToken, error) {
ts.refreshCalls++
return ts.refreshToken, ts.refreshErr
}
func token2Cfg(token *azureToken) map[string]string {
@ -276,7 +427,17 @@ func token2Cfg(token *azureToken) map[string]string {
return cfg
}
func newFackeAzureToken(accessToken string, expiresOn string) adal.Token {
func newFakeAzureToken(accessToken string, expiresOnTime time.Time) *azureToken {
return &azureToken{
token: newFakeADALToken(accessToken, strconv.FormatInt(expiresOnTime.Unix(), 10)),
environment: "testenv",
clientID: "fake",
tenantID: "fake",
apiserverID: "fake",
}
}
func newFakeADALToken(accessToken string, expiresOn string) adal.Token {
return adal.Token{
AccessToken: accessToken,
RefreshToken: "fake",
@ -287,3 +448,19 @@ func newFackeAzureToken(accessToken string, expiresOn string) adal.Token {
Type: "fake",
}
}
// copied from go-autorest/adal
type fakeTokenRefreshError struct {
message string
resp *http.Response
}
// Error implements the error interface which is part of the TokenRefreshError interface.
func (tre fakeTokenRefreshError) Error() string {
return tre.message
}
// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
func (tre fakeTokenRefreshError) Response() *http.Response {
return tre.resp
}