mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-10 12:42:34 +00:00
feat(editor): ChatExcel
🔥ChatExcel Mode Complete
This commit is contained in:
parent
a06e9b29ad
commit
d47c27f7dd
@ -3,99 +3,152 @@ from pandas import DataFrame
|
||||
from pilot.commands.command_mange import command
|
||||
from pilot.configs.config import Config
|
||||
import pandas as pd
|
||||
import base64
|
||||
import uuid
|
||||
import io
|
||||
import os
|
||||
import matplotlib
|
||||
import seaborn as sns
|
||||
matplotlib.use("Agg")
|
||||
|
||||
# matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.font_manager import FontManager
|
||||
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils import build_logger
|
||||
|
||||
CFG = Config()
|
||||
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
|
||||
static_message_img_path = os.path.join(os.getcwd(), "message/img")
|
||||
|
||||
def zh_font_set():
|
||||
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams['font.sans-serif'] = can_use_fonts
|
||||
|
||||
|
||||
@command("response_line_chart", "Line chart display, used to display comparative trend analysis data", '"speak": "<speak>", "df":"<data frame>"')
|
||||
def response_line_chart(speak: str, df: DataFrame) -> str:
|
||||
@command("response_line_chart", "Line chart display, used to display comparative trend analysis data",
|
||||
'"speak": "<speak>", "df":"<data frame>"')
|
||||
def response_line_chart(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_line_chart:{speak},")
|
||||
|
||||
columns = df.columns.tolist()
|
||||
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
plt.rcParams["font.family"] = ["sans-serif"]
|
||||
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
|
||||
sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"})
|
||||
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
|
||||
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams['font.sans-serif'] = can_use_fonts
|
||||
|
||||
rc = {'font.sans-serif': can_use_fonts}
|
||||
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
|
||||
|
||||
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
sns.set_style("dark")
|
||||
sns.color_palette("hls", 10)
|
||||
sns.hls_palette(8, l=.5, s=.7)
|
||||
sns.set(context='notebook', style='ticks', rc=rc)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
sns.lineplot(df, x=columns[0], y=columns[1], ax=ax)
|
||||
plt.title("")
|
||||
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png", dpi=100)
|
||||
buf.seek(0)
|
||||
data = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
chart_name = "line_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
|
||||
|
||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="data:image/png;base64,{data}" />"""
|
||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
|
||||
|
||||
|
||||
@command("response_bar_chart", "Histogram, suitable for comparative analysis of multiple target values", '"speak": "<speak>", "df":"<data frame>"')
|
||||
def response_bar_chart(speak: str, df: DataFrame) -> str:
|
||||
@command("response_bar_chart", "Histogram, suitable for comparative analysis of multiple target values",
|
||||
'"speak": "<speak>", "df":"<data frame>"')
|
||||
def response_bar_chart(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_bar_chart:{speak},")
|
||||
columns = df.columns.tolist()
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
plt.rcParams["font.family"] = ["sans-serif"]
|
||||
rc = {'font.sans-serif': "Microsoft Yahei"}
|
||||
sns.set(context="notebook", color_codes=True, rc=rc)
|
||||
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams['font.sans-serif'] = can_use_fonts
|
||||
|
||||
rc = {'font.sans-serif': can_use_fonts}
|
||||
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
|
||||
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
sns.set_style("dark")
|
||||
sns.color_palette("hls", 10)
|
||||
sns.hls_palette(8, l=.5, s=.7)
|
||||
sns.set(context='notebook', style='ticks', rc=rc)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
plt.ticklabel_format(style='plain')
|
||||
sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax)
|
||||
|
||||
plt.title("")
|
||||
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png", dpi=100)
|
||||
buf.seek(0)
|
||||
data = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="data:image/png;base64,{data}" />"""
|
||||
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
|
||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
|
||||
|
||||
|
||||
@command("response_pie_chart", "Pie chart, suitable for scenarios such as proportion and distribution statistics", '"speak": "<speak>", "df":"<data frame>"')
|
||||
def response_pie_chart(speak: str, df: DataFrame) -> str:
|
||||
@command("response_pie_chart", "Pie chart, suitable for scenarios such as proportion and distribution statistics",
|
||||
'"speak": "<speak>", "df":"<data frame>"')
|
||||
def response_pie_chart(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_pie_chart:{speak},")
|
||||
columns = df.columns.tolist()
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
plt.rcParams["font.family"] = ["sans-serif"]
|
||||
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
|
||||
sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"})
|
||||
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams['font.sans-serif'] = can_use_fonts
|
||||
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
|
||||
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
|
||||
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%')
|
||||
ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%')
|
||||
# 手动设置 labels 的位置和大小
|
||||
ax.legend(loc='upper right', bbox_to_anchor=(1, 1, 1, 1), labels=df[columns[0]].values, fontsize=10)
|
||||
ax.legend(loc='upper right', bbox_to_anchor=(0, 0, 1, 1), labels=df[columns[0]].values, fontsize=10)
|
||||
|
||||
plt.axis('equal') # 使饼图为正圆形
|
||||
# plt.title(columns[0])
|
||||
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="png", dpi=100)
|
||||
buf.seek(0)
|
||||
data = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
|
||||
|
||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="data:image/png;base64,{data}" />"""
|
||||
return html_img
|
||||
html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
|
||||
return html_img
|
||||
|
@ -9,7 +9,7 @@ CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
Please use the data structure information of the above historical dialogue, make sure not to use column names that are not in the data structure.
|
||||
According to the user goal: {user_input},give the correct duckdb SQL for data analysis.
|
||||
Use the table name: {table_name}
|
||||
@ -24,12 +24,29 @@ Ensure the response is correct json and can be parsed by Python json.loads
|
||||
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """
|
||||
请使用上述历史对话中的数据结构和列信息,根据用户目标:{user_input},给出正确的duckdb SQL进行数据分析和问题回答。
|
||||
请确保不要使用不在数据结构中的列名。
|
||||
SQL中需要使用的表名是: {table_name}
|
||||
|
||||
根据用户目标得到的分析SQL,请从以下显示类型中选择最合适的一种用来展示结果数据,如果无法确定,则使用'Text'作为显示。
|
||||
显示类型如下:
|
||||
{disply_type}
|
||||
|
||||
以以下 json 格式响应::
|
||||
{response}
|
||||
确保响应是正确的json,并且可以被Python的json.loads方法解析.
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT_SIMPLE = {
|
||||
"sql": "analysis SQL",
|
||||
"thoughts": "Current thinking and value of data analysis",
|
||||
"display": "display type name"
|
||||
}
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
|
@ -9,7 +9,7 @@ CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
This is an example data,please learn to understand the structure and content of this data:
|
||||
{data_example}
|
||||
Explain the meaning and function of each column, and give a simple and clear explanation of the technical terms.
|
||||
@ -19,12 +19,25 @@ Please return your answer in JSON format, the return format is as follows:
|
||||
{response}
|
||||
"""
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """
|
||||
下面是一份示例数据,请学习理解该数据的结构和内容:
|
||||
{data_example}
|
||||
分析各列数据的含义和作用,并对专业术语进行简单明了的解释。
|
||||
提供一些分析方案思路,请一步一步思考。
|
||||
|
||||
请以JSON格式返回您的答案,返回格式如下:
|
||||
{response}
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT_SIMPLE = {
|
||||
"DataAnalysis": "数据内容分析总结",
|
||||
"ColumnAnalysis": [{"column name1": "字段1介绍,专业术语解释(请尽量简单明了)"}],
|
||||
"AnalysisProgram": ["1.分析方案1,图表展示方式1", "2.分析方案2,图表展示方式2"],
|
||||
}
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
@ -51,3 +64,4 @@ prompt = PromptTemplate(
|
||||
)
|
||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||
|
||||
|
||||
|
@ -13,6 +13,7 @@ def excel_colunm_format(old_name:str)->str:
|
||||
return new_column
|
||||
|
||||
def add_quotes(sql, column_names=[]):
|
||||
sql = sql.replace("`", "")
|
||||
parsed = sqlparse.parse(sql)
|
||||
for stmt in parsed:
|
||||
for token in stmt.tokens:
|
||||
@ -76,8 +77,10 @@ class ExcelReader:
|
||||
self.db.register(self.table_name, self.df)
|
||||
|
||||
def run(self, sql):
|
||||
sql = sql.replace(self.table_name, f'"{self.table_name}"')
|
||||
if f'"{self.table_name}"' not in sql:
|
||||
sql = sql.replace(self.table_name, f'"{self.table_name}"')
|
||||
sql = add_quotes(sql, self.columns_map.values())
|
||||
print(f"excute sql:{sql}")
|
||||
results = self.db.execute(sql)
|
||||
colunms = []
|
||||
for descrip in results.description:
|
||||
|
@ -32,6 +32,7 @@ from pilot.server.knowledge.api import router as knowledge_router
|
||||
from pilot.openapi.api_v1.api_v1 import router as api_v1
|
||||
from pilot.openapi.base import validation_exception_handler
|
||||
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
@ -79,10 +80,9 @@ app.include_router(api_editor_route_v1, prefix="/api")
|
||||
app.include_router(knowledge_router)
|
||||
# app.include_router(api_editor_route_v1)
|
||||
|
||||
app.mount("/images", StaticFiles(directory=static_message_img_path, html=True), name="static2")
|
||||
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
|
||||
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
|
||||
# app.mount("/chat", StaticFiles(directory=static_file_path + "/chat.html", html=True), name="chat")
|
||||
|
||||
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user