mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[hotfix] fix boundary check in batch (#5306)
This commit is contained in:
@@ -22,6 +22,7 @@ def _fwd_context_paged_attention_kernel(
|
||||
KCache,
|
||||
VCache,
|
||||
BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence]
|
||||
batch_size,
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_qd,
|
||||
@@ -49,6 +50,8 @@ def _fwd_context_paged_attention_kernel(
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
cur_seq_idx = tl.program_id(0)
|
||||
if cur_seq_idx >= batch_size:
|
||||
return
|
||||
cur_head_idx = tl.program_id(1)
|
||||
block_start_m = tl.program_id(2) # Br, max_input_len // Block_M
|
||||
cur_kv_head_idx = cur_head_idx // KV_GROUPS
|
||||
@@ -217,6 +220,8 @@ def context_attention_unpadded(
|
||||
assert block_size in {16, 32, 64, 128}
|
||||
BLOCK_M = BLOCK_N = block_size
|
||||
|
||||
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
|
||||
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
||||
grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M))
|
||||
|
||||
_fwd_context_paged_attention_kernel[grid](
|
||||
@@ -227,6 +232,7 @@ def context_attention_unpadded(
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
num_seqs,
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
|
Reference in New Issue
Block a user