mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-10 20:52:33 +00:00
bugfix(ChatData): ChatData Use AntV Table
1.Merge ChatData and ChatDB
This commit is contained in:
parent
baeaf1933f
commit
ad7964f4ab
@ -229,7 +229,6 @@ async def dialogue_scenes():
|
|||||||
new_modes: List[ChatScene] = [
|
new_modes: List[ChatScene] = [
|
||||||
ChatScene.ChatWithDbExecute,
|
ChatScene.ChatWithDbExecute,
|
||||||
ChatScene.ChatExcel,
|
ChatScene.ChatExcel,
|
||||||
ChatScene.ChatWithDbQA,
|
|
||||||
ChatScene.ChatKnowledge,
|
ChatScene.ChatKnowledge,
|
||||||
ChatScene.ChatDashboard,
|
ChatScene.ChatDashboard,
|
||||||
ChatScene.ChatAgent,
|
ChatScene.ChatAgent,
|
||||||
|
@ -215,7 +215,7 @@ class BaseOutputParser(ABC):
|
|||||||
cleaned_output = self.__illegal_json_ends(cleaned_output)
|
cleaned_output = self.__illegal_json_ends(cleaned_output)
|
||||||
return cleaned_output
|
return cleaned_output
|
||||||
|
|
||||||
def parse_view_response(self, ai_text, data) -> str:
|
def parse_view_response(self, ai_text, data, parse_prompt_response:Any=None) -> str:
|
||||||
"""
|
"""
|
||||||
parse the ai response info to user view
|
parse the ai response info to user view
|
||||||
Args:
|
Args:
|
||||||
|
@ -25,8 +25,8 @@ class Scene:
|
|||||||
class ChatScene(Enum):
|
class ChatScene(Enum):
|
||||||
ChatWithDbExecute = Scene(
|
ChatWithDbExecute = Scene(
|
||||||
code="chat_with_db_execute",
|
code="chat_with_db_execute",
|
||||||
name="Chat Data",
|
name="Chat DB",
|
||||||
describe="Dialogue with your private data through natural language.",
|
describe="Dialogue with your private databse data through natural language.",
|
||||||
param_types=["DB Select"],
|
param_types=["DB Select"],
|
||||||
)
|
)
|
||||||
ExcelLearning = Scene(
|
ExcelLearning = Scene(
|
||||||
|
@ -227,7 +227,6 @@ class BaseChat(ABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
### run
|
### run
|
||||||
# result = self.do_action(prompt_define_response)
|
|
||||||
result = await blocking_func_to_async(
|
result = await blocking_func_to_async(
|
||||||
self._executor, self.do_action, prompt_define_response
|
self._executor, self.do_action, prompt_define_response
|
||||||
)
|
)
|
||||||
@ -243,6 +242,7 @@ class BaseChat(ABC):
|
|||||||
self.prompt_template.output_parser.parse_view_response,
|
self.prompt_template.output_parser.parse_view_response,
|
||||||
speak_to_user,
|
speak_to_user,
|
||||||
result,
|
result,
|
||||||
|
prompt_define_response
|
||||||
)
|
)
|
||||||
|
|
||||||
view_message = view_message.replace("\n", "\\n")
|
view_message = view_message.replace("\n", "\\n")
|
||||||
|
@ -50,17 +50,28 @@ class ChatExcel(BaseChat):
|
|||||||
super().__init__(chat_param=chat_param)
|
super().__init__(chat_param=chat_param)
|
||||||
|
|
||||||
def _generate_numbered_list(self) -> str:
|
def _generate_numbered_list(self) -> str:
|
||||||
command_strings = []
|
antv_charts = [{"line_chart":"used to display comparative trend analysis data"},
|
||||||
if CFG.command_disply:
|
{"pie_chart":"suitable for scenarios such as proportion and distribution statistics"},
|
||||||
for name, item in CFG.command_disply.commands.items():
|
{"response_table":"suitable for display with many display columns or non-numeric columns"},
|
||||||
if item.enabled:
|
{"data_text":" the default display method, suitable for single-line or simple content display"},
|
||||||
command_strings.append(f"{name}:{item.description}")
|
{"scatter_plot":"Suitable for exploring relationships between variables, detecting outliers, etc."},
|
||||||
|
{"bubble_chart":"Suitable for relationships between multiple variables, highlighting outliers or special situations, etc."},
|
||||||
|
{"donut_chart":"Suitable for hierarchical structure representation, category proportion display and highlighting key categories, etc."},
|
||||||
|
{"area_chart":"Suitable for visualization of time series data, comparison of multiple groups of data, analysis of data change trends, etc."},
|
||||||
|
{"heatmap":"Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc."}
|
||||||
|
]
|
||||||
|
|
||||||
|
# command_strings = []
|
||||||
|
# if CFG.command_disply:
|
||||||
|
# for name, item in CFG.command_disply.commands.items():
|
||||||
|
# if item.enabled:
|
||||||
|
# command_strings.append(f"{name}:{item.description}")
|
||||||
# command_strings += [
|
# command_strings += [
|
||||||
# str(item)
|
# str(item)
|
||||||
# for item in CFG.command_disply.commands.values()
|
# for item in CFG.command_disply.commands.values()
|
||||||
# if item.enabled
|
# if item.enabled
|
||||||
# ]
|
# ]
|
||||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
|
return "\n".join(f"{key}:{value}" for dict_item in antv_charts for key, value in dict_item.items())
|
||||||
|
|
||||||
async def generate_input_values(self) -> Dict:
|
async def generate_input_values(self) -> Dict:
|
||||||
input_values = {
|
input_values = {
|
||||||
|
@ -22,7 +22,7 @@ Constraint:
|
|||||||
5.The <api-call></api-call> part of the required output format needs to be parsed by the code. Please ensure that this part of the content is output as required.
|
5.The <api-call></api-call> part of the required output format needs to be parsed by the code. Please ensure that this part of the content is output as required.
|
||||||
|
|
||||||
Please respond in the following format:
|
Please respond in the following format:
|
||||||
Summary of your analytical thinking.<api-call><name>[Data display method]</name><args><sql>[Correct duckdb data analysis sql]</sql></args></api-call>
|
thoughts summary to say to user.<api-call><name>[Data display method]</name><args><sql>[Correct duckdb data analysis sql]</sql></args></api-call>
|
||||||
|
|
||||||
User Questions:
|
User Questions:
|
||||||
{user_input}
|
{user_input}
|
||||||
@ -38,7 +38,7 @@ _DEFAULT_TEMPLATE_ZH = """
|
|||||||
4.优先使用数据分析的方式回答,如果用户问题不涉及数据分析内容,你可以按你的理解进行回答
|
4.优先使用数据分析的方式回答,如果用户问题不涉及数据分析内容,你可以按你的理解进行回答
|
||||||
5.要求的输出格式中<api-call></api-call>部分需要被代码解析执行,请确保这部分内容按要求输出
|
5.要求的输出格式中<api-call></api-call>部分需要被代码解析执行,请确保这部分内容按要求输出
|
||||||
请确保你的输出格式如下:
|
请确保你的输出格式如下:
|
||||||
分析思路总结.<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
|
对用户说的想法摘要.<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
|
||||||
|
|
||||||
用户问题:{user_input}
|
用户问题:{user_input}
|
||||||
"""
|
"""
|
||||||
|
@ -247,17 +247,12 @@ if __name__ == "__main__":
|
|||||||
# sql = add_quotes_to_chinese_columns(sql_3)
|
# sql = add_quotes_to_chinese_columns(sql_3)
|
||||||
# print(f"excute sql:{sql}")
|
# print(f"excute sql:{sql}")
|
||||||
|
|
||||||
data = {
|
my_list = [{'name': 'John', 'age': 30}, {'name': 'Alice', 'age': 25}, {'name': 'Bob', 'age': 35}]
|
||||||
'name': ['John', 'Alice', 'Bob'],
|
|
||||||
'age': [30, 25, 35],
|
|
||||||
'timestamp': [pd.Timestamp('2022-01-01'), pd.Timestamp('2022-01-02'), pd.Timestamp('2022-01-03')],
|
|
||||||
'category': pd.Categorical(['A', 'B', 'C'])
|
|
||||||
}
|
|
||||||
|
|
||||||
df = pd.DataFrame(data)
|
for dict_item in my_list:
|
||||||
|
for key, value in dict_item.items():
|
||||||
|
print(key, value)
|
||||||
|
|
||||||
json_data = df.to_json(orient='records', date_format='iso', date_unit='s')
|
|
||||||
print(json_data)
|
|
||||||
|
|
||||||
class ExcelReader:
|
class ExcelReader:
|
||||||
def __init__(self, file_path):
|
def __init__(self, file_path):
|
||||||
|
@ -37,7 +37,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
|
||||||
self.top_k: int = 200
|
self.top_k: int = 50
|
||||||
self.api_call = ApiCall(display_registry=CFG.command_disply)
|
self.api_call = ApiCall(display_registry=CFG.command_disply)
|
||||||
|
|
||||||
async def generate_input_values(self):
|
async def generate_input_values(self):
|
||||||
@ -64,7 +64,7 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
# table_infos = self.database.table_simple_info()
|
# table_infos = self.database.table_simple_info()
|
||||||
|
|
||||||
input_values = {
|
input_values = {
|
||||||
"input": self.current_user_input,
|
# "input": self.current_user_input,
|
||||||
"top_k": str(self.top_k),
|
"top_k": str(self.top_k),
|
||||||
"dialect": self.database.dialect,
|
"dialect": self.database.dialect,
|
||||||
"table_info": table_infos,
|
"table_info": table_infos,
|
||||||
@ -79,4 +79,4 @@ class ChatWithDbAutoExecute(BaseChat):
|
|||||||
|
|
||||||
def do_action(self, prompt_response):
|
def do_action(self, prompt_response):
|
||||||
print(f"do_action:{prompt_response}")
|
print(f"do_action:{prompt_response}")
|
||||||
return self.database.run(prompt_response.sql)
|
return self.database.run_to_df
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, NamedTuple
|
from typing import Dict, NamedTuple
|
||||||
import logging
|
import logging
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from pilot.common.json_utils import serialize
|
||||||
from pilot.out_parser.base import BaseOutputParser, T
|
from pilot.out_parser.base import BaseOutputParser, T
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
from pilot.scene.chat_db.data_loader import DbDataLoader
|
from pilot.scene.chat_db.data_loader import DbDataLoader
|
||||||
@ -31,21 +33,28 @@ class DbChatOutputParser(BaseOutputParser):
|
|||||||
thoughts = response[key]
|
thoughts = response[key]
|
||||||
return SqlAction(sql, thoughts)
|
return SqlAction(sql, thoughts)
|
||||||
|
|
||||||
def parse_view_response(self, speak, data) -> str:
|
def parse_view_response(self, speak, data, prompt_response) -> str:
|
||||||
import pandas as pd
|
|
||||||
|
param = {}
|
||||||
|
api_call_element = ET.Element("chart-view")
|
||||||
|
try:
|
||||||
|
df = data(prompt_response.sql)
|
||||||
|
param["type"] = "response_table"
|
||||||
|
param["sql"] = prompt_response.sql
|
||||||
|
param["data"] = json.loads(df.to_json(orient='records', date_format='iso', date_unit='s'))
|
||||||
|
view_json_str = json.dumps(param, default=serialize)
|
||||||
|
except Exception as e:
|
||||||
|
err_param ={}
|
||||||
|
param["sql"] = prompt_response.sql
|
||||||
|
err_param["type"] = "response_table"
|
||||||
|
err_param["err_msg"] = str(e)
|
||||||
|
view_json_str = json.dumps(err_param, default=serialize)
|
||||||
|
|
||||||
|
api_call_element.text = view_json_str
|
||||||
|
result = ET.tostring(api_call_element, encoding="utf-8")
|
||||||
|
|
||||||
|
return speak + "\n" + result.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### tool out data to table view
|
|
||||||
data_loader = DbDataLoader()
|
|
||||||
if len(data) < 1:
|
|
||||||
data.insert(0, [])
|
|
||||||
df = pd.DataFrame(data[1:], columns=data[0])
|
|
||||||
if not CFG.NEW_SERVER_MODE and not CFG.SERVER_LIGHT_MODE:
|
|
||||||
table_style = """<style>
|
|
||||||
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}
|
|
||||||
</style>"""
|
|
||||||
html_table = df.to_html(index=False, escape=False)
|
|
||||||
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
|
|
||||||
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
|
||||||
return view_text
|
|
||||||
else:
|
|
||||||
return data_loader.get_table_view_by_conn(data, speak)
|
|
||||||
|
@ -13,36 +13,35 @@ _PROMPT_SCENE_DEFINE_EN = "You are a database expert. "
|
|||||||
_PROMPT_SCENE_DEFINE_ZH = "你是一个数据库专家. "
|
_PROMPT_SCENE_DEFINE_ZH = "你是一个数据库专家. "
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE_EN = """
|
_DEFAULT_TEMPLATE_EN = """
|
||||||
Given an input question, create a syntactically correct {dialect} sql.
|
Please create a syntactically correct {dialect} sql based on the user question, use the following tables schema to generate sql:
|
||||||
Table structure information:
|
|
||||||
{table_info}
|
{table_info}
|
||||||
|
|
||||||
Constraint:
|
Constraint:
|
||||||
1. You can only use the table provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql query." It is prohibited to fabricate information at will.
|
1.Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
|
||||||
2. Do not query columns that do not exist. Pay attention to which column is in which table.
|
2.Please do not use columns that do not appear in the tables schema. Also be careful not to misunderstand the relationship between fields and tables in SQL.
|
||||||
3. Unless the user specifies in the question a specific number of examples he wishes to obtain, always limit the query to a maximum of {top_k} results.
|
3.Use as few tables as possible when querying.
|
||||||
4. Please ensure that the output result contains: <api-call><name>response_table</name><args><sql>SQL Query to run</sql></args></api-call>,and replace the generated sql into the parameter sql.
|
4.Please check the correctness of the SQL and ensure that the query performance is optimized under correct conditions.
|
||||||
|
|
||||||
Please make sure to respond as following format:
|
Please think step by step and respond according to the following JSON format:
|
||||||
thoughts summary to say to user.<api-call><name>response_table</name><args><sql>SQL Query to run</sql></args></api-call>
|
{response}
|
||||||
|
Ensure the response is correct json and can be parsed by Python json.loads.
|
||||||
|
|
||||||
Question: {input}
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE_ZH = """
|
_DEFAULT_TEMPLATE_ZH = """
|
||||||
给定一个输入问题,创建一个语法正确的 {dialect} sql。
|
请根据用户输入问题,使用如下的表结构定义创建一个语法正确的 {dialect} sql:
|
||||||
已知表结构信息:
|
|
||||||
{table_info}
|
{table_info}
|
||||||
|
|
||||||
约束:
|
约束:
|
||||||
1. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。
|
1. 除非用户在问题中指定了他希望获得的具体数据行数,否则始终将查询限制为最多 {top_k} 个结果。
|
||||||
2. 不要查询不存在的列,注意哪一列位于哪张表中。
|
2. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。
|
||||||
3. 请确保输出结果包含<api-call><name>response_table</name><args><sql>要运行的SQL</sql></args></api-call>, 并将对应的sql替换到sql参数中
|
3. 请注意生成SQL时不要弄错表和列的关系
|
||||||
4. 除非用户在问题中指定了他希望获得的具体示例数量,否则始终将查询限制为最多 {top_k} 个结果。
|
4. 请检查SQL的正确性,并保证正确的情况下优化查询性能
|
||||||
|
|
||||||
请务必按照以下格式回复:
|
请一步步思考并按照以下JSON格式回复:
|
||||||
对用户说的想法摘要。<api-call><name>response_table</name><args><sql>要运行的 SQL</sql></args></api-call>
|
{response}
|
||||||
|
确保返回正确的json并且可以被Python json.loads方法解析.
|
||||||
|
|
||||||
问题:{input}
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = (
|
_DEFAULT_TEMPLATE = (
|
||||||
@ -60,7 +59,7 @@ RESPONSE_FORMAT_SIMPLE = {
|
|||||||
|
|
||||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||||
|
|
||||||
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
|
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
|
||||||
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
|
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
|
||||||
@ -69,8 +68,8 @@ PROMPT_TEMPERATURE = 0.5
|
|||||||
|
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template_scene=ChatScene.ChatWithDbExecute.value(),
|
template_scene=ChatScene.ChatWithDbExecute.value(),
|
||||||
input_variables=["input", "table_info", "dialect", "top_k", "response"],
|
input_variables=["table_info", "dialect", "top_k", "response"],
|
||||||
# response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
|
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
|
||||||
template_define=PROMPT_SCENE_DEFINE,
|
template_define=PROMPT_SCENE_DEFINE,
|
||||||
template=_DEFAULT_TEMPLATE,
|
template=_DEFAULT_TEMPLATE,
|
||||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||||
@ -79,6 +78,7 @@ prompt = PromptTemplate(
|
|||||||
),
|
),
|
||||||
# example_selector=sql_data_example,
|
# example_selector=sql_data_example,
|
||||||
temperature=PROMPT_TEMPERATURE,
|
temperature=PROMPT_TEMPERATURE,
|
||||||
|
need_historical_messages=True
|
||||||
)
|
)
|
||||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||||
from . import prompt_baichuan
|
from . import prompt_baichuan
|
||||||
|
Loading…
Reference in New Issue
Block a user