diff --git a/dbgpt/model/adapter/vllm_adapter.py b/dbgpt/model/adapter/vllm_adapter.py index 268ce82c6..3aaa391e4 100644 --- a/dbgpt/model/adapter/vllm_adapter.py +++ b/dbgpt/model/adapter/vllm_adapter.py @@ -40,7 +40,7 @@ class VLLMModelAdapterWrapper(LLMModelAdapter): help="local model path of the huggingface model to use", ) parser.add_argument("--model_type", type=str, help="model type") - parser.add_argument("--device", type=str, default=None, help="device") + # parser.add_argument("--device", type=str, default=None, help="device") # TODO parse prompt templete from `model_name` and `model_path` parser.add_argument( "--prompt_template", @@ -76,7 +76,11 @@ class VLLMModelAdapterWrapper(LLMModelAdapter): # Set the attributes from the parsed arguments. engine_args = AsyncEngineArgs(**vllm_engine_args_dict) engine = AsyncLLMEngine.from_engine_args(engine_args) - return engine, engine.engine.tokenizer + tokenizer = engine.engine.tokenizer + if hasattr(tokenizer, "tokenizer"): + # vllm >= 0.2.7 + tokenizer = tokenizer.tokenizer + return engine, tokenizer def support_async(self) -> bool: return True diff --git a/dbgpt/model/llm_out/vllm_llm.py b/dbgpt/model/llm_out/vllm_llm.py index 838bcc35a..54ed73483 100644 --- a/dbgpt/model/llm_out/vllm_llm.py +++ b/dbgpt/model/llm_out/vllm_llm.py @@ -61,9 +61,7 @@ async def generate_stream( **gen_params ) - results_generator = model.generate( - prompt, sampling_params, request_id, prompt_token_ids=prompt_token_ids - ) + results_generator = model.generate(prompt, sampling_params, request_id) async for request_output in results_generator: prompt = request_output.prompt if echo: