diff --git a/staging/src/k8s.io/apiserver/pkg/server/filters/timeout.go b/staging/src/k8s.io/apiserver/pkg/server/filters/timeout.go index 117dc8042f3..dff4a01d197 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/filters/timeout.go +++ b/staging/src/k8s.io/apiserver/pkg/server/filters/timeout.go @@ -93,6 +93,10 @@ func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { resultCh := make(chan interface{}) var tw timeoutWriter tw, w = newTimeoutWriter(w) + + // Make a copy of request and work on it in new goroutine + // to avoid race condition when accessing/modifying request (e.g. headers) + rCopy := r.Clone(r.Context()) go func() { defer func() { err := recover() @@ -107,7 +111,7 @@ func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } resultCh <- err }() - t.handler.ServeHTTP(w, r) + t.handler.ServeHTTP(w, rCopy) }() select { case err := <-resultCh: diff --git a/staging/src/k8s.io/apiserver/pkg/server/filters/timeout_test.go b/staging/src/k8s.io/apiserver/pkg/server/filters/timeout_test.go index 0e76fd7c828..cf4203f72a6 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/filters/timeout_test.go +++ b/staging/src/k8s.io/apiserver/pkg/server/filters/timeout_test.go @@ -253,6 +253,113 @@ func TestTimeoutHeaders(t *testing.T) { res.Body.Close() } +func TestTimeoutRequestHeaders(t *testing.T) { + origReallyCrash := runtime.ReallyCrash + runtime.ReallyCrash = false + defer func() { + runtime.ReallyCrash = origReallyCrash + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Add dummy request info, otherwise we skip postTimeoutFn + ctx = request.WithRequestInfo(ctx, &request.RequestInfo{}) + + withDeadline := func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + handler.ServeHTTP(w, req.WithContext(ctx)) + }) + } + + ts := httptest.NewServer( + withDeadline( + WithTimeoutForNonLongRunningRequests( + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // trigger the timeout + cancel() + // mutate request Headers + // Authorization filter does it for example + for j := 0; j < 10000; j++ { + req.Header.Set("Test", "post") + } + }), + func(r *http.Request, requestInfo *request.RequestInfo) bool { + return false + }, + ), + ), + ) + defer ts.Close() + + client := &http.Client{} + req, err := http.NewRequest(http.MethodPatch, ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != http.StatusGatewayTimeout { + t.Errorf("got res.StatusCde %d; expected %d", res.StatusCode, http.StatusServiceUnavailable) + } + res.Body.Close() +} + +func TestTimeoutWithLogging(t *testing.T) { + origReallyCrash := runtime.ReallyCrash + runtime.ReallyCrash = false + defer func() { + runtime.ReallyCrash = origReallyCrash + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + withDeadline := func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + handler.ServeHTTP(w, req.WithContext(ctx)) + }) + } + + ts := httptest.NewServer( + WithHTTPLogging( + withDeadline( + WithTimeoutForNonLongRunningRequests( + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // trigger the timeout + cancel() + // mutate request Headers + // Authorization filter does it for example + for j := 0; j < 10000; j++ { + req.Header.Set("Test", "post") + } + }), + func(r *http.Request, requestInfo *request.RequestInfo) bool { + return false + }, + ), + ), + ), + ) + defer ts.Close() + + client := &http.Client{} + req, err := http.NewRequest(http.MethodPatch, ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != http.StatusGatewayTimeout { + t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusServiceUnavailable) + } + res.Body.Close() +} + func TestErrConnKilled(t *testing.T) { var buf bytes.Buffer klog.SetOutput(&buf)