mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
feat rmsnorm cuda kernel and add unittest, benchmark script (#5417)
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import rms_layernorm
|
||||
|
||||
try:
|
||||
import triton # noqa
|
||||
|
||||
except ImportError:
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
# Triton benchmark plot attributions
|
||||
configs = [
|
||||
@@ -19,16 +19,20 @@ configs = [
|
||||
line_vals=[
|
||||
"vllm_rms_layernorm",
|
||||
"triton_rms_layernorm",
|
||||
"triton_rms_layernorm_with_residual",
|
||||
"cuda_rms_layernorm",
|
||||
"vllm_rms_layernorm_with_residual",
|
||||
"triton_rms_layernorm_with_residual",
|
||||
"cuda_rms_layernorm_with_residual",
|
||||
],
|
||||
line_names=[
|
||||
"vllm_rms_layernorm",
|
||||
"triton_rms_layernorm",
|
||||
"triton_rms_layernorm_with_residual",
|
||||
"cuda_rms_layernorm",
|
||||
"vllm_rms_layernorm_with_residual",
|
||||
"triton_rms_layernorm_with_residual",
|
||||
"cuda_rms_layernorm_with_residual",
|
||||
],
|
||||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
|
||||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")],
|
||||
ylabel="ms",
|
||||
plot_name=f"RMSNorm benchmarking results",
|
||||
args={"HIDDEN_SIZE": 1024},
|
||||
@@ -62,10 +66,15 @@ def benchmark_rms_layernorm(
|
||||
fn = lambda: vllm_norm(x)
|
||||
elif provider == "triton_rms_layernorm":
|
||||
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
||||
elif provider == "cuda_rms_layernorm":
|
||||
out = torch.empty_like(x)
|
||||
fn = lambda: inference_ops.rms_layernorm(out, x, weight, eps)
|
||||
elif provider == "vllm_rms_layernorm_with_residual":
|
||||
fn = lambda: vllm_norm(x, residual=residual)
|
||||
elif provider == "triton_rms_layernorm_with_residual":
|
||||
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
|
||||
elif provider == "cuda_rms_layernorm_with_residual":
|
||||
fn = lambda: inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
|
||||
else:
|
||||
raise ValueError("Undefined provider.")
|
||||
|
Reference in New Issue
Block a user