mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
Optimized the execution interval time between cuda kernels caused by view and memcopy (#5390)
* opt_view_and_memcopy * fix bugs in ci * fix ci bugs * update benchmark scripts * fix ci bugs
This commit is contained in:
@@ -100,10 +100,14 @@ def test_context_attention(
|
||||
k_cache_triton = torch.zeros_like(k_cache_ref)
|
||||
v_cache_triton = torch.zeros_like(v_cache_ref)
|
||||
|
||||
_, num_heads, head_dim = q_unpad.shape
|
||||
|
||||
out_triton = context_attention_unpadded(
|
||||
q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size
|
||||
)
|
||||
|
||||
out_triton = out_triton.view(-1, num_heads, head_dim)
|
||||
|
||||
out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads)
|
||||
|
||||
assert out_torch.shape == out_triton.shape
|
||||
|
@@ -3,6 +3,7 @@ import torch
|
||||
import triton
|
||||
from packaging import version
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
from colossalai.kernel.triton import rms_layernorm
|
||||
from colossalai.testing.utils import parameterize
|
||||
@@ -29,15 +30,28 @@ def test_layer_norm(M, N):
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1],)
|
||||
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
||||
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
|
||||
residual_copy = residual.clone()
|
||||
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||
x_copy = x.clone()
|
||||
|
||||
y_triton = rms_layernorm(x, weight, eps=eps)
|
||||
y_triton, _ = rms_layernorm(x, weight, eps=eps)
|
||||
y_llama = rms_norm.forward(x).to(dtype)
|
||||
|
||||
assert y_triton.shape == y_llama.shape
|
||||
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
|
||||
|
||||
y_triton, residual = rms_layernorm(x, weight, eps=eps, residual=residual)
|
||||
|
||||
x = x_copy + residual_copy
|
||||
|
||||
y_llama = rms_norm.forward(x).to(dtype)
|
||||
|
||||
assert y_triton.shape == y_llama.shape
|
||||
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-3)
|
||||
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
|
||||
|
||||
|
||||
# Triton benchmark plot attributions
|
||||
configs = [
|
||||
@@ -45,9 +59,19 @@ configs = [
|
||||
x_names=["SEQUENCE_TOTAL"],
|
||||
x_vals=[i for i in range(128, 1025, 128)],
|
||||
line_arg="provider",
|
||||
line_vals=["torch_rms_layernorm", "triton_rms_layernorm"],
|
||||
line_names=["torch_rms_layernorm", "triton_rms_layernorm"],
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
line_vals=[
|
||||
"vllm_rms_layernorm",
|
||||
"triton_rms_layernorm",
|
||||
"triton_rms_layernorm_with_residual",
|
||||
"vllm_rms_layernorm_with_residual",
|
||||
],
|
||||
line_names=[
|
||||
"vllm_rms_layernorm",
|
||||
"triton_rms_layernorm",
|
||||
"triton_rms_layernorm_with_residual",
|
||||
"vllm_rms_layernorm_with_residual",
|
||||
],
|
||||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"RMSNorm benchmarking results",
|
||||
args={"HIDDEN_SIZE": 1024},
|
||||
@@ -68,13 +92,18 @@ def benchmark_rms_layernorm(
|
||||
eps = 1e-5
|
||||
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
|
||||
w_shape = (x_shape[-1],)
|
||||
residual = torch.rand(x_shape, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
||||
torch_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE).to(dtype=dtype, device="cuda")
|
||||
vllm_norm = RMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).to(dtype=dtype, device="cuda")
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||
if provider == "torch_rms_layernorm":
|
||||
fn = lambda: torch_norm(x)
|
||||
if provider == "vllm_rms_layernorm":
|
||||
fn = lambda: vllm_norm(x)
|
||||
elif provider == "triton_rms_layernorm":
|
||||
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
||||
elif provider == "vllm_rms_layernorm_with_residual":
|
||||
fn = lambda: vllm_norm(x, residual=residual)
|
||||
elif provider == "triton_rms_layernorm_with_residual":
|
||||
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
|
||||
else:
|
||||
raise ValueError("Undefined provider.")
|
||||
|
||||
|
Reference in New Issue
Block a user