[Fix] Fix & Update Inference Tests (compatibility w/ main)

This commit is contained in:
Yuanheng Zhao
2024-05-05 16:28:56 +00:00
parent 56ed09aba5
commit 8754abae24
30 changed files with 32 additions and 30 deletions

View File

@@ -14,7 +14,6 @@ from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base"
@@ -87,7 +86,7 @@ def run_engine(world_size, **kwargs):
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
if ret:
ret[rank] = func_to_run(**kwargs)
@@ -99,7 +98,7 @@ def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
@parameterize("prompt_template", [None, "baichuan"])
@parameterize("do_sample", [False])
@parameterize("use_cuda_kernel", [True])
def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
def check_tp_engine(prompt_template, do_sample, use_cuda_kernel):
kwargs1 = {
"use_engine": True,
"prompt_template": prompt_template,
@@ -132,7 +131,7 @@ def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_inference_engine():
test_tp_engine()
check_tp_engine()
if __name__ == "__main__":