[Refactor] Integrated some lightllm kernels into token-attention (#4946)

* add some req for inference

* clean codes

* add codes

* add some lightllm deps

* clean codes

* hello

* delete rms files

* add some comments

* add comments

* add doc

* add lightllm deps

* add lightllm cahtglm2 kernels

* add lightllm cahtglm2 kernels

* replace rotary embedding with lightllm kernel

* add some commnets

* add some comments

* add some comments

* add

* replace fwd kernel att1

* fix a arg

* add

* add

* fix token attention

* add some comments

* clean codes

* modify comments

* fix readme

* fix bug

* fix bug

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
This commit is contained in:
Cuiqing Li
2023-10-19 22:22:47 +08:00
committed by GitHub
parent 11009103be
commit 3a41e8304e
20 changed files with 160 additions and 1555 deletions

View File

@@ -4,7 +4,7 @@
## Introduction
`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including TGI, vLLM, FasterTransformer, LightLLM and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.
`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including LightLLM, TGI, vLLM, FasterTransformer and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.
## Design
@@ -62,6 +62,12 @@ triton==2.0.0.dev20221202
vllm
# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c
flash-attention
# install lightllm since we depend on lightllm triton kernels
git clone https://github.com/ModelTC/lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
cd lightllm
pip3 install -e .
```
### Docker
@@ -73,6 +79,17 @@ You can use docker run to use docker container to set-up environment
docker pull hpcaitech/colossalai-inference:v2
docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash
# enter into docker container
cd /path/to/CollossalAI
pip install -e .
# install lightllm
git clone https://github.com/ModelTC/lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
cd lightllm
pip3 install -e .
```
### Dive into fast-inference!

View File

@@ -5,7 +5,7 @@ import torch
from .kvcache_manager import MemoryManager
# adapted from: lightllm/server/router/model_infer/infer_batch.py
@dataclass
class BatchInferState:
r"""
@@ -41,6 +41,7 @@ class BatchInferState:
def set_cache_manager(self, manager: MemoryManager):
self.cache_manager = manager
# adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1
@staticmethod
def init_block_loc(
b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor

View File

@@ -1,7 +1,9 @@
# Adapted from lightllm/common/mem_manager.py
# of the ModelTC/lightllm GitHub repository
# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
"""
Refered/Modified from lightllm/common/mem_manager.py
of the ModelTC/lightllm GitHub repository
https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
"""
import torch
from transformers.utils import logging

View File

@@ -6,8 +6,6 @@ from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton.context_attention import llama2_context_attn_fwd
from colossalai.kernel.triton.rotary_embedding_kernel import Llama2Forwards
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
@@ -20,6 +18,14 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
from ._utils import copy_kv_to_mem_cache
try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd
HAS_LIGHTLLM_KERNEL = True
except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False
# This func is same as Llama model init_to_get_rotary, we should move them into _utils.py
def _init_to_get_rotary(self, base=10000):
@@ -433,17 +439,17 @@ class ChatGLM2InferenceForwards:
cos, sin = infer_state.position_cos, infer_state.position_sin
Llama2Forwards.rotary_emb_fwd(
chatglm2_rotary_emb_fwd(
query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin
)
if self.multi_query_attention:
Llama2Forwards.rotary_emb_fwd(
chatglm2_rotary_emb_fwd(
key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head),
cos,
sin,
)
else:
Llama2Forwards.rotary_emb_fwd(
chatglm2_rotary_emb_fwd(
key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
cos,
sin,
@@ -474,7 +480,7 @@ class ChatGLM2InferenceForwards:
attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
# NOTE: no bug in context attn fwd (del it )
llama2_context_attn_fwd(
lightllm_llama2_context_attention_fwd(
query_layer,
key_layer,
value_layer,

View File

@@ -5,12 +5,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import (
llama2_context_attn_fwd,
llama_context_attn_fwd,
rotary_embedding_fwd,
token_attention_fwd,
)
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
from ._utils import copy_kv_to_mem_cache
@@ -29,6 +24,17 @@ except:
)
HAS_VLLM_KERNERL = False
try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
HAS_LIGHTLLM_KERNEL = True
except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
@@ -280,8 +286,8 @@ class LlamaInferenceForwards:
cos, sin = infer_state.position_cos, infer_state.position_sin
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
@@ -312,7 +318,7 @@ class LlamaInferenceForwards:
infer_state.cache_manager.past_key_values_length,
)
else:
llama2_context_attn_fwd(
lightllm_llama2_context_attention_fwd(
query_states,
key_states,
value_states,
@@ -371,6 +377,7 @@ class LlamaInferenceForwards:
infer_state.cache_manager.past_key_values_length,
infer_state.other_kv_index,
)
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)

View File

@@ -12,8 +12,7 @@ from ..modeling._utils import init_to_get_rotary
from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
try:
from colossalai.kernel.triton import rmsnorm_forward
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
@@ -22,9 +21,8 @@ except:
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
return _triton_rmsnorm_forward
else: