From 42b76979a3a22d4b95133513eed8c51648ab9e29 Mon Sep 17 00:00:00 2001 From: csunny Date: Sun, 21 May 2023 11:22:56 +0800 Subject: [PATCH] llms: multi model support --- examples/embdserver.py | 4 ++-- pilot/server/llmserver.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/embdserver.py b/examples/embdserver.py index 79140ba66..bb0016f00 100644 --- a/examples/embdserver.py +++ b/examples/embdserver.py @@ -12,7 +12,7 @@ from pilot.conversation import conv_qa_prompt_template, conv_templates from langchain.prompts import PromptTemplate -vicuna_stream_path = "generate_stream" +llmstream_stream_path = "generate_stream" CFG = Config() @@ -44,7 +44,7 @@ def generate(query): } response = requests.post( - url=urljoin(CFG.MODEL_SERVER, vicuna_stream_path), data=json.dumps(params) + url=urljoin(CFG.MODEL_SERVER, llmstream_stream_path), data=json.dumps(params) ) skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("") * 3 diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index fa1da5608..79b3450d3 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -27,8 +27,6 @@ from pilot.server.chat_adapter import get_llm_chat_adapter CFG = Config() -model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] -print(model_path) class ModelWorker: @@ -154,6 +152,8 @@ def embeddings(prompt_request: EmbeddingRequest): if __name__ == "__main__": + model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] + print(model_path) worker = ModelWorker( model_path=model_path,