diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/duckdb_history.py index b24546d19..42177be75 100644 --- a/pilot/memory/chat_history/duckdb_history.py +++ b/pilot/memory/chat_history/duckdb_history.py @@ -8,6 +8,7 @@ from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.scene.message import ( OnceConversation, conversation_from_dict, + _conversation_to_dic, conversations_to_dict, ) from pilot.common.formatting import MyEncoder @@ -41,12 +42,15 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): def __get_messages_by_conv_uid(self, conv_uid: str): cursor = self.connect.cursor() cursor.execute("SELECT messages FROM chat_history where conv_uid=?", [conv_uid]) - return cursor.fetchone() - + content = cursor.fetchone() + if content: + return cursor.fetchone()[0] + else: + return None def messages(self) -> List[OnceConversation]: context = self.__get_messages_by_conv_uid(self.chat_seesion_id) if context: - conversations: List[OnceConversation] = json.loads(context[0]) + conversations: List[OnceConversation] = json.loads(context) return conversations return [] @@ -54,15 +58,15 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): context = self.__get_messages_by_conv_uid(self.chat_seesion_id) conversations: List[OnceConversation] = [] if context: - conversations = json.load(context) - conversations.append(once_message) + conversations = json.loads(context) + conversations.append(_conversation_to_dic(once_message)) cursor = self.connect.cursor() if context: cursor.execute("UPDATE chat_history set messages=? where conv_uid=?", - [json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4), self.chat_seesion_id]) + [json.dumps(conversations, ensure_ascii=False), self.chat_seesion_id]) else: cursor.execute("INSERT INTO chat_history(conv_uid, user_name, messages)VALUES(?,?,?)", - [self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False, indent=4)]) + [self.chat_seesion_id, "", json.dumps(conversations_to_dict(conversations), ensure_ascii=False)]) cursor.commit() self.connect.commit() diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index b75179eaf..f04c20145 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -12,7 +12,7 @@ from fastapi.responses import JSONResponse from sse_starlette.sse import EventSourceResponse from typing import List -from pilot.server.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo +from pilot.openapi.api_v1.api_view_model import Result, ConversationVo, MessageVo, ChatSceneVo from pilot.configs.config import Config from pilot.scene.base_chat import BaseChat from pilot.scene.base import ChatScene @@ -70,7 +70,8 @@ async def dialogue_list(response: Response, user_id: str = None): @router.post('/v1/chat/dialogue/scenes', response_model=Result[List[ChatSceneVo]]) async def dialogue_scenes(): scene_vos: List[ChatSceneVo] = [] - new_modes:List[ChatScene] = [ChatScene.ChatDb, ChatScene.ChatData, ChatScene.ChatDashboard, ChatScene.ChatKnowledge, ChatScene.ChatExecution] + new_modes: List[ChatScene] = [ChatScene.ChatWithDbQA, ChatScene.ChatWithDbExecute, ChatScene.ChatDashboard, + ChatScene.ChatKnowledge, ChatScene.ChatExecution] for scene in new_modes: if not scene.value in [ChatScene.ChatNormal.value, ChatScene.InnerChatDBSummary.value]: scene_vo = ChatSceneVo(chat_scene=scene.value, scene_name=scene.name, param_title="Selection Param") @@ -87,7 +88,7 @@ async def dialogue_new(chat_mode: str = ChatScene.ChatNormal.value, user_id: str def get_db_list(): db = CFG.local_db dbs = db.get_database_list() - params:dict = {} + params: dict = {} for name in dbs: params.update({name: name}) return params @@ -108,9 +109,9 @@ def knowledge_list(): @router.post('/v1/chat/mode/params/list', response_model=Result[dict]) async def params_list(chat_mode: str = ChatScene.ChatNormal.value): - if ChatScene.ChatDb.value == chat_mode: + if ChatScene.ChatWithDbQA.value == chat_mode: return Result.succ(get_db_list()) - elif ChatScene.ChatData.value == chat_mode: + elif ChatScene.ChatWithDbExecute.value == chat_mode: return Result.succ(get_db_list()) elif ChatScene.ChatDashboard.value == chat_mode: return Result.succ(get_db_list()) @@ -155,24 +156,21 @@ async def chat_completions(dialogue: ConversationVo = Body()): "user_input": dialogue.user_input, } - if ChatScene.ChatDb == dialogue.chat_mode: - chat_param.update("db_name", dialogue.select_param) - elif ChatScene.ChatData == dialogue.chat_mode: - chat_param.update("db_name", dialogue.select_param) - elif ChatScene.ChatDashboard == dialogue.chat_mode: - chat_param.update("db_name", dialogue.select_param) - elif ChatScene.ChatExecution == dialogue.chat_mode: - chat_param.update("plugin_selector", dialogue.select_param) - elif ChatScene.ChatKnowledge == dialogue.chat_mode: - chat_param.update("knowledge_name", dialogue.select_param) + if ChatScene.ChatWithDbQA.value == dialogue.chat_mode: + chat_param.update({"db_name": dialogue.select_param}) + elif ChatScene.ChatWithDbExecute.value == dialogue.chat_mode: + chat_param.update({"db_name": dialogue.select_param}) + elif ChatScene.ChatDashboard.value == dialogue.chat_mode: + chat_param.update({"db_name": dialogue.select_param}) + elif ChatScene.ChatExecution.value == dialogue.chat_mode: + chat_param.update({"plugin_selector": dialogue.select_param}) + elif ChatScene.ChatKnowledge.value == dialogue.chat_mode: + chat_param.update({"knowledge_name": dialogue.select_param}) chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) if not chat.prompt_template.stream_out: return non_stream_response(chat) else: - # generator = stream_generator(chat) - # result = Result.succ(data=StreamingResponse(stream_test(), media_type='text/plain')) - # return result return StreamingResponse(stream_generator(chat), media_type="text/plain") @@ -196,12 +194,6 @@ def stream_generator(chat): chat.memory.append(chat.current_message) -# def stream_response(chat): -# logger.info("stream out start!") -# api_response = StreamingResponse(stream_generator(chat), media_type="application/json") -# return api_response - - def message2Vo(message: dict, order) -> MessageVo: # message.additional_kwargs['time_stamp'] if message.additional_kwargs["time_stamp"] else 0 return MessageVo(role=message['type'], context=message['data']['content'], order=order) diff --git a/pilot/scene/base.py b/pilot/scene/base.py index cec443beb..b5e15b4e8 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -18,8 +18,8 @@ class ChatScene(Enum): ChatNormal = "chat_normal" ChatDashboard = "chat_dashboard" ChatKnowledge = "chat_knowledge" - ChatDb = "chat_db" - ChatData= "chat_data" + # ChatDb = "chat_db" + # ChatData= "chat_data" @staticmethod def is_valid_mode(mode): diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index d6628a19a..042d6af59 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -75,7 +75,7 @@ class BaseChat(ABC): self.prompt_template: PromptTemplate = CFG.prompt_templates[ self.chat_mode.value ] - self.history_message: List[OnceConversation] = [] + self.history_message: List[OnceConversation] = self.memory.messages() self.current_message: OnceConversation = OnceConversation(chat_mode.value) self.current_tokens_used: int = 0 ### load chat_session_id's chat historys @@ -95,7 +95,7 @@ class BaseChat(ABC): def generate_input_values(self): pass - @abstractmethod + def do_action(self, prompt_response): return prompt_response diff --git a/pilot/server/webserver.py b/pilot/server/webserver.py index c2cd9d434..86dc0d426 100644 --- a/pilot/server/webserver.py +++ b/pilot/server/webserver.py @@ -58,7 +58,7 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles -from pilot.server.api_v1.api_v1 import router as api_v1, validation_exception_handler +from pilot.openapi.api_v1.api_v1 import router as api_v1, validation_exception_handler # 加载插件 CFG = Config()