mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 12:43:55 +00:00
[Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399)
fix dependency in pytest
This commit is contained in:
parent
bc1da87366
commit
19061188c3
@ -3,13 +3,12 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
||||||
|
|
||||||
from colossalai.kernel.triton import rms_layernorm
|
from colossalai.kernel.triton import rms_layernorm
|
||||||
from colossalai.testing.utils import parameterize
|
from colossalai.testing.utils import parameterize
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pass
|
import triton # noqa
|
||||||
|
|
||||||
HAS_TRITON = True
|
HAS_TRITON = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -85,6 +84,11 @@ def benchmark_rms_layernorm(
|
|||||||
SEQUENCE_TOTAL: int,
|
SEQUENCE_TOTAL: int,
|
||||||
HIDDEN_SIZE: int,
|
HIDDEN_SIZE: int,
|
||||||
):
|
):
|
||||||
|
try:
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")
|
||||||
|
|
||||||
warmup = 10
|
warmup = 10
|
||||||
rep = 1000
|
rep = 1000
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user