mirror of
https://github.com/kubernetes/client-go.git
synced 2025-06-24 14:12:18 +00:00
Merge pull request #86481 from tdihp/feature/aad-fallback-real-auth
aad fallback to real auth if refresh token fails, fixes #82776 Kubernetes-commit: e7852bff43b358bcce7c77a352e171eca200c005
This commit is contained in:
commit
615fa027f2
2
Godeps/Godeps.json
generated
2
Godeps/Godeps.json
generated
@ -348,7 +348,7 @@
|
||||
},
|
||||
{
|
||||
"ImportPath": "k8s.io/apimachinery",
|
||||
"Rev": "48159c651603"
|
||||
"Rev": "1aec6bc431a9"
|
||||
},
|
||||
{
|
||||
"ImportPath": "k8s.io/gengo",
|
||||
|
4
go.mod
4
go.mod
@ -28,7 +28,7 @@ require (
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
|
||||
google.golang.org/appengine v1.5.0 // indirect
|
||||
k8s.io/api v0.0.0-20200320042356-1fc28ea2498c
|
||||
k8s.io/apimachinery v0.0.0-20200320122144-48159c651603
|
||||
k8s.io/apimachinery v0.0.0-20200324202305-1aec6bc431a9
|
||||
k8s.io/klog v1.0.0
|
||||
k8s.io/utils v0.0.0-20200322164244-327a8059b905
|
||||
sigs.k8s.io/yaml v1.2.0
|
||||
@ -38,5 +38,5 @@ replace (
|
||||
golang.org/x/sys => golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a // pinned to release-branch.go1.13
|
||||
golang.org/x/tools => golang.org/x/tools v0.0.0-20190821162956-65e3620a7ae7 // pinned to release-branch.go1.13
|
||||
k8s.io/api => k8s.io/api v0.0.0-20200320042356-1fc28ea2498c
|
||||
k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20200320122144-48159c651603
|
||||
k8s.io/apimachinery => k8s.io/apimachinery v0.0.0-20200324202305-1aec6bc431a9
|
||||
)
|
||||
|
2
go.sum
2
go.sum
@ -188,7 +188,7 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
k8s.io/api v0.0.0-20200320042356-1fc28ea2498c/go.mod h1:5nMyHS4bWX496fulniJ+Sws3P6GLvaP43GadMObLf58=
|
||||
k8s.io/apimachinery v0.0.0-20200320122144-48159c651603/go.mod h1:yKN3QjQfKl8UdUL9RQ+/1VkR7nIUs7w02zC5CXhD+G0=
|
||||
k8s.io/apimachinery v0.0.0-20200324202305-1aec6bc431a9/go.mod h1:yKN3QjQfKl8UdUL9RQ+/1VkR7nIUs7w02zC5CXhD+G0=
|
||||
k8s.io/gengo v0.0.0-20190128074634-0689ccc1d7d6/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8IAqLxYwwyPxAX1Pzy0ii0=
|
||||
k8s.io/klog v0.0.0-20181102134211-b9b56d5dfc92/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=
|
||||
k8s.io/klog v0.3.0/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=
|
||||
|
@ -180,6 +180,7 @@ type azureToken struct {
|
||||
|
||||
type tokenSource interface {
|
||||
Token() (*azureToken, error)
|
||||
Refresh(*azureToken) (*azureToken, error)
|
||||
}
|
||||
|
||||
type azureTokenSource struct {
|
||||
@ -210,33 +211,66 @@ func (ts *azureTokenSource) Token() (*azureToken, error) {
|
||||
|
||||
var err error
|
||||
token := ts.cache.getToken(azureTokenKey)
|
||||
|
||||
if token != nil && !token.token.IsExpired() {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// retrieve from config if no cache
|
||||
if token == nil {
|
||||
token, err = ts.retrieveTokenFromCfg()
|
||||
if err != nil {
|
||||
token, err = ts.source.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("acquiring a new fresh token: %v", err)
|
||||
}
|
||||
tokenFromCfg, err := ts.retrieveTokenFromCfg()
|
||||
|
||||
if err == nil {
|
||||
token = tokenFromCfg
|
||||
}
|
||||
}
|
||||
|
||||
if token != nil {
|
||||
// cache and return if the token is as good
|
||||
// avoids frequent persistor calls
|
||||
if !token.token.IsExpired() {
|
||||
ts.cache.setToken(azureTokenKey, token)
|
||||
err = ts.storeTokenInCfg(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("storing the token in configuration: %v", err)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
klog.V(4).Info("Refreshing token.")
|
||||
tokenFromRefresh, err := ts.Refresh(token)
|
||||
switch {
|
||||
case err == nil:
|
||||
token = tokenFromRefresh
|
||||
case autorest.IsTokenRefreshError(err):
|
||||
klog.V(4).Infof("Failed to refresh expired token, proceed to auth: %v", err)
|
||||
// reset token to nil so that the token source will be used to acquire new
|
||||
token = nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected error when refreshing token: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if token == nil {
|
||||
tokenFromSource, err := ts.source.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed acquiring new token: %v", err)
|
||||
}
|
||||
token = tokenFromSource
|
||||
}
|
||||
|
||||
// sanity check
|
||||
if token == nil {
|
||||
return nil, fmt.Errorf("unable to acquire token")
|
||||
}
|
||||
|
||||
// corner condition, newly got token is valid but expired
|
||||
if token.token.IsExpired() {
|
||||
token, err = ts.refreshToken(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refreshing the expired token: %v", err)
|
||||
}
|
||||
ts.cache.setToken(azureTokenKey, token)
|
||||
err = ts.storeTokenInCfg(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("storing the refreshed token in configuration: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("newly acquired token is expired")
|
||||
}
|
||||
|
||||
err = ts.storeTokenInCfg(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("storing the refreshed token in configuration: %v", err)
|
||||
}
|
||||
ts.cache.setToken(azureTokenKey, token)
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
@ -314,7 +348,13 @@ func (ts *azureTokenSource) storeTokenInCfg(token *azureToken) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *azureTokenSource) refreshToken(token *azureToken) (*azureToken, error) {
|
||||
func (ts *azureTokenSource) Refresh(token *azureToken) (*azureToken, error) {
|
||||
return ts.source.Refresh(token)
|
||||
}
|
||||
|
||||
// refresh outdated token with adal.
|
||||
// adal.RefreshTokenError will be returned if error occur during refreshing.
|
||||
func (ts *azureTokenSourceDeviceCode) Refresh(token *azureToken) (*azureToken, error) {
|
||||
env, err := azure.EnvironmentFromName(token.environment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -18,6 +18,8 @@ package azure
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@ -172,10 +174,7 @@ func TestAzureTokenSource(t *testing.T) {
|
||||
for i, configMode := range configModes {
|
||||
t.Run("validate token against cache", func(t *testing.T) {
|
||||
fakeAccessToken := "fake token 1"
|
||||
fakeSource := fakeTokenSource{
|
||||
accessToken: fakeAccessToken,
|
||||
expiresOn: strconv.FormatInt(time.Now().Add(3600*time.Second).Unix(), 10),
|
||||
}
|
||||
fakeSource := fakeTokenSource{token: newFakeAzureToken(fakeAccessToken, time.Now().Add(3600*time.Second))}
|
||||
cfg := make(map[string]string)
|
||||
persiter := &fakePersister{cache: make(map[string]string)}
|
||||
tokenCache := newAzureTokenCache()
|
||||
@ -210,7 +209,7 @@ func TestAzureTokenSource(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
fakeSource.accessToken = "fake token 2"
|
||||
fakeSource.token = newFakeAzureToken("fake token 2", time.Now().Add(3600*time.Second))
|
||||
token, err = tokenSource.Token()
|
||||
if err != nil {
|
||||
t.Errorf("failed to retrieve the cached token: %v", err)
|
||||
@ -223,14 +222,161 @@ func TestAzureTokenSource(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureTokenSourceScenarios(t *testing.T) {
|
||||
configMode := configModeDefault
|
||||
expiredToken := newFakeAzureToken("expired token", time.Now().Add(-time.Second))
|
||||
extendedToken := newFakeAzureToken("extend token", time.Now().Add(1000*time.Second))
|
||||
fakeToken := newFakeAzureToken("fake token", time.Now().Add(1000*time.Second))
|
||||
wrongToken := newFakeAzureToken("wrong token", time.Now().Add(1000*time.Second))
|
||||
tests := []struct {
|
||||
name string
|
||||
sourceToken *azureToken
|
||||
refreshToken *azureToken
|
||||
cachedToken *azureToken
|
||||
configToken *azureToken
|
||||
expectToken *azureToken
|
||||
tokenErr error
|
||||
refreshErr error
|
||||
expectErr string
|
||||
tokenCalls uint
|
||||
refreshCalls uint
|
||||
persistCalls uint
|
||||
}{
|
||||
{
|
||||
name: "new config",
|
||||
sourceToken: fakeToken,
|
||||
expectToken: fakeToken,
|
||||
tokenCalls: 1,
|
||||
persistCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "load token from cache",
|
||||
sourceToken: wrongToken,
|
||||
cachedToken: fakeToken,
|
||||
configToken: wrongToken,
|
||||
expectToken: fakeToken,
|
||||
},
|
||||
{
|
||||
name: "load token from config",
|
||||
sourceToken: wrongToken,
|
||||
configToken: fakeToken,
|
||||
expectToken: fakeToken,
|
||||
},
|
||||
{
|
||||
name: "cached token timeout, extend success, config token should never load",
|
||||
cachedToken: expiredToken,
|
||||
refreshToken: extendedToken,
|
||||
configToken: wrongToken,
|
||||
expectToken: extendedToken,
|
||||
refreshCalls: 1,
|
||||
persistCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "config token timeout, extend failure, acquire new token",
|
||||
configToken: expiredToken,
|
||||
refreshErr: fakeTokenRefreshError{message: "FakeError happened when refreshing"},
|
||||
sourceToken: fakeToken,
|
||||
expectToken: fakeToken,
|
||||
refreshCalls: 1,
|
||||
tokenCalls: 1,
|
||||
persistCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "unexpected error when extend",
|
||||
configToken: expiredToken,
|
||||
refreshErr: errors.New("unexpected refresh error"),
|
||||
sourceToken: fakeToken,
|
||||
expectErr: "unexpected refresh error",
|
||||
refreshCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "token error",
|
||||
tokenErr: errors.New("tokenerr"),
|
||||
expectErr: "tokenerr",
|
||||
tokenCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "Token() got expired token",
|
||||
sourceToken: expiredToken,
|
||||
expectErr: "newly acquired token is expired",
|
||||
tokenCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "Token() got nil but no error",
|
||||
sourceToken: nil,
|
||||
expectErr: "unable to acquire token",
|
||||
tokenCalls: 1,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
persister := newFakePersister()
|
||||
|
||||
cfg := map[string]string{}
|
||||
if tc.configToken != nil {
|
||||
cfg = token2Cfg(tc.configToken)
|
||||
}
|
||||
|
||||
tokenCache := newAzureTokenCache()
|
||||
if tc.cachedToken != nil {
|
||||
tokenCache.setToken(azureTokenKey, tc.cachedToken)
|
||||
}
|
||||
|
||||
fakeSource := fakeTokenSource{
|
||||
token: tc.sourceToken,
|
||||
tokenErr: tc.tokenErr,
|
||||
refreshToken: tc.refreshToken,
|
||||
refreshErr: tc.refreshErr,
|
||||
}
|
||||
|
||||
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister)
|
||||
token, err := tokenSource.Token()
|
||||
|
||||
if fakeSource.tokenCalls != tc.tokenCalls {
|
||||
t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls)
|
||||
}
|
||||
|
||||
if fakeSource.refreshCalls != tc.refreshCalls {
|
||||
t.Errorf("expecting refreshCalls: %v, got: %v", tc.refreshCalls, fakeSource.refreshCalls)
|
||||
}
|
||||
|
||||
if persister.calls != tc.persistCalls {
|
||||
t.Errorf("expecting persister calls: %v, got: %v", tc.persistCalls, persister.calls)
|
||||
}
|
||||
|
||||
if tc.expectErr != "" {
|
||||
if !strings.Contains(err.Error(), tc.expectErr) {
|
||||
t.Errorf("expecting error %v, got %v", tc.expectErr, err)
|
||||
}
|
||||
if token != nil {
|
||||
t.Errorf("token should be nil in err situation, got %v", token)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatalf("error should be nil, got %v", err)
|
||||
}
|
||||
if token.token.AccessToken != tc.expectToken.token.AccessToken {
|
||||
t.Errorf("token should have accessToken %v, got %v", token.token.AccessToken, tc.expectToken.token.AccessToken)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type fakePersister struct {
|
||||
lock sync.Mutex
|
||||
cache map[string]string
|
||||
calls uint
|
||||
}
|
||||
|
||||
func newFakePersister() fakePersister {
|
||||
return fakePersister{cache: make(map[string]string), calls: 0}
|
||||
}
|
||||
|
||||
func (p *fakePersister) Persist(cache map[string]string) error {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
p.calls++
|
||||
p.cache = map[string]string{}
|
||||
for k, v := range cache {
|
||||
p.cache[k] = v
|
||||
@ -248,19 +394,24 @@ func (p *fakePersister) Cache() map[string]string {
|
||||
return ret
|
||||
}
|
||||
|
||||
// a simple token source simply always returns the token property
|
||||
type fakeTokenSource struct {
|
||||
expiresOn string
|
||||
accessToken string
|
||||
token *azureToken
|
||||
tokenCalls uint
|
||||
tokenErr error
|
||||
refreshToken *azureToken
|
||||
refreshCalls uint
|
||||
refreshErr error
|
||||
}
|
||||
|
||||
func (ts *fakeTokenSource) Token() (*azureToken, error) {
|
||||
return &azureToken{
|
||||
token: newFackeAzureToken(ts.accessToken, ts.expiresOn),
|
||||
environment: "testenv",
|
||||
clientID: "fake",
|
||||
tenantID: "fake",
|
||||
apiserverID: "fake",
|
||||
}, nil
|
||||
ts.tokenCalls++
|
||||
return ts.token, ts.tokenErr
|
||||
}
|
||||
|
||||
func (ts *fakeTokenSource) Refresh(*azureToken) (*azureToken, error) {
|
||||
ts.refreshCalls++
|
||||
return ts.refreshToken, ts.refreshErr
|
||||
}
|
||||
|
||||
func token2Cfg(token *azureToken) map[string]string {
|
||||
@ -276,7 +427,17 @@ func token2Cfg(token *azureToken) map[string]string {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func newFackeAzureToken(accessToken string, expiresOn string) adal.Token {
|
||||
func newFakeAzureToken(accessToken string, expiresOnTime time.Time) *azureToken {
|
||||
return &azureToken{
|
||||
token: newFakeADALToken(accessToken, strconv.FormatInt(expiresOnTime.Unix(), 10)),
|
||||
environment: "testenv",
|
||||
clientID: "fake",
|
||||
tenantID: "fake",
|
||||
apiserverID: "fake",
|
||||
}
|
||||
}
|
||||
|
||||
func newFakeADALToken(accessToken string, expiresOn string) adal.Token {
|
||||
return adal.Token{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: "fake",
|
||||
@ -287,3 +448,19 @@ func newFackeAzureToken(accessToken string, expiresOn string) adal.Token {
|
||||
Type: "fake",
|
||||
}
|
||||
}
|
||||
|
||||
// copied from go-autorest/adal
|
||||
type fakeTokenRefreshError struct {
|
||||
message string
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
// Error implements the error interface which is part of the TokenRefreshError interface.
|
||||
func (tre fakeTokenRefreshError) Error() string {
|
||||
return tre.message
|
||||
}
|
||||
|
||||
// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
|
||||
func (tre fakeTokenRefreshError) Response() *http.Response {
|
||||
return tre.resp
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user