diff --git a/pilot/base_modules/agent/commands/command_mange.py b/pilot/base_modules/agent/commands/command_mange.py index a2bf03b09..0932ad48e 100644 --- a/pilot/base_modules/agent/commands/command_mange.py +++ b/pilot/base_modules/agent/commands/command_mange.py @@ -5,7 +5,9 @@ import time import json import logging import xml.etree.ElementTree as ET +import pandas as pd +from pilot.common.json_utils import serialize from datetime import datetime from typing import Any, Callable, Optional, List from pydantic import BaseModel @@ -357,7 +359,7 @@ class ApiCall: if api_status.api_result: param["result"] = api_status.api_result - return json.dumps(param) + return json.dumps(param, default=serialize) def to_view_text(self, api_status: PluginStatus): api_call_element = ET.Element("dbgpt-view") @@ -391,7 +393,7 @@ class ApiCall: if api_status.api_result: param["data"] = api_status.api_result - return json.dumps(param) + return json.dumps(param, default=serialize) def run(self, llm_text): if self.__is_need_wait_plugin_call(llm_text): @@ -469,7 +471,7 @@ class ApiCall: if sql is not None and len(sql) > 0: data_df = sql_run_func(sql) value.df = data_df - value.api_result = data_df.apply(lambda row: row.to_dict(), axis=1).to_list() + value.api_result = json.loads(data_df.to_json(orient='records', date_format='iso', date_unit='s')) value.status = Status.COMPLETED.value else: value.status = Status.FAILED.value @@ -479,4 +481,6 @@ class ApiCall: value.status = Status.FAILED.value value.err_msg = str(e) value.end_time = datetime.now().timestamp() * 1000 - return self.api_view_context(llm_text, True) \ No newline at end of file + return self.api_view_context(llm_text, True) + + diff --git a/pilot/common/json_utils.py b/pilot/common/json_utils.py new file mode 100644 index 000000000..413c2d105 --- /dev/null +++ b/pilot/common/json_utils.py @@ -0,0 +1,6 @@ +import json +from datetime import date + +def serialize(obj): + if isinstance(obj, date): + return obj.isoformat() \ No newline at end of file diff --git a/pilot/scene/chat_data/chat_excel/excel_reader.py b/pilot/scene/chat_data/chat_excel/excel_reader.py index 779a37db3..175b3bf4e 100644 --- a/pilot/scene/chat_data/chat_excel/excel_reader.py +++ b/pilot/scene/chat_data/chat_excel/excel_reader.py @@ -4,7 +4,7 @@ import duckdb import os import re import sqlparse - +import pandas as pd import chardet import pandas as pd import numpy as np @@ -240,13 +240,24 @@ if __name__ == "__main__": # print(add_quotes_to_chinese_columns(sql_2)) # sql = """ SELECT 省份, 2021年, 2022年 as GDP FROM excel_data """ - sql = """ SELECT 省份, 2022年, 2021年 FROM excel_data """ - sql_2 = """ SELECT "省份", "2022年" as "2022年实际GDP增速", "2021年" as "2021年实际GDP增速" FROM excel_data """ - sql_3 = """ SELECT "省份", ("2022年" / ("2022年" + "2021年")) * 100 as "2022年实际GDP增速占比", ("2021年" / ("2022年" + "2021年")) * 100 as "2021年实际GDP增速占比" FROM excel_data """ + # sql = """ SELECT 省份, 2022年, 2021年 FROM excel_data """ + # sql_2 = """ SELECT "省份", "2022年" as "2022年实际GDP增速", "2021年" as "2021年实际GDP增速" FROM excel_data """ + # sql_3 = """ SELECT "省份", ("2022年" / ("2022年" + "2021年")) * 100 as "2022年实际GDP增速占比", ("2021年" / ("2022年" + "2021年")) * 100 as "2021年实际GDP增速占比" FROM excel_data """ + # + # sql = add_quotes_to_chinese_columns(sql_3) + # print(f"excute sql:{sql}") - sql = add_quotes_to_chinese_columns(sql_3) - print(f"excute sql:{sql}") + data = { + 'name': ['John', 'Alice', 'Bob'], + 'age': [30, 25, 35], + 'timestamp': [pd.Timestamp('2022-01-01'), pd.Timestamp('2022-01-02'), pd.Timestamp('2022-01-03')], + 'category': pd.Categorical(['A', 'B', 'C']) + } + df = pd.DataFrame(data) + + json_data = df.to_json(orient='records', date_format='iso', date_unit='s') + print(json_data) class ExcelReader: def __init__(self, file_path): diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 6cf7dd6a9..6fd3c5d7e 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -74,9 +74,9 @@ class ChatWithDbAutoExecute(BaseChat): def stream_plugin_call(self, text): text = text.replace("\n", " ") print(f"stream_plugin_call:{text}") - return self.api_call.run_display_sql(text, self.database.run_to_df) + return self.api_call.display_sql_llmvis(text, self.database.run_to_df) - # - # def do_action(self, prompt_response): - # print(f"do_action:{prompt_response}") - # return self.database.run(prompt_response.sql) + + def do_action(self, prompt_response): + print(f"do_action:{prompt_response}") + return self.database.run(prompt_response.sql) diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index b96899160..585146fbb 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -20,7 +20,9 @@ Constraint: 1. You can only use the table provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql query." It is prohibited to fabricate information at will. 2. Do not query columns that do not exist. Pay attention to which column is in which table. 3. Unless the user specifies in the question a specific number of examples he wishes to obtain, always limit the query to a maximum of {top_k} results. -4. Please ensure that the output result contains: response_tableSQL Query to run,and replace the generated sql into the parameter sql.Please make sure to respond as following format: +4. Please ensure that the output result contains: response_tableSQL Query to run,and replace the generated sql into the parameter sql. + +Please make sure to respond as following format: thoughts summary to say to user.response_tableSQL Query to run Question: {input}