From 451e9142b8b8b77ed3138fb03ad54494c3c57126 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 3 Jan 2024 14:39:53 +0800 Subject: [PATCH] fix flash attn (#5209) --- colossalai/shardformer/modeling/llama.py | 7 +++---- colossalai/shardformer/policies/llama.py | 4 +++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 286852899..1b53ce4af 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -414,7 +414,7 @@ class LlamaPipelineForwards: return {"hidden_states": hidden_states} -def get_llama_flash_attention_forward(): +def get_llama_flash_attention_forward(shard_config: ShardConfig): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention @@ -470,14 +470,13 @@ def get_llama_flash_attention_forward(): flash_attention_mask = None attn_mask_type = AttnMaskType.causal - if attention_mask != None: + if not getattr(shard_config, "causal_lm", False) and attention_mask != None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous() - if not torch.all(flash_attention_mask): - attn_mask_type = AttnMaskType.paddedcausal + attn_mask_type = AttnMaskType.paddedcausal attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads) attn_output = attention( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 39a4d4023..1faa24f71 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -130,7 +130,7 @@ class LlamaPolicy(Policy): if self.shard_config.enable_flash_attention: self.append_or_create_method_replacement( description={ - "forward": get_llama_flash_attention_forward(), + "forward": get_llama_flash_attention_forward(self.shard_config), }, policy=policy, target_key=LlamaAttention, @@ -250,6 +250,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy): policy = super().module_policy() + setattr(self.shard_config, "causal_lm", True) + if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = {