From 5b2a31f375755386b5cb2541b912f3561f7d6431 Mon Sep 17 00:00:00 2001 From: Jordan Liggitt Date: Tue, 4 Jan 2022 22:57:29 -0500 Subject: [PATCH] Fix header mutation race in timeout filter --- .../apiserver/pkg/server/filters/timeout.go | 19 ++++++-- .../pkg/server/filters/timeout_test.go | 46 +++++++++++++++++++ 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/staging/src/k8s.io/apiserver/pkg/server/filters/timeout.go b/staging/src/k8s.io/apiserver/pkg/server/filters/timeout.go index 7487fbf1ec7..a485dfb22b3 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/filters/timeout.go +++ b/staging/src/k8s.io/apiserver/pkg/server/filters/timeout.go @@ -149,7 +149,7 @@ type timeoutWriter interface { } func newTimeoutWriter(w http.ResponseWriter) (timeoutWriter, http.ResponseWriter) { - base := &baseTimeoutWriter{w: w} + base := &baseTimeoutWriter{w: w, handlerHeaders: w.Header().Clone()} wrapped := responsewriter.WrapForHTTP1Or2(base) return base, wrapped @@ -161,6 +161,9 @@ var _ responsewriter.UserProvidedDecorator = &baseTimeoutWriter{} type baseTimeoutWriter struct { w http.ResponseWriter + // headers written by the normal handler + handlerHeaders http.Header + mu sync.Mutex // if the timeout handler has timeout timedOut bool @@ -182,7 +185,7 @@ func (tw *baseTimeoutWriter) Header() http.Header { return http.Header{} } - return tw.w.Header() + return tw.handlerHeaders } func (tw *baseTimeoutWriter) Write(p []byte) (int, error) { @@ -196,7 +199,10 @@ func (tw *baseTimeoutWriter) Write(p []byte) (int, error) { return 0, http.ErrHijacked } - tw.wroteHeader = true + if !tw.wroteHeader { + copyHeaders(tw.w.Header(), tw.handlerHeaders) + tw.wroteHeader = true + } return tw.w.Write(p) } @@ -221,10 +227,17 @@ func (tw *baseTimeoutWriter) WriteHeader(code int) { return } + copyHeaders(tw.w.Header(), tw.handlerHeaders) tw.wroteHeader = true tw.w.WriteHeader(code) } +func copyHeaders(dst, src http.Header) { + for k, v := range src { + dst[k] = v + } +} + func (tw *baseTimeoutWriter) timeout(err *apierrors.StatusError) { tw.mu.Lock() defer tw.mu.Unlock() diff --git a/staging/src/k8s.io/apiserver/pkg/server/filters/timeout_test.go b/staging/src/k8s.io/apiserver/pkg/server/filters/timeout_test.go index 32682c713cb..1bca00f2dbc 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/filters/timeout_test.go +++ b/staging/src/k8s.io/apiserver/pkg/server/filters/timeout_test.go @@ -203,6 +203,52 @@ func TestTimeout(t *testing.T) { } } +func TestTimeoutHeaders(t *testing.T) { + origReallyCrash := runtime.ReallyCrash + runtime.ReallyCrash = false + defer func() { + runtime.ReallyCrash = origReallyCrash + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + withDeadline := func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + handler.ServeHTTP(w, req.WithContext(ctx)) + }) + } + + ts := httptest.NewServer( + withDeadline( + WithTimeout( + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + h := w.Header() + // trigger the timeout + cancel() + // mutate response Headers + for j := 0; j < 1000; j++ { + h.Set("Test", "post") + } + }), + func(req *http.Request) (*http.Request, bool, func(), *apierrors.StatusError) { + return req, false, func() {}, apierrors.NewServerTimeout(schema.GroupResource{Group: "foo", Resource: "bar"}, "get", 0) + }, + ), + ), + ) + defer ts.Close() + + res, err := http.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != http.StatusGatewayTimeout { + t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusServiceUnavailable) + } + res.Body.Close() +} + func captureStdErr() (func() string, func(), error) { var buf bytes.Buffer reader, writer, err := os.Pipe()