diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 271b3238f..a2215d0fc 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -131,6 +131,16 @@ def retain_grad(x: Any) -> None: x.retain_grad() +def require_grad(x: Any) -> None: + """Call require_grad on a tensor. + + Args: + x (Any): Object to be called. + """ + if isinstance(x, torch.Tensor) and x.requires_grad: + x.requires_grad_() + + def detach(x: Any) -> Any: """Call detach() on a tensor. @@ -145,6 +155,34 @@ def detach(x: Any) -> Any: return x +def clone(x: Any) -> Any: + """Call clone() on a tensor. + + Args: + x (Any): Object to be called. + + Returns: + Any: The cloned object. + """ + if isinstance(x, torch.Tensor): + return x.clone() + return x + + +def deallocate(x: Any) -> Any: + """Call deallocate() on a tensor. + + Args: + x (Any): Object to be called. + + Returns: + Any: The deallocate .data object. + """ + if isinstance(x, torch.Tensor): + return x.data.untyped_storage().resize_(0) + return x + + def merge_batch(data: List[Any], batch_size_dim=0) -> Any: """Merge micro batches into a batch. diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index c1c4f13c6..365125ba3 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -12,7 +12,7 @@ from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager -from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device +from ._utils import clone, detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} @@ -39,6 +39,20 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): out.data.untyped_storage().resize_(0) +def require_grad(tensor): + """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. + + This method should be called right after the output tensor has been + sent to the next pipeline stage. At this point, the output tensor is + only useful for its '.grad_fn' field, and not its '.data'. + """ + if tensor is None: + return + assert isinstance(tensor, torch.Tensor), "expected Tensor, found %s." % type(tensor).__name__ + assert tensor._base is None, "counter-productive to free a view of another tensor." + tensor.requires_grad_() + + class ZeroBubbleVPipeScheduler(PipelineSchedule): def __init__( self, @@ -409,6 +423,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): self, model_chunk: Union[ModuleList, Module], model_chunk_id: int, + micro_batch: Optional[dict], input_obj: Optional[dict], criterion: Callable, accum_loss: Optional[torch.Tensor] = None, @@ -427,18 +442,27 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). """ # Load input ids, attention mask and labels - # micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) - - # for the first stage, input_obj is None + # for the first stage, input_obj is None; So,we use micro_batch as input_obj # for other stages, input_obj is the output of the previous/next stage containing hidden_states etc. # Only attention_mask from micro_batch is used - with self.stage_manager.switch_model_chunk_id(model_chunk_id): - # fwd calculate - output_obj = model_chunk[model_chunk_id](input_obj) + # fwd calculate + if isinstance(model_chunk, ModuleList): + # fwd for ModuleList model + if input_obj is None: + output_obj = model_chunk[model_chunk_id](**micro_batch) + else: + output_obj = model_chunk[model_chunk_id](**input_obj) + else: + # fwd for shardformer + # NOTE: in shardformer, each device still has the entire model, so we need to use relevant stage layers + internal_inputs = {} if input_obj is None else input_obj + # internal_inputs["stage_index"] = self.stage_manager.stage_indices[model_chunk_id] + output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, internal_inputs) + # last layer in model if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): - loss = criterion(output_obj) / self.num_microbatch + loss = criterion(output_obj, micro_batch) / self.num_microbatch if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: @@ -472,19 +496,25 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # calculate bwd b step ; only dx = w*dy; # Retain the grad on the input_obj. - tree_map(retain_grad, input_obj) + if input_obj is None: + return None + else: + tree_map(retain_grad, input_obj) + input_obj_ = input_obj["hidden_states"] if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss; so output_obj_grad should be None assert output_obj_grad is None - + output_obj_ = output_obj + else: + output_obj_ = output_obj["hidden_states"] optimizer.backward_by_grad( - tensor=output_obj, + tensor=output_obj_, grad=output_obj_grad, - inputs=input_obj, + inputs=input_obj_, retain_graph=True, ) - return input_obj.grad + return input_obj_.grad def backward_w_step( self, @@ -511,8 +541,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss output_obj_grad = None + output_obj_ = output_obj + else: + output_obj_ = output_obj["hidden_states"] optimizer.backward_by_grad( - tensor=output_obj, + tensor=output_obj_, grad=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()), retain_graph=False, @@ -543,9 +576,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) # Step1: recv fwd if model_chunk_id == 0: - # is first stage; get input from func param + # is first stage; get input from microbatch if self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj = micro_batch + input_obj = None else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) else: @@ -557,45 +590,68 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) # Here, let input_obj.requires_grad_() - tree_map(torch.Tensor.requires_grad_, input_obj) + if input_obj is not None: + tree_map(require_grad, input_obj) + + # Also requires_grad_ for micro_batch in stage 0 chunk 0 fwd, + # tree_map(torch.Tensor.requires_grad_, micro_batch) # Step2: fwd step output_obj = self.forward_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, + micro_batch=micro_batch, input_obj=input_obj, criterion=criterion, accum_loss=accum_loss, outputs=outputs, ) + + # Step3: deallocate output for bwd b & w; (do not detach output) + deallocate_output_obj = tree_map(clone, output_obj) + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # We should not deallocate bwd LOSS + pass + else: + # deallocate output + tree_map(partial(deallocate_output_tensor, deallocate_pipeline_outputs=True), deallocate_output_obj) + + # add input and output object for backward b + if input_obj is not None: + self.input_tensors[model_chunk_id].append(input_obj) + else: + self.input_tensors[model_chunk_id].append(micro_batch) + + # for bwd b&w, we only need the graph(grad_fn) of output_obj + # Do not deallocate loss, deallocate other output_obj; + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + self.output_tensors[model_chunk_id].append(deallocate_output_obj) + self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj) + else: + self.output_tensors[model_chunk_id].append(deallocate_output_obj) + self.output_tensors_dw[model_chunk_id].append(deallocate_output_obj) + + # Step4: detach output for send fwd; if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # We should not detach bwd LOSS pass else: - detached_output_obj = output_obj.clone().detach() + # detach output + output_obj = tree_map(detach, output_obj) - # Step3: send fwd # add output to send_fwd_buffer - if model_chunk_id == 0: + if model_chunk_id == 0: # chunk 0 # is last stage; send to local_send_forward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): - self.local_send_forward_buffer.append(detached_output_obj) + self.local_send_forward_buffer.append(output_obj) else: - self.send_forward_buffer[model_chunk_id].append(detached_output_obj) - else: - # is first stage; end of fwd; append LOSS to local_send_backward_buffer + self.send_forward_buffer[model_chunk_id].append(output_obj) + else: # chunk 1 + # is first stage; end of fwd; do nothing if self.stage_manager.is_first_stage(ignore_chunk=True): pass else: - self.send_forward_buffer[model_chunk_id].append(detached_output_obj) - - # add input and output object for backward b - self.input_tensors[model_chunk_id].append(input_obj) - # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj - deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True) - self.output_tensors[model_chunk_id].append(output_obj) - # add output object for backward w - self.output_tensors_dw[model_chunk_id].append(output_obj) + self.send_forward_buffer[model_chunk_id].append(output_obj) def schedule_b( self, @@ -603,9 +659,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, - # input_obj: Optional[dict], - # output_obj: Union[dict, torch.Tensor], - # output_obj_grad: Optional[dict], ): """A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd; @@ -616,20 +669,19 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): Returns: Nothing. """ - # Step1: recv bwd if model_chunk_id == 0: # chunk0 is last stage; recv output_grad from local_send_backward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): output_tensor_grad = self.local_send_backward_buffer.pop(0) - # chunk 0 not last stage; recv output_grad from recv_backward_buffer + # chunk0 not last stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) else: # chunk1, is first stage; recv LOSS from local send bwd buffer if self.stage_manager.is_first_stage(ignore_chunk=True): output_tensor_grad = None - # chunk1, not first stage; recv output_grad from recv_backward_buffer + # chunk1, not first stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) @@ -645,7 +697,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # we save output_tensor_grad here self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) - # _wait_p2p(recv_bwd_handles) # Step2: bwd step input_object_grad = self.backward_b_step( model_chunk=model_chunk, @@ -777,8 +828,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # communication communication_func = self.communication_map[scheduled_node.type] communication_func(scheduled_node.chunk) - - if scheduled_node.type == "F": + elif scheduled_node.type == "F": self.schedule_f( scheduled_node=scheduled_node, model_chunk=model_chunk, diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 1e5cdb3e5..43c6293c6 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -1,4 +1,6 @@ from copy import deepcopy +from functools import partial +from types import MethodType from typing import Tuple import pytest @@ -16,7 +18,8 @@ from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from tests.kit.model_zoo import model_zoo + +# from tests.kit.model_zoo import model_zoo class MlpModel(nn.Module): @@ -24,10 +27,32 @@ class MlpModel(nn.Module): super().__init__() self.layers = nn.ModuleList([nn.Linear(in_dim, out_dim, bias=None) for _ in range(num_layers)]) - def forward(self, x): + def forward( + self, + hidden_states, + ): for layer in self.layers: - x = layer(x) - return x + hidden_states = layer(hidden_states) + return hidden_states + + +def pp_linear_fwd( + forward, + data: torch.Tensor = None, + hidden_states: torch.Tensor = None, + stage_mgr: PipelineStageManager = None, + model_chunk_id: int = None, +): + with stage_mgr.switch_model_chunk_id(model_chunk_id): + # fwd end + if stage_mgr.is_first_stage() and model_chunk_id == 1: + return forward(hidden_states) + # fwd start + elif stage_mgr.is_first_stage() and model_chunk_id == 0: + return {"hidden_states": forward(hidden_states)} + # fwd middle + else: + return {"hidden_states": forward(hidden_states)} def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: @@ -510,15 +535,15 @@ def run_fwd_bwd_iter_input(test_config): "precision": "bf16", "num_model_chunk": 2, }, - { - "batch_size": 8, - "tp_size": 1, - "pp_size": 4, - "num_microbatches": 8, - "zero_stage": 1, - "precision": "bf16", - "num_model_chunk": 2, - }, + # { + # "batch_size": 8, + # "tp_size": 1, + # "pp_size": 4, + # "num_microbatches": 8, + # "zero_stage": 1, + # "precision": "bf16", + # "num_model_chunk": 2, + # }, ], ) def run_fwd_bwd_vschedule_with_optim(test_config): @@ -562,6 +587,10 @@ def run_fwd_bwd_vschedule_with_optim(test_config): # init loss func def criterion(x, *args, **kwargs): + x = x["hidden_states"] + return (x * x).mean() + + def criterion_base(x, *args, **kwargs): return (x * x).mean() # init model and input @@ -572,9 +601,10 @@ def run_fwd_bwd_vschedule_with_optim(test_config): before_init_memory = torch.cuda.memory_allocated() / 1024**3 print(f"Before init Model: {before_init_memory :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) - data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] - - input_base = [t.clone() for t in data_iter] + # data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] + data_iter = {"hidden_states": torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)} + # input_base = [t.clone() for t in data_iter] + input_base = {k: v.clone() for k, v in data_iter.items()} model_base = deepcopy(model) if rank == 0: @@ -582,24 +612,44 @@ def run_fwd_bwd_vschedule_with_optim(test_config): local_chunk = torch.nn.ModuleList().to(rank) for idx, sub_model in enumerate(model.layers): if idx == 0 or idx == 7: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), + sub_model._forward, + ) local_chunk.append(sub_model) elif rank == 1: # layer 1 & 6 to chunk 1 on rank1 local_chunk = torch.nn.ModuleList().to(rank) for idx, sub_model in enumerate(model.layers): if idx == 1 or idx == 6: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), + sub_model._forward, + ) local_chunk.append(sub_model) elif rank == 2: # layer 2 & 5 to chunk 2 on rank2 local_chunk = torch.nn.ModuleList().to(rank) for idx, sub_model in enumerate(model.layers): if idx == 2 or idx == 5: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), + sub_model._forward, + ) local_chunk.append(sub_model) else: # layer 3 & 4 to chunk 3 on rank3 local_chunk = torch.nn.ModuleList().to(rank) for idx, sub_model in enumerate(model.layers): if idx == 3 or idx == 4: + sub_model._forward = sub_model.forward + sub_model.forward = MethodType( + partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(local_chunk)), + sub_model._forward, + ) local_chunk.append(sub_model) # init optimizer @@ -612,7 +662,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config): torch.cuda.synchronize() result = scheduler.forward_backward_step( model_chunk=local_chunk, - data_iter=iter(data_iter), + data_iter=iter([data_iter]), criterion=criterion, optimizer=optimizer_pp, return_loss=True, @@ -643,8 +693,8 @@ def run_fwd_bwd_vschedule_with_optim(test_config): # Fwd bwd for base ########################## # fwd & bwd - output_base = model_base(input_base[0]) - loss_base = criterion(output_base) + output_base = model_base(input_base["hidden_states"]) + loss_base = criterion_base(output_base) loss_base.backward() optimizer_base.step() @@ -654,7 +704,7 @@ def run_fwd_bwd_vschedule_with_optim(test_config): # only chunk 1 stage 0 hold loss and output if rank == 0: assert_close(result["loss"], loss_base) - assert_close(result["outputs"], output_base) + assert_close(result["outputs"]["hidden_states"], output_base) # print(f"pp result {result}; base result loss:{loss_base} output_base:{output_base} ") ########################## @@ -727,6 +777,7 @@ def run_with_hybridplugin(test_config): { "pp_style": "zbv", "tp_size": 1, + "ep_size": 1, "pp_size": 4, "num_microbatches": 4, "zero_stage": 1, @@ -737,7 +788,7 @@ def run_with_hybridplugin(test_config): ) def run_with_moehybridplugin(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") - test_config["use_lazy_init"] = False + # test_config["use_lazy_init"] = False test_config["initial_scale"] = 2**16 model_list = [ "transformers_bert", @@ -749,6 +800,7 @@ def run_with_moehybridplugin(test_config): # base param model = model_fn() data = data_gen_fn() + print(f"data {data}") criterion = loss_fn optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5) @@ -787,7 +839,7 @@ def run_with_moehybridplugin(test_config): # plugin = MoeHybridParallelPlugin( # **test_config # ) - # model_pp, optimizer_pp, criterion, data_pp = plugin.configure( + # model_pp, optimizer_pp, criterion, data_pp, _ = plugin.configure( # model = model_pp, # optimizer = optimizer_pp, # criterion = criterion, @@ -806,16 +858,34 @@ def run_with_moehybridplugin(test_config): # TODO:6) support booster & Hybrid base 4) + # TODO:7) support booster & MoEHybrid base 4) +@parameterize( + "test_config", + [ + { + "pp_style": "zbv", + "tp_size": 1, + "ep_size": 1, + "pp_size": 4, + "num_microbatches": 4, + "zero_stage": 1, + "precision": "bf16", + "num_model_chunks": 2, + }, + ], +) +def run_with_booster_moehybridplugin(test_config): + pass def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # run_fwd_bwd_iter_input() - # run_fwd_bwd_vschedule_with_optim() + run_fwd_bwd_vschedule_with_optim() # run_with_moehybridplugin() - run_with_moehybridplugin() + # run_with_booster_moehybridplugin() @pytest.mark.dist