diff --git a/pkg/scheduler/framework/BUILD b/pkg/scheduler/framework/BUILD index af8bc8b75a5..bfb4637ccee 100644 --- a/pkg/scheduler/framework/BUILD +++ b/pkg/scheduler/framework/BUILD @@ -27,6 +27,8 @@ go_library( "//staging/src/k8s.io/client-go/kubernetes:go_default_library", "//staging/src/k8s.io/client-go/tools/events:go_default_library", "//staging/src/k8s.io/kube-scheduler/extender/v1:go_default_library", + "//vendor/github.com/google/go-cmp/cmp:go_default_library", + "//vendor/github.com/google/go-cmp/cmp/cmpopts:go_default_library", "//vendor/k8s.io/klog/v2:go_default_library", ], ) @@ -63,5 +65,6 @@ go_test( "//staging/src/k8s.io/apimachinery/pkg/api/resource:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/apis/meta/v1:go_default_library", "//staging/src/k8s.io/apimachinery/pkg/types:go_default_library", + "//vendor/github.com/google/go-cmp/cmp:go_default_library", ], ) diff --git a/pkg/scheduler/framework/interface.go b/pkg/scheduler/framework/interface.go index fbabdb7020e..61d52c0713d 100644 --- a/pkg/scheduler/framework/interface.go +++ b/pkg/scheduler/framework/interface.go @@ -25,6 +25,8 @@ import ( "strings" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/informers" @@ -180,6 +182,21 @@ func (s *Status) AsError() error { return errors.New(s.Message()) } +// Equal checks equality of two statuses. This is useful for testing with +// cmp.Equal. +func (s *Status) Equal(x *Status) bool { + if s == nil || x == nil { + return s.IsSuccess() && x.IsSuccess() + } + if s.code != x.code { + return false + } + if s.code == Error { + return cmp.Equal(s.err, x.err, cmpopts.EquateErrors()) + } + return cmp.Equal(s.reasons, x.reasons) +} + // NewStatus makes a Status out of the given arguments and returns its pointer. func NewStatus(code Code, reasons ...string) *Status { s := &Status{ diff --git a/pkg/scheduler/framework/interface_test.go b/pkg/scheduler/framework/interface_test.go index 50e8b3bbe56..9064ea3a5ad 100644 --- a/pkg/scheduler/framework/interface_test.go +++ b/pkg/scheduler/framework/interface_test.go @@ -18,9 +18,14 @@ package framework import ( "errors" + "fmt" "testing" + + "github.com/google/go-cmp/cmp" ) +var errorStatus = NewStatus(Error, "internal error") + func TestStatus(t *testing.T) { tests := []struct { name string @@ -133,3 +138,79 @@ func TestPluginToStatusMerge(t *testing.T) { }) } } + +func TestIsStatusEqual(t *testing.T) { + tests := []struct { + name string + x, y *Status + want bool + }{ + { + name: "two nil should be equal", + x: nil, + y: nil, + want: true, + }, + { + name: "nil should be equal to success status", + x: nil, + y: NewStatus(Success), + want: true, + }, + { + name: "nil should not be equal with status except success", + x: nil, + y: NewStatus(Error, "internal error"), + want: false, + }, + { + name: "one status should be equal to itself", + x: errorStatus, + y: errorStatus, + want: true, + }, + { + name: "same type statuses without reasons should be equal", + x: NewStatus(Success), + y: NewStatus(Success), + want: true, + }, + { + name: "statuses with same message should be equal", + x: NewStatus(Unschedulable, "unschedulable"), + y: NewStatus(Unschedulable, "unschedulable"), + want: true, + }, + { + name: "error statuses with same message should not be equal", + x: NewStatus(Error, "error"), + y: NewStatus(Error, "error"), + want: false, + }, + { + name: "statuses with different reasons should not be equal", + x: NewStatus(Unschedulable, "unschedulable"), + y: NewStatus(Unschedulable, "unschedulable", "injected filter status"), + want: false, + }, + { + name: "statuses with different codes should not be equal", + x: NewStatus(Error, "internal error"), + y: NewStatus(Unschedulable, "internal error"), + want: false, + }, + { + name: "wrap error status should be equal with original one", + x: errorStatus, + y: AsStatus(fmt.Errorf("error: %w", errorStatus.AsError())), + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := cmp.Equal(tt.x, tt.y); got != tt.want { + t.Errorf("cmp.Equal() = %v, want %v", got, tt.want) + } + }) + } +}