diff --git a/pkg/scheduler/framework/cycle_state.go b/pkg/scheduler/framework/cycle_state.go index 4674878377c..7f4a1351258 100644 --- a/pkg/scheduler/framework/cycle_state.go +++ b/pkg/scheduler/framework/cycle_state.go @@ -82,6 +82,7 @@ func (c *CycleState) Clone() *CycleState { copy.storage.Store(k, v.(StateData).Clone()) return true }) + copy.recordPluginMetrics = c.recordPluginMetrics return copy } diff --git a/pkg/scheduler/framework/cycle_state_test.go b/pkg/scheduler/framework/cycle_state_test.go index 37f0d09b7e8..a98fb538b13 100644 --- a/pkg/scheduler/framework/cycle_state_test.go +++ b/pkg/scheduler/framework/cycle_state_test.go @@ -17,6 +17,7 @@ limitations under the License. package framework import ( + "fmt" "testing" ) @@ -31,44 +32,120 @@ func (f *fakeData) Clone() StateData { return copy } +var key StateKey = "fakedata_key" + +// createCycleStateWithFakeData creates *CycleState with fakeData. +// The given data is used in stored fakeData. +func createCycleStateWithFakeData(data string, recordPluginMetrics bool) *CycleState { + c := NewCycleState() + c.Write(key, &fakeData{ + data: data, + }) + c.SetRecordPluginMetrics(recordPluginMetrics) + return c +} + +// isCycleStateEqual returns whether two CycleState, which has fakeData in storage, is equal or not. +// And if they are not equal, returns message which shows why not equal. +func isCycleStateEqual(a, b *CycleState) (bool, string) { + if a == nil && b == nil { + return true, "" + } + if a == nil || b == nil { + return false, fmt.Sprintf("one CycleState is nil, but another one is not nil. A: %v, B: %v", a, b) + } + + if a.recordPluginMetrics != b.recordPluginMetrics { + return false, fmt.Sprintf("CycleState A and B have a different recordPluginMetrics. A: %v, B: %v", a.recordPluginMetrics, b.recordPluginMetrics) + } + + var msg string + isEqual := true + countA := 0 + a.storage.Range(func(k, v1 interface{}) bool { + countA++ + v2, ok := b.storage.Load(k) + if !ok { + isEqual = false + msg = fmt.Sprintf("CycleState B doesn't have the data which CycleState A has. key: %v, data: %v", k, v1) + return false + } + + typed1, ok1 := v1.(*fakeData) + typed2, ok2 := v2.(*fakeData) + if !ok1 || !ok2 { + isEqual = false + msg = fmt.Sprintf("CycleState has the data which is not type *fakeData.") + return false + } + + if typed1.data != typed2.data { + isEqual = false + msg = fmt.Sprintf("CycleState B has a different data on key %v. A: %v, B: %v", k, typed1.data, typed2.data) + return false + } + + return true + }) + + if !isEqual { + return false, msg + } + + countB := 0 + b.storage.Range(func(k, _ interface{}) bool { + countB++ + return true + }) + + if countA != countB { + return false, fmt.Sprintf("two Cyclestates have different numbers of data. A: %v, B: %v", countA, countB) + } + + return true, "" +} + func TestCycleStateClone(t *testing.T) { - var key StateKey = "key" - data1 := "value1" - data2 := "value2" - - state := NewCycleState() - originalValue := &fakeData{ - data: data1, - } - state.Write(key, originalValue) - stateCopy := state.Clone() - - valueCopy, err := stateCopy.Read(key) - if err != nil { - t.Errorf("failed to read copied value: %v", err) - } - if v, ok := valueCopy.(*fakeData); ok && v.data != data1 { - t.Errorf("clone failed, got %q, expected %q", v.data, data1) + tests := []struct { + name string + state *CycleState + wantClonedState *CycleState + }{ + { + name: "clone with recordPluginMetrics true", + state: createCycleStateWithFakeData("data", true), + wantClonedState: createCycleStateWithFakeData("data", true), + }, + { + name: "clone with recordPluginMetrics false", + state: createCycleStateWithFakeData("data", false), + wantClonedState: createCycleStateWithFakeData("data", false), + }, + { + name: "clone with nil CycleState", + state: nil, + wantClonedState: nil, + }, } - originalValue.data = data2 - original, err := state.Read(key) - if err != nil { - t.Errorf("failed to read original value: %v", err) - } - if v, ok := original.(*fakeData); ok && v.data != data2 { - t.Errorf("original value should change, got %q, expected %q", v.data, data2) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + state := tt.state + stateCopy := state.Clone() - if v, ok := valueCopy.(*fakeData); ok && v.data != data1 { - t.Errorf("cloned copy should not change, got %q, expected %q", v.data, data1) - } -} - -func TestCycleStateCloneNil(t *testing.T) { - var state *CycleState - stateCopy := state.Clone() - if stateCopy != nil { - t.Errorf("clone expected to be nil") + if isEqual, msg := isCycleStateEqual(stateCopy, tt.wantClonedState); !isEqual { + t.Errorf("unexpected cloned state: %v", msg) + } + + if state == nil || stateCopy == nil { + // not need to run the rest check in this case. + return + } + + stateCopy.Write(key, &fakeData{data: "modified"}) + if isEqual, _ := isCycleStateEqual(state, stateCopy); isEqual { + t.Errorf("the change for a cloned state should not affect the original state.") + } + }) } }