diff --git a/pilot/commands/disply_type/show_chart_gen.py b/pilot/commands/disply_type/show_chart_gen.py index dd3ddc963..254b9fed1 100644 --- a/pilot/commands/disply_type/show_chart_gen.py +++ b/pilot/commands/disply_type/show_chart_gen.py @@ -84,9 +84,15 @@ def response_pie_chart(speak: str, df: DataFrame) -> str: sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"}) sns.set(context="notebook", style="ticks", color_codes=True, rc=rc) sns.set_palette("Set3") # 设置颜色主题 - fig, ax = plt.pie(columns[1], labels=columns[0], autopct='%1.1f%%', startangle=90) + + # fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90) + fig, ax = plt.subplots(figsize=(8, 5), dpi=100) + ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%') + # 手动设置 labels 的位置和大小 + ax.legend(loc='upper right', bbox_to_anchor=(1, 1, 1, 1), labels=df[columns[0]].values, fontsize=10) + plt.axis('equal') # 使饼图为正圆形 - plt.title(columns[0]) + # plt.title(columns[0]) buf = io.BytesIO() ax.set_facecolor("lightgray") diff --git a/pilot/commands/disply_type/show_text_gen.py b/pilot/commands/disply_type/show_text_gen.py index 2ed43a428..16932ff1c 100644 --- a/pilot/commands/disply_type/show_text_gen.py +++ b/pilot/commands/disply_type/show_text_gen.py @@ -17,7 +17,8 @@ def response_data_text(speak: str, df: DataFrame) -> str: data = df.values row_size = data.shape[0] - value_str, text_info = "" + value_str = "" + text_info = "" if row_size > 1: html_table = df.to_html(index=False, escape=False, sparsify=False) table_str = "".join(html_table.split()) @@ -26,7 +27,10 @@ def response_data_text(speak: str, df: DataFrame) -> str: elif row_size == 1: row = data[0] for value in row: - value_str = value_str + f", ** {value} **" + if value_str: + value_str = value_str + f", ** {value} **" + else: + value_str = f" ** {value} **" text_info = f"{speak}: {value_str}" else: text_info = f"##### {speak}: _没有找到可用的数据_" diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index cb7fc1145..828521b00 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -130,8 +130,10 @@ async def dialogue_list(user_id: str = None): messages = json.loads(item.get("messages")) last_round = max(messages, key=lambda x: x['chat_order']) - select_param = last_round["param_value"] - + if "param_value" in last_round: + select_param = last_round["param_value"] + else: + select_param = "" conv_vo: ConversationVo = ConversationVo( conv_uid=conv_uid, user_input=summary, diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/test.py b/pilot/scene/chat_data/chat_excel/excel_learning/test.py index 373b6fac1..a51407c09 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/test.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/test.py @@ -1,7 +1,10 @@ import os import duckdb import pandas as pd +import matplotlib +import seaborn as sns +import matplotlib.pyplot as plt import time from fsspec import filesystem import spatial @@ -15,9 +18,23 @@ if __name__ == "__main__": excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx") # colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear") - colunms, datas = excel_reader.run( """ SELECT Month_Name, SUM(Sales) AS Total_Sales FROM example WHERE Year = '2019' GROUP BY Month_Name """) + # colunms, datas = excel_reader.run( """ SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country; """) + df = excel_reader.get_df_by_sql_ex("SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country;") + columns = df.columns.tolist() + plt.rcParams["font.family"] = ["sans-serif"] + rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False} + sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"}) + sns.set(context="notebook", style="ticks", color_codes=True, rc=rc) + sns.set_palette("Set3") # 设置颜色主题 - print("xx") + # fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90) + fig, ax = plt.subplots(figsize=(8, 5), dpi=100) + plt.subplots_adjust(top=0.9) + ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%') + # 手动设置 labels 的位置和大小 + ax.legend(loc='center left', bbox_to_anchor=(-1, 0.5, 0,0), labels=None, fontsize=10) + plt.axis('equal') # 使饼图为正圆形 + plt.show() # # # def csv_colunm_foramt(val): diff --git a/pilot/scene/chat_data/chat_excel/excel_reader.py b/pilot/scene/chat_data/chat_excel/excel_reader.py index d48eb00f8..e801b9127 100644 --- a/pilot/scene/chat_data/chat_excel/excel_reader.py +++ b/pilot/scene/chat_data/chat_excel/excel_reader.py @@ -12,11 +12,22 @@ def excel_colunm_format(old_name:str)->str: class ExcelReader: def __init__(self, file_path): - # read excel filt - df_tmp = pd.read_excel(file_path) + file_name = os.path.basename(file_path) + file_name_without_extension = os.path.splitext(file_name)[0] + + self.excel_file_name = file_name + self.extension = os.path.splitext(file_name)[1] + # read excel file + if file_path.endswith('.xlsx') or file_path.endswith('.xls'): + df_tmp = pd.read_excel(file_path) + self.df = pd.read_excel(file_path, converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])}) + elif file_path.endswith('.csv'): + df_tmp = pd.read_csv(file_path) + self.df = pd.read_csv(file_path, converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])}) + else: + raise ValueError("Unsupported file format.") - self.df = pd.read_excel(file_path, converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])}) self.columns_map = {} for column_name in df_tmp.columns: self.columns_map.update({column_name: excel_colunm_format(column_name)}) @@ -25,13 +36,8 @@ class ExcelReader: # connect DuckDB self.db = duckdb.connect(database=':memory:', read_only=False) - file_name = os.path.basename(file_path) - file_name_without_extension = os.path.splitext(file_name)[0] - self.excel_file_name = file_name - self.extension = os.path.splitext(file_name)[1] - self.table_name = file_name_without_extension # write data in duckdb self.db.register(self.table_name, self.df) @@ -43,9 +49,6 @@ class ExcelReader: colunms.append(descrip[0]) return colunms, results.fetchall() - def get_df_by_sql(self, sql): - return pd.read_sql(sql, self.db) - def get_df_by_sql_ex(self, sql): colunms, values = self.run(sql) return pd.DataFrame(values, columns=colunms)