Merge pull request #107452 from liggitt/timeout_headers

Fix header mutation race in timeout filter
This commit is contained in:
Kubernetes Prow Robot 2022-01-10 14:36:37 -08:00 committed by GitHub
commit 3cec1d1a13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 3 deletions

View File

@ -149,7 +149,7 @@ type timeoutWriter interface {
} }
func newTimeoutWriter(w http.ResponseWriter) (timeoutWriter, http.ResponseWriter) { func newTimeoutWriter(w http.ResponseWriter) (timeoutWriter, http.ResponseWriter) {
base := &baseTimeoutWriter{w: w} base := &baseTimeoutWriter{w: w, handlerHeaders: w.Header().Clone()}
wrapped := responsewriter.WrapForHTTP1Or2(base) wrapped := responsewriter.WrapForHTTP1Or2(base)
return base, wrapped return base, wrapped
@ -161,6 +161,9 @@ var _ responsewriter.UserProvidedDecorator = &baseTimeoutWriter{}
type baseTimeoutWriter struct { type baseTimeoutWriter struct {
w http.ResponseWriter w http.ResponseWriter
// headers written by the normal handler
handlerHeaders http.Header
mu sync.Mutex mu sync.Mutex
// if the timeout handler has timeout // if the timeout handler has timeout
timedOut bool timedOut bool
@ -182,7 +185,7 @@ func (tw *baseTimeoutWriter) Header() http.Header {
return http.Header{} return http.Header{}
} }
return tw.w.Header() return tw.handlerHeaders
} }
func (tw *baseTimeoutWriter) Write(p []byte) (int, error) { func (tw *baseTimeoutWriter) Write(p []byte) (int, error) {
@ -196,7 +199,10 @@ func (tw *baseTimeoutWriter) Write(p []byte) (int, error) {
return 0, http.ErrHijacked return 0, http.ErrHijacked
} }
if !tw.wroteHeader {
copyHeaders(tw.w.Header(), tw.handlerHeaders)
tw.wroteHeader = true tw.wroteHeader = true
}
return tw.w.Write(p) return tw.w.Write(p)
} }
@ -221,10 +227,17 @@ func (tw *baseTimeoutWriter) WriteHeader(code int) {
return return
} }
copyHeaders(tw.w.Header(), tw.handlerHeaders)
tw.wroteHeader = true tw.wroteHeader = true
tw.w.WriteHeader(code) tw.w.WriteHeader(code)
} }
func copyHeaders(dst, src http.Header) {
for k, v := range src {
dst[k] = v
}
}
func (tw *baseTimeoutWriter) timeout(err *apierrors.StatusError) { func (tw *baseTimeoutWriter) timeout(err *apierrors.StatusError) {
tw.mu.Lock() tw.mu.Lock()
defer tw.mu.Unlock() defer tw.mu.Unlock()

View File

@ -203,6 +203,52 @@ func TestTimeout(t *testing.T) {
} }
} }
func TestTimeoutHeaders(t *testing.T) {
origReallyCrash := runtime.ReallyCrash
runtime.ReallyCrash = false
defer func() {
runtime.ReallyCrash = origReallyCrash
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
withDeadline := func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
handler.ServeHTTP(w, req.WithContext(ctx))
})
}
ts := httptest.NewServer(
withDeadline(
WithTimeout(
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
h := w.Header()
// trigger the timeout
cancel()
// mutate response Headers
for j := 0; j < 1000; j++ {
h.Set("Test", "post")
}
}),
func(req *http.Request) (*http.Request, bool, func(), *apierrors.StatusError) {
return req, false, func() {}, apierrors.NewServerTimeout(schema.GroupResource{Group: "foo", Resource: "bar"}, "get", 0)
},
),
),
)
defer ts.Close()
res, err := http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusGatewayTimeout {
t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusServiceUnavailable)
}
res.Body.Close()
}
func captureStdErr() (func() string, func(), error) { func captureStdErr() (func() string, func(), error) {
var buf bytes.Buffer var buf bytes.Buffer
reader, writer, err := os.Pipe() reader, writer, err := os.Pipe()