reset token if got Unauthorized in KCM

Kubernetes-commit: 5e7b60ba5fe218d8ac59496350dfdb9f43785d98
This commit is contained in:
Shihang Zhang 2021-02-22 15:38:18 -08:00 committed by Kubernetes Publisher
parent c3c3d406ff
commit 3d6ec322f2
2 changed files with 125 additions and 14 deletions

View File

@ -43,9 +43,29 @@ func TokenSourceWrapTransport(ts oauth2.TokenSource) func(http.RoundTripper) htt
}
}
// NewCachedFileTokenSource returns a oauth2.TokenSource reads a token from a
// file at a specified path and periodically reloads it.
func NewCachedFileTokenSource(path string) oauth2.TokenSource {
type ResettableTokenSource interface {
oauth2.TokenSource
ResetTokenOlderThan(time.Time)
}
// ResettableTokenSourceWrapTransport returns a WrapTransport that injects bearer tokens
// authentication from an ResettableTokenSource.
func ResettableTokenSourceWrapTransport(ts ResettableTokenSource) func(http.RoundTripper) http.RoundTripper {
return func(rt http.RoundTripper) http.RoundTripper {
return &tokenSourceTransport{
base: rt,
ort: &oauth2.Transport{
Source: ts,
Base: rt,
},
src: ts,
}
}
}
// NewCachedFileTokenSource returns a resettable token source which reads a
// token from a file at a specified path and periodically reloads it.
func NewCachedFileTokenSource(path string) *cachingTokenSource {
return &cachingTokenSource{
now: time.Now,
leeway: 10 * time.Second,
@ -60,9 +80,9 @@ func NewCachedFileTokenSource(path string) oauth2.TokenSource {
}
}
// NewCachedTokenSource returns a oauth2.TokenSource reads a token from a
// designed TokenSource. The ts would provide the source of token.
func NewCachedTokenSource(ts oauth2.TokenSource) oauth2.TokenSource {
// NewCachedTokenSource returns resettable token source with caching. It reads
// a token from a designed TokenSource if not in cache or expired.
func NewCachedTokenSource(ts oauth2.TokenSource) *cachingTokenSource {
return &cachingTokenSource{
now: time.Now,
base: ts,
@ -72,6 +92,7 @@ func NewCachedTokenSource(ts oauth2.TokenSource) oauth2.TokenSource {
type tokenSourceTransport struct {
base http.RoundTripper
ort http.RoundTripper
src ResettableTokenSource
}
func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
@ -79,7 +100,15 @@ func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, e
if req.Header.Get("Authorization") != "" {
return tst.base.RoundTrip(req)
}
return tst.ort.RoundTrip(req)
// record time before RoundTrip to make sure newly acquired Unauthorized
// token would not be reset. Another request from user is required to reset
// and proceed.
start := time.Now()
resp, err := tst.ort.RoundTrip(req)
if err == nil && resp != nil && resp.StatusCode == 401 && tst.src != nil {
tst.src.ResetTokenOlderThan(start)
}
return resp, err
}
func (tst *tokenSourceTransport) CancelRequest(req *http.Request) {
@ -119,13 +148,12 @@ type cachingTokenSource struct {
sync.RWMutex
tok *oauth2.Token
t time.Time
// for testing
now func() time.Time
}
var _ = oauth2.TokenSource(&cachingTokenSource{})
func (ts *cachingTokenSource) Token() (*oauth2.Token, error) {
now := ts.now()
// fast path
@ -153,6 +181,16 @@ func (ts *cachingTokenSource) Token() (*oauth2.Token, error) {
return ts.tok, nil
}
ts.t = ts.now()
ts.tok = tok
return tok, nil
}
func (ts *cachingTokenSource) ResetTokenOlderThan(t time.Time) {
ts.Lock()
defer ts.Unlock()
if ts.t.Before(t) {
ts.tok = nil
ts.t = time.Time{}
}
}

View File

@ -156,6 +156,76 @@ func TestCachingTokenSourceRace(t *testing.T) {
}
}
func TestTokenSourceTransportRoundTrip(t *testing.T) {
goodToken := &oauth2.Token{
AccessToken: "good",
Expiry: time.Now().Add(1000 * time.Hour),
}
badToken := &oauth2.Token{
AccessToken: "bad",
Expiry: time.Now().Add(1000 * time.Hour),
}
tests := []struct {
name string
header http.Header
token *oauth2.Token
cachedToken *oauth2.Token
wantCalls int
wantCaching bool
}{
{
name: "skip oauth rt if has authorization header",
header: map[string][]string{"Authorization": {"Bearer TOKEN"}},
token: goodToken,
},
{
name: "authorized on newly acquired good token",
token: goodToken,
wantCalls: 1,
wantCaching: true,
},
{
name: "authorized on cached good token",
token: goodToken,
cachedToken: goodToken,
wantCalls: 0,
wantCaching: true,
},
{
name: "unauthorized on newly acquired bad token",
token: badToken,
wantCalls: 1,
wantCaching: true,
},
{
name: "unauthorized on cached bad token",
token: badToken,
cachedToken: badToken,
wantCalls: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
tts := &testTokenSource{
tok: test.token,
}
cachedTokenSource := NewCachedTokenSource(tts)
cachedTokenSource.tok = test.cachedToken
rt := ResettableTokenSourceWrapTransport(cachedTokenSource)(&testTransport{})
rt.RoundTrip(&http.Request{Header: test.header})
if tts.calls != test.wantCalls {
t.Errorf("RoundTrip() called Token() = %d times, want %d", tts.calls, test.wantCalls)
}
if (cachedTokenSource.tok != nil) != test.wantCaching {
t.Errorf("Got caching %v, want caching %v", cachedTokenSource != nil, test.wantCaching)
}
})
}
}
type uncancellableRT struct {
rt http.RoundTripper
}
@ -164,7 +234,7 @@ func (urt *uncancellableRT) RoundTrip(req *http.Request) (*http.Response, error)
return urt.rt.RoundTrip(req)
}
func TestCancellation(t *testing.T) {
func TestTokenSourceTransportCancelRequest(t *testing.T) {
tests := []struct {
name string
header http.Header
@ -186,7 +256,7 @@ func TestCancellation(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
baseRecorder := &recordCancelRoundTripper{}
baseRecorder := &testTransport{}
var base http.RoundTripper = baseRecorder
if test.wrapTransport != nil {
@ -211,16 +281,19 @@ func TestCancellation(t *testing.T) {
}
}
type recordCancelRoundTripper struct {
type testTransport struct {
canceled bool
base http.RoundTripper
}
func (rt *recordCancelRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
func (rt *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.Header["Authorization"][0] == "Bearer bad" {
return &http.Response{StatusCode: 401}, nil
}
return nil, nil
}
func (rt *recordCancelRoundTripper) CancelRequest(req *http.Request) {
func (rt *testTransport) CancelRequest(req *http.Request) {
rt.canceled = true
if rt.base != nil {
tryCancelRequest(rt.base, req)