[shardformer] refactor pipeline grad ckpt config (#5646)

* [shardformer] refactor pipeline grad ckpt config

* [shardformer] refactor pipeline grad ckpt config

* [pipeline] fix stage manager
This commit is contained in:
Hongxin Liu
2024-04-25 15:19:30 +08:00
committed by GitHub
parent 7ef91606e1
commit 1b387ca9fe
11 changed files with 59 additions and 102 deletions

View File

@@ -27,16 +27,18 @@ class PipelineStageManager:
pipeline_axis: int,
enable_interleave: bool = False,
num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None,
) -> None:
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
self.num_layers_per_stage = None
self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
if num_layers_per_stage is not None:
assert len(num_layers_per_stage) == self.num_stages
self.num_layers_per_stage = num_layers_per_stage
# init prev and next coord
coord = self.pg_mesh.coordinate()
@@ -56,6 +58,8 @@ class PipelineStageManager:
self.p2p_groups[tuple(ranks_in_group)] = group
self.is_interleave = enable_interleave
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
if enable_interleave:
# use circle p2p communication
# add the process group of the first rank and the last rank
@@ -64,59 +68,11 @@ class PipelineStageManager:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
# for shardformer, hold stage indices of model
self.stage_indices: List[Tuple[int, int]]
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None
@property
def control_distribute_layers(self) -> bool:
return self.num_layers_per_stage is not None
def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
"""Set the distribution configuration.
This allows user to customize the number of layers for each stage.
Args:
num_model_layers (int): Number of layers in the model.
num_layers_per_stage (List[int]): Number of layers for each stage.
"""
assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
assert sum(num_layers_per_stage) == num_model_layers
assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
self.num_model_layers = num_model_layers
self.num_layers_per_stage = num_layers_per_stage
def distribute_layers(
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
) -> List[int]:
"""Divide layers into stages"""
num_stages = self.num_stages if num_stages is None else num_stages
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)
if self.control_distribute_layers:
assert num_layers == self.num_model_layers
return self.num_layers_per_stage
else:
quotient = num_layers // (num_stages * num_model_chunks)
remainder = num_layers % (num_stages * num_model_chunks)
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks
# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
def get_stage_index(
self,
layers_per_stage: List[int],
@@ -139,9 +95,7 @@ class PipelineStageManager:
"""
stage = self.stage if stage is None else stage
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
num_stages = self.num_stages if num_stages is None else num_stages
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
@@ -261,3 +215,25 @@ class PipelineStageManager:
self.model_chunk_id = model_chunk_id
yield
self.model_chunk_id = old_model_chunk_id
def distribute_layers(
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
) -> List[int]:
if self.num_layers_per_stage is not None:
assert sum(self.num_layers_per_stage) == num_layers
return self.num_layers_per_stage
num_stages = self.num_stages if num_stages is None else num_stages
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
quotient = num_layers // (num_stages * num_model_chunks)
remainder = num_layers % (num_stages * num_model_chunks)
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks
# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage