mirror of
https://github.com/kubernetes/client-go.git
synced 2025-07-10 13:43:37 +00:00
reset token if got Unauthorized in KCM
Kubernetes-commit: 5e7b60ba5fe218d8ac59496350dfdb9f43785d98
This commit is contained in:
parent
c3c3d406ff
commit
3d6ec322f2
@ -43,9 +43,29 @@ func TokenSourceWrapTransport(ts oauth2.TokenSource) func(http.RoundTripper) htt
|
||||
}
|
||||
}
|
||||
|
||||
// NewCachedFileTokenSource returns a oauth2.TokenSource reads a token from a
|
||||
// file at a specified path and periodically reloads it.
|
||||
func NewCachedFileTokenSource(path string) oauth2.TokenSource {
|
||||
type ResettableTokenSource interface {
|
||||
oauth2.TokenSource
|
||||
ResetTokenOlderThan(time.Time)
|
||||
}
|
||||
|
||||
// ResettableTokenSourceWrapTransport returns a WrapTransport that injects bearer tokens
|
||||
// authentication from an ResettableTokenSource.
|
||||
func ResettableTokenSourceWrapTransport(ts ResettableTokenSource) func(http.RoundTripper) http.RoundTripper {
|
||||
return func(rt http.RoundTripper) http.RoundTripper {
|
||||
return &tokenSourceTransport{
|
||||
base: rt,
|
||||
ort: &oauth2.Transport{
|
||||
Source: ts,
|
||||
Base: rt,
|
||||
},
|
||||
src: ts,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewCachedFileTokenSource returns a resettable token source which reads a
|
||||
// token from a file at a specified path and periodically reloads it.
|
||||
func NewCachedFileTokenSource(path string) *cachingTokenSource {
|
||||
return &cachingTokenSource{
|
||||
now: time.Now,
|
||||
leeway: 10 * time.Second,
|
||||
@ -60,9 +80,9 @@ func NewCachedFileTokenSource(path string) oauth2.TokenSource {
|
||||
}
|
||||
}
|
||||
|
||||
// NewCachedTokenSource returns a oauth2.TokenSource reads a token from a
|
||||
// designed TokenSource. The ts would provide the source of token.
|
||||
func NewCachedTokenSource(ts oauth2.TokenSource) oauth2.TokenSource {
|
||||
// NewCachedTokenSource returns resettable token source with caching. It reads
|
||||
// a token from a designed TokenSource if not in cache or expired.
|
||||
func NewCachedTokenSource(ts oauth2.TokenSource) *cachingTokenSource {
|
||||
return &cachingTokenSource{
|
||||
now: time.Now,
|
||||
base: ts,
|
||||
@ -72,6 +92,7 @@ func NewCachedTokenSource(ts oauth2.TokenSource) oauth2.TokenSource {
|
||||
type tokenSourceTransport struct {
|
||||
base http.RoundTripper
|
||||
ort http.RoundTripper
|
||||
src ResettableTokenSource
|
||||
}
|
||||
|
||||
func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
@ -79,7 +100,15 @@ func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, e
|
||||
if req.Header.Get("Authorization") != "" {
|
||||
return tst.base.RoundTrip(req)
|
||||
}
|
||||
return tst.ort.RoundTrip(req)
|
||||
// record time before RoundTrip to make sure newly acquired Unauthorized
|
||||
// token would not be reset. Another request from user is required to reset
|
||||
// and proceed.
|
||||
start := time.Now()
|
||||
resp, err := tst.ort.RoundTrip(req)
|
||||
if err == nil && resp != nil && resp.StatusCode == 401 && tst.src != nil {
|
||||
tst.src.ResetTokenOlderThan(start)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (tst *tokenSourceTransport) CancelRequest(req *http.Request) {
|
||||
@ -119,13 +148,12 @@ type cachingTokenSource struct {
|
||||
|
||||
sync.RWMutex
|
||||
tok *oauth2.Token
|
||||
t time.Time
|
||||
|
||||
// for testing
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
var _ = oauth2.TokenSource(&cachingTokenSource{})
|
||||
|
||||
func (ts *cachingTokenSource) Token() (*oauth2.Token, error) {
|
||||
now := ts.now()
|
||||
// fast path
|
||||
@ -153,6 +181,16 @@ func (ts *cachingTokenSource) Token() (*oauth2.Token, error) {
|
||||
return ts.tok, nil
|
||||
}
|
||||
|
||||
ts.t = ts.now()
|
||||
ts.tok = tok
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
func (ts *cachingTokenSource) ResetTokenOlderThan(t time.Time) {
|
||||
ts.Lock()
|
||||
defer ts.Unlock()
|
||||
if ts.t.Before(t) {
|
||||
ts.tok = nil
|
||||
ts.t = time.Time{}
|
||||
}
|
||||
}
|
||||
|
@ -156,6 +156,76 @@ func TestCachingTokenSourceRace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenSourceTransportRoundTrip(t *testing.T) {
|
||||
goodToken := &oauth2.Token{
|
||||
AccessToken: "good",
|
||||
Expiry: time.Now().Add(1000 * time.Hour),
|
||||
}
|
||||
badToken := &oauth2.Token{
|
||||
AccessToken: "bad",
|
||||
Expiry: time.Now().Add(1000 * time.Hour),
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
header http.Header
|
||||
token *oauth2.Token
|
||||
cachedToken *oauth2.Token
|
||||
wantCalls int
|
||||
wantCaching bool
|
||||
}{
|
||||
{
|
||||
name: "skip oauth rt if has authorization header",
|
||||
header: map[string][]string{"Authorization": {"Bearer TOKEN"}},
|
||||
token: goodToken,
|
||||
},
|
||||
{
|
||||
name: "authorized on newly acquired good token",
|
||||
token: goodToken,
|
||||
wantCalls: 1,
|
||||
wantCaching: true,
|
||||
},
|
||||
{
|
||||
name: "authorized on cached good token",
|
||||
token: goodToken,
|
||||
cachedToken: goodToken,
|
||||
wantCalls: 0,
|
||||
wantCaching: true,
|
||||
},
|
||||
{
|
||||
name: "unauthorized on newly acquired bad token",
|
||||
token: badToken,
|
||||
wantCalls: 1,
|
||||
wantCaching: true,
|
||||
},
|
||||
{
|
||||
name: "unauthorized on cached bad token",
|
||||
token: badToken,
|
||||
cachedToken: badToken,
|
||||
wantCalls: 0,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
tts := &testTokenSource{
|
||||
tok: test.token,
|
||||
}
|
||||
cachedTokenSource := NewCachedTokenSource(tts)
|
||||
cachedTokenSource.tok = test.cachedToken
|
||||
|
||||
rt := ResettableTokenSourceWrapTransport(cachedTokenSource)(&testTransport{})
|
||||
|
||||
rt.RoundTrip(&http.Request{Header: test.header})
|
||||
if tts.calls != test.wantCalls {
|
||||
t.Errorf("RoundTrip() called Token() = %d times, want %d", tts.calls, test.wantCalls)
|
||||
}
|
||||
|
||||
if (cachedTokenSource.tok != nil) != test.wantCaching {
|
||||
t.Errorf("Got caching %v, want caching %v", cachedTokenSource != nil, test.wantCaching)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type uncancellableRT struct {
|
||||
rt http.RoundTripper
|
||||
}
|
||||
@ -164,7 +234,7 @@ func (urt *uncancellableRT) RoundTrip(req *http.Request) (*http.Response, error)
|
||||
return urt.rt.RoundTrip(req)
|
||||
}
|
||||
|
||||
func TestCancellation(t *testing.T) {
|
||||
func TestTokenSourceTransportCancelRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header http.Header
|
||||
@ -186,7 +256,7 @@ func TestCancellation(t *testing.T) {
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
baseRecorder := &recordCancelRoundTripper{}
|
||||
baseRecorder := &testTransport{}
|
||||
|
||||
var base http.RoundTripper = baseRecorder
|
||||
if test.wrapTransport != nil {
|
||||
@ -211,16 +281,19 @@ func TestCancellation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type recordCancelRoundTripper struct {
|
||||
type testTransport struct {
|
||||
canceled bool
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *recordCancelRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
func (rt *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req.Header["Authorization"][0] == "Bearer bad" {
|
||||
return &http.Response{StatusCode: 401}, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (rt *recordCancelRoundTripper) CancelRequest(req *http.Request) {
|
||||
func (rt *testTransport) CancelRequest(req *http.Request) {
|
||||
rt.canceled = true
|
||||
if rt.base != nil {
|
||||
tryCancelRequest(rt.base, req)
|
||||
|
Loading…
Reference in New Issue
Block a user