From 8b37323f16a5329742066b466088f8ab9cf66a47 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 27 Aug 2024 09:31:38 +0000 Subject: [PATCH] [feat] add run_fwd_bwd_with_microbatch (replace input) & test; add p&p.grad assert close test & all pass; --- .../pipeline/schedule/zero_bubble_pp.py | 10 +- .../test_schedule/test_zerobubble_pp.py | 265 ++++++++++++++++-- 2 files changed, 247 insertions(+), 28 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index b2d9f00cf..02ecf5b19 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -495,7 +495,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): scheduled_node, model_chunk: torch.nn.ModuleList, model_chunk_id: int, - input_obj: Optional[dict], criterion: Callable, accum_loss: Optional[torch.Tensor] = None, outputs: Optional[List[Any]] = None, @@ -506,7 +505,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): scheduled_node: model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; - input_obj (Optional[dict]): x; criterion (Callable): loss function; accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. @@ -518,7 +516,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if model_chunk_id == 0: # is first stage; get input from func param if self.stage_manager.is_first_stage(ignore_chunk=True): - input_obj = input_obj + input_obj = self.load_micro_batch(model_chunk_id=model_chunk_id) else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) else: @@ -671,7 +669,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): def run_forward_backward( self, model_chunk: Union[ModuleList, Module], - input_obj: Optional[dict], data_iter: Iterable, criterion: Callable[..., Any], optimizer: Optional[OptimizerWrapper] = None, @@ -683,7 +680,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): """ # # prepare batch self.load_batch(data_iter) - # print(f"self.batch {self.batch}; self.batch_size {self.batch_size}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}") + print( + f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}" + ) it = 0 # while we still have schedules_node in self.schedules @@ -707,7 +706,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): scheduled_node=scheduled_node, model_chunk=model_chunk, model_chunk_id=scheduled_node.chunk, - input_obj=input_obj, criterion=criterion, accum_loss=return_loss, outputs=return_outputs, diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 99c8fcf0f..40aedfa47 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -36,8 +36,8 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: return num_params, num_params_trainable -# Test run_forward_backward with baseline; -def test_run_fwd_bwd_base( +# Test iter input & multiple microbatch +def test_run_fwd_bwd_iter_input( rank: int, world_size: int, port: int, @@ -47,7 +47,7 @@ def test_run_fwd_bwd_base( rank = dist.get_rank() pp_size = world_size pg_mesh = ProcessGroupMesh(pp_size) - + num_microbatch = 4 # stage_manager stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size) @@ -55,6 +55,7 @@ def test_run_fwd_bwd_base( zbv_schedule = [ # stage 0 [ + # microbatch 0 # chunk 0 fwd ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=0), ScheduledNode(type="F", chunk=0, stage=0, minibatch=0), @@ -73,9 +74,67 @@ def test_run_fwd_bwd_base( ScheduledNode(type="B", chunk=0, stage=0, minibatch=0), ScheduledNode(type="W", chunk=0, stage=0, minibatch=0), ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=0, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=0, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=0, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=0, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=0, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3), ], # stage 1 [ + # microbatch 0 # chunk 0 fwd ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=0), ScheduledNode(type="F", chunk=0, stage=1, minibatch=0), @@ -94,9 +153,67 @@ def test_run_fwd_bwd_base( ScheduledNode(type="B", chunk=0, stage=1, minibatch=0), ScheduledNode(type="W", chunk=0, stage=1, minibatch=0), ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=1, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=1, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=1, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=1, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=1, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=0, minibatch=3), ], # stage 2 [ + # microbatch 0 # chunk 0 fwd ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=0), ScheduledNode(type="F", chunk=0, stage=2, minibatch=0), @@ -114,10 +231,68 @@ def test_run_fwd_bwd_base( ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=0), ScheduledNode(type="B", chunk=0, stage=2, minibatch=0), ScheduledNode(type="W", chunk=0, stage=2, minibatch=0), - ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0), # Send nothing + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=2, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=2, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=2, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=2, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=2, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=2, minibatch=3), ], # stage 3 [ + # microbatch 0 # chunk 0 fwd ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=0), ScheduledNode(type="F", chunk=0, stage=3, minibatch=0), @@ -136,6 +311,63 @@ def test_run_fwd_bwd_base( ScheduledNode(type="B", chunk=0, stage=3, minibatch=0), ScheduledNode(type="W", chunk=0, stage=3, minibatch=0), ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=0), + # microbatch 1 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=1), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=1), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=1), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=1), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=1), + # microbatch 2 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=2), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=2), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=2), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=2), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=2), + # microbatch 3 + # chunk 0 fwd + ScheduledNode(type="RECV_FORWARD", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="F", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=0, stage=3, minibatch=3), + # chunk 1 fwd + ScheduledNode(type="RECV_FORWARD", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="F", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="SEND_FORWARD", chunk=1, stage=3, minibatch=3), + # chunk 1 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="B", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="W", chunk=1, stage=3, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=1, stage=3, minibatch=3), + # chunk 0 bwd + ScheduledNode(type="RECV_BACKWARD", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="B", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="W", chunk=0, stage=3, minibatch=3), + ScheduledNode(type="SEND_BACKWARD", chunk=0, stage=3, minibatch=3), ], ] @@ -143,7 +375,7 @@ def test_run_fwd_bwd_base( schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ? stage_manager=stage_manager, num_model_chunks=pp_size, - num_microbatch=1, + num_microbatch=num_microbatch, overlap_p2p=False, ) @@ -152,14 +384,15 @@ def test_run_fwd_bwd_base( return (x * x).mean() # init model and input + batch_size = 4 num_layers = 8 in_dim = out_dim = 8 print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) input0 = torch.rand(in_dim, out_dim, requires_grad=True).to(rank) - # data_iter = [input0] + data_iter = [torch.rand(batch_size, in_dim, out_dim, requires_grad=True).to(rank)] - input_base = input0.clone() + [t.clone() for t in data_iter] model_base = deepcopy(model) if rank == 0: @@ -193,9 +426,7 @@ def test_run_fwd_bwd_base( torch.cuda.synchronize() scheduler.run_forward_backward( model_chunk=local_chunk, - input_obj=input0, - # data_iter=iter(data_iter), - data_iter=None, + data_iter=iter(data_iter), criterion=criterion, optimizer=None, return_loss=None, @@ -206,8 +437,7 @@ def test_run_fwd_bwd_base( # Fwd bwd for base ########################## # fwd & bwd - output_base = model_base(input_base) - # loss_base = output_base.mean() + output_base = model_base(data_iter[0]) loss_base = criterion(output_base) loss_base.backward() print(f"After base fwd & bwd: {torch.cuda.memory_allocated()/1024**3 :.3f} GB;") @@ -245,15 +475,6 @@ def test_run_fwd_bwd_base( assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) -# Test iter input & multiple microbatch -def test_run_fwd_bwd_iter_input( - rank: int, - world_size: int, - port: int, -): - pass - - @pytest.mark.dist # @pytest.mark.parametrize("num_microbatch", [4]) # @pytest.mark.parametrize("batch_size", [4]) @@ -261,7 +482,7 @@ def test_run_fwd_bwd_iter_input( @rerun_if_address_is_in_use() def test_pp(): spawn( - test_run_fwd_bwd_base, + test_run_fwd_bwd_iter_input, nprocs=4, )