[shardformer] refactor pipeline grad ckpt config (#5646)

* [shardformer] refactor pipeline grad ckpt config

* [shardformer] refactor pipeline grad ckpt config

* [pipeline] fix stage manager
This commit is contained in:
Hongxin Liu
2024-04-25 15:19:30 +08:00
committed by GitHub
parent 7ef91606e1
commit 1b387ca9fe
11 changed files with 59 additions and 102 deletions

View File

@@ -983,6 +983,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
@@ -1056,6 +1057,7 @@ class HybridParallelPlugin(PipelinePluginBase):
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
)
if pp_style == "interleaved":