mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[Inference] Adapt to Fused rotary (#5348)
* revise rotary embedding * remove useless print * adapt * fix * add * fix * modeling * fix * fix * fix
This commit is contained in:
@@ -3,7 +3,7 @@ import torch
|
||||
from packaging import version
|
||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||
|
||||
from colossalai.kernel.triton import rotary_embedding
|
||||
from colossalai.kernel.triton import copy_kv_to_blocked_cache, rotary_embedding
|
||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
|
||||
|
||||
try:
|
||||
@@ -94,8 +94,8 @@ configs = [
|
||||
x_names=["num_tokens"],
|
||||
x_vals=[2**i for i in range(4, 11)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
||||
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
||||
line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
|
||||
line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"rotary_emb-batch-{BATCH}",
|
||||
@@ -110,11 +110,16 @@ def benchmark_rotary_emb(
|
||||
num_tokens: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
BATCH_SIZE = 4
|
||||
SEQ_LEN = num_tokens // BATCH_SIZE
|
||||
max_num_blocks_per_seq = 8
|
||||
block_size = 64
|
||||
warmup = 10
|
||||
rep = 100
|
||||
|
||||
head_dim = 128
|
||||
head_dim = 256
|
||||
dtype = torch.float16
|
||||
|
||||
q_shape = (num_tokens, num_kv_heads, head_dim)
|
||||
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||
k_shape = (num_tokens, num_kv_heads, head_dim)
|
||||
@@ -122,11 +127,26 @@ def benchmark_rotary_emb(
|
||||
cos_shape = (num_tokens, head_dim // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
|
||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||
v = torch.randn_like(k)
|
||||
v_cache = torch.zeros_like(k_cache)
|
||||
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||
)
|
||||
new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda")
|
||||
new_q = torch.randn_like(new_k)
|
||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||
block_tables = block_tables.to(device="cuda")
|
||||
|
||||
if provider == "torch_rotary_emb_func":
|
||||
fn = lambda: torch_rotary_emb(q, cos, sin)
|
||||
elif provider == "triton_rotary_emb_func":
|
||||
fn = lambda: rotary_embedding(q, k, cos, sin)
|
||||
if provider == "no_fused_rotary_emb_func":
|
||||
fn = lambda: [
|
||||
rotary_embedding(new_q, new_k, cos, sin),
|
||||
copy_kv_to_blocked_cache(new_k, k_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables),
|
||||
]
|
||||
elif provider == "fused_triton_rotary_emb_func":
|
||||
fn = lambda: rotary_embedding(new_q, new_k, cos, sin, k_cache, block_tables, kv_seq_lengths)
|
||||
else:
|
||||
raise ValueError("Undefined provider")
|
||||
|
||||
@@ -135,5 +155,5 @@ def benchmark_rotary_emb(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_emb(4, 64, 32, 64, torch.float32)
|
||||
# benchmark_rotary_emb.run(save_path=".",print_data=True)
|
||||
# test_rotary_emb(4, 64, 32, 64, torch.float32)
|
||||
benchmark_rotary_emb.run(save_path=".", print_data=True)
|
||||
|
Reference in New Issue
Block a user