reset token if got Unauthorized in KCM

Kubernetes-commit: 5e7b60ba5fe218d8ac59496350dfdb9f43785d98
This commit is contained in:
Shihang Zhang 2021-02-22 15:38:18 -08:00 committed by Kubernetes Publisher
parent c3c3d406ff
commit 3d6ec322f2
2 changed files with 125 additions and 14 deletions

View File

@ -43,9 +43,29 @@ func TokenSourceWrapTransport(ts oauth2.TokenSource) func(http.RoundTripper) htt
} }
} }
// NewCachedFileTokenSource returns a oauth2.TokenSource reads a token from a type ResettableTokenSource interface {
// file at a specified path and periodically reloads it. oauth2.TokenSource
func NewCachedFileTokenSource(path string) oauth2.TokenSource { ResetTokenOlderThan(time.Time)
}
// ResettableTokenSourceWrapTransport returns a WrapTransport that injects bearer tokens
// authentication from an ResettableTokenSource.
func ResettableTokenSourceWrapTransport(ts ResettableTokenSource) func(http.RoundTripper) http.RoundTripper {
return func(rt http.RoundTripper) http.RoundTripper {
return &tokenSourceTransport{
base: rt,
ort: &oauth2.Transport{
Source: ts,
Base: rt,
},
src: ts,
}
}
}
// NewCachedFileTokenSource returns a resettable token source which reads a
// token from a file at a specified path and periodically reloads it.
func NewCachedFileTokenSource(path string) *cachingTokenSource {
return &cachingTokenSource{ return &cachingTokenSource{
now: time.Now, now: time.Now,
leeway: 10 * time.Second, leeway: 10 * time.Second,
@ -60,9 +80,9 @@ func NewCachedFileTokenSource(path string) oauth2.TokenSource {
} }
} }
// NewCachedTokenSource returns a oauth2.TokenSource reads a token from a // NewCachedTokenSource returns resettable token source with caching. It reads
// designed TokenSource. The ts would provide the source of token. // a token from a designed TokenSource if not in cache or expired.
func NewCachedTokenSource(ts oauth2.TokenSource) oauth2.TokenSource { func NewCachedTokenSource(ts oauth2.TokenSource) *cachingTokenSource {
return &cachingTokenSource{ return &cachingTokenSource{
now: time.Now, now: time.Now,
base: ts, base: ts,
@ -72,6 +92,7 @@ func NewCachedTokenSource(ts oauth2.TokenSource) oauth2.TokenSource {
type tokenSourceTransport struct { type tokenSourceTransport struct {
base http.RoundTripper base http.RoundTripper
ort http.RoundTripper ort http.RoundTripper
src ResettableTokenSource
} }
func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, error) { func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
@ -79,7 +100,15 @@ func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, e
if req.Header.Get("Authorization") != "" { if req.Header.Get("Authorization") != "" {
return tst.base.RoundTrip(req) return tst.base.RoundTrip(req)
} }
return tst.ort.RoundTrip(req) // record time before RoundTrip to make sure newly acquired Unauthorized
// token would not be reset. Another request from user is required to reset
// and proceed.
start := time.Now()
resp, err := tst.ort.RoundTrip(req)
if err == nil && resp != nil && resp.StatusCode == 401 && tst.src != nil {
tst.src.ResetTokenOlderThan(start)
}
return resp, err
} }
func (tst *tokenSourceTransport) CancelRequest(req *http.Request) { func (tst *tokenSourceTransport) CancelRequest(req *http.Request) {
@ -119,13 +148,12 @@ type cachingTokenSource struct {
sync.RWMutex sync.RWMutex
tok *oauth2.Token tok *oauth2.Token
t time.Time
// for testing // for testing
now func() time.Time now func() time.Time
} }
var _ = oauth2.TokenSource(&cachingTokenSource{})
func (ts *cachingTokenSource) Token() (*oauth2.Token, error) { func (ts *cachingTokenSource) Token() (*oauth2.Token, error) {
now := ts.now() now := ts.now()
// fast path // fast path
@ -153,6 +181,16 @@ func (ts *cachingTokenSource) Token() (*oauth2.Token, error) {
return ts.tok, nil return ts.tok, nil
} }
ts.t = ts.now()
ts.tok = tok ts.tok = tok
return tok, nil return tok, nil
} }
func (ts *cachingTokenSource) ResetTokenOlderThan(t time.Time) {
ts.Lock()
defer ts.Unlock()
if ts.t.Before(t) {
ts.tok = nil
ts.t = time.Time{}
}
}

View File

@ -156,6 +156,76 @@ func TestCachingTokenSourceRace(t *testing.T) {
} }
} }
func TestTokenSourceTransportRoundTrip(t *testing.T) {
goodToken := &oauth2.Token{
AccessToken: "good",
Expiry: time.Now().Add(1000 * time.Hour),
}
badToken := &oauth2.Token{
AccessToken: "bad",
Expiry: time.Now().Add(1000 * time.Hour),
}
tests := []struct {
name string
header http.Header
token *oauth2.Token
cachedToken *oauth2.Token
wantCalls int
wantCaching bool
}{
{
name: "skip oauth rt if has authorization header",
header: map[string][]string{"Authorization": {"Bearer TOKEN"}},
token: goodToken,
},
{
name: "authorized on newly acquired good token",
token: goodToken,
wantCalls: 1,
wantCaching: true,
},
{
name: "authorized on cached good token",
token: goodToken,
cachedToken: goodToken,
wantCalls: 0,
wantCaching: true,
},
{
name: "unauthorized on newly acquired bad token",
token: badToken,
wantCalls: 1,
wantCaching: true,
},
{
name: "unauthorized on cached bad token",
token: badToken,
cachedToken: badToken,
wantCalls: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
tts := &testTokenSource{
tok: test.token,
}
cachedTokenSource := NewCachedTokenSource(tts)
cachedTokenSource.tok = test.cachedToken
rt := ResettableTokenSourceWrapTransport(cachedTokenSource)(&testTransport{})
rt.RoundTrip(&http.Request{Header: test.header})
if tts.calls != test.wantCalls {
t.Errorf("RoundTrip() called Token() = %d times, want %d", tts.calls, test.wantCalls)
}
if (cachedTokenSource.tok != nil) != test.wantCaching {
t.Errorf("Got caching %v, want caching %v", cachedTokenSource != nil, test.wantCaching)
}
})
}
}
type uncancellableRT struct { type uncancellableRT struct {
rt http.RoundTripper rt http.RoundTripper
} }
@ -164,7 +234,7 @@ func (urt *uncancellableRT) RoundTrip(req *http.Request) (*http.Response, error)
return urt.rt.RoundTrip(req) return urt.rt.RoundTrip(req)
} }
func TestCancellation(t *testing.T) { func TestTokenSourceTransportCancelRequest(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
header http.Header header http.Header
@ -186,7 +256,7 @@ func TestCancellation(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
baseRecorder := &recordCancelRoundTripper{} baseRecorder := &testTransport{}
var base http.RoundTripper = baseRecorder var base http.RoundTripper = baseRecorder
if test.wrapTransport != nil { if test.wrapTransport != nil {
@ -211,16 +281,19 @@ func TestCancellation(t *testing.T) {
} }
} }
type recordCancelRoundTripper struct { type testTransport struct {
canceled bool canceled bool
base http.RoundTripper base http.RoundTripper
} }
func (rt *recordCancelRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (rt *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.Header["Authorization"][0] == "Bearer bad" {
return &http.Response{StatusCode: 401}, nil
}
return nil, nil return nil, nil
} }
func (rt *recordCancelRoundTripper) CancelRequest(req *http.Request) { func (rt *testTransport) CancelRequest(req *http.Request) {
rt.canceled = true rt.canceled = true
if rt.base != nil { if rt.base != nil {
tryCancelRequest(rt.base, req) tryCancelRequest(rt.base, req)