mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 13:05:26 +00:00
fix flash attn (#5209)
This commit is contained in:
@@ -130,7 +130,7 @@ class LlamaPolicy(Policy):
|
||||
if self.shard_config.enable_flash_attention:
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_llama_flash_attention_forward(),
|
||||
"forward": get_llama_flash_attention_forward(self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=LlamaAttention,
|
||||
@@ -250,6 +250,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
setattr(self.shard_config, "causal_lm", True)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
|
Reference in New Issue
Block a user