mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[inference] Adapted to Rotary Embedding and RMS Norm (#5283)
* adapted to rotary_embedding * adapted to nopad rms norm * fix bugs in benchmark * fix flash_decoding.py
This commit is contained in:
@@ -53,16 +53,23 @@ def copy_kv_to_blocked_cache(
|
||||
Copy keys or values to the blocked key/value cache during decoding stage.
|
||||
|
||||
Parameters:
|
||||
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
|
||||
- k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache.
|
||||
- kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
|
||||
- block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
|
||||
"""
|
||||
assert k.dim() == 4, "Unsupported shape of k (supposed to be used for decoding stage)"
|
||||
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
|
||||
assert k.size(-1) == k_cache.size(-2), "Incompatible head dim"
|
||||
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
|
||||
bsz, _, num_kv_heads, head_dim = k.shape
|
||||
if k.dim() == 4:
|
||||
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
|
||||
bsz, _, num_kv_heads, head_dim = k.shape
|
||||
# [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim]
|
||||
k = k.squeeze(dim=1)
|
||||
elif k.dim() == 3:
|
||||
bsz, num_kv_heads, head_dim = k.shape
|
||||
else:
|
||||
raise ValueError(f"The key dim should be 3 or 4, but got {k.dim()}.")
|
||||
|
||||
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
|
||||
f"Got incompatible batch size (number of seqs):\n"
|
||||
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; "
|
||||
@@ -71,8 +78,6 @@ def copy_kv_to_blocked_cache(
|
||||
|
||||
# Modify if the shape of kv cahce is changed.
|
||||
block_size = k_cache.size(-1)
|
||||
# [bsz, 1, num_kv_heads, head_dim] -> [bsz, num_kv_heads, head_dim]
|
||||
k = k.squeeze(dim=1)
|
||||
|
||||
num_warps = 8 if head_dim > 128 else 4
|
||||
|
||||
|
Reference in New Issue
Block a user