mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)
* revise shape of kvcache (context attn kernel) * revise shape of kvcache (flash decoding kernel) * revise shape of kvcache (kvcache copy) and attn func * init of kvcache in kvcache manager * revise llama modeling * revise block size retrieval * use torch for rms_norm benchmarking * revise block size retrieval
This commit is contained in:
@@ -79,10 +79,10 @@ class KVCacheManager:
|
||||
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
|
||||
|
||||
# Physical cache allocation
|
||||
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size)
|
||||
if verbose:
|
||||
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
|
||||
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
|
||||
self._kv_caches = self._init_device_caches()
|
||||
self._kv_caches = self._init_device_caches(alloc_shape)
|
||||
self.total_physical_cache_size_in_bytes = (
|
||||
self.elem_size_in_bytes
|
||||
* self.num_layers
|
||||
@@ -297,15 +297,12 @@ class KVCacheManager:
|
||||
blocks.append(cache_block)
|
||||
return blocks
|
||||
|
||||
def _init_device_caches(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Initialize the physical cache on the device.
|
||||
|
||||
For each layer of the model, we allocate two tensors for key and value respectively,
|
||||
with shape of [num_blocks, num_kv_heads, head_size, block_size]
|
||||
with shape of [num_blocks, num_kv_heads, block_size, head_size]
|
||||
"""
|
||||
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
|
||||
# TODO: Explore the performance when using difference shapes with kernel-related optimizations
|
||||
# e.g. [num_blocks, num_kv_heads // x, head_size, block_size, x]
|
||||
k_cache: List[torch.Tensor] = []
|
||||
v_cache: List[torch.Tensor] = []
|
||||
for _ in range(self.num_layers):
|
||||
|
@@ -16,7 +16,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
||||
lengths: key/value lengths
|
||||
block_tables
|
||||
"""
|
||||
num_blocks, num_heads, head_size, block_size = cache.shape
|
||||
num_blocks, num_heads, block_size, head_size = cache.shape
|
||||
bsz, max_blocks_per_seq = block_tables.shape
|
||||
needed_blocks = (lengths + block_size - 1) // block_size
|
||||
|
||||
@@ -26,17 +26,17 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
|
||||
block_num = needed_blocks[i]
|
||||
token_id = 0
|
||||
for block_idx in range(block_num - 1):
|
||||
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0)
|
||||
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 0, 2)
|
||||
token_id += block_size
|
||||
cache[block_tables[i][block_num - 1], :, :, : seq_len - token_id] = source[i][token_id:seq_len].permute(
|
||||
1, 2, 0
|
||||
cache[block_tables[i][block_num - 1], :, : seq_len - token_id, :] = source[i][token_id:seq_len].permute(
|
||||
1, 0, 2
|
||||
)
|
||||
elif type == "decoding":
|
||||
assert source.size(1) == 1, "seq_len should be equal to 1 when decoding."
|
||||
source = source.squeeze(1)
|
||||
slot_idx = (lengths + block_size - 1) % block_size
|
||||
for i in range(bsz):
|
||||
cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i]
|
||||
cache[block_tables[i, needed_blocks[i] - 1], :, slot_idx[i], :] = source[i]
|
||||
|
||||
return cache
|
||||
|
||||
@@ -46,12 +46,12 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0):
|
||||
"""
|
||||
Func: convert key/value cache for calculation
|
||||
|
||||
Args: cache: shape [num_blocks, num_heads, head_size, block_size]
|
||||
Args: cache: shape [num_blocks, num_heads, block_size, head_size]
|
||||
lengths: key/value length
|
||||
block_tables
|
||||
pad_id: padded_id
|
||||
"""
|
||||
num_blocks, num_heads, head_size, block_size = cache.shape
|
||||
num_blocks, num_heads, block_size, head_size = cache.shape
|
||||
|
||||
needed_blocks = (lengths + block_size - 1) // block_size
|
||||
num_remaing_tokens = lengths % block_size
|
||||
@@ -62,8 +62,8 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0):
|
||||
for i in range(bsz):
|
||||
_cache = torch.cat(
|
||||
(
|
||||
cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size),
|
||||
cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1),
|
||||
cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 2, 1, 3)).reshape(-1, num_heads, head_size),
|
||||
cache[block_tables[i][needed_blocks[i] - 1], :, : num_remaing_tokens[i], :].permute(1, 0, 2),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
@@ -127,7 +127,7 @@ class PagedAttention:
|
||||
q: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
|
||||
v: torch.Tensor,
|
||||
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
|
||||
v_cache: torch.Tensor,
|
||||
context_lengths: torch.Tensor, # [num_seqs]
|
||||
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
||||
@@ -142,7 +142,7 @@ class PagedAttention:
|
||||
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
|
||||
num_kv_groups = num_heads // num_kv_heads
|
||||
|
||||
block_size = k_cache.shape[-1]
|
||||
block_size = k_cache.size(-2)
|
||||
bsz, max_blocks_per_sequence = block_tables.shape
|
||||
max_seq_len = max_blocks_per_sequence * block_size
|
||||
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
|
||||
@@ -196,7 +196,7 @@ class PagedAttention:
|
||||
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
|
||||
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
|
||||
v: torch.Tensor,
|
||||
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
|
||||
v_cache: torch.Tensor,
|
||||
context_lengths: torch.Tensor, # [num_seqs]
|
||||
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
||||
@@ -207,7 +207,7 @@ class PagedAttention:
|
||||
num_kv_heads = k.shape[-2]
|
||||
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
|
||||
num_kv_groups = num_heads // num_kv_heads
|
||||
block_size = k_cache.shape[-1]
|
||||
block_size = k_cache.size(-2)
|
||||
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
|
||||
block_tables.shape[-1] * block_size
|
||||
|
||||
@@ -254,7 +254,7 @@ class PagedAttention:
|
||||
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
|
||||
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
|
||||
v: torch.Tensor,
|
||||
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
|
||||
v_cache: torch.Tensor,
|
||||
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
|
||||
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
|
||||
|
@@ -171,7 +171,7 @@ def llama_attn_forward(
|
||||
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
|
||||
_, _, _, block_size = k_cache.shape
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
attn_output = context_attention_unpadded(
|
||||
|
@@ -226,7 +226,7 @@ def llama_attn_forward(
|
||||
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
|
||||
_, _, _, block_size = k_cache.shape
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
attn_output = context_attention_unpadded(
|
||||
|
@@ -36,8 +36,8 @@ def _fwd_context_paged_attention_kernel(
|
||||
stride_od,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
stride_cached,
|
||||
stride_cachebs,
|
||||
stride_cached,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
context_lengths,
|
||||
@@ -158,29 +158,29 @@ def _fwd_context_paged_attention_kernel(
|
||||
# Copy k to corresponding cache block
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offsets_k = K + offset_kv + offsets_dmodel[:, None] * stride_kd + offsets_kt[None, :] * stride_kt
|
||||
k = tl.load(offsets_k, mask=offsets_kt[None, :] < cur_seq_len, other=0.0)
|
||||
offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt
|
||||
k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0)
|
||||
offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
|
||||
offsets_kcache = (
|
||||
KCache
|
||||
+ offset_kvcache
|
||||
+ offsets_dmodel[:, None] * stride_cached
|
||||
+ offsets_kcachebs[None, :] * stride_cachebs
|
||||
+ offsets_dmodel[None, :] * stride_cached
|
||||
+ offsets_kcachebs[:, None] * stride_cachebs
|
||||
)
|
||||
tl.store(offsets_kcache, k, mask=offsets_kcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
# Copy v to corresponding cache block
|
||||
offsets_vd = offsets_dmodel
|
||||
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
offsets_v = V + offset_kv + offsets_vt[:, None] * stride_vt + offsets_vd[None, :] * stride_vd
|
||||
v = tl.load(offsets_v, mask=offsets_vt[:, None] < cur_seq_len, other=0.0)
|
||||
offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd
|
||||
v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0)
|
||||
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
|
||||
offsets_vcache = (
|
||||
VCache
|
||||
+ offset_kvcache
|
||||
+ offsets_vcachebs[:, None] * stride_cachebs
|
||||
+ offsets_dmodel[None, :] * stride_cached
|
||||
+ offsets_vcachebs[None, :] * stride_cachebs
|
||||
+ offsets_dmodel[:, None] * stride_cached
|
||||
)
|
||||
tl.store(offsets_vcache, v, mask=offsets_vcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
|
||||
return
|
||||
|
||||
|
@@ -10,8 +10,8 @@ import triton.language as tl
|
||||
@triton.jit
|
||||
def _flash_decoding_fwd_kernel(
|
||||
Q, # [batch_size, head_num, q_len(1), head_dim]
|
||||
KCache, # [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
VCache, # [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
|
||||
block_tables, # [batch_size, max_blocks_per_sequence]
|
||||
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
|
||||
mid_o_lse, # [batch_size, head_num, kv_split_num]
|
||||
@@ -22,8 +22,8 @@ def _flash_decoding_fwd_kernel(
|
||||
stride_qd,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
stride_cached,
|
||||
stride_cachebs,
|
||||
stride_cached,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
stride_mid_ot,
|
||||
@@ -79,18 +79,18 @@ def _flash_decoding_fwd_kernel(
|
||||
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=KCache + offset_kvcache,
|
||||
shape=(HEAD_DIM, cur_occupied_size),
|
||||
strides=(stride_cached, stride_cachebs),
|
||||
shape=(cur_occupied_size, HEAD_DIM),
|
||||
strides=(stride_cachebs, stride_cached),
|
||||
offsets=(0, 0),
|
||||
block_shape=(HEAD_DIM, BLOCK_SIZE),
|
||||
block_shape=(BLOCK_SIZE, HEAD_DIM),
|
||||
order=(0, 1),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=VCache + offset_kvcache,
|
||||
shape=(HEAD_DIM, cur_occupied_size),
|
||||
strides=(stride_cached, stride_cachebs),
|
||||
shape=(cur_occupied_size, HEAD_DIM),
|
||||
strides=(stride_cachebs, stride_cached),
|
||||
offsets=(0, 0),
|
||||
block_shape=(HEAD_DIM, BLOCK_SIZE),
|
||||
block_shape=(BLOCK_SIZE, HEAD_DIM),
|
||||
order=(0, 1),
|
||||
)
|
||||
k_cur_block = tl.load(K_block_ptr)
|
||||
@@ -102,7 +102,7 @@ def _flash_decoding_fwd_kernel(
|
||||
# NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16,
|
||||
# Multiplying two tensors with shapes [1, d] * [d, block_size] will fail.
|
||||
# Refer to https://github.com/openai/triton/discussions/895
|
||||
S_ij += tl.sum(q[:, None] * k_cur_block, 0)
|
||||
S_ij += tl.sum(q[None, :] * k_cur_block, 1)
|
||||
S_ij *= sm_scale
|
||||
S_ij += tl.where(block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) < cur_kv_seq_len, 0, float("-inf"))
|
||||
|
||||
@@ -111,7 +111,7 @@ def _flash_decoding_fwd_kernel(
|
||||
p_ij_hat = tl.exp(S_ij)
|
||||
l = tl.sum(p_ij_hat, 0)
|
||||
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
|
||||
acc += tl.sum(v_cur_block * p_ij_hat[None, :], 1)
|
||||
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
|
||||
acc = acc / l
|
||||
|
||||
offsets_mid_o = (
|
||||
@@ -206,8 +206,8 @@ def flash_decoding_attention(
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): [bsz, num_heads, head_dim]
|
||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
|
||||
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
|
||||
kv_seq_len (torch.Tensor): [batch_size]
|
||||
records the (kv) sequence lengths incorporating past kv sequence lengths.
|
||||
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
||||
@@ -230,13 +230,13 @@ def flash_decoding_attention(
|
||||
assert head_dim in {32, 64, 128, 256}
|
||||
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
|
||||
f"Got incompatible batch size (number of seqs):\n"
|
||||
f" KV seq lengths bsz {kv_seq_len.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
|
||||
f" KV seq lengths bsz {kv_seq_len.size(0)}, Block tables bsz {block_tables.size(0)}, "
|
||||
f"batch size {bsz}"
|
||||
)
|
||||
assert k_cache.size(-1) == v_cache.size(-1) == block_size, (
|
||||
assert k_cache.size(-2) == v_cache.size(-2) == block_size, (
|
||||
f"Got incompatible block size on kv caches:\n"
|
||||
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-1)}, "
|
||||
f"v_cache block_size {v_cache.size(-1)}"
|
||||
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-2)}, "
|
||||
f"v_cache block_size {v_cache.size(-2)}"
|
||||
)
|
||||
|
||||
# NOTE BLOCK_KV could be considered as block splitting the sequence on k/v
|
||||
|
@@ -15,8 +15,8 @@ def _copy_to_kvcache_seqlen1_kernel(
|
||||
stride_kd,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
stride_cached,
|
||||
stride_cachebs,
|
||||
stride_cached,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
block_size,
|
||||
@@ -29,15 +29,15 @@ def _copy_to_kvcache_seqlen1_kernel(
|
||||
last_bt_block_idx = past_kv_seq_len // block_size
|
||||
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
|
||||
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
|
||||
offsets_in_last_block = (past_kv_seq_len % block_size) * stride_cachebs
|
||||
offsets_in_last_block = past_kv_seq_len % block_size
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
|
||||
kv = tl.load(KV + offsets_kv)
|
||||
offsets_kvcache = (
|
||||
block_id * stride_cacheb
|
||||
+ cur_kv_head_idx * stride_cacheh
|
||||
+ offsets_in_last_block * stride_cachebs
|
||||
+ offsets_dmodel * stride_cached
|
||||
+ offsets_in_last_block
|
||||
)
|
||||
tl.store(KVCache + offsets_kvcache, kv)
|
||||
return
|
||||
@@ -52,23 +52,18 @@ def copy_kv_to_blocked_cache(
|
||||
"""
|
||||
Copy keys or values to the blocked key/value cache during decoding stage.
|
||||
|
||||
Parameters:
|
||||
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||
- k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache.
|
||||
- kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||
- block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||
Args:
|
||||
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
|
||||
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||
"""
|
||||
assert k.size(-1) == k_cache.size(-2), "Incompatible head dim"
|
||||
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
|
||||
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||
if k.dim() == 4:
|
||||
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
|
||||
bsz, _, num_kv_heads, head_dim = k.shape
|
||||
# [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim]
|
||||
k = k.squeeze(dim=1)
|
||||
elif k.dim() == 3:
|
||||
bsz, num_kv_heads, head_dim = k.shape
|
||||
else:
|
||||
raise ValueError(f"The key dim should be 3 or 4, but got {k.dim()}.")
|
||||
|
||||
k = k.squeeze(1) if k.dim() == 4 else k
|
||||
assert k.dim() == 3, f"Incompatible k dim {k.dim()}"
|
||||
bsz, num_kv_heads, head_dim = k.shape
|
||||
|
||||
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||
f"Got incompatible batch size (number of seqs):\n"
|
||||
@@ -77,7 +72,7 @@ def copy_kv_to_blocked_cache(
|
||||
)
|
||||
|
||||
# Modify if the shape of kv cahce is changed.
|
||||
block_size = k_cache.size(-1)
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
num_warps = 8 if head_dim > 128 else 4
|
||||
|
||||
|
Reference in New Issue
Block a user