diff --git a/pilot/commands/disply_type/show_chart_gen.py b/pilot/commands/disply_type/show_chart_gen.py index 6a3ee3e27..052b9d971 100644 --- a/pilot/commands/disply_type/show_chart_gen.py +++ b/pilot/commands/disply_type/show_chart_gen.py @@ -210,18 +210,22 @@ def response_bar_chart(speak: str, df: DataFrame) -> str: data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax ) else: - if len(num_colmns) >= 3: - can_use_columns = num_colmns[:3] + if len(num_colmns) > 5: + can_use_columns = num_colmns[:5] else: can_use_columns = num_colmns - sns.barplot( - data=df, x=x, y=y, hue=can_use_columns[0], palette="Set2", ax=ax - ) + can_use_columns.append(y) - for sub_y_column in can_use_columns[1:]: - sns.barplot( - data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax - ) + df_melted = pd.melt( + df, + id_vars=x, + value_vars=can_use_columns, + var_name="line", + value_name="Value", + ) + sns.barplot( + data=df_melted, x=x, y="Value", hue="line", palette="Set2", ax=ax + ) else: sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax) diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py index 7f6ae856c..9cc0b43ff 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -104,4 +104,7 @@ class ChatExcel(BaseChat): "speak": prompt_response.thoughts, "df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql), } - return CFG.command_disply.call(prompt_response.display, **param) + if CFG.command_disply.get_command(prompt_response.display): + return CFG.command_disply.call(prompt_response.display, **param) + else: + return CFG.command_disply.call("response_table", **param) diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/test.py b/pilot/scene/chat_data/chat_excel/excel_learning/test.py index 7dc1a2746..fdfaf146a 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/test.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/test.py @@ -157,7 +157,7 @@ if __name__ == "__main__": # # colunms, datas = excel_reader.run( """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """) # df = excel_reader.get_df_by_sql_ex(""" SELECT Segment, Country, SUM(Sales) AS Total_Sales, SUM(Profit) AS Total_Profit FROM example GROUP BY Segment, Country """) df = excel_reader.get_df_by_sql_ex( - """ SELECT `明细`, `费用小计`, `支出小计` FROM yhj-zx limit 10""" + """ SELECT 大项, AVG(实际) AS 平均实际支出, AVG(已支出) AS 平均已支出 FROM yhj-zx GROUP BY 大项""" ) for column_name in df.columns.tolist(): @@ -203,44 +203,55 @@ if __name__ == "__main__": # sns.barplot(df, x=x, y="Total_Profit", hue='Country', ax=ax) # sns.catplot(data=df, x=x, y=y, hue='Country', kind='bar') - x, y, non_num_columns, num_colmns = data_pre_classification(df) - print(x, y, str(non_num_columns), str(num_colmns)) + # x, y, non_num_columns, num_colmns = data_pre_classification(df) + # print(x, y, str(non_num_columns), str(num_colmns)) ## 复杂折线图实现 - if len(num_colmns) > 0: - num_colmns.append(y) - df_melted = pd.melt( - df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value" - ) - sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2") - else: - sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2") - - # hue = None - # ## 复杂柱状图实现 - # x,y, non_num_columns, num_colmns =data_pre_classification(df) - # if len(non_num_columns) >= 1: - # hue = non_num_columns[0] - - # if len(num_colmns)>=1: - # if hue: - # if len(num_colmns) >= 2: - # can_use_columns = num_colmns[:2] - # else: - # can_use_columns = num_colmns - # sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax) - # for sub_y_column in can_use_columns: - # sns.barplot(data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax) - # else: - # if len(num_colmns) >= 3: - # can_use_columns = num_colmns[:3] - # else: - # can_use_columns = num_colmns - # sns.barplot(data=df, x=x, y=y, hue=can_use_columns[0], palette="Set2", ax=ax) - # - # for sub_y_column in can_use_columns[1:]: - # sns.barplot(data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax) + # if len(num_colmns) > 0: + # num_colmns.append(y) + # df_melted = pd.melt( + # df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value" + # ) + # sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2") # else: - # sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax) + # sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2") + + hue = None + ## 复杂柱状图实现 + x, y, non_num_columns, num_colmns = data_pre_classification(df) + + if len(non_num_columns) >= 1: + hue = non_num_columns[0] + + if len(num_colmns) >= 1: + if hue: + if len(num_colmns) >= 2: + can_use_columns = num_colmns[:2] + else: + can_use_columns = num_colmns + sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax) + for sub_y_column in can_use_columns: + sns.barplot( + data=df, x=x, y=sub_y_column, hue=hue, palette="Set2", ax=ax + ) + else: + if len(num_colmns) > 5: + can_use_columns = num_colmns[:5] + else: + can_use_columns = num_colmns + can_use_columns.append(y) + + df_melted = pd.melt( + df, + id_vars=x, + value_vars=can_use_columns, + var_name="line", + value_name="Value", + ) + sns.barplot( + data=df_melted, x=x, y="Value", hue="line", palette="Set2", ax=ax + ) + else: + sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax) # # 转换 DataFrame 格式 # df_melted = pd.melt(df, id_vars=x, value_vars=['Total_Sales', 'Total_Profit'], var_name='line', value_name='y')