diff --git a/plugin/pkg/client/auth/gcp/gcp.go b/plugin/pkg/client/auth/gcp/gcp.go index 97f9912dadb..df5c623f947 100644 --- a/plugin/pkg/client/auth/gcp/gcp.go +++ b/plugin/pkg/client/auth/gcp/gcp.go @@ -23,6 +23,7 @@ import ( "net/http" "os/exec" "strings" + "sync" "time" "github.com/golang/glog" @@ -74,6 +75,7 @@ func (g *gcpAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper func (g *gcpAuthProvider) Login() error { return nil } type cachedTokenSource struct { + lk sync.Mutex source oauth2.TokenSource accessToken string expiry time.Time @@ -99,11 +101,7 @@ func newCachedTokenSource(accessToken, expiry string, persister restclient.AuthP } func (t *cachedTokenSource) Token() (*oauth2.Token, error) { - tok := &oauth2.Token{ - AccessToken: t.accessToken, - TokenType: "Bearer", - Expiry: t.expiry, - } + tok := t.cachedToken() if tok.Valid() && !tok.Expiry.IsZero() { return tok, nil } @@ -111,16 +109,39 @@ func (t *cachedTokenSource) Token() (*oauth2.Token, error) { if err != nil { return nil, err } + cache := t.update(tok) if t.persister != nil { - t.cache["access-token"] = tok.AccessToken - t.cache["expiry"] = tok.Expiry.Format(time.RFC3339Nano) - if err := t.persister.Persist(t.cache); err != nil { + if err := t.persister.Persist(cache); err != nil { glog.V(4).Infof("Failed to persist token: %v", err) } } return tok, nil } +func (t *cachedTokenSource) cachedToken() *oauth2.Token { + t.lk.Lock() + defer t.lk.Unlock() + return &oauth2.Token{ + AccessToken: t.accessToken, + TokenType: "Bearer", + Expiry: t.expiry, + } +} + +func (t *cachedTokenSource) update(tok *oauth2.Token) map[string]string { + t.lk.Lock() + defer t.lk.Unlock() + t.accessToken = tok.AccessToken + t.expiry = tok.Expiry + ret := map[string]string{} + for k, v := range t.cache { + ret[k] = v + } + ret["access-token"] = t.accessToken + ret["expiry"] = t.expiry.Format(time.RFC3339Nano) + return ret +} + type commandTokenSource struct { cmd string args []string diff --git a/plugin/pkg/client/auth/gcp/gcp_test.go b/plugin/pkg/client/auth/gcp/gcp_test.go index dfd25bbf1b7..cb64b8380ca 100644 --- a/plugin/pkg/client/auth/gcp/gcp_test.go +++ b/plugin/pkg/client/auth/gcp/gcp_test.go @@ -20,6 +20,7 @@ import ( "fmt" "reflect" "strings" + "sync" "testing" "time" @@ -141,3 +142,70 @@ func TestCmdTokenSource(t *testing.T) { } } } + +type fakePersister struct { + lk sync.Mutex + cache map[string]string +} + +func (f *fakePersister) Persist(cache map[string]string) error { + f.lk.Lock() + defer f.lk.Unlock() + f.cache = map[string]string{} + for k, v := range cache { + f.cache[k] = v + } + return nil +} + +func (f *fakePersister) read() map[string]string { + ret := map[string]string{} + f.lk.Lock() + for k, v := range f.cache { + ret[k] = v + } + return ret +} + +type fakeTokenSource struct { + token *oauth2.Token + err error +} + +func (f *fakeTokenSource) Token() (*oauth2.Token, error) { + return f.token, f.err +} + +func TestCachedTokenSource(t *testing.T) { + tok := &oauth2.Token{AccessToken: "fakeaccesstoken"} + persister := &fakePersister{} + source := &fakeTokenSource{ + token: tok, + err: nil, + } + cache := map[string]string{ + "foo": "bar", + "baz": "bazinga", + } + ts, err := newCachedTokenSource("fakeaccesstoken", "", persister, source, cache) + if err != nil { + t.Fatal(err) + } + var wg sync.WaitGroup + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + _, err := ts.Token() + if err != nil { + t.Errorf("unexpected error: %s", err) + } + wg.Done() + }() + } + wg.Wait() + cache["access-token"] = "fakeaccesstoken" + cache["expiry"] = tok.Expiry.Format(time.RFC3339Nano) + if got := persister.read(); !reflect.DeepEqual(got, cache) { + t.Errorf("got cache %v, want %v", got, cache) + } +}