From 422544222fad45990a99fbd96617062cbd6b542a Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 27 Jun 2023 16:17:01 +0800 Subject: [PATCH] [pipeline] add stage manager (#4093) * [pipeline] add stage manager * [test] add pipeline stage manager test * [pipeline] add docstring for stage manager --- colossalai/pipeline/stage_manager.py | 176 ++++++++++++++++++++++ tests/test_pipeline/test_stage_manager.py | 86 +++++++++++ 2 files changed, 262 insertions(+) create mode 100644 colossalai/pipeline/stage_manager.py create mode 100644 tests/test_pipeline/test_stage_manager.py diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py new file mode 100644 index 000000000..fe228e227 --- /dev/null +++ b/colossalai/pipeline/stage_manager.py @@ -0,0 +1,176 @@ +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple + +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from colossalai.cluster import ProcessGroupMesh + + +class PipelineStageManager: + """PipelineStageManager is a helper class to manage pipeline stages. + + Args: + pg_mesh (ProcessGroupMesh): Process group mesh. + pipeline_axis (int): The axis along which the pipeline is constructed. + + Attributes: + num_stages (int): Number of stages in the pipeline. + stage (int): The current stage. + num_virtual_stages (int): Number of virtual stages in the pipeline. + virtual_stage (int): The current virtual stage. + """ + + def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None: + self.pg_mesh = pg_mesh + self.pipeline_axis = pipeline_axis + self.num_virtual_stages: Optional[int] = None + self.virtual_stage: Optional[int] = None + self.prev_rank: Optional[Tuple[int, ...]] = None + self.next_rank: Optional[Tuple[int, ...]] = None + self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} + # init prev and next coord + coord = self.pg_mesh.coordinate() + if self.stage > 0: + prev_coord = coord[: self.pipeline_axis] + \ + (coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:] + self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape) + if self.stage < self.num_stages - 1: + next_coord = coord[: self.pipeline_axis] + \ + (coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:] + self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape) + + # init p2p process groups + stages = list(range(self.num_stages)) + for prev, cur in zip(stages[:-1], stages[1:]): + group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur]) + if self.stage in [prev, cur]: + ranks_in_group = self.pg_mesh.get_ranks_in_group(group) + self.p2p_groups[tuple(ranks_in_group)] = group + + def is_first_stage(self, virtual: bool = False) -> bool: + """Is the current stage the first stage. + + Args: + virtual (bool, optional): Whether to consider virtual stages. Defaults to False. + + Returns: + bool: Whether the current stage is the first stage. + """ + if virtual: + assert self.num_virtual_stages is not None + return self.virtual_stage == 0 + return self.stage == 0 + + def is_last_stage(self, virtual: bool = False) -> bool: + """Is the current stage the last stage. + + Args: + virtual (bool, optional): Whether to consider virtual stages. Defaults to False. + + Returns: + bool: Whether the current stage is the last stage. + """ + if virtual: + assert self.num_virtual_stages is not None + return self.virtual_stage == self.num_virtual_stages - 1 + return self.stage == self.num_stages - 1 + + @property + def num_stages(self) -> int: + """Number of stages in the pipeline. + + Returns: + int: Number of stages in the pipeline. + """ + return self.pg_mesh.size(self.pipeline_axis) + + @property + def stage(self) -> int: + """Current stage. + + Returns: + int: Current stage. + """ + return self.pg_mesh.coordinate(self.pipeline_axis) + + def get_rank(self) -> int: + """Get the rank of the current process. + + Returns: + int: Rank of the current process. + """ + return dist.get_rank() + + def get_prev_rank(self) -> int: + """Get the rank of the previous stage. + + Returns: + int: Rank of the previous stage. + """ + assert not self.is_first_stage(), "Cannot get previous rank in the first stage." + return self.prev_rank + + def get_next_rank(self) -> int: + """Get the rank of the next stage. + + Returns: + int: Rank of the next stage. + """ + assert not self.is_last_stage(), "Cannot get next rank in the last stage." + return self.next_rank + + def set_num_virtual_stages(self, num_virtual_stages: int) -> None: + """Set the number of virtual stages. + + Args: + num_virtual_stages (int): Number of virtual stages. + """ + self.num_virtual_stages = num_virtual_stages + + def set_virtual_stage(self, virtual_stage: int) -> None: + """Set the virtual stage. + + Args: + virtual_stage (int): Virtual stage. + """ + self.virtual_stage = virtual_stage + + @contextmanager + def switch_virtual_stage(self, virtual_stage: int) -> None: + """A context manager to switch virtual stage. + + Args: + virtual_stage (int): Target virtual stage. + """ + old_stage = self.virtual_stage + try: + self.set_virtual_stage(virtual_stage) + yield + finally: + self.set_virtual_stage(old_stage) + + def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup: + """Get the p2p process group between two ranks. The order of the two ranks does not matter. + + Args: + first_rank (int): The first rank. + second_rank (int): The second rank. + + Returns: + ProcessGroup: P2P process group between the two ranks. + """ + if first_rank > second_rank: + first_rank, second_rank = second_rank, first_rank + return self.p2p_groups[(first_rank, second_rank)] + + def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup: + """Get the process group of the given stages. + + Args: + stages (List[int]): List of stages. + + Returns: + ProcessGroup: Process group of the given stages. + """ + return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages) diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py new file mode 100644 index 000000000..b920f88db --- /dev/null +++ b/tests/test_pipeline/test_stage_manager.py @@ -0,0 +1,86 @@ +import pytest +import torch.distributed as dist + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import spawn + + +def check_stage_manager(): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + # check stage info + assert stage_manager.num_stages == PP_SIZE + assert stage_manager.stage == RANK_TO_COORDINATE[rank][PP_DIM] + + # check is_first_stage + ranks_in_group = PP_RANKS_IN_GROUP[rank] + is_first_stage = ranks_in_group.index(rank) == 0 + assert stage_manager.is_first_stage() == is_first_stage + + # check is_last_stage + is_last_stage = ranks_in_group.index(rank) == len(ranks_in_group) - 1 + assert stage_manager.is_last_stage() == is_last_stage + + # check prev rank + if not is_first_stage: + prev_rank = ranks_in_group[ranks_in_group.index(rank) - 1] + assert stage_manager.get_prev_rank() == prev_rank + + # check next rank + if not is_last_stage: + next_rank = ranks_in_group[ranks_in_group.index(rank) + 1] + assert stage_manager.get_next_rank() == next_rank + + # check virtual stage + stage_manager.set_num_virtual_stages(PP_SIZE * 2) + assert stage_manager.num_virtual_stages == PP_SIZE * 2 + stage_manager.set_virtual_stage(stage_manager.stage * 2) + assert stage_manager.virtual_stage == stage_manager.stage * 2 + with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1): + assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1 + assert stage_manager.virtual_stage == stage_manager.stage * 2 + + # check p2p groups + for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]): + if rank in [prev, cur]: + group = stage_manager.get_p2p_process_group(prev, cur) + dist.barrier(group=group) + + # check stage groups + pg_mesh = ProcessGroupMesh(4) + stage_manager = PipelineStageManager(pg_mesh, 0) + group = stage_manager.init_process_group_by_stages([0, 2]) + if rank in [0, 2]: + dist.barrier(group=group) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_stage_manager() + + +@pytest.mark.dist +def test_process_group_mesh(): + spawn(run_dist, 4) + + +if __name__ == '__main__': + test_process_group_mesh()