diff --git a/pilot/commands/disply_type/show_chart_gen.py b/pilot/commands/disply_type/show_chart_gen.py index 27f1639d5..6a3ee3e27 100644 --- a/pilot/commands/disply_type/show_chart_gen.py +++ b/pilot/commands/disply_type/show_chart_gen.py @@ -4,13 +4,13 @@ from pilot.commands.command_mange import command from pilot.configs.config import Config import pandas as pd import uuid -import io import os import matplotlib import seaborn as sns matplotlib.use("Agg") import matplotlib.pyplot as plt +import matplotlib.ticker as mtick from matplotlib.font_manager import FontManager from pilot.configs.model_config import LOGDIR @@ -21,6 +21,54 @@ logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log") static_message_img_path = os.path.join(os.getcwd(), "message/img") +def data_pre_classification(df: DataFrame): + ## Data pre-classification + columns = df.columns.tolist() + + number_columns = [] + non_numeric_colums = [] + + # 收集数据分类小于10个的列 + non_numeric_colums_value_map = {} + numeric_colums_value_map = {} + for column_name in columns: + if pd.api.types.is_numeric_dtype(df[column_name].dtypes): + number_columns.append(column_name) + unique_values = df[column_name].unique() + numeric_colums_value_map.update({column_name: len(unique_values)}) + else: + non_numeric_colums.append(column_name) + unique_values = df[column_name].unique() + non_numeric_colums_value_map.update({column_name: len(unique_values)}) + + sorted_numeric_colums_value_map = dict( + sorted(numeric_colums_value_map.items(), key=lambda x: x[1]) + ) + numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys()) + + sorted_colums_value_map = dict( + sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1]) + ) + non_numeric_colums_sort_list = list(sorted_colums_value_map.keys()) + + # Analyze x-coordinate + if len(non_numeric_colums_sort_list) > 0: + x_cloumn = non_numeric_colums_sort_list[-1] + non_numeric_colums_sort_list.remove(x_cloumn) + else: + x_cloumn = number_columns[0] + numeric_colums_sort_list.remove(x_cloumn) + + # Analyze y-coordinate + if len(numeric_colums_sort_list) > 0: + y_column = numeric_colums_sort_list[0] + numeric_colums_sort_list.remove(y_column) + else: + raise ValueError("Not enough numeric columns for chart!") + + return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list + + def zh_font_set(): font_names = [ "Heiti TC", @@ -48,9 +96,6 @@ def zh_font_set(): ) def response_line_chart(speak: str, df: DataFrame) -> str: logger.info(f"response_line_chart:{speak},") - - columns = df.columns.tolist() - if df.size <= 0: raise ValueError("No Data!") @@ -85,7 +130,19 @@ def response_line_chart(speak: str, df: DataFrame) -> str: sns.set(context="notebook", style="ticks", rc=rc) fig, ax = plt.subplots(figsize=(8, 5), dpi=100) - sns.lineplot(df, x=columns[0], y=columns[1], ax=ax) + x, y, non_num_columns, num_colmns = data_pre_classification(df) + # ## 复杂折线图实现 + if len(num_colmns) > 0: + num_colmns.append(y) + df_melted = pd.melt( + df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value" + ) + sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2") + else: + sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2") + + ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y))) + ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x))) chart_name = "line_" + str(uuid.uuid1()) + ".png" chart_path = static_message_img_path + "/" + chart_name @@ -102,7 +159,6 @@ def response_line_chart(speak: str, df: DataFrame) -> str: ) def response_bar_chart(speak: str, df: DataFrame) -> str: logger.info(f"response_bar_chart:{speak},") - columns = df.columns.tolist() if df.size <= 0: raise ValueError("No Data!") @@ -136,9 +192,44 @@ def response_bar_chart(speak: str, df: DataFrame) -> str: sns.set(context="notebook", style="ticks", rc=rc) fig, ax = plt.subplots(figsize=(8, 5), dpi=100) - sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax) - chart_name = "pie_" + str(uuid.uuid1()) + ".png" + hue = None + x, y, non_num_columns, num_colmns = data_pre_classification(df) + if len(non_num_columns) >= 1: + hue = non_num_columns[0] + + if len(num_colmns) >= 1: + if hue: + if len(num_colmns) >= 2: + can_use_columns = num_colmns[:2] + else: + can_use_columns = num_colmns + sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax) + for sub_y_column in can_use_columns: + sns.barplot( + data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax + ) + else: + if len(num_colmns) >= 3: + can_use_columns = num_colmns[:3] + else: + can_use_columns = num_colmns + sns.barplot( + data=df, x=x, y=y, hue=can_use_columns[0], palette="Set2", ax=ax + ) + + for sub_y_column in can_use_columns[1:]: + sns.barplot( + data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax + ) + else: + sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax) + + # 设置 y 轴刻度格式为普通数字格式 + ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y))) + ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x))) + + chart_name = "bar_" + str(uuid.uuid1()) + ".png" chart_path = static_message_img_path + "/" + chart_name plt.savefig(chart_path, bbox_inches="tight", dpi=100) html_img = f"""
{speak}
""" @@ -188,13 +279,6 @@ def response_pie_chart(speak: str, df: DataFrame) -> str: startangle=90, autopct="%1.1f%%", ) - # 手动设置 labels 的位置和大小 - ax.legend( - loc="upper right", - bbox_to_anchor=(0, 0, 1, 1), - labels=df[columns[0]].values, - fontsize=10, - ) plt.axis("equal") # 使饼图为正圆形 # plt.title(columns[0]) diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/bar_1d2c71b4-4637-11ee-86eb-b26789cc3e58.png b/pilot/scene/chat_data/chat_excel/excel_learning/bar_1d2c71b4-4637-11ee-86eb-b26789cc3e58.png deleted file mode 100644 index e9bd360b4..000000000 Binary files a/pilot/scene/chat_data/chat_excel/excel_learning/bar_1d2c71b4-4637-11ee-86eb-b26789cc3e58.png and /dev/null differ diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/bar_6077945c-4638-11ee-9fb4-b26789cc3e58.png b/pilot/scene/chat_data/chat_excel/excel_learning/bar_6077945c-4638-11ee-9fb4-b26789cc3e58.png deleted file mode 100644 index 09ad02b58..000000000 Binary files a/pilot/scene/chat_data/chat_excel/excel_learning/bar_6077945c-4638-11ee-9fb4-b26789cc3e58.png and /dev/null differ diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/test.py b/pilot/scene/chat_data/chat_excel/excel_learning/test.py index 2d90f169d..7dc1a2746 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/test.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/test.py @@ -1,9 +1,11 @@ import os import duckdb import pandas as pd +import numpy as np import matplotlib import seaborn as sns import uuid + from pandas import DataFrame import matplotlib.pyplot as plt @@ -29,8 +31,12 @@ def data_pre_classification(df: DataFrame): # 收集数据分类小于10个的列 non_numeric_colums_value_map = {} numeric_colums_value_map = {} + df_filtered = df.dropna() for column_name in columns: - if pd.to_numeric(df[column_name], errors="coerce").notna().all(): + print(np.issubdtype(df_filtered[column_name].dtype, np.number)) + # if pd.to_numeric(df[column_name], errors='coerce').notna().all(): + # if np.issubdtype(df_filtered[column_name].dtype, np.number): + if pd.api.types.is_numeric_dtype(df[column_name].dtypes): number_columns.append(column_name) unique_values = df[column_name].unique() numeric_colums_value_map.update({column_name: len(unique_values)}) @@ -39,27 +45,68 @@ def data_pre_classification(df: DataFrame): unique_values = df[column_name].unique() non_numeric_colums_value_map.update({column_name: len(unique_values)}) - if len(non_numeric_colums) <= 0: - sorted_colums_value_map = dict( - sorted(numeric_colums_value_map.items(), key=lambda x: x[1]) - ) - numeric_colums_sort_list = list(sorted_colums_value_map.keys()) - x_column = number_columns[0] - hue_column = numeric_colums_sort_list[0] - y_column = numeric_colums_sort_list[1] - elif len(number_columns) <= 0: - raise ValueError("Have No numeric Column!") - else: - # 数字和非数字都存在多列,放弃部分数字列 - y_column = number_columns[0] - x_column = non_numeric_colums[0] - # if len(non_numeric_colums) > 1: - # - # else: + sorted_numeric_colums_value_map = dict( + sorted(numeric_colums_value_map.items(), key=lambda x: x[1]) + ) + numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys()) - # non_numeric_colums_sort_list.remove(non_numeric_colums[0]) - # hue_column = non_numeric_colums_sort_list - return x_column, y_column, hue_column + sorted_colums_value_map = dict( + sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1]) + ) + non_numeric_colums_sort_list = list(sorted_colums_value_map.keys()) + + # Analyze x-coordinate + if len(non_numeric_colums_sort_list) > 0: + x_cloumn = non_numeric_colums_sort_list[-1] + non_numeric_colums_sort_list.remove(x_cloumn) + else: + x_cloumn = number_columns[0] + numeric_colums_sort_list.remove(x_cloumn) + + # Analyze y-coordinate + if len(numeric_colums_sort_list) > 0: + y_column = numeric_colums_sort_list[0] + numeric_colums_sort_list.remove(y_column) + else: + raise ValueError("Not enough numeric columns for chart!") + + return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list + + # + # if len(non_numeric_colums) <=0: + # sorted_colums_value_map = dict(sorted(numeric_colums_value_map.items(), key=lambda x: x[1])) + # numeric_colums_sort_list = list(sorted_colums_value_map.keys()) + # x_column = number_columns[0] + # hue_column = numeric_colums_sort_list[0] + # y_column = numeric_colums_sort_list[1] + # cols = numeric_colums_sort_list[2:] + # elif len(number_columns) <=0: + # raise ValueError("Have No numeric Column!") + # else: + # # 数字和非数字都存在多列,放弃部分数字列 + # x_column = non_numeric_colums[0] + # y_column = number_columns[0] + # if len(non_numeric_colums) > 1: + # sorted_colums_value_map = dict(sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1])) + # non_numeric_colums_sort_list = list(sorted_colums_value_map.keys()) + # non_numeric_colums_sort_list.remove(non_numeric_colums[0]) + # hue_column = non_numeric_colums_sort_list[0] + # if len(number_columns) > 1: + # # try multiple charts + # cols = number_columns.remove( number_columns[0]) + # + # else: + # sorted_colums_value_map = dict(sorted(numeric_colums_value_map.items(), key=lambda x: x[1])) + # numeric_colums_sort_list = list(sorted_colums_value_map.keys()) + # numeric_colums_sort_list.remove(number_columns[0]) + # if sorted_colums_value_map[numeric_colums_sort_list[0]].value < 5: + # hue_column = numeric_colums_sort_list[0] + # if len(number_columns) > 2: + # # try multiple charts + # cols = numeric_colums_sort_list.remove(numeric_colums_sort_list[0]) + # + # print(x_column, y_column, hue_column, cols) + # return x_column, y_column, hue_column if __name__ == "__main__": @@ -79,17 +126,47 @@ if __name__ == "__main__": # 获取系统中的默认中文字体名称 # default_font = fm.fontManager.defaultFontProperties.get_family() + # 创建一个示例 DataFrame + df = pd.DataFrame( + { + "A": [1, 2, 3, None, 5], + "B": [10, 20, 30, 40, 50], + "C": [1.1, 2.2, None, 4.4, 5.5], + "D": ["a", "b", "c", "d", "e"], + } + ) + + # 判断列是否为数字列 + column_name = "A" # 要判断的列名 + is_numeric = pd.to_numeric(df[column_name], errors="coerce").notna().all() + + if is_numeric: + print( + f"Column '{column_name}' is a numeric column (ignoring null and NaN values in some elements)." + ) + else: + print( + f"Column '{column_name}' is not a numeric column (ignoring null and NaN values in some elements)." + ) + # - excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx") + # excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx") + excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/yhj-zx.csv") # # # colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear") # # colunms, datas = excel_reader.run( """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """) + # df = excel_reader.get_df_by_sql_ex(""" SELECT Segment, Country, SUM(Sales) AS Total_Sales, SUM(Profit) AS Total_Profit FROM example GROUP BY Segment, Country """) df = excel_reader.get_df_by_sql_ex( - """ SELECT Segment, Country, SUM(Sales) AS Total_Sales, SUM(Profit) AS Total_Profit FROM example GROUP BY Segment, Country """ + """ SELECT `明细`, `费用小计`, `支出小计` FROM yhj-zx limit 10""" ) - x, y, hue = data_pre_classification(df) - print(x, y, hue) + for column_name in df.columns.tolist(): + print(column_name + ":" + str(df[column_name].dtypes)) + print( + column_name + + ":" + + str(pd.api.types.is_numeric_dtype(df[column_name].dtypes)) + ) columns = df.columns.tolist() font_names = [ @@ -118,116 +195,66 @@ if __name__ == "__main__": sns.color_palette("hls", 10) sns.hls_palette(8, l=0.5, s=0.7) sns.set(context="notebook", style="ticks", rc=rc) - # sns.set_palette("Set3") # 设置颜色主题 - # sns.set_style("dark") - # sns.color_palette("hls", 10) - # sns.hls_palette(8, l=.5, s=.7) - # sns.set(context='notebook', style='ticks', rc=rc) fig, ax = plt.subplots(figsize=(8, 5), dpi=100) # plt.ticklabel_format(style='plain') # ax = df.plot(kind='bar', ax=ax) - # sns.barplot(df, x=x, y=y, hue= "Country", ax=ax) - sns.catplot(data=df, x=x, y=y, hue="Country", kind="bar") + # sns.barplot(df, x=x, y="Total_Sales", hue='Country', ax=ax) + # sns.barplot(df, x=x, y="Total_Profit", hue='Country', ax=ax) + + # sns.catplot(data=df, x=x, y=y, hue='Country', kind='bar') + x, y, non_num_columns, num_colmns = data_pre_classification(df) + print(x, y, str(non_num_columns), str(num_colmns)) + ## 复杂折线图实现 + if len(num_colmns) > 0: + num_colmns.append(y) + df_melted = pd.melt( + df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value" + ) + sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2") + else: + sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2") + + # hue = None + # ## 复杂柱状图实现 + # x,y, non_num_columns, num_colmns =data_pre_classification(df) + # if len(non_num_columns) >= 1: + # hue = non_num_columns[0] + + # if len(num_colmns)>=1: + # if hue: + # if len(num_colmns) >= 2: + # can_use_columns = num_colmns[:2] + # else: + # can_use_columns = num_colmns + # sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax) + # for sub_y_column in can_use_columns: + # sns.barplot(data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax) + # else: + # if len(num_colmns) >= 3: + # can_use_columns = num_colmns[:3] + # else: + # can_use_columns = num_colmns + # sns.barplot(data=df, x=x, y=y, hue=can_use_columns[0], palette="Set2", ax=ax) + # + # for sub_y_column in can_use_columns[1:]: + # sns.barplot(data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax) + # else: + # sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax) + + # # 转换 DataFrame 格式 + # df_melted = pd.melt(df, id_vars=x, value_vars=['Total_Sales', 'Total_Profit'], var_name='line', value_name='y') + # + # # 绘制多列柱状图 + # + # sns.barplot(data=df, x=x, y="Total_Sales", hue = "Country", palette="Set2", ax=ax) + # sns.barplot(data=df, x=x, y="Total_Profit", hue = "Country", palette="Set1", ax=ax) + # 设置 y 轴刻度格式为普通数字格式 ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x))) - # fonts = font_manager.findSystemFonts() - # font_path = "" - # for font in fonts: - # if "Heiti" in font: - # font_path = font - # my_font = font_manager.FontProperties(fname=font_path) - # plt.title("测试", fontproperties=my_font) - # plt.ylabel(columns[1], fontproperties=my_font) - # plt.xlabel(columns[0], fontproperties=my_font) - chart_name = "bar_" + str(uuid.uuid1()) + ".png" chart_path = chart_name plt.savefig(chart_path, bbox_inches="tight", dpi=100) - # sns.set(context="notebook", style="ticks", color_codes=True) - # sns.set_palette("Set3") # 设置颜色主题 - # - # # fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90) - # fig, ax = plt.subplots(figsize=(8, 5), dpi=100) - # plt.subplots_adjust(top=0.9) - # ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%') - # # 手动设置 labels 的位置和大小 - # ax.legend(loc='center left', bbox_to_anchor=(-1, 0.5, 0,0), labels=None, fontsize=10) - # plt.axis('equal') # 使饼图为正圆形 - # plt.show() - - # - # - # def csv_colunm_foramt(val): - # if str(val).find("$") >= 0: - # return float(val.replace('$', '').replace(',', '')) - # if str(val).find("¥") >= 0: - # return float(val.replace('¥', '').replace(',', '')) - # return val - # - # # 获取当前时间戳,作为代码开始的时间 - # start_time = int(time.time() * 1000) - # - # df = pd.read_excel('/Users/tuyang.yhj/Downloads/example.xlsx') - # # 读取 Excel 文件为 Pandas DataFrame - # df = pd.read_excel('/Users/tuyang.yhj/Downloads/example.xlsx', converters={i: csv_colunm_foramt for i in range(df.shape[1])}) - # - # # d = df.values - # # print(d.shape[0]) - # # for row in d: - # # print(row[0]) - # # print(len(row)) - # # r = df.iterrows() - # - # # 获取当前时间戳,作为代码结束的时间 - # end_time = int(time.time() * 1000) - # - # print(f"耗时:{(end_time-start_time)/1000}秒") - # - # # 连接 DuckDB 数据库 - # con = duckdb.connect(database=':memory:', read_only=False) - # - # # 将 DataFrame 写入 DuckDB 数据库中的一个表 - # con.register('example', df) - # - # # 查询 DuckDB 数据库中的表 - # conn = con.cursor() - # results = con.execute('SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country ORDER BY Total_Profit DESC LIMIT 1;') - # colunms = [] - # for descrip in results.description: - # colunms.append(descrip[0]) - # print(colunms) - # for row in results.fetchall(): - # print(row) - # - # - # # 连接 DuckDB 数据库 - # # con = duckdb.connect(':memory:') - # - # # # 加载 spatial 扩展 - # # con.execute('install spatial;') - # # con.execute('load spatial;') - # # - # # # 查询 duckdb_internal 系统表,获取扩展列表 - # # result = con.execute("SELECT * FROM duckdb_internal.functions WHERE schema='list_extensions';") - # # - # # # 遍历查询结果,输出扩展名称和版本号 - # # for row in result: - # # print(row['name'], row['return_type']) - # # duckdb.read_csv('/Users/tuyang.yhj/Downloads/example_csc.csv') - # # result = duckdb.sql('SELECT * FROM "/Users/tuyang.yhj/Downloads/yhj-zx.csv" ') - # # result = duckdb.sql('SELECT * FROM "/Users/tuyang.yhj/Downloads/example_csc.csv" limit 20') - # # for row in result.fetchall(): - # # print(row) - # - # - # # result = con.execute("SELECT * FROM st_read('/Users/tuyang.yhj/Downloads/example.xlsx', layer='Sheet1')") - # # # 遍历查询结果 - # # for row in result.fetchall(): - # # print(row) - # print("xx") - # - # # diff --git a/pilot/scene/chat_data/chat_excel/excel_reader.py b/pilot/scene/chat_data/chat_excel/excel_reader.py index ff4759fa2..7ac8b0d4e 100644 --- a/pilot/scene/chat_data/chat_excel/excel_reader.py +++ b/pilot/scene/chat_data/chat_excel/excel_reader.py @@ -71,7 +71,8 @@ class ExcelReader: for column_name in df_tmp.columns: self.columns_map.update({column_name: excel_colunm_format(column_name)}) try: - self.df[column_name] = self.df[column_name].astype(float) + self.df[column_name] = pd.to_numeric(self.df[column_name]) + self.df[column_name] = self.df[column_name].fillna(0) except Exception as e: print("transfor column error!" + column_name)