diff --git a/pkg/controller/statefulset/stateful_set_control_test.go b/pkg/controller/statefulset/stateful_set_control_test.go index 8c77ae1c76f..0c19a554678 100644 --- a/pkg/controller/statefulset/stateful_set_control_test.go +++ b/pkg/controller/statefulset/stateful_set_control_test.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "math" "math/rand" "reflect" "runtime" @@ -2423,51 +2424,53 @@ type requestTracker struct { err error after int - parallelLock sync.Mutex - parallel int - maxParallel int - - delay time.Duration + // this block should be updated consistently + parallelLock sync.Mutex + shouldTrackParallelRequests bool + parallelRequests int + maxParallelRequests int + parallelRequestDelay time.Duration } -func (rt *requestTracker) errorReady() bool { - rt.Lock() - defer rt.Unlock() - return rt.err != nil && rt.requests >= rt.after -} - -func (rt *requestTracker) inc() { - rt.parallelLock.Lock() - rt.parallel++ - if rt.maxParallel < rt.parallel { - rt.maxParallel = rt.parallel +func (rt *requestTracker) trackParallelRequests() { + if !rt.shouldTrackParallelRequests { + // do not track parallel requests unless specifically enabled + return } - rt.parallelLock.Unlock() + if rt.parallelLock.TryLock() { + // lock acquired: we are the only or the first concurrent request + // initialize the next set of parallel requests + rt.parallelRequests = 1 + } else { + // lock is held by other requests + // now wait for the lock to increase the parallelRequests + rt.parallelLock.Lock() + rt.parallelRequests++ + } + defer rt.parallelLock.Unlock() + // update the local maximum of parallel collisions + if rt.maxParallelRequests < rt.parallelRequests { + rt.maxParallelRequests = rt.parallelRequests + } + // increase the chance of collisions + if rt.parallelRequestDelay > 0 { + time.Sleep(rt.parallelRequestDelay) + } +} +func (rt *requestTracker) incWithOptionalError() error { rt.Lock() defer rt.Unlock() rt.requests++ - if rt.delay != 0 { - time.Sleep(rt.delay) + if rt.err != nil && rt.requests >= rt.after { + // reset and pass the error + defer func() { + rt.err = nil + rt.after = 0 + }() + return rt.err } -} - -func (rt *requestTracker) reset() { - rt.parallelLock.Lock() - rt.parallel = 0 - rt.parallelLock.Unlock() - - rt.Lock() - defer rt.Unlock() - rt.err = nil - rt.after = 0 - rt.delay = 0 -} - -func (rt *requestTracker) getErr() error { - rt.Lock() - defer rt.Unlock() - return rt.err + return nil } func newRequestTracker(requests int, err error, after int) requestTracker { @@ -2512,10 +2515,9 @@ func newFakeObjectManager(informerFactory informers.SharedInformerFactory) *fake } func (om *fakeObjectManager) CreatePod(ctx context.Context, pod *v1.Pod) error { - defer om.createPodTracker.inc() - if om.createPodTracker.errorReady() { - defer om.createPodTracker.reset() - return om.createPodTracker.getErr() + defer om.createPodTracker.trackParallelRequests() + if err := om.createPodTracker.incWithOptionalError(); err != nil { + return err } pod.SetUID(types.UID(pod.Name + "-uid")) return om.podsIndexer.Update(pod) @@ -2526,19 +2528,17 @@ func (om *fakeObjectManager) GetPod(namespace, podName string) (*v1.Pod, error) } func (om *fakeObjectManager) UpdatePod(pod *v1.Pod) error { - defer om.updatePodTracker.inc() - if om.updatePodTracker.errorReady() { - defer om.updatePodTracker.reset() - return om.updatePodTracker.getErr() + defer om.updatePodTracker.trackParallelRequests() + if err := om.updatePodTracker.incWithOptionalError(); err != nil { + return err } return om.podsIndexer.Update(pod) } func (om *fakeObjectManager) DeletePod(pod *v1.Pod) error { - defer om.deletePodTracker.inc() - if om.deletePodTracker.errorReady() { - defer om.deletePodTracker.reset() - return om.deletePodTracker.getErr() + defer om.deletePodTracker.trackParallelRequests() + if err := om.deletePodTracker.incWithOptionalError(); err != nil { + return err } if key, err := controller.KeyFunc(pod); err != nil { return err @@ -2733,10 +2733,9 @@ func newFakeStatefulSetStatusUpdater(setInformer appsinformers.StatefulSetInform } func (ssu *fakeStatefulSetStatusUpdater) UpdateStatefulSetStatus(ctx context.Context, set *apps.StatefulSet, status *apps.StatefulSetStatus) error { - defer ssu.updateStatusTracker.inc() - if ssu.updateStatusTracker.errorReady() { - defer ssu.updateStatusTracker.reset() - return ssu.updateStatusTracker.err + defer ssu.updateStatusTracker.trackParallelRequests() + if err := ssu.updateStatusTracker.incWithOptionalError(); err != nil { + return err } set.Status = *status ssu.setsIndexer.Update(set) @@ -2942,50 +2941,61 @@ func fakeResourceVersion(object interface{}) { obj.SetResourceVersion(strconv.FormatInt(intValue+1, 10)) } } - func TestParallelScale(t *testing.T) { for _, tc := range []struct { - desc string - replicas int32 - desiredReplicas int32 + desc string + replicas int32 + desiredReplicas int32 + expectedMinParallelRequests int }{ { - desc: "scale up from 3 to 30", - replicas: 3, - desiredReplicas: 30, + desc: "scale up from 3 to 30", + replicas: 3, + desiredReplicas: 30, + expectedMinParallelRequests: 2, }, { - desc: "scale down from 10 to 1", - replicas: 10, - desiredReplicas: 1, + desc: "scale down from 10 to 1", + replicas: 10, + desiredReplicas: 1, + expectedMinParallelRequests: 2, }, { - desc: "scale down to 0", - replicas: 501, - desiredReplicas: 0, + desc: "scale down to 0", + replicas: 501, + desiredReplicas: 0, + expectedMinParallelRequests: 10, }, { - desc: "scale up from 0", - replicas: 0, - desiredReplicas: 1000, + desc: "scale up from 0", + replicas: 0, + desiredReplicas: 1000, + expectedMinParallelRequests: 20, }, } { t.Run(tc.desc, func(t *testing.T) { set := burst(newStatefulSet(0)) set.Spec.VolumeClaimTemplates[0].ObjectMeta.Labels = map[string]string{"test": "test"} - parallelScale(t, set, tc.replicas, tc.desiredReplicas, assertBurstInvariants) + parallelScale(t, set, tc.replicas, tc.desiredReplicas, tc.expectedMinParallelRequests, assertBurstInvariants) }) } } -func parallelScale(t *testing.T, set *apps.StatefulSet, replicas, desiredReplicas int32, invariants invariantFunc) { +func parallelScale(t *testing.T, set *apps.StatefulSet, replicas, desiredReplicas int32, expectedMinParallelRequests int, invariants invariantFunc) { var err error diff := desiredReplicas - replicas + + // maxParallelRequests: MaxBatchSize of the controller is 500, We divide the diff by 4 to allow maximum of the half of the last batch. + if maxParallelRequests := min(500, math.Abs(float64(diff))/4); expectedMinParallelRequests < 2 || float64(expectedMinParallelRequests) > maxParallelRequests { + t.Fatalf("expectedMinParallelRequests should be between 2 and %v. Batch size of the controller is expontially increasing until 500. "+ + "Got expectedMinParallelRequests %v, ", maxParallelRequests, expectedMinParallelRequests) + } client := fake.NewSimpleClientset(set) om, _, ssc := setupController(client) - om.createPodTracker.delay = time.Millisecond + om.createPodTracker.shouldTrackParallelRequests = true + om.createPodTracker.parallelRequestDelay = time.Millisecond *set.Spec.Replicas = replicas if err := parallelScaleUpStatefulSetControl(set, ssc, om, invariants); err != nil { @@ -3017,8 +3027,8 @@ func parallelScale(t *testing.T, set *apps.StatefulSet, replicas, desiredReplica t.Errorf("Failed to scale statefulset to %v replicas, got %v replicas", desiredReplicas, set.Status.Replicas) } - if (diff < -1 || diff > 1) && om.createPodTracker.maxParallel <= 1 { - t.Errorf("want max parallel requests > 1, got %v", om.createPodTracker.maxParallel) + if om.createPodTracker.maxParallelRequests < expectedMinParallelRequests { + t.Errorf("want max parallelRequests requests >= %v, got %v", expectedMinParallelRequests, om.createPodTracker.maxParallelRequests) } }