mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[shardformer]fix flash attention, when mask is casual, just don't unpad it (#5084)
* fix flash attn * fix fix
This commit is contained in:
@@ -771,11 +771,12 @@ def get_gpt2_flash_attention_forward():
|
||||
attn_mask_type = AttnMaskType.causal
|
||||
flash_attention_mask = None
|
||||
if attention_mask != None:
|
||||
if attn_mask_type == AttnMaskType.causal:
|
||||
attn_mask_type == AttnMaskType.paddedcausal
|
||||
else:
|
||||
attn_mask_type = AttnMaskType.padding
|
||||
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
|
||||
if not torch.all(flash_attention_mask):
|
||||
if attn_mask_type == AttnMaskType.causal:
|
||||
attn_mask_type == AttnMaskType.paddedcausal
|
||||
else:
|
||||
attn_mask_type = AttnMaskType.padding
|
||||
|
||||
scale = value.size(-1) ** -0.5
|
||||
if self.scale_attn_by_inverse_layer_idx:
|
||||
|
Reference in New Issue
Block a user