mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-30 15:21:02 +00:00
bugfix(ChatExcel): ChatExcel Language confusion bug
1.Fix ChatExcel Language confusion bug
This commit is contained in:
parent
7359616f8c
commit
61f39a18dc
@ -26,6 +26,10 @@ class BaseOutputParser(ABC):
|
||||
def __init__(self, sep: str, is_stream_out: bool = True):
|
||||
self.sep = sep
|
||||
self.is_stream_out = is_stream_out
|
||||
self.data_schema = None
|
||||
|
||||
def update(self, data_schema):
|
||||
self.data_schema = data_schema
|
||||
|
||||
def __post_process_code(self, code):
|
||||
sep = "\n```"
|
||||
|
@ -44,11 +44,13 @@ class ExcelLearning(BaseChat):
|
||||
if parent_mode:
|
||||
self.current_message.chat_mode = parent_mode.value()
|
||||
|
||||
|
||||
async def generate_input_values(self) -> Dict:
|
||||
# colunms, datas = self.excel_reader.get_sample_data()
|
||||
colunms, datas = await blocking_func_to_async(
|
||||
self._executor, self.excel_reader.get_sample_data
|
||||
)
|
||||
self.prompt_template.output_parser.update(colunms)
|
||||
copy_datas = datas.copy()
|
||||
datas.insert(0, colunms)
|
||||
|
||||
|
@ -19,11 +19,13 @@ logger = logging.getLogger(__name__)
|
||||
class LearningExcelOutputParser(BaseOutputParser):
|
||||
def __init__(self, sep: str, is_stream_out: bool):
|
||||
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
||||
self.is_downgraded = False
|
||||
|
||||
|
||||
def parse_prompt_response(self, model_out_text):
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
print("clean prompt response:", clean_str)
|
||||
try:
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
logger.info(f"parse_prompt_response:{model_out_text},{model_out_text}")
|
||||
response = json.loads(clean_str)
|
||||
for key in sorted(response):
|
||||
if key.strip() == "DataAnalysis":
|
||||
@ -34,27 +36,39 @@ class LearningExcelOutputParser(BaseOutputParser):
|
||||
plans = response[key]
|
||||
return ExcelResponse(desciption=desciption, clounms=clounms, plans=plans)
|
||||
except Exception as e:
|
||||
return model_out_text
|
||||
logger.error(f"parse_prompt_response Faild!{str(e)}")
|
||||
self.is_downgraded = True
|
||||
return ExcelResponse(desciption=model_out_text, clounms=self.data_schema, plans=None)
|
||||
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
if data and not isinstance(data, str):
|
||||
### tool out data to table view
|
||||
html_title = f"### **Data Summary**\n{data.desciption} "
|
||||
html_colunms = f"### **Data Structure**\n"
|
||||
column_index = 0
|
||||
for item in data.clounms:
|
||||
column_index += 1
|
||||
keys = item.keys()
|
||||
for key in keys:
|
||||
if self.is_downgraded:
|
||||
column_index = 0
|
||||
for item in data.clounms:
|
||||
column_index += 1
|
||||
html_colunms = (
|
||||
html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n"
|
||||
html_colunms + f"- **{column_index}.[{item}]** _未知_\n"
|
||||
)
|
||||
else:
|
||||
column_index = 0
|
||||
for item in data.clounms:
|
||||
column_index += 1
|
||||
keys = item.keys()
|
||||
for key in keys:
|
||||
html_colunms = (
|
||||
html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n"
|
||||
)
|
||||
|
||||
html_plans = f"### **Recommended analysis plan**\n"
|
||||
index = 0
|
||||
for item in data.plans:
|
||||
index += 1
|
||||
html_plans = html_plans + f"{item} \n"
|
||||
if data.plans:
|
||||
for item in data.plans:
|
||||
index += 1
|
||||
html_plans = html_plans + f"{item} \n"
|
||||
html = f"""{html_title}\n{html_colunms}\n{html_plans}"""
|
||||
return html
|
||||
else:
|
||||
|
@ -1,271 +0,0 @@
|
||||
import os
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import seaborn as sns
|
||||
import uuid
|
||||
|
||||
from pandas import DataFrame
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as mtick
|
||||
from matplotlib import font_manager
|
||||
from matplotlib.font_manager import FontManager
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import time
|
||||
from fsspec import filesystem
|
||||
import spatial
|
||||
|
||||
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
||||
|
||||
|
||||
def data_pre_classification(df: DataFrame):
|
||||
## Data pre-classification
|
||||
columns = df.columns.tolist()
|
||||
|
||||
number_columns = []
|
||||
non_numeric_colums = []
|
||||
|
||||
# 收集数据分类小于10个的列
|
||||
non_numeric_colums_value_map = {}
|
||||
numeric_colums_value_map = {}
|
||||
df_filtered = df.dropna()
|
||||
for column_name in columns:
|
||||
print(np.issubdtype(df_filtered[column_name].dtype, np.number))
|
||||
# if pd.to_numeric(df[column_name], errors='coerce').notna().all():
|
||||
# if np.issubdtype(df_filtered[column_name].dtype, np.number):
|
||||
if pd.api.types.is_numeric_dtype(df[column_name].dtypes):
|
||||
number_columns.append(column_name)
|
||||
unique_values = df[column_name].unique()
|
||||
numeric_colums_value_map.update({column_name: len(unique_values)})
|
||||
else:
|
||||
non_numeric_colums.append(column_name)
|
||||
unique_values = df[column_name].unique()
|
||||
non_numeric_colums_value_map.update({column_name: len(unique_values)})
|
||||
|
||||
sorted_numeric_colums_value_map = dict(
|
||||
sorted(numeric_colums_value_map.items(), key=lambda x: x[1])
|
||||
)
|
||||
numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys())
|
||||
|
||||
sorted_colums_value_map = dict(
|
||||
sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1])
|
||||
)
|
||||
non_numeric_colums_sort_list = list(sorted_colums_value_map.keys())
|
||||
|
||||
# Analyze x-coordinate
|
||||
if len(non_numeric_colums_sort_list) > 0:
|
||||
x_cloumn = non_numeric_colums_sort_list[-1]
|
||||
non_numeric_colums_sort_list.remove(x_cloumn)
|
||||
else:
|
||||
x_cloumn = number_columns[0]
|
||||
numeric_colums_sort_list.remove(x_cloumn)
|
||||
|
||||
# Analyze y-coordinate
|
||||
if len(numeric_colums_sort_list) > 0:
|
||||
y_column = numeric_colums_sort_list[0]
|
||||
numeric_colums_sort_list.remove(y_column)
|
||||
else:
|
||||
raise ValueError("Not enough numeric columns for chart!")
|
||||
|
||||
return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list
|
||||
|
||||
#
|
||||
# if len(non_numeric_colums) <=0:
|
||||
# sorted_colums_value_map = dict(sorted(numeric_colums_value_map.items(), key=lambda x: x[1]))
|
||||
# numeric_colums_sort_list = list(sorted_colums_value_map.keys())
|
||||
# x_column = number_columns[0]
|
||||
# hue_column = numeric_colums_sort_list[0]
|
||||
# y_column = numeric_colums_sort_list[1]
|
||||
# cols = numeric_colums_sort_list[2:]
|
||||
# elif len(number_columns) <=0:
|
||||
# raise ValueError("Have No numeric Column!")
|
||||
# else:
|
||||
# # 数字和非数字都存在多列,放弃部分数字列
|
||||
# x_column = non_numeric_colums[0]
|
||||
# y_column = number_columns[0]
|
||||
# if len(non_numeric_colums) > 1:
|
||||
# sorted_colums_value_map = dict(sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1]))
|
||||
# non_numeric_colums_sort_list = list(sorted_colums_value_map.keys())
|
||||
# non_numeric_colums_sort_list.remove(non_numeric_colums[0])
|
||||
# hue_column = non_numeric_colums_sort_list[0]
|
||||
# if len(number_columns) > 1:
|
||||
# # try multiple charts
|
||||
# cols = number_columns.remove( number_columns[0])
|
||||
#
|
||||
# else:
|
||||
# sorted_colums_value_map = dict(sorted(numeric_colums_value_map.items(), key=lambda x: x[1]))
|
||||
# numeric_colums_sort_list = list(sorted_colums_value_map.keys())
|
||||
# numeric_colums_sort_list.remove(number_columns[0])
|
||||
# if sorted_colums_value_map[numeric_colums_sort_list[0]].value < 5:
|
||||
# hue_column = numeric_colums_sort_list[0]
|
||||
# if len(number_columns) > 2:
|
||||
# # try multiple charts
|
||||
# cols = numeric_colums_sort_list.remove(numeric_colums_sort_list[0])
|
||||
#
|
||||
# print(x_column, y_column, hue_column, cols)
|
||||
# return x_column, y_column, hue_column
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# connect = duckdb.connect("/Users/tuyang.yhj/Downloads/example.xlsx")
|
||||
#
|
||||
|
||||
# fonts = fm.findSystemFonts()
|
||||
# for font in fonts:
|
||||
# if 'Hei' in font:
|
||||
# print(font)
|
||||
|
||||
# fm = FontManager()
|
||||
# mat_fonts = set(f.name for f in fm.ttflist)
|
||||
# for i in mat_fonts:
|
||||
# print(i)
|
||||
# print(len(mat_fonts))
|
||||
# 获取系统中的默认中文字体名称
|
||||
# default_font = fm.fontManager.defaultFontProperties.get_family()
|
||||
|
||||
# 创建一个示例 DataFrame
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"A": [1, 2, 3, None, 5],
|
||||
"B": [10, 20, 30, 40, 50],
|
||||
"C": [1.1, 2.2, None, 4.4, 5.5],
|
||||
"D": ["a", "b", "c", "d", "e"],
|
||||
}
|
||||
)
|
||||
|
||||
# 判断列是否为数字列
|
||||
column_name = "A" # 要判断的列名
|
||||
is_numeric = pd.to_numeric(df[column_name], errors="coerce").notna().all()
|
||||
|
||||
if is_numeric:
|
||||
print(
|
||||
f"Column '{column_name}' is a numeric column (ignoring null and NaN values in some elements)."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Column '{column_name}' is not a numeric column (ignoring null and NaN values in some elements)."
|
||||
)
|
||||
|
||||
#
|
||||
# excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx")
|
||||
excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/yhj-zx.csv")
|
||||
#
|
||||
# # colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear")
|
||||
# # colunms, datas = excel_reader.run( """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """)
|
||||
# df = excel_reader.get_df_by_sql_ex(""" SELECT Segment, Country, SUM(Sales) AS Total_Sales, SUM(Profit) AS Total_Profit FROM example GROUP BY Segment, Country """)
|
||||
df = excel_reader.get_df_by_sql_ex(
|
||||
""" SELECT 大项, AVG(实际) AS 平均实际支出, AVG(已支出) AS 平均已支出 FROM yhj-zx GROUP BY 大项"""
|
||||
)
|
||||
|
||||
for column_name in df.columns.tolist():
|
||||
print(column_name + ":" + str(df[column_name].dtypes))
|
||||
print(
|
||||
column_name
|
||||
+ ":"
|
||||
+ str(pd.api.types.is_numeric_dtype(df[column_name].dtypes))
|
||||
)
|
||||
|
||||
columns = df.columns.tolist()
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
rc = {"font.sans-serif": can_use_fonts}
|
||||
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||
sns.set(font="Heiti TC", font_scale=0.8) # 解决Seaborn中文显示问题
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
sns.set_style("dark")
|
||||
sns.color_palette("hls", 10)
|
||||
sns.hls_palette(8, l=0.5, s=0.7)
|
||||
sns.set(context="notebook", style="ticks", rc=rc)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
# plt.ticklabel_format(style='plain')
|
||||
# ax = df.plot(kind='bar', ax=ax)
|
||||
# sns.barplot(df, x=x, y="Total_Sales", hue='Country', ax=ax)
|
||||
# sns.barplot(df, x=x, y="Total_Profit", hue='Country', ax=ax)
|
||||
|
||||
# sns.catplot(data=df, x=x, y=y, hue='Country', kind='bar')
|
||||
# x, y, non_num_columns, num_colmns = data_pre_classification(df)
|
||||
# print(x, y, str(non_num_columns), str(num_colmns))
|
||||
## 复杂折线图实现
|
||||
# if len(num_colmns) > 0:
|
||||
# num_colmns.append(y)
|
||||
# df_melted = pd.melt(
|
||||
# df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
|
||||
# )
|
||||
# sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
|
||||
# else:
|
||||
# sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
|
||||
|
||||
hue = None
|
||||
## 复杂柱状图实现
|
||||
x, y, non_num_columns, num_colmns = data_pre_classification(df)
|
||||
|
||||
if len(non_num_columns) >= 1:
|
||||
hue = non_num_columns[0]
|
||||
|
||||
if len(num_colmns) >= 1:
|
||||
if hue:
|
||||
if len(num_colmns) >= 2:
|
||||
can_use_columns = num_colmns[:2]
|
||||
else:
|
||||
can_use_columns = num_colmns
|
||||
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
|
||||
for sub_y_column in can_use_columns:
|
||||
sns.barplot(
|
||||
data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax
|
||||
)
|
||||
else:
|
||||
if len(num_colmns) > 5:
|
||||
can_use_columns = num_colmns[:5]
|
||||
else:
|
||||
can_use_columns = num_colmns
|
||||
can_use_columns.append(y)
|
||||
|
||||
df_melted = pd.melt(
|
||||
df,
|
||||
id_vars=x,
|
||||
value_vars=can_use_columns,
|
||||
var_name="line",
|
||||
value_name="Value",
|
||||
)
|
||||
sns.barplot(
|
||||
data=df_melted, x=x, y="Value", hue="line", palette="Set2", ax=ax
|
||||
)
|
||||
else:
|
||||
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
|
||||
|
||||
# # 转换 DataFrame 格式
|
||||
# df_melted = pd.melt(df, id_vars=x, value_vars=['Total_Sales', 'Total_Profit'], var_name='line', value_name='y')
|
||||
#
|
||||
# # 绘制多列柱状图
|
||||
#
|
||||
# sns.barplot(data=df, x=x, y="Total_Sales", hue = "Country", palette="Set2", ax=ax)
|
||||
# sns.barplot(data=df, x=x, y="Total_Profit", hue = "Country", palette="Set1", ax=ax)
|
||||
|
||||
# 设置 y 轴刻度格式为普通数字格式
|
||||
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
|
||||
|
||||
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = chart_name
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||
|
||||
#
|
Loading…
Reference in New Issue
Block a user