mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-07 20:39:48 +00:00
add do sample test
This commit is contained in:
parent
719fb81667
commit
e7a789e97b
@ -20,7 +20,7 @@ def setup_seed(seed):
|
|||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
def check_inference_engine(use_engine=False, prompt_template=None, do_sample=False):
|
||||||
setup_seed(20)
|
setup_seed(20)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
model = LlamaForCausalLM(
|
model = LlamaForCausalLM(
|
||||||
@ -29,14 +29,13 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
|||||||
)
|
)
|
||||||
).cuda()
|
).cuda()
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
|
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
|
||||||
"介绍一下武汉,",
|
"介绍一下武汉,",
|
||||||
]
|
]
|
||||||
|
|
||||||
output_len = 38
|
output_len = 38
|
||||||
do_sample = False
|
do_sample = do_sample
|
||||||
top_p = 0.5
|
top_p = 0.5
|
||||||
top_k = 50
|
top_k = 50
|
||||||
|
|
||||||
@ -81,9 +80,10 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
|||||||
|
|
||||||
|
|
||||||
@parameterize("prompt_template", [None, "llama"])
|
@parameterize("prompt_template", [None, "llama"])
|
||||||
def check_output_consistency(prompt_template):
|
@parameterize("do_sample", [True, False])
|
||||||
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
|
def check_output_consistency(prompt_template, do_sample):
|
||||||
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)
|
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template, do_sample=do_sample)
|
||||||
|
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template, do_sample=do_sample)
|
||||||
|
|
||||||
for s1, s2 in zip(cai_outputs, transformer_outputs):
|
for s1, s2 in zip(cai_outputs, transformer_outputs):
|
||||||
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
|
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
|
||||||
|
Loading…
Reference in New Issue
Block a user