From bd8b57156bb4de3fb40ff7edb06fac72757825a7 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 25 Jun 2025 09:48:50 +0800 Subject: [PATCH] fix pp --- colossalai/shardformer/policies/qwen2.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 823527df6..0c6d43524 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -302,6 +302,15 @@ class Qwen2Policy(Policy): ) if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: + if not self.shard_config.enable_tensor_parallelism: + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size, + "self_attn.num_heads": self.model.config.num_attention_heads, + } + if getattr(self.model.config, "num_key_value_heads", False): + decoder_attribute_replacement["self_attn.num_key_value_heads"] = self.model.config.num_key_value_heads + policy[Qwen2DecoderLayer] = ModulePolicyDescription(attribute_replacement=decoder_attribute_replacement) + self.append_or_create_method_replacement( description={ "forward": get_qwen2_flash_attention_npu_forward(self.shard_config, sp_mode, sp_size, sp_group),