[shardformer] update shardformer to use flash attention 2 (#4392)

* cherry-pick flash attention 2

cherry-pick flash attention 2

* [shardformer] update shardformer to use flash attention 2

[shardformer] update shardformer to use flash attention 2, fix

[shardformer] update shardformer to use flash attention 2, fix

[shardformer] update shardformer to use flash attention 2, fix
This commit is contained in:
flybird1111
2023-08-09 14:32:19 +08:00
committed by Hongxin Liu
parent ed4c448488
commit 7a3dfd0c64
9 changed files with 10 additions and 11 deletions

View File

@@ -13,7 +13,6 @@ if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
DTYPE = [torch.float16, torch.bfloat16, torch.float32]
FLASH_DTYPE = [torch.float16, torch.bfloat16]
def attention_ref(q, k, v, attn_mask=None, causal=False):