Merge pull request #108455 from Argh4k/race-conditions

Copy request in timeout handler
This commit is contained in:
Kubernetes Prow Robot 2022-03-24 14:00:41 -07:00 committed by GitHub
commit 9bb5823b83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 112 additions and 1 deletions

View File

@ -93,6 +93,10 @@ func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
resultCh := make(chan interface{})
var tw timeoutWriter
tw, w = newTimeoutWriter(w)
// Make a copy of request and work on it in new goroutine
// to avoid race condition when accessing/modifying request (e.g. headers)
rCopy := r.Clone(r.Context())
go func() {
defer func() {
err := recover()
@ -107,7 +111,7 @@ func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
resultCh <- err
}()
t.handler.ServeHTTP(w, r)
t.handler.ServeHTTP(w, rCopy)
}()
select {
case err := <-resultCh:

View File

@ -253,6 +253,113 @@ func TestTimeoutHeaders(t *testing.T) {
res.Body.Close()
}
func TestTimeoutRequestHeaders(t *testing.T) {
origReallyCrash := runtime.ReallyCrash
runtime.ReallyCrash = false
defer func() {
runtime.ReallyCrash = origReallyCrash
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Add dummy request info, otherwise we skip postTimeoutFn
ctx = request.WithRequestInfo(ctx, &request.RequestInfo{})
withDeadline := func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
handler.ServeHTTP(w, req.WithContext(ctx))
})
}
ts := httptest.NewServer(
withDeadline(
WithTimeoutForNonLongRunningRequests(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// trigger the timeout
cancel()
// mutate request Headers
// Authorization filter does it for example
for j := 0; j < 10000; j++ {
req.Header.Set("Test", "post")
}
}),
func(r *http.Request, requestInfo *request.RequestInfo) bool {
return false
},
),
),
)
defer ts.Close()
client := &http.Client{}
req, err := http.NewRequest(http.MethodPatch, ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusGatewayTimeout {
t.Errorf("got res.StatusCde %d; expected %d", res.StatusCode, http.StatusServiceUnavailable)
}
res.Body.Close()
}
func TestTimeoutWithLogging(t *testing.T) {
origReallyCrash := runtime.ReallyCrash
runtime.ReallyCrash = false
defer func() {
runtime.ReallyCrash = origReallyCrash
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withDeadline := func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
handler.ServeHTTP(w, req.WithContext(ctx))
})
}
ts := httptest.NewServer(
WithHTTPLogging(
withDeadline(
WithTimeoutForNonLongRunningRequests(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// trigger the timeout
cancel()
// mutate request Headers
// Authorization filter does it for example
for j := 0; j < 10000; j++ {
req.Header.Set("Test", "post")
}
}),
func(r *http.Request, requestInfo *request.RequestInfo) bool {
return false
},
),
),
),
)
defer ts.Close()
client := &http.Client{}
req, err := http.NewRequest(http.MethodPatch, ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusGatewayTimeout {
t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusServiceUnavailable)
}
res.Body.Close()
}
func TestErrConnKilled(t *testing.T) {
var buf bytes.Buffer
klog.SetOutput(&buf)