mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[Fix] Fix & Update Inference Tests (compatibility w/ main)
This commit is contained in:
@@ -14,7 +14,6 @@ from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.modeling.policy import NoPaddingBaichuanModelInferPolicy
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base"
|
||||
BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-13B-Base"
|
||||
|
||||
|
||||
@@ -87,7 +86,7 @@ def run_engine(world_size, **kwargs):
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
|
||||
if ret:
|
||||
ret[rank] = func_to_run(**kwargs)
|
||||
@@ -99,7 +98,7 @@ def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||
@parameterize("prompt_template", [None, "baichuan"])
|
||||
@parameterize("do_sample", [False])
|
||||
@parameterize("use_cuda_kernel", [True])
|
||||
def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
|
||||
def check_tp_engine(prompt_template, do_sample, use_cuda_kernel):
|
||||
kwargs1 = {
|
||||
"use_engine": True,
|
||||
"prompt_template": prompt_template,
|
||||
@@ -132,7 +131,7 @@ def test_tp_engine(prompt_template, do_sample, use_cuda_kernel):
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_inference_engine():
|
||||
test_tp_engine()
|
||||
check_tp_engine()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user