[Kernels]added flash-decoidng of triton (#5063)

* added flash-decoidng of triton based on lightllm kernel

* add req

* clean

* clean

* delete build.sh

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
This commit is contained in:
Cuiqing Li (李崔卿)
2023-11-20 13:58:29 +08:00
committed by GitHub
parent fd6482ad8c
commit bce919708f
6 changed files with 82 additions and 43 deletions

View File

@@ -1,24 +0,0 @@
#!/usr/bin/env bash
# install triton
pip install triton
pip install transformers
# install lightllm and flash-attention
mkdir 3rdParty
cd 3rdParty
git clone https://github.com/ModelTC/lightllm
cd lightllm
git checkout 28c1267cfca536b7b4f28e921e03de735b003039
pip install -e .
cd ..
git clone -recursive https://github.com/Dao-AILab/flash-attention
cd flash-attention
pip install -e .
cd ../../

View File

@@ -27,9 +27,15 @@ except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False
try:
from colossalai.kernel.triton.flash_decoding import token_flash_decoding
HAS_TRITON_FLASH_DECODING_KERNEL = True
except:
print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
HAS_TRITON_FLASH_DECODING_KERNEL = False
try:
from flash_attn import flash_attn_with_kvcache
HAS_FLASH_KERNEL = True
except:
HAS_FLASH_KERNEL = False
@@ -42,7 +48,6 @@ def rotate_half(x):
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
@@ -67,7 +72,6 @@ def llama_triton_context_attention(
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
@@ -78,7 +82,6 @@ def llama_triton_context_attention(
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
@@ -90,13 +93,20 @@ def llama_triton_context_attention(
attn_output,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1):
if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1:
token_flash_decoding(q = query_states,
o_tensor = attn_output,
infer_state = infer_state,
q_head_num = q_head_num,
head_dim = head_dim,
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id])
return
if num_key_value_groups == 1:
token_attention_fwd(
query_states,
@@ -106,7 +116,6 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
@@ -118,7 +127,6 @@ def llama_triton_token_attention(query_states, attn_output, infer_state, num_key
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
# infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
infer_state.other_kv_index,
)
@@ -451,10 +459,14 @@ class LlamaInferenceForwards:
)
if HAS_LIGHTLLM_KERNEL:
attn_output = torch.empty_like(query_states)
llama_triton_token_attention(
query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups
)
llama_triton_token_attention(query_states = query_states,
attn_output = attn_output,
infer_state = infer_state,
num_key_value_groups = self.num_key_value_groups,
q_head_num = q_len * self.num_heads,
head_dim = self.head_dim)
else:
self.num_heads // self.num_key_value_heads
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]