diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 97ad9d5f5..0a97c466a 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -115,10 +115,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self.output_tensors_grad_dw = [[], []] # buffer for communication - self.send_forward_buffer = [[], []] - self.recv_forward_buffer = [[], []] - self.send_backward_buffer = [[], []] - self.recv_backward_buffer = [[], []] + self.send_forward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]] + self.recv_forward_buffer = [ + [], + [], + ] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]] + self.send_backward_buffer = [[], []] # [chunk0:[torch.Tensor], chunk1:[torch.Tensor]] + self.recv_backward_buffer = [ + [], + [], + ] # [chunk0:[(torch.Tensor, wait_handle)], chunk1:[(torch.Tensor, wait_handle)]] # y buffer for local send fwd self.local_send_forward_buffer = [] @@ -257,7 +263,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ) if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) - self.recv_forward_buffer[model_chunk_id].append(input_tensor) + self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles)) return wait_handles else: @@ -280,7 +286,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ) if self.enable_metadata_cache and self.tensor_metadata_recv[model_chunk_id] is None: self.tensor_metadata_recv[model_chunk_id] = create_send_metadata(input_tensor) - self.recv_forward_buffer[model_chunk_id].append(input_tensor) + self.recv_forward_buffer[model_chunk_id].append((input_tensor, wait_handles)) return wait_handles def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> List: @@ -316,7 +322,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ) if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) - self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) + self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles)) return wait_handles else: @@ -339,7 +345,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ) if self.enable_metadata_cache and self.grad_metadata_recv[model_chunk_id] is None: self.grad_metadata_recv[model_chunk_id] = create_send_metadata(output_tensor_grad) - self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) + self.recv_backward_buffer[model_chunk_id].append((output_tensor_grad, wait_handles)) return wait_handles def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: @@ -651,9 +657,12 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if model_chunk_id == 0: # is first stage; get input from microbatch if self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj = None + input_obj = None # (tensor, wait_handle) else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) + for h in input_obj[1]: + h.wait() + input_obj = input_obj[0] else: # is last stage; recv from local if self.stage_manager.is_last_stage(ignore_chunk=True): @@ -661,7 +670,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # not last stage; recv from next else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) - + for h in input_obj[1]: + h.wait() + input_obj = input_obj[0] # Here, let input_obj.requires_grad_() # if input_obj is not None: if not isinstance(input_obj, torch.Tensor): @@ -751,6 +762,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # chunk0 not last stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + for h in output_tensor_grad[1]: + h.wait() + output_tensor_grad = output_tensor_grad[0] else: # chunk1, is first stage; recv LOSS from local send bwd buffer if self.stage_manager.is_first_stage(ignore_chunk=True): @@ -758,6 +772,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # chunk1, not first stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + for h in output_tensor_grad[1]: + h.wait() + output_tensor_grad = output_tensor_grad[0] # get input and output object from buffer; input_obj = self.input_tensors[model_chunk_id].pop(0) @@ -886,9 +903,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): wait_handle = communication_func(scheduled_node.chunk) self.wait_handles.append(wait_handle) elif scheduled_node.type == "F": - for h in self.wait_handles: - for hh in h: - hh.wait() self.schedule_f( scheduled_node=scheduled_node, model_chunk=model_chunk, @@ -898,9 +912,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): outputs=outputs, ) elif scheduled_node.type == "B": - for h in self.wait_handles: - for hh in h: - hh.wait() self.schedule_b( scheduled_node=scheduled_node, model_chunk=model_chunk, @@ -914,9 +925,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk_id=scheduled_node.chunk, optimizer=optimizer, ) - for h in self.wait_handles: - for hh in h: - hh.wait() # return loss & output if outputs is not None: outputs = merge_batch(outputs)