mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-28 05:49:27 +00:00
fix pp
This commit is contained in:
parent
7891cc7ddb
commit
bd8b57156b
@ -302,6 +302,15 @@ class Qwen2Policy(Policy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
|
||||||
|
if not self.shard_config.enable_tensor_parallelism:
|
||||||
|
decoder_attribute_replacement = {
|
||||||
|
"self_attn.hidden_size": self.model.config.hidden_size,
|
||||||
|
"self_attn.num_heads": self.model.config.num_attention_heads,
|
||||||
|
}
|
||||||
|
if getattr(self.model.config, "num_key_value_heads", False):
|
||||||
|
decoder_attribute_replacement["self_attn.num_key_value_heads"] = self.model.config.num_key_value_heads
|
||||||
|
policy[Qwen2DecoderLayer] = ModulePolicyDescription(attribute_replacement=decoder_attribute_replacement)
|
||||||
|
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
"forward": get_qwen2_flash_attention_npu_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
"forward": get_qwen2_flash_attention_npu_forward(self.shard_config, sp_mode, sp_size, sp_group),
|
||||||
|
Loading…
Reference in New Issue
Block a user