[hotfix] fix boundary check in batch (#5306)

This commit is contained in:
Yuanheng Zhao
2024-01-25 10:23:12 +08:00
committed by GitHub
parent c647e00e3c
commit af8359c430
2 changed files with 18 additions and 1 deletions

View File

@@ -22,6 +22,7 @@ def _fwd_context_paged_attention_kernel(
KCache,
VCache,
BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence]
batch_size,
stride_qt,
stride_qh,
stride_qd,
@@ -49,6 +50,8 @@ def _fwd_context_paged_attention_kernel(
BLOCK_N: tl.constexpr,
):
cur_seq_idx = tl.program_id(0)
if cur_seq_idx >= batch_size:
return
cur_head_idx = tl.program_id(1)
block_start_m = tl.program_id(2) # Br, max_input_len // Block_M
cur_kv_head_idx = cur_head_idx // KV_GROUPS
@@ -217,6 +220,8 @@ def context_attention_unpadded(
assert block_size in {16, 32, 64, 128}
BLOCK_M = BLOCK_N = block_size
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M))
_fwd_context_paged_attention_kernel[grid](
@@ -227,6 +232,7 @@ def context_attention_unpadded(
k_cache,
v_cache,
block_tables,
num_seqs,
q.stride(0),
q.stride(1),
q.stride(2),