From 6af81d8c0db205a7466e6b0d9ccc1855834e6056 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 30 Aug 2024 02:47:52 +0000 Subject: [PATCH] [feat] add fwd_bwd_step, run_fwd_only; --- .../pipeline/schedule/zero_bubble_pp.py | 86 ++++++++++++++++++- .../test_schedule/test_zerobubble_pp.py | 29 +++++-- 2 files changed, 108 insertions(+), 7 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 23039af6d..ee6ad3227 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Callable, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.cuda @@ -696,6 +696,54 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_obj_grad=output_obj_grad, ) + def run_forward_only( + self, + model_chunk: Union[ModuleList, Module], + data_iter: Iterable, + criterion: Callable[..., Any], + return_loss: bool = False, + return_outputs: bool = False, + ) -> Dict: + assert self.forward_only + + # prepare batch + self.load_batch(data_iter) + + # prepare accum loss & output + accum_loss = None + + # reset accum loss at fwd end; + if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True): + accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) + + outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None + + it = 0 + # while we still have schedules_node in self.schedules + while it < len(self.schedules): + scheduled_node = self.schedules[it] + + if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: + # communication + if scheduled_node.type == "RECV_FORWARD": + self.recv_forward(scheduled_node.chunk) + elif scheduled_node.type == "SEND_FORWARD": + self.send_forward(scheduled_node.chunk) + if scheduled_node.type == "F": + self.schedule_f( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + criterion=criterion, + accum_loss=accum_loss, + outputs=outputs, + ) + it += 1 + # return loss & output + if outputs is not None: + outputs = merge_batch(outputs) + return {"loss": accum_loss, "outputs": outputs} + def run_forward_backward( self, model_chunk: Union[ModuleList, Module], @@ -704,7 +752,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): optimizer: Optional[OptimizerWrapper] = None, return_loss: bool = False, return_outputs: bool = False, - ): + ) -> Dict: """ Runs Zerobubble schedule, with communication between pipeline stages. """ @@ -770,3 +818,37 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if outputs is not None: outputs = merge_batch(outputs) return {"loss": accum_loss, "outputs": outputs} + + def forward_backward_step( + self, + model_chunk: Union[ModuleList, Module], + data_iter: Iterable, + criterion: Callable[..., Any], + optimizer: Optional[OptimizerWrapper] = None, + return_loss: bool = False, + return_outputs: bool = False, + ) -> dict: + """ + Args: + model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification + data_iter (Iterable): Data iterator. + criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. + optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. + return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. + return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. + + Returns: + dict: A dict with keys: 'loss' and 'outputs'. + """ + self.forward_only = not torch.is_grad_enabled() + if optimizer is None: + assert self.forward_only, "Optimizer should be passed when doing backward." + + if self.forward_only: + result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs) + else: + result = self.run_forward_backward( + model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs + ) + + return result diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 8086f4b7d..8c869ae52 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -644,10 +644,10 @@ def run_fwd_bwd_vschedule_with_optim( graph = PipelineGraph( n_stage=world_size, n_micro=num_microbatch, - f_cost=6, - b_cost=6, - w_cost=6, - c_cost=6, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, f_mem=mem_f, b_mem=mem_b, w_mem=mem_w, @@ -714,7 +714,7 @@ def run_fwd_bwd_vschedule_with_optim( ) torch.cuda.synchronize() - result = scheduler.run_forward_backward( + result = scheduler.forward_backward_step( model_chunk=local_chunk, data_iter=iter(data_iter), criterion=criterion, @@ -793,6 +793,25 @@ def run_fwd_bwd_vschedule_with_optim( assert val_base[:2] == val_pp +# 4) support Hybrid base 3) +def run_with_hybrid( + rank: int, + world_size: int, + port: int, + num_microbatch: int, + batch_size: int, + num_model_chunk: int, +): + pass + + +# 5) support MoE base 3) + +# 6) support booster & Hybrid base 4) + +# 6) support booster & MoE base 4) + + @pytest.mark.dist @pytest.mark.parametrize("num_microbatch", [4]) @pytest.mark.parametrize("batch_size", [4])