mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
fix (#5311)
This commit is contained in:
@@ -39,8 +39,8 @@ configs = [
|
||||
x_names=["max_num_tokens"],
|
||||
x_vals=[2**i for i in range(6, 12)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch_get_cos_sin_func", "triton_get_xine_func"],
|
||||
line_names=["torch_get_cos_sin_func", "triton_get_xine_func"],
|
||||
line_vals=["torch_get_cos_sin", "triton_get_cos_sin"],
|
||||
line_names=["torch_get_cos_sin", "triton_get_cos_sin"],
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
ylabel="ms",
|
||||
plot_name="Get_cos-sin_func",
|
||||
@@ -58,19 +58,15 @@ def benchmark_get_xine_cache(
|
||||
):
|
||||
warmup = 10
|
||||
rep = 1000
|
||||
max_token_per_seq = max_num_tokens // batch_size
|
||||
dtype = torch.float16
|
||||
cos_cache = torch.randn((max_num_tokens, head_dim), dtype=dtype, device="cuda")
|
||||
sin_cache = torch.randn((max_num_tokens, head_dim), dtype=dtype, device="cuda")
|
||||
lengths = torch.randint(2, max_token_per_seq, (batch_size,), device="cuda")
|
||||
cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
|
||||
sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
|
||||
lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda")
|
||||
|
||||
if provider == "torch_get_cos_sin_func":
|
||||
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
|
||||
elif provider == "triton_get_xine_func":
|
||||
fn = lambda: [
|
||||
get_xine_cache(lengths, cos_cache, is_prompts=False),
|
||||
get_xine_cache(lengths, sin_cache, is_prompts=False),
|
||||
]
|
||||
if provider == "torch_get_cos_sin":
|
||||
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
|
||||
elif provider == "triton_get_cos_sin":
|
||||
fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
|
||||
else:
|
||||
raise ValueError("Undefined provider")
|
||||
|
||||
|
Reference in New Issue
Block a user