mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[npu] support triangle attention for llama (#5130)
* update fused attn * update spda * tri attn * update triangle * import * fix * fix
This commit is contained in:
@@ -12,6 +12,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer.utils import get_attention_kernel
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||
@@ -404,7 +405,7 @@ class LlamaPipelineForwards:
|
||||
def get_llama_flash_attention_forward():
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
|
||||
AttnMaskType, ColoAttention = get_attention_kernel()
|
||||
|
||||
llama_version = 2
|
||||
try:
|
||||
@@ -468,7 +469,7 @@ def get_llama_flash_attention_forward():
|
||||
|
||||
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||
attn_output = attention(
|
||||
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
|
||||
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask,
|
||||
)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
Reference in New Issue
Block a user