mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[Refactor] Integrated some lightllm kernels into token-attention (#4946)
* add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
|
||||
## Introduction
|
||||
|
||||
`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.
|
||||
`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including LightLLM, TGI, vLLM, FasterTransformer and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.
|
||||
|
||||
## Design
|
||||
|
||||
@@ -62,6 +62,12 @@ triton==2.0.0.dev20221202
|
||||
vllm
|
||||
# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c
|
||||
flash-attention
|
||||
|
||||
# install lightllm since we depend on lightllm triton kernels
|
||||
git clone https://github.com/ModelTC/lightllm
|
||||
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
||||
cd lightllm
|
||||
pip3 install -e .
|
||||
```
|
||||
|
||||
### Docker
|
||||
@@ -73,6 +79,17 @@ You can use docker run to use docker container to set-up environment
|
||||
docker pull hpcaitech/colossalai-inference:v2
|
||||
docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash
|
||||
|
||||
# enter into docker container
|
||||
cd /path/to/CollossalAI
|
||||
pip install -e .
|
||||
|
||||
# install lightllm
|
||||
git clone https://github.com/ModelTC/lightllm
|
||||
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
|
||||
cd lightllm
|
||||
pip3 install -e .
|
||||
|
||||
|
||||
```
|
||||
|
||||
### Dive into fast-inference!
|
||||
|
@@ -5,7 +5,7 @@ import torch
|
||||
|
||||
from .kvcache_manager import MemoryManager
|
||||
|
||||
|
||||
# adapted from: lightllm/server/router/model_infer/infer_batch.py
|
||||
@dataclass
|
||||
class BatchInferState:
|
||||
r"""
|
||||
@@ -41,6 +41,7 @@ class BatchInferState:
|
||||
def set_cache_manager(self, manager: MemoryManager):
|
||||
self.cache_manager = manager
|
||||
|
||||
# adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1
|
||||
@staticmethod
|
||||
def init_block_loc(
|
||||
b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
|
||||
|
@@ -1,7 +1,9 @@
|
||||
# Adapted from lightllm/common/mem_manager.py
|
||||
# of the ModelTC/lightllm GitHub repository
|
||||
# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
|
||||
|
||||
"""
|
||||
Refered/Modified from lightllm/common/mem_manager.py
|
||||
of the ModelTC/lightllm GitHub repository
|
||||
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
|
||||
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
|
||||
"""
|
||||
import torch
|
||||
from transformers.utils import logging
|
||||
|
||||
|
@@ -6,8 +6,6 @@ from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd
|
||||
from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards
|
||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
||||
ChatGLMForConditionalGeneration,
|
||||
@@ -20,6 +18,14 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
|
||||
|
||||
from ._utils import copy_kv_to_mem_cache
|
||||
|
||||
try:
|
||||
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd
|
||||
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
|
||||
# This func is same as Llama model init_to_get_rotary, we should move them into _utils.py
|
||||
def _init_to_get_rotary(self, base=10000):
|
||||
@@ -433,17 +439,17 @@ class ChatGLM2InferenceForwards:
|
||||
|
||||
cos, sin = infer_state.position_cos, infer_state.position_sin
|
||||
|
||||
Llama2Forwards.rotary_emb_fwd(
|
||||
chatglm2_rotary_emb_fwd(
|
||||
query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin
|
||||
)
|
||||
if self.multi_query_attention:
|
||||
Llama2Forwards.rotary_emb_fwd(
|
||||
chatglm2_rotary_emb_fwd(
|
||||
key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head),
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
else:
|
||||
Llama2Forwards.rotary_emb_fwd(
|
||||
chatglm2_rotary_emb_fwd(
|
||||
key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
|
||||
cos,
|
||||
sin,
|
||||
@@ -474,7 +480,7 @@ class ChatGLM2InferenceForwards:
|
||||
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
|
||||
|
||||
# NOTE: no bug in context attn fwd (del it )
|
||||
llama2_context_attn_fwd(
|
||||
lightllm_llama2_context_attention_fwd(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
|
@@ -5,12 +5,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
|
||||
|
||||
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
|
||||
from colossalai.kernel.triton import (
|
||||
llama2_context_attn_fwd,
|
||||
llama_context_attn_fwd,
|
||||
rotary_embedding_fwd,
|
||||
token_attention_fwd,
|
||||
)
|
||||
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
|
||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||
|
||||
from ._utils import copy_kv_to_mem_cache
|
||||
@@ -29,6 +24,17 @@ except:
|
||||
)
|
||||
HAS_VLLM_KERNERL = False
|
||||
|
||||
try:
|
||||
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
|
||||
context_attention_fwd as lightllm_llama2_context_attention_fwd,
|
||||
)
|
||||
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
|
||||
|
||||
HAS_LIGHTLLM_KERNEL = True
|
||||
except:
|
||||
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
|
||||
HAS_LIGHTLLM_KERNEL = False
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
@@ -280,8 +286,8 @@ class LlamaInferenceForwards:
|
||||
cos, sin = infer_state.position_cos, infer_state.position_sin
|
||||
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
|
||||
|
||||
rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
||||
rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
|
||||
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
|
||||
llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
|
||||
|
||||
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
|
||||
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
|
||||
@@ -312,7 +318,7 @@ class LlamaInferenceForwards:
|
||||
infer_state.cache_manager.past_key_values_length,
|
||||
)
|
||||
else:
|
||||
llama2_context_attn_fwd(
|
||||
lightllm_llama2_context_attention_fwd(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -371,6 +377,7 @@ class LlamaInferenceForwards:
|
||||
infer_state.cache_manager.past_key_values_length,
|
||||
infer_state.other_kv_index,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
@@ -12,8 +12,7 @@ from ..modeling._utils import init_to_get_rotary
|
||||
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton import rmsnorm_forward
|
||||
|
||||
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
|
||||
HAS_TRITON_RMSNORM = True
|
||||
except:
|
||||
print("you should install triton from https://github.com/openai/triton")
|
||||
@@ -22,9 +21,8 @@ except:
|
||||
|
||||
def get_triton_rmsnorm_forward():
|
||||
if HAS_TRITON_RMSNORM:
|
||||
|
||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
||||
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
|
||||
|
||||
return _triton_rmsnorm_forward
|
||||
else:
|
||||
|
Reference in New Issue
Block a user