diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index 01bb2c95..5a4ab2ea 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -348,7 +348,7 @@ }, { "ImportPath": "k8s.io/apimachinery", - "Rev": "48159c651603" + "Rev": "1aec6bc431a9" }, { "ImportPath": "k8s.io/gengo", diff --git a/go.mod b/go.mod index 6060f3fa..4d6410a1 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,7 @@ require ( golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 google.golang.org/appengine v1.5.0 // indirect k8s.io/api v0.0.0-20200320042356-1fc28ea2498c - k8s.io/apimachinery v0.0.0-20200320122144-48159c651603 + k8s.io/apimachinery v0.0.0-20200324202305-1aec6bc431a9 k8s.io/klog v1.0.0 k8s.io/utils v0.0.0-20200322164244-327a8059b905 sigs.k8s.io/yaml v1.2.0 @@ -38,5 +38,5 @@ replace ( golang.org/x/sys => golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a // pinned to release-branch.go1.13 golang.org/x/tools => golang.org/x/tools v0.0.0-20190821162956-65e3620a7ae7 // pinned to release-branch.go1.13 k8s.io/api => k8s.io/api v0.0.0-20200320042356-1fc28ea2498c - k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20200320122144-48159c651603 + k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20200324202305-1aec6bc431a9 ) diff --git a/go.sum b/go.sum index 8982c8f1..3b04fdae 100644 --- a/go.sum +++ b/go.sum @@ -188,7 +188,7 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= k8s.io/api v0.0.0-20200320042356-1fc28ea2498c/go.mod h1:5nMyHS4bWX496fulniJ+Sws3P6GLvaP43GadMObLf58= -k8s.io/apimachinery v0.0.0-20200320122144-48159c651603/go.mod h1:yKN3QjQfKl8UdUL9RQ+/1VkR7nIUs7w02zC5CXhD+G0= +k8s.io/apimachinery v0.0.0-20200324202305-1aec6bc431a9/go.mod h1:yKN3QjQfKl8UdUL9RQ+/1VkR7nIUs7w02zC5CXhD+G0= k8s.io/gengo v0.0.0-20190128074634-0689ccc1d7d6/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8IAqLxYwwyPxAX1Pzy0ii0= k8s.io/klog v0.0.0-20181102134211-b9b56d5dfc92/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk= k8s.io/klog v0.3.0/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk= diff --git a/plugin/pkg/client/auth/azure/azure.go b/plugin/pkg/client/auth/azure/azure.go index 11ef7afb..2746536b 100644 --- a/plugin/pkg/client/auth/azure/azure.go +++ b/plugin/pkg/client/auth/azure/azure.go @@ -180,6 +180,7 @@ type azureToken struct { type tokenSource interface { Token() (*azureToken, error) + Refresh(*azureToken) (*azureToken, error) } type azureTokenSource struct { @@ -210,33 +211,66 @@ func (ts *azureTokenSource) Token() (*azureToken, error) { var err error token := ts.cache.getToken(azureTokenKey) + + if token != nil && !token.token.IsExpired() { + return token, nil + } + + // retrieve from config if no cache if token == nil { - token, err = ts.retrieveTokenFromCfg() - if err != nil { - token, err = ts.source.Token() - if err != nil { - return nil, fmt.Errorf("acquiring a new fresh token: %v", err) - } + tokenFromCfg, err := ts.retrieveTokenFromCfg() + + if err == nil { + token = tokenFromCfg } + } + + if token != nil { + // cache and return if the token is as good + // avoids frequent persistor calls if !token.token.IsExpired() { ts.cache.setToken(azureTokenKey, token) - err = ts.storeTokenInCfg(token) - if err != nil { - return nil, fmt.Errorf("storing the token in configuration: %v", err) - } + return token, nil + } + + klog.V(4).Info("Refreshing token.") + tokenFromRefresh, err := ts.Refresh(token) + switch { + case err == nil: + token = tokenFromRefresh + case autorest.IsTokenRefreshError(err): + klog.V(4).Infof("Failed to refresh expired token, proceed to auth: %v", err) + // reset token to nil so that the token source will be used to acquire new + token = nil + default: + return nil, fmt.Errorf("unexpected error when refreshing token: %v", err) } } + + if token == nil { + tokenFromSource, err := ts.source.Token() + if err != nil { + return nil, fmt.Errorf("failed acquiring new token: %v", err) + } + token = tokenFromSource + } + + // sanity check + if token == nil { + return nil, fmt.Errorf("unable to acquire token") + } + + // corner condition, newly got token is valid but expired if token.token.IsExpired() { - token, err = ts.refreshToken(token) - if err != nil { - return nil, fmt.Errorf("refreshing the expired token: %v", err) - } - ts.cache.setToken(azureTokenKey, token) - err = ts.storeTokenInCfg(token) - if err != nil { - return nil, fmt.Errorf("storing the refreshed token in configuration: %v", err) - } + return nil, fmt.Errorf("newly acquired token is expired") } + + err = ts.storeTokenInCfg(token) + if err != nil { + return nil, fmt.Errorf("storing the refreshed token in configuration: %v", err) + } + ts.cache.setToken(azureTokenKey, token) + return token, nil } @@ -314,7 +348,13 @@ func (ts *azureTokenSource) storeTokenInCfg(token *azureToken) error { return nil } -func (ts *azureTokenSource) refreshToken(token *azureToken) (*azureToken, error) { +func (ts *azureTokenSource) Refresh(token *azureToken) (*azureToken, error) { + return ts.source.Refresh(token) +} + +// refresh outdated token with adal. +// adal.RefreshTokenError will be returned if error occur during refreshing. +func (ts *azureTokenSourceDeviceCode) Refresh(token *azureToken) (*azureToken, error) { env, err := azure.EnvironmentFromName(token.environment) if err != nil { return nil, err diff --git a/plugin/pkg/client/auth/azure/azure_test.go b/plugin/pkg/client/auth/azure/azure_test.go index 75fb6c3a..05612e29 100644 --- a/plugin/pkg/client/auth/azure/azure_test.go +++ b/plugin/pkg/client/auth/azure/azure_test.go @@ -18,6 +18,8 @@ package azure import ( "encoding/json" + "errors" + "net/http" "strconv" "strings" "sync" @@ -172,10 +174,7 @@ func TestAzureTokenSource(t *testing.T) { for i, configMode := range configModes { t.Run("validate token against cache", func(t *testing.T) { fakeAccessToken := "fake token 1" - fakeSource := fakeTokenSource{ - accessToken: fakeAccessToken, - expiresOn: strconv.FormatInt(time.Now().Add(3600*time.Second).Unix(), 10), - } + fakeSource := fakeTokenSource{token: newFakeAzureToken(fakeAccessToken, time.Now().Add(3600*time.Second))} cfg := make(map[string]string) persiter := &fakePersister{cache: make(map[string]string)} tokenCache := newAzureTokenCache() @@ -210,7 +209,7 @@ func TestAzureTokenSource(t *testing.T) { } } - fakeSource.accessToken = "fake token 2" + fakeSource.token = newFakeAzureToken("fake token 2", time.Now().Add(3600*time.Second)) token, err = tokenSource.Token() if err != nil { t.Errorf("failed to retrieve the cached token: %v", err) @@ -223,14 +222,161 @@ func TestAzureTokenSource(t *testing.T) { } } +func TestAzureTokenSourceScenarios(t *testing.T) { + configMode := configModeDefault + expiredToken := newFakeAzureToken("expired token", time.Now().Add(-time.Second)) + extendedToken := newFakeAzureToken("extend token", time.Now().Add(1000*time.Second)) + fakeToken := newFakeAzureToken("fake token", time.Now().Add(1000*time.Second)) + wrongToken := newFakeAzureToken("wrong token", time.Now().Add(1000*time.Second)) + tests := []struct { + name string + sourceToken *azureToken + refreshToken *azureToken + cachedToken *azureToken + configToken *azureToken + expectToken *azureToken + tokenErr error + refreshErr error + expectErr string + tokenCalls uint + refreshCalls uint + persistCalls uint + }{ + { + name: "new config", + sourceToken: fakeToken, + expectToken: fakeToken, + tokenCalls: 1, + persistCalls: 1, + }, + { + name: "load token from cache", + sourceToken: wrongToken, + cachedToken: fakeToken, + configToken: wrongToken, + expectToken: fakeToken, + }, + { + name: "load token from config", + sourceToken: wrongToken, + configToken: fakeToken, + expectToken: fakeToken, + }, + { + name: "cached token timeout, extend success, config token should never load", + cachedToken: expiredToken, + refreshToken: extendedToken, + configToken: wrongToken, + expectToken: extendedToken, + refreshCalls: 1, + persistCalls: 1, + }, + { + name: "config token timeout, extend failure, acquire new token", + configToken: expiredToken, + refreshErr: fakeTokenRefreshError{message: "FakeError happened when refreshing"}, + sourceToken: fakeToken, + expectToken: fakeToken, + refreshCalls: 1, + tokenCalls: 1, + persistCalls: 1, + }, + { + name: "unexpected error when extend", + configToken: expiredToken, + refreshErr: errors.New("unexpected refresh error"), + sourceToken: fakeToken, + expectErr: "unexpected refresh error", + refreshCalls: 1, + }, + { + name: "token error", + tokenErr: errors.New("tokenerr"), + expectErr: "tokenerr", + tokenCalls: 1, + }, + { + name: "Token() got expired token", + sourceToken: expiredToken, + expectErr: "newly acquired token is expired", + tokenCalls: 1, + }, + { + name: "Token() got nil but no error", + sourceToken: nil, + expectErr: "unable to acquire token", + tokenCalls: 1, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + persister := newFakePersister() + + cfg := map[string]string{} + if tc.configToken != nil { + cfg = token2Cfg(tc.configToken) + } + + tokenCache := newAzureTokenCache() + if tc.cachedToken != nil { + tokenCache.setToken(azureTokenKey, tc.cachedToken) + } + + fakeSource := fakeTokenSource{ + token: tc.sourceToken, + tokenErr: tc.tokenErr, + refreshToken: tc.refreshToken, + refreshErr: tc.refreshErr, + } + + tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister) + token, err := tokenSource.Token() + + if fakeSource.tokenCalls != tc.tokenCalls { + t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls) + } + + if fakeSource.refreshCalls != tc.refreshCalls { + t.Errorf("expecting refreshCalls: %v, got: %v", tc.refreshCalls, fakeSource.refreshCalls) + } + + if persister.calls != tc.persistCalls { + t.Errorf("expecting persister calls: %v, got: %v", tc.persistCalls, persister.calls) + } + + if tc.expectErr != "" { + if !strings.Contains(err.Error(), tc.expectErr) { + t.Errorf("expecting error %v, got %v", tc.expectErr, err) + } + if token != nil { + t.Errorf("token should be nil in err situation, got %v", token) + } + } else { + if err != nil { + t.Fatalf("error should be nil, got %v", err) + } + if token.token.AccessToken != tc.expectToken.token.AccessToken { + t.Errorf("token should have accessToken %v, got %v", token.token.AccessToken, tc.expectToken.token.AccessToken) + } + } + }) + } +} + type fakePersister struct { lock sync.Mutex cache map[string]string + calls uint +} + +func newFakePersister() fakePersister { + return fakePersister{cache: make(map[string]string), calls: 0} } func (p *fakePersister) Persist(cache map[string]string) error { p.lock.Lock() defer p.lock.Unlock() + p.calls++ p.cache = map[string]string{} for k, v := range cache { p.cache[k] = v @@ -248,19 +394,24 @@ func (p *fakePersister) Cache() map[string]string { return ret } +// a simple token source simply always returns the token property type fakeTokenSource struct { - expiresOn string - accessToken string + token *azureToken + tokenCalls uint + tokenErr error + refreshToken *azureToken + refreshCalls uint + refreshErr error } func (ts *fakeTokenSource) Token() (*azureToken, error) { - return &azureToken{ - token: newFackeAzureToken(ts.accessToken, ts.expiresOn), - environment: "testenv", - clientID: "fake", - tenantID: "fake", - apiserverID: "fake", - }, nil + ts.tokenCalls++ + return ts.token, ts.tokenErr +} + +func (ts *fakeTokenSource) Refresh(*azureToken) (*azureToken, error) { + ts.refreshCalls++ + return ts.refreshToken, ts.refreshErr } func token2Cfg(token *azureToken) map[string]string { @@ -276,7 +427,17 @@ func token2Cfg(token *azureToken) map[string]string { return cfg } -func newFackeAzureToken(accessToken string, expiresOn string) adal.Token { +func newFakeAzureToken(accessToken string, expiresOnTime time.Time) *azureToken { + return &azureToken{ + token: newFakeADALToken(accessToken, strconv.FormatInt(expiresOnTime.Unix(), 10)), + environment: "testenv", + clientID: "fake", + tenantID: "fake", + apiserverID: "fake", + } +} + +func newFakeADALToken(accessToken string, expiresOn string) adal.Token { return adal.Token{ AccessToken: accessToken, RefreshToken: "fake", @@ -287,3 +448,19 @@ func newFackeAzureToken(accessToken string, expiresOn string) adal.Token { Type: "fake", } } + +// copied from go-autorest/adal +type fakeTokenRefreshError struct { + message string + resp *http.Response +} + +// Error implements the error interface which is part of the TokenRefreshError interface. +func (tre fakeTokenRefreshError) Error() string { + return tre.message +} + +// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation. +func (tre fakeTokenRefreshError) Response() *http.Response { + return tre.resp +}