apiserver: fix timeout handler

Protect access of the original writer. Panics if anything has wrote
into the original writer or the writer is hijacked when times out.
This commit is contained in:
Xiang Li 2016-07-25 20:43:47 -07:00
parent b39cde37c9
commit c995050ee3
2 changed files with 63 additions and 14 deletions

View File

@ -164,6 +164,8 @@ func RecoverPanics(handler http.Handler) http.Handler {
})
}
var errConnKilled = fmt.Errorf("kill connection/stream")
// TimeoutHandler returns an http.Handler that runs h with a timeout
// determined by timeoutFunc. The new http.Handler calls h.ServeHTTP to handle
// each request, but if a call runs for longer than its time limit, the
@ -188,11 +190,11 @@ func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
done := make(chan struct{}, 1)
done := make(chan struct{})
tw := newTimeoutWriter(w)
go func() {
t.handler.ServeHTTP(tw, r)
done <- struct{}{}
close(done)
}()
select {
case <-done:
@ -228,26 +230,38 @@ func newTimeoutWriter(w http.ResponseWriter) timeoutWriter {
type baseTimeoutWriter struct {
w http.ResponseWriter
mu sync.Mutex
timedOut bool
mu sync.Mutex
// if the timeout handler has timedout
timedOut bool
// if this timeout writer has wrote header
wroteHeader bool
hijacked bool
// if this timeout writer has been hijacked
hijacked bool
}
func (tw *baseTimeoutWriter) Header() http.Header {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut {
return http.Header{}
}
return tw.w.Header()
}
func (tw *baseTimeoutWriter) Write(p []byte) (int, error) {
tw.mu.Lock()
defer tw.mu.Unlock()
tw.wroteHeader = true
if tw.hijacked {
return 0, http.ErrHijacked
}
if tw.timedOut {
return 0, http.ErrHandlerTimeout
}
if tw.hijacked {
return 0, http.ErrHijacked
}
tw.wroteHeader = true
return tw.w.Write(p)
}
@ -255,6 +269,10 @@ func (tw *baseTimeoutWriter) Flush() {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut {
return
}
if flusher, ok := tw.w.(http.Flusher); ok {
flusher.Flush()
}
@ -263,9 +281,11 @@ func (tw *baseTimeoutWriter) Flush() {
func (tw *baseTimeoutWriter) WriteHeader(code int) {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut || tw.wroteHeader || tw.hijacked {
return
}
tw.wroteHeader = true
tw.w.WriteHeader(code)
}
@ -273,6 +293,12 @@ func (tw *baseTimeoutWriter) WriteHeader(code int) {
func (tw *baseTimeoutWriter) timeout(msg string) {
tw.mu.Lock()
defer tw.mu.Unlock()
tw.timedOut = true
// The timeout writer has not been used by the inner handler.
// We can safely timeout the HTTP request by sending by a timeout
// handler
if !tw.wroteHeader && !tw.hijacked {
tw.w.WriteHeader(http.StatusGatewayTimeout)
if msg != "" {
@ -281,17 +307,40 @@ func (tw *baseTimeoutWriter) timeout(msg string) {
enc := json.NewEncoder(tw.w)
enc.Encode(errors.NewServerTimeout(api.Resource(""), "", 0))
}
} else {
// The timeout writer has been used by the inner handler. There is
// no way to timeout the HTTP request at the point. We have to shutdown
// the connection for HTTP1 or reset stream for HTTP2.
//
// Note from: Brad Fitzpatrick
// if the ServeHTTP goroutine panics, that will do the best possible thing for both
// HTTP/1 and HTTP/2. In HTTP/1, assuming you're replying with at least HTTP/1.1 and
// you've already flushed the headers so it's using HTTP chunking, it'll kill the TCP
// connection immediately without a proper 0-byte EOF chunk, so the peer will recognize
// the response as bogus. In HTTP/2 the server will just RST_STREAM the stream, leaving
// the TCP connection open, but resetting the stream to the peer so it'll have an error,
// like the HTTP/1 case.
panic(errConnKilled)
}
tw.timedOut = true
}
func (tw *baseTimeoutWriter) closeNotify() <-chan bool {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut {
done := make(chan bool)
close(done)
return done
}
return tw.w.(http.CloseNotifier).CloseNotify()
}
func (tw *baseTimeoutWriter) hijack() (net.Conn, *bufio.ReadWriter, error) {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut {
return nil, nil, http.ErrHandlerTimeout
}

View File

@ -633,10 +633,10 @@ func (s *GenericAPIServer) Run(options *options.ServerRunOptions) {
}
if secureLocation != "" {
handler := apiserver.TimeoutHandler(s.Handler, longRunningTimeout)
handler := apiserver.TimeoutHandler(apiserver.RecoverPanics(s.Handler), longRunningTimeout)
secureServer := &http.Server{
Addr: secureLocation,
Handler: apiserver.MaxInFlightLimit(sem, longRunningRequestCheck, apiserver.RecoverPanics(handler)),
Handler: apiserver.MaxInFlightLimit(sem, longRunningRequestCheck, handler),
MaxHeaderBytes: 1 << 20,
TLSConfig: &tls.Config{
// Can't use SSLv3 because of POODLE and BEAST
@ -696,10 +696,10 @@ func (s *GenericAPIServer) Run(options *options.ServerRunOptions) {
}
}
handler := apiserver.TimeoutHandler(s.InsecureHandler, longRunningTimeout)
handler := apiserver.TimeoutHandler(apiserver.RecoverPanics(s.InsecureHandler), longRunningTimeout)
http := &http.Server{
Addr: insecureLocation,
Handler: apiserver.RecoverPanics(handler),
Handler: handler,
MaxHeaderBytes: 1 << 20,
}