feat(editor): ChatExcel

🔥ChatExcel Mode Operation Manual
This commit is contained in:
yhjun1026 2023-08-30 14:41:51 +08:00
parent 91225e8b25
commit c03825b2c6
5 changed files with 220 additions and 134 deletions

View File

@ -4,7 +4,6 @@ from pilot.commands.command_mange import command
from pilot.configs.config import Config
import pandas as pd
import uuid
import io
import os
import matplotlib
import seaborn as sns
@ -20,6 +19,53 @@ CFG = Config()
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
static_message_img_path = os.path.join(os.getcwd(), "message/img")
def data_pre_classification(df: DataFrame):
## Data pre-classification
columns = df.columns.tolist()
number_columns = []
non_numeric_colums = []
# 收集数据分类小于10个的列
non_numeric_colums_value_map = {}
numeric_colums_value_map = {}
for column_name in columns:
if pd.api.types.is_numeric_dtype(df[column_name].dtypes):
number_columns.append(column_name)
unique_values = df[column_name].unique()
numeric_colums_value_map.update({column_name: len(unique_values)})
else:
non_numeric_colums.append(column_name)
unique_values = df[column_name].unique()
non_numeric_colums_value_map.update({column_name: len(unique_values)})
sorted_numeric_colums_value_map = dict(sorted(numeric_colums_value_map.items(), key=lambda x: x[1]))
numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys())
sorted_colums_value_map = dict(sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1]))
non_numeric_colums_sort_list = list(sorted_colums_value_map.keys())
# Analyze x-coordinate
if len(non_numeric_colums_sort_list) > 0:
x_cloumn = non_numeric_colums_sort_list[-1]
non_numeric_colums_sort_list.remove(x_cloumn)
else:
x_cloumn = number_columns[0]
numeric_colums_sort_list.remove(x_cloumn)
# Analyze y-coordinate
if len(numeric_colums_sort_list) > 0:
y_column = numeric_colums_sort_list[0]
numeric_colums_sort_list.remove(y_column)
else:
raise ValueError("Not enough numeric columns for chart")
return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list
def zh_font_set():
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
fm = FontManager()
@ -36,9 +82,6 @@ def zh_font_set():
'"speak": "<speak>", "df":"<data frame>"')
def response_line_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_line_chart:{speak},")
columns = df.columns.tolist()
if df.size <= 0:
raise ValueError("No Data")
@ -65,7 +108,14 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
sns.set(context='notebook', style='ticks', rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
sns.lineplot(df, x=columns[0], y=columns[1], ax=ax)
x,y, non_num_columns, num_colmns =data_pre_classification(df)
# ## 复杂折线图实现
if len(num_colmns)>0:
num_colmns.append(y)
df_melted = pd.melt(df, id_vars=x, value_vars=num_colmns, var_name='line', value_name='Value')
sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
else:
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
chart_name = "line_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
@ -79,7 +129,6 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
'"speak": "<speak>", "df":"<data frame>"')
def response_bar_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_bar_chart:{speak},")
columns = df.columns.tolist()
if df.size <= 0:
raise ValueError("No Data")
@ -105,9 +154,34 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
sns.set(context='notebook', style='ticks', rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax)
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
x,y, non_num_columns, num_colmns =data_pre_classification(df)
if len(non_num_columns) >= 1:
hue = non_num_columns[0]
if len(num_colmns)>=1:
if hue:
if len(num_colmns) >= 2:
can_use_columns = num_colmns[:2]
else:
can_use_columns = num_colmns
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
for sub_y_column in can_use_columns:
sns.barplot(data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax)
else:
if len(num_colmns) >= 3:
can_use_columns = num_colmns[:3]
else:
can_use_columns = num_colmns
sns.barplot(data=df, x=x, y=y, hue=can_use_columns[0], palette="Set2", ax=ax)
for sub_y_column in can_use_columns[1:]:
sns.barplot(data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax)
else:
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 18 KiB

View File

@ -1,9 +1,11 @@
import os
import duckdb
import pandas as pd
import numpy as np
import matplotlib
import seaborn as sns
import uuid
from pandas import DataFrame
import matplotlib.pyplot as plt
@ -29,9 +31,12 @@ def data_pre_classification(df: DataFrame):
# 收集数据分类小于10个的列
non_numeric_colums_value_map = {}
numeric_colums_value_map = {}
df_filtered = df.dropna()
for column_name in columns:
if pd.to_numeric(df[column_name], errors='coerce').notna().all():
print(np.issubdtype(df_filtered[column_name].dtype, np.number))
# if pd.to_numeric(df[column_name], errors='coerce').notna().all():
# if np.issubdtype(df_filtered[column_name].dtype, np.number):
if pd.api.types.is_numeric_dtype(df[column_name].dtypes):
number_columns.append(column_name)
unique_values = df[column_name].unique()
numeric_colums_value_map.update({column_name: len(unique_values)})
@ -40,26 +45,65 @@ def data_pre_classification(df: DataFrame):
unique_values = df[column_name].unique()
non_numeric_colums_value_map.update({column_name: len(unique_values)})
sorted_numeric_colums_value_map = dict(sorted(numeric_colums_value_map.items(), key=lambda x: x[1]))
numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys())
if len(non_numeric_colums) <=0:
sorted_colums_value_map = dict(sorted(numeric_colums_value_map.items(), key=lambda x: x[1]))
numeric_colums_sort_list = list(sorted_colums_value_map.keys())
x_column = number_columns[0]
hue_column = numeric_colums_sort_list[0]
y_column = numeric_colums_sort_list[1]
elif len(number_columns) <=0:
raise ValueError("Have No numeric Column")
sorted_colums_value_map = dict(sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1]))
non_numeric_colums_sort_list = list(sorted_colums_value_map.keys())
# Analyze x-coordinate
if len(non_numeric_colums_sort_list) > 0:
x_cloumn = non_numeric_colums_sort_list[-1]
non_numeric_colums_sort_list.remove(x_cloumn)
else:
# 数字和非数字都存在多列,放弃部分数字列
y_column = number_columns[0]
x_column = non_numeric_colums[0]
x_cloumn = number_columns[0]
numeric_colums_sort_list.remove(x_cloumn)
# Analyze y-coordinate
if len(numeric_colums_sort_list) > 0:
y_column = numeric_colums_sort_list[0]
numeric_colums_sort_list.remove(y_column)
else:
raise ValueError("Not enough numeric columns for chart")
return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list
#
# if len(non_numeric_colums) <=0:
# sorted_colums_value_map = dict(sorted(numeric_colums_value_map.items(), key=lambda x: x[1]))
# numeric_colums_sort_list = list(sorted_colums_value_map.keys())
# x_column = number_columns[0]
# hue_column = numeric_colums_sort_list[0]
# y_column = numeric_colums_sort_list[1]
# cols = numeric_colums_sort_list[2:]
# elif len(number_columns) <=0:
# raise ValueError("Have No numeric Column")
# else:
# # 数字和非数字都存在多列,放弃部分数字列
# x_column = non_numeric_colums[0]
# y_column = number_columns[0]
# if len(non_numeric_colums) > 1:
# sorted_colums_value_map = dict(sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1]))
# non_numeric_colums_sort_list = list(sorted_colums_value_map.keys())
# non_numeric_colums_sort_list.remove(non_numeric_colums[0])
# hue_column = non_numeric_colums_sort_list[0]
# if len(number_columns) > 1:
# # try multiple charts
# cols = number_columns.remove( number_columns[0])
#
# else:
# non_numeric_colums_sort_list.remove(non_numeric_colums[0])
# hue_column = non_numeric_colums_sort_list
return x_column, y_column, hue_column
# sorted_colums_value_map = dict(sorted(numeric_colums_value_map.items(), key=lambda x: x[1]))
# numeric_colums_sort_list = list(sorted_colums_value_map.keys())
# numeric_colums_sort_list.remove(number_columns[0])
# if sorted_colums_value_map[numeric_colums_sort_list[0]].value < 5:
# hue_column = numeric_colums_sort_list[0]
# if len(number_columns) > 2:
# # try multiple charts
# cols = numeric_colums_sort_list.remove(numeric_colums_sort_list[0])
#
# print(x_column, y_column, hue_column, cols)
# return x_column, y_column, hue_column
if __name__ == "__main__":
# connect = duckdb.connect("/Users/tuyang.yhj/Downloads/example.xlsx")
@ -78,16 +122,35 @@ if __name__ == "__main__":
# 获取系统中的默认中文字体名称
# default_font = fm.fontManager.defaultFontProperties.get_family()
# 创建一个示例 DataFrame
df = pd.DataFrame({
'A': [1, 2, 3, None, 5],
'B': [10, 20, 30, 40, 50],
'C': [1.1, 2.2, None, 4.4, 5.5],
'D': ['a', 'b', 'c', 'd', 'e']
})
# 判断列是否为数字列
column_name = 'A' # 要判断的列名
is_numeric = pd.to_numeric(df[column_name], errors='coerce').notna().all()
if is_numeric:
print(f"Column '{column_name}' is a numeric column (ignoring null and NaN values in some elements).")
else:
print(f"Column '{column_name}' is not a numeric column (ignoring null and NaN values in some elements).")
#
excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx")
# excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx")
excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/yhj-zx.csv")
#
# # colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear")
# # colunms, datas = excel_reader.run( """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """)
df = excel_reader.get_df_by_sql_ex(""" SELECT Segment, Country, SUM(Sales) AS Total_Sales, SUM(Profit) AS Total_Profit FROM example GROUP BY Segment, Country """)
# df = excel_reader.get_df_by_sql_ex(""" SELECT Segment, Country, SUM(Sales) AS Total_Sales, SUM(Profit) AS Total_Profit FROM example GROUP BY Segment, Country """)
df = excel_reader.get_df_by_sql_ex(""" SELECT `明细`, `费用小计`, `支出小计` FROM yhj-zx limit 10""")
x,y,hue =data_pre_classification(df)
print(x, y, hue)
for column_name in df.columns.tolist():
print(column_name + ":"+ str(df[column_name].dtypes))
print(column_name + ":"+ str(pd.api.types.is_numeric_dtype(df[column_name].dtypes)))
columns = df.columns.tolist()
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
@ -108,30 +171,63 @@ if __name__ == "__main__":
sns.color_palette("hls", 10)
sns.hls_palette(8, l=.5, s=.7)
sns.set(context='notebook', style='ticks', rc=rc)
# sns.set_palette("Set3") # 设置颜色主题
# sns.set_style("dark")
# sns.color_palette("hls", 10)
# sns.hls_palette(8, l=.5, s=.7)
# sns.set(context='notebook', style='ticks', rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
# plt.ticklabel_format(style='plain')
# ax = df.plot(kind='bar', ax=ax)
# sns.barplot(df, x=x, y=y, hue= "Country", ax=ax)
sns.catplot(data=df, x=x, y=y, hue='Country', kind='bar')
# sns.barplot(df, x=x, y="Total_Sales", hue='Country', ax=ax)
# sns.barplot(df, x=x, y="Total_Profit", hue='Country', ax=ax)
# sns.catplot(data=df, x=x, y=y, hue='Country', kind='bar')
x,y, non_num_columns, num_colmns =data_pre_classification(df)
print(x, y, str(non_num_columns), str(num_colmns))
## 复杂折线图实现
if len(num_colmns)>0:
num_colmns.append(y)
df_melted = pd.melt(df, id_vars=x, value_vars=num_colmns, var_name='line', value_name='Value')
sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
else:
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
# hue = None
# ## 复杂柱状图实现
# x,y, non_num_columns, num_colmns =data_pre_classification(df)
# if len(non_num_columns) >= 1:
# hue = non_num_columns[0]
# if len(num_colmns)>=1:
# if hue:
# if len(num_colmns) >= 2:
# can_use_columns = num_colmns[:2]
# else:
# can_use_columns = num_colmns
# sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
# for sub_y_column in can_use_columns:
# sns.barplot(data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax)
# else:
# if len(num_colmns) >= 3:
# can_use_columns = num_colmns[:3]
# else:
# can_use_columns = num_colmns
# sns.barplot(data=df, x=x, y=y, hue=can_use_columns[0], palette="Set2", ax=ax)
#
# for sub_y_column in can_use_columns[1:]:
# sns.barplot(data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax)
# else:
# sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
# # 转换 DataFrame 格式
# df_melted = pd.melt(df, id_vars=x, value_vars=['Total_Sales', 'Total_Profit'], var_name='line', value_name='y')
#
# # 绘制多列柱状图
#
# sns.barplot(data=df, x=x, y="Total_Sales", hue = "Country", palette="Set2", ax=ax)
# sns.barplot(data=df, x=x, y="Total_Profit", hue = "Country", palette="Set1", ax=ax)
# 设置 y 轴刻度格式为普通数字格式
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: '{:,.0f}'.format(x)))
# fonts = font_manager.findSystemFonts()
# font_path = ""
# for font in fonts:
# if "Heiti" in font:
# font_path = font
# my_font = font_manager.FontProperties(fname=font_path)
# plt.title("测试", fontproperties=my_font)
# plt.ylabel(columns[1], fontproperties=my_font)
# plt.xlabel(columns[0], fontproperties=my_font)
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
@ -139,89 +235,4 @@ if __name__ == "__main__":
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
# sns.set(context="notebook", style="ticks", color_codes=True)
# sns.set_palette("Set3") # 设置颜色主题
#
# # fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
# fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
# plt.subplots_adjust(top=0.9)
# ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%')
# # 手动设置 labels 的位置和大小
# ax.legend(loc='center left', bbox_to_anchor=(-1, 0.5, 0,0), labels=None, fontsize=10)
# plt.axis('equal') # 使饼图为正圆形
# plt.show()
#
#
# def csv_colunm_foramt(val):
# if str(val).find("$") >= 0:
# return float(val.replace('$', '').replace(',', ''))
# if str(val).find("¥") >= 0:
# return float(val.replace('¥', '').replace(',', ''))
# return val
#
# # 获取当前时间戳,作为代码开始的时间
# start_time = int(time.time() * 1000)
#
# df = pd.read_excel('/Users/tuyang.yhj/Downloads/example.xlsx')
# # 读取 Excel 文件为 Pandas DataFrame
# df = pd.read_excel('/Users/tuyang.yhj/Downloads/example.xlsx', converters={i: csv_colunm_foramt for i in range(df.shape[1])})
#
# # d = df.values
# # print(d.shape[0])
# # for row in d:
# # print(row[0])
# # print(len(row))
# # r = df.iterrows()
#
# # 获取当前时间戳,作为代码结束的时间
# end_time = int(time.time() * 1000)
#
# print(f"耗时:{(end_time-start_time)/1000}秒")
#
# # 连接 DuckDB 数据库
# con = duckdb.connect(database=':memory:', read_only=False)
#
# # 将 DataFrame 写入 DuckDB 数据库中的一个表
# con.register('example', df)
#
# # 查询 DuckDB 数据库中的表
# conn = con.cursor()
# results = con.execute('SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country ORDER BY Total_Profit DESC LIMIT 1;')
# colunms = []
# for descrip in results.description:
# colunms.append(descrip[0])
# print(colunms)
# for row in results.fetchall():
# print(row)
#
#
# # 连接 DuckDB 数据库
# # con = duckdb.connect(':memory:')
#
# # # 加载 spatial 扩展
# # con.execute('install spatial;')
# # con.execute('load spatial;')
# #
# # # 查询 duckdb_internal 系统表,获取扩展列表
# # result = con.execute("SELECT * FROM duckdb_internal.functions WHERE schema='list_extensions';")
# #
# # # 遍历查询结果,输出扩展名称和版本号
# # for row in result:
# # print(row['name'], row['return_type'])
# # duckdb.read_csv('/Users/tuyang.yhj/Downloads/example_csc.csv')
# # result = duckdb.sql('SELECT * FROM "/Users/tuyang.yhj/Downloads/yhj-zx.csv" ')
# # result = duckdb.sql('SELECT * FROM "/Users/tuyang.yhj/Downloads/example_csc.csv" limit 20')
# # for row in result.fetchall():
# # print(row)
#
#
# # result = con.execute("SELECT * FROM st_read('/Users/tuyang.yhj/Downloads/example.xlsx', layer='Sheet1')")
# # # 遍历查询结果
# # for row in result.fetchall():
# # print(row)
# print("xx")
#
#
#

View File

@ -62,7 +62,8 @@ class ExcelReader:
for column_name in df_tmp.columns:
self.columns_map.update({column_name: excel_colunm_format(column_name)})
try:
self.df[column_name] = self.df[column_name].astype(float)
self.df[column_name] = pd.to_numeric(self.df[column_name])
self.df[column_name] = self.df[column_name].fillna(0)
except Exception as e:
print("transfor column error" + column_name)