From 7568b34626ff81e1c70c4dacc0a84d9ea11d5960 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 9 Sep 2024 08:04:28 +0000 Subject: [PATCH] [fix] fix redundant detach & clone; add buffer assertation in the end; --- .../pipeline/schedule/zero_bubble_pp.py | 26 +++++++++++++++++-- .../test_schedule/test_zerobubble_pp.py | 5 ++-- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index ad0adc7f7..622e7eb08 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -108,6 +108,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # dy buffer for local send bwd self.local_send_backward_buffer = [] + def assert_buffer_empty(self): + # assert buuffer is empty at end + assert len(self.input_tensors[0]) == 0 + assert len(self.input_tensors[1]) == 0 + assert len(self.output_tensors[0]) == 0 + assert len(self.output_tensors[1]) == 0 + assert len(self.output_tensors_dw[0]) == 0 + assert len(self.output_tensors_dw[1]) == 0 + assert len(self.output_tensors_grad_dw[0]) == 0 + assert len(self.output_tensors_grad_dw[1]) == 0 + assert len(self.send_forward_buffer[0]) == 0 + assert len(self.send_forward_buffer[1]) == 0 + assert len(self.recv_forward_buffer[0]) == 0 + assert len(self.recv_forward_buffer[1]) == 0 + assert len(self.send_backward_buffer[0]) == 0 + assert len(self.send_backward_buffer[1]) == 0 + assert len(self.recv_backward_buffer[0]) == 0 + assert len(self.recv_backward_buffer[1]) == 0 + assert len(self.local_send_forward_buffer) == 0 + assert len(self.local_send_backward_buffer) == 0 + def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. @@ -546,7 +567,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ) if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # We should not detach bwd LOSS - detached_output_obj = output_obj.clone() + pass else: detached_output_obj = output_obj.clone().detach() @@ -555,7 +576,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if model_chunk_id == 0: # is last stage; send to local_send_forward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): - detached_output_obj = detached_output_obj.detach() self.local_send_forward_buffer.append(detached_output_obj) else: self.send_forward_buffer[model_chunk_id].append(detached_output_obj) @@ -816,4 +836,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs ) + self.assert_buffer_empty() + return result diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 950424338..6ad93e6cb 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -558,7 +558,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config): batch_size = test_config["batch_size"] num_layers = 8 assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" - in_dim = out_dim = 8192 + in_dim = out_dim = 4096 before_init_memory = torch.cuda.memory_allocated() / 1024**3 print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) @@ -619,7 +619,6 @@ def run_fwd_bwd_vschedule_with_optim(test_config): # output hid_dim * hid_dim * 4(fp32) / 1024**3 print(f"rank {rank}: {(after_pp_step_memory - after_init_memory)} <= {(in_dim * in_dim * 4 * 3 / 1024**3)}") assert (after_pp_step_memory - after_init_memory) <= (in_dim * in_dim * 4 * 3 / 1024**3) - # pass else: # rank0 will also hold output; print( @@ -628,7 +627,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config): assert round((after_pp_step_memory - after_init_memory), 5) <= round( (in_dim * in_dim * 4 * 3 / 1024**3 + batch_size * in_dim * in_dim * 4 / 1024**3), 5 ) - # pass + ########################## # Fwd bwd for base ##########################