feat:multi-llm select

This commit is contained in:
aries_ckt 2023-09-08 10:51:12 +08:00
parent 1a2bf96767
commit 63ad842daa

View File

@ -215,8 +215,8 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()):
@router.post("/v1/chat/mode/params/file/load")
async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File(...)):
print(f"params_load: {conv_uid},{chat_mode}")
async def params_load(conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)):
print(f"params_load: {conv_uid},{chat_mode},{model_name}")
try:
if doc_file:
## file save
@ -235,7 +235,7 @@ async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File
)
## chat prepare
dialogue = ConversationVo(
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename, model_name=model_name
)
chat: BaseChat = get_chat_instance(dialogue)
resp = await chat.prepare()
@ -287,15 +287,17 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
chat_param = {
"chat_session_id": dialogue.conv_uid,
"user_input": dialogue.user_input,
"current_user_input": dialogue.user_input,
"select_param": dialogue.select_param,
"model_name": dialogue.model_name
}
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **{"chat_param": chat_param})
return chat
@router.post("/v1/chat/prepare")
async def chat_prepare(dialogue: ConversationVo = Body()):
# dialogue.model_name = CFG.LLM_MODEL
logger.info(f"chat_prepare:{dialogue}")
## check conv_uid
chat: BaseChat = get_chat_instance(dialogue)
@ -307,7 +309,8 @@ async def chat_prepare(dialogue: ConversationVo = Body()):
@router.post("/v1/chat/completions")
async def chat_completions(dialogue: ConversationVo = Body()):
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}")
# dialogue.model_name = CFG.LLM_MODEL
print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}")
chat: BaseChat = get_chat_instance(dialogue)
# background_tasks = BackgroundTasks()
# background_tasks.add_task(release_model_semaphore)
@ -332,6 +335,28 @@ async def chat_completions(dialogue: ConversationVo = Body()):
)
@router.get("/v1/model/types")
async def model_types():
print(f"/controller/model/types")
try:
import requests
response = requests.get(
f"{CFG.MODEL_SERVER}/api/controller/models?healthy_only=true"
)
types = set()
if response.status_code == 200:
models = json.loads(response.text)
for model in models:
worker_type = model["model_name"].split("@")[1]
if worker_type == "llm":
types.add(model["model_name"].split("@")[0])
return Result.succ(list(types))
except Exception as e:
return Result.faild(code="E000X", msg=f"controller model types error {e}")
async def no_stream_generator(chat):
msg = await chat.nostream_call()
msg = msg.replace("\n", "\\n")