gcp client auth plugin: persist default cache on unauthorized

The default cache for a cachedTokenSource is not always empty. In the
case of commandTokenSource, it contains calling details for the
external command that is used to generate refresh tokens. Persisting
a completely empty cache will thus break ability for the plugin to
obtain refresh tokens. This changes the roundtripper to persist
the default cache instead of assuming an empty map.
This commit is contained in:
Jeff Lowdermilk 2018-07-17 14:06:11 -07:00
parent 25cbd1c753
commit 73e5e43711
2 changed files with 72 additions and 23 deletions

View File

@ -174,7 +174,13 @@ func parseScopes(gcpConfig map[string]string) []string {
}
func (g *gcpAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}, g.persister}
var resetCache map[string]string
if cts, ok := g.tokenSource.(*cachedTokenSource); ok {
resetCache = cts.baseCache()
} else {
resetCache = make(map[string]string)
}
return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}, g.persister, resetCache}
}
func (g *gcpAuthProvider) Login() error { return nil }
@ -247,6 +253,19 @@ func (t *cachedTokenSource) update(tok *oauth2.Token) map[string]string {
return ret
}
// baseCache is the base configuration value for this TokenSource, without any cached ephemeral tokens.
func (t *cachedTokenSource) baseCache() map[string]string {
t.lk.Lock()
defer t.lk.Unlock()
ret := map[string]string{}
for k, v := range t.cache {
ret[k] = v
}
delete(ret, "access-token")
delete(ret, "expiry")
return ret
}
type commandTokenSource struct {
cmd string
args []string
@ -337,6 +356,7 @@ func parseJSONPath(input interface{}, name, template string) (string, error) {
type conditionalTransport struct {
oauthTransport *oauth2.Transport
persister restclient.AuthProviderConfigPersister
resetCache map[string]string
}
var _ net.RoundTripperWrapper = &conditionalTransport{}
@ -354,8 +374,7 @@ func (t *conditionalTransport) RoundTrip(req *http.Request) (*http.Response, err
if res.StatusCode == 401 {
glog.V(4).Infof("The credentials that were supplied are invalid for the target cluster")
emptyCache := make(map[string]string)
t.persister.Persist(emptyCache)
t.persister.Persist(t.resetCache)
}
return res, nil

View File

@ -442,37 +442,61 @@ func (t *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.res, nil
}
func TestClearingCredentials(t *testing.T) {
func Test_cmdTokenSource_roundTrip(t *testing.T) {
accessToken := "fakeToken"
fakeExpiry := time.Now().Add(time.Hour)
cache := map[string]string{
"access-token": "fakeToken",
"expiry": fakeExpiry.String(),
fakeExpiryStr := fakeExpiry.Format(time.RFC3339Nano)
fs := &fakeTokenSource{
token: &oauth2.Token{
AccessToken: accessToken,
Expiry: fakeExpiry,
},
}
cts := cachedTokenSource{
source: nil,
accessToken: cache["access-token"],
expiry: fakeExpiry,
persister: nil,
cache: nil,
cmdCache := map[string]string{
"cmd-path": "/path/to/tokensource/cmd",
"cmd-args": "--output=json",
}
cmdCacheUpdated := map[string]string{
"cmd-path": "/path/to/tokensource/cmd",
"cmd-args": "--output=json",
"access-token": accessToken,
"expiry": fakeExpiryStr,
}
simpleCacheUpdated := map[string]string{
"access-token": accessToken,
"expiry": fakeExpiryStr,
}
tests := []struct {
name string
res http.Response
cache map[string]string
name string
res http.Response
baseCache, expectedCache map[string]string
}{
{
"Unauthorized",
http.Response{StatusCode: 401},
make(map[string]string),
make(map[string]string),
},
{
"Unauthorized, nonempty defaultCache",
http.Response{StatusCode: 401},
cmdCache,
cmdCache,
},
{
"Authorized",
http.Response{StatusCode: 200},
cache,
make(map[string]string),
simpleCacheUpdated,
},
{
"Authorized, nonempty defaultCache",
http.Response{StatusCode: 200},
cmdCache,
cmdCacheUpdated,
},
}
@ -480,17 +504,23 @@ func TestClearingCredentials(t *testing.T) {
req := http.Request{Header: http.Header{}}
for _, tc := range tests {
authProvider := gcpAuthProvider{&cts, persister}
cts, err := newCachedTokenSource(accessToken, fakeExpiry.String(), persister, fs, tc.baseCache)
if err != nil {
t.Fatalf("unexpected error from newCachedTokenSource: %v", err)
}
authProvider := gcpAuthProvider{cts, persister}
fakeTransport := MockTransport{&tc.res}
transport := (authProvider.WrapTransport(&fakeTransport))
persister.Persist(cache)
// call Token to persist/update cache
if _, err := cts.Token(); err != nil {
t.Fatalf("unexpected error from cachedTokenSource.Token(): %v", err)
}
transport.RoundTrip(&req)
if got := persister.read(); !reflect.DeepEqual(got, tc.cache) {
t.Errorf("got cache %v, want %v", got, tc.cache)
if got := persister.read(); !reflect.DeepEqual(got, tc.expectedCache) {
t.Errorf("got cache %v, want %v", got, tc.expectedCache)
}
}