mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 11:45:23 +00:00
[Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399)
fix dependency in pytest
This commit is contained in:
parent
bc1da87366
commit
19061188c3
@ -3,13 +3,12 @@ import torch
|
||||
import triton
|
||||
from packaging import version
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
from colossalai.kernel.triton import rms_layernorm
|
||||
from colossalai.testing.utils import parameterize
|
||||
|
||||
try:
|
||||
pass
|
||||
import triton # noqa
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
@ -85,6 +84,11 @@ def benchmark_rms_layernorm(
|
||||
SEQUENCE_TOTAL: int,
|
||||
HIDDEN_SIZE: int,
|
||||
):
|
||||
try:
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
except ImportError:
|
||||
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")
|
||||
|
||||
warmup = 10
|
||||
rep = 1000
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user