diff --git a/pkg/kubelet/cm/dra/plugin/registration.go b/pkg/kubelet/cm/dra/plugin/registration.go index 1d95e582e0a..0f410c4290b 100644 --- a/pkg/kubelet/cm/dra/plugin/registration.go +++ b/pkg/kubelet/cm/dra/plugin/registration.go @@ -52,10 +52,12 @@ type RegistrationHandler struct { // This is necessary because it implements APIs which don't // provide a context. backgroundCtx context.Context + cancel func(err error) kubeClient kubernetes.Interface getNode func() (*v1.Node, error) wipingDelay time.Duration + wg sync.WaitGroup mutex sync.Mutex // pendingWipes maps a plugin name to a cancel function for @@ -76,9 +78,15 @@ var _ cache.PluginHandler = &RegistrationHandler{} // If a kubeClient is provided, then it synchronizes ResourceSlices // with the resource information provided by plugins. func NewRegistrationHandler(kubeClient kubernetes.Interface, getNode func() (*v1.Node, error), wipingDelay time.Duration) *RegistrationHandler { + // The context and thus logger should come from the caller. + return newRegistrationHandler(context.TODO(), kubeClient, getNode, wipingDelay) +} + +func newRegistrationHandler(ctx context.Context, kubeClient kubernetes.Interface, getNode func() (*v1.Node, error), wipingDelay time.Duration) *RegistrationHandler { + ctx, cancel := context.WithCancelCause(ctx) handler := &RegistrationHandler{ - // The context and thus logger should come from the caller. - backgroundCtx: klog.NewContext(context.TODO(), klog.LoggerWithName(klog.TODO(), "DRA registration handler")), + backgroundCtx: klog.NewContext(ctx, klog.LoggerWithName(klog.FromContext(ctx), "DRA registration handler")), + cancel: cancel, kubeClient: kubeClient, getNode: getNode, wipingDelay: wipingDelay, @@ -92,13 +100,24 @@ func NewRegistrationHandler(kubeClient kubernetes.Interface, getNode func() (*v1 // to start up. // // This has to run in the background. - logger := klog.LoggerWithName(klog.FromContext(handler.backgroundCtx), "startup") - ctx := klog.NewContext(handler.backgroundCtx, logger) - go handler.wipeResourceSlices(ctx, 0 /* no delay */, "" /* all drivers */) + handler.wg.Add(1) + go func() { + defer handler.wg.Done() + + logger := klog.LoggerWithName(klog.FromContext(handler.backgroundCtx), "startup") + ctx := klog.NewContext(handler.backgroundCtx, logger) + handler.wipeResourceSlices(ctx, 0 /* no delay */, "" /* all drivers */) + }() return handler } +// Stop cancels any remaining background activities and blocks until all goroutines have stopped. +func (h *RegistrationHandler) Stop() { + h.cancel(errors.New("Stop was called")) + h.wg.Wait() +} + // wipeResourceSlices deletes ResourceSlices of the node, optionally just for a specific driver. // Wiping will delay for a while and can be canceled by canceling the context. func (h *RegistrationHandler) wipeResourceSlices(ctx context.Context, delay time.Duration, driver string) { @@ -291,7 +310,9 @@ func (h *RegistrationHandler) DeRegisterPlugin(pluginName, endpoint string) { } h.pendingWipes[pluginName] = &cancel + h.wg.Add(1) go func() { + defer h.wg.Done() defer func() { h.mutex.Lock() defer h.mutex.Unlock() diff --git a/pkg/kubelet/cm/dra/plugin/registration_test.go b/pkg/kubelet/cm/dra/plugin/registration_test.go index 59f3099ab65..2013563b744 100644 --- a/pkg/kubelet/cm/dra/plugin/registration_test.go +++ b/pkg/kubelet/cm/dra/plugin/registration_test.go @@ -149,15 +149,26 @@ func TestRegistrationHandler(t *testing.T) { // Set expected slice fields for the next call of this reactor. // The reactor will be called next time when resourceslices object is deleted // by the kubelet after plugin deregistration. - expectedSliceFields = fields.Set{"spec.nodeName": nodeName, "spec.driver": test.pluginName} - + switch len(expectedSliceFields) { + case 1: + // Startup cleanup done, now expect cleanup for test plugin. + expectedSliceFields = fields.Set{"spec.nodeName": nodeName, "spec.driver": test.pluginName} + case 2: + // Test plugin cleanup done, now expect cleanup for the other plugin. + otherPlugin := pluginA + if otherPlugin == test.pluginName { + otherPlugin = pluginB + } + expectedSliceFields = fields.Set{"spec.nodeName": nodeName, "spec.driver": otherPlugin} + } return true, nil, err }) client = fakeClient } // The handler wipes all slices at startup. - handler := NewRegistrationHandler(client, getFakeNode, time.Second /* very short wiping delay for testing */) + handler := newRegistrationHandler(tCtx, client, getFakeNode, time.Second /* very short wiping delay for testing */) + tCtx.Cleanup(handler.Stop) requireNoSlices := func() { t.Helper() if client == nil {