diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 78286f3d3..0dfe82d57 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -26,6 +26,10 @@ class BaseOutputParser(ABC): def __init__(self, sep: str, is_stream_out: bool = True): self.sep = sep self.is_stream_out = is_stream_out + self.data_schema = None + + def update(self, data_schema): + self.data_schema = data_schema def __post_process_code(self, code): sep = "\n```" diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py index f05221eba..e5c3517be 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py @@ -44,11 +44,13 @@ class ExcelLearning(BaseChat): if parent_mode: self.current_message.chat_mode = parent_mode.value() + async def generate_input_values(self) -> Dict: # colunms, datas = self.excel_reader.get_sample_data() colunms, datas = await blocking_func_to_async( self._executor, self.excel_reader.get_sample_data ) + self.prompt_template.output_parser.update(colunms) copy_datas = datas.copy() datas.insert(0, colunms) diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py b/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py index e3bcfb8ea..78ba72659 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py @@ -19,11 +19,13 @@ logger = logging.getLogger(__name__) class LearningExcelOutputParser(BaseOutputParser): def __init__(self, sep: str, is_stream_out: bool): super().__init__(sep=sep, is_stream_out=is_stream_out) + self.is_downgraded = False + def parse_prompt_response(self, model_out_text): - clean_str = super().parse_prompt_response(model_out_text) - print("clean prompt response:", clean_str) try: + clean_str = super().parse_prompt_response(model_out_text) + logger.info(f"parse_prompt_response:{model_out_text},{model_out_text}") response = json.loads(clean_str) for key in sorted(response): if key.strip() == "DataAnalysis": @@ -34,27 +36,39 @@ class LearningExcelOutputParser(BaseOutputParser): plans = response[key] return ExcelResponse(desciption=desciption, clounms=clounms, plans=plans) except Exception as e: - return model_out_text + logger.error(f"parse_prompt_response Faild!{str(e)}") + self.is_downgraded = True + return ExcelResponse(desciption=model_out_text, clounms=self.data_schema, plans=None) + def parse_view_response(self, speak, data) -> str: if data and not isinstance(data, str): ### tool out data to table view html_title = f"### **Data Summary**\n{data.desciption} " html_colunms = f"### **Data Structure**\n" - column_index = 0 - for item in data.clounms: - column_index += 1 - keys = item.keys() - for key in keys: + if self.is_downgraded: + column_index = 0 + for item in data.clounms: + column_index += 1 html_colunms = ( - html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n" + html_colunms + f"- **{column_index}.[{item}]** _未知_\n" ) + else: + column_index = 0 + for item in data.clounms: + column_index += 1 + keys = item.keys() + for key in keys: + html_colunms = ( + html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n" + ) html_plans = f"### **Recommended analysis plan**\n" index = 0 - for item in data.plans: - index += 1 - html_plans = html_plans + f"{item} \n" + if data.plans: + for item in data.plans: + index += 1 + html_plans = html_plans + f"{item} \n" html = f"""{html_title}\n{html_colunms}\n{html_plans}""" return html else: diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/test.py b/pilot/scene/chat_data/chat_excel/excel_learning/test.py deleted file mode 100644 index fdfaf146a..000000000 --- a/pilot/scene/chat_data/chat_excel/excel_learning/test.py +++ /dev/null @@ -1,271 +0,0 @@ -import os -import duckdb -import pandas as pd -import numpy as np -import matplotlib -import seaborn as sns -import uuid - -from pandas import DataFrame - -import matplotlib.pyplot as plt -import matplotlib.ticker as mtick -from matplotlib import font_manager -from matplotlib.font_manager import FontManager - -matplotlib.use("Agg") -import time -from fsspec import filesystem -import spatial - -from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader - - -def data_pre_classification(df: DataFrame): - ## Data pre-classification - columns = df.columns.tolist() - - number_columns = [] - non_numeric_colums = [] - - # 收集数据分类小于10个的列 - non_numeric_colums_value_map = {} - numeric_colums_value_map = {} - df_filtered = df.dropna() - for column_name in columns: - print(np.issubdtype(df_filtered[column_name].dtype, np.number)) - # if pd.to_numeric(df[column_name], errors='coerce').notna().all(): - # if np.issubdtype(df_filtered[column_name].dtype, np.number): - if pd.api.types.is_numeric_dtype(df[column_name].dtypes): - number_columns.append(column_name) - unique_values = df[column_name].unique() - numeric_colums_value_map.update({column_name: len(unique_values)}) - else: - non_numeric_colums.append(column_name) - unique_values = df[column_name].unique() - non_numeric_colums_value_map.update({column_name: len(unique_values)}) - - sorted_numeric_colums_value_map = dict( - sorted(numeric_colums_value_map.items(), key=lambda x: x[1]) - ) - numeric_colums_sort_list = list(sorted_numeric_colums_value_map.keys()) - - sorted_colums_value_map = dict( - sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1]) - ) - non_numeric_colums_sort_list = list(sorted_colums_value_map.keys()) - - # Analyze x-coordinate - if len(non_numeric_colums_sort_list) > 0: - x_cloumn = non_numeric_colums_sort_list[-1] - non_numeric_colums_sort_list.remove(x_cloumn) - else: - x_cloumn = number_columns[0] - numeric_colums_sort_list.remove(x_cloumn) - - # Analyze y-coordinate - if len(numeric_colums_sort_list) > 0: - y_column = numeric_colums_sort_list[0] - numeric_colums_sort_list.remove(y_column) - else: - raise ValueError("Not enough numeric columns for chart!") - - return x_cloumn, y_column, non_numeric_colums_sort_list, numeric_colums_sort_list - - # - # if len(non_numeric_colums) <=0: - # sorted_colums_value_map = dict(sorted(numeric_colums_value_map.items(), key=lambda x: x[1])) - # numeric_colums_sort_list = list(sorted_colums_value_map.keys()) - # x_column = number_columns[0] - # hue_column = numeric_colums_sort_list[0] - # y_column = numeric_colums_sort_list[1] - # cols = numeric_colums_sort_list[2:] - # elif len(number_columns) <=0: - # raise ValueError("Have No numeric Column!") - # else: - # # 数字和非数字都存在多列,放弃部分数字列 - # x_column = non_numeric_colums[0] - # y_column = number_columns[0] - # if len(non_numeric_colums) > 1: - # sorted_colums_value_map = dict(sorted(non_numeric_colums_value_map.items(), key=lambda x: x[1])) - # non_numeric_colums_sort_list = list(sorted_colums_value_map.keys()) - # non_numeric_colums_sort_list.remove(non_numeric_colums[0]) - # hue_column = non_numeric_colums_sort_list[0] - # if len(number_columns) > 1: - # # try multiple charts - # cols = number_columns.remove( number_columns[0]) - # - # else: - # sorted_colums_value_map = dict(sorted(numeric_colums_value_map.items(), key=lambda x: x[1])) - # numeric_colums_sort_list = list(sorted_colums_value_map.keys()) - # numeric_colums_sort_list.remove(number_columns[0]) - # if sorted_colums_value_map[numeric_colums_sort_list[0]].value < 5: - # hue_column = numeric_colums_sort_list[0] - # if len(number_columns) > 2: - # # try multiple charts - # cols = numeric_colums_sort_list.remove(numeric_colums_sort_list[0]) - # - # print(x_column, y_column, hue_column, cols) - # return x_column, y_column, hue_column - - -if __name__ == "__main__": - # connect = duckdb.connect("/Users/tuyang.yhj/Downloads/example.xlsx") - # - - # fonts = fm.findSystemFonts() - # for font in fonts: - # if 'Hei' in font: - # print(font) - - # fm = FontManager() - # mat_fonts = set(f.name for f in fm.ttflist) - # for i in mat_fonts: - # print(i) - # print(len(mat_fonts)) - # 获取系统中的默认中文字体名称 - # default_font = fm.fontManager.defaultFontProperties.get_family() - - # 创建一个示例 DataFrame - df = pd.DataFrame( - { - "A": [1, 2, 3, None, 5], - "B": [10, 20, 30, 40, 50], - "C": [1.1, 2.2, None, 4.4, 5.5], - "D": ["a", "b", "c", "d", "e"], - } - ) - - # 判断列是否为数字列 - column_name = "A" # 要判断的列名 - is_numeric = pd.to_numeric(df[column_name], errors="coerce").notna().all() - - if is_numeric: - print( - f"Column '{column_name}' is a numeric column (ignoring null and NaN values in some elements)." - ) - else: - print( - f"Column '{column_name}' is not a numeric column (ignoring null and NaN values in some elements)." - ) - - # - # excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx") - excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/yhj-zx.csv") - # - # # colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear") - # # colunms, datas = excel_reader.run( """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """) - # df = excel_reader.get_df_by_sql_ex(""" SELECT Segment, Country, SUM(Sales) AS Total_Sales, SUM(Profit) AS Total_Profit FROM example GROUP BY Segment, Country """) - df = excel_reader.get_df_by_sql_ex( - """ SELECT 大项, AVG(实际) AS 平均实际支出, AVG(已支出) AS 平均已支出 FROM yhj-zx GROUP BY 大项""" - ) - - for column_name in df.columns.tolist(): - print(column_name + ":" + str(df[column_name].dtypes)) - print( - column_name - + ":" - + str(pd.api.types.is_numeric_dtype(df[column_name].dtypes)) - ) - - columns = df.columns.tolist() - font_names = [ - "Heiti TC", - "Songti SC", - "STHeiti Light", - "Microsoft YaHei", - "SimSun", - "SimHei", - "KaiTi", - ] - fm = FontManager() - mat_fonts = set(f.name for f in fm.ttflist) - can_use_fonts = [] - for font_name in font_names: - if font_name in mat_fonts: - can_use_fonts.append(font_name) - if len(can_use_fonts) > 0: - plt.rcParams["font.sans-serif"] = can_use_fonts - - rc = {"font.sans-serif": can_use_fonts} - plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题 - sns.set(font="Heiti TC", font_scale=0.8) # 解决Seaborn中文显示问题 - sns.set_palette("Set3") # 设置颜色主题 - sns.set_style("dark") - sns.color_palette("hls", 10) - sns.hls_palette(8, l=0.5, s=0.7) - sns.set(context="notebook", style="ticks", rc=rc) - - fig, ax = plt.subplots(figsize=(8, 5), dpi=100) - # plt.ticklabel_format(style='plain') - # ax = df.plot(kind='bar', ax=ax) - # sns.barplot(df, x=x, y="Total_Sales", hue='Country', ax=ax) - # sns.barplot(df, x=x, y="Total_Profit", hue='Country', ax=ax) - - # sns.catplot(data=df, x=x, y=y, hue='Country', kind='bar') - # x, y, non_num_columns, num_colmns = data_pre_classification(df) - # print(x, y, str(non_num_columns), str(num_colmns)) - ## 复杂折线图实现 - # if len(num_colmns) > 0: - # num_colmns.append(y) - # df_melted = pd.melt( - # df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value" - # ) - # sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2") - # else: - # sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2") - - hue = None - ## 复杂柱状图实现 - x, y, non_num_columns, num_colmns = data_pre_classification(df) - - if len(non_num_columns) >= 1: - hue = non_num_columns[0] - - if len(num_colmns) >= 1: - if hue: - if len(num_colmns) >= 2: - can_use_columns = num_colmns[:2] - else: - can_use_columns = num_colmns - sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax) - for sub_y_column in can_use_columns: - sns.barplot( - data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax - ) - else: - if len(num_colmns) > 5: - can_use_columns = num_colmns[:5] - else: - can_use_columns = num_colmns - can_use_columns.append(y) - - df_melted = pd.melt( - df, - id_vars=x, - value_vars=can_use_columns, - var_name="line", - value_name="Value", - ) - sns.barplot( - data=df_melted, x=x, y="Value", hue="line", palette="Set2", ax=ax - ) - else: - sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax) - - # # 转换 DataFrame 格式 - # df_melted = pd.melt(df, id_vars=x, value_vars=['Total_Sales', 'Total_Profit'], var_name='line', value_name='y') - # - # # 绘制多列柱状图 - # - # sns.barplot(data=df, x=x, y="Total_Sales", hue = "Country", palette="Set2", ax=ax) - # sns.barplot(data=df, x=x, y="Total_Profit", hue = "Country", palette="Set1", ax=ax) - - # 设置 y 轴刻度格式为普通数字格式 - ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x))) - - chart_name = "bar_" + str(uuid.uuid1()) + ".png" - chart_path = chart_name - plt.savefig(chart_path, bbox_inches="tight", dpi=100) - - #