feat rmsnorm cuda kernel and add unittest, benchmark script (#5417)

This commit is contained in:
Steve Luo
2024-03-08 16:21:12 +08:00
committed by GitHub
parent 2b28b54ac6
commit f7aecc0c6b
8 changed files with 244 additions and 49 deletions

View File

@@ -1,14 +1,14 @@
import torch
import triton
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import rms_layernorm
try:
import triton # noqa
except ImportError:
print("please install triton from https://github.com/openai/triton")
inference_ops = InferenceOpsLoader().load()
# Triton benchmark plot attributions
configs = [
@@ -19,16 +19,20 @@ configs = [
line_vals=[
"vllm_rms_layernorm",
"triton_rms_layernorm",
"triton_rms_layernorm_with_residual",
"cuda_rms_layernorm",
"vllm_rms_layernorm_with_residual",
"triton_rms_layernorm_with_residual",
"cuda_rms_layernorm_with_residual",
],
line_names=[
"vllm_rms_layernorm",
"triton_rms_layernorm",
"triton_rms_layernorm_with_residual",
"cuda_rms_layernorm",
"vllm_rms_layernorm_with_residual",
"triton_rms_layernorm_with_residual",
"cuda_rms_layernorm_with_residual",
],
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")],
ylabel="ms",
plot_name=f"RMSNorm benchmarking results",
args={"HIDDEN_SIZE": 1024},
@@ -62,10 +66,15 @@ def benchmark_rms_layernorm(
fn = lambda: vllm_norm(x)
elif provider == "triton_rms_layernorm":
fn = lambda: rms_layernorm(x, weight, eps=eps)
elif provider == "cuda_rms_layernorm":
out = torch.empty_like(x)
fn = lambda: inference_ops.rms_layernorm(out, x, weight, eps)
elif provider == "vllm_rms_layernorm_with_residual":
fn = lambda: vllm_norm(x, residual=residual)
elif provider == "triton_rms_layernorm_with_residual":
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
elif provider == "cuda_rms_layernorm_with_residual":
fn = lambda: inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
else:
raise ValueError("Undefined provider.")