diff --git a/pilot/base_modules/agent/commands/command_mange.py b/pilot/base_modules/agent/commands/command_mange.py
index 10bb5c340..19bcd4462 100644
--- a/pilot/base_modules/agent/commands/command_mange.py
+++ b/pilot/base_modules/agent/commands/command_mange.py
@@ -1,16 +1,18 @@
import functools
import importlib
import inspect
+import time
import json
import logging
import xml.etree.ElementTree as ET
from datetime import datetime
-from typing import Any, Callable, Optional
+from typing import Any, Callable, Optional, List
+from pydantic import BaseModel
from pilot.base_modules.agent.common.schema import Status, ApiTagType
from pilot.base_modules.agent.commands.command import execute_command
from pilot.base_modules.agent.commands.generator import PluginPromptGenerator
-from pilot.common.string_utils import extract_content_include, extract_content_open_ending, extract_content, extract_content_include_open_ending
+from pilot.common.string_utils import extract_content_open_ending, extract_content
# Unique identifier for auto-gpt commands
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
@@ -165,107 +167,179 @@ def command(
return decorator
+class PluginStatus(BaseModel):
+ name: str
+ location: List[int]
+ args: dict
+ status: Status = Status.TODO.value
+ logo_url: str = None
+ api_result: str = None
+ err_msg: str = None
+ start_time = datetime.now().timestamp() * 1000
+ end_time: int = None
+
+
class ApiCall:
agent_prefix = ""
agent_end = ""
name_prefix = ""
name_end = ""
- def __init__(self, plugin_generator):
- self.name: str = ""
- self.status: Status = Status.TODO.value
- self.logo_url: str = None
- self.args = {}
- self.api_result: str = None
- self.err_msg: str = None
+ def __init__(self, plugin_generator: Any = None, display_registry: Any = None):
+ # self.name: str = ""
+ # self.status: Status = Status.TODO.value
+ # self.logo_url: str = None
+ # self.args = {}
+ # self.api_result: str = None
+ # self.err_msg: str = None
+
+ self.plugin_status_map = {}
+
self.plugin_generator = plugin_generator
+ self.display_registry = display_registry
+ self.start_time = datetime.now().timestamp() * 1000
def __repr__(self):
return f"ApiCall(name={self.name}, status={self.status}, args={self.args})"
-
def __is_need_wait_plugin_call(self, api_call_context):
+ start_agent_count = api_call_context.count(self.agent_prefix)
+ end_agent_count = api_call_context.count(self.agent_end)
- if api_call_context.find(self.agent_prefix) >= 0:
+ if start_agent_count > 0:
return True
- check_len = len(self.agent_prefix)
- last_text = api_call_context[-check_len:]
- for i in range(check_len):
- text_tmp = last_text[-i:]
- prefix_tmp = self.agent_prefix[:i]
- if text_tmp == prefix_tmp:
- return True
- else:
- i += 1
+ else:
+ # 末尾新出字符检测
+ check_len = len(self.agent_prefix)
+ last_text = api_call_context[-check_len:]
+ for i in range(check_len):
+ text_tmp = last_text[-i:]
+ prefix_tmp = self.agent_prefix[:i]
+ if text_tmp == prefix_tmp:
+ return True
+ else:
+ i += 1
return False
- def __get_api_call_context(self, all_context):
- return extract_content_include(all_context, self.agent_prefix, self.agent_end)
+ def __check_last_plugin_call_ready(self, all_context):
+ start_agent_count = all_context.count(self.agent_prefix)
+ end_agent_count = all_context.count(self.agent_end)
- def __check_plugin_call_ready(self, all_context):
- if all_context.find(self.agent_end) > 0:
+ if start_agent_count > 0 and start_agent_count == end_agent_count:
return True
+ return False
- def api_view_context(self, all_context:str):
- if all_context.find(self.agent_prefix) >= 0:
- call_context = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end)
- call_context_all = extract_content_include_open_ending(all_context, self.agent_prefix, self.agent_end)
- if len(call_context) > 0:
- name_context = extract_content(call_context, self.name_prefix, self.name_end)
- if len(name_context) > 0:
- self.name = name_context
- return all_context.replace(call_context_all, self.to_view_text())
- else:
- return all_context
+ def api_view_context(self, all_context: str, display_mode: bool = False):
+ call_context_map = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end, True)
+ for api_index, api_context in call_context_map.items():
+ api_status = self.plugin_status_map.get(api_context)
+ if api_status is not None:
+ if display_mode:
+ if api_status.api_result:
+ all_context = all_context.replace(api_context, api_status.api_result)
+ else:
+ if api_status.status == Status.FAILED.value:
+ all_context = all_context.replace(api_context, f"""\nERROR!{api_status.err_msg}\n """)
+ else:
+ cost = (api_status.end_time - self.start_time) / 1000
+ cost_str = "{:.2f}".format(cost)
+ all_context = all_context.replace(api_context, f'\nWaiting...{cost_str}S\n')
+ else:
+ all_context = all_context.replace(api_context, self.to_view_text(api_status))
+ else:
+ # not ready api call view change
+ now_time = datetime.now().timestamp() * 1000
+ cost = (now_time - self.start_time) / 1000
+ cost_str = "{:.2f}".format(cost)
+ all_context = all_context.replace(api_context, f'\nWaiting... * {cost_str}S * \n')
+
+ return all_context
def update_from_context(self, all_context):
logging.info(f"from_context:{all_context}")
- api_context = extract_content_include(all_context, self.agent_prefix, self.agent_end)
- api_context = api_context.replace("\\n", "").replace("\n", "")
- api_call_element = ET.fromstring(api_context)
- self.name = api_call_element.find('name').text
+ api_context_map = extract_content(all_context, self.agent_prefix, self.agent_end, True)
+ for api_index, api_context in api_context_map.items():
+ api_context = api_context.replace("\\n", "").replace("\n", "")
+ api_call_element = ET.fromstring(api_context)
+ api_name = api_call_element.find('name').text
+ api_args = {}
+ args_elements = api_call_element.find('args')
+ for child_element in args_elements.iter():
+ api_args[child_element.tag] = child_element.text
- args_elements = api_call_element.find('args')
- for child_element in args_elements.iter():
- self.args[child_element.tag] = child_element.text
+ api_status = self.plugin_status_map.get(api_context)
+ if api_status is None:
+ api_status = PluginStatus(name=api_name, location=[api_index], args=api_args)
+ self.plugin_status_map[api_context] = api_status
+ else:
+ api_status.location.append(api_index)
- def __to_view_param_str(self):
+ def __to_view_param_str(self, api_status):
param = {}
- if self.name:
- param['name'] = self.name
- param['status'] = self.status
- if self.logo_url:
- param['logo'] = self.logo_url
+ if api_status.name:
+ param['name'] = api_status.name
+ param['status'] = api_status.status
+ if api_status.logo_url:
+ param['logo'] = api_status.logo_url
- if self.err_msg:
- param['err_msg'] = self.err_msg
-
- if self.api_result:
- param['result'] = self.api_result
+ if api_status.err_msg:
+ param['err_msg'] = api_status.err_msg
+ if api_status.api_result:
+ param['result'] = api_status.api_result
return json.dumps(param)
- def to_view_text(self):
+ def to_view_text(self, api_status: PluginStatus):
api_call_element = ET.Element('dbgpt-view')
- api_call_element.text = self.__to_view_param_str()
+ api_call_element.text = self.__to_view_param_str(api_status)
result = ET.tostring(api_call_element, encoding="utf-8")
return result.decode("utf-8")
-
def run(self, llm_text):
print(f"stream_plugin_call:{llm_text}")
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
- if self.__check_plugin_call_ready(llm_text):
+ if self.__check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
- if self.status == Status.TODO.value:
- self.status = Status.RUNNING.value
- logging.info(f"插件执行:{self.name},{self.args}")
- try:
- self.api_result = execute_command(self.name, self.args, self.plugin_generator)
- self.status = Status.COMPLETED.value
- except Exception as e:
- self.status = Status.FAILED.value
- self.err_msg = str(e)
- return self.api_view_context(llm_text)
\ No newline at end of file
+ for key, value in self.plugin_status_map:
+ if value.status == Status.TODO.value:
+ value.status = Status.RUNNING.value
+ logging.info(f"插件执行:{value.name},{value.args}")
+ try:
+ value.api_result = execute_command(value.name, value.args, self.plugin_generator)
+ value.status = Status.COMPLETED.value
+ except Exception as e:
+ value.status = Status.FAILED.value
+ value.err_msg = str(e)
+ value.end_time = datetime.now().timestamp() * 1000
+ return self.api_view_context(llm_text)
+
+ def run_display_sql(self, llm_text, sql_run_func):
+ print(f"get_display_sql:{llm_text}")
+
+ if self.__is_need_wait_plugin_call(llm_text):
+ # wait api call generate complete
+ if self.__check_last_plugin_call_ready(llm_text):
+ self.update_from_context(llm_text)
+ for key, value in self.plugin_status_map.items():
+ if value.status == Status.TODO.value:
+ value.status = Status.RUNNING.value
+ logging.info(f"sql展示执行:{value.name},{value.args}")
+ try:
+ sql = value.args['sql']
+ if sql:
+ param = {
+ "df": sql_run_func(sql),
+ }
+ if self.display_registry.get_command(value.name):
+ value.api_result = self.display_registry.call(value.name, **param)
+ else:
+ value.api_result = self.display_registry.call("response_table", **param)
+
+ value.status = Status.COMPLETED.value
+ except Exception as e:
+ value.status = Status.FAILED.value
+ value.err_msg = str(e)
+ value.end_time = datetime.now().timestamp() * 1000
+ return self.api_view_context(llm_text, True)
diff --git a/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py b/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py
index 941586328..1db3f844d 100644
--- a/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py
+++ b/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py
@@ -92,10 +92,10 @@ def zh_font_set():
@command(
"response_line_chart",
"Line chart display, used to display comparative trend analysis data",
- '"speak": "", "df":""',
+ '"df":""',
)
-def response_line_chart(speak: str, df: DataFrame) -> str:
- logger.info(f"response_line_chart:{speak},")
+def response_line_chart( df: DataFrame) -> str:
+ logger.info(f"response_line_chart")
if df.size <= 0:
raise ValueError("No Data!")
@@ -148,17 +148,17 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
- html_img = f"""{speak}
"""
+ html_img = f"""
"""
return html_img
@command(
"response_bar_chart",
"Histogram, suitable for comparative analysis of multiple target values",
- '"speak": "", "df":""',
+ '"df":""',
)
-def response_bar_chart(speak: str, df: DataFrame) -> str:
- logger.info(f"response_bar_chart:{speak},")
+def response_bar_chart( df: DataFrame) -> str:
+ logger.info(f"response_bar_chart")
if df.size <= 0:
raise ValueError("No Data!")
@@ -236,17 +236,17 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
- html_img = f"""{speak}
"""
+ html_img = f"""
"""
return html_img
@command(
"response_pie_chart",
"Pie chart, suitable for scenarios such as proportion and distribution statistics",
- '"speak": "", "df":""',
+ '"df":""',
)
-def response_pie_chart(speak: str, df: DataFrame) -> str:
- logger.info(f"response_pie_chart:{speak},")
+def response_pie_chart(df: DataFrame) -> str:
+ logger.info(f"response_pie_chart")
columns = df.columns.tolist()
if df.size <= 0:
raise ValueError("No Data!")
@@ -291,6 +291,6 @@ def response_pie_chart(speak: str, df: DataFrame) -> str:
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
- html_img = f"""{speak.replace("`", '"')}
"""
+ html_img = f"""
"""
return html_img
diff --git a/pilot/base_modules/agent/commands/disply_type/show_table_gen.py b/pilot/base_modules/agent/commands/disply_type/show_table_gen.py
index 45f9f2f21..e2b57a031 100644
--- a/pilot/base_modules/agent/commands/disply_type/show_table_gen.py
+++ b/pilot/base_modules/agent/commands/disply_type/show_table_gen.py
@@ -13,12 +13,12 @@ logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command(
"response_table",
"Table display, suitable for display with many display columns or non-numeric columns",
- '"speak": "", "df":""',
+ '"df":""',
)
-def response_table(speak: str, df: DataFrame) -> str:
- logger.info(f"response_table:{speak}")
+def response_table(df: DataFrame) -> str:
+ logger.info(f"response_table")
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
html = f"""{table_str}
"""
- view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
+ view_text = html.replace("\n", " ")
return view_text
diff --git a/pilot/base_modules/agent/commands/disply_type/show_text_gen.py b/pilot/base_modules/agent/commands/disply_type/show_text_gen.py
index f90ef087e..1dbc5d52e 100644
--- a/pilot/base_modules/agent/commands/disply_type/show_text_gen.py
+++ b/pilot/base_modules/agent/commands/disply_type/show_text_gen.py
@@ -12,10 +12,10 @@ logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command(
"response_data_text",
"Text display, the default display method, suitable for single-line or simple content display",
- '"speak": "", "df":""',
+ '"df":""',
)
-def response_data_text(speak: str, df: DataFrame) -> str:
- logger.info(f"response_data_text:{speak}")
+def response_data_text(df: DataFrame) -> str:
+ logger.info(f"response_data_text")
data = df.values
row_size = data.shape[0]
@@ -25,7 +25,7 @@ def response_data_text(speak: str, df: DataFrame) -> str:
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
html = f"""{table_str}
"""
- text_info = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
+ text_info = html.replace("\n", " ")
elif row_size == 1:
row = data[0]
for value in row:
@@ -33,7 +33,7 @@ def response_data_text(speak: str, df: DataFrame) -> str:
value_str = value_str + f", ** {value} **"
else:
value_str = f" ** {value} **"
- text_info = f"{speak}: {value_str}"
+ text_info = f" {value_str}"
else:
- text_info = f"##### {speak}: _没有找到可用的数据_"
+ text_info = f"##### _没有找到可用的数据_"
return text_info
diff --git a/pilot/common/string_utils.py b/pilot/common/string_utils.py
index b3e6404fa..c81fd5b6b 100644
--- a/pilot/common/string_utils.py
+++ b/pilot/common/string_utils.py
@@ -1,49 +1,47 @@
-def extract_content(long_string, s1, s2):
+def extract_content(long_string, s1, s2, is_include: bool = False):
+ # extract text
+ match_map ={}
start_index = long_string.find(s1)
- if start_index < 0:
- return ""
- end_index = long_string.find(s2, start_index + len(s1))
- extracted_content = long_string[start_index + len(s1):end_index]
- return extracted_content
+ while start_index != -1:
+ if is_include:
+ end_index = long_string.find(s2, start_index + len(s1) + 1)
+ extracted_content = long_string[start_index:end_index + len(s2)]
+ else:
+ end_index = long_string.find(s2, start_index + len(s1))
+ extracted_content = long_string[start_index + len(s1):end_index]
+ if extracted_content:
+ match_map[start_index] = extracted_content
+ start_index = long_string.find(s1, start_index + 1)
+ return match_map
-def extract_content_open_ending(long_string, s1, s2):
+def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
+ # extract text open ending
+ match_map = {}
start_index = long_string.find(s1)
- if start_index < 0:
- return ""
- if long_string.find(s2) <=0:
- end_index = len(long_string)
- else:
- end_index = long_string.find(s2, start_index + len(s1))
- extracted_content = long_string[start_index + len(s1):end_index]
- return extracted_content
-
-def extract_content_include(long_string, s1, s2):
- start_index = long_string.find(s1)
- if start_index < 0:
- return ""
- end_index = long_string.find(s2, start_index + len(s1) + 1)
- extracted_content = long_string[start_index:end_index + len(s2)]
- return extracted_content
-
-def extract_content_include_open_ending(long_string, s1, s2):
-
- start_index = long_string.find(s1)
- if start_index < 0:
- return ""
- if long_string.find(s2) <=0:
- end_index = len(long_string)
- else:
- end_index = long_string.find(s2, start_index + len(s1) + 1)
- extracted_content = long_string[start_index:end_index + len(s2)]
- return extracted_content
+ while start_index != -1:
+ if long_string.find(s2, start_index) <=0:
+ end_index = len(long_string)
+ else:
+ if is_include:
+ end_index = long_string.find(s2, start_index + len(s1) + 1)
+ else:
+ end_index = long_string.find(s2, start_index + len(s1))
+ if is_include:
+ extracted_content = long_string[start_index:end_index + len(s2)]
+ else:
+ extracted_content = long_string[start_index + len(s1):end_index]
+ if extracted_content:
+ match_map[start_index] = extracted_content
+ start_index= long_string.find(s1, start_index + 1)
+ return match_map
if __name__=="__main__":
- s = "abcd123efghijkjhhh456"
+ s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123"
s1 = "123"
s2 = "456"
- print(extract_content_open_ending(s, s1, s2))
\ No newline at end of file
+ print(extract_content_open_ending(s, s1, s2, True))
\ No newline at end of file
diff --git a/pilot/scene/chat_agent/chat.py b/pilot/scene/chat_agent/chat.py
index 8175d3d7c..5ba1c6f4f 100644
--- a/pilot/scene/chat_agent/chat.py
+++ b/pilot/scene/chat_agent/chat.py
@@ -34,7 +34,7 @@ class ChatAgent(BaseChat):
agent_module = CFG.SYSTEM_APP.get_componet(ComponetType.AGENT_HUB, ModuleAgent)
self.plugins_prompt_generator = agent_module.load_select_plugin(self.plugins_prompt_generator, self.select_plugins)
- self.api_call = ApiCall(self.plugins_prompt_generator)
+ self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
def generate_input_values(self):
input_values = {
diff --git a/pilot/scene/chat_agent/out_parser.py b/pilot/scene/chat_agent/out_parser.py
index 2078018d0..ecea328bb 100644
--- a/pilot/scene/chat_agent/out_parser.py
+++ b/pilot/scene/chat_agent/out_parser.py
@@ -15,27 +15,6 @@ class PluginAction(NamedTuple):
class PluginChatOutputParser(BaseOutputParser):
- def parse_prompt_response(self, model_out_text) -> T:
- clean_json_str = super().parse_prompt_response(model_out_text)
- print(clean_json_str)
- if not clean_json_str:
- raise ValueError("model server response not have json!")
- try:
- response = json.loads(clean_json_str)
- except Exception as e:
- raise ValueError("model server out not fllow the prompt!")
-
- speak = ""
- thoughts = ""
- for key in sorted(response):
- if key.strip() == "command":
- command = response[key]
- if key.strip() == "thoughts":
- thoughts = response[key]
- if key.strip() == "speak":
- speak = response[key]
- return PluginAction(command, speak, thoughts)
-
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
print(f"parse_view_response:{speak},{str(data)}")
diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py
index 4f90452c5..259467c42 100644
--- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py
+++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py
@@ -7,13 +7,13 @@ from pilot.scene.base_chat import BaseChat, logger
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
-
+from pilot.base_modules.agent.commands.command_mange import ApiCall
from pilot.scene.chat_data.chat_excel.excel_analyze.prompt import prompt
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
from pilot.common.path_utils import has_path
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
-
+from pilot.base_modules.agent.common.schema import Status
CFG = Config()
@@ -36,9 +36,18 @@ class ChatExcel(BaseChat):
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
)
)
-
+ self.api_call = ApiCall(display_registry = CFG.command_disply)
super().__init__(chat_param=chat_param)
+ def _generate_numbered_list(self) -> str:
+ command_strings = []
+ if CFG.command_disply:
+ command_strings += [
+ str(item)
+ for item in CFG.command_disply.commands.values()
+ if item.enabled
+ ]
+ return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
def generate_input_values(self):
input_values = {
@@ -64,15 +73,25 @@ class ChatExcel(BaseChat):
result = await learn_chat.nostream_call()
return result
- def do_action(self, prompt_response):
- print(f"do_action:{prompt_response}")
+ def stream_plugin_call(self, text):
+ text = text.replace("\n", " ")
+ print(f"stream_plugin_call:{text}")
+ return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex)
- # colunms, datas = self.excel_reader.run(prompt_response.sql)
- param = {
- "speak": prompt_response.thoughts,
- "df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql),
- }
- if CFG.command_disply.get_command(prompt_response.display):
- return CFG.command_disply.call(prompt_response.display, **param)
- else:
- return CFG.command_disply.call("response_table", **param)
+
+
+ # def do_action(self, prompt_response):
+ # print(f"do_action:{prompt_response}")
+ #
+ # # colunms, datas = self.excel_reader.run(prompt_response.sql)
+ #
+ #
+ # param = {
+ # "speak": prompt_response.thoughts,
+ # "df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql),
+ # }
+ #
+ # if CFG.command_disply.get_command(prompt_response.display):
+ # return CFG.command_disply.call(prompt_response.display, **param)
+ # else:
+ # return CFG.command_disply.call("response_table", **param)
diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py b/pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py
index 9d7b35228..ce2089942 100644
--- a/pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py
+++ b/pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py
@@ -9,50 +9,52 @@ from pilot.common.schema import SeparatorStyle
CFG = Config()
-PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
+_PROMPT_SCENE_DEFINE_EN = "You are a data analysis expert. "
_DEFAULT_TEMPLATE_EN = """
-Please use the data structure information of the above historical dialogue, make sure not to use column names that are not in the data structure.
-According to the user goal: {user_input},give the correct duckdb SQL for data analysis.
-Use the table name: {table_name}
+Please use the data structure and column information in the above historical dialogue and combine it with data analysis to answer the user's questions while satisfying the constraints.
-According to the analysis SQL obtained by the user's goal, select the best one from the following display forms, if it cannot be determined, use Text as the display,Just need to return the type name into the result.
-Display type:
- {disply_type}
-
-Respond in the following json format:
- {response}
-Ensure the response is correct json and can be parsed by Python json.loads
+Constraint:
+ 1.Please output your thinking process and analysis ideas first, and then output the specific data analysis results. The data analysis results are output in the following format:display typeCorrect duckdb data analysis sql
+ 2.For the available display methods of data analysis results, please choose the most appropriate one from the following display methods. If you are not sure, use 'response_data_text' as the display. The available display types are as follows:{disply_type}
+ 3.The table name that needs to be used in SQL is: {table_name}, please make sure not to use column names that are not in the data structure.
+ 4.Give priority to answering using data analysis. If the user's question does not involve data analysis, you can answer according to your understanding.
+User Questions:
+ {user_input}
"""
+_PROMPT_SCENE_DEFINE_ZH = """你是一个数据分析专家!"""
_DEFAULT_TEMPLATE_ZH = """
-请使用上述历史对话中的数据结构和列信息,根据用户目标:{user_input},给出正确的duckdb SQL进行数据分析和问题回答。
-请确保不要使用不在数据结构中的列名。
-SQL中需要使用的表名是: {table_name}
+请使用上述历史对话中的数据结构信息,在满足约束条件下结合数据分析回答用户的问题。
-根据用户目标得到的分析SQL,请从以下显示类型中选择最合适的一种用来展示结果数据,如果无法确定,则使用'Text'作为显示, 只需要将类型名称返回到结果中。
-显示类型如下:
- {disply_type}
-
-以以下 json 格式响应::
- {response}
-确保响应是正确的json,并且可以被Python的json.loads方法解析.
+约束条件:
+ 1.请先输出你的分析思路内容,再输出具体的数据分析结果,其中数据分析结果部分确保用如下格式输出:[数据展示方式][正确的duckdb数据分析sql]
+ 2.请确保数据分析结果格式的内容在整个回答中只出现一次,确保上述结构稳定,把[]部分内容替换为对应的值
+ 3.数据分析结果可用的展示方式请在下面的展示方式中选择最合适的一种,放入数据分析结果的name字段内如果无法确定,则使用'Text'作为显示,可用显示类型如下: {disply_type}
+ 4.SQL中需要使用的表名是: {table_name},请确保不要使用不在数据结构中的列名。
+ 5.优先使用数据分析的方式回答,如果用户问题不涉及数据分析内容,你可以按你的理解进行回答
+ 6.请确保你的输出内容有良好的markdown格式的排版
+
+用户问题:{user_input}
+请确保你的输出格式如下:
+ 需要告诉用户的分析思路.
+ [数据展示方式][正确的duckdb数据分析sql]
"""
-RESPONSE_FORMAT_SIMPLE = {
- "sql": "analysis SQL",
- "thoughts": "Current thinking and value of data analysis",
- "display": "display type name",
-}
+
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
+_PROMPT_SCENE_DEFINE = (
+ _PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
+)
+
PROMPT_SEP = SeparatorStyle.SINGLE.value
-PROMPT_NEED_NEED_STREAM_OUT = False
+PROMPT_NEED_NEED_STREAM_OUT = True
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
@@ -62,8 +64,7 @@ PROMPT_TEMPERATURE = 0.8
prompt = PromptTemplate(
template_scene=ChatScene.ChatExcel.value(),
input_variables=["user_input", "table_name", "disply_type"],
- response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
- template_define=PROMPT_SCENE_DEFINE,
+ template_define=_PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=ChatExcelOutputParser(
diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py b/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py
index 23a3696ba..fad30bb56 100644
--- a/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py
+++ b/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py
@@ -9,7 +9,7 @@ from pilot.common.schema import SeparatorStyle
CFG = Config()
-PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
+_PROMPT_SCENE_DEFINE_EN = "You are a data analysis expert. "
_DEFAULT_TEMPLATE_EN = """
This is an example data,please learn to understand the structure and content of this data:
@@ -21,6 +21,8 @@ Please return your answer in JSON format, the return format is as follows:
{response}
"""
+_PROMPT_SCENE_DEFINE_ZH = "你是一个数据分析专家. "
+
_DEFAULT_TEMPLATE_ZH = """
下面是一份示例数据,请学习理解该数据的结构和内容:
{data_example}
@@ -37,10 +39,15 @@ RESPONSE_FORMAT_SIMPLE = {
"AnalysisProgram": ["1.分析方案1,图表展示方式1", "2.分析方案2,图表展示方式2"],
}
+
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
+PROMPT_SCENE_DEFINE =(
+ _PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
+)
+
PROMPT_SEP = SeparatorStyle.SINGLE.value