diff --git a/pkg/genericapiserver/server/filters/BUILD b/pkg/genericapiserver/server/filters/BUILD index 8a16796f7ba..1b6063b45d0 100644 --- a/pkg/genericapiserver/server/filters/BUILD +++ b/pkg/genericapiserver/server/filters/BUILD @@ -20,10 +20,10 @@ go_library( ], tags = ["automanaged"], deps = [ - "//pkg/api:go_default_library", "//pkg/util:go_default_library", "//vendor:github.com/golang/glog", "//vendor:k8s.io/apimachinery/pkg/api/errors", + "//vendor:k8s.io/apimachinery/pkg/runtime/schema", "//vendor:k8s.io/apimachinery/pkg/util/runtime", "//vendor:k8s.io/apimachinery/pkg/util/sets", "//vendor:k8s.io/apiserver/pkg/endpoints/request", @@ -43,6 +43,7 @@ go_test( deps = [ "//pkg/genericapiserver/endpoints/filters:go_default_library", "//vendor:k8s.io/apimachinery/pkg/api/errors", + "//vendor:k8s.io/apimachinery/pkg/runtime/schema", "//vendor:k8s.io/apimachinery/pkg/util/sets", "//vendor:k8s.io/apiserver/pkg/endpoints/request", ], diff --git a/pkg/genericapiserver/server/filters/timeout.go b/pkg/genericapiserver/server/filters/timeout.go index 69617f1520f..9232fcb516a 100644 --- a/pkg/genericapiserver/server/filters/timeout.go +++ b/pkg/genericapiserver/server/filters/timeout.go @@ -25,9 +25,9 @@ import ( "sync" "time" - "k8s.io/apimachinery/pkg/api/errors" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime/schema" apirequest "k8s.io/apiserver/pkg/endpoints/request" - "k8s.io/kubernetes/pkg/api" ) const globalTimeout = time.Minute @@ -39,22 +39,24 @@ func WithTimeoutForNonLongRunningRequests(handler http.Handler, requestContextMa if longRunning == nil { return handler } - timeoutFunc := func(req *http.Request) (<-chan time.Time, string) { + timeoutFunc := func(req *http.Request) (<-chan time.Time, *apierrors.StatusError) { // TODO unify this with apiserver.MaxInFlightLimit ctx, ok := requestContextMapper.Get(req) if !ok { - return time.After(globalTimeout), "" + // if this happens, the handler chain isn't setup correctly because there is no context mapper + return time.After(globalTimeout), apierrors.NewInternalError(fmt.Errorf("no context found for request during timeout")) } requestInfo, ok := apirequest.RequestInfoFrom(ctx) if !ok { - return time.After(globalTimeout), "" + // if this happens, the handler chain isn't setup correctly because there is no request info + return time.After(globalTimeout), apierrors.NewInternalError(fmt.Errorf("no request info found for request during timeout")) } if longRunning(req, requestInfo) { - return nil, "" + return nil, nil } - return time.After(globalTimeout), "" + return time.After(globalTimeout), apierrors.NewServerTimeout(schema.GroupResource{Group: requestInfo.APIGroup, Resource: requestInfo.Resource}, requestInfo.Verb, 0) } return WithTimeout(handler, timeoutFunc) } @@ -67,17 +69,17 @@ func WithTimeoutForNonLongRunningRequests(handler http.Handler, requestContextMa // the handler times out, writes by h to its http.ResponseWriter will return // http.ErrHandlerTimeout. If timeoutFunc returns a nil timeout channel, no // timeout will be enforced. -func WithTimeout(h http.Handler, timeoutFunc func(*http.Request) (timeout <-chan time.Time, msg string)) http.Handler { +func WithTimeout(h http.Handler, timeoutFunc func(*http.Request) (timeout <-chan time.Time, err *apierrors.StatusError)) http.Handler { return &timeoutHandler{h, timeoutFunc} } type timeoutHandler struct { handler http.Handler - timeout func(*http.Request) (<-chan time.Time, string) + timeout func(*http.Request) (<-chan time.Time, *apierrors.StatusError) } func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - after, msg := t.timeout(r) + after, err := t.timeout(r) if after == nil { t.handler.ServeHTTP(w, r) return @@ -93,13 +95,13 @@ func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case <-done: return case <-after: - tw.timeout(msg) + tw.timeout(err) } } type timeoutWriter interface { http.ResponseWriter - timeout(string) + timeout(*apierrors.StatusError) } func newTimeoutWriter(w http.ResponseWriter) timeoutWriter { @@ -183,7 +185,7 @@ func (tw *baseTimeoutWriter) WriteHeader(code int) { tw.w.WriteHeader(code) } -func (tw *baseTimeoutWriter) timeout(msg string) { +func (tw *baseTimeoutWriter) timeout(err *apierrors.StatusError) { tw.mu.Lock() defer tw.mu.Unlock() @@ -194,12 +196,8 @@ func (tw *baseTimeoutWriter) timeout(msg string) { // handler if !tw.wroteHeader && !tw.hijacked { tw.w.WriteHeader(http.StatusGatewayTimeout) - if msg != "" { - tw.w.Write([]byte(msg)) - } else { - enc := json.NewEncoder(tw.w) - enc.Encode(errors.NewServerTimeout(api.Resource(""), "", 0)) - } + enc := json.NewEncoder(tw.w) + enc.Encode(err) } else { // The timeout writer has been used by the inner handler. There is // no way to timeout the HTTP request at the point. We have to shutdown diff --git a/pkg/genericapiserver/server/filters/timeout_test.go b/pkg/genericapiserver/server/filters/timeout_test.go index 989ce331c2f..4498122414c 100644 --- a/pkg/genericapiserver/server/filters/timeout_test.go +++ b/pkg/genericapiserver/server/filters/timeout_test.go @@ -22,6 +22,10 @@ import ( "net/http/httptest" "testing" "time" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime/schema" + "strings" ) func TestTimeout(t *testing.T) { @@ -29,7 +33,7 @@ func TestTimeout(t *testing.T) { writeErrors := make(chan error, 1) timeout := make(chan time.Time, 1) resp := "test response" - timeoutResp := "test timeout" + timeoutErr := apierrors.NewServerTimeout(schema.GroupResource{Group: "foo", Resource: "bar"}, "get", 0) ts := httptest.NewServer(WithTimeout(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { @@ -37,8 +41,8 @@ func TestTimeout(t *testing.T) { _, err := w.Write([]byte(resp)) writeErrors <- err }), - func(*http.Request) (<-chan time.Time, string) { - return timeout, timeoutResp + func(*http.Request) (<-chan time.Time, *apierrors.StatusError) { + return timeout, timeoutErr })) defer ts.Close() @@ -69,8 +73,8 @@ func TestTimeout(t *testing.T) { t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusServiceUnavailable) } body, _ = ioutil.ReadAll(res.Body) - if string(body) != timeoutResp { - t.Errorf("got body %q; expected %q", string(body), timeoutResp) + if !strings.Contains(string(body), timeoutErr.Error()) { + t.Errorf("got body %q; expected it to contain %q", string(body), timeoutErr.Error()) } // Now try to send a response