diff --git a/pkg/controller/statefulset/stateful_set_utils.go b/pkg/controller/statefulset/stateful_set_utils.go index 3c470fac43f..27f72e0b50e 100644 --- a/pkg/controller/statefulset/stateful_set_utils.go +++ b/pkg/controller/statefulset/stateful_set_utils.go @@ -344,11 +344,12 @@ func ApplyRevision(set *apps.StatefulSet, revision *apps.ControllerRevision) (*a if err != nil { return nil, err } - err = json.Unmarshal(patched, clone) + restoredSet := &apps.StatefulSet{} + err = json.Unmarshal(patched, restoredSet) if err != nil { return nil, err } - return clone, nil + return restoredSet, nil } // nextRevision finds the next valid revision number based on revisions. If the length of revisions diff --git a/pkg/controller/statefulset/stateful_set_utils_test.go b/pkg/controller/statefulset/stateful_set_utils_test.go index ab4e9f04c6f..8dee5cadea6 100644 --- a/pkg/controller/statefulset/stateful_set_utils_test.go +++ b/pkg/controller/statefulset/stateful_set_utils_test.go @@ -289,6 +289,39 @@ func TestCreateApplyRevision(t *testing.T) { } } +func TestRollingUpdateApplyRevision(t *testing.T) { + set := newStatefulSet(1) + set.Status.CollisionCount = new(int32) + currentSet := set.DeepCopy() + currentRevision, err := newRevision(set, 1, set.Status.CollisionCount) + if err != nil { + t.Fatal(err) + } + + set.Spec.Template.Spec.Containers[0].Env = []v1.EnvVar{{Name: "foo", Value: "bar"}} + updateSet := set.DeepCopy() + updateRevision, err := newRevision(set, 2, set.Status.CollisionCount) + if err != nil { + t.Fatal(err) + } + + restoredCurrentSet, err := ApplyRevision(set, currentRevision) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(currentSet.Spec.Template, restoredCurrentSet.Spec.Template) { + t.Errorf("want %v got %v", currentSet.Spec.Template, restoredCurrentSet.Spec.Template) + } + + restoredUpdateSet, err := ApplyRevision(set, updateRevision) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(updateSet.Spec.Template, restoredUpdateSet.Spec.Template) { + t.Errorf("want %v got %v", updateSet.Spec.Template, restoredUpdateSet.Spec.Template) + } +} + func TestGetPersistentVolumeClaims(t *testing.T) { // nil inherits statefulset labels