From 6385b86a9b124eb03848af9a3029e8bc9058d72f Mon Sep 17 00:00:00 2001 From: Abu Kashem Date: Fri, 13 Jan 2023 18:04:13 -0500 Subject: [PATCH] apiserver: terminate watch with a rate limiter during shutdown --- .../src/k8s.io/apiserver/pkg/server/config.go | 7 + .../pkg/server/filters/watch_termination.go | 62 +++++++ .../server/filters/watch_termination_test.go | 166 ++++++++++++++++++ .../apiserver/pkg/server/genericapiserver.go | 84 +++++++-- ...ericapiserver_graceful_termination_test.go | 109 +++++++++--- 5 files changed, 387 insertions(+), 41 deletions(-) create mode 100644 staging/src/k8s.io/apiserver/pkg/server/filters/watch_termination.go create mode 100644 staging/src/k8s.io/apiserver/pkg/server/filters/watch_termination_test.go diff --git a/staging/src/k8s.io/apiserver/pkg/server/config.go b/staging/src/k8s.io/apiserver/pkg/server/config.go index 86f8eca6d8e..0d348370ab6 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/config.go +++ b/staging/src/k8s.io/apiserver/pkg/server/config.go @@ -161,6 +161,10 @@ type Config struct { // handlers associated with non long-running requests // to complete while the server is shuting down. NonLongRunningRequestWaitGroup *utilwaitgroup.SafeWaitGroup + // WatchRequestWaitGroup allows us to wait for all chain + // handlers associated with active watch requests to + // complete while the server is shuting down. + WatchRequestWaitGroup *utilwaitgroup.RateLimitedSafeWaitGroup // DiscoveryAddresses is used to build the IPs pass to discovery. If nil, the ExternalAddress is // always reported DiscoveryAddresses discovery.Addresses @@ -371,6 +375,7 @@ func NewConfig(codecs serializer.CodecFactory) *Config { Serializer: codecs, BuildHandlerChainFunc: DefaultBuildHandlerChain, NonLongRunningRequestWaitGroup: new(utilwaitgroup.SafeWaitGroup), + WatchRequestWaitGroup: &utilwaitgroup.RateLimitedSafeWaitGroup{}, LegacyAPIGroupPrefixes: sets.NewString(DefaultLegacyAPIPrefix), DisabledPostStartHooks: sets.NewString(), PostStartHooks: map[string]PostStartHookConfigEntry{}, @@ -670,6 +675,7 @@ func (c completedConfig) New(name string, delegationTarget DelegationTarget) (*G delegationTarget: delegationTarget, EquivalentResourceRegistry: c.EquivalentResourceRegistry, NonLongRunningRequestWaitGroup: c.NonLongRunningRequestWaitGroup, + WatchRequestWaitGroup: c.WatchRequestWaitGroup, Handler: apiServerHandler, UnprotectedDebugSocket: debugSocket, @@ -907,6 +913,7 @@ func DefaultBuildHandlerChain(apiHandler http.Handler, c *Config) http.Handler { handler = genericapifilters.WithRequestDeadline(handler, c.AuditBackend, c.AuditPolicyRuleEvaluator, c.LongRunningFunc, c.Serializer, c.RequestTimeout) handler = genericfilters.WithWaitGroup(handler, c.LongRunningFunc, c.NonLongRunningRequestWaitGroup) + handler = genericfilters.WithWatchTerminationDuringShutdown(handler, c.lifecycleSignals, c.WatchRequestWaitGroup) if c.SecureServing != nil && !c.SecureServing.DisableHTTP2 && c.GoawayChance > 0 { handler = genericfilters.WithProbabilisticGoaway(handler, c.GoawayChance) } diff --git a/staging/src/k8s.io/apiserver/pkg/server/filters/watch_termination.go b/staging/src/k8s.io/apiserver/pkg/server/filters/watch_termination.go new file mode 100644 index 00000000000..515f38e516f --- /dev/null +++ b/staging/src/k8s.io/apiserver/pkg/server/filters/watch_termination.go @@ -0,0 +1,62 @@ +/* +Copyright 2023 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "errors" + "net/http" + + "k8s.io/apiserver/pkg/endpoints/handlers/responsewriters" + apirequest "k8s.io/apiserver/pkg/endpoints/request" + "k8s.io/klog/v2" +) + +func WithWatchTerminationDuringShutdown(handler http.Handler, termination apirequest.ServerShutdownSignal, wg RequestWaitGroup) http.Handler { + if termination == nil || wg == nil { + klog.Warningf("watch termination during shutdown not attached to the handler chain") + return handler + } + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + ctx := req.Context() + requestInfo, ok := apirequest.RequestInfoFrom(ctx) + if !ok { + // if this happens, the handler chain isn't setup correctly because there is no request info + responsewriters.InternalError(w, req, errors.New("no RequestInfo found in the context")) + return + } + if !watchVerbs.Has(requestInfo.Verb) { + handler.ServeHTTP(w, req) + return + } + + if err := wg.Add(1); err != nil { + // When apiserver is shutting down, signal clients to retry + // There is a good chance the client hit a different server, so a tight retry is good for client responsiveness. + waitGroupWriteRetryAfterToResponse(w) + return + } + + // attach ServerShutdownSignal to the watch request so that the + // watch handler loop can return as soon as the server signals + // that it is shutting down. + ctx = apirequest.WithServerShutdownSignal(req.Context(), termination) + req = req.WithContext(ctx) + + defer wg.Done() + handler.ServeHTTP(w, req) + }) +} diff --git a/staging/src/k8s.io/apiserver/pkg/server/filters/watch_termination_test.go b/staging/src/k8s.io/apiserver/pkg/server/filters/watch_termination_test.go new file mode 100644 index 00000000000..af69250e95c --- /dev/null +++ b/staging/src/k8s.io/apiserver/pkg/server/filters/watch_termination_test.go @@ -0,0 +1,166 @@ +/* +Copyright 2021 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filters + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "k8s.io/apimachinery/pkg/runtime" + apirequest "k8s.io/apiserver/pkg/endpoints/request" +) + +func TestWithWatchTerminationDuringShutdown(t *testing.T) { + tests := []struct { + name string + requestInfo *apirequest.RequestInfo + signal *fakeServerShutdownSignal + wg *fakeRequestWaitGroup + handlerInvoked int + statusCodeExpected int + retryAfterExpected bool + wgInvokedExpected int + signalAttachedToContext bool + }{ + { + name: "no RequestInfo attached to request context", + handlerInvoked: 0, + statusCodeExpected: http.StatusInternalServerError, + }, + { + name: "request is not a WATCH, not added into wait group", + requestInfo: &apirequest.RequestInfo{Verb: "get"}, + handlerInvoked: 1, + statusCodeExpected: http.StatusOK, + }, + { + name: "request is a WATCH, wait group is in waiting mode", + requestInfo: &apirequest.RequestInfo{Verb: "watch"}, + wg: &fakeRequestWaitGroup{waiting: true}, + handlerInvoked: 0, + signalAttachedToContext: false, + wgInvokedExpected: 1, + retryAfterExpected: true, + statusCodeExpected: http.StatusServiceUnavailable, + }, + { + name: "request is a WATCH, wait group is accepting", + requestInfo: &apirequest.RequestInfo{Verb: "watch"}, + wg: &fakeRequestWaitGroup{}, + signal: &fakeServerShutdownSignal{}, + wgInvokedExpected: 1, + signalAttachedToContext: true, + handlerInvoked: 1, + statusCodeExpected: http.StatusOK, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var ( + handlerInvokedGot int + signalGot *fakeServerShutdownSignal + ) + delegate := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + handlerInvokedGot++ + if signal := apirequest.ServerShutdownSignalFrom(req.Context()); signal != nil { + signalGot, _ = signal.(*fakeServerShutdownSignal) + } + w.WriteHeader(http.StatusOK) + }) + + handler := WithWatchTerminationDuringShutdown(delegate, test.signal, test.wg) + + req, err := http.NewRequest(http.MethodGet, "/apis/groups.k8s.io/v1/namespaces", nil) + if err != nil { + t.Fatalf("failed to create new http request - %v", err) + } + if test.requestInfo != nil { + req = req.WithContext(apirequest.WithRequestInfo(req.Context(), test.requestInfo)) + } + + w := httptest.NewRecorder() + w.Code = 0 + handler.ServeHTTP(w, req) + responseGot := w.Result() + + if test.handlerInvoked != handlerInvokedGot { + t.Errorf("expected the handler to be invoked: %d timed, but got: %d", test.handlerInvoked, handlerInvokedGot) + } + if test.statusCodeExpected != responseGot.StatusCode { + t.Errorf("expected status code: %d, but got: %d", test.statusCodeExpected, w.Result().StatusCode) + } + retryAfterGot := retryAfterSent(responseGot) + if test.retryAfterExpected != retryAfterGot { + t.Errorf("expected retry-after: %t, but got: %t, response: %v#", test.retryAfterExpected, retryAfterGot, responseGot) + } + + switch { + case test.signalAttachedToContext: + if test.signal == nil || test.signal != signalGot { + t.Errorf("expected request context to have server shutdown signal: %p, but got: %p", test.signal, signalGot) + } + default: + if signalGot != nil { + t.Errorf("expected request context to not have server shutdown signal: %p, but got: %p", test.signal, signalGot) + } + } + if test.wg == nil { + return + } + if test.wg.inflight != 0 { + t.Errorf("expected wait group inflight to be zero, but got: %d", test.wg.inflight) + } + if test.wgInvokedExpected != test.wg.invoked { + t.Errorf("expected wait group Add to be invoked: %d times, but got: %d", test.wgInvokedExpected, test.wg.invoked) + } + }) + } +} + +type fakeServerShutdownSignal struct{} + +func (fakeServerShutdownSignal) ShuttingDown() <-chan struct{} { return nil } + +type fakeRequestWaitGroup struct { + waiting bool + invoked, inflight int +} + +func (f *fakeRequestWaitGroup) Add(delta int) error { + f.invoked++ + if f.waiting { + return fmt.Errorf("waitgroup is in waiting mode") + } + f.inflight += delta + return nil +} +func (f *fakeRequestWaitGroup) Done() { f.inflight-- } + +func retryAfterSent(resp *http.Response) bool { + switch { + case resp.StatusCode == http.StatusServiceUnavailable && + resp.Header.Get("Retry-After") == "1" && + resp.Header.Get("Content-Type") == runtime.ContentTypeJSON && + resp.Header.Get("X-Content-Type-Options") == "nosniff": + return true + default: + return false + } +} diff --git a/staging/src/k8s.io/apiserver/pkg/server/genericapiserver.go b/staging/src/k8s.io/apiserver/pkg/server/genericapiserver.go index 1a2554048bc..2226b1e954a 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/genericapiserver.go +++ b/staging/src/k8s.io/apiserver/pkg/server/genericapiserver.go @@ -17,6 +17,7 @@ limitations under the License. package server import ( + "context" "fmt" "net/http" gpath "path" @@ -26,6 +27,7 @@ import ( systemd "github.com/coreos/go-systemd/v22/daemon" + "golang.org/x/time/rate" apidiscoveryv2beta1 "k8s.io/api/apidiscovery/v2beta1" "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -221,6 +223,10 @@ type GenericAPIServer struct { // handlers associated with non long-running requests // to complete while the server is shuting down. NonLongRunningRequestWaitGroup *utilwaitgroup.SafeWaitGroup + // WatchRequestWaitGroup allows us to wait for all chain + // handlers associated with active watch requests to + // complete while the server is shuting down. + WatchRequestWaitGroup *utilwaitgroup.RateLimitedSafeWaitGroup // ShutdownDelayDuration allows to block shutdown for some time, e.g. until endpoints pointing to this API server // have converged on all node. During this time, the API server keeps serving, /healthz will return 200, @@ -447,23 +453,27 @@ func (s *GenericAPIServer) PrepareRun() preparedGenericAPIServer { // | NotAcceptingNewRequest (notAcceptingNewRequestCh) // | | // | | -// | |---------------------------------------------------------| -// | | | | | -// | [without [with | | -// | ShutdownSendRetryAfter] ShutdownSendRetryAfter] | | -// | | | | | -// | | ---------------| | -// | | | | -// | | (NonLongRunningRequestWaitGroup::Wait) | -// | | | | -// | | InFlightRequestsDrained (drainedCh) | -// | | | | -// | ----------------------------------------|-----------------| -// | | | +// | |----------------------------------------------------------------------------------| +// | | | | | +// | [without [with | | +// | ShutdownSendRetryAfter] ShutdownSendRetryAfter] | | +// | | | | | +// | | ---------------| | +// | | | | +// | | |----------------|-----------------------| | +// | | | | | +// | | (NonLongRunningRequestWaitGroup::Wait) (WatchRequestWaitGroup::Wait) | +// | | | | | +// | | |------------------|---------------------| | +// | | | | +// | | InFlightRequestsDrained (drainedCh) | +// | | | | +// | |-------------------|---------------------|----------------------------------------| +// | | | // | stopHttpServerCh (AuditBackend::Shutdown()) -// | | +// | | // | listenerStoppedCh -// | | +// | | // | HTTPServerStoppedListening (httpServerStoppedListeningCh) func (s preparedGenericAPIServer) Run(stopCh <-chan struct{}) error { delayedStopCh := s.lifecycleSignals.AfterShutdownDelayDuration @@ -576,9 +586,11 @@ func (s preparedGenericAPIServer) Run(stopCh <-chan struct{}) error { <-preShutdownHooksHasStoppedCh.Signaled() }() + // wait for all in-flight non-long running requests to finish + nonLongRunningRequestDrainedCh := make(chan struct{}) go func() { - defer klog.V(1).InfoS("[graceful-termination] shutdown event", "name", drainedCh.Name()) - defer drainedCh.Signal() + defer close(nonLongRunningRequestDrainedCh) + defer klog.V(1).Info("[graceful-termination] in-flight non long-running request(s) have drained") // wait for the delayed stopCh before closing the handler chain (it rejects everything after Wait has been called). <-notAcceptingNewRequestCh.Signaled() @@ -599,6 +611,44 @@ func (s preparedGenericAPIServer) Run(stopCh <-chan struct{}) error { s.NonLongRunningRequestWaitGroup.Wait() }() + // wait for all in-flight watches to finish + activeWatchesDrainedCh := make(chan struct{}) + go func() { + defer close(activeWatchesDrainedCh) + + <-notAcceptingNewRequestCh.Signaled() + + // Wait for all active watches to finish + // TODO(tkashem): make the grace period configurable + grace := 10 * time.Second + activeBefore, activeAfter, err := s.WatchRequestWaitGroup.Wait(func(count int) (utilwaitgroup.RateLimiter, context.Context, context.CancelFunc) { + qps := float64(count) / grace.Seconds() + // TODO: we don't want the QPS (max requests drained per second) to + // get below a certain floor value, since we want the server to + // drain the active watch requests as soon as possible. + // For now, it's hard coded to 200, and it is subject to change + // based on the result from the scale testing. + if qps < 200 { + qps = 200 + } + + ctx, cancel := context.WithTimeout(context.Background(), grace) + // We don't expect more than one token to be consumed + // in a single Wait call, so setting burst to 1. + return rate.NewLimiter(rate.Limit(qps), 1), ctx, cancel + }) + klog.V(1).InfoS("[graceful-termination] active watch request(s) have drained", + "duration", grace, "activeWatchesBefore", activeBefore, "activeWatchesAfter", activeAfter, "error", err) + }() + + go func() { + defer klog.V(1).InfoS("[graceful-termination] shutdown event", "name", drainedCh.Name()) + defer drainedCh.Signal() + + <-nonLongRunningRequestDrainedCh + <-activeWatchesDrainedCh + }() + klog.V(1).Info("[graceful-termination] waiting for shutdown to be initiated") <-stopCh diff --git a/staging/src/k8s.io/apiserver/pkg/server/genericapiserver_graceful_termination_test.go b/staging/src/k8s.io/apiserver/pkg/server/genericapiserver_graceful_termination_test.go index edab4983931..2ee5411db5e 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/genericapiserver_graceful_termination_test.go +++ b/staging/src/k8s.io/apiserver/pkg/server/genericapiserver_graceful_termination_test.go @@ -38,6 +38,7 @@ import ( auditinternal "k8s.io/apiserver/pkg/apis/audit" "k8s.io/apiserver/pkg/audit" "k8s.io/apiserver/pkg/authorization/authorizer" + apirequest "k8s.io/apiserver/pkg/endpoints/request" "k8s.io/apiserver/pkg/server/dynamiccertificates" "k8s.io/klog/v2" @@ -147,7 +148,7 @@ func newSignalInterceptingTestStep() *signalInterceptingTestStep { // | close(stopHttpServerCh) NonLongRunningRequestWaitGroup.Wait() // | | | // | server.Shutdown(timeout=60s) | -// | | | +// | | WatchRequestWaitGroup.Wait() // | stop listener (net/http) | // | | | // | |-------------------------------------| | @@ -176,8 +177,10 @@ func TestGracefulTerminationWithKeepListeningDuringGracefulTerminationDisabled(t connReusingClient := newClient(false) doer := setupDoer(t, s.SecureServingInfo) - // handler for a request that we want to keep in flight through to the end - inflightRequest := setupInFlightReuestHandler(s) + // handler for a non long-running and a watch request that + // we want to keep in flight through to the end. + inflightNonLongRunning := setupInFlightNonLongRunningRequestHandler(s) + inflightWatch := setupInFlightWatchRequestHandler(s) // API calls from the pre-shutdown hook(s) must succeed up to // the point where the HTTP server is shut down. @@ -204,10 +207,13 @@ func TestGracefulTerminationWithKeepListeningDuringGracefulTerminationDisabled(t }() waitForAPIServerStarted(t, doer) - // fire a request now so it is in-flight on the server now, and - // we will unblock it after ShutdownDelayDuration elapses - inflightRequest.launch(doer, connReusingClient) - waitForeverUntil(t, inflightRequest.startedCh, "in-flight request did not reach the server") + // fire the non long-running and the watch request so it is + // in-flight on the server now, and we will unblock them + // after ShutdownDelayDuration elapses. + inflightNonLongRunning.launch(doer, connReusingClient) + waitForeverUntil(t, inflightNonLongRunning.startedCh, "in-flight non long-running request did not reach the server") + inflightWatch.launch(doer, connReusingClient) + waitForeverUntil(t, inflightWatch.startedCh, "in-flight watch request did not reach the server") // /readyz should return OK resultGot := doer.Do(newClient(true), func(httptrace.GotConnInfo) {}, "/readyz", time.Second) @@ -300,13 +306,21 @@ func TestGracefulTerminationWithKeepListeningDuringGracefulTerminationDisabled(t t.Errorf("Expected error %v, but got: %v %v", syscall.ECONNREFUSED, resultGot.err, resultGot.response) } - // the server has stopped listening but we still have a request - // in flight, let it unblock and we expect the request to succeed. - inFlightResultGot := inflightRequest.unblockAndWaitForResult(t) - if err := assertResponseStatusCode(inFlightResultGot, http.StatusOK); err != nil { + // the server has stopped listening but we still have a non long-running, + // and a watch request in flight, unblock both of these, and we expect + // the requests to return appropriate response to the caller. + inflightNonLongRunningResultGot := inflightNonLongRunning.unblockAndWaitForResult(t) + if err := assertResponseStatusCode(inflightNonLongRunningResultGot, http.StatusOK); err != nil { t.Errorf("%s", err.Error()) } - if err := assertRequestAudited(inFlightResultGot, fakeAudit); err != nil { + if err := assertRequestAudited(inflightNonLongRunningResultGot, fakeAudit); err != nil { + t.Errorf("%s", err.Error()) + } + inflightWatchResultGot := inflightWatch.unblockAndWaitForResult(t) + if err := assertResponseStatusCode(inflightWatchResultGot, http.StatusOK); err != nil { + t.Errorf("%s", err.Error()) + } + if err := assertRequestAudited(inflightWatchResultGot, fakeAudit); err != nil { t.Errorf("%s", err.Error()) } @@ -359,6 +373,8 @@ func TestGracefulTerminationWithKeepListeningDuringGracefulTerminationDisabled(t // | | // | NonLongRunningRequestWaitGroup.Wait() // | | +// | WatchRequestWaitGroup.Wait() +// | | // | (InFlightRequestsDrained) // | | // | | @@ -384,8 +400,10 @@ func TestGracefulTerminationWithKeepListeningDuringGracefulTerminationEnabled(t connReusingClient := newClient(false) doer := setupDoer(t, s.SecureServingInfo) - // handler for a request that we want to keep in flight through to the end - inflightRequest := setupInFlightReuestHandler(s) + // handler for a non long-running and a watch request that + // we want to keep in flight through to the end. + inflightNonLongRunning := setupInFlightNonLongRunningRequestHandler(s) + inflightWatch := setupInFlightWatchRequestHandler(s) // API calls from the pre-shutdown hook(s) must succeed up to // the point where the HTTP server is shut down. @@ -412,10 +430,13 @@ func TestGracefulTerminationWithKeepListeningDuringGracefulTerminationEnabled(t }() waitForAPIServerStarted(t, doer) - // fire a request now so it is in-flight on the server now, and - // we will unblock it after ShutdownDelayDuration elapses - inflightRequest.launch(doer, connReusingClient) - waitForeverUntil(t, inflightRequest.startedCh, "in-flight request did not reach the server") + // fire the non long-running and the watch request so it is + // in-flight on the server now, and we will unblock them + // after ShutdownDelayDuration elapses. + inflightNonLongRunning.launch(doer, connReusingClient) + waitForeverUntil(t, inflightNonLongRunning.startedCh, "in-flight request did not reach the server") + inflightWatch.launch(doer, connReusingClient) + waitForeverUntil(t, inflightWatch.startedCh, "in-flight watch request did not reach the server") // /readyz should return OK resultGot := doer.Do(newClient(true), func(httptrace.GotConnInfo) {}, "/readyz", time.Second) @@ -487,12 +508,21 @@ func TestGracefulTerminationWithKeepListeningDuringGracefulTerminationEnabled(t t.Errorf("%s", err.Error()) } - // we still have a request in flight, let it unblock and we expect the request to succeed. - inFlightResultGot := inflightRequest.unblockAndWaitForResult(t) - if err := assertResponseStatusCode(inFlightResultGot, http.StatusOK); err != nil { + // we still have a non long-running, and a watch request in flight, + // unblock both of these, and we expect the requests + // to return appropriate response to the caller. + inflightNonLongRunningResultGot := inflightNonLongRunning.unblockAndWaitForResult(t) + if err := assertResponseStatusCode(inflightNonLongRunningResultGot, http.StatusOK); err != nil { t.Errorf("%s", err.Error()) } - if err := assertRequestAudited(inFlightResultGot, fakeAudit); err != nil { + if err := assertRequestAudited(inflightNonLongRunningResultGot, fakeAudit); err != nil { + t.Errorf("%s", err.Error()) + } + inflightWatchResultGot := inflightWatch.unblockAndWaitForResult(t) + if err := assertResponseStatusCode(inflightWatchResultGot, http.StatusOK); err != nil { + t.Errorf("%s", err.Error()) + } + if err := assertRequestAudited(inflightWatchResultGot, fakeAudit); err != nil { t.Errorf("%s", err.Error()) } @@ -663,12 +693,12 @@ type inFlightRequest struct { url string } -func setupInFlightReuestHandler(s *GenericAPIServer) *inFlightRequest { +func setupInFlightNonLongRunningRequestHandler(s *GenericAPIServer) *inFlightRequest { inflight := &inFlightRequest{ blockedCh: make(chan struct{}), startedCh: make(chan struct{}), resultCh: make(chan result), - url: "/in-flight-request-as-designed", + url: "/in-flight-non-long-running-request-as-designed", } handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { close(inflight.startedCh) @@ -680,6 +710,37 @@ func setupInFlightReuestHandler(s *GenericAPIServer) *inFlightRequest { return inflight } +func setupInFlightWatchRequestHandler(s *GenericAPIServer) *inFlightRequest { + inflight := &inFlightRequest{ + blockedCh: make(chan struct{}), + startedCh: make(chan struct{}), + resultCh: make(chan result), + url: "/apis/watches.group/v1/namespaces/foo/bar?watch=true", + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + close(inflight.startedCh) + // this request handler blocks until we deliberately unblock it. + <-inflight.blockedCh + + // this simulates a watch well enough for our test + signals := apirequest.ServerShutdownSignalFrom(req.Context()) + if signals == nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + select { + case <-signals.ShuttingDown(): + w.WriteHeader(http.StatusOK) + return + } + + w.WriteHeader(http.StatusInternalServerError) + }) + s.Handler.NonGoRestfulMux.Handle("/apis/watches.group/v1/namespaces/foo/bar", handler) + return inflight +} + func (ifr *inFlightRequest) launch(doer doer, client *http.Client) { go func() { result := doer.Do(client, func(httptrace.GotConnInfo) {}, ifr.url, 0)