mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[Infer] Optimize Blocked KVCache And Kernels Using It (#5325)
* revise shape of kvcache (context attn kernel) * revise shape of kvcache (flash decoding kernel) * revise shape of kvcache (kvcache copy) and attn func * init of kvcache in kvcache manager * revise llama modeling * revise block size retrieval * use torch for rms_norm benchmarking * revise block size retrieval
This commit is contained in:
@@ -36,8 +36,8 @@ def _fwd_context_paged_attention_kernel(
|
||||
stride_od,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
stride_cached,
|
||||
stride_cachebs,
|
||||
stride_cached,
|
||||
stride_bts,
|
||||
stride_btb,
|
||||
context_lengths,
|
||||
@@ -158,29 +158,29 @@ def _fwd_context_paged_attention_kernel(
|
||||
# Copy k to corresponding cache block
|
||||
offsets_dmodel = tl.arange(0, HEAD_DIM)
|
||||
offsets_kt = block_start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offsets_k = K + offset_kv + offsets_dmodel[:, None] * stride_kd + offsets_kt[None, :] * stride_kt
|
||||
k = tl.load(offsets_k, mask=offsets_kt[None, :] < cur_seq_len, other=0.0)
|
||||
offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt
|
||||
k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0)
|
||||
offsets_kcachebs = tl.arange(0, BLOCK_SIZE)
|
||||
offsets_kcache = (
|
||||
KCache
|
||||
+ offset_kvcache
|
||||
+ offsets_dmodel[:, None] * stride_cached
|
||||
+ offsets_kcachebs[None, :] * stride_cachebs
|
||||
+ offsets_dmodel[None, :] * stride_cached
|
||||
+ offsets_kcachebs[:, None] * stride_cachebs
|
||||
)
|
||||
tl.store(offsets_kcache, k, mask=offsets_kcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
# Copy v to corresponding cache block
|
||||
offsets_vd = offsets_dmodel
|
||||
offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
offsets_v = V + offset_kv + offsets_vt[:, None] * stride_vt + offsets_vd[None, :] * stride_vd
|
||||
v = tl.load(offsets_v, mask=offsets_vt[:, None] < cur_seq_len, other=0.0)
|
||||
offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd
|
||||
v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0)
|
||||
offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here
|
||||
offsets_vcache = (
|
||||
VCache
|
||||
+ offset_kvcache
|
||||
+ offsets_vcachebs[:, None] * stride_cachebs
|
||||
+ offsets_dmodel[None, :] * stride_cached
|
||||
+ offsets_vcachebs[None, :] * stride_cachebs
|
||||
+ offsets_dmodel[:, None] * stride_cached
|
||||
)
|
||||
tl.store(offsets_vcache, v, mask=offsets_vcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE)
|
||||
|
||||
return
|
||||
|
||||
|
Reference in New Issue
Block a user