mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[shardformer] refactor pipeline grad ckpt config (#5646)
* [shardformer] refactor pipeline grad ckpt config * [shardformer] refactor pipeline grad ckpt config * [pipeline] fix stage manager
This commit is contained in:
@@ -217,9 +217,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||
num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0]
|
||||
),
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
@@ -303,9 +301,6 @@ def run_llama_test(test_config):
|
||||
"initial_scale": 1,
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||
num_stages=2,
|
||||
num_model_chunks=2,
|
||||
num_model_layers=8,
|
||||
num_ckpt_layers_per_stage=[0, 1, 2, 2],
|
||||
),
|
||||
},
|
||||
|
Reference in New Issue
Block a user