mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-10 12:42:34 +00:00
Merge remote-tracking branch 'origin/Agent_Hub_Dev' into Agent_Hub_Dev
This commit is contained in:
commit
8db497f6c0
@ -227,7 +227,7 @@ class ApiCall:
|
|||||||
i += 1
|
i += 1
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def __check_last_plugin_call_ready(self, all_context):
|
def check_last_plugin_call_ready(self, all_context):
|
||||||
start_agent_count = all_context.count(self.agent_prefix)
|
start_agent_count = all_context.count(self.agent_prefix)
|
||||||
end_agent_count = all_context.count(self.agent_end)
|
end_agent_count = all_context.count(self.agent_end)
|
||||||
|
|
||||||
@ -359,7 +359,7 @@ class ApiCall:
|
|||||||
def run(self, llm_text):
|
def run(self, llm_text):
|
||||||
if self.__is_need_wait_plugin_call(llm_text):
|
if self.__is_need_wait_plugin_call(llm_text):
|
||||||
# wait api call generate complete
|
# wait api call generate complete
|
||||||
if self.__check_last_plugin_call_ready(llm_text):
|
if self.check_last_plugin_call_ready(llm_text):
|
||||||
self.update_from_context(llm_text)
|
self.update_from_context(llm_text)
|
||||||
for key, value in self.plugin_status_map.items():
|
for key, value in self.plugin_status_map.items():
|
||||||
if value.status == Status.TODO.value:
|
if value.status == Status.TODO.value:
|
||||||
@ -379,7 +379,7 @@ class ApiCall:
|
|||||||
def run_display_sql(self, llm_text, sql_run_func):
|
def run_display_sql(self, llm_text, sql_run_func):
|
||||||
if self.__is_need_wait_plugin_call(llm_text):
|
if self.__is_need_wait_plugin_call(llm_text):
|
||||||
# wait api call generate complete
|
# wait api call generate complete
|
||||||
if self.__check_last_plugin_call_ready(llm_text):
|
if self.check_last_plugin_call_ready(llm_text):
|
||||||
self.update_from_context(llm_text)
|
self.update_from_context(llm_text)
|
||||||
for key, value in self.plugin_status_map.items():
|
for key, value in self.plugin_status_map.items():
|
||||||
if value.status == Status.TODO.value:
|
if value.status == Status.TODO.value:
|
||||||
|
@ -29,6 +29,8 @@ from pilot.scene.message import OnceConversation
|
|||||||
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
|
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
|
||||||
from pilot.scene.chat_db.data_loader import DbDataLoader
|
from pilot.scene.chat_db.data_loader import DbDataLoader
|
||||||
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
|
||||||
|
from pilot.base_modules.agent.commands.command_mange import ApiCall
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
@ -101,12 +103,15 @@ async def get_editor_sql(con_uid: str, round: int):
|
|||||||
logger.info(
|
logger.info(
|
||||||
f'history ai json resp:{element["data"]["content"]}'
|
f'history ai json resp:{element["data"]["content"]}'
|
||||||
)
|
)
|
||||||
context = (
|
api_call = ApiCall()
|
||||||
element["data"]["content"]
|
result = {}
|
||||||
.replace("\\n", " ")
|
result['thoughts'] = element["data"]["content"]
|
||||||
.replace("\n", " ")
|
if api_call.check_last_plugin_call_ready(element["data"]["content"]):
|
||||||
)
|
api_call.update_from_context(element["data"]["content"])
|
||||||
return Result.succ(json.loads(context))
|
if len(api_call.plugin_status_map) > 0:
|
||||||
|
first_item = next(iter(api_call.plugin_status_map.items()))[1]
|
||||||
|
result['sql'] = first_item.args["sql"]
|
||||||
|
return Result.succ(result)
|
||||||
return Result.faild(msg="not have sql!")
|
return Result.faild(msg="not have sql!")
|
||||||
|
|
||||||
|
|
||||||
@ -156,17 +161,18 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
|
|||||||
)
|
)
|
||||||
)[0]
|
)[0]
|
||||||
if edit_round:
|
if edit_round:
|
||||||
|
new_ai_text = ""
|
||||||
for element in edit_round["messages"]:
|
for element in edit_round["messages"]:
|
||||||
if element["type"] == "ai":
|
if element["type"] == "ai":
|
||||||
db_resp = json.loads(element["data"]["content"])
|
new_ai_text = element["data"]["content"]
|
||||||
db_resp["thoughts"] = sql_edit_context.new_speak
|
new_ai_text.replace(sql_edit_context.old_sql, sql_edit_context.new_sql)
|
||||||
db_resp["sql"] = sql_edit_context.new_sql
|
element["data"]["content"] = new_ai_text
|
||||||
element["data"]["content"] = json.dumps(db_resp)
|
|
||||||
|
for element in edit_round["messages"]:
|
||||||
if element["type"] == "view":
|
if element["type"] == "view":
|
||||||
data_loader = DbDataLoader()
|
api_call = ApiCall()
|
||||||
element["data"]["content"] = data_loader.get_table_view_by_conn(
|
new_view_text = api_call.run_display_sql(new_ai_text, conn.run_to_df)
|
||||||
conn.run(sql_edit_context.new_sql), sql_edit_context.new_speak
|
element["data"]["content"] = new_view_text
|
||||||
)
|
|
||||||
history_mem.update(history_messages)
|
history_mem.update(history_messages)
|
||||||
return Result.succ(None)
|
return Result.succ(None)
|
||||||
return Result.faild(msg="Edit Faild!")
|
return Result.faild(msg="Edit Faild!")
|
||||||
|
@ -5,6 +5,7 @@ from pilot.scene.base import ChatScene
|
|||||||
from pilot.common.sql_database import Database
|
from pilot.common.sql_database import Database
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
from pilot.scene.chat_db.auto_execute.prompt import prompt
|
||||||
|
from pilot.base_modules.agent.commands.command_mange import ApiCall
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
@ -37,6 +38,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
|
|
||||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
||||||
self.top_k: int = 200
|
self.top_k: int = 200
|
||||||
|
self.api_call = ApiCall(display_registry=CFG.command_disply)
|
||||||
|
|
||||||
def generate_input_values(self):
|
def generate_input_values(self):
|
||||||
"""
|
"""
|
||||||
@ -69,6 +71,11 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
}
|
}
|
||||||
return input_values
|
return input_values
|
||||||
|
|
||||||
def do_action(self, prompt_response):
|
def stream_plugin_call(self, text):
|
||||||
print(f"do_action:{prompt_response}")
|
text = text.replace("\n", " ")
|
||||||
return self.database.run(prompt_response.sql)
|
print(f"stream_plugin_call:{text}")
|
||||||
|
return self.api_call.run_display_sql(text, self.database.run_to_df)
|
||||||
|
#
|
||||||
|
# def do_action(self, prompt_response):
|
||||||
|
# print(f"do_action:{prompt_response}")
|
||||||
|
# return self.database.run(prompt_response.sql)
|
||||||
|
@ -8,24 +8,51 @@ from pilot.scene.chat_db.auto_execute.example import sql_data_example
|
|||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
|
|
||||||
PROMPT_SCENE_DEFINE = "You are a SQL expert. "
|
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = """
|
_PROMPT_SCENE_DEFINE_EN = "You are a database expert. "
|
||||||
|
_PROMPT_SCENE_DEFINE_ZH = "你是一个数据库专家. "
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE_EN = """
|
||||||
Given an input question, create a syntactically correct {dialect} sql.
|
Given an input question, create a syntactically correct {dialect} sql.
|
||||||
|
Table structure information:
|
||||||
Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
|
|
||||||
Use as few tables as possible when querying.
|
|
||||||
Only use the following tables schema to generate sql:
|
|
||||||
{table_info}
|
{table_info}
|
||||||
Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
Constraint:
|
||||||
|
1. You can only use the table provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql query." It is prohibited to fabricate information at will.
|
||||||
|
2. Do not query columns that do not exist. Pay attention to which column is in which table.
|
||||||
|
3. Replace the corresponding sql into the sql field in the returned result
|
||||||
|
4. Unless the user specifies in the question a specific number of examples he wishes to obtain, always limit the query to a maximum of {top_k} results.
|
||||||
|
5. Please output the Sql content in the following format to execute the corresponding SQL to display the data:<api-call><name>response_table</name><args><sql>SQL Query to run</sql></args></api-call>
|
||||||
|
Please make sure to respond as following format:
|
||||||
|
thoughts summary to say to user.<api-call><name>response_table</name><args><sql>SQL Query to run</sql></args></api-call>
|
||||||
|
|
||||||
Question: {input}
|
Question: {input}
|
||||||
|
|
||||||
Respond in JSON format as following format:
|
|
||||||
{response}
|
|
||||||
Ensure the response is correct json and can be parsed by Python json.loads
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE_ZH = """
|
||||||
|
给定一个输入问题,创建一个语法正确的 {dialect} sql。
|
||||||
|
已知表结构信息:
|
||||||
|
{table_info}
|
||||||
|
|
||||||
|
约束:
|
||||||
|
1. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。
|
||||||
|
2. 不要查询不存在的列,注意哪一列位于哪张表中。
|
||||||
|
3.将对应的sql替换到返回结果中的sql字段中
|
||||||
|
4.除非用户在问题中指定了他希望获得的具体示例数量,否则始终将查询限制为最多 {top_k} 个结果。
|
||||||
|
|
||||||
|
请务必按照以下格式回复:
|
||||||
|
对用户说的想法摘要。<api-call><name>response_table</name><args><sql>要运行的 SQL</sql></args></api-call>
|
||||||
|
|
||||||
|
问题:{input}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = (
|
||||||
|
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||||
|
)
|
||||||
|
|
||||||
|
PROMPT_SCENE_DEFINE = (
|
||||||
|
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
|
||||||
|
)
|
||||||
|
|
||||||
RESPONSE_FORMAT_SIMPLE = {
|
RESPONSE_FORMAT_SIMPLE = {
|
||||||
"thoughts": "thoughts summary to say to user",
|
"thoughts": "thoughts summary to say to user",
|
||||||
"sql": "SQL Query to run",
|
"sql": "SQL Query to run",
|
||||||
@ -33,7 +60,7 @@ RESPONSE_FORMAT_SIMPLE = {
|
|||||||
|
|
||||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||||
|
|
||||||
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
|
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
|
||||||
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
|
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
|
||||||
@ -43,7 +70,7 @@ PROMPT_TEMPERATURE = 0.5
|
|||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template_scene=ChatScene.ChatWithDbExecute.value(),
|
template_scene=ChatScene.ChatWithDbExecute.value(),
|
||||||
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
||||||
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
|
# response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
|
||||||
template_define=PROMPT_SCENE_DEFINE,
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
template=_DEFAULT_TEMPLATE,
|
template=_DEFAULT_TEMPLATE,
|
||||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
|
Loading…
Reference in New Issue
Block a user