From 9093f126b87cb686784bb27b08be9eb12b4d5453 Mon Sep 17 00:00:00 2001 From: Abu Kashem Date: Tue, 10 Jan 2023 15:55:19 -0500 Subject: [PATCH] apiserver: refactor WithWaitGroup handler --- .../src/k8s.io/apiserver/pkg/server/config.go | 72 ++++++++++--------- .../apiserver/pkg/server/config_test.go | 12 ++-- .../apiserver/pkg/server/filters/waitgroup.go | 36 +++++++--- .../apiserver/pkg/server/genericapiserver.go | 12 ++-- ...ericapiserver_graceful_termination_test.go | 4 +- .../pkg/server/genericapiserver_test.go | 4 +- 6 files changed, 81 insertions(+), 59 deletions(-) diff --git a/staging/src/k8s.io/apiserver/pkg/server/config.go b/staging/src/k8s.io/apiserver/pkg/server/config.go index 61a13ddd55c..8c5da2db524 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/config.go +++ b/staging/src/k8s.io/apiserver/pkg/server/config.go @@ -156,8 +156,10 @@ type Config struct { // BuildHandlerChainFunc allows you to build custom handler chains by decorating the apiHandler. BuildHandlerChainFunc func(apiHandler http.Handler, c *Config) (secure http.Handler) - // HandlerChainWaitGroup allows you to wait for all chain handlers exit after the server shutdown. - HandlerChainWaitGroup *utilwaitgroup.SafeWaitGroup + // NonLongRunningRequestWaitGroup allows you to wait for all chain + // handlers associated with non long-running requests + // to complete while the server is shuting down. + NonLongRunningRequestWaitGroup *utilwaitgroup.SafeWaitGroup // DiscoveryAddresses is used to build the IPs pass to discovery. If nil, the ExternalAddress is // always reported DiscoveryAddresses discovery.Addresses @@ -349,26 +351,26 @@ func NewConfig(codecs serializer.CodecFactory) *Config { lifecycleSignals := newLifecycleSignals() return &Config{ - Serializer: codecs, - BuildHandlerChainFunc: DefaultBuildHandlerChain, - HandlerChainWaitGroup: new(utilwaitgroup.SafeWaitGroup), - LegacyAPIGroupPrefixes: sets.NewString(DefaultLegacyAPIPrefix), - DisabledPostStartHooks: sets.NewString(), - PostStartHooks: map[string]PostStartHookConfigEntry{}, - HealthzChecks: append([]healthz.HealthChecker{}, defaultHealthChecks...), - ReadyzChecks: append([]healthz.HealthChecker{}, defaultHealthChecks...), - LivezChecks: append([]healthz.HealthChecker{}, defaultHealthChecks...), - EnableIndex: true, - EnableDiscovery: true, - EnableProfiling: true, - DebugSocketPath: "", - EnableMetrics: true, - MaxRequestsInFlight: 400, - MaxMutatingRequestsInFlight: 200, - RequestTimeout: time.Duration(60) * time.Second, - MinRequestTimeout: 1800, - LivezGracePeriod: time.Duration(0), - ShutdownDelayDuration: time.Duration(0), + Serializer: codecs, + BuildHandlerChainFunc: DefaultBuildHandlerChain, + NonLongRunningRequestWaitGroup: new(utilwaitgroup.SafeWaitGroup), + LegacyAPIGroupPrefixes: sets.NewString(DefaultLegacyAPIPrefix), + DisabledPostStartHooks: sets.NewString(), + PostStartHooks: map[string]PostStartHookConfigEntry{}, + HealthzChecks: append([]healthz.HealthChecker{}, defaultHealthChecks...), + ReadyzChecks: append([]healthz.HealthChecker{}, defaultHealthChecks...), + LivezChecks: append([]healthz.HealthChecker{}, defaultHealthChecks...), + EnableIndex: true, + EnableDiscovery: true, + EnableProfiling: true, + DebugSocketPath: "", + EnableMetrics: true, + MaxRequestsInFlight: 400, + MaxMutatingRequestsInFlight: 200, + RequestTimeout: time.Duration(60) * time.Second, + MinRequestTimeout: 1800, + LivezGracePeriod: time.Duration(0), + ShutdownDelayDuration: time.Duration(0), // 1.5MB is the default client request size in bytes // the etcd server should accept. See // https://github.com/etcd-io/etcd/blob/release-3.4/embed/config.go#L56. @@ -641,18 +643,18 @@ func (c completedConfig) New(name string, delegationTarget DelegationTarget) (*G apiServerHandler := NewAPIServerHandler(name, c.Serializer, handlerChainBuilder, delegationTarget.UnprotectedHandler()) s := &GenericAPIServer{ - discoveryAddresses: c.DiscoveryAddresses, - LoopbackClientConfig: c.LoopbackClientConfig, - legacyAPIGroupPrefixes: c.LegacyAPIGroupPrefixes, - admissionControl: c.AdmissionControl, - Serializer: c.Serializer, - AuditBackend: c.AuditBackend, - Authorizer: c.Authorization.Authorizer, - delegationTarget: delegationTarget, - EquivalentResourceRegistry: c.EquivalentResourceRegistry, - HandlerChainWaitGroup: c.HandlerChainWaitGroup, - Handler: apiServerHandler, - UnprotectedDebugSocket: debugSocket, + discoveryAddresses: c.DiscoveryAddresses, + LoopbackClientConfig: c.LoopbackClientConfig, + legacyAPIGroupPrefixes: c.LegacyAPIGroupPrefixes, + admissionControl: c.AdmissionControl, + Serializer: c.Serializer, + AuditBackend: c.AuditBackend, + Authorizer: c.Authorization.Authorizer, + delegationTarget: delegationTarget, + EquivalentResourceRegistry: c.EquivalentResourceRegistry, + NonLongRunningRequestWaitGroup: c.NonLongRunningRequestWaitGroup, + Handler: apiServerHandler, + UnprotectedDebugSocket: debugSocket, listedPathProvider: apiServerHandler, @@ -887,7 +889,7 @@ func DefaultBuildHandlerChain(apiHandler http.Handler, c *Config) http.Handler { handler = genericapifilters.WithRequestDeadline(handler, c.AuditBackend, c.AuditPolicyRuleEvaluator, c.LongRunningFunc, c.Serializer, c.RequestTimeout) - handler = genericfilters.WithWaitGroup(handler, c.LongRunningFunc, c.HandlerChainWaitGroup) + handler = genericfilters.WithWaitGroup(handler, c.LongRunningFunc, c.NonLongRunningRequestWaitGroup) if c.SecureServing != nil && !c.SecureServing.DisableHTTP2 && c.GoawayChance > 0 { handler = genericfilters.WithProbabilisticGoaway(handler, c.GoawayChance) } diff --git a/staging/src/k8s.io/apiserver/pkg/server/config_test.go b/staging/src/k8s.io/apiserver/pkg/server/config_test.go index 6da9f3bb683..fcbc15324d3 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/config_test.go +++ b/staging/src/k8s.io/apiserver/pkg/server/config_test.go @@ -298,12 +298,12 @@ func TestAuthenticationAuditAnnotationsDefaultChain(t *testing.T) { AuditPolicyRuleEvaluator: policy.NewFakePolicyRuleEvaluator(auditinternal.LevelMetadata, nil), // avoid nil panics - HandlerChainWaitGroup: &waitgroup.SafeWaitGroup{}, - RequestInfoResolver: &request.RequestInfoFactory{}, - RequestTimeout: 10 * time.Second, - LongRunningFunc: func(_ *http.Request, _ *request.RequestInfo) bool { return false }, - lifecycleSignals: newLifecycleSignals(), - TracerProvider: tracing.NewNoopTracerProvider(), + NonLongRunningRequestWaitGroup: &waitgroup.SafeWaitGroup{}, + RequestInfoResolver: &request.RequestInfoFactory{}, + RequestTimeout: 10 * time.Second, + LongRunningFunc: func(_ *http.Request, _ *request.RequestInfo) bool { return false }, + lifecycleSignals: newLifecycleSignals(), + TracerProvider: tracing.NewNoopTracerProvider(), } h := DefaultBuildHandlerChain(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/staging/src/k8s.io/apiserver/pkg/server/filters/waitgroup.go b/staging/src/k8s.io/apiserver/pkg/server/filters/waitgroup.go index 70b32c76697..4cab1f86d8b 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/filters/waitgroup.go +++ b/staging/src/k8s.io/apiserver/pkg/server/filters/waitgroup.go @@ -24,20 +24,34 @@ import ( "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime" - utilwaitgroup "k8s.io/apimachinery/pkg/util/waitgroup" "k8s.io/apiserver/pkg/endpoints/handlers/responsewriters" apirequest "k8s.io/apiserver/pkg/endpoints/request" "k8s.io/client-go/kubernetes/scheme" ) +// RequestWaitGroup helps with the accounting of request(s) that are in +// flight: the caller is expected to invoke Add(1) before executing the +// request handler and then invoke Done() when the handler finishes. +// NOTE: implementations must ensure that it is thread-safe +// when invoked from multiple goroutines. +type RequestWaitGroup interface { + // Add adds delta, which may be negative, similar to sync.WaitGroup. + // If Add with a positive delta happens after Wait, it will return error, + // which prevent unsafe Add. + Add(delta int) error + + // Done decrements the WaitGroup counter. + Done() +} + // WithWaitGroup adds all non long-running requests to wait group, which is used for graceful shutdown. -func WithWaitGroup(handler http.Handler, longRunning apirequest.LongRunningRequestCheck, wg *utilwaitgroup.SafeWaitGroup) http.Handler { +func WithWaitGroup(handler http.Handler, longRunning apirequest.LongRunningRequestCheck, wg RequestWaitGroup) http.Handler { // NOTE: both WithWaitGroup and WithRetryAfter must use the same exact isRequestExemptFunc 'isRequestExemptFromRetryAfter, // otherwise SafeWaitGroup might wait indefinitely and will prevent the server from shutting down gracefully. return withWaitGroup(handler, longRunning, wg, isRequestExemptFromRetryAfter) } -func withWaitGroup(handler http.Handler, longRunning apirequest.LongRunningRequestCheck, wg *utilwaitgroup.SafeWaitGroup, isRequestExemptFn isRequestExemptFunc) http.Handler { +func withWaitGroup(handler http.Handler, longRunning apirequest.LongRunningRequestCheck, wg RequestWaitGroup, isRequestExemptFn isRequestExemptFunc) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { ctx := req.Context() requestInfo, ok := apirequest.RequestInfoFrom(ctx) @@ -64,12 +78,7 @@ func withWaitGroup(handler http.Handler, longRunning apirequest.LongRunningReque // When apiserver is shutting down, signal clients to retry // There is a good chance the client hit a different server, so a tight retry is good for client responsiveness. - w.Header().Add("Retry-After", "1") - w.Header().Set("Content-Type", runtime.ContentTypeJSON) - w.Header().Set("X-Content-Type-Options", "nosniff") - statusErr := apierrors.NewServiceUnavailable("apiserver is shutting down").Status() - w.WriteHeader(int(statusErr.Code)) - fmt.Fprintln(w, runtime.EncodeOrDie(scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion), &statusErr)) + waitGroupWriteRetryAfterToResponse(w) return } @@ -77,3 +86,12 @@ func withWaitGroup(handler http.Handler, longRunning apirequest.LongRunningReque handler.ServeHTTP(w, req) }) } + +func waitGroupWriteRetryAfterToResponse(w http.ResponseWriter) { + w.Header().Add("Retry-After", "1") + w.Header().Set("Content-Type", runtime.ContentTypeJSON) + w.Header().Set("X-Content-Type-Options", "nosniff") + statusErr := apierrors.NewServiceUnavailable("apiserver is shutting down").Status() + w.WriteHeader(int(statusErr.Code)) + fmt.Fprintln(w, runtime.EncodeOrDie(scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion), &statusErr)) +} diff --git a/staging/src/k8s.io/apiserver/pkg/server/genericapiserver.go b/staging/src/k8s.io/apiserver/pkg/server/genericapiserver.go index 51a3e542a36..cc1afd0ba4b 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/genericapiserver.go +++ b/staging/src/k8s.io/apiserver/pkg/server/genericapiserver.go @@ -217,8 +217,10 @@ type GenericAPIServer struct { // delegationTarget is the next delegate in the chain. This is never nil. delegationTarget DelegationTarget - // HandlerChainWaitGroup allows you to wait for all chain handlers finish after the server shutdown. - HandlerChainWaitGroup *utilwaitgroup.SafeWaitGroup + // NonLongRunningRequestWaitGroup allows you to wait for all chain + // handlers associated with non long-running requests + // to complete while the server is shuting down. + NonLongRunningRequestWaitGroup *utilwaitgroup.SafeWaitGroup // ShutdownDelayDuration allows to block shutdown for some time, e.g. until endpoints pointing to this API server // have converged on all node. During this time, the API server keeps serving, /healthz will return 200, @@ -452,7 +454,7 @@ func (s *GenericAPIServer) PrepareRun() preparedGenericAPIServer { // | | | | | // | | ---------------| | // | | | | -// | | (HandlerChainWaitGroup::Wait) | +// | | (NonLongRunningRequestWaitGroup::Wait) | // | | | | // | | InFlightRequestsDrained (drainedCh) | // | | | | @@ -582,7 +584,7 @@ func (s preparedGenericAPIServer) Run(stopCh <-chan struct{}) error { <-notAcceptingNewRequestCh.Signaled() // Wait for all requests to finish, which are bounded by the RequestTimeout variable. - // once HandlerChainWaitGroup.Wait is invoked, the apiserver is + // once NonLongRunningRequestWaitGroup.Wait is invoked, the apiserver is // expected to reject any incoming request with a {503, Retry-After} // response via the WithWaitGroup filter. On the contrary, we observe // that incoming request(s) get a 'connection refused' error, this is @@ -594,7 +596,7 @@ func (s preparedGenericAPIServer) Run(stopCh <-chan struct{}) error { // 'Server.Shutdown' will be invoked only after in-flight requests // have been drained. // TODO: can we consolidate these two modes of graceful termination? - s.HandlerChainWaitGroup.Wait() + s.NonLongRunningRequestWaitGroup.Wait() }() klog.V(1).Info("[graceful-termination] waiting for shutdown to be initiated") diff --git a/staging/src/k8s.io/apiserver/pkg/server/genericapiserver_graceful_termination_test.go b/staging/src/k8s.io/apiserver/pkg/server/genericapiserver_graceful_termination_test.go index 373ce82d52d..ee8371ddd8c 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/genericapiserver_graceful_termination_test.go +++ b/staging/src/k8s.io/apiserver/pkg/server/genericapiserver_graceful_termination_test.go @@ -144,7 +144,7 @@ func newSignalInterceptingTestStep() *signalInterceptingTestStep { // | | // | |-------------------------------------------------| // | | | -// | close(stopHttpServerCh) HandlerChainWaitGroup.Wait() +// | close(stopHttpServerCh) NonLongRunningRequestWaitGroup.Wait() // | | | // | server.Shutdown(timeout=60s) | // | | | @@ -357,7 +357,7 @@ func TestGracefulTerminationWithKeepListeningDuringGracefulTerminationDisabled(t // | | // | (NotAcceptingNewRequest) // | | -// | HandlerChainWaitGroup.Wait() +// | NonLongRunningRequestWaitGroup.Wait() // | | // | (InFlightRequestsDrained) // | | diff --git a/staging/src/k8s.io/apiserver/pkg/server/genericapiserver_test.go b/staging/src/k8s.io/apiserver/pkg/server/genericapiserver_test.go index ccffc8276ce..96c4e145620 100644 --- a/staging/src/k8s.io/apiserver/pkg/server/genericapiserver_test.go +++ b/staging/src/k8s.io/apiserver/pkg/server/genericapiserver_test.go @@ -602,7 +602,7 @@ func TestGracefulShutdown(t *testing.T) { wg.Add(1) config.BuildHandlerChainFunc = func(apiHandler http.Handler, c *Config) http.Handler { - handler := genericfilters.WithWaitGroup(apiHandler, c.LongRunningFunc, c.HandlerChainWaitGroup) + handler := genericfilters.WithWaitGroup(apiHandler, c.LongRunningFunc, c.NonLongRunningRequestWaitGroup) handler = genericapifilters.WithRequestInfo(handler, c.RequestInfoResolver) return handler } @@ -666,7 +666,7 @@ func TestGracefulShutdown(t *testing.T) { } // wait for wait group handler finish - s.HandlerChainWaitGroup.Wait() + s.NonLongRunningRequestWaitGroup.Wait() <-stoppedCh // check server all handlers finished.