[shardformer] refactor pipeline grad ckpt config (#5646)

* [shardformer] refactor pipeline grad ckpt config

* [shardformer] refactor pipeline grad ckpt config

* [pipeline] fix stage manager
This commit is contained in:
Hongxin Liu
2024-04-25 15:19:30 +08:00
committed by GitHub
parent 7ef91606e1
commit 1b387ca9fe
11 changed files with 59 additions and 102 deletions

View File

@@ -88,16 +88,15 @@ def main():
pass
# ckpt config for LLaMA3-70B on 64 H100 GPUs
ckpt_config = (
PipelineGradientCheckpointConfig(
num_stages=args.pp,
num_model_chunks=1,
num_model_layers=80,
num_layers_per_stage=[19, 20, 20, 21],
num_ckpt_layers_per_stage=[19, 19, 19, 13],
)
hybrid_kwargs = (
{
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_ckpt_layers_per_stage=[19, 19, 19, 13],
),
"num_layers_per_stage": [19, 20, 20, 21],
}
if args.custom_ckpt
else None
else {}
)
# ==============================
@@ -173,7 +172,7 @@ def main():
microbatch_size=args.mbs,
precision="bf16",
dp_outside=False,
gradient_checkpoint_config=ckpt_config,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
plugin = HybridParallelPlugin(