gcp client auth plugin: persist default cache on unauthorized

The default cache for a cachedTokenSource is not always empty. In the
case of commandTokenSource, it contains calling details for the
external command that is used to generate refresh tokens. Persisting
a completely empty cache will thus break ability for the plugin to
obtain refresh tokens. This changes the roundtripper to persist
the default cache instead of assuming an empty map.
This commit is contained in:
Jeff Lowdermilk 2018-07-17 14:06:11 -07:00
parent 25cbd1c753
commit 73e5e43711
2 changed files with 72 additions and 23 deletions

View File

@ -174,7 +174,13 @@ func parseScopes(gcpConfig map[string]string) []string {
} }
func (g *gcpAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper { func (g *gcpAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}, g.persister} var resetCache map[string]string
if cts, ok := g.tokenSource.(*cachedTokenSource); ok {
resetCache = cts.baseCache()
} else {
resetCache = make(map[string]string)
}
return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}, g.persister, resetCache}
} }
func (g *gcpAuthProvider) Login() error { return nil } func (g *gcpAuthProvider) Login() error { return nil }
@ -247,6 +253,19 @@ func (t *cachedTokenSource) update(tok *oauth2.Token) map[string]string {
return ret return ret
} }
// baseCache is the base configuration value for this TokenSource, without any cached ephemeral tokens.
func (t *cachedTokenSource) baseCache() map[string]string {
t.lk.Lock()
defer t.lk.Unlock()
ret := map[string]string{}
for k, v := range t.cache {
ret[k] = v
}
delete(ret, "access-token")
delete(ret, "expiry")
return ret
}
type commandTokenSource struct { type commandTokenSource struct {
cmd string cmd string
args []string args []string
@ -337,6 +356,7 @@ func parseJSONPath(input interface{}, name, template string) (string, error) {
type conditionalTransport struct { type conditionalTransport struct {
oauthTransport *oauth2.Transport oauthTransport *oauth2.Transport
persister restclient.AuthProviderConfigPersister persister restclient.AuthProviderConfigPersister
resetCache map[string]string
} }
var _ net.RoundTripperWrapper = &conditionalTransport{} var _ net.RoundTripperWrapper = &conditionalTransport{}
@ -354,8 +374,7 @@ func (t *conditionalTransport) RoundTrip(req *http.Request) (*http.Response, err
if res.StatusCode == 401 { if res.StatusCode == 401 {
glog.V(4).Infof("The credentials that were supplied are invalid for the target cluster") glog.V(4).Infof("The credentials that were supplied are invalid for the target cluster")
emptyCache := make(map[string]string) t.persister.Persist(t.resetCache)
t.persister.Persist(emptyCache)
} }
return res, nil return res, nil

View File

@ -442,37 +442,61 @@ func (t *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.res, nil return t.res, nil
} }
func TestClearingCredentials(t *testing.T) { func Test_cmdTokenSource_roundTrip(t *testing.T) {
accessToken := "fakeToken"
fakeExpiry := time.Now().Add(time.Hour) fakeExpiry := time.Now().Add(time.Hour)
fakeExpiryStr := fakeExpiry.Format(time.RFC3339Nano)
cache := map[string]string{ fs := &fakeTokenSource{
"access-token": "fakeToken", token: &oauth2.Token{
"expiry": fakeExpiry.String(), AccessToken: accessToken,
Expiry: fakeExpiry,
},
} }
cts := cachedTokenSource{ cmdCache := map[string]string{
source: nil, "cmd-path": "/path/to/tokensource/cmd",
accessToken: cache["access-token"], "cmd-args": "--output=json",
expiry: fakeExpiry, }
persister: nil, cmdCacheUpdated := map[string]string{
cache: nil, "cmd-path": "/path/to/tokensource/cmd",
"cmd-args": "--output=json",
"access-token": accessToken,
"expiry": fakeExpiryStr,
}
simpleCacheUpdated := map[string]string{
"access-token": accessToken,
"expiry": fakeExpiryStr,
} }
tests := []struct { tests := []struct {
name string name string
res http.Response res http.Response
cache map[string]string baseCache, expectedCache map[string]string
}{ }{
{ {
"Unauthorized", "Unauthorized",
http.Response{StatusCode: 401}, http.Response{StatusCode: 401},
make(map[string]string), make(map[string]string),
make(map[string]string),
},
{
"Unauthorized, nonempty defaultCache",
http.Response{StatusCode: 401},
cmdCache,
cmdCache,
}, },
{ {
"Authorized", "Authorized",
http.Response{StatusCode: 200}, http.Response{StatusCode: 200},
cache, make(map[string]string),
simpleCacheUpdated,
},
{
"Authorized, nonempty defaultCache",
http.Response{StatusCode: 200},
cmdCache,
cmdCacheUpdated,
}, },
} }
@ -480,17 +504,23 @@ func TestClearingCredentials(t *testing.T) {
req := http.Request{Header: http.Header{}} req := http.Request{Header: http.Header{}}
for _, tc := range tests { for _, tc := range tests {
authProvider := gcpAuthProvider{&cts, persister} cts, err := newCachedTokenSource(accessToken, fakeExpiry.String(), persister, fs, tc.baseCache)
if err != nil {
t.Fatalf("unexpected error from newCachedTokenSource: %v", err)
}
authProvider := gcpAuthProvider{cts, persister}
fakeTransport := MockTransport{&tc.res} fakeTransport := MockTransport{&tc.res}
transport := (authProvider.WrapTransport(&fakeTransport)) transport := (authProvider.WrapTransport(&fakeTransport))
persister.Persist(cache) // call Token to persist/update cache
if _, err := cts.Token(); err != nil {
t.Fatalf("unexpected error from cachedTokenSource.Token(): %v", err)
}
transport.RoundTrip(&req) transport.RoundTrip(&req)
if got := persister.read(); !reflect.DeepEqual(got, tc.cache) { if got := persister.read(); !reflect.DeepEqual(got, tc.expectedCache) {
t.Errorf("got cache %v, want %v", got, tc.cache) t.Errorf("got cache %v, want %v", got, tc.expectedCache)
} }
} }