diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 021ccb9c1..763522453 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -10,7 +10,7 @@ except ImportError: if HAS_TRITON: from .context_attn_unpad import context_attention_unpadded from .flash_decoding import flash_decoding_fwd - from .fused_layernorm import layer_norm + from .rms_layernorm import rms_layernorm from .gptq_triton import gptq_fused_linear_triton from .kvcache_copy import copy_kv_to_blocked_cache from .no_pad_rotary_embedding import rotary_embedding @@ -21,7 +21,7 @@ if HAS_TRITON: "flash_decoding_fwd", "copy_kv_to_blocked_cache", "softmax", - "layer_norm", + "rms_layernorm", "gptq_fused_linear_triton", "rotary_embedding", ] diff --git a/colossalai/kernel/triton/fused_layernorm.py b/colossalai/kernel/triton/rms_layernorm.py similarity index 74% rename from colossalai/kernel/triton/fused_layernorm.py rename to colossalai/kernel/triton/rms_layernorm.py index 24083b050..b514c7789 100644 --- a/colossalai/kernel/triton/fused_layernorm.py +++ b/colossalai/kernel/triton/rms_layernorm.py @@ -14,34 +14,28 @@ if HAS_TRITON: # https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html @triton.jit - def _layer_norm_fwd_fused( + def _rmsnorm_kernel( X, # pointer to the input Y, # pointer to the output W, # pointer to the weights - B, # pointer to the biases stride, # how much to increase the pointer when moving by 1 row N, # number of columns in X eps, # epsilon to avoid division by zero BLOCK_SIZE: tl.constexpr, ): + + # This triton kernel implements Root Mean Square Layer Norm (RMSNorm). + # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) Y += row * stride X += row * stride - # Compute mean - mean = 0 - _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - _mean += a - mean = tl.sum(_mean, axis=0) / N # Compute variance _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - x = tl.where(cols < N, x - mean, 0.0) + x = tl.where(cols < N, x, 0.0) _var += x * x var = tl.sum(_var, axis=0) / N rstd = 1 / tl.sqrt(var + eps) @@ -50,15 +44,14 @@ if HAS_TRITON: cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N w = tl.load(W + cols, mask=mask) - b = tl.load(B + cols, mask=mask) x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) - x_hat = (x - mean) * rstd - y = x_hat * w + b + x_hat = x * rstd + y = x_hat * w # Write output tl.store(Y + cols, y.to(tl.float16), mask=mask) @torch.no_grad() - def layer_norm(x, weight, bias, eps): + def rms_layernorm(x, weight, eps): # allocate output y = torch.empty_like(x) # reshape input data into 2D tensor @@ -72,7 +65,7 @@ if HAS_TRITON: # heuristics for number of warps num_warps = min(max(BLOCK_SIZE // 256, 1), 8) # enqueue kernel - _layer_norm_fwd_fused[(M,)]( - x_arg, y, weight, bias, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps + _rmsnorm_kernel[(M,)]( + x_arg, y, weight, x_arg.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps ) return y diff --git a/tests/test_infer_ops/triton/test_layernorm_triton.py b/tests/test_infer_ops/triton/test_layernorm_triton.py deleted file mode 100644 index 7f814e8c9..000000000 --- a/tests/test_infer_ops/triton/test_layernorm_triton.py +++ /dev/null @@ -1,43 +0,0 @@ -import pytest -import torch -from packaging import version - -from colossalai.kernel.triton import layer_norm -from colossalai.testing.utils import parameterize - -try: - pass - - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") - - -@pytest.mark.skipif( - not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" -) -@parameterize("M", [2, 4, 8, 16]) -@parameterize("N", [64, 128]) -def test_layer_norm(M, N): - dtype = torch.float16 - eps = 1e-5 - x_shape = (M, N) - w_shape = (x_shape[-1],) - weight = torch.rand(w_shape, dtype=dtype, device="cuda") - bias = torch.rand(w_shape, dtype=dtype, device="cuda") - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") - - y_triton = layer_norm(x, weight, bias, eps) - y_torch = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) - - assert y_triton.shape == y_torch.shape - assert y_triton.dtype == y_torch.dtype - print("max delta: ", torch.max(torch.abs(y_triton - y_torch))) - assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0) - - -if __name__ == "__main__": - test_layer_norm() diff --git a/tests/test_infer_ops/triton/test_rmsnorm_triton.py b/tests/test_infer_ops/triton/test_rmsnorm_triton.py new file mode 100644 index 000000000..6828151ce --- /dev/null +++ b/tests/test_infer_ops/triton/test_rmsnorm_triton.py @@ -0,0 +1,91 @@ +import pytest +import torch +from packaging import version +import triton + +from colossalai.kernel.triton import rms_layernorm +from colossalai.testing.utils import parameterize +from transformers.models.llama.modeling_llama import LlamaRMSNorm + +try: + pass + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4") + + +@pytest.mark.skipif( + not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4" +) +@parameterize("M", [2, 4, 8, 16]) +@parameterize("N", [64, 128]) +def test_layer_norm(M, N): + + dtype = torch.float16 + eps = 1e-5 + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.ones(w_shape, dtype=dtype, device="cuda") + rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda() + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + + y_triton = rms_layernorm(x, weight, eps=eps) + y_llama = rms_norm.forward(x).to(dtype) + + assert torch.allclose(y_triton, y_llama, atol=1e-5, rtol=1e-5) + + + +# Triton benchmark plot attributions +configs = [ + triton.testing.Benchmark( + x_names=["SEQUENCE_TOTAL"], + x_vals=[i for i in range(128, 1025, 128)], + line_arg="provider", + line_vals=["llama_rms_layernorm", "triton_rms_layernorm"], + line_names=["llama_rms_layernorm", "triton_rms_layernorm"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"RMSNorm benchmarking results", + args={"HIDDEN_SIZE": 1024}, + ) +] + + +@triton.testing.perf_report(configs) +def benchmark_rms_layernorm( + provider: str, + SEQUENCE_TOTAL: int, + HIDDEN_SIZE: int, +): + warmup = 10 + rep = 100 + + dtype = torch.float16 + eps = 1e-5 + x_shape = (SEQUENCE_TOTAL, HIDDEN_SIZE) + w_shape = (x_shape[-1],) + weight = torch.ones(w_shape, dtype=dtype, device="cuda") + rms_norm = LlamaRMSNorm(hidden_size=HIDDEN_SIZE, eps=eps).cuda() + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + + if provider == "llama_rms_layernorm": + fn = lambda: rms_norm.forward(x).to(dtype) + elif provider == "triton_rms_layernorm": + fn = lambda: rms_layernorm(x, weight, eps=eps) + else: + raise ValueError("Undefined provider.") + + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + + return ms + + + +if __name__ == "__main__": + test_layer_norm() + # benchmark_rms_layernorm.run(save_path=".") \ No newline at end of file