bugfix(ChatData): ChatData Use AntV Table

1.Merge ChatData and ChatDB
This commit is contained in:
yhjun1026 2023-11-10 11:08:20 +08:00
parent baeaf1933f
commit ad7964f4ab
10 changed files with 80 additions and 66 deletions

View File

@ -229,7 +229,6 @@ async def dialogue_scenes():
new_modes: List[ChatScene] = [
ChatScene.ChatWithDbExecute,
ChatScene.ChatExcel,
ChatScene.ChatWithDbQA,
ChatScene.ChatKnowledge,
ChatScene.ChatDashboard,
ChatScene.ChatAgent,

View File

@ -215,7 +215,7 @@ class BaseOutputParser(ABC):
cleaned_output = self.__illegal_json_ends(cleaned_output)
return cleaned_output
def parse_view_response(self, ai_text, data) -> str:
def parse_view_response(self, ai_text, data, parse_prompt_response:Any=None) -> str:
"""
parse the ai response info to user view
Args:

View File

@ -25,8 +25,8 @@ class Scene:
class ChatScene(Enum):
ChatWithDbExecute = Scene(
code="chat_with_db_execute",
name="Chat Data",
describe="Dialogue with your private data through natural language.",
name="Chat DB",
describe="Dialogue with your private databse data through natural language.",
param_types=["DB Select"],
)
ExcelLearning = Scene(

View File

@ -227,7 +227,6 @@ class BaseChat(ABC):
)
)
### run
# result = self.do_action(prompt_define_response)
result = await blocking_func_to_async(
self._executor, self.do_action, prompt_define_response
)
@ -243,6 +242,7 @@ class BaseChat(ABC):
self.prompt_template.output_parser.parse_view_response,
speak_to_user,
result,
prompt_define_response
)
view_message = view_message.replace("\n", "\\n")

View File

@ -50,17 +50,28 @@ class ChatExcel(BaseChat):
super().__init__(chat_param=chat_param)
def _generate_numbered_list(self) -> str:
command_strings = []
if CFG.command_disply:
for name, item in CFG.command_disply.commands.items():
if item.enabled:
command_strings.append(f"{name}:{item.description}")
antv_charts = [{"line_chart":"used to display comparative trend analysis data"},
{"pie_chart":"suitable for scenarios such as proportion and distribution statistics"},
{"response_table":"suitable for display with many display columns or non-numeric columns"},
{"data_text":" the default display method, suitable for single-line or simple content display"},
{"scatter_plot":"Suitable for exploring relationships between variables, detecting outliers, etc."},
{"bubble_chart":"Suitable for relationships between multiple variables, highlighting outliers or special situations, etc."},
{"donut_chart":"Suitable for hierarchical structure representation, category proportion display and highlighting key categories, etc."},
{"area_chart":"Suitable for visualization of time series data, comparison of multiple groups of data, analysis of data change trends, etc."},
{"heatmap":"Suitable for visual analysis of time series data, large-scale data sets, distribution of classified data, etc."}
]
# command_strings = []
# if CFG.command_disply:
# for name, item in CFG.command_disply.commands.items():
# if item.enabled:
# command_strings.append(f"{name}:{item.description}")
# command_strings += [
# str(item)
# for item in CFG.command_disply.commands.values()
# if item.enabled
# ]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
return "\n".join(f"{key}:{value}" for dict_item in antv_charts for key, value in dict_item.items())
async def generate_input_values(self) -> Dict:
input_values = {

View File

@ -22,7 +22,7 @@ Constraint:
5.The <api-call></api-call> part of the required output format needs to be parsed by the code. Please ensure that this part of the content is output as required.
Please respond in the following format:
Summary of your analytical thinking.<api-call><name>[Data display method]</name><args><sql>[Correct duckdb data analysis sql]</sql></args></api-call>
thoughts summary to say to user.<api-call><name>[Data display method]</name><args><sql>[Correct duckdb data analysis sql]</sql></args></api-call>
User Questions:
{user_input}
@ -38,7 +38,7 @@ _DEFAULT_TEMPLATE_ZH = """
4.优先使用数据分析的方式回答如果用户问题不涉及数据分析内容你可以按你的理解进行回答
5.要求的输出格式中<api-call></api-call>部分需要被代码解析执行请确保这部分内容按要求输出
请确保你的输出格式如下:
分析思路总结.<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
对用户说的想法摘要.<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
用户问题{user_input}
"""

View File

@ -247,17 +247,12 @@ if __name__ == "__main__":
# sql = add_quotes_to_chinese_columns(sql_3)
# print(f"excute sql:{sql}")
data = {
'name': ['John', 'Alice', 'Bob'],
'age': [30, 25, 35],
'timestamp': [pd.Timestamp('2022-01-01'), pd.Timestamp('2022-01-02'), pd.Timestamp('2022-01-03')],
'category': pd.Categorical(['A', 'B', 'C'])
}
my_list = [{'name': 'John', 'age': 30}, {'name': 'Alice', 'age': 25}, {'name': 'Bob', 'age': 35}]
df = pd.DataFrame(data)
for dict_item in my_list:
for key, value in dict_item.items():
print(key, value)
json_data = df.to_json(orient='records', date_format='iso', date_unit='s')
print(json_data)
class ExcelReader:
def __init__(self, file_path):

View File

@ -37,7 +37,7 @@ class ChatWithDbAutoExecute(BaseChat):
)
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.top_k: int = 200
self.top_k: int = 50
self.api_call = ApiCall(display_registry=CFG.command_disply)
async def generate_input_values(self):
@ -64,7 +64,7 @@ class ChatWithDbAutoExecute(BaseChat):
# table_infos = self.database.table_simple_info()
input_values = {
"input": self.current_user_input,
# "input": self.current_user_input,
"top_k": str(self.top_k),
"dialect": self.database.dialect,
"table_info": table_infos,
@ -79,4 +79,4 @@ class ChatWithDbAutoExecute(BaseChat):
def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
return self.database.run(prompt_response.sql)
return self.database.run_to_df

View File

@ -1,6 +1,8 @@
import json
from typing import Dict, NamedTuple
import logging
import xml.etree.ElementTree as ET
from pilot.common.json_utils import serialize
from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.config import Config
from pilot.scene.chat_db.data_loader import DbDataLoader
@ -31,21 +33,28 @@ class DbChatOutputParser(BaseOutputParser):
thoughts = response[key]
return SqlAction(sql, thoughts)
def parse_view_response(self, speak, data) -> str:
import pandas as pd
def parse_view_response(self, speak, data, prompt_response) -> str:
param = {}
api_call_element = ET.Element("chart-view")
try:
df = data(prompt_response.sql)
param["type"] = "response_table"
param["sql"] = prompt_response.sql
param["data"] = json.loads(df.to_json(orient='records', date_format='iso', date_unit='s'))
view_json_str = json.dumps(param, default=serialize)
except Exception as e:
err_param ={}
param["sql"] = prompt_response.sql
err_param["type"] = "response_table"
err_param["err_msg"] = str(e)
view_json_str = json.dumps(err_param, default=serialize)
api_call_element.text = view_json_str
result = ET.tostring(api_call_element, encoding="utf-8")
return speak + "\n" + result.decode("utf-8")
### tool out data to table view
data_loader = DbDataLoader()
if len(data) < 1:
data.insert(0, [])
df = pd.DataFrame(data[1:], columns=data[0])
if not CFG.NEW_SERVER_MODE and not CFG.SERVER_LIGHT_MODE:
table_style = """<style>
table{border-collapse:collapse;width:100%;height:80%;margin:0 auto;float:center;border: 1px solid #007bff; background-color:#333; color:#fff}th,td{border:1px solid #ddd;padding:3px;text-align:center}th{background-color:#C9C3C7;color: #fff;font-weight: bold;}tr:nth-child(even){background-color:#444}tr:hover{background-color:#444}
</style>"""
html_table = df.to_html(index=False, escape=False)
html = f"<html><head>{table_style}</head><body>{html_table}</body></html>"
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
return view_text
else:
return data_loader.get_table_view_by_conn(data, speak)

View File

@ -13,36 +13,35 @@ _PROMPT_SCENE_DEFINE_EN = "You are a database expert. "
_PROMPT_SCENE_DEFINE_ZH = "你是一个数据库专家. "
_DEFAULT_TEMPLATE_EN = """
Given an input question, create a syntactically correct {dialect} sql.
Table structure information:
{table_info}
Constraint:
1. You can only use the table provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql query." It is prohibited to fabricate information at will.
2. Do not query columns that do not exist. Pay attention to which column is in which table.
3. Unless the user specifies in the question a specific number of examples he wishes to obtain, always limit the query to a maximum of {top_k} results.
4. Please ensure that the output result contains: <api-call><name>response_table</name><args><sql>SQL Query to run</sql></args></api-call>and replace the generated sql into the parameter sql.
Please create a syntactically correct {dialect} sql based on the user question, use the following tables schema to generate sql:
{table_info}
Please make sure to respond as following format:
thoughts summary to say to user.<api-call><name>response_table</name><args><sql>SQL Query to run</sql></args></api-call>
Constraint:
1.Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
2.Please do not use columns that do not appear in the tables schema. Also be careful not to misunderstand the relationship between fields and tables in SQL.
3.Use as few tables as possible when querying.
4.Please check the correctness of the SQL and ensure that the query performance is optimized under correct conditions.
Question: {input}
Please think step by step and respond according to the following JSON format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads.
"""
_DEFAULT_TEMPLATE_ZH = """
给定一个输入问题创建一个语法正确的 {dialect} sql
已知表结构信息:
请根据用户输入问题使用如下的表结构定义创建一个语法正确的 {dialect} sql:
{table_info}
约束:
1. 只能使用表结构信息中提供的表来生成 sql如果无法根据提供的表结构中生成 sql 请说提供的表结构信息不足以生成 sql 查询 禁止随意捏造信息
2. 不要查询不存在的列注意哪一列位于哪张表中
3. 请确保输出结果包含<api-call><name>response_table</name><args><sql>要运行的SQL</sql></args></api-call> 并将对应的sql替换到sql参数中
4. 除非用户在问题中指定了他希望获得的具体示例数量否则始终将查询限制为最多 {top_k} 个结果
1. 除非用户在问题中指定了他希望获得的具体数据行数否则始终将查询限制为最多 {top_k} 个结果
2. 只能使用表结构信息中提供的表来生成 sql如果无法根据提供的表结构中生成 sql 请说提供的表结构信息不足以生成 sql 查询 禁止随意捏造信息
3. 请注意生成SQL时不要弄错表和列的关系
4. 请检查SQL的正确性并保证正确的情况下优化查询性能
请一步步思考并按照以下JSON格式回复
{response}
确保返回正确的json并且可以被Python json.loads方法解析.
请务必按照以下格式回复
对用户说的想法摘要<api-call><name>response_table</name><args><sql>要运行的 SQL</sql></args></api-call>
问题:{input}
"""
_DEFAULT_TEMPLATE = (
@ -60,7 +59,7 @@ RESPONSE_FORMAT_SIMPLE = {
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = True
PROMPT_NEED_NEED_STREAM_OUT = False
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
@ -69,8 +68,8 @@ PROMPT_TEMPERATURE = 0.5
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value(),
input_variables=["input", "table_info", "dialect", "top_k", "response"],
# response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
input_variables=["table_info", "dialect", "top_k", "response"],
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
@ -79,6 +78,7 @@ prompt = PromptTemplate(
),
# example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE,
need_historical_messages=True
)
CFG.prompt_template_registry.register(prompt, is_default=True)
from . import prompt_baichuan