fix flash attn (#5209)

This commit is contained in:
flybird11111 2024-01-03 14:39:53 +08:00 committed by GitHub
parent 365671be10
commit 451e9142b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 5 deletions

View File

@ -414,7 +414,7 @@ class LlamaPipelineForwards:
return {"hidden_states": hidden_states} return {"hidden_states": hidden_states}
def get_llama_flash_attention_forward(): def get_llama_flash_attention_forward(shard_config: ShardConfig):
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
@ -470,14 +470,13 @@ def get_llama_flash_attention_forward():
flash_attention_mask = None flash_attention_mask = None
attn_mask_type = AttnMaskType.causal attn_mask_type = AttnMaskType.causal
if attention_mask != None: if not getattr(shard_config, "causal_lm", False) and attention_mask != None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError( raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
) )
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
if not torch.all(flash_attention_mask): attn_mask_type = AttnMaskType.paddedcausal
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention( attn_output = attention(

View File

@ -130,7 +130,7 @@ class LlamaPolicy(Policy):
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
"forward": get_llama_flash_attention_forward(), "forward": get_llama_flash_attention_forward(self.shard_config),
}, },
policy=policy, policy=policy,
target_key=LlamaAttention, target_key=LlamaAttention,
@ -250,6 +250,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
policy = super().module_policy() policy = super().module_policy()
setattr(self.shard_config, "causal_lm", True)
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for casual lm
new_item = { new_item = {