diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index fdbfcf2b1..2594fdde4 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -215,8 +215,8 @@ async def params_list(chat_mode: str = ChatScene.ChatNormal.value()): @router.post("/v1/chat/mode/params/file/load") -async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File(...)): - print(f"params_load: {conv_uid},{chat_mode}") +async def params_load(conv_uid: str, chat_mode: str, model_name: str, doc_file: UploadFile = File(...)): + print(f"params_load: {conv_uid},{chat_mode},{model_name}") try: if doc_file: ## file save @@ -235,7 +235,7 @@ async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File ) ## chat prepare dialogue = ConversationVo( - conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename + conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename, model_name=model_name ) chat: BaseChat = get_chat_instance(dialogue) resp = await chat.prepare() @@ -287,15 +287,17 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: chat_param = { "chat_session_id": dialogue.conv_uid, - "user_input": dialogue.user_input, + "current_user_input": dialogue.user_input, "select_param": dialogue.select_param, + "model_name": dialogue.model_name } - chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) + chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **{"chat_param": chat_param}) return chat @router.post("/v1/chat/prepare") async def chat_prepare(dialogue: ConversationVo = Body()): + # dialogue.model_name = CFG.LLM_MODEL logger.info(f"chat_prepare:{dialogue}") ## check conv_uid chat: BaseChat = get_chat_instance(dialogue) @@ -307,7 +309,8 @@ async def chat_prepare(dialogue: ConversationVo = Body()): @router.post("/v1/chat/completions") async def chat_completions(dialogue: ConversationVo = Body()): - print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param}") + # dialogue.model_name = CFG.LLM_MODEL + print(f"chat_completions:{dialogue.chat_mode},{dialogue.select_param},{dialogue.model_name}") chat: BaseChat = get_chat_instance(dialogue) # background_tasks = BackgroundTasks() # background_tasks.add_task(release_model_semaphore) @@ -332,6 +335,28 @@ async def chat_completions(dialogue: ConversationVo = Body()): ) +@router.get("/v1/model/types") +async def model_types(): + print(f"/controller/model/types") + try: + import requests + + response = requests.get( + f"{CFG.MODEL_SERVER}/api/controller/models?healthy_only=true" + ) + types = set() + if response.status_code == 200: + models = json.loads(response.text) + for model in models: + worker_type = model["model_name"].split("@")[1] + if worker_type == "llm": + types.add(model["model_name"].split("@")[0]) + return Result.succ(list(types)) + + except Exception as e: + return Result.faild(code="E000X", msg=f"controller model types error {e}") + + async def no_stream_generator(chat): msg = await chat.nostream_call() msg = msg.replace("\n", "\\n")