[kernel] Support New KCache Layout - Triton Kernel (#5677)

* kvmemcpy triton for new kcache layout

* revise tests for new kcache layout

* naive triton flash decoding - new kcache layout

* rotary triton kernel - new kcache layout

* remove redundancy - triton decoding

* remove redundancy - triton kvcache copy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yuanheng Zhao
2024-05-03 17:20:45 +08:00
committed by GitHub
parent 9df016fc45
commit 537a3cbc4d
10 changed files with 428 additions and 206 deletions

View File

@@ -338,8 +338,8 @@ def _fwd_context_paged_attention_kernel_v2(
X_range = tl.arange(0, KCACHE_X)
# unroll the loop aggressively
for split_x in tl.static_range(HEAD_DIM // KCACHE_X):
offsets_dmodel_x_partion = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X)
offsets_k = K + offset_kv + offsets_dmodel_x_partion[None, :] * stride_kd + offsets_m[:, None] * stride_kt
offsets_dmodel_x_partition = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X)
offsets_k = K + offset_kv + offsets_dmodel_x_partition[None, :] * stride_kd + offsets_m[:, None] * stride_kt
k = tl.load(offsets_k, mask=offsets_m[:, None] < cur_seq_len, other=0.0)
# HACK: KCache must be contiguous in order to apply the following offsets calculation
offsets_kcache = (