[inference] Adapted to Rotary Embedding and RMS Norm (#5283)

* adapted to rotary_embedding

* adapted to nopad rms norm

* fix bugs in benchmark

* fix flash_decoding.py
This commit is contained in:
yuehuayingxueluo
2024-01-22 10:55:34 +08:00
committed by GitHub
parent 6e487e7d3c
commit bfff9254ac
5 changed files with 140 additions and 43 deletions

View File

@@ -18,7 +18,6 @@ def _flash_decoding_fwd_kernel(
kv_seq_len, # [batch_size]
stride_qt,
stride_qh,
stride_ql,
stride_qd,
stride_cacheb,
stride_cacheh,
@@ -199,7 +198,7 @@ def flash_decoding_attention(
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
Args:
q (torch.Tensor): [bsz, num_heads, q_len(1), head_dim]
q (torch.Tensor): [bsz, num_heads, head_dim]
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
kv_seq_len (torch.Tensor): [batch_size]
@@ -216,7 +215,10 @@ def flash_decoding_attention(
Returns:
Output tensor with shape [bsz, num_heads, q_len, head_dim]
"""
bsz, num_heads, _, head_dim = q.shape
if q.dim() == 3:
bsz, num_heads, head_dim = q.shape
else:
raise ValueError(f"The query dim should be 3, but got {q.dim()}.")
assert head_dim in {32, 64, 128, 256}
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
@@ -262,7 +264,6 @@ def flash_decoding_attention(
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),