diff --git a/pkg/scheduler/framework/interface.go b/pkg/scheduler/framework/interface.go index 65a115dc0d1..756362cce57 100644 --- a/pkg/scheduler/framework/interface.go +++ b/pkg/scheduler/framework/interface.go @@ -649,6 +649,9 @@ type Framework interface { // SetPodNominator sets the PodNominator SetPodNominator(nominator PodNominator) + + // Close calls Close method of each plugin. + Close() error } // Handle provides data and some tools that plugins can use. It is diff --git a/pkg/scheduler/framework/runtime/framework.go b/pkg/scheduler/framework/runtime/framework.go index 705d2b42472..aea68fe2e55 100644 --- a/pkg/scheduler/framework/runtime/framework.go +++ b/pkg/scheduler/framework/runtime/framework.go @@ -18,7 +18,9 @@ package runtime import ( "context" + "errors" "fmt" + "io" "reflect" "sort" "time" @@ -66,6 +68,9 @@ type frameworkImpl struct { postBindPlugins []framework.PostBindPlugin permitPlugins []framework.PermitPlugin + // pluginsMap contains all plugins, by name. + pluginsMap map[string]framework.Plugin + clientSet clientset.Interface kubeConfig *restclient.Config eventRecorder events.EventRecorder @@ -297,7 +302,7 @@ func NewFramework(ctx context.Context, r Registry, profile *config.KubeScheduler PluginConfig: make([]config.PluginConfig, 0, len(pg)), } - pluginsMap := make(map[string]framework.Plugin) + f.pluginsMap = make(map[string]framework.Plugin) for name, factory := range r { // initialize only needed plugins. if !pg.Has(name) { @@ -315,21 +320,21 @@ func NewFramework(ctx context.Context, r Registry, profile *config.KubeScheduler if err != nil { return nil, fmt.Errorf("initializing plugin %q: %w", name, err) } - pluginsMap[name] = p + f.pluginsMap[name] = p f.fillEnqueueExtensions(p) } // initialize plugins per individual extension points for _, e := range f.getExtensionPoints(profile.Plugins) { - if err := updatePluginList(e.slicePtr, *e.plugins, pluginsMap); err != nil { + if err := updatePluginList(e.slicePtr, *e.plugins, f.pluginsMap); err != nil { return nil, err } } // initialize multiPoint plugins to their expanded extension points if len(profile.Plugins.MultiPoint.Enabled) > 0 { - if err := f.expandMultiPointPlugins(logger, profile, pluginsMap); err != nil { + if err := f.expandMultiPointPlugins(logger, profile); err != nil { return nil, err } } @@ -341,7 +346,7 @@ func NewFramework(ctx context.Context, r Registry, profile *config.KubeScheduler return nil, fmt.Errorf("at least one bind plugin is needed for profile with scheduler name %q", profile.SchedulerName) } - if err := getScoreWeights(f, pluginsMap, append(profile.Plugins.Score.Enabled, profile.Plugins.MultiPoint.Enabled...)); err != nil { + if err := getScoreWeights(f, append(profile.Plugins.Score.Enabled, profile.Plugins.MultiPoint.Enabled...)); err != nil { return nil, err } @@ -405,14 +410,29 @@ func (f *frameworkImpl) SetPodNominator(n framework.PodNominator) { f.PodNominator = n } +// Close closes each plugin, when they implement io.Closer interface. +func (f *frameworkImpl) Close() error { + var errs []error + for name, plugin := range f.pluginsMap { + if closer, ok := plugin.(io.Closer); ok { + err := closer.Close() + if err != nil { + errs = append(errs, fmt.Errorf("%s failed to close: %w", name, err)) + // We try to close all plugins even if we got errors from some. + } + } + } + return errors.Join(errs...) +} + // getScoreWeights makes sure that, between MultiPoint-Score plugin weights and individual Score // plugin weights there is not an overflow of MaxTotalScore. -func getScoreWeights(f *frameworkImpl, pluginsMap map[string]framework.Plugin, plugins []config.Plugin) error { +func getScoreWeights(f *frameworkImpl, plugins []config.Plugin) error { var totalPriority int64 scorePlugins := reflect.ValueOf(&f.scorePlugins).Elem() pluginType := scorePlugins.Type().Elem() for _, e := range plugins { - pg := pluginsMap[e.Name] + pg := f.pluginsMap[e.Name] if !reflect.TypeOf(pg).Implements(pluginType) { continue } @@ -469,7 +489,7 @@ func (os *orderedSet) delete(s string) { } } -func (f *frameworkImpl) expandMultiPointPlugins(logger klog.Logger, profile *config.KubeSchedulerProfile, pluginsMap map[string]framework.Plugin) error { +func (f *frameworkImpl) expandMultiPointPlugins(logger klog.Logger, profile *config.KubeSchedulerProfile) error { // initialize MultiPoint plugins for _, e := range f.getExtensionPoints(profile.Plugins) { plugins := reflect.ValueOf(e.slicePtr).Elem() @@ -495,7 +515,7 @@ func (f *frameworkImpl) expandMultiPointPlugins(logger klog.Logger, profile *con multiPointEnabled := newOrderedSet() overridePlugins := newOrderedSet() for _, ep := range profile.Plugins.MultiPoint.Enabled { - pg, ok := pluginsMap[ep.Name] + pg, ok := f.pluginsMap[ep.Name] if !ok { return fmt.Errorf("%s %q does not exist", pluginType.Name(), ep.Name) } @@ -539,17 +559,17 @@ func (f *frameworkImpl) expandMultiPointPlugins(logger klog.Logger, profile *con // part 1 for _, name := range slice.CopyStrings(enabledSet.list) { if overridePlugins.has(name) { - newPlugins = reflect.Append(newPlugins, reflect.ValueOf(pluginsMap[name])) + newPlugins = reflect.Append(newPlugins, reflect.ValueOf(f.pluginsMap[name])) enabledSet.delete(name) } } // part 2 for _, name := range multiPointEnabled.list { - newPlugins = reflect.Append(newPlugins, reflect.ValueOf(pluginsMap[name])) + newPlugins = reflect.Append(newPlugins, reflect.ValueOf(f.pluginsMap[name])) } // part 3 for _, name := range enabledSet.list { - newPlugins = reflect.Append(newPlugins, reflect.ValueOf(pluginsMap[name])) + newPlugins = reflect.Append(newPlugins, reflect.ValueOf(f.pluginsMap[name])) } plugins.Set(newPlugins) } diff --git a/pkg/scheduler/framework/runtime/framework_test.go b/pkg/scheduler/framework/runtime/framework_test.go index 8a16de8b182..1a7ddc66606 100644 --- a/pkg/scheduler/framework/runtime/framework_test.go +++ b/pkg/scheduler/framework/runtime/framework_test.go @@ -54,6 +54,7 @@ const ( testPlugin = "test-plugin" permitPlugin = "permit-plugin" bindPlugin = "bind-plugin" + testCloseErrorPlugin = "test-close-error-plugin" testProfileName = "test-profile" testPercentageOfNodesToScore = 35 @@ -238,6 +239,25 @@ func (pl *TestPlugin) Bind(ctx context.Context, state *framework.CycleState, p * return framework.NewStatus(framework.Code(pl.inj.BindStatus), injectReason) } +func newTestCloseErrorPlugin(_ context.Context, injArgs runtime.Object, f framework.Handle) (framework.Plugin, error) { + return &TestCloseErrorPlugin{name: testCloseErrorPlugin}, nil +} + +// TestCloseErrorPlugin implements for Close test. +type TestCloseErrorPlugin struct { + name string +} + +func (pl *TestCloseErrorPlugin) Name() string { + return pl.name +} + +var errClose = errors.New("close err") + +func (pl *TestCloseErrorPlugin) Close() error { + return errClose +} + // TestPreFilterPlugin only implements PreFilterPlugin interface. type TestPreFilterPlugin struct { PreFilterCalled int @@ -379,6 +399,7 @@ var registry = func() Registry { r.Register(testPlugin, newTestPlugin) r.Register(queueSortPlugin, newQueueSortPlugin) r.Register(bindPlugin, newBindPlugin) + r.Register(testCloseErrorPlugin, newTestCloseErrorPlugin) return r }() @@ -3211,6 +3232,53 @@ func TestListPlugins(t *testing.T) { } } +func TestClose(t *testing.T) { + tests := []struct { + name string + plugins *config.Plugins + wantErr error + }{ + { + name: "close doesn't return error", + plugins: &config.Plugins{ + MultiPoint: config.PluginSet{ + Enabled: []config.Plugin{ + {Name: testPlugin, Weight: 5}, + }, + }, + }, + }, + { + name: "close returns error", + plugins: &config.Plugins{ + MultiPoint: config.PluginSet{ + Enabled: []config.Plugin{ + {Name: testPlugin, Weight: 5}, + {Name: testCloseErrorPlugin}, + }, + }, + }, + wantErr: errClose, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, ctx := ktesting.NewTestContext(t) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + fw, err := NewFramework(ctx, registry, &config.KubeSchedulerProfile{Plugins: tc.plugins}) + if err != nil { + t.Fatalf("Unexpected error during calling NewFramework, got %v", err) + } + err = fw.Close() + if !errors.Is(err, tc.wantErr) { + t.Fatalf("Unexpected error from Close(), got: %v, want: %v", err, tc.wantErr) + } + }) + } +} + func buildScoreConfigDefaultWeights(ps ...string) *config.Plugins { return buildScoreConfigWithWeights(defaultWeights, ps...) } diff --git a/pkg/scheduler/profile/profile.go b/pkg/scheduler/profile/profile.go index 3846d9720df..e794411a119 100644 --- a/pkg/scheduler/profile/profile.go +++ b/pkg/scheduler/profile/profile.go @@ -70,6 +70,18 @@ func (m Map) HandlesSchedulerName(name string) bool { return ok } +// Close closes all frameworks registered in this map. +func (m Map) Close() error { + var errs []error + for name, f := range m { + err := f.Close() + if err != nil { + errs = append(errs, fmt.Errorf("framework %s failed to close: %w", name, err)) + } + } + return errors.Join(errs...) +} + // NewRecorderFactory returns a RecorderFactory for the broadcaster. func NewRecorderFactory(b events.EventBroadcaster) RecorderFactory { return func(name string) events.EventRecorder { diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index 84b73c44954..217b340b88a 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -415,6 +415,12 @@ func (sched *Scheduler) Run(ctx context.Context) { <-ctx.Done() sched.SchedulingQueue.Close() + + // If the plugins satisfy the io.Closer interface, they are closed. + err := sched.Profiles.Close() + if err != nil { + logger.Error(err, "Failed to close plugins") + } } // NewInformerFactory creates a SharedInformerFactory and initializes a scheduler specific