From c18ef060cfcf868c78d22a132cb144e039050446 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 23 Aug 2024 06:04:12 +0000 Subject: [PATCH] [feat] add dw test; --- .../pipeline/schedule/zero_bubble_pp.py | 36 ++++-- .../test_schedule/test_zerobubble_pp.py | 108 +++++++++++++++++- 2 files changed, 132 insertions(+), 12 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 0cf9bf67a..0fef29446 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -64,8 +64,15 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): def _free_buffers(self): # free local buffer # two dim array, first dim is the model chunk, second dim is the microbatch queue + + # x & y buffer for schedule b self.input_tensors = [[], []] self.output_tensors = [[], []] + + # y & dy buffer for schedule b + self.output_tensors_dw = [[], []] + self.output_tensors_grad_dw = [[], []] + self.send_forward_buffer = [[], []] self.recv_forward_buffer = [[], []] self.send_backward_buffer = [[], []] @@ -467,7 +474,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk: Union[ModuleList, Module], model_chunk_id: int, # optimizer: OptimizerWrapper, - input_obj: Optional[dict], + # input_obj: Optional[dict], output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], ): @@ -479,8 +486,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): else: if self.stage_manager.is_first_stage(ignore_chunk=True): - torch.autograd.backward(output_obj_grad, inputs=list(model=model_chunk[model_chunk_id].parameters())) - + torch.autograd.backward(output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters())) else: torch.autograd.backward( tensors=output_obj, @@ -518,10 +524,13 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ) # print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}") - # add input and output object for backward + # add input and output object for backward b self.input_tensors[model_chunk_id].append(input_obj) self.output_tensors[model_chunk_id].append(output_obj) + # add output object for backward w + self.output_tensors_dw[model_chunk_id].append(output_obj) + # Step3: send fwd send_handles = self.send_forward(model_chunk_id=model_chunk_id, output_tensor=output_obj) @@ -544,10 +553,18 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): output_tensor_grad, recv_bwd_handles = self.recv_backward(model_chunk_id=model_chunk_id) # print(f"recv output_tensor_grad {output_tensor_grad}") - # get input and output object from buffer + # get input and output object from buffer; input_obj = self.input_tensors[model_chunk_id].pop() output_obj = self.output_tensors[model_chunk_id].pop() + # save output_tensor_grad for dw + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + # we save loss here + self.output_tensors_grad_dw[model_chunk_id].append(output_obj) + else: + # we save output_tensor_grad here + self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) + _wait_p2p(recv_bwd_handles) # print(f"input_obj {input_obj} output_obj {output_obj} output_tensor_grad {output_tensor_grad}") # Step2: bwd step @@ -571,15 +588,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk: Union[ModuleList, Module], model_chunk_id: int, # optimizer: OptimizerWrapper, - input_obj: Optional[dict], - output_obj: Union[dict, torch.Tensor], - output_obj_grad: Optional[dict], ): + + # get y & dy from buffer + output_obj = self.output_tensors_dw[model_chunk_id].pop() + output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop() + self.backward_w_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, # optimizer: OptimizerWrapper, - input_obj=input_obj, output_obj=output_obj, output_obj_grad=output_obj_grad, ) diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index fbc4df3ac..bf1fba3c6 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -4,6 +4,7 @@ from typing import Tuple import torch import torch.distributed as dist import torch.nn as nn +from torch.testing import assert_close import colossalai from colossalai.cluster import ProcessGroupMesh @@ -56,13 +57,13 @@ def test_zerobubble_pipeline_base( # init model and input num_layers = 8 - in_dim = out_dim = 2048 + in_dim = out_dim = 8 print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank) - input0.clone() - deepcopy(model) + input_base = input0.clone() + model_base = deepcopy(model) if rank == 0: # layer 0 & 7 to chunk 0 on rank0 @@ -245,6 +246,13 @@ def test_zerobubble_pipeline_base( model_chunk_id=chunk_id, # optimizer: OptimizerWrapper, ) + scheduler.schedule_w( + scheduled_node=None, + non_w_pending=None, + model_chunk=chunk_0, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) # # chunk 1 id 1 (layer 6) bwd if rank == 1: @@ -255,6 +263,13 @@ def test_zerobubble_pipeline_base( model_chunk_id=chunk_id, # optimizer: OptimizerWrapper, ) + scheduler.schedule_w( + scheduled_node=None, + non_w_pending=None, + model_chunk=chunk_1, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) # chunk 2 id 1 (layer 5) bwd if rank == 2: @@ -266,6 +281,14 @@ def test_zerobubble_pipeline_base( # optimizer: OptimizerWrapper, ) + scheduler.schedule_w( + scheduled_node=None, + non_w_pending=None, + model_chunk=chunk_2, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) + # chunk 3 id 1 (layer 4) bwd if rank == 3: chunk_id = 1 @@ -276,6 +299,14 @@ def test_zerobubble_pipeline_base( # optimizer: OptimizerWrapper, ) + scheduler.schedule_w( + scheduled_node=None, + non_w_pending=None, + model_chunk=chunk_3, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) + # ###### # # bwd rank 1->4 # ###### @@ -290,6 +321,13 @@ def test_zerobubble_pipeline_base( # optimizer: OptimizerWrapper, ) # print(f"input_grad3 {input_grad3}") + scheduler.schedule_w( + scheduled_node=None, + non_w_pending=None, + model_chunk=chunk_3, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) # chunk 2 id 0 (layer 2) bwd if rank == 2: @@ -301,6 +339,13 @@ def test_zerobubble_pipeline_base( # optimizer: OptimizerWrapper, ) # print(f"input_grad2 {input_grad2}") + scheduler.schedule_w( + scheduled_node=None, + non_w_pending=None, + model_chunk=chunk_2, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) # chunk 1 id 0 (layer 1) bwd if rank == 1: @@ -312,6 +357,14 @@ def test_zerobubble_pipeline_base( # optimizer: OptimizerWrapper, ) + scheduler.schedule_w( + scheduled_node=None, + non_w_pending=None, + model_chunk=chunk_1, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) + # chunk 0 id 0 (layer 0) bwd if rank == 0: chunk_id = 0 @@ -323,6 +376,55 @@ def test_zerobubble_pipeline_base( ) # print(f"input_grad0 {input_grad0}") + scheduler.schedule_w( + scheduled_node=None, + non_w_pending=None, + model_chunk=chunk_0, + model_chunk_id=chunk_id, + # optimizer: OptimizerWrapper, + ) + + ########################## + # Fwd bwd for base + ########################## + # fwd & bwd + output_base = model_base(input_base) + loss_base = output_base.mean() + loss_base.backward() + print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") + + # assert weight + if rank == 0: + # layer 0 + assert_close(chunk_0[0].weight, model_base.layers[0].weight) + assert_close(chunk_0[0].weight.grad, model_base.layers[0].weight.grad) + # layer 7 + assert_close(chunk_0[1].weight, model_base.layers[7].weight) + assert_close(chunk_0[1].weight.grad, model_base.layers[7].weight.grad) + if rank == 1: + # layer 1 + assert_close(chunk_1[0].weight, model_base.layers[1].weight) + assert_close(chunk_1[0].weight.grad, model_base.layers[1].weight.grad) + # layer 6 + assert_close(chunk_1[1].weight, model_base.layers[6].weight) + assert_close(chunk_1[1].weight.grad, model_base.layers[6].weight.grad) + + if rank == 2: + # layer 2 + assert_close(chunk_2[0].weight, model_base.layers[2].weight) + assert_close(chunk_2[0].weight.grad, model_base.layers[2].weight.grad) + # layer 5 + assert_close(chunk_2[1].weight, model_base.layers[5].weight) + assert_close(chunk_2[1].weight.grad, model_base.layers[5].weight.grad) + + if rank == 3: + # layer 3 + assert_close(chunk_3[0].weight, model_base.layers[3].weight) + assert_close(chunk_3[0].weight.grad, model_base.layers[3].weight.grad) + # layer 4 + assert_close(chunk_3[1].weight, model_base.layers[4].weight) + assert_close(chunk_3[1].weight.grad, model_base.layers[4].weight.grad) + # @pytest.mark.dist # @pytest.mark.parametrize("num_microbatch", [4])