mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 03:43:01 +00:00
[Inference]Move benchmark-related code to the example directory. (#5408)
* move benchmark-related code to the example directory. * fix bugs in test_fused_rotary_embedding.py
This commit is contained in:
parent
600881a8ea
commit
0aa27f1961
113
examples/inference/benchmark_ops/benchmark_context_attn_unpad.py
Normal file
113
examples/inference/benchmark_ops/benchmark_context_attn_unpad.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import torch
|
||||||
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
|
|
||||||
|
from colossalai.inference.modeling.layers.attention import PagedAttention
|
||||||
|
from colossalai.kernel.triton import context_attention_unpadded
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
|
||||||
|
|
||||||
|
try:
|
||||||
|
import triton # noqa
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
|
HEAD_DIM = 32
|
||||||
|
BATCH = 16
|
||||||
|
BLOCK_SIZE = 32
|
||||||
|
SAME_LEN = True
|
||||||
|
WARM_UPS = 10
|
||||||
|
REPS = 100
|
||||||
|
configs = [
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["KV_LEN"],
|
||||||
|
x_vals=[2**i for i in range(8, 13)],
|
||||||
|
# x_vals=[x for x in range(256, 8192, 256)],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["torch", "triton"],
|
||||||
|
line_names=["Torch", "Triton"],
|
||||||
|
styles=[("red", "-"), ("blue", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name=f"context_attn-block_size-{BLOCK_SIZE}-batch{BATCH}",
|
||||||
|
args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(configs)
|
||||||
|
def bench_kernel(
|
||||||
|
bsz,
|
||||||
|
KV_LEN,
|
||||||
|
provider,
|
||||||
|
block_size: int,
|
||||||
|
kv_group_num: int,
|
||||||
|
same_context_len: bool,
|
||||||
|
):
|
||||||
|
num_attn_heads = 16
|
||||||
|
max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size)
|
||||||
|
max_seq_len = block_size * max_num_blocks_per_seq
|
||||||
|
|
||||||
|
num_kv_heads = num_attn_heads // kv_group_num
|
||||||
|
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
||||||
|
dtype = torch.float16
|
||||||
|
device = get_current_device()
|
||||||
|
|
||||||
|
if same_context_len:
|
||||||
|
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
|
||||||
|
else:
|
||||||
|
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
|
||||||
|
num_tokens = torch.sum(context_lengths).item()
|
||||||
|
|
||||||
|
qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, HEAD_DIM)
|
||||||
|
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||||
|
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
|
||||||
|
q_unpad = q_unpad.contiguous()
|
||||||
|
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
|
||||||
|
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||||
|
)
|
||||||
|
block_tables = block_tables.to(device=device)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
if provider == "torch":
|
||||||
|
q_padded = PagedAttention.pad_and_reshape(q_unpad, context_lengths, max_seq_len, num_attn_heads, HEAD_DIM)
|
||||||
|
k_padded = PagedAttention.pad_and_reshape(k_unpad, context_lengths, max_seq_len, num_kv_heads, HEAD_DIM)
|
||||||
|
v_padded = PagedAttention.pad_and_reshape(v_unpad, context_lengths, max_seq_len, num_kv_heads, HEAD_DIM)
|
||||||
|
q_padded, k_padded, v_padded = (
|
||||||
|
q_padded.to(device=device),
|
||||||
|
k_padded.to(device=device),
|
||||||
|
v_padded.to(device=device),
|
||||||
|
)
|
||||||
|
q_padded = q_padded.transpose(1, 2)
|
||||||
|
k_padded = PagedAttention.repeat_kv(k_padded.transpose(1, 2), kv_group_num)
|
||||||
|
v_padded = PagedAttention.repeat_kv(v_padded.transpose(1, 2), kv_group_num)
|
||||||
|
# This benchmark ignores the padding mask. *Only* use the-same-length inputs for benchmarkings
|
||||||
|
attn_mask = AttentionMaskConverter._make_causal_mask(
|
||||||
|
(bsz, max_seq_len), q_padded.dtype, q_padded.device, past_key_values_length=0
|
||||||
|
)
|
||||||
|
attn_mask = attn_mask.to(device=q_padded.device)
|
||||||
|
fn = lambda: torch_attn_ref(
|
||||||
|
q_padded,
|
||||||
|
k_padded,
|
||||||
|
v_padded,
|
||||||
|
attn_mask,
|
||||||
|
bsz,
|
||||||
|
max_seq_len,
|
||||||
|
max_seq_len,
|
||||||
|
num_attn_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
HEAD_DIM,
|
||||||
|
)
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||||
|
if provider == "triton":
|
||||||
|
k_cache_triton = torch.zeros_like(k_cache_ref)
|
||||||
|
v_cache_triton = torch.zeros_like(v_cache_ref)
|
||||||
|
fn = lambda: context_attention_unpadded(
|
||||||
|
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
|
||||||
|
)
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||||
|
|
||||||
|
return ms, min_ms, max_ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bench_kernel.run(save_path=".", print_data=True)
|
110
examples/inference/benchmark_ops/benchmark_decoding_attn.py
Normal file
110
examples/inference/benchmark_ops/benchmark_decoding_attn.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.kernel.triton import flash_decoding_attention
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from tests.test_infer.test_ops.triton.kernel_utils import (
|
||||||
|
convert_kv_unpad_to_padded,
|
||||||
|
generate_caches_and_block_tables_v2,
|
||||||
|
prepare_padding_mask,
|
||||||
|
torch_attn_ref,
|
||||||
|
)
|
||||||
|
from tests.test_infer.test_ops.triton.test_decoding_attn import prepare_data
|
||||||
|
|
||||||
|
try:
|
||||||
|
import triton # noqa
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
|
Q_LEN = 1
|
||||||
|
HEAD_DIM = 128
|
||||||
|
BATCH = 16
|
||||||
|
BLOCK_SIZE = 32
|
||||||
|
SAME_LEN = True
|
||||||
|
WARM_UPS = 10
|
||||||
|
REPS = 100
|
||||||
|
configs = [
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["KV_LEN"],
|
||||||
|
x_vals=[2**i for i in range(8, 14)],
|
||||||
|
# x_vals=[x for x in range(256, 8192, 256)],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["torch", "triton"],
|
||||||
|
line_names=["Torch", "Triton"],
|
||||||
|
styles=[("red", "-"), ("blue", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name=f"decoding-block_size-{BLOCK_SIZE}-batch{BATCH}",
|
||||||
|
args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(configs)
|
||||||
|
def bench_kernel(
|
||||||
|
bsz,
|
||||||
|
KV_LEN,
|
||||||
|
provider,
|
||||||
|
block_size: int,
|
||||||
|
kv_group_num: int,
|
||||||
|
same_context_len: bool,
|
||||||
|
):
|
||||||
|
num_attn_heads = 16
|
||||||
|
max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size)
|
||||||
|
max_seq_len = block_size * max_num_blocks_per_seq
|
||||||
|
|
||||||
|
num_kv_heads = num_attn_heads // kv_group_num
|
||||||
|
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
||||||
|
block_size * max_num_blocks_per_seq
|
||||||
|
dtype = torch.float16
|
||||||
|
device = get_current_device()
|
||||||
|
|
||||||
|
q, k_unpad, v_unpad, kv_lengths = prepare_data(
|
||||||
|
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
|
||||||
|
)
|
||||||
|
max_seq_len_in_b = kv_lengths.max().item() # for random lengths
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
if provider == "torch":
|
||||||
|
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_seq_len_in_b)
|
||||||
|
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_seq_len_in_b)
|
||||||
|
torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, max_seq_len_in_b, q.device)
|
||||||
|
fn = lambda: torch_attn_ref(
|
||||||
|
q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
||||||
|
)
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||||
|
if provider == "triton":
|
||||||
|
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
||||||
|
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||||
|
)
|
||||||
|
block_tables = block_tables.to(device=device)
|
||||||
|
# the maximum block length splitted on kv should be the kv cache block size
|
||||||
|
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
||||||
|
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
|
||||||
|
mid_output = torch.empty(
|
||||||
|
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||||
|
)
|
||||||
|
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
||||||
|
sm_scale = 1.0 / (HEAD_DIM**0.5)
|
||||||
|
fn = lambda: flash_decoding_attention(
|
||||||
|
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1),
|
||||||
|
# refer to attention forward in modeling.
|
||||||
|
q.squeeze(2),
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
kv_lengths,
|
||||||
|
block_tables,
|
||||||
|
block_size,
|
||||||
|
max_seq_len_in_b,
|
||||||
|
output,
|
||||||
|
mid_output,
|
||||||
|
mid_output_lse,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
kv_group_num=kv_group_num,
|
||||||
|
) # [bsz, 1, num_heads, head_dim]
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
||||||
|
|
||||||
|
return ms, min_ms, max_ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bench_kernel.run(save_path=".", print_data=True)
|
@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
|
||||||
|
from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding
|
||||||
|
|
||||||
|
BATCH = 16
|
||||||
|
configs = [
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["num_tokens"],
|
||||||
|
x_vals=[2**i for i in range(4, 12)],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
||||||
|
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
||||||
|
styles=[("red", "-"), ("blue", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name=f"rotary_emb-batch-{BATCH}",
|
||||||
|
args={"num_kv_heads": 16},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def torch_rotary_emb(x, cos, sin):
|
||||||
|
seq_len, h, dim = x.shape
|
||||||
|
x0 = x[:, :, 0 : dim // 2]
|
||||||
|
x1 = x[:, :, dim // 2 : dim]
|
||||||
|
cos = cos.view((seq_len, 1, dim // 2))
|
||||||
|
sin = sin.view((seq_len, 1, dim // 2))
|
||||||
|
o0 = x0 * cos - x1 * sin
|
||||||
|
o1 = x0 * sin + x1 * cos
|
||||||
|
return torch.cat((o0, o1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(configs)
|
||||||
|
def benchmark_rotary_emb(
|
||||||
|
provider: str,
|
||||||
|
num_tokens: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
):
|
||||||
|
warmup = 10
|
||||||
|
rep = 100
|
||||||
|
|
||||||
|
head_dim = 128
|
||||||
|
dtype = torch.float16
|
||||||
|
q_shape = (num_tokens, num_kv_heads, head_dim)
|
||||||
|
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||||
|
k_shape = (num_tokens, num_kv_heads, head_dim)
|
||||||
|
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||||
|
cos_shape = (4096, head_dim // 2)
|
||||||
|
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
lengths = torch.tensor([3, 4, 6, 7], device="cuda")
|
||||||
|
|
||||||
|
if provider == "torch_rotary_emb_func":
|
||||||
|
fn = lambda: torch_rotary_emb(q, cos[:num_tokens], sin[:num_tokens])
|
||||||
|
elif provider == "triton_rotary_emb_func":
|
||||||
|
fn = lambda: fused_rotary_embedding(q, k, cos, sin, lengths)
|
||||||
|
else:
|
||||||
|
raise ValueError("Undefined provider")
|
||||||
|
|
||||||
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||||
|
return ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
benchmark_rotary_emb.run(save_path=".", print_data=True)
|
78
examples/inference/benchmark_ops/benchmark_rmsnorm_triton.py
Normal file
78
examples/inference/benchmark_ops/benchmark_rmsnorm_triton.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
|
||||||
|
from colossalai.kernel.triton import rms_layernorm
|
||||||
|
|
||||||
|
try:
|
||||||
|
import triton # noqa
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
|
|
||||||
|
# Triton benchmark plot attributions
|
||||||
|
configs = [
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["SEQUENCE_TOTAL"],
|
||||||
|
x_vals=[i for i in range(128, 1025, 128)],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=[
|
||||||
|
"vllm_rms_layernorm",
|
||||||
|
"triton_rms_layernorm",
|
||||||
|
"triton_rms_layernorm_with_residual",
|
||||||
|
"vllm_rms_layernorm_with_residual",
|
||||||
|
],
|
||||||
|
line_names=[
|
||||||
|
"vllm_rms_layernorm",
|
||||||
|
"triton_rms_layernorm",
|
||||||
|
"triton_rms_layernorm_with_residual",
|
||||||
|
"vllm_rms_layernorm_with_residual",
|
||||||
|
],
|
||||||
|
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name=f"RMSNorm benchmarking results",
|
||||||
|
args={"HIDDEN_SIZE": 1024},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(configs)
|
||||||
|
def benchmark_rms_layernorm(
|
||||||
|
provider: str,
|
||||||
|
SEQUENCE_TOTAL: int,
|
||||||
|
HIDDEN_SIZE: int,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")
|
||||||
|
|
||||||
|
warmup = 10
|
||||||
|
rep = 1000
|
||||||
|
|
||||||
|
dtype = torch.float16
|
||||||
|
eps = 1e-5
|
||||||
|
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
|
||||||
|
w_shape = (x_shape[-1],)
|
||||||
|
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
|
||||||
|
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
||||||
|
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).to(dtype=dtype, device="cuda")
|
||||||
|
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||||
|
if provider == "vllm_rms_layernorm":
|
||||||
|
fn = lambda: vllm_norm(x)
|
||||||
|
elif provider == "triton_rms_layernorm":
|
||||||
|
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
||||||
|
elif provider == "vllm_rms_layernorm_with_residual":
|
||||||
|
fn = lambda: vllm_norm(x, residual=residual)
|
||||||
|
elif provider == "triton_rms_layernorm_with_residual":
|
||||||
|
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
|
||||||
|
else:
|
||||||
|
raise ValueError("Undefined provider.")
|
||||||
|
|
||||||
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||||
|
|
||||||
|
return ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
benchmark_rms_layernorm.run(save_path=".", print_data=True)
|
@ -0,0 +1,90 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
||||||
|
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token
|
||||||
|
|
||||||
|
try:
|
||||||
|
import triton # noqa
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
|
|
||||||
|
BATCH = 16
|
||||||
|
configs = [
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["num_tokens"],
|
||||||
|
x_vals=[2**i for i in range(4, 11)],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
|
||||||
|
line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
|
||||||
|
styles=[("red", "-"), ("blue", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name=f"rotary_emb-batch-{BATCH}",
|
||||||
|
args={"num_kv_heads": 16},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(configs)
|
||||||
|
def benchmark_rotary_emb(
|
||||||
|
provider: str,
|
||||||
|
num_tokens: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
):
|
||||||
|
BATCH_SIZE = 4
|
||||||
|
SEQ_LEN = num_tokens // BATCH_SIZE
|
||||||
|
max_num_blocks_per_seq = 8
|
||||||
|
block_size = 64
|
||||||
|
warmup = 10
|
||||||
|
rep = 100
|
||||||
|
|
||||||
|
head_dim = 4096
|
||||||
|
dtype = torch.float16
|
||||||
|
|
||||||
|
q_shape = (num_tokens, num_kv_heads, head_dim)
|
||||||
|
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||||
|
k_shape = (num_tokens, num_kv_heads, head_dim)
|
||||||
|
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||||
|
v = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
cos_shape = (num_tokens, head_dim // 2)
|
||||||
|
|
||||||
|
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
|
||||||
|
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||||
|
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||||
|
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||||
|
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
||||||
|
)
|
||||||
|
new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda")
|
||||||
|
new_q = torch.randn_like(new_k)
|
||||||
|
new_v = torch.randn_like(new_k)
|
||||||
|
|
||||||
|
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
||||||
|
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||||
|
block_tables = block_tables.to(device="cuda")
|
||||||
|
|
||||||
|
if provider == "no_fused_rotary_emb_func":
|
||||||
|
fn = lambda: [
|
||||||
|
rotary_embedding(new_q, new_k, cos, sin),
|
||||||
|
copy_kv_to_blocked_cache(
|
||||||
|
new_k, new_v, k_cache, v_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables
|
||||||
|
),
|
||||||
|
]
|
||||||
|
elif provider == "fused_triton_rotary_emb_func":
|
||||||
|
fn = lambda: decoding_fused_rotary_embedding(
|
||||||
|
new_q, new_k, new_k, cos, sin, k_cache, k_cache, block_tables, kv_seq_lengths
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Undefined provider")
|
||||||
|
|
||||||
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||||
|
return ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
benchmark_rotary_emb.run(save_path=".", print_data=True)
|
@ -1,9 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
|
||||||
|
|
||||||
from colossalai.inference.modeling.layers.attention import PagedAttention
|
|
||||||
from colossalai.kernel.triton import context_attention_unpadded
|
from colossalai.kernel.triton import context_attention_unpadded
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
|
from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref
|
||||||
@ -92,7 +90,6 @@ def test_context_attention(
|
|||||||
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||||
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
|
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
|
||||||
q_unpad = q_unpad.contiguous()
|
q_unpad = q_unpad.contiguous()
|
||||||
|
|
||||||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
|
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
|
||||||
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
||||||
)
|
)
|
||||||
@ -116,102 +113,5 @@ def test_context_attention(
|
|||||||
assert torch.equal(v_cache_ref, v_cache_triton)
|
assert torch.equal(v_cache_ref, v_cache_triton)
|
||||||
|
|
||||||
|
|
||||||
BATCH = 16
|
|
||||||
BLOCK_SIZE = 32
|
|
||||||
SAME_LEN = True
|
|
||||||
WARM_UPS = 10
|
|
||||||
REPS = 100
|
|
||||||
configs = [
|
|
||||||
triton.testing.Benchmark(
|
|
||||||
x_names=["KV_LEN"],
|
|
||||||
x_vals=[2**i for i in range(8, 13)],
|
|
||||||
# x_vals=[x for x in range(256, 8192, 256)],
|
|
||||||
line_arg="provider",
|
|
||||||
line_vals=["torch", "triton"],
|
|
||||||
line_names=["Torch", "Triton"],
|
|
||||||
styles=[("red", "-"), ("blue", "-")],
|
|
||||||
ylabel="ms",
|
|
||||||
plot_name=f"context_attn-block_size-{BLOCK_SIZE}-batch{BATCH}",
|
|
||||||
args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(configs)
|
|
||||||
def bench_kernel(
|
|
||||||
bsz,
|
|
||||||
KV_LEN,
|
|
||||||
provider,
|
|
||||||
block_size: int,
|
|
||||||
kv_group_num: int,
|
|
||||||
same_context_len: bool,
|
|
||||||
):
|
|
||||||
num_attn_heads = 16
|
|
||||||
max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size)
|
|
||||||
max_seq_len = block_size * max_num_blocks_per_seq
|
|
||||||
|
|
||||||
num_kv_heads = num_attn_heads // kv_group_num
|
|
||||||
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
|
||||||
dtype = torch.float16
|
|
||||||
device = get_current_device()
|
|
||||||
|
|
||||||
if same_context_len:
|
|
||||||
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
|
|
||||||
else:
|
|
||||||
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
|
|
||||||
num_tokens = torch.sum(context_lengths).item()
|
|
||||||
|
|
||||||
qkv_size = (num_tokens, num_attn_heads + 2 * num_kv_heads, HEAD_DIM)
|
|
||||||
qkv_unpad = torch.empty(size=qkv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
|
||||||
q_unpad, k_unpad, v_unpad = torch.split(qkv_unpad, [num_attn_heads, num_kv_heads, num_kv_heads], dim=-2)
|
|
||||||
q_unpad = q_unpad.contiguous()
|
|
||||||
k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2(
|
|
||||||
k_unpad, v_unpad, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
|
||||||
)
|
|
||||||
block_tables = block_tables.to(device=device)
|
|
||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
|
||||||
if provider == "torch":
|
|
||||||
q_padded = PagedAttention.pad_and_reshape(q_unpad, context_lengths, max_seq_len, num_attn_heads, HEAD_DIM)
|
|
||||||
k_padded = PagedAttention.pad_and_reshape(k_unpad, context_lengths, max_seq_len, num_kv_heads, HEAD_DIM)
|
|
||||||
v_padded = PagedAttention.pad_and_reshape(v_unpad, context_lengths, max_seq_len, num_kv_heads, HEAD_DIM)
|
|
||||||
q_padded, k_padded, v_padded = (
|
|
||||||
q_padded.to(device=device),
|
|
||||||
k_padded.to(device=device),
|
|
||||||
v_padded.to(device=device),
|
|
||||||
)
|
|
||||||
q_padded = q_padded.transpose(1, 2)
|
|
||||||
k_padded = PagedAttention.repeat_kv(k_padded.transpose(1, 2), kv_group_num)
|
|
||||||
v_padded = PagedAttention.repeat_kv(v_padded.transpose(1, 2), kv_group_num)
|
|
||||||
# This benchmark ignores the padding mask. *Only* use the-same-length inputs for benchmarkings
|
|
||||||
attn_mask = AttentionMaskConverter._make_causal_mask(
|
|
||||||
(bsz, max_seq_len), q_padded.dtype, q_padded.device, past_key_values_length=0
|
|
||||||
)
|
|
||||||
attn_mask = attn_mask.to(device=q_padded.device)
|
|
||||||
fn = lambda: torch_attn_ref(
|
|
||||||
q_padded,
|
|
||||||
k_padded,
|
|
||||||
v_padded,
|
|
||||||
attn_mask,
|
|
||||||
bsz,
|
|
||||||
max_seq_len,
|
|
||||||
max_seq_len,
|
|
||||||
num_attn_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
HEAD_DIM,
|
|
||||||
)
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
|
||||||
if provider == "triton":
|
|
||||||
k_cache_triton = torch.zeros_like(k_cache_ref)
|
|
||||||
v_cache_triton = torch.zeros_like(v_cache_ref)
|
|
||||||
fn = lambda: context_attention_unpadded(
|
|
||||||
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
|
|
||||||
)
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
|
||||||
|
|
||||||
return ms, min_ms, max_ms
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_context_attention(4, 32, 8, 16, 1, True)
|
test_context_attention(4, 32, 8, 16, 1, True)
|
||||||
# bench_kernel.run(save_path=".", print_data=True)
|
|
||||||
|
@ -128,94 +128,5 @@ def test_flash_decoding(
|
|||||||
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
|
assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
BATCH = 16
|
|
||||||
BLOCK_SIZE = 32
|
|
||||||
SAME_LEN = True
|
|
||||||
WARM_UPS = 10
|
|
||||||
REPS = 100
|
|
||||||
configs = [
|
|
||||||
triton.testing.Benchmark(
|
|
||||||
x_names=["KV_LEN"],
|
|
||||||
x_vals=[2**i for i in range(8, 14)],
|
|
||||||
# x_vals=[x for x in range(256, 8192, 256)],
|
|
||||||
line_arg="provider",
|
|
||||||
line_vals=["torch", "triton"],
|
|
||||||
line_names=["Torch", "Triton"],
|
|
||||||
styles=[("red", "-"), ("blue", "-")],
|
|
||||||
ylabel="ms",
|
|
||||||
plot_name=f"decoding-block_size-{BLOCK_SIZE}-batch{BATCH}",
|
|
||||||
args={"bsz": BATCH, "block_size": BLOCK_SIZE, "same_context_len": SAME_LEN, "kv_group_num": 1},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(configs)
|
|
||||||
def bench_kernel(
|
|
||||||
bsz,
|
|
||||||
KV_LEN,
|
|
||||||
provider,
|
|
||||||
block_size: int,
|
|
||||||
kv_group_num: int,
|
|
||||||
same_context_len: bool,
|
|
||||||
):
|
|
||||||
num_attn_heads = 16
|
|
||||||
max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size)
|
|
||||||
max_seq_len = block_size * max_num_blocks_per_seq
|
|
||||||
|
|
||||||
num_kv_heads = num_attn_heads // kv_group_num
|
|
||||||
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
|
|
||||||
block_size * max_num_blocks_per_seq
|
|
||||||
dtype = torch.float16
|
|
||||||
device = get_current_device()
|
|
||||||
|
|
||||||
q, k_unpad, v_unpad, kv_lengths = prepare_data(
|
|
||||||
bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, Q_LEN, max_seq_len, dtype, device
|
|
||||||
)
|
|
||||||
max_seq_len_in_b = kv_lengths.max().item() # for random lengths
|
|
||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
|
||||||
if provider == "torch":
|
|
||||||
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_seq_len_in_b)
|
|
||||||
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_seq_len_in_b)
|
|
||||||
torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, max_seq_len_in_b, q.device)
|
|
||||||
fn = lambda: torch_attn_ref(
|
|
||||||
q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
|
|
||||||
)
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
|
||||||
if provider == "triton":
|
|
||||||
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v2(
|
|
||||||
k_unpad, v_unpad, kv_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device
|
|
||||||
)
|
|
||||||
block_tables = block_tables.to(device=device)
|
|
||||||
# the maximum block length splitted on kv should be the kv cache block size
|
|
||||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
|
||||||
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
|
|
||||||
mid_output = torch.empty(
|
|
||||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
|
||||||
)
|
|
||||||
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
|
|
||||||
sm_scale = 1.0 / (HEAD_DIM**0.5)
|
|
||||||
fn = lambda: flash_decoding_attention(
|
|
||||||
# Here we use q.squeeze(2) because we hide the q_len dimension (which is equivalent to 1),
|
|
||||||
# refer to attention forward in modeling.
|
|
||||||
q.squeeze(2),
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
kv_lengths,
|
|
||||||
block_tables,
|
|
||||||
block_size,
|
|
||||||
max_seq_len_in_b,
|
|
||||||
output,
|
|
||||||
mid_output,
|
|
||||||
mid_output_lse,
|
|
||||||
sm_scale=sm_scale,
|
|
||||||
kv_group_num=kv_group_num,
|
|
||||||
) # [bsz, 1, num_heads, head_dim]
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=WARM_UPS, rep=REPS, quantiles=quantiles)
|
|
||||||
|
|
||||||
return ms, min_ms, max_ms
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_flash_decoding(16, 32, 32, 16, 1, True)
|
test_flash_decoding(16, 32, 32, 16, 1, True)
|
||||||
# bench_kernel.run(save_path=".", print_data=True)
|
|
||||||
|
@ -1,70 +1,26 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import triton
|
from packaging import version
|
||||||
|
|
||||||
from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding
|
from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding
|
||||||
from colossalai.kernel.triton.no_pad_rotary_embedding import rotary_embedding
|
from colossalai.kernel.triton.no_pad_rotary_embedding import rotary_embedding
|
||||||
from colossalai.kernel.triton.rotary_cache_copy import get_xine_cache
|
from colossalai.kernel.triton.rotary_cache_copy import get_xine_cache
|
||||||
|
|
||||||
BATCH = 16
|
try:
|
||||||
configs = [
|
import triton # noqa
|
||||||
triton.testing.Benchmark(
|
|
||||||
x_names=["num_tokens"],
|
HAS_TRITON = True
|
||||||
x_vals=[2**i for i in range(4, 12)],
|
except ImportError:
|
||||||
line_arg="provider",
|
HAS_TRITON = False
|
||||||
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
print("please install triton from https://github.com/openai/triton")
|
||||||
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
|
||||||
styles=[("red", "-"), ("blue", "-")],
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||||
ylabel="ms",
|
|
||||||
plot_name=f"rotary_emb-batch-{BATCH}",
|
|
||||||
args={"num_kv_heads": 16},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def torch_rotary_emb(x, cos, sin):
|
@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
|
||||||
seq_len, h, dim = x.shape
|
def test_fused_rotary_emb():
|
||||||
x0 = x[:, :, 0 : dim // 2]
|
|
||||||
x1 = x[:, :, dim // 2 : dim]
|
|
||||||
cos = cos.view((seq_len, 1, dim // 2))
|
|
||||||
sin = sin.view((seq_len, 1, dim // 2))
|
|
||||||
o0 = x0 * cos - x1 * sin
|
|
||||||
o1 = x0 * sin + x1 * cos
|
|
||||||
return torch.cat((o0, o1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(configs)
|
|
||||||
def benchmark_rotary_emb(
|
|
||||||
provider: str,
|
|
||||||
num_tokens: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
):
|
|
||||||
warmup = 10
|
|
||||||
rep = 100
|
|
||||||
|
|
||||||
head_dim = 128
|
|
||||||
dtype = torch.float16
|
|
||||||
q_shape = (num_tokens, num_kv_heads, head_dim)
|
|
||||||
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
|
||||||
k_shape = (num_tokens, num_kv_heads, head_dim)
|
|
||||||
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
|
||||||
cos_shape = (4096, head_dim // 2)
|
|
||||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
|
||||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
|
||||||
|
|
||||||
if provider == "torch_rotary_emb_func":
|
|
||||||
fn = lambda: torch_rotary_emb(q, cos[:num_tokens], sin[:num_tokens])
|
|
||||||
elif provider == "triton_rotary_emb_func":
|
|
||||||
fn = lambda: fused_rotary_embedding(q, k, cos, sin, lengths)
|
|
||||||
else:
|
|
||||||
raise ValueError("Undefined provider")
|
|
||||||
|
|
||||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
|
||||||
return ms
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
num_tokens = 20
|
num_tokens = 20
|
||||||
num_kv_heads = 32
|
num_kv_heads = 32
|
||||||
head_dim = 64
|
head_dim = 64
|
||||||
@ -82,12 +38,13 @@ if __name__ == "__main__":
|
|||||||
cos_cache = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
cos_cache = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
sin_cache = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
sin_cache = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
cos = get_xine_cache(lengths, cos_cache[:, : head_dim // 2])
|
cos, sin = get_xine_cache(lengths, cos_cache[:, : head_dim // 2], sin_cache[:, : head_dim // 2])
|
||||||
sin = get_xine_cache(lengths, sin_cache[:, : head_dim // 2])
|
|
||||||
|
|
||||||
rotary_embedding(q, k, cos, sin)
|
rotary_embedding(q, k, cos, sin)
|
||||||
fused_rotary_embedding(q_copy, k_copy, cos_cache, sin_cache, lengths)
|
fused_rotary_embedding(q_copy, k_copy, cos_cache, sin_cache, lengths)
|
||||||
torch.allclose(q, q_copy)
|
torch.allclose(q, q_copy)
|
||||||
torch.allclose(k, k_copy)
|
torch.allclose(k, k_copy)
|
||||||
|
|
||||||
# benchmark_rotary_emb.run(save_path=".",print_data=True)
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_fused_rotary_emb()
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
|
|
||||||
@ -52,70 +51,5 @@ def test_layer_norm(M, N):
|
|||||||
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
|
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
# Triton benchmark plot attributions
|
|
||||||
configs = [
|
|
||||||
triton.testing.Benchmark(
|
|
||||||
x_names=["SEQUENCE_TOTAL"],
|
|
||||||
x_vals=[i for i in range(128, 1025, 128)],
|
|
||||||
line_arg="provider",
|
|
||||||
line_vals=[
|
|
||||||
"vllm_rms_layernorm",
|
|
||||||
"triton_rms_layernorm",
|
|
||||||
"triton_rms_layernorm_with_residual",
|
|
||||||
"vllm_rms_layernorm_with_residual",
|
|
||||||
],
|
|
||||||
line_names=[
|
|
||||||
"vllm_rms_layernorm",
|
|
||||||
"triton_rms_layernorm",
|
|
||||||
"triton_rms_layernorm_with_residual",
|
|
||||||
"vllm_rms_layernorm_with_residual",
|
|
||||||
],
|
|
||||||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
|
|
||||||
ylabel="ms",
|
|
||||||
plot_name=f"RMSNorm benchmarking results",
|
|
||||||
args={"HIDDEN_SIZE": 1024},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(configs)
|
|
||||||
def benchmark_rms_layernorm(
|
|
||||||
provider: str,
|
|
||||||
SEQUENCE_TOTAL: int,
|
|
||||||
HIDDEN_SIZE: int,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")
|
|
||||||
|
|
||||||
warmup = 10
|
|
||||||
rep = 1000
|
|
||||||
|
|
||||||
dtype = torch.float16
|
|
||||||
eps = 1e-5
|
|
||||||
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
|
|
||||||
w_shape = (x_shape[-1],)
|
|
||||||
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
|
|
||||||
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
|
||||||
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).to(dtype=dtype, device="cuda")
|
|
||||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
|
||||||
if provider == "vllm_rms_layernorm":
|
|
||||||
fn = lambda: vllm_norm(x)
|
|
||||||
elif provider == "triton_rms_layernorm":
|
|
||||||
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
|
||||||
elif provider == "vllm_rms_layernorm_with_residual":
|
|
||||||
fn = lambda: vllm_norm(x, residual=residual)
|
|
||||||
elif provider == "triton_rms_layernorm_with_residual":
|
|
||||||
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
|
|
||||||
else:
|
|
||||||
raise ValueError("Undefined provider.")
|
|
||||||
|
|
||||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
|
||||||
|
|
||||||
return ms
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_layer_norm()
|
test_layer_norm()
|
||||||
# benchmark_rms_layernorm.run(save_path=".", print_data=True)
|
|
||||||
|
@ -3,8 +3,8 @@ import torch
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||||
|
|
||||||
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
from colossalai.kernel.triton import decoding_fused_rotary_embedding
|
||||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token
|
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton # noqa
|
import triton # noqa
|
||||||
@ -28,6 +28,9 @@ def torch_rotary_emb(x, cos, sin):
|
|||||||
return torch.cat((o0, o1), dim=-1)
|
return torch.cat((o0, o1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("BATCH_SIZE", [4])
|
@pytest.mark.parametrize("BATCH_SIZE", [4])
|
||||||
@pytest.mark.parametrize("SEQ_LEN", [64])
|
@pytest.mark.parametrize("SEQ_LEN", [64])
|
||||||
@pytest.mark.parametrize("H", [32])
|
@pytest.mark.parametrize("H", [32])
|
||||||
@ -77,82 +80,5 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
|||||||
assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4)
|
assert torch.allclose(new_q, q_ref, atol=1e-4, rtol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
BATCH = 16
|
|
||||||
configs = [
|
|
||||||
triton.testing.Benchmark(
|
|
||||||
x_names=["num_tokens"],
|
|
||||||
x_vals=[2**i for i in range(4, 11)],
|
|
||||||
line_arg="provider",
|
|
||||||
line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
|
|
||||||
line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
|
|
||||||
styles=[("red", "-"), ("blue", "-")],
|
|
||||||
ylabel="ms",
|
|
||||||
plot_name=f"rotary_emb-batch-{BATCH}",
|
|
||||||
args={"num_kv_heads": 16},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(configs)
|
|
||||||
def benchmark_rotary_emb(
|
|
||||||
provider: str,
|
|
||||||
num_tokens: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
):
|
|
||||||
BATCH_SIZE = 4
|
|
||||||
SEQ_LEN = num_tokens // BATCH_SIZE
|
|
||||||
max_num_blocks_per_seq = 8
|
|
||||||
block_size = 64
|
|
||||||
warmup = 10
|
|
||||||
rep = 100
|
|
||||||
|
|
||||||
head_dim = 4096
|
|
||||||
dtype = torch.float16
|
|
||||||
|
|
||||||
q_shape = (num_tokens, num_kv_heads, head_dim)
|
|
||||||
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
|
||||||
k_shape = (num_tokens, num_kv_heads, head_dim)
|
|
||||||
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
|
||||||
v = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
|
||||||
|
|
||||||
cos_shape = (num_tokens, head_dim // 2)
|
|
||||||
|
|
||||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
|
||||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
|
||||||
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
|
|
||||||
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
|
||||||
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
|
||||||
|
|
||||||
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
|
||||||
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
|
||||||
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
|
|
||||||
)
|
|
||||||
new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda")
|
|
||||||
new_q = torch.randn_like(new_k)
|
|
||||||
new_v = torch.randn_like(new_k)
|
|
||||||
|
|
||||||
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)
|
|
||||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
|
||||||
block_tables = block_tables.to(device="cuda")
|
|
||||||
|
|
||||||
if provider == "no_fused_rotary_emb_func":
|
|
||||||
fn = lambda: [
|
|
||||||
rotary_embedding(new_q, new_k, cos, sin),
|
|
||||||
copy_kv_to_blocked_cache(
|
|
||||||
new_k, new_v, k_cache, v_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables
|
|
||||||
),
|
|
||||||
]
|
|
||||||
elif provider == "fused_triton_rotary_emb_func":
|
|
||||||
fn = lambda: decoding_fused_rotary_embedding(
|
|
||||||
new_q, new_k, new_k, cos, sin, k_cache, k_cache, block_tables, kv_seq_lengths
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("Undefined provider")
|
|
||||||
|
|
||||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
|
||||||
return ms
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_rotary_emb(4, 64, 32, 64, torch.float32)
|
test_rotary_emb(4, 64, 32, 64, torch.float32)
|
||||||
# benchmark_rotary_emb.run(save_path=".", print_data=True)
|
|
||||||
|
@ -38,6 +38,9 @@ def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
|
|||||||
return (cos_output, sin_output)
|
return (cos_output, sin_output)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("BATCH_SIZE", [4])
|
@pytest.mark.parametrize("BATCH_SIZE", [4])
|
||||||
@pytest.mark.parametrize("MAX_SEQ_LEN", [64])
|
@pytest.mark.parametrize("MAX_SEQ_LEN", [64])
|
||||||
@pytest.mark.parametrize("HEAD_DIM", [64])
|
@pytest.mark.parametrize("HEAD_DIM", [64])
|
||||||
@ -59,46 +62,5 @@ def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
|
|||||||
assert torch.allclose(sin, nsin_ref)
|
assert torch.allclose(sin, nsin_ref)
|
||||||
|
|
||||||
|
|
||||||
configs = [
|
|
||||||
triton.testing.Benchmark(
|
|
||||||
x_names=["max_num_tokens"],
|
|
||||||
x_vals=[2**i for i in range(6, 12)],
|
|
||||||
line_arg="provider",
|
|
||||||
line_vals=["torch_get_cos_sin", "triton_get_cos_sin"],
|
|
||||||
line_names=["torch_get_cos_sin", "triton_get_cos_sin"],
|
|
||||||
styles=[("red", "-"), ("blue", "-")],
|
|
||||||
ylabel="ms",
|
|
||||||
plot_name="Get_cos-sin_func",
|
|
||||||
args={"batch_size": 16, "head_dim": 256},
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(configs)
|
|
||||||
def benchmark_get_xine_cache(
|
|
||||||
provider: str,
|
|
||||||
max_num_tokens: int,
|
|
||||||
batch_size: int,
|
|
||||||
head_dim: int,
|
|
||||||
):
|
|
||||||
warmup = 10
|
|
||||||
rep = 1000
|
|
||||||
dtype = torch.float16
|
|
||||||
cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
|
|
||||||
sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
|
|
||||||
lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda")
|
|
||||||
|
|
||||||
if provider == "torch_get_cos_sin":
|
|
||||||
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
|
|
||||||
elif provider == "triton_get_cos_sin":
|
|
||||||
fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
|
|
||||||
else:
|
|
||||||
raise ValueError("Undefined provider")
|
|
||||||
|
|
||||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
|
||||||
return ms
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_get_xine_cache(4, 64, 256, torch.float32)
|
test_get_xine_cache(4, 64, 256, torch.float32)
|
||||||
# benchmark_get_xine_cache.run(save_path=".",print_data=True)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user