fix(model): Fix vllm new tokenizer error (#1601)

This commit is contained in:
Fangyin Cheng 2024-06-05 15:27:58 +08:00 committed by GitHub
parent c3c063683c
commit 43b5821ce4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 5 deletions

View File

@ -40,7 +40,7 @@ class VLLMModelAdapterWrapper(LLMModelAdapter):
help="local model path of the huggingface model to use",
)
parser.add_argument("--model_type", type=str, help="model type")
parser.add_argument("--device", type=str, default=None, help="device")
# parser.add_argument("--device", type=str, default=None, help="device")
# TODO parse prompt templete from `model_name` and `model_path`
parser.add_argument(
"--prompt_template",
@ -76,7 +76,11 @@ class VLLMModelAdapterWrapper(LLMModelAdapter):
# Set the attributes from the parsed arguments.
engine_args = AsyncEngineArgs(**vllm_engine_args_dict)
engine = AsyncLLMEngine.from_engine_args(engine_args)
return engine, engine.engine.tokenizer
tokenizer = engine.engine.tokenizer
if hasattr(tokenizer, "tokenizer"):
# vllm >= 0.2.7
tokenizer = tokenizer.tokenizer
return engine, tokenizer
def support_async(self) -> bool:
return True

View File

@ -61,9 +61,7 @@ async def generate_stream(
**gen_params
)
results_generator = model.generate(
prompt, sampling_params, request_id, prompt_token_ids=prompt_token_ids
)
results_generator = model.generate(prompt, sampling_params, request_id)
async for request_output in results_generator:
prompt = request_output.prompt
if echo: