[shardformer] update tests for all optimization (#4413)

[shardformer] update tests for all optimization
This commit is contained in:
flybird11111
2023-08-11 16:40:06 +08:00
committed by Hongxin Liu
parent 7711bd524a
commit 1edc9b5fb3
3 changed files with 50 additions and 23 deletions

View File

@@ -1048,9 +1048,12 @@ def get_bert_flash_attention_forward():
final_attention_mask = final_attention_mask * scale + attention_mask
else:
final_attention_mask = attention_mask
if final_attention_mask is not None:
batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
tgt_len = key_layer.size()[2]
final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, tgt_len)
final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len,
tgt_len).contiguous()
query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
key_layer = key_layer.permute(0, 2, 1, 3).contiguous()