feat(editor): ChatExcel

🔥ChatExcel Mode Complete
This commit is contained in:
yhjun1026 2023-08-28 21:25:05 +08:00
parent a06e9b29ad
commit d47c27f7dd
5 changed files with 136 additions and 49 deletions

View File

@ -3,99 +3,152 @@ from pandas import DataFrame
from pilot.commands.command_mange import command from pilot.commands.command_mange import command
from pilot.configs.config import Config from pilot.configs.config import Config
import pandas as pd import pandas as pd
import base64 import uuid
import io import io
import os
import matplotlib import matplotlib
import seaborn as sns import seaborn as sns
matplotlib.use("Agg")
# matplotlib.use("Agg")
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.font_manager import FontManager
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
from pilot.utils import build_logger from pilot.utils import build_logger
CFG = Config() CFG = Config()
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log") logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
static_message_img_path = os.path.join(os.getcwd(), "message/img")
def zh_font_set():
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
for font_name in font_names:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts
@command("response_line_chart", "Line chart display, used to display comparative trend analysis data", '"speak": "<speak>", "df":"<data frame>"') @command("response_line_chart", "Line chart display, used to display comparative trend analysis data",
def response_line_chart(speak: str, df: DataFrame) -> str: '"speak": "<speak>", "df":"<data frame>"')
def response_line_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_line_chart:{speak},") logger.info(f"response_line_chart:{speak},")
columns = df.columns.tolist() columns = df.columns.tolist()
if df.size <= 0: if df.size <= 0:
raise ValueError("No Data") raise ValueError("No Data")
plt.rcParams["font.family"] = ["sans-serif"]
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False} # set font
sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"}) # zh_font_set()
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc) font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
for font_name in font_names:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts
rc = {'font.sans-serif': can_use_fonts}
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
sns.set_palette("Set3") # 设置颜色主题
sns.set_style("dark")
sns.color_palette("hls", 10)
sns.hls_palette(8, l=.5, s=.7)
sns.set(context='notebook', style='ticks', rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100) fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
sns.lineplot(df, x=columns[0], y=columns[1], ax=ax) sns.lineplot(df, x=columns[0], y=columns[1], ax=ax)
plt.title("")
buf = io.BytesIO() chart_name = "line_" + str(uuid.uuid1()) + ".png"
plt.savefig(buf, format="png", dpi=100) chart_path = static_message_img_path + "/" + chart_name
buf.seek(0) plt.savefig(chart_path, bbox_inches='tight', dpi=100)
data = base64.b64encode(buf.getvalue()).decode("ascii")
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="data:image/png;base64,{data}" />""" html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img return html_img
@command("response_bar_chart", "Histogram, suitable for comparative analysis of multiple target values",
@command("response_bar_chart", "Histogram, suitable for comparative analysis of multiple target values", '"speak": "<speak>", "df":"<data frame>"') '"speak": "<speak>", "df":"<data frame>"')
def response_bar_chart(speak: str, df: DataFrame) -> str: def response_bar_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_bar_chart:{speak},") logger.info(f"response_bar_chart:{speak},")
columns = df.columns.tolist() columns = df.columns.tolist()
if df.size <= 0: if df.size <= 0:
raise ValueError("No Data") raise ValueError("No Data")
plt.rcParams["font.family"] = ["sans-serif"]
rc = {'font.sans-serif': "Microsoft Yahei"} # set font
sns.set(context="notebook", color_codes=True, rc=rc) # zh_font_set()
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
for font_name in font_names:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts
rc = {'font.sans-serif': can_use_fonts}
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
sns.set_palette("Set3") # 设置颜色主题
sns.set_style("dark") sns.set_style("dark")
sns.color_palette("hls", 10) sns.color_palette("hls", 10)
sns.hls_palette(8, l=.5, s=.7) sns.hls_palette(8, l=.5, s=.7)
sns.set(context='notebook', style='ticks', rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100) fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
plt.ticklabel_format(style='plain')
sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax) sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax)
plt.title("") chart_name = "pie_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
buf = io.BytesIO() plt.savefig(chart_path, bbox_inches='tight', dpi=100)
plt.savefig(buf, format="png", dpi=100) html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
buf.seek(0)
data = base64.b64encode(buf.getvalue()).decode("ascii")
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="data:image/png;base64,{data}" />"""
return html_img return html_img
@command("response_pie_chart", "Pie chart, suitable for scenarios such as proportion and distribution statistics",
@command("response_pie_chart", "Pie chart, suitable for scenarios such as proportion and distribution statistics", '"speak": "<speak>", "df":"<data frame>"') '"speak": "<speak>", "df":"<data frame>"')
def response_pie_chart(speak: str, df: DataFrame) -> str: def response_pie_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_pie_chart:{speak},") logger.info(f"response_pie_chart:{speak},")
columns = df.columns.tolist() columns = df.columns.tolist()
if df.size <= 0: if df.size <= 0:
raise ValueError("No Data") raise ValueError("No Data")
plt.rcParams["font.family"] = ["sans-serif"] # set font
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False} # zh_font_set()
sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"}) font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc) fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
for font_name in font_names:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
sns.set_palette("Set3") # 设置颜色主题 sns.set_palette("Set3") # 设置颜色主题
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90) # fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100) fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%') ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%')
# 手动设置 labels 的位置和大小 # 手动设置 labels 的位置和大小
ax.legend(loc='upper right', bbox_to_anchor=(1, 1, 1, 1), labels=df[columns[0]].values, fontsize=10) ax.legend(loc='upper right', bbox_to_anchor=(0, 0, 1, 1), labels=df[columns[0]].values, fontsize=10)
plt.axis('equal') # 使饼图为正圆形 plt.axis('equal') # 使饼图为正圆形
# plt.title(columns[0]) # plt.title(columns[0])
buf = io.BytesIO() chart_name = "pie_" + str(uuid.uuid1()) + ".png"
plt.savefig(buf, format="png", dpi=100) chart_path = static_message_img_path + "/" + chart_name
buf.seek(0) plt.savefig(chart_path, bbox_inches='tight', dpi=100)
data = base64.b64encode(buf.getvalue()).decode("ascii")
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="data:image/png;base64,{data}" />""" html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img
return html_img

View File

@ -9,7 +9,7 @@ CFG = Config()
PROMPT_SCENE_DEFINE = "You are a data analysis expert. " PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
_DEFAULT_TEMPLATE = """ _DEFAULT_TEMPLATE_EN = """
Please use the data structure information of the above historical dialogue, make sure not to use column names that are not in the data structure. Please use the data structure information of the above historical dialogue, make sure not to use column names that are not in the data structure.
According to the user goal: {user_input}give the correct duckdb SQL for data analysis. According to the user goal: {user_input}give the correct duckdb SQL for data analysis.
Use the table name: {table_name} Use the table name: {table_name}
@ -24,12 +24,29 @@ Ensure the response is correct json and can be parsed by Python json.loads
""" """
_DEFAULT_TEMPLATE_ZH = """
请使用上述历史对话中的数据结构和列信息根据用户目标{user_input}给出正确的duckdb SQL进行数据分析和问题回答
请确保不要使用不在数据结构中的列名
SQL中需要使用的表名是: {table_name}
根据用户目标得到的分析SQL请从以下显示类型中选择最合适的一种用来展示结果数据如果无法确定则使用'Text'作为显示
显示类型如下:
{disply_type}
以以下 json 格式响应:
{response}
确保响应是正确的json,并且可以被Python的json.loads方法解析.
"""
RESPONSE_FORMAT_SIMPLE = { RESPONSE_FORMAT_SIMPLE = {
"sql": "analysis SQL", "sql": "analysis SQL",
"thoughts": "Current thinking and value of data analysis", "thoughts": "Current thinking and value of data analysis",
"display": "display type name" "display": "display type name"
} }
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value

View File

@ -9,7 +9,7 @@ CFG = Config()
PROMPT_SCENE_DEFINE = "You are a data analysis expert. " PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
_DEFAULT_TEMPLATE = """ _DEFAULT_TEMPLATE_EN = """
This is an example dataplease learn to understand the structure and content of this data: This is an example dataplease learn to understand the structure and content of this data:
{data_example} {data_example}
Explain the meaning and function of each column, and give a simple and clear explanation of the technical terms. Explain the meaning and function of each column, and give a simple and clear explanation of the technical terms.
@ -19,12 +19,25 @@ Please return your answer in JSON format, the return format is as follows:
{response} {response}
""" """
_DEFAULT_TEMPLATE_ZH = """
下面是一份示例数据请学习理解该数据的结构和内容:
{data_example}
分析各列数据的含义和作用并对专业术语进行简单明了的解释
提供一些分析方案思路请一步一步思考
请以JSON格式返回您的答案返回格式如下
{response}
"""
RESPONSE_FORMAT_SIMPLE = { RESPONSE_FORMAT_SIMPLE = {
"DataAnalysis": "数据内容分析总结", "DataAnalysis": "数据内容分析总结",
"ColumnAnalysis": [{"column name1": "字段1介绍专业术语解释(请尽量简单明了)"}], "ColumnAnalysis": [{"column name1": "字段1介绍专业术语解释(请尽量简单明了)"}],
"AnalysisProgram": ["1.分析方案1图表展示方式1", "2.分析方案2图表展示方式2"], "AnalysisProgram": ["1.分析方案1图表展示方式1", "2.分析方案2图表展示方式2"],
} }
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value
@ -51,3 +64,4 @@ prompt = PromptTemplate(
) )
CFG.prompt_template_registry.register(prompt, is_default=True) CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -13,6 +13,7 @@ def excel_colunm_format(old_name:str)->str:
return new_column return new_column
def add_quotes(sql, column_names=[]): def add_quotes(sql, column_names=[]):
sql = sql.replace("`", "")
parsed = sqlparse.parse(sql) parsed = sqlparse.parse(sql)
for stmt in parsed: for stmt in parsed:
for token in stmt.tokens: for token in stmt.tokens:
@ -76,8 +77,10 @@ class ExcelReader:
self.db.register(self.table_name, self.df) self.db.register(self.table_name, self.df)
def run(self, sql): def run(self, sql):
sql = sql.replace(self.table_name, f'"{self.table_name}"') if f'"{self.table_name}"' not in sql:
sql = sql.replace(self.table_name, f'"{self.table_name}"')
sql = add_quotes(sql, self.columns_map.values()) sql = add_quotes(sql, self.columns_map.values())
print(f"excute sql:{sql}")
results = self.db.execute(sql) results = self.db.execute(sql)
colunms = [] colunms = []
for descrip in results.description: for descrip in results.description:

View File

@ -32,6 +32,7 @@ from pilot.server.knowledge.api import router as knowledge_router
from pilot.openapi.api_v1.api_v1 import router as api_v1 from pilot.openapi.api_v1.api_v1 import router as api_v1
from pilot.openapi.base import validation_exception_handler from pilot.openapi.base import validation_exception_handler
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1 from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -79,10 +80,9 @@ app.include_router(api_editor_route_v1, prefix="/api")
app.include_router(knowledge_router) app.include_router(knowledge_router)
# app.include_router(api_editor_route_v1) # app.include_router(api_editor_route_v1)
app.mount("/images", StaticFiles(directory=static_message_img_path, html=True), name="static2")
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static")) app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static") app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
# app.mount("/chat", StaticFiles(directory=static_file_path + "/chat.html", html=True), name="chat")
app.add_exception_handler(RequestValidationError, validation_exception_handler) app.add_exception_handler(RequestValidationError, validation_exception_handler)