diff --git a/pilot/commands/disply_type/show_chart_gen.py b/pilot/commands/disply_type/show_chart_gen.py
index 27f1639d5..6a3ee3e27 100644
--- a/pilot/commands/disply_type/show_chart_gen.py
+++ b/pilot/commands/disply_type/show_chart_gen.py
@@ -4,13 +4,13 @@ from pilot.commands.command_mange import command
from pilot.configs.config import Config
import pandas as pd
import uuid
-import io
import os
import matplotlib
import seaborn as sns
matplotlib.use("Agg")
import matplotlib.pyplot as plt
+import matplotlib.ticker as mtick
from matplotlib.font_manager import FontManager
from pilot.configs.model_config import LOGDIR
@@ -21,6 +21,54 @@ logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
static_message_img_path = os.path.join(os.getcwd(), "message/img")
+def data_pre_classification(df: DataFrame):
+ ## Data pre-classification
+ columns = df.columns.tolist()
+
+ number_columns = []
+ non_numeric_colums = []
+
+ # 收集数据分类小于10个的列
+ non_numeric_colums_value_map = {}
+ numeric_colums_value_map = {}
+ for column_name in columns:
+ if pd.api.types.is_numeric_dtype(df[column_name].dtypes):
+ number_columns.append(column_name)
+ unique_values = df[column_name].unique()
+ numeric_colums_value_map.update({column_name: len(unique_values)})
+ else:
+ non_numeric_colums.append(column_name)
+ unique_values = df[column_name].unique()
+ non_numeric_colums_value_map.update({column_name: len(unique_values)})
+
+ sorted_numeric_colums_value_map = dict(
+ sorted(numeric_colums_value_map.items(), key=lambda x: x[1])
+ )
+ numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys())
+
+ sorted_colums_value_map = dict(
+ sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1])
+ )
+ non_numeric_colums_sort_list = list(sorted_colums_value_map.keys())
+
+ # Analyze x-coordinate
+ if len(non_numeric_colums_sort_list) > 0:
+ x_cloumn = non_numeric_colums_sort_list[-1]
+ non_numeric_colums_sort_list.remove(x_cloumn)
+ else:
+ x_cloumn = number_columns[0]
+ numeric_colums_sort_list.remove(x_cloumn)
+
+ # Analyze y-coordinate
+ if len(numeric_colums_sort_list) > 0:
+ y_column = numeric_colums_sort_list[0]
+ numeric_colums_sort_list.remove(y_column)
+ else:
+ raise ValueError("Not enough numeric columns for chart!")
+
+ return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list
+
+
def zh_font_set():
font_names = [
"Heiti TC",
@@ -48,9 +96,6 @@ def zh_font_set():
)
def response_line_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_line_chart:{speak},")
-
- columns = df.columns.tolist()
-
if df.size <= 0:
raise ValueError("No Data!")
@@ -85,7 +130,19 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
sns.set(context="notebook", style="ticks", rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
- sns.lineplot(df, x=columns[0], y=columns[1], ax=ax)
+ x, y, non_num_columns, num_colmns = data_pre_classification(df)
+ # ## 复杂折线图实现
+ if len(num_colmns) > 0:
+ num_colmns.append(y)
+ df_melted = pd.melt(
+ df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
+ )
+ sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
+ else:
+ sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
+
+ ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y)))
+ ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
chart_name = "line_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
@@ -102,7 +159,6 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
)
def response_bar_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_bar_chart:{speak},")
- columns = df.columns.tolist()
if df.size <= 0:
raise ValueError("No Data!")
@@ -136,9 +192,44 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
sns.set(context="notebook", style="ticks", rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
- sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax)
- chart_name = "pie_" + str(uuid.uuid1()) + ".png"
+ hue = None
+ x, y, non_num_columns, num_colmns = data_pre_classification(df)
+ if len(non_num_columns) >= 1:
+ hue = non_num_columns[0]
+
+ if len(num_colmns) >= 1:
+ if hue:
+ if len(num_colmns) >= 2:
+ can_use_columns = num_colmns[:2]
+ else:
+ can_use_columns = num_colmns
+ sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
+ for sub_y_column in can_use_columns:
+ sns.barplot(
+ data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax
+ )
+ else:
+ if len(num_colmns) >= 3:
+ can_use_columns = num_colmns[:3]
+ else:
+ can_use_columns = num_colmns
+ sns.barplot(
+ data=df, x=x, y=y, hue=can_use_columns[0], palette="Set2", ax=ax
+ )
+
+ for sub_y_column in can_use_columns[1:]:
+ sns.barplot(
+ data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax
+ )
+ else:
+ sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
+
+ # 设置 y 轴刻度格式为普通数字格式
+ ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y)))
+ ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
+
+ chart_name = "bar_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""
{speak}
"""
@@ -188,13 +279,6 @@ def response_pie_chart(speak: str, df: DataFrame) -> str:
startangle=90,
autopct="%1.1f%%",
)
- # 手动设置 labels 的位置和大小
- ax.legend(
- loc="upper right",
- bbox_to_anchor=(0, 0, 1, 1),
- labels=df[columns[0]].values,
- fontsize=10,
- )
plt.axis("equal") # 使饼图为正圆形
# plt.title(columns[0])
diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/bar_1d2c71b4-4637-11ee-86eb-b26789cc3e58.png b/pilot/scene/chat_data/chat_excel/excel_learning/bar_1d2c71b4-4637-11ee-86eb-b26789cc3e58.png
deleted file mode 100644
index e9bd360b4..000000000
Binary files a/pilot/scene/chat_data/chat_excel/excel_learning/bar_1d2c71b4-4637-11ee-86eb-b26789cc3e58.png and /dev/null differ
diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/bar_6077945c-4638-11ee-9fb4-b26789cc3e58.png b/pilot/scene/chat_data/chat_excel/excel_learning/bar_6077945c-4638-11ee-9fb4-b26789cc3e58.png
deleted file mode 100644
index 09ad02b58..000000000
Binary files a/pilot/scene/chat_data/chat_excel/excel_learning/bar_6077945c-4638-11ee-9fb4-b26789cc3e58.png and /dev/null differ
diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/test.py b/pilot/scene/chat_data/chat_excel/excel_learning/test.py
index 2d90f169d..7dc1a2746 100644
--- a/pilot/scene/chat_data/chat_excel/excel_learning/test.py
+++ b/pilot/scene/chat_data/chat_excel/excel_learning/test.py
@@ -1,9 +1,11 @@
import os
import duckdb
import pandas as pd
+import numpy as np
import matplotlib
import seaborn as sns
import uuid
+
from pandas import DataFrame
import matplotlib.pyplot as plt
@@ -29,8 +31,12 @@ def data_pre_classification(df: DataFrame):
# 收集数据分类小于10个的列
non_numeric_colums_value_map = {}
numeric_colums_value_map = {}
+ df_filtered = df.dropna()
for column_name in columns:
- if pd.to_numeric(df[column_name], errors="coerce").notna().all():
+ print(np.issubdtype(df_filtered[column_name].dtype, np.number))
+ # if pd.to_numeric(df[column_name], errors='coerce').notna().all():
+ # if np.issubdtype(df_filtered[column_name].dtype, np.number):
+ if pd.api.types.is_numeric_dtype(df[column_name].dtypes):
number_columns.append(column_name)
unique_values = df[column_name].unique()
numeric_colums_value_map.update({column_name: len(unique_values)})
@@ -39,27 +45,68 @@ def data_pre_classification(df: DataFrame):
unique_values = df[column_name].unique()
non_numeric_colums_value_map.update({column_name: len(unique_values)})
- if len(non_numeric_colums) <= 0:
- sorted_colums_value_map = dict(
- sorted(numeric_colums_value_map.items(), key=lambda x: x[1])
- )
- numeric_colums_sort_list = list(sorted_colums_value_map.keys())
- x_column = number_columns[0]
- hue_column = numeric_colums_sort_list[0]
- y_column = numeric_colums_sort_list[1]
- elif len(number_columns) <= 0:
- raise ValueError("Have No numeric Column!")
- else:
- # 数字和非数字都存在多列,放弃部分数字列
- y_column = number_columns[0]
- x_column = non_numeric_colums[0]
- # if len(non_numeric_colums) > 1:
- #
- # else:
+ sorted_numeric_colums_value_map = dict(
+ sorted(numeric_colums_value_map.items(), key=lambda x: x[1])
+ )
+ numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys())
- # non_numeric_colums_sort_list.remove(non_numeric_colums[0])
- # hue_column = non_numeric_colums_sort_list
- return x_column, y_column, hue_column
+ sorted_colums_value_map = dict(
+ sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1])
+ )
+ non_numeric_colums_sort_list = list(sorted_colums_value_map.keys())
+
+ # Analyze x-coordinate
+ if len(non_numeric_colums_sort_list) > 0:
+ x_cloumn = non_numeric_colums_sort_list[-1]
+ non_numeric_colums_sort_list.remove(x_cloumn)
+ else:
+ x_cloumn = number_columns[0]
+ numeric_colums_sort_list.remove(x_cloumn)
+
+ # Analyze y-coordinate
+ if len(numeric_colums_sort_list) > 0:
+ y_column = numeric_colums_sort_list[0]
+ numeric_colums_sort_list.remove(y_column)
+ else:
+ raise ValueError("Not enough numeric columns for chart!")
+
+ return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list
+
+ #
+ # if len(non_numeric_colums) <=0:
+ # sorted_colums_value_map = dict(sorted(numeric_colums_value_map.items(), key=lambda x: x[1]))
+ # numeric_colums_sort_list = list(sorted_colums_value_map.keys())
+ # x_column = number_columns[0]
+ # hue_column = numeric_colums_sort_list[0]
+ # y_column = numeric_colums_sort_list[1]
+ # cols = numeric_colums_sort_list[2:]
+ # elif len(number_columns) <=0:
+ # raise ValueError("Have No numeric Column!")
+ # else:
+ # # 数字和非数字都存在多列,放弃部分数字列
+ # x_column = non_numeric_colums[0]
+ # y_column = number_columns[0]
+ # if len(non_numeric_colums) > 1:
+ # sorted_colums_value_map = dict(sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1]))
+ # non_numeric_colums_sort_list = list(sorted_colums_value_map.keys())
+ # non_numeric_colums_sort_list.remove(non_numeric_colums[0])
+ # hue_column = non_numeric_colums_sort_list[0]
+ # if len(number_columns) > 1:
+ # # try multiple charts
+ # cols = number_columns.remove( number_columns[0])
+ #
+ # else:
+ # sorted_colums_value_map = dict(sorted(numeric_colums_value_map.items(), key=lambda x: x[1]))
+ # numeric_colums_sort_list = list(sorted_colums_value_map.keys())
+ # numeric_colums_sort_list.remove(number_columns[0])
+ # if sorted_colums_value_map[numeric_colums_sort_list[0]].value < 5:
+ # hue_column = numeric_colums_sort_list[0]
+ # if len(number_columns) > 2:
+ # # try multiple charts
+ # cols = numeric_colums_sort_list.remove(numeric_colums_sort_list[0])
+ #
+ # print(x_column, y_column, hue_column, cols)
+ # return x_column, y_column, hue_column
if __name__ == "__main__":
@@ -79,17 +126,47 @@ if __name__ == "__main__":
# 获取系统中的默认中文字体名称
# default_font = fm.fontManager.defaultFontProperties.get_family()
+ # 创建一个示例 DataFrame
+ df = pd.DataFrame(
+ {
+ "A": [1, 2, 3, None, 5],
+ "B": [10, 20, 30, 40, 50],
+ "C": [1.1, 2.2, None, 4.4, 5.5],
+ "D": ["a", "b", "c", "d", "e"],
+ }
+ )
+
+ # 判断列是否为数字列
+ column_name = "A" # 要判断的列名
+ is_numeric = pd.to_numeric(df[column_name], errors="coerce").notna().all()
+
+ if is_numeric:
+ print(
+ f"Column '{column_name}' is a numeric column (ignoring null and NaN values in some elements)."
+ )
+ else:
+ print(
+ f"Column '{column_name}' is not a numeric column (ignoring null and NaN values in some elements)."
+ )
+
#
- excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx")
+ # excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx")
+ excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/yhj-zx.csv")
#
# # colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear")
# # colunms, datas = excel_reader.run( """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """)
+ # df = excel_reader.get_df_by_sql_ex(""" SELECT Segment, Country, SUM(Sales) AS Total_Sales, SUM(Profit) AS Total_Profit FROM example GROUP BY Segment, Country """)
df = excel_reader.get_df_by_sql_ex(
- """ SELECT Segment, Country, SUM(Sales) AS Total_Sales, SUM(Profit) AS Total_Profit FROM example GROUP BY Segment, Country """
+ """ SELECT `明细`, `费用小计`, `支出小计` FROM yhj-zx limit 10"""
)
- x, y, hue = data_pre_classification(df)
- print(x, y, hue)
+ for column_name in df.columns.tolist():
+ print(column_name + ":" + str(df[column_name].dtypes))
+ print(
+ column_name
+ + ":"
+ + str(pd.api.types.is_numeric_dtype(df[column_name].dtypes))
+ )
columns = df.columns.tolist()
font_names = [
@@ -118,116 +195,66 @@ if __name__ == "__main__":
sns.color_palette("hls", 10)
sns.hls_palette(8, l=0.5, s=0.7)
sns.set(context="notebook", style="ticks", rc=rc)
- # sns.set_palette("Set3") # 设置颜色主题
- # sns.set_style("dark")
- # sns.color_palette("hls", 10)
- # sns.hls_palette(8, l=.5, s=.7)
- # sns.set(context='notebook', style='ticks', rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
# plt.ticklabel_format(style='plain')
# ax = df.plot(kind='bar', ax=ax)
- # sns.barplot(df, x=x, y=y, hue= "Country", ax=ax)
- sns.catplot(data=df, x=x, y=y, hue="Country", kind="bar")
+ # sns.barplot(df, x=x, y="Total_Sales", hue='Country', ax=ax)
+ # sns.barplot(df, x=x, y="Total_Profit", hue='Country', ax=ax)
+
+ # sns.catplot(data=df, x=x, y=y, hue='Country', kind='bar')
+ x, y, non_num_columns, num_colmns = data_pre_classification(df)
+ print(x, y, str(non_num_columns), str(num_colmns))
+ ## 复杂折线图实现
+ if len(num_colmns) > 0:
+ num_colmns.append(y)
+ df_melted = pd.melt(
+ df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
+ )
+ sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
+ else:
+ sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
+
+ # hue = None
+ # ## 复杂柱状图实现
+ # x,y, non_num_columns, num_colmns =data_pre_classification(df)
+ # if len(non_num_columns) >= 1:
+ # hue = non_num_columns[0]
+
+ # if len(num_colmns)>=1:
+ # if hue:
+ # if len(num_colmns) >= 2:
+ # can_use_columns = num_colmns[:2]
+ # else:
+ # can_use_columns = num_colmns
+ # sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
+ # for sub_y_column in can_use_columns:
+ # sns.barplot(data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax)
+ # else:
+ # if len(num_colmns) >= 3:
+ # can_use_columns = num_colmns[:3]
+ # else:
+ # can_use_columns = num_colmns
+ # sns.barplot(data=df, x=x, y=y, hue=can_use_columns[0], palette="Set2", ax=ax)
+ #
+ # for sub_y_column in can_use_columns[1:]:
+ # sns.barplot(data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax)
+ # else:
+ # sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
+
+ # # 转换 DataFrame 格式
+ # df_melted = pd.melt(df, id_vars=x, value_vars=['Total_Sales', 'Total_Profit'], var_name='line', value_name='y')
+ #
+ # # 绘制多列柱状图
+ #
+ # sns.barplot(data=df, x=x, y="Total_Sales", hue = "Country", palette="Set2", ax=ax)
+ # sns.barplot(data=df, x=x, y="Total_Profit", hue = "Country", palette="Set1", ax=ax)
+
# 设置 y 轴刻度格式为普通数字格式
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
- # fonts = font_manager.findSystemFonts()
- # font_path = ""
- # for font in fonts:
- # if "Heiti" in font:
- # font_path = font
- # my_font = font_manager.FontProperties(fname=font_path)
- # plt.title("测试", fontproperties=my_font)
- # plt.ylabel(columns[1], fontproperties=my_font)
- # plt.xlabel(columns[0], fontproperties=my_font)
-
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
chart_path = chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
- # sns.set(context="notebook", style="ticks", color_codes=True)
- # sns.set_palette("Set3") # 设置颜色主题
- #
- # # fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
- # fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
- # plt.subplots_adjust(top=0.9)
- # ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%')
- # # 手动设置 labels 的位置和大小
- # ax.legend(loc='center left', bbox_to_anchor=(-1, 0.5, 0,0), labels=None, fontsize=10)
- # plt.axis('equal') # 使饼图为正圆形
- # plt.show()
-
- #
- #
- # def csv_colunm_foramt(val):
- # if str(val).find("$") >= 0:
- # return float(val.replace('$', '').replace(',', ''))
- # if str(val).find("¥") >= 0:
- # return float(val.replace('¥', '').replace(',', ''))
- # return val
- #
- # # 获取当前时间戳,作为代码开始的时间
- # start_time = int(time.time() * 1000)
- #
- # df = pd.read_excel('/Users/tuyang.yhj/Downloads/example.xlsx')
- # # 读取 Excel 文件为 Pandas DataFrame
- # df = pd.read_excel('/Users/tuyang.yhj/Downloads/example.xlsx', converters={i: csv_colunm_foramt for i in range(df.shape[1])})
- #
- # # d = df.values
- # # print(d.shape[0])
- # # for row in d:
- # # print(row[0])
- # # print(len(row))
- # # r = df.iterrows()
- #
- # # 获取当前时间戳,作为代码结束的时间
- # end_time = int(time.time() * 1000)
- #
- # print(f"耗时:{(end_time-start_time)/1000}秒")
- #
- # # 连接 DuckDB 数据库
- # con = duckdb.connect(database=':memory:', read_only=False)
- #
- # # 将 DataFrame 写入 DuckDB 数据库中的一个表
- # con.register('example', df)
- #
- # # 查询 DuckDB 数据库中的表
- # conn = con.cursor()
- # results = con.execute('SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country ORDER BY Total_Profit DESC LIMIT 1;')
- # colunms = []
- # for descrip in results.description:
- # colunms.append(descrip[0])
- # print(colunms)
- # for row in results.fetchall():
- # print(row)
- #
- #
- # # 连接 DuckDB 数据库
- # # con = duckdb.connect(':memory:')
- #
- # # # 加载 spatial 扩展
- # # con.execute('install spatial;')
- # # con.execute('load spatial;')
- # #
- # # # 查询 duckdb_internal 系统表,获取扩展列表
- # # result = con.execute("SELECT * FROM duckdb_internal.functions WHERE schema='list_extensions';")
- # #
- # # # 遍历查询结果,输出扩展名称和版本号
- # # for row in result:
- # # print(row['name'], row['return_type'])
- # # duckdb.read_csv('/Users/tuyang.yhj/Downloads/example_csc.csv')
- # # result = duckdb.sql('SELECT * FROM "/Users/tuyang.yhj/Downloads/yhj-zx.csv" ')
- # # result = duckdb.sql('SELECT * FROM "/Users/tuyang.yhj/Downloads/example_csc.csv" limit 20')
- # # for row in result.fetchall():
- # # print(row)
- #
- #
- # # result = con.execute("SELECT * FROM st_read('/Users/tuyang.yhj/Downloads/example.xlsx', layer='Sheet1')")
- # # # 遍历查询结果
- # # for row in result.fetchall():
- # # print(row)
- # print("xx")
- #
- #
#
diff --git a/pilot/scene/chat_data/chat_excel/excel_reader.py b/pilot/scene/chat_data/chat_excel/excel_reader.py
index ff4759fa2..7ac8b0d4e 100644
--- a/pilot/scene/chat_data/chat_excel/excel_reader.py
+++ b/pilot/scene/chat_data/chat_excel/excel_reader.py
@@ -71,7 +71,8 @@ class ExcelReader:
for column_name in df_tmp.columns:
self.columns_map.update({column_name: excel_colunm_format(column_name)})
try:
- self.df[column_name] = self.df[column_name].astype(float)
+ self.df[column_name] = pd.to_numeric(self.df[column_name])
+ self.df[column_name] = self.df[column_name].fillna(0)
except Exception as e:
print("transfor column error!" + column_name)