From cdbf1c4b6218d8e709f80e8d14e7e726288721ab Mon Sep 17 00:00:00 2001 From: Mike Danese Date: Wed, 5 Dec 2018 12:36:48 -0800 Subject: [PATCH] implement request cancellation in token transport Kubernetes-commit: a42e029e6905bee5b9d5489610c4fbe5988eeac6 --- transport/round_trippers.go | 40 +++------------- transport/token_source.go | 9 ++++ transport/token_source_test.go | 83 ++++++++++++++++++++++++++++++++++ transport/transport.go | 17 +++++++ 4 files changed, 115 insertions(+), 34 deletions(-) diff --git a/transport/round_trippers.go b/transport/round_trippers.go index 844ee9a2..a272753a 100644 --- a/transport/round_trippers.go +++ b/transport/round_trippers.go @@ -80,10 +80,6 @@ func DebugWrappers(rt http.RoundTripper) http.RoundTripper { return rt } -type requestCanceler interface { - CancelRequest(*http.Request) -} - type authProxyRoundTripper struct { username string groups []string @@ -140,11 +136,7 @@ func SetAuthProxyHeaders(req *http.Request, username string, groups []string, ex } func (rt *authProxyRoundTripper) CancelRequest(req *http.Request) { - if canceler, ok := rt.rt.(requestCanceler); ok { - canceler.CancelRequest(req) - } else { - klog.Errorf("CancelRequest not implemented by %T", rt.rt) - } + tryCancelRequest(rt.WrappedRoundTripper(), req) } func (rt *authProxyRoundTripper) WrappedRoundTripper() http.RoundTripper { return rt.rt } @@ -168,11 +160,7 @@ func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, e } func (rt *userAgentRoundTripper) CancelRequest(req *http.Request) { - if canceler, ok := rt.rt.(requestCanceler); ok { - canceler.CancelRequest(req) - } else { - klog.Errorf("CancelRequest not implemented by %T", rt.rt) - } + tryCancelRequest(rt.WrappedRoundTripper(), req) } func (rt *userAgentRoundTripper) WrappedRoundTripper() http.RoundTripper { return rt.rt } @@ -199,11 +187,7 @@ func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, e } func (rt *basicAuthRoundTripper) CancelRequest(req *http.Request) { - if canceler, ok := rt.rt.(requestCanceler); ok { - canceler.CancelRequest(req) - } else { - klog.Errorf("CancelRequest not implemented by %T", rt.rt) - } + tryCancelRequest(rt.WrappedRoundTripper(), req) } func (rt *basicAuthRoundTripper) WrappedRoundTripper() http.RoundTripper { return rt.rt } @@ -259,11 +243,7 @@ func (rt *impersonatingRoundTripper) RoundTrip(req *http.Request) (*http.Respons } func (rt *impersonatingRoundTripper) CancelRequest(req *http.Request) { - if canceler, ok := rt.delegate.(requestCanceler); ok { - canceler.CancelRequest(req) - } else { - klog.Errorf("CancelRequest not implemented by %T", rt.delegate) - } + tryCancelRequest(rt.WrappedRoundTripper(), req) } func (rt *impersonatingRoundTripper) WrappedRoundTripper() http.RoundTripper { return rt.delegate } @@ -318,11 +298,7 @@ func (rt *bearerAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, } func (rt *bearerAuthRoundTripper) CancelRequest(req *http.Request) { - if canceler, ok := rt.rt.(requestCanceler); ok { - canceler.CancelRequest(req) - } else { - klog.Errorf("CancelRequest not implemented by %T", rt.rt) - } + tryCancelRequest(rt.WrappedRoundTripper(), req) } func (rt *bearerAuthRoundTripper) WrappedRoundTripper() http.RoundTripper { return rt.rt } @@ -402,11 +378,7 @@ func newDebuggingRoundTripper(rt http.RoundTripper, levels ...debugLevel) *debug } func (rt *debuggingRoundTripper) CancelRequest(req *http.Request) { - if canceler, ok := rt.delegatedRoundTripper.(requestCanceler); ok { - canceler.CancelRequest(req) - } else { - klog.Errorf("CancelRequest not implemented by %T", rt.delegatedRoundTripper) - } + tryCancelRequest(rt.WrappedRoundTripper(), req) } var knownAuthTypes = map[string]bool{ diff --git a/transport/token_source.go b/transport/token_source.go index b8cadd38..bb32c3b4 100644 --- a/transport/token_source.go +++ b/transport/token_source.go @@ -25,6 +25,7 @@ import ( "time" "golang.org/x/oauth2" + "k8s.io/klog" ) @@ -81,6 +82,14 @@ func (tst *tokenSourceTransport) RoundTrip(req *http.Request) (*http.Response, e return tst.ort.RoundTrip(req) } +func (tst *tokenSourceTransport) CancelRequest(req *http.Request) { + if req.Header.Get("Authorization") != "" { + tryCancelRequest(tst.base, req) + return + } + tryCancelRequest(tst.ort, req) +} + type fileTokenSource struct { path string period time.Duration diff --git a/transport/token_source_test.go b/transport/token_source_test.go index a222495b..6d61caca 100644 --- a/transport/token_source_test.go +++ b/transport/token_source_test.go @@ -18,6 +18,7 @@ package transport import ( "fmt" + "net/http" "reflect" "sync" "testing" @@ -154,3 +155,85 @@ func TestCachingTokenSourceRace(t *testing.T) { } } } + +type uncancellableRT struct { + rt http.RoundTripper +} + +func (urt *uncancellableRT) RoundTrip(req *http.Request) (*http.Response, error) { + return urt.rt.RoundTrip(req) +} + +func TestCancellation(t *testing.T) { + tests := []struct { + name string + header http.Header + wrapTransport func(http.RoundTripper) http.RoundTripper + expectCancel bool + }{ + { + name: "cancel req with bearer token skips oauth rt", + header: map[string][]string{"Authorization": {"Bearer TOKEN"}}, + expectCancel: true, + }, + { + name: "cancel req without bearer token hits both rts", + expectCancel: true, + }, + { + name: "cancel req without bearer token hits both wrapped rts", + wrapTransport: func(rt http.RoundTripper) http.RoundTripper { + return NewUserAgentRoundTripper("testing testing", rt) + }, + expectCancel: true, + }, + { + name: "can't cancel request with rts that doesn't implent unwrap or cancel", + wrapTransport: func(rt http.RoundTripper) http.RoundTripper { + return &uncancellableRT{rt: rt} + }, + expectCancel: false, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + baseRecorder := &recordCancelRoundTripper{} + + var base http.RoundTripper = baseRecorder + if test.wrapTransport != nil { + base = test.wrapTransport(base) + } + + rt := &tokenSourceTransport{ + base: base, + ort: &oauth2.Transport{ + Base: base, + }, + } + + rt.CancelRequest(&http.Request{ + Header: test.header, + }) + + if baseRecorder.canceled != test.expectCancel { + t.Errorf("unexpected cancel: got=%v, want=%v", baseRecorder.canceled, test.expectCancel) + } + }) + } +} + +type recordCancelRoundTripper struct { + canceled bool + base http.RoundTripper +} + +func (rt *recordCancelRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return nil, nil +} + +func (rt *recordCancelRoundTripper) CancelRequest(req *http.Request) { + rt.canceled = true + if rt.base != nil { + tryCancelRequest(rt.base, req) + } +} diff --git a/transport/transport.go b/transport/transport.go index 2a145c97..1815c11f 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -23,6 +23,9 @@ import ( "fmt" "io/ioutil" "net/http" + + utilnet "k8s.io/apimachinery/pkg/util/net" + "k8s.io/klog" ) // New returns an http.RoundTripper that will provide the authentication @@ -225,3 +228,17 @@ func (b *contextCanceller) RoundTrip(req *http.Request) (*http.Response, error) return b.rt.RoundTrip(req) } } + +func tryCancelRequest(rt http.RoundTripper, req *http.Request) { + type canceler interface { + CancelRequest(*http.Request) + } + switch rt := rt.(type) { + case canceler: + rt.CancelRequest(req) + case utilnet.RoundTripperWrapper: + tryCancelRequest(rt.WrappedRoundTripper(), req) + default: + klog.Warningf("Unable to cancel request for %T", rt) + } +}