feat(Agent): ChatAgent And AgentHub

1.ChatExcel support stream response
2.ChatExcel support Multiple charts
This commit is contained in:
yhjun1026 2023-10-09 16:26:05 +08:00
parent ed702db9bc
commit f8de26b5ff
10 changed files with 270 additions and 192 deletions

View File

@ -1,16 +1,18 @@
import functools import functools
import importlib import importlib
import inspect import inspect
import time
import json import json
import logging import logging
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from datetime import datetime from datetime import datetime
from typing import Any, Callable, Optional from typing import Any, Callable, Optional, List
from pydantic import BaseModel
from pilot.base_modules.agent.common.schema import Status, ApiTagType from pilot.base_modules.agent.common.schema import Status, ApiTagType
from pilot.base_modules.agent.commands.command import execute_command from pilot.base_modules.agent.commands.command import execute_command
from pilot.base_modules.agent.commands.generator import PluginPromptGenerator from pilot.base_modules.agent.commands.generator import PluginPromptGenerator
from pilot.common.string_utils import extract_content_include, extract_content_open_ending, extract_content, extract_content_include_open_ending from pilot.common.string_utils import extract_content_open_ending, extract_content
# Unique identifier for auto-gpt commands # Unique identifier for auto-gpt commands
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command" AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
@ -165,107 +167,179 @@ def command(
return decorator return decorator
class PluginStatus(BaseModel):
name: str
location: List[int]
args: dict
status: Status = Status.TODO.value
logo_url: str = None
api_result: str = None
err_msg: str = None
start_time = datetime.now().timestamp() * 1000
end_time: int = None
class ApiCall: class ApiCall:
agent_prefix = "<api-call>" agent_prefix = "<api-call>"
agent_end = "</api-call>" agent_end = "</api-call>"
name_prefix = "<name>" name_prefix = "<name>"
name_end = "</name>" name_end = "</name>"
def __init__(self, plugin_generator): def __init__(self, plugin_generator: Any = None, display_registry: Any = None):
self.name: str = "" # self.name: str = ""
self.status: Status = Status.TODO.value # self.status: Status = Status.TODO.value
self.logo_url: str = None # self.logo_url: str = None
self.args = {} # self.args = {}
self.api_result: str = None # self.api_result: str = None
self.err_msg: str = None # self.err_msg: str = None
self.plugin_status_map = {}
self.plugin_generator = plugin_generator self.plugin_generator = plugin_generator
self.display_registry = display_registry
self.start_time = datetime.now().timestamp() * 1000
def __repr__(self): def __repr__(self):
return f"ApiCall(name={self.name}, status={self.status}, args={self.args})" return f"ApiCall(name={self.name}, status={self.status}, args={self.args})"
def __is_need_wait_plugin_call(self, api_call_context): def __is_need_wait_plugin_call(self, api_call_context):
start_agent_count = api_call_context.count(self.agent_prefix)
end_agent_count = api_call_context.count(self.agent_end)
if api_call_context.find(self.agent_prefix) >= 0: if start_agent_count > 0:
return True return True
check_len = len(self.agent_prefix) else:
last_text = api_call_context[-check_len:] # 末尾新出字符检测
for i in range(check_len): check_len = len(self.agent_prefix)
text_tmp = last_text[-i:] last_text = api_call_context[-check_len:]
prefix_tmp = self.agent_prefix[:i] for i in range(check_len):
if text_tmp == prefix_tmp: text_tmp = last_text[-i:]
return True prefix_tmp = self.agent_prefix[:i]
else: if text_tmp == prefix_tmp:
i += 1 return True
else:
i += 1
return False return False
def __get_api_call_context(self, all_context): def __check_last_plugin_call_ready(self, all_context):
return extract_content_include(all_context, self.agent_prefix, self.agent_end) start_agent_count = all_context.count(self.agent_prefix)
end_agent_count = all_context.count(self.agent_end)
def __check_plugin_call_ready(self, all_context): if start_agent_count > 0 and start_agent_count == end_agent_count:
if all_context.find(self.agent_end) > 0:
return True return True
return False
def api_view_context(self, all_context:str): def api_view_context(self, all_context: str, display_mode: bool = False):
if all_context.find(self.agent_prefix) >= 0: call_context_map = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end, True)
call_context = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end) for api_index, api_context in call_context_map.items():
call_context_all = extract_content_include_open_ending(all_context, self.agent_prefix, self.agent_end) api_status = self.plugin_status_map.get(api_context)
if len(call_context) > 0: if api_status is not None:
name_context = extract_content(call_context, self.name_prefix, self.name_end) if display_mode:
if len(name_context) > 0: if api_status.api_result:
self.name = name_context all_context = all_context.replace(api_context, api_status.api_result)
return all_context.replace(call_context_all, self.to_view_text()) else:
else: if api_status.status == Status.FAILED.value:
return all_context all_context = all_context.replace(api_context, f"""\n<span style=\"color:red\">ERROR!</span>{api_status.err_msg}\n """)
else:
cost = (api_status.end_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
all_context = all_context.replace(api_context, f'\n<span style=\"color:green\">Waiting...{cost_str}S</span>\n')
else:
all_context = all_context.replace(api_context, self.to_view_text(api_status))
else:
# not ready api call view change
now_time = datetime.now().timestamp() * 1000
cost = (now_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
all_context = all_context.replace(api_context, f'\nWaiting... * {cost_str}S * \n')
return all_context
def update_from_context(self, all_context): def update_from_context(self, all_context):
logging.info(f"from_context:{all_context}") logging.info(f"from_context:{all_context}")
api_context = extract_content_include(all_context, self.agent_prefix, self.agent_end)
api_context = api_context.replace("\\n", "").replace("\n", "")
api_call_element = ET.fromstring(api_context) api_context_map = extract_content(all_context, self.agent_prefix, self.agent_end, True)
self.name = api_call_element.find('name').text for api_index, api_context in api_context_map.items():
api_context = api_context.replace("\\n", "").replace("\n", "")
api_call_element = ET.fromstring(api_context)
api_name = api_call_element.find('name').text
api_args = {}
args_elements = api_call_element.find('args')
for child_element in args_elements.iter():
api_args[child_element.tag] = child_element.text
args_elements = api_call_element.find('args') api_status = self.plugin_status_map.get(api_context)
for child_element in args_elements.iter(): if api_status is None:
self.args[child_element.tag] = child_element.text api_status = PluginStatus(name=api_name, location=[api_index], args=api_args)
self.plugin_status_map[api_context] = api_status
else:
api_status.location.append(api_index)
def __to_view_param_str(self): def __to_view_param_str(self, api_status):
param = {} param = {}
if self.name: if api_status.name:
param['name'] = self.name param['name'] = api_status.name
param['status'] = self.status param['status'] = api_status.status
if self.logo_url: if api_status.logo_url:
param['logo'] = self.logo_url param['logo'] = api_status.logo_url
if self.err_msg: if api_status.err_msg:
param['err_msg'] = self.err_msg param['err_msg'] = api_status.err_msg
if self.api_result:
param['result'] = self.api_result
if api_status.api_result:
param['result'] = api_status.api_result
return json.dumps(param) return json.dumps(param)
def to_view_text(self): def to_view_text(self, api_status: PluginStatus):
api_call_element = ET.Element('dbgpt-view') api_call_element = ET.Element('dbgpt-view')
api_call_element.text = self.__to_view_param_str() api_call_element.text = self.__to_view_param_str(api_status)
result = ET.tostring(api_call_element, encoding="utf-8") result = ET.tostring(api_call_element, encoding="utf-8")
return result.decode("utf-8") return result.decode("utf-8")
def run(self, llm_text): def run(self, llm_text):
print(f"stream_plugin_call:{llm_text}") print(f"stream_plugin_call:{llm_text}")
if self.__is_need_wait_plugin_call(llm_text): if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete # wait api call generate complete
if self.__check_plugin_call_ready(llm_text): if self.__check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text) self.update_from_context(llm_text)
if self.status == Status.TODO.value: for key, value in self.plugin_status_map:
self.status = Status.RUNNING.value if value.status == Status.TODO.value:
logging.info(f"插件执行:{self.name},{self.args}") value.status = Status.RUNNING.value
try: logging.info(f"插件执行:{value.name},{value.args}")
self.api_result = execute_command(self.name, self.args, self.plugin_generator) try:
self.status = Status.COMPLETED.value value.api_result = execute_command(value.name, value.args, self.plugin_generator)
except Exception as e: value.status = Status.COMPLETED.value
self.status = Status.FAILED.value except Exception as e:
self.err_msg = str(e) value.status = Status.FAILED.value
return self.api_view_context(llm_text) value.err_msg = str(e)
value.end_time = datetime.now().timestamp() * 1000
return self.api_view_context(llm_text)
def run_display_sql(self, llm_text, sql_run_func):
print(f"get_display_sql:{llm_text}")
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.__check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
value.status = Status.RUNNING.value
logging.info(f"sql展示执行:{value.name},{value.args}")
try:
sql = value.args['sql']
if sql:
param = {
"df": sql_run_func(sql),
}
if self.display_registry.get_command(value.name):
value.api_result = self.display_registry.call(value.name, **param)
else:
value.api_result = self.display_registry.call("response_table", **param)
value.status = Status.COMPLETED.value
except Exception as e:
value.status = Status.FAILED.value
value.err_msg = str(e)
value.end_time = datetime.now().timestamp() * 1000
return self.api_view_context(llm_text, True)

View File

@ -92,10 +92,10 @@ def zh_font_set():
@command( @command(
"response_line_chart", "response_line_chart",
"Line chart display, used to display comparative trend analysis data", "Line chart display, used to display comparative trend analysis data",
'"speak": "<speak>", "df":"<data frame>"', '"df":"<data frame>"',
) )
def response_line_chart(speak: str, df: DataFrame) -> str: def response_line_chart( df: DataFrame) -> str:
logger.info(f"response_line_chart:{speak},") logger.info(f"response_line_chart")
if df.size <= 0: if df.size <= 0:
raise ValueError("No Data") raise ValueError("No Data")
@ -148,17 +148,17 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
chart_path = static_message_img_path + "/" + chart_name chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100) plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />""" html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img return html_img
@command( @command(
"response_bar_chart", "response_bar_chart",
"Histogram, suitable for comparative analysis of multiple target values", "Histogram, suitable for comparative analysis of multiple target values",
'"speak": "<speak>", "df":"<data frame>"', '"df":"<data frame>"',
) )
def response_bar_chart(speak: str, df: DataFrame) -> str: def response_bar_chart( df: DataFrame) -> str:
logger.info(f"response_bar_chart:{speak},") logger.info(f"response_bar_chart")
if df.size <= 0: if df.size <= 0:
raise ValueError("No Data") raise ValueError("No Data")
@ -236,17 +236,17 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
chart_name = "bar_" + str(uuid.uuid1()) + ".png" chart_name = "bar_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100) plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />""" html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img return html_img
@command( @command(
"response_pie_chart", "response_pie_chart",
"Pie chart, suitable for scenarios such as proportion and distribution statistics", "Pie chart, suitable for scenarios such as proportion and distribution statistics",
'"speak": "<speak>", "df":"<data frame>"', '"df":"<data frame>"',
) )
def response_pie_chart(speak: str, df: DataFrame) -> str: def response_pie_chart(df: DataFrame) -> str:
logger.info(f"response_pie_chart:{speak},") logger.info(f"response_pie_chart")
columns = df.columns.tolist() columns = df.columns.tolist()
if df.size <= 0: if df.size <= 0:
raise ValueError("No Data") raise ValueError("No Data")
@ -291,6 +291,6 @@ def response_pie_chart(speak: str, df: DataFrame) -> str:
chart_path = static_message_img_path + "/" + chart_name chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100) plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />""" html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img return html_img

View File

@ -13,12 +13,12 @@ logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command( @command(
"response_table", "response_table",
"Table display, suitable for display with many display columns or non-numeric columns", "Table display, suitable for display with many display columns or non-numeric columns",
'"speak": "<speak>", "df":"<data frame>"', '"df":"<data frame>"',
) )
def response_table(speak: str, df: DataFrame) -> str: def response_table(df: DataFrame) -> str:
logger.info(f"response_table:{speak}") logger.info(f"response_table")
html_table = df.to_html(index=False, escape=False, sparsify=False) html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split()) table_str = "".join(html_table.split())
html = f"""<div class="w-full overflow-auto">{table_str}</div>""" html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ") view_text = html.replace("\n", " ")
return view_text return view_text

View File

@ -12,10 +12,10 @@ logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command( @command(
"response_data_text", "response_data_text",
"Text display, the default display method, suitable for single-line or simple content display", "Text display, the default display method, suitable for single-line or simple content display",
'"speak": "<speak>", "df":"<data frame>"', '"df":"<data frame>"',
) )
def response_data_text(speak: str, df: DataFrame) -> str: def response_data_text(df: DataFrame) -> str:
logger.info(f"response_data_text:{speak}") logger.info(f"response_data_text")
data = df.values data = df.values
row_size = data.shape[0] row_size = data.shape[0]
@ -25,7 +25,7 @@ def response_data_text(speak: str, df: DataFrame) -> str:
html_table = df.to_html(index=False, escape=False, sparsify=False) html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split()) table_str = "".join(html_table.split())
html = f"""<div class="w-full overflow-auto">{table_str}</div>""" html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
text_info = f"##### {str(speak)}" + "\n" + html.replace("\n", " ") text_info = html.replace("\n", " ")
elif row_size == 1: elif row_size == 1:
row = data[0] row = data[0]
for value in row: for value in row:
@ -33,7 +33,7 @@ def response_data_text(speak: str, df: DataFrame) -> str:
value_str = value_str + f", ** {value} **" value_str = value_str + f", ** {value} **"
else: else:
value_str = f" ** {value} **" value_str = f" ** {value} **"
text_info = f"{speak}: {value_str}" text_info = f" {value_str}"
else: else:
text_info = f"##### {speak}: _没有找到可用的数据_" text_info = f"##### _没有找到可用的数据_"
return text_info return text_info

View File

@ -1,49 +1,47 @@
def extract_content(long_string, s1, s2): def extract_content(long_string, s1, s2, is_include: bool = False):
# extract text
match_map ={}
start_index = long_string.find(s1) start_index = long_string.find(s1)
if start_index < 0: while start_index != -1:
return "" if is_include:
end_index = long_string.find(s2, start_index + len(s1)) end_index = long_string.find(s2, start_index + len(s1) + 1)
extracted_content = long_string[start_index + len(s1):end_index] extracted_content = long_string[start_index:end_index + len(s2)]
return extracted_content else:
end_index = long_string.find(s2, start_index + len(s1))
extracted_content = long_string[start_index + len(s1):end_index]
if extracted_content:
match_map[start_index] = extracted_content
start_index = long_string.find(s1, start_index + 1)
return match_map
def extract_content_open_ending(long_string, s1, s2): def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
# extract text open ending
match_map = {}
start_index = long_string.find(s1) start_index = long_string.find(s1)
if start_index < 0: while start_index != -1:
return "" if long_string.find(s2, start_index) <=0:
if long_string.find(s2) <=0: end_index = len(long_string)
end_index = len(long_string) else:
else: if is_include:
end_index = long_string.find(s2, start_index + len(s1)) end_index = long_string.find(s2, start_index + len(s1) + 1)
extracted_content = long_string[start_index + len(s1):end_index] else:
return extracted_content end_index = long_string.find(s2, start_index + len(s1))
if is_include:
def extract_content_include(long_string, s1, s2): extracted_content = long_string[start_index:end_index + len(s2)]
start_index = long_string.find(s1) else:
if start_index < 0: extracted_content = long_string[start_index + len(s1):end_index]
return "" if extracted_content:
end_index = long_string.find(s2, start_index + len(s1) + 1) match_map[start_index] = extracted_content
extracted_content = long_string[start_index:end_index + len(s2)] start_index= long_string.find(s1, start_index + 1)
return extracted_content return match_map
def extract_content_include_open_ending(long_string, s1, s2):
start_index = long_string.find(s1)
if start_index < 0:
return ""
if long_string.find(s2) <=0:
end_index = len(long_string)
else:
end_index = long_string.find(s2, start_index + len(s1) + 1)
extracted_content = long_string[start_index:end_index + len(s2)]
return extracted_content
if __name__=="__main__": if __name__=="__main__":
s = "abcd123efghijkjhhh456" s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123"
s1 = "123" s1 = "123"
s2 = "456" s2 = "456"
print(extract_content_open_ending(s, s1, s2)) print(extract_content_open_ending(s, s1, s2, True))

View File

@ -34,7 +34,7 @@ class ChatAgent(BaseChat):
agent_module = CFG.SYSTEM_APP.get_componet(ComponetType.AGENT_HUB, ModuleAgent) agent_module = CFG.SYSTEM_APP.get_componet(ComponetType.AGENT_HUB, ModuleAgent)
self.plugins_prompt_generator = agent_module.load_select_plugin(self.plugins_prompt_generator, self.select_plugins) self.plugins_prompt_generator = agent_module.load_select_plugin(self.plugins_prompt_generator, self.select_plugins)
self.api_call = ApiCall(self.plugins_prompt_generator) self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
def generate_input_values(self): def generate_input_values(self):
input_values = { input_values = {

View File

@ -15,27 +15,6 @@ class PluginAction(NamedTuple):
class PluginChatOutputParser(BaseOutputParser): class PluginChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
clean_json_str = super().parse_prompt_response(model_out_text)
print(clean_json_str)
if not clean_json_str:
raise ValueError("model server response not have json!")
try:
response = json.loads(clean_json_str)
except Exception as e:
raise ValueError("model server out not fllow the prompt!")
speak = ""
thoughts = ""
for key in sorted(response):
if key.strip() == "command":
command = response[key]
if key.strip() == "thoughts":
thoughts = response[key]
if key.strip() == "speak":
speak = response[key]
return PluginAction(command, speak, thoughts)
def parse_view_response(self, speak, data) -> str: def parse_view_response(self, speak, data) -> str:
### tool out data to table view ### tool out data to table view
print(f"parse_view_response:{speak},{str(data)}") print(f"parse_view_response:{speak},{str(data)}")

View File

@ -7,13 +7,13 @@ from pilot.scene.base_chat import BaseChat, logger
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database from pilot.common.sql_database import Database
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.base_modules.agent.commands.command_mange import ApiCall
from pilot.scene.chat_data.chat_excel.excel_analyze.prompt import prompt from pilot.scene.chat_data.chat_excel.excel_analyze.prompt import prompt
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
from pilot.common.path_utils import has_path from pilot.common.path_utils import has_path
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.base_modules.agent.common.schema import Status
CFG = Config() CFG = Config()
@ -36,9 +36,18 @@ class ChatExcel(BaseChat):
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
) )
) )
self.api_call = ApiCall(display_registry = CFG.command_disply)
super().__init__(chat_param=chat_param) super().__init__(chat_param=chat_param)
def _generate_numbered_list(self) -> str:
command_strings = []
if CFG.command_disply:
command_strings += [
str(item)
for item in CFG.command_disply.commands.values()
if item.enabled
]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
def generate_input_values(self): def generate_input_values(self):
input_values = { input_values = {
@ -64,15 +73,25 @@ class ChatExcel(BaseChat):
result = await learn_chat.nostream_call() result = await learn_chat.nostream_call()
return result return result
def do_action(self, prompt_response): def stream_plugin_call(self, text):
print(f"do_action:{prompt_response}") text = text.replace("\n", " ")
print(f"stream_plugin_call:{text}")
return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex)
# colunms, datas = self.excel_reader.run(prompt_response.sql)
param = {
"speak": prompt_response.thoughts, # def do_action(self, prompt_response):
"df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql), # print(f"do_action:{prompt_response}")
} #
if CFG.command_disply.get_command(prompt_response.display): # # colunms, datas = self.excel_reader.run(prompt_response.sql)
return CFG.command_disply.call(prompt_response.display, **param) #
else: #
return CFG.command_disply.call("response_table", **param) # param = {
# "speak": prompt_response.thoughts,
# "df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql),
# }
#
# if CFG.command_disply.get_command(prompt_response.display):
# return CFG.command_disply.call(prompt_response.display, **param)
# else:
# return CFG.command_disply.call("response_table", **param)

View File

@ -9,50 +9,52 @@ from pilot.common.schema import SeparatorStyle
CFG = Config() CFG = Config()
PROMPT_SCENE_DEFINE = "You are a data analysis expert. " _PROMPT_SCENE_DEFINE_EN = "You are a data analysis expert. "
_DEFAULT_TEMPLATE_EN = """ _DEFAULT_TEMPLATE_EN = """
Please use the data structure information of the above historical dialogue, make sure not to use column names that are not in the data structure. Please use the data structure and column information in the above historical dialogue and combine it with data analysis to answer the user's questions while satisfying the constraints.
According to the user goal: {user_input}give the correct duckdb SQL for data analysis.
Use the table name: {table_name}
According to the analysis SQL obtained by the user's goal, select the best one from the following display forms, if it cannot be determined, use Text as the display,Just need to return the type name into the result. Constraint:
Display type: 1.Please output your thinking process and analysis ideas first, and then output the specific data analysis results. The data analysis results are output in the following format:<api-call><name>display type</name><args><sql>Correct duckdb data analysis sql</sql></args></api-call>
{disply_type} 2.For the available display methods of data analysis results, please choose the most appropriate one from the following display methods. If you are not sure, use 'response_data_text' as the display. The available display types are as follows:{disply_type}
3.The table name that needs to be used in SQL is: {table_name}, please make sure not to use column names that are not in the data structure.
Respond in the following json format: 4.Give priority to answering using data analysis. If the user's question does not involve data analysis, you can answer according to your understanding.
{response}
Ensure the response is correct json and can be parsed by Python json.loads
User Questions:
{user_input}
""" """
_PROMPT_SCENE_DEFINE_ZH = """你是一个数据分析专家!"""
_DEFAULT_TEMPLATE_ZH = """ _DEFAULT_TEMPLATE_ZH = """
请使用上述历史对话中的数据结构和列信息根据用户目标{user_input}给出正确的duckdb SQL进行数据分析和问题回答 请使用上述历史对话中的数据结构信息在满足约束条件下结合数据分析回答用户的问题
请确保不要使用不在数据结构中的列名
SQL中需要使用的表名是: {table_name}
根据用户目标得到的分析SQL请从以下显示类型中选择最合适的一种用来展示结果数据如果无法确定则使用'Text'作为显示, 只需要将类型名称返回到结果中 约束条件:
显示类型如下: 1.请先输出你的分析思路内容再输出具体的数据分析结果其中数据分析结果部分确保用如下格式输出:<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
{disply_type} 2.请确保数据分析结果格式的内容在整个回答中只出现一次,确保上述结构稳定[]部分内容替换为对应的值
3.数据分析结果可用的展示方式请在下面的展示方式中选择最合适的一种,放入数据分析结果的name字段内如果无法确定则使用'Text'作为显示可用显示类型如下: {disply_type}
以以下 json 格式响应: 4.SQL中需要使用的表名是: {table_name},请确保不要使用不在数据结构中的列名
{response} 5.优先使用数据分析的方式回答如果用户问题不涉及数据分析内容你可以按你的理解进行回答
确保响应是正确的json,并且可以被Python的json.loads方法解析. 6.请确保你的输出内容有良好的markdown格式的排版
用户问题{user_input}
请确保你的输出格式如下:
需要告诉用户的分析思路.
<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
""" """
RESPONSE_FORMAT_SIMPLE = {
"sql": "analysis SQL",
"thoughts": "Current thinking and value of data analysis",
"display": "display type name",
}
_DEFAULT_TEMPLATE = ( _DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
) )
_PROMPT_SCENE_DEFINE = (
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
)
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False PROMPT_NEED_NEED_STREAM_OUT = True
# Temperature is a configuration hyperparameter that controls the randomness of language model output. # Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output. # A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
@ -62,8 +64,7 @@ PROMPT_TEMPERATURE = 0.8
prompt = PromptTemplate( prompt = PromptTemplate(
template_scene=ChatScene.ChatExcel.value(), template_scene=ChatScene.ChatExcel.value(),
input_variables=["user_input", "table_name", "disply_type"], input_variables=["user_input", "table_name", "disply_type"],
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4), template_define=_PROMPT_SCENE_DEFINE,
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE, template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT, stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=ChatExcelOutputParser( output_parser=ChatExcelOutputParser(

View File

@ -9,7 +9,7 @@ from pilot.common.schema import SeparatorStyle
CFG = Config() CFG = Config()
PROMPT_SCENE_DEFINE = "You are a data analysis expert. " _PROMPT_SCENE_DEFINE_EN = "You are a data analysis expert. "
_DEFAULT_TEMPLATE_EN = """ _DEFAULT_TEMPLATE_EN = """
This is an example dataplease learn to understand the structure and content of this data: This is an example dataplease learn to understand the structure and content of this data:
@ -21,6 +21,8 @@ Please return your answer in JSON format, the return format is as follows:
{response} {response}
""" """
_PROMPT_SCENE_DEFINE_ZH = "你是一个数据分析专家. "
_DEFAULT_TEMPLATE_ZH = """ _DEFAULT_TEMPLATE_ZH = """
下面是一份示例数据请学习理解该数据的结构和内容: 下面是一份示例数据请学习理解该数据的结构和内容:
{data_example} {data_example}
@ -37,10 +39,15 @@ RESPONSE_FORMAT_SIMPLE = {
"AnalysisProgram": ["1.分析方案1图表展示方式1", "2.分析方案2图表展示方式2"], "AnalysisProgram": ["1.分析方案1图表展示方式1", "2.分析方案2图表展示方式2"],
} }
_DEFAULT_TEMPLATE = ( _DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
) )
PROMPT_SCENE_DEFINE =(
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
)
PROMPT_SEP = SeparatorStyle.SINGLE.value PROMPT_SEP = SeparatorStyle.SINGLE.value