add do sample test

This commit is contained in:
CjhHa1 2024-04-15 17:46:22 +08:00
parent 719fb81667
commit e7a789e97b

View File

@ -20,7 +20,7 @@ def setup_seed(seed):
random.seed(seed)
def check_inference_engine(use_engine=False, prompt_template=None):
def check_inference_engine(use_engine=False, prompt_template=None, do_sample=False):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = LlamaForCausalLM(
@ -29,14 +29,13 @@ def check_inference_engine(use_engine=False, prompt_template=None):
)
).cuda()
model = model.eval()
inputs = [
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
"介绍一下武汉,",
]
output_len = 38
do_sample = False
do_sample = do_sample
top_p = 0.5
top_k = 50
@ -81,9 +80,10 @@ def check_inference_engine(use_engine=False, prompt_template=None):
@parameterize("prompt_template", [None, "llama"])
def check_output_consistency(prompt_template):
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)
@parameterize("do_sample", [True, False])
def check_output_consistency(prompt_template, do_sample):
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template, do_sample=do_sample)
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template, do_sample=do_sample)
for s1, s2 in zip(cai_outputs, transformer_outputs):
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"