From 19061188c396d851ef17bc34b526e2f2b4fc1479 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Date: Mon, 26 Feb 2024 16:17:47 +0800 Subject: [PATCH] [Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399) fix dependency in pytest --- tests/test_infer/test_ops/triton/test_rmsnorm_triton.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_infer/test_ops/triton/test_rmsnorm_triton.py b/tests/test_infer/test_ops/triton/test_rmsnorm_triton.py index 5ce852164..66e1745d8 100644 --- a/tests/test_infer/test_ops/triton/test_rmsnorm_triton.py +++ b/tests/test_infer/test_ops/triton/test_rmsnorm_triton.py @@ -3,13 +3,12 @@ import torch import triton from packaging import version from transformers.models.llama.modeling_llama import LlamaRMSNorm -from vllm.model_executor.layers.layernorm import RMSNorm from colossalai.kernel.triton import rms_layernorm from colossalai.testing.utils import parameterize try: - pass + import triton # noqa HAS_TRITON = True except ImportError: @@ -85,6 +84,11 @@ def benchmark_rms_layernorm( SEQUENCE_TOTAL: int, HIDDEN_SIZE: int, ): + try: + from vllm.model_executor.layers.layernorm import RMSNorm + except ImportError: + raise ImportError("Please install vllm from https://github.com/vllm-project/vllm") + warmup = 10 rep = 1000