From b32e0396e9b4c3b82b892335bdbfc42e1c5e43a2 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Mon, 11 Sep 2023 17:59:25 +0800 Subject: [PATCH] feat:chat_history add model_name --- pilot/openapi/api_v1/api_v1.py | 7 ++++--- pilot/openapi/api_view_model.py | 5 +++++ pilot/scene/base_chat.py | 2 +- pilot/scene/message.py | 3 +++ 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 2594fdde4..ae23503fb 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -259,8 +259,9 @@ def get_hist_messages(conv_uid: str): history_messages: List[OnceConversation] = history_mem.get_messages() if history_messages: for once in history_messages: + model_name = once.get("model_name", CFG.LLM_MODEL) once_message_vos = [ - message2Vo(element, once["chat_order"]) for element in once["messages"] + message2Vo(element, once["chat_order"], model_name) for element in once["messages"] ] message_vos.extend(once_message_vos) return message_vos @@ -381,7 +382,7 @@ async def stream_generator(chat): chat.memory.append(chat.current_message) -def message2Vo(message: dict, order) -> MessageVo: +def message2Vo(message: dict, order, model_name) -> MessageVo: return MessageVo( - role=message["type"], context=message["data"]["content"], order=order + role=message["type"], context=message["data"]["content"], order=order, model_name=model_name ) diff --git a/pilot/openapi/api_view_model.py b/pilot/openapi/api_view_model.py index 57b438879..d03beec8d 100644 --- a/pilot/openapi/api_view_model.py +++ b/pilot/openapi/api_view_model.py @@ -78,3 +78,8 @@ class MessageVo(BaseModel): time the current message was sent """ time_stamp: Any = None + + """ + model_name + """ + model_name: str diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index a774ba573..d3fa7b474 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -38,7 +38,7 @@ class BaseChat(ABC): self.chat_session_id = chat_param["chat_session_id"] self.chat_mode = chat_param["chat_mode"] self.current_user_input: str = chat_param["current_user_input"] - self.llm_model = chat_param["model_name"] + self.llm_model = chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL self.llm_echo = False ### load prompt template diff --git a/pilot/scene/message.py b/pilot/scene/message.py index af52780f9..d787729e6 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -23,6 +23,7 @@ class OnceConversation: self.messages: List[BaseMessage] = [] self.start_date: str = "" self.chat_order: int = 0 + self.model_name: str = "" self.param_type: str = "" self.param_value: str = "" self.cost: int = 0 @@ -104,6 +105,7 @@ def _conversation_to_dic(once: OnceConversation) -> dict: return { "chat_mode": once.chat_mode, + "model_name": once.model_name if once.model_name else "proxyllm", "chat_order": once.chat_order, "start_date": start_str, "cost": once.cost if once.cost else 0, @@ -127,6 +129,7 @@ def conversation_from_dict(once: dict) -> OnceConversation: conversation.chat_order = int(once.get("chat_order")) conversation.param_type = once.get("param_type", "") conversation.param_value = once.get("param_value", "") + conversation.model_name = once.get("model_name", "proxyllm") print(once.get("messages")) conversation.messages = messages_from_dict(once.get("messages", [])) return conversation