mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-18 17:31:53 +00:00
[Refactor] Integrated some lightllm kernels into token-attention (#4946)
* add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
This commit is contained in:
parent
11009103be
commit
3a41e8304e
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
## Introduction
|
## Introduction
|
||||||
|
|
||||||
`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.
|
`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including LightLLM, TGI, vLLM, FasterTransformer and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.
|
||||||
|
|
||||||
## Design
|
## Design
|
||||||
|
|
||||||
@ -62,6 +62,12 @@ triton==2.0.0.dev20221202
|
|||||||
vllm
|
vllm
|
||||||
# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c
|
# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c
|
||||||
flash-attention
|
flash-attention
|
||||||
|
|
||||||
|
# install lightllm since we depend on lightllm triton kernels
|
||||||
|
git clone https://github.com/ModelTC/lightllm
|
||||||
|
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
||||||
|
cd lightllm
|
||||||
|
pip3 install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
### Docker
|
### Docker
|
||||||
@ -73,6 +79,17 @@ You can use docker run to use docker container to set-up environment
|
|||||||
docker pull hpcaitech/colossalai-inference:v2
|
docker pull hpcaitech/colossalai-inference:v2
|
||||||
docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash
|
docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash
|
||||||
|
|
||||||
|
# enter into docker container
|
||||||
|
cd /path/to/CollossalAI
|
||||||
|
pip install -e .
|
||||||
|
|
||||||
|
# install lightllm
|
||||||
|
git clone https://github.com/ModelTC/lightllm
|
||||||
|
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
||||||
|
cd lightllm
|
||||||
|
pip3 install -e .
|
||||||
|
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Dive into fast-inference!
|
### Dive into fast-inference!
|
||||||
|
@ -5,7 +5,7 @@ import torch
|
|||||||
|
|
||||||
from .kvcache_manager import MemoryManager
|
from .kvcache_manager import MemoryManager
|
||||||
|
|
||||||
|
# adapted from: lightllm/server/router/model_infer/infer_batch.py
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchInferState:
|
class BatchInferState:
|
||||||
r"""
|
r"""
|
||||||
@ -41,6 +41,7 @@ class BatchInferState:
|
|||||||
def set_cache_manager(self, manager: MemoryManager):
|
def set_cache_manager(self, manager: MemoryManager):
|
||||||
self.cache_manager = manager
|
self.cache_manager = manager
|
||||||
|
|
||||||
|
# adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_block_loc(
|
def init_block_loc(
|
||||||
b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
|
b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
# Adapted from lightllm/common/mem_manager.py
|
"""
|
||||||
# of the ModelTC/lightllm GitHub repository
|
Refered/Modified from lightllm/common/mem_manager.py
|
||||||
# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
|
of the ModelTC/lightllm GitHub repository
|
||||||
|
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
|
||||||
|
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
@ -6,8 +6,6 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
|
|
||||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||||
from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd
|
|
||||||
from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards
|
|
||||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
||||||
ChatGLMForConditionalGeneration,
|
ChatGLMForConditionalGeneration,
|
||||||
@ -20,6 +18,14 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
|||||||
|
|
||||||
from ._utils import copy_kv_to_mem_cache
|
from ._utils import copy_kv_to_mem_cache
|
||||||
|
|
||||||
|
try:
|
||||||
|
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd
|
||||||
|
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd
|
||||||
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
|
except:
|
||||||
|
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
|
||||||
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
|
||||||
|
|
||||||
# This func is same as Llama model init_to_get_rotary, we should move them into _utils.py
|
# This func is same as Llama model init_to_get_rotary, we should move them into _utils.py
|
||||||
def _init_to_get_rotary(self, base=10000):
|
def _init_to_get_rotary(self, base=10000):
|
||||||
@ -433,17 +439,17 @@ class ChatGLM2InferenceForwards:
|
|||||||
|
|
||||||
cos, sin = infer_state.position_cos, infer_state.position_sin
|
cos, sin = infer_state.position_cos, infer_state.position_sin
|
||||||
|
|
||||||
Llama2Forwards.rotary_emb_fwd(
|
chatglm2_rotary_emb_fwd(
|
||||||
query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin
|
query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin
|
||||||
)
|
)
|
||||||
if self.multi_query_attention:
|
if self.multi_query_attention:
|
||||||
Llama2Forwards.rotary_emb_fwd(
|
chatglm2_rotary_emb_fwd(
|
||||||
key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head),
|
key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head),
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
Llama2Forwards.rotary_emb_fwd(
|
chatglm2_rotary_emb_fwd(
|
||||||
key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
|
key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
@ -474,7 +480,7 @@ class ChatGLM2InferenceForwards:
|
|||||||
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
|
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
|
||||||
|
|
||||||
# NOTE: no bug in context attn fwd (del it )
|
# NOTE: no bug in context attn fwd (del it )
|
||||||
llama2_context_attn_fwd(
|
lightllm_llama2_context_attention_fwd(
|
||||||
query_layer,
|
query_layer,
|
||||||
key_layer,
|
key_layer,
|
||||||
value_layer,
|
value_layer,
|
||||||
|
@ -5,12 +5,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
|
|||||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
||||||
|
|
||||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||||
from colossalai.kernel.triton import (
|
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
|
||||||
llama2_context_attn_fwd,
|
|
||||||
llama_context_attn_fwd,
|
|
||||||
rotary_embedding_fwd,
|
|
||||||
token_attention_fwd,
|
|
||||||
)
|
|
||||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||||
|
|
||||||
from ._utils import copy_kv_to_mem_cache
|
from ._utils import copy_kv_to_mem_cache
|
||||||
@ -29,6 +24,17 @@ except:
|
|||||||
)
|
)
|
||||||
HAS_VLLM_KERNERL = False
|
HAS_VLLM_KERNERL = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
|
||||||
|
context_attention_fwd as lightllm_llama2_context_attention_fwd,
|
||||||
|
)
|
||||||
|
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
|
||||||
|
|
||||||
|
HAS_LIGHTLLM_KERNEL = True
|
||||||
|
except:
|
||||||
|
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
|
||||||
|
HAS_LIGHTLLM_KERNEL = False
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
"""Rotates half the hidden dims of the input."""
|
"""Rotates half the hidden dims of the input."""
|
||||||
@ -280,8 +286,8 @@ class LlamaInferenceForwards:
|
|||||||
cos, sin = infer_state.position_cos, infer_state.position_sin
|
cos, sin = infer_state.position_cos, infer_state.position_sin
|
||||||
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
|
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
|
||||||
|
|
||||||
rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
||||||
rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
|
llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
|
||||||
|
|
||||||
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
|
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
|
||||||
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
|
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
|
||||||
@ -312,7 +318,7 @@ class LlamaInferenceForwards:
|
|||||||
infer_state.cache_manager.past_key_values_length,
|
infer_state.cache_manager.past_key_values_length,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
llama2_context_attn_fwd(
|
lightllm_llama2_context_attention_fwd(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
@ -371,6 +377,7 @@ class LlamaInferenceForwards:
|
|||||||
infer_state.cache_manager.past_key_values_length,
|
infer_state.cache_manager.past_key_values_length,
|
||||||
infer_state.other_kv_index,
|
infer_state.other_kv_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
@ -12,8 +12,7 @@ from ..modeling._utils import init_to_get_rotary
|
|||||||
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
|
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from colossalai.kernel.triton import rmsnorm_forward
|
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
|
||||||
|
|
||||||
HAS_TRITON_RMSNORM = True
|
HAS_TRITON_RMSNORM = True
|
||||||
except:
|
except:
|
||||||
print("you should install triton from https://github.com/openai/triton")
|
print("you should install triton from https://github.com/openai/triton")
|
||||||
@ -22,9 +21,8 @@ except:
|
|||||||
|
|
||||||
def get_triton_rmsnorm_forward():
|
def get_triton_rmsnorm_forward():
|
||||||
if HAS_TRITON_RMSNORM:
|
if HAS_TRITON_RMSNORM:
|
||||||
|
|
||||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||||
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
||||||
|
|
||||||
return _triton_rmsnorm_forward
|
return _triton_rmsnorm_forward
|
||||||
else:
|
else:
|
||||||
|
@ -9,26 +9,21 @@ except ImportError:
|
|||||||
|
|
||||||
# There may exist import error even if we have triton installed.
|
# There may exist import error even if we have triton installed.
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
from .context_attention import bloom_context_attn_fwd, llama2_context_attn_fwd, llama_context_attn_fwd
|
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
|
||||||
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
from .copy_kv_cache_dest import copy_kv_cache_to_dest
|
||||||
from .fused_layernorm import layer_norm
|
from .fused_layernorm import layer_norm
|
||||||
from .gptq_triton import gptq_fused_linear_triton
|
from .gptq_triton import gptq_fused_linear_triton
|
||||||
from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd
|
from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd
|
||||||
from .rms_norm import rmsnorm_forward
|
|
||||||
from .rotary_embedding_kernel import rotary_embedding_fwd
|
|
||||||
from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd
|
from .smooth_attention import smooth_llama_context_attn_fwd, smooth_token_attention_fwd
|
||||||
from .softmax import softmax
|
from .softmax import softmax
|
||||||
from .token_attention_kernel import token_attention_fwd
|
from .token_attention_kernel import token_attention_fwd
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"llama_context_attn_fwd",
|
"llama_context_attn_fwd",
|
||||||
"llama2_context_attn_fwd",
|
|
||||||
"bloom_context_attn_fwd",
|
"bloom_context_attn_fwd",
|
||||||
"softmax",
|
"softmax",
|
||||||
"layer_norm",
|
"layer_norm",
|
||||||
"rmsnorm_forward",
|
|
||||||
"copy_kv_cache_to_dest",
|
"copy_kv_cache_to_dest",
|
||||||
"rotary_embedding_fwd",
|
|
||||||
"token_attention_fwd",
|
"token_attention_fwd",
|
||||||
"gptq_fused_linear_triton",
|
"gptq_fused_linear_triton",
|
||||||
"int8_rotary_embedding_fwd",
|
"int8_rotary_embedding_fwd",
|
||||||
|
@ -238,329 +238,5 @@ if HAS_TRITON:
|
|||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _fwd_kernel_latest(
|
|
||||||
Q,
|
|
||||||
K,
|
|
||||||
V,
|
|
||||||
sm_scale,
|
|
||||||
B_Start_Loc,
|
|
||||||
B_Seqlen,
|
|
||||||
Out,
|
|
||||||
stride_qbs,
|
|
||||||
stride_qh,
|
|
||||||
stride_qd,
|
|
||||||
stride_kbs,
|
|
||||||
stride_kh,
|
|
||||||
stride_kd,
|
|
||||||
stride_vbs,
|
|
||||||
stride_vh,
|
|
||||||
stride_vd,
|
|
||||||
stride_obs,
|
|
||||||
stride_oh,
|
|
||||||
stride_od,
|
|
||||||
kv_group_num,
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
):
|
|
||||||
cur_batch = tl.program_id(0)
|
|
||||||
cur_head = tl.program_id(1)
|
|
||||||
start_m = tl.program_id(2)
|
|
||||||
|
|
||||||
cur_kv_head = cur_head // kv_group_num
|
|
||||||
|
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
|
||||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
|
||||||
|
|
||||||
block_start_loc = BLOCK_M * start_m
|
|
||||||
|
|
||||||
# initialize offsets
|
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
|
||||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
|
||||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
||||||
off_q = (
|
|
||||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
|
||||||
+ cur_head * stride_qh
|
|
||||||
+ offs_d[None, :] * stride_qd
|
|
||||||
)
|
|
||||||
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
|
|
||||||
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
|
|
||||||
|
|
||||||
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
|
||||||
|
|
||||||
k_ptrs = K + off_k
|
|
||||||
v_ptrs = V + off_v
|
|
||||||
|
|
||||||
# initialize pointer to m and l
|
|
||||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
||||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
|
||||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
||||||
|
|
||||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
|
||||||
|
|
||||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
||||||
# -- compute qk ----
|
|
||||||
k = tl.load(
|
|
||||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
|
||||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
|
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
||||||
qk += tl.dot(q, k)
|
|
||||||
qk *= sm_scale
|
|
||||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
|
||||||
|
|
||||||
# -- compute m_ij, p, l_ij
|
|
||||||
m_ij = tl.max(qk, 1)
|
|
||||||
p = tl.exp(qk - m_ij[:, None])
|
|
||||||
l_ij = tl.sum(p, 1)
|
|
||||||
# -- update m_i and l_i
|
|
||||||
m_i_new = tl.maximum(m_i, m_ij)
|
|
||||||
alpha = tl.exp(m_i - m_i_new)
|
|
||||||
beta = tl.exp(m_ij - m_i_new)
|
|
||||||
l_i_new = alpha * l_i + beta * l_ij
|
|
||||||
# -- update output accumulator --
|
|
||||||
# scale p
|
|
||||||
p_scale = beta / l_i_new
|
|
||||||
p = p * p_scale[:, None]
|
|
||||||
# scale acc
|
|
||||||
acc_scale = l_i / l_i_new * alpha
|
|
||||||
acc = acc * acc_scale[:, None]
|
|
||||||
# update acc
|
|
||||||
v = tl.load(
|
|
||||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
|
||||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
p = p.to(v.dtype)
|
|
||||||
acc += tl.dot(p, v)
|
|
||||||
# update m_i and l_i
|
|
||||||
l_i = l_i_new
|
|
||||||
m_i = m_i_new
|
|
||||||
# initialize pointers to output
|
|
||||||
off_o = (
|
|
||||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
|
||||||
+ cur_head * stride_oh
|
|
||||||
+ offs_d[None, :] * stride_od
|
|
||||||
)
|
|
||||||
out_ptrs = Out + off_o
|
|
||||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
|
||||||
return
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _fwd_kernel_old(
|
|
||||||
Q,
|
|
||||||
K,
|
|
||||||
V,
|
|
||||||
sm_scale,
|
|
||||||
B_Start_Loc,
|
|
||||||
B_Seqlen,
|
|
||||||
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
|
||||||
Out,
|
|
||||||
stride_qbs,
|
|
||||||
stride_qh,
|
|
||||||
stride_qd,
|
|
||||||
stride_kbs,
|
|
||||||
stride_kh,
|
|
||||||
stride_kd,
|
|
||||||
stride_vbs,
|
|
||||||
stride_vh,
|
|
||||||
stride_vd,
|
|
||||||
stride_obs,
|
|
||||||
stride_oh,
|
|
||||||
stride_od,
|
|
||||||
stride_tmp_b,
|
|
||||||
stride_tmp_h,
|
|
||||||
stride_tmp_s,
|
|
||||||
kv_group_num,
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
):
|
|
||||||
cur_batch = tl.program_id(0)
|
|
||||||
cur_head = tl.program_id(1)
|
|
||||||
start_m = tl.program_id(2)
|
|
||||||
|
|
||||||
cur_kv_head = cur_head // kv_group_num
|
|
||||||
|
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
|
||||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
|
||||||
|
|
||||||
block_start_loc = BLOCK_M * start_m
|
|
||||||
|
|
||||||
# initialize offsets
|
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
|
||||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
|
||||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
||||||
off_q = (
|
|
||||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
|
|
||||||
+ cur_head * stride_qh
|
|
||||||
+ offs_d[None, :] * stride_qd
|
|
||||||
)
|
|
||||||
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
|
|
||||||
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
|
|
||||||
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
|
||||||
|
|
||||||
k_ptrs = K + off_k
|
|
||||||
v_ptrs = V + off_v
|
|
||||||
|
|
||||||
t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
|
|
||||||
# t_ptrs = TMP + offs_m
|
|
||||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
||||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
|
||||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
||||||
|
|
||||||
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
|
|
||||||
|
|
||||||
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
|
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
||||||
# -- compute qk ----
|
|
||||||
k = tl.load(
|
|
||||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
|
||||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
||||||
qk += tl.dot(q, k)
|
|
||||||
qk *= sm_scale
|
|
||||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
|
||||||
|
|
||||||
m_ij = tl.max(qk, 1)
|
|
||||||
p = tl.exp(qk - m_ij[:, None])
|
|
||||||
l_ij = tl.sum(p, 1)
|
|
||||||
# -- update m_i and l_i
|
|
||||||
m_i_new = tl.maximum(m_i, m_ij)
|
|
||||||
alpha = tl.exp(m_i - m_i_new)
|
|
||||||
beta = tl.exp(m_ij - m_i_new)
|
|
||||||
l_i_new = alpha * l_i + beta * l_ij
|
|
||||||
# -- update output accumulator --
|
|
||||||
# scale p
|
|
||||||
p_scale = beta / l_i_new
|
|
||||||
p = p * p_scale[:, None]
|
|
||||||
# scale acc
|
|
||||||
acc_scale = l_i / l_i_new * alpha
|
|
||||||
tl.store(t_ptrs, acc_scale)
|
|
||||||
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
|
|
||||||
acc = acc * acc_scale[:, None]
|
|
||||||
# update acc
|
|
||||||
v = tl.load(
|
|
||||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
|
||||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
p = p.to(v.dtype)
|
|
||||||
acc += tl.dot(p, v)
|
|
||||||
# update m_i and l_i
|
|
||||||
l_i = l_i_new
|
|
||||||
m_i = m_i_new
|
|
||||||
# initialize pointers to output
|
|
||||||
off_o = (
|
|
||||||
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
|
|
||||||
+ cur_head * stride_oh
|
|
||||||
+ offs_d[None, :] * stride_od
|
|
||||||
)
|
|
||||||
out_ptrs = Out + off_o
|
|
||||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def llama2_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
|
||||||
if triton.__version__ >= "2.1.0":
|
|
||||||
BLOCK = 128
|
|
||||||
# shape constraints
|
|
||||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
|
||||||
assert Lq == Lk and Lk == Lv
|
|
||||||
assert Lk in {16, 32, 64, 128}
|
|
||||||
sm_scale = 1.0 / (Lq**0.5) # 计算scale系数
|
|
||||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
|
||||||
kv_group_num = q.shape[1] // k.shape[1]
|
|
||||||
|
|
||||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
|
|
||||||
|
|
||||||
num_warps = 4 if Lk <= 64 else 8
|
|
||||||
_fwd_kernel_latest[grid](
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
sm_scale,
|
|
||||||
b_start_loc,
|
|
||||||
b_seq_len,
|
|
||||||
o,
|
|
||||||
q.stride(0),
|
|
||||||
q.stride(1),
|
|
||||||
q.stride(2),
|
|
||||||
k.stride(0),
|
|
||||||
k.stride(1),
|
|
||||||
k.stride(2),
|
|
||||||
v.stride(0),
|
|
||||||
v.stride(1),
|
|
||||||
v.stride(2),
|
|
||||||
o.stride(0),
|
|
||||||
o.stride(1),
|
|
||||||
o.stride(2),
|
|
||||||
kv_group_num=kv_group_num,
|
|
||||||
BLOCK_M=BLOCK,
|
|
||||||
BLOCK_DMODEL=Lk,
|
|
||||||
BLOCK_N=BLOCK,
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=1,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
elif triton.__version__ == "2.0.0":
|
|
||||||
BLOCK = 128
|
|
||||||
# shape constraints
|
|
||||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
|
||||||
assert Lq == Lk and Lk == Lv
|
|
||||||
assert Lk in {16, 32, 64, 128}
|
|
||||||
|
|
||||||
sm_scale = 1.0 / (Lq**0.5)
|
|
||||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
|
||||||
kv_group_num = q.shape[1] // k.shape[1]
|
|
||||||
|
|
||||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
|
|
||||||
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
|
|
||||||
num_warps = 4 if Lk <= 64 else 8
|
|
||||||
# num_warps = 4
|
|
||||||
_fwd_kernel_old[grid](
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
sm_scale,
|
|
||||||
b_start_loc,
|
|
||||||
b_seq_len,
|
|
||||||
tmp,
|
|
||||||
o,
|
|
||||||
q.stride(0),
|
|
||||||
q.stride(1),
|
|
||||||
q.stride(2),
|
|
||||||
k.stride(0),
|
|
||||||
k.stride(1),
|
|
||||||
k.stride(2),
|
|
||||||
v.stride(0),
|
|
||||||
v.stride(1),
|
|
||||||
v.stride(2),
|
|
||||||
o.stride(0),
|
|
||||||
o.stride(1),
|
|
||||||
o.stride(2),
|
|
||||||
tmp.stride(0),
|
|
||||||
tmp.stride(1),
|
|
||||||
tmp.stride(2),
|
|
||||||
kv_group_num=kv_group_num,
|
|
||||||
BLOCK_M=BLOCK,
|
|
||||||
BLOCK_DMODEL=Lk,
|
|
||||||
BLOCK_N=BLOCK,
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=1,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
@ -11,6 +11,7 @@ except ImportError:
|
|||||||
|
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
|
|
||||||
|
# adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _fwd_copy_kv_cache_dest(
|
def _fwd_copy_kv_cache_dest(
|
||||||
kv_cache_ptr,
|
kv_cache_ptr,
|
||||||
@ -42,6 +43,7 @@ if HAS_TRITON:
|
|||||||
tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)
|
tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# adepted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
|
def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
|
||||||
seq_len = dest_index_ptr.shape[0]
|
seq_len = dest_index_ptr.shape[0]
|
||||||
|
@ -1,71 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
try:
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
HAS_TRITON = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_TRITON = False
|
|
||||||
print("please install triton from https://github.com/openai/triton")
|
|
||||||
|
|
||||||
|
|
||||||
if HAS_TRITON:
|
|
||||||
"""
|
|
||||||
this kernel function is modified from
|
|
||||||
https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/rmsnorm.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _rms_norm_fwd_fused(
|
|
||||||
X, # pointer to the input
|
|
||||||
Y, # pointer to the output
|
|
||||||
W, # pointer to the weights
|
|
||||||
stride, # how much to increase the pointer when moving by 1 row
|
|
||||||
N, # number of columns in X
|
|
||||||
eps, # epsilon to avoid division by zero
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
# Map the program id to the row of X and Y it should compute.
|
|
||||||
row = tl.program_id(0)
|
|
||||||
Y += row * stride
|
|
||||||
X += row * stride
|
|
||||||
# Compute variance
|
|
||||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
|
||||||
for off in range(0, N, BLOCK_SIZE):
|
|
||||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
|
||||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
|
||||||
_var += x * x
|
|
||||||
var = tl.sum(_var, axis=0) / N
|
|
||||||
rstd = 1 / tl.sqrt(var + eps)
|
|
||||||
# Normalize and apply linear transformation
|
|
||||||
for off in range(0, N, BLOCK_SIZE):
|
|
||||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = cols < N
|
|
||||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
|
||||||
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
|
|
||||||
x_hat = x * rstd
|
|
||||||
y = x_hat * w
|
|
||||||
# Write output
|
|
||||||
tl.store(Y + cols, y.to(tl.float16), mask=mask)
|
|
||||||
|
|
||||||
def rmsnorm_forward(x, weight, eps):
|
|
||||||
# allocate output
|
|
||||||
y = torch.empty_like(x)
|
|
||||||
# reshape input data into 2D tensor
|
|
||||||
x_arg = x.view(-1, x.shape[-1])
|
|
||||||
M, N = x_arg.shape
|
|
||||||
# Less than 64KB per feature: enqueue fused kernel
|
|
||||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
|
||||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
|
||||||
# print("BLOCK_SIZE:", BLOCK_SIZE)
|
|
||||||
if N > BLOCK_SIZE:
|
|
||||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
||||||
# heuristics for number of warps
|
|
||||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
|
||||||
# print(BLOCK_SIZE, num_warps, "block_size, numwarps")
|
|
||||||
BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2
|
|
||||||
num_warps = 8
|
|
||||||
# enqueue kernel
|
|
||||||
_rms_norm_fwd_fused[(M,)](x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
|
|
||||||
return y
|
|
@ -1,212 +0,0 @@
|
|||||||
# Adapted from ModelTC https://github.com/ModelTC/lightllm
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _rotary_kernel(
|
|
||||||
q,
|
|
||||||
Cos,
|
|
||||||
Sin,
|
|
||||||
q_bs_stride,
|
|
||||||
q_h_stride,
|
|
||||||
q_d_stride,
|
|
||||||
cos_bs_stride,
|
|
||||||
cos_d_stride,
|
|
||||||
total_len,
|
|
||||||
HEAD_NUM: tl.constexpr,
|
|
||||||
BLOCK_HEAD: tl.constexpr,
|
|
||||||
BLOCK_SEQ: tl.constexpr,
|
|
||||||
HEAD_DIM: tl.constexpr,
|
|
||||||
):
|
|
||||||
current_head_index = tl.program_id(0)
|
|
||||||
current_seq_index = tl.program_id(1)
|
|
||||||
|
|
||||||
current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
|
|
||||||
current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
|
|
||||||
|
|
||||||
dim_range0 = tl.arange(0, HEAD_DIM // 2)
|
|
||||||
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
|
|
||||||
|
|
||||||
off_q0 = (
|
|
||||||
current_seq_range[:, None, None] * q_bs_stride
|
|
||||||
+ current_head_range[None, :, None] * q_h_stride
|
|
||||||
+ dim_range0[None, None, :] * q_d_stride
|
|
||||||
)
|
|
||||||
off_q1 = (
|
|
||||||
current_seq_range[:, None, None] * q_bs_stride
|
|
||||||
+ current_head_range[None, :, None] * q_h_stride
|
|
||||||
+ dim_range1[None, None, :] * q_d_stride
|
|
||||||
)
|
|
||||||
|
|
||||||
off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
|
|
||||||
|
|
||||||
q0 = tl.load(
|
|
||||||
q + off_q0,
|
|
||||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
q1 = tl.load(
|
|
||||||
q + off_q1,
|
|
||||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
|
|
||||||
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
|
|
||||||
|
|
||||||
out0 = q0 * cos - q1 * sin
|
|
||||||
out1 = q0 * sin + q1 * cos
|
|
||||||
|
|
||||||
tl.store(
|
|
||||||
q + off_q0,
|
|
||||||
out0,
|
|
||||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
|
||||||
)
|
|
||||||
tl.store(
|
|
||||||
q + off_q1,
|
|
||||||
out1,
|
|
||||||
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def rotary_embedding_fwd(q, cos, sin):
|
|
||||||
total_len = q.shape[0]
|
|
||||||
head_num = q.shape[1]
|
|
||||||
head_dim = q.shape[2]
|
|
||||||
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
|
|
||||||
BLOCK_HEAD = 4
|
|
||||||
BLOCK_SEQ = 32
|
|
||||||
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
|
|
||||||
if head_dim >= 128:
|
|
||||||
num_warps = 8
|
|
||||||
else:
|
|
||||||
num_warps = 4
|
|
||||||
|
|
||||||
_rotary_kernel[grid](
|
|
||||||
q,
|
|
||||||
cos,
|
|
||||||
sin,
|
|
||||||
q.stride(0),
|
|
||||||
q.stride(1),
|
|
||||||
q.stride(2),
|
|
||||||
cos.stride(0),
|
|
||||||
cos.stride(1),
|
|
||||||
total_len,
|
|
||||||
HEAD_NUM=head_num,
|
|
||||||
BLOCK_HEAD=BLOCK_HEAD,
|
|
||||||
BLOCK_SEQ=BLOCK_SEQ,
|
|
||||||
HEAD_DIM=head_dim,
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=1,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
class Llama2Forwards:
|
|
||||||
@staticmethod
|
|
||||||
@triton.jit
|
|
||||||
def _rotary_kernel(
|
|
||||||
Q,
|
|
||||||
Cos,
|
|
||||||
Sin,
|
|
||||||
stride_qbs,
|
|
||||||
stride_qh,
|
|
||||||
stride_qd,
|
|
||||||
stride_cosbs,
|
|
||||||
stride_cosd,
|
|
||||||
stride_sinbs,
|
|
||||||
stride_sind,
|
|
||||||
max_total_len,
|
|
||||||
H, # N_CTX
|
|
||||||
BLOCK_HEAD: tl.constexpr,
|
|
||||||
BLOCK_SEQ: tl.constexpr,
|
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
|
||||||
):
|
|
||||||
cur_head_index = tl.program_id(0)
|
|
||||||
cur_seq_index = tl.program_id(1)
|
|
||||||
|
|
||||||
cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
|
|
||||||
cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
|
|
||||||
|
|
||||||
dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2
|
|
||||||
dim_range1 = dim_range0 + 1
|
|
||||||
off_q0 = (
|
|
||||||
cur_seq_range[:, None, None] * stride_qbs
|
|
||||||
+ cur_head_range[None, :, None] * stride_qh
|
|
||||||
+ dim_range0[None, None, :] * stride_qd
|
|
||||||
)
|
|
||||||
off_q1 = (
|
|
||||||
cur_seq_range[:, None, None] * stride_qbs
|
|
||||||
+ cur_head_range[None, :, None] * stride_qh
|
|
||||||
+ dim_range1[None, None, :] * stride_qd
|
|
||||||
)
|
|
||||||
|
|
||||||
cos_range = tl.arange(0, BLOCK_DMODEL // 2)
|
|
||||||
off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd
|
|
||||||
|
|
||||||
q0 = tl.load(
|
|
||||||
Q + off_q0,
|
|
||||||
mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H),
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
q1 = tl.load(
|
|
||||||
Q + off_q1,
|
|
||||||
mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H),
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)
|
|
||||||
sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)
|
|
||||||
|
|
||||||
out0 = q0 * cos - q1 * sin
|
|
||||||
out1 = q0 * sin + q1 * cos
|
|
||||||
|
|
||||||
tl.store(
|
|
||||||
Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H)
|
|
||||||
)
|
|
||||||
tl.store(
|
|
||||||
Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < H)
|
|
||||||
)
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@torch.no_grad()
|
|
||||||
def rotary_emb_fwd(q, cos, sin):
|
|
||||||
total_len = q.shape[0]
|
|
||||||
head_num = q.shape[1]
|
|
||||||
head_dim = q.shape[2] // 2
|
|
||||||
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
|
|
||||||
BLOCK_HEAD = 4
|
|
||||||
BLOCK_SEQ = 32
|
|
||||||
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
|
|
||||||
if head_dim >= 128:
|
|
||||||
num_warps = 8
|
|
||||||
else:
|
|
||||||
num_warps = 4
|
|
||||||
|
|
||||||
Llama2Forwards._rotary_kernel[grid](
|
|
||||||
q,
|
|
||||||
cos,
|
|
||||||
sin,
|
|
||||||
q.stride(0),
|
|
||||||
q.stride(1),
|
|
||||||
q.stride(2),
|
|
||||||
cos.stride(0),
|
|
||||||
cos.stride(1),
|
|
||||||
sin.stride(0),
|
|
||||||
sin.stride(1),
|
|
||||||
total_len,
|
|
||||||
head_num,
|
|
||||||
BLOCK_HEAD=BLOCK_HEAD,
|
|
||||||
BLOCK_SEQ=BLOCK_SEQ,
|
|
||||||
BLOCK_DMODEL=head_dim,
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=1,
|
|
||||||
)
|
|
||||||
return
|
|
@ -12,6 +12,7 @@ if HAS_TRITON:
|
|||||||
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
|
from .qkv_matmul_kernel import qkv_gemm_4d_kernel
|
||||||
from .softmax import softmax_kernel
|
from .softmax import softmax_kernel
|
||||||
|
|
||||||
|
# adpeted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py#L312
|
||||||
def self_attention_forward_without_fusion(
|
def self_attention_forward_without_fusion(
|
||||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float
|
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float
|
||||||
):
|
):
|
||||||
@ -141,6 +142,7 @@ if HAS_TRITON:
|
|||||||
)
|
)
|
||||||
return output.view(batches, -1, d_model)
|
return output.view(batches, -1, d_model)
|
||||||
|
|
||||||
|
# modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton/attention.py#L212
|
||||||
def self_attention_compute_using_triton(
|
def self_attention_compute_using_triton(
|
||||||
qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False
|
qkv, input_mask, layer_past, alibi, scale, head_size, triangular=False, use_flash=False
|
||||||
):
|
):
|
||||||
|
@ -12,363 +12,29 @@ except ImportError:
|
|||||||
HAS_TRITON = False
|
HAS_TRITON = False
|
||||||
print("please install triton from https://github.com/openai/triton")
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import (
|
||||||
|
token_att_fwd as lightllm_llama2_token_att_fwd,
|
||||||
|
)
|
||||||
|
from lightllm.models.llama2.triton_kernel.token_attention_nopad_reduceV import (
|
||||||
|
token_att_fwd2 as lightllm_llama2_token_att_fwd2,
|
||||||
|
)
|
||||||
|
from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import (
|
||||||
|
token_softmax_fwd as lightllm_llama2_token_softmax_fwd,
|
||||||
|
)
|
||||||
|
|
||||||
|
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2
|
||||||
|
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
|
||||||
|
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
|
||||||
|
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
|
||||||
|
|
||||||
|
HAS_TRITON_TOKEN_ATTENTION = True
|
||||||
|
except ImportError:
|
||||||
|
print("unable to import lightllm kernels")
|
||||||
|
HAS_TRITON_TOKEN_ATTENTION = False
|
||||||
|
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _token_attn_1_kernel(
|
|
||||||
Q,
|
|
||||||
K,
|
|
||||||
sm_scale,
|
|
||||||
kv_cache_loc,
|
|
||||||
kv_cache_start_loc,
|
|
||||||
kv_cache_seqlen,
|
|
||||||
max_kv_cache_len,
|
|
||||||
attn_out,
|
|
||||||
kv_cache_loc_b_stride,
|
|
||||||
kv_cache_loc_s_stride,
|
|
||||||
q_batch_stride,
|
|
||||||
q_head_stride,
|
|
||||||
q_head_dim_stride,
|
|
||||||
k_batch_stride,
|
|
||||||
k_head_stride,
|
|
||||||
k_head_dim_stride,
|
|
||||||
attn_head_stride,
|
|
||||||
attn_batch_stride,
|
|
||||||
HEAD_DIM: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
):
|
|
||||||
current_batch = tl.program_id(0)
|
|
||||||
current_head = tl.program_id(1)
|
|
||||||
start_n = tl.program_id(2)
|
|
||||||
|
|
||||||
offs_d = tl.arange(0, HEAD_DIM)
|
|
||||||
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
|
||||||
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
|
||||||
|
|
||||||
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
|
|
||||||
current_batch_end_index = max_kv_cache_len
|
|
||||||
|
|
||||||
off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
|
|
||||||
|
|
||||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
||||||
|
|
||||||
block_stard_index = start_n * BLOCK_N
|
|
||||||
block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
|
|
||||||
|
|
||||||
for start_mark in range(0, block_mask, 1):
|
|
||||||
q = tl.load(Q + off_q + start_mark)
|
|
||||||
offs_n_new = current_batch_start_index + offs_n
|
|
||||||
k_loc = tl.load(
|
|
||||||
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
|
|
||||||
mask=offs_n_new < current_batch_end_index,
|
|
||||||
other=0,
|
|
||||||
)
|
|
||||||
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
|
|
||||||
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
|
|
||||||
att_value = tl.sum(q[None, :] * k, 1)
|
|
||||||
att_value *= sm_scale
|
|
||||||
off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
|
|
||||||
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
|
|
||||||
return
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _token_attn_1_alibi_kernel(
|
|
||||||
Q,
|
|
||||||
K,
|
|
||||||
sm_scale,
|
|
||||||
alibi,
|
|
||||||
kv_cache_loc,
|
|
||||||
kv_cache_start_loc,
|
|
||||||
kv_cache_seqlen,
|
|
||||||
max_kv_cache_len,
|
|
||||||
attn_out,
|
|
||||||
kv_cache_loc_b_stride,
|
|
||||||
kv_cache_loc_s_stride,
|
|
||||||
q_batch_stride,
|
|
||||||
q_head_stride,
|
|
||||||
q_head_dim_stride,
|
|
||||||
k_batch_stride,
|
|
||||||
k_head_stride,
|
|
||||||
k_head_dim_stride,
|
|
||||||
attn_head_stride,
|
|
||||||
attn_batch_stride,
|
|
||||||
HEAD_DIM: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
):
|
|
||||||
current_batch = tl.program_id(0)
|
|
||||||
current_head = tl.program_id(1)
|
|
||||||
start_n = tl.program_id(2)
|
|
||||||
|
|
||||||
offs_d = tl.arange(0, HEAD_DIM)
|
|
||||||
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
|
||||||
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
|
||||||
|
|
||||||
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
|
|
||||||
current_batch_end_index = max_kv_cache_len
|
|
||||||
|
|
||||||
off_q = current_batch * q_batch_stride + current_head * q_head_stride + offs_d * q_head_dim_stride
|
|
||||||
|
|
||||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
||||||
|
|
||||||
block_stard_index = start_n * BLOCK_N
|
|
||||||
block_mask = tl.where(block_stard_index < current_batch_seq_len, 1, 0)
|
|
||||||
|
|
||||||
for start_mark in range(0, block_mask, 1):
|
|
||||||
alibi_m = tl.load(alibi + current_head)
|
|
||||||
q = tl.load(Q + off_q + start_mark)
|
|
||||||
offs_n_new = current_batch_start_index + offs_n
|
|
||||||
k_loc = tl.load(
|
|
||||||
kv_cache_loc + kv_cache_loc_b_stride * current_batch + kv_cache_loc_s_stride * offs_n_new,
|
|
||||||
mask=offs_n_new < current_batch_end_index,
|
|
||||||
other=0,
|
|
||||||
)
|
|
||||||
off_k = k_loc[:, None] * k_batch_stride + current_head * k_head_stride + offs_d[None, :] * k_head_dim_stride
|
|
||||||
k = tl.load(K + off_k, mask=offs_n_new[:, None] < current_batch_end_index, other=0.0)
|
|
||||||
att_value = tl.sum(q[None, :] * k, 1)
|
|
||||||
att_value *= sm_scale
|
|
||||||
att_value -= alibi_m * (current_batch_seq_len - 1 - offs_n)
|
|
||||||
off_o = current_head * attn_head_stride + (current_batch_in_all_start_index + offs_n) * attn_batch_stride
|
|
||||||
tl.store(attn_out + off_o, att_value, mask=offs_n_new < current_batch_end_index)
|
|
||||||
return
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def token_attn_fwd_1(
|
|
||||||
q, k, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len, alibi=None
|
|
||||||
):
|
|
||||||
BLOCK = 32
|
|
||||||
# shape constraints
|
|
||||||
q_head_dim, k_head_dim = q.shape[-1], k.shape[-1]
|
|
||||||
assert q_head_dim == k_head_dim
|
|
||||||
assert k_head_dim in {16, 32, 64, 128}
|
|
||||||
sm_scale = 1.0 / (k_head_dim**0.5)
|
|
||||||
|
|
||||||
batch, head_num = kv_cache_loc.shape[0], q.shape[1]
|
|
||||||
|
|
||||||
grid = (batch, head_num, triton.cdiv(max_kv_cache_len, BLOCK))
|
|
||||||
|
|
||||||
num_warps = 4 if k_head_dim <= 64 else 8
|
|
||||||
num_warps = 2
|
|
||||||
|
|
||||||
if alibi is not None:
|
|
||||||
_token_attn_1_alibi_kernel[grid](
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
sm_scale,
|
|
||||||
alibi,
|
|
||||||
kv_cache_loc,
|
|
||||||
kv_cache_start_loc,
|
|
||||||
kv_cache_seqlen,
|
|
||||||
max_kv_cache_len,
|
|
||||||
attn_out,
|
|
||||||
kv_cache_loc.stride(0),
|
|
||||||
kv_cache_loc.stride(1),
|
|
||||||
q.stride(0),
|
|
||||||
q.stride(1),
|
|
||||||
q.stride(2),
|
|
||||||
k.stride(0),
|
|
||||||
k.stride(1),
|
|
||||||
k.stride(2),
|
|
||||||
attn_out.stride(0),
|
|
||||||
attn_out.stride(1),
|
|
||||||
HEAD_DIM=k_head_dim,
|
|
||||||
BLOCK_N=BLOCK,
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_token_attn_1_kernel[grid](
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
sm_scale,
|
|
||||||
kv_cache_loc,
|
|
||||||
kv_cache_start_loc,
|
|
||||||
kv_cache_seqlen,
|
|
||||||
max_kv_cache_len,
|
|
||||||
attn_out,
|
|
||||||
kv_cache_loc.stride(0),
|
|
||||||
kv_cache_loc.stride(1),
|
|
||||||
q.stride(0),
|
|
||||||
q.stride(1),
|
|
||||||
q.stride(2),
|
|
||||||
k.stride(0),
|
|
||||||
k.stride(1),
|
|
||||||
k.stride(2),
|
|
||||||
attn_out.stride(0),
|
|
||||||
attn_out.stride(1),
|
|
||||||
HEAD_DIM=k_head_dim,
|
|
||||||
BLOCK_N=BLOCK,
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=1,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _token_attn_softmax_fwd(
|
|
||||||
softmax_logics,
|
|
||||||
kv_cache_start_loc,
|
|
||||||
kv_cache_seqlen,
|
|
||||||
softmax_prob_out,
|
|
||||||
logics_head_dim_stride,
|
|
||||||
logics_batch_stride,
|
|
||||||
prob_head_dim_stride,
|
|
||||||
prob_batch_stride,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
current_batch = tl.program_id(0)
|
|
||||||
current_head = tl.program_id(1)
|
|
||||||
|
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
||||||
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
|
||||||
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
|
||||||
|
|
||||||
row = tl.load(
|
|
||||||
softmax_logics
|
|
||||||
+ current_head * logics_head_dim_stride
|
|
||||||
+ (current_batch_in_all_start_index + col_offsets) * logics_batch_stride,
|
|
||||||
mask=col_offsets < current_batch_seq_len,
|
|
||||||
other=-float("inf"),
|
|
||||||
).to(tl.float32)
|
|
||||||
|
|
||||||
row_minus_max = row - tl.max(row, axis=0)
|
|
||||||
numerator = tl.exp(row_minus_max)
|
|
||||||
denominator = tl.sum(numerator, axis=0)
|
|
||||||
softmax_output = numerator / denominator
|
|
||||||
|
|
||||||
tl.store(
|
|
||||||
softmax_prob_out
|
|
||||||
+ current_head * prob_head_dim_stride
|
|
||||||
+ (current_batch_in_all_start_index + col_offsets) * prob_batch_stride,
|
|
||||||
softmax_output,
|
|
||||||
mask=col_offsets < current_batch_seq_len,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def token_attn_softmax_fwd(softmax_logics, kv_cache_start_loc, kv_cache_seqlen, softmax_prob_out, max_kv_cache_len):
|
|
||||||
BLOCK_SIZE = triton.next_power_of_2(max_kv_cache_len)
|
|
||||||
batch, head_num = kv_cache_start_loc.shape[0], softmax_logics.shape[0]
|
|
||||||
|
|
||||||
num_warps = 4
|
|
||||||
if BLOCK_SIZE >= 2048:
|
|
||||||
num_warps = 8
|
|
||||||
if BLOCK_SIZE >= 4096:
|
|
||||||
num_warps = 16
|
|
||||||
|
|
||||||
_token_attn_softmax_fwd[(batch, head_num)](
|
|
||||||
softmax_logics,
|
|
||||||
kv_cache_start_loc,
|
|
||||||
kv_cache_seqlen,
|
|
||||||
softmax_prob_out,
|
|
||||||
softmax_logics.stride(0),
|
|
||||||
softmax_logics.stride(1),
|
|
||||||
softmax_prob_out.stride(0),
|
|
||||||
softmax_prob_out.stride(1),
|
|
||||||
num_warps=num_warps,
|
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _token_attn_2_kernel(
|
|
||||||
Prob,
|
|
||||||
V,
|
|
||||||
attn_out,
|
|
||||||
kv_cache_loc,
|
|
||||||
kv_cache_start_loc,
|
|
||||||
kv_cache_seqlen,
|
|
||||||
max_kv_cache_len,
|
|
||||||
kv_cache_loc_b_stride,
|
|
||||||
kv_cache_loc_s_stride,
|
|
||||||
prob_head_dim_stride,
|
|
||||||
prob_batch_stride,
|
|
||||||
v_batch_stride,
|
|
||||||
v_head_stride,
|
|
||||||
v_head_dim_stride,
|
|
||||||
attn_out_batch_stride,
|
|
||||||
attn_out_head_stride,
|
|
||||||
attn_out_head_dim_stride,
|
|
||||||
HEAD_DIM: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
):
|
|
||||||
current_batch = tl.program_id(0)
|
|
||||||
current_head = tl.program_id(1)
|
|
||||||
|
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
|
||||||
offs_d = tl.arange(0, HEAD_DIM)
|
|
||||||
current_batch_seq_len = tl.load(kv_cache_seqlen + current_batch)
|
|
||||||
current_batch_start_index = max_kv_cache_len - current_batch_seq_len
|
|
||||||
current_batch_in_all_start_index = tl.load(kv_cache_start_loc + current_batch)
|
|
||||||
|
|
||||||
v_loc_off = current_batch * kv_cache_loc_b_stride + (current_batch_start_index + offs_n) * kv_cache_loc_s_stride
|
|
||||||
p_offs = current_head * prob_head_dim_stride + (current_batch_in_all_start_index + offs_n) * prob_batch_stride
|
|
||||||
v_offs = current_head * v_head_stride + offs_d[None, :] * v_head_dim_stride
|
|
||||||
|
|
||||||
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
|
|
||||||
for start_n in range(0, current_batch_seq_len, BLOCK_N):
|
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
||||||
p_value = tl.load(
|
|
||||||
Prob + p_offs + start_n * kv_cache_loc_s_stride,
|
|
||||||
mask=(start_n + offs_n) < current_batch_seq_len,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
v_loc = tl.load(
|
|
||||||
kv_cache_loc + v_loc_off + start_n * kv_cache_loc_s_stride,
|
|
||||||
mask=(start_n + offs_n) < current_batch_seq_len,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
v_value = tl.load(
|
|
||||||
V + v_offs + v_loc[:, None] * v_batch_stride,
|
|
||||||
mask=(start_n + offs_n[:, None]) < current_batch_seq_len,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
acc += tl.sum(p_value[:, None] * v_value, 0)
|
|
||||||
|
|
||||||
acc = acc.to(tl.float16)
|
|
||||||
off_o = (
|
|
||||||
current_batch * attn_out_batch_stride
|
|
||||||
+ current_head * attn_out_head_stride
|
|
||||||
+ offs_d * attn_out_head_dim_stride
|
|
||||||
)
|
|
||||||
out_ptrs = attn_out + off_o
|
|
||||||
tl.store(out_ptrs, acc)
|
|
||||||
return
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def token_attn_fwd_2(prob, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seqlen, max_kv_cache_len):
|
|
||||||
if triton.__version__ >= "2.1.0":
|
|
||||||
BLOCK = 128
|
|
||||||
else:
|
|
||||||
BLOCK = 64
|
|
||||||
batch, head = kv_cache_loc.shape[0], v.shape[1]
|
|
||||||
grid = (batch, head)
|
|
||||||
num_warps = 4
|
|
||||||
dim = v.shape[-1]
|
|
||||||
|
|
||||||
_token_attn_2_kernel[grid](
|
|
||||||
prob,
|
|
||||||
v,
|
|
||||||
attn_out,
|
|
||||||
kv_cache_loc,
|
|
||||||
kv_cache_start_loc,
|
|
||||||
kv_cache_seqlen,
|
|
||||||
max_kv_cache_len,
|
|
||||||
kv_cache_loc.stride(0),
|
|
||||||
kv_cache_loc.stride(1),
|
|
||||||
prob.stride(0),
|
|
||||||
prob.stride(1),
|
|
||||||
v.stride(0),
|
|
||||||
v.stride(1),
|
|
||||||
v.stride(2),
|
|
||||||
attn_out.stride(0),
|
|
||||||
attn_out.stride(1),
|
|
||||||
attn_out.stride(2),
|
|
||||||
HEAD_DIM=dim,
|
|
||||||
BLOCK_N=BLOCK,
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=1,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def token_attention_fwd(
|
def token_attention_fwd(
|
||||||
q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None
|
q, k, v, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch, alibi=None
|
||||||
@ -380,7 +46,8 @@ if HAS_TRITON:
|
|||||||
|
|
||||||
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
|
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
|
||||||
|
|
||||||
token_attn_fwd_1(
|
if alibi is None:
|
||||||
|
lightllm_llama_token_att_fwd(
|
||||||
q.view(calcu_shape1),
|
q.view(calcu_shape1),
|
||||||
k,
|
k,
|
||||||
att_m_tensor,
|
att_m_tensor,
|
||||||
@ -388,25 +55,35 @@ if HAS_TRITON:
|
|||||||
kv_cache_start_loc,
|
kv_cache_start_loc,
|
||||||
kv_cache_seq_len,
|
kv_cache_seq_len,
|
||||||
max_len_in_batch,
|
max_len_in_batch,
|
||||||
alibi=alibi,
|
)
|
||||||
|
else:
|
||||||
|
lightllm_bloom_token_att_fwd(
|
||||||
|
q.view(calcu_shape1),
|
||||||
|
k,
|
||||||
|
att_m_tensor,
|
||||||
|
alibi,
|
||||||
|
kv_cache_loc,
|
||||||
|
kv_cache_start_loc,
|
||||||
|
kv_cache_seq_len,
|
||||||
|
max_len_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
prob = torch.empty_like(att_m_tensor)
|
prob = torch.empty_like(att_m_tensor)
|
||||||
|
|
||||||
token_attn_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
|
lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
|
||||||
att_m_tensor = None
|
att_m_tensor = None
|
||||||
token_attn_fwd_2(
|
lightllm_llama_token_att_fw2(
|
||||||
prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch
|
prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
prob = None
|
prob = None
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
class Llama2TokenAttentionForwards:
|
class Llama2TokenAttentionForwards:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
|
||||||
|
# this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L8
|
||||||
def _fwd_kernel(
|
def _fwd_kernel(
|
||||||
Logics,
|
Logics,
|
||||||
V,
|
V,
|
||||||
@ -478,6 +155,7 @@ class Llama2TokenAttentionForwards:
|
|||||||
tl.store(out_ptrs, acc)
|
tl.store(out_ptrs, acc)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# this function is adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/models/llama2/triton_kernel/token_attention_nopad_softmax.py#L36
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index):
|
def token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index):
|
||||||
@ -514,277 +192,6 @@ class Llama2TokenAttentionForwards:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@triton.jit
|
|
||||||
def _fwd_kernel_token_softmax(
|
|
||||||
Logics,
|
|
||||||
B_Start_Loc,
|
|
||||||
B_Seqlen,
|
|
||||||
Prob_Out,
|
|
||||||
stride_logic_h,
|
|
||||||
stride_logic_bs,
|
|
||||||
stride_prob_h,
|
|
||||||
stride_prob_bs,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
cur_batch = tl.program_id(0)
|
|
||||||
cur_head = tl.program_id(1)
|
|
||||||
|
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
|
||||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
|
||||||
|
|
||||||
row = tl.load(
|
|
||||||
Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs,
|
|
||||||
mask=col_offsets < cur_batch_seq_len,
|
|
||||||
other=-float("inf"),
|
|
||||||
).to(tl.float32)
|
|
||||||
|
|
||||||
row_minus_max = row - tl.max(row, axis=0)
|
|
||||||
numerator = tl.exp(row_minus_max)
|
|
||||||
denominator = tl.sum(numerator, axis=0)
|
|
||||||
softmax_output = numerator / denominator
|
|
||||||
|
|
||||||
tl.store(
|
|
||||||
Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets) * stride_prob_bs,
|
|
||||||
softmax_output,
|
|
||||||
mask=col_offsets < cur_batch_seq_len,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@torch.no_grad()
|
|
||||||
def token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len):
|
|
||||||
BLOCK_SIZE = triton.next_power_of_2(max_input_len)
|
|
||||||
batch, head_num = B_Start_Loc.shape[0], Logics.shape[0]
|
|
||||||
|
|
||||||
num_warps = 4
|
|
||||||
if BLOCK_SIZE >= 2048:
|
|
||||||
num_warps = 8
|
|
||||||
if BLOCK_SIZE >= 4096:
|
|
||||||
num_warps = 16
|
|
||||||
|
|
||||||
Llama2TokenAttentionForwards._fwd_kernel_token_softmax[(batch, head_num)](
|
|
||||||
Logics,
|
|
||||||
B_Start_Loc,
|
|
||||||
B_Seqlen,
|
|
||||||
Prob_Out,
|
|
||||||
Logics.stride(0),
|
|
||||||
Logics.stride(1),
|
|
||||||
Prob_Out.stride(0),
|
|
||||||
Prob_Out.stride(1),
|
|
||||||
num_warps=num_warps,
|
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@triton.jit
|
|
||||||
def _fwd_kernel_token_att1(
|
|
||||||
Q,
|
|
||||||
K,
|
|
||||||
sm_scale,
|
|
||||||
B_Loc,
|
|
||||||
B_Start_Loc,
|
|
||||||
B_Seqlen,
|
|
||||||
max_input_len,
|
|
||||||
Att_Out,
|
|
||||||
stride_b_loc_b,
|
|
||||||
stride_b_loc_s,
|
|
||||||
stride_qbs,
|
|
||||||
stride_qh,
|
|
||||||
stride_qd,
|
|
||||||
stride_kbs,
|
|
||||||
stride_kh,
|
|
||||||
stride_kd,
|
|
||||||
att_stride_h,
|
|
||||||
att_stride_bs,
|
|
||||||
kv_group_num,
|
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
):
|
|
||||||
cur_batch = tl.program_id(0)
|
|
||||||
cur_head = tl.program_id(1)
|
|
||||||
start_n = tl.program_id(2)
|
|
||||||
|
|
||||||
cur_kv_head = cur_head // kv_group_num
|
|
||||||
|
|
||||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
|
||||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
|
||||||
|
|
||||||
cur_batch_start_index = max_input_len - cur_batch_seq_len
|
|
||||||
cur_batch_end_index = max_input_len
|
|
||||||
|
|
||||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd
|
|
||||||
|
|
||||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
||||||
|
|
||||||
block_stard_index = start_n * BLOCK_N
|
|
||||||
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
|
||||||
|
|
||||||
for start_mark in range(0, block_mask, 1):
|
|
||||||
q = tl.load(Q + off_q + start_mark)
|
|
||||||
offs_n_new = cur_batch_start_index + offs_n
|
|
||||||
k_loc = tl.load(
|
|
||||||
B_Loc + stride_b_loc_b * cur_batch + stride_b_loc_s * offs_n_new,
|
|
||||||
mask=offs_n_new < cur_batch_end_index,
|
|
||||||
other=0,
|
|
||||||
)
|
|
||||||
off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd
|
|
||||||
k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)
|
|
||||||
att_value = tl.sum(q[None, :] * k, 1)
|
|
||||||
att_value *= sm_scale
|
|
||||||
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs
|
|
||||||
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
|
|
||||||
return
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@torch.no_grad()
|
|
||||||
def token_att_fwd(q, k, att_out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len):
|
|
||||||
BLOCK = 32
|
|
||||||
# shape constraints
|
|
||||||
Lq, Lk = q.shape[-1], k.shape[-1]
|
|
||||||
assert Lq == Lk
|
|
||||||
assert Lk in {16, 32, 64, 128}
|
|
||||||
sm_scale = 1.0 / (Lk**0.5)
|
|
||||||
|
|
||||||
batch, head_num = B_Loc.shape[0], q.shape[1]
|
|
||||||
|
|
||||||
grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK))
|
|
||||||
kv_group_num = q.shape[1] // k.shape[1]
|
|
||||||
|
|
||||||
num_warps = 4 if Lk <= 64 else 8
|
|
||||||
num_warps = 2
|
|
||||||
|
|
||||||
Llama2TokenAttentionForwards._fwd_kernel_token_att1[grid](
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
sm_scale,
|
|
||||||
B_Loc,
|
|
||||||
B_Start_Loc,
|
|
||||||
B_Seqlen,
|
|
||||||
max_input_len,
|
|
||||||
att_out,
|
|
||||||
B_Loc.stride(0),
|
|
||||||
B_Loc.stride(1),
|
|
||||||
q.stride(0),
|
|
||||||
q.stride(1),
|
|
||||||
q.stride(2),
|
|
||||||
k.stride(0),
|
|
||||||
k.stride(1),
|
|
||||||
k.stride(2),
|
|
||||||
att_out.stride(0),
|
|
||||||
att_out.stride(1),
|
|
||||||
kv_group_num=kv_group_num,
|
|
||||||
BLOCK_DMODEL=Lk,
|
|
||||||
BLOCK_N=BLOCK,
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=1,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@triton.jit
|
|
||||||
def _fwd_kernel_token_att2(
|
|
||||||
Prob,
|
|
||||||
V,
|
|
||||||
Out,
|
|
||||||
B_Loc,
|
|
||||||
B_Start_Loc,
|
|
||||||
B_Seqlen,
|
|
||||||
max_input_len, # B_Start_Loc cumsum of input lens if continuous
|
|
||||||
stride_b_loc_b,
|
|
||||||
stride_b_loc_s,
|
|
||||||
stride_ph,
|
|
||||||
stride_pbs,
|
|
||||||
stride_vbs,
|
|
||||||
stride_vh,
|
|
||||||
stride_vd,
|
|
||||||
stride_obs,
|
|
||||||
stride_oh,
|
|
||||||
stride_od,
|
|
||||||
kv_group_num,
|
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
):
|
|
||||||
cur_batch = tl.program_id(0)
|
|
||||||
cur_head = tl.program_id(1)
|
|
||||||
|
|
||||||
cur_kv_head = cur_head // kv_group_num
|
|
||||||
|
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
|
||||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
|
||||||
cur_batch_start_index = max_input_len - cur_batch_seq_len
|
|
||||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
|
||||||
|
|
||||||
v_loc_off = cur_batch * stride_b_loc_b + (cur_batch_start_index + offs_n) * stride_b_loc_s
|
|
||||||
p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs
|
|
||||||
v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
|
|
||||||
|
|
||||||
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
|
|
||||||
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
|
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
||||||
p_value = tl.load(
|
|
||||||
Prob + p_offs + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0
|
|
||||||
)
|
|
||||||
v_loc = tl.load(
|
|
||||||
B_Loc + v_loc_off + start_n * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0
|
|
||||||
)
|
|
||||||
v_value = tl.load(
|
|
||||||
V + v_offs + v_loc[:, None] * stride_vbs,
|
|
||||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
acc += tl.sum(p_value[:, None] * v_value, 0)
|
|
||||||
|
|
||||||
acc = acc.to(tl.float16)
|
|
||||||
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
|
|
||||||
out_ptrs = Out + off_o
|
|
||||||
tl.store(out_ptrs, acc)
|
|
||||||
return
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@torch.no_grad()
|
|
||||||
def token_att_fwd2(prob, v, out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len):
|
|
||||||
if triton.__version__ >= "2.1.0":
|
|
||||||
BLOCK = 128
|
|
||||||
else:
|
|
||||||
BLOCK = 64
|
|
||||||
batch, head = B_Loc.shape[0], prob.shape[0]
|
|
||||||
grid = (batch, head)
|
|
||||||
num_warps = 4
|
|
||||||
dim = v.shape[-1]
|
|
||||||
|
|
||||||
kv_group_num = prob.shape[0] // v.shape[1]
|
|
||||||
|
|
||||||
Llama2TokenAttentionForwards._fwd_kernel_token_att2[grid](
|
|
||||||
prob,
|
|
||||||
v,
|
|
||||||
out,
|
|
||||||
B_Loc,
|
|
||||||
B_Start_Loc,
|
|
||||||
B_Seqlen,
|
|
||||||
max_input_len,
|
|
||||||
B_Loc.stride(0),
|
|
||||||
B_Loc.stride(1),
|
|
||||||
prob.stride(0),
|
|
||||||
prob.stride(1),
|
|
||||||
v.stride(0),
|
|
||||||
v.stride(1),
|
|
||||||
v.stride(2),
|
|
||||||
out.stride(0),
|
|
||||||
out.stride(1),
|
|
||||||
out.stride(2),
|
|
||||||
kv_group_num=kv_group_num,
|
|
||||||
BLOCK_DMODEL=dim,
|
|
||||||
BLOCK_N=BLOCK,
|
|
||||||
num_warps=num_warps,
|
|
||||||
num_stages=1,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# this is the interface of llama2 attn forward
|
# this is the interface of llama2 attn forward
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -796,7 +203,7 @@ class Llama2TokenAttentionForwards:
|
|||||||
calcu_shape1 = (batch_size, head_num, head_dim)
|
calcu_shape1 = (batch_size, head_num, head_dim)
|
||||||
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
|
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
|
||||||
|
|
||||||
Llama2TokenAttentionForwards.token_att_fwd(
|
lightllm_llama2_token_att_fwd(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
att_m_tensor,
|
att_m_tensor,
|
||||||
@ -808,12 +215,12 @@ class Llama2TokenAttentionForwards:
|
|||||||
|
|
||||||
if triton.__version__ == "2.0.0":
|
if triton.__version__ == "2.0.0":
|
||||||
prob = torch.empty_like(att_m_tensor)
|
prob = torch.empty_like(att_m_tensor)
|
||||||
Llama2TokenAttentionForwards.token_softmax_fwd(
|
lightllm_llama2_token_softmax_fwd(
|
||||||
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
|
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
|
||||||
)
|
)
|
||||||
att_m_tensor = None
|
att_m_tensor = None
|
||||||
|
|
||||||
Llama2TokenAttentionForwards.token_att_fwd2(
|
lightllm_llama2_token_att_fwd2(
|
||||||
prob,
|
prob,
|
||||||
v,
|
v,
|
||||||
attn_out.view(calcu_shape1),
|
attn_out.view(calcu_shape1),
|
||||||
|
@ -3,7 +3,6 @@ import os
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.profiler import ProfilerActivity, profile, record_function
|
|
||||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
@ -16,6 +15,7 @@ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
|||||||
|
|
||||||
|
|
||||||
def print_perf_stats(latency_set, config, bs, warmup=3):
|
def print_perf_stats(latency_set, config, bs, warmup=3):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
# trim warmup queries
|
# trim warmup queries
|
||||||
latency_set = list(latency_set)
|
latency_set = list(latency_set)
|
||||||
latency_set = latency_set[warmup:]
|
latency_set = latency_set[warmup:]
|
||||||
@ -38,24 +38,29 @@ def run_llama_test(args):
|
|||||||
max_batch_size = args.batch_size
|
max_batch_size = args.batch_size
|
||||||
max_input_len = args.input_len
|
max_input_len = args.input_len
|
||||||
max_output_len = args.output_len
|
max_output_len = args.output_len
|
||||||
|
args.test_mode
|
||||||
|
|
||||||
|
print("max_batch_size : " + str(max_batch_size))
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
|
tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
|
||||||
tokenizer.pad_token_id = tokenizer.unk_token_id
|
tokenizer.pad_token_id = tokenizer.unk_token_id
|
||||||
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
|
model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
|
||||||
model = model.half()
|
model = model.half()
|
||||||
model_config = model.config
|
model.config
|
||||||
|
|
||||||
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
|
||||||
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
|
||||||
|
|
||||||
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
generate_kwargs = dict(max_new_tokens=1, do_sample=False)
|
||||||
input_tokens = {
|
input_tokens = {
|
||||||
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
|
"input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
|
||||||
"attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
|
"attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
|
||||||
}
|
}
|
||||||
|
|
||||||
iters = 10
|
iters = 10
|
||||||
times = []
|
prefill_times = []
|
||||||
|
|
||||||
|
warmup = 3
|
||||||
|
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -65,17 +70,39 @@ def run_llama_test(args):
|
|||||||
end = time.time()
|
end = time.time()
|
||||||
out_len = outputs.shape[1]
|
out_len = outputs.shape[1]
|
||||||
print("generation time {} s".format(str(end - start)))
|
print("generation time {} s".format(str(end - start)))
|
||||||
times.append((end - start) / (out_len - max_input_len))
|
print(out_len - max_input_len)
|
||||||
|
prefill_times.append((end - start) / (out_len - max_input_len))
|
||||||
|
|
||||||
print("outputs, ", len(outputs))
|
prefill_times = prefill_times[warmup:]
|
||||||
print_perf_stats(times, model_config, max_batch_size)
|
prefill_time_avg = sum(prefill_times) / len(prefill_times)
|
||||||
|
generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
|
||||||
|
|
||||||
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
|
times = []
|
||||||
with record_function("model_inference"):
|
decoder_times = []
|
||||||
|
for i in range(iters):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
start = time.time()
|
||||||
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
|
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
end = time.time()
|
||||||
|
out_len = outputs.shape[1]
|
||||||
|
print("generation time {} s".format(str(end - start)))
|
||||||
|
print(out_len - max_input_len)
|
||||||
|
times.append((end - start) / (out_len - max_input_len))
|
||||||
|
if args.test_mode == "decoder_test":
|
||||||
|
decoder_times.append((end - start - prefill_time_avg) / (out_len - max_input_len - 1))
|
||||||
|
|
||||||
|
times = times[warmup:]
|
||||||
|
latency = sum(times) / len(times)
|
||||||
|
print("total process latency is : " + str(latency) + " s")
|
||||||
|
print("total throughput is : " + str(1 / latency * max_batch_size))
|
||||||
|
|
||||||
|
if args.test_mode == "decoder_test":
|
||||||
|
decoder_times = decoder_times[warmup:]
|
||||||
|
latency = sum(decoder_times) / len(decoder_times)
|
||||||
|
|
||||||
|
print("decoder process latency is : " + str(latency) + " s")
|
||||||
|
print("decoder throughput is : " + str(1 / latency * max_batch_size))
|
||||||
|
|
||||||
|
|
||||||
def check_llama(rank, world_size, port, args):
|
def check_llama(rank, world_size, port, args):
|
||||||
@ -95,8 +122,11 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
|
parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
|
||||||
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
|
parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
|
||||||
parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
|
parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
|
||||||
parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
|
parser.add_argument("--input_len", type=int, default=256, help="Maximum input length")
|
||||||
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
|
parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"]
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -11,3 +11,6 @@ ninja
|
|||||||
torch>=1.12
|
torch>=1.12
|
||||||
safetensors
|
safetensors
|
||||||
einops
|
einops
|
||||||
|
sentencepiece
|
||||||
|
google
|
||||||
|
protobuf
|
||||||
|
@ -1,63 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
try:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
|
||||||
|
|
||||||
HAS_TRITON = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_TRITON = False
|
|
||||||
print("please install triton from https://github.com/openai/triton")
|
|
||||||
|
|
||||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
|
||||||
|
|
||||||
|
|
||||||
def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
|
||||||
xq = xq.view(bs, 1, num_head, head_dim)
|
|
||||||
xk = xk.view(bs, seqlen, num_head, head_dim)
|
|
||||||
xv = xv.view(bs, seqlen, num_head, head_dim)
|
|
||||||
|
|
||||||
logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5)
|
|
||||||
prob = torch.softmax(logics, dim=1)
|
|
||||||
prob = prob.view(bs, seqlen, num_head, 1)
|
|
||||||
|
|
||||||
return torch.sum(prob * xv, dim=1, keepdim=False)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
|
||||||
)
|
|
||||||
def test():
|
|
||||||
Z, head_num, seq_len, head_dim = 2, 32, 2048, 128
|
|
||||||
dtype = torch.float16
|
|
||||||
|
|
||||||
# attn out: 2,4096
|
|
||||||
q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
|
|
||||||
k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
|
|
||||||
v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
|
|
||||||
o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda")
|
|
||||||
max_kv_cache_len = seq_len
|
|
||||||
kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda")
|
|
||||||
kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda")
|
|
||||||
kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda")
|
|
||||||
other_kv_index = 2048
|
|
||||||
|
|
||||||
kv_cache_seq_len[:] = seq_len
|
|
||||||
kv_cache_start_loc[0] = 0
|
|
||||||
kv_cache_start_loc[1] = seq_len
|
|
||||||
|
|
||||||
for i in range(Z):
|
|
||||||
kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda")
|
|
||||||
|
|
||||||
Llama2TokenAttentionForwards.token_attn(
|
|
||||||
q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, other_kv_index
|
|
||||||
)
|
|
||||||
torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim)
|
|
||||||
assert torch.allclose(torch_out, o, atol=1e-3, rtol=0)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test()
|
|
@ -1,55 +0,0 @@
|
|||||||
# Adapted from ModelTC https://github.com/ModelTC/lightllm
|
|
||||||
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
try:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
|
|
||||||
|
|
||||||
HAS_TRITON = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_TRITON = False
|
|
||||||
print("please install triton from https://github.com/openai/triton")
|
|
||||||
|
|
||||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
|
||||||
|
|
||||||
|
|
||||||
def torch_rotary_emb(x, cos, sin):
|
|
||||||
seq_len, h, dim = x.shape
|
|
||||||
x0 = x[:, :, 0 : dim // 2]
|
|
||||||
x1 = x[:, :, dim // 2 : dim]
|
|
||||||
cos = cos.view((seq_len, 1, dim // 2))
|
|
||||||
sin = sin.view((seq_len, 1, dim // 2))
|
|
||||||
o0 = x0 * cos - x1 * sin
|
|
||||||
o1 = x0 * sin + x1 * cos
|
|
||||||
return torch.cat((o0, o1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
|
||||||
)
|
|
||||||
def test_rotary_emb():
|
|
||||||
SEQ_LEN = 1
|
|
||||||
HEAD_NUM = 32
|
|
||||||
HEAD_DIM = 128
|
|
||||||
dtype = torch.half
|
|
||||||
# create data
|
|
||||||
x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
|
|
||||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
|
||||||
cos_shape = (SEQ_LEN, HEAD_DIM // 2)
|
|
||||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
|
||||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
|
||||||
# forward pass
|
|
||||||
y_torch = torch_rotary_emb(x, cos, sin)
|
|
||||||
rotary_embedding_fwd(x, cos, sin)
|
|
||||||
y_triton = x
|
|
||||||
# compare
|
|
||||||
assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_rotary_emb()
|
|
@ -1,74 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
try:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1
|
|
||||||
|
|
||||||
HAS_TRITON = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_TRITON = False
|
|
||||||
print("please install triton from https://github.com/openai/triton")
|
|
||||||
|
|
||||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
|
||||||
|
|
||||||
|
|
||||||
def torch_attn(xq, xk, bs, seqlen, num_head, head_dim):
|
|
||||||
xq = xq.view(bs, 1, num_head, head_dim)
|
|
||||||
xk = xk.view(bs, seqlen, num_head, head_dim)
|
|
||||||
keys = xk
|
|
||||||
xq = xq.transpose(1, 2)
|
|
||||||
keys = keys.transpose(1, 2)
|
|
||||||
scores = (
|
|
||||||
(torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1)
|
|
||||||
)
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
def torch_attn_1(xq, xk, seqlen, num_head, head_dim):
|
|
||||||
xq = xq.view(1, num_head, head_dim)
|
|
||||||
xk = xk.view(seqlen, num_head, head_dim)
|
|
||||||
logics = torch.sum(xq * xk, dim=-1, keepdim=False)
|
|
||||||
|
|
||||||
logics = logics.transpose(0, 1) / math.sqrt(head_dim)
|
|
||||||
return logics
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
|
||||||
)
|
|
||||||
def test_attn_1():
|
|
||||||
pass
|
|
||||||
|
|
||||||
batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128
|
|
||||||
|
|
||||||
dtype = torch.float16
|
|
||||||
|
|
||||||
q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
|
|
||||||
k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
|
|
||||||
attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda")
|
|
||||||
|
|
||||||
b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda")
|
|
||||||
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
|
||||||
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
|
||||||
kv_cache_start_loc[i] = i * seq_len
|
|
||||||
kv_cache_seq_len[i] = seq_len
|
|
||||||
b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
|
|
||||||
|
|
||||||
token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
|
|
||||||
|
|
||||||
torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze()
|
|
||||||
o = attn_out.squeeze()
|
|
||||||
print("max ", torch.max(torch.abs(torch_out - o)))
|
|
||||||
print("mean ", torch.mean(torch.abs(torch_out - o)))
|
|
||||||
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_attn_1()
|
|
@ -1,63 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
try:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2
|
|
||||||
|
|
||||||
HAS_TRITON = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_TRITON = False
|
|
||||||
print("please install triton from https://github.com/openai/triton")
|
|
||||||
|
|
||||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
|
||||||
|
|
||||||
|
|
||||||
def torch_attn(V, P, bs, seqlen, num_head, head_dim):
|
|
||||||
V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2)
|
|
||||||
P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1)
|
|
||||||
attn_out = torch.matmul(P, V)
|
|
||||||
|
|
||||||
return attn_out
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
|
||||||
)
|
|
||||||
def test_token_attn_2():
|
|
||||||
pass
|
|
||||||
|
|
||||||
batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128
|
|
||||||
dtype = torch.float16
|
|
||||||
|
|
||||||
V = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=10)
|
|
||||||
Prob = (
|
|
||||||
torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda")
|
|
||||||
.normal_(mean=0.4, std=0.2)
|
|
||||||
.reshape(head_num, batch_size, seq_len)
|
|
||||||
.softmax(-1)
|
|
||||||
.reshape(head_num, batch_size * seq_len)
|
|
||||||
)
|
|
||||||
attn_out = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda")
|
|
||||||
|
|
||||||
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
|
||||||
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
|
||||||
kv_cache_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda")
|
|
||||||
for i in range(batch_size):
|
|
||||||
kv_cache_start_loc[i] = i * seq_len
|
|
||||||
kv_cache_seq_len[i] = seq_len
|
|
||||||
kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
|
|
||||||
|
|
||||||
token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
|
|
||||||
|
|
||||||
torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze()
|
|
||||||
o = attn_out
|
|
||||||
print("max ", torch.max(torch.abs(torch_out - o)))
|
|
||||||
print("mean ", torch.mean(torch.abs(torch_out - o)))
|
|
||||||
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_token_attn_2()
|
|
@ -3,16 +3,13 @@ import torch
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pass
|
|
||||||
|
|
||||||
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
|
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
|
||||||
|
|
||||||
HAS_TRITON = True
|
HAS_TRITON = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_TRITON = False
|
HAS_TRITON = False
|
||||||
print("please install triton from https://github.com/openai/triton")
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) >= version.parse("11.6")
|
||||||
|
|
||||||
|
|
||||||
def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
||||||
|
Loading…
Reference in New Issue
Block a user