mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-30 15:05:27 +00:00
gcp client auth plugin: persist default cache on unauthorized
The default cache for a cachedTokenSource is not always empty. In the case of commandTokenSource, it contains calling details for the external command that is used to generate refresh tokens. Persisting a completely empty cache will thus break ability for the plugin to obtain refresh tokens. This changes the roundtripper to persist the default cache instead of assuming an empty map.
This commit is contained in:
parent
25cbd1c753
commit
73e5e43711
@ -174,7 +174,13 @@ func parseScopes(gcpConfig map[string]string) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *gcpAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
|
func (g *gcpAuthProvider) WrapTransport(rt http.RoundTripper) http.RoundTripper {
|
||||||
return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}, g.persister}
|
var resetCache map[string]string
|
||||||
|
if cts, ok := g.tokenSource.(*cachedTokenSource); ok {
|
||||||
|
resetCache = cts.baseCache()
|
||||||
|
} else {
|
||||||
|
resetCache = make(map[string]string)
|
||||||
|
}
|
||||||
|
return &conditionalTransport{&oauth2.Transport{Source: g.tokenSource, Base: rt}, g.persister, resetCache}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *gcpAuthProvider) Login() error { return nil }
|
func (g *gcpAuthProvider) Login() error { return nil }
|
||||||
@ -247,6 +253,19 @@ func (t *cachedTokenSource) update(tok *oauth2.Token) map[string]string {
|
|||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// baseCache is the base configuration value for this TokenSource, without any cached ephemeral tokens.
|
||||||
|
func (t *cachedTokenSource) baseCache() map[string]string {
|
||||||
|
t.lk.Lock()
|
||||||
|
defer t.lk.Unlock()
|
||||||
|
ret := map[string]string{}
|
||||||
|
for k, v := range t.cache {
|
||||||
|
ret[k] = v
|
||||||
|
}
|
||||||
|
delete(ret, "access-token")
|
||||||
|
delete(ret, "expiry")
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
type commandTokenSource struct {
|
type commandTokenSource struct {
|
||||||
cmd string
|
cmd string
|
||||||
args []string
|
args []string
|
||||||
@ -337,6 +356,7 @@ func parseJSONPath(input interface{}, name, template string) (string, error) {
|
|||||||
type conditionalTransport struct {
|
type conditionalTransport struct {
|
||||||
oauthTransport *oauth2.Transport
|
oauthTransport *oauth2.Transport
|
||||||
persister restclient.AuthProviderConfigPersister
|
persister restclient.AuthProviderConfigPersister
|
||||||
|
resetCache map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ net.RoundTripperWrapper = &conditionalTransport{}
|
var _ net.RoundTripperWrapper = &conditionalTransport{}
|
||||||
@ -354,8 +374,7 @@ func (t *conditionalTransport) RoundTrip(req *http.Request) (*http.Response, err
|
|||||||
|
|
||||||
if res.StatusCode == 401 {
|
if res.StatusCode == 401 {
|
||||||
glog.V(4).Infof("The credentials that were supplied are invalid for the target cluster")
|
glog.V(4).Infof("The credentials that were supplied are invalid for the target cluster")
|
||||||
emptyCache := make(map[string]string)
|
t.persister.Persist(t.resetCache)
|
||||||
t.persister.Persist(emptyCache)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
|
@ -442,37 +442,61 @@ func (t *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||||||
return t.res, nil
|
return t.res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClearingCredentials(t *testing.T) {
|
func Test_cmdTokenSource_roundTrip(t *testing.T) {
|
||||||
|
|
||||||
|
accessToken := "fakeToken"
|
||||||
fakeExpiry := time.Now().Add(time.Hour)
|
fakeExpiry := time.Now().Add(time.Hour)
|
||||||
|
fakeExpiryStr := fakeExpiry.Format(time.RFC3339Nano)
|
||||||
cache := map[string]string{
|
fs := &fakeTokenSource{
|
||||||
"access-token": "fakeToken",
|
token: &oauth2.Token{
|
||||||
"expiry": fakeExpiry.String(),
|
AccessToken: accessToken,
|
||||||
|
Expiry: fakeExpiry,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
cts := cachedTokenSource{
|
cmdCache := map[string]string{
|
||||||
source: nil,
|
"cmd-path": "/path/to/tokensource/cmd",
|
||||||
accessToken: cache["access-token"],
|
"cmd-args": "--output=json",
|
||||||
expiry: fakeExpiry,
|
}
|
||||||
persister: nil,
|
cmdCacheUpdated := map[string]string{
|
||||||
cache: nil,
|
"cmd-path": "/path/to/tokensource/cmd",
|
||||||
|
"cmd-args": "--output=json",
|
||||||
|
"access-token": accessToken,
|
||||||
|
"expiry": fakeExpiryStr,
|
||||||
|
}
|
||||||
|
simpleCacheUpdated := map[string]string{
|
||||||
|
"access-token": accessToken,
|
||||||
|
"expiry": fakeExpiryStr,
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
res http.Response
|
res http.Response
|
||||||
cache map[string]string
|
baseCache, expectedCache map[string]string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"Unauthorized",
|
"Unauthorized",
|
||||||
http.Response{StatusCode: 401},
|
http.Response{StatusCode: 401},
|
||||||
make(map[string]string),
|
make(map[string]string),
|
||||||
|
make(map[string]string),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Unauthorized, nonempty defaultCache",
|
||||||
|
http.Response{StatusCode: 401},
|
||||||
|
cmdCache,
|
||||||
|
cmdCache,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"Authorized",
|
"Authorized",
|
||||||
http.Response{StatusCode: 200},
|
http.Response{StatusCode: 200},
|
||||||
cache,
|
make(map[string]string),
|
||||||
|
simpleCacheUpdated,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Authorized, nonempty defaultCache",
|
||||||
|
http.Response{StatusCode: 200},
|
||||||
|
cmdCache,
|
||||||
|
cmdCacheUpdated,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -480,17 +504,23 @@ func TestClearingCredentials(t *testing.T) {
|
|||||||
req := http.Request{Header: http.Header{}}
|
req := http.Request{Header: http.Header{}}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
authProvider := gcpAuthProvider{&cts, persister}
|
cts, err := newCachedTokenSource(accessToken, fakeExpiry.String(), persister, fs, tc.baseCache)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error from newCachedTokenSource: %v", err)
|
||||||
|
}
|
||||||
|
authProvider := gcpAuthProvider{cts, persister}
|
||||||
|
|
||||||
fakeTransport := MockTransport{&tc.res}
|
fakeTransport := MockTransport{&tc.res}
|
||||||
|
|
||||||
transport := (authProvider.WrapTransport(&fakeTransport))
|
transport := (authProvider.WrapTransport(&fakeTransport))
|
||||||
persister.Persist(cache)
|
// call Token to persist/update cache
|
||||||
|
if _, err := cts.Token(); err != nil {
|
||||||
|
t.Fatalf("unexpected error from cachedTokenSource.Token(): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
transport.RoundTrip(&req)
|
transport.RoundTrip(&req)
|
||||||
|
|
||||||
if got := persister.read(); !reflect.DeepEqual(got, tc.cache) {
|
if got := persister.read(); !reflect.DeepEqual(got, tc.expectedCache) {
|
||||||
t.Errorf("got cache %v, want %v", got, tc.cache)
|
t.Errorf("got cache %v, want %v", got, tc.expectedCache)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user