diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py b/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py
index 5c8e2a58e..4e69400a3 100644
--- a/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py
+++ b/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py
@@ -39,15 +39,19 @@ class LearningExcelOutputParser(BaseOutputParser):
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
- html_title = f"### **数据简介:**\n{data.desciption} \n"
- html_colunms = f"### **数据结构:**\n"
+ html_title = f"### **数据简介:**{data.desciption} "
+ html_colunms = f"
### **数据结构:**"
+ column_index = 0
for item in data.clounms:
+ column_index +=1
keys = item.keys()
for key in keys:
- html_colunms = html_colunms + f"- **{key}**:{item[key]} \n"
+ html_colunms = html_colunms + f"- **{column_index}.{key}** _{item[key]}_ \n"
- html_plans = f"\n ### **分析计划:** \n"
+ html_plans = f"
### **分析计划:** "
+ index = 0
for item in data.plans:
- html_plans = html_plans + f"- {item} \n"
+ index +=1
+ html_plans = html_plans + f"{index}.{item} \n"
html = f"""{html_title}{html_colunms}{html_plans}"""
return html
diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/test.py b/pilot/scene/chat_data/chat_excel/excel_learning/test.py
index a51407c09..4fe66ae06 100644
--- a/pilot/scene/chat_data/chat_excel/excel_learning/test.py
+++ b/pilot/scene/chat_data/chat_excel/excel_learning/test.py
@@ -18,7 +18,7 @@ if __name__ == "__main__":
excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx")
# colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear")
- # colunms, datas = excel_reader.run( """ SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country; """)
+ colunms, datas = excel_reader.run( """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """)
df = excel_reader.get_df_by_sql_ex("SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country;")
columns = df.columns.tolist()
plt.rcParams["font.family"] = ["sans-serif"]
diff --git a/pilot/scene/chat_data/chat_excel/excel_reader.py b/pilot/scene/chat_data/chat_excel/excel_reader.py
index e801b9127..1af45e66b 100644
--- a/pilot/scene/chat_data/chat_excel/excel_reader.py
+++ b/pilot/scene/chat_data/chat_excel/excel_reader.py
@@ -1,6 +1,9 @@
import duckdb
import os
+import re
+import sqlparse
import pandas as pd
+import numpy as np
from pilot.common.pd_utils import csv_colunm_foramt
@@ -9,6 +12,31 @@ def excel_colunm_format(old_name:str)->str:
new_column = new_column.replace(" ", "_")
return new_column
+def add_quotes(sql, column_names=[]):
+ parsed = sqlparse.parse(sql)
+ for stmt in parsed:
+ for token in stmt.tokens:
+ deep_quotes(token, column_names)
+ return str(parsed[0])
+
+def deep_quotes(token, column_names=[]):
+ if hasattr(token, "tokens") :
+ for token_child in token.tokens:
+ deep_quotes(token_child, column_names)
+ else:
+ if token.ttype == sqlparse.tokens.Name:
+ if len(column_names) >0:
+ if token.value in column_names:
+ token.value = f'"{token.value}"'
+ else:
+ token.value = f'"{token.value}"'
+
+def is_chinese(string):
+ # 使用正则表达式匹配中文字符
+ pattern = re.compile(r'[一-龥]')
+ match = re.search(pattern, string)
+ return match is not None
+
class ExcelReader:
def __init__(self, file_path):
@@ -28,9 +56,14 @@ class ExcelReader:
else:
raise ValueError("Unsupported file format.")
+ self.df.replace('', np.nan, inplace=True)
self.columns_map = {}
for column_name in df_tmp.columns:
self.columns_map.update({column_name: excel_colunm_format(column_name)})
+ try:
+ self.df[column_name] = self.df[column_name].astype(float)
+ except Exception as e:
+ print("transfor column error!" + column_name)
self.df = self.df.rename(columns=lambda x: x.strip().replace(' ', '_'))
@@ -43,6 +76,8 @@ class ExcelReader:
self.db.register(self.table_name, self.df)
def run(self, sql):
+ sql = sql.replace(self.table_name, f'"{self.table_name}"')
+ sql = add_quotes(sql, self.columns_map.values())
results = self.db.execute(sql)
colunms = []
for descrip in results.description: