mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[kernel/fix] Performance Optimization for Decoding Kernel and Benchmarking (#5274)
* prevent re-creating intermediate tensors * add singleton class holding intermediate values * fix triton kernel api * add benchmark in pytest * fix kernel api and add benchmark * revise flash decoding triton kernel in/out shapes * fix calling of triton kernel in modeling * fix pytest: extract to util functions
This commit is contained in:
@@ -6,7 +6,7 @@ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecode
|
||||
|
||||
from colossalai.inference.modeling.layers.attention import PagedAttention
|
||||
from colossalai.inference.struct import BatchInfo
|
||||
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_fwd
|
||||
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_attention
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
|
||||
@@ -209,7 +209,15 @@ def llama_attn_forward(
|
||||
if HAS_TRITON:
|
||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
attn_output = flash_decoding_fwd(query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
|
||||
# TODO Add dummy transpose and squeeze on in/out tensors of the triton flash decoding kernel
|
||||
# in order to maintain compatibility. This part as well as the logics of handling query_states and attn_output
|
||||
# should be revised, as we could see in previous part of `llama_attn_forward` we still have
|
||||
# redundant transpose and the in/out of torch- and triton-version decoding kernel are not consistent.
|
||||
query_states = query_states.transpose(1, 2)
|
||||
attn_output = flash_decoding_attention(
|
||||
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
|
||||
)
|
||||
attn_output = attn_output.squeeze(1)
|
||||
else:
|
||||
attn_output = PagedAttention.pad_decoding_forward(
|
||||
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask
|
||||
|
@@ -9,7 +9,9 @@ except ImportError:
|
||||
# There may exist import error even if we have triton installed.
|
||||
if HAS_TRITON:
|
||||
from .context_attn_unpad import context_attention_unpadded
|
||||
from .flash_decoding import flash_decoding_fwd
|
||||
from .flash_decoding import flash_decoding_attention
|
||||
from .flash_decoding_utils import FDIntermTensors
|
||||
|
||||
from .rms_layernorm import rms_layernorm
|
||||
from .gptq_triton import gptq_fused_linear_triton
|
||||
from .kvcache_copy import copy_kv_to_blocked_cache
|
||||
@@ -18,10 +20,11 @@ if HAS_TRITON:
|
||||
|
||||
__all__ = [
|
||||
"context_attention_unpadded",
|
||||
"flash_decoding_fwd",
|
||||
"flash_decoding_attention",
|
||||
"copy_kv_to_blocked_cache",
|
||||
"softmax",
|
||||
"rms_layernorm",
|
||||
"gptq_fused_linear_triton",
|
||||
"rotary_embedding",
|
||||
"FDIntermTensors",
|
||||
]
|
||||
|
@@ -9,15 +9,16 @@ import triton.language as tl
|
||||
# Triton 2.1.0
|
||||
@triton.jit
|
||||
def _flash_decoding_fwd_kernel(
|
||||
Q, # [batch_size, head_num, head_dim]
|
||||
Q, # [batch_size, head_num, q_len(1), head_dim]
|
||||
KCache, # [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
VCache, # [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
block_tables, # [batch_size, max_blocks_per_sequence]
|
||||
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
|
||||
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
||||
context_lengths, # [batch_size]
|
||||
kv_seq_len, # [batch_size]
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_ql,
|
||||
stride_qd,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
@@ -51,7 +52,7 @@ def _flash_decoding_fwd_kernel(
|
||||
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
|
||||
|
||||
# get the current (kv) sequence length from provided context lengths tensor
|
||||
cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx)
|
||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
|
||||
|
||||
offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
|
||||
q = tl.load(Q + offsets_q)
|
||||
@@ -65,7 +66,6 @@ def _flash_decoding_fwd_kernel(
|
||||
cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
|
||||
|
||||
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
|
||||
# TODO might want to remove if-else block?
|
||||
return
|
||||
|
||||
cur_occupied_size = tl.where(
|
||||
@@ -132,7 +132,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
||||
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
|
||||
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
||||
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
|
||||
context_lengths,
|
||||
kv_seq_len,
|
||||
stride_mid_ot,
|
||||
stride_mid_oh,
|
||||
stride_mid_ob,
|
||||
@@ -141,6 +141,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
||||
stride_o_lseh,
|
||||
stride_o_lseb,
|
||||
stride_ob,
|
||||
stride_ol,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
BLOCK_KV: tl.constexpr,
|
||||
@@ -149,7 +150,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
||||
cur_seq_idx = tl.program_id(0)
|
||||
cur_head_idx = tl.program_id(1)
|
||||
|
||||
cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx)
|
||||
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
|
||||
# NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have
|
||||
@@ -181,21 +182,46 @@ def _flash_decoding_fwd_reduce_kernel(
|
||||
|
||||
# Decoding Stage
|
||||
# Used with blocked KV Cache (PagedAttention)
|
||||
def flash_decoding_fwd(
|
||||
q: torch.Tensor, # [bsz(e.g.num_tokens), 1, num_heads, head_dim]
|
||||
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
context_lengths: torch.Tensor, # [batch_size]
|
||||
block_tables: torch.Tensor, # [batch_size, max_blocks_per_sequence]
|
||||
def flash_decoding_attention(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
kv_seq_len: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
block_size: int,
|
||||
num_kv_group: int = 1,
|
||||
max_seq_len_in_batch: int = None,
|
||||
mid_output: torch.Tensor = None,
|
||||
mid_output_lse: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
kv_group_num: int = 1,
|
||||
):
|
||||
bsz, _, num_heads, head_dim = q.shape
|
||||
"""
|
||||
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): [bsz, num_heads, q_len(1), head_dim]
|
||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
kv_seq_len (torch.Tensor): [batch_size]
|
||||
records the (kv) sequence lengths incorporating past kv sequence lengths.
|
||||
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
||||
max_seq_len_in_batch (int): Maximum sequence length in the batch.
|
||||
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
|
||||
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
|
||||
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
|
||||
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
|
||||
block_size (int): Size of each block in the blocked key/value cache.
|
||||
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
Output tensor with shape [bsz, num_heads, q_len, head_dim]
|
||||
"""
|
||||
bsz, num_heads, _, head_dim = q.shape
|
||||
|
||||
assert head_dim in {32, 64, 128, 256}
|
||||
assert context_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
|
||||
f"Got incompatible batch size (number of seqs):\n"
|
||||
f" Conext lengths bsz {context_lengths.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
|
||||
f" KV seq lengths bsz {kv_seq_len.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
|
||||
f"batch size {bsz}"
|
||||
)
|
||||
assert k_cache.size(-1) == v_cache.size(-1) == block_size, (
|
||||
@@ -203,75 +229,79 @@ def flash_decoding_fwd(
|
||||
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-1)}, "
|
||||
f"v_cache block_size {v_cache.size(-1)}"
|
||||
)
|
||||
# NOTE `context_lengths` records the (kv) sequence lengths incorporating past kv sequence lengths.
|
||||
bsz = context_lengths.size(0) # e.g. the number of seqs
|
||||
max_seq_len = context_lengths.max().item()
|
||||
sm_scale = 1.0 / (head_dim**0.5)
|
||||
|
||||
# NOTE BLOCK_KV could be considered as block splitting the sequence on k/v
|
||||
# For now, BLOCK_KV is supposed to be equivalent with the size of physical cache block (i.e.`block_size`)
|
||||
assert block_size in {16, 32, 64, 128}
|
||||
BLOCK_KV = block_size
|
||||
|
||||
kv_max_split_num = (max_seq_len + BLOCK_KV - 1) // BLOCK_KV
|
||||
mid_o = torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device)
|
||||
mid_o_lse = torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
||||
sm_scale = 1.0 / (head_dim**0.5) if sm_scale is None else sm_scale
|
||||
max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch
|
||||
# For compatibility (TODO revise modeling in future)
|
||||
kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV
|
||||
mid_output = (
|
||||
torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device)
|
||||
if mid_output is None
|
||||
else mid_output
|
||||
)
|
||||
mid_output_lse = (
|
||||
torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
||||
if mid_output_lse is None
|
||||
else mid_output_lse
|
||||
)
|
||||
|
||||
if q.dim() == 4:
|
||||
assert q.size(1) == 1, f"q_len is supposed to be 1 but is {q.size(1)}"
|
||||
q = q.squeeze(1)
|
||||
|
||||
grid = (bsz, num_heads, triton.cdiv(max_seq_len, BLOCK_KV))
|
||||
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
|
||||
_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
mid_o,
|
||||
mid_o_lse,
|
||||
context_lengths,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
kv_seq_len,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
q.stride(3),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
k_cache.stride(3),
|
||||
block_tables.stride(0),
|
||||
block_tables.stride(1),
|
||||
mid_o.stride(0),
|
||||
mid_o.stride(1),
|
||||
mid_o.stride(2),
|
||||
mid_o.stride(3),
|
||||
mid_o_lse.stride(0),
|
||||
mid_o_lse.stride(1),
|
||||
mid_o_lse.stride(2),
|
||||
mid_output.stride(0),
|
||||
mid_output.stride(1),
|
||||
mid_output.stride(2),
|
||||
mid_output.stride(3),
|
||||
mid_output_lse.stride(0),
|
||||
mid_output_lse.stride(1),
|
||||
mid_output_lse.stride(2),
|
||||
sm_scale,
|
||||
KV_GROUPS=num_kv_group,
|
||||
KV_GROUPS=kv_group_num,
|
||||
BLOCK_KV=block_size,
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
||||
output = torch.zeros_like(q)
|
||||
output = output.view(-1, output.size(-2), output.size(-1))
|
||||
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped
|
||||
|
||||
grid = (bsz, num_heads)
|
||||
_flash_decoding_fwd_reduce_kernel[grid](
|
||||
mid_o,
|
||||
mid_o_lse,
|
||||
mid_output,
|
||||
mid_output_lse,
|
||||
output,
|
||||
context_lengths,
|
||||
mid_o.stride(0),
|
||||
mid_o.stride(1),
|
||||
mid_o.stride(2),
|
||||
mid_o.stride(3),
|
||||
mid_o_lse.stride(0),
|
||||
mid_o_lse.stride(1),
|
||||
mid_o_lse.stride(2),
|
||||
kv_seq_len,
|
||||
mid_output.stride(0),
|
||||
mid_output.stride(1),
|
||||
mid_output.stride(2),
|
||||
mid_output.stride(3),
|
||||
mid_output_lse.stride(0),
|
||||
mid_output_lse.stride(1),
|
||||
mid_output_lse.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
output.stride(3),
|
||||
BLOCK_KV=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
58
colossalai/kernel/triton/flash_decoding_utils.py
Normal file
58
colossalai/kernel/triton/flash_decoding_utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class FDIntermTensors(metaclass=SingletonMeta):
|
||||
"""Singleton class to hold tensors used for storing intermediate values in flash-decoding.
|
||||
For now, it holds intermediate output and logsumexp (which will be used in reduction step along kv)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._tensors_initialized = False
|
||||
|
||||
@property
|
||||
def is_initialized(self):
|
||||
return self._tensors_initialized
|
||||
|
||||
@property
|
||||
def mid_output(self):
|
||||
assert self.is_initialized, "Intermediate tensors not initialized yet"
|
||||
return self._mid_output
|
||||
|
||||
@property
|
||||
def mid_output_lse(self):
|
||||
assert self.is_initialized, "Intermediate tensors not initialized yet"
|
||||
return self._mid_output_lse
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
max_batch_size: int,
|
||||
num_attn_heads: int,
|
||||
kv_max_split_num: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: torch.device = get_current_device(),
|
||||
) -> None:
|
||||
"""Initialize tensors.
|
||||
|
||||
Args:
|
||||
max_batch_size (int): The maximum batch size over all the model forward.
|
||||
This could be greater than the batch size in attention forward func when using dynamic batch size.
|
||||
num_attn_heads (int)): Number of attention heads.
|
||||
kv_max_split_num (int): The maximum number of blocks splitted on kv in flash-decoding algorithm.
|
||||
**The maximum length/size of blocks splitted on kv should be the kv cache block size.**
|
||||
head_dim (int): Head dimension.
|
||||
dtype (torch.dtype, optional): Data type to be assigned to intermediate tensors.
|
||||
device (torch.device, optional): Device used to initialize intermediate tensors.
|
||||
"""
|
||||
assert not self.is_initialized, "Intermediate tensors used for Flash-Decoding have been initialized."
|
||||
|
||||
self._mid_output = torch.empty(
|
||||
size=(max_batch_size, num_attn_heads, kv_max_split_num, head_dim), dtype=dtype, device=device
|
||||
)
|
||||
self._mid_output_lse = torch.empty(
|
||||
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
|
||||
)
|
||||
self._tensors_initialized = True
|
Reference in New Issue
Block a user