[Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399)

fix dependency in pytest
This commit is contained in:
Yuanheng Zhao 2024-02-26 16:17:47 +08:00 committed by GitHub
parent bc1da87366
commit 19061188c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,13 +3,12 @@ import torch
import triton
from packaging import version
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm
from colossalai.kernel.triton import rms_layernorm
from colossalai.testing.utils import parameterize
try:
pass
import triton # noqa
HAS_TRITON = True
except ImportError:
@ -85,6 +84,11 @@ def benchmark_rms_layernorm(
SEQUENCE_TOTAL: int,
HIDDEN_SIZE: int,
):
try:
from vllm.model_executor.layers.layernorm import RMSNorm
except ImportError:
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")
warmup = 10
rep = 1000