mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[shardformer] update tests for all optimization (#4413)
[shardformer] update tests for all optimization
This commit is contained in:
committed by
Hongxin Liu
parent
7711bd524a
commit
1edc9b5fb3
@@ -1048,9 +1048,12 @@ def get_bert_flash_attention_forward():
|
||||
final_attention_mask = final_attention_mask * scale + attention_mask
|
||||
else:
|
||||
final_attention_mask = attention_mask
|
||||
|
||||
if final_attention_mask is not None:
|
||||
batch_size, src_len = query_layer.size()[0], query_layer.size()[2]
|
||||
tgt_len = key_layer.size()[2]
|
||||
final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len, tgt_len)
|
||||
final_attention_mask = final_attention_mask.expand(batch_size, self.num_attention_heads, src_len,
|
||||
tgt_len).contiguous()
|
||||
|
||||
query_layer = query_layer.permute(0, 2, 1, 3).contiguous()
|
||||
key_layer = key_layer.permute(0, 2, 1, 3).contiguous()
|
||||
|
Reference in New Issue
Block a user