From ccf72797e3bfafcbfc42870ce24ee484858d4852 Mon Sep 17 00:00:00 2001 From: Steve Luo <36296769+SunflowerAries@users.noreply.github.com> Date: Fri, 19 Apr 2024 15:34:53 +0800 Subject: [PATCH] feat baichuan2 rmsnorm whose hidden size equals to 5120 (#5611) --- examples/inference/benchmark_ops/benchmark_rmsnorm.py | 2 +- extensions/csrc/cuda/rms_layernorm_kernel.cu | 6 ++++++ tests/test_infer/test_ops/cuda/test_rms_layernorm.py | 4 ++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/inference/benchmark_ops/benchmark_rmsnorm.py b/examples/inference/benchmark_ops/benchmark_rmsnorm.py index 3b5166af0..deddac8b1 100644 --- a/examples/inference/benchmark_ops/benchmark_rmsnorm.py +++ b/examples/inference/benchmark_ops/benchmark_rmsnorm.py @@ -35,7 +35,7 @@ configs = [ styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")], ylabel="ms", plot_name=f"RMSNorm benchmarking results", - args={"HIDDEN_SIZE": 1024}, + args={"HIDDEN_SIZE": 5120}, ) ] diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu index 9183462ad..f109edca4 100644 --- a/extensions/csrc/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -277,6 +277,9 @@ void rms_layernorm( case 2: RMSNORM_LAUNCHER(2, block); break; + case 3: + RMSNORM_LAUNCHER(3, block); + break; case 4: RMSNORM_LAUNCHER(4, block); break; @@ -321,6 +324,9 @@ void fused_add_rms_layernorm( case 2: FUSED_ADD_RMSNORM_LAUNCHER(2, block); break; + case 3: + FUSED_ADD_RMSNORM_LAUNCHER(3, block); + break; case 4: FUSED_ADD_RMSNORM_LAUNCHER(4, block); break; diff --git a/tests/test_infer/test_ops/cuda/test_rms_layernorm.py b/tests/test_infer/test_ops/cuda/test_rms_layernorm.py index d14010600..0b677fff8 100644 --- a/tests/test_infer/test_ops/cuda/test_rms_layernorm.py +++ b/tests/test_infer/test_ops/cuda/test_rms_layernorm.py @@ -9,7 +9,7 @@ inference_ops = InferenceOpsLoader().load() @pytest.mark.parametrize("M", [2, 4, 8, 16]) -@pytest.mark.parametrize("N", [64, 128, 512]) +@pytest.mark.parametrize("N", [64, 128, 512, 5120]) def test_rms_layernorm(M: int, N: int): torch.manual_seed(123) torch.cuda.empty_cache() @@ -48,4 +48,4 @@ def test_rms_layernorm(M: int, N: int): if __name__ == "__main__": - test_rms_layernorm(16, 512) + test_rms_layernorm(16, 5120)