[npu] support triangle attention for llama (#5130)

* update fused attn

* update spda

* tri attn

* update triangle

* import

* fix

* fix
This commit is contained in:
Xuanlei Zhao
2023-11-30 14:21:30 +08:00
committed by GitHub
parent f4e72c9992
commit d6df19bae7
9 changed files with 264 additions and 3 deletions

View File

@@ -12,6 +12,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.utils import get_attention_kernel
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
@@ -404,7 +405,7 @@ class LlamaPipelineForwards:
def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
AttnMaskType, ColoAttention = get_attention_kernel()
llama_version = 2
try:
@@ -468,7 +469,7 @@ def get_llama_flash_attention_forward():
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask,
)
attn_output = self.o_proj(attn_output)