Files
client-go/plugin/pkg/client/auth/azure/azure_test.go
Ping He a6c937f0cd Azure auth fallback to real auth if refresh token fails, refactor and add more tests.
Signed-off-by: Ping He <tdihp@hotmail.com>

Kubernetes-commit: 26c97fa1b40a7939ca26084c819af4794df34406
2020-03-22 17:04:20 +08:00

467 lines
13 KiB
Go

/*
Copyright 2017 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package azure
import (
"encoding/json"
"errors"
"net/http"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
)
func TestAzureAuthProvider(t *testing.T) {
t.Run("validate against invalid configurations", func(t *testing.T) {
vectors := []struct {
cfg map[string]string
expectedError string
}{
{
cfg: map[string]string{
cfgClientID: "foo",
cfgApiserverID: "foo",
cfgTenantID: "foo",
cfgConfigMode: "-1",
},
expectedError: "config-mode:-1 is not a valid mode",
},
{
cfg: map[string]string{
cfgClientID: "foo",
cfgApiserverID: "foo",
cfgTenantID: "foo",
cfgConfigMode: "2",
},
expectedError: "config-mode:2 is not a valid mode",
},
{
cfg: map[string]string{
cfgClientID: "foo",
cfgApiserverID: "foo",
cfgTenantID: "foo",
cfgConfigMode: "foo",
},
expectedError: "failed to parse config-mode, error: strconv.Atoi: parsing \"foo\": invalid syntax",
},
}
for _, v := range vectors {
persister := &fakePersister{}
_, err := newAzureAuthProvider("", v.cfg, persister)
if !strings.Contains(err.Error(), v.expectedError) {
t.Errorf("cfg %v should fail with message containing '%s'. actual: '%s'", v.cfg, v.expectedError, err)
}
}
})
t.Run("it should return non-nil provider in happy cases", func(t *testing.T) {
vectors := []struct {
cfg map[string]string
expectedConfigMode configMode
}{
{
cfg: map[string]string{
cfgClientID: "foo",
cfgApiserverID: "foo",
cfgTenantID: "foo",
},
expectedConfigMode: configModeDefault,
},
{
cfg: map[string]string{
cfgClientID: "foo",
cfgApiserverID: "foo",
cfgTenantID: "foo",
cfgConfigMode: "0",
},
expectedConfigMode: configModeDefault,
},
{
cfg: map[string]string{
cfgClientID: "foo",
cfgApiserverID: "foo",
cfgTenantID: "foo",
cfgConfigMode: "1",
},
expectedConfigMode: configModeOmitSPNPrefix,
},
}
for _, v := range vectors {
persister := &fakePersister{}
provider, err := newAzureAuthProvider("", v.cfg, persister)
if err != nil {
t.Errorf("newAzureAuthProvider should not fail with '%s'", err)
}
if provider == nil {
t.Fatalf("newAzureAuthProvider should return non-nil provider")
}
azureProvider := provider.(*azureAuthProvider)
if azureProvider == nil {
t.Errorf("newAzureAuthProvider should return an instance of type azureAuthProvider")
}
ts := azureProvider.tokenSource.(*azureTokenSource)
if ts == nil {
t.Errorf("azureAuthProvider should be an instance of azureTokenSource")
}
if ts.configMode != v.expectedConfigMode {
t.Errorf("expected configMode: %d, actual: %d", v.expectedConfigMode, ts.configMode)
}
}
})
}
func TestTokenSourceDeviceCode(t *testing.T) {
var (
clientID = "clientID"
tenantID = "tenantID"
apiserverID = "apiserverID"
configMode = configModeDefault
azureEnv = azure.Environment{}
)
t.Run("validate to create azureTokenSourceDeviceCode", func(t *testing.T) {
if _, err := newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, apiserverID, configModeDefault); err != nil {
t.Errorf("newAzureTokenSourceDeviceCode should not have failed. err: %s", err)
}
if _, err := newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, apiserverID, configModeOmitSPNPrefix); err != nil {
t.Errorf("newAzureTokenSourceDeviceCode should not have failed. err: %s", err)
}
_, err := newAzureTokenSourceDeviceCode(azureEnv, "", tenantID, apiserverID, configMode)
actual := "client-id is empty"
if err.Error() != actual {
t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err)
}
_, err = newAzureTokenSourceDeviceCode(azureEnv, clientID, "", apiserverID, configMode)
actual = "tenant-id is empty"
if err.Error() != actual {
t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err)
}
_, err = newAzureTokenSourceDeviceCode(azureEnv, clientID, tenantID, "", configMode)
actual = "apiserver-id is empty"
if err.Error() != actual {
t.Errorf("newAzureTokenSourceDeviceCode should have failed. expected: %s, actual: %s", actual, err)
}
})
}
func TestAzureTokenSource(t *testing.T) {
configModes := []configMode{configModeOmitSPNPrefix, configModeDefault}
expectedConfigModes := []string{"1", "0"}
for i, configMode := range configModes {
t.Run("validate token against cache", func(t *testing.T) {
fakeAccessToken := "fake token 1"
fakeSource := fakeTokenSource{token: newFakeAzureToken(fakeAccessToken, time.Now().Add(3600*time.Second))}
cfg := make(map[string]string)
persiter := &fakePersister{cache: make(map[string]string)}
tokenCache := newAzureTokenCache()
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, persiter)
token, err := tokenSource.Token()
if err != nil {
t.Errorf("failed to retrieve the token form cache: %v", err)
}
wantCacheLen := 1
if len(tokenCache.cache) != wantCacheLen {
t.Errorf("Token() cache length error: got %v, want %v", len(tokenCache.cache), wantCacheLen)
}
if token != tokenCache.cache[azureTokenKey] {
t.Error("Token() returned token != cached token")
}
wantCfg := token2Cfg(token)
wantCfg[cfgConfigMode] = expectedConfigModes[i]
persistedCfg := persiter.Cache()
wantCfgLen := len(wantCfg)
persistedCfgLen := len(persistedCfg)
if wantCfgLen != persistedCfgLen {
t.Errorf("wantCfgLen and persistedCfgLen do not match, wantCfgLen=%v, persistedCfgLen=%v", wantCfgLen, persistedCfgLen)
}
for k, v := range persistedCfg {
if strings.Compare(v, wantCfg[k]) != 0 {
t.Errorf("Token() persisted cfg %s: got %v, want %v", k, v, wantCfg[k])
}
}
fakeSource.token = newFakeAzureToken("fake token 2", time.Now().Add(3600*time.Second))
token, err = tokenSource.Token()
if err != nil {
t.Errorf("failed to retrieve the cached token: %v", err)
}
if token.token.AccessToken != fakeAccessToken {
t.Errorf("Token() didn't return the cached token")
}
})
}
}
func TestAzureTokenSourceScenarios(t *testing.T) {
configMode := configModeDefault
expiredToken := newFakeAzureToken("expired token", time.Now().Add(-time.Second))
extendedToken := newFakeAzureToken("extend token", time.Now().Add(1000*time.Second))
fakeToken := newFakeAzureToken("fake token", time.Now().Add(1000*time.Second))
wrongToken := newFakeAzureToken("wrong token", time.Now().Add(1000*time.Second))
tests := []struct {
name string
sourceToken *azureToken
refreshToken *azureToken
cachedToken *azureToken
configToken *azureToken
expectToken *azureToken
tokenErr error
refreshErr error
expectErr string
tokenCalls uint
refreshCalls uint
persistCalls uint
}{
{
name: "new config",
sourceToken: fakeToken,
expectToken: fakeToken,
tokenCalls: 1,
persistCalls: 1,
},
{
name: "load token from cache",
sourceToken: wrongToken,
cachedToken: fakeToken,
configToken: wrongToken,
expectToken: fakeToken,
},
{
name: "load token from config",
sourceToken: wrongToken,
configToken: fakeToken,
expectToken: fakeToken,
},
{
name: "cached token timeout, extend success, config token should never load",
cachedToken: expiredToken,
refreshToken: extendedToken,
configToken: wrongToken,
expectToken: extendedToken,
refreshCalls: 1,
persistCalls: 1,
},
{
name: "config token timeout, extend failure, acquire new token",
configToken: expiredToken,
refreshErr: fakeTokenRefreshError{message: "FakeError happened when refreshing"},
sourceToken: fakeToken,
expectToken: fakeToken,
refreshCalls: 1,
tokenCalls: 1,
persistCalls: 1,
},
{
name: "unexpected error when extend",
configToken: expiredToken,
refreshErr: errors.New("unexpected refresh error"),
sourceToken: fakeToken,
expectErr: "unexpected refresh error",
refreshCalls: 1,
},
{
name: "token error",
tokenErr: errors.New("tokenerr"),
expectErr: "tokenerr",
tokenCalls: 1,
},
{
name: "Token() got expired token",
sourceToken: expiredToken,
expectErr: "newly acquired token is expired",
tokenCalls: 1,
},
{
name: "Token() got nil but no error",
sourceToken: nil,
expectErr: "unable to acquire token",
tokenCalls: 1,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
persister := newFakePersister()
cfg := map[string]string{}
if tc.configToken != nil {
cfg = token2Cfg(tc.configToken)
}
tokenCache := newAzureTokenCache()
if tc.cachedToken != nil {
tokenCache.setToken(azureTokenKey, tc.cachedToken)
}
fakeSource := fakeTokenSource{
token: tc.sourceToken,
tokenErr: tc.tokenErr,
refreshToken: tc.refreshToken,
refreshErr: tc.refreshErr,
}
tokenSource := newAzureTokenSource(&fakeSource, tokenCache, cfg, configMode, &persister)
token, err := tokenSource.Token()
if fakeSource.tokenCalls != tc.tokenCalls {
t.Errorf("expecting tokenCalls: %v, got: %v", tc.tokenCalls, fakeSource.tokenCalls)
}
if fakeSource.refreshCalls != tc.refreshCalls {
t.Errorf("expecting refreshCalls: %v, got: %v", tc.refreshCalls, fakeSource.refreshCalls)
}
if persister.calls != tc.persistCalls {
t.Errorf("expecting persister calls: %v, got: %v", tc.persistCalls, persister.calls)
}
if tc.expectErr != "" {
if !strings.Contains(err.Error(), tc.expectErr) {
t.Errorf("expecting error %v, got %v", tc.expectErr, err)
}
if token != nil {
t.Errorf("token should be nil in err situation, got %v", token)
}
} else {
if err != nil {
t.Fatalf("error should be nil, got %v", err)
}
if token.token.AccessToken != tc.expectToken.token.AccessToken {
t.Errorf("token should have accessToken %v, got %v", token.token.AccessToken, tc.expectToken.token.AccessToken)
}
}
})
}
}
type fakePersister struct {
lock sync.Mutex
cache map[string]string
calls uint
}
func newFakePersister() fakePersister {
return fakePersister{cache: make(map[string]string), calls: 0}
}
func (p *fakePersister) Persist(cache map[string]string) error {
p.lock.Lock()
defer p.lock.Unlock()
p.calls++
p.cache = map[string]string{}
for k, v := range cache {
p.cache[k] = v
}
return nil
}
func (p *fakePersister) Cache() map[string]string {
ret := map[string]string{}
p.lock.Lock()
defer p.lock.Unlock()
for k, v := range p.cache {
ret[k] = v
}
return ret
}
// a simple token source simply always returns the token property
type fakeTokenSource struct {
token *azureToken
tokenCalls uint
tokenErr error
refreshToken *azureToken
refreshCalls uint
refreshErr error
}
func (ts *fakeTokenSource) Token() (*azureToken, error) {
ts.tokenCalls++
return ts.token, ts.tokenErr
}
func (ts *fakeTokenSource) Refresh(*azureToken) (*azureToken, error) {
ts.refreshCalls++
return ts.refreshToken, ts.refreshErr
}
func token2Cfg(token *azureToken) map[string]string {
cfg := make(map[string]string)
cfg[cfgAccessToken] = token.token.AccessToken
cfg[cfgRefreshToken] = token.token.RefreshToken
cfg[cfgEnvironment] = token.environment
cfg[cfgClientID] = token.clientID
cfg[cfgTenantID] = token.tenantID
cfg[cfgApiserverID] = token.apiserverID
cfg[cfgExpiresIn] = string(token.token.ExpiresIn)
cfg[cfgExpiresOn] = string(token.token.ExpiresOn)
return cfg
}
func newFakeAzureToken(accessToken string, expiresOnTime time.Time) *azureToken {
return &azureToken{
token: newFakeADALToken(accessToken, strconv.FormatInt(expiresOnTime.Unix(), 10)),
environment: "testenv",
clientID: "fake",
tenantID: "fake",
apiserverID: "fake",
}
}
func newFakeADALToken(accessToken string, expiresOn string) adal.Token {
return adal.Token{
AccessToken: accessToken,
RefreshToken: "fake",
ExpiresIn: "3600",
ExpiresOn: json.Number(expiresOn),
NotBefore: json.Number(expiresOn),
Resource: "fake",
Type: "fake",
}
}
// copied from go-autorest/adal
type fakeTokenRefreshError struct {
message string
resp *http.Response
}
// Error implements the error interface which is part of the TokenRefreshError interface.
func (tre fakeTokenRefreshError) Error() string {
return tre.message
}
// Response implements the TokenRefreshError interface, it returns the raw HTTP response from the refresh operation.
func (tre fakeTokenRefreshError) Response() *http.Response {
return tre.resp
}