mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-21 10:51:29 +00:00
Merge pull request #107452 from liggitt/timeout_headers
Fix header mutation race in timeout filter
This commit is contained in:
commit
3cec1d1a13
@ -149,7 +149,7 @@ type timeoutWriter interface {
|
||||
}
|
||||
|
||||
func newTimeoutWriter(w http.ResponseWriter) (timeoutWriter, http.ResponseWriter) {
|
||||
base := &baseTimeoutWriter{w: w}
|
||||
base := &baseTimeoutWriter{w: w, handlerHeaders: w.Header().Clone()}
|
||||
wrapped := responsewriter.WrapForHTTP1Or2(base)
|
||||
|
||||
return base, wrapped
|
||||
@ -161,6 +161,9 @@ var _ responsewriter.UserProvidedDecorator = &baseTimeoutWriter{}
|
||||
type baseTimeoutWriter struct {
|
||||
w http.ResponseWriter
|
||||
|
||||
// headers written by the normal handler
|
||||
handlerHeaders http.Header
|
||||
|
||||
mu sync.Mutex
|
||||
// if the timeout handler has timeout
|
||||
timedOut bool
|
||||
@ -182,7 +185,7 @@ func (tw *baseTimeoutWriter) Header() http.Header {
|
||||
return http.Header{}
|
||||
}
|
||||
|
||||
return tw.w.Header()
|
||||
return tw.handlerHeaders
|
||||
}
|
||||
|
||||
func (tw *baseTimeoutWriter) Write(p []byte) (int, error) {
|
||||
@ -196,7 +199,10 @@ func (tw *baseTimeoutWriter) Write(p []byte) (int, error) {
|
||||
return 0, http.ErrHijacked
|
||||
}
|
||||
|
||||
tw.wroteHeader = true
|
||||
if !tw.wroteHeader {
|
||||
copyHeaders(tw.w.Header(), tw.handlerHeaders)
|
||||
tw.wroteHeader = true
|
||||
}
|
||||
return tw.w.Write(p)
|
||||
}
|
||||
|
||||
@ -221,10 +227,17 @@ func (tw *baseTimeoutWriter) WriteHeader(code int) {
|
||||
return
|
||||
}
|
||||
|
||||
copyHeaders(tw.w.Header(), tw.handlerHeaders)
|
||||
tw.wroteHeader = true
|
||||
tw.w.WriteHeader(code)
|
||||
}
|
||||
|
||||
func copyHeaders(dst, src http.Header) {
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *baseTimeoutWriter) timeout(err *apierrors.StatusError) {
|
||||
tw.mu.Lock()
|
||||
defer tw.mu.Unlock()
|
||||
|
@ -203,6 +203,52 @@ func TestTimeout(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutHeaders(t *testing.T) {
|
||||
origReallyCrash := runtime.ReallyCrash
|
||||
runtime.ReallyCrash = false
|
||||
defer func() {
|
||||
runtime.ReallyCrash = origReallyCrash
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
withDeadline := func(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
handler.ServeHTTP(w, req.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(
|
||||
withDeadline(
|
||||
WithTimeout(
|
||||
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
h := w.Header()
|
||||
// trigger the timeout
|
||||
cancel()
|
||||
// mutate response Headers
|
||||
for j := 0; j < 1000; j++ {
|
||||
h.Set("Test", "post")
|
||||
}
|
||||
}),
|
||||
func(req *http.Request) (*http.Request, bool, func(), *apierrors.StatusError) {
|
||||
return req, false, func() {}, apierrors.NewServerTimeout(schema.GroupResource{Group: "foo", Resource: "bar"}, "get", 0)
|
||||
},
|
||||
),
|
||||
),
|
||||
)
|
||||
defer ts.Close()
|
||||
|
||||
res, err := http.Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if res.StatusCode != http.StatusGatewayTimeout {
|
||||
t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusServiceUnavailable)
|
||||
}
|
||||
res.Body.Close()
|
||||
}
|
||||
|
||||
func captureStdErr() (func() string, func(), error) {
|
||||
var buf bytes.Buffer
|
||||
reader, writer, err := os.Pipe()
|
||||
|
Loading…
Reference in New Issue
Block a user