plumb context with request deadline

- as soon as a request is received by the apiserver, determine the
timeout of the request and set a new request context with the deadline.
- the timeout filter that times out non-long-running requests should
use the request context as opposed to a fixed 60s wait today.
- admission and storage layer uses the same request context with the
deadline specified.
This commit is contained in:
Abu Kashem 2020-10-30 16:30:05 -04:00
parent d20e3246ba
commit 83f869ee13
No known key found for this signature in database
GPG Key ID: 76146D1A14E658ED
16 changed files with 392 additions and 67 deletions

View File

@ -39,7 +39,14 @@ func BuildInsecureHandlerChain(apiHandler http.Handler, c *server.Config) http.H
handler = genericapifilters.WithAudit(handler, c.AuditBackend, c.AuditPolicyChecker, c.LongRunningFunc) handler = genericapifilters.WithAudit(handler, c.AuditBackend, c.AuditPolicyChecker, c.LongRunningFunc)
handler = genericapifilters.WithAuthentication(handler, server.InsecureSuperuser{}, nil, nil) handler = genericapifilters.WithAuthentication(handler, server.InsecureSuperuser{}, nil, nil)
handler = genericfilters.WithCORS(handler, c.CorsAllowedOriginList, nil, nil, nil, "true") handler = genericfilters.WithCORS(handler, c.CorsAllowedOriginList, nil, nil, nil, "true")
handler = genericfilters.WithTimeoutForNonLongRunningRequests(handler, c.LongRunningFunc, c.RequestTimeout)
// WithTimeoutForNonLongRunningRequests will call the rest of the request handling in a go-routine with the
// context with deadline. The go-routine can keep running, while the timeout logic will return a timeout to the client.
handler = genericfilters.WithTimeoutForNonLongRunningRequests(handler, c.LongRunningFunc)
// WithRequestDeadline sets a deadline for the request context appropriately
handler = genericapifilters.WithRequestDeadline(handler, c.LongRunningFunc, c.RequestTimeout)
handler = genericfilters.WithWaitGroup(handler, c.LongRunningFunc, c.HandlerChainWaitGroup) handler = genericfilters.WithWaitGroup(handler, c.LongRunningFunc, c.HandlerChainWaitGroup)
handler = genericapifilters.WithRequestInfo(handler, requestInfoResolver) handler = genericapifilters.WithRequestInfo(handler, requestInfoResolver)
handler = genericapifilters.WithWarningRecorder(handler) handler = genericapifilters.WithWarningRecorder(handler)

View File

@ -53,6 +53,7 @@ go_test(
"//staging/src/k8s.io/apiserver/pkg/endpoints/testing:go_default_library", "//staging/src/k8s.io/apiserver/pkg/endpoints/testing:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/features:go_default_library", "//staging/src/k8s.io/apiserver/pkg/features:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/registry/rest:go_default_library", "//staging/src/k8s.io/apiserver/pkg/registry/rest:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/server/filters:go_default_library",
"//staging/src/k8s.io/apiserver/pkg/util/feature:go_default_library", "//staging/src/k8s.io/apiserver/pkg/util/feature:go_default_library",
"//staging/src/k8s.io/client-go/dynamic:go_default_library", "//staging/src/k8s.io/client-go/dynamic:go_default_library",
"//staging/src/k8s.io/client-go/rest:go_default_library", "//staging/src/k8s.io/client-go/rest:go_default_library",

View File

@ -73,6 +73,7 @@ import (
genericapitesting "k8s.io/apiserver/pkg/endpoints/testing" genericapitesting "k8s.io/apiserver/pkg/endpoints/testing"
"k8s.io/apiserver/pkg/features" "k8s.io/apiserver/pkg/features"
"k8s.io/apiserver/pkg/registry/rest" "k8s.io/apiserver/pkg/registry/rest"
"k8s.io/apiserver/pkg/server/filters"
utilfeature "k8s.io/apiserver/pkg/util/feature" utilfeature "k8s.io/apiserver/pkg/util/feature"
featuregatetesting "k8s.io/component-base/featuregate/testing" featuregatetesting "k8s.io/component-base/featuregate/testing"
) )
@ -286,6 +287,7 @@ func handleInternal(storage map[string]rest.Storage, admissionControl admission.
// simplified long-running check // simplified long-running check
return requestInfo.Verb == "watch" || requestInfo.Verb == "proxy" return requestInfo.Verb == "watch" || requestInfo.Verb == "proxy"
}) })
handler = genericapifilters.WithRequestDeadline(handler, testLongRunningCheck, 60*time.Second)
handler = genericapifilters.WithRequestInfo(handler, testRequestInfoResolver()) handler = genericapifilters.WithRequestInfo(handler, testRequestInfoResolver())
return &defaultAPIServer{handler, container} return &defaultAPIServer{handler, container}
@ -298,6 +300,11 @@ func testRequestInfoResolver() *request.RequestInfoFactory {
} }
} }
var testLongRunningCheck = filters.BasicLongRunningRequestCheck(
sets.NewString("watch", "proxy"),
sets.NewString("attach", "exec", "proxy", "log", "portforward"),
)
func TestSimpleSetupRight(t *testing.T) { func TestSimpleSetupRight(t *testing.T) {
s := &genericapitesting.Simple{ObjectMeta: metav1.ObjectMeta{Name: "aName"}} s := &genericapitesting.Simple{ObjectMeta: metav1.ObjectMeta{Name: "aName"}}
wire, err := runtime.Encode(codec, s) wire, err := runtime.Encode(codec, s)

View File

@ -16,6 +16,7 @@ go_test(
"cachecontrol_test.go", "cachecontrol_test.go",
"impersonation_test.go", "impersonation_test.go",
"metrics_test.go", "metrics_test.go",
"request_deadline_test.go",
"request_received_time_test.go", "request_received_time_test.go",
"requestinfo_test.go", "requestinfo_test.go",
"warning_test.go", "warning_test.go",
@ -56,6 +57,7 @@ go_library(
"doc.go", "doc.go",
"impersonation.go", "impersonation.go",
"metrics.go", "metrics.go",
"request_deadline.go",
"request_received_time.go", "request_received_time.go",
"requestinfo.go", "requestinfo.go",
"storageversion.go", "storageversion.go",

View File

@ -0,0 +1,105 @@
/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"context"
"errors"
"fmt"
"k8s.io/klog/v2"
"net/http"
"time"
"k8s.io/apiserver/pkg/endpoints/request"
)
var (
// The 'timeout' query parameter in the request URL has an invalid timeout specifier
errInvalidTimeoutInURL = errors.New("invalid timeout specified in the request URL")
// The timeout specified in the request URL exceeds the global maximum timeout allowed by the apiserver.
errTimeoutExceedsMaximumAllowed = errors.New("timeout specified in the request URL exceeds the maximum timeout allowed by the server")
)
// WithRequestDeadline determines the deadline of the given request and sets a new context with the appropriate timeout.
// requestTimeoutMaximum specifies the default request timeout value
func WithRequestDeadline(handler http.Handler, longRunning request.LongRunningRequestCheck, requestTimeoutMaximum time.Duration) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
requestInfo, ok := request.RequestInfoFrom(ctx)
if !ok {
handleError(w, req, http.StatusInternalServerError, fmt.Errorf("no RequestInfo found in context, handler chain must be wrong"))
return
}
if longRunning(req, requestInfo) {
handler.ServeHTTP(w, req)
return
}
userSpecifiedTimeout, ok, err := parseTimeout(req)
if err != nil {
statusCode := http.StatusInternalServerError
if err == errInvalidTimeoutInURL {
statusCode = http.StatusBadRequest
}
handleError(w, req, statusCode, err)
return
}
timeout := requestTimeoutMaximum
if ok {
if userSpecifiedTimeout > requestTimeoutMaximum {
handleError(w, req, http.StatusBadRequest, errTimeoutExceedsMaximumAllowed)
return
}
timeout = userSpecifiedTimeout
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
req = req.WithContext(ctx)
handler.ServeHTTP(w, req)
})
}
// parseTimeout parses the given HTTP request URL and extracts the timeout query parameter
// value if specified by the user.
// If a timeout is not specified the function returns false and err is set to nil
// If the value specified is malformed then the function returns false and err is set
func parseTimeout(req *http.Request) (time.Duration, bool, error) {
value := req.URL.Query().Get("timeout")
if value == "" {
return 0, false, nil
}
timeout, err := time.ParseDuration(value)
if err != nil {
return 0, false, errInvalidTimeoutInURL
}
return timeout, true, nil
}
func handleError(w http.ResponseWriter, r *http.Request, code int, err error) {
errorMsg := fmt.Sprintf("Error - %s: %#v", err.Error(), r.RequestURI)
http.Error(w, errorMsg, code)
klog.Errorf(err.Error())
}

View File

@ -0,0 +1,192 @@
/*
Copyright 2020 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"k8s.io/apiserver/pkg/endpoints/request"
)
func TestParseTimeout(t *testing.T) {
tests := []struct {
name string
url string
expected bool
timeoutExpected time.Duration
errExpected error
}{
{
name: "the user does not specify a timeout",
url: "/api/v1/namespaces",
},
{
name: "the user specifies a valid timeout",
url: "/api/v1/namespaces?timeout=10s",
expected: true,
timeoutExpected: 10 * time.Second,
},
{
name: "the use specifies an invalid timeout",
url: "/api/v1/namespaces?timeout=foo",
errExpected: errInvalidTimeoutInURL,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, test.url, nil)
if err != nil {
t.Fatalf("failed to create new http request - %v", err)
}
timeoutGot, ok, err := parseTimeout(request)
if test.expected != ok {
t.Errorf("expected: %t, but got: %t", test.expected, ok)
}
if test.errExpected != err {
t.Errorf("expected err: %v, but got: %v", test.errExpected, err)
}
if test.timeoutExpected != timeoutGot {
t.Errorf("expected timeout: %s, but got: %s", test.timeoutExpected, timeoutGot)
}
})
}
}
func TestWithRequestDeadline(t *testing.T) {
const requestTimeoutMaximum = 60 * time.Second
tests := []struct {
name string
requestURL string
longRunning bool
hasDeadlineExpected bool
deadlineExpected time.Duration
handlerCallCountExpected int
statusCodeExpected int
}{
{
name: "the user specifies a valid request timeout",
requestURL: "/api/v1/namespaces?timeout=15s",
longRunning: false,
handlerCallCountExpected: 1,
hasDeadlineExpected: true,
deadlineExpected: 14 * time.Second, // to account for the delay in verification
statusCodeExpected: http.StatusOK,
},
{
name: "the user does not specify any request timeout, default deadline is expected to be set",
requestURL: "/api/v1/namespaces?timeout=",
longRunning: false,
handlerCallCountExpected: 1,
hasDeadlineExpected: true,
deadlineExpected: requestTimeoutMaximum - time.Second, // to account for the delay in verification
statusCodeExpected: http.StatusOK,
},
{
name: "the request is long running, no deadline is expected to be set",
requestURL: "/api/v1/namespaces?timeout=10s",
longRunning: true,
hasDeadlineExpected: false,
handlerCallCountExpected: 1,
statusCodeExpected: http.StatusOK,
},
{
name: "the timeout specified is malformed, the request is aborted with HTTP 400",
requestURL: "/api/v1/namespaces?timeout=foo",
longRunning: false,
statusCodeExpected: http.StatusBadRequest,
},
{
name: "the timeout specified exceeds the maximum deadline allowed, the request is aborted with HTTP 400",
requestURL: fmt.Sprintf("/api/v1/namespaces?timeout=%s", requestTimeoutMaximum+time.Second),
longRunning: false,
statusCodeExpected: http.StatusBadRequest,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
callCount int
hasDeadlineGot bool
deadlineGot time.Duration
)
handler := http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
callCount++
deadlineGot, hasDeadlineGot = deadline(req)
})
withDeadline := WithRequestDeadline(
handler, func(_ *http.Request, _ *request.RequestInfo) bool { return test.longRunning }, requestTimeoutMaximum)
withDeadline = WithRequestInfo(withDeadline, &fakeRequestResolver{})
testRequest, err := http.NewRequest(http.MethodGet, test.requestURL, nil)
if err != nil {
t.Fatalf("failed to create new http request - %v", err)
}
// make sure a default request does not have any deadline set
remaning, ok := deadline(testRequest)
if ok {
t.Fatalf("test setup failed, expected the new HTTP request context to have no deadline but got: %s", remaning)
}
w := httptest.NewRecorder()
withDeadline.ServeHTTP(w, testRequest)
if test.handlerCallCountExpected != callCount {
t.Errorf("expected the request handler to be invoked %d times, but was actually invoked %d times", test.handlerCallCountExpected, callCount)
}
if test.hasDeadlineExpected != hasDeadlineGot {
t.Errorf("expected the request context to have deadline set: %t but got: %t", test.hasDeadlineExpected, hasDeadlineGot)
}
deadlineGot = deadlineGot.Truncate(time.Second)
if test.deadlineExpected != deadlineGot {
t.Errorf("expected a request context with a deadline of %s but got: %s", test.deadlineExpected, deadlineGot)
}
statusCodeGot := w.Result().StatusCode
if test.statusCodeExpected != statusCodeGot {
t.Errorf("expected status code %d but got: %d", test.statusCodeExpected, statusCodeGot)
}
})
}
}
type fakeRequestResolver struct{}
func (r fakeRequestResolver) NewRequestInfo(req *http.Request) (*request.RequestInfo, error) {
return &request.RequestInfo{}, nil
}
func deadline(r *http.Request) (time.Duration, bool) {
if deadline, ok := r.Context().Deadline(); ok {
remaining := time.Until(deadline)
return remaining, ok
}
return 0, false
}

View File

@ -57,9 +57,6 @@ func createHandler(r rest.NamedCreater, scope *RequestScope, admit admission.Int
return return
} }
// TODO: we either want to remove timeout or document it (if we document, move timeout out of this function and declare it in api_installer)
timeout := parseTimeout(req.URL.Query().Get("timeout"))
namespace, name, err := scope.Namer.Name(req) namespace, name, err := scope.Namer.Name(req)
if err != nil { if err != nil {
if includeName { if includeName {
@ -76,7 +73,7 @@ func createHandler(r rest.NamedCreater, scope *RequestScope, admit admission.Int
} }
} }
ctx, cancel := context.WithTimeout(req.Context(), timeout) ctx, cancel := context.WithTimeout(req.Context(), requestTimeout)
defer cancel() defer cancel()
outputMediaType, _, err := negotiation.NegotiateOutputMediaType(req, scope.Serializer, scope) outputMediaType, _, err := negotiation.NegotiateOutputMediaType(req, scope.Serializer, scope)
if err != nil { if err != nil {
@ -155,7 +152,7 @@ func createHandler(r rest.NamedCreater, scope *RequestScope, admit admission.Int
options, options,
) )
} }
result, err := finishRequest(timeout, func() (runtime.Object, error) { result, err := finishRequest(ctx, func() (runtime.Object, error) {
if scope.FieldManager != nil { if scope.FieldManager != nil {
liveObj, err := scope.Creater.New(scope.Kind) liveObj, err := scope.Creater.New(scope.Kind)
if err != nil { if err != nil {

View File

@ -54,15 +54,12 @@ func DeleteResource(r rest.GracefulDeleter, allowsOptions bool, scope *RequestSc
return return
} }
// TODO: we either want to remove timeout or document it (if we document, move timeout out of this function and declare it in api_installer)
timeout := parseTimeout(req.URL.Query().Get("timeout"))
namespace, name, err := scope.Namer.Name(req) namespace, name, err := scope.Namer.Name(req)
if err != nil { if err != nil {
scope.err(err, w, req) scope.err(err, w, req)
return return
} }
ctx, cancel := context.WithTimeout(req.Context(), timeout) ctx, cancel := context.WithTimeout(req.Context(), requestTimeout)
defer cancel() defer cancel()
ctx = request.WithNamespace(ctx, namespace) ctx = request.WithNamespace(ctx, namespace)
ae := request.AuditEventFrom(ctx) ae := request.AuditEventFrom(ctx)
@ -123,7 +120,7 @@ func DeleteResource(r rest.GracefulDeleter, allowsOptions bool, scope *RequestSc
wasDeleted := true wasDeleted := true
userInfo, _ := request.UserFrom(ctx) userInfo, _ := request.UserFrom(ctx)
staticAdmissionAttrs := admission.NewAttributesRecord(nil, nil, scope.Kind, namespace, name, scope.Resource, scope.Subresource, admission.Delete, options, dryrun.IsDryRun(options.DryRun), userInfo) staticAdmissionAttrs := admission.NewAttributesRecord(nil, nil, scope.Kind, namespace, name, scope.Resource, scope.Subresource, admission.Delete, options, dryrun.IsDryRun(options.DryRun), userInfo)
result, err := finishRequest(timeout, func() (runtime.Object, error) { result, err := finishRequest(ctx, func() (runtime.Object, error) {
obj, deleted, err := r.Delete(ctx, name, rest.AdmissionToValidateObjectDeleteFunc(admit, staticAdmissionAttrs, scope), options) obj, deleted, err := r.Delete(ctx, name, rest.AdmissionToValidateObjectDeleteFunc(admit, staticAdmissionAttrs, scope), options)
wasDeleted = deleted wasDeleted = deleted
return obj, err return obj, err
@ -172,16 +169,13 @@ func DeleteCollection(r rest.CollectionDeleter, checkBody bool, scope *RequestSc
return return
} }
// TODO: we either want to remove timeout or document it (if we document, move timeout out of this function and declare it in api_installer)
timeout := parseTimeout(req.URL.Query().Get("timeout"))
namespace, err := scope.Namer.Namespace(req) namespace, err := scope.Namer.Namespace(req)
if err != nil { if err != nil {
scope.err(err, w, req) scope.err(err, w, req)
return return
} }
ctx, cancel := context.WithTimeout(req.Context(), timeout) ctx, cancel := context.WithTimeout(req.Context(), requestTimeout)
defer cancel() defer cancel()
ctx = request.WithNamespace(ctx, namespace) ctx = request.WithNamespace(ctx, namespace)
ae := request.AuditEventFrom(ctx) ae := request.AuditEventFrom(ctx)
@ -265,7 +259,7 @@ func DeleteCollection(r rest.CollectionDeleter, checkBody bool, scope *RequestSc
admit = admission.WithAudit(admit, ae) admit = admission.WithAudit(admit, ae)
userInfo, _ := request.UserFrom(ctx) userInfo, _ := request.UserFrom(ctx)
staticAdmissionAttrs := admission.NewAttributesRecord(nil, nil, scope.Kind, namespace, "", scope.Resource, scope.Subresource, admission.Delete, options, dryrun.IsDryRun(options.DryRun), userInfo) staticAdmissionAttrs := admission.NewAttributesRecord(nil, nil, scope.Kind, namespace, "", scope.Resource, scope.Subresource, admission.Delete, options, dryrun.IsDryRun(options.DryRun), userInfo)
result, err := finishRequest(timeout, func() (runtime.Object, error) { result, err := finishRequest(ctx, func() (runtime.Object, error) {
return r.DeleteCollection(ctx, rest.AdmissionToValidateObjectDeleteFunc(admit, staticAdmissionAttrs, scope), options, &listOptions) return r.DeleteCollection(ctx, rest.AdmissionToValidateObjectDeleteFunc(admit, staticAdmissionAttrs, scope), options, &listOptions)
}) })
if err != nil { if err != nil {

View File

@ -84,18 +84,13 @@ func PatchResource(r rest.Patcher, scope *RequestScope, admit admission.Interfac
return return
} }
// TODO: we either want to remove timeout or document it (if we
// document, move timeout out of this function and declare it in
// api_installer)
timeout := parseTimeout(req.URL.Query().Get("timeout"))
namespace, name, err := scope.Namer.Name(req) namespace, name, err := scope.Namer.Name(req)
if err != nil { if err != nil {
scope.err(err, w, req) scope.err(err, w, req)
return return
} }
ctx, cancel := context.WithTimeout(req.Context(), timeout) ctx, cancel := context.WithTimeout(req.Context(), requestTimeout)
defer cancel() defer cancel()
ctx = request.WithNamespace(ctx, namespace) ctx = request.WithNamespace(ctx, namespace)
@ -208,7 +203,6 @@ func PatchResource(r rest.Patcher, scope *RequestScope, admit admission.Interfac
codec: codec, codec: codec,
timeout: timeout,
options: options, options: options,
restPatcher: r, restPatcher: r,
@ -271,7 +265,6 @@ type patcher struct {
codec runtime.Codec codec runtime.Codec
timeout time.Duration
options *metav1.PatchOptions options *metav1.PatchOptions
// Operation information // Operation information
@ -586,7 +579,7 @@ func (p *patcher) patchResource(ctx context.Context, scope *RequestScope) (runti
wasCreated = created wasCreated = created
return updateObject, updateErr return updateObject, updateErr
} }
result, err := finishRequest(p.timeout, func() (runtime.Object, error) { result, err := finishRequest(ctx, func() (runtime.Object, error) {
result, err := requestFunc() result, err := requestFunc()
// If the object wasn't committed to storage because it's serialized size was too large, // If the object wasn't committed to storage because it's serialized size was too large,
// it is safe to remove managedFields (which can be large) and try again. // it is safe to remove managedFields (which can be large) and try again.

View File

@ -49,6 +49,12 @@ import (
"k8s.io/klog/v2" "k8s.io/klog/v2"
) )
const (
// 34 chose as a number close to 30 that is likely to be unique enough to jump out at me the next time I see a timeout.
// Everyone chooses 30.
requestTimeout = 34 * time.Second
)
// RequestScope encapsulates common fields across all RESTful handler methods. // RequestScope encapsulates common fields across all RESTful handler methods.
type RequestScope struct { type RequestScope struct {
Namer ScopeNamer Namer ScopeNamer
@ -213,7 +219,7 @@ type resultFunc func() (runtime.Object, error)
// finishRequest makes a given resultFunc asynchronous and handles errors returned by the response. // finishRequest makes a given resultFunc asynchronous and handles errors returned by the response.
// An api.Status object with status != success is considered an "error", which interrupts the normal response flow. // An api.Status object with status != success is considered an "error", which interrupts the normal response flow.
func finishRequest(timeout time.Duration, fn resultFunc) (result runtime.Object, err error) { func finishRequest(ctx context.Context, fn resultFunc) (result runtime.Object, err error) {
// these channels need to be buffered to prevent the goroutine below from hanging indefinitely // these channels need to be buffered to prevent the goroutine below from hanging indefinitely
// when the select statement reads something other than the one the goroutine sends on. // when the select statement reads something other than the one the goroutine sends on.
ch := make(chan runtime.Object, 1) ch := make(chan runtime.Object, 1)
@ -257,8 +263,8 @@ func finishRequest(timeout time.Duration, fn resultFunc) (result runtime.Object,
return nil, err return nil, err
case p := <-panicCh: case p := <-panicCh:
panic(p) panic(p)
case <-time.After(timeout): case <-ctx.Done():
return nil, errors.NewTimeoutError(fmt.Sprintf("request did not complete within requested timeout %s", timeout), 0) return nil, errors.NewTimeoutError(fmt.Sprintf("request did not complete within requested timeout %s", ctx.Err()), 0)
} }
} }

View File

@ -456,8 +456,6 @@ func (tc *patchTestCase) Run(t *testing.T) {
codec: codec, codec: codec,
timeout: 1 * time.Second,
restPatcher: testPatcher, restPatcher: testPatcher,
name: name, name: name,
patchType: patchType, patchType: patchType,
@ -466,7 +464,10 @@ func (tc *patchTestCase) Run(t *testing.T) {
trace: utiltrace.New("Patch", utiltrace.Field{"name", name}), trace: utiltrace.New("Patch", utiltrace.Field{"name", name}),
} }
ctx, cancel := context.WithTimeout(ctx, time.Second)
resultObj, _, err := p.patchResource(ctx, &RequestScope{}) resultObj, _, err := p.patchResource(ctx, &RequestScope{})
cancel()
if len(tc.expectedError) != 0 { if len(tc.expectedError) != 0 {
if err == nil || err.Error() != tc.expectedError { if err == nil || err.Error() != tc.expectedError {
t.Errorf("%s: expected error %v, but got %v", tc.name, tc.expectedError, err) t.Errorf("%s: expected error %v, but got %v", tc.name, tc.expectedError, err)
@ -842,9 +843,13 @@ func TestFinishRequest(t *testing.T) {
exampleErr := fmt.Errorf("error") exampleErr := fmt.Errorf("error")
successStatusObj := &metav1.Status{Status: metav1.StatusSuccess, Message: "success message"} successStatusObj := &metav1.Status{Status: metav1.StatusSuccess, Message: "success message"}
errorStatusObj := &metav1.Status{Status: metav1.StatusFailure, Message: "error message"} errorStatusObj := &metav1.Status{Status: metav1.StatusFailure, Message: "error message"}
timeoutFunc := func() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.TODO(), time.Second)
}
testcases := []struct { testcases := []struct {
name string name string
timeout time.Duration timeout func() (context.Context, context.CancelFunc)
fn resultFunc fn resultFunc
expectedObj runtime.Object expectedObj runtime.Object
expectedErr error expectedErr error
@ -854,7 +859,7 @@ func TestFinishRequest(t *testing.T) {
}{ }{
{ {
name: "Expected obj is returned", name: "Expected obj is returned",
timeout: time.Second, timeout: timeoutFunc,
fn: func() (runtime.Object, error) { fn: func() (runtime.Object, error) {
return exampleObj, nil return exampleObj, nil
}, },
@ -863,7 +868,7 @@ func TestFinishRequest(t *testing.T) {
}, },
{ {
name: "Expected error is returned", name: "Expected error is returned",
timeout: time.Second, timeout: timeoutFunc,
fn: func() (runtime.Object, error) { fn: func() (runtime.Object, error) {
return nil, exampleErr return nil, exampleErr
}, },
@ -872,7 +877,7 @@ func TestFinishRequest(t *testing.T) {
}, },
{ {
name: "Successful status object is returned as expected", name: "Successful status object is returned as expected",
timeout: time.Second, timeout: timeoutFunc,
fn: func() (runtime.Object, error) { fn: func() (runtime.Object, error) {
return successStatusObj, nil return successStatusObj, nil
}, },
@ -881,7 +886,7 @@ func TestFinishRequest(t *testing.T) {
}, },
{ {
name: "Error status object is converted to StatusError", name: "Error status object is converted to StatusError",
timeout: time.Second, timeout: timeoutFunc,
fn: func() (runtime.Object, error) { fn: func() (runtime.Object, error) {
return errorStatusObj, nil return errorStatusObj, nil
}, },
@ -890,7 +895,7 @@ func TestFinishRequest(t *testing.T) {
}, },
{ {
name: "Panic is propagated up", name: "Panic is propagated up",
timeout: time.Second, timeout: timeoutFunc,
fn: func() (runtime.Object, error) { fn: func() (runtime.Object, error) {
panic("my panic") panic("my panic")
}, },
@ -900,7 +905,7 @@ func TestFinishRequest(t *testing.T) {
}, },
{ {
name: "Panic is propagated with stack", name: "Panic is propagated with stack",
timeout: time.Second, timeout: timeoutFunc,
fn: func() (runtime.Object, error) { fn: func() (runtime.Object, error) {
panic("my panic") panic("my panic")
}, },
@ -910,7 +915,7 @@ func TestFinishRequest(t *testing.T) {
}, },
{ {
name: "http.ErrAbortHandler panic is propagated without wrapping with stack", name: "http.ErrAbortHandler panic is propagated without wrapping with stack",
timeout: time.Second, timeout: timeoutFunc,
fn: func() (runtime.Object, error) { fn: func() (runtime.Object, error) {
panic(http.ErrAbortHandler) panic(http.ErrAbortHandler)
}, },
@ -922,7 +927,10 @@ func TestFinishRequest(t *testing.T) {
} }
for i, tc := range testcases { for i, tc := range testcases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
ctx, cancel := tc.timeout()
defer func() { defer func() {
cancel()
r := recover() r := recover()
switch { switch {
case r == nil && len(tc.expectedPanic) > 0: case r == nil && len(tc.expectedPanic) > 0:
@ -937,7 +945,7 @@ func TestFinishRequest(t *testing.T) {
t.Errorf("expected panic obj %#v, got %#v", tc.expectedPanicObj, r) t.Errorf("expected panic obj %#v, got %#v", tc.expectedPanicObj, r)
} }
}() }()
obj, err := finishRequest(tc.timeout, tc.fn) obj, err := finishRequest(ctx, tc.fn)
if (err == nil && tc.expectedErr != nil) || (err != nil && tc.expectedErr == nil) || (err != nil && tc.expectedErr != nil && err.Error() != tc.expectedErr.Error()) { if (err == nil && tc.expectedErr != nil) || (err != nil && tc.expectedErr == nil) || (err != nil && tc.expectedErr != nil && err.Error() != tc.expectedErr.Error()) {
t.Errorf("%d: unexpected err. expected: %v, got: %v", i, tc.expectedErr, err) t.Errorf("%d: unexpected err. expected: %v, got: %v", i, tc.expectedErr, err)
} }

View File

@ -54,15 +54,12 @@ func UpdateResource(r rest.Updater, scope *RequestScope, admit admission.Interfa
return return
} }
// TODO: we either want to remove timeout or document it (if we document, move timeout out of this function and declare it in api_installer)
timeout := parseTimeout(req.URL.Query().Get("timeout"))
namespace, name, err := scope.Namer.Name(req) namespace, name, err := scope.Namer.Name(req)
if err != nil { if err != nil {
scope.err(err, w, req) scope.err(err, w, req)
return return
} }
ctx, cancel := context.WithTimeout(req.Context(), timeout) ctx, cancel := context.WithTimeout(req.Context(), requestTimeout)
defer cancel() defer cancel()
ctx = request.WithNamespace(ctx, namespace) ctx = request.WithNamespace(ctx, namespace)
@ -188,7 +185,7 @@ func UpdateResource(r rest.Updater, scope *RequestScope, admit admission.Interfa
wasCreated = created wasCreated = created
return obj, err return obj, err
} }
result, err := finishRequest(timeout, func() (runtime.Object, error) { result, err := finishRequest(ctx, func() (runtime.Object, error) {
result, err := requestFunc() result, err := requestFunc()
// If the object wasn't committed to storage because it's serialized size was too large, // If the object wasn't committed to storage because it's serialized size was too large,
// it is safe to remove managedFields (which can be large) and try again. // it is safe to remove managedFields (which can be large) and try again.

View File

@ -746,7 +746,14 @@ func DefaultBuildHandlerChain(apiHandler http.Handler, c *Config) http.Handler {
handler = filterlatency.TrackStarted(handler, "authentication") handler = filterlatency.TrackStarted(handler, "authentication")
handler = genericfilters.WithCORS(handler, c.CorsAllowedOriginList, nil, nil, nil, "true") handler = genericfilters.WithCORS(handler, c.CorsAllowedOriginList, nil, nil, nil, "true")
handler = genericfilters.WithTimeoutForNonLongRunningRequests(handler, c.LongRunningFunc, c.RequestTimeout)
// WithTimeoutForNonLongRunningRequests will call the rest of the request handling in a go-routine with the
// context with deadline. The go-routine can keep running, while the timeout logic will return a timeout to the client.
handler = genericfilters.WithTimeoutForNonLongRunningRequests(handler, c.LongRunningFunc)
// WithRequestDeadline sets a deadline for the request context appropriately
handler = genericapifilters.WithRequestDeadline(handler, c.LongRunningFunc, c.RequestTimeout)
handler = genericfilters.WithWaitGroup(handler, c.LongRunningFunc, c.HandlerChainWaitGroup) handler = genericfilters.WithWaitGroup(handler, c.LongRunningFunc, c.HandlerChainWaitGroup)
handler = genericapifilters.WithRequestInfo(handler, c.RequestInfoResolver) handler = genericapifilters.WithRequestInfo(handler, c.RequestInfoResolver)
if c.SecureServing != nil && !c.SecureServing.DisableHTTP2 && c.GoawayChance > 0 { if c.SecureServing != nil && !c.SecureServing.DisableHTTP2 && c.GoawayChance > 0 {

View File

@ -18,14 +18,12 @@ package filters
import ( import (
"bufio" "bufio"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"runtime" "runtime"
"sync" "sync"
"time"
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
utilruntime "k8s.io/apimachinery/pkg/util/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime"
@ -34,37 +32,33 @@ import (
) )
// WithTimeoutForNonLongRunningRequests times out non-long-running requests after the time given by timeout. // WithTimeoutForNonLongRunningRequests times out non-long-running requests after the time given by timeout.
func WithTimeoutForNonLongRunningRequests(handler http.Handler, longRunning apirequest.LongRunningRequestCheck, timeout time.Duration) http.Handler { func WithTimeoutForNonLongRunningRequests(handler http.Handler, longRunning apirequest.LongRunningRequestCheck) http.Handler {
if longRunning == nil { if longRunning == nil {
return handler return handler
} }
timeoutFunc := func(req *http.Request) (*http.Request, <-chan time.Time, func(), *apierrors.StatusError) { timeoutFunc := func(req *http.Request) (*http.Request, bool, func(), *apierrors.StatusError) {
// TODO unify this with apiserver.MaxInFlightLimit // TODO unify this with apiserver.MaxInFlightLimit
ctx := req.Context() ctx := req.Context()
requestInfo, ok := apirequest.RequestInfoFrom(ctx) requestInfo, ok := apirequest.RequestInfoFrom(ctx)
if !ok { if !ok {
// if this happens, the handler chain isn't setup correctly because there is no request info // if this happens, the handler chain isn't setup correctly because there is no request info
return req, time.After(timeout), func() {}, apierrors.NewInternalError(fmt.Errorf("no request info found for request during timeout")) return req, false, func() {}, apierrors.NewInternalError(fmt.Errorf("no request info found for request during timeout"))
} }
if longRunning(req, requestInfo) { if longRunning(req, requestInfo) {
return req, nil, nil, nil return req, true, nil, nil
} }
ctx, cancel := context.WithCancel(ctx)
req = req.WithContext(ctx)
postTimeoutFn := func() { postTimeoutFn := func() {
cancel()
metrics.RecordRequestTermination(req, requestInfo, metrics.APIServerComponent, http.StatusGatewayTimeout) metrics.RecordRequestTermination(req, requestInfo, metrics.APIServerComponent, http.StatusGatewayTimeout)
} }
return req, time.After(timeout), postTimeoutFn, apierrors.NewTimeoutError(fmt.Sprintf("request did not complete within %s", timeout), 0) return req, false, postTimeoutFn, apierrors.NewTimeoutError("request did not complete within the allotted timeout", 0)
} }
return WithTimeout(handler, timeoutFunc) return WithTimeout(handler, timeoutFunc)
} }
type timeoutFunc = func(*http.Request) (req *http.Request, timeout <-chan time.Time, postTimeoutFunc func(), err *apierrors.StatusError) type timeoutFunc = func(*http.Request) (req *http.Request, longRunning bool, postTimeoutFunc func(), err *apierrors.StatusError)
// WithTimeout returns an http.Handler that runs h with a timeout // WithTimeout returns an http.Handler that runs h with a timeout
// determined by timeoutFunc. The new http.Handler calls h.ServeHTTP to handle // determined by timeoutFunc. The new http.Handler calls h.ServeHTTP to handle
@ -85,12 +79,14 @@ type timeoutHandler struct {
} }
func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r, after, postTimeoutFn, err := t.timeout(r) r, longRunning, postTimeoutFn, err := t.timeout(r)
if after == nil { if longRunning {
t.handler.ServeHTTP(w, r) t.handler.ServeHTTP(w, r)
return return
} }
timeout := r.Context().Done()
// resultCh is used as both errCh and stopCh // resultCh is used as both errCh and stopCh
resultCh := make(chan interface{}) resultCh := make(chan interface{})
tw := newTimeoutWriter(w) tw := newTimeoutWriter(w)
@ -117,7 +113,7 @@ func (t *timeoutHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
panic(err) panic(err)
} }
return return
case <-after: case <-timeout:
defer func() { defer func() {
// resultCh needs to have a reader, since the function doing // resultCh needs to have a reader, since the function doing
// the work needs to send to it. This is defer'd to ensure it runs // the work needs to send to it. This is defer'd to ensure it runs

View File

@ -18,6 +18,7 @@ package filters
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
@ -92,18 +93,27 @@ func TestTimeout(t *testing.T) {
timeoutErr := apierrors.NewServerTimeout(schema.GroupResource{Group: "foo", Resource: "bar"}, "get", 0) timeoutErr := apierrors.NewServerTimeout(schema.GroupResource{Group: "foo", Resource: "bar"}, "get", 0)
record := &recorder{} record := &recorder{}
var ctx context.Context
withDeadline := func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req = req.WithContext(ctx)
handler.ServeHTTP(w, req)
})
}
handler := newHandler(sendResponse, doPanic, writeErrors) handler := newHandler(sendResponse, doPanic, writeErrors)
ts := httptest.NewServer(withPanicRecovery( ts := httptest.NewServer(withDeadline(withPanicRecovery(
WithTimeout(handler, func(req *http.Request) (*http.Request, <-chan time.Time, func(), *apierrors.StatusError) { WithTimeout(handler, func(req *http.Request) (*http.Request, bool, func(), *apierrors.StatusError) {
return req, timeout, record.Record, timeoutErr return req, false, record.Record, timeoutErr
}), func(w http.ResponseWriter, req *http.Request, err interface{}) { }), func(w http.ResponseWriter, req *http.Request, err interface{}) {
gotPanic <- err gotPanic <- err
http.Error(w, "This request caused apiserver to panic. Look in the logs for details.", http.StatusInternalServerError) http.Error(w, "This request caused apiserver to panic. Look in the logs for details.", http.StatusInternalServerError)
}), }),
) ))
defer ts.Close() defer ts.Close()
// No timeouts // No timeouts
ctx = context.Background()
sendResponse <- resp sendResponse <- resp
res, err := http.Get(ts.URL) res, err := http.Get(ts.URL)
if err != nil { if err != nil {
@ -124,6 +134,8 @@ func TestTimeout(t *testing.T) {
} }
// Times out // Times out
ctx, cancel := context.WithCancel(context.Background())
cancel()
timeout <- time.Time{} timeout <- time.Time{}
res, err = http.Get(ts.URL) res, err = http.Get(ts.URL)
if err != nil { if err != nil {
@ -145,6 +157,7 @@ func TestTimeout(t *testing.T) {
} }
// Now try to send a response // Now try to send a response
ctx = context.Background()
sendResponse <- resp sendResponse <- resp
if err := <-writeErrors; err != http.ErrHandlerTimeout { if err := <-writeErrors; err != http.ErrHandlerTimeout {
t.Errorf("got Write error of %v; expected %v", err, http.ErrHandlerTimeout) t.Errorf("got Write error of %v; expected %v", err, http.ErrHandlerTimeout)
@ -170,6 +183,7 @@ func TestTimeout(t *testing.T) {
} }
// Panics with http.ErrAbortHandler // Panics with http.ErrAbortHandler
ctx = context.Background()
doPanic <- http.ErrAbortHandler doPanic <- http.ErrAbortHandler
res, err = http.Get(ts.URL) res, err = http.Get(ts.URL)
if err != nil { if err != nil {

View File

@ -124,7 +124,7 @@ func testWebhookTimeout(t *testing.T, watchCache bool) {
}, },
{ {
name: "timed out client requests skip later mutating webhooks (regardless of failure policy) and fail", name: "timed out client requests skip later mutating webhooks (regardless of failure policy) and fail",
timeoutSeconds: 3, timeoutSeconds: 4,
mutatingWebhooks: []testWebhook{ mutatingWebhooks: []testWebhook{
{path: "/mutating/1/5s", policy: admissionv1beta1.Ignore, timeoutSeconds: 4}, {path: "/mutating/1/5s", policy: admissionv1beta1.Ignore, timeoutSeconds: 4},
{path: "/mutating/2/1s", policy: admissionv1beta1.Ignore, timeoutSeconds: 5}, {path: "/mutating/2/1s", policy: admissionv1beta1.Ignore, timeoutSeconds: 5},
@ -133,8 +133,7 @@ func testWebhookTimeout(t *testing.T, watchCache bool) {
expectInvocations: []invocation{ expectInvocations: []invocation{
{path: "/mutating/1/5s", timeoutSeconds: 3}, // from request {path: "/mutating/1/5s", timeoutSeconds: 3}, // from request
}, },
expectError: true, expectError: true,
errorContains: "request did not complete within requested timeout",
}, },
} }