From 1daaf23e70cb73a614911b6d7a76eb382c8afbf2 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Tue, 12 Sep 2023 20:49:55 +0800 Subject: [PATCH] fix:llm select add model_name --- pilot/openapi/api_v1/api_v1.py | 3 +++ pilot/scene/base_chat.py | 1 + pilot/scene/message.py | 2 +- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 3da31a9ce..81d9fa01d 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -149,6 +149,7 @@ async def dialogue_list(user_id: str = None): conv_uid = item.get("conv_uid") summary = item.get("summary") chat_mode = item.get("chat_mode") + model_name = item.get("model_name", CFG.LLM_MODEL) messages = json.loads(item.get("messages")) last_round = max(messages, key=lambda x: x["chat_order"]) @@ -160,6 +161,7 @@ async def dialogue_list(user_id: str = None): conv_uid=conv_uid, user_input=summary, chat_mode=chat_mode, + model_name=model_name, select_param=select_param, ) dialogues.append(conv_vo) @@ -259,6 +261,7 @@ def get_hist_messages(conv_uid: str): history_messages: List[OnceConversation] = history_mem.get_messages() if history_messages: for once in history_messages: + print(f"once:{once}") model_name = once.get("model_name", CFG.LLM_MODEL) once_message_vos = [ message2Vo(element, once["chat_order"], model_name) for element in once["messages"] diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index dc881f7dd..a2c284114 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -59,6 +59,7 @@ class BaseChat(ABC): self.history_message: List[OnceConversation] = self.memory.messages() self.current_message: OnceConversation = OnceConversation(self.chat_mode.value()) + self.current_message.model_name = self.llm_model if chat_param["select_param"]: if len(self.chat_mode.param_types()) > 0: self.current_message.param_type = self.chat_mode.param_types()[0] diff --git a/pilot/scene/message.py b/pilot/scene/message.py index d787729e6..4d5a5c383 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -105,7 +105,7 @@ def _conversation_to_dic(once: OnceConversation) -> dict: return { "chat_mode": once.chat_mode, - "model_name": once.model_name if once.model_name else "proxyllm", + "model_name": once.model_name, "chat_order": once.chat_order, "start_date": start_str, "cost": once.cost if once.cost else 0,