[pipeline] 1f1b schedule receive microbatch size (#4589)

This commit is contained in:
Hongxin Liu
2023-09-01 21:45:14 +08:00
committed by GitHub
parent 38ccb8b1a3
commit 508ca36fe3
3 changed files with 30 additions and 7 deletions

View File

@@ -61,7 +61,7 @@ def examine_pp():
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
pg_mesh = ProcessGroupMesh(1, world_size, 1)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
schedule = OneForwardOneBackwardSchedule(NUM_MICRO_BATCHS, stage_manager)
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS)
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
if idx % (world_size) == local_rank: