mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[kernel] Add RMSLayerNorm triton kernel (#5262)
* add layerrmsnorm triton kernel * add layerrmsnorm kernel * modify the atol and rtol in test file * Remove the logics of mean computations, and update the name of ther kernel functions and files * add benchmark of rms norm
This commit is contained in:
@@ -1,43 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.kernel.triton import layer_norm
|
||||
from colossalai.testing.utils import parameterize
|
||||
|
||||
try:
|
||||
pass
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
@parameterize("M", [2, 4, 8, 16])
|
||||
@parameterize("N", [64, 128])
|
||||
def test_layer_norm(M, N):
|
||||
dtype = torch.float16
|
||||
eps = 1e-5
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1],)
|
||||
weight = torch.rand(w_shape, dtype=dtype, device="cuda")
|
||||
bias = torch.rand(w_shape, dtype=dtype, device="cuda")
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||
|
||||
y_triton = layer_norm(x, weight, bias, eps)
|
||||
y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
|
||||
|
||||
assert y_triton.shape == y_torch.shape
|
||||
assert y_triton.dtype == y_torch.dtype
|
||||
print("max delta: ", torch.max(torch.abs(y_triton - y_torch)))
|
||||
assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_layer_norm()
|
91
tests/test_infer_ops/triton/test_rmsnorm_triton.py
Normal file
91
tests/test_infer_ops/triton/test_rmsnorm_triton.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
import triton
|
||||
|
||||
from colossalai.kernel.triton import rms_layernorm
|
||||
from colossalai.testing.utils import parameterize
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
|
||||
try:
|
||||
pass
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
@parameterize("M", [2, 4, 8, 16])
|
||||
@parameterize("N", [64, 128])
|
||||
def test_layer_norm(M, N):
|
||||
|
||||
dtype = torch.float16
|
||||
eps = 1e-5
|
||||
x_shape = (M, N)
|
||||
w_shape = (x_shape[-1],)
|
||||
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
||||
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||
|
||||
y_triton = rms_layernorm(x, weight, eps=eps)
|
||||
y_llama = rms_norm.forward(x).to(dtype)
|
||||
|
||||
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
|
||||
# Triton benchmark plot attributions
|
||||
configs = [
|
||||
triton.testing.Benchmark(
|
||||
x_names=["SEQUENCE_TOTAL"],
|
||||
x_vals=[i for i in range(128, 1025, 128)],
|
||||
line_arg="provider",
|
||||
line_vals=["llama_rms_layernorm", "triton_rms_layernorm"],
|
||||
line_names=["llama_rms_layernorm", "triton_rms_layernorm"],
|
||||
styles=[("red", "-"), ("blue", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"RMSNorm benchmarking results",
|
||||
args={"HIDDEN_SIZE": 1024},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark_rms_layernorm(
|
||||
provider: str,
|
||||
SEQUENCE_TOTAL: int,
|
||||
HIDDEN_SIZE: int,
|
||||
):
|
||||
warmup = 10
|
||||
rep = 100
|
||||
|
||||
dtype = torch.float16
|
||||
eps = 1e-5
|
||||
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
|
||||
w_shape = (x_shape[-1],)
|
||||
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
||||
rms_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).cuda()
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||
|
||||
if provider == "llama_rms_layernorm":
|
||||
fn = lambda: rms_norm.forward(x).to(dtype)
|
||||
elif provider == "triton_rms_layernorm":
|
||||
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
||||
else:
|
||||
raise ValueError("Undefined provider.")
|
||||
|
||||
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||
|
||||
return ms
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_layer_norm()
|
||||
# benchmark_rms_layernorm.run(save_path=".")
|
Reference in New Issue
Block a user