apiserver: fix timeout handler

Protect access of the original writer. Panics if anything has wrote
into the original writer or the writer is hijacked when times out.
This commit is contained in:
Xiang Li 2016-07-25 20:43:47 -07:00
parent b39cde37c9
commit c995050ee3
2 changed files with 63 additions and 14 deletions

View File

@ -164,6 +164,8 @@ func RecoverPanics(handler http.Handler) http.Handler {
}) })
} }
var errConnKilled = fmt.Errorf("kill connection/stream")
// TimeoutHandler returns an http.Handler that runs h with a timeout // TimeoutHandler returns an http.Handler that runs h with a timeout
// determined by timeoutFunc. The new http.Handler calls h.ServeHTTP to handle // determined by timeoutFunc. The new http.Handler calls h.ServeHTTP to handle
// each request, but if a call runs for longer than its time limit, the // each request, but if a call runs for longer than its time limit, the
@ -188,11 +190,11 @@ func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
done := make(chan struct{}, 1) done := make(chan struct{})
tw := newTimeoutWriter(w) tw := newTimeoutWriter(w)
go func() { go func() {
t.handler.ServeHTTP(tw, r) t.handler.ServeHTTP(tw, r)
done <- struct{}{} close(done)
}() }()
select { select {
case <-done: case <-done:
@ -228,26 +230,38 @@ func newTimeoutWriter(w http.ResponseWriter) timeoutWriter {
type baseTimeoutWriter struct { type baseTimeoutWriter struct {
w http.ResponseWriter w http.ResponseWriter
mu sync.Mutex mu sync.Mutex
timedOut bool // if the timeout handler has timedout
timedOut bool
// if this timeout writer has wrote header
wroteHeader bool wroteHeader bool
hijacked bool // if this timeout writer has been hijacked
hijacked bool
} }
func (tw *baseTimeoutWriter) Header() http.Header { func (tw *baseTimeoutWriter) Header() http.Header {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut {
return http.Header{}
}
return tw.w.Header() return tw.w.Header()
} }
func (tw *baseTimeoutWriter) Write(p []byte) (int, error) { func (tw *baseTimeoutWriter) Write(p []byte) (int, error) {
tw.mu.Lock() tw.mu.Lock()
defer tw.mu.Unlock() defer tw.mu.Unlock()
tw.wroteHeader = true
if tw.hijacked {
return 0, http.ErrHijacked
}
if tw.timedOut { if tw.timedOut {
return 0, http.ErrHandlerTimeout return 0, http.ErrHandlerTimeout
} }
if tw.hijacked {
return 0, http.ErrHijacked
}
tw.wroteHeader = true
return tw.w.Write(p) return tw.w.Write(p)
} }
@ -255,6 +269,10 @@ func (tw *baseTimeoutWriter) Flush() {
tw.mu.Lock() tw.mu.Lock()
defer tw.mu.Unlock() defer tw.mu.Unlock()
if tw.timedOut {
return
}
if flusher, ok := tw.w.(http.Flusher); ok { if flusher, ok := tw.w.(http.Flusher); ok {
flusher.Flush() flusher.Flush()
} }
@ -263,9 +281,11 @@ func (tw *baseTimeoutWriter) Flush() {
func (tw *baseTimeoutWriter) WriteHeader(code int) { func (tw *baseTimeoutWriter) WriteHeader(code int) {
tw.mu.Lock() tw.mu.Lock()
defer tw.mu.Unlock() defer tw.mu.Unlock()
if tw.timedOut || tw.wroteHeader || tw.hijacked { if tw.timedOut || tw.wroteHeader || tw.hijacked {
return return
} }
tw.wroteHeader = true tw.wroteHeader = true
tw.w.WriteHeader(code) tw.w.WriteHeader(code)
} }
@ -273,6 +293,12 @@ func (tw *baseTimeoutWriter) WriteHeader(code int) {
func (tw *baseTimeoutWriter) timeout(msg string) { func (tw *baseTimeoutWriter) timeout(msg string) {
tw.mu.Lock() tw.mu.Lock()
defer tw.mu.Unlock() defer tw.mu.Unlock()
tw.timedOut = true
// The timeout writer has not been used by the inner handler.
// We can safely timeout the HTTP request by sending by a timeout
// handler
if !tw.wroteHeader && !tw.hijacked { if !tw.wroteHeader && !tw.hijacked {
tw.w.WriteHeader(http.StatusGatewayTimeout) tw.w.WriteHeader(http.StatusGatewayTimeout)
if msg != "" { if msg != "" {
@ -281,17 +307,40 @@ func (tw *baseTimeoutWriter) timeout(msg string) {
enc := json.NewEncoder(tw.w) enc := json.NewEncoder(tw.w)
enc.Encode(errors.NewServerTimeout(api.Resource(""), "", 0)) enc.Encode(errors.NewServerTimeout(api.Resource(""), "", 0))
} }
} else {
// The timeout writer has been used by the inner handler. There is
// no way to timeout the HTTP request at the point. We have to shutdown
// the connection for HTTP1 or reset stream for HTTP2.
//
// Note from: Brad Fitzpatrick
// if the ServeHTTP goroutine panics, that will do the best possible thing for both
// HTTP/1 and HTTP/2. In HTTP/1, assuming you're replying with at least HTTP/1.1 and
// you've already flushed the headers so it's using HTTP chunking, it'll kill the TCP
// connection immediately without a proper 0-byte EOF chunk, so the peer will recognize
// the response as bogus. In HTTP/2 the server will just RST_STREAM the stream, leaving
// the TCP connection open, but resetting the stream to the peer so it'll have an error,
// like the HTTP/1 case.
panic(errConnKilled)
} }
tw.timedOut = true
} }
func (tw *baseTimeoutWriter) closeNotify() <-chan bool { func (tw *baseTimeoutWriter) closeNotify() <-chan bool {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut {
done := make(chan bool)
close(done)
return done
}
return tw.w.(http.CloseNotifier).CloseNotify() return tw.w.(http.CloseNotifier).CloseNotify()
} }
func (tw *baseTimeoutWriter) hijack() (net.Conn, *bufio.ReadWriter, error) { func (tw *baseTimeoutWriter) hijack() (net.Conn, *bufio.ReadWriter, error) {
tw.mu.Lock() tw.mu.Lock()
defer tw.mu.Unlock() defer tw.mu.Unlock()
if tw.timedOut { if tw.timedOut {
return nil, nil, http.ErrHandlerTimeout return nil, nil, http.ErrHandlerTimeout
} }

View File

@ -633,10 +633,10 @@ func (s *GenericAPIServer) Run(options *options.ServerRunOptions) {
} }
if secureLocation != "" { if secureLocation != "" {
handler := apiserver.TimeoutHandler(s.Handler, longRunningTimeout) handler := apiserver.TimeoutHandler(apiserver.RecoverPanics(s.Handler), longRunningTimeout)
secureServer := &http.Server{ secureServer := &http.Server{
Addr: secureLocation, Addr: secureLocation,
Handler: apiserver.MaxInFlightLimit(sem, longRunningRequestCheck, apiserver.RecoverPanics(handler)), Handler: apiserver.MaxInFlightLimit(sem, longRunningRequestCheck, handler),
MaxHeaderBytes: 1 << 20, MaxHeaderBytes: 1 << 20,
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
// Can't use SSLv3 because of POODLE and BEAST // Can't use SSLv3 because of POODLE and BEAST
@ -696,10 +696,10 @@ func (s *GenericAPIServer) Run(options *options.ServerRunOptions) {
} }
} }
handler := apiserver.TimeoutHandler(s.InsecureHandler, longRunningTimeout) handler := apiserver.TimeoutHandler(apiserver.RecoverPanics(s.InsecureHandler), longRunningTimeout)
http := &http.Server{ http := &http.Server{
Addr: insecureLocation, Addr: insecureLocation,
Handler: apiserver.RecoverPanics(handler), Handler: handler,
MaxHeaderBytes: 1 << 20, MaxHeaderBytes: 1 << 20,
} }