diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 7534435a4..b2d9f00cf 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -37,7 +37,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): overlap_p2p: bool = True, ): super().__init__(stage_manager) + # batch info self.num_microbatch = num_microbatch + self.microbatch_size = microbatch_size + self.num_model_chunks = num_model_chunks + self.batch: Any + self.batch_size: int + self.last_batch_size: Optional[int] = None + self.microbatch_offset: List[int] + self.collect_non_loss_data = None self.forward_only = None self.schedules = schedule @@ -45,7 +53,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self.do_post_validation = False self.is_first_run = True self.optimizer = None - self.num_model_chunks = num_model_chunks # P2PMeta cache # self.enable_metadata_cache = enable_metadata_cache @@ -674,6 +681,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): """ Runs Zerobubble schedule, with communication between pipeline stages. """ + # # prepare batch + self.load_batch(data_iter) + # print(f"self.batch {self.batch}; self.batch_size {self.batch_size}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}") + it = 0 # while we still have schedules_node in self.schedules while it < len(self.schedules): diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 15897f73d..99c8fcf0f 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -157,6 +157,7 @@ def test_run_fwd_bwd_base( print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank) + # data_iter = [input0] input_base = input0.clone() model_base = deepcopy(model) @@ -193,6 +194,7 @@ def test_run_fwd_bwd_base( scheduler.run_forward_backward( model_chunk=local_chunk, input_obj=input0, + # data_iter=iter(data_iter), data_iter=None, criterion=criterion, optimizer=None,