[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)

* revise shape of kvcache (context attn kernel)

* revise shape of kvcache (flash decoding kernel)

* revise shape of kvcache (kvcache copy) and attn func

* init of kvcache in kvcache manager

* revise llama modeling

* revise block size retrieval

* use torch for rms_norm benchmarking

* revise block size retrieval
This commit is contained in:
Yuanheng Zhao
2024-01-30 16:06:09 +08:00
committed by GitHub
parent e8f0642f28
commit 5f98a9d68a
14 changed files with 171 additions and 145 deletions

View File

@@ -79,10 +79,10 @@ class KVCacheManager:
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
# Physical cache allocation
alloc_shape = (self.num_blocks, self.head_num, self.block_size, self.head_size)
if verbose:
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches()
self._kv_caches = self._init_device_caches(alloc_shape)
self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes
* self.num_layers
@@ -297,15 +297,12 @@ class KVCacheManager:
blocks.append(cache_block)
return blocks
def _init_device_caches(self) -> Tuple[torch.Tensor, torch.Tensor]:
def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initialize the physical cache on the device.
For each layer of the model, we allocate two tensors for key and value respectively,
with shape of [num_blocks, num_kv_heads, head_size, block_size]
with shape of [num_blocks, num_kv_heads, block_size, head_size]
"""
alloc_shape = (self.num_blocks, self.head_num, self.head_size, self.block_size)
# TODO: Explore the performance when using difference shapes with kernel-related optimizations
# e.g. [num_blocks, num_kv_heads // x, head_size, block_size, x]
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):

View File

@@ -16,7 +16,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
lengths: key/value lengths
block_tables
"""
num_blocks, num_heads, head_size, block_size = cache.shape
num_blocks, num_heads, block_size, head_size = cache.shape
bsz, max_blocks_per_seq = block_tables.shape
needed_blocks = (lengths + block_size - 1) // block_size
@@ -26,17 +26,17 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"):
block_num = needed_blocks[i]
token_id = 0
for block_idx in range(block_num - 1):
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 2, 0)
cache[block_tables[i][block_idx]] = source[i][token_id : token_id + block_size].permute(1, 0, 2)
token_id += block_size
cache[block_tables[i][block_num - 1], :, :, : seq_len - token_id] = source[i][token_id:seq_len].permute(
1, 2, 0
cache[block_tables[i][block_num - 1], :, : seq_len - token_id, :] = source[i][token_id:seq_len].permute(
1, 0, 2
)
elif type == "decoding":
assert source.size(1) == 1, "seq_len should be equal to 1 when decoding."
source = source.squeeze(1)
slot_idx = (lengths + block_size - 1) % block_size
for i in range(bsz):
cache[block_tables[i, needed_blocks[i] - 1], :, :, slot_idx[i]] = source[i]
cache[block_tables[i, needed_blocks[i] - 1], :, slot_idx[i], :] = source[i]
return cache
@@ -46,12 +46,12 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0):
"""
Func: convert key/value cache for calculation
Args: cache: shape [num_blocks, num_heads, head_size, block_size]
Args: cache: shape [num_blocks, num_heads, block_size, head_size]
lengths: key/value length
block_tables
pad_id: padded_id
"""
num_blocks, num_heads, head_size, block_size = cache.shape
num_blocks, num_heads, block_size, head_size = cache.shape
needed_blocks = (lengths + block_size - 1) // block_size
num_remaing_tokens = lengths % block_size
@@ -62,8 +62,8 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0):
for i in range(bsz):
_cache = torch.cat(
(
cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 3, 1, 2)).reshape(-1, num_heads, head_size),
cache[block_tables[i][needed_blocks[i] - 1], :, :, : num_remaing_tokens[i]].permute(2, 0, 1),
cache[block_tables[i][: needed_blocks[i] - 1]].permute((0, 2, 1, 3)).reshape(-1, num_heads, head_size),
cache[block_tables[i][needed_blocks[i] - 1], :, : num_remaing_tokens[i], :].permute(1, 0, 2),
),
dim=0,
)
@@ -127,7 +127,7 @@ class PagedAttention:
q: torch.Tensor, # [num_tokens, num_heads, head_size]
k: torch.Tensor, # [num_tokens, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
@@ -142,7 +142,7 @@ class PagedAttention:
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1]
block_size = k_cache.size(-2)
bsz, max_blocks_per_sequence = block_tables.shape
max_seq_len = max_blocks_per_sequence * block_size
assert q.shape[-1] == k.shape[-1] == v.shape[-1]
@@ -196,7 +196,7 @@ class PagedAttention:
q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size]
k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
v_cache: torch.Tensor,
context_lengths: torch.Tensor, # [num_seqs]
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]
@@ -207,7 +207,7 @@ class PagedAttention:
num_kv_heads = k.shape[-2]
assert num_heads % num_kv_heads == 0, "num_kv_heads should be divisible by num_heads"
num_kv_groups = num_heads // num_kv_heads
block_size = k_cache.shape[-1]
block_size = k_cache.size(-2)
assert q.shape[0] == k.shape[0] == v.shape[0] == block_tables.shape[0]
block_tables.shape[-1] * block_size
@@ -254,7 +254,7 @@ class PagedAttention:
q: torch.Tensor, # [bsz, 1, num_heads, head_size]
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size]
v: torch.Tensor,
k_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
k_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
v_cache: torch.Tensor,
lengths: torch.Tensor, # [num_seqs]: input_lengths + output_lengths
block_tables: torch.Tensor, # [num_seqs,max_blocks_per_sequence]

View File

@@ -171,7 +171,7 @@ def llama_attn_forward(
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
_, _, _, block_size = k_cache.shape
block_size = k_cache.size(-2)
if is_prompts:
attn_output = context_attention_unpadded(

View File

@@ -226,7 +226,7 @@ def llama_attn_forward(
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
_, _, _, block_size = k_cache.shape
block_size = k_cache.size(-2)
if is_prompts:
attn_output = context_attention_unpadded(