mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
feat(Agent): ChatAgent And AgentHub
1.ChatExcel support stream response 2.ChatExcel support Multiple charts
This commit is contained in:
parent
ed702db9bc
commit
f8de26b5ff
@ -1,16 +1,18 @@
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, List
|
||||
from pydantic import BaseModel
|
||||
from pilot.base_modules.agent.common.schema import Status, ApiTagType
|
||||
from pilot.base_modules.agent.commands.command import execute_command
|
||||
from pilot.base_modules.agent.commands.generator import PluginPromptGenerator
|
||||
from pilot.common.string_utils import extract_content_include, extract_content_open_ending, extract_content, extract_content_include_open_ending
|
||||
from pilot.common.string_utils import extract_content_open_ending, extract_content
|
||||
|
||||
# Unique identifier for auto-gpt commands
|
||||
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
|
||||
@ -165,107 +167,179 @@ def command(
|
||||
return decorator
|
||||
|
||||
|
||||
class PluginStatus(BaseModel):
|
||||
name: str
|
||||
location: List[int]
|
||||
args: dict
|
||||
status: Status = Status.TODO.value
|
||||
logo_url: str = None
|
||||
api_result: str = None
|
||||
err_msg: str = None
|
||||
start_time = datetime.now().timestamp() * 1000
|
||||
end_time: int = None
|
||||
|
||||
|
||||
class ApiCall:
|
||||
agent_prefix = "<api-call>"
|
||||
agent_end = "</api-call>"
|
||||
name_prefix = "<name>"
|
||||
name_end = "</name>"
|
||||
|
||||
def __init__(self, plugin_generator):
|
||||
self.name: str = ""
|
||||
self.status: Status = Status.TODO.value
|
||||
self.logo_url: str = None
|
||||
self.args = {}
|
||||
self.api_result: str = None
|
||||
self.err_msg: str = None
|
||||
def __init__(self, plugin_generator: Any = None, display_registry: Any = None):
|
||||
# self.name: str = ""
|
||||
# self.status: Status = Status.TODO.value
|
||||
# self.logo_url: str = None
|
||||
# self.args = {}
|
||||
# self.api_result: str = None
|
||||
# self.err_msg: str = None
|
||||
|
||||
self.plugin_status_map = {}
|
||||
|
||||
self.plugin_generator = plugin_generator
|
||||
self.display_registry = display_registry
|
||||
self.start_time = datetime.now().timestamp() * 1000
|
||||
|
||||
def __repr__(self):
|
||||
return f"ApiCall(name={self.name}, status={self.status}, args={self.args})"
|
||||
|
||||
|
||||
def __is_need_wait_plugin_call(self, api_call_context):
|
||||
start_agent_count = api_call_context.count(self.agent_prefix)
|
||||
end_agent_count = api_call_context.count(self.agent_end)
|
||||
|
||||
if api_call_context.find(self.agent_prefix) >= 0:
|
||||
if start_agent_count > 0:
|
||||
return True
|
||||
check_len = len(self.agent_prefix)
|
||||
last_text = api_call_context[-check_len:]
|
||||
for i in range(check_len):
|
||||
text_tmp = last_text[-i:]
|
||||
prefix_tmp = self.agent_prefix[:i]
|
||||
if text_tmp == prefix_tmp:
|
||||
return True
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
# 末尾新出字符检测
|
||||
check_len = len(self.agent_prefix)
|
||||
last_text = api_call_context[-check_len:]
|
||||
for i in range(check_len):
|
||||
text_tmp = last_text[-i:]
|
||||
prefix_tmp = self.agent_prefix[:i]
|
||||
if text_tmp == prefix_tmp:
|
||||
return True
|
||||
else:
|
||||
i += 1
|
||||
return False
|
||||
|
||||
def __get_api_call_context(self, all_context):
|
||||
return extract_content_include(all_context, self.agent_prefix, self.agent_end)
|
||||
def __check_last_plugin_call_ready(self, all_context):
|
||||
start_agent_count = all_context.count(self.agent_prefix)
|
||||
end_agent_count = all_context.count(self.agent_end)
|
||||
|
||||
def __check_plugin_call_ready(self, all_context):
|
||||
if all_context.find(self.agent_end) > 0:
|
||||
if start_agent_count > 0 and start_agent_count == end_agent_count:
|
||||
return True
|
||||
return False
|
||||
|
||||
def api_view_context(self, all_context:str):
|
||||
if all_context.find(self.agent_prefix) >= 0:
|
||||
call_context = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end)
|
||||
call_context_all = extract_content_include_open_ending(all_context, self.agent_prefix, self.agent_end)
|
||||
if len(call_context) > 0:
|
||||
name_context = extract_content(call_context, self.name_prefix, self.name_end)
|
||||
if len(name_context) > 0:
|
||||
self.name = name_context
|
||||
return all_context.replace(call_context_all, self.to_view_text())
|
||||
else:
|
||||
return all_context
|
||||
def api_view_context(self, all_context: str, display_mode: bool = False):
|
||||
call_context_map = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end, True)
|
||||
for api_index, api_context in call_context_map.items():
|
||||
api_status = self.plugin_status_map.get(api_context)
|
||||
if api_status is not None:
|
||||
if display_mode:
|
||||
if api_status.api_result:
|
||||
all_context = all_context.replace(api_context, api_status.api_result)
|
||||
else:
|
||||
if api_status.status == Status.FAILED.value:
|
||||
all_context = all_context.replace(api_context, f"""\n<span style=\"color:red\">ERROR!</span>{api_status.err_msg}\n """)
|
||||
else:
|
||||
cost = (api_status.end_time - self.start_time) / 1000
|
||||
cost_str = "{:.2f}".format(cost)
|
||||
all_context = all_context.replace(api_context, f'\n<span style=\"color:green\">Waiting...{cost_str}S</span>\n')
|
||||
else:
|
||||
all_context = all_context.replace(api_context, self.to_view_text(api_status))
|
||||
else:
|
||||
# not ready api call view change
|
||||
now_time = datetime.now().timestamp() * 1000
|
||||
cost = (now_time - self.start_time) / 1000
|
||||
cost_str = "{:.2f}".format(cost)
|
||||
all_context = all_context.replace(api_context, f'\nWaiting... * {cost_str}S * \n')
|
||||
|
||||
return all_context
|
||||
|
||||
def update_from_context(self, all_context):
|
||||
logging.info(f"from_context:{all_context}")
|
||||
api_context = extract_content_include(all_context, self.agent_prefix, self.agent_end)
|
||||
api_context = api_context.replace("\\n", "").replace("\n", "")
|
||||
|
||||
api_call_element = ET.fromstring(api_context)
|
||||
self.name = api_call_element.find('name').text
|
||||
api_context_map = extract_content(all_context, self.agent_prefix, self.agent_end, True)
|
||||
for api_index, api_context in api_context_map.items():
|
||||
api_context = api_context.replace("\\n", "").replace("\n", "")
|
||||
api_call_element = ET.fromstring(api_context)
|
||||
api_name = api_call_element.find('name').text
|
||||
api_args = {}
|
||||
args_elements = api_call_element.find('args')
|
||||
for child_element in args_elements.iter():
|
||||
api_args[child_element.tag] = child_element.text
|
||||
|
||||
args_elements = api_call_element.find('args')
|
||||
for child_element in args_elements.iter():
|
||||
self.args[child_element.tag] = child_element.text
|
||||
api_status = self.plugin_status_map.get(api_context)
|
||||
if api_status is None:
|
||||
api_status = PluginStatus(name=api_name, location=[api_index], args=api_args)
|
||||
self.plugin_status_map[api_context] = api_status
|
||||
else:
|
||||
api_status.location.append(api_index)
|
||||
|
||||
def __to_view_param_str(self):
|
||||
def __to_view_param_str(self, api_status):
|
||||
param = {}
|
||||
if self.name:
|
||||
param['name'] = self.name
|
||||
param['status'] = self.status
|
||||
if self.logo_url:
|
||||
param['logo'] = self.logo_url
|
||||
if api_status.name:
|
||||
param['name'] = api_status.name
|
||||
param['status'] = api_status.status
|
||||
if api_status.logo_url:
|
||||
param['logo'] = api_status.logo_url
|
||||
|
||||
if self.err_msg:
|
||||
param['err_msg'] = self.err_msg
|
||||
|
||||
if self.api_result:
|
||||
param['result'] = self.api_result
|
||||
if api_status.err_msg:
|
||||
param['err_msg'] = api_status.err_msg
|
||||
|
||||
if api_status.api_result:
|
||||
param['result'] = api_status.api_result
|
||||
return json.dumps(param)
|
||||
|
||||
def to_view_text(self):
|
||||
def to_view_text(self, api_status: PluginStatus):
|
||||
api_call_element = ET.Element('dbgpt-view')
|
||||
api_call_element.text = self.__to_view_param_str()
|
||||
api_call_element.text = self.__to_view_param_str(api_status)
|
||||
result = ET.tostring(api_call_element, encoding="utf-8")
|
||||
return result.decode("utf-8")
|
||||
|
||||
|
||||
def run(self, llm_text):
|
||||
print(f"stream_plugin_call:{llm_text}")
|
||||
if self.__is_need_wait_plugin_call(llm_text):
|
||||
# wait api call generate complete
|
||||
if self.__check_plugin_call_ready(llm_text):
|
||||
if self.__check_last_plugin_call_ready(llm_text):
|
||||
self.update_from_context(llm_text)
|
||||
if self.status == Status.TODO.value:
|
||||
self.status = Status.RUNNING.value
|
||||
logging.info(f"插件执行:{self.name},{self.args}")
|
||||
try:
|
||||
self.api_result = execute_command(self.name, self.args, self.plugin_generator)
|
||||
self.status = Status.COMPLETED.value
|
||||
except Exception as e:
|
||||
self.status = Status.FAILED.value
|
||||
self.err_msg = str(e)
|
||||
return self.api_view_context(llm_text)
|
||||
for key, value in self.plugin_status_map:
|
||||
if value.status == Status.TODO.value:
|
||||
value.status = Status.RUNNING.value
|
||||
logging.info(f"插件执行:{value.name},{value.args}")
|
||||
try:
|
||||
value.api_result = execute_command(value.name, value.args, self.plugin_generator)
|
||||
value.status = Status.COMPLETED.value
|
||||
except Exception as e:
|
||||
value.status = Status.FAILED.value
|
||||
value.err_msg = str(e)
|
||||
value.end_time = datetime.now().timestamp() * 1000
|
||||
return self.api_view_context(llm_text)
|
||||
|
||||
def run_display_sql(self, llm_text, sql_run_func):
|
||||
print(f"get_display_sql:{llm_text}")
|
||||
|
||||
if self.__is_need_wait_plugin_call(llm_text):
|
||||
# wait api call generate complete
|
||||
if self.__check_last_plugin_call_ready(llm_text):
|
||||
self.update_from_context(llm_text)
|
||||
for key, value in self.plugin_status_map.items():
|
||||
if value.status == Status.TODO.value:
|
||||
value.status = Status.RUNNING.value
|
||||
logging.info(f"sql展示执行:{value.name},{value.args}")
|
||||
try:
|
||||
sql = value.args['sql']
|
||||
if sql:
|
||||
param = {
|
||||
"df": sql_run_func(sql),
|
||||
}
|
||||
if self.display_registry.get_command(value.name):
|
||||
value.api_result = self.display_registry.call(value.name, **param)
|
||||
else:
|
||||
value.api_result = self.display_registry.call("response_table", **param)
|
||||
|
||||
value.status = Status.COMPLETED.value
|
||||
except Exception as e:
|
||||
value.status = Status.FAILED.value
|
||||
value.err_msg = str(e)
|
||||
value.end_time = datetime.now().timestamp() * 1000
|
||||
return self.api_view_context(llm_text, True)
|
||||
|
@ -92,10 +92,10 @@ def zh_font_set():
|
||||
@command(
|
||||
"response_line_chart",
|
||||
"Line chart display, used to display comparative trend analysis data",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_line_chart(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_line_chart:{speak},")
|
||||
def response_line_chart( df: DataFrame) -> str:
|
||||
logger.info(f"response_line_chart")
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
|
||||
@ -148,17 +148,17 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||
|
||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
|
||||
|
||||
@command(
|
||||
"response_bar_chart",
|
||||
"Histogram, suitable for comparative analysis of multiple target values",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_bar_chart(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_bar_chart:{speak},")
|
||||
def response_bar_chart( df: DataFrame) -> str:
|
||||
logger.info(f"response_bar_chart")
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
|
||||
@ -236,17 +236,17 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
|
||||
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
|
||||
|
||||
@command(
|
||||
"response_pie_chart",
|
||||
"Pie chart, suitable for scenarios such as proportion and distribution statistics",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_pie_chart(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_pie_chart:{speak},")
|
||||
def response_pie_chart(df: DataFrame) -> str:
|
||||
logger.info(f"response_pie_chart")
|
||||
columns = df.columns.tolist()
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
@ -291,6 +291,6 @@ def response_pie_chart(speak: str, df: DataFrame) -> str:
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||
|
||||
html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
|
||||
return html_img
|
||||
|
@ -13,12 +13,12 @@ logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
||||
@command(
|
||||
"response_table",
|
||||
"Table display, suitable for display with many display columns or non-numeric columns",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_table(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_table:{speak}")
|
||||
def response_table(df: DataFrame) -> str:
|
||||
logger.info(f"response_table")
|
||||
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||
table_str = "".join(html_table.split())
|
||||
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
||||
view_text = html.replace("\n", " ")
|
||||
return view_text
|
||||
|
@ -12,10 +12,10 @@ logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
|
||||
@command(
|
||||
"response_data_text",
|
||||
"Text display, the default display method, suitable for single-line or simple content display",
|
||||
'"speak": "<speak>", "df":"<data frame>"',
|
||||
'"df":"<data frame>"',
|
||||
)
|
||||
def response_data_text(speak: str, df: DataFrame) -> str:
|
||||
logger.info(f"response_data_text:{speak}")
|
||||
def response_data_text(df: DataFrame) -> str:
|
||||
logger.info(f"response_data_text")
|
||||
data = df.values
|
||||
|
||||
row_size = data.shape[0]
|
||||
@ -25,7 +25,7 @@ def response_data_text(speak: str, df: DataFrame) -> str:
|
||||
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||
table_str = "".join(html_table.split())
|
||||
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||
text_info = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
|
||||
text_info = html.replace("\n", " ")
|
||||
elif row_size == 1:
|
||||
row = data[0]
|
||||
for value in row:
|
||||
@ -33,7 +33,7 @@ def response_data_text(speak: str, df: DataFrame) -> str:
|
||||
value_str = value_str + f", ** {value} **"
|
||||
else:
|
||||
value_str = f" ** {value} **"
|
||||
text_info = f"{speak}: {value_str}"
|
||||
text_info = f" {value_str}"
|
||||
else:
|
||||
text_info = f"##### {speak}: _没有找到可用的数据_"
|
||||
text_info = f"##### _没有找到可用的数据_"
|
||||
return text_info
|
||||
|
@ -1,49 +1,47 @@
|
||||
|
||||
|
||||
def extract_content(long_string, s1, s2):
|
||||
def extract_content(long_string, s1, s2, is_include: bool = False):
|
||||
# extract text
|
||||
match_map ={}
|
||||
start_index = long_string.find(s1)
|
||||
if start_index < 0:
|
||||
return ""
|
||||
end_index = long_string.find(s2, start_index + len(s1))
|
||||
extracted_content = long_string[start_index + len(s1):end_index]
|
||||
return extracted_content
|
||||
while start_index != -1:
|
||||
if is_include:
|
||||
end_index = long_string.find(s2, start_index + len(s1) + 1)
|
||||
extracted_content = long_string[start_index:end_index + len(s2)]
|
||||
else:
|
||||
end_index = long_string.find(s2, start_index + len(s1))
|
||||
extracted_content = long_string[start_index + len(s1):end_index]
|
||||
if extracted_content:
|
||||
match_map[start_index] = extracted_content
|
||||
start_index = long_string.find(s1, start_index + 1)
|
||||
return match_map
|
||||
|
||||
def extract_content_open_ending(long_string, s1, s2):
|
||||
def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
|
||||
# extract text open ending
|
||||
match_map = {}
|
||||
start_index = long_string.find(s1)
|
||||
if start_index < 0:
|
||||
return ""
|
||||
if long_string.find(s2) <=0:
|
||||
end_index = len(long_string)
|
||||
else:
|
||||
end_index = long_string.find(s2, start_index + len(s1))
|
||||
extracted_content = long_string[start_index + len(s1):end_index]
|
||||
return extracted_content
|
||||
|
||||
def extract_content_include(long_string, s1, s2):
|
||||
start_index = long_string.find(s1)
|
||||
if start_index < 0:
|
||||
return ""
|
||||
end_index = long_string.find(s2, start_index + len(s1) + 1)
|
||||
extracted_content = long_string[start_index:end_index + len(s2)]
|
||||
return extracted_content
|
||||
|
||||
def extract_content_include_open_ending(long_string, s1, s2):
|
||||
|
||||
start_index = long_string.find(s1)
|
||||
if start_index < 0:
|
||||
return ""
|
||||
if long_string.find(s2) <=0:
|
||||
end_index = len(long_string)
|
||||
else:
|
||||
end_index = long_string.find(s2, start_index + len(s1) + 1)
|
||||
extracted_content = long_string[start_index:end_index + len(s2)]
|
||||
return extracted_content
|
||||
while start_index != -1:
|
||||
if long_string.find(s2, start_index) <=0:
|
||||
end_index = len(long_string)
|
||||
else:
|
||||
if is_include:
|
||||
end_index = long_string.find(s2, start_index + len(s1) + 1)
|
||||
else:
|
||||
end_index = long_string.find(s2, start_index + len(s1))
|
||||
if is_include:
|
||||
extracted_content = long_string[start_index:end_index + len(s2)]
|
||||
else:
|
||||
extracted_content = long_string[start_index + len(s1):end_index]
|
||||
if extracted_content:
|
||||
match_map[start_index] = extracted_content
|
||||
start_index= long_string.find(s1, start_index + 1)
|
||||
return match_map
|
||||
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
s = "abcd123efghijkjhhh456"
|
||||
s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123"
|
||||
s1 = "123"
|
||||
s2 = "456"
|
||||
|
||||
print(extract_content_open_ending(s, s1, s2))
|
||||
print(extract_content_open_ending(s, s1, s2, True))
|
@ -34,7 +34,7 @@ class ChatAgent(BaseChat):
|
||||
agent_module = CFG.SYSTEM_APP.get_componet(ComponetType.AGENT_HUB, ModuleAgent)
|
||||
self.plugins_prompt_generator = agent_module.load_select_plugin(self.plugins_prompt_generator, self.select_plugins)
|
||||
|
||||
self.api_call = ApiCall(self.plugins_prompt_generator)
|
||||
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
|
||||
|
||||
def generate_input_values(self):
|
||||
input_values = {
|
||||
|
@ -15,27 +15,6 @@ class PluginAction(NamedTuple):
|
||||
|
||||
|
||||
class PluginChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
clean_json_str = super().parse_prompt_response(model_out_text)
|
||||
print(clean_json_str)
|
||||
if not clean_json_str:
|
||||
raise ValueError("model server response not have json!")
|
||||
try:
|
||||
response = json.loads(clean_json_str)
|
||||
except Exception as e:
|
||||
raise ValueError("model server out not fllow the prompt!")
|
||||
|
||||
speak = ""
|
||||
thoughts = ""
|
||||
for key in sorted(response):
|
||||
if key.strip() == "command":
|
||||
command = response[key]
|
||||
if key.strip() == "thoughts":
|
||||
thoughts = response[key]
|
||||
if key.strip() == "speak":
|
||||
speak = response[key]
|
||||
return PluginAction(command, speak, thoughts)
|
||||
|
||||
def parse_view_response(self, speak, data) -> str:
|
||||
### tool out data to table view
|
||||
print(f"parse_view_response:{speak},{str(data)}")
|
||||
|
@ -7,13 +7,13 @@ from pilot.scene.base_chat import BaseChat, logger
|
||||
from pilot.scene.base import ChatScene
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
|
||||
from pilot.base_modules.agent.commands.command_mange import ApiCall
|
||||
from pilot.scene.chat_data.chat_excel.excel_analyze.prompt import prompt
|
||||
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
|
||||
from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
|
||||
from pilot.common.path_utils import has_path
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
|
||||
|
||||
from pilot.base_modules.agent.common.schema import Status
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -36,9 +36,18 @@ class ChatExcel(BaseChat):
|
||||
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
|
||||
)
|
||||
)
|
||||
|
||||
self.api_call = ApiCall(display_registry = CFG.command_disply)
|
||||
super().__init__(chat_param=chat_param)
|
||||
|
||||
def _generate_numbered_list(self) -> str:
|
||||
command_strings = []
|
||||
if CFG.command_disply:
|
||||
command_strings += [
|
||||
str(item)
|
||||
for item in CFG.command_disply.commands.values()
|
||||
if item.enabled
|
||||
]
|
||||
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
|
||||
|
||||
def generate_input_values(self):
|
||||
input_values = {
|
||||
@ -64,15 +73,25 @@ class ChatExcel(BaseChat):
|
||||
result = await learn_chat.nostream_call()
|
||||
return result
|
||||
|
||||
def do_action(self, prompt_response):
|
||||
print(f"do_action:{prompt_response}")
|
||||
def stream_plugin_call(self, text):
|
||||
text = text.replace("\n", " ")
|
||||
print(f"stream_plugin_call:{text}")
|
||||
return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex)
|
||||
|
||||
# colunms, datas = self.excel_reader.run(prompt_response.sql)
|
||||
param = {
|
||||
"speak": prompt_response.thoughts,
|
||||
"df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql),
|
||||
}
|
||||
if CFG.command_disply.get_command(prompt_response.display):
|
||||
return CFG.command_disply.call(prompt_response.display, **param)
|
||||
else:
|
||||
return CFG.command_disply.call("response_table", **param)
|
||||
|
||||
|
||||
# def do_action(self, prompt_response):
|
||||
# print(f"do_action:{prompt_response}")
|
||||
#
|
||||
# # colunms, datas = self.excel_reader.run(prompt_response.sql)
|
||||
#
|
||||
#
|
||||
# param = {
|
||||
# "speak": prompt_response.thoughts,
|
||||
# "df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql),
|
||||
# }
|
||||
#
|
||||
# if CFG.command_disply.get_command(prompt_response.display):
|
||||
# return CFG.command_disply.call(prompt_response.display, **param)
|
||||
# else:
|
||||
# return CFG.command_disply.call("response_table", **param)
|
||||
|
@ -9,50 +9,52 @@ from pilot.common.schema import SeparatorStyle
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
|
||||
_PROMPT_SCENE_DEFINE_EN = "You are a data analysis expert. "
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
Please use the data structure information of the above historical dialogue, make sure not to use column names that are not in the data structure.
|
||||
According to the user goal: {user_input},give the correct duckdb SQL for data analysis.
|
||||
Use the table name: {table_name}
|
||||
Please use the data structure and column information in the above historical dialogue and combine it with data analysis to answer the user's questions while satisfying the constraints.
|
||||
|
||||
According to the analysis SQL obtained by the user's goal, select the best one from the following display forms, if it cannot be determined, use Text as the display,Just need to return the type name into the result.
|
||||
Display type:
|
||||
{disply_type}
|
||||
|
||||
Respond in the following json format:
|
||||
{response}
|
||||
Ensure the response is correct json and can be parsed by Python json.loads
|
||||
Constraint:
|
||||
1.Please output your thinking process and analysis ideas first, and then output the specific data analysis results. The data analysis results are output in the following format:<api-call><name>display type</name><args><sql>Correct duckdb data analysis sql</sql></args></api-call>
|
||||
2.For the available display methods of data analysis results, please choose the most appropriate one from the following display methods. If you are not sure, use 'response_data_text' as the display. The available display types are as follows:{disply_type}
|
||||
3.The table name that needs to be used in SQL is: {table_name}, please make sure not to use column names that are not in the data structure.
|
||||
4.Give priority to answering using data analysis. If the user's question does not involve data analysis, you can answer according to your understanding.
|
||||
|
||||
User Questions:
|
||||
{user_input}
|
||||
"""
|
||||
|
||||
_PROMPT_SCENE_DEFINE_ZH = """你是一个数据分析专家!"""
|
||||
_DEFAULT_TEMPLATE_ZH = """
|
||||
请使用上述历史对话中的数据结构和列信息,根据用户目标:{user_input},给出正确的duckdb SQL进行数据分析和问题回答。
|
||||
请确保不要使用不在数据结构中的列名。
|
||||
SQL中需要使用的表名是: {table_name}
|
||||
请使用上述历史对话中的数据结构信息,在满足约束条件下结合数据分析回答用户的问题。
|
||||
|
||||
根据用户目标得到的分析SQL,请从以下显示类型中选择最合适的一种用来展示结果数据,如果无法确定,则使用'Text'作为显示, 只需要将类型名称返回到结果中。
|
||||
显示类型如下:
|
||||
{disply_type}
|
||||
|
||||
以以下 json 格式响应::
|
||||
{response}
|
||||
确保响应是正确的json,并且可以被Python的json.loads方法解析.
|
||||
约束条件:
|
||||
1.请先输出你的分析思路内容,再输出具体的数据分析结果,其中数据分析结果部分确保用如下格式输出:<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
|
||||
2.请确保数据分析结果格式的内容在整个回答中只出现一次,确保上述结构稳定,把[]部分内容替换为对应的值
|
||||
3.数据分析结果可用的展示方式请在下面的展示方式中选择最合适的一种,放入数据分析结果的name字段内如果无法确定,则使用'Text'作为显示,可用显示类型如下: {disply_type}
|
||||
4.SQL中需要使用的表名是: {table_name},请确保不要使用不在数据结构中的列名。
|
||||
5.优先使用数据分析的方式回答,如果用户问题不涉及数据分析内容,你可以按你的理解进行回答
|
||||
6.请确保你的输出内容有良好的markdown格式的排版
|
||||
|
||||
用户问题:{user_input}
|
||||
请确保你的输出格式如下:
|
||||
需要告诉用户的分析思路.
|
||||
<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT_SIMPLE = {
|
||||
"sql": "analysis SQL",
|
||||
"thoughts": "Current thinking and value of data analysis",
|
||||
"display": "display type name",
|
||||
}
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
_PROMPT_SCENE_DEFINE = (
|
||||
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
|
||||
)
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
PROMPT_NEED_NEED_STREAM_OUT = False
|
||||
PROMPT_NEED_NEED_STREAM_OUT = True
|
||||
|
||||
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
|
||||
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
|
||||
@ -62,8 +64,7 @@ PROMPT_TEMPERATURE = 0.8
|
||||
prompt = PromptTemplate(
|
||||
template_scene=ChatScene.ChatExcel.value(),
|
||||
input_variables=["user_input", "table_name", "disply_type"],
|
||||
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
|
||||
template_define=PROMPT_SCENE_DEFINE,
|
||||
template_define=_PROMPT_SCENE_DEFINE,
|
||||
template=_DEFAULT_TEMPLATE,
|
||||
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
|
||||
output_parser=ChatExcelOutputParser(
|
||||
|
@ -9,7 +9,7 @@ from pilot.common.schema import SeparatorStyle
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
|
||||
_PROMPT_SCENE_DEFINE_EN = "You are a data analysis expert. "
|
||||
|
||||
_DEFAULT_TEMPLATE_EN = """
|
||||
This is an example data,please learn to understand the structure and content of this data:
|
||||
@ -21,6 +21,8 @@ Please return your answer in JSON format, the return format is as follows:
|
||||
{response}
|
||||
"""
|
||||
|
||||
_PROMPT_SCENE_DEFINE_ZH = "你是一个数据分析专家. "
|
||||
|
||||
_DEFAULT_TEMPLATE_ZH = """
|
||||
下面是一份示例数据,请学习理解该数据的结构和内容:
|
||||
{data_example}
|
||||
@ -37,10 +39,15 @@ RESPONSE_FORMAT_SIMPLE = {
|
||||
"AnalysisProgram": ["1.分析方案1,图表展示方式1", "2.分析方案2,图表展示方式2"],
|
||||
}
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
|
||||
)
|
||||
|
||||
PROMPT_SCENE_DEFINE =(
|
||||
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
|
||||
)
|
||||
|
||||
|
||||
PROMPT_SEP = SeparatorStyle.SINGLE.value
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user