diff --git a/pkg/scheduler/framework/runtime/framework.go b/pkg/scheduler/framework/runtime/framework.go index d0c4c2150cf..ed5579453be 100644 --- a/pkg/scheduler/framework/runtime/framework.go +++ b/pkg/scheduler/framework/runtime/framework.go @@ -1263,25 +1263,26 @@ func (f *frameworkImpl) SharedInformerFactory() informers.SharedInformerFactory return f.informerFactory } -func (f *frameworkImpl) pluginsNeeded(plugins *config.Plugins) map[string]config.Plugin { - pgMap := make(map[string]config.Plugin) +func (f *frameworkImpl) pluginsNeeded(plugins *config.Plugins) sets.String { + pgSet := sets.String{} if plugins == nil { - return pgMap + return pgSet } find := func(pgs *config.PluginSet) { for _, pg := range pgs.Enabled { - pgMap[pg.Name] = pg + pgSet.Insert(pg.Name) } } + for _, e := range f.getExtensionPoints(plugins) { find(e.plugins) } - // Parse MultiPoint separately since they are not returned by f.getExtensionPoints() find(&plugins.MultiPoint) - return pgMap + + return pgSet } // ProfileName returns the profile name associated to this framework. diff --git a/pkg/scheduler/framework/runtime/framework_test.go b/pkg/scheduler/framework/runtime/framework_test.go index 3ec5e7121d2..dd3e3601ba2 100644 --- a/pkg/scheduler/framework/runtime/framework_test.go +++ b/pkg/scheduler/framework/runtime/framework_test.go @@ -338,6 +338,8 @@ var registry = func() Registry { r.Register(pluginNotImplementingScore, newPluginNotImplementingScore) r.Register(duplicatePluginName, newDuplicatePlugin) r.Register(testPlugin, newTestPlugin) + r.Register(queueSortPlugin, newQueueSortPlugin) + r.Register(bindPlugin, newBindPlugin) return r }() @@ -757,12 +759,19 @@ func TestNewFrameworkMultiPointExpansion(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { fw, err := NewFramework(registry, &config.KubeSchedulerProfile{Plugins: tc.plugins}) - if (err != nil && tc.wantErr == "") || (err == nil && tc.wantErr != "") || (err != nil && !strings.Contains(err.Error(), tc.wantErr)) { - t.Errorf("Unexpected error, got %v, expect: %s", err, tc.wantErr) + if err != nil { + if tc.wantErr == "" || !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("Unexpected error, got %v, expect: %s", err, tc.wantErr) + } + } else { + if tc.wantErr != "" { + t.Fatalf("Unexpected error, got %v, expect: %s", err, tc.wantErr) + } } + if tc.wantErr == "" { if diff := cmp.Diff(tc.wantPlugins, fw.ListPlugins()); diff != "" { - t.Errorf("Unexpected eventToPlugin map (-want,+got):%s", diff) + t.Fatalf("Unexpected eventToPlugin map (-want,+got):%s", diff) } } }) @@ -1083,6 +1092,33 @@ func TestRunScorePlugins(t *testing.T) { }, err: true, }, + { + name: "single Score plugin with MultiPointExpansion", + plugins: &config.Plugins{ + MultiPoint: config.PluginSet{ + Enabled: []config.Plugin{ + {Name: scorePlugin1}, + }, + }, + Score: config.PluginSet{ + Enabled: []config.Plugin{ + {Name: scorePlugin1, Weight: 3}, + }, + }, + }, + pluginConfigs: []config.PluginConfig{ + { + Name: scorePlugin1, + Args: &runtime.Unknown{ + Raw: []byte(`{ "scoreRes": 1 }`), + }, + }, + }, + // scorePlugin1 Score returns 1, weight=3, so want=3. + want: framework.PluginToNodeScores{ + scorePlugin1: {{Name: "node1", Score: 3}, {Name: "node2", Score: 3}}, + }, + }, } for _, tt := range tests { @@ -1110,7 +1146,7 @@ func TestRunScorePlugins(t *testing.T) { t.Errorf("Expected status to be success.") } if !reflect.DeepEqual(res, tt.want) { - t.Errorf("Score map after RunScorePlugin: %+v, want: %+v.", res, tt.want) + t.Errorf("Score map after RunScorePlugin. got: %+v, want: %+v.", res, tt.want) } }) }