[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)

* revise shape of kvcache (context attn kernel)

* revise shape of kvcache (flash decoding kernel)

* revise shape of kvcache (kvcache copy) and attn func

* init of kvcache in kvcache manager

* revise llama modeling

* revise block size retrieval

* use torch for rms_norm benchmarking

* revise block size retrieval
This commit is contained in:
Yuanheng Zhao
2024-01-30 16:06:09 +08:00
committed by GitHub
parent e8f0642f28
commit 5f98a9d68a
14 changed files with 171 additions and 145 deletions

View File

@@ -36,8 +36,8 @@ def _fwd_context_paged_attention_kernel(
stride_od,
stride_cacheb,
stride_cacheh,
stride_cached,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
context_lengths,
@@ -158,29 +158,29 @@ def _fwd_context_paged_attention_kernel(
# Copy k to corresponding cache block
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offsets_k = K + offset_kv + offsets_dmodel[:, None] * stride_kd + offsets_kt[None, :] * stride_kt
k = tl.load(offsets_k, mask=offsets_kt[None, :] < cur_seq_len, other=0.0)
offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt
k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0)
offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
offsets_kcache = (
KCache
+ offset_kvcache
+ offsets_dmodel[:, None] * stride_cached
+ offsets_kcachebs[None, :] * stride_cachebs
+ offsets_dmodel[None, :] * stride_cached
+ offsets_kcachebs[:, None] * stride_cachebs
)
tl.store(offsets_kcache, k, mask=offsets_kcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
# Copy v to corresponding cache block
offsets_vd = offsets_dmodel
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
offsets_v = V + offset_kv + offsets_vt[:, None] * stride_vt + offsets_vd[None, :] * stride_vd
v = tl.load(offsets_v, mask=offsets_vt[:, None] < cur_seq_len, other=0.0)
offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd
v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0)
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
offsets_vcache = (
VCache
+ offset_kvcache
+ offsets_vcachebs[:, None] * stride_cachebs
+ offsets_dmodel[None, :] * stride_cached
+ offsets_vcachebs[None, :] * stride_cachebs
+ offsets_dmodel[:, None] * stride_cached
)
tl.store(offsets_vcache, v, mask=offsets_vcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
return

View File

@@ -10,8 +10,8 @@ import triton.language as tl
@triton.jit
def _flash_decoding_fwd_kernel(
Q, # [batch_size, head_num, q_len(1), head_dim]
KCache, # [num_blocks, num_kv_heads, head_dim, block_size]
VCache, # [num_blocks, num_kv_heads, head_dim, block_size]
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
block_tables, # [batch_size, max_blocks_per_sequence]
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
mid_o_lse, # [batch_size, head_num, kv_split_num]
@@ -22,8 +22,8 @@ def _flash_decoding_fwd_kernel(
stride_qd,
stride_cacheb,
stride_cacheh,
stride_cached,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
stride_mid_ot,
@@ -79,18 +79,18 @@ def _flash_decoding_fwd_kernel(
K_block_ptr = tl.make_block_ptr(
base=KCache + offset_kvcache,
shape=(HEAD_DIM, cur_occupied_size),
strides=(stride_cached, stride_cachebs),
shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cachebs, stride_cached),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_SIZE),
block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1),
)
V_block_ptr = tl.make_block_ptr(
base=VCache + offset_kvcache,
shape=(HEAD_DIM, cur_occupied_size),
strides=(stride_cached, stride_cachebs),
shape=(cur_occupied_size, HEAD_DIM),
strides=(stride_cachebs, stride_cached),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_SIZE),
block_shape=(BLOCK_SIZE, HEAD_DIM),
order=(0, 1),
)
k_cur_block = tl.load(K_block_ptr)
@@ -102,7 +102,7 @@ def _flash_decoding_fwd_kernel(
# NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16,
# Multiplying two tensors with shapes [1, d] * [d, block_size] will fail.
# Refer to https://github.com/openai/triton/discussions/895
S_ij += tl.sum(q[:, None] * k_cur_block, 0)
S_ij += tl.sum(q[None, :] * k_cur_block, 1)
S_ij *= sm_scale
S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf"))
@@ -111,7 +111,7 @@ def _flash_decoding_fwd_kernel(
p_ij_hat = tl.exp(S_ij)
l = tl.sum(p_ij_hat, 0)
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
acc += tl.sum(v_cur_block * p_ij_hat[None, :], 1)
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
acc = acc / l
offsets_mid_o = (
@@ -206,8 +206,8 @@ def flash_decoding_attention(
Args:
q (torch.Tensor): [bsz, num_heads, head_dim]
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
kv_seq_len (torch.Tensor): [batch_size]
records the (kv) sequence lengths incorporating past kv sequence lengths.
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
@@ -230,13 +230,13 @@ def flash_decoding_attention(
assert head_dim in {32, 64, 128, 256}
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n"
f" KV seq lengths bsz {kv_seq_len.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
f" KV seq lengths bsz {kv_seq_len.size(0)}, Block tables bsz {block_tables.size(0)}, "
f"batch size {bsz}"
)
assert k_cache.size(-1) == v_cache.size(-1) == block_size, (
assert k_cache.size(-2) == v_cache.size(-2) == block_size, (
f"Got incompatible block size on kv caches:\n"
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-1)}, "
f"v_cache block_size {v_cache.size(-1)}"
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-2)}, "
f"v_cache block_size {v_cache.size(-2)}"
)
# NOTE BLOCK_KV could be considered as block splitting the sequence on k/v

View File

@@ -15,8 +15,8 @@ def _copy_to_kvcache_seqlen1_kernel(
stride_kd,
stride_cacheb,
stride_cacheh,
stride_cached,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
block_size,
@@ -29,15 +29,15 @@ def _copy_to_kvcache_seqlen1_kernel(
last_bt_block_idx = past_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
offsets_in_last_block = (past_kv_seq_len % block_size) * stride_cachebs
offsets_in_last_block = past_kv_seq_len % block_size
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
kv = tl.load(KV + offsets_kv)
offsets_kvcache = (
block_id * stride_cacheb
+ cur_kv_head_idx * stride_cacheh
+ offsets_in_last_block * stride_cachebs
+ offsets_dmodel * stride_cached
+ offsets_in_last_block
)
tl.store(KVCache + offsets_kvcache, kv)
return
@@ -52,23 +52,18 @@ def copy_kv_to_blocked_cache(
"""
Copy keys or values to the blocked key/value cache during decoding stage.
Parameters:
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
- k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache.
- kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
- block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
Args:
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
"""
assert k.size(-1) == k_cache.size(-2), "Incompatible head dim"
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
if k.dim() == 4:
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
bsz, _, num_kv_heads, head_dim = k.shape
# [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim]
k = k.squeeze(dim=1)
elif k.dim() == 3:
bsz, num_kv_heads, head_dim = k.shape
else:
raise ValueError(f"The key dim should be 3 or 4, but got {k.dim()}.")
k = k.squeeze(1) if k.dim() == 4 else k
assert k.dim() == 3, f"Incompatible k dim {k.dim()}"
bsz, num_kv_heads, head_dim = k.shape
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n"
@@ -77,7 +72,7 @@ def copy_kv_to_blocked_cache(
)
# Modify if the shape of kv cahce is changed.
block_size = k_cache.size(-1)
block_size = k_cache.size(-2)
num_warps = 8 if head_dim > 128 else 4