From 03fa79a55c327d411ed1f7af7c3fc88007708d60 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 25 Oct 2024 10:17:06 +0000 Subject: [PATCH] [fix] fix llama modeling policy; --- colossalai/shardformer/policies/llama.py | 3 ++- tests/test_shardformer/test_model/test_shard_llama.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 28ac2dc7f..bef39a6ca 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -96,7 +96,8 @@ class LlamaPolicy(Policy): target_key=attn_cls, ) - if self.pipeline_stage_manager is not None: + # if self.pipeline_stage_manager is not None: + if self.pipeline_stage_manager is None: self.append_or_create_method_replacement( description={ "forward": partial( diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 33707a4f6..b43e45bcf 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -325,6 +325,7 @@ def run_llama_test(test_config): ).get_v_schedule() test_config["scheduler_nodes"] = scheduler_nodes for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + print(f"name {name}") if test_config.get("sequence_parallelism_mode", None) == "ring_attn" and "causal" not in name: continue try: