mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-04 11:06:25 +00:00
fix flash attn (#5209)
This commit is contained in:
parent
365671be10
commit
451e9142b8
@ -414,7 +414,7 @@ class LlamaPipelineForwards:
|
|||||||
return {"hidden_states": hidden_states}
|
return {"hidden_states": hidden_states}
|
||||||
|
|
||||||
|
|
||||||
def get_llama_flash_attention_forward():
|
def get_llama_flash_attention_forward(shard_config: ShardConfig):
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||||
|
|
||||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||||
@ -470,14 +470,13 @@ def get_llama_flash_attention_forward():
|
|||||||
|
|
||||||
flash_attention_mask = None
|
flash_attention_mask = None
|
||||||
attn_mask_type = AttnMaskType.causal
|
attn_mask_type = AttnMaskType.causal
|
||||||
if attention_mask != None:
|
if not getattr(shard_config, "causal_lm", False) and attention_mask != None:
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||||
)
|
)
|
||||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||||
if not torch.all(flash_attention_mask):
|
attn_mask_type = AttnMaskType.paddedcausal
|
||||||
attn_mask_type = AttnMaskType.paddedcausal
|
|
||||||
|
|
||||||
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||||
attn_output = attention(
|
attn_output = attention(
|
||||||
|
@ -130,7 +130,7 @@ class LlamaPolicy(Policy):
|
|||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
"forward": get_llama_flash_attention_forward(),
|
"forward": get_llama_flash_attention_forward(self.shard_config),
|
||||||
},
|
},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=LlamaAttention,
|
target_key=LlamaAttention,
|
||||||
@ -250,6 +250,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||||||
|
|
||||||
policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
|
|
||||||
|
setattr(self.shard_config, "causal_lm", True)
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for casual lm
|
# add a new item for casual lm
|
||||||
new_item = {
|
new_item = {
|
||||||
|
Loading…
Reference in New Issue
Block a user