mirror of
https://github.com/k3s-io/kubernetes.git
synced 2025-07-20 10:20:51 +00:00
Merge pull request #107452 from liggitt/timeout_headers
Fix header mutation race in timeout filter
This commit is contained in:
commit
3cec1d1a13
@ -149,7 +149,7 @@ type timeoutWriter interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newTimeoutWriter(w http.ResponseWriter) (timeoutWriter, http.ResponseWriter) {
|
func newTimeoutWriter(w http.ResponseWriter) (timeoutWriter, http.ResponseWriter) {
|
||||||
base := &baseTimeoutWriter{w: w}
|
base := &baseTimeoutWriter{w: w, handlerHeaders: w.Header().Clone()}
|
||||||
wrapped := responsewriter.WrapForHTTP1Or2(base)
|
wrapped := responsewriter.WrapForHTTP1Or2(base)
|
||||||
|
|
||||||
return base, wrapped
|
return base, wrapped
|
||||||
@ -161,6 +161,9 @@ var _ responsewriter.UserProvidedDecorator = &baseTimeoutWriter{}
|
|||||||
type baseTimeoutWriter struct {
|
type baseTimeoutWriter struct {
|
||||||
w http.ResponseWriter
|
w http.ResponseWriter
|
||||||
|
|
||||||
|
// headers written by the normal handler
|
||||||
|
handlerHeaders http.Header
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
// if the timeout handler has timeout
|
// if the timeout handler has timeout
|
||||||
timedOut bool
|
timedOut bool
|
||||||
@ -182,7 +185,7 @@ func (tw *baseTimeoutWriter) Header() http.Header {
|
|||||||
return http.Header{}
|
return http.Header{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return tw.w.Header()
|
return tw.handlerHeaders
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tw *baseTimeoutWriter) Write(p []byte) (int, error) {
|
func (tw *baseTimeoutWriter) Write(p []byte) (int, error) {
|
||||||
@ -196,7 +199,10 @@ func (tw *baseTimeoutWriter) Write(p []byte) (int, error) {
|
|||||||
return 0, http.ErrHijacked
|
return 0, http.ErrHijacked
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !tw.wroteHeader {
|
||||||
|
copyHeaders(tw.w.Header(), tw.handlerHeaders)
|
||||||
tw.wroteHeader = true
|
tw.wroteHeader = true
|
||||||
|
}
|
||||||
return tw.w.Write(p)
|
return tw.w.Write(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,10 +227,17 @@ func (tw *baseTimeoutWriter) WriteHeader(code int) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
copyHeaders(tw.w.Header(), tw.handlerHeaders)
|
||||||
tw.wroteHeader = true
|
tw.wroteHeader = true
|
||||||
tw.w.WriteHeader(code)
|
tw.w.WriteHeader(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func copyHeaders(dst, src http.Header) {
|
||||||
|
for k, v := range src {
|
||||||
|
dst[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (tw *baseTimeoutWriter) timeout(err *apierrors.StatusError) {
|
func (tw *baseTimeoutWriter) timeout(err *apierrors.StatusError) {
|
||||||
tw.mu.Lock()
|
tw.mu.Lock()
|
||||||
defer tw.mu.Unlock()
|
defer tw.mu.Unlock()
|
||||||
|
@ -203,6 +203,52 @@ func TestTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTimeoutHeaders(t *testing.T) {
|
||||||
|
origReallyCrash := runtime.ReallyCrash
|
||||||
|
runtime.ReallyCrash = false
|
||||||
|
defer func() {
|
||||||
|
runtime.ReallyCrash = origReallyCrash
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
withDeadline := func(handler http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
handler.ServeHTTP(w, req.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := httptest.NewServer(
|
||||||
|
withDeadline(
|
||||||
|
WithTimeout(
|
||||||
|
http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
h := w.Header()
|
||||||
|
// trigger the timeout
|
||||||
|
cancel()
|
||||||
|
// mutate response Headers
|
||||||
|
for j := 0; j < 1000; j++ {
|
||||||
|
h.Set("Test", "post")
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
func(req *http.Request) (*http.Request, bool, func(), *apierrors.StatusError) {
|
||||||
|
return req, false, func() {}, apierrors.NewServerTimeout(schema.GroupResource{Group: "foo", Resource: "bar"}, "get", 0)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
res, err := http.Get(ts.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if res.StatusCode != http.StatusGatewayTimeout {
|
||||||
|
t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusServiceUnavailable)
|
||||||
|
}
|
||||||
|
res.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func captureStdErr() (func() string, func(), error) {
|
func captureStdErr() (func() string, func(), error) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
reader, writer, err := os.Pipe()
|
reader, writer, err := os.Pipe()
|
||||||
|
Loading…
Reference in New Issue
Block a user