mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[shardformer]fix flash attention, when mask is casual, just don't unpad it (#5084)
* fix flash attn * fix fix
This commit is contained in:
@@ -106,7 +106,10 @@ def get_whisper_flash_attention_forward():
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
|
||||
attn_type = AttnMaskType.paddedcausal
|
||||
if not torch.all(flash_attention_mask):
|
||||
attn_type = AttnMaskType.paddedcausal
|
||||
else:
|
||||
attn_type = AttnMaskType.causal
|
||||
|
||||
attention = ColoAttention(
|
||||
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
|
||||
|
Reference in New Issue
Block a user