mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[shardformer] refactor pipeline grad ckpt config (#5646)
* [shardformer] refactor pipeline grad ckpt config * [shardformer] refactor pipeline grad ckpt config * [pipeline] fix stage manager
This commit is contained in:
@@ -88,16 +88,15 @@ def main():
|
||||
pass
|
||||
|
||||
# ckpt config for LLaMA3-70B on 64 H100 GPUs
|
||||
ckpt_config = (
|
||||
PipelineGradientCheckpointConfig(
|
||||
num_stages=args.pp,
|
||||
num_model_chunks=1,
|
||||
num_model_layers=80,
|
||||
num_layers_per_stage=[19, 20, 20, 21],
|
||||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||
)
|
||||
hybrid_kwargs = (
|
||||
{
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||
),
|
||||
"num_layers_per_stage": [19, 20, 20, 21],
|
||||
}
|
||||
if args.custom_ckpt
|
||||
else None
|
||||
else {}
|
||||
)
|
||||
|
||||
# ==============================
|
||||
@@ -173,7 +172,7 @@ def main():
|
||||
microbatch_size=args.mbs,
|
||||
precision="bf16",
|
||||
dp_outside=False,
|
||||
gradient_checkpoint_config=ckpt_config,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
plugin = HybridParallelPlugin(
|
||||
|
Reference in New Issue
Block a user