diff --git a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/gcp/gcp.go b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/gcp/gcp.go index f5d2899d91f..59a27bea7a4 100644 --- a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/gcp/gcp.go +++ b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/gcp/gcp.go @@ -124,7 +124,7 @@ func newGCPAuthProvider(_ string, gcpConfig map[string]string, persister restcli } func (g *gcpAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper { - return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}} + return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}, g.persister} } func (g *gcpAuthProvider) Login() error { return nil } @@ -284,11 +284,25 @@ func parseJSONPath(input interface{}, name, template string) (string, error) { type conditionalTransport struct { oauthTransport *oauth2.Transport + persister restclient.AuthProviderConfigPersister } func (t *conditionalTransport) RoundTrip(req *http.Request) (*http.Response, error) { if len(req.Header.Get("Authorization")) != 0 { return t.oauthTransport.Base.RoundTrip(req) } - return t.oauthTransport.RoundTrip(req) + + res, err := t.oauthTransport.RoundTrip(req) + + if err != nil { + return nil, err + } + + if res.StatusCode == 401 { + glog.V(4).Infof("The credentials that were supplied are invalid for the target cluster") + emptyCache := make(map[string]string) + t.persister.Persist(emptyCache) + } + + return res, nil } diff --git a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/gcp/gcp_test.go b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/gcp/gcp_test.go index 755fbcd8f48..662d38b8f19 100644 --- a/staging/src/k8s.io/client-go/plugin/pkg/client/auth/gcp/gcp_test.go +++ b/staging/src/k8s.io/client-go/plugin/pkg/client/auth/gcp/gcp_test.go @@ -18,6 +18,7 @@ package gcp import ( "fmt" + "net/http" "os" "os/exec" "reflect" @@ -323,3 +324,65 @@ func TestCachedTokenSource(t *testing.T) { t.Errorf("got cache %v, want %v", got, cache) } } + +type MockTransport struct { + res *http.Response +} + +func (t *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.res, nil +} + +func TestClearingCredentials(t *testing.T) { + + fakeExpiry := time.Now().Add(time.Hour) + + cache := map[string]string{ + "access-token": "fakeToken", + "expiry": fakeExpiry.String(), + } + + cts := cachedTokenSource{ + source: nil, + accessToken: cache["access-token"], + expiry: fakeExpiry, + persister: nil, + cache: nil, + } + + tests := []struct { + name string + res http.Response + cache map[string]string + }{ + { + "Unauthorized", + http.Response{StatusCode: 401}, + make(map[string]string), + }, + { + "Authorized", + http.Response{StatusCode: 200}, + cache, + }, + } + + persister := &fakePersister{} + req := http.Request{Header: http.Header{}} + + for _, tc := range tests { + authProvider := gcpAuthProvider{&cts, persister} + + fakeTransport := MockTransport{&tc.res} + + transport := (authProvider.WrapTransport(&fakeTransport)) + persister.Persist(cache) + + transport.RoundTrip(&req) + + if got := persister.read(); !reflect.DeepEqual(got, tc.cache) { + t.Errorf("got cache %v, want %v", got, tc.cache) + } + } + +}