mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-07 04:18:55 +00:00
add do sample test
This commit is contained in:
parent
719fb81667
commit
e7a789e97b
@ -20,7 +20,7 @@ def setup_seed(seed):
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
def check_inference_engine(use_engine=False, prompt_template=None, do_sample=False):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
model = LlamaForCausalLM(
|
||||
@ -29,14 +29,13 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
)
|
||||
).cuda()
|
||||
model = model.eval()
|
||||
|
||||
inputs = [
|
||||
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
|
||||
"介绍一下武汉,",
|
||||
]
|
||||
|
||||
output_len = 38
|
||||
do_sample = False
|
||||
do_sample = do_sample
|
||||
top_p = 0.5
|
||||
top_k = 50
|
||||
|
||||
@ -81,9 +80,10 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
|
||||
|
||||
@parameterize("prompt_template", [None, "llama"])
|
||||
def check_output_consistency(prompt_template):
|
||||
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
|
||||
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)
|
||||
@parameterize("do_sample", [True, False])
|
||||
def check_output_consistency(prompt_template, do_sample):
|
||||
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template, do_sample=do_sample)
|
||||
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template, do_sample=do_sample)
|
||||
|
||||
for s1, s2 in zip(cai_outputs, transformer_outputs):
|
||||
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
|
||||
|
Loading…
Reference in New Issue
Block a user