fix flash attn (#5209)

This commit is contained in:
flybird11111
2024-01-03 14:39:53 +08:00
committed by GitHub
parent 365671be10
commit 451e9142b8
2 changed files with 6 additions and 5 deletions

View File

@@ -130,7 +130,7 @@ class LlamaPolicy(Policy):
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_llama_flash_attention_forward(),
"forward": get_llama_flash_attention_forward(self.shard_config),
},
policy=policy,
target_key=LlamaAttention,
@@ -250,6 +250,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
policy = super().module_policy()
setattr(self.shard_config, "causal_lm", True)
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {