llms: multi model support

This commit is contained in:
csunny 2023-05-21 11:22:56 +08:00
parent 8e127b3863
commit 42b76979a3
2 changed files with 4 additions and 4 deletions

View File

@ -12,7 +12,7 @@ from pilot.conversation import conv_qa_prompt_template, conv_templates
from langchain.prompts import PromptTemplate
vicuna_stream_path = "generate_stream"
llmstream_stream_path = "generate_stream"
CFG = Config()
@ -44,7 +44,7 @@ def generate(query):
}
response = requests.post(
url=urljoin(CFG.MODEL_SERVER, vicuna_stream_path), data=json.dumps(params)
url=urljoin(CFG.MODEL_SERVER, llmstream_stream_path), data=json.dumps(params)
)
skip_echo_len = len(params["prompt"]) + 1 - params["prompt"].count("</s>") * 3

View File

@ -27,8 +27,6 @@ from pilot.server.chat_adapter import get_llm_chat_adapter
CFG = Config()
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
print(model_path)
class ModelWorker:
@ -154,6 +152,8 @@ def embeddings(prompt_request: EmbeddingRequest):
if __name__ == "__main__":
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
print(model_path)
worker = ModelWorker(
model_path=model_path,