diff --git a/pkg/scheduler/framework/plugins/defaultpodtopologyspread/BUILD b/pkg/scheduler/framework/plugins/defaultpodtopologyspread/BUILD index 1b4742f7704..8369b967740 100644 --- a/pkg/scheduler/framework/plugins/defaultpodtopologyspread/BUILD +++ b/pkg/scheduler/framework/plugins/defaultpodtopologyspread/BUILD @@ -26,6 +26,7 @@ go_test( deps = [ "//pkg/scheduler/framework/v1alpha1:go_default_library", "//pkg/scheduler/internal/cache:go_default_library", + "//pkg/scheduler/internal/parallelize:go_default_library", "//pkg/scheduler/testing:go_default_library", "//staging/src/k8s.io/api/apps/v1:go_default_library", "//staging/src/k8s.io/api/core/v1:go_default_library", diff --git a/pkg/scheduler/framework/plugins/defaultpodtopologyspread/default_pod_topology_spread_perf_test.go b/pkg/scheduler/framework/plugins/defaultpodtopologyspread/default_pod_topology_spread_perf_test.go index 5e03d3bce9d..212e70a1f26 100644 --- a/pkg/scheduler/framework/plugins/defaultpodtopologyspread/default_pod_topology_spread_perf_test.go +++ b/pkg/scheduler/framework/plugins/defaultpodtopologyspread/default_pod_topology_spread_perf_test.go @@ -25,6 +25,7 @@ import ( "k8s.io/client-go/kubernetes/fake" framework "k8s.io/kubernetes/pkg/scheduler/framework/v1alpha1" "k8s.io/kubernetes/pkg/scheduler/internal/cache" + "k8s.io/kubernetes/pkg/scheduler/internal/parallelize" st "k8s.io/kubernetes/pkg/scheduler/testing" ) @@ -76,15 +77,14 @@ func BenchmarkTestSelectorSpreadPriority(b *testing.B) { if !status.IsSuccess() { b.Fatalf("unexpected error: %v", status) } - var gotList framework.NodeScoreList - for _, node := range filteredNodes { - score, status := plugin.Score(ctx, state, pod, node.Name) - if !status.IsSuccess() { - b.Errorf("unexpected error: %v", status) - } - gotList = append(gotList, framework.NodeScore{Name: node.Name, Score: score}) + gotList := make(framework.NodeScoreList, len(filteredNodes)) + scoreNode := func(i int) { + n := filteredNodes[i] + score, _ := plugin.Score(ctx, state, pod, n.Name) + gotList[i] = framework.NodeScore{Name: n.Name, Score: score} } - status = plugin.NormalizeScore(context.Background(), state, pod, gotList) + parallelize.Until(ctx, len(filteredNodes), scoreNode) + status = plugin.NormalizeScore(ctx, state, pod, gotList) if !status.IsSuccess() { b.Fatal(status) } diff --git a/pkg/scheduler/framework/plugins/podtopologyspread/BUILD b/pkg/scheduler/framework/plugins/podtopologyspread/BUILD index 78a644051bc..d24ad02b7dc 100644 --- a/pkg/scheduler/framework/plugins/podtopologyspread/BUILD +++ b/pkg/scheduler/framework/plugins/podtopologyspread/BUILD @@ -41,6 +41,7 @@ go_test( deps = [ "//pkg/scheduler/framework/v1alpha1:go_default_library", "//pkg/scheduler/internal/cache:go_default_library", + "//pkg/scheduler/internal/parallelize:go_default_library", "//pkg/scheduler/nodeinfo:go_default_library", "//pkg/scheduler/testing:go_default_library", "//staging/src/k8s.io/api/apps/v1:go_default_library", diff --git a/pkg/scheduler/framework/plugins/podtopologyspread/common.go b/pkg/scheduler/framework/plugins/podtopologyspread/common.go index b87af00c88e..fa53a9ac48b 100644 --- a/pkg/scheduler/framework/plugins/podtopologyspread/common.go +++ b/pkg/scheduler/framework/plugins/podtopologyspread/common.go @@ -82,3 +82,17 @@ func filterTopologySpreadConstraints(constraints []v1.TopologySpreadConstraint, } return result, nil } + +func countPodsMatchSelector(pods []*v1.Pod, selector labels.Selector, ns string) int { + count := 0 + for _, p := range pods { + // Bypass terminating Pod (see #87621). + if p.DeletionTimestamp != nil || p.Namespace != ns { + continue + } + if selector.Matches(labels.Set(p.Labels)) { + count++ + } + } + return count +} diff --git a/pkg/scheduler/framework/plugins/podtopologyspread/scoring.go b/pkg/scheduler/framework/plugins/podtopologyspread/scoring.go index 8e347bb3dbc..5bd9e40b95e 100644 --- a/pkg/scheduler/framework/plugins/podtopologyspread/scoring.go +++ b/pkg/scheduler/framework/plugins/podtopologyspread/scoring.go @@ -23,7 +23,6 @@ import ( "sync/atomic" v1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/util/sets" pluginhelper "k8s.io/kubernetes/pkg/scheduler/framework/plugins/helper" framework "k8s.io/kubernetes/pkg/scheduler/framework/v1alpha1" @@ -73,6 +72,10 @@ func (pl *PodTopologySpread) initPreScoreState(s *preScoreState, pod *v1.Pod, fi continue } for _, constraint := range s.Constraints { + // per-node counts are calculated during Score. + if constraint.TopologyKey == v1.LabelHostname { + continue + } pair := topologyPair{key: constraint.TopologyKey, value: node.Labels[constraint.TopologyKey]} if s.TopologyPairToPodCounts[pair] == nil { s.TopologyPairToPodCounts[pair] = new(int64) @@ -103,7 +106,7 @@ func (pl *PodTopologySpread) PreScore( } state := &preScoreState{ - NodeNameSet: sets.String{}, + NodeNameSet: make(sets.String, len(filteredNodes)), TopologyPairToPodCounts: make(map[topologyPair]*int64), } err = pl.initPreScoreState(state, pod, filteredNodes) @@ -134,22 +137,13 @@ func (pl *PodTopologySpread) PreScore( pair := topologyPair{key: c.TopologyKey, value: node.Labels[c.TopologyKey]} // If current topology pair is not associated with any candidate node, // continue to avoid unnecessary calculation. - if state.TopologyPairToPodCounts[pair] == nil { + // Per-node counts are also skipped, as they are done during Score. + tpCount := state.TopologyPairToPodCounts[pair] + if tpCount == nil { continue } - - // indicates how many pods (on current node) match the . - matchSum := int64(0) - for _, existingPod := range nodeInfo.Pods() { - // Bypass terminating Pod (see #87621). - if existingPod.DeletionTimestamp != nil || existingPod.Namespace != pod.Namespace { - continue - } - if c.Selector.Matches(labels.Set(existingPod.Labels)) { - matchSum++ - } - } - atomic.AddInt64(state.TopologyPairToPodCounts[pair], matchSum) + count := countPodsMatchSelector(nodeInfo.Pods(), c.Selector, pod.Namespace) + atomic.AddInt64(tpCount, int64(count)) } } parallelize.Until(ctx, len(allNodes), processAllNode) @@ -183,9 +177,14 @@ func (pl *PodTopologySpread) Score(ctx context.Context, cycleState *framework.Cy var score int64 for _, c := range s.Constraints { if tpVal, ok := node.Labels[c.TopologyKey]; ok { - pair := topologyPair{key: c.TopologyKey, value: tpVal} - matchSum := *s.TopologyPairToPodCounts[pair] - score += matchSum + if c.TopologyKey == v1.LabelHostname { + count := countPodsMatchSelector(nodeInfo.Pods(), c.Selector, pod.Namespace) + score += int64(count) + } else { + pair := topologyPair{key: c.TopologyKey, value: tpVal} + matchSum := *s.TopologyPairToPodCounts[pair] + score += matchSum + } } } return score, nil diff --git a/pkg/scheduler/framework/plugins/podtopologyspread/scoring_test.go b/pkg/scheduler/framework/plugins/podtopologyspread/scoring_test.go index 60583b7f021..c34503f02b5 100644 --- a/pkg/scheduler/framework/plugins/podtopologyspread/scoring_test.go +++ b/pkg/scheduler/framework/plugins/podtopologyspread/scoring_test.go @@ -29,6 +29,7 @@ import ( "k8s.io/client-go/kubernetes/fake" framework "k8s.io/kubernetes/pkg/scheduler/framework/v1alpha1" "k8s.io/kubernetes/pkg/scheduler/internal/cache" + "k8s.io/kubernetes/pkg/scheduler/internal/parallelize" st "k8s.io/kubernetes/pkg/scheduler/testing" "k8s.io/utils/pointer" ) @@ -746,19 +747,18 @@ func BenchmarkTestDefaultEvenPodsSpreadPriority(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - var gotList framework.NodeScoreList status := p.PreScore(ctx, state, pod, filteredNodes) if !status.IsSuccess() { b.Fatalf("unexpected error: %v", status) } - for _, n := range filteredNodes { - score, status := p.Score(context.Background(), state, pod, n.Name) - if !status.IsSuccess() { - b.Fatalf("unexpected error: %v", status) - } - gotList = append(gotList, framework.NodeScore{Name: n.Name, Score: score}) + gotList := make(framework.NodeScoreList, len(filteredNodes)) + scoreNode := func(i int) { + n := filteredNodes[i] + score, _ := p.Score(ctx, state, pod, n.Name) + gotList[i] = framework.NodeScore{Name: n.Name, Score: score} } - status = p.NormalizeScore(context.Background(), state, pod, gotList) + parallelize.Until(ctx, len(filteredNodes), scoreNode) + status = p.NormalizeScore(ctx, state, pod, gotList) if !status.IsSuccess() { b.Fatal(status) }