mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[inference] Adapted to Rotary Embedding and RMS Norm (#5283)
* adapted to rotary_embedding * adapted to nopad rms norm * fix bugs in benchmark * fix flash_decoding.py
This commit is contained in:
@@ -18,7 +18,6 @@ def _flash_decoding_fwd_kernel(
|
||||
kv_seq_len, # [batch_size]
|
||||
stride_qt,
|
||||
stride_qh,
|
||||
stride_ql,
|
||||
stride_qd,
|
||||
stride_cacheb,
|
||||
stride_cacheh,
|
||||
@@ -199,7 +198,7 @@ def flash_decoding_attention(
|
||||
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): [bsz, num_heads, q_len(1), head_dim]
|
||||
q (torch.Tensor): [bsz, num_heads, head_dim]
|
||||
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
kv_seq_len (torch.Tensor): [batch_size]
|
||||
@@ -216,7 +215,10 @@ def flash_decoding_attention(
|
||||
Returns:
|
||||
Output tensor with shape [bsz, num_heads, q_len, head_dim]
|
||||
"""
|
||||
bsz, num_heads, _, head_dim = q.shape
|
||||
if q.dim() == 3:
|
||||
bsz, num_heads, head_dim = q.shape
|
||||
else:
|
||||
raise ValueError(f"The query dim should be 3, but got {q.dim()}.")
|
||||
|
||||
assert head_dim in {32, 64, 128, 256}
|
||||
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
|
||||
@@ -262,7 +264,6 @@ def flash_decoding_attention(
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
q.stride(2),
|
||||
q.stride(3),
|
||||
k_cache.stride(0),
|
||||
k_cache.stride(1),
|
||||
k_cache.stride(2),
|
||||
|
Reference in New Issue
Block a user