mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-31 11:25:27 +00:00
[kernel] Add RMSLayerNorm triton kernel (#5262)
* add layerrmsnorm triton kernel * add layerrmsnorm kernel * modify the atol and rtol in test file * Remove the logics of mean computations, and update the name of ther kernel functions and files * add benchmark of rms norm
This commit is contained in:
parent
86b63f720c
commit
5ae9099f92
@ -10,7 +10,7 @@ except ImportError:
|
|||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
from .context_attn_unpad import context_attention_unpadded
|
from .context_attn_unpad import context_attention_unpadded
|
||||||
from .flash_decoding import flash_decoding_fwd
|
from .flash_decoding import flash_decoding_fwd
|
||||||
from .fused_layernorm import layer_norm
|
from .rms_layernorm import rms_layernorm
|
||||||
from .gptq_triton import gptq_fused_linear_triton
|
from .gptq_triton import gptq_fused_linear_triton
|
||||||
from .kvcache_copy import copy_kv_to_blocked_cache
|
from .kvcache_copy import copy_kv_to_blocked_cache
|
||||||
from .no_pad_rotary_embedding import rotary_embedding
|
from .no_pad_rotary_embedding import rotary_embedding
|
||||||
@ -21,7 +21,7 @@ if HAS_TRITON:
|
|||||||
"flash_decoding_fwd",
|
"flash_decoding_fwd",
|
||||||
"copy_kv_to_blocked_cache",
|
"copy_kv_to_blocked_cache",
|
||||||
"softmax",
|
"softmax",
|
||||||
"layer_norm",
|
"rms_layernorm",
|
||||||
"gptq_fused_linear_triton",
|
"gptq_fused_linear_triton",
|
||||||
"rotary_embedding",
|
"rotary_embedding",
|
||||||
]
|
]
|
||||||
|
@ -14,34 +14,28 @@ if HAS_TRITON:
|
|||||||
# https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
# https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _layer_norm_fwd_fused(
|
def _rmsnorm_kernel(
|
||||||
X, # pointer to the input
|
X, # pointer to the input
|
||||||
Y, # pointer to the output
|
Y, # pointer to the output
|
||||||
W, # pointer to the weights
|
W, # pointer to the weights
|
||||||
B, # pointer to the biases
|
|
||||||
stride, # how much to increase the pointer when moving by 1 row
|
stride, # how much to increase the pointer when moving by 1 row
|
||||||
N, # number of columns in X
|
N, # number of columns in X
|
||||||
eps, # epsilon to avoid division by zero
|
eps, # epsilon to avoid division by zero
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
# This triton kernel implements Root Mean Square Layer Norm (RMSNorm).
|
||||||
|
|
||||||
# Map the program id to the row of X and Y it should compute.
|
# Map the program id to the row of X and Y it should compute.
|
||||||
row = tl.program_id(0)
|
row = tl.program_id(0)
|
||||||
Y += row * stride
|
Y += row * stride
|
||||||
X += row * stride
|
X += row * stride
|
||||||
# Compute mean
|
|
||||||
mean = 0
|
|
||||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
|
||||||
for off in range(0, N, BLOCK_SIZE):
|
|
||||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
|
||||||
a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
|
||||||
_mean += a
|
|
||||||
mean = tl.sum(_mean, axis=0) / N
|
|
||||||
# Compute variance
|
# Compute variance
|
||||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||||
for off in range(0, N, BLOCK_SIZE):
|
for off in range(0, N, BLOCK_SIZE):
|
||||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||||
x = tl.where(cols < N, x - mean, 0.0)
|
x = tl.where(cols < N, x, 0.0)
|
||||||
_var += x * x
|
_var += x * x
|
||||||
var = tl.sum(_var, axis=0) / N
|
var = tl.sum(_var, axis=0) / N
|
||||||
rstd = 1 / tl.sqrt(var + eps)
|
rstd = 1 / tl.sqrt(var + eps)
|
||||||
@ -50,15 +44,14 @@ if HAS_TRITON:
|
|||||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||||
mask = cols < N
|
mask = cols < N
|
||||||
w = tl.load(W + cols, mask=mask)
|
w = tl.load(W + cols, mask=mask)
|
||||||
b = tl.load(B + cols, mask=mask)
|
|
||||||
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
|
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
|
||||||
x_hat = (x - mean) * rstd
|
x_hat = x * rstd
|
||||||
y = x_hat * w + b
|
y = x_hat * w
|
||||||
# Write output
|
# Write output
|
||||||
tl.store(Y + cols, y.to(tl.float16), mask=mask)
|
tl.store(Y + cols, y.to(tl.float16), mask=mask)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def layer_norm(x, weight, bias, eps):
|
def rms_layernorm(x, weight, eps):
|
||||||
# allocate output
|
# allocate output
|
||||||
y = torch.empty_like(x)
|
y = torch.empty_like(x)
|
||||||
# reshape input data into 2D tensor
|
# reshape input data into 2D tensor
|
||||||
@ -72,7 +65,7 @@ if HAS_TRITON:
|
|||||||
# heuristics for number of warps
|
# heuristics for number of warps
|
||||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||||
# enqueue kernel
|
# enqueue kernel
|
||||||
_layer_norm_fwd_fused[(M,)](
|
_rmsnorm_kernel[(M,)](
|
||||||
x_arg, y, weight, bias, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
|
||||||
)
|
)
|
||||||
return y
|
return y
|
@ -1,43 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
from colossalai.kernel.triton import layer_norm
|
|
||||||
from colossalai.testing.utils import parameterize
|
|
||||||
|
|
||||||
try:
|
|
||||||
pass
|
|
||||||
|
|
||||||
HAS_TRITON = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_TRITON = False
|
|
||||||
print("please install triton from https://github.com/openai/triton")
|
|
||||||
|
|
||||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
|
||||||
)
|
|
||||||
@parameterize("M", [2, 4, 8, 16])
|
|
||||||
@parameterize("N", [64, 128])
|
|
||||||
def test_layer_norm(M, N):
|
|
||||||
dtype = torch.float16
|
|
||||||
eps = 1e-5
|
|
||||||
x_shape = (M, N)
|
|
||||||
w_shape = (x_shape[-1],)
|
|
||||||
weight = torch.rand(w_shape, dtype=dtype, device="cuda")
|
|
||||||
bias = torch.rand(w_shape, dtype=dtype, device="cuda")
|
|
||||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
|
||||||
|
|
||||||
y_triton = layer_norm(x, weight, bias, eps)
|
|
||||||
y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
|
|
||||||
|
|
||||||
assert y_triton.shape == y_torch.shape
|
|
||||||
assert y_triton.dtype == y_torch.dtype
|
|
||||||
print("max delta: ", torch.max(torch.abs(y_triton - y_torch)))
|
|
||||||
assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_layer_norm()
|
|
91
tests/test_infer_ops/triton/test_rmsnorm_triton.py
Normal file
91
tests/test_infer_ops/triton/test_rmsnorm_triton.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
import triton
|
||||||
|
|
||||||
|
from colossalai.kernel.triton import rms_layernorm
|
||||||
|
from colossalai.testing.utils import parameterize
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
|
|
||||||
|
try:
|
||||||
|
pass
|
||||||
|
|
||||||
|
HAS_TRITON = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TRITON = False
|
||||||
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
|
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||||
|
)
|
||||||
|
@parameterize("M", [2, 4, 8, 16])
|
||||||
|
@parameterize("N", [64, 128])
|
||||||
|
def test_layer_norm(M, N):
|
||||||
|
|
||||||
|
dtype = torch.float16
|
||||||
|
eps = 1e-5
|
||||||
|
x_shape = (M, N)
|
||||||
|
w_shape = (x_shape[-1],)
|
||||||
|
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
||||||
|
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
|
||||||
|
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
y_triton = rms_layernorm(x, weight, eps=eps)
|
||||||
|
y_llama = rms_norm.forward(x).to(dtype)
|
||||||
|
|
||||||
|
assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Triton benchmark plot attributions
|
||||||
|
configs = [
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["SEQUENCE_TOTAL"],
|
||||||
|
x_vals=[i for i in range(128, 1025, 128)],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["llama_rms_layernorm", "triton_rms_layernorm"],
|
||||||
|
line_names=["llama_rms_layernorm", "triton_rms_layernorm"],
|
||||||
|
styles=[("red", "-"), ("blue", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name=f"RMSNorm benchmarking results",
|
||||||
|
args={"HIDDEN_SIZE": 1024},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(configs)
|
||||||
|
def benchmark_rms_layernorm(
|
||||||
|
provider: str,
|
||||||
|
SEQUENCE_TOTAL: int,
|
||||||
|
HIDDEN_SIZE: int,
|
||||||
|
):
|
||||||
|
warmup = 10
|
||||||
|
rep = 100
|
||||||
|
|
||||||
|
dtype = torch.float16
|
||||||
|
eps = 1e-5
|
||||||
|
x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE)
|
||||||
|
w_shape = (x_shape[-1],)
|
||||||
|
weight = torch.ones(w_shape, dtype=dtype, device="cuda")
|
||||||
|
rms_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).cuda()
|
||||||
|
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
if provider == "llama_rms_layernorm":
|
||||||
|
fn = lambda: rms_norm.forward(x).to(dtype)
|
||||||
|
elif provider == "triton_rms_layernorm":
|
||||||
|
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
||||||
|
else:
|
||||||
|
raise ValueError("Undefined provider.")
|
||||||
|
|
||||||
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||||
|
|
||||||
|
return ms
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_layer_norm()
|
||||||
|
# benchmark_rms_layernorm.run(save_path=".")
|
Loading…
Reference in New Issue
Block a user