[shardformer] update shardformer to use flash attention 2 (#4392)

* cherry-pick flash attention 2

cherry-pick flash attention 2

* [shardformer] update shardformer to use flash attention 2

[shardformer] update shardformer to use flash attention 2, fix

[shardformer] update shardformer to use flash attention 2, fix

[shardformer] update shardformer to use flash attention 2, fix
This commit is contained in:
flybird1111
2023-08-09 14:32:19 +08:00
committed by Hongxin Liu
parent ed4c448488
commit 7a3dfd0c64
9 changed files with 10 additions and 11 deletions

View File

@@ -65,7 +65,7 @@ def get_blip2_flash_attention_forward():
from transformers.models.blip_2.modeling_blip_2 import Blip2Attention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
from colossalai.kernel.cuda_native import ColoAttention
def forward(
self: Blip2Attention,