[pipeline] add stage manager (#4093)

* [pipeline] add stage manager

* [test] add pipeline stage manager test

* [pipeline] add docstring for stage manager
This commit is contained in:
Hongxin Liu
2023-06-27 16:17:01 +08:00
parent 5e1a9d48dd
commit 422544222f
2 changed files with 262 additions and 0 deletions

View File

@@ -0,0 +1,86 @@
import pytest
import torch.distributed as dist
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import spawn
def check_stage_manager():
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
# check stage info
assert stage_manager.num_stages == PP_SIZE
assert stage_manager.stage == RANK_TO_COORDINATE[rank][PP_DIM]
# check is_first_stage
ranks_in_group = PP_RANKS_IN_GROUP[rank]
is_first_stage = ranks_in_group.index(rank) == 0
assert stage_manager.is_first_stage() == is_first_stage
# check is_last_stage
is_last_stage = ranks_in_group.index(rank) == len(ranks_in_group) - 1
assert stage_manager.is_last_stage() == is_last_stage
# check prev rank
if not is_first_stage:
prev_rank = ranks_in_group[ranks_in_group.index(rank) - 1]
assert stage_manager.get_prev_rank() == prev_rank
# check next rank
if not is_last_stage:
next_rank = ranks_in_group[ranks_in_group.index(rank) + 1]
assert stage_manager.get_next_rank() == next_rank
# check virtual stage
stage_manager.set_num_virtual_stages(PP_SIZE * 2)
assert stage_manager.num_virtual_stages == PP_SIZE * 2
stage_manager.set_virtual_stage(stage_manager.stage * 2)
assert stage_manager.virtual_stage == stage_manager.stage * 2
with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1):
assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1
assert stage_manager.virtual_stage == stage_manager.stage * 2
# check p2p groups
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
if rank in [prev, cur]:
group = stage_manager.get_p2p_process_group(prev, cur)
dist.barrier(group=group)
# check stage groups
pg_mesh = ProcessGroupMesh(4)
stage_manager = PipelineStageManager(pg_mesh, 0)
group = stage_manager.init_process_group_by_stages([0, 2])
if rank in [0, 2]:
dist.barrier(group=group)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_stage_manager()
@pytest.mark.dist
def test_process_group_mesh():
spawn(run_dist, 4)
if __name__ == '__main__':
test_process_group_mesh()