diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index da5320cf3..b589579c3 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -40,9 +40,8 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self.num_microbatch = num_microbatch self.collect_non_loss_data = None self.forward_only = None - self.schedules = schedule - self.it = 0 # curr iteration + # TODO: optim post valid self.do_post_validation = False self.is_first_run = True self.optimizer = None @@ -69,16 +68,19 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self.input_tensors = [[], []] self.output_tensors = [[], []] - # y & dy buffer for schedule b + # y & dy buffer for schedule w self.output_tensors_dw = [[], []] self.output_tensors_grad_dw = [[], []] + # buffer for communication self.send_forward_buffer = [[], []] self.recv_forward_buffer = [[], []] self.send_backward_buffer = [[], []] self.recv_backward_buffer = [[], []] - self.forward_data_store = [] + + # y buffer for local send fwd self.local_send_forward_buffer = [] + # dy buffer for local send bwd self.local_send_backward_buffer = [] def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: @@ -263,7 +265,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): Args: model_chunk_id (int): The current model chunk idx. - output_object (Any): Object to be sent. next_rank (int, optional): The rank of the recipient of the tensor. Returns: @@ -313,7 +314,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): Args: model_chunk_id (int): The current model chunk idx. - input_object (Any): Object to be sent. prev_rank (int, optional): The rank of the recipient of the tensor Returns: @@ -371,9 +371,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ) -> Union[torch.Tensor, dict]: """Forward one step of the pipeline Args: - model (ModuleList or Module): Model Chunk to be run - input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None. - criterion (Callable): Criterion to calculate loss. + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + input_obj (Optional[dict]): x; + criterion (Callable): loss function; accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. @@ -410,16 +411,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], ) -> Optional[dict]: - """Backward one step of the pipeline + """Backward dx step of the pipeline; we calculate "dx = w*dy" here; Args: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; optimizer (OptimizerWrapper): Optimizer to update the model - input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None. - output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor). - output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None. + input_obj (Optional[dict]): x. + output_obj (Union[dict, torch.Tensor]): y. + output_obj_grad (dict): dy. Returns: - Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None. + Optional[dict]: dx. """ # calculate bwd b step ; only dx = w*dy; @@ -451,10 +454,21 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk: Union[ModuleList, Module], model_chunk_id: int, # optimizer: OptimizerWrapper, - # input_obj: Optional[dict], output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], ): + """Backward dw step of the pipeline; we calculate "dw = x*dy" here; + + Args: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + optimizer (OptimizerWrapper): Optimizer to update the model + output_obj (Union[dict, torch.Tensor]): y. + output_obj_grad (dict): dy. + + Returns: + Nothing need to return; we only calculate dw then update w; + """ # calculate bwd w step ; only dw = x*dy; if model_chunk_id == 0: torch.autograd.backward( @@ -481,6 +495,20 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): accum_loss: Optional[torch.Tensor] = None, outputs: Optional[List[Any]] = None, ): + """A complete forward schedule; Include recv fwd --> cal fwd --> send fwd; + + Args: + scheduled_node: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + input_obj (Optional[dict]): x; + criterion (Callable): loss function; + accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. + outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. + + Returns: + Nothing. + """ # Step1: recv fwd if model_chunk_id == 0: # is first stage; get input from func param @@ -541,6 +569,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # output_obj: Union[dict, torch.Tensor], # output_obj_grad: Optional[dict], ): + """A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd; + + Args: + scheduled_node: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + Returns: + Nothing. + """ + # Step1: recv bwd if model_chunk_id == 0: # chunk0 is last stage; recv output_grad from local_send_backward_buffer @@ -606,6 +644,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk_id: int, # optimizer: OptimizerWrapper, ): + """A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w); + + Args: + scheduled_node: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + Returns: + Nothing. + """ # get y & dy from buffer output_obj = self.output_tensors_dw[model_chunk_id].pop(0) @@ -629,7 +676,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): return_loss: bool = False, return_outputs: bool = False, ): - it = self.it + """ + Runs Zerobubble schedule, with communication between pipeline stages. + """ + it = 0 # while we still have schedules_node in self.schedules # print(f"manger_stage {self.stage_manager.stage} schedule {self.schedules} \n") while it < len(self.schedules):