diff --git a/tests/test_infer/test_ops/triton/test_rmsnorm_triton.py b/tests/test_infer/test_ops/triton/test_rmsnorm_triton.py index 5ce852164..66e1745d8 100644 --- a/tests/test_infer/test_ops/triton/test_rmsnorm_triton.py +++ b/tests/test_infer/test_ops/triton/test_rmsnorm_triton.py @@ -3,13 +3,12 @@ import torch import triton from packaging import version from transformers.models.llama.modeling_llama import LlamaRMSNorm -from vllm.model_executor.layers.layernorm import RMSNorm from colossalai.kernel.triton import rms_layernorm from colossalai.testing.utils import parameterize try: - pass + import triton # noqa HAS_TRITON = True except ImportError: @@ -85,6 +84,11 @@ def benchmark_rms_layernorm( SEQUENCE_TOTAL: int, HIDDEN_SIZE: int, ): + try: + from vllm.model_executor.layers.layernorm import RMSNorm + except ImportError: + raise ImportError("Please install vllm from https://github.com/vllm-project/vllm") + warmup = 10 rep = 1000