[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)

* revise shape of kvcache (context attn kernel)

* revise shape of kvcache (flash decoding kernel)

* revise shape of kvcache (kvcache copy) and attn func

* init of kvcache in kvcache manager

* revise llama modeling

* revise block size retrieval

* use torch for rms_norm benchmarking

* revise block size retrieval
This commit is contained in:
Yuanheng Zhao
2024-01-30 16:06:09 +08:00
committed by GitHub
parent e8f0642f28
commit 5f98a9d68a
14 changed files with 171 additions and 145 deletions

View File

@@ -36,8 +36,8 @@ def _fwd_context_paged_attention_kernel(
stride_od,
stride_cacheb,
stride_cacheh,
stride_cached,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
context_lengths,
@@ -158,29 +158,29 @@ def _fwd_context_paged_attention_kernel(
# Copy k to corresponding cache block
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offsets_k = K + offset_kv + offsets_dmodel[:, None] * stride_kd + offsets_kt[None, :] * stride_kt
k = tl.load(offsets_k, mask=offsets_kt[None, :] < cur_seq_len, other=0.0)
offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt
k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0)
offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
offsets_kcache = (
KCache
+ offset_kvcache
+ offsets_dmodel[:, None] * stride_cached
+ offsets_kcachebs[None, :] * stride_cachebs
+ offsets_dmodel[None, :] * stride_cached
+ offsets_kcachebs[:, None] * stride_cachebs
)
tl.store(offsets_kcache, k, mask=offsets_kcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
# Copy v to corresponding cache block
offsets_vd = offsets_dmodel
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
offsets_v = V + offset_kv + offsets_vt[:, None] * stride_vt + offsets_vd[None, :] * stride_vd
v = tl.load(offsets_v, mask=offsets_vt[:, None] < cur_seq_len, other=0.0)
offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd
v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0)
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
offsets_vcache = (
VCache
+ offset_kvcache
+ offsets_vcachebs[:, None] * stride_cachebs
+ offsets_dmodel[None, :] * stride_cached
+ offsets_vcachebs[None, :] * stride_cachebs
+ offsets_dmodel[:, None] * stride_cached
)
tl.store(offsets_vcache, v, mask=offsets_vcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
return