Optimized the execution interval time between cuda kernels caused by view and memcopy (#5390)

* opt_view_and_memcopy

* fix bugs in ci

* fix ci bugs

* update benchmark scripts

* fix ci bugs
This commit is contained in:
yuehuayingxueluo
2024-02-21 13:23:57 +08:00
committed by GitHub
parent 730103819d
commit 2a718c8be8
8 changed files with 141 additions and 55 deletions

View File

@@ -100,10 +100,14 @@ def test_context_attention(
k_cache_triton = torch.zeros_like(k_cache_ref)
v_cache_triton = torch.zeros_like(v_cache_ref)
_, num_heads, head_dim = q_unpad.shape
out_triton = context_attention_unpadded(
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
)
out_triton = out_triton.view(-1, num_heads, head_dim)
out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads)
assert out_torch.shape == out_triton.shape

View File

@@ -3,6 +3,7 @@ import torch
import triton
from packaging import version
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm
from colossalai.kernel.triton import rms_layernorm
from colossalai.testing.utils import parameterize
@@ -29,15 +30,28 @@ def test_layer_norm(M, N):
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
residual_copy = residual.clone()
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
x_copy = x.clone()
y_triton = rms_layernorm(x, weight, eps=eps)
y_triton, _ = rms_layernorm(x, weight, eps=eps)
y_llama = rms_norm.forward(x).to(dtype)
assert y_triton.shape == y_llama.shape
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
y_triton, residual = rms_layernorm(x, weight, eps=eps, residual=residual)
x = x_copy + residual_copy
y_llama = rms_norm.forward(x).to(dtype)
assert y_triton.shape == y_llama.shape
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
# Triton benchmark plot attributions
configs = [
@@ -45,9 +59,19 @@ configs = [
x_names=["SEQUENCE_TOTAL"],
x_vals=[i for i in range(128, 1025, 128)],
line_arg="provider",
line_vals=["torch_rms_layernorm", "triton_rms_layernorm"],
line_names=["torch_rms_layernorm", "triton_rms_layernorm"],
styles=[("red", "-"), ("blue", "-")],
line_vals=[
"vllm_rms_layernorm",
"triton_rms_layernorm",
"triton_rms_layernorm_with_residual",
"vllm_rms_layernorm_with_residual",
],
line_names=[
"vllm_rms_layernorm",
"triton_rms_layernorm",
"triton_rms_layernorm_with_residual",
"vllm_rms_layernorm_with_residual",
],
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
ylabel="ms",
plot_name=f"RMSNorm benchmarking results",
args={"HIDDEN_SIZE": 1024},
@@ -68,13 +92,18 @@ def benchmark_rms_layernorm(
eps = 1e-5
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
w_shape = (x_shape[-1],)
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
torch_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda")
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).to(dtype=dtype, device="cuda")
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
if provider == "torch_rms_layernorm":
fn = lambda: torch_norm(x)
if provider == "vllm_rms_layernorm":
fn = lambda: vllm_norm(x)
elif provider == "triton_rms_layernorm":
fn = lambda: rms_layernorm(x, weight, eps=eps)
elif provider == "vllm_rms_layernorm_with_residual":
fn = lambda: vllm_norm(x, residual=residual)
elif provider == "triton_rms_layernorm_with_residual":
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
else:
raise ValueError("Undefined provider.")