mirror of
https://github.com/kubernetes/client-go.git
synced 2025-07-10 21:54:01 +00:00
reset token if got Unauthorized in KCM
Kubernetes-commit: 5e7b60ba5fe218d8ac59496350dfdb9f43785d98
This commit is contained in:
parent
c3c3d406ff
commit
3d6ec322f2
@ -43,9 +43,29 @@ func TokenSourceWrapTransport(ts oauth2.TokenSource) func(http.RoundTripper) htt
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCachedFileTokenSource returns a oauth2.TokenSource reads a token from a
|
type ResettableTokenSource interface {
|
||||||
// file at a specified path and periodically reloads it.
|
oauth2.TokenSource
|
||||||
func NewCachedFileTokenSource(path string) oauth2.TokenSource {
|
ResetTokenOlderThan(time.Time)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResettableTokenSourceWrapTransport returns a WrapTransport that injects bearer tokens
|
||||||
|
// authentication from an ResettableTokenSource.
|
||||||
|
func ResettableTokenSourceWrapTransport(ts ResettableTokenSource) func(http.RoundTripper) http.RoundTripper {
|
||||||
|
return func(rt http.RoundTripper) http.RoundTripper {
|
||||||
|
return &tokenSourceTransport{
|
||||||
|
base: rt,
|
||||||
|
ort: &oauth2.Transport{
|
||||||
|
Source: ts,
|
||||||
|
Base: rt,
|
||||||
|
},
|
||||||
|
src: ts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCachedFileTokenSource returns a resettable token source which reads a
|
||||||
|
// token from a file at a specified path and periodically reloads it.
|
||||||
|
func NewCachedFileTokenSource(path string) *cachingTokenSource {
|
||||||
return &cachingTokenSource{
|
return &cachingTokenSource{
|
||||||
now: time.Now,
|
now: time.Now,
|
||||||
leeway: 10 * time.Second,
|
leeway: 10 * time.Second,
|
||||||
@ -60,9 +80,9 @@ func NewCachedFileTokenSource(path string) oauth2.TokenSource {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCachedTokenSource returns a oauth2.TokenSource reads a token from a
|
// NewCachedTokenSource returns resettable token source with caching. It reads
|
||||||
// designed TokenSource. The ts would provide the source of token.
|
// a token from a designed TokenSource if not in cache or expired.
|
||||||
func NewCachedTokenSource(ts oauth2.TokenSource) oauth2.TokenSource {
|
func NewCachedTokenSource(ts oauth2.TokenSource) *cachingTokenSource {
|
||||||
return &cachingTokenSource{
|
return &cachingTokenSource{
|
||||||
now: time.Now,
|
now: time.Now,
|
||||||
base: ts,
|
base: ts,
|
||||||
@ -72,6 +92,7 @@ func NewCachedTokenSource(ts oauth2.TokenSource) oauth2.TokenSource {
|
|||||||
type tokenSourceTransport struct {
|
type tokenSourceTransport struct {
|
||||||
base http.RoundTripper
|
base http.RoundTripper
|
||||||
ort http.RoundTripper
|
ort http.RoundTripper
|
||||||
|
src ResettableTokenSource
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
@ -79,7 +100,15 @@ func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, e
|
|||||||
if req.Header.Get("Authorization") != "" {
|
if req.Header.Get("Authorization") != "" {
|
||||||
return tst.base.RoundTrip(req)
|
return tst.base.RoundTrip(req)
|
||||||
}
|
}
|
||||||
return tst.ort.RoundTrip(req)
|
// record time before RoundTrip to make sure newly acquired Unauthorized
|
||||||
|
// token would not be reset. Another request from user is required to reset
|
||||||
|
// and proceed.
|
||||||
|
start := time.Now()
|
||||||
|
resp, err := tst.ort.RoundTrip(req)
|
||||||
|
if err == nil && resp != nil && resp.StatusCode == 401 && tst.src != nil {
|
||||||
|
tst.src.ResetTokenOlderThan(start)
|
||||||
|
}
|
||||||
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tst *tokenSourceTransport) CancelRequest(req *http.Request) {
|
func (tst *tokenSourceTransport) CancelRequest(req *http.Request) {
|
||||||
@ -119,13 +148,12 @@ type cachingTokenSource struct {
|
|||||||
|
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
tok *oauth2.Token
|
tok *oauth2.Token
|
||||||
|
t time.Time
|
||||||
|
|
||||||
// for testing
|
// for testing
|
||||||
now func() time.Time
|
now func() time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ = oauth2.TokenSource(&cachingTokenSource{})
|
|
||||||
|
|
||||||
func (ts *cachingTokenSource) Token() (*oauth2.Token, error) {
|
func (ts *cachingTokenSource) Token() (*oauth2.Token, error) {
|
||||||
now := ts.now()
|
now := ts.now()
|
||||||
// fast path
|
// fast path
|
||||||
@ -153,6 +181,16 @@ func (ts *cachingTokenSource) Token() (*oauth2.Token, error) {
|
|||||||
return ts.tok, nil
|
return ts.tok, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ts.t = ts.now()
|
||||||
ts.tok = tok
|
ts.tok = tok
|
||||||
return tok, nil
|
return tok, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts *cachingTokenSource) ResetTokenOlderThan(t time.Time) {
|
||||||
|
ts.Lock()
|
||||||
|
defer ts.Unlock()
|
||||||
|
if ts.t.Before(t) {
|
||||||
|
ts.tok = nil
|
||||||
|
ts.t = time.Time{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -156,6 +156,76 @@ func TestCachingTokenSourceRace(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTokenSourceTransportRoundTrip(t *testing.T) {
|
||||||
|
goodToken := &oauth2.Token{
|
||||||
|
AccessToken: "good",
|
||||||
|
Expiry: time.Now().Add(1000 * time.Hour),
|
||||||
|
}
|
||||||
|
badToken := &oauth2.Token{
|
||||||
|
AccessToken: "bad",
|
||||||
|
Expiry: time.Now().Add(1000 * time.Hour),
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
header http.Header
|
||||||
|
token *oauth2.Token
|
||||||
|
cachedToken *oauth2.Token
|
||||||
|
wantCalls int
|
||||||
|
wantCaching bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "skip oauth rt if has authorization header",
|
||||||
|
header: map[string][]string{"Authorization": {"Bearer TOKEN"}},
|
||||||
|
token: goodToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "authorized on newly acquired good token",
|
||||||
|
token: goodToken,
|
||||||
|
wantCalls: 1,
|
||||||
|
wantCaching: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "authorized on cached good token",
|
||||||
|
token: goodToken,
|
||||||
|
cachedToken: goodToken,
|
||||||
|
wantCalls: 0,
|
||||||
|
wantCaching: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized on newly acquired bad token",
|
||||||
|
token: badToken,
|
||||||
|
wantCalls: 1,
|
||||||
|
wantCaching: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unauthorized on cached bad token",
|
||||||
|
token: badToken,
|
||||||
|
cachedToken: badToken,
|
||||||
|
wantCalls: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
tts := &testTokenSource{
|
||||||
|
tok: test.token,
|
||||||
|
}
|
||||||
|
cachedTokenSource := NewCachedTokenSource(tts)
|
||||||
|
cachedTokenSource.tok = test.cachedToken
|
||||||
|
|
||||||
|
rt := ResettableTokenSourceWrapTransport(cachedTokenSource)(&testTransport{})
|
||||||
|
|
||||||
|
rt.RoundTrip(&http.Request{Header: test.header})
|
||||||
|
if tts.calls != test.wantCalls {
|
||||||
|
t.Errorf("RoundTrip() called Token() = %d times, want %d", tts.calls, test.wantCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cachedTokenSource.tok != nil) != test.wantCaching {
|
||||||
|
t.Errorf("Got caching %v, want caching %v", cachedTokenSource != nil, test.wantCaching)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type uncancellableRT struct {
|
type uncancellableRT struct {
|
||||||
rt http.RoundTripper
|
rt http.RoundTripper
|
||||||
}
|
}
|
||||||
@ -164,7 +234,7 @@ func (urt *uncancellableRT) RoundTrip(req *http.Request) (*http.Response, error)
|
|||||||
return urt.rt.RoundTrip(req)
|
return urt.rt.RoundTrip(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCancellation(t *testing.T) {
|
func TestTokenSourceTransportCancelRequest(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
header http.Header
|
header http.Header
|
||||||
@ -186,7 +256,7 @@ func TestCancellation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
baseRecorder := &recordCancelRoundTripper{}
|
baseRecorder := &testTransport{}
|
||||||
|
|
||||||
var base http.RoundTripper = baseRecorder
|
var base http.RoundTripper = baseRecorder
|
||||||
if test.wrapTransport != nil {
|
if test.wrapTransport != nil {
|
||||||
@ -211,16 +281,19 @@ func TestCancellation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type recordCancelRoundTripper struct {
|
type testTransport struct {
|
||||||
canceled bool
|
canceled bool
|
||||||
base http.RoundTripper
|
base http.RoundTripper
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rt *recordCancelRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (rt *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
if req.Header["Authorization"][0] == "Bearer bad" {
|
||||||
|
return &http.Response{StatusCode: 401}, nil
|
||||||
|
}
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rt *recordCancelRoundTripper) CancelRequest(req *http.Request) {
|
func (rt *testTransport) CancelRequest(req *http.Request) {
|
||||||
rt.canceled = true
|
rt.canceled = true
|
||||||
if rt.base != nil {
|
if rt.base != nil {
|
||||||
tryCancelRequest(rt.base, req)
|
tryCancelRequest(rt.base, req)
|
||||||
|
Loading…
Reference in New Issue
Block a user