From de3f67d128eaae2116fb7665d4d8333bd2e960aa Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Wed, 26 Jun 2024 10:15:13 +0800 Subject: [PATCH] fix llama (#5856) --- colossalai/shardformer/modeling/llama.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 60bc8b711..5855dcc4f 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -150,10 +150,8 @@ class LlamaPipelineForwards: if shard_config.gradient_checkpoint_config is not None: num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( stage=stage_manager.stage, - num_stages=stage_manager.num_stages, num_layers=end_idx - start_idx, - model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), - num_model_chunks=stage_manager.num_model_chunks, + model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0, ) assert num_ckpt_layers <= end_idx - start_idx