diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py b/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py index 5c8e2a58e..4e69400a3 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py @@ -39,15 +39,19 @@ class LearningExcelOutputParser(BaseOutputParser): def parse_view_response(self, speak, data) -> str: ### tool out data to table view - html_title = f"### **数据简介:**\n{data.desciption} \n" - html_colunms = f"### **数据结构:**\n" + html_title = f"### **数据简介:**{data.desciption} " + html_colunms = f"
### **数据结构:**" + column_index = 0 for item in data.clounms: + column_index +=1 keys = item.keys() for key in keys: - html_colunms = html_colunms + f"- **{key}**:{item[key]} \n" + html_colunms = html_colunms + f"- **{column_index}.{key}** _{item[key]}_ \n" - html_plans = f"\n ### **分析计划:** \n" + html_plans = f"
### **分析计划:** " + index = 0 for item in data.plans: - html_plans = html_plans + f"- {item} \n" + index +=1 + html_plans = html_plans + f"{index}.{item} \n" html = f"""{html_title}{html_colunms}{html_plans}""" return html diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/test.py b/pilot/scene/chat_data/chat_excel/excel_learning/test.py index a51407c09..4fe66ae06 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/test.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/test.py @@ -18,7 +18,7 @@ if __name__ == "__main__": excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx") # colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear") - # colunms, datas = excel_reader.run( """ SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country; """) + colunms, datas = excel_reader.run( """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """) df = excel_reader.get_df_by_sql_ex("SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country;") columns = df.columns.tolist() plt.rcParams["font.family"] = ["sans-serif"] diff --git a/pilot/scene/chat_data/chat_excel/excel_reader.py b/pilot/scene/chat_data/chat_excel/excel_reader.py index e801b9127..1af45e66b 100644 --- a/pilot/scene/chat_data/chat_excel/excel_reader.py +++ b/pilot/scene/chat_data/chat_excel/excel_reader.py @@ -1,6 +1,9 @@ import duckdb import os +import re +import sqlparse import pandas as pd +import numpy as np from pilot.common.pd_utils import csv_colunm_foramt @@ -9,6 +12,31 @@ def excel_colunm_format(old_name:str)->str: new_column = new_column.replace(" ", "_") return new_column +def add_quotes(sql, column_names=[]): + parsed = sqlparse.parse(sql) + for stmt in parsed: + for token in stmt.tokens: + deep_quotes(token, column_names) + return str(parsed[0]) + +def deep_quotes(token, column_names=[]): + if hasattr(token, "tokens") : + for token_child in token.tokens: + deep_quotes(token_child, column_names) + else: + if token.ttype == sqlparse.tokens.Name: + if len(column_names) >0: + if token.value in column_names: + token.value = f'"{token.value}"' + else: + token.value = f'"{token.value}"' + +def is_chinese(string): + # 使用正则表达式匹配中文字符 + pattern = re.compile(r'[一-龥]') + match = re.search(pattern, string) + return match is not None + class ExcelReader: def __init__(self, file_path): @@ -28,9 +56,14 @@ class ExcelReader: else: raise ValueError("Unsupported file format.") + self.df.replace('', np.nan, inplace=True) self.columns_map = {} for column_name in df_tmp.columns: self.columns_map.update({column_name: excel_colunm_format(column_name)}) + try: + self.df[column_name] = self.df[column_name].astype(float) + except Exception as e: + print("transfor column error!" + column_name) self.df = self.df.rename(columns=lambda x: x.strip().replace(' ', '_')) @@ -43,6 +76,8 @@ class ExcelReader: self.db.register(self.table_name, self.df) def run(self, sql): + sql = sql.replace(self.table_name, f'"{self.table_name}"') + sql = add_quotes(sql, self.columns_map.values()) results = self.db.execute(sql) colunms = [] for descrip in results.description: