diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py index 0c6d43524..d2798a23e 100644 --- a/colossalai/shardformer/policies/qwen2.py +++ b/colossalai/shardformer/policies/qwen2.py @@ -308,9 +308,11 @@ class Qwen2Policy(Policy): "self_attn.num_heads": self.model.config.num_attention_heads, } if getattr(self.model.config, "num_key_value_heads", False): - decoder_attribute_replacement["self_attn.num_key_value_heads"] = self.model.config.num_key_value_heads + decoder_attribute_replacement["self_attn.num_key_value_heads"] = ( + self.model.config.num_key_value_heads + ) policy[Qwen2DecoderLayer] = ModulePolicyDescription(attribute_replacement=decoder_attribute_replacement) - + self.append_or_create_method_replacement( description={ "forward": get_qwen2_flash_attention_npu_forward(self.shard_config, sp_mode, sp_size, sp_group),