[Inference/Kernel] Optimize paged attention: Refactor key cache layout (#5643)

* optimize flashdecodingattention: refactor code with different key cache layout(from [num_blocks, num_kv_heads, block_size, head_size] to [num_blocks, num_kv_heads, head_size/x, block_size, x])

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Steve Luo
2024-04-25 14:24:02 +08:00
committed by GitHub
parent 90cd5227a3
commit a8fd3b0342
8 changed files with 152 additions and 49 deletions

View File

@@ -12,7 +12,7 @@ inference_ops = InferenceOpsLoader().load()
from tests.test_infer.test_ops.triton.kernel_utils import (
convert_kv_unpad_to_padded,
create_attention_mask,
generate_caches_and_block_tables_v2,
generate_caches_and_block_tables_v3,
generate_caches_and_block_tables_vllm,
torch_attn_ref,
)
@@ -77,7 +77,7 @@ def test_flash_decoding_attention(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)

View File

@@ -150,6 +150,50 @@ def mock_alloc_block_table_and_kvcache_v2(
return block_tables
def mock_alloc_block_table_and_kvcache_v3(
k: torch.Tensor,
v: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
context_lengths: torch.Tensor,
num_seqs: int,
max_num_blocks_per_seq: int,
block_size: int,
) -> torch.Tensor:
"""Allocate block tables based on provided context lengths; and copy KV to blocked KV Cache."""
block_id = 0
block_tables = torch.full(size=(num_seqs, max_num_blocks_per_seq), fill_value=-1, dtype=torch.int32)
num_tokens_processed = 0
_, num_kv_heads, head_dim = k.shape
x = 16 // torch.tensor([], dtype=k.dtype).element_size()
for i, seq_len in enumerate(context_lengths.tolist()):
right_bound = (seq_len + block_size - 1) // block_size # open bound
block_tables[i, :right_bound] = torch.arange(block_id, block_id + right_bound, dtype=torch.int32)
# Manually fill kv caches by copying from k and v
for i in range(right_bound):
if i == right_bound - 1:
allocated_locs = seq_len % block_size or block_size
else:
allocated_locs = block_size
# [block_size, num_kv_heads, head_dim/x, x]->[num_kv_heads, head_dim/x, block_size,x]
k_block = (
k[num_tokens_processed : num_tokens_processed + allocated_locs, :, :]
.reshape(allocated_locs, num_kv_heads, head_dim // x, x)
.permute(1, 2, 0, 3)
)
v_block = v[num_tokens_processed : num_tokens_processed + allocated_locs, :, :].permute(1, 0, 2)
k_cache[block_id, :, :, :allocated_locs, :] = k_block
v_cache[block_id, :, :allocated_locs, :] = v_block
num_tokens_processed += allocated_locs
block_id += 1
return block_tables
def mock_alloc_block_table_and_kvcache_vllm(
k: torch.Tensor,
v: torch.Tensor,
@@ -251,6 +295,26 @@ def generate_caches_and_block_tables_v2(
return k_cache, v_cache, block_tables
def generate_caches_and_block_tables_v3(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]:
# Mock generation of k/v blocked caches and block tables from providied kv unpad and seq lengths
# k_unpad/v_unpad [num_total_tokens, num_kv_heads, head_dim]
_, num_kv_heads, head_dim = k_unpad.shape
x = 16 // torch.tensor([], dtype=dtype).element_size()
k_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
v_cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=device)
v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache_v3(
k_unpad, v_unpad, k_cache, v_cache, kv_lengths, bsz, max_num_blocks_per_seq, block_size
)
return k_cache, v_cache, block_tables
def generate_caches_and_block_tables_vllm(
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype=torch.float16, device="cuda"
) -> Tuple[torch.Tensor, ...]: