mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[npu] use extension for op builder (#5172)
* update extension * update cpu adam * update is * add doc for cpu adam * update kernel * update commit * update flash * update memory efficient * update flash attn * update flash attention loader * update api * fix * update doc * update example time limit * reverse change * fix doc * remove useless kernel * fix * not use warning * update * update
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
@@ -12,14 +12,15 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer.utils import get_attention_kernel
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||
|
||||
LATEST_VERSION = True
|
||||
except ImportError:
|
||||
LATEST_VERSION = False
|
||||
|
||||
|
||||
class LlamaPipelineForwards:
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
@@ -405,7 +406,7 @@ class LlamaPipelineForwards:
|
||||
def get_llama_flash_attention_forward():
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
|
||||
|
||||
AttnMaskType, ColoAttention = get_attention_kernel()
|
||||
from colossalai.kernel import AttnMaskType, ColoAttention
|
||||
|
||||
llama_version = 2
|
||||
try:
|
||||
@@ -469,7 +470,12 @@ def get_llama_flash_attention_forward():
|
||||
|
||||
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
|
||||
attn_output = attention(
|
||||
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=flash_attention_mask,
|
||||
attn_mask_type=attn_mask_type,
|
||||
origin_attn_mask=attention_mask,
|
||||
)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
Reference in New Issue
Block a user