[Inference] Fix bugs and docs for feat/online-server (#5598)

* fix test bugs

* add do sample test

* del useless lines

* fix comments

* fix tests

* delete version tag

* delete version tag

* add

* del test sever

* fix test

* fix

* Revert "add"

This reverts commit b9305fb024.
This commit is contained in:
Jianghai
2024-05-08 15:14:06 +08:00
committed by CjhHa1
parent 7bbb28e48b
commit 61a1b2e798
12 changed files with 98 additions and 172 deletions

View File

@@ -37,7 +37,6 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
)
).cuda()
model = model.eval()
inputs = [
"介绍一下今天的北京,比如故宫,天安门,长城或者其他的一些景点,",
"介绍一下武汉,",
@@ -60,7 +59,9 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)
assert inference_engine.request_handler._has_waiting()
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
generation_config = GenerationConfig(
max_new_tokens=output_len, do_sample=do_sample, dtype="fp32", top_p=top_p, top_k=top_k
)
outputs = inference_engine.generate(generation_config=generation_config)
else:
if prompt_template:
@@ -72,6 +73,7 @@ def check_inference_engine(use_engine=False, prompt_template=None, do_sample=Tru
inputs = inputs.cuda()
generation_config = GenerationConfig(
do_sample=do_sample,
dtype="fp32",
top_p=top_p,
top_k=top_k,
pad_token_id=tokenizer.pad_token_id,