diff --git a/pilot/conversation.py b/pilot/conversation.py index 8a758dd51..3fe648529 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -108,8 +108,8 @@ class Conversation: conv_default = Conversation( system = None, roles=("human", "ai"), - messages= (), - offset=2, + messages=[], + offset=0, sep_style=SeparatorStyle.SINGLE, sep="###", ) diff --git a/pilot/memory/chat_history/mem_history.py b/pilot/memory/chat_history/mem_history.py new file mode 100644 index 000000000..8ff0d08eb --- /dev/null +++ b/pilot/memory/chat_history/mem_history.py @@ -0,0 +1,33 @@ +from typing import List +import json +import os +import datetime +from pilot.memory.chat_history.base import BaseChatHistoryMemory +from pathlib import Path + +from pilot.configs.config import Config +from pilot.scene.message import ( + OnceConversation, + conversation_from_dict, + conversations_to_dict, +) + + +CFG = Config() + + +class MemHistoryMemory(BaseChatHistoryMemory): + histroies_map = {} + + def __init__(self, chat_session_id: str): + self.chat_seesion_id = chat_session_id + self.histroies_map.update({chat_session_id: []}) + + def messages(self) -> List[OnceConversation]: + return self.histroies_map.get(self.chat_seesion_id) + + def append(self, once_message: OnceConversation) -> None: + self.histroies_map.get(self.chat_seesion_id).append(once_message) + + def clear(self) -> None: + self.histroies_map.pop(self.chat_seesion_id) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 798e071e3..dce25bec4 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -23,6 +23,7 @@ from pilot.scene.message import OnceConversation from pilot.prompts.prompt_new import PromptTemplate from pilot.memory.chat_history.base import BaseChatHistoryMemory from pilot.memory.chat_history.file_history import FileHistoryMemory +from pilot.memory.chat_history.mem_history import MemHistoryMemory from pilot.configs.model_config import LOGDIR, DATASETS_DIR from pilot.utils import ( @@ -61,7 +62,10 @@ class BaseChat(ABC): self.chat_mode = chat_mode self.current_user_input: str = current_user_input self.llm_model = CFG.LLM_MODEL - ### TODO + ### can configurable storage methods + # self.memory = MemHistoryMemory(chat_session_id) + + ## TEST self.memory = FileHistoryMemory(chat_session_id) ### load prompt template self.prompt_template: PromptTemplate = CFG.prompt_templates[self.chat_mode.value] diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index 9a381345f..22cb46846 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -14,14 +14,12 @@ PROMPT_SCENE_DEFINE = """You are an AI designed to answer human questions, pleas _DEFAULT_TEMPLATE = """ You are a SQL expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. -You can order the results by a relevant column to return the most interesting examples in the database. -Never query for all the columns from a specific table, only ask for a the few relevant columns given the question. -If the given table is beyond the scope of use, do not use it forcibly. +Use as few tables as possible when querying. Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. """ -PROMPT_SUFFIX = """Only use the following tables: +PROMPT_SUFFIX = """Only use the following tables generate sql: {table_info} Question: {input}