[shardformer] update shardformer to use flash attention 2 (#4392)

* cherry-pick flash attention 2

cherry-pick flash attention 2

* [shardformer] update shardformer to use flash attention 2

[shardformer] update shardformer to use flash attention 2, fix

[shardformer] update shardformer to use flash attention 2, fix

[shardformer] update shardformer to use flash attention 2, fix
This commit is contained in:
flybird1111
2023-08-09 14:32:19 +08:00
committed by Hongxin Liu
parent ed4c448488
commit 7a3dfd0c64
9 changed files with 10 additions and 11 deletions

View File

@@ -8,7 +8,7 @@ def get_opt_flash_attention_forward():
from transformers.models.opt.modeling_opt import OPTAttention
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def forward(
self: OPTAttention,