[shardformer] update shardformer to use flash attention 2 (#4392)

* cherry-pick flash attention 2

cherry-pick flash attention 2

* [shardformer] update shardformer to use flash attention 2

[shardformer] update shardformer to use flash attention 2, fix

[shardformer] update shardformer to use flash attention 2, fix

[shardformer] update shardformer to use flash attention 2, fix
This commit is contained in:
flybird1111
2023-08-09 14:32:19 +08:00
committed by Hongxin Liu
parent ed4c448488
commit 7a3dfd0c64
9 changed files with 10 additions and 11 deletions

View File

@@ -392,7 +392,7 @@ def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
def forward(
self: LlamaAttention,