[shardformer]fix flash attention, when mask is casual, just don't unpad it (#5084)

* fix flash attn

* fix

fix
This commit is contained in:
flybird11111
2023-11-22 16:00:07 +08:00
committed by GitHub
parent 75af66cd81
commit aae496631c
6 changed files with 16 additions and 8 deletions

View File

@@ -771,11 +771,12 @@ def get_gpt2_flash_attention_forward():
attn_mask_type = AttnMaskType.causal
flash_attention_mask = None
if attention_mask != None:
if attn_mask_type == AttnMaskType.causal:
attn_mask_type == AttnMaskType.paddedcausal
else:
attn_mask_type = AttnMaskType.padding
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
if not torch.all(flash_attention_mask):
if attn_mask_type == AttnMaskType.causal:
attn_mask_type == AttnMaskType.paddedcausal
else:
attn_mask_type = AttnMaskType.padding
scale = value.size(-1) ** -0.5
if self.scale_attn_by_inverse_layer_idx: