[npu] use extension for op builder (#5172)

* update extension

* update cpu adam

* update is

* add doc for cpu adam

* update kernel

* update commit

* update flash

* update memory efficient

* update flash attn

* update flash attention loader

* update api

* fix

* update doc

* update example time limit

* reverse change

* fix doc

* remove useless kernel

* fix

* not use warning

* update

* update
This commit is contained in:
Xuanlei Zhao
2024-01-08 11:39:16 +08:00
committed by GitHub
parent d6df19bae7
commit dd2c28a323
35 changed files with 1067 additions and 274 deletions

View File

@@ -1,5 +1,5 @@
import warnings
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -12,14 +12,15 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForS
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer.utils import get_attention_kernel
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
LATEST_VERSION = True
except ImportError:
LATEST_VERSION = False
class LlamaPipelineForwards:
"""
This class serves as a micro library for forward function substitution of Llama models
@@ -405,7 +406,7 @@ class LlamaPipelineForwards:
def get_llama_flash_attention_forward():
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
AttnMaskType, ColoAttention = get_attention_kernel()
from colossalai.kernel import AttnMaskType, ColoAttention
llama_version = 2
try:
@@ -469,7 +470,12 @@ def get_llama_flash_attention_forward():
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
query_states, key_states, value_states, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type, origin_attn_mask=attention_mask,
query_states,
key_states,
value_states,
attn_mask=flash_attention_mask,
attn_mask_type=attn_mask_type,
origin_attn_mask=attention_mask,
)
attn_output = self.o_proj(attn_output)