Merge pull request #116884 from mengjiao-liu/contextual-logging-scheduler-plugin-nodevolumelimits

Change the scheduler plugins FactoryAdapter function to use context parameter to pass logger
This commit is contained in:
Kubernetes Prow Robot 2023-09-20 11:26:00 -07:00 committed by GitHub
commit 89b4153d4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
53 changed files with 287 additions and 218 deletions

View File

@ -450,7 +450,7 @@ func (*foo) Name() string {
return "Foo" return "Foo"
} }
func newFoo(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func newFoo(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &foo{}, nil return &foo{}, nil
} }

View File

@ -38,7 +38,7 @@ type DefaultBinder struct {
var _ framework.BindPlugin = &DefaultBinder{} var _ framework.BindPlugin = &DefaultBinder{}
// New creates a DefaultBinder. // New creates a DefaultBinder.
func New(_ runtime.Object, handle framework.Handle) (framework.Plugin, error) { func New(_ context.Context, _ runtime.Object, handle framework.Handle) (framework.Plugin, error) {
return &DefaultBinder{handle: handle}, nil return &DefaultBinder{handle: handle}, nil
} }

View File

@ -63,7 +63,7 @@ func (pl *DefaultPreemption) Name() string {
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(dpArgs runtime.Object, fh framework.Handle, fts feature.Features) (framework.Plugin, error) { func New(_ context.Context, dpArgs runtime.Object, fh framework.Handle, fts feature.Features) (framework.Plugin, error) {
args, ok := dpArgs.(*config.DefaultPreemptionArgs) args, ok := dpArgs.(*config.DefaultPreemptionArgs)
if !ok { if !ok {
return nil, fmt.Errorf("got args of type %T, want *DefaultPreemptionArgs", dpArgs) return nil, fmt.Errorf("got args of type %T, want *DefaultPreemptionArgs", dpArgs)

View File

@ -104,7 +104,7 @@ type TestPlugin struct {
name string name string
} }
func newTestPlugin(injArgs runtime.Object, f framework.Handle) (framework.Plugin, error) { func newTestPlugin(_ context.Context, injArgs runtime.Object, f framework.Handle) (framework.Plugin, error) {
return &TestPlugin{name: "test-plugin"}, nil return &TestPlugin{name: "test-plugin"}, nil
} }
@ -1066,7 +1066,7 @@ func TestDryRunPreemption(t *testing.T) {
registeredPlugins := append([]tf.RegisterPluginFunc{ registeredPlugins := append([]tf.RegisterPluginFunc{
tf.RegisterFilterPlugin( tf.RegisterFilterPlugin(
"FakeFilter", "FakeFilter",
func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return &fakePlugin, nil return &fakePlugin, nil
}, },
)}, )},

View File

@ -241,7 +241,7 @@ type dynamicResources struct {
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(plArgs runtime.Object, fh framework.Handle, fts feature.Features) (framework.Plugin, error) { func New(_ context.Context, plArgs runtime.Object, fh framework.Handle, fts feature.Features) (framework.Plugin, error) {
if !fts.EnableDynamicResourceAllocation { if !fts.EnableDynamicResourceAllocation {
// Disabled, won't do anything. // Disabled, won't do anything.
return &dynamicResources{}, nil return &dynamicResources{}, nil

View File

@ -780,7 +780,7 @@ func setup(t *testing.T, nodes []*v1.Node, claims []*resourcev1alpha2.ResourceCl
t.Fatal(err) t.Fatal(err)
} }
pl, err := New(nil, fh, feature.Features{EnableDynamicResourceAllocation: true}) pl, err := New(ctx, nil, fh, feature.Features{EnableDynamicResourceAllocation: true})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -74,7 +74,7 @@ func (pl *ImageLocality) ScoreExtensions() framework.ScoreExtensions {
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(_ runtime.Object, h framework.Handle) (framework.Plugin, error) { func New(_ context.Context, _ runtime.Object, h framework.Handle) (framework.Plugin, error) {
return &ImageLocality{handle: h}, nil return &ImageLocality{handle: h}, nil
} }

View File

@ -340,7 +340,10 @@ func TestImageLocalityPriority(t *testing.T) {
state := framework.NewCycleState() state := framework.NewCycleState()
fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(snapshot)) fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(snapshot))
p, _ := New(nil, fh) p, err := New(ctx, nil, fh)
if err != nil {
t.Fatalf("creating plugin: %v", err)
}
var gotList framework.NodeScoreList var gotList framework.NodeScoreList
for _, n := range test.nodes { for _, n := range test.nodes {
nodeName := n.ObjectMeta.Name nodeName := n.ObjectMeta.Name

View File

@ -17,6 +17,7 @@ limitations under the License.
package interpodaffinity package interpodaffinity
import ( import (
"context"
"fmt" "fmt"
"k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/labels"
@ -69,7 +70,7 @@ func (pl *InterPodAffinity) EventsToRegister() []framework.ClusterEventWithHint
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(plArgs runtime.Object, h framework.Handle) (framework.Plugin, error) { func New(_ context.Context, plArgs runtime.Object, h framework.Handle) (framework.Plugin, error) {
if h.SnapshotSharedLister() == nil { if h.SnapshotSharedLister() == nil {
return nil, fmt.Errorf("SnapshotSharedlister is nil") return nil, fmt.Errorf("SnapshotSharedlister is nil")
} }

View File

@ -243,7 +243,7 @@ func (pl *NodeAffinity) ScoreExtensions() framework.ScoreExtensions {
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(plArgs runtime.Object, h framework.Handle) (framework.Plugin, error) { func New(_ context.Context, plArgs runtime.Object, h framework.Handle) (framework.Plugin, error) {
args, err := getArgs(plArgs) args, err := getArgs(plArgs)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -896,6 +896,7 @@ func TestNodeAffinity(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
node := v1.Node{ObjectMeta: metav1.ObjectMeta{ node := v1.Node{ObjectMeta: metav1.ObjectMeta{
Name: test.nodeName, Name: test.nodeName,
Labels: test.labels, Labels: test.labels,
@ -903,7 +904,7 @@ func TestNodeAffinity(t *testing.T) {
nodeInfo := framework.NewNodeInfo() nodeInfo := framework.NewNodeInfo()
nodeInfo.SetNode(&node) nodeInfo.SetNode(&node)
p, err := New(&test.args, nil) p, err := New(ctx, &test.args, nil)
if err != nil { if err != nil {
t.Fatalf("Creating plugin: %v", err) t.Fatalf("Creating plugin: %v", err)
} }
@ -1141,7 +1142,7 @@ func TestNodeAffinityPriority(t *testing.T) {
state := framework.NewCycleState() state := framework.NewCycleState()
fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(cache.NewSnapshot(nil, test.nodes))) fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(cache.NewSnapshot(nil, test.nodes)))
p, err := New(&test.args, fh) p, err := New(ctx, &test.args, fh)
if err != nil { if err != nil {
t.Fatalf("Creating plugin: %v", err) t.Fatalf("Creating plugin: %v", err)
} }

View File

@ -67,6 +67,6 @@ func Fits(pod *v1.Pod, nodeInfo *framework.NodeInfo) bool {
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func New(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &NodeName{}, nil return &NodeName{}, nil
} }

View File

@ -17,13 +17,13 @@ limitations under the License.
package nodename package nodename
import ( import (
"context"
"reflect" "reflect"
"testing" "testing"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
"k8s.io/kubernetes/pkg/scheduler/framework" "k8s.io/kubernetes/pkg/scheduler/framework"
st "k8s.io/kubernetes/pkg/scheduler/testing" st "k8s.io/kubernetes/pkg/scheduler/testing"
"k8s.io/kubernetes/test/utils/ktesting"
) )
func TestNodeName(t *testing.T) { func TestNodeName(t *testing.T) {
@ -55,9 +55,12 @@ func TestNodeName(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
nodeInfo := framework.NewNodeInfo() nodeInfo := framework.NewNodeInfo()
nodeInfo.SetNode(test.node) nodeInfo.SetNode(test.node)
_, ctx := ktesting.NewTestContext(t)
p, _ := New(nil, nil) p, err := New(ctx, nil, nil)
gotStatus := p.(framework.FilterPlugin).Filter(context.Background(), nil, test.pod, nodeInfo) if err != nil {
t.Fatalf("creating plugin: %v", err)
}
gotStatus := p.(framework.FilterPlugin).Filter(ctx, nil, test.pod, nodeInfo)
if !reflect.DeepEqual(gotStatus, test.wantStatus) { if !reflect.DeepEqual(gotStatus, test.wantStatus) {
t.Errorf("status does not match: %v, want: %v", gotStatus, test.wantStatus) t.Errorf("status does not match: %v, want: %v", gotStatus, test.wantStatus)
} }

View File

@ -149,6 +149,6 @@ func fitsPorts(wantPorts []*v1.ContainerPort, nodeInfo *framework.NodeInfo) bool
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func New(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &NodePorts{}, nil return &NodePorts{}, nil
} }

View File

@ -17,7 +17,6 @@ limitations under the License.
package nodeports package nodeports
import ( import (
"context"
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
@ -26,6 +25,7 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
"k8s.io/klog/v2/ktesting"
"k8s.io/kubernetes/pkg/scheduler/framework" "k8s.io/kubernetes/pkg/scheduler/framework"
st "k8s.io/kubernetes/pkg/scheduler/testing" st "k8s.io/kubernetes/pkg/scheduler/testing"
) )
@ -143,9 +143,13 @@ func TestNodePorts(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
p, _ := New(nil, nil) _, ctx := ktesting.NewTestContext(t)
p, err := New(ctx, nil, nil)
if err != nil {
t.Fatalf("creating plugin: %v", err)
}
cycleState := framework.NewCycleState() cycleState := framework.NewCycleState()
_, preFilterStatus := p.(framework.PreFilterPlugin).PreFilter(context.Background(), cycleState, test.pod) _, preFilterStatus := p.(framework.PreFilterPlugin).PreFilter(ctx, cycleState, test.pod)
if diff := cmp.Diff(test.wantPreFilterStatus, preFilterStatus); diff != "" { if diff := cmp.Diff(test.wantPreFilterStatus, preFilterStatus); diff != "" {
t.Errorf("preFilter status does not match (-want,+got): %s", diff) t.Errorf("preFilter status does not match (-want,+got): %s", diff)
} }
@ -155,7 +159,7 @@ func TestNodePorts(t *testing.T) {
if !preFilterStatus.IsSuccess() { if !preFilterStatus.IsSuccess() {
t.Errorf("prefilter failed with status: %v", preFilterStatus) t.Errorf("prefilter failed with status: %v", preFilterStatus)
} }
gotStatus := p.(framework.FilterPlugin).Filter(context.Background(), cycleState, test.pod, test.nodeInfo) gotStatus := p.(framework.FilterPlugin).Filter(ctx, cycleState, test.pod, test.nodeInfo)
if diff := cmp.Diff(test.wantFilterStatus, gotStatus); diff != "" { if diff := cmp.Diff(test.wantFilterStatus, gotStatus); diff != "" {
t.Errorf("filter status does not match (-want, +got): %s", diff) t.Errorf("filter status does not match (-want, +got): %s", diff)
} }
@ -164,13 +168,17 @@ func TestNodePorts(t *testing.T) {
} }
func TestPreFilterDisabled(t *testing.T) { func TestPreFilterDisabled(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
pod := &v1.Pod{} pod := &v1.Pod{}
nodeInfo := framework.NewNodeInfo() nodeInfo := framework.NewNodeInfo()
node := v1.Node{} node := v1.Node{}
nodeInfo.SetNode(&node) nodeInfo.SetNode(&node)
p, _ := New(nil, nil) p, err := New(ctx, nil, nil)
if err != nil {
t.Fatalf("creating plugin: %v", err)
}
cycleState := framework.NewCycleState() cycleState := framework.NewCycleState()
gotStatus := p.(framework.FilterPlugin).Filter(context.Background(), cycleState, pod, nodeInfo) gotStatus := p.(framework.FilterPlugin).Filter(ctx, cycleState, pod, nodeInfo)
wantStatus := framework.AsStatus(fmt.Errorf(`reading "PreFilterNodePorts" from cycleState: %w`, framework.ErrNotFound)) wantStatus := framework.AsStatus(fmt.Errorf(`reading "PreFilterNodePorts" from cycleState: %w`, framework.ErrNotFound))
if !reflect.DeepEqual(gotStatus, wantStatus) { if !reflect.DeepEqual(gotStatus, wantStatus) {
t.Errorf("status does not match: %v, want: %v", gotStatus, wantStatus) t.Errorf("status does not match: %v, want: %v", gotStatus, wantStatus)

View File

@ -114,7 +114,7 @@ func (ba *BalancedAllocation) ScoreExtensions() framework.ScoreExtensions {
} }
// NewBalancedAllocation initializes a new plugin and returns it. // NewBalancedAllocation initializes a new plugin and returns it.
func NewBalancedAllocation(baArgs runtime.Object, h framework.Handle, fts feature.Features) (framework.Plugin, error) { func NewBalancedAllocation(_ context.Context, baArgs runtime.Object, h framework.Handle, fts feature.Features) (framework.Plugin, error) {
args, ok := baArgs.(*config.NodeResourcesBalancedAllocationArgs) args, ok := baArgs.(*config.NodeResourcesBalancedAllocationArgs)
if !ok { if !ok {
return nil, fmt.Errorf("want args to be of type NodeResourcesBalancedAllocationArgs, got %T", baArgs) return nil, fmt.Errorf("want args to be of type NodeResourcesBalancedAllocationArgs, got %T", baArgs)

View File

@ -389,7 +389,7 @@ func TestNodeResourcesBalancedAllocation(t *testing.T) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(snapshot)) fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(snapshot))
p, _ := NewBalancedAllocation(&test.args, fh, feature.Features{}) p, _ := NewBalancedAllocation(ctx, &test.args, fh, feature.Features{})
state := framework.NewCycleState() state := framework.NewCycleState()
for i := range test.nodes { for i := range test.nodes {
if test.runPreScore { if test.runPreScore {

View File

@ -145,7 +145,7 @@ func (f *Fit) Name() string {
} }
// NewFit initializes a new plugin and returns it. // NewFit initializes a new plugin and returns it.
func NewFit(plArgs runtime.Object, h framework.Handle, fts feature.Features) (framework.Plugin, error) { func NewFit(_ context.Context, plArgs runtime.Object, h framework.Handle, fts feature.Features) (framework.Plugin, error) {
args, ok := plArgs.(*config.NodeResourcesFitArgs) args, ok := plArgs.(*config.NodeResourcesFitArgs)
if !ok { if !ok {
return nil, fmt.Errorf("want args to be of type NodeResourcesFitArgs, got %T", plArgs) return nil, fmt.Errorf("want args to be of type NodeResourcesFitArgs, got %T", plArgs)

View File

@ -496,17 +496,20 @@ func TestEnoughRequests(t *testing.T) {
test.args.ScoringStrategy = defaultScoringStrategy test.args.ScoringStrategy = defaultScoringStrategy
} }
p, err := NewFit(&test.args, nil, plfeature.Features{}) _, ctx := ktesting.NewTestContext(t)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
p, err := NewFit(ctx, &test.args, nil, plfeature.Features{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
cycleState := framework.NewCycleState() cycleState := framework.NewCycleState()
_, preFilterStatus := p.(framework.PreFilterPlugin).PreFilter(context.Background(), cycleState, test.pod) _, preFilterStatus := p.(framework.PreFilterPlugin).PreFilter(ctx, cycleState, test.pod)
if !preFilterStatus.IsSuccess() { if !preFilterStatus.IsSuccess() {
t.Errorf("prefilter failed with status: %v", preFilterStatus) t.Errorf("prefilter failed with status: %v", preFilterStatus)
} }
gotStatus := p.(framework.FilterPlugin).Filter(context.Background(), cycleState, test.pod, test.nodeInfo) gotStatus := p.(framework.FilterPlugin).Filter(ctx, cycleState, test.pod, test.nodeInfo)
if !reflect.DeepEqual(gotStatus, test.wantStatus) { if !reflect.DeepEqual(gotStatus, test.wantStatus) {
t.Errorf("status does not match: %v, want: %v", gotStatus, test.wantStatus) t.Errorf("status does not match: %v, want: %v", gotStatus, test.wantStatus)
} }
@ -520,16 +523,19 @@ func TestEnoughRequests(t *testing.T) {
} }
func TestPreFilterDisabled(t *testing.T) { func TestPreFilterDisabled(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
pod := &v1.Pod{} pod := &v1.Pod{}
nodeInfo := framework.NewNodeInfo() nodeInfo := framework.NewNodeInfo()
node := v1.Node{} node := v1.Node{}
nodeInfo.SetNode(&node) nodeInfo.SetNode(&node)
p, err := NewFit(&config.NodeResourcesFitArgs{ScoringStrategy: defaultScoringStrategy}, nil, plfeature.Features{}) p, err := NewFit(ctx, &config.NodeResourcesFitArgs{ScoringStrategy: defaultScoringStrategy}, nil, plfeature.Features{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
cycleState := framework.NewCycleState() cycleState := framework.NewCycleState()
gotStatus := p.(framework.FilterPlugin).Filter(context.Background(), cycleState, pod, nodeInfo) gotStatus := p.(framework.FilterPlugin).Filter(ctx, cycleState, pod, nodeInfo)
wantStatus := framework.AsStatus(fmt.Errorf(`error reading "PreFilterNodeResourcesFit" from cycleState: %w`, framework.ErrNotFound)) wantStatus := framework.AsStatus(fmt.Errorf(`error reading "PreFilterNodeResourcesFit" from cycleState: %w`, framework.ErrNotFound))
if !reflect.DeepEqual(gotStatus, wantStatus) { if !reflect.DeepEqual(gotStatus, wantStatus) {
t.Errorf("status does not match: %v, want: %v", gotStatus, wantStatus) t.Errorf("status does not match: %v, want: %v", gotStatus, wantStatus)
@ -571,20 +577,23 @@ func TestNotEnoughRequests(t *testing.T) {
} }
for _, test := range notEnoughPodsTests { for _, test := range notEnoughPodsTests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
node := v1.Node{Status: v1.NodeStatus{Capacity: v1.ResourceList{}, Allocatable: makeAllocatableResources(10, 20, 1, 0, 0, 0)}} node := v1.Node{Status: v1.NodeStatus{Capacity: v1.ResourceList{}, Allocatable: makeAllocatableResources(10, 20, 1, 0, 0, 0)}}
test.nodeInfo.SetNode(&node) test.nodeInfo.SetNode(&node)
p, err := NewFit(&config.NodeResourcesFitArgs{ScoringStrategy: defaultScoringStrategy}, nil, plfeature.Features{}) p, err := NewFit(ctx, &config.NodeResourcesFitArgs{ScoringStrategy: defaultScoringStrategy}, nil, plfeature.Features{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
cycleState := framework.NewCycleState() cycleState := framework.NewCycleState()
_, preFilterStatus := p.(framework.PreFilterPlugin).PreFilter(context.Background(), cycleState, test.pod) _, preFilterStatus := p.(framework.PreFilterPlugin).PreFilter(ctx, cycleState, test.pod)
if !preFilterStatus.IsSuccess() { if !preFilterStatus.IsSuccess() {
t.Errorf("prefilter failed with status: %v", preFilterStatus) t.Errorf("prefilter failed with status: %v", preFilterStatus)
} }
gotStatus := p.(framework.FilterPlugin).Filter(context.Background(), cycleState, test.pod, test.nodeInfo) gotStatus := p.(framework.FilterPlugin).Filter(ctx, cycleState, test.pod, test.nodeInfo)
if !reflect.DeepEqual(gotStatus, test.wantStatus) { if !reflect.DeepEqual(gotStatus, test.wantStatus) {
t.Errorf("status does not match: %v, want: %v", gotStatus, test.wantStatus) t.Errorf("status does not match: %v, want: %v", gotStatus, test.wantStatus)
} }
@ -629,20 +638,23 @@ func TestStorageRequests(t *testing.T) {
for _, test := range storagePodsTests { for _, test := range storagePodsTests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
node := v1.Node{Status: v1.NodeStatus{Capacity: makeResources(10, 20, 32, 5, 20, 5).Capacity, Allocatable: makeAllocatableResources(10, 20, 32, 5, 20, 5)}} node := v1.Node{Status: v1.NodeStatus{Capacity: makeResources(10, 20, 32, 5, 20, 5).Capacity, Allocatable: makeAllocatableResources(10, 20, 32, 5, 20, 5)}}
test.nodeInfo.SetNode(&node) test.nodeInfo.SetNode(&node)
p, err := NewFit(&config.NodeResourcesFitArgs{ScoringStrategy: defaultScoringStrategy}, nil, plfeature.Features{}) p, err := NewFit(ctx, &config.NodeResourcesFitArgs{ScoringStrategy: defaultScoringStrategy}, nil, plfeature.Features{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
cycleState := framework.NewCycleState() cycleState := framework.NewCycleState()
_, preFilterStatus := p.(framework.PreFilterPlugin).PreFilter(context.Background(), cycleState, test.pod) _, preFilterStatus := p.(framework.PreFilterPlugin).PreFilter(ctx, cycleState, test.pod)
if !preFilterStatus.IsSuccess() { if !preFilterStatus.IsSuccess() {
t.Errorf("prefilter failed with status: %v", preFilterStatus) t.Errorf("prefilter failed with status: %v", preFilterStatus)
} }
gotStatus := p.(framework.FilterPlugin).Filter(context.Background(), cycleState, test.pod, test.nodeInfo) gotStatus := p.(framework.FilterPlugin).Filter(ctx, cycleState, test.pod, test.nodeInfo)
if !reflect.DeepEqual(gotStatus, test.wantStatus) { if !reflect.DeepEqual(gotStatus, test.wantStatus) {
t.Errorf("status does not match: %v, want: %v", gotStatus, test.wantStatus) t.Errorf("status does not match: %v, want: %v", gotStatus, test.wantStatus)
} }
@ -707,11 +719,14 @@ func TestRestartableInitContainers(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
node := v1.Node{Status: v1.NodeStatus{Capacity: v1.ResourceList{}, Allocatable: makeAllocatableResources(0, 0, 1, 0, 0, 0)}} node := v1.Node{Status: v1.NodeStatus{Capacity: v1.ResourceList{}, Allocatable: makeAllocatableResources(0, 0, 1, 0, 0, 0)}}
nodeInfo := framework.NewNodeInfo() nodeInfo := framework.NewNodeInfo()
nodeInfo.SetNode(&node) nodeInfo.SetNode(&node)
p, err := NewFit(&config.NodeResourcesFitArgs{ScoringStrategy: defaultScoringStrategy}, nil, plfeature.Features{EnableSidecarContainers: test.enableSidecarContainers}) p, err := NewFit(ctx, &config.NodeResourcesFitArgs{ScoringStrategy: defaultScoringStrategy}, nil, plfeature.Features{EnableSidecarContainers: test.enableSidecarContainers})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -924,7 +939,7 @@ func TestFitScore(t *testing.T) {
snapshot := cache.NewSnapshot(test.existingPods, test.nodes) snapshot := cache.NewSnapshot(test.existingPods, test.nodes)
fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(snapshot)) fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(snapshot))
args := test.nodeResourcesFitArgs args := test.nodeResourcesFitArgs
p, err := NewFit(&args, fh, plfeature.Features{}) p, err := NewFit(ctx, &args, fh, plfeature.Features{})
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }

View File

@ -395,6 +395,7 @@ func TestLeastAllocatedScoringStrategy(t *testing.T) {
fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(snapshot)) fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(snapshot))
p, err := NewFit( p, err := NewFit(
ctx,
&config.NodeResourcesFitArgs{ &config.NodeResourcesFitArgs{
ScoringStrategy: &config.ScoringStrategy{ ScoringStrategy: &config.ScoringStrategy{
Type: config.LeastAllocated, Type: config.LeastAllocated,

View File

@ -351,7 +351,7 @@ func TestMostAllocatedScoringStrategy(t *testing.T) {
snapshot := cache.NewSnapshot(test.existingPods, test.nodes) snapshot := cache.NewSnapshot(test.existingPods, test.nodes)
fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(snapshot)) fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(snapshot))
p, err := NewFit( p, err := NewFit(ctx,
&config.NodeResourcesFitArgs{ &config.NodeResourcesFitArgs{
ScoringStrategy: &config.ScoringStrategy{ ScoringStrategy: &config.ScoringStrategy{
Type: config.MostAllocated, Type: config.MostAllocated,

View File

@ -111,7 +111,7 @@ func TestRequestedToCapacityRatioScoringStrategy(t *testing.T) {
snapshot := cache.NewSnapshot(test.existingPods, test.nodes) snapshot := cache.NewSnapshot(test.existingPods, test.nodes)
fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(snapshot)) fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(snapshot))
p, err := NewFit(&config.NodeResourcesFitArgs{ p, err := NewFit(ctx, &config.NodeResourcesFitArgs{
ScoringStrategy: &config.ScoringStrategy{ ScoringStrategy: &config.ScoringStrategy{
Type: config.RequestedToCapacityRatio, Type: config.RequestedToCapacityRatio,
Resources: test.resources, Resources: test.resources,
@ -320,7 +320,7 @@ func TestResourceBinPackingSingleExtended(t *testing.T) {
}, },
}, },
} }
p, err := NewFit(&args, fh, plfeature.Features{}) p, err := NewFit(ctx, &args, fh, plfeature.Features{})
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
@ -548,7 +548,7 @@ func TestResourceBinPackingMultipleExtended(t *testing.T) {
}, },
} }
p, err := NewFit(&args, fh, plfeature.Features{}) p, err := NewFit(ctx, &args, fh, plfeature.Features{})
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }

View File

@ -104,6 +104,6 @@ func (pl *NodeUnschedulable) Filter(ctx context.Context, _ *framework.CycleState
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func New(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &NodeUnschedulable{}, nil return &NodeUnschedulable{}, nil
} }

View File

@ -17,13 +17,12 @@ limitations under the License.
package nodeunschedulable package nodeunschedulable
import ( import (
"context"
"reflect" "reflect"
"testing" "testing"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
"k8s.io/klog/v2/ktesting"
"k8s.io/kubernetes/pkg/scheduler/framework" "k8s.io/kubernetes/pkg/scheduler/framework"
"k8s.io/kubernetes/test/utils/ktesting"
) )
func TestNodeUnschedulable(t *testing.T) { func TestNodeUnschedulable(t *testing.T) {
@ -75,9 +74,12 @@ func TestNodeUnschedulable(t *testing.T) {
for _, test := range testCases { for _, test := range testCases {
nodeInfo := framework.NewNodeInfo() nodeInfo := framework.NewNodeInfo()
nodeInfo.SetNode(test.node) nodeInfo.SetNode(test.node)
_, ctx := ktesting.NewTestContext(t)
p, _ := New(nil, nil) p, err := New(ctx, nil, nil)
gotStatus := p.(framework.FilterPlugin).Filter(context.Background(), nil, test.pod, nodeInfo) if err != nil {
t.Fatalf("creating plugin: %v", err)
}
gotStatus := p.(framework.FilterPlugin).Filter(ctx, nil, test.pod, nodeInfo)
if !reflect.DeepEqual(gotStatus, test.wantStatus) { if !reflect.DeepEqual(gotStatus, test.wantStatus) {
t.Errorf("status does not match: %v, want: %v", gotStatus, test.wantStatus) t.Errorf("status does not match: %v, want: %v", gotStatus, test.wantStatus)
} }

View File

@ -110,15 +110,17 @@ func (pl *CSILimits) Filter(ctx context.Context, _ *framework.CycleState, pod *v
node := nodeInfo.Node() node := nodeInfo.Node()
logger := klog.FromContext(ctx)
// If CSINode doesn't exist, the predicate may read the limits from Node object // If CSINode doesn't exist, the predicate may read the limits from Node object
csiNode, err := pl.csiNodeLister.Get(node.Name) csiNode, err := pl.csiNodeLister.Get(node.Name)
if err != nil { if err != nil {
// TODO: return the error once CSINode is created by default (2 releases) // TODO: return the error once CSINode is created by default (2 releases)
klog.V(5).InfoS("Could not get a CSINode object for the node", "node", klog.KObj(node), "err", err) logger.V(5).Info("Could not get a CSINode object for the node", "node", klog.KObj(node), "err", err)
} }
newVolumes := make(map[string]string) newVolumes := make(map[string]string)
if err := pl.filterAttachableVolumes(pod, csiNode, true /* new pod */, newVolumes); err != nil { if err := pl.filterAttachableVolumes(logger, pod, csiNode, true /* new pod */, newVolumes); err != nil {
return framework.AsStatus(err) return framework.AsStatus(err)
} }
@ -135,7 +137,7 @@ func (pl *CSILimits) Filter(ctx context.Context, _ *framework.CycleState, pod *v
attachedVolumes := make(map[string]string) attachedVolumes := make(map[string]string)
for _, existingPod := range nodeInfo.Pods { for _, existingPod := range nodeInfo.Pods {
if err := pl.filterAttachableVolumes(existingPod.Pod, csiNode, false /* existing pod */, attachedVolumes); err != nil { if err := pl.filterAttachableVolumes(logger, existingPod.Pod, csiNode, false /* existing pod */, attachedVolumes); err != nil {
return framework.AsStatus(err) return framework.AsStatus(err)
} }
} }
@ -156,7 +158,7 @@ func (pl *CSILimits) Filter(ctx context.Context, _ *framework.CycleState, pod *v
maxVolumeLimit, ok := nodeVolumeLimits[v1.ResourceName(volumeLimitKey)] maxVolumeLimit, ok := nodeVolumeLimits[v1.ResourceName(volumeLimitKey)]
if ok { if ok {
currentVolumeCount := attachedVolumeCount[volumeLimitKey] currentVolumeCount := attachedVolumeCount[volumeLimitKey]
klog.V(5).InfoS("Found plugin volume limits", "node", node.Name, "volumeLimitKey", volumeLimitKey, logger.V(5).Info("Found plugin volume limits", "node", node.Name, "volumeLimitKey", volumeLimitKey,
"maxLimits", maxVolumeLimit, "currentVolumeCount", currentVolumeCount, "newVolumeCount", count, "maxLimits", maxVolumeLimit, "currentVolumeCount", currentVolumeCount, "newVolumeCount", count,
"pod", klog.KObj(pod)) "pod", klog.KObj(pod))
if currentVolumeCount+count > int(maxVolumeLimit) { if currentVolumeCount+count > int(maxVolumeLimit) {
@ -169,7 +171,7 @@ func (pl *CSILimits) Filter(ctx context.Context, _ *framework.CycleState, pod *v
} }
func (pl *CSILimits) filterAttachableVolumes( func (pl *CSILimits) filterAttachableVolumes(
pod *v1.Pod, csiNode *storagev1.CSINode, newPod bool, result map[string]string) error { logger klog.Logger, pod *v1.Pod, csiNode *storagev1.CSINode, newPod bool, result map[string]string) error {
for _, vol := range pod.Spec.Volumes { for _, vol := range pod.Spec.Volumes {
pvcName := "" pvcName := ""
isEphemeral := false isEphemeral := false
@ -190,7 +192,7 @@ func (pl *CSILimits) filterAttachableVolumes(
// - If the volume is migratable and CSI migration is enabled, need to count it // - If the volume is migratable and CSI migration is enabled, need to count it
// as well. // as well.
// - If the volume is not migratable, it will be count in non_csi filter. // - If the volume is not migratable, it will be count in non_csi filter.
if err := pl.checkAttachableInlineVolume(&vol, csiNode, pod, result); err != nil { if err := pl.checkAttachableInlineVolume(logger, &vol, csiNode, pod, result); err != nil {
return err return err
} }
@ -212,7 +214,7 @@ func (pl *CSILimits) filterAttachableVolumes(
} }
// If the PVC is invalid, we don't count the volume because // If the PVC is invalid, we don't count the volume because
// there's no guarantee that it belongs to the running predicate. // there's no guarantee that it belongs to the running predicate.
klog.V(5).InfoS("Unable to look up PVC info", "pod", klog.KObj(pod), "PVC", klog.KRef(pod.Namespace, pvcName)) logger.V(5).Info("Unable to look up PVC info", "pod", klog.KObj(pod), "PVC", klog.KRef(pod.Namespace, pvcName))
continue continue
} }
@ -223,9 +225,9 @@ func (pl *CSILimits) filterAttachableVolumes(
} }
} }
driverName, volumeHandle := pl.getCSIDriverInfo(csiNode, pvc) driverName, volumeHandle := pl.getCSIDriverInfo(logger, csiNode, pvc)
if driverName == "" || volumeHandle == "" { if driverName == "" || volumeHandle == "" {
klog.V(5).InfoS("Could not find a CSI driver name or volume handle, not counting volume") logger.V(5).Info("Could not find a CSI driver name or volume handle, not counting volume")
continue continue
} }
@ -238,7 +240,7 @@ func (pl *CSILimits) filterAttachableVolumes(
// checkAttachableInlineVolume takes an inline volume and add to the result map if the // checkAttachableInlineVolume takes an inline volume and add to the result map if the
// volume is migratable and CSI migration for this plugin has been enabled. // volume is migratable and CSI migration for this plugin has been enabled.
func (pl *CSILimits) checkAttachableInlineVolume(vol *v1.Volume, csiNode *storagev1.CSINode, func (pl *CSILimits) checkAttachableInlineVolume(logger klog.Logger, vol *v1.Volume, csiNode *storagev1.CSINode,
pod *v1.Pod, result map[string]string) error { pod *v1.Pod, result map[string]string) error {
if !pl.translator.IsInlineMigratable(vol) { if !pl.translator.IsInlineMigratable(vol) {
return nil return nil
@ -253,7 +255,7 @@ func (pl *CSILimits) checkAttachableInlineVolume(vol *v1.Volume, csiNode *storag
if csiNode != nil { if csiNode != nil {
csiNodeName = csiNode.Name csiNodeName = csiNode.Name
} }
klog.V(5).InfoS("CSI Migration is not enabled for provisioner", "provisioner", inTreeProvisionerName, logger.V(5).Info("CSI Migration is not enabled for provisioner", "provisioner", inTreeProvisionerName,
"pod", klog.KObj(pod), "csiNode", csiNodeName) "pod", klog.KObj(pod), "csiNode", csiNodeName)
return nil return nil
} }
@ -280,21 +282,21 @@ func (pl *CSILimits) checkAttachableInlineVolume(vol *v1.Volume, csiNode *storag
// getCSIDriverInfo returns the CSI driver name and volume ID of a given PVC. // getCSIDriverInfo returns the CSI driver name and volume ID of a given PVC.
// If the PVC is from a migrated in-tree plugin, this function will return // If the PVC is from a migrated in-tree plugin, this function will return
// the information of the CSI driver that the plugin has been migrated to. // the information of the CSI driver that the plugin has been migrated to.
func (pl *CSILimits) getCSIDriverInfo(csiNode *storagev1.CSINode, pvc *v1.PersistentVolumeClaim) (string, string) { func (pl *CSILimits) getCSIDriverInfo(logger klog.Logger, csiNode *storagev1.CSINode, pvc *v1.PersistentVolumeClaim) (string, string) {
pvName := pvc.Spec.VolumeName pvName := pvc.Spec.VolumeName
if pvName == "" { if pvName == "" {
klog.V(5).InfoS("Persistent volume had no name for claim", "PVC", klog.KObj(pvc)) logger.V(5).Info("Persistent volume had no name for claim", "PVC", klog.KObj(pvc))
return pl.getCSIDriverInfoFromSC(csiNode, pvc) return pl.getCSIDriverInfoFromSC(logger, csiNode, pvc)
} }
pv, err := pl.pvLister.Get(pvName) pv, err := pl.pvLister.Get(pvName)
if err != nil { if err != nil {
klog.V(5).InfoS("Unable to look up PV info for PVC and PV", "PVC", klog.KObj(pvc), "PV", klog.KRef("", pvName)) logger.V(5).Info("Unable to look up PV info for PVC and PV", "PVC", klog.KObj(pvc), "PV", klog.KRef("", pvName))
// If we can't fetch PV associated with PVC, may be it got deleted // If we can't fetch PV associated with PVC, may be it got deleted
// or PVC was prebound to a PVC that hasn't been created yet. // or PVC was prebound to a PVC that hasn't been created yet.
// fallback to using StorageClass for volume counting // fallback to using StorageClass for volume counting
return pl.getCSIDriverInfoFromSC(csiNode, pvc) return pl.getCSIDriverInfoFromSC(logger, csiNode, pvc)
} }
csiSource := pv.Spec.PersistentVolumeSource.CSI csiSource := pv.Spec.PersistentVolumeSource.CSI
@ -306,23 +308,23 @@ func (pl *CSILimits) getCSIDriverInfo(csiNode *storagev1.CSINode, pvc *v1.Persis
pluginName, err := pl.translator.GetInTreePluginNameFromSpec(pv, nil) pluginName, err := pl.translator.GetInTreePluginNameFromSpec(pv, nil)
if err != nil { if err != nil {
klog.V(5).InfoS("Unable to look up plugin name from PV spec", "err", err) logger.V(5).Info("Unable to look up plugin name from PV spec", "err", err)
return "", "" return "", ""
} }
if !isCSIMigrationOn(csiNode, pluginName) { if !isCSIMigrationOn(csiNode, pluginName) {
klog.V(5).InfoS("CSI Migration of plugin is not enabled", "plugin", pluginName) logger.V(5).Info("CSI Migration of plugin is not enabled", "plugin", pluginName)
return "", "" return "", ""
} }
csiPV, err := pl.translator.TranslateInTreePVToCSI(pv) csiPV, err := pl.translator.TranslateInTreePVToCSI(pv)
if err != nil { if err != nil {
klog.V(5).InfoS("Unable to translate in-tree volume to CSI", "err", err) logger.V(5).Info("Unable to translate in-tree volume to CSI", "err", err)
return "", "" return "", ""
} }
if csiPV.Spec.PersistentVolumeSource.CSI == nil { if csiPV.Spec.PersistentVolumeSource.CSI == nil {
klog.V(5).InfoS("Unable to get a valid volume source for translated PV", "PV", pvName) logger.V(5).Info("Unable to get a valid volume source for translated PV", "PV", pvName)
return "", "" return "", ""
} }
@ -333,7 +335,7 @@ func (pl *CSILimits) getCSIDriverInfo(csiNode *storagev1.CSINode, pvc *v1.Persis
} }
// getCSIDriverInfoFromSC returns the CSI driver name and a random volume ID of a given PVC's StorageClass. // getCSIDriverInfoFromSC returns the CSI driver name and a random volume ID of a given PVC's StorageClass.
func (pl *CSILimits) getCSIDriverInfoFromSC(csiNode *storagev1.CSINode, pvc *v1.PersistentVolumeClaim) (string, string) { func (pl *CSILimits) getCSIDriverInfoFromSC(logger klog.Logger, csiNode *storagev1.CSINode, pvc *v1.PersistentVolumeClaim) (string, string) {
namespace := pvc.Namespace namespace := pvc.Namespace
pvcName := pvc.Name pvcName := pvc.Name
scName := storagehelpers.GetPersistentVolumeClaimClass(pvc) scName := storagehelpers.GetPersistentVolumeClaimClass(pvc)
@ -341,13 +343,13 @@ func (pl *CSILimits) getCSIDriverInfoFromSC(csiNode *storagev1.CSINode, pvc *v1.
// If StorageClass is not set or not found, then PVC must be using immediate binding mode // If StorageClass is not set or not found, then PVC must be using immediate binding mode
// and hence it must be bound before scheduling. So it is safe to not count it. // and hence it must be bound before scheduling. So it is safe to not count it.
if scName == "" { if scName == "" {
klog.V(5).InfoS("PVC has no StorageClass", "PVC", klog.KObj(pvc)) logger.V(5).Info("PVC has no StorageClass", "PVC", klog.KObj(pvc))
return "", "" return "", ""
} }
storageClass, err := pl.scLister.Get(scName) storageClass, err := pl.scLister.Get(scName)
if err != nil { if err != nil {
klog.V(5).InfoS("Could not get StorageClass for PVC", "PVC", klog.KObj(pvc), "err", err) logger.V(5).Info("Could not get StorageClass for PVC", "PVC", klog.KObj(pvc), "err", err)
return "", "" return "", ""
} }
@ -359,13 +361,13 @@ func (pl *CSILimits) getCSIDriverInfoFromSC(csiNode *storagev1.CSINode, pvc *v1.
provisioner := storageClass.Provisioner provisioner := storageClass.Provisioner
if pl.translator.IsMigratableIntreePluginByName(provisioner) { if pl.translator.IsMigratableIntreePluginByName(provisioner) {
if !isCSIMigrationOn(csiNode, provisioner) { if !isCSIMigrationOn(csiNode, provisioner) {
klog.V(5).InfoS("CSI Migration of provisioner is not enabled", "provisioner", provisioner) logger.V(5).Info("CSI Migration of provisioner is not enabled", "provisioner", provisioner)
return "", "" return "", ""
} }
driverName, err := pl.translator.GetCSINameFromInTreeName(provisioner) driverName, err := pl.translator.GetCSINameFromInTreeName(provisioner)
if err != nil { if err != nil {
klog.V(5).InfoS("Unable to look up driver name from provisioner name", "provisioner", provisioner, "err", err) logger.V(5).Info("Unable to look up driver name from provisioner name", "provisioner", provisioner, "err", err)
return "", "" return "", ""
} }
return driverName, volumeHandle return driverName, volumeHandle
@ -375,7 +377,7 @@ func (pl *CSILimits) getCSIDriverInfoFromSC(csiNode *storagev1.CSINode, pvc *v1.
} }
// NewCSI initializes a new plugin and returns it. // NewCSI initializes a new plugin and returns it.
func NewCSI(_ runtime.Object, handle framework.Handle, fts feature.Features) (framework.Plugin, error) { func NewCSI(_ context.Context, _ runtime.Object, handle framework.Handle, fts feature.Features) (framework.Plugin, error) {
informerFactory := handle.SharedInformerFactory() informerFactory := handle.SharedInformerFactory()
pvLister := informerFactory.Core().V1().PersistentVolumes().Lister() pvLister := informerFactory.Core().V1().PersistentVolumes().Lister()
pvcLister := informerFactory.Core().V1().PersistentVolumeClaims().Lister() pvcLister := informerFactory.Core().V1().PersistentVolumeClaims().Lister()

View File

@ -70,36 +70,36 @@ const (
const AzureDiskName = names.AzureDiskLimits const AzureDiskName = names.AzureDiskLimits
// NewAzureDisk returns function that initializes a new plugin and returns it. // NewAzureDisk returns function that initializes a new plugin and returns it.
func NewAzureDisk(_ runtime.Object, handle framework.Handle, fts feature.Features) (framework.Plugin, error) { func NewAzureDisk(ctx context.Context, _ runtime.Object, handle framework.Handle, fts feature.Features) (framework.Plugin, error) {
informerFactory := handle.SharedInformerFactory() informerFactory := handle.SharedInformerFactory()
return newNonCSILimitsWithInformerFactory(azureDiskVolumeFilterType, informerFactory, fts), nil return newNonCSILimitsWithInformerFactory(ctx, azureDiskVolumeFilterType, informerFactory, fts), nil
} }
// CinderName is the name of the plugin used in the plugin registry and configurations. // CinderName is the name of the plugin used in the plugin registry and configurations.
const CinderName = names.CinderLimits const CinderName = names.CinderLimits
// NewCinder returns function that initializes a new plugin and returns it. // NewCinder returns function that initializes a new plugin and returns it.
func NewCinder(_ runtime.Object, handle framework.Handle, fts feature.Features) (framework.Plugin, error) { func NewCinder(ctx context.Context, _ runtime.Object, handle framework.Handle, fts feature.Features) (framework.Plugin, error) {
informerFactory := handle.SharedInformerFactory() informerFactory := handle.SharedInformerFactory()
return newNonCSILimitsWithInformerFactory(cinderVolumeFilterType, informerFactory, fts), nil return newNonCSILimitsWithInformerFactory(ctx, cinderVolumeFilterType, informerFactory, fts), nil
} }
// EBSName is the name of the plugin used in the plugin registry and configurations. // EBSName is the name of the plugin used in the plugin registry and configurations.
const EBSName = names.EBSLimits const EBSName = names.EBSLimits
// NewEBS returns function that initializes a new plugin and returns it. // NewEBS returns function that initializes a new plugin and returns it.
func NewEBS(_ runtime.Object, handle framework.Handle, fts feature.Features) (framework.Plugin, error) { func NewEBS(ctx context.Context, _ runtime.Object, handle framework.Handle, fts feature.Features) (framework.Plugin, error) {
informerFactory := handle.SharedInformerFactory() informerFactory := handle.SharedInformerFactory()
return newNonCSILimitsWithInformerFactory(ebsVolumeFilterType, informerFactory, fts), nil return newNonCSILimitsWithInformerFactory(ctx, ebsVolumeFilterType, informerFactory, fts), nil
} }
// GCEPDName is the name of the plugin used in the plugin registry and configurations. // GCEPDName is the name of the plugin used in the plugin registry and configurations.
const GCEPDName = names.GCEPDLimits const GCEPDName = names.GCEPDLimits
// NewGCEPD returns function that initializes a new plugin and returns it. // NewGCEPD returns function that initializes a new plugin and returns it.
func NewGCEPD(_ runtime.Object, handle framework.Handle, fts feature.Features) (framework.Plugin, error) { func NewGCEPD(ctx context.Context, _ runtime.Object, handle framework.Handle, fts feature.Features) (framework.Plugin, error) {
informerFactory := handle.SharedInformerFactory() informerFactory := handle.SharedInformerFactory()
return newNonCSILimitsWithInformerFactory(gcePDVolumeFilterType, informerFactory, fts), nil return newNonCSILimitsWithInformerFactory(ctx, gcePDVolumeFilterType, informerFactory, fts), nil
} }
// nonCSILimits contains information to check the max number of volumes for a plugin. // nonCSILimits contains information to check the max number of volumes for a plugin.
@ -125,6 +125,7 @@ var _ framework.EnqueueExtensions = &nonCSILimits{}
// newNonCSILimitsWithInformerFactory returns a plugin with filter name and informer factory. // newNonCSILimitsWithInformerFactory returns a plugin with filter name and informer factory.
func newNonCSILimitsWithInformerFactory( func newNonCSILimitsWithInformerFactory(
ctx context.Context,
filterName string, filterName string,
informerFactory informers.SharedInformerFactory, informerFactory informers.SharedInformerFactory,
fts feature.Features, fts feature.Features,
@ -134,7 +135,7 @@ func newNonCSILimitsWithInformerFactory(
csiNodesLister := informerFactory.Storage().V1().CSINodes().Lister() csiNodesLister := informerFactory.Storage().V1().CSINodes().Lister()
scLister := informerFactory.Storage().V1().StorageClasses().Lister() scLister := informerFactory.Storage().V1().StorageClasses().Lister()
return newNonCSILimits(filterName, csiNodesLister, scLister, pvLister, pvcLister, fts) return newNonCSILimits(ctx, filterName, csiNodesLister, scLister, pvLister, pvcLister, fts)
} }
// newNonCSILimits creates a plugin which evaluates whether a pod can fit based on the // newNonCSILimits creates a plugin which evaluates whether a pod can fit based on the
@ -148,6 +149,7 @@ func newNonCSILimitsWithInformerFactory(
// types, counts the number of unique volumes, and rejects the new pod if it would place the total count over // types, counts the number of unique volumes, and rejects the new pod if it would place the total count over
// the maximum. // the maximum.
func newNonCSILimits( func newNonCSILimits(
ctx context.Context,
filterName string, filterName string,
csiNodeLister storagelisters.CSINodeLister, csiNodeLister storagelisters.CSINodeLister,
scLister storagelisters.StorageClassLister, scLister storagelisters.StorageClassLister,
@ -155,6 +157,7 @@ func newNonCSILimits(
pvcLister corelisters.PersistentVolumeClaimLister, pvcLister corelisters.PersistentVolumeClaimLister,
fts feature.Features, fts feature.Features,
) framework.Plugin { ) framework.Plugin {
logger := klog.FromContext(ctx)
var filter VolumeFilter var filter VolumeFilter
var volumeLimitKey v1.ResourceName var volumeLimitKey v1.ResourceName
var name string var name string
@ -177,14 +180,14 @@ func newNonCSILimits(
filter = cinderVolumeFilter filter = cinderVolumeFilter
volumeLimitKey = v1.ResourceName(volumeutil.CinderVolumeLimitKey) volumeLimitKey = v1.ResourceName(volumeutil.CinderVolumeLimitKey)
default: default:
klog.ErrorS(errors.New("wrong filterName"), "Cannot create nonCSILimits plugin") logger.Error(errors.New("wrong filterName"), "Cannot create nonCSILimits plugin")
return nil return nil
} }
pl := &nonCSILimits{ pl := &nonCSILimits{
name: name, name: name,
filter: filter, filter: filter,
volumeLimitKey: volumeLimitKey, volumeLimitKey: volumeLimitKey,
maxVolumeFunc: getMaxVolumeFunc(filterName), maxVolumeFunc: getMaxVolumeFunc(logger, filterName),
csiNodeLister: csiNodeLister, csiNodeLister: csiNodeLister,
pvLister: pvLister, pvLister: pvLister,
pvcLister: pvcLister, pvcLister: pvcLister,
@ -238,8 +241,9 @@ func (pl *nonCSILimits) Filter(ctx context.Context, _ *framework.CycleState, pod
return nil return nil
} }
logger := klog.FromContext(ctx)
newVolumes := sets.New[string]() newVolumes := sets.New[string]()
if err := pl.filterVolumes(pod, true /* new pod */, newVolumes); err != nil { if err := pl.filterVolumes(logger, pod, true /* new pod */, newVolumes); err != nil {
return framework.AsStatus(err) return framework.AsStatus(err)
} }
@ -257,7 +261,7 @@ func (pl *nonCSILimits) Filter(ctx context.Context, _ *framework.CycleState, pod
if err != nil { if err != nil {
// we don't fail here because the CSINode object is only necessary // we don't fail here because the CSINode object is only necessary
// for determining whether the migration is enabled or not // for determining whether the migration is enabled or not
klog.V(5).InfoS("Could not get a CSINode object for the node", "node", klog.KObj(node), "err", err) logger.V(5).Info("Could not get a CSINode object for the node", "node", klog.KObj(node), "err", err)
} }
} }
@ -269,7 +273,7 @@ func (pl *nonCSILimits) Filter(ctx context.Context, _ *framework.CycleState, pod
// count unique volumes // count unique volumes
existingVolumes := sets.New[string]() existingVolumes := sets.New[string]()
for _, existingPod := range nodeInfo.Pods { for _, existingPod := range nodeInfo.Pods {
if err := pl.filterVolumes(existingPod.Pod, false /* existing pod */, existingVolumes); err != nil { if err := pl.filterVolumes(logger, existingPod.Pod, false /* existing pod */, existingVolumes); err != nil {
return framework.AsStatus(err) return framework.AsStatus(err)
} }
} }
@ -293,7 +297,7 @@ func (pl *nonCSILimits) Filter(ctx context.Context, _ *framework.CycleState, pod
return nil return nil
} }
func (pl *nonCSILimits) filterVolumes(pod *v1.Pod, newPod bool, filteredVolumes sets.Set[string]) error { func (pl *nonCSILimits) filterVolumes(logger klog.Logger, pod *v1.Pod, newPod bool, filteredVolumes sets.Set[string]) error {
volumes := pod.Spec.Volumes volumes := pod.Spec.Volumes
for i := range volumes { for i := range volumes {
vol := &volumes[i] vol := &volumes[i]
@ -336,7 +340,7 @@ func (pl *nonCSILimits) filterVolumes(pod *v1.Pod, newPod bool, filteredVolumes
} }
// If the PVC is invalid, we don't count the volume because // If the PVC is invalid, we don't count the volume because
// there's no guarantee that it belongs to the running predicate. // there's no guarantee that it belongs to the running predicate.
klog.V(4).InfoS("Unable to look up PVC info, assuming PVC doesn't match predicate when counting limits", "pod", klog.KObj(pod), "PVC", klog.KRef(pod.Namespace, pvcName), "err", err) logger.V(4).Info("Unable to look up PVC info, assuming PVC doesn't match predicate when counting limits", "pod", klog.KObj(pod), "PVC", klog.KRef(pod.Namespace, pvcName), "err", err)
continue continue
} }
@ -354,7 +358,7 @@ func (pl *nonCSILimits) filterVolumes(pod *v1.Pod, newPod bool, filteredVolumes
// original PV where it was bound to, so we count the volume if // original PV where it was bound to, so we count the volume if
// it belongs to the running predicate. // it belongs to the running predicate.
if pl.matchProvisioner(pvc) { if pl.matchProvisioner(pvc) {
klog.V(4).InfoS("PVC is not bound, assuming PVC matches predicate when counting limits", "pod", klog.KObj(pod), "PVC", klog.KRef(pod.Namespace, pvcName)) logger.V(4).Info("PVC is not bound, assuming PVC matches predicate when counting limits", "pod", klog.KObj(pod), "PVC", klog.KRef(pod.Namespace, pvcName))
filteredVolumes.Insert(pvID) filteredVolumes.Insert(pvID)
} }
continue continue
@ -365,7 +369,7 @@ func (pl *nonCSILimits) filterVolumes(pod *v1.Pod, newPod bool, filteredVolumes
// If the PV is invalid and PVC belongs to the running predicate, // If the PV is invalid and PVC belongs to the running predicate,
// log the error and count the PV towards the PV limit. // log the error and count the PV towards the PV limit.
if pl.matchProvisioner(pvc) { if pl.matchProvisioner(pvc) {
klog.V(4).InfoS("Unable to look up PV, assuming PV matches predicate when counting limits", "pod", klog.KObj(pod), "PVC", klog.KRef(pod.Namespace, pvcName), "PV", klog.KRef("", pvName), "err", err) logger.V(4).Info("Unable to look up PV, assuming PV matches predicate when counting limits", "pod", klog.KObj(pod), "PVC", klog.KRef(pod.Namespace, pvcName), "PV", klog.KRef("", pvName), "err", err)
filteredVolumes.Insert(pvID) filteredVolumes.Insert(pvID)
} }
continue continue
@ -394,12 +398,12 @@ func (pl *nonCSILimits) matchProvisioner(pvc *v1.PersistentVolumeClaim) bool {
} }
// getMaxVolLimitFromEnv checks the max PD volumes environment variable, otherwise returning a default value. // getMaxVolLimitFromEnv checks the max PD volumes environment variable, otherwise returning a default value.
func getMaxVolLimitFromEnv() int { func getMaxVolLimitFromEnv(logger klog.Logger) int {
if rawMaxVols := os.Getenv(KubeMaxPDVols); rawMaxVols != "" { if rawMaxVols := os.Getenv(KubeMaxPDVols); rawMaxVols != "" {
if parsedMaxVols, err := strconv.Atoi(rawMaxVols); err != nil { if parsedMaxVols, err := strconv.Atoi(rawMaxVols); err != nil {
klog.ErrorS(err, "Unable to parse maximum PD volumes value, using default") logger.Error(err, "Unable to parse maximum PD volumes value, using default")
} else if parsedMaxVols <= 0 { } else if parsedMaxVols <= 0 {
klog.ErrorS(errors.New("maximum PD volumes is negative"), "Unable to parse maximum PD volumes value, using default") logger.Error(errors.New("maximum PD volumes is negative"), "Unable to parse maximum PD volumes value, using default")
} else { } else {
return parsedMaxVols return parsedMaxVols
} }
@ -520,9 +524,9 @@ var cinderVolumeFilter = VolumeFilter{
}, },
} }
func getMaxVolumeFunc(filterName string) func(node *v1.Node) int { func getMaxVolumeFunc(logger klog.Logger, filterName string) func(node *v1.Node) int {
return func(node *v1.Node) int { return func(node *v1.Node) int {
maxVolumesFromEnv := getMaxVolLimitFromEnv() maxVolumesFromEnv := getMaxVolLimitFromEnv(logger)
if maxVolumesFromEnv > 0 { if maxVolumesFromEnv > 0 {
return maxVolumesFromEnv return maxVolumesFromEnv
} }

View File

@ -28,6 +28,7 @@ import (
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
csilibplugins "k8s.io/csi-translation-lib/plugins" csilibplugins "k8s.io/csi-translation-lib/plugins"
"k8s.io/klog/v2/ktesting"
"k8s.io/kubernetes/pkg/scheduler/framework" "k8s.io/kubernetes/pkg/scheduler/framework"
"k8s.io/kubernetes/pkg/scheduler/framework/plugins/feature" "k8s.io/kubernetes/pkg/scheduler/framework/plugins/feature"
st "k8s.io/kubernetes/pkg/scheduler/testing" st "k8s.io/kubernetes/pkg/scheduler/testing"
@ -181,16 +182,17 @@ func TestEphemeralLimits(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.test, func(t *testing.T) { t.Run(test.test, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
fts := feature.Features{} fts := feature.Features{}
node, csiNode := getNodeWithPodAndVolumeLimits("node", test.existingPods, test.maxVols, filterName) node, csiNode := getNodeWithPodAndVolumeLimits("node", test.existingPods, test.maxVols, filterName)
p := newNonCSILimits(filterName, getFakeCSINodeLister(csiNode), getFakeCSIStorageClassLister(filterName, driverName), getFakePVLister(filterName), append(getFakePVCLister(filterName), test.extraClaims...), fts).(framework.FilterPlugin) p := newNonCSILimits(ctx, filterName, getFakeCSINodeLister(csiNode), getFakeCSIStorageClassLister(filterName, driverName), getFakePVLister(filterName), append(getFakePVCLister(filterName), test.extraClaims...), fts).(framework.FilterPlugin)
_, gotPreFilterStatus := p.(*nonCSILimits).PreFilter(context.Background(), nil, test.newPod) _, gotPreFilterStatus := p.(*nonCSILimits).PreFilter(ctx, nil, test.newPod)
if diff := cmp.Diff(test.wantPreFilterStatus, gotPreFilterStatus); diff != "" { if diff := cmp.Diff(test.wantPreFilterStatus, gotPreFilterStatus); diff != "" {
t.Errorf("PreFilter status does not match (-want, +got): %s", diff) t.Errorf("PreFilter status does not match (-want, +got): %s", diff)
} }
if gotPreFilterStatus.Code() != framework.Skip { if gotPreFilterStatus.Code() != framework.Skip {
gotStatus := p.Filter(context.Background(), nil, test.newPod, node) gotStatus := p.Filter(ctx, nil, test.newPod, node)
if !reflect.DeepEqual(gotStatus, test.wantStatus) { if !reflect.DeepEqual(gotStatus, test.wantStatus) {
t.Errorf("Filter status does not match: %v, want: %v", gotStatus, test.wantStatus) t.Errorf("Filter status does not match: %v, want: %v", gotStatus, test.wantStatus)
} }
@ -412,8 +414,9 @@ func TestAzureDiskLimits(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.test, func(t *testing.T) { t.Run(test.test, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
node, csiNode := getNodeWithPodAndVolumeLimits("node", test.existingPods, test.maxVols, test.filterName) node, csiNode := getNodeWithPodAndVolumeLimits("node", test.existingPods, test.maxVols, test.filterName)
p := newNonCSILimits(test.filterName, getFakeCSINodeLister(csiNode), getFakeCSIStorageClassLister(test.filterName, test.driverName), getFakePVLister(test.filterName), getFakePVCLister(test.filterName), feature.Features{}).(framework.FilterPlugin) p := newNonCSILimits(ctx, test.filterName, getFakeCSINodeLister(csiNode), getFakeCSIStorageClassLister(test.filterName, test.driverName), getFakePVLister(test.filterName), getFakePVCLister(test.filterName), feature.Features{}).(framework.FilterPlugin)
_, gotPreFilterStatus := p.(*nonCSILimits).PreFilter(context.Background(), nil, test.newPod) _, gotPreFilterStatus := p.(*nonCSILimits).PreFilter(context.Background(), nil, test.newPod)
if diff := cmp.Diff(test.wantPreFilterStatus, gotPreFilterStatus); diff != "" { if diff := cmp.Diff(test.wantPreFilterStatus, gotPreFilterStatus); diff != "" {
t.Errorf("PreFilter status does not match (-want, +got): %s", diff) t.Errorf("PreFilter status does not match (-want, +got): %s", diff)
@ -693,15 +696,16 @@ func TestEBSLimits(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.test, func(t *testing.T) { t.Run(test.test, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
node, csiNode := getNodeWithPodAndVolumeLimits("node", test.existingPods, test.maxVols, test.filterName) node, csiNode := getNodeWithPodAndVolumeLimits("node", test.existingPods, test.maxVols, test.filterName)
p := newNonCSILimits(test.filterName, getFakeCSINodeLister(csiNode), getFakeCSIStorageClassLister(test.filterName, test.driverName), getFakePVLister(test.filterName), getFakePVCLister(test.filterName), feature.Features{}).(framework.FilterPlugin) p := newNonCSILimits(ctx, test.filterName, getFakeCSINodeLister(csiNode), getFakeCSIStorageClassLister(test.filterName, test.driverName), getFakePVLister(test.filterName), getFakePVCLister(test.filterName), feature.Features{}).(framework.FilterPlugin)
_, gotPreFilterStatus := p.(*nonCSILimits).PreFilter(context.Background(), nil, test.newPod) _, gotPreFilterStatus := p.(*nonCSILimits).PreFilter(ctx, nil, test.newPod)
if diff := cmp.Diff(test.wantPreFilterStatus, gotPreFilterStatus); diff != "" { if diff := cmp.Diff(test.wantPreFilterStatus, gotPreFilterStatus); diff != "" {
t.Errorf("PreFilter status does not match (-want, +got): %s", diff) t.Errorf("PreFilter status does not match (-want, +got): %s", diff)
} }
if gotPreFilterStatus.Code() != framework.Skip { if gotPreFilterStatus.Code() != framework.Skip {
gotStatus := p.Filter(context.Background(), nil, test.newPod, node) gotStatus := p.Filter(ctx, nil, test.newPod, node)
if !reflect.DeepEqual(gotStatus, test.wantStatus) { if !reflect.DeepEqual(gotStatus, test.wantStatus) {
t.Errorf("Filter status does not match: %v, want: %v", gotStatus, test.wantStatus) t.Errorf("Filter status does not match: %v, want: %v", gotStatus, test.wantStatus)
} }
@ -923,8 +927,9 @@ func TestGCEPDLimits(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.test, func(t *testing.T) { t.Run(test.test, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
node, csiNode := getNodeWithPodAndVolumeLimits("node", test.existingPods, test.maxVols, test.filterName) node, csiNode := getNodeWithPodAndVolumeLimits("node", test.existingPods, test.maxVols, test.filterName)
p := newNonCSILimits(test.filterName, getFakeCSINodeLister(csiNode), getFakeCSIStorageClassLister(test.filterName, test.driverName), getFakePVLister(test.filterName), getFakePVCLister(test.filterName), feature.Features{}).(framework.FilterPlugin) p := newNonCSILimits(ctx, test.filterName, getFakeCSINodeLister(csiNode), getFakeCSIStorageClassLister(test.filterName, test.driverName), getFakePVLister(test.filterName), getFakePVCLister(test.filterName), feature.Features{}).(framework.FilterPlugin)
_, gotPreFilterStatus := p.(*nonCSILimits).PreFilter(context.Background(), nil, test.newPod) _, gotPreFilterStatus := p.(*nonCSILimits).PreFilter(context.Background(), nil, test.newPod)
if diff := cmp.Diff(test.wantPreFilterStatus, gotPreFilterStatus); diff != "" { if diff := cmp.Diff(test.wantPreFilterStatus, gotPreFilterStatus); diff != "" {
t.Errorf("PreFilter status does not match (-want, +got): %s", diff) t.Errorf("PreFilter status does not match (-want, +got): %s", diff)
@ -965,8 +970,9 @@ func TestGetMaxVols(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
logger, _ := ktesting.NewTestContext(t)
t.Setenv(KubeMaxPDVols, test.rawMaxVols) t.Setenv(KubeMaxPDVols, test.rawMaxVols)
result := getMaxVolLimitFromEnv() result := getMaxVolLimitFromEnv(logger)
if result != test.expected { if result != test.expected {
t.Errorf("expected %v got %v", test.expected, result) t.Errorf("expected %v got %v", test.expected, result)
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package podtopologyspread package podtopologyspread
import ( import (
"context"
"fmt" "fmt"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
@ -82,7 +83,7 @@ func (pl *PodTopologySpread) Name() string {
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(plArgs runtime.Object, h framework.Handle, fts feature.Features) (framework.Plugin, error) { func New(_ context.Context, plArgs runtime.Object, h framework.Handle, fts feature.Features) (framework.Plugin, error) {
if h.SnapshotSharedLister() == nil { if h.SnapshotSharedLister() == nil {
return nil, fmt.Errorf("SnapshotSharedlister is nil") return nil, fmt.Errorf("SnapshotSharedlister is nil")
} }

View File

@ -95,7 +95,7 @@ func TestPreScoreSkip(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed creating framework runtime: %v", err) t.Fatalf("Failed creating framework runtime: %v", err)
} }
pl, err := New(&tt.config, f, feature.Features{}) pl, err := New(ctx, &tt.config, f, feature.Features{})
if err != nil { if err != nil {
t.Fatalf("Failed creating plugin: %v", err) t.Fatalf("Failed creating plugin: %v", err)
} }
@ -103,7 +103,7 @@ func TestPreScoreSkip(t *testing.T) {
informerFactory.WaitForCacheSync(ctx.Done()) informerFactory.WaitForCacheSync(ctx.Done())
p := pl.(*PodTopologySpread) p := pl.(*PodTopologySpread)
cs := framework.NewCycleState() cs := framework.NewCycleState()
if s := p.PreScore(context.Background(), cs, tt.pod, tt.nodes); !s.IsSkip() { if s := p.PreScore(ctx, cs, tt.pod, tt.nodes); !s.IsSkip() {
t.Fatalf("Expected skip but got %v", s.AsError()) t.Fatalf("Expected skip but got %v", s.AsError())
} }
}) })
@ -582,7 +582,7 @@ func TestPreScoreStateEmptyNodes(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed creating framework runtime: %v", err) t.Fatalf("Failed creating framework runtime: %v", err)
} }
pl, err := New(&tt.config, f, feature.Features{EnableNodeInclusionPolicyInPodTopologySpread: tt.enableNodeInclusionPolicy}) pl, err := New(ctx, &tt.config, f, feature.Features{EnableNodeInclusionPolicyInPodTopologySpread: tt.enableNodeInclusionPolicy})
if err != nil { if err != nil {
t.Fatalf("Failed creating plugin: %v", err) t.Fatalf("Failed creating plugin: %v", err)
} }
@ -1336,7 +1336,8 @@ func TestPodTopologySpreadScore(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) _, ctx := ktesting.NewTestContext(t)
ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel) t.Cleanup(cancel)
allNodes := append([]*v1.Node{}, tt.nodes...) allNodes := append([]*v1.Node{}, tt.nodes...)
allNodes = append(allNodes, tt.failedNodes...) allNodes = append(allNodes, tt.failedNodes...)
@ -1346,7 +1347,7 @@ func TestPodTopologySpreadScore(t *testing.T) {
p.enableNodeInclusionPolicyInPodTopologySpread = tt.enableNodeInclusionPolicy p.enableNodeInclusionPolicyInPodTopologySpread = tt.enableNodeInclusionPolicy
p.enableMatchLabelKeysInPodTopologySpread = tt.enableMatchLabelKeys p.enableMatchLabelKeysInPodTopologySpread = tt.enableMatchLabelKeys
status := p.PreScore(context.Background(), state, tt.pod, tt.nodes) status := p.PreScore(ctx, state, tt.pod, tt.nodes)
if !status.IsSuccess() { if !status.IsSuccess() {
t.Errorf("unexpected error: %v", status) t.Errorf("unexpected error: %v", status)
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package queuesort package queuesort
import ( import (
"context"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
corev1helpers "k8s.io/component-helpers/scheduling/corev1" corev1helpers "k8s.io/component-helpers/scheduling/corev1"
"k8s.io/kubernetes/pkg/scheduler/framework" "k8s.io/kubernetes/pkg/scheduler/framework"
@ -46,6 +47,6 @@ func (pl *PrioritySort) Less(pInfo1, pInfo2 *framework.QueuedPodInfo) bool {
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(_ runtime.Object, handle framework.Handle) (framework.Plugin, error) { func New(_ context.Context, _ runtime.Object, handle framework.Handle) (framework.Plugin, error) {
return &PrioritySort{}, nil return &PrioritySort{}, nil
} }

View File

@ -62,6 +62,6 @@ func (pl *SchedulingGates) EventsToRegister() []framework.ClusterEventWithHint {
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(_ runtime.Object, _ framework.Handle, fts feature.Features) (framework.Plugin, error) { func New(_ context.Context, _ runtime.Object, _ framework.Handle, fts feature.Features) (framework.Plugin, error) {
return &SchedulingGates{enablePodSchedulingReadiness: fts.EnablePodSchedulingReadiness}, nil return &SchedulingGates{enablePodSchedulingReadiness: fts.EnablePodSchedulingReadiness}, nil
} }

View File

@ -17,7 +17,6 @@ limitations under the License.
package schedulinggates package schedulinggates
import ( import (
"context"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@ -26,6 +25,7 @@ import (
"k8s.io/kubernetes/pkg/scheduler/framework" "k8s.io/kubernetes/pkg/scheduler/framework"
"k8s.io/kubernetes/pkg/scheduler/framework/plugins/feature" "k8s.io/kubernetes/pkg/scheduler/framework/plugins/feature"
st "k8s.io/kubernetes/pkg/scheduler/testing" st "k8s.io/kubernetes/pkg/scheduler/testing"
"k8s.io/kubernetes/test/utils/ktesting"
) )
func TestPreEnqueue(t *testing.T) { func TestPreEnqueue(t *testing.T) {
@ -63,12 +63,13 @@ func TestPreEnqueue(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
p, err := New(nil, nil, feature.Features{EnablePodSchedulingReadiness: tt.enablePodSchedulingReadiness}) _, ctx := ktesting.NewTestContext(t)
p, err := New(ctx, nil, nil, feature.Features{EnablePodSchedulingReadiness: tt.enablePodSchedulingReadiness})
if err != nil { if err != nil {
t.Fatalf("Creating plugin: %v", err) t.Fatalf("Creating plugin: %v", err)
} }
got := p.(framework.PreEnqueuePlugin).PreEnqueue(context.Background(), tt.pod) got := p.(framework.PreEnqueuePlugin).PreEnqueue(ctx, tt.pod)
if diff := cmp.Diff(tt.want, got); diff != "" { if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("unexpected status (-want, +got):\n%s", diff) t.Errorf("unexpected status (-want, +got):\n%s", diff)
} }

View File

@ -164,6 +164,6 @@ func (pl *TaintToleration) ScoreExtensions() framework.ScoreExtensions {
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(_ runtime.Object, h framework.Handle) (framework.Plugin, error) { func New(_ context.Context, _ runtime.Object, h framework.Handle) (framework.Plugin, error) {
return &TaintToleration{handle: h}, nil return &TaintToleration{handle: h}, nil
} }

View File

@ -237,7 +237,10 @@ func TestTaintTolerationScore(t *testing.T) {
snapshot := cache.NewSnapshot(nil, test.nodes) snapshot := cache.NewSnapshot(nil, test.nodes)
fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(snapshot)) fh, _ := runtime.NewFramework(ctx, nil, nil, runtime.WithSnapshotSharedLister(snapshot))
p, _ := New(nil, fh) p, err := New(ctx, nil, fh)
if err != nil {
t.Fatalf("creating plugin: %v", err)
}
status := p.(framework.PreScorePlugin).PreScore(ctx, state, test.pod, test.nodes) status := p.(framework.PreScorePlugin).PreScore(ctx, state, test.pod, test.nodes)
if !status.IsSuccess() { if !status.IsSuccess() {
t.Errorf("unexpected error: %v", status) t.Errorf("unexpected error: %v", status)
@ -335,10 +338,14 @@ func TestTaintTolerationFilter(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
nodeInfo := framework.NewNodeInfo() nodeInfo := framework.NewNodeInfo()
nodeInfo.SetNode(test.node) nodeInfo.SetNode(test.node)
p, _ := New(nil, nil) p, err := New(ctx, nil, nil)
gotStatus := p.(framework.FilterPlugin).Filter(context.Background(), nil, test.pod, nodeInfo) if err != nil {
t.Fatalf("creating plugin: %v", err)
}
gotStatus := p.(framework.FilterPlugin).Filter(ctx, nil, test.pod, nodeInfo)
if !reflect.DeepEqual(gotStatus, test.wantStatus) { if !reflect.DeepEqual(gotStatus, test.wantStatus) {
t.Errorf("status does not match: %v, want: %v", gotStatus, test.wantStatus) t.Errorf("status does not match: %v, want: %v", gotStatus, test.wantStatus)
} }

View File

@ -49,7 +49,7 @@ func SetupPluginWithInformers(
if err != nil { if err != nil {
tb.Fatalf("Failed creating framework runtime: %v", err) tb.Fatalf("Failed creating framework runtime: %v", err)
} }
p, err := pf(config, fh) p, err := pf(ctx, config, fh)
if err != nil { if err != nil {
tb.Fatal(err) tb.Fatal(err)
} }
@ -72,7 +72,7 @@ func SetupPlugin(
if err != nil { if err != nil {
tb.Fatalf("Failed creating framework runtime: %v", err) tb.Fatalf("Failed creating framework runtime: %v", err)
} }
p, err := pf(config, fh) p, err := pf(ctx, config, fh)
if err != nil { if err != nil {
tb.Fatal(err) tb.Fatal(err)
} }

View File

@ -361,7 +361,7 @@ func (pl *VolumeBinding) Unreserve(ctx context.Context, cs *framework.CycleState
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(plArgs runtime.Object, fh framework.Handle, fts feature.Features) (framework.Plugin, error) { func New(_ context.Context, plArgs runtime.Object, fh framework.Handle, fts feature.Features) (framework.Plugin, error) {
args, ok := plArgs.(*config.VolumeBindingArgs) args, ok := plArgs.(*config.VolumeBindingArgs)
if !ok { if !ok {
return nil, fmt.Errorf("want args to be of type VolumeBindingArgs, got %T", plArgs) return nil, fmt.Errorf("want args to be of type VolumeBindingArgs, got %T", plArgs)

View File

@ -806,7 +806,7 @@ func TestVolumeBinding(t *testing.T) {
} }
} }
pl, err := New(args, fh, item.fts) pl, err := New(ctx, args, fh, item.fts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -348,7 +348,7 @@ func (pl *VolumeRestrictions) EventsToRegister() []framework.ClusterEventWithHin
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(_ runtime.Object, handle framework.Handle, fts feature.Features) (framework.Plugin, error) { func New(_ context.Context, _ runtime.Object, handle framework.Handle, fts feature.Features) (framework.Plugin, error) {
informerFactory := handle.SharedInformerFactory() informerFactory := handle.SharedInformerFactory()
pvcLister := informerFactory.Core().V1().PersistentVolumeClaims().Lister() pvcLister := informerFactory.Core().V1().PersistentVolumeClaims().Lister()
sharedLister := handle.SnapshotSharedLister() sharedLister := handle.SnapshotSharedLister()

View File

@ -505,8 +505,8 @@ func newPlugin(ctx context.Context, t *testing.T) framework.Plugin {
} }
func newPluginWithListers(ctx context.Context, t *testing.T, pods []*v1.Pod, nodes []*v1.Node, pvcs []*v1.PersistentVolumeClaim, enableReadWriteOncePod bool) framework.Plugin { func newPluginWithListers(ctx context.Context, t *testing.T, pods []*v1.Pod, nodes []*v1.Node, pvcs []*v1.PersistentVolumeClaim, enableReadWriteOncePod bool) framework.Plugin {
pluginFactory := func(plArgs runtime.Object, fh framework.Handle) (framework.Plugin, error) { pluginFactory := func(ctx context.Context, plArgs runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return New(plArgs, fh, feature.Features{ return New(ctx, plArgs, fh, feature.Features{
EnableReadWriteOncePod: enableReadWriteOncePod, EnableReadWriteOncePod: enableReadWriteOncePod,
}) })
} }

View File

@ -290,7 +290,7 @@ func (pl *VolumeZone) EventsToRegister() []framework.ClusterEventWithHint {
} }
// New initializes a new plugin and returns it. // New initializes a new plugin and returns it.
func New(_ runtime.Object, handle framework.Handle) (framework.Plugin, error) { func New(_ context.Context, _ runtime.Object, handle framework.Handle) (framework.Plugin, error) {
informerFactory := handle.SharedInformerFactory() informerFactory := handle.SharedInformerFactory()
pvLister := informerFactory.Core().V1().PersistentVolumes().Lister() pvLister := informerFactory.Core().V1().PersistentVolumes().Lister()
pvcLister := informerFactory.Core().V1().PersistentVolumeClaims().Lister() pvcLister := informerFactory.Core().V1().PersistentVolumeClaims().Lister()

View File

@ -302,7 +302,7 @@ func NewFramework(ctx context.Context, r Registry, profile *config.KubeScheduler
Args: args, Args: args,
}) })
} }
p, err := factory(args, f) p, err := factory(ctx, args, f)
if err != nil { if err != nil {
return nil, fmt.Errorf("initializing plugin %q: %w", name, err) return nil, fmt.Errorf("initializing plugin %q: %w", name, err)
} }

View File

@ -76,7 +76,7 @@ var cmpOpts = []cmp.Option{
}), }),
} }
func newScoreWithNormalizePlugin1(injArgs runtime.Object, f framework.Handle) (framework.Plugin, error) { func newScoreWithNormalizePlugin1(_ context.Context, injArgs runtime.Object, f framework.Handle) (framework.Plugin, error) {
var inj injectedResult var inj injectedResult
if err := DecodeInto(injArgs, &inj); err != nil { if err := DecodeInto(injArgs, &inj); err != nil {
return nil, err return nil, err
@ -84,7 +84,7 @@ func newScoreWithNormalizePlugin1(injArgs runtime.Object, f framework.Handle) (f
return &TestScoreWithNormalizePlugin{scoreWithNormalizePlugin1, inj}, nil return &TestScoreWithNormalizePlugin{scoreWithNormalizePlugin1, inj}, nil
} }
func newScoreWithNormalizePlugin2(injArgs runtime.Object, f framework.Handle) (framework.Plugin, error) { func newScoreWithNormalizePlugin2(_ context.Context, injArgs runtime.Object, f framework.Handle) (framework.Plugin, error) {
var inj injectedResult var inj injectedResult
if err := DecodeInto(injArgs, &inj); err != nil { if err := DecodeInto(injArgs, &inj); err != nil {
return nil, err return nil, err
@ -92,7 +92,7 @@ func newScoreWithNormalizePlugin2(injArgs runtime.Object, f framework.Handle) (f
return &TestScoreWithNormalizePlugin{scoreWithNormalizePlugin2, inj}, nil return &TestScoreWithNormalizePlugin{scoreWithNormalizePlugin2, inj}, nil
} }
func newScorePlugin1(injArgs runtime.Object, f framework.Handle) (framework.Plugin, error) { func newScorePlugin1(_ context.Context, injArgs runtime.Object, f framework.Handle) (framework.Plugin, error) {
var inj injectedResult var inj injectedResult
if err := DecodeInto(injArgs, &inj); err != nil { if err := DecodeInto(injArgs, &inj); err != nil {
return nil, err return nil, err
@ -100,7 +100,7 @@ func newScorePlugin1(injArgs runtime.Object, f framework.Handle) (framework.Plug
return &TestScorePlugin{scorePlugin1, inj}, nil return &TestScorePlugin{scorePlugin1, inj}, nil
} }
func newPluginNotImplementingScore(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func newPluginNotImplementingScore(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &PluginNotImplementingScore{}, nil return &PluginNotImplementingScore{}, nil
} }
@ -154,7 +154,7 @@ func (pl *PluginNotImplementingScore) Name() string {
return pluginNotImplementingScore return pluginNotImplementingScore
} }
func newTestPlugin(injArgs runtime.Object, f framework.Handle) (framework.Plugin, error) { func newTestPlugin(_ context.Context, injArgs runtime.Object, f framework.Handle) (framework.Plugin, error) {
return &TestPlugin{name: testPlugin}, nil return &TestPlugin{name: testPlugin}, nil
} }
@ -296,7 +296,7 @@ func (dp *TestDuplicatePlugin) PreFilterExtensions() framework.PreFilterExtensio
var _ framework.PreFilterPlugin = &TestDuplicatePlugin{} var _ framework.PreFilterPlugin = &TestDuplicatePlugin{}
func newDuplicatePlugin(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func newDuplicatePlugin(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &TestDuplicatePlugin{}, nil return &TestDuplicatePlugin{}, nil
} }
@ -326,7 +326,7 @@ func (pl *TestPreEnqueuePlugin) PreEnqueue(ctx context.Context, p *v1.Pod) *fram
var _ framework.QueueSortPlugin = &TestQueueSortPlugin{} var _ framework.QueueSortPlugin = &TestQueueSortPlugin{}
func newQueueSortPlugin(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func newQueueSortPlugin(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &TestQueueSortPlugin{}, nil return &TestQueueSortPlugin{}, nil
} }
@ -343,7 +343,7 @@ func (pl *TestQueueSortPlugin) Less(_, _ *framework.QueuedPodInfo) bool {
var _ framework.BindPlugin = &TestBindPlugin{} var _ framework.BindPlugin = &TestBindPlugin{}
func newBindPlugin(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func newBindPlugin(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &TestBindPlugin{}, nil return &TestBindPlugin{}, nil
} }
@ -881,7 +881,7 @@ func TestPreEnqueuePlugins(t *testing.T) {
// register all plugins // register all plugins
tmpPl := pl tmpPl := pl
if err := registry.Register(pl.Name(), if err := registry.Register(pl.Name(),
func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return tmpPl, nil return tmpPl, nil
}); err != nil { }); err != nil {
t.Fatalf("fail to register preEnqueue plugin (%s)", pl.Name()) t.Fatalf("fail to register preEnqueue plugin (%s)", pl.Name())
@ -1004,9 +1004,11 @@ func TestRunPreScorePlugins(t *testing.T) {
for i, p := range tt.plugins { for i, p := range tt.plugins {
p := p p := p
enabled[i].Name = p.name enabled[i].Name = p.name
r.Register(p.name, func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { if err := r.Register(p.name, func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return p, nil return p, nil
}) }); err != nil {
t.Fatalf("fail to register PreScorePlugins plugin (%s)", p.Name())
}
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -1425,11 +1427,11 @@ func TestPreFilterPlugins(t *testing.T) {
preFilter2 := &TestPreFilterWithExtensionsPlugin{} preFilter2 := &TestPreFilterWithExtensionsPlugin{}
r := make(Registry) r := make(Registry)
r.Register(preFilterPluginName, r.Register(preFilterPluginName,
func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return preFilter1, nil return preFilter1, nil
}) })
r.Register(preFilterWithExtensionsPluginName, r.Register(preFilterWithExtensionsPluginName,
func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return preFilter2, nil return preFilter2, nil
}) })
plugins := &config.Plugins{PreFilter: config.PluginSet{Enabled: []config.Plugin{{Name: preFilterWithExtensionsPluginName}, {Name: preFilterPluginName}}}} plugins := &config.Plugins{PreFilter: config.PluginSet{Enabled: []config.Plugin{{Name: preFilterWithExtensionsPluginName}, {Name: preFilterPluginName}}}}
@ -1563,9 +1565,11 @@ func TestRunPreFilterPlugins(t *testing.T) {
for i, p := range tt.plugins { for i, p := range tt.plugins {
p := p p := p
enabled[i].Name = p.name enabled[i].Name = p.name
r.Register(p.name, func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { if err := r.Register(p.name, func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return p, nil return p, nil
}) }); err != nil {
t.Fatalf("fail to register PreFilter plugin (%s)", p.Name())
}
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -1651,9 +1655,11 @@ func TestRunPreFilterExtensionRemovePod(t *testing.T) {
for i, p := range tt.plugins { for i, p := range tt.plugins {
p := p p := p
enabled[i].Name = p.name enabled[i].Name = p.name
r.Register(p.name, func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { if err := r.Register(p.name, func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return p, nil return p, nil
}) }); err != nil {
t.Fatalf("fail to register PreFilterExtension plugin (%s)", p.Name())
}
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -1733,9 +1739,11 @@ func TestRunPreFilterExtensionAddPod(t *testing.T) {
for i, p := range tt.plugins { for i, p := range tt.plugins {
p := p p := p
enabled[i].Name = p.name enabled[i].Name = p.name
r.Register(p.name, func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { if err := r.Register(p.name, func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return p, nil return p, nil
}) }); err != nil {
t.Fatalf("fail to register PreFilterExtension plugin (%s)", p.Name())
}
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -1934,7 +1942,7 @@ func TestFilterPlugins(t *testing.T) {
// register all plugins // register all plugins
tmpPl := pl tmpPl := pl
if err := registry.Register(pl.name, if err := registry.Register(pl.name,
func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return tmpPl, nil return tmpPl, nil
}); err != nil { }); err != nil {
t.Fatalf("fail to register filter plugin (%s)", pl.name) t.Fatalf("fail to register filter plugin (%s)", pl.name)
@ -2058,7 +2066,7 @@ func TestPostFilterPlugins(t *testing.T) {
// register all plugins // register all plugins
tmpPl := pl tmpPl := pl
if err := registry.Register(pl.name, if err := registry.Register(pl.name,
func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return tmpPl, nil return tmpPl, nil
}); err != nil { }); err != nil {
t.Fatalf("fail to register postFilter plugin (%s)", pl.name) t.Fatalf("fail to register postFilter plugin (%s)", pl.name)
@ -2190,7 +2198,7 @@ func TestFilterPluginsWithNominatedPods(t *testing.T) {
if tt.preFilterPlugin != nil { if tt.preFilterPlugin != nil {
if err := registry.Register(tt.preFilterPlugin.name, if err := registry.Register(tt.preFilterPlugin.name,
func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return tt.preFilterPlugin, nil return tt.preFilterPlugin, nil
}); err != nil { }); err != nil {
t.Fatalf("fail to register preFilter plugin (%s)", tt.preFilterPlugin.name) t.Fatalf("fail to register preFilter plugin (%s)", tt.preFilterPlugin.name)
@ -2202,7 +2210,7 @@ func TestFilterPluginsWithNominatedPods(t *testing.T) {
} }
if tt.filterPlugin != nil { if tt.filterPlugin != nil {
if err := registry.Register(tt.filterPlugin.name, if err := registry.Register(tt.filterPlugin.name,
func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return tt.filterPlugin, nil return tt.filterPlugin, nil
}); err != nil { }); err != nil {
t.Fatalf("fail to register filter plugin (%s)", tt.filterPlugin.name) t.Fatalf("fail to register filter plugin (%s)", tt.filterPlugin.name)
@ -2366,7 +2374,7 @@ func TestPreBindPlugins(t *testing.T) {
for _, pl := range tt.plugins { for _, pl := range tt.plugins {
tmpPl := pl tmpPl := pl
if err := registry.Register(pl.name, func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { if err := registry.Register(pl.name, func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return tmpPl, nil return tmpPl, nil
}); err != nil { }); err != nil {
t.Fatalf("Unable to register pre bind plugins: %s", pl.name) t.Fatalf("Unable to register pre bind plugins: %s", pl.name)
@ -2524,7 +2532,7 @@ func TestReservePlugins(t *testing.T) {
for _, pl := range tt.plugins { for _, pl := range tt.plugins {
tmpPl := pl tmpPl := pl
if err := registry.Register(pl.name, func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { if err := registry.Register(pl.name, func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return tmpPl, nil return tmpPl, nil
}); err != nil { }); err != nil {
t.Fatalf("Unable to register pre bind plugins: %s", pl.name) t.Fatalf("Unable to register pre bind plugins: %s", pl.name)
@ -2650,7 +2658,7 @@ func TestPermitPlugins(t *testing.T) {
for _, pl := range tt.plugins { for _, pl := range tt.plugins {
tmpPl := pl tmpPl := pl
if err := registry.Register(pl.name, func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { if err := registry.Register(pl.name, func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return tmpPl, nil return tmpPl, nil
}); err != nil { }); err != nil {
t.Fatalf("Unable to register Permit plugin: %s", pl.name) t.Fatalf("Unable to register Permit plugin: %s", pl.name)
@ -2817,7 +2825,7 @@ func TestRecordingMetrics(t *testing.T) {
plugin := &TestPlugin{name: testPlugin, inj: tt.inject} plugin := &TestPlugin{name: testPlugin, inj: tt.inject}
r := make(Registry) r := make(Registry)
r.Register(testPlugin, r.Register(testPlugin,
func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return plugin, nil return plugin, nil
}) })
pluginSet := config.PluginSet{Enabled: []config.Plugin{{Name: testPlugin, Weight: 1}}} pluginSet := config.PluginSet{Enabled: []config.Plugin{{Name: testPlugin, Weight: 1}}}
@ -2941,7 +2949,7 @@ func TestRunBindPlugins(t *testing.T) {
name := fmt.Sprintf("bind-%d", i) name := fmt.Sprintf("bind-%d", i)
plugin := &TestPlugin{name: name, inj: injectedResult{BindStatus: int(inj)}} plugin := &TestPlugin{name: name, inj: injectedResult{BindStatus: int(inj)}}
r.Register(name, r.Register(name,
func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return plugin, nil return plugin, nil
}) })
pluginSet.Enabled = append(pluginSet.Enabled, config.Plugin{Name: name}) pluginSet.Enabled = append(pluginSet.Enabled, config.Plugin{Name: name})
@ -3000,7 +3008,7 @@ func TestPermitWaitDurationMetric(t *testing.T) {
plugin := &TestPlugin{name: testPlugin, inj: tt.inject} plugin := &TestPlugin{name: testPlugin, inj: tt.inject}
r := make(Registry) r := make(Registry)
err := r.Register(testPlugin, err := r.Register(testPlugin,
func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return plugin, nil return plugin, nil
}) })
if err != nil { if err != nil {
@ -3059,7 +3067,7 @@ func TestWaitOnPermit(t *testing.T) {
testPermitPlugin := &TestPermitPlugin{} testPermitPlugin := &TestPermitPlugin{}
r := make(Registry) r := make(Registry)
r.Register(permitPlugin, r.Register(permitPlugin,
func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return testPermitPlugin, nil return testPermitPlugin, nil
}) })
plugins := &config.Plugins{ plugins := &config.Plugins{

View File

@ -17,6 +17,7 @@ limitations under the License.
package runtime package runtime
import ( import (
"context"
"fmt" "fmt"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
@ -27,16 +28,16 @@ import (
) )
// PluginFactory is a function that builds a plugin. // PluginFactory is a function that builds a plugin.
type PluginFactory = func(configuration runtime.Object, f framework.Handle) (framework.Plugin, error) type PluginFactory = func(ctx context.Context, configuration runtime.Object, f framework.Handle) (framework.Plugin, error)
// PluginFactoryWithFts is a function that builds a plugin with certain feature gates. // PluginFactoryWithFts is a function that builds a plugin with certain feature gates.
type PluginFactoryWithFts func(runtime.Object, framework.Handle, plfeature.Features) (framework.Plugin, error) type PluginFactoryWithFts func(context.Context, runtime.Object, framework.Handle, plfeature.Features) (framework.Plugin, error)
// FactoryAdapter can be used to inject feature gates for a plugin that needs // FactoryAdapter can be used to inject feature gates for a plugin that needs
// them when the caller expects the older PluginFactory method. // them when the caller expects the older PluginFactory method.
func FactoryAdapter(fts plfeature.Features, withFts PluginFactoryWithFts) PluginFactory { func FactoryAdapter(fts plfeature.Features, withFts PluginFactoryWithFts) PluginFactory {
return func(plArgs runtime.Object, fh framework.Handle) (framework.Plugin, error) { return func(ctx context.Context, plArgs runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return withFts(plArgs, fh, fts) return withFts(ctx, plArgs, fh, fts)
} }
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package runtime package runtime
import ( import (
"context"
"reflect" "reflect"
"testing" "testing"
@ -78,8 +79,8 @@ func TestDecodeInto(t *testing.T) {
func isRegistryEqual(registryX, registryY Registry) bool { func isRegistryEqual(registryX, registryY Registry) bool {
for name, pluginFactory := range registryY { for name, pluginFactory := range registryY {
if val, ok := registryX[name]; ok { if val, ok := registryX[name]; ok {
p1, _ := pluginFactory(nil, nil) p1, _ := pluginFactory(nil, nil, nil)
p2, _ := val(nil, nil) p2, _ := val(nil, nil, nil)
if p1.Name() != p2.Name() { if p1.Name() != p2.Name() {
// pluginFactory functions are not the same. // pluginFactory functions are not the same.
return false return false
@ -110,7 +111,7 @@ func (p *mockNoopPlugin) Name() string {
func NewMockNoopPluginFactory() PluginFactory { func NewMockNoopPluginFactory() PluginFactory {
uuid := uuid.New().String() uuid := uuid.New().String()
return func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &mockNoopPlugin{uuid}, nil return &mockNoopPlugin{uuid}, nil
} }
} }

View File

@ -280,8 +280,8 @@ func (p *fakePlugin) Bind(context.Context, *framework.CycleState, *v1.Pod, strin
return nil return nil
} }
func newFakePlugin(name string) func(object runtime.Object, handle framework.Handle) (framework.Plugin, error) { func newFakePlugin(name string) func(ctx context.Context, object runtime.Object, handle framework.Handle) (framework.Plugin, error) {
return func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &fakePlugin{name: name}, nil return &fakePlugin{name: name}, nil
} }
} }

View File

@ -141,7 +141,7 @@ func (f *fakeExtender) IsInterested(pod *v1.Pod) bool {
type falseMapPlugin struct{} type falseMapPlugin struct{}
func newFalseMapPlugin() frameworkruntime.PluginFactory { func newFalseMapPlugin() frameworkruntime.PluginFactory {
return func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &falseMapPlugin{}, nil return &falseMapPlugin{}, nil
} }
} }
@ -161,7 +161,7 @@ func (pl *falseMapPlugin) ScoreExtensions() framework.ScoreExtensions {
type numericMapPlugin struct{} type numericMapPlugin struct{}
func newNumericMapPlugin() frameworkruntime.PluginFactory { func newNumericMapPlugin() frameworkruntime.PluginFactory {
return func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &numericMapPlugin{}, nil return &numericMapPlugin{}, nil
} }
} }
@ -183,7 +183,7 @@ func (pl *numericMapPlugin) ScoreExtensions() framework.ScoreExtensions {
} }
// NewNoPodsFilterPlugin initializes a noPodsFilterPlugin and returns it. // NewNoPodsFilterPlugin initializes a noPodsFilterPlugin and returns it.
func NewNoPodsFilterPlugin(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func NewNoPodsFilterPlugin(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &noPodsFilterPlugin{}, nil return &noPodsFilterPlugin{}, nil
} }
@ -223,7 +223,7 @@ func (pl *reverseNumericMapPlugin) NormalizeScore(_ context.Context, _ *framewor
} }
func newReverseNumericMapPlugin() frameworkruntime.PluginFactory { func newReverseNumericMapPlugin() frameworkruntime.PluginFactory {
return func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &reverseNumericMapPlugin{}, nil return &reverseNumericMapPlugin{}, nil
} }
} }
@ -252,7 +252,7 @@ func (pl *trueMapPlugin) NormalizeScore(_ context.Context, _ *framework.CycleSta
} }
func newTrueMapPlugin() frameworkruntime.PluginFactory { func newTrueMapPlugin() frameworkruntime.PluginFactory {
return func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &trueMapPlugin{}, nil return &trueMapPlugin{}, nil
} }
} }
@ -291,7 +291,7 @@ func (s *fakeNodeSelector) Filter(_ context.Context, _ *framework.CycleState, _
return nil return nil
} }
func newFakeNodeSelector(args runtime.Object, _ framework.Handle) (framework.Plugin, error) { func newFakeNodeSelector(_ context.Context, args runtime.Object, _ framework.Handle) (framework.Plugin, error) {
pl := &fakeNodeSelector{} pl := &fakeNodeSelector{}
if err := frameworkruntime.DecodeInto(args, &pl.fakeNodeSelectorArgs); err != nil { if err := frameworkruntime.DecodeInto(args, &pl.fakeNodeSelectorArgs); err != nil {
return nil, err return nil, err
@ -333,7 +333,7 @@ func (f *fakeNodeSelectorDependOnPodAnnotation) Filter(_ context.Context, _ *fra
return nil return nil
} }
func newFakeNodeSelectorDependOnPodAnnotation(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func newFakeNodeSelectorDependOnPodAnnotation(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &fakeNodeSelectorDependOnPodAnnotation{}, nil return &fakeNodeSelectorDependOnPodAnnotation{}, nil
} }
@ -2257,7 +2257,7 @@ func TestSchedulerSchedulePod(t *testing.T) {
"node1": framework.Unschedulable, "node1": framework.Unschedulable,
}), }),
), ),
tf.RegisterPluginAsExtensions("FakeFilter2", func(configuration runtime.Object, f framework.Handle) (framework.Plugin, error) { tf.RegisterPluginAsExtensions("FakeFilter2", func(_ context.Context, configuration runtime.Object, f framework.Handle) (framework.Plugin, error) {
return tf.FakePreFilterAndFilterPlugin{ return tf.FakePreFilterAndFilterPlugin{
FakePreFilterPlugin: &tf.FakePreFilterPlugin{ FakePreFilterPlugin: &tf.FakePreFilterPlugin{
Result: nil, Result: nil,
@ -2488,7 +2488,7 @@ func TestFindFitPredicateCallCounts(t *testing.T) {
plugin := tf.FakeFilterPlugin{} plugin := tf.FakeFilterPlugin{}
registerFakeFilterFunc := tf.RegisterFilterPlugin( registerFakeFilterFunc := tf.RegisterFilterPlugin(
"FakeFilter", "FakeFilter",
func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return &plugin, nil return &plugin, nil
}, },
) )
@ -3063,7 +3063,7 @@ func TestPreferNominatedNodeFilterCallCounts(t *testing.T) {
plugin := tf.FakeFilterPlugin{FailedNodeReturnCodeMap: test.nodeReturnCodeMap} plugin := tf.FakeFilterPlugin{FailedNodeReturnCodeMap: test.nodeReturnCodeMap}
registerFakeFilterFunc := tf.RegisterFilterPlugin( registerFakeFilterFunc := tf.RegisterFilterPlugin(
"FakeFilter", "FakeFilter",
func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return &plugin, nil return &plugin, nil
}, },
) )
@ -3279,7 +3279,7 @@ func setupTestSchedulerWithVolumeBinding(ctx context.Context, t *testing.T, volu
fns := []tf.RegisterPluginFunc{ fns := []tf.RegisterPluginFunc{
tf.RegisterQueueSortPlugin(queuesort.Name, queuesort.New), tf.RegisterQueueSortPlugin(queuesort.Name, queuesort.New),
tf.RegisterBindPlugin(defaultbinder.Name, defaultbinder.New), tf.RegisterBindPlugin(defaultbinder.Name, defaultbinder.New),
tf.RegisterPluginAsExtensions(volumebinding.Name, func(plArgs runtime.Object, handle framework.Handle) (framework.Plugin, error) { tf.RegisterPluginAsExtensions(volumebinding.Name, func(ctx context.Context, plArgs runtime.Object, handle framework.Handle) (framework.Plugin, error) {
return &volumebinding.VolumeBinding{Binder: volumeBinder, PVCLister: pvcInformer.Lister()}, nil return &volumebinding.VolumeBinding{Binder: volumeBinder, PVCLister: pvcInformer.Lister()}, nil
}, "PreFilter", "Filter", "Reserve", "PreBind"), }, "PreFilter", "Filter", "Reserve", "PreBind"),
} }

View File

@ -538,7 +538,7 @@ func TestInitPluginsWithIndexers(t *testing.T) {
{ {
name: "register indexer, no conflicts", name: "register indexer, no conflicts",
entrypoints: map[string]frameworkruntime.PluginFactory{ entrypoints: map[string]frameworkruntime.PluginFactory{
"AddIndexer": func(obj runtime.Object, handle framework.Handle) (framework.Plugin, error) { "AddIndexer": func(ctx context.Context, obj runtime.Object, handle framework.Handle) (framework.Plugin, error) {
podInformer := handle.SharedInformerFactory().Core().V1().Pods() podInformer := handle.SharedInformerFactory().Core().V1().Pods()
err := podInformer.Informer().GetIndexer().AddIndexers(cache.Indexers{ err := podInformer.Informer().GetIndexer().AddIndexers(cache.Indexers{
"nodeName": indexByPodSpecNodeName, "nodeName": indexByPodSpecNodeName,
@ -551,14 +551,14 @@ func TestInitPluginsWithIndexers(t *testing.T) {
name: "register the same indexer name multiple times, conflict", name: "register the same indexer name multiple times, conflict",
// order of registration doesn't matter // order of registration doesn't matter
entrypoints: map[string]frameworkruntime.PluginFactory{ entrypoints: map[string]frameworkruntime.PluginFactory{
"AddIndexer1": func(obj runtime.Object, handle framework.Handle) (framework.Plugin, error) { "AddIndexer1": func(ctx context.Context, obj runtime.Object, handle framework.Handle) (framework.Plugin, error) {
podInformer := handle.SharedInformerFactory().Core().V1().Pods() podInformer := handle.SharedInformerFactory().Core().V1().Pods()
err := podInformer.Informer().GetIndexer().AddIndexers(cache.Indexers{ err := podInformer.Informer().GetIndexer().AddIndexers(cache.Indexers{
"nodeName": indexByPodSpecNodeName, "nodeName": indexByPodSpecNodeName,
}) })
return &TestPlugin{name: "AddIndexer1"}, err return &TestPlugin{name: "AddIndexer1"}, err
}, },
"AddIndexer2": func(obj runtime.Object, handle framework.Handle) (framework.Plugin, error) { "AddIndexer2": func(ctx context.Context, obj runtime.Object, handle framework.Handle) (framework.Plugin, error) {
podInformer := handle.SharedInformerFactory().Core().V1().Pods() podInformer := handle.SharedInformerFactory().Core().V1().Pods()
err := podInformer.Informer().GetIndexer().AddIndexers(cache.Indexers{ err := podInformer.Informer().GetIndexer().AddIndexers(cache.Indexers{
"nodeName": indexByPodAnnotationNodeName, "nodeName": indexByPodAnnotationNodeName,
@ -572,14 +572,14 @@ func TestInitPluginsWithIndexers(t *testing.T) {
name: "register the same indexer body with different names, no conflicts", name: "register the same indexer body with different names, no conflicts",
// order of registration doesn't matter // order of registration doesn't matter
entrypoints: map[string]frameworkruntime.PluginFactory{ entrypoints: map[string]frameworkruntime.PluginFactory{
"AddIndexer1": func(obj runtime.Object, handle framework.Handle) (framework.Plugin, error) { "AddIndexer1": func(ctx context.Context, obj runtime.Object, handle framework.Handle) (framework.Plugin, error) {
podInformer := handle.SharedInformerFactory().Core().V1().Pods() podInformer := handle.SharedInformerFactory().Core().V1().Pods()
err := podInformer.Informer().GetIndexer().AddIndexers(cache.Indexers{ err := podInformer.Informer().GetIndexer().AddIndexers(cache.Indexers{
"nodeName1": indexByPodSpecNodeName, "nodeName1": indexByPodSpecNodeName,
}) })
return &TestPlugin{name: "AddIndexer1"}, err return &TestPlugin{name: "AddIndexer1"}, err
}, },
"AddIndexer2": func(obj runtime.Object, handle framework.Handle) (framework.Plugin, error) { "AddIndexer2": func(ctx context.Context, obj runtime.Object, handle framework.Handle) (framework.Plugin, error) {
podInformer := handle.SharedInformerFactory().Core().V1().Pods() podInformer := handle.SharedInformerFactory().Core().V1().Pods()
err := podInformer.Informer().GetIndexer().AddIndexers(cache.Indexers{ err := podInformer.Informer().GetIndexer().AddIndexers(cache.Indexers{
"nodeName2": indexByPodAnnotationNodeName, "nodeName2": indexByPodAnnotationNodeName,
@ -819,13 +819,15 @@ func Test_buildQueueingHintMap(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
defer featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features.SchedulerQueueingHints, !tt.featuregateDisabled)() defer featuregatetesting.SetFeatureGateDuringTest(t, utilfeature.DefaultFeatureGate, features.SchedulerQueueingHints, !tt.featuregateDisabled)()
logger, _ := ktesting.NewTestContext(t) logger, ctx := ktesting.NewTestContext(t)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
registry := frameworkruntime.Registry{} registry := frameworkruntime.Registry{}
cfgPls := &schedulerapi.Plugins{} cfgPls := &schedulerapi.Plugins{}
plugins := append(tt.plugins, &fakebindPlugin{}, &fakeQueueSortPlugin{}) plugins := append(tt.plugins, &fakebindPlugin{}, &fakeQueueSortPlugin{})
for _, pl := range plugins { for _, pl := range plugins {
tmpPl := pl tmpPl := pl
if err := registry.Register(pl.Name(), func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { if err := registry.Register(pl.Name(), func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return tmpPl, nil return tmpPl, nil
}); err != nil { }); err != nil {
t.Fatalf("fail to register filter plugin (%s)", pl.Name()) t.Fatalf("fail to register filter plugin (%s)", pl.Name())
@ -834,9 +836,7 @@ func Test_buildQueueingHintMap(t *testing.T) {
} }
profile := schedulerapi.KubeSchedulerProfile{Plugins: cfgPls} profile := schedulerapi.KubeSchedulerProfile{Plugins: cfgPls}
stopCh := make(chan struct{}) fwk, err := newFramework(ctx, registry, profile)
defer close(stopCh)
fwk, err := newFramework(registry, profile, stopCh)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1010,13 +1010,16 @@ func Test_UnionedGVKs(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, ctx := ktesting.NewTestContext(t)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
registry := plugins.NewInTreeRegistry() registry := plugins.NewInTreeRegistry()
cfgPls := &schedulerapi.Plugins{MultiPoint: tt.plugins} cfgPls := &schedulerapi.Plugins{MultiPoint: tt.plugins}
plugins := []framework.Plugin{&fakeNodePlugin{}, &fakePodPlugin{}, &fakeNoopPlugin{}, &fakeNoopRuntimePlugin{}, &fakeQueueSortPlugin{}, &fakebindPlugin{}} plugins := []framework.Plugin{&fakeNodePlugin{}, &fakePodPlugin{}, &fakeNoopPlugin{}, &fakeNoopRuntimePlugin{}, &fakeQueueSortPlugin{}, &fakebindPlugin{}}
for _, pl := range plugins { for _, pl := range plugins {
tmpPl := pl tmpPl := pl
if err := registry.Register(pl.Name(), func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { if err := registry.Register(pl.Name(), func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return tmpPl, nil return tmpPl, nil
}); err != nil { }); err != nil {
t.Fatalf("fail to register filter plugin (%s)", pl.Name()) t.Fatalf("fail to register filter plugin (%s)", pl.Name())
@ -1024,9 +1027,7 @@ func Test_UnionedGVKs(t *testing.T) {
} }
profile := schedulerapi.KubeSchedulerProfile{Plugins: cfgPls, PluginConfig: defaults.PluginConfigsV1} profile := schedulerapi.KubeSchedulerProfile{Plugins: cfgPls, PluginConfig: defaults.PluginConfigsV1}
stopCh := make(chan struct{}) fwk, err := newFramework(ctx, registry, profile)
defer close(stopCh)
fwk, err := newFramework(registry, profile, stopCh)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1043,8 +1044,8 @@ func Test_UnionedGVKs(t *testing.T) {
} }
} }
func newFramework(r frameworkruntime.Registry, profile schedulerapi.KubeSchedulerProfile, stopCh <-chan struct{}) (framework.Framework, error) { func newFramework(ctx context.Context, r frameworkruntime.Registry, profile schedulerapi.KubeSchedulerProfile) (framework.Framework, error) {
return frameworkruntime.NewFramework(context.Background(), r, &profile, return frameworkruntime.NewFramework(ctx, r, &profile,
frameworkruntime.WithSnapshotSharedLister(internalcache.NewSnapshot(nil, nil)), frameworkruntime.WithSnapshotSharedLister(internalcache.NewSnapshot(nil, nil)),
frameworkruntime.WithInformerFactory(informers.NewSharedInformerFactory(fake.NewSimpleClientset(), 0)), frameworkruntime.WithInformerFactory(informers.NewSharedInformerFactory(fake.NewSimpleClientset(), 0)),
) )

View File

@ -112,7 +112,7 @@ type node2PrioritizerPlugin struct{}
// NewNode2PrioritizerPlugin returns a factory function to build node2PrioritizerPlugin. // NewNode2PrioritizerPlugin returns a factory function to build node2PrioritizerPlugin.
func NewNode2PrioritizerPlugin() frameworkruntime.PluginFactory { func NewNode2PrioritizerPlugin() frameworkruntime.PluginFactory {
return func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &node2PrioritizerPlugin{}, nil return &node2PrioritizerPlugin{}, nil
} }
} }

View File

@ -45,7 +45,7 @@ func (pl *FalseFilterPlugin) Filter(_ context.Context, _ *framework.CycleState,
} }
// NewFalseFilterPlugin initializes a FalseFilterPlugin and returns it. // NewFalseFilterPlugin initializes a FalseFilterPlugin and returns it.
func NewFalseFilterPlugin(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func NewFalseFilterPlugin(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &FalseFilterPlugin{}, nil return &FalseFilterPlugin{}, nil
} }
@ -63,7 +63,7 @@ func (pl *TrueFilterPlugin) Filter(_ context.Context, _ *framework.CycleState, p
} }
// NewTrueFilterPlugin initializes a TrueFilterPlugin and returns it. // NewTrueFilterPlugin initializes a TrueFilterPlugin and returns it.
func NewTrueFilterPlugin(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func NewTrueFilterPlugin(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &TrueFilterPlugin{}, nil return &TrueFilterPlugin{}, nil
} }
@ -102,7 +102,7 @@ func (pl *FakeFilterPlugin) Filter(_ context.Context, _ *framework.CycleState, p
// NewFakeFilterPlugin initializes a fakeFilterPlugin and returns it. // NewFakeFilterPlugin initializes a fakeFilterPlugin and returns it.
func NewFakeFilterPlugin(failedNodeReturnCodeMap map[string]framework.Code) frameworkruntime.PluginFactory { func NewFakeFilterPlugin(failedNodeReturnCodeMap map[string]framework.Code) frameworkruntime.PluginFactory {
return func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &FakeFilterPlugin{ return &FakeFilterPlugin{
FailedNodeReturnCodeMap: failedNodeReturnCodeMap, FailedNodeReturnCodeMap: failedNodeReturnCodeMap,
}, nil }, nil
@ -131,7 +131,7 @@ func (pl *MatchFilterPlugin) Filter(_ context.Context, _ *framework.CycleState,
} }
// NewMatchFilterPlugin initializes a MatchFilterPlugin and returns it. // NewMatchFilterPlugin initializes a MatchFilterPlugin and returns it.
func NewMatchFilterPlugin(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func NewMatchFilterPlugin(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &MatchFilterPlugin{}, nil return &MatchFilterPlugin{}, nil
} }
@ -159,7 +159,7 @@ func (pl *FakePreFilterPlugin) PreFilterExtensions() framework.PreFilterExtensio
// NewFakePreFilterPlugin initializes a fakePreFilterPlugin and returns it. // NewFakePreFilterPlugin initializes a fakePreFilterPlugin and returns it.
func NewFakePreFilterPlugin(name string, result *framework.PreFilterResult, status *framework.Status) frameworkruntime.PluginFactory { func NewFakePreFilterPlugin(name string, result *framework.PreFilterResult, status *framework.Status) frameworkruntime.PluginFactory {
return func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &FakePreFilterPlugin{ return &FakePreFilterPlugin{
Result: result, Result: result,
Status: status, Status: status,
@ -189,7 +189,7 @@ func (pl *FakeReservePlugin) Unreserve(_ context.Context, _ *framework.CycleStat
// NewFakeReservePlugin initializes a fakeReservePlugin and returns it. // NewFakeReservePlugin initializes a fakeReservePlugin and returns it.
func NewFakeReservePlugin(status *framework.Status) frameworkruntime.PluginFactory { func NewFakeReservePlugin(status *framework.Status) frameworkruntime.PluginFactory {
return func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &FakeReservePlugin{ return &FakeReservePlugin{
Status: status, Status: status,
}, nil }, nil
@ -213,7 +213,7 @@ func (pl *FakePreBindPlugin) PreBind(_ context.Context, _ *framework.CycleState,
// NewFakePreBindPlugin initializes a fakePreBindPlugin and returns it. // NewFakePreBindPlugin initializes a fakePreBindPlugin and returns it.
func NewFakePreBindPlugin(status *framework.Status) frameworkruntime.PluginFactory { func NewFakePreBindPlugin(status *framework.Status) frameworkruntime.PluginFactory {
return func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &FakePreBindPlugin{ return &FakePreBindPlugin{
Status: status, Status: status,
}, nil }, nil
@ -238,7 +238,7 @@ func (pl *FakePermitPlugin) Permit(_ context.Context, _ *framework.CycleState, _
// NewFakePermitPlugin initializes a fakePermitPlugin and returns it. // NewFakePermitPlugin initializes a fakePermitPlugin and returns it.
func NewFakePermitPlugin(status *framework.Status, timeout time.Duration) frameworkruntime.PluginFactory { func NewFakePermitPlugin(status *framework.Status, timeout time.Duration) frameworkruntime.PluginFactory {
return func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &FakePermitPlugin{ return &FakePermitPlugin{
Status: status, Status: status,
Timeout: timeout, Timeout: timeout,
@ -271,7 +271,7 @@ func (pl *FakePreScoreAndScorePlugin) PreScore(ctx context.Context, state *frame
} }
func NewFakePreScoreAndScorePlugin(name string, score int64, preScoreStatus, scoreStatus *framework.Status) frameworkruntime.PluginFactory { func NewFakePreScoreAndScorePlugin(name string, score int64, preScoreStatus, scoreStatus *framework.Status) frameworkruntime.PluginFactory {
return func(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &FakePreScoreAndScorePlugin{ return &FakePreScoreAndScorePlugin{
name: name, name: name,
score: score, score: score,

View File

@ -62,7 +62,7 @@ var (
// newPlugin returns a plugin factory with specified Plugin. // newPlugin returns a plugin factory with specified Plugin.
func newPlugin(plugin framework.Plugin) frameworkruntime.PluginFactory { func newPlugin(plugin framework.Plugin) frameworkruntime.PluginFactory {
return func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
switch pl := plugin.(type) { switch pl := plugin.(type) {
case *PermitPlugin: case *PermitPlugin:
pl.fh = fh pl.fh = fh
@ -2518,7 +2518,7 @@ func (j *JobPlugin) PostBind(_ context.Context, state *framework.CycleState, p *
func TestActivatePods(t *testing.T) { func TestActivatePods(t *testing.T) {
var jobPlugin *JobPlugin var jobPlugin *JobPlugin
// Create a plugin registry for testing. Register a Job plugin. // Create a plugin registry for testing. Register a Job plugin.
registry := frameworkruntime.Registry{jobPluginName: func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { registry := frameworkruntime.Registry{jobPluginName: func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
jobPlugin = &JobPlugin{podLister: fh.SharedInformerFactory().Core().V1().Pods().Lister()} jobPlugin = &JobPlugin{podLister: fh.SharedInformerFactory().Core().V1().Pods().Lister()}
return jobPlugin, nil return jobPlugin, nil
}} }}

View File

@ -159,7 +159,7 @@ func TestPreemption(t *testing.T) {
// Initialize scheduler with a filter plugin. // Initialize scheduler with a filter plugin.
var filter tokenFilter var filter tokenFilter
registry := make(frameworkruntime.Registry) registry := make(frameworkruntime.Registry)
err := registry.Register(filterPluginName, func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { err := registry.Register(filterPluginName, func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return &filter, nil return &filter, nil
}) })
if err != nil { if err != nil {
@ -1046,7 +1046,7 @@ func (af *alwaysFail) PreBind(_ context.Context, _ *framework.CycleState, p *v1.
return framework.NewStatus(framework.Unschedulable) return framework.NewStatus(framework.Unschedulable)
} }
func newAlwaysFail(_ runtime.Object, _ framework.Handle) (framework.Plugin, error) { func newAlwaysFail(_ context.Context, _ runtime.Object, _ framework.Handle) (framework.Plugin, error) {
return &alwaysFail{}, nil return &alwaysFail{}, nil
} }

View File

@ -340,7 +340,7 @@ func TestCustomResourceEnqueue(t *testing.T) {
} }
registry := frameworkruntime.Registry{ registry := frameworkruntime.Registry{
"fakeCRPlugin": func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { "fakeCRPlugin": func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return &fakeCRPlugin{}, nil return &fakeCRPlugin{}, nil
}, },
} }
@ -447,8 +447,8 @@ func TestCustomResourceEnqueue(t *testing.T) {
func TestRequeueByBindFailure(t *testing.T) { func TestRequeueByBindFailure(t *testing.T) {
fakeBind := &firstFailBindPlugin{} fakeBind := &firstFailBindPlugin{}
registry := frameworkruntime.Registry{ registry := frameworkruntime.Registry{
"firstFailBindPlugin": func(o runtime.Object, fh framework.Handle) (framework.Plugin, error) { "firstFailBindPlugin": func(ctx context.Context, o runtime.Object, fh framework.Handle) (framework.Plugin, error) {
binder, err := defaultbinder.New(nil, fh) binder, err := defaultbinder.New(ctx, nil, fh)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -539,7 +539,7 @@ func TestRequeueByPermitRejection(t *testing.T) {
queueingHintCalledCounter := 0 queueingHintCalledCounter := 0
fakePermit := &fakePermitPlugin{} fakePermit := &fakePermitPlugin{}
registry := frameworkruntime.Registry{ registry := frameworkruntime.Registry{
fakePermitPluginName: func(o runtime.Object, fh framework.Handle) (framework.Plugin, error) { fakePermitPluginName: func(ctx context.Context, o runtime.Object, fh framework.Handle) (framework.Plugin, error) {
fakePermit = &fakePermitPlugin{ fakePermit = &fakePermitPlugin{
frameworkHandler: fh, frameworkHandler: fh,
schedulingHint: func(logger klog.Logger, pod *v1.Pod, oldObj, newObj interface{}) framework.QueueingHint { schedulingHint: func(logger klog.Logger, pod *v1.Pod, oldObj, newObj interface{}) framework.QueueingHint {

View File

@ -17,6 +17,7 @@ limitations under the License.
package scheduler package scheduler
import ( import (
"context"
"testing" "testing"
"time" "time"
@ -81,7 +82,7 @@ func InitTestSchedulerForFrameworkTest(t *testing.T, testCtx *testutils.TestCont
// NewPlugin returns a plugin factory with specified Plugin. // NewPlugin returns a plugin factory with specified Plugin.
func NewPlugin(plugin framework.Plugin) frameworkruntime.PluginFactory { func NewPlugin(plugin framework.Plugin) frameworkruntime.PluginFactory {
return func(_ runtime.Object, fh framework.Handle) (framework.Plugin, error) { return func(_ context.Context, _ runtime.Object, fh framework.Handle) (framework.Plugin, error) {
return plugin, nil return plugin, nil
} }
} }