Merge remote-tracking branch 'origin/Agent_Hub_Dev' into Agent_Hub_Dev

This commit is contained in:
aries_ckt 2023-10-20 16:25:14 +08:00
commit 8db497f6c0
4 changed files with 75 additions and 35 deletions

View File

@ -227,7 +227,7 @@ class ApiCall:
i += 1
return False
def __check_last_plugin_call_ready(self, all_context):
def check_last_plugin_call_ready(self, all_context):
start_agent_count = all_context.count(self.agent_prefix)
end_agent_count = all_context.count(self.agent_end)
@ -359,7 +359,7 @@ class ApiCall:
def run(self, llm_text):
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.__check_last_plugin_call_ready(llm_text):
if self.check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
@ -379,7 +379,7 @@ class ApiCall:
def run_display_sql(self, llm_text, sql_run_func):
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.__check_last_plugin_call_ready(llm_text):
if self.check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:

View File

@ -29,6 +29,8 @@ from pilot.scene.message import OnceConversation
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
from pilot.scene.chat_db.data_loader import DbDataLoader
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
from pilot.base_modules.agent.commands.command_mange import ApiCall
router = APIRouter()
CFG = Config()
@ -101,12 +103,15 @@ async def get_editor_sql(con_uid: str, round: int):
logger.info(
f'history ai json resp:{element["data"]["content"]}'
)
context = (
element["data"]["content"]
.replace("\\n", " ")
.replace("\n", " ")
)
return Result.succ(json.loads(context))
api_call = ApiCall()
result = {}
result['thoughts'] = element["data"]["content"]
if api_call.check_last_plugin_call_ready(element["data"]["content"]):
api_call.update_from_context(element["data"]["content"])
if len(api_call.plugin_status_map) > 0:
first_item = next(iter(api_call.plugin_status_map.items()))[1]
result['sql'] = first_item.args["sql"]
return Result.succ(result)
return Result.faild(msg="not have sql!")
@ -156,17 +161,18 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
)
)[0]
if edit_round:
new_ai_text = ""
for element in edit_round["messages"]:
if element["type"] == "ai":
db_resp = json.loads(element["data"]["content"])
db_resp["thoughts"] = sql_edit_context.new_speak
db_resp["sql"] = sql_edit_context.new_sql
element["data"]["content"] = json.dumps(db_resp)
new_ai_text = element["data"]["content"]
new_ai_text.replace(sql_edit_context.old_sql, sql_edit_context.new_sql)
element["data"]["content"] = new_ai_text
for element in edit_round["messages"]:
if element["type"] == "view":
data_loader = DbDataLoader()
element["data"]["content"] = data_loader.get_table_view_by_conn(
conn.run(sql_edit_context.new_sql), sql_edit_context.new_speak
)
api_call = ApiCall()
new_view_text = api_call.run_display_sql(new_ai_text, conn.run_to_df)
element["data"]["content"] = new_view_text
history_mem.update(history_messages)
return Result.succ(None)
return Result.faild(msg="Edit Faild!")

View File

@ -5,6 +5,7 @@ from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.scene.chat_db.auto_execute.prompt import prompt
from pilot.base_modules.agent.commands.command_mange import ApiCall
CFG = Config()
@ -37,6 +38,7 @@ class ChatWithDbAutoExecute(BaseChat):
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.top_k: int = 200
self.api_call = ApiCall(display_registry=CFG.command_disply)
def generate_input_values(self):
"""
@ -69,6 +71,11 @@ class ChatWithDbAutoExecute(BaseChat):
}
return input_values
def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
return self.database.run(prompt_response.sql)
def stream_plugin_call(self, text):
text = text.replace("\n", " ")
print(f"stream_plugin_call:{text}")
return self.api_call.run_display_sql(text, self.database.run_to_df)
#
# def do_action(self, prompt_response):
# print(f"do_action:{prompt_response}")
# return self.database.run(prompt_response.sql)

View File

@ -8,24 +8,51 @@ from pilot.scene.chat_db.auto_execute.example import sql_data_example
CFG = Config()
PROMPT_SCENE_DEFINE = "You are a SQL expert. "
_DEFAULT_TEMPLATE = """
_PROMPT_SCENE_DEFINE_EN = "You are a database expert. "
_PROMPT_SCENE_DEFINE_ZH = "你是一个数据库专家. "
_DEFAULT_TEMPLATE_EN = """
Given an input question, create a syntactically correct {dialect} sql.
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
Use as few tables as possible when querying.
Only use the following tables schema to generate sql:
{table_info}
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Table structure information:
{table_info}
Constraint:
1. You can only use the table provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql query." It is prohibited to fabricate information at will.
2. Do not query columns that do not exist. Pay attention to which column is in which table.
3. Replace the corresponding sql into the sql field in the returned result
4. Unless the user specifies in the question a specific number of examples he wishes to obtain, always limit the query to a maximum of {top_k} results.
5. Please output the Sql content in the following format to execute the corresponding SQL to display the data:<api-call><name>response_table</name><args><sql>SQL Query to run</sql></args></api-call>
Please make sure to respond as following format:
thoughts summary to say to user.<api-call><name>response_table</name><args><sql>SQL Query to run</sql></args></api-call>
Question: {input}
Respond in JSON format as following format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads
"""
_DEFAULT_TEMPLATE_ZH = """
给定一个输入问题创建一个语法正确的 {dialect} sql
已知表结构信息:
{table_info}
约束:
1. 只能使用表结构信息中提供的表来生成 sql如果无法根据提供的表结构中生成 sql 请说提供的表结构信息不足以生成 sql 查询 禁止随意捏造信息
2. 不要查询不存在的列注意哪一列位于哪张表中
3.将对应的sql替换到返回结果中的sql字段中
4.除非用户在问题中指定了他希望获得的具体示例数量否则始终将查询限制为最多 {top_k} 个结果
请务必按照以下格式回复
对用户说的想法摘要<api-call><name>response_table</name><args><sql>要运行的 SQL</sql></args></api-call>
问题:{input}
"""
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_SCENE_DEFINE = (
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
)
RESPONSE_FORMAT_SIMPLE = {
"thoughts": "thoughts summary to say to user",
"sql": "SQL Query to run",
@ -33,7 +60,7 @@ RESPONSE_FORMAT_SIMPLE = {
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
PROMPT_NEED_NEED_STREAM_OUT = True
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
@ -43,7 +70,7 @@ PROMPT_TEMPERATURE = 0.5
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value(),
input_variables=["input", "table_info", "dialect", "top_k", "response"],
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
# response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,