This commit is contained in:
Jianghai
2024-01-26 15:02:12 +08:00
committed by GitHub
parent 4f28cb43c0
commit 7ddd8b37f0
4 changed files with 149 additions and 75 deletions

View File

@@ -39,8 +39,8 @@ configs = [
x_names=["max_num_tokens"],
x_vals=[2**i for i in range(6, 12)],
line_arg="provider",
line_vals=["torch_get_cos_sin_func", "triton_get_xine_func"],
line_names=["torch_get_cos_sin_func", "triton_get_xine_func"],
line_vals=["torch_get_cos_sin", "triton_get_cos_sin"],
line_names=["torch_get_cos_sin", "triton_get_cos_sin"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name="Get_cos-sin_func",
@@ -58,19 +58,15 @@ def benchmark_get_xine_cache(
):
warmup = 10
rep = 1000
max_token_per_seq = max_num_tokens // batch_size
dtype = torch.float16
cos_cache = torch.randn((max_num_tokens, head_dim), dtype=dtype, device="cuda")
sin_cache = torch.randn((max_num_tokens, head_dim), dtype=dtype, device="cuda")
lengths = torch.randint(2, max_token_per_seq, (batch_size,), device="cuda")
cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda")
if provider == "torch_get_cos_sin_func":
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
elif provider == "triton_get_xine_func":
fn = lambda: [
get_xine_cache(lengths, cos_cache, is_prompts=False),
get_xine_cache(lengths, sin_cache, is_prompts=False),
]
if provider == "torch_get_cos_sin":
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
elif provider == "triton_get_cos_sin":
fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
else:
raise ValueError("Undefined provider")