adapted to pad_context_forward

This commit is contained in:
yuehuayingxueluo
2024-01-09 13:52:53 +08:00
committed by FrankLeeeee
parent 47e53eaa1c
commit fa4fbdbffb
9 changed files with 42 additions and 41 deletions

View File

@@ -28,20 +28,24 @@ def check_inference_engine(test_cai=False):
)
).cuda()
model = model.eval()
inputs = [
"介绍一下今天的北京,",
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
"介绍一下武汉,",
]
output_len = 16
output_len = 128
do_sample = True
top_p = 0.5
top_k = 50
if test_cai:
inference_config = InferenceConfig(max_output_len=output_len)
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=0.5, top_k=50)
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
outputs = inference_engine.generate(generation_config)
else:
tokenizer.pad_token = tokenizer.eos_token
@@ -49,7 +53,11 @@ def check_inference_engine(test_cai=False):
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample, top_p=0.5, top_k=50, pad_token_id=tokenizer.pad_token_id, max_new_tokens=output_len
do_sample=do_sample,
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=output_len,
)
outputs = model.generate(inputs, generation_config=generation_config)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)