diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 6c7604629..a0ac738c0 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -20,7 +20,7 @@ def setup_seed(seed): random.seed(seed) -def check_inference_engine(use_engine=False, prompt_template=None): +def check_inference_engine(use_engine=False, prompt_template=None, do_sample=False): setup_seed(20) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") model = LlamaForCausalLM( @@ -29,14 +29,13 @@ def check_inference_engine(use_engine=False, prompt_template=None): ) ).cuda() model = model.eval() - inputs = [ "介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,", "介绍一下武汉,", ] output_len = 38 - do_sample = False + do_sample = do_sample top_p = 0.5 top_k = 50 @@ -81,9 +80,10 @@ def check_inference_engine(use_engine=False, prompt_template=None): @parameterize("prompt_template", [None, "llama"]) -def check_output_consistency(prompt_template): - cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template) - transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template) +@parameterize("do_sample", [True, False]) +def check_output_consistency(prompt_template, do_sample): + cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template, do_sample=do_sample) + transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template, do_sample=do_sample) for s1, s2 in zip(cai_outputs, transformer_outputs): assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"