diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index b589579c3..7534435a4 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -440,9 +440,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True) else: # commom bwd step - # print(f"bwd output_obj {output_obj} output_obj_grad {output_obj_grad} input_obj {input_obj}") # BUG:output_obj_grad is None - # print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; tensor {output_obj};\n grad_tensors {output_obj_grad};\n inputs {input_obj}\n") torch.autograd.backward( tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True ) @@ -516,7 +514,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): input_obj = input_obj else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) - else: # is last stage; recv from local if self.stage_manager.is_last_stage(ignore_chunk=True): @@ -535,8 +532,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): outputs=outputs, ) - # print(f"model_chunk_id {model_chunk_id} fwd output_obj {output_obj}") - # add input and output object for backward b self.input_tensors[model_chunk_id].append(input_obj) self.output_tensors[model_chunk_id].append(output_obj) @@ -681,7 +676,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): """ it = 0 # while we still have schedules_node in self.schedules - # print(f"manger_stage {self.stage_manager.stage} schedule {self.schedules} \n") while it < len(self.schedules): scheduled_node = self.schedules[it] print( diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 74fa3358f..15897f73d 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -1,6 +1,7 @@ from copy import deepcopy from typing import Tuple +import pytest import torch import torch.distributed as dist import torch.nn as nn @@ -139,7 +140,7 @@ def test_run_fwd_bwd_base( ] scheduler = ZeroBubbleVPipeScheduler( - schedule=zbv_schedule[rank], + schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ? stage_manager=stage_manager, num_model_chunks=pp_size, num_microbatch=1, @@ -226,7 +227,6 @@ def test_run_fwd_bwd_base( # layer 6 assert_close(local_chunk[1].weight, model_base.layers[6].weight) assert_close(local_chunk[1].weight.grad, model_base.layers[6].weight.grad) - if rank == 2: # layer 2 assert_close(local_chunk[0].weight, model_base.layers[2].weight) @@ -234,7 +234,6 @@ def test_run_fwd_bwd_base( # layer 5 assert_close(local_chunk[1].weight, model_base.layers[5].weight) assert_close(local_chunk[1].weight.grad, model_base.layers[5].weight.grad) - if rank == 3: # layer 3 assert_close(local_chunk[0].weight, model_base.layers[3].weight) @@ -244,7 +243,16 @@ def test_run_fwd_bwd_base( assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) -# @pytest.mark.dist +# Test iter input & multiple microbatch +def test_run_fwd_bwd_iter_input( + rank: int, + world_size: int, + port: int, +): + pass + + +@pytest.mark.dist # @pytest.mark.parametrize("num_microbatch", [4]) # @pytest.mark.parametrize("batch_size", [4]) # @pytest.mark.parametrize("num_model_chunk", [2])