[kernel/fix] Performance Optimization for Decoding Kernel and Benchmarking (#5274)

* prevent re-creating intermediate tensors

* add singleton class holding intermediate values

* fix triton kernel api

* add benchmark in pytest

* fix kernel api and add benchmark

* revise flash decoding triton kernel in/out shapes

* fix calling of triton kernel in modeling

* fix pytest: extract to util functions
This commit is contained in:
Yuanheng Zhao
2024-01-19 15:47:16 +08:00
committed by GitHub
parent 9e2342bde2
commit 6e487e7d3c
7 changed files with 382 additions and 152 deletions

View File

@@ -9,7 +9,9 @@ except ImportError:
# There may exist import error even if we have triton installed.
if HAS_TRITON:
from .context_attn_unpad import context_attention_unpadded
from .flash_decoding import flash_decoding_fwd
from .flash_decoding import flash_decoding_attention
from .flash_decoding_utils import FDIntermTensors
from .rms_layernorm import rms_layernorm
from .gptq_triton import gptq_fused_linear_triton
from .kvcache_copy import copy_kv_to_blocked_cache
@@ -18,10 +20,11 @@ if HAS_TRITON:
__all__ = [
"context_attention_unpadded",
"flash_decoding_fwd",
"flash_decoding_attention",
"copy_kv_to_blocked_cache",
"softmax",
"rms_layernorm",
"gptq_fused_linear_triton",
"rotary_embedding",
"FDIntermTensors",
]

View File

@@ -9,15 +9,16 @@ import triton.language as tl
# Triton 2.1.0
@triton.jit
def _flash_decoding_fwd_kernel(
Q, # [batch_size, head_num, head_dim]
Q, # [batch_size, head_num, q_len(1), head_dim]
KCache, # [num_blocks, num_kv_heads, head_dim, block_size]
VCache, # [num_blocks, num_kv_heads, head_dim, block_size]
block_tables, # [batch_size, max_blocks_per_sequence]
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
mid_o_lse, # [batch_size, head_num, kv_split_num]
context_lengths, # [batch_size]
kv_seq_len, # [batch_size]
stride_qt,
stride_qh,
stride_ql,
stride_qd,
stride_cacheb,
stride_cacheh,
@@ -51,7 +52,7 @@ def _flash_decoding_fwd_kernel(
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
# get the current (kv) sequence length from provided context lengths tensor
cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx)
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
q = tl.load(Q + offsets_q)
@@ -65,7 +66,6 @@ def _flash_decoding_fwd_kernel(
cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
# TODO might want to remove if-else block?
return
cur_occupied_size = tl.where(
@@ -132,7 +132,7 @@ def _flash_decoding_fwd_reduce_kernel(
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
mid_o_lse, # [batch_size, head_num, kv_split_num]
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
context_lengths,
kv_seq_len,
stride_mid_ot,
stride_mid_oh,
stride_mid_ob,
@@ -141,6 +141,7 @@ def _flash_decoding_fwd_reduce_kernel(
stride_o_lseh,
stride_o_lseb,
stride_ob,
stride_ol,
stride_oh,
stride_od,
BLOCK_KV: tl.constexpr,
@@ -149,7 +150,7 @@ def _flash_decoding_fwd_reduce_kernel(
cur_seq_idx = tl.program_id(0)
cur_head_idx = tl.program_id(1)
cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx)
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
offsets_dmodel = tl.arange(0, HEAD_DIM)
# NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have
@@ -181,21 +182,46 @@ def _flash_decoding_fwd_reduce_kernel(
# Decoding Stage
# Used with blocked KV Cache (PagedAttention)
def flash_decoding_fwd(
q: torch.Tensor, # [bsz(e.g.num_tokens), 1, num_heads, head_dim]
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size]
v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size]
context_lengths: torch.Tensor, # [batch_size]
block_tables: torch.Tensor, # [batch_size, max_blocks_per_sequence]
def flash_decoding_attention(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
kv_seq_len: torch.Tensor,
block_tables: torch.Tensor,
block_size: int,
num_kv_group: int = 1,
max_seq_len_in_batch: int = None,
mid_output: torch.Tensor = None,
mid_output_lse: torch.Tensor = None,
sm_scale: int = None,
kv_group_num: int = 1,
):
bsz, _, num_heads, head_dim = q.shape
"""
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
Args:
q (torch.Tensor): [bsz, num_heads, q_len(1), head_dim]
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
kv_seq_len (torch.Tensor): [batch_size]
records the (kv) sequence lengths incorporating past kv sequence lengths.
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
max_seq_len_in_batch (int): Maximum sequence length in the batch.
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
block_size (int): Size of each block in the blocked key/value cache.
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
Returns:
Output tensor with shape [bsz, num_heads, q_len, head_dim]
"""
bsz, num_heads, _, head_dim = q.shape
assert head_dim in {32, 64, 128, 256}
assert context_lengths.shape[0] == block_tables.shape[0] == bsz, (
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n"
f" Conext lengths bsz {context_lengths.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
f" KV seq lengths bsz {kv_seq_len.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
f"batch size {bsz}"
)
assert k_cache.size(-1) == v_cache.size(-1) == block_size, (
@@ -203,75 +229,79 @@ def flash_decoding_fwd(
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-1)}, "
f"v_cache block_size {v_cache.size(-1)}"
)
# NOTE `context_lengths` records the (kv) sequence lengths incorporating past kv sequence lengths.
bsz = context_lengths.size(0) # e.g. the number of seqs
max_seq_len = context_lengths.max().item()
sm_scale = 1.0 / (head_dim**0.5)
# NOTE BLOCK_KV could be considered as block splitting the sequence on k/v
# For now, BLOCK_KV is supposed to be equivalent with the size of physical cache block (i.e.`block_size`)
assert block_size in {16, 32, 64, 128}
BLOCK_KV = block_size
kv_max_split_num = (max_seq_len + BLOCK_KV - 1) // BLOCK_KV
mid_o = torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device)
mid_o_lse = torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
sm_scale = 1.0 / (head_dim**0.5) if sm_scale is None else sm_scale
max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch
# For compatibility (TODO revise modeling in future)
kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV
mid_output = (
torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device)
if mid_output is None
else mid_output
)
mid_output_lse = (
torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
if mid_output_lse is None
else mid_output_lse
)
if q.dim() == 4:
assert q.size(1) == 1, f"q_len is supposed to be 1 but is {q.size(1)}"
q = q.squeeze(1)
grid = (bsz, num_heads, triton.cdiv(max_seq_len, BLOCK_KV))
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
_flash_decoding_fwd_kernel[grid](
q,
k_cache,
v_cache,
block_tables,
mid_o,
mid_o_lse,
context_lengths,
mid_output,
mid_output_lse,
kv_seq_len,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
mid_o.stride(0),
mid_o.stride(1),
mid_o.stride(2),
mid_o.stride(3),
mid_o_lse.stride(0),
mid_o_lse.stride(1),
mid_o_lse.stride(2),
mid_output.stride(0),
mid_output.stride(1),
mid_output.stride(2),
mid_output.stride(3),
mid_output_lse.stride(0),
mid_output_lse.stride(1),
mid_output_lse.stride(2),
sm_scale,
KV_GROUPS=num_kv_group,
KV_GROUPS=kv_group_num,
BLOCK_KV=block_size,
BLOCK_SIZE=block_size,
HEAD_DIM=head_dim,
)
output = torch.zeros_like(q)
output = output.view(-1, output.size(-2), output.size(-1))
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped
grid = (bsz, num_heads)
_flash_decoding_fwd_reduce_kernel[grid](
mid_o,
mid_o_lse,
mid_output,
mid_output_lse,
output,
context_lengths,
mid_o.stride(0),
mid_o.stride(1),
mid_o.stride(2),
mid_o.stride(3),
mid_o_lse.stride(0),
mid_o_lse.stride(1),
mid_o_lse.stride(2),
kv_seq_len,
mid_output.stride(0),
mid_output.stride(1),
mid_output.stride(2),
mid_output.stride(3),
mid_output_lse.stride(0),
mid_output_lse.stride(1),
mid_output_lse.stride(2),
output.stride(0),
output.stride(1),
output.stride(2),
output.stride(3),
BLOCK_KV=block_size,
HEAD_DIM=head_dim,
)

View File

@@ -0,0 +1,58 @@
import torch
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.utils import get_current_device
class FDIntermTensors(metaclass=SingletonMeta):
"""Singleton class to hold tensors used for storing intermediate values in flash-decoding.
For now, it holds intermediate output and logsumexp (which will be used in reduction step along kv)
"""
def __init__(self):
self._tensors_initialized = False
@property
def is_initialized(self):
return self._tensors_initialized
@property
def mid_output(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._mid_output
@property
def mid_output_lse(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._mid_output_lse
def initialize(
self,
max_batch_size: int,
num_attn_heads: int,
kv_max_split_num: int,
head_dim: int,
dtype: torch.dtype = torch.float32,
device: torch.device = get_current_device(),
) -> None:
"""Initialize tensors.
Args:
max_batch_size (int): The maximum batch size over all the model forward.
This could be greater than the batch size in attention forward func when using dynamic batch size.
num_attn_heads (int)): Number of attention heads.
kv_max_split_num (int): The maximum number of blocks splitted on kv in flash-decoding algorithm.
**The maximum length/size of blocks splitted on kv should be the kv cache block size.**
head_dim (int): Head dimension.
dtype (torch.dtype, optional): Data type to be assigned to intermediate tensors.
device (torch.device, optional): Device used to initialize intermediate tensors.
"""
assert not self.is_initialized, "Intermediate tensors used for Flash-Decoding have been initialized."
self._mid_output = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num, head_dim), dtype=dtype, device=device
)
self._mid_output_lse = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
)
self._tensors_initialized = True