mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-13 05:55:54 +00:00
feat(editor): ChatExcel
🔥ChatExcel Mode Complete
This commit is contained in:
parent
a06e9b29ad
commit
d47c27f7dd
@ -3,99 +3,152 @@ from pandas import DataFrame
|
|||||||
from pilot.commands.command_mange import command
|
from pilot.commands.command_mange import command
|
||||||
from pilot.configs.config import Config
|
from pilot.configs.config import Config
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import base64
|
import uuid
|
||||||
import io
|
import io
|
||||||
|
import os
|
||||||
import matplotlib
|
import matplotlib
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
matplotlib.use("Agg")
|
|
||||||
|
# matplotlib.use("Agg")
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.font_manager import FontManager
|
||||||
|
|
||||||
from pilot.configs.model_config import LOGDIR
|
from pilot.configs.model_config import LOGDIR
|
||||||
from pilot.utils import build_logger
|
from pilot.utils import build_logger
|
||||||
|
|
||||||
CFG = Config()
|
CFG = Config()
|
||||||
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
|
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
|
||||||
|
static_message_img_path = os.path.join(os.getcwd(), "message/img")
|
||||||
|
|
||||||
|
def zh_font_set():
|
||||||
|
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
|
||||||
|
fm = FontManager()
|
||||||
|
mat_fonts = set(f.name for f in fm.ttflist)
|
||||||
|
can_use_fonts = []
|
||||||
|
for font_name in font_names:
|
||||||
|
if font_name in mat_fonts:
|
||||||
|
can_use_fonts.append(font_name)
|
||||||
|
if len(can_use_fonts) > 0:
|
||||||
|
plt.rcParams['font.sans-serif'] = can_use_fonts
|
||||||
|
|
||||||
|
|
||||||
@command("response_line_chart", "Line chart display, used to display comparative trend analysis data", '"speak": "<speak>", "df":"<data frame>"')
|
@command("response_line_chart", "Line chart display, used to display comparative trend analysis data",
|
||||||
def response_line_chart(speak: str, df: DataFrame) -> str:
|
'"speak": "<speak>", "df":"<data frame>"')
|
||||||
|
def response_line_chart(speak: str, df: DataFrame) -> str:
|
||||||
logger.info(f"response_line_chart:{speak},")
|
logger.info(f"response_line_chart:{speak},")
|
||||||
|
|
||||||
columns = df.columns.tolist()
|
columns = df.columns.tolist()
|
||||||
|
|
||||||
if df.size <= 0:
|
if df.size <= 0:
|
||||||
raise ValueError("No Data!")
|
raise ValueError("No Data!")
|
||||||
plt.rcParams["font.family"] = ["sans-serif"]
|
|
||||||
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
|
# set font
|
||||||
sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"})
|
# zh_font_set()
|
||||||
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
|
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
|
||||||
|
fm = FontManager()
|
||||||
|
mat_fonts = set(f.name for f in fm.ttflist)
|
||||||
|
can_use_fonts = []
|
||||||
|
for font_name in font_names:
|
||||||
|
if font_name in mat_fonts:
|
||||||
|
can_use_fonts.append(font_name)
|
||||||
|
if len(can_use_fonts) > 0:
|
||||||
|
plt.rcParams['font.sans-serif'] = can_use_fonts
|
||||||
|
|
||||||
|
rc = {'font.sans-serif': can_use_fonts}
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
|
||||||
|
|
||||||
|
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
|
||||||
|
sns.set_palette("Set3") # 设置颜色主题
|
||||||
|
sns.set_style("dark")
|
||||||
|
sns.color_palette("hls", 10)
|
||||||
|
sns.hls_palette(8, l=.5, s=.7)
|
||||||
|
sns.set(context='notebook', style='ticks', rc=rc)
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||||
sns.lineplot(df, x=columns[0], y=columns[1], ax=ax)
|
sns.lineplot(df, x=columns[0], y=columns[1], ax=ax)
|
||||||
plt.title("")
|
|
||||||
|
|
||||||
buf = io.BytesIO()
|
chart_name = "line_" + str(uuid.uuid1()) + ".png"
|
||||||
plt.savefig(buf, format="png", dpi=100)
|
chart_path = static_message_img_path + "/" + chart_name
|
||||||
buf.seek(0)
|
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
|
||||||
data = base64.b64encode(buf.getvalue()).decode("ascii")
|
|
||||||
|
|
||||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="data:image/png;base64,{data}" />"""
|
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||||
return html_img
|
return html_img
|
||||||
|
|
||||||
|
|
||||||
|
@command("response_bar_chart", "Histogram, suitable for comparative analysis of multiple target values",
|
||||||
@command("response_bar_chart", "Histogram, suitable for comparative analysis of multiple target values", '"speak": "<speak>", "df":"<data frame>"')
|
'"speak": "<speak>", "df":"<data frame>"')
|
||||||
def response_bar_chart(speak: str, df: DataFrame) -> str:
|
def response_bar_chart(speak: str, df: DataFrame) -> str:
|
||||||
logger.info(f"response_bar_chart:{speak},")
|
logger.info(f"response_bar_chart:{speak},")
|
||||||
columns = df.columns.tolist()
|
columns = df.columns.tolist()
|
||||||
if df.size <= 0:
|
if df.size <= 0:
|
||||||
raise ValueError("No Data!")
|
raise ValueError("No Data!")
|
||||||
plt.rcParams["font.family"] = ["sans-serif"]
|
|
||||||
rc = {'font.sans-serif': "Microsoft Yahei"}
|
# set font
|
||||||
sns.set(context="notebook", color_codes=True, rc=rc)
|
# zh_font_set()
|
||||||
|
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
|
||||||
|
fm = FontManager()
|
||||||
|
mat_fonts = set(f.name for f in fm.ttflist)
|
||||||
|
can_use_fonts = []
|
||||||
|
for font_name in font_names:
|
||||||
|
if font_name in mat_fonts:
|
||||||
|
can_use_fonts.append(font_name)
|
||||||
|
if len(can_use_fonts) > 0:
|
||||||
|
plt.rcParams['font.sans-serif'] = can_use_fonts
|
||||||
|
|
||||||
|
rc = {'font.sans-serif': can_use_fonts}
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
|
||||||
|
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
|
||||||
|
sns.set_palette("Set3") # 设置颜色主题
|
||||||
sns.set_style("dark")
|
sns.set_style("dark")
|
||||||
sns.color_palette("hls", 10)
|
sns.color_palette("hls", 10)
|
||||||
sns.hls_palette(8, l=.5, s=.7)
|
sns.hls_palette(8, l=.5, s=.7)
|
||||||
|
sns.set(context='notebook', style='ticks', rc=rc)
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||||
plt.ticklabel_format(style='plain')
|
|
||||||
sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax)
|
sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax)
|
||||||
|
|
||||||
plt.title("")
|
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
|
||||||
|
chart_path = static_message_img_path + "/" + chart_name
|
||||||
buf = io.BytesIO()
|
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
|
||||||
plt.savefig(buf, format="png", dpi=100)
|
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||||
buf.seek(0)
|
|
||||||
data = base64.b64encode(buf.getvalue()).decode("ascii")
|
|
||||||
|
|
||||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="data:image/png;base64,{data}" />"""
|
|
||||||
return html_img
|
return html_img
|
||||||
|
|
||||||
|
|
||||||
|
@command("response_pie_chart", "Pie chart, suitable for scenarios such as proportion and distribution statistics",
|
||||||
@command("response_pie_chart", "Pie chart, suitable for scenarios such as proportion and distribution statistics", '"speak": "<speak>", "df":"<data frame>"')
|
'"speak": "<speak>", "df":"<data frame>"')
|
||||||
def response_pie_chart(speak: str, df: DataFrame) -> str:
|
def response_pie_chart(speak: str, df: DataFrame) -> str:
|
||||||
logger.info(f"response_pie_chart:{speak},")
|
logger.info(f"response_pie_chart:{speak},")
|
||||||
columns = df.columns.tolist()
|
columns = df.columns.tolist()
|
||||||
if df.size <= 0:
|
if df.size <= 0:
|
||||||
raise ValueError("No Data!")
|
raise ValueError("No Data!")
|
||||||
plt.rcParams["font.family"] = ["sans-serif"]
|
# set font
|
||||||
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
|
# zh_font_set()
|
||||||
sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"})
|
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
|
||||||
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
|
fm = FontManager()
|
||||||
|
mat_fonts = set(f.name for f in fm.ttflist)
|
||||||
|
can_use_fonts = []
|
||||||
|
for font_name in font_names:
|
||||||
|
if font_name in mat_fonts:
|
||||||
|
can_use_fonts.append(font_name)
|
||||||
|
if len(can_use_fonts) > 0:
|
||||||
|
plt.rcParams['font.sans-serif'] = can_use_fonts
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
|
||||||
|
|
||||||
sns.set_palette("Set3") # 设置颜色主题
|
sns.set_palette("Set3") # 设置颜色主题
|
||||||
|
|
||||||
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
|
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
|
||||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||||
ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%')
|
ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%')
|
||||||
# 手动设置 labels 的位置和大小
|
# 手动设置 labels 的位置和大小
|
||||||
ax.legend(loc='upper right', bbox_to_anchor=(1, 1, 1, 1), labels=df[columns[0]].values, fontsize=10)
|
ax.legend(loc='upper right', bbox_to_anchor=(0, 0, 1, 1), labels=df[columns[0]].values, fontsize=10)
|
||||||
|
|
||||||
plt.axis('equal') # 使饼图为正圆形
|
plt.axis('equal') # 使饼图为正圆形
|
||||||
# plt.title(columns[0])
|
# plt.title(columns[0])
|
||||||
|
|
||||||
buf = io.BytesIO()
|
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
|
||||||
plt.savefig(buf, format="png", dpi=100)
|
chart_path = static_message_img_path + "/" + chart_name
|
||||||
buf.seek(0)
|
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
|
||||||
data = base64.b64encode(buf.getvalue()).decode("ascii")
|
|
||||||
|
|
||||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="data:image/png;base64,{data}" />"""
|
html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||||
return html_img
|
|
||||||
|
return html_img
|
||||||
|
@ -9,7 +9,7 @@ CFG = Config()
|
|||||||
|
|
||||||
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
|
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = """
|
_DEFAULT_TEMPLATE_EN = """
|
||||||
Please use the data structure information of the above historical dialogue, make sure not to use column names that are not in the data structure.
|
Please use the data structure information of the above historical dialogue, make sure not to use column names that are not in the data structure.
|
||||||
According to the user goal: {user_input},give the correct duckdb SQL for data analysis.
|
According to the user goal: {user_input},give the correct duckdb SQL for data analysis.
|
||||||
Use the table name: {table_name}
|
Use the table name: {table_name}
|
||||||
@ -24,12 +24,29 @@ Ensure the response is correct json and can be parsed by Python json.loads
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE_ZH = """
|
||||||
|
请使用上述历史对话中的数据结构和列信息,根据用户目标:{user_input},给出正确的duckdb SQL进行数据分析和问题回答。
|
||||||
|
请确保不要使用不在数据结构中的列名。
|
||||||
|
SQL中需要使用的表名是: {table_name}
|
||||||
|
|
||||||
|
根据用户目标得到的分析SQL,请从以下显示类型中选择最合适的一种用来展示结果数据,如果无法确定,则使用'Text'作为显示。
|
||||||
|
显示类型如下:
|
||||||
|
{disply_type}
|
||||||
|
|
||||||
|
以以下 json 格式响应::
|
||||||
|
{response}
|
||||||
|
确保响应是正确的json,并且可以被Python的json.loads方法解析.
|
||||||
|
"""
|
||||||
|
|
||||||
RESPONSE_FORMAT_SIMPLE = {
|
RESPONSE_FORMAT_SIMPLE = {
|
||||||
"sql": "analysis SQL",
|
"sql": "analysis SQL",
|
||||||
"thoughts": "Current thinking and value of data analysis",
|
"thoughts": "Current thinking and value of data analysis",
|
||||||
"display": "display type name"
|
"display": "display type name"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = (
|
||||||
|
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||||
|
)
|
||||||
|
|
||||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ CFG = Config()
|
|||||||
|
|
||||||
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
|
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
|
||||||
|
|
||||||
_DEFAULT_TEMPLATE = """
|
_DEFAULT_TEMPLATE_EN = """
|
||||||
This is an example data,please learn to understand the structure and content of this data:
|
This is an example data,please learn to understand the structure and content of this data:
|
||||||
{data_example}
|
{data_example}
|
||||||
Explain the meaning and function of each column, and give a simple and clear explanation of the technical terms.
|
Explain the meaning and function of each column, and give a simple and clear explanation of the technical terms.
|
||||||
@ -19,12 +19,25 @@ Please return your answer in JSON format, the return format is as follows:
|
|||||||
{response}
|
{response}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE_ZH = """
|
||||||
|
下面是一份示例数据,请学习理解该数据的结构和内容:
|
||||||
|
{data_example}
|
||||||
|
分析各列数据的含义和作用,并对专业术语进行简单明了的解释。
|
||||||
|
提供一些分析方案思路,请一步一步思考。
|
||||||
|
|
||||||
|
请以JSON格式返回您的答案,返回格式如下:
|
||||||
|
{response}
|
||||||
|
"""
|
||||||
|
|
||||||
RESPONSE_FORMAT_SIMPLE = {
|
RESPONSE_FORMAT_SIMPLE = {
|
||||||
"DataAnalysis": "数据内容分析总结",
|
"DataAnalysis": "数据内容分析总结",
|
||||||
"ColumnAnalysis": [{"column name1": "字段1介绍,专业术语解释(请尽量简单明了)"}],
|
"ColumnAnalysis": [{"column name1": "字段1介绍,专业术语解释(请尽量简单明了)"}],
|
||||||
"AnalysisProgram": ["1.分析方案1,图表展示方式1", "2.分析方案2,图表展示方式2"],
|
"AnalysisProgram": ["1.分析方案1,图表展示方式1", "2.分析方案2,图表展示方式2"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_DEFAULT_TEMPLATE = (
|
||||||
|
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||||
@ -51,3 +64,4 @@ prompt = PromptTemplate(
|
|||||||
)
|
)
|
||||||
CFG.prompt_template_registry.register(prompt, is_default=True)
|
CFG.prompt_template_registry.register(prompt, is_default=True)
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ def excel_colunm_format(old_name:str)->str:
|
|||||||
return new_column
|
return new_column
|
||||||
|
|
||||||
def add_quotes(sql, column_names=[]):
|
def add_quotes(sql, column_names=[]):
|
||||||
|
sql = sql.replace("`", "")
|
||||||
parsed = sqlparse.parse(sql)
|
parsed = sqlparse.parse(sql)
|
||||||
for stmt in parsed:
|
for stmt in parsed:
|
||||||
for token in stmt.tokens:
|
for token in stmt.tokens:
|
||||||
@ -76,8 +77,10 @@ class ExcelReader:
|
|||||||
self.db.register(self.table_name, self.df)
|
self.db.register(self.table_name, self.df)
|
||||||
|
|
||||||
def run(self, sql):
|
def run(self, sql):
|
||||||
sql = sql.replace(self.table_name, f'"{self.table_name}"')
|
if f'"{self.table_name}"' not in sql:
|
||||||
|
sql = sql.replace(self.table_name, f'"{self.table_name}"')
|
||||||
sql = add_quotes(sql, self.columns_map.values())
|
sql = add_quotes(sql, self.columns_map.values())
|
||||||
|
print(f"excute sql:{sql}")
|
||||||
results = self.db.execute(sql)
|
results = self.db.execute(sql)
|
||||||
colunms = []
|
colunms = []
|
||||||
for descrip in results.description:
|
for descrip in results.description:
|
||||||
|
@ -32,6 +32,7 @@ from pilot.server.knowledge.api import router as knowledge_router
|
|||||||
from pilot.openapi.api_v1.api_v1 import router as api_v1
|
from pilot.openapi.api_v1.api_v1 import router as api_v1
|
||||||
from pilot.openapi.base import validation_exception_handler
|
from pilot.openapi.base import validation_exception_handler
|
||||||
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||||
|
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
@ -79,10 +80,9 @@ app.include_router(api_editor_route_v1, prefix="/api")
|
|||||||
app.include_router(knowledge_router)
|
app.include_router(knowledge_router)
|
||||||
# app.include_router(api_editor_route_v1)
|
# app.include_router(api_editor_route_v1)
|
||||||
|
|
||||||
|
app.mount("/images", StaticFiles(directory=static_message_img_path, html=True), name="static2")
|
||||||
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
|
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
|
||||||
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
|
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")
|
||||||
# app.mount("/chat", StaticFiles(directory=static_file_path + "/chat.html", html=True), name="chat")
|
|
||||||
|
|
||||||
|
|
||||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user