mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[shardformer] refactor pipeline grad ckpt config (#5646)
* [shardformer] refactor pipeline grad ckpt config * [shardformer] refactor pipeline grad ckpt config * [pipeline] fix stage manager
This commit is contained in:
@@ -27,16 +27,18 @@ class PipelineStageManager:
|
||||
pipeline_axis: int,
|
||||
enable_interleave: bool = False,
|
||||
num_model_chunks: int = 1,
|
||||
num_layers_per_stage: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
|
||||
|
||||
self.num_layers_per_stage = None
|
||||
|
||||
self.pg_mesh = pg_mesh
|
||||
self.pipeline_axis = pipeline_axis
|
||||
self.prev_rank: Optional[Tuple[int, ...]] = None
|
||||
self.next_rank: Optional[Tuple[int, ...]] = None
|
||||
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
|
||||
if num_layers_per_stage is not None:
|
||||
assert len(num_layers_per_stage) == self.num_stages
|
||||
self.num_layers_per_stage = num_layers_per_stage
|
||||
|
||||
# init prev and next coord
|
||||
coord = self.pg_mesh.coordinate()
|
||||
@@ -56,6 +58,8 @@ class PipelineStageManager:
|
||||
self.p2p_groups[tuple(ranks_in_group)] = group
|
||||
|
||||
self.is_interleave = enable_interleave
|
||||
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
|
||||
self.num_model_chunks: int = num_model_chunks
|
||||
if enable_interleave:
|
||||
# use circle p2p communication
|
||||
# add the process group of the first rank and the last rank
|
||||
@@ -64,59 +68,11 @@ class PipelineStageManager:
|
||||
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
|
||||
self.p2p_groups[tuple(ranks_in_group)] = group
|
||||
|
||||
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
|
||||
self.num_model_chunks: int = num_model_chunks
|
||||
|
||||
# for shardformer, hold stage indices of model
|
||||
self.stage_indices: List[Tuple[int, int]]
|
||||
# for shardformer, hold model chunk id
|
||||
self.model_chunk_id: Optional[int] = None
|
||||
|
||||
@property
|
||||
def control_distribute_layers(self) -> bool:
|
||||
return self.num_layers_per_stage is not None
|
||||
|
||||
def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
|
||||
"""Set the distribution configuration.
|
||||
This allows user to customize the number of layers for each stage.
|
||||
|
||||
Args:
|
||||
num_model_layers (int): Number of layers in the model.
|
||||
num_layers_per_stage (List[int]): Number of layers for each stage.
|
||||
"""
|
||||
assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
|
||||
assert sum(num_layers_per_stage) == num_model_layers
|
||||
assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
|
||||
self.num_model_layers = num_model_layers
|
||||
self.num_layers_per_stage = num_layers_per_stage
|
||||
|
||||
def distribute_layers(
|
||||
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
|
||||
) -> List[int]:
|
||||
"""Divide layers into stages"""
|
||||
num_stages = self.num_stages if num_stages is None else num_stages
|
||||
num_model_chunks = (
|
||||
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
|
||||
)
|
||||
|
||||
if self.control_distribute_layers:
|
||||
assert num_layers == self.num_model_layers
|
||||
return self.num_layers_per_stage
|
||||
|
||||
else:
|
||||
quotient = num_layers // (num_stages * num_model_chunks)
|
||||
remainder = num_layers % (num_stages * num_model_chunks)
|
||||
|
||||
# calculate the num_layers per stage
|
||||
layers_per_stage = [quotient] * num_stages * num_model_chunks
|
||||
|
||||
# deal with the rest layers
|
||||
if remainder > 0:
|
||||
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
|
||||
for i in range(start_position, start_position + remainder):
|
||||
layers_per_stage[i] += 1
|
||||
return layers_per_stage
|
||||
|
||||
def get_stage_index(
|
||||
self,
|
||||
layers_per_stage: List[int],
|
||||
@@ -139,9 +95,7 @@ class PipelineStageManager:
|
||||
|
||||
"""
|
||||
stage = self.stage if stage is None else stage
|
||||
num_model_chunks = (
|
||||
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
|
||||
)
|
||||
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
|
||||
num_stages = self.num_stages if num_stages is None else num_stages
|
||||
|
||||
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
|
||||
@@ -261,3 +215,25 @@ class PipelineStageManager:
|
||||
self.model_chunk_id = model_chunk_id
|
||||
yield
|
||||
self.model_chunk_id = old_model_chunk_id
|
||||
|
||||
def distribute_layers(
|
||||
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
|
||||
) -> List[int]:
|
||||
if self.num_layers_per_stage is not None:
|
||||
assert sum(self.num_layers_per_stage) == num_layers
|
||||
return self.num_layers_per_stage
|
||||
|
||||
num_stages = self.num_stages if num_stages is None else num_stages
|
||||
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
|
||||
quotient = num_layers // (num_stages * num_model_chunks)
|
||||
remainder = num_layers % (num_stages * num_model_chunks)
|
||||
|
||||
# calculate the num_layers per stage
|
||||
layers_per_stage = [quotient] * num_stages * num_model_chunks
|
||||
|
||||
# deal with the rest layers
|
||||
if remainder > 0:
|
||||
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
|
||||
for i in range(start_position, start_position + remainder):
|
||||
layers_per_stage[i] += 1
|
||||
return layers_per_stage
|
||||
|
Reference in New Issue
Block a user