mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[kernel] Support New KCache Layout - Triton Kernel (#5677)
* kvmemcpy triton for new kcache layout * revise tests for new kcache layout * naive triton flash decoding - new kcache layout * rotary triton kernel - new kcache layout * remove redundancy - triton decoding * remove redundancy - triton kvcache copy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -338,8 +338,8 @@ def _fwd_context_paged_attention_kernel_v2(
|
||||
X_range = tl.arange(0, KCACHE_X)
|
||||
# unroll the loop aggressively
|
||||
for split_x in tl.static_range(HEAD_DIM // KCACHE_X):
|
||||
offsets_dmodel_x_partion = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X)
|
||||
offsets_k = K + offset_kv + offsets_dmodel_x_partion[None, :] * stride_kd + offsets_m[:, None] * stride_kt
|
||||
offsets_dmodel_x_partition = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X)
|
||||
offsets_k = K + offset_kv + offsets_dmodel_x_partition[None, :] * stride_kd + offsets_m[:, None] * stride_kt
|
||||
k = tl.load(offsets_k, mask=offsets_m[:, None] < cur_seq_len, other=0.0)
|
||||
# HACK: KCache must be contiguous in order to apply the following offsets calculation
|
||||
offsets_kcache = (
|
||||
|
Reference in New Issue
Block a user