diff --git a/transport/token_source.go b/transport/token_source.go index f730c397..fea02e61 100644 --- a/transport/token_source.go +++ b/transport/token_source.go @@ -43,9 +43,29 @@ func TokenSourceWrapTransport(ts oauth2.TokenSource) func(http.RoundTripper) htt } } -// NewCachedFileTokenSource returns a oauth2.TokenSource reads a token from a -// file at a specified path and periodically reloads it. -func NewCachedFileTokenSource(path string) oauth2.TokenSource { +type ResettableTokenSource interface { + oauth2.TokenSource + ResetTokenOlderThan(time.Time) +} + +// ResettableTokenSourceWrapTransport returns a WrapTransport that injects bearer tokens +// authentication from an ResettableTokenSource. +func ResettableTokenSourceWrapTransport(ts ResettableTokenSource) func(http.RoundTripper) http.RoundTripper { + return func(rt http.RoundTripper) http.RoundTripper { + return &tokenSourceTransport{ + base: rt, + ort: &oauth2.Transport{ + Source: ts, + Base: rt, + }, + src: ts, + } + } +} + +// NewCachedFileTokenSource returns a resettable token source which reads a +// token from a file at a specified path and periodically reloads it. +func NewCachedFileTokenSource(path string) *cachingTokenSource { return &cachingTokenSource{ now: time.Now, leeway: 10 * time.Second, @@ -60,9 +80,9 @@ func NewCachedFileTokenSource(path string) oauth2.TokenSource { } } -// NewCachedTokenSource returns a oauth2.TokenSource reads a token from a -// designed TokenSource. The ts would provide the source of token. -func NewCachedTokenSource(ts oauth2.TokenSource) oauth2.TokenSource { +// NewCachedTokenSource returns resettable token source with caching. It reads +// a token from a designed TokenSource if not in cache or expired. +func NewCachedTokenSource(ts oauth2.TokenSource) *cachingTokenSource { return &cachingTokenSource{ now: time.Now, base: ts, @@ -72,6 +92,7 @@ func NewCachedTokenSource(ts oauth2.TokenSource) oauth2.TokenSource { type tokenSourceTransport struct { base http.RoundTripper ort http.RoundTripper + src ResettableTokenSource } func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, error) { @@ -79,7 +100,15 @@ func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, e if req.Header.Get("Authorization") != "" { return tst.base.RoundTrip(req) } - return tst.ort.RoundTrip(req) + // record time before RoundTrip to make sure newly acquired Unauthorized + // token would not be reset. Another request from user is required to reset + // and proceed. + start := time.Now() + resp, err := tst.ort.RoundTrip(req) + if err == nil && resp != nil && resp.StatusCode == 401 && tst.src != nil { + tst.src.ResetTokenOlderThan(start) + } + return resp, err } func (tst *tokenSourceTransport) CancelRequest(req *http.Request) { @@ -119,13 +148,12 @@ type cachingTokenSource struct { sync.RWMutex tok *oauth2.Token + t time.Time // for testing now func() time.Time } -var _ = oauth2.TokenSource(&cachingTokenSource{}) - func (ts *cachingTokenSource) Token() (*oauth2.Token, error) { now := ts.now() // fast path @@ -153,6 +181,16 @@ func (ts *cachingTokenSource) Token() (*oauth2.Token, error) { return ts.tok, nil } + ts.t = ts.now() ts.tok = tok return tok, nil } + +func (ts *cachingTokenSource) ResetTokenOlderThan(t time.Time) { + ts.Lock() + defer ts.Unlock() + if ts.t.Before(t) { + ts.tok = nil + ts.t = time.Time{} + } +} diff --git a/transport/token_source_test.go b/transport/token_source_test.go index 7c2ff6e2..2c55b2df 100644 --- a/transport/token_source_test.go +++ b/transport/token_source_test.go @@ -156,6 +156,76 @@ func TestCachingTokenSourceRace(t *testing.T) { } } +func TestTokenSourceTransportRoundTrip(t *testing.T) { + goodToken := &oauth2.Token{ + AccessToken: "good", + Expiry: time.Now().Add(1000 * time.Hour), + } + badToken := &oauth2.Token{ + AccessToken: "bad", + Expiry: time.Now().Add(1000 * time.Hour), + } + tests := []struct { + name string + header http.Header + token *oauth2.Token + cachedToken *oauth2.Token + wantCalls int + wantCaching bool + }{ + { + name: "skip oauth rt if has authorization header", + header: map[string][]string{"Authorization": {"Bearer TOKEN"}}, + token: goodToken, + }, + { + name: "authorized on newly acquired good token", + token: goodToken, + wantCalls: 1, + wantCaching: true, + }, + { + name: "authorized on cached good token", + token: goodToken, + cachedToken: goodToken, + wantCalls: 0, + wantCaching: true, + }, + { + name: "unauthorized on newly acquired bad token", + token: badToken, + wantCalls: 1, + wantCaching: true, + }, + { + name: "unauthorized on cached bad token", + token: badToken, + cachedToken: badToken, + wantCalls: 0, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tts := &testTokenSource{ + tok: test.token, + } + cachedTokenSource := NewCachedTokenSource(tts) + cachedTokenSource.tok = test.cachedToken + + rt := ResettableTokenSourceWrapTransport(cachedTokenSource)(&testTransport{}) + + rt.RoundTrip(&http.Request{Header: test.header}) + if tts.calls != test.wantCalls { + t.Errorf("RoundTrip() called Token() = %d times, want %d", tts.calls, test.wantCalls) + } + + if (cachedTokenSource.tok != nil) != test.wantCaching { + t.Errorf("Got caching %v, want caching %v", cachedTokenSource != nil, test.wantCaching) + } + }) + } +} + type uncancellableRT struct { rt http.RoundTripper } @@ -164,7 +234,7 @@ func (urt *uncancellableRT) RoundTrip(req *http.Request) (*http.Response, error) return urt.rt.RoundTrip(req) } -func TestCancellation(t *testing.T) { +func TestTokenSourceTransportCancelRequest(t *testing.T) { tests := []struct { name string header http.Header @@ -186,7 +256,7 @@ func TestCancellation(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - baseRecorder := &recordCancelRoundTripper{} + baseRecorder := &testTransport{} var base http.RoundTripper = baseRecorder if test.wrapTransport != nil { @@ -211,16 +281,19 @@ func TestCancellation(t *testing.T) { } } -type recordCancelRoundTripper struct { +type testTransport struct { canceled bool base http.RoundTripper } -func (rt *recordCancelRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { +func (rt *testTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if req.Header["Authorization"][0] == "Bearer bad" { + return &http.Response{StatusCode: 401}, nil + } return nil, nil } -func (rt *recordCancelRoundTripper) CancelRequest(req *http.Request) { +func (rt *testTransport) CancelRequest(req *http.Request) { rt.canceled = true if rt.base != nil { tryCancelRequest(rt.base, req)